23#include <mlir/IR/BuiltinOps.h>
25#include <llvm/ADT/DenseMap.h>
26#include <llvm/ADT/DenseMapInfo.h>
27#include <llvm/ADT/SmallVector.h>
28#include <llvm/Support/Debug.h>
35#define GEN_PASS_DECL_POLYLOWERINGPASS
36#define GEN_PASS_DEF_POLYLOWERINGPASS
47#define DEBUG_TYPE "llzk-poly-lowering-pass"
48#define AUXILIARY_MEMBER_PREFIX "__llzk_poly_lowering_pass_aux_member_"
53 std::string auxMemberName;
59 void setMaxDegree(
unsigned degree) { this->maxDegree = degree; }
62 unsigned auxCounter = 0;
64 void collectStructDefs(ModuleOp modOp, SmallVectorImpl<StructDefOp> &structDefs) {
66 structDefs.push_back(structDef);
67 return WalkResult::skip();
72 unsigned getDegree(Value val, DenseMap<Value, unsigned> &memo) {
73 if (
auto it = memo.find(val); it != memo.end()) {
77 if (llvm::isa<BlockArgument>(val)) {
90 if (
auto addOp = val.getDefiningOp<
AddFeltOp>()) {
91 return memo[val] = std::max(getDegree(addOp.getLhs(), memo), getDegree(addOp.getRhs(), memo));
93 if (
auto subOp = val.getDefiningOp<
SubFeltOp>()) {
94 return memo[val] = std::max(getDegree(subOp.getLhs(), memo), getDegree(subOp.getRhs(), memo));
96 if (
auto mulOp = val.getDefiningOp<
MulFeltOp>()) {
97 return memo[val] = getDegree(mulOp.getLhs(), memo) + getDegree(mulOp.getRhs(), memo);
99 if (
auto divOp = val.getDefiningOp<
DivFeltOp>()) {
100 return memo[val] = getDegree(divOp.getLhs(), memo) + getDegree(divOp.getRhs(), memo);
102 if (
auto negOp = val.getDefiningOp<
NegFeltOp>()) {
103 return memo[val] = getDegree(negOp.getOperand(), memo);
106 llvm_unreachable(
"Unhandled Felt SSA value in degree computation");
109 Value lowerExpression(
111 DenseMap<Value, unsigned> °reeMemo, DenseMap<Value, Value> &rewrites,
112 SmallVector<AuxAssignment> &auxAssignments
114 if (rewrites.count(val)) {
115 return rewrites[val];
118 unsigned degree = getDegree(val, degreeMemo);
119 if (degree <= maxDegree) {
124 if (
auto mulOp = val.getDefiningOp<
MulFeltOp>()) {
126 Value lhs = lowerExpression(
127 mulOp.getLhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
129 Value rhs = lowerExpression(
130 mulOp.getRhs(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
133 unsigned lhsDeg = getDegree(lhs, degreeMemo);
134 unsigned rhsDeg = getDegree(rhs, degreeMemo);
136 OpBuilder builder(mulOp.getOperation()->getBlock(), ++Block::iterator(mulOp));
138 bool eraseMul = lhsDeg + rhsDeg > maxDegree;
140 if (lhs == rhs && eraseMul) {
145 lhs.getLoc(), lhs.getType(), selfVal, auxMember.getNameAttr()
147 auxAssignments.push_back({auxName, lhs});
148 Location loc = builder.getFusedLoc({auxVal.getLoc(), lhs.getLoc()});
152 degreeMemo[auxVal] = 1;
153 rewrites[lhs] = auxVal;
154 rewrites[rhs] = auxVal;
165 while (lhsDeg + rhsDeg > maxDegree) {
166 Value &toFactor = (lhsDeg >= rhsDeg) ? lhs : rhs;
174 toFactor.getLoc(), toFactor.getType(), selfVal, auxMember.getNameAttr()
178 Location loc = builder.getFusedLoc({auxVal.getLoc(), toFactor.getLoc()});
179 auto eqOp = builder.create<
EmitEqualityOp>(loc, auxVal, toFactor);
180 auxAssignments.push_back({auxName, toFactor});
182 rewrites[toFactor] = auxVal;
183 degreeMemo[auxVal] = 1;
191 lhsDeg = getDegree(lhs, degreeMemo);
192 rhsDeg = getDegree(rhs, degreeMemo);
196 auto mulVal = builder.create<
MulFeltOp>(lhs.getLoc(), lhs.getType(), lhs, rhs);
198 mulOp->replaceAllUsesWith(mulVal);
203 degreeMemo[mulVal] = lhsDeg + rhsDeg;
204 rewrites[val] = mulVal;
214 void runOnOperation()
override {
215 ModuleOp moduleOp = getOperation();
219 auto diag = moduleOp.emitError();
220 diag <<
"Invalid max degree: " << maxDegree.getValue() <<
". Must be >= 2.";
226 moduleOp.walk([
this, &moduleOp](
StructDefOp structDef) {
229 if (!constrainFunc) {
230 auto diag = structDef.emitOpError();
239 auto diag = structDef.emitOpError();
240 diag <<
'"' << structDef.getName() <<
"\" doesn't have a \"@" <<
FUNC_NAME_COMPUTE
252 DenseMap<Value, unsigned> degreeMemo;
253 DenseMap<Value, Value> rewrites;
254 SmallVector<AuxAssignment> auxAssignments;
260 unsigned degreeLhs = getDegree(lhsOperand.get(), degreeMemo);
261 unsigned degreeRhs = getDegree(rhsOperand.get(), degreeMemo);
263 if (degreeLhs > maxDegree) {
264 Value loweredExpr = lowerExpression(
265 lhsOperand.get(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
267 lhsOperand.set(loweredExpr);
269 if (degreeRhs > maxDegree) {
270 Value loweredExpr = lowerExpression(
271 rhsOperand.get(), structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
273 rhsOperand.set(loweredExpr);
280 auto diag = moduleOp.emitError();
281 diag <<
"EmitContainmentOp is unsupported for now in the lowering pass";
288 constrainFunc.walk([&](
CallOp callOp) {
290 SmallVector<Value> newOperands = llvm::to_vector(callOp.
getArgOperands());
291 bool modified =
false;
293 for (Value &arg : newOperands) {
294 unsigned deg = getDegree(arg, degreeMemo);
297 Value loweredArg = lowerExpression(
298 arg, structDef, constrainFunc, degreeMemo, rewrites, auxAssignments
306 OpBuilder builder(callOp);
308 callOp.getLoc(), callOp.getResultTypes(), callOp.
getCallee(),
317 DenseMap<Value, Value> rebuildMemo;
318 Block &computeBlock = computeFunc.
getBody().front();
319 OpBuilder builder(&computeBlock, computeBlock.getTerminator()->getIterator());
322 for (
const auto &assign : auxAssignments) {
326 assign.computedValue.getLoc(), selfVal, builder.getStringAttr(assign.auxMemberName),
336 return std::make_unique<PolyLoweringPass>();
340 auto pass = std::make_unique<PolyLoweringPass>();
341 static_cast<PolyLoweringPass *
>(pass.get())->setMaxDegree(
maxDegree);
#define AUXILIARY_MEMBER_PREFIX
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
::mlir::OpOperand & getRhsMutable()
::mlir::OpOperand & getLhsMutable()
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
::mlir::SymbolRefAttr getCallee()
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
::mlir::Operation::operand_range getArgOperands()
::mlir::OperandRangeRange getMapOperands()
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
::mlir::Region & getBody()
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Value rebuildExprInCompute(Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap< Value, Value > &memo)
void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp)
constexpr char FUNC_NAME_CONSTRAIN[]
MemberDefOp addAuxMember(StructDefOp structDef, StringRef name)
std::unique_ptr< mlir::Pass > createPolyLoweringPass()
LogicalResult checkForAuxMemberConflicts(StructDefOp structDef, StringRef prefix)