LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKPolyLoweringPass.cpp
Go to the documentation of this file.
1//===-- LLZKPolyLoweringPass.cpp --------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// Copyright 2026 Project LLZK
7// SPDX-License-Identifier: Apache-2.0
8//
9//===----------------------------------------------------------------------===//
14//===----------------------------------------------------------------------===//
15
22
23#include <mlir/IR/BuiltinOps.h>
24
25#include <llvm/ADT/DenseMap.h>
26#include <llvm/ADT/DenseMapInfo.h>
27#include <llvm/ADT/SmallVector.h>
28#include <llvm/Support/Debug.h>
29
30#include <deque>
31#include <memory>
32
33// Include the generated base pass class definitions.
34namespace llzk {
35#define GEN_PASS_DECL_POLYLOWERINGPASS
36#define GEN_PASS_DEF_POLYLOWERINGPASS
38} // namespace llzk
39
40using namespace mlir;
41using namespace llzk;
42using namespace llzk::felt;
43using namespace llzk::function;
44using namespace llzk::component;
45using namespace llzk::constrain;
46
47#define DEBUG_TYPE "llzk-poly-lowering-pass"
48#define AUXILIARY_MEMBER_PREFIX "__llzk_poly_lowering_pass_aux_member_"
49
50namespace {
51
52struct AuxAssignment {
53 std::string auxMemberName;
54 Value computedValue;
55};
56
57class PolyLoweringPass : public llzk::impl::PolyLoweringPassBase<PolyLoweringPass> {
58public:
59 void setMaxDegree(unsigned degree) { this->maxDegree = degree; }
60
61private:
62 unsigned auxCounter = 0;
63
64 void collectStructDefs(ModuleOp modOp, SmallVectorImpl<StructDefOp> &structDefs) {
65 modOp.walk([&structDefs](StructDefOp structDef) {
66 structDefs.push_back(structDef);
67 return WalkResult::skip();
68 });
69 }
70
71 // Recursively compute degree of FeltOps SSA values
72 unsigned getDegree(Value val, DenseMap<Value, unsigned> &memo) {
73 if (auto it = memo.find(val); it != memo.end()) {
74 return it->second;
75 }
76 // Handle function parameters (BlockArguments)
77 if (llvm::isa<BlockArgument>(val)) {
78 memo[val] = 1;
79 return 1;
80 }
81 if (val.getDefiningOp<FeltConstantOp>()) {
82 return memo[val] = 0;
83 }
84 if (val.getDefiningOp<NonDetOp>()) {
85 return memo[val] = 1;
86 }
87 if (val.getDefiningOp<MemberReadOp>()) {
88 return memo[val] = 1;
89 }
90 if (auto addOp = val.getDefiningOp<AddFeltOp>()) {
91 return memo[val] = std::max(getDegree(addOp.getLhs(), memo), getDegree(addOp.getRhs(), memo));
92 }
93 if (auto subOp = val.getDefiningOp<SubFeltOp>()) {
94 return memo[val] = std::max(getDegree(subOp.getLhs(), memo), getDegree(subOp.getRhs(), memo));
95 }
96 if (auto mulOp = val.getDefiningOp<MulFeltOp>()) {
97 return memo[val] = getDegree(mulOp.getLhs(), memo) + getDegree(mulOp.getRhs(), memo);
98 }
99 if (auto divOp = val.getDefiningOp<DivFeltOp>()) {
100 return memo[val] = getDegree(divOp.getLhs(), memo) + getDegree(divOp.getRhs(), memo);
101 }
102 if (auto negOp = val.getDefiningOp<NegFeltOp>()) {
103 return memo[val] = getDegree(negOp.getOperand(), memo);
104 }
105
106 llvm_unreachable("Unhandled Felt SSA value in degree computation");
107 }
108
109 Value lowerExpression(
110 Value val, StructDefOp structDef, FuncDefOp constrainFunc,
111 DenseMap<Value, unsigned> &degreeMemo, DenseMap<Value, Value> &rewrites,
112 SmallVector<AuxAssignment> &auxAssignments
113 ) {
114 if (rewrites.count(val)) {
115 return rewrites[val];
116 }
117
118 unsigned degree = getDegree(val, degreeMemo);
119 if (degree <= maxDegree) {
120 rewrites[val] = val;
121 return val;
122 }
123
124 if (auto mulOp = val.getDefiningOp<MulFeltOp>()) {
125 // Recursively lower operands first
126 Value lhs = lowerExpression(
127 mulOp.getLhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
128 );
129 Value rhs = lowerExpression(
130 mulOp.getRhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
131 );
132
133 unsigned lhsDeg = getDegree(lhs, degreeMemo);
134 unsigned rhsDeg = getDegree(rhs, degreeMemo);
135
136 OpBuilder builder(mulOp.getOperation()->getBlock(), ++Block::iterator(mulOp));
137 Value selfVal = constrainFunc.getSelfValueFromConstrain();
138 bool eraseMul = lhsDeg + rhsDeg > maxDegree;
139 // Optimization: If lhs == rhs, factor it only once
140 if (lhs == rhs && eraseMul) {
141 std::string auxName = AUXILIARY_MEMBER_PREFIX + std::to_string(this->auxCounter++);
142 MemberDefOp auxMember = addAuxMember(structDef, auxName);
143
144 auto auxVal = builder.create<MemberReadOp>(
145 lhs.getLoc(), lhs.getType(), selfVal, auxMember.getNameAttr()
146 );
147 auxAssignments.push_back({auxName, lhs});
148 Location loc = builder.getFusedLoc({auxVal.getLoc(), lhs.getLoc()});
149 auto eqOp = builder.create<EmitEqualityOp>(loc, auxVal, lhs);
150
151 // Memoize auxVal as degree 1
152 degreeMemo[auxVal] = 1;
153 rewrites[lhs] = auxVal;
154 rewrites[rhs] = auxVal;
155 // Now selectively replace subsequent uses of lhs with auxVal
156 replaceSubsequentUsesWith(lhs, auxVal, eqOp);
157
158 // Update lhs and rhs to use auxVal
159 lhs = auxVal;
160 rhs = auxVal;
161
162 lhsDeg = rhsDeg = 1;
163 }
164 // While their product exceeds maxDegree, factor out one side
165 while (lhsDeg + rhsDeg > maxDegree) {
166 Value &toFactor = (lhsDeg >= rhsDeg) ? lhs : rhs;
167
168 // Create auxiliary member for toFactor
169 std::string auxName = AUXILIARY_MEMBER_PREFIX + std::to_string(this->auxCounter++);
170 MemberDefOp auxMember = addAuxMember(structDef, auxName);
171
172 // Read back as MemberReadOp (new SSA value)
173 auto auxVal = builder.create<MemberReadOp>(
174 toFactor.getLoc(), toFactor.getType(), selfVal, auxMember.getNameAttr()
175 );
176
177 // Emit constraint: auxVal == toFactor
178 Location loc = builder.getFusedLoc({auxVal.getLoc(), toFactor.getLoc()});
179 auto eqOp = builder.create<EmitEqualityOp>(loc, auxVal, toFactor);
180 auxAssignments.push_back({auxName, toFactor});
181 // Update memoization
182 rewrites[toFactor] = auxVal;
183 degreeMemo[auxVal] = 1; // stays same
184 // replace the term with auxVal.
185 replaceSubsequentUsesWith(toFactor, auxVal, eqOp);
186
187 // Remap toFactor to auxVal for next iterations
188 toFactor = auxVal;
189
190 // Recompute degrees
191 lhsDeg = getDegree(lhs, degreeMemo);
192 rhsDeg = getDegree(rhs, degreeMemo);
193 }
194
195 // Now lhs * rhs fits within degree bound
196 auto mulVal = builder.create<MulFeltOp>(lhs.getLoc(), lhs.getType(), lhs, rhs);
197 if (eraseMul) {
198 mulOp->replaceAllUsesWith(mulVal);
199 mulOp->erase();
200 }
201
202 // Result of this multiply has degree lhsDeg + rhsDeg
203 degreeMemo[mulVal] = lhsDeg + rhsDeg;
204 rewrites[val] = mulVal;
205
206 return mulVal;
207 }
208
209 // For non-mul ops, leave untouched (they're degree-1 safe)
210 rewrites[val] = val;
211 return val;
212 }
213
214 void runOnOperation() override {
215 ModuleOp moduleOp = getOperation();
216
217 // Validate degree parameter
218 if (maxDegree < 2) {
219 auto diag = moduleOp.emitError();
220 diag << "Invalid max degree: " << maxDegree.getValue() << ". Must be >= 2.";
221 diag.report();
222 signalPassFailure();
223 return;
224 }
225
226 moduleOp.walk([this, &moduleOp](StructDefOp structDef) {
227 FuncDefOp constrainFunc = structDef.getConstrainFuncOp();
228 FuncDefOp computeFunc = structDef.getComputeFuncOp();
229 if (!constrainFunc) {
230 auto diag = structDef.emitOpError();
231 diag << '"' << structDef.getName() << "\" doesn't have a \"@" << FUNC_NAME_CONSTRAIN
232 << "\" function";
233 diag.report();
234 signalPassFailure();
235 return;
236 }
237
238 if (!computeFunc) {
239 auto diag = structDef.emitOpError();
240 diag << '"' << structDef.getName() << "\" doesn't have a \"@" << FUNC_NAME_COMPUTE
241 << "\" function";
242 diag.report();
243 signalPassFailure();
244 return;
245 }
246
247 if (failed(checkForAuxMemberConflicts(structDef, AUXILIARY_MEMBER_PREFIX))) {
248 signalPassFailure();
249 return;
250 }
251
252 DenseMap<Value, unsigned> degreeMemo;
253 DenseMap<Value, Value> rewrites;
254 SmallVector<AuxAssignment> auxAssignments;
255
256 // Lower equality constraints
257 constrainFunc.walk([&](EmitEqualityOp constraintOp) {
258 auto &lhsOperand = constraintOp.getLhsMutable();
259 auto &rhsOperand = constraintOp.getRhsMutable();
260 unsigned degreeLhs = getDegree(lhsOperand.get(), degreeMemo);
261 unsigned degreeRhs = getDegree(rhsOperand.get(), degreeMemo);
262
263 if (degreeLhs > maxDegree) {
264 Value loweredExpr = lowerExpression(
265 lhsOperand.get(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
266 );
267 lhsOperand.set(loweredExpr);
268 }
269 if (degreeRhs > maxDegree) {
270 Value loweredExpr = lowerExpression(
271 rhsOperand.get(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
272 );
273 rhsOperand.set(loweredExpr);
274 }
275 });
276
277 // The pass doesn't currently support EmitContainmentOp.
278 // See https://github.com/project-llzk/llzk-lib/issues/261
279 constrainFunc.walk([this, &moduleOp](EmitContainmentOp /*containOp*/) {
280 auto diag = moduleOp.emitError();
281 diag << "EmitContainmentOp is unsupported for now in the lowering pass";
282 diag.report();
283 signalPassFailure();
284 return;
285 });
286
287 // Lower function call arguments
288 constrainFunc.walk([&](CallOp callOp) {
289 if (callOp.calleeIsStructConstrain()) {
290 SmallVector<Value> newOperands = llvm::to_vector(callOp.getArgOperands());
291 bool modified = false;
292
293 for (Value &arg : newOperands) {
294 unsigned deg = getDegree(arg, degreeMemo);
295
296 if (deg > 1) {
297 Value loweredArg = lowerExpression(
298 arg, structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
299 );
300 arg = loweredArg;
301 modified = true;
302 }
303 }
304
305 if (modified) {
306 OpBuilder builder(callOp);
307 builder.create<CallOp>(
308 callOp.getLoc(), callOp.getResultTypes(), callOp.getCallee(),
310 newOperands
311 );
312 callOp->erase();
313 }
314 }
315 });
316
317 DenseMap<Value, Value> rebuildMemo;
318 Block &computeBlock = computeFunc.getBody().front();
319 OpBuilder builder(&computeBlock, computeBlock.getTerminator()->getIterator());
320 Value selfVal = computeFunc.getSelfValueFromCompute();
321
322 for (const auto &assign : auxAssignments) {
323 Value rebuiltExpr =
324 rebuildExprInCompute(assign.computedValue, computeFunc, builder, rebuildMemo);
325 builder.create<MemberWriteOp>(
326 assign.computedValue.getLoc(), selfVal, builder.getStringAttr(assign.auxMemberName),
327 rebuiltExpr
328 );
329 }
330 });
331 }
332};
333} // namespace
334
335std::unique_ptr<mlir::Pass> llzk::createPolyLoweringPass() {
336 return std::make_unique<PolyLoweringPass>();
337};
338
339std::unique_ptr<mlir::Pass> llzk::createPolyLoweringPass(unsigned maxDegree) {
340 auto pass = std::make_unique<PolyLoweringPass>();
341 static_cast<PolyLoweringPass *>(pass.get())->setMaxDegree(maxDegree);
342 return pass;
343}
#define AUXILIARY_MEMBER_PREFIX
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:430
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:426
::mlir::OpOperand & getRhsMutable()
Definition Ops.h.inc:285
::mlir::OpOperand & getLhsMutable()
Definition Ops.h.inc:280
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:772
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:467
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
Definition Ops.cpp.inc:472
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:241
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:245
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
Definition Ops.cpp:813
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:354
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:373
::mlir::Region & getBody()
Definition Ops.h.inc:607
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
Value rebuildExprInCompute(Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap< Value, Value > &memo)
void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
MemberDefOp addAuxMember(StructDefOp structDef, StringRef name)
std::unique_ptr< mlir::Pass > createPolyLoweringPass()
LogicalResult checkForAuxMemberConflicts(StructDefOp structDef, StringRef prefix)