LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKComputeConstrainToProductPass.cpp
Go to the documentation of this file.
1//===-- LLZKComputeConstrainToProductPass.cpp -------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
16
24
25#include <mlir/IR/Builders.h>
26#include <mlir/IR/SymbolTable.h>
27#include <mlir/Transforms/InliningUtils.h>
28
29#include <llvm/Support/Debug.h>
31#include <memory>
32#include <ranges>
33
34namespace llzk {
35#define GEN_PASS_DECL_COMPUTECONSTRAINTOPRODUCTPASS
36#define GEN_PASS_DEF_COMPUTECONSTRAINTOPRODUCTPASS
38} // namespace llzk
40#define DEBUG_TYPE "llzk-compute-constrain-to-product-pass"
41
42using namespace llzk::component;
43using namespace llzk::function;
44using namespace mlir;
46using std::make_unique;
48namespace llzk {
49
51 FuncDefOp computeFunc = root.getComputeFuncOp();
52 FuncDefOp constrainFunc = root.getConstrainFuncOp();
54 if (!computeFunc || !constrainFunc) {
55 root->emitError()
56 .append("no ", FUNC_NAME_COMPUTE, "/", FUNC_NAME_CONSTRAIN, " to align")
57 .report();
58 return false;
59 }
60
61 /// TODO: If root::@compute and root::@constrain are called anywhere else, this is not a valid
62 /// root to start aligning from (issue #241)
63
64 return true;
65}
67LogicalResult alignStartingAt(
68 component::StructDefOp root, SymbolTableCollection &tables,
70) {
71 if (!isValidRoot(root)) {
72 return failure();
73 }
74
75 ProductAligner aligner {tables, equivalence};
76 if (!aligner.alignFuncs(root, root.getComputeFuncOp(), root.getConstrainFuncOp())) {
77 return failure();
78 }
79
80 return success();
81}
82
84 : public llzk::impl::ComputeConstrainToProductPassBase<ComputeConstrainToProductPass> {
85
86public:
87 void runOnOperation() override {
88 ModuleOp mod = getOperation();
89 StructDefOp root;
90
91 SymbolTableCollection tables;
93 getAnalysis<LightweightSignalEquivalenceAnalysis>()
94 };
95
96 // Find the indicated root struct and make sure its a valid place to start aligning
97 mod.walk([&root, this](StructDefOp structDef) {
98 if (structDef.getSymName() == rootStruct) {
99 root = structDef;
100 }
101 });
102
103 if (failed(alignStartingAt(root, tables, equivalence))) {
104 signalPassFailure();
105 }
106 }
107};
108
110
111 if (auto prod = root.getProductFuncOp()) {
112 return prod;
113 }
114
115 OpBuilder funcBuilder(compute);
116
117 // Add compute/constrain attributes
118 compute.walk([&funcBuilder](Operation *op) {
119 op->setAttr(PRODUCT_SOURCE, funcBuilder.getStringAttr(FUNC_NAME_COMPUTE));
120 });
121
122 constrain.walk([&funcBuilder](Operation *op) {
123 op->setAttr(PRODUCT_SOURCE, funcBuilder.getStringAttr(FUNC_NAME_CONSTRAIN));
124 });
125
126 // Create an empty @product func...
127 FuncDefOp productFunc = funcBuilder.create<FuncDefOp>(
128 funcBuilder.getFusedLoc({compute.getLoc(), constrain.getLoc()}), FUNC_NAME_PRODUCT,
129 compute.getFunctionType()
130 );
131 productFunc->setAttr(DERIVED_ATTR_NAME, UnitAttr::get(funcBuilder.getContext()));
132 Block *entryBlock = productFunc.addEntryBlock();
133 funcBuilder.setInsertionPointToStart(entryBlock);
134
136 compute.hasAllowNonNativeFieldOpsAttr() || constrain.hasAllowNonNativeFieldOpsAttr()
137 );
138
139 // ...with the right arguments
140 llvm::SmallVector<Value> args {productFunc.getArguments()};
141
142 // Add calls to @compute and @constrain...
143 CallOp computeCall = funcBuilder.create<CallOp>(funcBuilder.getUnknownLoc(), compute, args);
144 args.insert(args.begin(), computeCall->getResult(0));
145 CallOp constrainCall = funcBuilder.create<CallOp>(funcBuilder.getUnknownLoc(), constrain, args);
146 funcBuilder.create<ReturnOp>(funcBuilder.getUnknownLoc(), computeCall->getResult(0));
147
148 // ..and inline them
149 InlinerInterface inliner(productFunc.getContext());
150 if (failed(inlineCall(inliner, computeCall, compute, &compute.getBody(), true))) {
151 root->emitError().append("failed to inline ", FUNC_NAME_COMPUTE).report();
152 return nullptr;
153 }
154 if (failed(inlineCall(inliner, constrainCall, constrain, &constrain.getBody(), true))) {
155 root->emitError().append("failed to inline ", FUNC_NAME_CONSTRAIN).report();
156 return nullptr;
157 }
158 computeCall->erase();
159 constrainCall->erase();
160
161 // Mark the compute/constrain for deletion
162 alignedStructs.push_back(root);
163
164 // Make sure we can align sub-calls to @compute and @constrain
165 if (failed(alignCalls(productFunc))) {
166 return nullptr;
167 }
168 return productFunc;
169}
170
171LogicalResult ProductAligner::alignCalls(FuncDefOp product) {
172 // Gather up all the remaining calls to @compute and @constrain
173 llvm::SetVector<CallOp> computeCalls, constrainCalls;
174 product.walk([&](CallOp callOp) {
175 if (callOp.calleeIsStructCompute()) {
176 computeCalls.insert(callOp);
177 } else if (callOp.calleeIsStructConstrain()) {
178 constrainCalls.insert(callOp);
179 }
180 });
181
182 llvm::SetVector<std::pair<CallOp, CallOp>> alignedCalls;
183
184 // A @compute matches a @constrain if they belong to the same struct and all their input signals
185 // are pairwise equivalent
186 auto doCallsMatch = [&](CallOp compute, CallOp constrain) -> bool {
187 LLVM_DEBUG({
188 llvm::outs() << "Asking for equivalence between calls\n"
189 << compute << "\nand\n"
190 << constrain << "\n\n";
191 llvm::outs() << "In block:\n\n" << *compute->getBlock() << "\n";
192 });
193
194 auto computeStruct = getPrefixAsSymbolRefAttr(compute.getCallee());
195 auto constrainStruct = getPrefixAsSymbolRefAttr(constrain.getCallee());
196 if (computeStruct != constrainStruct) {
197 return false;
198 }
199 if (compute.getNumOperands() == 0) {
200 return true;
201 }
202 for (unsigned i = 0, e = compute->getNumOperands() - 1; i < e; i++) {
203 if (!equivalence.areSignalsEquivalent(compute->getOperand(i), constrain->getOperand(i + 1))) {
204 return false;
205 }
206 }
207
208 return true;
209 };
210
211 for (auto compute : computeCalls) {
212 // If there is exactly one @compute that matches a given @constrain, we can align them
213 auto matches = llvm::filter_to_vector(constrainCalls, [&](CallOp constrain) {
214 return doCallsMatch(compute, constrain);
215 });
216
217 if (matches.size() == 1) {
218 alignedCalls.insert({compute, matches[0]});
219 computeCalls.remove(compute);
220 constrainCalls.remove(matches[0]);
221 }
222 }
223
224 if (!computeCalls.empty() && constrainCalls.empty()) {
225 product.emitWarning()
226 .append("failed to align some @", FUNC_NAME_COMPUTE, " and @", FUNC_NAME_CONSTRAIN)
227 .report();
228 }
229
230 for (auto [compute, constrain] : alignedCalls) {
231 // If @A::@compute matches @A::@constrain, recursively align the functions in @A...
232 auto calleeTgt = compute.getCalleeTarget(tables);
233 if (failed(calleeTgt)) {
234 return failure();
235 }
236 auto newRoot = calleeTgt->get()->getParentOfType<StructDefOp>();
237 assert(newRoot);
238 FuncDefOp newProduct =
239 alignFuncs(newRoot, newRoot.getComputeFuncOp(), newRoot.getConstrainFuncOp());
240 if (!newProduct) {
241 return failure();
242 }
243
244 // ...and replace the two calls with a single call to @A::@product
245 OpBuilder callBuilder(compute);
246 CallOp newCall = callBuilder.create<CallOp>(
247 callBuilder.getFusedLoc({compute.getLoc(), constrain.getLoc()}), newProduct,
248 compute.getOperands()
249 );
250 compute->replaceAllUsesWith(newCall.getResults());
251 }
252
253 return success();
254}
255
256std::unique_ptr<Pass> createComputeConstrainToProductPass() {
257 return make_unique<ComputeConstrainToProductPass>();
258}
259
260} // namespace llzk
std::vector< component::StructDefOp > alignedStructs
function::FuncDefOp alignFuncs(component::StructDefOp root, function::FuncDefOp compute, function::FuncDefOp constrain)
mlir::LogicalResult alignCalls(function::FuncDefOp product)
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:1600
::llzk::function::FuncDefOp getProductFuncOp()
Gets the FuncDefOp that defines the product function in this structure, if present,...
Definition Ops.cpp:474
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:470
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:466
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:1038
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:1032
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:470
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
Definition Ops.cpp:1054
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:984
void setAllowNonNativeFieldOpsAttr(bool newValue=true)
Add (resp. remove) the allow_non_native_field_ops attribute to (resp. from) the function def.
Definition Ops.cpp:226
bool hasAllowNonNativeFieldOpsAttr()
Return true iff the function def has the allow_non_native_field_ops attribute.
Definition Ops.h.inc:820
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
::mlir::Region & getBody()
Definition Ops.h.inc:690
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char PRODUCT_SOURCE[]
Name of the attribute on aligned product program ops that specifies where they came from.
Definition Constants.h:40
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
LogicalResult alignStartingAt(component::StructDefOp root, SymbolTableCollection &tables, LightweightSignalEquivalenceAnalysis &equivalence)
std::unique_ptr< Pass > createComputeConstrainToProductPass()
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:18
constexpr char DERIVED_ATTR_NAME[]
Name of the attribute on a @product func that has been automatically aligned from @compute + @constra...
Definition Constants.h:28
bool isValidRoot(StructDefOp root)