18#include <mlir/Dialect/SCF/IR/SCF.h>
20#include <llvm/ADT/TypeSwitch.h>
37 const llvm::SMTSolverRef &solver, Operation *op,
const ExpressionValue &val,
43 DynamicAPInt invVal = field.
inv(iv.
lhs());
51 if (!suffix.empty()) {
52 symName += suffix.str();
54 llvm::SMTExprRef invSym = field.
createSymbol(solver, symName.c_str());
55 llvm::SMTExprRef one = solver->mkBitvector(APSInt::get(1), field.
bitWidth());
57 llvm::SMTExprRef mult = solver->mkBVMul(val.
getExpr(), invSym);
58 llvm::SMTExprRef
mod = solver->mkBVURem(mult, prime);
59 llvm::SMTExprRef constraint = solver->mkEqual(
mod, one);
60 solver->addConstraint(constraint);
65 if (expr ==
nullptr && rhs.expr ==
nullptr) {
68 if (expr ==
nullptr || rhs.expr ==
nullptr) {
71 return i == rhs.i && *expr == *rhs.expr;
76 llvm::SMTExprRef zero = solver->mkBitvector(mlir::APSInt::get(0), bitwidth);
77 llvm::SMTExprRef one = solver->mkBitvector(mlir::APSInt::get(1), bitwidth);
78 llvm::SMTExprRef boolToFeltConv = solver->mkIte(expr.
getExpr(), one, zero);
98 llvm::SMTExprRef resultExpr =
107 const auto *exprEq = solver->mkEqual(lhs.expr, rhs.expr);
114 res.i = lhs.i + rhs.i;
115 res.expr = solver->mkBVAdd(lhs.expr, rhs.expr);
122 res.i = lhs.i - rhs.i;
123 res.expr = solver->mkBVSub(lhs.expr, rhs.expr);
130 res.i = lhs.i * rhs.i;
131 res.expr = solver->mkBVMul(lhs.expr, rhs.expr);
139 auto divRes =
feltDiv(lhs.i, rhs.i);
140 if (failed(divRes)) {
147 "non-degenerate felt.div divisors are not tracked precisely, and the divisor may "
148 "contain zero. Range of division result will be treated as unbounded."
153 "non-degenerate felt.div divisors are not tracked precisely because precise field "
154 "division over intervals would require enumerating divisor inverses. Range of "
155 "division result will be treated as unbounded."
161 "divisor is zero, leading to a divide-by-zero error. Range of division result will "
162 "be treated as unbounded."
171 res.expr = solver->mkBVMul(lhs.expr, invExpr);
176 const llvm::SMTSolverRef &solver, Operation *op,
const ExpressionValue &lhs,
181 if (failed(divRes)) {
183 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
184 " Range of division result will be treated as unbounded."
191 res.expr = solver->mkBVUDiv(lhs.expr, rhs.expr);
196 const llvm::SMTSolverRef &solver, Operation *op,
const ExpressionValue &lhs,
201 if (failed(divRes)) {
203 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
204 " Range of division result will be treated as unbounded."
211 res.expr = solver->mkBVSDiv(lhs.expr, rhs.expr);
218 res.i = lhs.i % rhs.i;
219 res.expr = solver->mkBVURem(lhs.expr, rhs.expr);
226 res.i = lhs.i & rhs.i;
227 res.expr = solver->mkBVAnd(lhs.expr, rhs.expr);
234 res.i = lhs.i | rhs.i;
235 res.expr = solver->mkBVOr(lhs.expr, rhs.expr);
242 return boolXor(solver, lhs, rhs);
246 res.i = lhs.i ^ rhs.i;
247 res.expr = solver->mkBVXor(lhs.expr, rhs.expr);
255 res.i = lhs.i << rhs.i;
256 res.expr = solver->mkBVShl(lhs.expr, rhs.expr);
264 res.i = lhs.i >> rhs.i;
265 res.expr = solver->mkBVLshr(lhs.expr, rhs.expr);
277 case FeltCmpPredicate::EQ:
278 res.expr = solver->mkEqual(lhs.expr, rhs.expr);
285 case FeltCmpPredicate::NE:
286 res.expr = solver->mkNot(solver->mkEqual(lhs.expr, rhs.expr));
293 case FeltCmpPredicate::LT:
294 res.expr = solver->mkBVUlt(lhs.expr, rhs.expr);
302 case FeltCmpPredicate::LE:
303 res.expr = solver->mkBVUle(lhs.expr, rhs.expr);
311 case FeltCmpPredicate::GT:
312 res.expr = solver->mkBVUgt(lhs.expr, rhs.expr);
320 case FeltCmpPredicate::GE:
321 res.expr = solver->mkBVUge(lhs.expr, rhs.expr);
337 res.expr = solver->mkAnd(lhs.expr, rhs.expr);
344 res.i =
boolOr(lhs.i, rhs.i);
345 res.expr = solver->mkOr(lhs.expr, rhs.expr);
354 res.expr = solver->mkAnd(
355 solver->mkOr(lhs.expr, rhs.expr), solver->mkNot(solver->mkAnd(lhs.expr, rhs.expr))
363 res.expr = solver->mkBVNeg(val.expr);
370 res.expr = solver->mkBVNot(val.expr);
377 res.expr = solver->mkNot(val.expr);
386 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
389 }).Default([](Operation *unsupported) {
390 llvm::report_fatal_error(
391 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
403 os <<
"<null expression>";
406 os <<
" ( interval: " << i <<
" )";
414 return ChangeResult::NoChange;
420 return ChangeResult::NoChange;
424 os <<
"IntervalAnalysisLattice { " << val <<
" }";
429 return ChangeResult::NoChange;
432 return ChangeResult::Change;
441 if (!constraints.contains(e)) {
442 constraints.insert(e);
443 return ChangeResult::Change;
445 return ChangeResult::NoChange;
454std::vector<SourceRefIndex> IntervalDataFlowAnalysis::getArrayAccessIndices(
455 Operation *baseOp, ArrayAccessOpInterface arrayAccessOp
457 std::vector<SourceRefIndex> indices;
458 ArrayType arrayType = arrayAccessOp.getArrRefType();
459 size_t numIndices = arrayAccessOp.getIndices().size();
460 indices.reserve(numIndices);
462 for (
size_t i = 0; i < numIndices; ++i) {
463 Value idxOperand = arrayAccessOp.getIndices()[i];
464 SourceRefLatticeValue idxVals = getSourceRefState(idxOperand);
467 if (idxVals.isSingleValue() && idxVals.getSingleValue().isConstant()) {
468 indices.emplace_back(*idxVals.getSingleValue().getConstantValue());
470 auto lower = APInt::getZero(64);
471 APInt upper(64, arrayType.getDimSize(i));
472 indices.emplace_back(lower, upper);
479mlir::FailureOr<SourceRef> IntervalDataFlowAnalysis::getArrayAccessRef(
482 std::vector<SourceRefIndex> indices = getArrayAccessIndices(baseOp, arrayAccessOp);
483 Value arrayVal = arrayAccessOp.getArrRef();
484 if (
auto blockArg = llvm::dyn_cast<BlockArgument>(arrayVal)) {
485 return SourceRef(blockArg, std::move(indices));
487 if (
auto result = llvm::dyn_cast<OpResult>(arrayVal)) {
488 return SourceRef(result, std::move(indices));
494 if (
auto it = writeResults.find(ref); it != writeResults.end()) {
495 return it->second.getInterval();
498 if (ref.isConstantInt()) {
499 auto constVal = ref.getConstantValue();
500 if (succeeded(constVal)) {
505 if (ref.isRooted() && ref.getPath().empty()) {
506 auto rootVal = ref.getRoot();
507 if (succeeded(rootVal) && !llvm::isa<ArrayType, StructType>(rootVal->getType())) {
509 if (rootExpr.getExpr() !=
nullptr) {
510 return rootExpr.getInterval();
515 return getDefaultIntervalForType(ref.getType());
519 if (
auto it = writeResults.find(ref); it != writeResults.end()) {
522 return createUnknownValue(val).
withInterval(getRefInterval(ref));
525void IntervalDataFlowAnalysis::recordRefWrite(
528 ExpressionValue written = writeVal;
530 if (
auto it = writeResults.find(writtenRef); it != writeResults.end()) {
531 const ExpressionValue &old = it->second;
532 Interval combinedWrite = old.getInterval().join(written.getInterval());
533 if (old.getExpr() !=
nullptr && written.getExpr() !=
nullptr &&
534 *old.getExpr() == *written.getExpr()) {
535 writeResults[writtenRef] = old.withInterval(combinedWrite);
538 writeResults[writtenRef] = ExpressionValue(expr, combinedWrite);
541 writeResults[writtenRef] = written;
544 for (Lattice *readerLattice : readResults[writtenRef]) {
545 ExpressionValue prior = readerLattice->getValue().getScalarValue();
547 ExpressionValue newVal = prior.withInterval(
intersection);
548 propagateIfChanged(readerLattice, readerLattice->setValue(newVal));
553 Operation *op, ArrayRef<const Lattice *> operands, ArrayRef<Lattice *> results
562 if (operands.empty() && results.empty()) {
567 llvm::SmallVector<LatticeValue> operandVals;
568 llvm::SmallVector<std::optional<SourceRef>> operandRefs;
569 for (
unsigned opNum = 0; opNum < op->getNumOperands(); ++opNum) {
570 Value val = op->getOperand(opNum);
575 operandRefs.push_back(std::nullopt);
578 auto priorState = operands[opNum]->getValue();
579 if (priorState.getScalarValue().getExpr() !=
nullptr) {
580 operandVals.push_back(priorState);
584 if (
auto readArr = llvm::dyn_cast_if_present<ReadArrayOp>(val.getDefiningOp())) {
585 auto arrayRef = getArrayAccessRef(op, readArr);
586 if (succeeded(arrayRef)) {
587 if (
auto it = writeResults.find(*arrayRef); it != writeResults.end()) {
588 operandVals.emplace_back(it->second);
590 (void)operandLattice->
setValue(it->second);
599 Type valTy = val.getType();
600 if (llvm::isa<ArrayType, StructType>(valTy)) {
602 operandVals.emplace_back(anyVal);
606 ensure(refSet.
isScalar(),
"should have ruled out array values already");
614 "state of ", val,
" is empty; defining operation is unsupported by SourceRef analysis"
623 joinedInterval = joinedInterval.
join(getRefInterval(ref));
626 operandVals.emplace_back(anyVal);
629 operandVals.emplace_back(getRefValue(ref, val));
636 (void)operandLattice->
setValue(operandVals[opNum]);
641 llvm::DynamicAPInt constVal = getConst(op);
642 llvm::SMTExprRef expr;
643 if (isBoolConstOp(op)) {
644 expr = createConstBoolExpr(constVal != 0);
646 expr = createConstBitvectorExpr(constVal);
650 propagateIfChanged(results[0], results[0]->setValue(latticeVal));
651 }
else if (isArithmeticOp(op)) {
653 if (operands.size() == 2) {
654 result = performBinaryArithmetic(op, operandVals[0], operandVals[1]);
656 result = performUnaryArithmetic(op, operandVals[0]);
660 const ExpressionValue &prior = results[0]->getValue().getScalarValue();
664 propagateIfChanged(results[0], results[0]->setValue(result));
665 }
else if (
auto selectOp = llvm::dyn_cast<arith::SelectOp>(op)) {
667 smtSolver, operandVals[0].getScalarValue(), operandVals[1].getScalarValue(),
668 operandVals[2].getScalarValue()
670 const ExpressionValue &prior = results[0]->getValue().getScalarValue();
674 propagateIfChanged(results[0], results[0]->setValue(result));
675 }
else if (
EmitEqualityOp emitEq = llvm::dyn_cast<EmitEqualityOp>(op)) {
676 Value lhsVal = emitEq.getLhs(), rhsVal = emitEq.getRhs();
682 auto res = getGeneralizedDecompInterval(op, lhsVal, rhsVal);
683 if (succeeded(res)) {
684 for (Value signalVal : res->first) {
685 applyInterval(emitEq, signalVal, res->second);
693 applyInterval(emitEq, lhsVal, constrainInterval);
694 applyInterval(emitEq, rhsVal, constrainInterval);
695 }
else if (
auto assertOp = llvm::dyn_cast<AssertOp>(op)) {
698 Value cond = assertOp.getCondition();
701 auto assertExpr = operandVals[0].getScalarValue();
704 }
else if (
auto writem = llvm::dyn_cast<MemberWriteOp>(op)) {
707 auto cmp = writem.getComponent();
711 auto memberDefRes = writem.getMemberDefOp(tables);
712 if (succeeded(memberDefRes)) {
715 ensure(succeeded(memberRefRes),
"could not create SourceRef child for member write");
717 Type memberTy = writem.getVal().
getType();
718 if (!llvm::isa<ArrayType, StructType>(memberTy)) {
720 recordRefWrite(memberRef, writeVal);
723 std::optional<SourceRef> rhsPrefix;
724 if (operandRefs[1].has_value() && operandRefs[1]->isRooted()) {
725 rhsPrefix = operandRefs[1];
726 }
else if (
auto blockArg = llvm::dyn_cast<BlockArgument>(writem.getVal())) {
728 }
else if (
auto result = llvm::dyn_cast<OpResult>(writem.getVal())) {
732 if (rhsPrefix.has_value()) {
733 llvm::SmallVector<std::pair<SourceRef, ExpressionValue>> remappedWrites;
734 for (
const auto &[writtenRef, writtenVal] : writeResults) {
735 if (!writtenRef.isValidPrefix(*rhsPrefix)) {
739 auto translatedRef = writtenRef.translate(*rhsPrefix, memberRef);
740 ensure(succeeded(translatedRef),
"could not translate composite member write");
741 remappedWrites.emplace_back(*translatedRef, writtenVal);
744 for (
const auto &[translatedRef, translatedVal] : remappedWrites) {
745 recordRefWrite(translatedRef, translatedVal);
751 }
else if (
auto writeArr = llvm::dyn_cast<WriteArrayOp>(op)) {
753 auto arrayRef = getArrayAccessRef(op, writeArr);
754 if (succeeded(arrayRef)) {
755 recordRefWrite(*arrayRef, writeVal);
760 std::vector<SourceRefIndex> indices = getArrayAccessIndices(op, writeArr);
761 auto targetRefsRes = arrayVals.
extract(indices);
762 ensure(succeeded(targetRefsRes),
"could not create SourceRef child for array write");
763 auto [targetRefs, _] = *targetRefsRes;
764 ensure(targetRefs.isScalar(),
"array write must resolve to scalar references");
765 for (
const SourceRef &ref : targetRefs.getScalarValue()) {
766 recordRefWrite(ref, writeVal);
769 }
else if (
auto createArray = llvm::dyn_cast<CreateArrayOp>(op)) {
770 const auto &elements = createArray.getElements();
771 ArrayType arrayTy = createArray.getType();
774 if (!elements.empty() && !llvm::isa<ArrayType, StructType>(elemTy)) {
775 ensure(arrayTy.hasStaticShape(),
"array.new with explicit elements must have static shape");
777 std::cmp_equal(elements.size(), arrayTy.getNumElements()),
778 "array.new explicit initializer length must match array shape"
782 auto arrayRes = llvm::cast<OpResult>(createArray->getResult(0));
783 for (
unsigned i = 0; i < elements.size(); ++i) {
784 auto maybeIndices = indexGen.
delinearize(i, op->getContext());
785 ensure(maybeIndices.has_value(),
"could not delinearize array.new element index");
788 path.reserve(maybeIndices->size());
789 for (Attribute attr : *maybeIndices) {
790 auto idxAttr = llvm::dyn_cast<IntegerAttr>(attr);
791 ensure(idxAttr !=
nullptr,
"array.new delinearize should produce integer attributes");
792 path.emplace_back(idxAttr.getValue());
795 recordRefWrite(
SourceRef(arrayRes, std::move(path)), operandVals[i].getScalarValue());
798 }
else if (isa<IntToFeltOp, FeltToIndexOp>(op)) {
805 expr =
boolToFelt(smtSolver, expr, field.get().bitWidth());
807 propagateIfChanged(results[0], results[0]->setValue(expr));
808 }
else if (
auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
811 Operation *parent = op->getParentOp();
812 ensure(parent,
"yield operation must have parent operation");
814 for (
unsigned idx = 0; idx < yieldOp.getResults().size(); ++idx) {
815 Value parentRes = parent->getResult(idx);
821 if (
auto loopOp = llvm::dyn_cast<LoopLikeOpInterface>(parent)) {
825 if (exprVal.
getExpr() !=
nullptr) {
830 propagateIfChanged(resLattice, resLattice->
setValue(newResVal));
840 && !isDefinitionOp(op)
842 && !llvm::isa<CreateArrayOp, CreateStructOp, NonDetOp>(op)
844 op->emitWarning(
"unhandled operation, analysis may be incomplete").report();
851 auto it = refSymbols.find(r);
852 if (it != refSymbols.end()) {
855 llvm::SMTExprRef sym = createSymbol(r);
860llvm::SMTExprRef IntervalDataFlowAnalysis::createSymbol(mlir::Type ty,
const char *name)
const {
861 if (isBooleanType(ty)) {
862 return smtSolver->mkSymbol(name, smtSolver->getBoolSort());
864 return field.get().createSymbol(smtSolver, name);
867llvm::SMTExprRef IntervalDataFlowAnalysis::createSymbol(
const SourceRef &r)
const {
869 return createSymbol(r.getType(), name.c_str());
872llvm::SMTExprRef IntervalDataFlowAnalysis::createSymbol(Value v)
const {
874 return createSymbol(v.getType(), name.c_str());
877llvm::DynamicAPInt IntervalDataFlowAnalysis::getConst(Operation *op)
const {
878 ensure(isConstOp(op),
"op is not a const op");
882 llvm::DynamicAPInt fieldConst = TypeSwitch<Operation *, llvm::DynamicAPInt>(op)
883 .Case<FeltConstantOp>([&](
auto feltConst) {
884 llvm::APSInt constOpVal(feltConst.getValue());
885 return field.get().reduce(constOpVal);
887 .Case<arith::ConstantIndexOp>([&](
auto indexConst) {
888 return DynamicAPInt(indexConst.value());
890 .Case<arith::ConstantIntOp>([&](
auto intConst) {
891 auto valAttr = dyn_cast<IntegerAttr>(intConst.getValue());
892 ensure(valAttr !=
nullptr,
"arith::ConstantIntOp must have an IntegerAttr as its value");
895 .Default([](
auto *illegalOp) {
897 debug::Appender(err) <<
"unhandled getConst case: " << *illegalOp;
898 llvm::report_fatal_error(Twine(err));
899 return llvm::DynamicAPInt();
906 Operation *op,
const LatticeValue &a,
const LatticeValue &b
908 ensure(isArithmeticOp(op),
"is not arithmetic op");
910 auto lhs = a.getScalarValue(), rhs = b.getScalarValue();
911 ensure(lhs.getExpr(),
"cannot perform arithmetic over null lhs smt expr");
912 ensure(rhs.getExpr(),
"cannot perform arithmetic over null rhs smt expr");
915 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
916 .Case<AddFeltOp>([&](
auto) {
return add(smtSolver, lhs, rhs); })
917 .Case<SubFeltOp>([&](
auto) {
return sub(smtSolver, lhs, rhs); })
918 .Case<MulFeltOp>([&](
auto) {
return mul(smtSolver, lhs, rhs); })
919 .Case<DivFeltOp>([&](
auto) {
return div(smtSolver, op, lhs, rhs); })
920 .Case<UnsignedIntDivFeltOp>([&](
auto) {
return uintDiv(smtSolver, op, lhs, rhs); })
921 .Case<SignedIntDivFeltOp>([&](
auto) {
return sintDiv(smtSolver, op, lhs, rhs); })
922 .Case<UnsignedModFeltOp>([&](
auto) {
return mod(smtSolver, lhs, rhs); })
923 .Case<AndFeltOp>([&](
auto) {
return bitAnd(smtSolver, lhs, rhs); })
924 .Case<OrFeltOp>([&](
auto) {
return bitOr(smtSolver, lhs, rhs); })
925 .Case<XorFeltOp, arith::XOrIOp>([&](
auto) {
return bitXor(smtSolver, lhs, rhs); })
926 .Case<ShlFeltOp>([&](
auto) {
return shiftLeft(smtSolver, lhs, rhs); })
927 .Case<ShrFeltOp>([&](
auto) {
return shiftRight(smtSolver, lhs, rhs); })
928 .Case<CmpOp>([&](
auto cmpOp) {
return cmp(smtSolver, cmpOp, lhs, rhs); })
929 .Case<AndBoolOp>([&](
auto) {
return boolAnd(smtSolver, lhs, rhs); })
930 .Case<OrBoolOp>([&](
auto) {
return boolOr(smtSolver, lhs, rhs); })
931 .Case<XorBoolOp>([&](
auto) {
return boolXor(smtSolver, lhs, rhs); })
932 .Default([&](
auto *unsupported) {
935 "unsupported binary arithmetic operation"
938 return ExpressionValue();
942 ensure(res.getExpr(),
"arithmetic produced null smt expr");
947IntervalDataFlowAnalysis::performUnaryArithmetic(Operation *op,
const LatticeValue &a) {
948 ensure(isArithmeticOp(op),
"is not arithmetic op");
950 auto val = a.getScalarValue();
951 ensure(val.getExpr(),
"cannot perform arithmetic over null smt expr");
953 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
954 .Case<NegFeltOp>([&](
auto) {
return neg(smtSolver, val); })
955 .Case<NotFeltOp>([&](
auto) {
return notOp(smtSolver, val); })
956 .Case<NotBoolOp>([&](
auto) {
return boolNot(smtSolver, val); })
958 .Case<InvFeltOp>([&](
auto inv) {
960 }).Default([&](
auto *unsupported) {
963 "unsupported unary arithmetic operation, defaulting to over-approximated interval"
969 ensure(res.getExpr(),
"arithmetic produced null smt expr");
973void IntervalDataFlowAnalysis::applyInterval(Operation *valUser, Value val,
Interval newInterval) {
975 ExpressionValue oldLatticeVal = valLattice->getValue().getScalarValue();
978 ExpressionValue newLatticeVal = oldLatticeVal.withInterval(
intersection);
979 ChangeResult changed = valLattice->setValue(newLatticeVal);
981 if (
auto blockArg = llvm::dyn_cast<BlockArgument>(val)) {
982 auto fnOp = dyn_cast<FuncDefOp>(blockArg.getOwner()->getParentOp());
985 if (propagateInputConstraints && fnOp && fnOp.isStructConstrain() &&
986 blockArg.getArgNumber() > 0 && !newInterval.isEntire()) {
987 auto structOp = fnOp->getParentOfType<StructDefOp>();
988 FuncDefOp computeFn = structOp.getComputeFuncOp();
989 BlockArgument computeArg = computeFn.getArgument(blockArg.getArgNumber() - 1);
992 SourceRef ref(computeArg);
994 propagateIfChanged(computeEntryLattice, computeEntryLattice->setValue(newArgVal));
999 Operation *definingOp = val.getDefiningOp();
1001 propagateIfChanged(valLattice, changed);
1005 const Field &f = field.get();
1013 auto cmpCase = [&](CmpOp cmpOp) {
1019 newInterval.isBoolean() || newInterval.isEmpty(),
1020 "new interval for CmpOp is not boolean or empty"
1022 if (!newInterval.isDegenerate()) {
1027 bool cmpTrue = newInterval.rhs() == f.one();
1029 Value lhs = cmpOp.getLhs(), rhs = cmpOp.getRhs();
1031 ExpressionValue lhsExpr = lhsLat->getValue().getScalarValue(),
1032 rhsExpr = rhsLat->getValue().getScalarValue();
1034 Interval newLhsInterval, newRhsInterval;
1035 const Interval &lhsInterval = lhsExpr.getInterval();
1036 const Interval &rhsInterval = rhsExpr.getInterval();
1040 auto eqCase = [&]() {
1041 return (pred == FeltCmpPredicate::EQ && cmpTrue) ||
1042 (pred == FeltCmpPredicate::NE && !cmpTrue);
1044 auto neCase = [&]() {
1045 return (pred == FeltCmpPredicate::NE && cmpTrue) ||
1046 (pred == FeltCmpPredicate::EQ && !cmpTrue);
1048 auto ltCase = [&]() {
1049 return (pred == FeltCmpPredicate::LT && cmpTrue) ||
1050 (pred == FeltCmpPredicate::GE && !cmpTrue);
1052 auto leCase = [&]() {
1053 return (pred == FeltCmpPredicate::LE && cmpTrue) ||
1054 (pred == FeltCmpPredicate::GT && !cmpTrue);
1056 auto gtCase = [&]() {
1057 return (pred == FeltCmpPredicate::GT && cmpTrue) ||
1058 (pred == FeltCmpPredicate::LE && !cmpTrue);
1060 auto geCase = [&]() {
1061 return (pred == FeltCmpPredicate::GE && cmpTrue) ||
1062 (pred == FeltCmpPredicate::LT && !cmpTrue);
1067 newLhsInterval = newRhsInterval = lhsInterval.intersect(rhsInterval);
1068 }
else if (neCase()) {
1069 if (lhsInterval.isDegenerate() && rhsInterval.isDegenerate() && lhsInterval == rhsInterval) {
1073 }
else if (lhsInterval.isDegenerate()) {
1075 newLhsInterval = lhsInterval;
1076 newRhsInterval = rhsInterval.difference(lhsInterval);
1077 }
else if (rhsInterval.isDegenerate()) {
1079 newLhsInterval = lhsInterval.difference(rhsInterval);
1080 newRhsInterval = rhsInterval;
1083 newLhsInterval = lhsInterval;
1084 newRhsInterval = rhsInterval;
1086 }
else if (ltCase()) {
1087 newLhsInterval = lhsInterval.toUnreduced().computeLTPart(rhsInterval.toUnreduced()).reduce(f);
1088 newRhsInterval = rhsInterval.toUnreduced().computeGEPart(lhsInterval.toUnreduced()).reduce(f);
1089 }
else if (leCase()) {
1090 newLhsInterval = lhsInterval.toUnreduced().computeLEPart(rhsInterval.toUnreduced()).reduce(f);
1091 newRhsInterval = rhsInterval.toUnreduced().computeGTPart(lhsInterval.toUnreduced()).reduce(f);
1092 }
else if (gtCase()) {
1093 newLhsInterval = lhsInterval.toUnreduced().computeGTPart(rhsInterval.toUnreduced()).reduce(f);
1094 newRhsInterval = rhsInterval.toUnreduced().computeLEPart(lhsInterval.toUnreduced()).reduce(f);
1095 }
else if (geCase()) {
1096 newLhsInterval = lhsInterval.toUnreduced().computeGEPart(rhsInterval.toUnreduced()).reduce(f);
1097 newRhsInterval = rhsInterval.toUnreduced().computeLTPart(lhsInterval.toUnreduced()).reduce(f);
1099 cmpOp->emitWarning(
"unhandled cmp predicate").report();
1104 applyInterval(cmpOp, lhs, newLhsInterval);
1105 applyInterval(cmpOp, rhs, newRhsInterval);
1113 auto mulCase = [&](MulFeltOp mulOp) {
1115 auto constCase = [&](FeltConstantOp constOperand, Value multiplicand) {
1117 APInt constVal = constOperand.getValue();
1118 if (constVal.isZero()) {
1123 applyInterval(mulOp, multiplicand, updatedInterval);
1126 Value lhs = mulOp.getLhs(), rhs = mulOp.getRhs();
1128 auto lhsConstOp = dyn_cast_if_present<FeltConstantOp>(lhs.getDefiningOp());
1129 auto rhsConstOp = dyn_cast_if_present<FeltConstantOp>(rhs.getDefiningOp());
1131 if (lhsConstOp && rhsConstOp) {
1133 }
else if (lhsConstOp) {
1134 constCase(lhsConstOp, rhs);
1136 }
else if (rhsConstOp) {
1137 constCase(rhsConstOp, lhs);
1143 if (newInterval.intersect(zeroInt).isNotEmpty()) {
1149 ExpressionValue lhsExpr = lhsLat->getValue().getScalarValue(),
1150 rhsExpr = rhsLat->getValue().getScalarValue();
1151 Interval newLhsInterval = lhsExpr.getInterval().difference(zeroInt);
1152 Interval newRhsInterval = rhsExpr.getInterval().difference(zeroInt);
1153 applyInterval(mulOp, lhs, newLhsInterval);
1154 applyInterval(mulOp, rhs, newRhsInterval);
1157 auto addCase = [&](AddFeltOp addOp) {
1158 Value lhs = addOp.getLhs(), rhs = addOp.getRhs();
1160 ExpressionValue lhsVal = lhsLat->getValue().getScalarValue();
1161 ExpressionValue rhsVal = rhsLat->getValue().getScalarValue();
1163 const Interval &currLhsInt = lhsVal.getInterval(), &currRhsInt = rhsVal.getInterval();
1165 Interval derivedLhsInt = newInterval - currRhsInt;
1166 Interval derivedRhsInt = newInterval - currLhsInt;
1168 Interval finalLhsInt = currLhsInt.intersect(derivedLhsInt);
1169 Interval finalRhsInt = currRhsInt.intersect(derivedRhsInt);
1171 applyInterval(addOp, lhs, finalLhsInt);
1172 applyInterval(addOp, rhs, finalRhsInt);
1175 auto subCase = [&](SubFeltOp subOp) {
1176 Value lhs = subOp.getLhs(), rhs = subOp.getRhs();
1178 ExpressionValue lhsVal = lhsLat->getValue().getScalarValue();
1179 ExpressionValue rhsVal = rhsLat->getValue().getScalarValue();
1181 const Interval &currLhsInt = lhsVal.getInterval(), &currRhsInt = rhsVal.getInterval();
1183 Interval derivedLhsInt = newInterval + currRhsInt;
1184 Interval derivedRhsInt = currLhsInt - newInterval;
1186 Interval finalLhsInt = currLhsInt.intersect(derivedLhsInt);
1187 Interval finalRhsInt = currRhsInt.intersect(derivedRhsInt);
1189 applyInterval(subOp, lhs, finalLhsInt);
1190 applyInterval(subOp, rhs, finalRhsInt);
1193 auto selectCase = [&](arith::SelectOp selectOp) {
1194 Value cond = selectOp.getCondition();
1195 Value trueVal = selectOp.getTrueValue();
1196 Value falseVal = selectOp.getFalseValue();
1202 const Interval &condInterval = condExpr.getInterval();
1203 if (condInterval.isDegenerate() && condInterval.rhs() == f.one()) {
1204 applyInterval(selectOp, trueVal, newInterval);
1207 if (condInterval.isDegenerate() && condInterval.rhs() == f.zero()) {
1208 applyInterval(selectOp, falseVal, newInterval);
1212 Interval trueOverlap = trueExpr.getInterval().intersect(newInterval);
1213 Interval falseOverlap = falseExpr.getInterval().intersect(newInterval);
1214 bool truePossible = trueOverlap.isNotEmpty();
1215 bool falsePossible = falseOverlap.isNotEmpty();
1217 if (truePossible && !falsePossible) {
1219 applyInterval(selectOp, trueVal, newInterval);
1222 if (!truePossible && falsePossible) {
1224 applyInterval(selectOp, falseVal, newInterval);
1227 if (!truePossible && !falsePossible) {
1232 auto readmCase = [&](MemberReadOp) {
1233 SourceRefLatticeValue sourceRefVal = getSourceRefState(val);
1235 if (sourceRefVal.isSingleValue()) {
1236 const SourceRef &ref = sourceRefVal.getSingleValue();
1237 readResults[ref].insert(valLattice);
1240 for (Lattice *l : readResults[ref]) {
1241 if (l != valLattice) {
1242 propagateIfChanged(l, l->setValue(newLatticeVal));
1248 auto readArrCase = [&](ReadArrayOp) {
1249 auto arrayRef = getArrayAccessRef(valUser, llvm::cast<ReadArrayOp>(definingOp));
1250 if (succeeded(arrayRef)) {
1251 readResults[*arrayRef].insert(valLattice);
1253 for (Lattice *l : readResults[*arrayRef]) {
1254 if (l != valLattice) {
1255 propagateIfChanged(l, l->setValue(newLatticeVal));
1260 SourceRefLatticeValue sourceRefVal = getSourceRefState(val);
1262 if (sourceRefVal.isSingleValue()) {
1263 const SourceRef &ref = sourceRefVal.getSingleValue();
1264 readResults[ref].insert(valLattice);
1267 for (Lattice *l : readResults[ref]) {
1268 if (l != valLattice) {
1269 propagateIfChanged(l, l->setValue(newLatticeVal));
1276 auto castCase = [&](Operation *op) { applyInterval(op, op->getOperand(0), newInterval); };
1282 TypeSwitch<Operation *>(definingOp)
1283 .Case<CmpOp>([&](
auto op) { cmpCase(op); })
1284 .Case<AddFeltOp>([&](
auto op) {
return addCase(op); })
1285 .Case<SubFeltOp>([&](
auto op) {
return subCase(op); })
1286 .Case<MulFeltOp>([&](
auto op) { mulCase(op); })
1287 .Case<arith::SelectOp>([&](
auto op) { selectCase(op); })
1288 .Case<MemberReadOp>([&](
auto op){ readmCase(op); })
1289 .Case<ReadArrayOp>([&](
auto op){ readArrCase(op); })
1290 .Case<IntToFeltOp, FeltToIndexOp>([&](
auto op) { castCase(op); })
1291 .Default([&](Operation *) { });
1295 propagateIfChanged(valLattice, changed);
1298FailureOr<std::pair<DenseSet<Value>,
Interval>>
1299IntervalDataFlowAnalysis::getGeneralizedDecompInterval(Operation *baseOp, Value lhs, Value rhs) {
1300 auto isZeroConst = [
this](Value v) {
1301 Operation *op = v.getDefiningOp();
1305 if (!isConstOp(op)) {
1308 return getConst(op) == field.get().zero();
1310 bool lhsIsZero = isZeroConst(lhs), rhsIsZero = isZeroConst(rhs);
1311 Value exprTree =
nullptr;
1312 if (lhsIsZero && !rhsIsZero) {
1314 }
else if (!lhsIsZero && rhsIsZero) {
1321 std::optional<SourceRef> signalRef = std::nullopt;
1322 DenseSet<Value> signalVals;
1323 SmallVector<DynamicAPInt> consts;
1324 SmallVector<Value> frontier {exprTree};
1325 while (!frontier.empty()) {
1326 Value v = frontier.back();
1327 frontier.pop_back();
1328 Operation *op = v.getDefiningOp();
1332 auto handleRefValue = [
this, &signalRef, &signalVal, &signalVals]() {
1333 SourceRefLatticeValue refSet = getSourceRefState(signalVal);
1334 if (!refSet.isScalar() || !refSet.isSingleValue()) {
1337 SourceRef r = refSet.getSingleValue();
1338 if (signalRef.has_value() && signalRef.value() != r) {
1340 }
else if (!signalRef.has_value()) {
1343 signalVals.insert(signalVal);
1348 if (op && matchPattern(op, subPattern)) {
1349 if (failed(handleRefValue())) {
1352 auto constInt = APSInt(c.getValue());
1353 consts.push_back(field.get().reduce(constInt));
1355 }
else if (
m_RefValue(&signalVal).match(v)) {
1356 if (failed(handleRefValue())) {
1359 consts.push_back(field.get().zero());
1365 if (op && matchPattern(op, mulPattern)) {
1366 frontier.push_back(a);
1367 frontier.push_back(b);
1376 std::sort(consts.begin(), consts.end());
1377 Interval iv = UnreducedInterval(consts.front(), consts.back()).reduce(field.get());
1378 return std::make_pair(std::move(signalVals), iv);
1387 auto computeIntervalsImpl = [&solver, &ctx,
this](
1388 FuncDefOp fn, llvm::MapVector<SourceRef, Interval> &memberRanges,
1389 llvm::SetVector<ExpressionValue> &
1399 if (!ref.isScalar()) {
1402 searchSet.insert(ref);
1406 for (BlockArgument arg : fn.getArguments()) {
1408 if (searchSet.erase(ref)) {
1412 if (!expr.getExpr()) {
1415 memberRanges[ref] = expr.getInterval();
1416 assert(memberRanges[ref].getField() == ctx.
getField() &&
"bad interval defaults");
1424 if (!lattices.empty() && searchSet.erase(ref)) {
1427 joinedInterval = joinedInterval.
join(lattice->getValue().getScalarValue().getInterval());
1429 memberRanges[ref] = joinedInterval;
1430 assert(memberRanges[ref].getField() == ctx.
getField() &&
"bad interval defaults");
1435 if (searchSet.erase(ref)) {
1436 memberRanges[ref] = val.getInterval();
1437 assert(memberRanges[ref].getField() == ctx.
getField() &&
"bad interval defaults");
1442 for (
const auto &ref : searchSet) {
1451 llvm::SmallVector<std::pair<SourceRef, Interval>> sortedRanges;
1452 sortedRanges.reserve(memberRanges.size());
1453 for (
const auto &[ref, interval] : memberRanges) {
1454 sortedRanges.emplace_back(ref, interval);
1456 llvm::sort(sortedRanges, [](
const auto &a,
const auto &b) {
return a.first < b.first; });
1457 memberRanges.clear();
1458 for (
auto &[ref, interval] : sortedRanges) {
1459 memberRanges[ref] = interval;
1463 computeIntervalsImpl(structDef.getComputeFuncOp(), computeMemberRanges, computeSolverConstraints);
1464 computeIntervalsImpl(
1465 structDef.getConstrainFuncOp(), constrainMemberRanges, constrainSolverConstraints
1472 auto writeIntervals =
1473 [&os, &withConstraints](
1474 const char *fnName,
const llvm::MapVector<SourceRef, Interval> &memberRanges,
1475 const llvm::SetVector<ExpressionValue> &solverConstraints,
bool printName
1480 os.indent(indent) << fnName <<
" {";
1484 if (memberRanges.empty()) {
1489 for (
const auto &[ref, interval] : memberRanges) {
1491 os.indent(indent) << ref <<
" in " << interval;
1494 if (withConstraints) {
1496 os.indent(indent) <<
"Solver Constraints { ";
1497 if (solverConstraints.empty()) {
1500 for (
const auto &e : solverConstraints) {
1502 os.indent(indent + 4);
1503 e.getExpr()->print(os);
1506 os.indent(indent) <<
'}';
1512 os.indent(indent - 4) <<
'}';
1516 os <<
"StructIntervals { ";
1517 if (constrainMemberRanges.empty() && (!printCompute || computeMemberRanges.empty())) {
1523 writeIntervals(
FUNC_NAME_COMPUTE, computeMemberRanges, computeSolverConstraints, printCompute);
Tracks a solver expression and an interval range for that expression.
ExpressionValue withExpression(const llvm::SMTExprRef &newExpr) const
Return the current expression with a new SMT expression.
const Interval & getInterval() const
ExpressionValue withInterval(const Interval &newInterval) const
Return the current expression with a new interval.
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
llvm::SMTExprRef getExpr() const
bool isBoolSort(const llvm::SMTSolverRef &solver) const
const Field & getField() const
Information about the prime finite field used for the interval analysis.
llvm::DynamicAPInt zero() const
Returns 0 at the bitwidth of the field.
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
llvm::DynamicAPInt one() const
Returns 1 at the bitwidth of the field.
llvm::DynamicAPInt inv(const llvm::DynamicAPInt &i) const
Returns the multiplicative inverse of i in prime field p.
unsigned bitWidth() const
llvm::SMTExprRef createSymbol(const llvm::SMTSolverRef &solver, const char *name) const
Create a SMT solver symbol with the current field's bitwidth.
const LatticeValue & getValue() const
mlir::ChangeResult setValue(const LatticeValue &val)
IntervalAnalysisLatticeValue LatticeValue
mlir::ChangeResult meet(const AbstractSparseLattice &other) override
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult join(const AbstractSparseLattice &other) override
mlir::ChangeResult addSolverConstraint(const ExpressionValue &e)
mlir::LogicalResult visitOperation(mlir::Operation *op, mlir::ArrayRef< const Lattice * > operands, mlir::ArrayRef< Lattice * > results) override
Visit an operation with the lattices of its operands.
llvm::SMTExprRef getOrCreateSymbol(const SourceRef &r)
Either return the existing SMT expression that corresponds to the SourceRef, or create one.
const llvm::DenseMap< SourceRef, ExpressionValue > & getWriteResults() const
const llvm::DenseMap< SourceRef, llvm::DenseSet< Lattice * > > & getReadResults() const
Intervals over a finite field.
static Interval True(const Field &f)
llvm::DynamicAPInt rhs() const
Interval intersect(const Interval &rhs) const
Intersect.
UnreducedInterval toUnreduced() const
Convert to an UnreducedInterval.
static Interval Boolean(const Field &f)
static Interval Entire(const Field &f)
bool isDegenerate() const
static Interval False(const Field &f)
llvm::DynamicAPInt lhs() const
Interval join(const Interval &rhs) const
Union.
static SourceRefLatticeValue getValueState(mlir::DataFlowSolver &solver, mlir::Value val)
Defines an index into an LLZK object.
A value at a given point of the SourceRefLattice.
const SourceRef & getSingleValue() const
mlir::FailureOr< std::pair< SourceRefLatticeValue, mlir::ChangeResult > > extract(const std::vector< SourceRefIndex > &indices) const
Perform an array.extract or array.read operation, depending on how many indices are provided.
A reference to a "source", which is the base value from which other SSA values are derived.
mlir::FailureOr< SourceRef > createChild(const SourceRefIndex &r) const
std::vector< SourceRefIndex > Path
static std::vector< SourceRef > getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, const SourceRef &root)
Produce all possible SourceRefs that are present starting from the given root.
mlir::Type getType() const
void print(mlir::raw_ostream &os, bool withConstraints=false, bool printCompute=false) const
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, const IntervalAnalysisContext &ctx)
UnreducedInterval computeLTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is guaranteed to be less than the rhs's max value.
UnreducedInterval computeGEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than or equal to the rhs's lower bound.
UnreducedInterval computeGTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than the rhs's lower bound.
Interval reduce(const Field &field) const
Reduce the interval to an interval in the given field.
UnreducedInterval computeLEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is less than or equal to the rhs's upper bound.
Helper for converting between linear and multi-dimensional indexing with checks to ensure indices are...
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
::mlir::Type getElementType() const
::llzk::boolean::FeltCmpPredicate getPredicate()
bool isSingleValue() const
const ScalarTy & getScalarValue() const
IntervalAnalysisLattice * getLatticeElement(mlir::Value value) override
ExpressionValue boolNot(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
ExpressionValue add(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
RefValueCapture m_RefValue()
ExpressionValue intersection(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
FailureOr< Interval > signedIntDiv(const Interval &lhs, const Interval &rhs)
Computes signed integer division with possibly non-Degenerate divisors.
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue shiftLeft(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue fallbackUnaryOp(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &val)
constexpr char FUNC_NAME_CONSTRAIN[]
void ensure(bool condition, const llvm::Twine &errMsg)
ExpressionValue boolXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue cmp(const llvm::SMTSolverRef &solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue neg(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
DynamicAPInt toDynamicAPInt(StringRef str)
llvm::SMTExprRef createFieldInverseExpr(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &val, StringRef suffix="")
ExpressionValue sintDiv(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue boolAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
FailureOr< Interval > unsignedIntDiv(const Interval &lhs, const Interval &rhs)
Computes unsigned integer division with possibly non-Degenerate divisors.
ExpressionValue div(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue mul(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
ExpressionValue boolToFelt(const llvm::SMTSolverRef &solver, const ExpressionValue &expr, unsigned bitwidth)
ConstantCapture m_Constant()
std::string buildStringViaInsertionOp(Args &&...args)
Generate a string by using the insertion operator (<<) to append all args to a stream backed by the r...
ExpressionValue bitOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue uintDiv(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue bitAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue shiftRight(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
APSInt toAPSInt(const DynamicAPInt &i)
ExpressionValue sub(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
auto m_CommutativeOp(LhsMatcher lhs, RhsMatcher rhs)
ExpressionValue bitXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue notOp(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
FailureOr< Interval > feltDiv(const Interval &lhs, const Interval &rhs)
Computes finite-field division by multiplying the dividend by the multiplicative inverse of the divis...
ExpressionValue boolOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue selectValue(const llvm::SMTSolverRef &solver, const ExpressionValue &cond, const ExpressionValue &trueVal, const ExpressionValue &falseVal)
Parameters and shared objects to pass to child analyses.
const Field & getField() const
IntervalDataFlowAnalysis * intervalDFA