LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKPolyLoweringPass.cpp
Go to the documentation of this file.
1//===-- LLZKPolyLoweringPass.cpp --------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// Copyright 2026 Project LLZK
7// SPDX-License-Identifier: Apache-2.0
8//
9//===----------------------------------------------------------------------===//
14//===----------------------------------------------------------------------===//
15
22
23#include <mlir/IR/BuiltinOps.h>
24
25#include <llvm/ADT/DenseMap.h>
26#include <llvm/ADT/DenseMapInfo.h>
27#include <llvm/ADT/SmallVector.h>
28#include <llvm/Support/Debug.h>
29
30#include <deque>
31#include <memory>
32
33// Include the generated base pass class definitions.
34namespace llzk {
35#define GEN_PASS_DEF_POLYLOWERINGPASS
37} // namespace llzk
38
39using namespace mlir;
40using namespace llzk;
41using namespace llzk::felt;
42using namespace llzk::function;
43using namespace llzk::component;
44using namespace llzk::constrain;
45
46#define DEBUG_TYPE "llzk-poly-lowering-pass"
47#define AUXILIARY_MEMBER_PREFIX "__llzk_poly_lowering_pass_aux_member_"
48
49namespace {
50
51struct AuxAssignment {
52 std::string auxMemberName;
53 Value computedValue;
54};
55
56class PassImpl : public llzk::impl::PolyLoweringPassBase<PassImpl> {
57 using Base = PolyLoweringPassBase<PassImpl>;
58 using Base::Base;
59
60 unsigned auxCounter = 0;
61
62 void collectStructDefs(ModuleOp modOp, SmallVectorImpl<StructDefOp> &structDefs) {
63 modOp.walk([&structDefs](StructDefOp structDef) {
64 structDefs.push_back(structDef);
65 return WalkResult::skip();
66 });
67 }
68
69 // Recursively compute degree of FeltOps SSA values
70 unsigned getDegree(Value val, DenseMap<Value, unsigned> &memo) {
71 if (auto it = memo.find(val); it != memo.end()) {
72 return it->second;
73 }
74 // Handle function parameters (BlockArguments)
75 if (llvm::isa<BlockArgument>(val)) {
76 memo[val] = 1;
77 return 1;
78 }
79 if (val.getDefiningOp<FeltConstantOp>()) {
80 return memo[val] = 0;
81 }
82 if (val.getDefiningOp<NonDetOp>()) {
83 return memo[val] = 1;
84 }
85 if (val.getDefiningOp<MemberReadOp>()) {
86 return memo[val] = 1;
87 }
88 if (auto addOp = val.getDefiningOp<AddFeltOp>()) {
89 return memo[val] = std::max(getDegree(addOp.getLhs(), memo), getDegree(addOp.getRhs(), memo));
90 }
91 if (auto subOp = val.getDefiningOp<SubFeltOp>()) {
92 return memo[val] = std::max(getDegree(subOp.getLhs(), memo), getDegree(subOp.getRhs(), memo));
93 }
94 if (auto mulOp = val.getDefiningOp<MulFeltOp>()) {
95 return memo[val] = getDegree(mulOp.getLhs(), memo) + getDegree(mulOp.getRhs(), memo);
96 }
97 if (auto divOp = val.getDefiningOp<DivFeltOp>()) {
98 return memo[val] = getDegree(divOp.getLhs(), memo) + getDegree(divOp.getRhs(), memo);
99 }
100 if (auto negOp = val.getDefiningOp<NegFeltOp>()) {
101 return memo[val] = getDegree(negOp.getOperand(), memo);
102 }
103
104 llvm_unreachable("Unhandled Felt SSA value in degree computation");
105 }
106
107 Value lowerExpression(
108 Value val, StructDefOp structDef, FuncDefOp constrainFunc,
109 DenseMap<Value, unsigned> &degreeMemo, DenseMap<Value, Value> &rewrites,
110 SmallVector<AuxAssignment> &auxAssignments
111 ) {
112 if (rewrites.count(val)) {
113 return rewrites[val];
114 }
115
116 unsigned degree = getDegree(val, degreeMemo);
117 if (degree <= maxDegree) {
118 rewrites[val] = val;
119 return val;
120 }
121
122 // Degree-neutral roots can still contain over-degree operands.
123 auto lowerBinaryRoot = [&](auto op) -> Value {
124 Value lhs = lowerExpression(
125 op.getLhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
126 );
127 Value rhs = lowerExpression(
128 op.getRhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
129 );
130
131 if (lhs != op.getLhs()) {
132 op.getLhsMutable().set(lhs);
133 }
134 if (rhs != op.getRhs()) {
135 op.getRhsMutable().set(rhs);
136 }
137 degreeMemo[val] = std::max(getDegree(lhs, degreeMemo), getDegree(rhs, degreeMemo));
138 rewrites[val] = val;
139 return val;
140 };
141
142 if (auto addOp = val.getDefiningOp<AddFeltOp>()) {
143 return lowerBinaryRoot(addOp);
144 }
145
146 if (auto subOp = val.getDefiningOp<SubFeltOp>()) {
147 return lowerBinaryRoot(subOp);
148 }
149
150 if (auto negOp = val.getDefiningOp<NegFeltOp>()) {
151 Value operand = lowerExpression(
152 negOp.getOperand(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
153 );
154
155 if (operand != negOp.getOperand()) {
156 negOp.getOperandMutable().set(operand);
157 }
158 degreeMemo[val] = getDegree(operand, degreeMemo);
159 rewrites[val] = val;
160 return val;
161 }
162
163 if (auto mulOp = val.getDefiningOp<MulFeltOp>()) {
164 // Recursively lower operands first
165 Value lhs = lowerExpression(
166 mulOp.getLhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
167 );
168 Value rhs = lowerExpression(
169 mulOp.getRhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
170 );
171
172 unsigned lhsDeg = getDegree(lhs, degreeMemo);
173 unsigned rhsDeg = getDegree(rhs, degreeMemo);
174
175 OpBuilder builder(mulOp.getOperation()->getBlock(), ++Block::iterator(mulOp));
176 Value selfVal = constrainFunc.getSelfValueFromConstrain();
177 bool eraseMul = lhsDeg + rhsDeg > maxDegree;
178 // Optimization: If lhs == rhs, factor it only once
179 if (lhs == rhs && eraseMul) {
180 std::string auxName = AUXILIARY_MEMBER_PREFIX + std::to_string(this->auxCounter++);
181 MemberDefOp auxMember = addAuxMember(structDef, auxName);
182
183 auto auxVal = builder.create<MemberReadOp>(
184 lhs.getLoc(), lhs.getType(), selfVal, auxMember.getNameAttr()
185 );
186 auxAssignments.push_back({auxName, lhs});
187 Location loc = builder.getFusedLoc({auxVal.getLoc(), lhs.getLoc()});
188 auto eqOp = builder.create<EmitEqualityOp>(loc, auxVal, lhs);
189
190 // Memoize auxVal as degree 1
191 degreeMemo[auxVal] = 1;
192 rewrites[lhs] = auxVal;
193 rewrites[rhs] = auxVal;
194 // Now selectively replace subsequent uses of lhs with auxVal
195 replaceSubsequentUsesWith(lhs, auxVal, eqOp);
196
197 // Update lhs and rhs to use auxVal
198 lhs = auxVal;
199 rhs = auxVal;
200
201 lhsDeg = rhsDeg = 1;
202 }
203 // While their product exceeds maxDegree, factor out one side
204 while (lhsDeg + rhsDeg > maxDegree) {
205 Value &toFactor = (lhsDeg >= rhsDeg) ? lhs : rhs;
206
207 // Create auxiliary member for toFactor
208 std::string auxName = AUXILIARY_MEMBER_PREFIX + std::to_string(this->auxCounter++);
209 MemberDefOp auxMember = addAuxMember(structDef, auxName);
210
211 // Read back as MemberReadOp (new SSA value)
212 auto auxVal = builder.create<MemberReadOp>(
213 toFactor.getLoc(), toFactor.getType(), selfVal, auxMember.getNameAttr()
214 );
215
216 // Emit constraint: auxVal == toFactor
217 Location loc = builder.getFusedLoc({auxVal.getLoc(), toFactor.getLoc()});
218 auto eqOp = builder.create<EmitEqualityOp>(loc, auxVal, toFactor);
219 auxAssignments.push_back({auxName, toFactor});
220 // Update memoization
221 rewrites[toFactor] = auxVal;
222 degreeMemo[auxVal] = 1; // stays same
223 // replace the term with auxVal.
224 replaceSubsequentUsesWith(toFactor, auxVal, eqOp);
225
226 // Remap toFactor to auxVal for next iterations
227 toFactor = auxVal;
228
229 // Recompute degrees
230 lhsDeg = getDegree(lhs, degreeMemo);
231 rhsDeg = getDegree(rhs, degreeMemo);
232 }
233
234 // Now lhs * rhs fits within degree bound
235 auto mulVal = builder.create<MulFeltOp>(lhs.getLoc(), lhs.getType(), lhs, rhs);
236 if (eraseMul) {
237 mulOp->replaceAllUsesWith(mulVal);
238 mulOp->erase();
239 }
240
241 // Result of this multiply has degree lhsDeg + rhsDeg
242 degreeMemo[mulVal] = lhsDeg + rhsDeg;
243 rewrites[val] = mulVal;
244
245 return mulVal;
246 }
247
248 // Unsupported roots are left unchanged.
249 rewrites[val] = val;
250 return val;
251 }
252
253 Value materializeCallArgument(
254 Value val, StructDefOp structDef, FuncDefOp constrainFunc, CallOp callOp,
255 DenseMap<Value, unsigned> &degreeMemo, DenseMap<Value, Value> &rewrites,
256 SmallVector<AuxAssignment> &auxAssignments
257 ) {
258 Value loweredVal =
259 lowerExpression(val, structDef, constrainFunc, degreeMemo, rewrites, auxAssignments);
260 DenseMap<Value, unsigned> checkMemo;
261 if (getDegree(loweredVal, checkMemo) <= 1) {
262 return loweredVal;
263 }
264
265 // Callees only receive SSA values, not the caller expression tree, so nonlinear
266 // call arguments must be represented by an auxiliary member read.
267 std::string auxName = AUXILIARY_MEMBER_PREFIX + std::to_string(this->auxCounter++);
268 MemberDefOp auxMember = addAuxMember(structDef, auxName);
269
270 OpBuilder builder(callOp);
271 Value selfVal = constrainFunc.getSelfValueFromConstrain();
272 auto auxVal = builder.create<MemberReadOp>(
273 loweredVal.getLoc(), loweredVal.getType(), selfVal, auxMember.getNameAttr()
274 );
275
276 Location loc = builder.getFusedLoc({auxVal.getLoc(), loweredVal.getLoc()});
277 builder.create<EmitEqualityOp>(loc, auxVal, loweredVal);
278 auxAssignments.push_back({auxName, loweredVal});
279
280 degreeMemo[auxVal] = 1;
281 rewrites[loweredVal] = auxVal;
282 rewrites[val] = auxVal;
283 return auxVal;
284 }
285
286 LogicalResult checkEqualityDegrees(FuncDefOp constrainFunc) {
287 bool failedCheck = false;
288
289 constrainFunc.walk([&](EmitEqualityOp eqOp) {
290 DenseMap<Value, unsigned> checkMemo;
291 unsigned lhsDegree = getDegree(eqOp.getLhs(), checkMemo);
292 unsigned rhsDegree = getDegree(eqOp.getRhs(), checkMemo);
293
294 if (lhsDegree > maxDegree || rhsDegree > maxDegree) {
295 auto diag = eqOp.emitOpError();
296 diag << "poly lowering postcondition failed: equality operand degree exceeds max-degree "
297 << maxDegree.getValue() << " (lhs degree " << lhsDegree << ", rhs degree " << rhsDegree
298 << ")";
299 diag.report();
300 failedCheck = true;
301 }
302 });
303
304 return failure(failedCheck);
305 }
306
307 LogicalResult checkStructConstrainCallArguments(FuncDefOp constrainFunc) {
308 bool failedCheck = false;
309
310 constrainFunc.walk([&](CallOp callOp) {
311 if (!callOp.calleeIsStructConstrain()) {
312 return;
313 }
314
315 for (Value arg : callOp.getArgOperands()) {
316 if (!llvm::isa<FeltType>(arg.getType())) {
317 continue;
318 }
319
320 DenseMap<Value, unsigned> checkMemo;
321 unsigned argDegree = getDegree(arg, checkMemo);
322 if (argDegree > 1) {
323 auto diag = callOp.emitOpError();
324 diag << "poly lowering postcondition failed: struct constrain call argument degree "
325 "exceeds 1 (argument degree "
326 << argDegree << ")";
327 diag.report();
328 failedCheck = true;
329 }
330 }
331 });
332
333 return failure(failedCheck);
334 }
335
336 void runOnOperation() override {
337 ModuleOp moduleOp = getOperation();
338
339 // Validate degree parameter
340 if (maxDegree < 2) {
341 auto diag = moduleOp.emitError();
342 diag << "Invalid max degree: " << maxDegree.getValue() << ". Must be >= 2.";
343 diag.report();
344 signalPassFailure();
345 return;
346 }
347
348 moduleOp.walk([this, &moduleOp](StructDefOp structDef) {
349 FuncDefOp constrainFunc = structDef.getConstrainFuncOp();
350 FuncDefOp computeFunc = structDef.getComputeFuncOp();
351 if (!constrainFunc) {
352 auto diag = structDef.emitOpError();
353 diag << '"' << structDef.getName() << "\" doesn't have a \"@" << FUNC_NAME_CONSTRAIN
354 << "\" function";
355 diag.report();
356 signalPassFailure();
357 return;
358 }
359
360 if (!computeFunc) {
361 auto diag = structDef.emitOpError();
362 diag << '"' << structDef.getName() << "\" doesn't have a \"@" << FUNC_NAME_COMPUTE
363 << "\" function";
364 diag.report();
365 signalPassFailure();
366 return;
367 }
368
369 if (failed(checkForAuxMemberConflicts(structDef, AUXILIARY_MEMBER_PREFIX))) {
370 signalPassFailure();
371 return;
372 }
373
374 DenseMap<Value, unsigned> degreeMemo;
375 DenseMap<Value, Value> rewrites;
376 SmallVector<AuxAssignment> auxAssignments;
377
378 // Lower equality constraints
379 constrainFunc.walk([&](EmitEqualityOp constraintOp) {
380 auto &lhsOperand = constraintOp.getLhsMutable();
381 auto &rhsOperand = constraintOp.getRhsMutable();
382 unsigned degreeLhs = getDegree(lhsOperand.get(), degreeMemo);
383 unsigned degreeRhs = getDegree(rhsOperand.get(), degreeMemo);
384
385 if (degreeLhs > maxDegree) {
386 Value loweredExpr = lowerExpression(
387 lhsOperand.get(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
388 );
389 lhsOperand.set(loweredExpr);
390 }
391 if (degreeRhs > maxDegree) {
392 Value loweredExpr = lowerExpression(
393 rhsOperand.get(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
394 );
395 rhsOperand.set(loweredExpr);
396 }
397 });
398
399 // The pass doesn't currently support EmitContainmentOp.
400 // See https://github.com/project-llzk/llzk-lib/issues/261
401 constrainFunc.walk([this, &moduleOp](EmitContainmentOp /*containOp*/) {
402 auto diag = moduleOp.emitError();
403 diag << "EmitContainmentOp is unsupported for now in the lowering pass";
404 diag.report();
405 signalPassFailure();
406 return;
407 });
408
409 // Lower function call arguments
410 constrainFunc.walk([&](CallOp callOp) {
411 if (callOp.calleeIsStructConstrain()) {
412 SmallVector<Value> newOperands = llvm::to_vector(callOp.getArgOperands());
413 bool modified = false;
414
415 for (Value &arg : newOperands) {
416 if (!llvm::isa<FeltType>(arg.getType())) {
417 continue;
418 }
419
420 DenseMap<Value, unsigned> callMemo;
421 unsigned deg = getDegree(arg, callMemo);
422
423 if (deg > 1) {
424 arg = materializeCallArgument(
425 arg, structDef, constrainFunc, callOp, degreeMemo, rewrites, auxAssignments
426 );
427 modified = true;
428 }
429 }
431 if (modified) {
432 OpBuilder builder(callOp);
433 builder.create<CallOp>(
434 callOp.getLoc(), callOp.getResultTypes(), callOp.getCallee(),
436 newOperands
437 );
438 callOp->erase();
440 }
441 });
443 if (failed(checkEqualityDegrees(constrainFunc))) {
444 signalPassFailure();
445 return;
446 }
448 if (failed(checkStructConstrainCallArguments(constrainFunc))) {
449 signalPassFailure();
450 return;
451 }
452
453 DenseMap<Value, Value> rebuildMemo;
454 Block &computeBlock = computeFunc.getBody().front();
455 OpBuilder builder(&computeBlock, computeBlock.getTerminator()->getIterator());
456 Value selfVal = computeFunc.getSelfValueFromCompute();
457
458 for (const auto &assign : auxAssignments) {
459 Value rebuiltExpr =
460 rebuildExprInCompute(assign.computedValue, computeFunc, builder, rebuildMemo);
461 builder.create<MemberWriteOp>(
462 assign.computedValue.getLoc(), selfVal, builder.getStringAttr(assign.auxMemberName),
463 rebuiltExpr
464 );
465 }
466 });
467 }
468};
469
470} // namespace
#define AUXILIARY_MEMBER_PREFIX
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:472
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:468
::mlir::OpOperand & getRhsMutable()
Definition Ops.h.inc:285
::mlir::TypedValue<::mlir::Type > getLhs()
Definition Ops.h.inc:272
::mlir::OpOperand & getLhsMutable()
Definition Ops.h.inc:280
::mlir::TypedValue<::mlir::Type > getRhs()
Definition Ops.h.inc:276
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:1145
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:470
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
Definition Ops.cpp.inc:480
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:266
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:270
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
Definition Ops.cpp:1186
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:457
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:476
::mlir::Region & getBody()
Definition Ops.h.inc:690
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
Value rebuildExprInCompute(Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap< Value, Value > &memo)
void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
MemberDefOp addAuxMember(StructDefOp structDef, StringRef name)
LogicalResult checkForAuxMemberConflicts(StructDefOp structDef, StringRef prefix)