LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKFuseProductLoopsPass.cpp
Go to the documentation of this file.
1//===-- LLZKFuseProductLoopsPass.cpp -----------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2026 Project LLZK
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
20#include "llzk/Util/Constants.h"
21
22#include <mlir/Dialect/SCF/Utils/Utils.h>
23
24#include <llvm/Support/Debug.h>
25#include <llvm/Support/SMTAPI.h>
26
27#include <memory>
28
29// Include the generated base pass class definitions.
30namespace llzk {
31#define GEN_PASS_DEF_FUSEPRODUCTLOOPSPASS
33} // namespace llzk
34
35namespace {
36
37using namespace mlir;
38using namespace llzk;
39
40// Bitwidth of `index` for instantiating SMT variables
41constexpr int INDEX_WIDTH = 64;
42
43static inline bool isConstOrStructParam(Value val) {
44 // TODO: doing arithmetic over constants should also be fine?
45 return val.getDefiningOp<arith::ConstantIndexOp>() ||
46 val.getDefiningOp<polymorphic::ConstReadOp>();
47}
48
49static llvm::SMTExprRef mkExpr(Value value, llvm::SMTSolver *solver) {
50 if (auto constOp = value.getDefiningOp<arith::ConstantIndexOp>()) {
51 return solver->mkBitvector(llvm::APSInt::get(constOp.value()), INDEX_WIDTH);
52 } else if (auto polyReadOp = value.getDefiningOp<polymorphic::ConstReadOp>()) {
53
54 return solver->mkSymbol(
55 std::string {polyReadOp.getConstName()}.c_str(), solver->getBitvectorSort(INDEX_WIDTH)
56 );
57 }
58 assert(false && "unsupported: checking non-constant trip counts");
59 return nullptr; // Unreachable
60}
61
62static llvm::SMTExprRef tripCount(scf::ForOp op, llvm::SMTSolver *solver) {
63 const auto *one = solver->mkBitvector(llvm::APSInt::get(1), INDEX_WIDTH);
64 return solver->mkBVSDiv(
65 solver->mkBVAdd(
66 one,
67 solver->mkBVSub(mkExpr(op.getUpperBound(), solver), mkExpr(op.getLowerBound(), solver))
68 ),
69 mkExpr(op.getStep(), solver)
70 );
71}
72
73static inline bool canLoopsBeFused(scf::ForOp a, scf::ForOp b) {
74 // A priori, two loops can be fused if:
75 // 1. They live in the same parent region,
76 // 2. One comes from witgen and the other comes from constraint gen, and
77 // 3. They have the same trip count
78
79 // Check 1.
80 if (a->getParentRegion() != b->getParentRegion()) {
81 return false;
82 }
83
84 // Check 2.
85 if (!a->hasAttrOfType<StringAttr>(PRODUCT_SOURCE) ||
86 !b->hasAttrOfType<StringAttr>(PRODUCT_SOURCE)) {
87 // Ideally this should never happen, since the pass only runs on fused @product functions, but
88 // check anyway just to be safe
89 return false;
90 }
91 if (a->getAttrOfType<StringAttr>(PRODUCT_SOURCE) ==
92 b->getAttrOfType<StringAttr>(PRODUCT_SOURCE)) {
93 return false;
94 }
95
96 // Check 3.
97 // Easy case: both have a constant trip-count. If the trip counts are not "constant up to a struct
98 // param", we definitely can't tell if they're equal. If the trip counts are only "constant up to
99 // a struct param" but not actually constant, we can ask a solver if the equations are guaranteed
100 // to be the same
101 auto tripCountA = constantTripCount(a.getLowerBound(), a.getUpperBound(), a.getStep());
102 auto tripCountB = constantTripCount(b.getLowerBound(), b.getUpperBound(), b.getStep());
103 if (tripCountA.has_value() && tripCountB.has_value() && *tripCountA == *tripCountB) {
104 return true;
105 }
106
107 if (!isConstOrStructParam(a.getLowerBound()) || !isConstOrStructParam(a.getUpperBound()) ||
108 !isConstOrStructParam(a.getStep()) || !isConstOrStructParam(b.getLowerBound()) ||
109 !isConstOrStructParam(b.getUpperBound()) || !isConstOrStructParam(b.getStep())) {
110 return false;
111 }
112
113 llvm::SMTSolverRef solver = llvm::CreateZ3Solver();
114 solver->addConstraint(/* (actually ask if they "can't be different") */ solver->mkNot(
115 solver->mkEqual(tripCount(a, solver.get()), tripCount(b, solver.get()))
116 ));
117
118 return !*solver->check();
119}
120
121static LogicalResult fuseMatchingLoopPairs(Region &body, MLIRContext *context) {
122 // Start by collecting all possible loops
123 llvm::SmallVector<scf::ForOp> witnessLoops, constraintLoops;
124 body.walk<WalkOrder::PreOrder>([&witnessLoops, &constraintLoops](scf::ForOp forOp) {
125 if (!forOp->hasAttrOfType<StringAttr>(PRODUCT_SOURCE)) {
126 return WalkResult::skip();
127 }
128 auto productSource = forOp->getAttrOfType<StringAttr>(PRODUCT_SOURCE);
129 if (productSource == FUNC_NAME_COMPUTE) {
130 witnessLoops.push_back(forOp);
131 } else if (productSource == FUNC_NAME_CONSTRAIN) {
132 constraintLoops.push_back(forOp);
133 }
134 // Skipping here, because any nested loops can't possibly be fused at this stage
135 return WalkResult::skip();
136 });
137
138 // A pair of loops will be fused iff (1) they can be fused according to the rules above, and (2)
139 // neither can be fused with anything else (so there's no ambiguity)
141 witnessLoops, constraintLoops, canLoopsBeFused
142 );
143
144 // This shouldn't happen, since we allow partial matches
145 if (failed(fusionCandidates)) {
146 return failure();
147 }
148
149 // Finally, fuse all the marked loops...
150 IRRewriter rewriter {context};
151 for (auto [w, c] : *fusionCandidates) {
152 auto fusedLoop = fuseIndependentSiblingForLoops(w, c, rewriter);
153 fusedLoop->setAttr(PRODUCT_SOURCE, rewriter.getAttr<StringAttr>("fused"));
154 // ...and recurse to fuse nested loops
155 if (failed(fuseMatchingLoopPairs(fusedLoop.getBodyRegion(), context))) {
156 return failure();
157 }
158 }
159 return success();
160}
161
162class PassImpl : public llzk::impl::FuseProductLoopsPassBase<PassImpl> {
163 using Base = FuseProductLoopsPassBase<PassImpl>;
164 using Base::Base;
165
166 void runOnOperation() override {
167 ModuleOp mod = getOperation();
168 mod.walk([this](function::FuncDefOp funcDef) {
169 if (funcDef.isStructProduct()) {
170 if (failed(fuseMatchingLoopPairs(funcDef.getFunctionBody(), &getContext()))) {
171 signalPassFailure();
172 }
173 }
174 });
175 }
176};
177
178} // namespace
bool isStructProduct()
Return true iff the function is within a StructDefOp and named FUNC_NAME_PRODUCT.
Definition Ops.h.inc:885
llvm::FailureOr< llvm::SetVector< std::pair< ValueT, ValueT > > > getMatchingPairs(llvm::ArrayRef< ValueT > as, llvm::ArrayRef< ValueT > bs, FnT doesMatch, bool allowPartial=true)
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char PRODUCT_SOURCE[]
Name of the attribute on aligned product program ops that specifies where they came from.
Definition Constants.h:40
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17