23#include <mlir/IR/BuiltinOps.h>
25#include <llvm/ADT/DenseMap.h>
26#include <llvm/ADT/DenseMapInfo.h>
27#include <llvm/ADT/SmallVector.h>
28#include <llvm/Support/Debug.h>
35#define GEN_PASS_DEF_POLYLOWERINGPASS
46#define DEBUG_TYPE "llzk-poly-lowering-pass"
47#define AUXILIARY_MEMBER_PREFIX "__llzk_poly_lowering_pass_aux_member_"
52 std::string auxMemberName;
57 using Base = PolyLoweringPassBase<PassImpl>;
60 unsigned auxCounter = 0;
62 void collectStructDefs(ModuleOp modOp, SmallVectorImpl<StructDefOp> &structDefs) {
64 structDefs.push_back(structDef);
65 return WalkResult::skip();
70 unsigned getDegree(Value val, DenseMap<Value, unsigned> &memo) {
71 if (
auto it = memo.find(val); it != memo.end()) {
75 if (llvm::isa<BlockArgument>(val)) {
88 if (
auto addOp = val.getDefiningOp<
AddFeltOp>()) {
89 return memo[val] = std::max(getDegree(addOp.getLhs(), memo), getDegree(addOp.getRhs(), memo));
91 if (
auto subOp = val.getDefiningOp<
SubFeltOp>()) {
92 return memo[val] = std::max(getDegree(subOp.getLhs(), memo), getDegree(subOp.getRhs(), memo));
94 if (
auto mulOp = val.getDefiningOp<
MulFeltOp>()) {
95 return memo[val] = getDegree(mulOp.getLhs(), memo) + getDegree(mulOp.getRhs(), memo);
97 if (
auto divOp = val.getDefiningOp<
DivFeltOp>()) {
98 return memo[val] = getDegree(divOp.getLhs(), memo) + getDegree(divOp.getRhs(), memo);
100 if (
auto negOp = val.getDefiningOp<
NegFeltOp>()) {
101 return memo[val] = getDegree(negOp.getOperand(), memo);
104 llvm_unreachable(
"Unhandled Felt SSA value in degree computation");
107 Value lowerExpression(
109 DenseMap<Value, unsigned> °reeMemo, DenseMap<Value, Value> &rewrites,
110 SmallVector<AuxAssignment> &auxAssignments
112 if (rewrites.count(val)) {
113 return rewrites[val];
116 unsigned degree = getDegree(val, degreeMemo);
117 if (degree <= maxDegree) {
123 auto lowerBinaryRoot = [&](
auto op) -> Value {
124 Value lhs = lowerExpression(
125 op.getLhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
127 Value rhs = lowerExpression(
128 op.getRhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
131 if (lhs != op.getLhs()) {
132 op.getLhsMutable().set(lhs);
134 if (rhs != op.getRhs()) {
135 op.getRhsMutable().set(rhs);
137 degreeMemo[val] = std::max(getDegree(lhs, degreeMemo), getDegree(rhs, degreeMemo));
142 if (
auto addOp = val.getDefiningOp<
AddFeltOp>()) {
143 return lowerBinaryRoot(addOp);
146 if (
auto subOp = val.getDefiningOp<
SubFeltOp>()) {
147 return lowerBinaryRoot(subOp);
150 if (
auto negOp = val.getDefiningOp<
NegFeltOp>()) {
151 Value operand = lowerExpression(
152 negOp.getOperand(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
155 if (operand != negOp.getOperand()) {
156 negOp.getOperandMutable().set(operand);
158 degreeMemo[val] = getDegree(operand, degreeMemo);
163 if (
auto mulOp = val.getDefiningOp<
MulFeltOp>()) {
165 Value lhs = lowerExpression(
166 mulOp.getLhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
168 Value rhs = lowerExpression(
169 mulOp.getRhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
172 unsigned lhsDeg = getDegree(lhs, degreeMemo);
173 unsigned rhsDeg = getDegree(rhs, degreeMemo);
175 OpBuilder builder(mulOp.getOperation()->getBlock(), ++Block::iterator(mulOp));
177 bool eraseMul = lhsDeg + rhsDeg > maxDegree;
179 if (lhs == rhs && eraseMul) {
184 lhs.getLoc(), lhs.getType(), selfVal, auxMember.getNameAttr()
186 auxAssignments.push_back({auxName, lhs});
187 Location loc = builder.getFusedLoc({auxVal.getLoc(), lhs.getLoc()});
191 degreeMemo[auxVal] = 1;
192 rewrites[lhs] = auxVal;
193 rewrites[rhs] = auxVal;
204 while (lhsDeg + rhsDeg > maxDegree) {
205 Value &toFactor = (lhsDeg >= rhsDeg) ? lhs : rhs;
213 toFactor.getLoc(), toFactor.getType(), selfVal, auxMember.getNameAttr()
217 Location loc = builder.getFusedLoc({auxVal.getLoc(), toFactor.getLoc()});
218 auto eqOp = builder.create<
EmitEqualityOp>(loc, auxVal, toFactor);
219 auxAssignments.push_back({auxName, toFactor});
221 rewrites[toFactor] = auxVal;
222 degreeMemo[auxVal] = 1;
230 lhsDeg = getDegree(lhs, degreeMemo);
231 rhsDeg = getDegree(rhs, degreeMemo);
235 auto mulVal = builder.create<
MulFeltOp>(lhs.getLoc(), lhs.getType(), lhs, rhs);
237 mulOp->replaceAllUsesWith(mulVal);
242 degreeMemo[mulVal] = lhsDeg + rhsDeg;
243 rewrites[val] = mulVal;
253 Value materializeCallArgument(
255 DenseMap<Value, unsigned> °reeMemo, DenseMap<Value, Value> &rewrites,
256 SmallVector<AuxAssignment> &auxAssignments
259 lowerExpression(val, structDef, constrainFunc, degreeMemo, rewrites, auxAssignments);
260 DenseMap<Value, unsigned> checkMemo;
261 if (getDegree(loweredVal, checkMemo) <= 1) {
270 OpBuilder builder(callOp);
273 loweredVal.getLoc(), loweredVal.getType(), selfVal, auxMember.getNameAttr()
276 Location loc = builder.getFusedLoc({auxVal.getLoc(), loweredVal.getLoc()});
278 auxAssignments.push_back({auxName, loweredVal});
280 degreeMemo[auxVal] = 1;
281 rewrites[loweredVal] = auxVal;
282 rewrites[val] = auxVal;
286 LogicalResult checkEqualityDegrees(
FuncDefOp constrainFunc) {
287 bool failedCheck =
false;
290 DenseMap<Value, unsigned> checkMemo;
291 unsigned lhsDegree = getDegree(eqOp.
getLhs(), checkMemo);
292 unsigned rhsDegree = getDegree(eqOp.
getRhs(), checkMemo);
294 if (lhsDegree > maxDegree || rhsDegree > maxDegree) {
295 auto diag = eqOp.emitOpError();
296 diag <<
"poly lowering postcondition failed: equality operand degree exceeds max-degree "
297 << maxDegree.getValue() <<
" (lhs degree " << lhsDegree <<
", rhs degree " << rhsDegree
304 return failure(failedCheck);
307 LogicalResult checkStructConstrainCallArguments(
FuncDefOp constrainFunc) {
308 bool failedCheck =
false;
310 constrainFunc.walk([&](
CallOp callOp) {
316 if (!llvm::isa<FeltType>(arg.getType())) {
320 DenseMap<Value, unsigned> checkMemo;
321 unsigned argDegree = getDegree(arg, checkMemo);
323 auto diag = callOp.emitOpError();
324 diag <<
"poly lowering postcondition failed: struct constrain call argument degree "
325 "exceeds 1 (argument degree "
333 return failure(failedCheck);
336 void runOnOperation()
override {
337 ModuleOp moduleOp = getOperation();
341 auto diag = moduleOp.emitError();
342 diag <<
"Invalid max degree: " << maxDegree.getValue() <<
". Must be >= 2.";
348 moduleOp.walk([
this, &moduleOp](
StructDefOp structDef) {
351 if (!constrainFunc) {
352 auto diag = structDef.emitOpError();
361 auto diag = structDef.emitOpError();
362 diag <<
'"' << structDef.getName() <<
"\" doesn't have a \"@" <<
FUNC_NAME_COMPUTE
374 DenseMap<Value, unsigned> degreeMemo;
375 DenseMap<Value, Value> rewrites;
376 SmallVector<AuxAssignment> auxAssignments;
382 unsigned degreeLhs = getDegree(lhsOperand.get(), degreeMemo);
383 unsigned degreeRhs = getDegree(rhsOperand.get(), degreeMemo);
385 if (degreeLhs > maxDegree) {
386 Value loweredExpr = lowerExpression(
387 lhsOperand.get(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
389 lhsOperand.set(loweredExpr);
391 if (degreeRhs > maxDegree) {
392 Value loweredExpr = lowerExpression(
393 rhsOperand.get(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
395 rhsOperand.set(loweredExpr);
402 auto diag = moduleOp.emitError();
403 diag <<
"EmitContainmentOp is unsupported for now in the lowering pass";
410 constrainFunc.walk([&](
CallOp callOp) {
412 SmallVector<Value> newOperands = llvm::to_vector(callOp.
getArgOperands());
413 bool modified =
false;
415 for (Value &arg : newOperands) {
416 if (!llvm::isa<FeltType>(arg.getType())) {
420 DenseMap<Value, unsigned> callMemo;
421 unsigned deg = getDegree(arg, callMemo);
424 arg = materializeCallArgument(
425 arg, structDef, constrainFunc, callOp, degreeMemo, rewrites, auxAssignments
432 OpBuilder builder(callOp);
434 callOp.getLoc(), callOp.getResultTypes(), callOp.
getCallee(),
443 if (failed(checkEqualityDegrees(constrainFunc))) {
448 if (failed(checkStructConstrainCallArguments(constrainFunc))) {
453 DenseMap<Value, Value> rebuildMemo;
454 Block &computeBlock = computeFunc.
getBody().front();
455 OpBuilder builder(&computeBlock, computeBlock.getTerminator()->getIterator());
458 for (
const auto &assign : auxAssignments) {
462 assign.computedValue.getLoc(), selfVal, builder.getStringAttr(assign.auxMemberName),
#define AUXILIARY_MEMBER_PREFIX
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
::mlir::OpOperand & getRhsMutable()
::mlir::TypedValue<::mlir::Type > getLhs()
::mlir::OpOperand & getLhsMutable()
::mlir::TypedValue<::mlir::Type > getRhs()
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
::mlir::SymbolRefAttr getCallee()
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
::mlir::Operation::operand_range getArgOperands()
::mlir::OperandRangeRange getMapOperands()
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
::mlir::Region & getBody()
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Value rebuildExprInCompute(Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap< Value, Value > &memo)
void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp)
constexpr char FUNC_NAME_CONSTRAIN[]
MemberDefOp addAuxMember(StructDefOp structDef, StringRef name)
LogicalResult checkForAuxMemberConflicts(StructDefOp structDef, StringRef prefix)