LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKComputeConstrainToProductPass.cpp
Go to the documentation of this file.
1//===-- LLZKComputeConstrainToProductPass.cpp -------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
16
22#include "llzk/Util/Constants.h"
24
25#include <mlir/IR/Builders.h>
26#include <mlir/IR/SymbolTable.h>
27#include <mlir/Transforms/InliningUtils.h>
28
29#include <llvm/Support/Debug.h>
30
31#include <memory>
32#include <ranges>
33
34namespace llzk {
35#define GEN_PASS_DEF_COMPUTECONSTRAINTOPRODUCTPASS
37} // namespace llzk
38
39#define DEBUG_TYPE "llzk-compute-constrain-to-product-pass"
40
41using namespace mlir;
42using namespace llzk;
43using namespace llzk::component;
44using namespace llzk::function;
45
48 if (auto prod = root.getProductFuncOp()) {
49 return prod;
50 }
52 OpBuilder funcBuilder(compute);
53
54 // Add compute/constrain attributes
55 compute.walk([&funcBuilder](Operation *op) {
56 op->setAttr(PRODUCT_SOURCE, funcBuilder.getStringAttr(FUNC_NAME_COMPUTE));
57 });
58
59 constrain.walk([&funcBuilder](Operation *op) {
60 op->setAttr(PRODUCT_SOURCE, funcBuilder.getStringAttr(FUNC_NAME_CONSTRAIN));
61 });
63 // Create an empty @product func...
64 FuncDefOp productFunc = funcBuilder.create<FuncDefOp>(
65 funcBuilder.getFusedLoc({compute.getLoc(), constrain.getLoc()}), FUNC_NAME_PRODUCT,
66 compute.getFunctionType()
67 );
68 productFunc->setAttr(DERIVED_ATTR_NAME, UnitAttr::get(funcBuilder.getContext()));
69 Block *entryBlock = productFunc.addEntryBlock();
70 funcBuilder.setInsertionPointToStart(entryBlock);
71
73 compute.hasAllowNonNativeFieldOpsAttr() || constrain.hasAllowNonNativeFieldOpsAttr()
74 );
75
76 // ...with the right arguments
77 llvm::SmallVector<Value> args {productFunc.getArguments()};
79 // Add calls to @compute and @constrain...
80 CallOp computeCall = funcBuilder.create<CallOp>(funcBuilder.getUnknownLoc(), compute, args);
81 args.insert(args.begin(), computeCall->getResult(0));
82 CallOp constrainCall = funcBuilder.create<CallOp>(funcBuilder.getUnknownLoc(), constrain, args);
83 funcBuilder.create<ReturnOp>(funcBuilder.getUnknownLoc(), computeCall->getResult(0));
84
85 // ..and inline them
86 InlinerInterface inliner(productFunc.getContext());
87 if (failed(inlineCall(inliner, computeCall, compute, &compute.getBody(), true))) {
88 root->emitError().append("failed to inline ", FUNC_NAME_COMPUTE).report();
89 return nullptr;
90 }
91 if (failed(inlineCall(inliner, constrainCall, constrain, &constrain.getBody(), true))) {
92 root->emitError().append("failed to inline ", FUNC_NAME_CONSTRAIN).report();
93 return nullptr;
94 }
95 computeCall->erase();
96 constrainCall->erase();
97
98 // Mark the compute/constrain for deletion
99 alignedStructs.push_back(root);
100
101 // Make sure we can align sub-calls to @compute and @constrain
102 if (failed(alignCalls(productFunc))) {
103 return nullptr;
104 }
105 return productFunc;
107
108LogicalResult ProductAligner::alignCalls(FuncDefOp product) {
109 // Gather up all the remaining calls to @compute and @constrain
110 llvm::SetVector<CallOp> computeCalls, constrainCalls;
111 product.walk([&](CallOp callOp) {
112 if (callOp.calleeIsStructCompute()) {
113 computeCalls.insert(callOp);
114 } else if (callOp.calleeIsStructConstrain()) {
115 constrainCalls.insert(callOp);
117 });
118
119 llvm::SetVector<std::pair<CallOp, CallOp>> alignedCalls;
121 // A @compute matches a @constrain if they belong to the same struct and all their input signals
122 // are pairwise equivalent
123 auto doCallsMatch = [&](CallOp compute, CallOp constrain) -> bool {
124 LLVM_DEBUG({
125 llvm::outs() << "Asking for equivalence between calls\n"
126 << compute << "\nand\n"
127 << constrain << "\n\n";
128 llvm::outs() << "In block:\n\n" << *compute->getBlock() << '\n';
129 });
130
131 auto computeStruct = getPrefixAsSymbolRefAttr(compute.getCallee());
132 auto constrainStruct = getPrefixAsSymbolRefAttr(constrain.getCallee());
133 if (computeStruct != constrainStruct) {
134 return false;
135 }
136 if (compute.getNumOperands() == 0) {
137 return true;
138 }
139 for (unsigned i = 0, e = compute->getNumOperands() - 1; i < e; i++) {
140 if (!equivalence.areSignalsEquivalent(compute->getOperand(i), constrain->getOperand(i + 1))) {
141 return false;
142 }
143 }
144
145 return true;
146 };
147
148 for (auto compute : computeCalls) {
149 // If there is exactly one @compute that matches a given @constrain, we can align them
150 auto matches = llvm::filter_to_vector(constrainCalls, [&](CallOp constrain) {
151 return doCallsMatch(compute, constrain);
152 });
153
154 if (matches.size() == 1) {
155 alignedCalls.insert({compute, matches[0]});
156 computeCalls.remove(compute);
157 constrainCalls.remove(matches[0]);
158 }
159 }
160
161 if (!computeCalls.empty() && constrainCalls.empty()) {
162 product.emitWarning()
163 .append("failed to align some @", FUNC_NAME_COMPUTE, " and @", FUNC_NAME_CONSTRAIN)
164 .report();
165 }
166
167 for (auto [compute, constrain] : alignedCalls) {
168 // If @A::@compute matches @A::@constrain, recursively align the functions in @A...
169 auto calleeTgt = compute.getCalleeTarget(tables);
170 if (failed(calleeTgt)) {
171 return failure();
172 }
173 auto newRoot = calleeTgt->get()->getParentOfType<StructDefOp>();
174 assert(newRoot);
175 FuncDefOp newProduct =
176 alignFuncs(newRoot, newRoot.getComputeFuncOp(), newRoot.getConstrainFuncOp());
177 if (!newProduct) {
178 return failure();
179 }
180
181 // ...and replace the two calls with a single call to @A::@product
182 OpBuilder callBuilder(compute);
183 CallOp newCall = callBuilder.create<CallOp>(
184 callBuilder.getFusedLoc({compute.getLoc(), constrain.getLoc()}), newProduct,
185 compute.getOperands()
186 );
187 compute->replaceAllUsesWith(newCall.getResults());
188 }
189
190 return success();
191}
192
193namespace {
194
195bool isValidRoot(StructDefOp root) {
196 FuncDefOp computeFunc = root.getComputeFuncOp();
197 FuncDefOp constrainFunc = root.getConstrainFuncOp();
198
199 if (!computeFunc || !constrainFunc) {
200 root->emitError()
201 .append("no ", FUNC_NAME_COMPUTE, "/", FUNC_NAME_CONSTRAIN, " to align")
202 .report();
203 return false;
204 }
205
208
209 return true;
210}
211
212LogicalResult alignStartingAt(
213 component::StructDefOp root, SymbolTableCollection &tables,
215) {
216 if (!isValidRoot(root)) {
217 return failure();
218 }
219
220 ProductAligner aligner {tables, equivalence};
221 if (!aligner.alignFuncs(root, root.getComputeFuncOp(), root.getConstrainFuncOp())) {
222 return failure();
223 }
224
225 return success();
226}
227
228class PassImpl : public llzk::impl::ComputeConstrainToProductPassBase<PassImpl> {
229 using Base = ComputeConstrainToProductPassBase<PassImpl>;
230 using Base::Base;
231
232 void runOnOperation() override {
233 ModuleOp mod = getOperation();
234 StructDefOp root;
235
236 SymbolTableCollection tables;
237 LightweightSignalEquivalenceAnalysis equivalence {
238 getAnalysis<LightweightSignalEquivalenceAnalysis>()
239 };
240
241 // Find the indicated root struct and make sure its a valid place to start aligning
242 mod.walk([&root, this](StructDefOp structDef) {
243 if (structDef.getSymName() == rootStruct) {
244 root = structDef;
245 }
246 });
247
248 if (failed(alignStartingAt(root, tables, equivalence))) {
249 signalPassFailure();
250 }
251 }
252};
253
254} // namespace
std::vector< component::StructDefOp > alignedStructs
function::FuncDefOp alignFuncs(component::StructDefOp root, function::FuncDefOp compute, function::FuncDefOp constrain)
mlir::LogicalResult alignCalls(function::FuncDefOp product)
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:1600
::llzk::function::FuncDefOp getProductFuncOp()
Gets the FuncDefOp that defines the product function in this structure, if present,...
Definition Ops.cpp:476
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:472
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:468
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:1145
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:1139
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:470
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
Definition Ops.cpp:1161
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:984
void setAllowNonNativeFieldOpsAttr(bool newValue=true)
Add (resp. remove) the allow_non_native_field_ops attribute to (resp. from) the function def.
Definition Ops.cpp:277
bool hasAllowNonNativeFieldOpsAttr()
Return true iff the function def has the allow_non_native_field_ops attribute.
Definition Ops.h.inc:820
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
::mlir::Region & getBody()
Definition Ops.h.inc:690
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char PRODUCT_SOURCE[]
Name of the attribute on aligned product program ops that specifies where they came from.
Definition Constants.h:40
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:18
constexpr char DERIVED_ATTR_NAME[]
Name of the attribute on a @product func that has been automatically aligned from @compute + @constra...
Definition Constants.h:28