LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
IntervalAnalysis.cpp
Go to the documentation of this file.
1//===-- IntervalAnalysis.cpp - Interval analysis implementation -*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
15#include "llzk/Util/Debug.h"
17
18#include <mlir/Dialect/SCF/IR/SCF.h>
19
20#include <llvm/ADT/TypeSwitch.h>
21
22using namespace mlir;
23
24namespace llzk {
25
26using namespace array;
27using namespace boolean;
28using namespace cast;
29using namespace component;
30using namespace constrain;
31using namespace felt;
32using namespace function;
33
34/* ExpressionValue */
35
36llvm::SMTExprRef createFieldInverseExpr(
37 const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &val,
38 StringRef suffix = ""
39) {
40 const Field &field = val.getField();
41 const Interval &iv = val.getInterval();
42 if (iv.isDegenerate() && iv.lhs() != field.zero()) {
43 DynamicAPInt invVal = field.inv(iv.lhs());
44 return solver->mkBitvector(toAPSInt(invVal), field.bitWidth());
45 }
46
47 // The definition of an inverse X^-1 is Y s.t. XY % prime = 1.
48 // To create this expression, we create a new symbol for Y and add the
49 // XY % prime = 1 constraint to the solver.
50 std::string symName = buildStringViaInsertionOp(*op);
51 if (!suffix.empty()) {
52 symName += suffix.str();
53 }
54 llvm::SMTExprRef invSym = field.createSymbol(solver, symName.c_str());
55 llvm::SMTExprRef one = solver->mkBitvector(APSInt::get(1), field.bitWidth());
56 llvm::SMTExprRef prime = solver->mkBitvector(toAPSInt(field.prime()), field.bitWidth());
57 llvm::SMTExprRef mult = solver->mkBVMul(val.getExpr(), invSym);
58 llvm::SMTExprRef mod = solver->mkBVURem(mult, prime);
59 llvm::SMTExprRef constraint = solver->mkEqual(mod, one);
60 solver->addConstraint(constraint);
61 return invSym;
62}
63
65 if (expr == nullptr && rhs.expr == nullptr) {
66 return i == rhs.i;
67 }
68 if (expr == nullptr || rhs.expr == nullptr) {
69 return false;
70 }
71 return i == rhs.i && *expr == *rhs.expr;
72}
73
75boolToFelt(const llvm::SMTSolverRef &solver, const ExpressionValue &expr, unsigned bitwidth) {
76 llvm::SMTExprRef zero = solver->mkBitvector(mlir::APSInt::get(0), bitwidth);
77 llvm::SMTExprRef one = solver->mkBitvector(mlir::APSInt::get(1), bitwidth);
78 llvm::SMTExprRef boolToFeltConv = solver->mkIte(expr.getExpr(), one, zero);
79 return expr.withExpression(boolToFeltConv);
80}
81
83 const llvm::SMTSolverRef &solver, const ExpressionValue &cond, const ExpressionValue &trueVal,
84 const ExpressionValue &falseVal
85) {
86 const Field &f = trueVal.getField();
87 const Interval &condInterval = cond.getInterval();
88 Interval resultInterval;
89 if (condInterval.isEmpty()) {
90 resultInterval = Interval::Empty(f);
91 } else if (condInterval.isDegenerate() && condInterval.rhs() == f.one()) {
92 resultInterval = trueVal.getInterval();
93 } else if (condInterval.isDegenerate() && condInterval.rhs() == f.zero()) {
94 resultInterval = falseVal.getInterval();
95 } else {
96 resultInterval = trueVal.getInterval().join(falseVal.getInterval());
97 }
98 llvm::SMTExprRef resultExpr =
99 solver->mkIte(cond.getExpr(), trueVal.getExpr(), falseVal.getExpr());
100 return ExpressionValue(resultExpr, resultInterval);
101}
102
104 const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs
105) {
106 Interval res = lhs.i.intersect(rhs.i);
107 const auto *exprEq = solver->mkEqual(lhs.expr, rhs.expr);
108 return ExpressionValue(exprEq, res);
109}
110
112add(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
113 ExpressionValue res;
114 res.i = lhs.i + rhs.i;
115 res.expr = solver->mkBVAdd(lhs.expr, rhs.expr);
116 return res;
117}
118
120sub(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
121 ExpressionValue res;
122 res.i = lhs.i - rhs.i;
123 res.expr = solver->mkBVSub(lhs.expr, rhs.expr);
124 return res;
125}
126
128mul(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
129 ExpressionValue res;
130 res.i = lhs.i * rhs.i;
131 res.expr = solver->mkBVMul(lhs.expr, rhs.expr);
132 return res;
133}
134
136div(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs,
137 const ExpressionValue &rhs) {
138 ExpressionValue res;
139 auto divRes = feltDiv(lhs.i, rhs.i);
140 if (failed(divRes)) {
141 const Field &field = lhs.getField();
142 const Interval &rhsInterval = rhs.getInterval();
143 Interval zero = Interval::Degenerate(field, field.zero());
144 if (!rhsInterval.isDegenerate()) {
145 if (rhsInterval.intersect(zero).isNotEmpty()) {
146 op->emitWarning(
147 "non-degenerate felt.div divisors are not tracked precisely, and the divisor may "
148 "contain zero. Range of division result will be treated as unbounded."
149 )
150 .report();
151 } else {
152 op->emitWarning(
153 "non-degenerate felt.div divisors are not tracked precisely because precise field "
154 "division over intervals would require enumerating divisor inverses. Range of "
155 "division result will be treated as unbounded."
156 )
157 .report();
158 }
159 } else {
160 op->emitWarning(
161 "divisor is zero, leading to a divide-by-zero error. Range of division result will "
162 "be treated as unbounded."
163 )
164 .report();
165 }
166 res.i = Interval::Entire(lhs.getField());
167 } else {
168 res.i = *divRes;
169 }
170 llvm::SMTExprRef invExpr = createFieldInverseExpr(solver, op, rhs, ".div_inv");
171 res.expr = solver->mkBVMul(lhs.expr, invExpr);
172 return res;
173}
174
176 const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs,
177 const ExpressionValue &rhs
178) {
179 ExpressionValue res;
180 auto divRes = unsignedIntDiv(lhs.i, rhs.i);
181 if (failed(divRes)) {
182 op->emitWarning(
183 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
184 " Range of division result will be treated as unbounded."
185 )
186 .report();
187 res.i = Interval::Entire(lhs.getField());
188 } else {
189 res.i = *divRes;
190 }
191 res.expr = solver->mkBVUDiv(lhs.expr, rhs.expr);
192 return res;
193}
194
196 const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs,
197 const ExpressionValue &rhs
198) {
199 ExpressionValue res;
200 auto divRes = signedIntDiv(lhs.i, rhs.i);
201 if (failed(divRes)) {
202 op->emitWarning(
203 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
204 " Range of division result will be treated as unbounded."
205 )
206 .report();
207 res.i = Interval::Entire(lhs.getField());
208 } else {
209 res.i = *divRes;
210 }
211 res.expr = solver->mkBVSDiv(lhs.expr, rhs.expr);
212 return res;
213}
214
215ExpressionValue
216mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
217 ExpressionValue res;
218 res.i = lhs.i % rhs.i;
219 res.expr = solver->mkBVURem(lhs.expr, rhs.expr);
220 return res;
221}
222
224bitAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
225 ExpressionValue res;
226 res.i = lhs.i & rhs.i;
227 res.expr = solver->mkBVAnd(lhs.expr, rhs.expr);
228 return res;
229}
230
232bitOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
233 ExpressionValue res;
234 res.i = lhs.i | rhs.i;
235 res.expr = solver->mkBVOr(lhs.expr, rhs.expr);
236 return res;
237}
238
240bitXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
241 if (lhs.isBoolSort(solver) && rhs.isBoolSort(solver)) {
242 return boolXor(solver, lhs, rhs);
243 }
244
245 ExpressionValue res;
246 res.i = lhs.i ^ rhs.i;
247 res.expr = solver->mkBVXor(lhs.expr, rhs.expr);
248 return res;
249}
250
252 const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs
253) {
254 ExpressionValue res;
255 res.i = lhs.i << rhs.i;
256 res.expr = solver->mkBVShl(lhs.expr, rhs.expr);
257 return res;
258}
259
261 const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs
262) {
263 ExpressionValue res;
264 res.i = lhs.i >> rhs.i;
265 res.expr = solver->mkBVLshr(lhs.expr, rhs.expr);
266 return res;
267}
268
270cmp(const llvm::SMTSolverRef &solver, CmpOp op, const ExpressionValue &lhs,
271 const ExpressionValue &rhs) {
272 ExpressionValue res;
273 const Field &f = lhs.getField();
274 // Default result is any boolean output for when we are unsure about the comparison result.
275 res.i = Interval::Boolean(f);
276 switch (op.getPredicate()) {
277 case FeltCmpPredicate::EQ:
278 res.expr = solver->mkEqual(lhs.expr, rhs.expr);
279 if (lhs.i.isDegenerate() && rhs.i.isDegenerate()) {
280 res.i = lhs.i == rhs.i ? Interval::True(f) : Interval::False(f);
281 } else if (lhs.i.intersect(rhs.i).isEmpty()) {
282 res.i = Interval::False(f);
283 }
284 break;
285 case FeltCmpPredicate::NE:
286 res.expr = solver->mkNot(solver->mkEqual(lhs.expr, rhs.expr));
287 if (lhs.i.isDegenerate() && rhs.i.isDegenerate()) {
288 res.i = lhs.i != rhs.i ? Interval::True(f) : Interval::False(f);
289 } else if (lhs.i.intersect(rhs.i).isEmpty()) {
290 res.i = Interval::True(f);
291 }
292 break;
293 case FeltCmpPredicate::LT:
294 res.expr = solver->mkBVUlt(lhs.expr, rhs.expr);
295 if (lhs.i.toUnreduced().computeGEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
296 res.i = Interval::True(f);
297 }
298 if (lhs.i.toUnreduced().computeLTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
299 res.i = Interval::False(f);
300 }
301 break;
302 case FeltCmpPredicate::LE:
303 res.expr = solver->mkBVUle(lhs.expr, rhs.expr);
304 if (lhs.i.toUnreduced().computeGTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
305 res.i = Interval::True(f);
306 }
307 if (lhs.i.toUnreduced().computeLEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
308 res.i = Interval::False(f);
309 }
310 break;
311 case FeltCmpPredicate::GT:
312 res.expr = solver->mkBVUgt(lhs.expr, rhs.expr);
313 if (lhs.i.toUnreduced().computeLEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
314 res.i = Interval::True(f);
315 }
316 if (lhs.i.toUnreduced().computeGTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
317 res.i = Interval::False(f);
318 }
319 break;
320 case FeltCmpPredicate::GE:
321 res.expr = solver->mkBVUge(lhs.expr, rhs.expr);
322 if (lhs.i.toUnreduced().computeLTPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
323 res.i = Interval::True(f);
324 }
325 if (lhs.i.toUnreduced().computeGEPart(rhs.i.toUnreduced()).reduce(f).isEmpty()) {
326 res.i = Interval::False(f);
327 }
328 break;
329 }
330 return res;
331}
332
334boolAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
335 ExpressionValue res;
336 res.i = boolAnd(lhs.i, rhs.i);
337 res.expr = solver->mkAnd(lhs.expr, rhs.expr);
338 return res;
339}
340
342boolOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
343 ExpressionValue res;
344 res.i = boolOr(lhs.i, rhs.i);
345 res.expr = solver->mkOr(lhs.expr, rhs.expr);
346 return res;
347}
348
350boolXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs) {
351 ExpressionValue res;
352 res.i = boolXor(lhs.i, rhs.i);
353 // There's no Xor, so we do (L || R) && !(L && R)
354 res.expr = solver->mkAnd(
355 solver->mkOr(lhs.expr, rhs.expr), solver->mkNot(solver->mkAnd(lhs.expr, rhs.expr))
356 );
357 return res;
358}
359
360ExpressionValue neg(const llvm::SMTSolverRef &solver, const ExpressionValue &val) {
361 ExpressionValue res;
362 res.i = -val.i;
363 res.expr = solver->mkBVNeg(val.expr);
364 return res;
365}
366
367ExpressionValue notOp(const llvm::SMTSolverRef &solver, const ExpressionValue &val) {
368 ExpressionValue res;
369 res.i = ~val.i;
370 res.expr = solver->mkBVNot(val.expr);
371 return res;
372}
373
374ExpressionValue boolNot(const llvm::SMTSolverRef &solver, const ExpressionValue &val) {
375 ExpressionValue res;
376 res.i = boolNot(val.i);
377 res.expr = solver->mkNot(val.expr);
378 return res;
379}
380
382fallbackUnaryOp(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &val) {
383 const Field &field = val.getField();
384 ExpressionValue res;
385 res.i = Interval::Entire(field);
386 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
387 .Case<InvFeltOp>([&](auto) {
388 return createFieldInverseExpr(solver, op, val);
389 }).Default([](Operation *unsupported) {
390 llvm::report_fatal_error(
391 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
392 );
393 return nullptr;
394 });
395
396 return res;
397}
398
399void ExpressionValue::print(mlir::raw_ostream &os) const {
400 if (expr) {
401 expr->print(os);
402 } else {
403 os << "<null expression>";
404 }
405
406 os << " ( interval: " << i << " )";
407}
408
409/* IntervalAnalysisLattice */
410
411ChangeResult IntervalAnalysisLattice::join(const AbstractSparseLattice & /*other*/) {
412 // The update logic is handled in visitOperation; we don't support a generic
413 // join operation, as it may override valid intervals.
414 return ChangeResult::NoChange;
415}
416
417ChangeResult IntervalAnalysisLattice::meet(const AbstractSparseLattice & /*other*/) {
418 // The update logic is handled in visitOperation; we don't support a generic
419 // meet operation, as it may override valid intervals.
420 return ChangeResult::NoChange;
421}
422
423void IntervalAnalysisLattice::print(mlir::raw_ostream &os) const {
424 os << "IntervalAnalysisLattice { " << val << " }";
425}
426
428 if (val == newVal) {
429 return ChangeResult::NoChange;
430 }
431 val = newVal;
432 return ChangeResult::Change;
433}
434
436 LatticeValue newVal(e);
437 return setValue(newVal);
438}
439
441 if (!constraints.contains(e)) {
442 constraints.insert(e);
443 return ChangeResult::Change;
444 }
445 return ChangeResult::NoChange;
446}
447
448/* IntervalDataFlowAnalysis */
449
450SourceRefLatticeValue IntervalDataFlowAnalysis::getSourceRefState(Value val) {
451 return SourceRefAnalysis::getValueState(_dataflowSolver, val);
452}
453
454std::vector<SourceRefIndex> IntervalDataFlowAnalysis::getArrayAccessIndices(
455 Operation *baseOp, ArrayAccessOpInterface arrayAccessOp
456) {
457 std::vector<SourceRefIndex> indices;
458 ArrayType arrayType = arrayAccessOp.getArrRefType();
459 size_t numIndices = arrayAccessOp.getIndices().size();
460 indices.reserve(numIndices);
461
462 for (size_t i = 0; i < numIndices; ++i) {
463 Value idxOperand = arrayAccessOp.getIndices()[i];
464 SourceRefLatticeValue idxVals = getSourceRefState(idxOperand);
465
466 // Only exact constant indices get tracked precisely.
467 if (idxVals.isSingleValue() && idxVals.getSingleValue().isConstant()) {
468 indices.emplace_back(*idxVals.getSingleValue().getConstantValue());
469 } else {
470 auto lower = APInt::getZero(64);
471 APInt upper(64, arrayType.getDimSize(i));
472 indices.emplace_back(lower, upper);
473 }
474 }
475
476 return indices;
477}
478
479mlir::FailureOr<SourceRef> IntervalDataFlowAnalysis::getArrayAccessRef(
480 Operation *baseOp, ArrayAccessOpInterface arrayAccessOp
481) {
482 std::vector<SourceRefIndex> indices = getArrayAccessIndices(baseOp, arrayAccessOp);
483 Value arrayVal = arrayAccessOp.getArrRef();
484 if (auto blockArg = llvm::dyn_cast<BlockArgument>(arrayVal)) {
485 return SourceRef(blockArg, std::move(indices));
486 }
487 if (auto result = llvm::dyn_cast<OpResult>(arrayVal)) {
488 return SourceRef(result, std::move(indices));
489 }
490 return failure();
491}
492
493Interval IntervalDataFlowAnalysis::getRefInterval(const SourceRef &ref) {
494 if (auto it = writeResults.find(ref); it != writeResults.end()) {
495 return it->second.getInterval();
496 }
497
498 if (ref.isConstantInt()) {
499 auto constVal = ref.getConstantValue();
500 if (succeeded(constVal)) {
501 return Interval::Degenerate(field.get(), *constVal);
502 }
503 }
504
505 if (ref.isRooted() && ref.getPath().empty()) {
506 auto rootVal = ref.getRoot();
507 if (succeeded(rootVal) && !llvm::isa<ArrayType, StructType>(rootVal->getType())) {
508 const ExpressionValue &rootExpr = getLatticeElement(*rootVal)->getValue().getScalarValue();
509 if (rootExpr.getExpr() != nullptr) {
510 return rootExpr.getInterval();
511 }
512 }
513 }
514
515 return getDefaultIntervalForType(ref.getType());
516}
517
518ExpressionValue IntervalDataFlowAnalysis::getRefValue(const SourceRef &ref, Value val) {
519 if (auto it = writeResults.find(ref); it != writeResults.end()) {
520 return it->second;
521 }
522 return createUnknownValue(val).withInterval(getRefInterval(ref));
523}
524
525void IntervalDataFlowAnalysis::recordRefWrite(
526 const SourceRef &writtenRef, const ExpressionValue &writeVal
527) {
528 ExpressionValue written = writeVal;
529
530 if (auto it = writeResults.find(writtenRef); it != writeResults.end()) {
531 const ExpressionValue &old = it->second;
532 Interval combinedWrite = old.getInterval().join(written.getInterval());
533 if (old.getExpr() != nullptr && written.getExpr() != nullptr &&
534 *old.getExpr() == *written.getExpr()) {
535 writeResults[writtenRef] = old.withInterval(combinedWrite);
536 } else {
537 llvm::SMTExprRef expr = getOrCreateSymbol(writtenRef);
538 writeResults[writtenRef] = ExpressionValue(expr, combinedWrite);
539 }
540 } else {
541 writeResults[writtenRef] = written;
542 }
543
544 for (Lattice *readerLattice : readResults[writtenRef]) {
545 ExpressionValue prior = readerLattice->getValue().getScalarValue();
546 Interval intersection = prior.getInterval().intersect(written.getInterval());
547 ExpressionValue newVal = prior.withInterval(intersection);
548 propagateIfChanged(readerLattice, readerLattice->setValue(newVal));
549 }
550}
551
553 Operation *op, ArrayRef<const Lattice *> operands, ArrayRef<Lattice *> results
554) {
555 // We only perform the visitation on operations within functions
556 FuncDefOp fn = op->getParentOfType<FuncDefOp>();
557 if (!fn) {
558 return success();
559 }
560
561 // If there are no operands or results, skip.
562 if (operands.empty() && results.empty()) {
563 return success();
564 }
565
566 // Get the values or defaults from the operand lattices
567 llvm::SmallVector<LatticeValue> operandVals;
568 llvm::SmallVector<std::optional<SourceRef>> operandRefs;
569 for (unsigned opNum = 0; opNum < op->getNumOperands(); ++opNum) {
570 Value val = op->getOperand(opNum);
571 SourceRefLatticeValue refSet = getSourceRefState(val);
572 if (refSet.isSingleValue()) {
573 operandRefs.push_back(refSet.getSingleValue());
574 } else {
575 operandRefs.push_back(std::nullopt);
576 }
577 // First, lookup the operand value after it is initialized
578 auto priorState = operands[opNum]->getValue();
579 if (priorState.getScalarValue().getExpr() != nullptr) {
580 operandVals.push_back(priorState);
581 continue;
582 }
583
584 if (auto readArr = llvm::dyn_cast_if_present<ReadArrayOp>(val.getDefiningOp())) {
585 auto arrayRef = getArrayAccessRef(op, readArr);
586 if (succeeded(arrayRef)) {
587 if (auto it = writeResults.find(*arrayRef); it != writeResults.end()) {
588 operandVals.emplace_back(it->second);
589 Lattice *operandLattice = getLatticeElement(val);
590 (void)operandLattice->setValue(it->second);
591 continue;
592 }
593 }
594 }
595
596 // Else, look up the stored value by `SourceRef`.
597 // We only care about scalar type values, so we ignore composite types, which
598 // are currently limited to structs and arrays.
599 Type valTy = val.getType();
600 if (llvm::isa<ArrayType, StructType>(valTy)) {
601 ExpressionValue anyVal(field.get(), createSymbol(valTy, buildStringViaPrint(val).c_str()));
602 operandVals.emplace_back(anyVal);
603 continue;
604 }
605
606 ensure(refSet.isScalar(), "should have ruled out array values already");
607
608 if (refSet.getScalarValue().empty()) {
609 // If we can't compute the reference, then there must be some unsupported
610 // op the reference analysis cannot handle. We emit a warning and return
611 // early, since there's no meaningful computation we can do for this op.
612 op->emitWarning()
613 .append(
614 "state of ", val, " is empty; defining operation is unsupported by SourceRef analysis"
615 )
616 .report();
617 // We still return success so we can return overapproximated and partial
618 // results to the user.
619 return success();
620 } else if (!refSet.isSingleValue()) {
621 Interval joinedInterval = Interval::Empty(field.get());
622 for (const SourceRef &ref : refSet.getScalarValue()) {
623 joinedInterval = joinedInterval.join(getRefInterval(ref));
624 }
625 ExpressionValue anyVal = createUnknownValue(val).withInterval(joinedInterval);
626 operandVals.emplace_back(anyVal);
627 } else {
628 const SourceRef &ref = refSet.getSingleValue();
629 operandVals.emplace_back(getRefValue(ref, val));
630 }
631
632 // Since we initialized a value that was not found in the before lattice,
633 // update that value in the lattice so we can find it later, but we don't
634 // need to propagate the changes, since we already have what we need.
635 Lattice *operandLattice = getLatticeElement(val);
636 (void)operandLattice->setValue(operandVals[opNum]);
637 }
638
639 // Now, the way we update is dependent on the type of the operation.
640 if (isConstOp(op)) {
641 llvm::DynamicAPInt constVal = getConst(op);
642 llvm::SMTExprRef expr;
643 if (isBoolConstOp(op)) {
644 expr = createConstBoolExpr(constVal != 0);
645 } else {
646 expr = createConstBitvectorExpr(constVal);
647 }
648
649 ExpressionValue latticeVal(field.get(), expr, constVal);
650 propagateIfChanged(results[0], results[0]->setValue(latticeVal));
651 } else if (isArithmeticOp(op)) {
652 ExpressionValue result;
653 if (operands.size() == 2) {
654 result = performBinaryArithmetic(op, operandVals[0], operandVals[1]);
655 } else {
656 result = performUnaryArithmetic(op, operandVals[0]);
657 }
658
659 // Also intersect with prior interval, if it's initialized
660 const ExpressionValue &prior = results[0]->getValue().getScalarValue();
661 if (prior.getExpr()) {
662 result = result.withInterval(result.getInterval().intersect(prior.getInterval()));
663 }
664 propagateIfChanged(results[0], results[0]->setValue(result));
665 } else if (auto selectOp = llvm::dyn_cast<arith::SelectOp>(op)) {
667 smtSolver, operandVals[0].getScalarValue(), operandVals[1].getScalarValue(),
668 operandVals[2].getScalarValue()
669 );
670 const ExpressionValue &prior = results[0]->getValue().getScalarValue();
671 if (prior.getExpr()) {
672 result = result.withInterval(result.getInterval().intersect(prior.getInterval()));
673 }
674 propagateIfChanged(results[0], results[0]->setValue(result));
675 } else if (EmitEqualityOp emitEq = llvm::dyn_cast<EmitEqualityOp>(op)) {
676 Value lhsVal = emitEq.getLhs(), rhsVal = emitEq.getRhs();
677 ExpressionValue lhsExpr = operandVals[0].getScalarValue();
678 ExpressionValue rhsExpr = operandVals[1].getScalarValue();
679
680 // Special handling for generalized (s - c0) * (s - c1) * ... * (s - cN) = 0 patterns.
681 // These patterns enforce that s is one of c0, ..., cN.
682 auto res = getGeneralizedDecompInterval(op, lhsVal, rhsVal);
683 if (succeeded(res)) {
684 for (Value signalVal : res->first) {
685 applyInterval(emitEq, signalVal, res->second);
686 }
687 }
688
689 ExpressionValue constraint = intersection(smtSolver, lhsExpr, rhsExpr);
690 // Update the LHS and RHS to the same value, but restricted intervals
691 // based on the constraints.
692 const Interval &constrainInterval = constraint.getInterval();
693 applyInterval(emitEq, lhsVal, constrainInterval);
694 applyInterval(emitEq, rhsVal, constrainInterval);
695 } else if (auto assertOp = llvm::dyn_cast<AssertOp>(op)) {
696 // assert enforces that the operand is true. So we apply an interval of [1, 1]
697 // to the operand.
698 Value cond = assertOp.getCondition();
699 applyInterval(assertOp, cond, Interval::True(field.get()));
700 // Also add the solver constraint that the expression must be true.
701 auto assertExpr = operandVals[0].getScalarValue();
702 // No need to propagate the constraint
703 (void)getLatticeElement(cond)->addSolverConstraint(assertExpr);
704 } else if (auto writem = llvm::dyn_cast<MemberWriteOp>(op)) {
705 // Update values stored in a member
706 ExpressionValue writeVal = operandVals[1].getScalarValue();
707 auto cmp = writem.getComponent();
708 // We also need to update the interval on the assigned symbol
709 SourceRefLatticeValue refSet = getSourceRefState(cmp);
710 if (refSet.isSingleValue()) {
711 auto memberDefRes = writem.getMemberDefOp(tables);
712 if (succeeded(memberDefRes)) {
713 SourceRefIndex idx(memberDefRes.value());
714 auto memberRefRes = refSet.getSingleValue().createChild(idx);
715 ensure(succeeded(memberRefRes), "could not create SourceRef child for member write");
716 SourceRef memberRef = *memberRefRes;
717 Type memberTy = writem.getVal().getType();
718 if (!llvm::isa<ArrayType, StructType>(memberTy)) {
719 // Simple scalar update
720 recordRefWrite(memberRef, writeVal);
721 } else {
722 // Map the intervals of aggregates to the written member
723 std::optional<SourceRef> rhsPrefix;
724 if (operandRefs[1].has_value() && operandRefs[1]->isRooted()) {
725 rhsPrefix = operandRefs[1];
726 } else if (auto blockArg = llvm::dyn_cast<BlockArgument>(writem.getVal())) {
727 rhsPrefix = SourceRef(blockArg);
728 } else if (auto result = llvm::dyn_cast<OpResult>(writem.getVal())) {
729 rhsPrefix = SourceRef(result);
730 }
731
732 if (rhsPrefix.has_value()) {
733 llvm::SmallVector<std::pair<SourceRef, ExpressionValue>> remappedWrites;
734 for (const auto &[writtenRef, writtenVal] : writeResults) {
735 if (!writtenRef.isValidPrefix(*rhsPrefix)) {
736 continue;
737 }
738
739 auto translatedRef = writtenRef.translate(*rhsPrefix, memberRef);
740 ensure(succeeded(translatedRef), "could not translate composite member write");
741 remappedWrites.emplace_back(*translatedRef, writtenVal);
742 }
743
744 for (const auto &[translatedRef, translatedVal] : remappedWrites) {
745 recordRefWrite(translatedRef, translatedVal);
746 }
747 }
748 }
749 }
750 }
751 } else if (auto writeArr = llvm::dyn_cast<WriteArrayOp>(op)) {
752 ExpressionValue writeVal = operandVals.back().getScalarValue();
753 auto arrayRef = getArrayAccessRef(op, writeArr);
754 if (succeeded(arrayRef)) {
755 recordRefWrite(*arrayRef, writeVal);
756 }
757
758 SourceRefLatticeValue arrayVals = getSourceRefState(writeArr.getArrRef());
759 if (arrayVals.isScalar()) {
760 std::vector<SourceRefIndex> indices = getArrayAccessIndices(op, writeArr);
761 auto targetRefsRes = arrayVals.extract(indices);
762 ensure(succeeded(targetRefsRes), "could not create SourceRef child for array write");
763 auto [targetRefs, _] = *targetRefsRes;
764 ensure(targetRefs.isScalar(), "array write must resolve to scalar references");
765 for (const SourceRef &ref : targetRefs.getScalarValue()) {
766 recordRefWrite(ref, writeVal);
767 }
768 }
769 } else if (auto createArray = llvm::dyn_cast<CreateArrayOp>(op)) {
770 const auto &elements = createArray.getElements();
771 ArrayType arrayTy = createArray.getType();
772 Type elemTy = arrayTy.getElementType();
773
774 if (!elements.empty() && !llvm::isa<ArrayType, StructType>(elemTy)) {
775 ensure(arrayTy.hasStaticShape(), "array.new with explicit elements must have static shape");
776 ensure(
777 std::cmp_equal(elements.size(), arrayTy.getNumElements()),
778 "array.new explicit initializer length must match array shape"
779 );
780
782 auto arrayRes = llvm::cast<OpResult>(createArray->getResult(0));
783 for (unsigned i = 0; i < elements.size(); ++i) {
784 auto maybeIndices = indexGen.delinearize(i, op->getContext());
785 ensure(maybeIndices.has_value(), "could not delinearize array.new element index");
786
787 SourceRef::Path path;
788 path.reserve(maybeIndices->size());
789 for (Attribute attr : *maybeIndices) {
790 auto idxAttr = llvm::dyn_cast<IntegerAttr>(attr);
791 ensure(idxAttr != nullptr, "array.new delinearize should produce integer attributes");
792 path.emplace_back(idxAttr.getValue());
793 }
794
795 recordRefWrite(SourceRef(arrayRes, std::move(path)), operandVals[i].getScalarValue());
796 }
797 }
798 } else if (isa<IntToFeltOp, FeltToIndexOp>(op)) {
799 // Casts don't modify the intervals, but they do modify the SMT types.
800 ExpressionValue expr = operandVals[0].getScalarValue();
801 // We treat all ints and indexes as felts with the exception of comparison
802 // results, which are bools. So if `expr` is a bool, this cast needs to
803 // upcast to a felt.
804 if (expr.isBoolSort(smtSolver)) {
805 expr = boolToFelt(smtSolver, expr, field.get().bitWidth());
806 }
807 propagateIfChanged(results[0], results[0]->setValue(expr));
808 } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
809 // Fetch the lattice for after the parent operation so we can propagate
810 // the yielded value to subsequent operations.
811 Operation *parent = op->getParentOp();
812 ensure(parent, "yield operation must have parent operation");
813 // Bind the operand values to the result values of the parent
814 for (unsigned idx = 0; idx < yieldOp.getResults().size(); ++idx) {
815 Value parentRes = parent->getResult(idx);
816 Lattice *resLattice = getLatticeElement(parentRes);
817 // Merge with the existing value, if present (e.g., another branch)
818 // has possible value that must be merged.
819 ExpressionValue exprVal = resLattice->getValue().getScalarValue();
820 ExpressionValue newResVal = operandVals[idx].getScalarValue();
821 if (auto loopOp = llvm::dyn_cast<LoopLikeOpInterface>(parent)) {
822 // We overapproximate for loops because we aren't going to try to track trip count.
823 newResVal = ExpressionValue(createSymbol(parentRes), Interval::Entire(field.get()));
824 }
825 if (exprVal.getExpr() != nullptr) {
826 newResVal = exprVal.withInterval(exprVal.getInterval().join(newResVal.getInterval()));
827 } else {
828 newResVal = ExpressionValue(createSymbol(parentRes), newResVal.getInterval());
829 }
830 propagateIfChanged(resLattice, resLattice->setValue(newResVal));
831 }
832 } else if (
833 // We do not need to explicitly handle read ops since they are resolved at the operand value
834 // step where `SourceRef`s are queries.
835 !isReadOp(op)
836 // We do not currently handle return ops as the analysis is currently limited to constrain
837 // functions, which return no value.
838 && !isReturnOp(op)
839 // The analysis ignores definition ops.
840 && !isDefinitionOp(op)
841 // We do not need to analyze storage creation directly.
842 && !llvm::isa<CreateArrayOp, CreateStructOp, NonDetOp>(op)
843 ) {
844 op->emitWarning("unhandled operation, analysis may be incomplete").report();
845 }
846
847 return success();
848}
849
851 auto it = refSymbols.find(r);
852 if (it != refSymbols.end()) {
853 return it->second;
854 }
855 llvm::SMTExprRef sym = createSymbol(r);
856 refSymbols[r] = sym;
857 return sym;
858}
859
860llvm::SMTExprRef IntervalDataFlowAnalysis::createSymbol(mlir::Type ty, const char *name) const {
861 if (isBooleanType(ty)) {
862 return smtSolver->mkSymbol(name, smtSolver->getBoolSort());
863 }
864 return field.get().createSymbol(smtSolver, name);
865}
866
867llvm::SMTExprRef IntervalDataFlowAnalysis::createSymbol(const SourceRef &r) const {
868 std::string name = buildStringViaPrint(r);
869 return createSymbol(r.getType(), name.c_str());
870}
871
872llvm::SMTExprRef IntervalDataFlowAnalysis::createSymbol(Value v) const {
873 std::string name = buildStringViaPrint(v);
874 return createSymbol(v.getType(), name.c_str());
875}
876
877llvm::DynamicAPInt IntervalDataFlowAnalysis::getConst(Operation *op) const {
878 ensure(isConstOp(op), "op is not a const op");
879
880 // NOTE: I think clang-format makes these hard to read by default
881 // clang-format off
882 llvm::DynamicAPInt fieldConst = TypeSwitch<Operation *, llvm::DynamicAPInt>(op)
883 .Case<FeltConstantOp>([&](auto feltConst) {
884 llvm::APSInt constOpVal(feltConst.getValue());
885 return field.get().reduce(constOpVal);
886 })
887 .Case<arith::ConstantIndexOp>([&](auto indexConst) {
888 return DynamicAPInt(indexConst.value());
889 })
890 .Case<arith::ConstantIntOp>([&](auto intConst) {
891 auto valAttr = dyn_cast<IntegerAttr>(intConst.getValue());
892 ensure(valAttr != nullptr, "arith::ConstantIntOp must have an IntegerAttr as its value");
893 return toDynamicAPInt(valAttr.getValue());
894 })
895 .Default([](auto *illegalOp) {
896 std::string err;
897 debug::Appender(err) << "unhandled getConst case: " << *illegalOp;
898 llvm::report_fatal_error(Twine(err));
899 return llvm::DynamicAPInt();
900 });
901 // clang-format on
902 return fieldConst;
903}
904
905ExpressionValue IntervalDataFlowAnalysis::performBinaryArithmetic(
906 Operation *op, const LatticeValue &a, const LatticeValue &b
907) {
908 ensure(isArithmeticOp(op), "is not arithmetic op");
909
910 auto lhs = a.getScalarValue(), rhs = b.getScalarValue();
911 ensure(lhs.getExpr(), "cannot perform arithmetic over null lhs smt expr");
912 ensure(rhs.getExpr(), "cannot perform arithmetic over null rhs smt expr");
913
914 // clang-format off
915 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
916 .Case<AddFeltOp>([&](auto) { return add(smtSolver, lhs, rhs); })
917 .Case<SubFeltOp>([&](auto) { return sub(smtSolver, lhs, rhs); })
918 .Case<MulFeltOp>([&](auto) { return mul(smtSolver, lhs, rhs); })
919 .Case<DivFeltOp>([&](auto) {return div(smtSolver, op, lhs, rhs); })
920 .Case<UnsignedIntDivFeltOp>([&](auto) {return uintDiv(smtSolver, op, lhs, rhs); })
921 .Case<SignedIntDivFeltOp>([&](auto) {return sintDiv(smtSolver, op, lhs, rhs); })
922 .Case<UnsignedModFeltOp>([&](auto) { return mod(smtSolver, lhs, rhs); })
923 .Case<AndFeltOp>([&](auto) { return bitAnd(smtSolver, lhs, rhs); })
924 .Case<OrFeltOp>([&](auto) { return bitOr(smtSolver, lhs, rhs); })
925 .Case<XorFeltOp, arith::XOrIOp>([&](auto) { return bitXor(smtSolver, lhs, rhs); })
926 .Case<ShlFeltOp>([&](auto) { return shiftLeft(smtSolver, lhs, rhs); })
927 .Case<ShrFeltOp>([&](auto) { return shiftRight(smtSolver, lhs, rhs); })
928 .Case<CmpOp>([&](auto cmpOp) { return cmp(smtSolver, cmpOp, lhs, rhs); })
929 .Case<AndBoolOp>([&](auto) { return boolAnd(smtSolver, lhs, rhs); })
930 .Case<OrBoolOp>([&](auto) { return boolOr(smtSolver, lhs, rhs); })
931 .Case<XorBoolOp>([&](auto) { return boolXor(smtSolver, lhs, rhs); })
932 .Default([&](auto *unsupported) {
933 unsupported
934 ->emitError(
935 "unsupported binary arithmetic operation"
936 )
937 .report();
938 return ExpressionValue();
939 });
940 // clang-format on
941
942 ensure(res.getExpr(), "arithmetic produced null smt expr");
943 return res;
944}
945
947IntervalDataFlowAnalysis::performUnaryArithmetic(Operation *op, const LatticeValue &a) {
948 ensure(isArithmeticOp(op), "is not arithmetic op");
949
950 auto val = a.getScalarValue();
951 ensure(val.getExpr(), "cannot perform arithmetic over null smt expr");
952
953 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
954 .Case<NegFeltOp>([&](auto) { return neg(smtSolver, val); })
955 .Case<NotFeltOp>([&](auto) { return notOp(smtSolver, val); })
956 .Case<NotBoolOp>([&](auto) { return boolNot(smtSolver, val); })
957 // The inverse op is currently overapproximated
958 .Case<InvFeltOp>([&](auto inv) {
959 return fallbackUnaryOp(smtSolver, inv, val);
960 }).Default([&](auto *unsupported) {
961 unsupported
962 ->emitWarning(
963 "unsupported unary arithmetic operation, defaulting to over-approximated interval"
964 )
965 .report();
966 return fallbackUnaryOp(smtSolver, unsupported, val);
967 });
968
969 ensure(res.getExpr(), "arithmetic produced null smt expr");
970 return res;
971}
972
973void IntervalDataFlowAnalysis::applyInterval(Operation *valUser, Value val, Interval newInterval) {
974 Lattice *valLattice = getLatticeElement(val);
975 ExpressionValue oldLatticeVal = valLattice->getValue().getScalarValue();
976 // Intersect with the current value to accumulate restrictions across constraints.
977 Interval intersection = oldLatticeVal.getInterval().intersect(newInterval);
978 ExpressionValue newLatticeVal = oldLatticeVal.withInterval(intersection);
979 ChangeResult changed = valLattice->setValue(newLatticeVal);
980
981 if (auto blockArg = llvm::dyn_cast<BlockArgument>(val)) {
982 auto fnOp = dyn_cast<FuncDefOp>(blockArg.getOwner()->getParentOp());
983
984 // Apply the interval from the constrain function inputs to the compute function inputs
985 if (propagateInputConstraints && fnOp && fnOp.isStructConstrain() &&
986 blockArg.getArgNumber() > 0 && !newInterval.isEntire()) {
987 auto structOp = fnOp->getParentOfType<StructDefOp>();
988 FuncDefOp computeFn = structOp.getComputeFuncOp();
989 BlockArgument computeArg = computeFn.getArgument(blockArg.getArgNumber() - 1);
990 Lattice *computeEntryLattice = getLatticeElement(computeArg);
991
992 SourceRef ref(computeArg);
993 ExpressionValue newArgVal(getOrCreateSymbol(ref), newInterval);
994 propagateIfChanged(computeEntryLattice, computeEntryLattice->setValue(newArgVal));
995 }
996 }
997
998 // Now we descend into val's operands, if it has any.
999 Operation *definingOp = val.getDefiningOp();
1000 if (!definingOp) {
1001 propagateIfChanged(valLattice, changed);
1002 return;
1003 }
1004
1005 const Field &f = field.get();
1006
1007 // This is a rules-based operation. If we have a rule for a given operation,
1008 // then we can make some kind of update, otherwise we leave the intervals
1009 // as is.
1010 // - First we'll define all the rules so the type switch can be less messy
1011
1012 // cmp.<pred> restricts each side of the comparison if the result is known.
1013 auto cmpCase = [&](CmpOp cmpOp) {
1014 // Cmp output range is [0, 1], so in order to do something, we must have newInterval
1015 // either "true" (1) or "false" (0).
1016 // -- In the case of a contradictory circuit, however, the cmp result is allowed
1017 // to be empty.
1018 ensure(
1019 newInterval.isBoolean() || newInterval.isEmpty(),
1020 "new interval for CmpOp is not boolean or empty"
1021 );
1022 if (!newInterval.isDegenerate()) {
1023 // The comparison result is unknown, so we can't update the operand ranges
1024 return;
1025 }
1026
1027 bool cmpTrue = newInterval.rhs() == f.one();
1028
1029 Value lhs = cmpOp.getLhs(), rhs = cmpOp.getRhs();
1030 auto lhsLat = getLatticeElement(lhs), rhsLat = getLatticeElement(rhs);
1031 ExpressionValue lhsExpr = lhsLat->getValue().getScalarValue(),
1032 rhsExpr = rhsLat->getValue().getScalarValue();
1033
1034 Interval newLhsInterval, newRhsInterval;
1035 const Interval &lhsInterval = lhsExpr.getInterval();
1036 const Interval &rhsInterval = rhsExpr.getInterval();
1037
1038 FeltCmpPredicate pred = cmpOp.getPredicate();
1039 // predicate cases
1040 auto eqCase = [&]() {
1041 return (pred == FeltCmpPredicate::EQ && cmpTrue) ||
1042 (pred == FeltCmpPredicate::NE && !cmpTrue);
1043 };
1044 auto neCase = [&]() {
1045 return (pred == FeltCmpPredicate::NE && cmpTrue) ||
1046 (pred == FeltCmpPredicate::EQ && !cmpTrue);
1047 };
1048 auto ltCase = [&]() {
1049 return (pred == FeltCmpPredicate::LT && cmpTrue) ||
1050 (pred == FeltCmpPredicate::GE && !cmpTrue);
1051 };
1052 auto leCase = [&]() {
1053 return (pred == FeltCmpPredicate::LE && cmpTrue) ||
1054 (pred == FeltCmpPredicate::GT && !cmpTrue);
1055 };
1056 auto gtCase = [&]() {
1057 return (pred == FeltCmpPredicate::GT && cmpTrue) ||
1058 (pred == FeltCmpPredicate::LE && !cmpTrue);
1059 };
1060 auto geCase = [&]() {
1061 return (pred == FeltCmpPredicate::GE && cmpTrue) ||
1062 (pred == FeltCmpPredicate::LT && !cmpTrue);
1063 };
1064
1065 // new intervals based on case
1066 if (eqCase()) {
1067 newLhsInterval = newRhsInterval = lhsInterval.intersect(rhsInterval);
1068 } else if (neCase()) {
1069 if (lhsInterval.isDegenerate() && rhsInterval.isDegenerate() && lhsInterval == rhsInterval) {
1070 // In this case, we know lhs and rhs cannot satisfy this assertion, so they have
1071 // an empty value range.
1072 newLhsInterval = newRhsInterval = Interval::Empty(f);
1073 } else if (lhsInterval.isDegenerate()) {
1074 // rhs must not overlap with lhs
1075 newLhsInterval = lhsInterval;
1076 newRhsInterval = rhsInterval.difference(lhsInterval);
1077 } else if (rhsInterval.isDegenerate()) {
1078 // lhs must not overlap with rhs
1079 newLhsInterval = lhsInterval.difference(rhsInterval);
1080 newRhsInterval = rhsInterval;
1081 } else {
1082 // Leave unchanged
1083 newLhsInterval = lhsInterval;
1084 newRhsInterval = rhsInterval;
1085 }
1086 } else if (ltCase()) {
1087 newLhsInterval = lhsInterval.toUnreduced().computeLTPart(rhsInterval.toUnreduced()).reduce(f);
1088 newRhsInterval = rhsInterval.toUnreduced().computeGEPart(lhsInterval.toUnreduced()).reduce(f);
1089 } else if (leCase()) {
1090 newLhsInterval = lhsInterval.toUnreduced().computeLEPart(rhsInterval.toUnreduced()).reduce(f);
1091 newRhsInterval = rhsInterval.toUnreduced().computeGTPart(lhsInterval.toUnreduced()).reduce(f);
1092 } else if (gtCase()) {
1093 newLhsInterval = lhsInterval.toUnreduced().computeGTPart(rhsInterval.toUnreduced()).reduce(f);
1094 newRhsInterval = rhsInterval.toUnreduced().computeLEPart(lhsInterval.toUnreduced()).reduce(f);
1095 } else if (geCase()) {
1096 newLhsInterval = lhsInterval.toUnreduced().computeGEPart(rhsInterval.toUnreduced()).reduce(f);
1097 newRhsInterval = rhsInterval.toUnreduced().computeLTPart(lhsInterval.toUnreduced()).reduce(f);
1098 } else {
1099 cmpOp->emitWarning("unhandled cmp predicate").report();
1100 return;
1101 }
1102
1103 // Now we recurse to each operand
1104 applyInterval(cmpOp, lhs, newLhsInterval);
1105 applyInterval(cmpOp, rhs, newRhsInterval);
1106 };
1107
1108 // Multiplication cases:
1109 // - If the result of a multiplication is non-zero, then both operands must be
1110 // non-zero.
1111 // - If one operand is a constant, we can propagate the new interval when multiplied
1112 // by the multiplicative inverse of the constant.
1113 auto mulCase = [&](MulFeltOp mulOp) {
1114 // We check for the constant case first.
1115 auto constCase = [&](FeltConstantOp constOperand, Value multiplicand) {
1116 auto latVal = getLatticeElement(multiplicand)->getValue().getScalarValue();
1117 APInt constVal = constOperand.getValue();
1118 if (constVal.isZero()) {
1119 // There's no inverse for zero, so we do nothing.
1120 return;
1121 }
1122 Interval updatedInterval = newInterval * Interval::Degenerate(f, f.inv(constVal));
1123 applyInterval(mulOp, multiplicand, updatedInterval);
1124 };
1125
1126 Value lhs = mulOp.getLhs(), rhs = mulOp.getRhs();
1127
1128 auto lhsConstOp = dyn_cast_if_present<FeltConstantOp>(lhs.getDefiningOp());
1129 auto rhsConstOp = dyn_cast_if_present<FeltConstantOp>(rhs.getDefiningOp());
1130 // If both are consts, we don't need to do anything
1131 if (lhsConstOp && rhsConstOp) {
1132 return;
1133 } else if (lhsConstOp) {
1134 constCase(lhsConstOp, rhs);
1135 return;
1136 } else if (rhsConstOp) {
1137 constCase(rhsConstOp, lhs);
1138 return;
1139 }
1140
1141 // Otherwise, try to propagate non-zero information.
1142 auto zeroInt = Interval::Degenerate(f, f.zero());
1143 if (newInterval.intersect(zeroInt).isNotEmpty()) {
1144 // The multiplication may be zero, so we can't reduce the operands to be non-zero
1145 return;
1146 }
1147
1148 auto lhsLat = getLatticeElement(lhs), rhsLat = getLatticeElement(rhs);
1149 ExpressionValue lhsExpr = lhsLat->getValue().getScalarValue(),
1150 rhsExpr = rhsLat->getValue().getScalarValue();
1151 Interval newLhsInterval = lhsExpr.getInterval().difference(zeroInt);
1152 Interval newRhsInterval = rhsExpr.getInterval().difference(zeroInt);
1153 applyInterval(mulOp, lhs, newLhsInterval);
1154 applyInterval(mulOp, rhs, newRhsInterval);
1155 };
1156
1157 auto addCase = [&](AddFeltOp addOp) {
1158 Value lhs = addOp.getLhs(), rhs = addOp.getRhs();
1159 Lattice *lhsLat = getLatticeElement(lhs), *rhsLat = getLatticeElement(rhs);
1160 ExpressionValue lhsVal = lhsLat->getValue().getScalarValue();
1161 ExpressionValue rhsVal = rhsLat->getValue().getScalarValue();
1162
1163 const Interval &currLhsInt = lhsVal.getInterval(), &currRhsInt = rhsVal.getInterval();
1164
1165 Interval derivedLhsInt = newInterval - currRhsInt;
1166 Interval derivedRhsInt = newInterval - currLhsInt;
1167
1168 Interval finalLhsInt = currLhsInt.intersect(derivedLhsInt);
1169 Interval finalRhsInt = currRhsInt.intersect(derivedRhsInt);
1170
1171 applyInterval(addOp, lhs, finalLhsInt);
1172 applyInterval(addOp, rhs, finalRhsInt);
1173 };
1174
1175 auto subCase = [&](SubFeltOp subOp) {
1176 Value lhs = subOp.getLhs(), rhs = subOp.getRhs();
1177 Lattice *lhsLat = getLatticeElement(lhs), *rhsLat = getLatticeElement(rhs);
1178 ExpressionValue lhsVal = lhsLat->getValue().getScalarValue();
1179 ExpressionValue rhsVal = rhsLat->getValue().getScalarValue();
1180
1181 const Interval &currLhsInt = lhsVal.getInterval(), &currRhsInt = rhsVal.getInterval();
1182
1183 Interval derivedLhsInt = newInterval + currRhsInt;
1184 Interval derivedRhsInt = currLhsInt - newInterval;
1185
1186 Interval finalLhsInt = currLhsInt.intersect(derivedLhsInt);
1187 Interval finalRhsInt = currRhsInt.intersect(derivedRhsInt);
1188
1189 applyInterval(subOp, lhs, finalLhsInt);
1190 applyInterval(subOp, rhs, finalRhsInt);
1191 };
1192
1193 auto selectCase = [&](arith::SelectOp selectOp) {
1194 Value cond = selectOp.getCondition();
1195 Value trueVal = selectOp.getTrueValue();
1196 Value falseVal = selectOp.getFalseValue();
1197
1198 ExpressionValue condExpr = getLatticeElement(cond)->getValue().getScalarValue();
1199 ExpressionValue trueExpr = getLatticeElement(trueVal)->getValue().getScalarValue();
1200 ExpressionValue falseExpr = getLatticeElement(falseVal)->getValue().getScalarValue();
1201
1202 const Interval &condInterval = condExpr.getInterval();
1203 if (condInterval.isDegenerate() && condInterval.rhs() == f.one()) {
1204 applyInterval(selectOp, trueVal, newInterval);
1205 return;
1206 }
1207 if (condInterval.isDegenerate() && condInterval.rhs() == f.zero()) {
1208 applyInterval(selectOp, falseVal, newInterval);
1209 return;
1210 }
1211
1212 Interval trueOverlap = trueExpr.getInterval().intersect(newInterval);
1213 Interval falseOverlap = falseExpr.getInterval().intersect(newInterval);
1214 bool truePossible = trueOverlap.isNotEmpty();
1215 bool falsePossible = falseOverlap.isNotEmpty();
1216
1217 if (truePossible && !falsePossible) {
1218 applyInterval(selectOp, cond, Interval::True(f));
1219 applyInterval(selectOp, trueVal, newInterval);
1220 return;
1221 }
1222 if (!truePossible && falsePossible) {
1223 applyInterval(selectOp, cond, Interval::False(f));
1224 applyInterval(selectOp, falseVal, newInterval);
1225 return;
1226 }
1227 if (!truePossible && !falsePossible) {
1228 applyInterval(selectOp, cond, Interval::Empty(f));
1229 }
1230 };
1231
1232 auto readmCase = [&](MemberReadOp) {
1233 SourceRefLatticeValue sourceRefVal = getSourceRefState(val);
1234
1235 if (sourceRefVal.isSingleValue()) {
1236 const SourceRef &ref = sourceRefVal.getSingleValue();
1237 readResults[ref].insert(valLattice);
1238
1239 // Also propagate to all other member read results for this member
1240 for (Lattice *l : readResults[ref]) {
1241 if (l != valLattice) {
1242 propagateIfChanged(l, l->setValue(newLatticeVal));
1243 }
1244 }
1245 }
1246 };
1247
1248 auto readArrCase = [&](ReadArrayOp) {
1249 auto arrayRef = getArrayAccessRef(valUser, llvm::cast<ReadArrayOp>(definingOp));
1250 if (succeeded(arrayRef)) {
1251 readResults[*arrayRef].insert(valLattice);
1252
1253 for (Lattice *l : readResults[*arrayRef]) {
1254 if (l != valLattice) {
1255 propagateIfChanged(l, l->setValue(newLatticeVal));
1256 }
1257 }
1258 }
1259
1260 SourceRefLatticeValue sourceRefVal = getSourceRefState(val);
1261
1262 if (sourceRefVal.isSingleValue()) {
1263 const SourceRef &ref = sourceRefVal.getSingleValue();
1264 readResults[ref].insert(valLattice);
1265
1266 // Also propagate to all other member read results for this member
1267 for (Lattice *l : readResults[ref]) {
1268 if (l != valLattice) {
1269 propagateIfChanged(l, l->setValue(newLatticeVal));
1270 }
1271 }
1272 }
1273 };
1274
1275 // For casts, just pass the interval along to the cast's operand.
1276 auto castCase = [&](Operation *op) { applyInterval(op, op->getOperand(0), newInterval); };
1277
1278 // - Apply the rules given the op.
1279 // NOTE: disabling clang-format for this because it makes the last case statement
1280 // look ugly.
1281 // clang-format off
1282 TypeSwitch<Operation *>(definingOp)
1283 .Case<CmpOp>([&](auto op) { cmpCase(op); })
1284 .Case<AddFeltOp>([&](auto op) { return addCase(op); })
1285 .Case<SubFeltOp>([&](auto op) { return subCase(op); })
1286 .Case<MulFeltOp>([&](auto op) { mulCase(op); })
1287 .Case<arith::SelectOp>([&](auto op) { selectCase(op); })
1288 .Case<MemberReadOp>([&](auto op){ readmCase(op); })
1289 .Case<ReadArrayOp>([&](auto op){ readArrCase(op); })
1290 .Case<IntToFeltOp, FeltToIndexOp>([&](auto op) { castCase(op); })
1291 .Default([&](Operation *) { });
1292 // clang-format on
1293
1294 // Propagate after recursion to avoid having recursive calls unset the value.
1295 propagateIfChanged(valLattice, changed);
1296}
1297
1298FailureOr<std::pair<DenseSet<Value>, Interval>>
1299IntervalDataFlowAnalysis::getGeneralizedDecompInterval(Operation *baseOp, Value lhs, Value rhs) {
1300 auto isZeroConst = [this](Value v) {
1301 Operation *op = v.getDefiningOp();
1302 if (!op) {
1303 return false;
1304 }
1305 if (!isConstOp(op)) {
1306 return false;
1307 }
1308 return getConst(op) == field.get().zero();
1309 };
1310 bool lhsIsZero = isZeroConst(lhs), rhsIsZero = isZeroConst(rhs);
1311 Value exprTree = nullptr;
1312 if (lhsIsZero && !rhsIsZero) {
1313 exprTree = rhs;
1314 } else if (!lhsIsZero && rhsIsZero) {
1315 exprTree = lhs;
1316 } else {
1317 return failure();
1318 }
1319
1320 // We now explore the expression tree for multiplications of subtractions/signal values.
1321 std::optional<SourceRef> signalRef = std::nullopt;
1322 DenseSet<Value> signalVals;
1323 SmallVector<DynamicAPInt> consts;
1324 SmallVector<Value> frontier {exprTree};
1325 while (!frontier.empty()) {
1326 Value v = frontier.back();
1327 frontier.pop_back();
1328 Operation *op = v.getDefiningOp();
1329
1330 FeltConstantOp c;
1331 Value signalVal;
1332 auto handleRefValue = [this, &signalRef, &signalVal, &signalVals]() {
1333 SourceRefLatticeValue refSet = getSourceRefState(signalVal);
1334 if (!refSet.isScalar() || !refSet.isSingleValue()) {
1335 return failure();
1336 }
1337 SourceRef r = refSet.getSingleValue();
1338 if (signalRef.has_value() && signalRef.value() != r) {
1339 return failure();
1340 } else if (!signalRef.has_value()) {
1341 signalRef = r;
1342 }
1343 signalVals.insert(signalVal);
1344 return success();
1345 };
1346
1347 auto subPattern = m_CommutativeOp<SubFeltOp>(m_RefValue(&signalVal), m_Constant(&c));
1348 if (op && matchPattern(op, subPattern)) {
1349 if (failed(handleRefValue())) {
1350 return failure();
1351 }
1352 auto constInt = APSInt(c.getValue());
1353 consts.push_back(field.get().reduce(constInt));
1354 continue;
1355 } else if (m_RefValue(&signalVal).match(v)) {
1356 if (failed(handleRefValue())) {
1357 return failure();
1358 }
1359 consts.push_back(field.get().zero());
1360 continue;
1361 }
1362
1363 Value a, b;
1364 auto mulPattern = m_CommutativeOp<MulFeltOp>(matchers::m_Any(&a), matchers::m_Any(&b));
1365 if (op && matchPattern(op, mulPattern)) {
1366 frontier.push_back(a);
1367 frontier.push_back(b);
1368 continue;
1369 }
1370
1371 return failure();
1372 }
1373
1374 // Now, we aggregate the Interval. If we have sparse values (e.g., 0, 2, 4),
1375 // we will create a larger range of [0, 4], since we don't support multiple intervals.
1376 std::sort(consts.begin(), consts.end());
1377 Interval iv = UnreducedInterval(consts.front(), consts.back()).reduce(field.get());
1378 return std::make_pair(std::move(signalVals), iv);
1379}
1380
1381/* StructIntervals */
1382
1384 mlir::DataFlowSolver &solver, const IntervalAnalysisContext &ctx
1385) {
1386
1387 auto computeIntervalsImpl = [&solver, &ctx, this](
1388 FuncDefOp fn, llvm::MapVector<SourceRef, Interval> &memberRanges,
1389 llvm::SetVector<ExpressionValue> & /*solverConstraints*/
1390 ) {
1391 // Since every lattice value does not contain every value, we will traverse
1392 // the function backwards (from most up-to-date to least-up-to-date lattices)
1393 // searching for the source refs. Once a source ref is found, we remove it
1394 // from the search set.
1395
1396 SourceRefSet searchSet;
1397 for (const auto &ref : SourceRef::getAllSourceRefs(structDef, fn)) {
1398 // We only want to compute intervals for field elements and not composite types.
1399 if (!ref.isScalar()) {
1400 continue;
1401 }
1402 searchSet.insert(ref);
1403 }
1404
1405 // Iterate over arguments
1406 for (BlockArgument arg : fn.getArguments()) {
1407 SourceRef ref {arg};
1408 if (searchSet.erase(ref)) {
1409 const IntervalAnalysisLattice *lattice = solver.lookupState<IntervalAnalysisLattice>(arg);
1410 // If we never referenced this argument, use a default value
1411 ExpressionValue expr = lattice->getValue().getScalarValue();
1412 if (!expr.getExpr()) {
1413 expr = expr.withInterval(Interval::Entire(ctx.getField()));
1414 }
1415 memberRanges[ref] = expr.getInterval();
1416 assert(memberRanges[ref].getField() == ctx.getField() && "bad interval defaults");
1417 }
1418 }
1419
1420 // Aggregate all read intervals for a ref. A single ref may be read at multiple program
1421 // points with different precision, so picking an arbitrary lattice from the DenseSet is
1422 // nondeterministic. Joining preserves the overapproximation regardless of iteration order.
1423 for (const auto &[ref, lattices] : ctx.intervalDFA->getReadResults()) {
1424 if (!lattices.empty() && searchSet.erase(ref)) {
1425 Interval joinedInterval = Interval::Empty(ctx.getField());
1426 for (const IntervalAnalysisLattice *lattice : lattices) {
1427 joinedInterval = joinedInterval.join(lattice->getValue().getScalarValue().getInterval());
1428 }
1429 memberRanges[ref] = joinedInterval;
1430 assert(memberRanges[ref].getField() == ctx.getField() && "bad interval defaults");
1431 }
1432 }
1433
1434 for (const auto &[ref, val] : ctx.intervalDFA->getWriteResults()) {
1435 if (searchSet.erase(ref)) {
1436 memberRanges[ref] = val.getInterval();
1437 assert(memberRanges[ref].getField() == ctx.getField() && "bad interval defaults");
1438 }
1439 }
1440
1441 // For all unfound refs, default to the entire range.
1442 for (const auto &ref : searchSet) {
1443 memberRanges[ref] = Interval::Entire(ctx.getField());
1444 }
1445
1446 // Sort the outputs since we assembled things out of order.
1447 //
1448 // `llvm::MapVector` maintains an internal key -> index map. Sorting it in
1449 // place corrupts lookup semantics because the backing vector is reordered
1450 // without rebuilding that map. Reinsert into a fresh MapVector instead.
1451 llvm::SmallVector<std::pair<SourceRef, Interval>> sortedRanges;
1452 sortedRanges.reserve(memberRanges.size());
1453 for (const auto &[ref, interval] : memberRanges) {
1454 sortedRanges.emplace_back(ref, interval);
1455 }
1456 llvm::sort(sortedRanges, [](const auto &a, const auto &b) { return a.first < b.first; });
1457 memberRanges.clear();
1458 for (auto &[ref, interval] : sortedRanges) {
1459 memberRanges[ref] = interval;
1460 }
1461 };
1462
1463 computeIntervalsImpl(structDef.getComputeFuncOp(), computeMemberRanges, computeSolverConstraints);
1464 computeIntervalsImpl(
1465 structDef.getConstrainFuncOp(), constrainMemberRanges, constrainSolverConstraints
1466 );
1467
1468 return success();
1469}
1470
1471void StructIntervals::print(mlir::raw_ostream &os, bool withConstraints, bool printCompute) const {
1472 auto writeIntervals =
1473 [&os, &withConstraints](
1474 const char *fnName, const llvm::MapVector<SourceRef, Interval> &memberRanges,
1475 const llvm::SetVector<ExpressionValue> &solverConstraints, bool printName
1476 ) {
1477 int indent = 4;
1478 if (printName) {
1479 os << '\n';
1480 os.indent(indent) << fnName << " {";
1481 indent += 4;
1482 }
1483
1484 if (memberRanges.empty()) {
1485 os << "}\n";
1486 return;
1487 }
1488
1489 for (const auto &[ref, interval] : memberRanges) {
1490 os << '\n';
1491 os.indent(indent) << ref << " in " << interval;
1492 }
1493
1494 if (withConstraints) {
1495 os << "\n\n";
1496 os.indent(indent) << "Solver Constraints { ";
1497 if (solverConstraints.empty()) {
1498 os << "}\n";
1499 } else {
1500 for (const auto &e : solverConstraints) {
1501 os << '\n';
1502 os.indent(indent + 4);
1503 e.getExpr()->print(os);
1504 }
1505 os << '\n';
1506 os.indent(indent) << '}';
1507 }
1508 }
1509
1510 if (printName) {
1511 os << '\n';
1512 os.indent(indent - 4) << '}';
1513 }
1514 };
1515
1516 os << "StructIntervals { ";
1517 if (constrainMemberRanges.empty() && (!printCompute || computeMemberRanges.empty())) {
1518 os << "}\n";
1519 return;
1520 }
1521
1522 if (printCompute) {
1523 writeIntervals(FUNC_NAME_COMPUTE, computeMemberRanges, computeSolverConstraints, printCompute);
1524 }
1525 writeIntervals(
1526 FUNC_NAME_CONSTRAIN, constrainMemberRanges, constrainSolverConstraints, printCompute
1527 );
1528
1529 os << "\n}\n";
1530}
1531
1532} // namespace llzk
Tracks a solver expression and an interval range for that expression.
ExpressionValue withExpression(const llvm::SMTExprRef &newExpr) const
Return the current expression with a new SMT expression.
const Interval & getInterval() const
ExpressionValue withInterval(const Interval &newInterval) const
Return the current expression with a new interval.
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
llvm::SMTExprRef getExpr() const
bool isBoolSort(const llvm::SMTSolverRef &solver) const
const Field & getField() const
Information about the prime finite field used for the interval analysis.
Definition Field.h:35
llvm::DynamicAPInt zero() const
Returns 0 at the bitwidth of the field.
Definition Field.h:80
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
Definition Field.h:71
llvm::DynamicAPInt one() const
Returns 1 at the bitwidth of the field.
Definition Field.h:83
llvm::DynamicAPInt inv(const llvm::DynamicAPInt &i) const
Returns the multiplicative inverse of i in prime field p.
unsigned bitWidth() const
Definition Field.h:106
llvm::SMTExprRef createSymbol(const llvm::SMTSolverRef &solver, const char *name) const
Create a SMT solver symbol with the current field's bitwidth.
Definition Field.h:111
const LatticeValue & getValue() const
mlir::ChangeResult setValue(const LatticeValue &val)
IntervalAnalysisLatticeValue LatticeValue
mlir::ChangeResult meet(const AbstractSparseLattice &other) override
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult join(const AbstractSparseLattice &other) override
mlir::ChangeResult addSolverConstraint(const ExpressionValue &e)
mlir::LogicalResult visitOperation(mlir::Operation *op, mlir::ArrayRef< const Lattice * > operands, mlir::ArrayRef< Lattice * > results) override
Visit an operation with the lattices of its operands.
llvm::SMTExprRef getOrCreateSymbol(const SourceRef &r)
Either return the existing SMT expression that corresponds to the SourceRef, or create one.
const llvm::DenseMap< SourceRef, ExpressionValue > & getWriteResults() const
const llvm::DenseMap< SourceRef, llvm::DenseSet< Lattice * > > & getReadResults() const
Intervals over a finite field.
Definition Intervals.h:200
bool isEmpty() const
Definition Intervals.h:308
static Interval True(const Field &f)
Definition Intervals.h:219
llvm::DynamicAPInt rhs() const
Definition Intervals.h:333
Interval intersect(const Interval &rhs) const
Intersect.
UnreducedInterval toUnreduced() const
Convert to an UnreducedInterval.
static Interval Boolean(const Field &f)
Definition Intervals.h:221
static Interval Entire(const Field &f)
Definition Intervals.h:223
bool isDegenerate() const
Definition Intervals.h:310
bool isNotEmpty() const
Definition Intervals.h:309
static Interval False(const Field &f)
Definition Intervals.h:217
llvm::DynamicAPInt lhs() const
Definition Intervals.h:332
Interval join(const Interval &rhs) const
Union.
static SourceRefLatticeValue getValueState(mlir::DataFlowSolver &solver, mlir::Value val)
Defines an index into an LLZK object.
Definition SourceRef.h:42
A value at a given point of the SourceRefLattice.
const SourceRef & getSingleValue() const
mlir::FailureOr< std::pair< SourceRefLatticeValue, mlir::ChangeResult > > extract(const std::vector< SourceRefIndex > &indices) const
Perform an array.extract or array.read operation, depending on how many indices are provided.
A reference to a "source", which is the base value from which other SSA values are derived.
Definition SourceRef.h:132
mlir::FailureOr< SourceRef > createChild(const SourceRefIndex &r) const
Definition SourceRef.h:339
std::vector< SourceRefIndex > Path
Definition SourceRef.h:134
static std::vector< SourceRef > getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, const SourceRef &root)
Produce all possible SourceRefs that are present starting from the given root.
mlir::Type getType() const
void print(mlir::raw_ostream &os, bool withConstraints=false, bool printCompute=false) const
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, const IntervalAnalysisContext &ctx)
UnreducedInterval computeLTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is guaranteed to be less than the rhs's max value.
Definition Intervals.cpp:63
UnreducedInterval computeGEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than or equal to the rhs's lower bound.
Definition Intervals.cpp:86
UnreducedInterval computeGTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than the rhs's lower bound.
Definition Intervals.cpp:78
Interval reduce(const Field &field) const
Reduce the interval to an interval in the given field.
Definition Intervals.cpp:23
UnreducedInterval computeLEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is less than or equal to the rhs's upper bound.
Definition Intervals.cpp:71
Helper for converting between linear and multi-dimensional indexing with checks to ensure indices are...
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
::mlir::Type getElementType() const
::llzk::boolean::FeltCmpPredicate getPredicate()
Definition Ops.cpp.inc:601
IntervalAnalysisLattice * getLatticeElement(mlir::Value value) override
ExpressionValue boolNot(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
ExpressionValue add(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
RefValueCapture m_RefValue()
Definition Matchers.h:69
ExpressionValue intersection(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
FailureOr< Interval > signedIntDiv(const Interval &lhs, const Interval &rhs)
Computes signed integer division with possibly non-Degenerate divisors.
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue shiftLeft(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue fallbackUnaryOp(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &val)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
void ensure(bool condition, const llvm::Twine &errMsg)
ExpressionValue boolXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue cmp(const llvm::SMTSolverRef &solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue neg(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
DynamicAPInt toDynamicAPInt(StringRef str)
llvm::SMTExprRef createFieldInverseExpr(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &val, StringRef suffix="")
ExpressionValue sintDiv(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue boolAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
FailureOr< Interval > unsignedIntDiv(const Interval &lhs, const Interval &rhs)
Computes unsigned integer division with possibly non-Degenerate divisors.
ExpressionValue div(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue mul(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
ExpressionValue boolToFelt(const llvm::SMTSolverRef &solver, const ExpressionValue &expr, unsigned bitwidth)
ConstantCapture m_Constant()
Definition Matchers.h:89
std::string buildStringViaInsertionOp(Args &&...args)
Generate a string by using the insertion operator (<<) to append all args to a stream backed by the r...
ExpressionValue bitOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue uintDiv(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue bitAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue shiftRight(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
APSInt toAPSInt(const DynamicAPInt &i)
ExpressionValue sub(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
auto m_CommutativeOp(LhsMatcher lhs, RhsMatcher rhs)
Definition Matchers.h:47
ExpressionValue bitXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue notOp(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
FailureOr< Interval > feltDiv(const Interval &lhs, const Interval &rhs)
Computes finite-field division by multiplying the dividend by the multiplicative inverse of the divis...
ExpressionValue boolOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue selectValue(const llvm::SMTSolverRef &solver, const ExpressionValue &cond, const ExpressionValue &trueVal, const ExpressionValue &falseVal)
Parameters and shared objects to pass to child analyses.
const Field & getField() const
IntervalDataFlowAnalysis * intervalDFA