LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKLoweringUtils.cpp
Go to the documentation of this file.
1//===-- LLZKLoweringUtils.cpp --------------------------------*- C++ -*----===//
2//
3// Shared utility function implementations for LLZK lowering passes.
4//
5//===----------------------------------------------------------------------===//
6
9
10#include <mlir/IR/Block.h>
11#include <mlir/IR/Builders.h>
12#include <mlir/IR/BuiltinOps.h>
13#include <mlir/IR/Operation.h>
14#include <mlir/Support/LogicalResult.h>
15
16#include <llvm/ADT/SmallVector.h>
17#include <llvm/Support/raw_ostream.h>
18
19using namespace mlir;
20using namespace llzk;
21using namespace llzk::felt;
22using namespace llzk::function;
23using namespace llzk::component;
24using namespace llzk::constrain;
25
26namespace llzk {
27
29 Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap<Value, Value> &memo
30) {
31 if (auto it = memo.find(val); it != memo.end()) {
32 return it->second;
33 }
34
35 if (auto barg = llvm::dyn_cast<BlockArgument>(val)) {
36 unsigned index = barg.getArgNumber();
37 Value mapped = computeFunc.getArgument(index - 1);
38 return memo[val] = mapped;
39 }
40
41 if (auto readOp = val.getDefiningOp<MemberReadOp>()) {
42 Value self = computeFunc.getSelfValueFromCompute();
43 Value rebuilt = builder.create<MemberReadOp>(
44 readOp.getLoc(), readOp.getType(), self, readOp.getMemberNameAttr().getAttr()
45 );
46 return memo[val] = rebuilt;
47 }
48
49 if (auto add = val.getDefiningOp<AddFeltOp>()) {
50 Value lhs = rebuildExprInCompute(add.getLhs(), computeFunc, builder, memo);
51 Value rhs = rebuildExprInCompute(add.getRhs(), computeFunc, builder, memo);
52 return memo[val] = builder.create<AddFeltOp>(add.getLoc(), add.getType(), lhs, rhs);
53 }
54
55 if (auto sub = val.getDefiningOp<SubFeltOp>()) {
56 Value lhs = rebuildExprInCompute(sub.getLhs(), computeFunc, builder, memo);
57 Value rhs = rebuildExprInCompute(sub.getRhs(), computeFunc, builder, memo);
58 return memo[val] = builder.create<SubFeltOp>(sub.getLoc(), sub.getType(), lhs, rhs);
59 }
60
61 if (auto mul = val.getDefiningOp<MulFeltOp>()) {
62 Value lhs = rebuildExprInCompute(mul.getLhs(), computeFunc, builder, memo);
63 Value rhs = rebuildExprInCompute(mul.getRhs(), computeFunc, builder, memo);
64 return memo[val] = builder.create<MulFeltOp>(mul.getLoc(), mul.getType(), lhs, rhs);
65 }
66
67 if (auto neg = val.getDefiningOp<NegFeltOp>()) {
68 Value operand = rebuildExprInCompute(neg.getOperand(), computeFunc, builder, memo);
69 return memo[val] = builder.create<NegFeltOp>(neg.getLoc(), neg.getType(), operand);
70 }
71
72 if (auto div = val.getDefiningOp<DivFeltOp>()) {
73 Value lhs = rebuildExprInCompute(div.getLhs(), computeFunc, builder, memo);
74 Value rhs = rebuildExprInCompute(div.getRhs(), computeFunc, builder, memo);
75 return memo[val] = builder.create<DivFeltOp>(div.getLoc(), div.getType(), lhs, rhs);
76 }
77
78 if (auto c = val.getDefiningOp<FeltConstantOp>()) {
79 return memo[val] = builder.create<FeltConstantOp>(c.getLoc(), c.getValueAttr());
80 }
81
82 llvm::errs() << "Unhandled op in rebuildExprInCompute: " << val << '\n';
83 llvm_unreachable("Unsupported op kind");
84}
85
86LogicalResult checkForAuxMemberConflicts(StructDefOp structDef, StringRef prefix) {
87 bool conflictFound = false;
88
89 structDef.walk([&conflictFound, &prefix](MemberDefOp memberDefOp) {
90 if (memberDefOp.getName().starts_with(prefix)) {
91 (memberDefOp.emitError() << "Member name '" << memberDefOp.getName()
92 << "' conflicts with reserved prefix '" << prefix << '\'')
93 .report();
94 conflictFound = true;
95 }
96 });
97
98 return failure(conflictFound);
99}
100
101void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp) {
102 assert(afterOp && "afterOp must be a valid Operation*");
103
104 for (auto &use : llvm::make_early_inc_range(oldVal.getUses())) {
105 Operation *user = use.getOwner();
106
107 // Skip uses that are:
108 // - Before afterOp in the same block.
109 // - Inside afterOp itself.
110 if ((user->getBlock() == afterOp->getBlock()) &&
111 (user == afterOp || user->isBeforeInBlock(afterOp))) {
112 continue;
113 }
114
115 // Replace this use of oldVal with newVal.
116 use.set(newVal);
117 }
118}
119
120MemberDefOp addAuxMember(StructDefOp structDef, StringRef name) {
121 OpBuilder builder(structDef);
122 builder.setInsertionPointToEnd(structDef.getBody());
123 return builder.create<MemberDefOp>(
124 structDef.getLoc(), builder.getStringAttr(name), builder.getType<FeltType>()
125 );
126}
127
128unsigned getFeltDegree(Value val, DenseMap<Value, unsigned> &memo) {
129 if (auto it = memo.find(val); it != memo.end()) {
130 return it->second;
131 }
132
133 if (isa<FeltConstantOp>(val.getDefiningOp())) {
134 return memo[val] = 0;
135 }
136 if (isa<NonDetOp, MemberReadOp>(val.getDefiningOp()) || isa<BlockArgument>(val)) {
137 return memo[val] = 1;
138 }
139 if (auto add = val.getDefiningOp<AddFeltOp>()) {
140 return memo[val] =
141 std::max(getFeltDegree(add.getLhs(), memo), getFeltDegree(add.getRhs(), memo));
142 }
143 if (auto sub = val.getDefiningOp<SubFeltOp>()) {
144 return memo[val] =
145 std::max(getFeltDegree(sub.getLhs(), memo), getFeltDegree(sub.getRhs(), memo));
146 }
147 if (auto mul = val.getDefiningOp<MulFeltOp>()) {
148 return memo[val] = getFeltDegree(mul.getLhs(), memo) + getFeltDegree(mul.getRhs(), memo);
149 }
150 if (auto div = val.getDefiningOp<DivFeltOp>()) {
151 return memo[val] = getFeltDegree(div.getLhs(), memo) + getFeltDegree(div.getRhs(), memo);
152 }
153 if (auto neg = val.getDefiningOp<NegFeltOp>()) {
154 return memo[val] = getFeltDegree(neg.getOperand(), memo);
155 }
156
157 llvm::errs() << "Unhandled felt op in degree computation: " << val << '\n';
158 llvm_unreachable("Unhandled op in getFeltDegree");
159}
160
161} // namespace llzk
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Definition LICENSE.txt:9
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:354
ExpressionValue div(llvm::SMTSolverRef solver, DivFeltOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue sub(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Value rebuildExprInCompute(Value val, FuncDefOp computeFunc, OpBuilder &builder, DenseMap< Value, Value > &memo)
void replaceSubsequentUsesWith(Value oldVal, Value newVal, Operation *afterOp)
ExpressionValue add(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue mul(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
MemberDefOp addAuxMember(StructDefOp structDef, StringRef name)
ExpressionValue neg(llvm::SMTSolverRef solver, const ExpressionValue &val)
unsigned getFeltDegree(Value val, DenseMap< Value, unsigned > &memo)
LogicalResult checkForAuxMemberConflicts(StructDefOp structDef, StringRef prefix)