LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
PredecessorAnalysisPass.cpp
Go to the documentation of this file.
1//===-- PredecessorPrinterPass.cpp ------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2026 Project LLZK
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
18
19#include <mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h>
20#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
21#include <mlir/Analysis/DataFlow/DenseAnalysis.h>
22#include <mlir/Analysis/DataFlowFramework.h>
23
24#include <llvm/ADT/MapVector.h>
25#include <llvm/ADT/SetVector.h>
26#include <llvm/Support/ErrorHandling.h>
27
28using namespace mlir;
29
30namespace llzk {
31
32using namespace function;
33
34#define GEN_PASS_DECL_PREDECESSORPRINTERPASS
35#define GEN_PASS_DEF_PREDECESSORPRINTERPASS
37
39raw_ostream &printRegionless(raw_ostream &os, Operation *op, bool withParent = false) {
40 std::string s;
41 llvm::raw_string_ostream ss(s);
42 if (withParent) {
43 if (auto fnOp = op->getParentOfType<FuncDefOp>()) {
44 os << '<' << fnOp.getFullyQualifiedName() << ">:";
45 } else {
46 os << "<(no parent function op)>:";
47 }
48 }
49 op->print(ss, mlir::OpPrintingFlags().skipRegions());
50 ss.flush();
51 // Skipping regions inserts a new line we don't want, so trim it here.
52 llvm::StringRef r(s);
53 os << r.rtrim();
54 return os;
55}
56
57class PredecessorLattice : public mlir::dataflow::AbstractDenseLattice {
59 llvm::MapVector<Operation *, llvm::SmallSetVector<Operation *, 4>> predecessors;
60
61public:
62 using AbstractDenseLattice::AbstractDenseLattice;
63
64 ChangeResult visit(Operation *op, Operation *pred) {
65 bool newlyInserted = predecessors[op].insert(pred);
66 return newlyInserted ? ChangeResult::Change : ChangeResult::NoChange;
67 }
68
69 ChangeResult join(const AbstractDenseLattice &rhs) override {
70 const auto *other = dynamic_cast<const PredecessorLattice *>(&rhs);
71 if (!other) {
72 llvm::report_fatal_error("wrong lattice type provided for join");
73 }
74 ChangeResult r = ChangeResult::NoChange;
75 for (const auto &[op, preds] : other->predecessors) {
76 for (auto *pred : preds) {
77 r |= visit(op, pred);
78 }
79 }
80 return r;
81 }
82
83 ChangeResult meet(const AbstractDenseLattice & /*rhs*/) override {
84 llvm::report_fatal_error("meet operation is not supported for PredecessorLattice");
85 return ChangeResult::NoChange;
86 }
87
88 void print(raw_ostream &os) const override {
89 if (predecessors.empty()) {
90 os << "(empty)\n";
91 return;
92 }
93 for (const auto &[k, v] : predecessors) {
94 os.indent(2);
95 printRegionless(os, k, true) << " predecessors:";
96 llvm::interleave(v, [&os](Operation *p) {
97 os << '\n';
98 printRegionless(os.indent(6), p, true);
99 }, []() {});
100 os << '\n';
101 }
102 }
103};
104
106 : public mlir::dataflow::DenseForwardDataFlowAnalysis<PredecessorLattice> {
109 [[maybe_unused]]
110 raw_ostream &os;
111
112 ProgramPoint *getPoint(const PredecessorLattice &l) const {
113 return dyn_cast<ProgramPoint *>(l.getAnchor());
114 }
115
116 ChangeResult
117 updateLattice(Operation *op, const PredecessorLattice &before, PredecessorLattice *after) {
118 ChangeResult result = after->join(before);
119 ProgramPoint *pointBefore = getProgramPointBefore(op);
120 auto *predState = getOrCreate<mlir::dataflow::PredecessorState>(pointBefore);
121 if (!predState->getKnownPredecessors().empty()) {
122 for (Operation *pred : predState->getKnownPredecessors()) {
123 result |= after->visit(op, pred);
124 }
125 } else {
126 // Predecessor is just the prior or parent op
127 Operation *pred = pointBefore->isBlockStart() ? op->getParentOp() : pointBefore->getPrevOp();
128 result |= after->visit(op, pred);
129 }
130 return result;
131 }
132
133public:
134 using Base = DenseForwardDataFlowAnalysis<PredecessorLattice>;
135
136 PredecessorAnalysis(DataFlowSolver &s, raw_ostream &ro) : Base(s), os(ro) {}
137
138 LogicalResult visitOperation(
139 Operation *op, const PredecessorLattice &before, PredecessorLattice *after
140 ) override {
141 ChangeResult result = updateLattice(op, before, after);
142 propagateIfChanged(after, result);
143 return success();
144 }
145
147 CallOpInterface call, mlir::dataflow::CallControlFlowAction action,
148 const PredecessorLattice &before, PredecessorLattice *after
149 ) override {
153 if (action == mlir::dataflow::CallControlFlowAction::EnterCallee) {
154 // We skip updating the incoming lattice for function calls to avoid a
155 // non-convergence scenario, as calling a function from other contexts
156 // can cause the lattice values to oscillate and constantly change.
157 setToEntryState(after);
158 }
162 else if (action == mlir::dataflow::CallControlFlowAction::ExitCallee) {
163 // Get the argument values of the lattice by getting the state as it would
164 // have been for the callsite.
165 const PredecessorLattice *beforeCall = getLattice(getProgramPointBefore(call));
166 ensure(beforeCall, "could not get prior lattice");
167 ChangeResult r = after->join(before);
168 // Perform a visit so that we see the call op in our lattice
169 r |= updateLattice(call, *beforeCall, after);
170 propagateIfChanged(after, r);
171 }
176 else if (action == mlir::dataflow::CallControlFlowAction::ExternalCallee) {
177 // For external calls, we propagate what information we already have from
178 // before the call to after the call, since the external call won't invalidate
179 // any of that information. It also, conservatively, makes no assumptions about
180 // external calls and their computation, so CDG edges will not be computed over
181 // input arguments to external functions.
182 join(after, before);
183 }
184 }
185
187 RegionBranchOpInterface branch, std::optional<unsigned> /*regionFrom*/,
188 std::optional<unsigned> /*regionTo*/, const PredecessorLattice &before,
189 PredecessorLattice *after
190 ) override {
191 // The default implementation is `join(after, before)`, but we want to
192 // show the predecessor logic for branch operations as well.
193 (void)visitOperation(branch, before, after);
194 }
195
196protected:
197 void setToEntryState(PredecessorLattice *lattice) override {}
198};
199
200class PredecessorPrinterPass : public impl::PredecessorPrinterPassBase<PredecessorPrinterPass> {
201
202public:
204
205protected:
206 void runOnOperation() override {
207 markAllAnalysesPreserved();
208 // Note: options like `outputStream` are safe to read here, but not in the
209 // pass constructor.
210 raw_ostream &os = toStream(outputStream);
211
212 DataFlowSolver solver;
214 ensure(
215 dataflow::loadAndRunRequiredAnalyses(solver, getOperation()).succeeded(),
216 "failed to pre-run!"
217 );
218 } else {
220 }
221 solver.load<PredecessorAnalysis>(os);
222 LogicalResult res = solver.initializeAndRun(getOperation());
223
224 if (res.failed()) {
225 llvm::report_fatal_error("PredecessorAnalysis failed.");
226 }
227
228 getOperation()->walk<WalkOrder::PreOrder>([&](FuncDefOp fnOp) {
229 Region &fnBody = fnOp.getFunctionBody();
230 if (fnBody.empty()) {
231 return WalkResult::skip();
232 }
233
234 ProgramPoint *point = solver.getProgramPointAfter(fnBody.back().getTerminator());
235 PredecessorLattice *finalLattice = solver.getOrCreateState<PredecessorLattice>(point);
236
237 printRegionless(os, fnOp.getOperation()) << ":\n" << *finalLattice << '\n';
238
239 return WalkResult::skip();
240 });
241 }
242};
243
244std::unique_ptr<mlir::Pass> createPredecessorPrinterPass() {
245 return std::make_unique<PredecessorPrinterPass>();
246}
247
248} // namespace llzk
void setToEntryState(PredecessorLattice *lattice) override
void visitCallControlFlowTransfer(CallOpInterface call, mlir::dataflow::CallControlFlowAction action, const PredecessorLattice &before, PredecessorLattice *after) override
DenseForwardDataFlowAnalysis< PredecessorLattice > Base
LogicalResult visitOperation(Operation *op, const PredecessorLattice &before, PredecessorLattice *after) override
PredecessorAnalysis(DataFlowSolver &s, raw_ostream &ro)
void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, std::optional< unsigned >, std::optional< unsigned >, const PredecessorLattice &before, PredecessorLattice *after) override
ChangeResult visit(Operation *op, Operation *pred)
ChangeResult join(const AbstractDenseLattice &rhs) override
ChangeResult meet(const AbstractDenseLattice &) override
void print(raw_ostream &os) const override
void loadRequiredAnalyses(DataFlowSolver &solver)
LogicalResult loadAndRunRequiredAnalyses(DataFlowSolver &solver, Operation *op)
std::unique_ptr< mlir::Pass > createPredecessorPrinterPass()
llvm::raw_ostream & toStream(OutputStream val)
void ensure(bool condition, const llvm::Twine &errMsg)
raw_ostream & printRegionless(raw_ostream &os, Operation *op, bool withParent=false)
Prints op without region.