LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKWhileToForPass.cpp
Go to the documentation of this file.
1//===-- LLZKWhileToForPass.cpp ----------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2026 Project LLZK
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
24//===----------------------------------------------------------------------===//
25
32
33#include <llvm/Support/Debug.h>
34
35#include <algorithm>
36
37namespace llzk {
38#define GEN_PASS_DEF_WHILETOFORPASS
40} // namespace llzk
41
42#define DEBUG_TYPE "while-to-for"
43
44using namespace mlir;
45using namespace scf;
46using namespace llzk::boolean;
47using namespace llzk::felt;
48
49namespace {
50
51struct ForOpInfo {
52 // SSA values holding the loop bounds
53 std::optional<Value> lb, ub, step;
54
55 // Block argument index of the induction variable in the *before* block
56 std::optional<size_t> ivarIndexBefore;
57 // Block argument index of the induction variable in the *after* block
58 std::optional<size_t> ivarIndexAfter;
59
60 bool success() const {
61 return lb.has_value() && ub.has_value() && step.has_value() && ivarIndexBefore.has_value() &&
62 ivarIndexAfter.has_value();
63 }
64};
65
66static inline ForOpInfo parseInfo(WhileOp op) {
67 auto reportFailureReason = [&op](Twine reason) {
68 return op.emitWarning() << "failed to transform op: " << reason;
69 };
70
71 ForOpInfo info;
72
73 if (!std::equal(
74 op.getBeforeArguments().begin(), op.getBeforeArguments().end(),
75 op.getConditionOp().getArgs().begin()
76 )) {
77 reportFailureReason("block arguments not passed through from preamble to body");
78 return info;
79 }
80
81 auto condition = op.getConditionOp().getCondition();
82 Value ivarBefore;
83 if (auto cmp = condition.getDefiningOp<CmpOp>();
84 cmp && cmp.getPredicate() == FeltCmpPredicate::LT) {
85 // We found the ivar and the ub
86 ivarBefore = cmp.getLhs();
87 info.ub = cmp.getRhs();
88 } else {
89 reportFailureReason("could not identify an upper bound");
90 return info;
91 }
92
93 auto getBlockArgIndex = [](Value v, ValueRange argList, std::optional<size_t> &index) {
94 for (auto [i, arg] : llvm::enumerate(argList)) {
95 if (arg == v) {
96 index.emplace(i);
97 break;
98 }
99 }
100 };
101
102 // Find which # block arg the ivar is in the before and after blocks
103 getBlockArgIndex(ivarBefore, op.getBeforeArguments(), info.ivarIndexBefore);
104 getBlockArgIndex(ivarBefore, op.getConditionOp().getArgs(), info.ivarIndexAfter);
105
106 if (!info.ivarIndexBefore.has_value() || !info.ivarIndexAfter.has_value()) {
107 reportFailureReason("could not identify an induction variable");
108 return info;
109 }
110
111 // If the yielded final value of the induction variable has any uses, we can't cleanly transform
112 // this to an scf.for (which doesn't explicitly yield its induction var) without doing some extra
113 // computation. Lets just conservatively bail out in that case
114 auto yieldedIVar = op->getResults().drop_front(*info.ivarIndexAfter).front();
115 if (!yieldedIVar.use_empty()) {
116 auto report = reportFailureReason("final ivar value unsafe to drop");
117 for (auto *use : yieldedIVar.getUsers()) {
118 report.attachNote(use->getLoc()) << "used here";
119 }
120 return info;
121 }
122
123 // We need an induction variable anyway, but if the loop has {llzk.loopbounds} we can skip trying
124 // to parse the rest of the bounds and just materialize constants
125 if (op->hasAttr(llzk::LoopBoundsAttr::name)) {
126 auto bounds = op->getAttrOfType<llzk::LoopBoundsAttr>(llzk::LoopBoundsAttr::name);
127
128 OpBuilder builder {op->getContext()};
129 builder.setInsertionPoint(op);
130
131 // Make these constant felts for now; the actual for op builder will later clean it up
132 info.lb = builder
133 .create<FeltConstantOp>(
134 op->getLoc(), FeltConstAttr::get(op->getContext(), bounds.getLower())
135 )
136 .getResult();
137 info.ub = builder
138 .create<FeltConstantOp>(
139 op->getLoc(), FeltConstAttr::get(op->getContext(), bounds.getUpper())
140 )
141 .getResult();
142 info.step = builder
143 .create<FeltConstantOp>(
144 op->getLoc(), FeltConstAttr::get(op->getContext(), bounds.getStep())
145 )
146 .getResult();
147 return info;
148 }
149
150 Value ivarAfter = op.getAfterArguments().drop_front(*info.ivarIndexAfter).front();
151
152 // Now, look for the lb as the corresponding init arg
153 info.lb = *op.getInits().drop_front(*info.ivarIndexBefore).begin();
154
155 // Finally, look for the step
156 auto nextIvar = *op.getYieldedValues().drop_front(*info.ivarIndexBefore).begin();
157 if (auto incOp = nextIvar.getDefiningOp<AddFeltOp>()) {
158 if (incOp.getRhs() == ivarAfter) {
159 info.step = incOp.getLhs();
160 } else if (incOp.getLhs() == ivarAfter) {
161 info.step = incOp.getRhs();
162 }
163 }
164
165 if (!info.step.has_value()) {
166 reportFailureReason("could not identify step");
167 return info;
168 }
169
170 // Make sure the bounds aren't loop-carried, making this not a legal for loop
171 // We don't actually need to check the LB because its always passed in via an init block arg
172 std::function<bool(Value, InFlightDiagnostic &)> isRuntimeConstant;
173 isRuntimeConstant = [&op, &isRuntimeConstant](Value val, InFlightDiagnostic &reporter) -> bool {
174 // The value can't come from a block argument owned by the while loop
175 if (auto blockArg = dyn_cast<BlockArgument>(val)) {
176 reporter.attachNote(blockArg.getLoc()) << "depends on loop-carried value";
177 return blockArg.getParentBlock()->getParentOp() != op;
178 }
179
180 // The value also can't depend on any value that comes from a block argument owned by the while
181 // loop
182 return !llvm::any_of(val.getDefiningOp()->getOperands(), [&](Value operand) {
183 return !isRuntimeConstant(operand, reporter);
184 });
185 };
186
187 auto ubReport = reportFailureReason("upper bound may not be constant");
188 if (!isRuntimeConstant(*info.ub, ubReport)) {
189 return ForOpInfo {};
190 }
191 ubReport.abandon();
192
193 auto stepReport = reportFailureReason("step may not be constant");
194 if (!isRuntimeConstant(*info.step, stepReport)) {
195 return ForOpInfo {};
196 }
197 stepReport.abandon();
198
199 return info;
200}
201
202static inline FailureOr<scf::ForOp>
203transformWhileToFor(scf::WhileOp op, ForOpInfo info, RewriterBase &rewriter) {
204 llzk::ensure(info.success(), "attempting to convert non-constant while loop");
205
206 rewriter.setInsertionPointAfter(op);
207 IRMapping mapping;
208
209 auto copyIfNeeded = [op, &rewriter, &mapping](Value val) -> Value {
210 if (auto *definingOp = val.getDefiningOp();
211 definingOp && definingOp->getParentOfType<scf::WhileOp>() == op) {
212 return rewriter.clone(*definingOp, mapping)->getResult(0);
213 }
214 return val;
215 };
216
217 // scf.for bounds have to be `index`, so we might have to cast them here felt -> index before
218 // building the op, and cast them back index -> felt inside the body
219 auto toIndex = [&rewriter](Value val) -> Value {
220 if (!isa<FeltType>(val.getType())) {
221 return val;
222 }
223 return rewriter.create<llzk::cast::FeltToIndexOp>(val.getLoc(), val).getResult();
224 };
225
226 // Emit a prelude setting up the loop bounds
227 auto lb = copyIfNeeded(*info.lb);
228 auto ub = copyIfNeeded(*info.ub);
229 auto step = copyIfNeeded(*info.step);
230
231 // Store the original type of the scf.while's induction var so we can cast back if necessary
233 lb.getType() == ub.getType() && lb.getType() == step.getType(),
234 "cannot have differing types for loop bounds"
235 );
236 Type ivarType = lb.getType();
237 if (isa<FeltType>(ivarType)) {
238 lb = toIndex(lb);
239 ub = toIndex(ub);
240 step = toIndex(step);
241 }
242
243 SmallVector<Value> inits;
244 for (auto [i, init] : llvm::enumerate(op.getInits())) {
245 if (i == info.ivarIndexBefore) {
246 continue;
247 }
248 inits.push_back(init);
249 }
250
251 // Build the skeleton of the for loop
252 auto forOp = rewriter.create<scf::ForOp>(op->getLoc(), lb, ub, step, inits);
253 rewriter.setInsertionPointToStart(forOp.getBody());
254
255 auto inductionVar = forOp.getInductionVar();
256 if (isa<FeltType>(ivarType)) {
257 // If the induction var was a felt, we need to cast it back to felt in the scf.for body
258 // Note that this means the body of the scf.for might cast it back to index again anyway, but
259 // --canonicalize should fix that
260 inductionVar = rewriter.create<llzk::cast::IntToFeltOp>(forOp.getLoc(), ivarType, inductionVar)
261 .getResult();
262 }
263
264 // Start by mapping the `before` block to the loop body
265 // Each block arg of the before block should get mapped to the corresponding iter_arg, with the
266 // exception of the induction var which should get mapped to the induction var
267 auto *whilePreamble = op.getBeforeBody();
268 for (size_t i = 0; i < whilePreamble->getNumArguments(); i++) {
269 if (i == info.ivarIndexBefore) {
270 mapping.map(whilePreamble->getArgument(i), inductionVar);
271 continue;
272 }
273 mapping.map(
274 whilePreamble->getArgument(i), forOp.getRegionIterArg(i > info.ivarIndexBefore ? i - 1 : i)
275 );
276 }
277
278 // Emit the preamble into the for loop body
279 for (auto &preambleOp : *whilePreamble) {
280 // Don't emit the scf.condition; rather, use it to update the mapping in preparation for the
281 // loop body
282 if (auto condOp = dyn_cast<scf::ConditionOp>(&preambleOp)) {
283 for (auto [value, blockArg] : llvm::zip(condOp.getArgs(), op.getAfterArguments())) {
284 // TODO: maybe this isn't transitive?
285 mapping.map(blockArg, mapping.lookupOrDefault(value));
286 }
287 continue;
288 }
289 rewriter.clone(preambleOp, mapping);
290 }
291
292 auto *whileBody = op.getAfterBody();
293
294 for (auto &bodyOp : *whileBody) {
295 // scf.yield is special here because we don't want to yield the induction var to the next
296 // iteration...
297 if (auto yieldOp = dyn_cast<scf::YieldOp>(&bodyOp)) {
298 SmallVector<Value> valuesToYield;
299 for (auto [i, val] : llvm::enumerate(yieldOp.getResults())) {
300 if (i == info.ivarIndexAfter) {
301 continue;
302 }
303 // ...but the other yielded values should point to the correct iter_arg
304 valuesToYield.push_back(mapping.lookupOrDefault(val));
305 }
306 if (!valuesToYield.empty()) {
307 rewriter.create<scf::YieldOp>(yieldOp.getLoc(), valuesToYield);
308 }
309 continue;
310 }
311 rewriter.clone(bodyOp, mapping);
312 }
313
314 // scf.for doesn't explicitly yield its induction var from the final iteration, so we need to
315 // reconstruct it
316 SmallVector<Value> replacedValues;
317 for (auto [i, result] : llvm::enumerate(op.getResults())) {
318 if (i == info.ivarIndexBefore) {
319 // Note that the final value of the induction variable might not actually be the upper bound
320 // (e.g. if `step` doesn't divide `(ub - lb)`), but we've already guaranteed earlier that this
321 // value isn't being used so it doesn't matter what gets yielded here (the canonicalizer can
322 // clean it up). But we still have to yield something to preserve the shape of the op.
323 replacedValues.push_back(*info.ub);
324 continue;
325 }
326
327 replacedValues.push_back(forOp.getResult(i > info.ivarIndexBefore ? i - 1 : i));
328 }
329
330 rewriter.replaceOp(op, replacedValues);
331 return forOp;
332}
333
334class PassImpl : public llzk::impl::WhileToForPassBase<PassImpl> {
335 using Base = WhileToForPassBase<PassImpl>;
336 using Base::Base;
337
338 void getDependentDialects(DialectRegistry &registry) const override {
339 registry.insert<llzk::cast::CastDialect>();
340 }
341 void runOnOperation() override {
342 IRRewriter rewriter {getOperation().getContext()};
343 auto result = getOperation()->walk([&rewriter](scf::WhileOp op) {
344 ForOpInfo info = parseInfo(op);
345 if (!info.success()) {
346 // Ignore loops we can't prove have constant bounds
347 return WalkResult::advance();
348 }
349 if (failed(transformWhileToFor(op, info, rewriter))) {
350 return WalkResult::interrupt();
351 }
352 return WalkResult::advance();
353 });
354 if (result.wasInterrupted()) {
355 signalPassFailure();
356 }
357 }
358};
359
360} // namespace
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Definition LICENSE.txt:9
::mlir::TypedValue<::mlir::IndexType > getResult()
Definition Ops.h.inc:180
::mlir::TypedValue<::llzk::felt::FeltType > getResult()
Definition Ops.h.inc:430
::mlir::TypedValue<::mlir::Type > getResult()
Definition Ops.h.inc:812
void ensure(bool condition, const llvm::Twine &errMsg)
ExpressionValue cmp(const llvm::SMTSolverRef &solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)