LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKLoweringUtils.cpp
Go to the documentation of this file.
1//===-- LLZKLoweringUtils.cpp --------------------------------*- C++ -*----===//
2//
3// Shared utility function implementations for LLZK lowering passes.
4//
5//===----------------------------------------------------------------------===//
6
8
10
11#include <mlir/IR/Block.h>
12#include <mlir/IR/Builders.h>
13#include <mlir/IR/BuiltinOps.h>
14#include <mlir/IR/Operation.h>
15#include <mlir/Support/LogicalResult.h>
16
17#include <llvm/ADT/SmallVector.h>
18#include <llvm/Support/raw_ostream.h>
19
20using namespace mlir;
21using namespace llzk;
22using namespace llzk::felt;
23using namespace llzk::function;
24using namespace llzk::component;
25using namespace llzk::constrain;
26
27namespace llzk {
28
30 Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap<Value, Value> &memo
31) {
32 if (auto it = memo.find(val); it != memo.end()) {
33 return it->second;
34 }
35
36 if (auto barg = llvm::dyn_cast<BlockArgument>(val)) {
37 unsigned index = barg.getArgNumber();
38 Value mapped = computeFunc.getArgument(index - 1);
39 return memo[val] = mapped;
40 }
41
42 if (auto readOp = val.getDefiningOp<MemberReadOp>()) {
43 Value self = computeFunc.getSelfValueFromCompute();
44 Value rebuilt = builder.create<MemberReadOp>(
45 readOp.getLoc(), readOp.getType(), self, readOp.getMemberNameAttr().getAttr()
46 );
47 return memo[val] = rebuilt;
48 }
49
50 if (auto add = val.getDefiningOp<AddFeltOp>()) {
51 Value lhs = rebuildExprInCompute(add.getLhs(), computeFunc, builder, memo);
52 Value rhs = rebuildExprInCompute(add.getRhs(), computeFunc, builder, memo);
53 return memo[val] = builder.create<AddFeltOp>(add.getLoc(), add.getType(), lhs, rhs);
54 }
55
56 if (auto sub = val.getDefiningOp<SubFeltOp>()) {
57 Value lhs = rebuildExprInCompute(sub.getLhs(), computeFunc, builder, memo);
58 Value rhs = rebuildExprInCompute(sub.getRhs(), computeFunc, builder, memo);
59 return memo[val] = builder.create<SubFeltOp>(sub.getLoc(), sub.getType(), lhs, rhs);
60 }
61
62 if (auto mul = val.getDefiningOp<MulFeltOp>()) {
63 Value lhs = rebuildExprInCompute(mul.getLhs(), computeFunc, builder, memo);
64 Value rhs = rebuildExprInCompute(mul.getRhs(), computeFunc, builder, memo);
65 return memo[val] = builder.create<MulFeltOp>(mul.getLoc(), mul.getType(), lhs, rhs);
66 }
67
68 if (auto neg = val.getDefiningOp<NegFeltOp>()) {
69 Value operand = rebuildExprInCompute(neg.getOperand(), computeFunc, builder, memo);
70 return memo[val] = builder.create<NegFeltOp>(neg.getLoc(), neg.getType(), operand);
71 }
72
73 if (auto div = val.getDefiningOp<DivFeltOp>()) {
74 Value lhs = rebuildExprInCompute(div.getLhs(), computeFunc, builder, memo);
75 Value rhs = rebuildExprInCompute(div.getRhs(), computeFunc, builder, memo);
76 return memo[val] = builder.create<DivFeltOp>(div.getLoc(), div.getType(), lhs, rhs);
77 }
78
79 if (auto c = val.getDefiningOp<FeltConstantOp>()) {
80 return memo[val] = builder.create<FeltConstantOp>(c.getLoc(), c.getValueAttr());
81 }
82
83 llvm::errs() << "Unhandled op in rebuildExprInCompute: " << val << '\n';
84 llvm_unreachable("Unsupported op kind");
85}
86
87LogicalResult checkForAuxMemberConflicts(StructDefOp structDef, StringRef prefix) {
88 bool conflictFound = false;
89
90 structDef.walk([&conflictFound, &prefix](MemberDefOp memberDefOp) {
91 if (memberDefOp.getName().starts_with(prefix)) {
92 (memberDefOp.emitError() << "Member name '" << memberDefOp.getName()
93 << "' conflicts with reserved prefix '" << prefix << '\'')
94 .report();
95 conflictFound = true;
96 }
97 });
98
99 return failure(conflictFound);
100}
101
102void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp) {
103 assert(afterOp && "afterOp must be a valid Operation*");
104
105 for (auto &use : llvm::make_early_inc_range(oldVal.getUses())) {
106 Operation *user = use.getOwner();
107
108 // Skip uses that are:
109 // - Before afterOp in the same block.
110 // - Inside afterOp itself.
111 if ((user->getBlock() == afterOp->getBlock()) &&
112 (user == afterOp || user->isBeforeInBlock(afterOp))) {
113 continue;
114 }
115
116 // Replace this use of oldVal with newVal.
117 use.set(newVal);
118 }
119}
120
121MemberDefOp addAuxMember(StructDefOp structDef, StringRef name) {
122 OpBuilder builder(structDef);
123 builder.setInsertionPointToEnd(structDef.getBody());
124 return builder.create<MemberDefOp>(
125 structDef.getLoc(), builder.getStringAttr(name), builder.getType<FeltType>()
126 );
127}
128
129unsigned getFeltDegree(Value val, DenseMap<Value, unsigned> &memo) {
130 if (auto it = memo.find(val); it != memo.end()) {
131 return it->second;
132 }
133
134 if (isa<FeltConstantOp>(val.getDefiningOp())) {
135 return memo[val] = 0;
136 }
137 if (isa<NonDetOp, MemberReadOp>(val.getDefiningOp()) || isa<BlockArgument>(val)) {
138 return memo[val] = 1;
139 }
140 if (auto add = val.getDefiningOp<AddFeltOp>()) {
141 return memo[val] =
142 std::max(getFeltDegree(add.getLhs(), memo), getFeltDegree(add.getRhs(), memo));
143 }
144 if (auto sub = val.getDefiningOp<SubFeltOp>()) {
145 return memo[val] =
146 std::max(getFeltDegree(sub.getLhs(), memo), getFeltDegree(sub.getRhs(), memo));
147 }
148 if (auto mul = val.getDefiningOp<MulFeltOp>()) {
149 return memo[val] = getFeltDegree(mul.getLhs(), memo) + getFeltDegree(mul.getRhs(), memo);
150 }
151 if (auto div = val.getDefiningOp<DivFeltOp>()) {
152 return memo[val] = getFeltDegree(div.getLhs(), memo) + getFeltDegree(div.getRhs(), memo);
153 }
154 if (auto neg = val.getDefiningOp<NegFeltOp>()) {
155 return memo[val] = getFeltDegree(neg.getOperand(), memo);
156 }
157
158 llvm::errs() << "Unhandled felt op in degree computation: " << val << '\n';
159 llvm_unreachable("Unhandled op in getFeltDegree");
160}
161
162} // namespace llzk
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Definition LICENSE.txt:9
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:363
ExpressionValue add(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Value rebuildExprInCompute(Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap< Value, Value > &memo)
void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp)
ExpressionValue neg(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
ExpressionValue div(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue mul(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
MemberDefOp addAuxMember(StructDefOp structDef, StringRef name)
ExpressionValue sub(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
unsigned getFeltDegree(Value val, DenseMap< Value, unsigned > &memo)
LogicalResult checkForAuxMemberConflicts(StructDefOp structDef, StringRef prefix)