LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKFuseProductLoopsPass.cpp
Go to the documentation of this file.
1//===-- LLZKFuseProductLoopsPass.cpp -----------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2026 Project LLZK
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
16
22#include "llzk/Util/Constants.h"
23
24#include <mlir/Dialect/SCF/Utils/Utils.h>
25
26#include <llvm/Support/Debug.h>
27#include <llvm/Support/SMTAPI.h>
28
29#include <memory>
30
31namespace llzk {
32
33#define GEN_PASS_DECL_FUSEPRODUCTLOOPSPASS
34#define GEN_PASS_DEF_FUSEPRODUCTLOOPSPASS
36
37using namespace llzk::function;
38
39// Bitwidth of `index` for instantiating SMT variables
40constexpr int INDEX_WIDTH = 64;
41
42class FuseProductLoopsPass : public impl::FuseProductLoopsPassBase<FuseProductLoopsPass> {
43
44public:
45 void runOnOperation() override {
46 mlir::ModuleOp mod = getOperation();
47 mod.walk([this](FuncDefOp funcDef) {
48 if (funcDef.isStructProduct()) {
49 if (mlir::failed(fuseMatchingLoopPairs(funcDef.getFunctionBody(), &getContext()))) {
50 signalPassFailure();
51 }
52 }
53 });
54 }
55};
56
57static inline bool isConstOrStructParam(mlir::Value val) {
58 // TODO: doing arithmetic over constants should also be fine?
59 return val.getDefiningOp<mlir::arith::ConstantIndexOp>() ||
60 val.getDefiningOp<llzk::polymorphic::ConstReadOp>();
61}
62
63llvm::SMTExprRef mkExpr(mlir::Value value, llvm::SMTSolver *solver) {
64 if (auto constOp = value.getDefiningOp<mlir::arith::ConstantIndexOp>()) {
65 return solver->mkBitvector(llvm::APSInt::get(constOp.value()), INDEX_WIDTH);
66 } else if (auto polyReadOp = value.getDefiningOp<llzk::polymorphic::ConstReadOp>()) {
67
68 return solver->mkSymbol(
69 std::string {polyReadOp.getConstName()}.c_str(), solver->getBitvectorSort(INDEX_WIDTH)
70 );
71 }
72 assert(false && "unsupported: checking non-constant trip counts");
73 return nullptr; // Unreachable
74}
75
76llvm::SMTExprRef tripCount(mlir::scf::ForOp op, llvm::SMTSolver *solver) {
77 const auto *one = solver->mkBitvector(llvm::APSInt::get(1), INDEX_WIDTH);
78 return solver->mkBVSDiv(
79 solver->mkBVAdd(
80 one,
81 solver->mkBVSub(mkExpr(op.getUpperBound(), solver), mkExpr(op.getLowerBound(), solver))
82 ),
83 mkExpr(op.getStep(), solver)
84 );
85}
86
87static inline bool canLoopsBeFused(mlir::scf::ForOp a, mlir::scf::ForOp b) {
88 // A priori, two loops can be fused if:
89 // 1. They live in the same parent region,
90 // 2. One comes from witgen and the other comes from constraint gen, and
91 // 3. They have the same trip count
92
93 // Check 1.
94 if (a->getParentRegion() != b->getParentRegion()) {
95 return false;
96 }
97
98 // Check 2.
99 if (!a->hasAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE) ||
100 !b->hasAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE)) {
101 // Ideally this should never happen, since the pass only runs on fused @product functions, but
102 // check anyway just to be safe
103 return false;
104 }
105 if (a->getAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE) ==
106 b->getAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE)) {
107 return false;
108 }
109
110 // Check 3.
111 // Easy case: both have a constant trip-count. If the trip counts are not "constant up to a struct
112 // param", we definitely can't tell if they're equal. If the trip counts are only "constant up to
113 // a struct param" but not actually constant, we can ask a solver if the equations are guaranteed
114 // to be the same
115 auto tripCountA = mlir::constantTripCount(a.getLowerBound(), a.getUpperBound(), a.getStep());
116 auto tripCountB = mlir::constantTripCount(b.getLowerBound(), b.getUpperBound(), b.getStep());
117 if (tripCountA.has_value() && tripCountB.has_value() && *tripCountA == *tripCountB) {
118 return true;
119 }
120
121 if (!isConstOrStructParam(a.getLowerBound()) || !isConstOrStructParam(a.getUpperBound()) ||
122 !isConstOrStructParam(a.getStep()) || !isConstOrStructParam(b.getLowerBound()) ||
123 !isConstOrStructParam(b.getUpperBound()) || !isConstOrStructParam(b.getStep())) {
124 return false;
125 }
126
127 llvm::SMTSolverRef solver = llvm::CreateZ3Solver();
128 solver->addConstraint(/* (actually ask if they "can't be different") */ solver->mkNot(
129 solver->mkEqual(tripCount(a, solver.get()), tripCount(b, solver.get()))
130 ));
131
132 return !*solver->check();
133}
134
135mlir::LogicalResult fuseMatchingLoopPairs(mlir::Region &body, mlir::MLIRContext *context) {
136 // Start by collecting all possible loops
137 llvm::SmallVector<mlir::scf::ForOp> witnessLoops, constraintLoops;
138 body.walk<mlir::WalkOrder::PreOrder>([&witnessLoops, &constraintLoops](mlir::scf::ForOp forOp) {
139 if (!forOp->hasAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE)) {
140 return mlir::WalkResult::skip();
141 }
142 auto productSource = forOp->getAttrOfType<mlir::StringAttr>(PRODUCT_SOURCE);
143 if (productSource == FUNC_NAME_COMPUTE) {
144 witnessLoops.push_back(forOp);
145 } else if (productSource == FUNC_NAME_CONSTRAIN) {
146 constraintLoops.push_back(forOp);
147 }
148 // Skipping here, because any nested loops can't possibly be fused at this stage
149 return mlir::WalkResult::skip();
150 });
151
152 // A pair of loops will be fused iff (1) they can be fused according to the rules above, and (2)
153 // neither can be fused with anything else (so there's no ambiguity)
155 witnessLoops, constraintLoops, canLoopsBeFused
156 );
157
158 // This shouldn't happen, since we allow partial matches
159 if (mlir::failed(fusionCandidates)) {
160 return mlir::failure();
161 }
162
163 // Finally, fuse all the marked loops...
164 mlir::IRRewriter rewriter {context};
165 for (auto [w, c] : *fusionCandidates) {
166 auto fusedLoop = mlir::fuseIndependentSiblingForLoops(w, c, rewriter);
167 fusedLoop->setAttr(PRODUCT_SOURCE, rewriter.getAttr<mlir::StringAttr>("fused"));
168 // ...and recurse to fuse nested loops
169 if (mlir::failed(fuseMatchingLoopPairs(fusedLoop.getBodyRegion(), context))) {
170 return mlir::failure();
171 }
172 }
173 return mlir::success();
174}
175
176std::unique_ptr<mlir::Pass> createFuseProductLoopsPass() {
177 return std::make_unique<FuseProductLoopsPass>();
178}
179} // namespace llzk
bool isStructProduct()
Return true iff the function is within a StructDefOp and named FUNC_NAME_PRODUCT.
Definition Ops.h.inc:873
llvm::FailureOr< llvm::SetVector< std::pair< ValueT, ValueT > > > getMatchingPairs(llvm::ArrayRef< ValueT > as, llvm::ArrayRef< ValueT > bs, FnT doesMatch, bool allowPartial=true)
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
llvm::SMTExprRef tripCount(mlir::scf::ForOp op, llvm::SMTSolver *solver)
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char PRODUCT_SOURCE[]
Name of the attribute on aligned product program ops that specifies where they came from.
Definition Constants.h:40
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
mlir::LogicalResult fuseMatchingLoopPairs(mlir::Region &body, mlir::MLIRContext *context)
Identify pairs of scf.for loops that can be fused, fuse them, and then recurse to fuse nested loops.
llvm::SMTExprRef mkExpr(mlir::Value value, llvm::SMTSolver *solver)
constexpr int INDEX_WIDTH
std::unique_ptr< mlir::Pass > createFuseProductLoopsPass()