LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
IntervalAnalysis.cpp
Go to the documentation of this file.
1//===-- IntervalAnalysis.cpp - Interval analysis implementation -*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
16#include "llzk/Util/Debug.h"
19
20#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
21#include <mlir/Dialect/SCF/IR/SCF.h>
22
23#include <llvm/ADT/EquivalenceClasses.h>
24#include <llvm/ADT/TypeSwitch.h>
25
26#include <functional>
27
28using namespace mlir;
29
30namespace llzk {
31
32using namespace array;
33using namespace boolean;
34using namespace cast;
35using namespace component;
36using namespace constrain;
37using namespace felt;
38using namespace function;
39
40namespace {
41
42std::optional<UnreducedInterval> mergeUnreducedIntervals(
43 const std::optional<UnreducedInterval> &lhs, const std::optional<UnreducedInterval> &rhs
44) {
45 if (!lhs.has_value() || !rhs.has_value()) {
46 return std::nullopt;
47 }
48 return lhs->doUnion(*rhs);
49}
50
51template <typename Fn>
52std::optional<UnreducedInterval>
53combineUnreducedIntervals(const ExpressionValue &lhs, const ExpressionValue &rhs, Fn &&fn) {
54 if (!lhs.hasUnreducedInterval() || !rhs.hasUnreducedInterval()) {
55 return std::nullopt;
56 }
57 return fn(lhs.getUnreducedInterval(), rhs.getUnreducedInterval());
58}
59
60ExpressionValue refineReducedInterval(const ExpressionValue &expr, const Interval &newInterval) {
61 ExpressionValue refined = expr.withInterval(newInterval);
62 if (expr.getInterval() != newInterval) {
63 refined = refined.dropUnreducedInterval();
64 }
65 return refined;
66}
67
68bool isInMaybeSkippedScfRegion(Operation *op) {
69 for (Operation *parent = op->getParentOp(); parent != nullptr; parent = parent->getParentOp()) {
70 if (llvm::isa<FuncDefOp>(parent)) {
71 return false;
72 }
73
74 // `writeResults` is a storage side-channel, not a path-sensitive lattice.
75 // Writes nested under branch/loop control may not be absolute on every path through the
76 // enclosing op, so keep the prior state instead of treating the nested write as unconditional.
77 if (llvm::isa<scf::ForOp, scf::IfOp, scf::WhileOp>(parent)) {
78 return true;
79 }
80 }
81 return false;
82}
83
84std::optional<UnreducedInterval> getBooleanUnreducedInterval(const Interval &interval) {
85 return interval.isBoolean() ? std::optional<UnreducedInterval>(interval.firstUnreduced())
86 : std::nullopt;
87}
88
89FailureOr<std::vector<SourceRef>>
90translateRef(const SourceRef &ref, const SourceRefRemappings &translations) {
91 std::vector<SourceRef> refs;
92 for (const auto &[prefix, vals] : translations) {
93 if (!ref.isValidPrefix(prefix)) {
94 continue;
95 }
96
97 if (vals.isArray()) {
98 auto suffix = ref.getSuffix(prefix);
99 ensure(succeeded(suffix), "prefix checked before SourceRef suffix extraction");
100
101 std::vector<SourceRefIndex> arraySuffix, remainingSuffix;
102 bool suffixIsPastArray = false;
103 for (const SourceRefIndex &idx : *suffix) {
104 if (!suffixIsPastArray && arraySuffix.size() < vals.getNumArrayDims() &&
105 (idx.isIndex() || idx.isIndexRange())) {
106 arraySuffix.push_back(idx);
107 continue;
108 }
109 suffixIsPastArray = true;
110 remainingSuffix.push_back(idx);
111 }
112
113 auto resolvedValsRes = vals.extract(arraySuffix);
114 ensure(succeeded(resolvedValsRes), "could not resolve translated SourceRef array child");
115 SourceRefSet folded = resolvedValsRes->first.foldToScalar();
116 if (remainingSuffix.empty()) {
117 refs.insert(refs.end(), folded.begin(), folded.end());
118 continue;
119 }
120
121 for (const SourceRef &baseRef : folded) {
122 auto translatedRef = mlir::FailureOr<SourceRef>(baseRef);
123 for (const SourceRefIndex &idx : remainingSuffix) {
124 if (failed(translatedRef)) {
125 break;
126 }
127 translatedRef = translatedRef->createChild(idx);
128 }
129 if (succeeded(translatedRef)) {
130 refs.push_back(*translatedRef);
131 }
132 }
133 } else {
134 for (const SourceRef &replacement : vals.getScalarValue()) {
135 auto translated = ref.translate(prefix, replacement);
136 if (succeeded(translated)) {
137 refs.push_back(*translated);
138 }
139 }
140 }
141 }
142
143 if (refs.empty()) {
144 return failure();
145 }
146 return refs;
147}
148
149bool isDirectSourceRefValue(Value value) {
150 if (llvm::isa<BlockArgument>(value)) {
151 return true;
152 }
153
154 Operation *definingOp = value.getDefiningOp();
155 return llvm::isa_and_present<MemberReadOp, ReadArrayOp, polymorphic::ConstReadOp>(definingOp);
156}
157
158std::optional<SourceRefLatticeValue>
159getIdentitySourceRefState(DataFlowSolver &solver, Value value) {
160 if (isDirectSourceRefValue(value)) {
162 if (val.isScalar()) {
163 return val;
164 }
165 return std::nullopt;
166 }
167
168 auto createArray = llvm::dyn_cast_if_present<CreateArrayOp>(value.getDefiningOp());
169 if (!createArray) {
170 return std::nullopt;
171 }
172
173 SourceRefLatticeValue arrayVal(createArray.getType().getShape());
174 for (auto [idx, element] : llvm::enumerate(createArray.getElements())) {
175 std::optional<SourceRefLatticeValue> elementVal = getIdentitySourceRefState(solver, element);
176 if (!elementVal.has_value()) {
177 return std::nullopt;
178 }
179 (void)arrayVal.getElemFlatIdx(idx).setValue(*elementVal);
180 }
181 return arrayVal;
182}
183
184llvm::EquivalenceClasses<SourceRef>
185collectDirectEqualityRefs(DataFlowSolver &solver, FuncDefOp fn) {
186 llvm::EquivalenceClasses<SourceRef> eqRefs;
187 fn.walk([&](EmitEqualityOp eqOp) {
188 Operation *op = eqOp.getOperation();
189 if (!dataflow::isOperationLive(solver, op)) {
190 return;
191 }
192
193 Value lhs = eqOp.getLhs();
194 Value rhs = eqOp.getRhs();
195 if (!isDirectSourceRefValue(lhs) || !isDirectSourceRefValue(rhs)) {
196 return;
197 }
198
201 if (!lhsState.isScalar() || !rhsState.isScalar() || !lhsState.isSingleValue() ||
202 !rhsState.isSingleValue()) {
203 return;
204 }
205
206 const SourceRef &lhsRef = lhsState.getSingleValue();
207 const SourceRef &rhsRef = rhsState.getSingleValue();
208 if (lhsRef.isConstant() || rhsRef.isConstant()) {
209 return;
210 }
211 eqRefs.unionSets(lhsRef, rhsRef);
212 });
213 return eqRefs;
214}
215
216} // namespace
217
218/* ExpressionValue */
219
220llvm::SMTExprRef createFieldInverseExpr(
221 const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &val,
222 StringRef suffix = ""
223) {
224 const Field &field = val.getField();
225 const Interval &iv = val.getInterval();
226 if (iv.isDegenerate() && iv.lhs() != field.zero()) {
227 DynamicAPInt invVal = field.inv(iv.lhs());
228 return solver->mkBitvector(toAPSInt(invVal), field.bitWidth());
229 }
230
231 // The definition of an inverse X^-1 is Y s.t. XY % prime = 1.
232 // To create this expression, we create a new symbol for Y and add the
233 // XY % prime = 1 constraint to the solver.
234 std::string symName = buildStringViaInsertionOp(*op);
235 if (!suffix.empty()) {
236 symName += suffix.str();
237 }
238 llvm::SMTExprRef invSym = field.createSymbol(solver, symName.c_str());
239 llvm::SMTExprRef one = solver->mkBitvector(APSInt::get(1), field.bitWidth());
240 llvm::SMTExprRef prime = solver->mkBitvector(toAPSInt(field.prime()), field.bitWidth());
241 llvm::SMTExprRef mult = solver->mkBVMul(val.getExpr(), invSym);
242 llvm::SMTExprRef mod = solver->mkBVURem(mult, prime);
243 llvm::SMTExprRef constraint = solver->mkEqual(mod, one);
244 solver->addConstraint(constraint);
245 return invSym;
246}
247
249 if (expr == nullptr && rhs.expr == nullptr) {
250 return i == rhs.i && unreduced == rhs.unreduced;
251 }
252 if (expr == nullptr || rhs.expr == nullptr) {
253 return false;
254 }
255 return i == rhs.i && unreduced == rhs.unreduced && *expr == *rhs.expr;
256}
257
259boolToFelt(const llvm::SMTSolverRef &solver, const ExpressionValue &expr, unsigned bitwidth) {
260 llvm::SMTExprRef zero = solver->mkBitvector(mlir::APSInt::get(0), bitwidth);
261 llvm::SMTExprRef one = solver->mkBitvector(mlir::APSInt::get(1), bitwidth);
262 llvm::SMTExprRef boolToFeltConv = solver->mkIte(expr.getExpr(), one, zero);
263 return expr.withExpression(boolToFeltConv);
264}
265
267 const llvm::SMTSolverRef &solver, const ExpressionValue &cond, const ExpressionValue &trueVal,
268 const ExpressionValue &falseVal
269) {
270 const Field &f = trueVal.getField();
271 const Interval &condInterval = cond.getInterval();
272 Interval resultInterval;
273 if (condInterval.isEmpty()) {
274 resultInterval = Interval::Empty(f);
275 } else if (condInterval.isDegenerate() && condInterval.rhs() == f.one()) {
276 resultInterval = trueVal.getInterval();
277 } else if (condInterval.isDegenerate() && condInterval.rhs() == f.zero()) {
278 resultInterval = falseVal.getInterval();
279 } else {
280 resultInterval = trueVal.getInterval().join(falseVal.getInterval());
281 }
282 llvm::SMTExprRef resultExpr =
283 solver->mkIte(cond.getExpr(), trueVal.getExpr(), falseVal.getExpr());
284 std::optional<UnreducedInterval> resultUnreduced;
285 if (condInterval.isEmpty()) {
286 resultUnreduced = resultInterval.firstUnreduced();
287 } else if (condInterval.isDegenerate() && condInterval.rhs() == f.one()) {
288 resultUnreduced = trueVal.getOptionalUnreducedInterval();
289 } else if (condInterval.isDegenerate() && condInterval.rhs() == f.zero()) {
290 resultUnreduced = falseVal.getOptionalUnreducedInterval();
291 } else {
292 resultUnreduced = mergeUnreducedIntervals(
294 );
295 }
296 return ExpressionValue(resultExpr, resultInterval, std::move(resultUnreduced));
297}
298
300 const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs
301) {
302 Interval res = lhs.i.intersect(rhs.i);
303 const auto *exprEq = solver->mkEqual(lhs.expr, rhs.expr);
304 return ExpressionValue(exprEq, res);
305}
306
308add(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
309 ExpressionValue res;
310 res.i = lhs.i + rhs.i;
311 res.expr = solver->mkBVAdd(lhs.expr, rhs.expr);
312 res = res.withOptionalUnreducedInterval(combineUnreducedIntervals(lhs, rhs, std::plus {}));
313 return res;
314}
315
317sub(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
318 ExpressionValue res;
319 res.i = lhs.i - rhs.i;
320 res.expr = solver->mkBVSub(lhs.expr, rhs.expr);
321 res = res.withOptionalUnreducedInterval(combineUnreducedIntervals(lhs, rhs, std::minus {}));
322 return res;
323}
324
326mul(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
327 ExpressionValue res;
328 res.i = lhs.i * rhs.i;
329 res.expr = solver->mkBVMul(lhs.expr, rhs.expr);
330 res = res.withOptionalUnreducedInterval(combineUnreducedIntervals(lhs, rhs, std::multiplies {}));
331 return res;
332}
333
335div(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs,
336 const ExpressionValue &rhs) {
337 ExpressionValue res;
338 auto divRes = feltDiv(lhs.i, rhs.i);
339 if (failed(divRes)) {
340 const Field &field = lhs.getField();
341 const Interval &rhsInterval = rhs.getInterval();
342 Interval zero = Interval::Degenerate(field, field.zero());
343 if (!rhsInterval.isDegenerate()) {
344 if (rhsInterval.intersect(zero).isNotEmpty()) {
345 op->emitWarning(
346 "non-degenerate felt.div divisors are not tracked precisely, and the divisor may "
347 "contain zero. Range of division result will be treated as unbounded."
348 )
349 .report();
350 } else {
351 op->emitWarning(
352 "non-degenerate felt.div divisors are not tracked precisely because precise field "
353 "division over intervals would require enumerating divisor inverses. Range of "
354 "division result will be treated as unbounded."
355 )
356 .report();
357 }
358 } else {
359 op->emitWarning(
360 "divisor is zero, leading to a divide-by-zero error. Range of division result will "
361 "be treated as unbounded."
362 )
363 .report();
364 }
365 res.i = Interval::Entire(lhs.getField());
366 } else {
367 res.i = *divRes;
368 }
369 llvm::SMTExprRef invExpr = createFieldInverseExpr(solver, op, rhs, ".div_inv");
370 res.expr = solver->mkBVMul(lhs.expr, invExpr);
371 return res;
372}
373
375 const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs,
376 const ExpressionValue &rhs
377) {
378 ExpressionValue res;
379 auto divRes = unsignedIntDiv(lhs.i, rhs.i);
380 if (failed(divRes)) {
381 op->emitWarning(
382 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
383 " Range of division result will be treated as unbounded."
384 )
385 .report();
386 res.i = Interval::Entire(lhs.getField());
387 } else {
388 res.i = *divRes;
389 }
390 res.expr = solver->mkBVUDiv(lhs.expr, rhs.expr);
391 return res;
392}
393
395 const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs,
396 const ExpressionValue &rhs
397) {
398 ExpressionValue res;
399 auto divRes = signedIntDiv(lhs.i, rhs.i);
400 if (failed(divRes)) {
401 op->emitWarning(
402 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
403 " Range of division result will be treated as unbounded."
404 )
405 .report();
406 res.i = Interval::Entire(lhs.getField());
407 } else {
408 res.i = *divRes;
409 }
410 res.expr = solver->mkBVSDiv(lhs.expr, rhs.expr);
411 return res;
412}
413
414ExpressionValue
415mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
416 ExpressionValue res;
417 res.i = lhs.i % rhs.i;
418 res.expr = solver->mkBVURem(lhs.expr, rhs.expr);
419 return res;
420}
421
423sintMod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
424 return ExpressionValue(
425 solver->mkBVSRem(lhs.getExpr(), rhs.getExpr()),
426 signedMod(lhs.getInterval(), rhs.getInterval())
427 );
428}
429
430ExpressionValue
431bitAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
432 ExpressionValue res;
433 res.i = lhs.i & rhs.i;
434 res.expr = solver->mkBVAnd(lhs.expr, rhs.expr);
435 return res;
436}
437
439bitOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
440 ExpressionValue res;
441 res.i = lhs.i | rhs.i;
442 res.expr = solver->mkBVOr(lhs.expr, rhs.expr);
443 return res;
444}
445
447bitXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
448 if (lhs.isBoolSort(solver) && rhs.isBoolSort(solver)) {
449 return boolXor(solver, lhs, rhs);
450 }
451
452 ExpressionValue res;
453 res.i = lhs.i ^ rhs.i;
454 res.expr = solver->mkBVXor(lhs.expr, rhs.expr);
455 return res;
456}
457
459 const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs
460) {
461 ExpressionValue res;
462 res.i = lhs.i << rhs.i;
463 res.expr = solver->mkBVShl(lhs.expr, rhs.expr);
464 return res;
465}
466
468 const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs
469) {
470 ExpressionValue res;
471 res.i = lhs.i >> rhs.i;
472 res.expr = solver->mkBVLshr(lhs.expr, rhs.expr);
473 return res;
474}
475
477cmp(const llvm::SMTSolverRef &solver, CmpOp op, const ExpressionValue &lhs,
478 const ExpressionValue &rhs) {
479 ExpressionValue res;
480 const Field &f = lhs.getField();
481 // Default result is any boolean output for when we are unsure about the comparison result.
482 res.i = Interval::Boolean(f);
483 switch (op.getPredicate()) {
484 case FeltCmpPredicate::EQ:
485 res.expr = solver->mkEqual(lhs.expr, rhs.expr);
486 if (lhs.i.isDegenerate() && rhs.i.isDegenerate()) {
487 res.i = lhs.i == rhs.i ? Interval::True(f) : Interval::False(f);
488 } else if (lhs.i.intersect(rhs.i).isEmpty()) {
489 res.i = Interval::False(f);
490 }
491 break;
492 case FeltCmpPredicate::NE:
493 res.expr = solver->mkNot(solver->mkEqual(lhs.expr, rhs.expr));
494 if (lhs.i.isDegenerate() && rhs.i.isDegenerate()) {
495 res.i = lhs.i != rhs.i ? Interval::True(f) : Interval::False(f);
496 } else if (lhs.i.intersect(rhs.i).isEmpty()) {
497 res.i = Interval::True(f);
498 }
499 break;
500 case FeltCmpPredicate::LT:
501 res.expr = solver->mkBVUlt(lhs.expr, rhs.expr);
502 if (lhs.i.toUnreduced().computeGEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
503 res.i = Interval::True(f);
504 }
505 if (lhs.i.toUnreduced().computeLTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
506 res.i = Interval::False(f);
507 }
508 break;
509 case FeltCmpPredicate::LE:
510 res.expr = solver->mkBVUle(lhs.expr, rhs.expr);
511 if (lhs.i.toUnreduced().computeGTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
512 res.i = Interval::True(f);
513 }
514 if (lhs.i.toUnreduced().computeLEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
515 res.i = Interval::False(f);
516 }
517 break;
518 case FeltCmpPredicate::GT:
519 res.expr = solver->mkBVUgt(lhs.expr, rhs.expr);
520 if (lhs.i.toUnreduced().computeLEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
521 res.i = Interval::True(f);
522 }
523 if (lhs.i.toUnreduced().computeGTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
524 res.i = Interval::False(f);
525 }
526 break;
527 case FeltCmpPredicate::GE:
528 res.expr = solver->mkBVUge(lhs.expr, rhs.expr);
529 if (lhs.i.toUnreduced().computeLTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
530 res.i = Interval::True(f);
531 }
532 if (lhs.i.toUnreduced().computeGEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
533 res.i = Interval::False(f);
534 }
535 break;
536 }
537 res = res.withOptionalUnreducedInterval(getBooleanUnreducedInterval(res.i));
538 return res;
539}
540
542boolAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
543 ExpressionValue res;
544 res.i = boolAnd(lhs.i, rhs.i);
545 res.expr = solver->mkAnd(lhs.expr, rhs.expr);
546 res = res.withOptionalUnreducedInterval(getBooleanUnreducedInterval(res.i));
547 return res;
548}
549
551boolOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
552 ExpressionValue res;
553 res.i = boolOr(lhs.i, rhs.i);
554 res.expr = solver->mkOr(lhs.expr, rhs.expr);
555 res = res.withOptionalUnreducedInterval(getBooleanUnreducedInterval(res.i));
556 return res;
557}
558
560boolXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
561 ExpressionValue res;
562 res.i = boolXor(lhs.i, rhs.i);
563 // There's no Xor, so we do (L || R) && !(L && R)
564 res.expr = solver->mkAnd(
565 solver->mkOr(lhs.expr, rhs.expr), solver->mkNot(solver->mkAnd(lhs.expr, rhs.expr))
566 );
567 res = res.withOptionalUnreducedInterval(getBooleanUnreducedInterval(res.i));
568 return res;
569}
570
571ExpressionValue neg(const llvm::SMTSolverRef &solver, const ExpressionValue &val) {
572 ExpressionValue res;
573 res.i = -val.i;
574 res.expr = solver->mkBVNeg(val.expr);
575 if (val.hasUnreducedInterval()) {
577 }
578 return res;
579}
580
581ExpressionValue notOp(const llvm::SMTSolverRef &solver, const ExpressionValue &val) {
582 ExpressionValue res;
583 res.i = ~val.i;
584 res.expr = solver->mkBVNot(val.expr);
585 return res;
586}
587
588ExpressionValue boolNot(const llvm::SMTSolverRef &solver, const ExpressionValue &val) {
589 ExpressionValue res;
590 res.i = boolNot(val.i);
591 res.expr = solver->mkNot(val.expr);
592 res = res.withOptionalUnreducedInterval(getBooleanUnreducedInterval(res.i));
593 return res;
594}
595
597fallbackUnaryOp(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &val) {
598 const Field &field = val.getField();
599 ExpressionValue res;
600 res.i = Interval::Entire(field);
601 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
602 .Case<InvFeltOp>([&](auto) {
603 return createFieldInverseExpr(solver, op, val);
604 }).Default([](Operation *unsupported) {
605 llvm::report_fatal_error(
606 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
607 );
608 return nullptr;
609 });
610
611 if (llvm::isa<InvFeltOp>(op)) {
612 // We have the inverse's unreduced range to be [0, p-1] because for any integer z we can always
613 // choose a conical element x \in [0, p-1] such that 1) (z * x) %p = 0 if z = 0, 2) (z * x) % p
614 // = 1
615 res = res.withOptionalUnreducedInterval(UnreducedInterval(field.zero(), field.maxVal()));
616 }
617
618 return res;
619}
620
621void ExpressionValue::print(mlir::raw_ostream &os) const {
622 if (expr) {
623 expr->print(os);
624 } else {
625 os << "<null expression>";
626 }
627
628 os << " ( interval: " << i << " )";
629 if (unreduced.has_value()) {
630 os << " ( unreduced: " << *unreduced << " )";
631 }
632}
633
634/* IntervalAnalysisLattice */
635
636ChangeResult IntervalAnalysisLattice::join(const AbstractSparseLattice & /*other*/) {
637 // The update logic is handled in visitOperation; we don't support a generic
638 // join operation, as it may override valid intervals.
639 return ChangeResult::NoChange;
640}
641
642ChangeResult IntervalAnalysisLattice::meet(const AbstractSparseLattice & /*other*/) {
643 // The update logic is handled in visitOperation; we don't support a generic
644 // meet operation, as it may override valid intervals.
645 return ChangeResult::NoChange;
646}
647
648void IntervalAnalysisLattice::print(mlir::raw_ostream &os) const {
649 os << "IntervalAnalysisLattice { " << val << " }";
650}
651
653 if (val == newVal) {
654 return ChangeResult::NoChange;
655 }
656 val = newVal;
657 return ChangeResult::Change;
658}
659
661 LatticeValue newVal(e);
662 return setValue(newVal);
663}
664
666 if (!constraints.contains(e)) {
667 constraints.insert(e);
668 return ChangeResult::Change;
669 }
670 return ChangeResult::NoChange;
671}
672
673/* IntervalDataFlowAnalysis */
674
675SourceRefLatticeValue IntervalDataFlowAnalysis::getSourceRefState(Value val) {
676 return SourceRefAnalysis::getValueState(_dataflowSolver, val);
677}
678
679std::vector<SourceRefIndex> IntervalDataFlowAnalysis::getArrayAccessIndices(
680 Operation * /*baseOp*/, ArrayAccessOpInterface arrayAccessOp
681) {
682 std::vector<SourceRefIndex> indices;
683 ArrayType arrayType = arrayAccessOp.getArrRefType();
684 size_t numIndices = arrayAccessOp.getIndices().size();
685 indices.reserve(numIndices);
686
687 for (size_t i = 0; i < numIndices; ++i) {
688 Value idxOperand = arrayAccessOp.getIndices()[i];
689 SourceRefLatticeValue idxVals = getSourceRefState(idxOperand);
690
691 // Only exact constant indices get tracked precisely.
692 if (idxVals.isSingleValue() && idxVals.getSingleValue().isConstant()) {
693 indices.emplace_back(*idxVals.getSingleValue().getConstantValue());
694 } else {
695 auto lower = APInt::getZero(64);
696 APInt upper(64, arrayType.getDimSize(i));
697 indices.emplace_back(lower, upper);
698 }
699 }
700
701 return indices;
702}
703
704mlir::FailureOr<SourceRef> IntervalDataFlowAnalysis::getArrayAccessRef(
705 Operation *baseOp, ArrayAccessOpInterface arrayAccessOp
706) {
707 std::vector<SourceRefIndex> indices = getArrayAccessIndices(baseOp, arrayAccessOp);
708 Value arrayVal = arrayAccessOp.getArrRef();
709 if (auto blockArg = llvm::dyn_cast<BlockArgument>(arrayVal)) {
710 return SourceRef(blockArg, std::move(indices));
711 }
712 if (auto result = llvm::dyn_cast<OpResult>(arrayVal)) {
713 return SourceRef(result, std::move(indices));
714 }
715 return failure();
716}
717
718Interval IntervalDataFlowAnalysis::getRefInterval(const SourceRef &ref) {
719 if (auto it = writeResults.find(ref); it != writeResults.end()) {
720 return it->second.getInterval();
721 }
722
723 if (ref.isConstantInt()) {
724 auto constVal = ref.getConstantValue();
725 if (succeeded(constVal)) {
726 return Interval::Degenerate(field.get(), *constVal);
727 }
728 }
729
730 if (ref.isRooted() && ref.getPath().empty()) {
731 auto rootVal = ref.getRoot();
732 if (succeeded(rootVal) && !llvm::isa<ArrayType, StructType>(rootVal->getType())) {
733 const ExpressionValue &rootExpr = getLatticeElement(*rootVal)->getValue().getScalarValue();
734 if (rootExpr.getExpr() != nullptr) {
735 return rootExpr.getInterval();
736 }
737 }
738 }
739
740 return getDefaultIntervalForType(ref.getType());
741}
742
743std::optional<UnreducedInterval>
744IntervalDataFlowAnalysis::getDefaultUnreducedIntervalForType(mlir::Type ty) const {
745 if (!trackUnreducedIntervals) {
746 return std::nullopt;
747 }
748 if (isBooleanType(ty)) {
749 return UnreducedInterval(0, 1);
750 }
751 return UnreducedInterval(field.get().zero(), field.get().maxVal());
752}
753
754std::optional<UnreducedInterval>
755IntervalDataFlowAnalysis::getRefUnreducedInterval(const SourceRef &ref) {
756 if (!trackUnreducedIntervals) {
757 return std::nullopt;
758 }
759
760 if (auto it = writeResults.find(ref); it != writeResults.end()) {
761 return it->second.getOptionalUnreducedInterval();
762 }
763
764 if (ref.isConstantInt()) {
765 auto constVal = ref.getConstantValue();
766 if (succeeded(constVal)) {
767 return UnreducedInterval(*constVal, *constVal);
768 }
769 }
770
771 if (ref.isRooted() && ref.getPath().empty()) {
772 auto rootVal = ref.getRoot();
773 if (succeeded(rootVal) && !llvm::isa<ArrayType, StructType>(rootVal->getType())) {
774 const ExpressionValue &rootExpr = getLatticeElement(*rootVal)->getValue().getScalarValue();
775 if (rootExpr.hasUnreducedInterval()) {
776 return rootExpr.getUnreducedInterval();
777 }
778 }
779 }
780
781 return getRefInterval(ref).firstUnreduced();
782}
783
784ExpressionValue IntervalDataFlowAnalysis::getRefValue(const SourceRef &ref, Value val) {
785 if (auto it = writeResults.find(ref); it != writeResults.end()) {
786 return it->second;
787 }
788 return createUnknownValue(val)
789 .withInterval(getRefInterval(ref))
790 .withOptionalUnreducedInterval(getRefUnreducedInterval(ref));
791}
792
793void IntervalDataFlowAnalysis::recordRefWrite(
794 const SourceRef &writtenRef, const ExpressionValue &writeVal, bool mayBeSkipped
795) {
796 auto joinStoredWrite = [this, &writtenRef](
797 const ExpressionValue &old, const ExpressionValue &next
798 ) -> ExpressionValue {
799 Interval combinedWrite = old.getInterval().join(next.getInterval());
800 auto combinedUnreduced = mergeUnreducedIntervals(
801 old.getOptionalUnreducedInterval(), next.getOptionalUnreducedInterval()
802 );
803 if (old.getExpr() != nullptr && next.getExpr() != nullptr &&
804 *old.getExpr() == *next.getExpr()) {
805 return old.withInterval(combinedWrite).withOptionalUnreducedInterval(combinedUnreduced);
806 }
807
808 return ExpressionValue(
809 getOrCreateSymbol(writtenRef), combinedWrite, std::move(combinedUnreduced)
810 );
811 };
812
813 if (auto it = writeResults.find(writtenRef); it != writeResults.end()) {
814 it->second = joinStoredWrite(it->second, writeVal);
815 } else if (mayBeSkipped) {
816 ExpressionValue noWrite(
817 getOrCreateSymbol(writtenRef), getRefInterval(writtenRef),
818 getRefUnreducedInterval(writtenRef)
819 );
820 writeResults[writtenRef] = joinStoredWrite(noWrite, writeVal);
821 } else {
822 writeResults[writtenRef] = writeVal;
823 }
824
825 const ExpressionValue &readerUpdate = mayBeSkipped ? writeResults[writtenRef] : writeVal;
826 for (Lattice *readerLattice : readResults[writtenRef]) {
827 ExpressionValue prior = readerLattice->getValue().getScalarValue();
828 Interval intersection = prior.getInterval().intersect(readerUpdate.getInterval());
829 ExpressionValue newVal = prior.withInterval(intersection);
830 propagateIfChanged(readerLattice, readerLattice->setValue(newVal));
831 }
832}
833
835 Operation *op, ArrayRef<const Lattice *> operands, ArrayRef<Lattice *> results
836) {
837 // We only perform the visitation on operations within functions
838 FuncDefOp fn = op->getParentOfType<FuncDefOp>();
839 if (!fn) {
840 return success();
841 }
842
843 // If there are no operands or results, skip.
844 if (operands.empty() && results.empty()) {
845 return success();
846 }
847
848 // Get the values or defaults from the operand lattices
849 llvm::SmallVector<LatticeValue> operandVals;
850 llvm::SmallVector<std::optional<SourceRef>> operandRefs;
851 auto resolveRefStateValue =
852 [&](Value value, const SourceRefLatticeValue &refSet) -> std::optional<LatticeValue> {
853 ensure(refSet.isScalar(), "should have ruled out array values already");
854
855 if (refSet.getScalarValue().empty()) {
856 // If we can't compute the reference, then there must be some unsupported
857 // op the reference analysis cannot handle. We emit a warning and return
858 // early, since there's no meaningful computation we can do for this op.
859 op->emitWarning()
860 .append(
861 "state of ", value,
862 " is empty; defining operation is unsupported by SourceRef analysis"
863 )
864 .report();
865 return std::nullopt;
866 }
867
868 if (!refSet.isSingleValue()) {
869 Interval joinedInterval = Interval::Empty(field.get());
870 std::optional<UnreducedInterval> joinedUnreduced = std::nullopt;
871 bool sawFirst = false;
872 for (const SourceRef &ref : refSet.getScalarValue()) {
873 joinedInterval = joinedInterval.join(getRefInterval(ref));
874 auto refUnreduced = getRefUnreducedInterval(ref);
875 if (!sawFirst) {
876 joinedUnreduced = refUnreduced;
877 sawFirst = true;
878 } else {
879 joinedUnreduced = mergeUnreducedIntervals(joinedUnreduced, refUnreduced);
880 }
881 }
882 ExpressionValue anyVal = createUnknownValue(value)
883 .withInterval(joinedInterval)
884 .withOptionalUnreducedInterval(joinedUnreduced);
885 return LatticeValue(anyVal);
886 }
887
888 return LatticeValue(getRefValue(refSet.getSingleValue(), value));
889 };
890 for (unsigned opNum = 0; opNum < op->getNumOperands(); ++opNum) {
891 Value val = op->getOperand(opNum);
892 SourceRefLatticeValue refSet = getSourceRefState(val);
893 if (refSet.isSingleValue()) {
894 operandRefs.push_back(refSet.getSingleValue());
895 } else {
896 operandRefs.push_back(std::nullopt);
897 }
898 // First, lookup the operand value after it is initialized
899 auto priorState = operands[opNum]->getValue();
900 if (priorState.getScalarValue().getExpr() != nullptr) {
901 operandVals.push_back(priorState);
902 continue;
903 }
904
905 if (auto readArr = llvm::dyn_cast_if_present<ReadArrayOp>(val.getDefiningOp())) {
906 auto arrayRef = getArrayAccessRef(op, readArr);
907 if (succeeded(arrayRef)) {
908 if (auto it = writeResults.find(*arrayRef); it != writeResults.end()) {
909 operandVals.emplace_back(it->second);
910 Lattice *operandLattice = getLatticeElement(val);
911 (void)operandLattice->setValue(it->second);
912 continue;
913 }
914 }
915 }
916
917 // Else, look up the stored value by `SourceRef`.
918 // We only care about scalar type values, so we ignore composite types, which
919 // are currently limited to structs and arrays.
920 Type valTy = val.getType();
921 if (llvm::isa<ArrayType, StructType>(valTy)) {
922 ExpressionValue anyVal(field.get(), createSymbol(valTy, buildStringViaPrint(val).c_str()));
923 operandVals.emplace_back(anyVal);
924 continue;
925 }
926
927 auto resolvedValue = resolveRefStateValue(val, refSet);
928 if (!resolvedValue.has_value()) {
929 // We still return success so we can return overapproximated and partial
930 // results to the user.
931 return success();
932 }
933 operandVals.push_back(*resolvedValue);
934
935 // Since we initialized a value that was not found in the before lattice,
936 // update that value in the lattice so we can find it later, but we don't
937 // need to propagate the changes, since we already have what we need.
938 Lattice *operandLattice = getLatticeElement(val);
939 (void)operandLattice->setValue(operandVals[opNum]);
940 }
941
942 if (isReadOp(op) && op->getNumResults() == 1) {
943 Value resultVal = op->getResult(0);
944 if (!llvm::isa<ArrayType, StructType>(resultVal.getType())) {
945 auto resolvedValue = resolveRefStateValue(resultVal, getSourceRefState(resultVal));
946 if (resolvedValue.has_value()) {
947 propagateIfChanged(results[0], results[0]->setValue(*resolvedValue));
948 }
949 }
950 return success();
951 }
952
953 // Now, the way we update is dependent on the type of the operation.
954 if (isConstOp(op)) {
955 llvm::DynamicAPInt constVal = getConst(op);
956 llvm::SMTExprRef expr;
957 if (isBoolConstOp(op)) {
958 expr = createConstBoolExpr(constVal != 0);
959 } else {
960 expr = createConstBitvectorExpr(constVal);
961 }
962
963 ExpressionValue latticeVal(field.get(), expr, constVal);
964 if (trackUnreducedIntervals) {
965 latticeVal = latticeVal.withUnreducedInterval(UnreducedInterval(constVal, constVal));
966 }
967 propagateIfChanged(results[0], results[0]->setValue(latticeVal));
968 } else if (isArithmeticOp(op)) {
969 ExpressionValue result;
970 if (operands.size() == 2) {
971 result = performBinaryArithmetic(op, operandVals[0], operandVals[1]);
972 } else {
973 result = performUnaryArithmetic(op, operandVals[0]);
974 }
975
976 // Also intersect with prior interval, if it's initialized
977 const ExpressionValue &prior = results[0]->getValue().getScalarValue();
978 if (prior.getExpr()) {
979 result = refineReducedInterval(result, result.getInterval().intersect(prior.getInterval()));
980 }
981 propagateIfChanged(results[0], results[0]->setValue(result));
982 } else if (auto selectOp = llvm::dyn_cast<arith::SelectOp>(op)) {
984 smtSolver, operandVals[0].getScalarValue(), operandVals[1].getScalarValue(),
985 operandVals[2].getScalarValue()
986 );
987 const ExpressionValue &prior = results[0]->getValue().getScalarValue();
988 if (prior.getExpr()) {
989 result = refineReducedInterval(result, result.getInterval().intersect(prior.getInterval()));
990 }
991 propagateIfChanged(results[0], results[0]->setValue(result));
992 } else if (EmitEqualityOp emitEq = llvm::dyn_cast<EmitEqualityOp>(op)) {
993 Value lhsVal = emitEq.getLhs(), rhsVal = emitEq.getRhs();
994 ExpressionValue lhsExpr = operandVals[0].getScalarValue();
995 ExpressionValue rhsExpr = operandVals[1].getScalarValue();
996
997 // Special handling for generalized (s - c0) * (s - c1) * ... * (s - cN) = 0 patterns.
998 // These patterns enforce that s is one of c0, ..., cN.
999 auto res = getGeneralizedDecompInterval(op, lhsVal, rhsVal);
1000 if (succeeded(res)) {
1001 for (Value signalVal : res->first) {
1002 applyInterval(emitEq, signalVal, res->second);
1003 }
1004 }
1005
1006 ExpressionValue constraint = intersection(smtSolver, lhsExpr, rhsExpr);
1007 // Update the LHS and RHS to the same value, but restricted intervals
1008 // based on the constraints.
1009 const Interval &constrainInterval = constraint.getInterval();
1010 applyInterval(emitEq, lhsVal, constrainInterval);
1011 applyInterval(emitEq, rhsVal, constrainInterval);
1012 } else if (auto assertOp = llvm::dyn_cast<AssertOp>(op)) {
1013 // assert enforces that the operand is true. So we apply an interval of [1, 1]
1014 // to the operand.
1015 Value cond = assertOp.getCondition();
1016 applyInterval(assertOp, cond, Interval::True(field.get()));
1017 // Also add the solver constraint that the expression must be true.
1018 auto assertExpr = operandVals[0].getScalarValue();
1019 // No need to propagate the constraint
1020 (void)getLatticeElement(cond)->addSolverConstraint(assertExpr);
1021 } else if (auto writem = llvm::dyn_cast<MemberWriteOp>(op)) {
1022 const bool maySkipWrite = isInMaybeSkippedScfRegion(op);
1023 // Update values stored in a member
1024 ExpressionValue writeVal = operandVals[1].getScalarValue();
1025 auto cmp = writem.getComponent();
1026 // We also need to update the interval on the assigned symbol
1027 SourceRefLatticeValue refSet = getSourceRefState(cmp);
1028 if (refSet.isSingleValue()) {
1029 auto memberDefRes = writem.getMemberDefOp(tables);
1030 if (succeeded(memberDefRes)) {
1031 SourceRefIndex idx(memberDefRes.value());
1032 auto memberRefRes = refSet.getSingleValue().createChild(idx);
1033 ensure(succeeded(memberRefRes), "could not create SourceRef child for member write");
1034 const SourceRef &memberRef = *memberRefRes;
1035 Type memberTy = writem.getVal().getType();
1036 if (!llvm::isa<ArrayType, StructType>(memberTy)) {
1037 // Simple scalar update
1038 recordRefWrite(memberRef, writeVal, maySkipWrite);
1039 } else {
1040 // Map the intervals of aggregates to the written member
1041 std::optional<SourceRef> rhsPrefix;
1042 if (operandRefs[1].has_value() && operandRefs[1]->isRooted()) {
1043 rhsPrefix = operandRefs[1];
1044 } else if (auto blockArg = llvm::dyn_cast<BlockArgument>(writem.getVal())) {
1045 rhsPrefix = SourceRef(blockArg);
1046 } else if (auto result = llvm::dyn_cast<OpResult>(writem.getVal())) {
1047 rhsPrefix = SourceRef(result);
1048 }
1049
1050 if (rhsPrefix.has_value()) {
1051 llvm::SmallVector<std::pair<SourceRef, ExpressionValue>> remappedWrites;
1052 for (const auto &[writtenRef, writtenVal] : writeResults) {
1053 if (!writtenRef.isValidPrefix(*rhsPrefix)) {
1054 continue;
1055 }
1056
1057 auto translatedRef = writtenRef.translate(*rhsPrefix, memberRef);
1058 ensure(succeeded(translatedRef), "could not translate composite member write");
1059 remappedWrites.emplace_back(*translatedRef, writtenVal);
1060 }
1061
1062 for (const auto &[translatedRef, translatedVal] : remappedWrites) {
1063 recordRefWrite(translatedRef, translatedVal, maySkipWrite);
1064 }
1065 }
1066 }
1067 }
1068 }
1069 } else if (auto writeArr = llvm::dyn_cast<WriteArrayOp>(op)) {
1070 const bool maySkipWrite = isInMaybeSkippedScfRegion(op);
1071 ExpressionValue writeVal = operandVals.back().getScalarValue();
1072 auto arrayRef = getArrayAccessRef(op, writeArr);
1073 if (succeeded(arrayRef)) {
1074 recordRefWrite(*arrayRef, writeVal, maySkipWrite);
1075 }
1076
1077 SourceRefLatticeValue arrayVals = getSourceRefState(writeArr.getArrRef());
1078 if (arrayVals.isScalar()) {
1079 std::vector<SourceRefIndex> indices = getArrayAccessIndices(op, writeArr);
1080 auto targetRefsRes = arrayVals.extract(indices);
1081 ensure(succeeded(targetRefsRes), "could not create SourceRef child for array write");
1082 auto [targetRefs, _] = *targetRefsRes;
1083 ensure(targetRefs.isScalar(), "array write must resolve to scalar references");
1084 for (const SourceRef &ref : targetRefs.getScalarValue()) {
1085 recordRefWrite(ref, writeVal, maySkipWrite);
1086 }
1087 }
1088 } else if (auto createArray = llvm::dyn_cast<CreateArrayOp>(op)) {
1089 const auto &elements = createArray.getElements();
1090 ArrayType arrayTy = createArray.getType();
1091 Type elemTy = arrayTy.getElementType();
1092
1093 if (!elements.empty() && !llvm::isa<ArrayType, StructType>(elemTy)) {
1094 ensure(arrayTy.hasStaticShape(), "array.new with explicit elements must have static shape");
1095 ensure(
1096 std::cmp_equal(elements.size(), arrayTy.getNumElements()),
1097 "array.new explicit initializer length must match array shape"
1098 );
1099
1101 auto arrayRes = llvm::cast<OpResult>(createArray->getResult(0));
1102 for (unsigned i = 0; i < elements.size(); ++i) {
1103 auto maybeIndices = indexGen.delinearize(i, op->getContext());
1104 ensure(maybeIndices.has_value(), "could not delinearize array.new element index");
1105
1106 SourceRef::Path path;
1107 path.reserve(maybeIndices->size());
1108 for (Attribute attr : *maybeIndices) {
1109 auto idxAttr = llvm::dyn_cast<IntegerAttr>(attr);
1110 ensure(idxAttr != nullptr, "array.new delinearize should produce integer attributes");
1111 path.emplace_back(idxAttr.getValue());
1112 }
1113
1114 recordRefWrite(SourceRef(arrayRes, std::move(path)), operandVals[i].getScalarValue());
1115 }
1116 }
1117 } else if (isa<IntToFeltOp, FeltToIndexOp>(op)) {
1118 // Casts don't modify the intervals, but they do modify the SMT types.
1119 ExpressionValue expr = operandVals[0].getScalarValue();
1120 // We treat all ints and indexes as felts with the exception of comparison
1121 // results, which are bools. So if `expr` is a bool, this cast needs to
1122 // upcast to a felt.
1123 if (expr.isBoolSort(smtSolver)) {
1124 expr = boolToFelt(smtSolver, expr, field.get().bitWidth());
1125 }
1126 propagateIfChanged(results[0], results[0]->setValue(expr));
1127 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1128 // Fetch the lattice for after the parent operation so we can propagate
1129 // the yielded value to subsequent operations.
1130 Operation *parent = op->getParentOp();
1131 ensure(parent, "yield operation must have parent operation");
1132 // Bind the operand values to the result values of the parent
1133 for (unsigned idx = 0; idx < yieldOp.getResults().size(); ++idx) {
1134 Value parentRes = parent->getResult(idx);
1135 Lattice *resLattice = getLatticeElement(parentRes);
1136 // Merge with the existing value, if present (e.g., another branch)
1137 // has possible value that must be merged.
1138 ExpressionValue exprVal = resLattice->getValue().getScalarValue();
1139 ExpressionValue newResVal = operandVals[idx].getScalarValue();
1140 if (auto loopOp = llvm::dyn_cast<LoopLikeOpInterface>(parent)) {
1141 // We overapproximate for loops because we aren't going to try to track trip count.
1142 newResVal = ExpressionValue(createSymbol(parentRes), Interval::Entire(field.get()));
1143 }
1144 if (exprVal.getExpr() != nullptr) {
1145 newResVal =
1146 exprVal.withInterval(exprVal.getInterval().join(newResVal.getInterval()))
1147 .withOptionalUnreducedInterval(mergeUnreducedIntervals(
1149 ));
1150 } else {
1151 newResVal = ExpressionValue(
1152 createSymbol(parentRes), newResVal.getInterval(),
1154 );
1155 }
1156 propagateIfChanged(resLattice, resLattice->setValue(newResVal));
1157 }
1158 } else if (
1159 // We do not need to explicitly handle read ops since they are resolved at the operand value
1160 // step where `SourceRef`s are queries.
1161 !isReadOp(op)
1162 // We do not currently handle return ops as the analysis is currently limited to constrain
1163 // functions, which return no value.
1164 && !isReturnOp(op)
1165 // The analysis ignores definition ops.
1166 && !isDefinitionOp(op)
1167 // We do not need to analyze storage creation directly.
1168 && !llvm::isa<CreateArrayOp, CreateStructOp, NonDetOp>(op)
1169 ) {
1170 op->emitWarning("unhandled operation, analysis may be incomplete").report();
1171 }
1172
1173 return success();
1174}
1175
1177 auto it = refSymbols.find(r);
1178 if (it != refSymbols.end()) {
1179 return it->second;
1180 }
1181 llvm::SMTExprRef sym = createSymbol(r);
1182 refSymbols[r] = sym;
1183 return sym;
1184}
1185
1186llvm::SMTExprRef IntervalDataFlowAnalysis::createSymbol(mlir::Type ty, const char *name) const {
1187 if (isBooleanType(ty)) {
1188 return smtSolver->mkSymbol(name, smtSolver->getBoolSort());
1189 }
1190 return field.get().createSymbol(smtSolver, name);
1191}
1192
1193llvm::SMTExprRef IntervalDataFlowAnalysis::createSymbol(const SourceRef &r) const {
1194 std::string name = buildStringViaPrint(r);
1195 return createSymbol(r.getType(), name.c_str());
1196}
1197
1198llvm::SMTExprRef IntervalDataFlowAnalysis::createSymbol(Value v) const {
1199 std::string name = buildStringViaPrint(v);
1200 return createSymbol(v.getType(), name.c_str());
1201}
1202
1203llvm::DynamicAPInt IntervalDataFlowAnalysis::getConst(Operation *op) const {
1204 ensure(isConstOp(op), "op is not a const op");
1205
1206 // NOTE: I think clang-format makes these hard to read by default
1207 // clang-format off
1208 llvm::DynamicAPInt fieldConst = TypeSwitch<Operation *, llvm::DynamicAPInt>(op)
1209 .Case<FeltConstantOp>([&](auto feltConst) {
1210 llvm::APSInt constOpVal(feltConst.getValue());
1211 return field.get().reduce(constOpVal);
1212 })
1213 .Case<arith::ConstantIndexOp>([&](auto indexConst) {
1214 return DynamicAPInt(indexConst.value());
1215 })
1216 .Case<arith::ConstantIntOp>([&](auto intConst) {
1217 auto valAttr = dyn_cast<IntegerAttr>(intConst.getValue());
1218 ensure(valAttr != nullptr, "arith::ConstantIntOp must have an IntegerAttr as its value");
1219 return toDynamicAPInt(valAttr.getValue());
1220 })
1221 .Default([](auto *illegalOp) {
1222 std::string err;
1223 debug::Appender(err) << "unhandled getConst case: " << *illegalOp;
1224 llvm::report_fatal_error(Twine(err));
1225 return llvm::DynamicAPInt();
1226 });
1227 // clang-format on
1228 return fieldConst;
1229}
1230
1231ExpressionValue IntervalDataFlowAnalysis::performBinaryArithmetic(
1232 Operation *op, const LatticeValue &a, const LatticeValue &b
1233) {
1234 ensure(isArithmeticOp(op), "is not arithmetic op");
1235
1236 auto lhs = a.getScalarValue(), rhs = b.getScalarValue();
1237 ensure(lhs.getExpr(), "cannot perform arithmetic over null lhs smt expr");
1238 ensure(rhs.getExpr(), "cannot perform arithmetic over null rhs smt expr");
1239
1240 // clang-format off
1241 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
1242 .Case<AddFeltOp>([&](auto) { return add(smtSolver, lhs, rhs); })
1243 .Case<SubFeltOp>([&](auto) { return sub(smtSolver, lhs, rhs); })
1244 .Case<MulFeltOp>([&](auto) { return mul(smtSolver, lhs, rhs); })
1245 .Case<DivFeltOp>([&](auto) {return div(smtSolver, op, lhs, rhs); })
1246 .Case<UnsignedIntDivFeltOp>([&](auto) {return uintDiv(smtSolver, op, lhs, rhs); })
1247 .Case<SignedIntDivFeltOp>([&](auto) {return sintDiv(smtSolver, op, lhs, rhs); })
1248 .Case<UnsignedModFeltOp>([&](auto) { return mod(smtSolver, lhs, rhs); })
1249 .Case<SignedModFeltOp>([&](auto) { return sintMod(smtSolver, lhs, rhs); })
1250 .Case<AndFeltOp>([&](auto) { return bitAnd(smtSolver, lhs, rhs); })
1251 .Case<OrFeltOp>([&](auto) { return bitOr(smtSolver, lhs, rhs); })
1252 .Case<XorFeltOp, arith::XOrIOp>([&](auto) { return bitXor(smtSolver, lhs, rhs); })
1253 .Case<ShlFeltOp>([&](auto) { return shiftLeft(smtSolver, lhs, rhs); })
1254 .Case<ShrFeltOp>([&](auto) { return shiftRight(smtSolver, lhs, rhs); })
1255 .Case<CmpOp>([&](auto cmpOp) { return cmp(smtSolver, cmpOp, lhs, rhs); })
1256 .Case<AndBoolOp>([&](auto) { return boolAnd(smtSolver, lhs, rhs); })
1257 .Case<OrBoolOp>([&](auto) { return boolOr(smtSolver, lhs, rhs); })
1258 .Case<XorBoolOp>([&](auto) { return boolXor(smtSolver, lhs, rhs); })
1259 .Default([&](auto *unsupported) {
1260 unsupported
1261 ->emitError(
1262 "unsupported binary arithmetic operation"
1263 )
1264 .report();
1265 return ExpressionValue();
1266 });
1267 // clang-format on
1268
1269 ensure(res.getExpr(), "arithmetic produced null smt expr");
1270 return res;
1271}
1272
1274IntervalDataFlowAnalysis::performUnaryArithmetic(Operation *op, const LatticeValue &a) {
1275 ensure(isArithmeticOp(op), "is not arithmetic op");
1276
1277 auto val = a.getScalarValue();
1278 ensure(val.getExpr(), "cannot perform arithmetic over null smt expr");
1279
1280 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
1281 .Case<NegFeltOp>([&](auto) { return neg(smtSolver, val); })
1282 .Case<NotFeltOp>([&](auto) { return notOp(smtSolver, val); })
1283 .Case<NotBoolOp>([&](auto) { return boolNot(smtSolver, val); })
1284 // The inverse op is currently overapproximated
1285 .Case<InvFeltOp>([&](auto inv) {
1286 return fallbackUnaryOp(smtSolver, inv, val);
1287 }).Default([&](auto *unsupported) {
1288 unsupported
1289 ->emitWarning(
1290 "unsupported unary arithmetic operation, defaulting to over-approximated interval"
1291 )
1292 .report();
1293 return fallbackUnaryOp(smtSolver, unsupported, val);
1294 });
1295
1296 ensure(res.getExpr(), "arithmetic produced null smt expr");
1297 return res;
1298}
1299
1300void IntervalDataFlowAnalysis::applyInterval(Operation *valUser, Value val, Interval newInterval) {
1301 Lattice *valLattice = getLatticeElement(val);
1302 ExpressionValue oldLatticeVal = valLattice->getValue().getScalarValue();
1303 // Intersect with the current value to accumulate restrictions across constraints.
1304 Interval intersection = oldLatticeVal.getInterval().intersect(newInterval);
1305 ExpressionValue newLatticeVal = refineReducedInterval(oldLatticeVal, intersection);
1306 ChangeResult changed = valLattice->setValue(newLatticeVal);
1307
1308 if (auto blockArg = llvm::dyn_cast<BlockArgument>(val)) {
1309 auto fnOp = dyn_cast<FuncDefOp>(blockArg.getOwner()->getParentOp());
1310
1311 // Apply the interval from the constrain function inputs to the compute function inputs
1312 if (propagateInputConstraints && fnOp && fnOp.isStructConstrain() &&
1313 blockArg.getArgNumber() > 0 && !newInterval.isEntire()) {
1314 auto structOp = fnOp->getParentOfType<StructDefOp>();
1315 FuncDefOp computeFn = structOp.getComputeFuncOp();
1316 BlockArgument computeArg = computeFn.getArgument(blockArg.getArgNumber() - 1);
1317 Lattice *computeEntryLattice = getLatticeElement(computeArg);
1318
1319 SourceRef ref(computeArg);
1320 ExpressionValue newArgVal(
1321 getOrCreateSymbol(ref), newInterval,
1322 trackUnreducedIntervals ? std::optional<UnreducedInterval>(newInterval.firstUnreduced())
1323 : std::nullopt
1324 );
1325 propagateIfChanged(computeEntryLattice, computeEntryLattice->setValue(newArgVal));
1326 }
1327 }
1328
1329 // Now we descend into val's operands, if it has any.
1330 Operation *definingOp = val.getDefiningOp();
1331 if (!definingOp) {
1332 propagateIfChanged(valLattice, changed);
1333 return;
1334 }
1335
1336 const Field &f = field.get();
1337
1338 // This is a rules-based operation. If we have a rule for a given operation,
1339 // then we can make some kind of update, otherwise we leave the intervals
1340 // as is.
1341 // - First we'll define all the rules so the type switch can be less messy
1342
1343 // cmp.<pred> restricts each side of the comparison if the result is known.
1344 auto cmpCase = [&](CmpOp cmpOp) {
1345 // Cmp output range is [0, 1], so in order to do something, we must have newInterval
1346 // either "true" (1) or "false" (0).
1347 // -- In the case of a contradictory circuit, however, the cmp result is allowed
1348 // to be empty.
1349 ensure(
1350 newInterval.isBoolean() || newInterval.isEmpty(),
1351 "new interval for CmpOp is not boolean or empty"
1352 );
1353 if (!newInterval.isDegenerate()) {
1354 // The comparison result is unknown, so we can't update the operand ranges
1355 return;
1356 }
1357
1358 bool cmpTrue = newInterval.rhs() == f.one();
1359
1360 Value lhs = cmpOp.getLhs(), rhs = cmpOp.getRhs();
1361 auto lhsLat = getLatticeElement(lhs), rhsLat = getLatticeElement(rhs);
1362 ExpressionValue lhsExpr = lhsLat->getValue().getScalarValue(),
1363 rhsExpr = rhsLat->getValue().getScalarValue();
1364
1365 Interval newLhsInterval, newRhsInterval;
1366 const Interval &lhsInterval = lhsExpr.getInterval();
1367 const Interval &rhsInterval = rhsExpr.getInterval();
1368
1369 FeltCmpPredicate pred = cmpOp.getPredicate();
1370 // predicate cases
1371 auto eqCase = [&]() {
1372 return (pred == FeltCmpPredicate::EQ && cmpTrue) ||
1373 (pred == FeltCmpPredicate::NE && !cmpTrue);
1374 };
1375 auto neCase = [&]() {
1376 return (pred == FeltCmpPredicate::NE && cmpTrue) ||
1377 (pred == FeltCmpPredicate::EQ && !cmpTrue);
1378 };
1379 auto ltCase = [&]() {
1380 return (pred == FeltCmpPredicate::LT && cmpTrue) ||
1381 (pred == FeltCmpPredicate::GE && !cmpTrue);
1382 };
1383 auto leCase = [&]() {
1384 return (pred == FeltCmpPredicate::LE && cmpTrue) ||
1385 (pred == FeltCmpPredicate::GT && !cmpTrue);
1386 };
1387 auto gtCase = [&]() {
1388 return (pred == FeltCmpPredicate::GT && cmpTrue) ||
1389 (pred == FeltCmpPredicate::LE && !cmpTrue);
1390 };
1391 auto geCase = [&]() {
1392 return (pred == FeltCmpPredicate::GE && cmpTrue) ||
1393 (pred == FeltCmpPredicate::LT && !cmpTrue);
1394 };
1395
1396 // new intervals based on case
1397 if (eqCase()) {
1398 newLhsInterval = newRhsInterval = lhsInterval.intersect(rhsInterval);
1399 } else if (neCase()) {
1400 if (lhsInterval.isDegenerate() && rhsInterval.isDegenerate() && lhsInterval == rhsInterval) {
1401 // In this case, we know lhs and rhs cannot satisfy this assertion, so they have
1402 // an empty value range.
1403 newLhsInterval = newRhsInterval = Interval::Empty(f);
1404 } else if (lhsInterval.isDegenerate()) {
1405 // rhs must not overlap with lhs
1406 newLhsInterval = lhsInterval;
1407 newRhsInterval = rhsInterval.difference(lhsInterval);
1408 } else if (rhsInterval.isDegenerate()) {
1409 // lhs must not overlap with rhs
1410 newLhsInterval = lhsInterval.difference(rhsInterval);
1411 newRhsInterval = rhsInterval;
1412 } else {
1413 // Leave unchanged
1414 newLhsInterval = lhsInterval;
1415 newRhsInterval = rhsInterval;
1416 }
1417 } else if (ltCase()) {
1418 newLhsInterval = lhsInterval.toUnreduced().computeLTPart(rhsInterval.toUnreduced()).reduce(f);
1419 newRhsInterval = rhsInterval.toUnreduced().computeGEPart(lhsInterval.toUnreduced()).reduce(f);
1420 } else if (leCase()) {
1421 newLhsInterval = lhsInterval.toUnreduced().computeLEPart(rhsInterval.toUnreduced()).reduce(f);
1422 newRhsInterval = rhsInterval.toUnreduced().computeGTPart(lhsInterval.toUnreduced()).reduce(f);
1423 } else if (gtCase()) {
1424 newLhsInterval = lhsInterval.toUnreduced().computeGTPart(rhsInterval.toUnreduced()).reduce(f);
1425 newRhsInterval = rhsInterval.toUnreduced().computeLEPart(lhsInterval.toUnreduced()).reduce(f);
1426 } else if (geCase()) {
1427 newLhsInterval = lhsInterval.toUnreduced().computeGEPart(rhsInterval.toUnreduced()).reduce(f);
1428 newRhsInterval = rhsInterval.toUnreduced().computeLTPart(lhsInterval.toUnreduced()).reduce(f);
1429 } else {
1430 cmpOp->emitWarning("unhandled cmp predicate").report();
1431 return;
1432 }
1433
1434 // Now we recurse to each operand
1435 applyInterval(cmpOp, lhs, newLhsInterval);
1436 applyInterval(cmpOp, rhs, newRhsInterval);
1437 };
1438
1439 // Multiplication cases:
1440 // - If the result of a multiplication is non-zero, then both operands must be
1441 // non-zero.
1442 // - If one operand is a constant, we can propagate the new interval when multiplied
1443 // by the multiplicative inverse of the constant.
1444 auto mulCase = [&](MulFeltOp mulOp) {
1445 // We check for the constant case first.
1446 auto constCase = [&](FeltConstantOp constOperand, Value multiplicand) {
1447 auto latVal = getLatticeElement(multiplicand)->getValue().getScalarValue();
1448 APInt constVal = constOperand.getValue();
1449 if (constVal.isZero()) {
1450 // There's no inverse for zero, so we do nothing.
1451 return;
1452 }
1453 Interval updatedInterval = newInterval * Interval::Degenerate(f, f.inv(constVal));
1454 applyInterval(mulOp, multiplicand, updatedInterval);
1455 };
1456
1457 Value lhs = mulOp.getLhs(), rhs = mulOp.getRhs();
1458
1459 auto lhsConstOp = dyn_cast_if_present<FeltConstantOp>(lhs.getDefiningOp());
1460 auto rhsConstOp = dyn_cast_if_present<FeltConstantOp>(rhs.getDefiningOp());
1461 // If both are consts, we don't need to do anything
1462 if (lhsConstOp && rhsConstOp) {
1463 return;
1464 } else if (lhsConstOp) {
1465 constCase(lhsConstOp, rhs);
1466 return;
1467 } else if (rhsConstOp) {
1468 constCase(rhsConstOp, lhs);
1469 return;
1470 }
1471
1472 // Otherwise, try to propagate non-zero information.
1473 auto zeroInt = Interval::Degenerate(f, f.zero());
1474 if (newInterval.intersect(zeroInt).isNotEmpty()) {
1475 // The multiplication may be zero, so we can't reduce the operands to be non-zero
1476 return;
1477 }
1478
1479 auto lhsLat = getLatticeElement(lhs), rhsLat = getLatticeElement(rhs);
1480 ExpressionValue lhsExpr = lhsLat->getValue().getScalarValue(),
1481 rhsExpr = rhsLat->getValue().getScalarValue();
1482 Interval newLhsInterval = lhsExpr.getInterval().difference(zeroInt);
1483 Interval newRhsInterval = rhsExpr.getInterval().difference(zeroInt);
1484 applyInterval(mulOp, lhs, newLhsInterval);
1485 applyInterval(mulOp, rhs, newRhsInterval);
1486 };
1487
1488 auto addCase = [&](AddFeltOp addOp) {
1489 Value lhs = addOp.getLhs(), rhs = addOp.getRhs();
1490 Lattice *lhsLat = getLatticeElement(lhs), *rhsLat = getLatticeElement(rhs);
1491 ExpressionValue lhsVal = lhsLat->getValue().getScalarValue();
1492 ExpressionValue rhsVal = rhsLat->getValue().getScalarValue();
1493
1494 const Interval &currLhsInt = lhsVal.getInterval(), &currRhsInt = rhsVal.getInterval();
1495
1496 Interval derivedLhsInt = newInterval - currRhsInt;
1497 Interval derivedRhsInt = newInterval - currLhsInt;
1498
1499 Interval finalLhsInt = currLhsInt.intersect(derivedLhsInt);
1500 Interval finalRhsInt = currRhsInt.intersect(derivedRhsInt);
1501
1502 applyInterval(addOp, lhs, finalLhsInt);
1503 applyInterval(addOp, rhs, finalRhsInt);
1504 };
1505
1506 auto subCase = [&](SubFeltOp subOp) {
1507 Value lhs = subOp.getLhs(), rhs = subOp.getRhs();
1508 Lattice *lhsLat = getLatticeElement(lhs), *rhsLat = getLatticeElement(rhs);
1509 ExpressionValue lhsVal = lhsLat->getValue().getScalarValue();
1510 ExpressionValue rhsVal = rhsLat->getValue().getScalarValue();
1511
1512 const Interval &currLhsInt = lhsVal.getInterval(), &currRhsInt = rhsVal.getInterval();
1513
1514 Interval derivedLhsInt = newInterval + currRhsInt;
1515 Interval derivedRhsInt = currLhsInt - newInterval;
1516
1517 Interval finalLhsInt = currLhsInt.intersect(derivedLhsInt);
1518 Interval finalRhsInt = currRhsInt.intersect(derivedRhsInt);
1519
1520 applyInterval(subOp, lhs, finalLhsInt);
1521 applyInterval(subOp, rhs, finalRhsInt);
1522 };
1523
1524 auto selectCase = [&](arith::SelectOp selectOp) {
1525 Value cond = selectOp.getCondition();
1526 Value trueVal = selectOp.getTrueValue();
1527 Value falseVal = selectOp.getFalseValue();
1528
1529 ExpressionValue condExpr = getLatticeElement(cond)->getValue().getScalarValue();
1530 ExpressionValue trueExpr = getLatticeElement(trueVal)->getValue().getScalarValue();
1531 ExpressionValue falseExpr = getLatticeElement(falseVal)->getValue().getScalarValue();
1532
1533 const Interval &condInterval = condExpr.getInterval();
1534 if (condInterval.isDegenerate() && condInterval.rhs() == f.one()) {
1535 applyInterval(selectOp, trueVal, newInterval);
1536 return;
1537 }
1538 if (condInterval.isDegenerate() && condInterval.rhs() == f.zero()) {
1539 applyInterval(selectOp, falseVal, newInterval);
1540 return;
1541 }
1542
1543 Interval trueOverlap = trueExpr.getInterval().intersect(newInterval);
1544 Interval falseOverlap = falseExpr.getInterval().intersect(newInterval);
1545 bool truePossible = trueOverlap.isNotEmpty();
1546 bool falsePossible = falseOverlap.isNotEmpty();
1547
1548 if (truePossible && !falsePossible) {
1549 applyInterval(selectOp, cond, Interval::True(f));
1550 applyInterval(selectOp, trueVal, newInterval);
1551 return;
1552 }
1553 if (!truePossible && falsePossible) {
1554 applyInterval(selectOp, cond, Interval::False(f));
1555 applyInterval(selectOp, falseVal, newInterval);
1556 return;
1557 }
1558 if (!truePossible && !falsePossible) {
1559 applyInterval(selectOp, cond, Interval::Empty(f));
1560 }
1561 };
1562
1563 auto readmCase = [&](MemberReadOp) {
1564 SourceRefLatticeValue sourceRefVal = getSourceRefState(val);
1565
1566 if (sourceRefVal.isSingleValue()) {
1567 const SourceRef &ref = sourceRefVal.getSingleValue();
1568 readResults[ref].insert(valLattice);
1569
1570 // Also propagate to all other member read results for this member
1571 for (Lattice *l : readResults[ref]) {
1572 if (l != valLattice) {
1573 propagateIfChanged(l, l->setValue(newLatticeVal));
1574 }
1575 }
1576 }
1577 };
1578
1579 auto readArrCase = [&](ReadArrayOp) {
1580 auto arrayRef = getArrayAccessRef(valUser, llvm::cast<ReadArrayOp>(definingOp));
1581 if (succeeded(arrayRef)) {
1582 readResults[*arrayRef].insert(valLattice);
1583
1584 for (Lattice *l : readResults[*arrayRef]) {
1585 if (l != valLattice) {
1586 propagateIfChanged(l, l->setValue(newLatticeVal));
1587 }
1588 }
1589 }
1590
1591 SourceRefLatticeValue sourceRefVal = getSourceRefState(val);
1592
1593 if (sourceRefVal.isSingleValue()) {
1594 const SourceRef &ref = sourceRefVal.getSingleValue();
1595 readResults[ref].insert(valLattice);
1596
1597 // Also propagate to all other member read results for this member
1598 for (Lattice *l : readResults[ref]) {
1599 if (l != valLattice) {
1600 propagateIfChanged(l, l->setValue(newLatticeVal));
1601 }
1602 }
1603 }
1604 };
1605
1606 // For casts, just pass the interval along to the cast's operand.
1607 auto castCase = [&](Operation *op) { applyInterval(op, op->getOperand(0), newInterval); };
1608
1609 // - Apply the rules given the op.
1610 // NOTE: disabling clang-format for this because it makes the last case statement
1611 // look ugly.
1612 // clang-format off
1613 TypeSwitch<Operation *>(definingOp)
1614 .Case<CmpOp>([&](auto op) { cmpCase(op); })
1615 .Case<AddFeltOp>([&](auto op) { return addCase(op); })
1616 .Case<SubFeltOp>([&](auto op) { return subCase(op); })
1617 .Case<MulFeltOp>([&](auto op) { mulCase(op); })
1618 .Case<arith::SelectOp>([&](auto op) { selectCase(op); })
1619 .Case<MemberReadOp>([&](auto op){ readmCase(op); })
1620 .Case<ReadArrayOp>([&](auto op){ readArrCase(op); })
1621 .Case<IntToFeltOp, FeltToIndexOp>([&](auto op) { castCase(op); })
1622 .Default([&](Operation *) { });
1623 // clang-format on
1624
1625 // Propagate after recursion to avoid having recursive calls unset the value.
1626 propagateIfChanged(valLattice, changed);
1627}
1628
1629FailureOr<std::pair<DenseSet<Value>, Interval>>
1630IntervalDataFlowAnalysis::getGeneralizedDecompInterval(
1631 Operation * /*baseOp*/, Value lhs, Value rhs
1632) {
1633 auto isZeroConst = [this](Value v) {
1634 Operation *op = v.getDefiningOp();
1635 if (!op) {
1636 return false;
1637 }
1638 if (!isConstOp(op)) {
1639 return false;
1640 }
1641 return getConst(op) == field.get().zero();
1642 };
1643 bool lhsIsZero = isZeroConst(lhs), rhsIsZero = isZeroConst(rhs);
1644 Value exprTree = nullptr;
1645 if (lhsIsZero && !rhsIsZero) {
1646 exprTree = rhs;
1647 } else if (!lhsIsZero && rhsIsZero) {
1648 exprTree = lhs;
1649 } else {
1650 return failure();
1651 }
1652
1653 // We now explore the expression tree for multiplications of subtractions/signal values.
1654 std::optional<SourceRef> signalRef = std::nullopt;
1655 DenseSet<Value> signalVals;
1656 SmallVector<DynamicAPInt> consts;
1657 SmallVector<Value> frontier {exprTree};
1658 while (!frontier.empty()) {
1659 Value v = frontier.back();
1660 frontier.pop_back();
1661 Operation *op = v.getDefiningOp();
1662
1663 FeltConstantOp c;
1664 Value signalVal;
1665 auto handleRefValue = [this, &signalRef, &signalVal, &signalVals]() {
1666 SourceRefLatticeValue refSet = getSourceRefState(signalVal);
1667 if (!refSet.isScalar() || !refSet.isSingleValue()) {
1668 return failure();
1669 }
1670 SourceRef r = refSet.getSingleValue();
1671 if (signalRef.has_value() && signalRef.value() != r) {
1672 return failure();
1673 } else if (!signalRef.has_value()) {
1674 signalRef = r;
1675 }
1676 signalVals.insert(signalVal);
1677 return success();
1678 };
1679
1680 auto subPattern = m_CommutativeOp<SubFeltOp>(m_RefValue(&signalVal), m_Constant(&c));
1681 if (op && matchPattern(op, subPattern)) {
1682 if (failed(handleRefValue())) {
1683 return failure();
1684 }
1685 auto constInt = APSInt(c.getValue());
1686 consts.push_back(field.get().reduce(constInt));
1687 continue;
1688 } else if (m_RefValue(&signalVal).match(v)) {
1689 if (failed(handleRefValue())) {
1690 return failure();
1691 }
1692 consts.push_back(field.get().zero());
1693 continue;
1694 }
1695
1696 Value a, b;
1697 auto mulPattern = m_CommutativeOp<MulFeltOp>(matchers::m_Any(&a), matchers::m_Any(&b));
1698 if (op && matchPattern(op, mulPattern)) {
1699 frontier.push_back(a);
1700 frontier.push_back(b);
1701 continue;
1702 }
1703
1704 return failure();
1705 }
1706
1707 // Now, we aggregate the Interval. If we have sparse values (e.g., 0, 2, 4),
1708 // we will create a larger range of [0, 4], since we don't support multiple intervals.
1709 std::sort(consts.begin(), consts.end());
1710 Interval iv = UnreducedInterval(consts.front(), consts.back()).reduce(field.get());
1711 return std::make_pair(std::move(signalVals), iv);
1712}
1713
1714/* StructIntervals */
1715
1717 mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, const IntervalAnalysisContext &ctx
1718) {
1719 SymbolTableCollection tables;
1720
1721 auto computeIntervalsImpl =
1722 [&solver, &am, &ctx, &tables, this](
1723 FuncDefOp fn, llvm::MapVector<SourceRef, Interval> &memberRanges,
1724 llvm::MapVector<SourceRef, UnreducedInterval> &memberUnreducedRanges,
1725 llvm::SetVector<ExpressionValue> & /*solverConstraints*/
1726 ) {
1727 auto setUnreducedRange =
1728 [&memberUnreducedRanges](const SourceRef &ref, const UnreducedInterval &interval) {
1729 memberUnreducedRanges.erase(ref);
1730 memberUnreducedRanges.insert({ref, interval});
1731 };
1732 // Since every lattice value does not contain every value, we will traverse
1733 // the function backwards (from most up-to-date to least-up-to-date lattices)
1734 // searching for the source refs. Once a source ref is found, we remove it
1735 // from the search set.
1736
1737 SourceRefSet searchSet;
1738 for (const auto &ref : SourceRef::getAllSourceRefs(structDef, fn)) {
1739 // We only want to compute intervals for field elements and not composite types.
1740 if (!ref.isScalar()) {
1741 continue;
1742 }
1743 searchSet.insert(ref);
1744 }
1745 SourceRefSet functionRefs = searchSet;
1746
1747 auto mergeInterval = [&memberRanges, &memberUnreducedRanges](
1748 const SourceRef &ref, const Interval &interval,
1749 std::optional<UnreducedInterval> unreducedInterval = std::nullopt
1750 ) {
1751 auto *existing = memberRanges.find(ref);
1752 if (existing != memberRanges.end()) {
1753 Interval mergedInterval = existing->second.intersect(interval);
1754 bool intervalChanged = mergedInterval != existing->second;
1755 existing->second = mergedInterval;
1756
1757 if (unreducedInterval.has_value()) {
1758 auto *existingUnreduced = memberUnreducedRanges.find(ref);
1759 if (existingUnreduced != memberUnreducedRanges.end()) {
1760 existingUnreduced->second = existingUnreduced->second.intersect(*unreducedInterval);
1761 } else {
1762 memberUnreducedRanges.insert({ref, *unreducedInterval});
1763 }
1764 } else if (intervalChanged) {
1765 memberUnreducedRanges.erase(ref);
1766 }
1767 return;
1768 }
1769
1770 memberRanges[ref] = interval;
1771 if (unreducedInterval.has_value()) {
1772 memberUnreducedRanges.insert({ref, *unreducedInterval});
1773 }
1774 };
1775
1776 // Iterate over arguments
1777 for (BlockArgument arg : fn.getArguments()) {
1778 SourceRef ref {arg};
1779 if (searchSet.erase(ref)) {
1780 const IntervalAnalysisLattice *lattice = solver.lookupState<IntervalAnalysisLattice>(arg);
1781 // If we never referenced this argument, use a default value
1782 ExpressionValue expr = lattice->getValue().getScalarValue();
1783 if (!expr.getExpr()) {
1784 expr = expr.withInterval(Interval::Entire(ctx.getField()));
1785 if (ctx.doTrackUnreducedIntervals()) {
1786 expr = expr.withUnreducedInterval(expr.getInterval().firstUnreduced());
1787 }
1788 }
1789 memberRanges[ref] = expr.getInterval();
1790 if (expr.hasUnreducedInterval()) {
1791 setUnreducedRange(ref, expr.getUnreducedInterval());
1792 }
1793 assert(memberRanges[ref].getField() == ctx.getField() && "bad interval defaults");
1794 }
1795 }
1796
1797 // Aggregate all read intervals for a ref. A single ref may be read at multiple program
1798 // points with different precision, so picking an arbitrary lattice from the DenseSet is
1799 // nondeterministic. Joining preserves the overapproximation regardless of iteration order.
1800 for (const auto &[ref, lattices] : ctx.intervalDFA->getReadResults()) {
1801 if (!lattices.empty() && searchSet.erase(ref)) {
1802 Interval joinedInterval = Interval::Empty(ctx.getField());
1803 std::optional<UnreducedInterval> joinedUnreduced = std::nullopt;
1804 bool sawFirst = false;
1805 for (const IntervalAnalysisLattice *lattice : lattices) {
1806 const ExpressionValue &expr = lattice->getValue().getScalarValue();
1807 joinedInterval = joinedInterval.join(expr.getInterval());
1808 if (!sawFirst) {
1809 joinedUnreduced = expr.getOptionalUnreducedInterval();
1810 sawFirst = true;
1811 } else {
1812 joinedUnreduced =
1813 mergeUnreducedIntervals(joinedUnreduced, expr.getOptionalUnreducedInterval());
1814 }
1815 }
1816 memberRanges[ref] = joinedInterval;
1817 if (joinedUnreduced.has_value()) {
1818 setUnreducedRange(ref, *joinedUnreduced);
1819 }
1820 assert(memberRanges[ref].getField() == ctx.getField() && "bad interval defaults");
1821 }
1822 }
1823
1824 for (const auto &[ref, val] : ctx.intervalDFA->getWriteResults()) {
1825 if (searchSet.erase(ref)) {
1826 memberRanges[ref] = val.getInterval();
1827 if (val.hasUnreducedInterval()) {
1828 setUnreducedRange(ref, val.getUnreducedInterval());
1829 }
1830 assert(memberRanges[ref].getField() == ctx.getField() && "bad interval defaults");
1831 }
1832 }
1833
1834 // Child constrain calls refine parent-visible storage, but only after the callee
1835 // summary is translated through the call operands. If translation cannot prove
1836 // which parent ref owns a child interval, the local overapproximation remains.
1837 if (fn.isStructConstrain()) {
1838 auto mergeChildConstrainIntervals = [&](CallOp fnCall) {
1839 if (!dataflow::isOperationLive(solver, fnCall.getOperation())) {
1840 return;
1841 }
1842
1843 auto res = resolveCallableSilently<FuncDefOp>(tables, fnCall);
1844 if (failed(res)) {
1845 return;
1846 }
1847
1848 FuncDefOp calledFn = res->get();
1849 if (!calledFn.isStructConstrain()) {
1850 return;
1851 }
1852
1853 auto calledStruct = calledFn->getParentOfType<StructDefOp>();
1854 if (calledStruct == structDef) {
1855 return;
1856 }
1857
1858 auto &childAnalysis = am.getChildAnalysis<StructIntervalAnalysis>(calledStruct);
1859 if (childAnalysis.inProgress(ctx)) {
1860 return;
1861 }
1862 if (!childAnalysis.constructed(ctx)) {
1863 ensure(
1864 succeeded(childAnalysis.runAnalysis(solver, am, ctx)),
1865 "could not construct interval analysis for child struct"
1866 );
1867 }
1868
1869 // Translate callee argument refs into parent refs and capture scalar call-site intervals
1870 // that can refine direct equality groups inside the child constrain function.
1871 SourceRefRemappings identityTranslations;
1872 llvm::MapVector<SourceRef, Interval> callOperandIntervals;
1873 for (unsigned i = 0; i < calledFn.getNumArguments(); i++) {
1874 SourceRef prefix(calledFn.getArgument(i));
1875 Value operand = fnCall.getOperand(i);
1876 std::optional<SourceRefLatticeValue> identityVal =
1877 getIdentitySourceRefState(solver, operand);
1878 if (identityVal.has_value()) {
1879 identityTranslations.push_back({prefix, *identityVal});
1880 }
1881
1882 if (!llvm::isa<ArrayType, StructType>(operand.getType())) {
1883 const IntervalAnalysisLattice *lattice =
1884 solver.lookupState<IntervalAnalysisLattice>(operand);
1885 if (lattice != nullptr) {
1886 const ExpressionValue &expr = lattice->getValue().getScalarValue();
1887 callOperandIntervals[prefix] = expr.getInterval();
1888 }
1889 }
1890 }
1891
1892 const StructIntervals &childIntervals = childAnalysis.getResult(ctx);
1893 const auto &constrainIntervals = childIntervals.getConstrainIntervals();
1894 const auto &constrainUnreducedIntervals = childIntervals.getConstrainUnreducedIntervals();
1895 for (const auto &[childRef, childInterval] : constrainIntervals) {
1896 auto translatedRefs = translateRef(childRef, identityTranslations);
1897 if (failed(translatedRefs)) {
1898 continue;
1899 }
1900
1901 std::optional<UnreducedInterval> childUnreduced = std::nullopt;
1902 if (auto *childUnreducedIt = constrainUnreducedIntervals.find(childRef);
1903 childUnreducedIt != constrainUnreducedIntervals.end()) {
1904 childUnreduced = childUnreducedIt->second;
1905 }
1906
1907 SourceRefSet uniqueTranslatedRefs;
1908 for (const SourceRef &translatedRef : *translatedRefs) {
1909 uniqueTranslatedRefs.insert(translatedRef);
1910 }
1911 if (uniqueTranslatedRefs.size() != 1) {
1912 continue;
1913 }
1914
1915 const SourceRef &translatedRef = *uniqueTranslatedRefs.begin();
1916 if (functionRefs.contains(translatedRef)) {
1917 mergeInterval(translatedRef, childInterval, childUnreduced);
1918 searchSet.erase(translatedRef);
1919 }
1920 }
1921
1922 // Direct equalities in the child can combine child summaries, call operand intervals, and
1923 // existing parent intervals before merging back into each translated parent ref.
1924 llvm::EquivalenceClasses<SourceRef> directEqRefs =
1925 collectDirectEqualityRefs(solver, calledFn);
1926 for (auto leaderIt = directEqRefs.begin(); leaderIt != directEqRefs.end(); ++leaderIt) {
1927 if (!leaderIt->isLeader()) {
1928 continue;
1929 }
1930
1931 llvm::MapVector<SourceRef, Interval> translatedEqRefs;
1932 Interval contextualInterval = Interval::Entire(ctx.getField());
1933 bool hasInterval = false;
1934 bool ambiguousTranslation = false;
1935
1936 for (auto memberIt = directEqRefs.member_begin(leaderIt);
1937 memberIt != directEqRefs.member_end(); ++memberIt) {
1938 Interval memberInterval = Interval::Entire(ctx.getField());
1939 if (const auto *childIntervalIt = constrainIntervals.find(*memberIt);
1940 childIntervalIt != constrainIntervals.end()) {
1941 memberInterval = memberInterval.intersect(childIntervalIt->second);
1942 }
1943 if (auto *callOperandIt = callOperandIntervals.find(*memberIt);
1944 callOperandIt != callOperandIntervals.end()) {
1945 memberInterval = memberInterval.intersect(callOperandIt->second);
1946 contextualInterval = contextualInterval.intersect(memberInterval);
1947 hasInterval = true;
1948 }
1949
1950 auto translatedRefs = translateRef(*memberIt, identityTranslations);
1951 if (failed(translatedRefs)) {
1952 continue;
1953 }
1954
1955 SourceRefSet uniqueTranslatedRefs;
1956 for (const SourceRef &translatedRef : *translatedRefs) {
1957 uniqueTranslatedRefs.insert(translatedRef);
1958 }
1959 if (uniqueTranslatedRefs.size() != 1) {
1960 ambiguousTranslation = true;
1961 break;
1962 }
1963
1964 const SourceRef &translatedRef = *uniqueTranslatedRefs.begin();
1965 if (!functionRefs.contains(translatedRef)) {
1966 continue;
1967 }
1968
1969 if (auto *parentIntervalIt = memberRanges.find(translatedRef);
1970 parentIntervalIt != memberRanges.end()) {
1971 memberInterval = memberInterval.intersect(parentIntervalIt->second);
1972 }
1973
1974 translatedEqRefs[translatedRef] = memberInterval;
1975 contextualInterval = contextualInterval.intersect(memberInterval);
1976 hasInterval = true;
1977 }
1978
1979 if (ambiguousTranslation || !hasInterval || translatedEqRefs.empty()) {
1980 continue;
1981 }
1982
1983 for (const auto &[translatedRef, _] : translatedEqRefs) {
1984 mergeInterval(translatedRef, contextualInterval);
1985 searchSet.erase(translatedRef);
1986 }
1987 }
1988 };
1989
1990 fn.walk(mergeChildConstrainIntervals);
1991 }
1992
1993 // For all unfound refs, default to the entire range.
1994 for (const auto &ref : searchSet) {
1995 memberRanges[ref] = Interval::Entire(ctx.getField());
1996 if (ctx.doTrackUnreducedIntervals()) {
1997 setUnreducedRange(ref, memberRanges[ref].firstUnreduced());
1998 }
1999 }
2000
2001 // Sort the outputs since we assembled things out of order.
2002 //
2003 // `llvm::MapVector` maintains an internal key -> index map. Sorting it in
2004 // place corrupts lookup semantics because the backing vector is reordered
2005 // without rebuilding that map. Reinsert into a fresh MapVector instead.
2006 llvm::SmallVector<std::pair<SourceRef, Interval>> sortedRanges;
2007 sortedRanges.reserve(memberRanges.size());
2008 for (const auto &[ref, interval] : memberRanges) {
2009 sortedRanges.emplace_back(ref, interval);
2010 }
2011 llvm::sort(sortedRanges, [](const auto &a, const auto &b) { return a.first < b.first; });
2012 llvm::SmallVector<std::pair<SourceRef, UnreducedInterval>> sortedUnreducedRanges;
2013 sortedUnreducedRanges.reserve(memberUnreducedRanges.size());
2014 for (const auto &[ref, interval] : memberUnreducedRanges) {
2015 sortedUnreducedRanges.emplace_back(ref, interval);
2016 }
2017 llvm::sort(sortedUnreducedRanges, [](const auto &a, const auto &b) {
2018 return a.first < b.first;
2019 });
2020 memberRanges.clear();
2021 memberUnreducedRanges.clear();
2022 for (auto &[ref, interval] : sortedRanges) {
2023 memberRanges[ref] = interval;
2024 }
2025 for (auto &[ref, interval] : sortedUnreducedRanges) {
2026 memberUnreducedRanges.insert({ref, interval});
2027 }
2028 };
2029
2030 if (auto computeFn = structDef.getComputeFuncOp()) {
2031 computeIntervalsImpl(
2032 computeFn, computeMemberRanges, computeMemberUnreducedRanges, computeSolverConstraints
2033 );
2034 }
2035 if (auto constrainFn = structDef.getConstrainFuncOp()) {
2036 computeIntervalsImpl(
2037 constrainFn, constrainMemberRanges, constrainMemberUnreducedRanges,
2038 constrainSolverConstraints
2039 );
2040 }
2041
2042 return success();
2043}
2044
2046 mlir::raw_ostream &os, bool withConstraints, bool printCompute, bool printUnreduced
2047) const {
2048 auto writeIntervals =
2049 [&os, &withConstraints, &printUnreduced](
2050 const char *fnName, const llvm::MapVector<SourceRef, Interval> &memberRanges,
2051 const llvm::MapVector<SourceRef, UnreducedInterval> &memberUnreducedRanges,
2052 const llvm::SetVector<ExpressionValue> &solverConstraints, bool printName
2053 ) {
2054 int indent = 4;
2055 if (printName) {
2056 os << '\n';
2057 os.indent(indent) << fnName << " {";
2058 indent += 4;
2059 }
2060
2061 if (memberRanges.empty()) {
2062 os << "}\n";
2063 return;
2064 }
2065
2066 for (const auto &[ref, interval] : memberRanges) {
2067 os << '\n';
2068 os.indent(indent) << ref << " in " << interval;
2069 if (printUnreduced) {
2070 const auto *unreducedIt = memberUnreducedRanges.find(ref);
2071 if (unreducedIt != memberUnreducedRanges.end()) {
2072 os << " ( " << unreducedIt->second << " )";
2073 }
2074 }
2075 }
2076
2077 if (withConstraints) {
2078 os << "\n\n";
2079 os.indent(indent) << "Solver Constraints { ";
2080 if (solverConstraints.empty()) {
2081 os << "}\n";
2082 } else {
2083 for (const auto &e : solverConstraints) {
2084 os << '\n';
2085 os.indent(indent + 4);
2086 e.getExpr()->print(os);
2087 }
2088 os << '\n';
2089 os.indent(indent) << '}';
2090 }
2091 }
2092
2093 if (printName) {
2094 os << '\n';
2095 os.indent(indent - 4) << '}';
2096 }
2097 };
2098
2099 os << "StructIntervals { ";
2100 if (constrainMemberRanges.empty() && (!printCompute || computeMemberRanges.empty())) {
2101 os << "}\n";
2102 return;
2103 }
2104
2105 if (printCompute) {
2106 writeIntervals(
2107 FUNC_NAME_COMPUTE, computeMemberRanges, computeMemberUnreducedRanges,
2108 computeSolverConstraints, printCompute
2109 );
2110 }
2111 writeIntervals(
2112 FUNC_NAME_CONSTRAIN, constrainMemberRanges, constrainMemberUnreducedRanges,
2113 constrainSolverConstraints, printCompute
2114 );
2115
2116 os << "\n}\n";
2117}
2118
2119} // namespace llzk
Tracks a solver expression and an interval range for that expression.
ExpressionValue withUnreducedInterval(const UnreducedInterval &newUnreducedInterval) const
ExpressionValue withExpression(const llvm::SMTExprRef &newExpr) const
Return the current expression with a new SMT expression.
const Interval & getInterval() const
const std::optional< UnreducedInterval > & getOptionalUnreducedInterval() const
ExpressionValue withOptionalUnreducedInterval(std::optional< UnreducedInterval > newUnreducedInterval) const
ExpressionValue withInterval(const Interval &newInterval) const
Return the current expression with a new interval.
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
llvm::SMTExprRef getExpr() const
bool isBoolSort(const llvm::SMTSolverRef &solver) const
bool hasUnreducedInterval() const
const Field & getField() const
const UnreducedInterval & getUnreducedInterval() const
Information about the prime finite field used for the interval analysis.
Definition Field.h:36
llvm::DynamicAPInt zero() const
Returns 0 at the bitwidth of the field.
Definition Field.h:81
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
Definition Field.h:72
llvm::DynamicAPInt one() const
Returns 1 at the bitwidth of the field.
Definition Field.h:84
llvm::DynamicAPInt inv(const llvm::DynamicAPInt &i) const
Returns the multiplicative inverse of i in prime field p.
unsigned bitWidth() const
Definition Field.h:107
llvm::SMTExprRef createSymbol(const llvm::SMTSolverRef &solver, const char *name) const
Create a SMT solver symbol with the current field's bitwidth.
Definition Field.h:112
llvm::DynamicAPInt maxVal() const
Returns p - 1, which is the max value possible in a prime field described by p.
Definition Field.h:87
const LatticeValue & getValue() const
mlir::ChangeResult setValue(const LatticeValue &val)
IntervalAnalysisLatticeValue LatticeValue
mlir::ChangeResult meet(const AbstractSparseLattice &other) override
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult join(const AbstractSparseLattice &other) override
mlir::ChangeResult addSolverConstraint(const ExpressionValue &e)
mlir::LogicalResult visitOperation(mlir::Operation *op, mlir::ArrayRef< const Lattice * > operands, mlir::ArrayRef< Lattice * > results) override
Visit an operation with the lattices of its operands.
llvm::SMTExprRef getOrCreateSymbol(const SourceRef &r)
Either return the existing SMT expression that corresponds to the SourceRef, or create one.
const llvm::DenseMap< SourceRef, ExpressionValue > & getWriteResults() const
const llvm::DenseMap< SourceRef, llvm::DenseSet< Lattice * > > & getReadResults() const
Intervals over a finite field.
Definition Intervals.h:206
bool isEmpty() const
Definition Intervals.h:314
static Interval True(const Field &f)
Definition Intervals.h:225
llvm::DynamicAPInt rhs() const
Definition Intervals.h:339
Interval intersect(const Interval &rhs) const
Intersect.
UnreducedInterval toUnreduced() const
Convert to an UnreducedInterval.
static Interval Boolean(const Field &f)
Definition Intervals.h:227
UnreducedInterval firstUnreduced() const
Get the first side of the interval for TypeF intervals, otherwise just get the full interval as an Un...
static Interval Entire(const Field &f)
Definition Intervals.h:229
bool isDegenerate() const
Definition Intervals.h:316
bool isNotEmpty() const
Definition Intervals.h:315
static Interval False(const Field &f)
Definition Intervals.h:223
llvm::DynamicAPInt lhs() const
Definition Intervals.h:338
Interval join(const Interval &rhs) const
Union.
static SourceRefLatticeValue getValueState(mlir::DataFlowSolver &solver, mlir::Value val)
Defines an index into an LLZK object.
Definition SourceRef.h:42
A value at a given point of the SourceRefLattice.
const SourceRef & getSingleValue() const
mlir::FailureOr< std::pair< SourceRefLatticeValue, mlir::ChangeResult > > extract(const std::vector< SourceRefIndex > &indices) const
Perform an array.extract or array.read operation, depending on how many indices are provided.
A reference to a "source", which is the base value from which other SSA values are derived.
Definition SourceRef.h:132
mlir::FailureOr< SourceRef > createChild(const SourceRefIndex &r) const
Definition SourceRef.h:340
bool isScalar() const
Definition SourceRef.h:239
std::vector< SourceRefIndex > Path
Definition SourceRef.h:134
static std::vector< SourceRef > getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, const SourceRef &root)
Produce all possible SourceRefs that are present starting from the given root.
mlir::Type getType() const
const llvm::MapVector< SourceRef, Interval > & getConstrainIntervals() const
const llvm::MapVector< SourceRef, UnreducedInterval > & getConstrainUnreducedIntervals() const
void print(mlir::raw_ostream &os, bool withConstraints=false, bool printCompute=false, bool printUnreduced=false) const
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, const IntervalAnalysisContext &ctx)
An inclusive interval [a, b] where a and b are arbitrary integers not necessarily bound to a given fi...
Definition Intervals.h:26
UnreducedInterval computeLTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is guaranteed to be less than the rhs's max value.
Definition Intervals.cpp:63
UnreducedInterval computeGEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than or equal to the rhs's lower bound.
Definition Intervals.cpp:86
UnreducedInterval computeGTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than the rhs's lower bound.
Definition Intervals.cpp:78
Interval reduce(const Field &field) const
Reduce the interval to an interval in the given field.
Definition Intervals.cpp:23
UnreducedInterval computeLEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is less than or equal to the rhs's upper bound.
Definition Intervals.cpp:71
Helper for converting between linear and multi-dimensional indexing with checks to ensure indices are...
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
::mlir::Type getElementType() const
::llzk::boolean::FeltCmpPredicate getPredicate()
Definition Ops.cpp.inc:601
IntervalAnalysisLattice * getLatticeElement(mlir::Value value) override
bool isStructConstrain()
Return true iff the function is within a StructDefOp and named FUNC_NAME_CONSTRAIN.
Definition Ops.h.inc:882
bool isOperationLive(DataFlowSolver &solver, Operation *op)
ExpressionValue boolNot(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
ExpressionValue add(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
ExpressionValue sintMod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
RefValueCapture m_RefValue()
Definition Matchers.h:69
ExpressionValue intersection(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
FailureOr< Interval > signedIntDiv(const Interval &lhs, const Interval &rhs)
Computes signed integer division with possibly non-Degenerate divisors.
std::vector< std::pair< SourceRef, SourceRefLatticeValue > > SourceRefRemappings
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue shiftLeft(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue fallbackUnaryOp(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &val)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
Interval signedMod(const Interval &lhs, const Interval &rhs)
Computes signed integer remainder with possibly non-Degenerate divisors.
void ensure(bool condition, const llvm::Twine &errMsg)
ExpressionValue boolXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue cmp(const llvm::SMTSolverRef &solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue neg(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
DynamicAPInt toDynamicAPInt(StringRef str)
llvm::SMTExprRef createFieldInverseExpr(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &val, StringRef suffix="")
ExpressionValue sintDiv(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue boolAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
FailureOr< Interval > unsignedIntDiv(const Interval &lhs, const Interval &rhs)
Computes unsigned integer division with possibly non-Degenerate divisors.
ExpressionValue div(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue mul(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
ExpressionValue boolToFelt(const llvm::SMTSolverRef &solver, const ExpressionValue &expr, unsigned bitwidth)
mlir::FailureOr< SymbolLookupResult< T > > resolveCallableSilently(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Resolve a callable without emitting a diagnostic for missing top-level symbols.
ConstantCapture m_Constant()
Definition Matchers.h:89
std::string buildStringViaInsertionOp(Args &&...args)
Generate a string by using the insertion operator (<<) to append all args to a stream backed by the r...
ExpressionValue bitOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue uintDiv(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue bitAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue shiftRight(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
APSInt toAPSInt(const DynamicAPInt &i)
ExpressionValue sub(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
auto m_CommutativeOp(LhsMatcher lhs, RhsMatcher rhs)
Definition Matchers.h:47
ExpressionValue bitXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue notOp(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
FailureOr< Interval > feltDiv(const Interval &lhs, const Interval &rhs)
Computes finite-field division by multiplying the dividend by the multiplicative inverse of the divis...
ExpressionValue boolOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue selectValue(const llvm::SMTSolverRef &solver, const ExpressionValue &cond, const ExpressionValue &trueVal, const ExpressionValue &falseVal)
Parameters and shared objects to pass to child analyses.
const Field & getField() const
IntervalDataFlowAnalysis * intervalDFA