LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
SparseAnalysis.cpp
Go to the documentation of this file.
1//===- SparseAnalysis.cpp - Sparse data-flow analysis ---------------------===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8// Adapted from mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp.
9// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
10// See https://llvm.org/LICENSE.txt for license information.
11// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
12//
13//
14//===----------------------------------------------------------------------===//
15
17
21
22#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
23#include <mlir/Analysis/DataFlow/SparseAnalysis.h>
24#include <mlir/Analysis/DataFlowFramework.h>
25#include <mlir/Dialect/SCF/IR/SCF.h>
26#include <mlir/IR/Attributes.h>
27#include <mlir/IR/Operation.h>
28#include <mlir/IR/Region.h>
29#include <mlir/IR/SymbolTable.h>
30#include <mlir/IR/Value.h>
31#include <mlir/IR/ValueRange.h>
32#include <mlir/Interfaces/CallInterfaces.h>
33#include <mlir/Interfaces/ControlFlowInterfaces.h>
34#include <mlir/Support/LLVM.h>
35
36#include <llvm/ADT/STLExtras.h>
37#include <llvm/Support/Casting.h>
38
39#include <cassert>
40#include <optional>
41
42using namespace mlir;
43using namespace mlir::dataflow;
44using namespace llzk::function;
45
46namespace llzk::dataflow {
47
48//===----------------------------------------------------------------------===//
49// AbstractSparseForwardDataFlowAnalysis
50//===----------------------------------------------------------------------===//
51
53 : DataFlowAnalysis(s) {
54 registerAnchorKind<CFGEdge>();
55}
56
58 // Mark the entry block arguments as having reached their pessimistic
59 // fixpoints.
60 for (Region &region : top->getRegions()) {
61 if (region.empty()) {
62 continue;
63 }
64 for (Value argument : region.front().getArguments()) {
66 }
67 }
68
69 return initializeRecursively(top);
70}
71
72LogicalResult AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
73 // Initialize the analysis by visiting every owner of an SSA value (all
74 // operations and blocks).
75 if (failed(visitOperation(op))) {
76 return failure();
77 }
78
79 for (Region &region : op->getRegions()) {
80 for (Block &block : region) {
81 getOrCreate<Executable>(getProgramPointBefore(&block))->blockContentSubscribe(this);
82 visitBlock(&block);
83 // LLZK: Renamed "op" -> "containedOp" to avoid shadowing.
84 for (Operation &containedOp : block) {
85 if (failed(initializeRecursively(&containedOp))) {
86 return failure();
87 }
88 }
89 }
90 }
91
92 return success();
93}
94
95LogicalResult AbstractSparseForwardDataFlowAnalysis::visit(ProgramPoint *point) {
96 if (!point->isBlockStart()) {
97 return visitOperation(point->getPrevOp());
98 }
99 visitBlock(point->getBlock());
100 return success();
101}
102
103LogicalResult AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
107
108 // If the containing block is not executable, bail out.
109 if (op->getBlock() != nullptr &&
110 !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive()) {
111 return success();
112 }
113
114 // Get the result lattices.
115 SmallVector<AbstractSparseLattice *> resultLattices;
116 resultLattices.reserve(op->getNumResults());
117 for (Value result : op->getResults()) {
118 AbstractSparseLattice *resultLattice = getLatticeElement(result);
119 resultLattices.push_back(resultLattice);
120 }
121
122 // The results of a region branch operation are determined by control-flow.
123 if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
124 visitRegionSuccessors(
125 getProgramPointAfter(branch), branch,
126 /*successor=*/RegionBranchPoint::parent(), resultLattices
127 );
128 return success();
129 }
130
131 // Grab the lattice elements of the operands.
132 SmallVector<const AbstractSparseLattice *> operandLattices;
133 operandLattices.reserve(op->getNumOperands());
134 for (Value operand : op->getOperands()) {
135 AbstractSparseLattice *operandLattice = getLatticeElement(operand);
136 operandLattice->useDefSubscribe(this);
137 operandLattices.push_back(operandLattice);
138 }
139
140 if (auto call = dyn_cast<CallOpInterface>(op)) {
141 // If the call operation is to an external function, attempt to infer the
142 // results from the call arguments.
143 auto callable = dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
144 if (!getSolverConfig().isInterprocedural() || (callable && !callable.getCallableRegion())) {
145 visitExternalCallImpl(call, operandLattices, resultLattices);
146 return success();
147 }
148
149 // Otherwise, the results of a call operation are determined by the
150 // callgraph.
151 const auto *predecessors =
152 getOrCreateFor<PredecessorState>(getProgramPointAfter(op), getProgramPointAfter(call));
153 // If not all return sites are known, then conservatively assume we can't
154 // reason about the data-flow.
155 if (!predecessors->allPredecessorsKnown()) {
156 setAllToEntryStates(resultLattices);
157 return success();
158 }
159 for (Operation *predecessor : predecessors->getKnownPredecessors()) {
160 for (auto &&[operand, resLattice] : llvm::zip(predecessor->getOperands(), resultLattices)) {
161 join(resLattice, *getLatticeElementFor(getProgramPointAfter(op), operand));
162 }
163 }
164 return success();
165 }
166
167 // Invoke the operation transfer function.
168 return visitOperationImpl(op, operandLattices, resultLattices);
169}
170
171void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
172 // Exit early on blocks with no arguments.
173 if (block->getNumArguments() == 0) {
174 return;
175 }
176
177 // If the block is not executable, bail out.
178 if (!getOrCreate<Executable>(getProgramPointBefore(block))->isLive()) {
179 return;
180 }
181
182 // Get the argument lattices.
183 SmallVector<AbstractSparseLattice *> argLattices;
184 argLattices.reserve(block->getNumArguments());
185 for (BlockArgument argument : block->getArguments()) {
186 AbstractSparseLattice *argLattice = getLatticeElement(argument);
187 argLattices.push_back(argLattice);
188 }
189
190 // The argument lattices of entry blocks are set by region control-flow or the
191 // callgraph.
192 if (block->isEntryBlock()) {
193 // Check if this block is the entry block of a callable region.
194 auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
195 if (callable && callable.getCallableRegion() == block->getParent()) {
196 const auto *callsites = getOrCreateFor<PredecessorState>(
197 getProgramPointBefore(block), getProgramPointAfter(callable)
198 );
199 // If not all callsites are known, conservatively mark all lattices as
200 // having reached their pessimistic fixpoints.
201 if (!callsites->allPredecessorsKnown() || !getSolverConfig().isInterprocedural()) {
202 return setAllToEntryStates(argLattices);
203 }
204 for (Operation *callsite : callsites->getKnownPredecessors()) {
205 auto call = cast<CallOpInterface>(callsite);
206 for (auto it : llvm::zip(call.getArgOperands(), argLattices)) {
207 join(
208 std::get<1>(it), *getLatticeElementFor(getProgramPointBefore(block), std::get<0>(it))
209 );
210 }
211 }
212 return;
213 }
214
215 // Check if the lattices can be determined from region control flow.
216 if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
217 return visitRegionSuccessors(
218 getProgramPointBefore(block), branch, block->getParent(), argLattices
219 );
220 }
221
222 // Otherwise, we can't reason about the data-flow.
224 block->getParentOp(), RegionSuccessor(block->getParent()), argLattices, /*firstIndex=*/0
225 );
226 }
227
228 // Iterate over the predecessors of the non-entry block.
229 for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
230 Block *predecessor = *it;
231
232 // If the edge from the predecessor block to the current block is not live,
233 // bail out.
234 auto *edgeExecutable = getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(predecessor, block));
235 edgeExecutable->blockContentSubscribe(this);
236 if (!edgeExecutable->isLive()) {
237 continue;
238 }
239
240 // Check if we can reason about the data-flow from the predecessor.
241 if (auto branch = dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
242 SuccessorOperands operands = branch.getSuccessorOperands(it.getSuccessorIndex());
243 for (auto [idx, lattice] : llvm::enumerate(argLattices)) {
244 if (Value operand = operands[idx]) {
245 join(lattice, *getLatticeElementFor(getProgramPointBefore(block), operand));
246 } else {
247 // Conservatively consider internally produced arguments as entry
248 // points.
249 setAllToEntryStates(lattice);
250 }
251 }
252 } else {
253 return setAllToEntryStates(argLattices);
254 }
255 }
256}
257
258void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
259 ProgramPoint *point, RegionBranchOpInterface branch, RegionBranchPoint successor,
260 ArrayRef<AbstractSparseLattice *> lattices
261) {
262 const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
263 assert(predecessors->allPredecessorsKnown() && "unexpected unresolved region successors");
264
265 for (Operation *op : predecessors->getKnownPredecessors()) {
266 // Get the incoming successor operands.
267 std::optional<OperandRange> operands;
268
269 // Check if the predecessor is the parent op.
270 if (op == branch) {
271 operands = branch.getEntrySuccessorOperands(successor);
272 // Otherwise, try to deduce the operands from a region return-like op.
273 } else if (auto regionTerminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
274 operands = regionTerminator.getSuccessorOperands(successor);
275 }
276
277 if (!operands) {
278 // We can't reason about the data-flow.
279 return setAllToEntryStates(lattices);
280 }
281
282 ValueRange inputs = predecessors->getSuccessorInputs(op);
283 assert(
284 inputs.size() == operands->size() &&
285 "expected the same number of successor inputs as operands"
286 );
287
288 unsigned firstIndex = 0;
289 if (inputs.size() != lattices.size()) {
290 if (!point->isBlockStart()) {
291 if (!inputs.empty()) {
292 firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
293 }
295 branch, RegionSuccessor(branch->getResults().slice(firstIndex, inputs.size())),
296 lattices, firstIndex
297 );
298 } else {
299 if (!inputs.empty()) {
300 firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
301 }
302 Region *region = point->getBlock()->getParent();
304 branch,
305 RegionSuccessor(region, region->getArguments().slice(firstIndex, inputs.size())),
306 lattices, firstIndex
307 );
308 }
309 }
310
311 for (auto it : llvm::zip(*operands, lattices.drop_front(firstIndex))) {
312 join(std::get<1>(it), *getLatticeElementFor(point, std::get<0>(it)));
313 }
314 }
315}
316
320 addDependency(state, point);
321 return state;
322}
323
325 ArrayRef<AbstractSparseLattice *> lattices
326) {
327 for (AbstractSparseLattice *lattice : lattices) {
328 setToEntryState(lattice);
329 }
330}
331
334) {
335 propagateIfChanged(lhs, lhs->join(rhs));
336}
337
338} // namespace llzk::dataflow
This file implements sparse data-flow analysis using the data-flow analysis framework.
mlir::LogicalResult visit(mlir::ProgramPoint *point) override
Visit a program point.
mlir::LogicalResult initialize(mlir::Operation *top) override
Initialize the analysis by visiting every owner of an SSA value: all operations and blocks.
virtual mlir::LogicalResult visitOperationImpl(mlir::Operation *op, mlir::ArrayRef< const AbstractSparseLattice * > operandLattices, mlir::ArrayRef< AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
virtual AbstractSparseLattice * getLatticeElement(mlir::Value value)=0
Get the lattice element of a value.
AbstractSparseForwardDataFlowAnalysis(mlir::DataFlowSolver &solver)
void setAllToEntryStates(mlir::ArrayRef< AbstractSparseLattice * > lattices)
const AbstractSparseLattice * getLatticeElementFor(mlir::ProgramPoint *point, mlir::Value value)
Get a read-only lattice element for a value and add it as a dependency to a program point.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
virtual void setToEntryState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow entry point(s).
virtual void visitNonControlFlowArgumentsImpl(mlir::Operation *op, const mlir::RegionSuccessor &successor, mlir::ArrayRef< AbstractSparseLattice * > argLattices, unsigned firstIndex)=0
Given an operation with region control-flow, the lattices of the operands, and a region successor,...
virtual void visitExternalCallImpl(mlir::CallOpInterface call, mlir::ArrayRef< const AbstractSparseLattice * > argumentLattices, mlir::ArrayRef< AbstractSparseLattice * > resultLattices)=0
The transfer function for calls to external functions.
mlir::dataflow::AbstractSparseLattice AbstractSparseLattice