LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
PredecessorAnalysisPass.cpp
Go to the documentation of this file.
1//===-- PredecessorPrinterPass.cpp ------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2026 Project LLZK
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
18
19#include <mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h>
20#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
21#include <mlir/Analysis/DataFlow/DenseAnalysis.h>
22#include <mlir/Analysis/DataFlowFramework.h>
23
24#include <llvm/ADT/MapVector.h>
25#include <llvm/ADT/SetVector.h>
26#include <llvm/Support/ErrorHandling.h>
27
28using namespace mlir;
29
30namespace llzk {
31
32#define GEN_PASS_DEF_PREDECESSORPRINTERPASS
34
36raw_ostream &printRegionless(raw_ostream &os, Operation *op, bool withParent = false) {
37 std::string s;
38 llvm::raw_string_ostream ss(s);
39 if (withParent) {
40 if (auto fnOp = op->getParentOfType<function::FuncDefOp>()) {
41 os << '<' << fnOp.getFullyQualifiedName() << ">:";
42 } else {
43 os << "<(no parent function op)>:";
44 }
45 }
46 op->print(ss, OpPrintingFlags().skipRegions());
47 ss.flush();
48 // Skipping regions inserts a new line we don't want, so trim it here.
49 llvm::StringRef r(s);
50 os << r.rtrim();
51 return os;
52}
53
54class PredecessorLattice : public mlir::dataflow::AbstractDenseLattice {
56 llvm::MapVector<Operation *, llvm::SmallSetVector<Operation *, 4>> predecessors;
57
58public:
59 using AbstractDenseLattice::AbstractDenseLattice;
60
61 ChangeResult visit(Operation *op, Operation *pred) {
62 bool newlyInserted = predecessors[op].insert(pred);
63 return newlyInserted ? ChangeResult::Change : ChangeResult::NoChange;
64 }
65
66 ChangeResult join(const AbstractDenseLattice &rhs) override {
67 const auto *other = dynamic_cast<const PredecessorLattice *>(&rhs);
68 if (!other) {
69 llvm::report_fatal_error("wrong lattice type provided for join");
70 }
71 ChangeResult r = ChangeResult::NoChange;
72 for (const auto &[op, preds] : other->predecessors) {
73 for (auto *pred : preds) {
74 r |= visit(op, pred);
75 }
76 }
77 return r;
78 }
79
80 ChangeResult meet(const AbstractDenseLattice & /*rhs*/) override {
81 llvm::report_fatal_error("meet operation is not supported for PredecessorLattice");
82 return ChangeResult::NoChange;
83 }
84
85 void print(raw_ostream &os) const override {
86 if (predecessors.empty()) {
87 os << "(empty)\n";
88 return;
89 }
90 for (const auto &[k, v] : predecessors) {
91 os.indent(2);
92 printRegionless(os, k, true) << " predecessors:";
93 llvm::interleave(v, [&os](Operation *p) {
94 os << '\n';
95 printRegionless(os.indent(6), p, true);
96 }, []() {});
97 os << '\n';
98 }
99 }
100};
101
103 : public mlir::dataflow::DenseForwardDataFlowAnalysis<PredecessorLattice> {
106 [[maybe_unused]]
107 raw_ostream &os;
108
109 ProgramPoint *getPoint(const PredecessorLattice &l) const {
110 return dyn_cast<ProgramPoint *>(l.getAnchor());
111 }
112
113 ChangeResult
114 updateLattice(Operation *op, const PredecessorLattice &before, PredecessorLattice *after) {
115 ChangeResult result = after->join(before);
116 ProgramPoint *pointBefore = getProgramPointBefore(op);
117 auto *predState = getOrCreate<mlir::dataflow::PredecessorState>(pointBefore);
118 if (!predState->getKnownPredecessors().empty()) {
119 for (Operation *pred : predState->getKnownPredecessors()) {
120 result |= after->visit(op, pred);
121 }
122 } else {
123 // Predecessor is just the prior or parent op
124 Operation *pred = pointBefore->isBlockStart() ? op->getParentOp() : pointBefore->getPrevOp();
125 result |= after->visit(op, pred);
126 }
127 return result;
128 }
129
130public:
131 using Base = DenseForwardDataFlowAnalysis<PredecessorLattice>;
132
133 PredecessorAnalysis(DataFlowSolver &s, raw_ostream &ro) : Base(s), os(ro) {}
134
135 LogicalResult visitOperation(
136 Operation *op, const PredecessorLattice &before, PredecessorLattice *after
137 ) override {
138 ChangeResult result = updateLattice(op, before, after);
139 propagateIfChanged(after, result);
140 return success();
141 }
142
144 CallOpInterface call, mlir::dataflow::CallControlFlowAction action,
145 const PredecessorLattice &before, PredecessorLattice *after
146 ) override {
150 if (action == mlir::dataflow::CallControlFlowAction::EnterCallee) {
151 // We skip updating the incoming lattice for function calls to avoid a
152 // non-convergence scenario, as calling a function from other contexts
153 // can cause the lattice values to oscillate and constantly change.
154 setToEntryState(after);
155 }
159 else if (action == mlir::dataflow::CallControlFlowAction::ExitCallee) {
160 // Get the argument values of the lattice by getting the state as it would
161 // have been for the callsite.
162 const PredecessorLattice *beforeCall = getLattice(getProgramPointBefore(call));
163 ensure(beforeCall, "could not get prior lattice");
164 ChangeResult r = after->join(before);
165 // Perform a visit so that we see the call op in our lattice
166 r |= updateLattice(call, *beforeCall, after);
167 propagateIfChanged(after, r);
168 }
173 else if (action == mlir::dataflow::CallControlFlowAction::ExternalCallee) {
174 // For external calls, we propagate what information we already have from
175 // before the call to after the call, since the external call won't invalidate
176 // any of that information. It also, conservatively, makes no assumptions about
177 // external calls and their computation, so CDG edges will not be computed over
178 // input arguments to external functions.
179 join(after, before);
180 }
181 }
182
184 RegionBranchOpInterface branch, std::optional<unsigned> /*regionFrom*/,
185 std::optional<unsigned> /*regionTo*/, const PredecessorLattice &before,
186 PredecessorLattice *after
187 ) override {
188 // The default implementation is `join(after, before)`, but we want to
189 // show the predecessor logic for branch operations as well.
190 (void)visitOperation(branch, before, after);
191 }
192
193protected:
194 void setToEntryState(PredecessorLattice *lattice) override {}
195};
196
197} // namespace llzk
198
199namespace {
200
201using namespace llzk;
202
203class PassImpl : public llzk::impl::PredecessorPrinterPassBase<PassImpl> {
204 using Base = PredecessorPrinterPassBase<PassImpl>;
205 using Base::Base;
206
207 void runOnOperation() override {
208 markAllAnalysesPreserved();
209 // Note: options like `outputStream` are safe to read here, but not in the
210 // pass constructor.
211 raw_ostream &os = toStream(outputStream);
212
213 DataFlowSolver solver;
214 if (preRunRequiredAnalyses) {
215 ensure(
216 llzk::dataflow::loadAndRunRequiredAnalyses(solver, getOperation()).succeeded(),
217 "failed to pre-run!"
218 );
219 } else {
221 }
222 solver.load<PredecessorAnalysis>(os);
223 LogicalResult res = solver.initializeAndRun(getOperation());
224
225 if (res.failed()) {
226 llvm::report_fatal_error("PredecessorAnalysis failed.");
227 }
228
229 getOperation()->walk<WalkOrder::PreOrder>([&](function::FuncDefOp fnOp) {
230 Region &fnBody = fnOp.getFunctionBody();
231 if (fnBody.empty()) {
232 return WalkResult::skip();
233 }
234
235 ProgramPoint *point = solver.getProgramPointAfter(fnBody.back().getTerminator());
236 PredecessorLattice *finalLattice = solver.getOrCreateState<PredecessorLattice>(point);
237
238 printRegionless(os, fnOp.getOperation()) << ":\n" << *finalLattice << '\n';
239
240 return WalkResult::skip();
241 });
242 }
243};
244
245} // namespace
void setToEntryState(PredecessorLattice *lattice) override
void visitCallControlFlowTransfer(CallOpInterface call, mlir::dataflow::CallControlFlowAction action, const PredecessorLattice &before, PredecessorLattice *after) override
DenseForwardDataFlowAnalysis< PredecessorLattice > Base
LogicalResult visitOperation(Operation *op, const PredecessorLattice &before, PredecessorLattice *after) override
PredecessorAnalysis(DataFlowSolver &s, raw_ostream &ro)
void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, std::optional< unsigned >, std::optional< unsigned >, const PredecessorLattice &before, PredecessorLattice *after) override
ChangeResult visit(Operation *op, Operation *pred)
ChangeResult join(const AbstractDenseLattice &rhs) override
ChangeResult meet(const AbstractDenseLattice &) override
void print(raw_ostream &os) const override
void loadRequiredAnalyses(DataFlowSolver &solver)
LogicalResult loadAndRunRequiredAnalyses(DataFlowSolver &solver, Operation *op)
llvm::raw_ostream & toStream(OutputStream val)
void ensure(bool condition, const llvm::Twine &errMsg)
raw_ostream & printRegionless(raw_ostream &os, Operation *op, bool withParent=false)
Prints op without region.