33#include <llvm/Support/Debug.h>
38#define GEN_PASS_DEF_WHILETOFORPASS
42#define DEBUG_TYPE "while-to-for"
53 std::optional<Value> lb, ub, step;
56 std::optional<size_t> ivarIndexBefore;
58 std::optional<size_t> ivarIndexAfter;
60 bool success()
const {
61 return lb.has_value() && ub.has_value() && step.has_value() && ivarIndexBefore.has_value() &&
62 ivarIndexAfter.has_value();
66static inline ForOpInfo parseInfo(WhileOp op) {
67 auto reportFailureReason = [&op](Twine reason) {
68 return op.emitWarning() <<
"failed to transform op: " << reason;
74 op.getBeforeArguments().begin(), op.getBeforeArguments().end(),
75 op.getConditionOp().getArgs().begin()
77 reportFailureReason(
"block arguments not passed through from preamble to body");
81 auto condition = op.getConditionOp().getCondition();
83 if (
auto cmp = condition.getDefiningOp<
CmpOp>();
86 ivarBefore =
cmp.getLhs();
87 info.ub =
cmp.getRhs();
89 reportFailureReason(
"could not identify an upper bound");
93 auto getBlockArgIndex = [](Value v, ValueRange argList, std::optional<size_t> &index) {
94 for (
auto [i, arg] : llvm::enumerate(argList)) {
103 getBlockArgIndex(ivarBefore, op.getBeforeArguments(), info.ivarIndexBefore);
104 getBlockArgIndex(ivarBefore, op.getConditionOp().getArgs(), info.ivarIndexAfter);
106 if (!info.ivarIndexBefore.has_value() || !info.ivarIndexAfter.has_value()) {
107 reportFailureReason(
"could not identify an induction variable");
114 auto yieldedIVar = op->getResults().drop_front(*info.ivarIndexAfter).front();
115 if (!yieldedIVar.use_empty()) {
116 auto report = reportFailureReason(
"final ivar value unsafe to drop");
117 for (
auto *
use : yieldedIVar.getUsers()) {
118 report.attachNote(
use->getLoc()) <<
"used here";
125 if (op->hasAttr(llzk::LoopBoundsAttr::name)) {
126 auto bounds = op->getAttrOfType<llzk::LoopBoundsAttr>(llzk::LoopBoundsAttr::name);
128 OpBuilder builder {op->getContext()};
129 builder.setInsertionPoint(op);
134 op->getLoc(), FeltConstAttr::get(op->getContext(), bounds.getLower())
139 op->getLoc(), FeltConstAttr::get(op->getContext(), bounds.getUpper())
144 op->getLoc(), FeltConstAttr::get(op->getContext(), bounds.getStep())
150 Value ivarAfter = op.getAfterArguments().drop_front(*info.ivarIndexAfter).front();
153 info.lb = *op.getInits().drop_front(*info.ivarIndexBefore).begin();
156 auto nextIvar = *op.getYieldedValues().drop_front(*info.ivarIndexBefore).begin();
157 if (
auto incOp = nextIvar.getDefiningOp<
AddFeltOp>()) {
158 if (incOp.getRhs() == ivarAfter) {
159 info.step = incOp.getLhs();
160 }
else if (incOp.getLhs() == ivarAfter) {
161 info.step = incOp.getRhs();
165 if (!info.step.has_value()) {
166 reportFailureReason(
"could not identify step");
172 std::function<bool(Value, InFlightDiagnostic &)> isRuntimeConstant;
173 isRuntimeConstant = [&op, &isRuntimeConstant](Value val, InFlightDiagnostic &reporter) ->
bool {
175 if (
auto blockArg = dyn_cast<BlockArgument>(val)) {
176 reporter.attachNote(blockArg.getLoc()) <<
"depends on loop-carried value";
177 return blockArg.getParentBlock()->getParentOp() != op;
182 return !llvm::any_of(val.getDefiningOp()->getOperands(), [&](Value operand) {
183 return !isRuntimeConstant(operand, reporter);
187 auto ubReport = reportFailureReason(
"upper bound may not be constant");
188 if (!isRuntimeConstant(*info.ub, ubReport)) {
193 auto stepReport = reportFailureReason(
"step may not be constant");
194 if (!isRuntimeConstant(*info.step, stepReport)) {
197 stepReport.abandon();
202static inline FailureOr<scf::ForOp>
203transformWhileToFor(scf::WhileOp op, ForOpInfo info, RewriterBase &rewriter) {
204 llzk::ensure(info.success(),
"attempting to convert non-constant while loop");
206 rewriter.setInsertionPointAfter(op);
209 auto copyIfNeeded = [op, &rewriter, &mapping](Value val) -> Value {
210 if (
auto *definingOp = val.getDefiningOp();
211 definingOp && definingOp->getParentOfType<scf::WhileOp>() == op) {
212 return rewriter.clone(*definingOp, mapping)->getResult(0);
219 auto toIndex = [&rewriter](Value val) -> Value {
220 if (!isa<FeltType>(val.getType())) {
227 auto lb = copyIfNeeded(*info.lb);
228 auto ub = copyIfNeeded(*info.ub);
229 auto step = copyIfNeeded(*info.step);
233 lb.getType() == ub.getType() && lb.getType() == step.getType(),
234 "cannot have differing types for loop bounds"
236 Type ivarType = lb.getType();
237 if (isa<FeltType>(ivarType)) {
240 step = toIndex(step);
243 SmallVector<Value> inits;
244 for (
auto [i, init] : llvm::enumerate(op.getInits())) {
245 if (i == info.ivarIndexBefore) {
248 inits.push_back(init);
252 auto forOp = rewriter.create<scf::ForOp>(op->getLoc(), lb, ub, step, inits);
253 rewriter.setInsertionPointToStart(forOp.getBody());
255 auto inductionVar = forOp.getInductionVar();
256 if (isa<FeltType>(ivarType)) {
267 auto *whilePreamble = op.getBeforeBody();
268 for (
size_t i = 0; i < whilePreamble->getNumArguments(); i++) {
269 if (i == info.ivarIndexBefore) {
270 mapping.map(whilePreamble->getArgument(i), inductionVar);
274 whilePreamble->getArgument(i), forOp.getRegionIterArg(i > info.ivarIndexBefore ? i - 1 : i)
279 for (
auto &preambleOp : *whilePreamble) {
282 if (
auto condOp = dyn_cast<scf::ConditionOp>(&preambleOp)) {
283 for (
auto [value, blockArg] : llvm::zip(condOp.getArgs(), op.getAfterArguments())) {
285 mapping.map(blockArg, mapping.lookupOrDefault(value));
289 rewriter.clone(preambleOp, mapping);
292 auto *whileBody = op.getAfterBody();
294 for (
auto &bodyOp : *whileBody) {
297 if (
auto yieldOp = dyn_cast<scf::YieldOp>(&bodyOp)) {
298 SmallVector<Value> valuesToYield;
299 for (
auto [i, val] : llvm::enumerate(yieldOp.getResults())) {
300 if (i == info.ivarIndexAfter) {
304 valuesToYield.push_back(mapping.lookupOrDefault(val));
306 if (!valuesToYield.empty()) {
307 rewriter.create<scf::YieldOp>(yieldOp.getLoc(), valuesToYield);
311 rewriter.clone(bodyOp, mapping);
316 SmallVector<Value> replacedValues;
317 for (
auto [i, result] : llvm::enumerate(op.getResults())) {
318 if (i == info.ivarIndexBefore) {
323 replacedValues.push_back(*info.ub);
327 replacedValues.push_back(forOp.getResult(i > info.ivarIndexBefore ? i - 1 : i));
330 rewriter.replaceOp(op, replacedValues);
335 using Base = WhileToForPassBase<PassImpl>;
338 void getDependentDialects(DialectRegistry ®istry)
const override {
341 void runOnOperation()
override {
342 IRRewriter rewriter {getOperation().getContext()};
343 auto result = getOperation()->walk([&rewriter](scf::WhileOp op) {
344 ForOpInfo info = parseInfo(op);
345 if (!info.success()) {
347 return WalkResult::advance();
349 if (failed(transformWhileToFor(op, info, rewriter))) {
350 return WalkResult::interrupt();
352 return WalkResult::advance();
354 if (result.wasInterrupted()) {
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
::mlir::TypedValue<::mlir::IndexType > getResult()
::mlir::TypedValue<::llzk::felt::FeltType > getResult()
::mlir::TypedValue<::mlir::Type > getResult()
void ensure(bool condition, const llvm::Twine &errMsg)
ExpressionValue cmp(const llvm::SMTSolverRef &solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)