LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
MemberOverwriteAnalysis.cpp
Go to the documentation of this file.
1//===-- MemberOverwriteAnalysis.cpp -----------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2026 Project LLZK
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
15
16#include <llvm/ADT/TypeSwitch.h>
17#include <llvm/Support/Debug.h>
18
19#define DEBUG_TYPE "member-overwrite-analysis"
20
21using namespace mlir;
22using namespace mlir::dataflow;
23
24namespace llzk {
25using namespace component;
26
27llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const MemberOverwriteLattice &lat) {
28 os << lat.mustWrites;
29 return os;
30}
31
33 auto name = write.getMemberName();
34
35 bool changed = false;
36
37 if (auto it = mayWrites.find(name); it != mayWrites.end() && it->second != write) {
38 // .insert(...) returns true if an insertion was performed (i.e., it wasn't present before),
39 // meaning there was a change
40 changed |= overwrites.insert({mayWrites.at(name), write});
41 } else {
42 mayWrites.insert({name, write});
43 changed = true;
44 }
45
46 changed |= mustWrites.insert(name);
47 return ChangeResult {changed};
48}
49
50ChangeResult MemberOverwriteLattice::join(const AbstractDenseLattice &other) {
51 const auto *rhs = dynamic_cast<const MemberOverwriteLattice *>(&other);
52 ensure(rhs, "cannot join incomparable lattices");
53
54 LLVM_DEBUG(
55 llvm::dbgs() << "Joining " << *dyn_cast<ProgramPoint *>(getAnchor()) << "(" << *this
56 << ") with " << *dyn_cast<ProgramPoint *>(rhs->getAnchor()) << "(" << *rhs
57 << ")\n"
58 );
59 bool changed = false;
60
61 // Union the mayWrites
62 for (auto [name, write] : rhs->mayWrites) {
63 auto it = mayWrites.find(name);
64 changed |= it == mayWrites.end() || it->second != write;
65 mayWrites[name] = write;
66 }
67 changed |= overwrites.set_union(rhs->overwrites);
68
69 // "Intersect" the mustWrites
70 changed |= mustWrites.intersect(rhs->mustWrites);
71
72 return ChangeResult {changed};
73}
74
75bool MemberOverwriteLattice::hasOverwrites() const { return !overwrites.empty(); }
76
77llvm::SetVector<Overwrite> MemberOverwriteLattice::getOverwrites() const { return overwrites; }
78
80 return mustWrites.contains(memberDef.getSymName());
81}
82
83void MemberOverwriteLattice::print(llvm::raw_ostream &os) const { os << *this << '\n'; }
84
86 Operation *op, const MemberOverwriteLattice &before, MemberOverwriteLattice *after
87) {
88 ChangeResult result = after->join(before);
89
90 LLVM_DEBUG(llvm::dbgs() << "Visiting operation: " << *op << ": " << before << "\n");
91
92 if (auto write = dyn_cast<MemberWriteOp>(op)) {
93 result |= after->record(write);
94 }
95
96 propagateIfChanged(after, result);
97 return success();
98}
99
100llvm::FailureOr<std::pair<llvm::SetVector<Overwrite>, FuzzySet>>
102 function::FuncDefOp computeOrProductFunc = structDef.getComputeFuncOp();
103 if (!computeOrProductFunc) {
104 computeOrProductFunc = structDef.getProductFuncOp();
105 }
106
107 DataFlowSolver solver {DataFlowConfig {}.setInterprocedural(false)};
109 solver.load<MemberOverwriteAnalysis>();
110 if (failed(solver.initializeAndRun(computeOrProductFunc))) {
111 return llvm::failure();
112 }
113
114 auto &funcBody = computeOrProductFunc.getBody();
115 if (funcBody.empty()) {
116 // If there's nothing, just build a default lattice element (no overwrites, everything is
117 // unwritten)
118 return {{{}, {}}};
119 }
120
121 auto *returnOp = funcBody.back().getTerminator();
122 const auto *lattice =
123 solver.lookupState<MemberOverwriteLattice>(solver.getProgramPointAfter(returnOp));
124 return {{lattice->getOverwrites(), lattice->mustWrites}};
125}
126
127} // namespace llzk
Represents a set where the membership predicate can take three values: true, false,...
mlir::LogicalResult visitOperation(mlir::Operation *op, const MemberOverwriteLattice &before, MemberOverwriteLattice *after) override
void print(llvm::raw_ostream &os) const override
llvm::SetVector< Overwrite > getOverwrites() const
bool checkWritten(component::MemberDefOp) const
mlir::ChangeResult join(const mlir::dataflow::AbstractDenseLattice &other) override
mlir::ChangeResult record(component::MemberWriteOp write)
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:528
::llvm::StringRef getMemberName()
Definition Ops.cpp.inc:1315
::llzk::function::FuncDefOp getProductFuncOp()
Gets the FuncDefOp that defines the product function in this structure, if present,...
Definition Ops.cpp:474
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:466
::mlir::Region & getBody()
Definition Ops.h.inc:690
void loadRequiredAnalyses(DataFlowSolver &solver)
void ensure(bool condition, const llvm::Twine &errMsg)
Interval operator<<(const Interval &lhs, const Interval &rhs)
llvm::FailureOr< std::pair< llvm::SetVector< Overwrite >, FuzzySet > > analyzeStruct(component::StructDefOp structDef)