LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKComputeConstrainToProductPass.cpp
Go to the documentation of this file.
1//===-- LLZKComputeConstrainToProductPass.cpp -------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
20#include "llzk/Util/Constants.h"
23#include <mlir/IR/Builders.h>
24#include <mlir/Transforms/InliningUtils.h>
25
26#include <llvm/Support/Debug.h>
27
28#include <memory>
29#include <ranges>
30
31namespace llzk {
32#define GEN_PASS_DECL_COMPUTECONSTRAINTOPRODUCTPASS
33#define GEN_PASS_DEF_COMPUTECONSTRAINTOPRODUCTPASS
35} // namespace llzk
37#define DEBUG_TYPE "llzk-compute-constrain-to-product-pass"
39using namespace llzk::component;
40using namespace llzk::function;
41using namespace mlir;
42
43using std::make_unique;
44
45namespace llzk {
48 FuncDefOp computeFunc = root.getComputeFuncOp();
49 FuncDefOp constrainFunc = root.getConstrainFuncOp();
50
51 if (!computeFunc || !constrainFunc) {
52 root->emitError() << "no " << FUNC_NAME_COMPUTE << "/" << FUNC_NAME_CONSTRAIN << " to align";
53 return false;
54 }
55
57 /// root to start aligning from (issue #241)
58
59 return true;
60}
61
62LogicalResult alignStartingAt(
63 component::StructDefOp root, SymbolTableCollection &tables,
65) {
66 if (!isValidRoot(root)) {
67 return failure();
68 }
69
70 ProductAligner aligner {tables, equivalence};
71 if (!aligner.alignFuncs(root, root.getComputeFuncOp(), root.getConstrainFuncOp())) {
72 return failure();
73 }
74
75 for (auto s : aligner.alignedStructs) {
76 s.getComputeFuncOp()->erase();
77 s.getConstrainFuncOp()->erase();
78 }
79
80 return success();
81}
82
84 : public llzk::impl::ComputeConstrainToProductPassBase<ComputeConstrainToProductPass> {
85
86public:
87 void runOnOperation() override {
88 ModuleOp mod = getOperation();
89 StructDefOp root;
90
91 SymbolTableCollection tables;
93 getAnalysis<LightweightSignalEquivalenceAnalysis>()
94 };
95
96 // Find the indicated root struct and make sure its a valid place to start aligning
97 mod.walk([&root, this](StructDefOp structDef) {
98 if (structDef.getSymName() == rootStruct) {
99 root = structDef;
100 }
101 });
102
103 if (failed(alignStartingAt(root, tables, equivalence))) {
104 signalPassFailure();
105 }
106 }
107};
108
110 OpBuilder funcBuilder(compute);
111
112 // Add compute/constrain attributes
113 compute.walk([&funcBuilder](Operation *op) {
114 op->setAttr(PRODUCT_SOURCE, funcBuilder.getStringAttr(FUNC_NAME_COMPUTE));
115 });
116
117 constrain.walk([&funcBuilder](Operation *op) {
118 op->setAttr(PRODUCT_SOURCE, funcBuilder.getStringAttr(FUNC_NAME_CONSTRAIN));
119 });
120
121 // Create an empty @product func...
122 FuncDefOp productFunc = funcBuilder.create<FuncDefOp>(
123 funcBuilder.getFusedLoc({compute.getLoc(), constrain.getLoc()}), FUNC_NAME_PRODUCT,
124 compute.getFunctionType()
125 );
126 Block *entryBlock = productFunc.addEntryBlock();
127 funcBuilder.setInsertionPointToStart(entryBlock);
128
129 // ...with the right arguments
130 llvm::SmallVector<Value> args {productFunc.getArguments()};
131
132 // Add calls to @compute and @constrain...
133 CallOp computeCall = funcBuilder.create<CallOp>(funcBuilder.getUnknownLoc(), compute, args);
134 args.insert(args.begin(), computeCall->getResult(0));
135 CallOp constrainCall = funcBuilder.create<CallOp>(funcBuilder.getUnknownLoc(), constrain, args);
136 funcBuilder.create<ReturnOp>(funcBuilder.getUnknownLoc(), computeCall->getResult(0));
137
138 // ..and inline them
139 InlinerInterface inliner(productFunc.getContext());
140 if (failed(inlineCall(inliner, computeCall, compute, &compute.getBody(), true))) {
141 root->emitError() << "failed to inline " << FUNC_NAME_COMPUTE;
142 return nullptr;
143 }
144 if (failed(inlineCall(inliner, constrainCall, constrain, &constrain.getBody(), true))) {
145 root->emitError() << "failed to inline " << FUNC_NAME_CONSTRAIN;
146 return nullptr;
147 }
148 computeCall->erase();
149 constrainCall->erase();
150
151 // Mark the compute/constrain for deletion
152 alignedStructs.push_back(root);
153
154 // Make sure we can align sub-calls to @compute and @constrain
155 if (failed(alignCalls(productFunc))) {
156 return nullptr;
157 }
158 return productFunc;
159}
160
161LogicalResult ProductAligner::alignCalls(FuncDefOp product) {
162 // Gather up all the remaining calls to @compute and @constrain
163 llvm::SetVector<CallOp> computeCalls, constrainCalls;
164 product.walk([&](CallOp callOp) {
165 if (callOp.calleeIsStructCompute()) {
166 computeCalls.insert(callOp);
167 } else if (callOp.calleeIsStructConstrain()) {
168 constrainCalls.insert(callOp);
169 }
170 });
171
172 llvm::SetVector<std::pair<CallOp, CallOp>> alignedCalls;
173
174 // A @compute matches a @constrain if they belong to the same struct and all their input signals
175 // are pairwise equivalent
176 auto doCallsMatch = [&](CallOp compute, CallOp constrain) -> bool {
177 LLVM_DEBUG({
178 llvm::outs() << "Asking for equivalence between calls\n"
179 << compute << "\nand\n"
180 << constrain << "\n\n";
181 llvm::outs() << "In block:\n\n" << *compute->getBlock() << "\n";
182 });
183
184 auto computeStruct = getPrefixAsSymbolRefAttr(compute.getCallee());
185 auto constrainStruct = getPrefixAsSymbolRefAttr(constrain.getCallee());
186 if (computeStruct != constrainStruct) {
187 return false;
188 }
189 for (unsigned i = 0, e = compute->getNumOperands() - 1; i < e; i++) {
190 if (!equivalence.areSignalsEquivalent(compute->getOperand(i), constrain->getOperand(i + 1))) {
191 return false;
192 }
193 }
194
195 return true;
196 };
197
198 for (auto compute : computeCalls) {
199 // If there is exactly one @compute that matches a given @constrain, we can align them
200 auto matches = llvm::filter_to_vector(constrainCalls, [&](CallOp constrain) {
201 return doCallsMatch(compute, constrain);
202 });
203
204 if (matches.size() == 1) {
205 alignedCalls.insert({compute, matches[0]});
206 computeCalls.remove(compute);
207 constrainCalls.remove(matches[0]);
208 }
209 }
210
211 // TODO: If unaligned calls remain, fully inline their structs and continue instead of failing
212 if (!computeCalls.empty() && constrainCalls.empty()) {
213 product->emitError() << "failed to align some @" << FUNC_NAME_COMPUTE << " and @"
215 return failure();
216 }
217
218 for (auto [compute, constrain] : alignedCalls) {
219 // If @A::@compute matches @A::@constrain, recursively align the functions in @A...
220 auto newRoot = compute.getCalleeTarget(tables)->get()->getParentOfType<StructDefOp>();
221 assert(newRoot);
222 FuncDefOp newProduct =
223 alignFuncs(newRoot, newRoot.getComputeFuncOp(), newRoot.getConstrainFuncOp());
224 if (!newProduct) {
225 return failure();
226 }
227
228 // ...and replace the two calls with a single call to @A::@product
229 OpBuilder callBuilder(compute);
230 CallOp newCall = callBuilder.create<CallOp>(
231 callBuilder.getFusedLoc({compute.getLoc(), constrain.getLoc()}), newProduct,
232 compute.getOperands()
233 );
234 compute->replaceAllUsesWith(newCall.getResults());
235 compute->erase();
236 constrain->erase();
237 }
238
239 return success();
240}
241
242std::unique_ptr<Pass> createComputeConstrainToProductPass() {
243 return make_unique<ComputeConstrainToProductPass>();
244}
245
246} // namespace llzk
std::vector< component::StructDefOp > alignedStructs
function::FuncDefOp alignFuncs(component::StructDefOp root, function::FuncDefOp compute, function::FuncDefOp constrain)
mlir::LogicalResult alignCalls(function::FuncDefOp product)
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:1676
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:430
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:426
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:772
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:766
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:467
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
Definition Ops.cpp:788
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:952
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
::mlir::Region & getBody()
Definition Ops.h.inc:607
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
constexpr char PRODUCT_SOURCE[]
Name of the attribute on aligned product program ops that specifies where they came from.
Definition Constants.h:31
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
LogicalResult alignStartingAt(component::StructDefOp root, SymbolTableCollection &tables, LightweightSignalEquivalenceAnalysis &equivalence)
std::unique_ptr< Pass > createComputeConstrainToProductPass()
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:18
bool isValidRoot(StructDefOp root)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)