20#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
21#include <mlir/Dialect/SCF/IR/SCF.h>
23#include <llvm/ADT/EquivalenceClasses.h>
24#include <llvm/ADT/TypeSwitch.h>
42std::optional<UnreducedInterval> mergeUnreducedIntervals(
43 const std::optional<UnreducedInterval> &lhs,
const std::optional<UnreducedInterval> &rhs
45 if (!lhs.has_value() || !rhs.has_value()) {
48 return lhs->doUnion(*rhs);
52std::optional<UnreducedInterval>
54 if (!lhs.hasUnreducedInterval() || !rhs.hasUnreducedInterval()) {
57 return fn(lhs.getUnreducedInterval(), rhs.getUnreducedInterval());
62 if (expr.getInterval() != newInterval) {
63 refined = refined.dropUnreducedInterval();
68bool isInMaybeSkippedScfRegion(Operation *op) {
69 for (Operation *parent = op->getParentOp(); parent !=
nullptr; parent = parent->getParentOp()) {
70 if (llvm::isa<FuncDefOp>(parent)) {
77 if (llvm::isa<scf::ForOp, scf::IfOp, scf::WhileOp>(parent)) {
84std::optional<UnreducedInterval> getBooleanUnreducedInterval(
const Interval &interval) {
85 return interval.isBoolean() ? std::optional<UnreducedInterval>(interval.firstUnreduced())
89FailureOr<std::vector<SourceRef>>
91 std::vector<SourceRef> refs;
92 for (
const auto &[prefix, vals] : translations) {
93 if (!ref.isValidPrefix(prefix)) {
98 auto suffix = ref.getSuffix(prefix);
99 ensure(succeeded(suffix),
"prefix checked before SourceRef suffix extraction");
101 std::vector<SourceRefIndex> arraySuffix, remainingSuffix;
102 bool suffixIsPastArray =
false;
104 if (!suffixIsPastArray && arraySuffix.size() < vals.getNumArrayDims() &&
105 (idx.isIndex() || idx.isIndexRange())) {
106 arraySuffix.push_back(idx);
109 suffixIsPastArray =
true;
110 remainingSuffix.push_back(idx);
113 auto resolvedValsRes = vals.extract(arraySuffix);
114 ensure(succeeded(resolvedValsRes),
"could not resolve translated SourceRef array child");
115 SourceRefSet folded = resolvedValsRes->first.foldToScalar();
116 if (remainingSuffix.empty()) {
117 refs.insert(refs.end(), folded.begin(), folded.end());
121 for (
const SourceRef &baseRef : folded) {
122 auto translatedRef = mlir::FailureOr<SourceRef>(baseRef);
124 if (failed(translatedRef)) {
127 translatedRef = translatedRef->createChild(idx);
129 if (succeeded(translatedRef)) {
130 refs.push_back(*translatedRef);
134 for (
const SourceRef &replacement : vals.getScalarValue()) {
135 auto translated = ref.translate(prefix, replacement);
136 if (succeeded(translated)) {
137 refs.push_back(*translated);
149bool isDirectSourceRefValue(Value value) {
150 if (llvm::isa<BlockArgument>(value)) {
154 Operation *definingOp = value.getDefiningOp();
155 return llvm::isa_and_present<MemberReadOp, ReadArrayOp, polymorphic::ConstReadOp>(definingOp);
158std::optional<SourceRefLatticeValue>
159getIdentitySourceRefState(DataFlowSolver &solver, Value value) {
160 if (isDirectSourceRefValue(value)) {
162 if (val.isScalar()) {
168 auto createArray = llvm::dyn_cast_if_present<CreateArrayOp>(value.getDefiningOp());
174 for (
auto [idx, element] : llvm::enumerate(createArray.getElements())) {
175 std::optional<SourceRefLatticeValue> elementVal = getIdentitySourceRefState(solver, element);
176 if (!elementVal.has_value()) {
179 (void)arrayVal.getElemFlatIdx(idx).setValue(*elementVal);
184llvm::EquivalenceClasses<SourceRef>
185collectDirectEqualityRefs(DataFlowSolver &solver,
FuncDefOp fn) {
186 llvm::EquivalenceClasses<SourceRef> eqRefs;
188 Operation *op = eqOp.getOperation();
193 Value lhs = eqOp.getLhs();
194 Value rhs = eqOp.getRhs();
195 if (!isDirectSourceRefValue(lhs) || !isDirectSourceRefValue(rhs)) {
201 if (!lhsState.isScalar() || !rhsState.isScalar() || !lhsState.isSingleValue() ||
202 !rhsState.isSingleValue()) {
206 const SourceRef &lhsRef = lhsState.getSingleValue();
207 const SourceRef &rhsRef = rhsState.getSingleValue();
208 if (lhsRef.isConstant() || rhsRef.isConstant()) {
211 eqRefs.unionSets(lhsRef, rhsRef);
221 const llvm::SMTSolverRef &solver, Operation *op,
const ExpressionValue &val,
222 StringRef suffix =
""
227 DynamicAPInt invVal = field.
inv(iv.
lhs());
235 if (!suffix.empty()) {
236 symName += suffix.str();
238 llvm::SMTExprRef invSym = field.
createSymbol(solver, symName.c_str());
239 llvm::SMTExprRef one = solver->mkBitvector(APSInt::get(1), field.
bitWidth());
241 llvm::SMTExprRef mult = solver->mkBVMul(val.
getExpr(), invSym);
242 llvm::SMTExprRef
mod = solver->mkBVURem(mult, prime);
243 llvm::SMTExprRef constraint = solver->mkEqual(
mod, one);
244 solver->addConstraint(constraint);
249 if (expr ==
nullptr && rhs.expr ==
nullptr) {
250 return i == rhs.i && unreduced == rhs.unreduced;
252 if (expr ==
nullptr || rhs.expr ==
nullptr) {
255 return i == rhs.i && unreduced == rhs.unreduced && *expr == *rhs.expr;
260 llvm::SMTExprRef zero = solver->mkBitvector(mlir::APSInt::get(0), bitwidth);
261 llvm::SMTExprRef one = solver->mkBitvector(mlir::APSInt::get(1), bitwidth);
262 llvm::SMTExprRef boolToFeltConv = solver->mkIte(expr.
getExpr(), one, zero);
282 llvm::SMTExprRef resultExpr =
284 std::optional<UnreducedInterval> resultUnreduced;
292 resultUnreduced = mergeUnreducedIntervals(
296 return ExpressionValue(resultExpr, resultInterval, std::move(resultUnreduced));
303 const auto *exprEq = solver->mkEqual(lhs.expr, rhs.expr);
310 res.i = lhs.i + rhs.i;
311 res.expr = solver->mkBVAdd(lhs.expr, rhs.expr);
319 res.i = lhs.i - rhs.i;
320 res.expr = solver->mkBVSub(lhs.expr, rhs.expr);
328 res.i = lhs.i * rhs.i;
329 res.expr = solver->mkBVMul(lhs.expr, rhs.expr);
338 auto divRes =
feltDiv(lhs.i, rhs.i);
339 if (failed(divRes)) {
346 "non-degenerate felt.div divisors are not tracked precisely, and the divisor may "
347 "contain zero. Range of division result will be treated as unbounded."
352 "non-degenerate felt.div divisors are not tracked precisely because precise field "
353 "division over intervals would require enumerating divisor inverses. Range of "
354 "division result will be treated as unbounded."
360 "divisor is zero, leading to a divide-by-zero error. Range of division result will "
361 "be treated as unbounded."
370 res.expr = solver->mkBVMul(lhs.expr, invExpr);
375 const llvm::SMTSolverRef &solver, Operation *op,
const ExpressionValue &lhs,
380 if (failed(divRes)) {
382 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
383 " Range of division result will be treated as unbounded."
390 res.expr = solver->mkBVUDiv(lhs.expr, rhs.expr);
395 const llvm::SMTSolverRef &solver, Operation *op,
const ExpressionValue &lhs,
400 if (failed(divRes)) {
402 "divisor is not restricted to non-zero values, leading to potential divide-by-zero error."
403 " Range of division result will be treated as unbounded."
410 res.expr = solver->mkBVSDiv(lhs.expr, rhs.expr);
417 res.i = lhs.i % rhs.i;
418 res.expr = solver->mkBVURem(lhs.expr, rhs.expr);
433 res.i = lhs.i & rhs.i;
434 res.expr = solver->mkBVAnd(lhs.expr, rhs.expr);
441 res.i = lhs.i | rhs.i;
442 res.expr = solver->mkBVOr(lhs.expr, rhs.expr);
449 return boolXor(solver, lhs, rhs);
453 res.i = lhs.i ^ rhs.i;
454 res.expr = solver->mkBVXor(lhs.expr, rhs.expr);
462 res.i = lhs.i << rhs.i;
463 res.expr = solver->mkBVShl(lhs.expr, rhs.expr);
471 res.i = lhs.i >> rhs.i;
472 res.expr = solver->mkBVLshr(lhs.expr, rhs.expr);
484 case FeltCmpPredicate::EQ:
485 res.expr = solver->mkEqual(lhs.expr, rhs.expr);
492 case FeltCmpPredicate::NE:
493 res.expr = solver->mkNot(solver->mkEqual(lhs.expr, rhs.expr));
500 case FeltCmpPredicate::LT:
501 res.expr = solver->mkBVUlt(lhs.expr, rhs.expr);
509 case FeltCmpPredicate::LE:
510 res.expr = solver->mkBVUle(lhs.expr, rhs.expr);
518 case FeltCmpPredicate::GT:
519 res.expr = solver->mkBVUgt(lhs.expr, rhs.expr);
527 case FeltCmpPredicate::GE:
528 res.expr = solver->mkBVUge(lhs.expr, rhs.expr);
545 res.expr = solver->mkAnd(lhs.expr, rhs.expr);
553 res.i =
boolOr(lhs.i, rhs.i);
554 res.expr = solver->mkOr(lhs.expr, rhs.expr);
564 res.expr = solver->mkAnd(
565 solver->mkOr(lhs.expr, rhs.expr), solver->mkNot(solver->mkAnd(lhs.expr, rhs.expr))
574 res.expr = solver->mkBVNeg(val.expr);
584 res.expr = solver->mkBVNot(val.expr);
591 res.expr = solver->mkNot(val.expr);
601 res.expr = TypeSwitch<Operation *, llvm::SMTExprRef>(op)
604 }).Default([](Operation *unsupported) {
605 llvm::report_fatal_error(
606 "no fallback provided for " + mlir::Twine(unsupported->getName().getStringRef())
611 if (llvm::isa<InvFeltOp>(op)) {
625 os <<
"<null expression>";
628 os <<
" ( interval: " << i <<
" )";
629 if (unreduced.has_value()) {
630 os <<
" ( unreduced: " << *unreduced <<
" )";
639 return ChangeResult::NoChange;
645 return ChangeResult::NoChange;
649 os <<
"IntervalAnalysisLattice { " << val <<
" }";
654 return ChangeResult::NoChange;
657 return ChangeResult::Change;
666 if (!constraints.contains(e)) {
667 constraints.insert(e);
668 return ChangeResult::Change;
670 return ChangeResult::NoChange;
679std::vector<SourceRefIndex> IntervalDataFlowAnalysis::getArrayAccessIndices(
680 Operation * , ArrayAccessOpInterface arrayAccessOp
682 std::vector<SourceRefIndex> indices;
683 ArrayType arrayType = arrayAccessOp.getArrRefType();
684 size_t numIndices = arrayAccessOp.getIndices().size();
685 indices.reserve(numIndices);
687 for (
size_t i = 0; i < numIndices; ++i) {
688 Value idxOperand = arrayAccessOp.getIndices()[i];
689 SourceRefLatticeValue idxVals = getSourceRefState(idxOperand);
692 if (idxVals.isSingleValue() && idxVals.getSingleValue().isConstant()) {
693 indices.emplace_back(*idxVals.getSingleValue().getConstantValue());
695 auto lower = APInt::getZero(64);
696 APInt upper(64, arrayType.getDimSize(i));
697 indices.emplace_back(lower, upper);
704mlir::FailureOr<SourceRef> IntervalDataFlowAnalysis::getArrayAccessRef(
707 std::vector<SourceRefIndex> indices = getArrayAccessIndices(baseOp, arrayAccessOp);
708 Value arrayVal = arrayAccessOp.getArrRef();
709 if (
auto blockArg = llvm::dyn_cast<BlockArgument>(arrayVal)) {
710 return SourceRef(blockArg, std::move(indices));
712 if (
auto result = llvm::dyn_cast<OpResult>(arrayVal)) {
713 return SourceRef(result, std::move(indices));
719 if (
auto it = writeResults.find(ref); it != writeResults.end()) {
720 return it->second.getInterval();
723 if (ref.isConstantInt()) {
724 auto constVal = ref.getConstantValue();
725 if (succeeded(constVal)) {
730 if (ref.isRooted() && ref.getPath().empty()) {
731 auto rootVal = ref.getRoot();
732 if (succeeded(rootVal) && !llvm::isa<ArrayType, StructType>(rootVal->getType())) {
734 if (rootExpr.getExpr() !=
nullptr) {
735 return rootExpr.getInterval();
740 return getDefaultIntervalForType(ref.getType());
743std::optional<UnreducedInterval>
744IntervalDataFlowAnalysis::getDefaultUnreducedIntervalForType(mlir::Type ty)
const {
745 if (!trackUnreducedIntervals) {
748 if (isBooleanType(ty)) {
749 return UnreducedInterval(0, 1);
751 return UnreducedInterval(field.get().zero(), field.get().maxVal());
754std::optional<UnreducedInterval>
755IntervalDataFlowAnalysis::getRefUnreducedInterval(
const SourceRef &ref) {
756 if (!trackUnreducedIntervals) {
760 if (
auto it = writeResults.find(ref); it != writeResults.end()) {
761 return it->second.getOptionalUnreducedInterval();
764 if (ref.isConstantInt()) {
765 auto constVal = ref.getConstantValue();
766 if (succeeded(constVal)) {
767 return UnreducedInterval(*constVal, *constVal);
771 if (ref.isRooted() && ref.getPath().empty()) {
772 auto rootVal = ref.getRoot();
773 if (succeeded(rootVal) && !llvm::isa<ArrayType, StructType>(rootVal->getType())) {
775 if (rootExpr.hasUnreducedInterval()) {
776 return rootExpr.getUnreducedInterval();
781 return getRefInterval(ref).firstUnreduced();
785 if (
auto it = writeResults.find(ref); it != writeResults.end()) {
788 return createUnknownValue(val)
793void IntervalDataFlowAnalysis::recordRefWrite(
796 auto joinStoredWrite = [
this, &writtenRef](
797 const ExpressionValue &old,
const ExpressionValue &next
798 ) -> ExpressionValue {
799 Interval combinedWrite = old.getInterval().join(next.getInterval());
800 auto combinedUnreduced = mergeUnreducedIntervals(
801 old.getOptionalUnreducedInterval(), next.getOptionalUnreducedInterval()
803 if (old.getExpr() !=
nullptr && next.getExpr() !=
nullptr &&
804 *old.getExpr() == *next.getExpr()) {
805 return old.withInterval(combinedWrite).withOptionalUnreducedInterval(combinedUnreduced);
808 return ExpressionValue(
813 if (
auto it = writeResults.find(writtenRef); it != writeResults.end()) {
814 it->second = joinStoredWrite(it->second, writeVal);
815 }
else if (mayBeSkipped) {
816 ExpressionValue noWrite(
818 getRefUnreducedInterval(writtenRef)
820 writeResults[writtenRef] = joinStoredWrite(noWrite, writeVal);
822 writeResults[writtenRef] = writeVal;
825 const ExpressionValue &readerUpdate = mayBeSkipped ? writeResults[writtenRef] : writeVal;
826 for (Lattice *readerLattice : readResults[writtenRef]) {
827 ExpressionValue prior = readerLattice->getValue().getScalarValue();
829 ExpressionValue newVal = prior.withInterval(
intersection);
830 propagateIfChanged(readerLattice, readerLattice->setValue(newVal));
835 Operation *op, ArrayRef<const Lattice *> operands, ArrayRef<Lattice *> results
844 if (operands.empty() && results.empty()) {
849 llvm::SmallVector<LatticeValue> operandVals;
850 llvm::SmallVector<std::optional<SourceRef>> operandRefs;
851 auto resolveRefStateValue =
853 ensure(refSet.isScalar(),
"should have ruled out array values already");
855 if (refSet.getScalarValue().empty()) {
862 " is empty; defining operation is unsupported by SourceRef analysis"
868 if (!refSet.isSingleValue()) {
870 std::optional<UnreducedInterval> joinedUnreduced = std::nullopt;
871 bool sawFirst =
false;
872 for (
const SourceRef &ref : refSet.getScalarValue()) {
873 joinedInterval = joinedInterval.
join(getRefInterval(ref));
874 auto refUnreduced = getRefUnreducedInterval(ref);
876 joinedUnreduced = refUnreduced;
879 joinedUnreduced = mergeUnreducedIntervals(joinedUnreduced, refUnreduced);
885 return LatticeValue(anyVal);
888 return LatticeValue(getRefValue(refSet.getSingleValue(), value));
890 for (
unsigned opNum = 0; opNum < op->getNumOperands(); ++opNum) {
891 Value val = op->getOperand(opNum);
896 operandRefs.push_back(std::nullopt);
899 auto priorState = operands[opNum]->getValue();
900 if (priorState.getScalarValue().getExpr() !=
nullptr) {
901 operandVals.push_back(priorState);
905 if (
auto readArr = llvm::dyn_cast_if_present<ReadArrayOp>(val.getDefiningOp())) {
906 auto arrayRef = getArrayAccessRef(op, readArr);
907 if (succeeded(arrayRef)) {
908 if (
auto it = writeResults.find(*arrayRef); it != writeResults.end()) {
909 operandVals.emplace_back(it->second);
911 (void)operandLattice->
setValue(it->second);
920 Type valTy = val.getType();
921 if (llvm::isa<ArrayType, StructType>(valTy)) {
923 operandVals.emplace_back(anyVal);
927 auto resolvedValue = resolveRefStateValue(val, refSet);
928 if (!resolvedValue.has_value()) {
933 operandVals.push_back(*resolvedValue);
939 (void)operandLattice->
setValue(operandVals[opNum]);
942 if (isReadOp(op) && op->getNumResults() == 1) {
943 Value resultVal = op->getResult(0);
944 if (!llvm::isa<ArrayType, StructType>(resultVal.getType())) {
945 auto resolvedValue = resolveRefStateValue(resultVal, getSourceRefState(resultVal));
946 if (resolvedValue.has_value()) {
947 propagateIfChanged(results[0], results[0]->setValue(*resolvedValue));
955 llvm::DynamicAPInt constVal = getConst(op);
956 llvm::SMTExprRef expr;
957 if (isBoolConstOp(op)) {
958 expr = createConstBoolExpr(constVal != 0);
960 expr = createConstBitvectorExpr(constVal);
964 if (trackUnreducedIntervals) {
967 propagateIfChanged(results[0], results[0]->setValue(latticeVal));
968 }
else if (isArithmeticOp(op)) {
970 if (operands.size() == 2) {
971 result = performBinaryArithmetic(op, operandVals[0], operandVals[1]);
973 result = performUnaryArithmetic(op, operandVals[0]);
977 const ExpressionValue &prior = results[0]->getValue().getScalarValue();
981 propagateIfChanged(results[0], results[0]->setValue(result));
982 }
else if (
auto selectOp = llvm::dyn_cast<arith::SelectOp>(op)) {
984 smtSolver, operandVals[0].getScalarValue(), operandVals[1].getScalarValue(),
985 operandVals[2].getScalarValue()
987 const ExpressionValue &prior = results[0]->getValue().getScalarValue();
991 propagateIfChanged(results[0], results[0]->setValue(result));
992 }
else if (
EmitEqualityOp emitEq = llvm::dyn_cast<EmitEqualityOp>(op)) {
993 Value lhsVal = emitEq.getLhs(), rhsVal = emitEq.getRhs();
999 auto res = getGeneralizedDecompInterval(op, lhsVal, rhsVal);
1000 if (succeeded(res)) {
1001 for (Value signalVal : res->first) {
1002 applyInterval(emitEq, signalVal, res->second);
1010 applyInterval(emitEq, lhsVal, constrainInterval);
1011 applyInterval(emitEq, rhsVal, constrainInterval);
1012 }
else if (
auto assertOp = llvm::dyn_cast<AssertOp>(op)) {
1015 Value cond = assertOp.getCondition();
1018 auto assertExpr = operandVals[0].getScalarValue();
1021 }
else if (
auto writem = llvm::dyn_cast<MemberWriteOp>(op)) {
1022 const bool maySkipWrite = isInMaybeSkippedScfRegion(op);
1025 auto cmp = writem.getComponent();
1029 auto memberDefRes = writem.getMemberDefOp(tables);
1030 if (succeeded(memberDefRes)) {
1033 ensure(succeeded(memberRefRes),
"could not create SourceRef child for member write");
1034 const SourceRef &memberRef = *memberRefRes;
1035 Type memberTy = writem.getVal().
getType();
1036 if (!llvm::isa<ArrayType, StructType>(memberTy)) {
1038 recordRefWrite(memberRef, writeVal, maySkipWrite);
1041 std::optional<SourceRef> rhsPrefix;
1042 if (operandRefs[1].has_value() && operandRefs[1]->isRooted()) {
1043 rhsPrefix = operandRefs[1];
1044 }
else if (
auto blockArg = llvm::dyn_cast<BlockArgument>(writem.getVal())) {
1046 }
else if (
auto result = llvm::dyn_cast<OpResult>(writem.getVal())) {
1050 if (rhsPrefix.has_value()) {
1051 llvm::SmallVector<std::pair<SourceRef, ExpressionValue>> remappedWrites;
1052 for (
const auto &[writtenRef, writtenVal] : writeResults) {
1053 if (!writtenRef.isValidPrefix(*rhsPrefix)) {
1057 auto translatedRef = writtenRef.translate(*rhsPrefix, memberRef);
1058 ensure(succeeded(translatedRef),
"could not translate composite member write");
1059 remappedWrites.emplace_back(*translatedRef, writtenVal);
1062 for (
const auto &[translatedRef, translatedVal] : remappedWrites) {
1063 recordRefWrite(translatedRef, translatedVal, maySkipWrite);
1069 }
else if (
auto writeArr = llvm::dyn_cast<WriteArrayOp>(op)) {
1070 const bool maySkipWrite = isInMaybeSkippedScfRegion(op);
1072 auto arrayRef = getArrayAccessRef(op, writeArr);
1073 if (succeeded(arrayRef)) {
1074 recordRefWrite(*arrayRef, writeVal, maySkipWrite);
1079 std::vector<SourceRefIndex> indices = getArrayAccessIndices(op, writeArr);
1080 auto targetRefsRes = arrayVals.
extract(indices);
1081 ensure(succeeded(targetRefsRes),
"could not create SourceRef child for array write");
1082 auto [targetRefs, _] = *targetRefsRes;
1083 ensure(targetRefs.isScalar(),
"array write must resolve to scalar references");
1084 for (
const SourceRef &ref : targetRefs.getScalarValue()) {
1085 recordRefWrite(ref, writeVal, maySkipWrite);
1088 }
else if (
auto createArray = llvm::dyn_cast<CreateArrayOp>(op)) {
1089 const auto &elements = createArray.getElements();
1090 ArrayType arrayTy = createArray.getType();
1093 if (!elements.empty() && !llvm::isa<ArrayType, StructType>(elemTy)) {
1094 ensure(arrayTy.hasStaticShape(),
"array.new with explicit elements must have static shape");
1096 std::cmp_equal(elements.size(), arrayTy.getNumElements()),
1097 "array.new explicit initializer length must match array shape"
1101 auto arrayRes = llvm::cast<OpResult>(createArray->getResult(0));
1102 for (
unsigned i = 0; i < elements.size(); ++i) {
1103 auto maybeIndices = indexGen.
delinearize(i, op->getContext());
1104 ensure(maybeIndices.has_value(),
"could not delinearize array.new element index");
1107 path.reserve(maybeIndices->size());
1108 for (Attribute attr : *maybeIndices) {
1109 auto idxAttr = llvm::dyn_cast<IntegerAttr>(attr);
1110 ensure(idxAttr !=
nullptr,
"array.new delinearize should produce integer attributes");
1111 path.emplace_back(idxAttr.getValue());
1114 recordRefWrite(
SourceRef(arrayRes, std::move(path)), operandVals[i].getScalarValue());
1117 }
else if (isa<IntToFeltOp, FeltToIndexOp>(op)) {
1124 expr =
boolToFelt(smtSolver, expr, field.get().bitWidth());
1126 propagateIfChanged(results[0], results[0]->setValue(expr));
1127 }
else if (
auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1130 Operation *parent = op->getParentOp();
1131 ensure(parent,
"yield operation must have parent operation");
1133 for (
unsigned idx = 0; idx < yieldOp.getResults().size(); ++idx) {
1134 Value parentRes = parent->getResult(idx);
1140 if (
auto loopOp = llvm::dyn_cast<LoopLikeOpInterface>(parent)) {
1144 if (exprVal.
getExpr() !=
nullptr) {
1156 propagateIfChanged(resLattice, resLattice->
setValue(newResVal));
1166 && !isDefinitionOp(op)
1168 && !llvm::isa<CreateArrayOp, CreateStructOp, NonDetOp>(op)
1170 op->emitWarning(
"unhandled operation, analysis may be incomplete").report();
1177 auto it = refSymbols.find(r);
1178 if (it != refSymbols.end()) {
1181 llvm::SMTExprRef sym = createSymbol(r);
1182 refSymbols[r] = sym;
1186llvm::SMTExprRef IntervalDataFlowAnalysis::createSymbol(mlir::Type ty,
const char *name)
const {
1187 if (isBooleanType(ty)) {
1188 return smtSolver->mkSymbol(name, smtSolver->getBoolSort());
1190 return field.get().createSymbol(smtSolver, name);
1193llvm::SMTExprRef IntervalDataFlowAnalysis::createSymbol(
const SourceRef &r)
const {
1195 return createSymbol(r.getType(), name.c_str());
1198llvm::SMTExprRef IntervalDataFlowAnalysis::createSymbol(Value v)
const {
1200 return createSymbol(v.getType(), name.c_str());
1203llvm::DynamicAPInt IntervalDataFlowAnalysis::getConst(Operation *op)
const {
1204 ensure(isConstOp(op),
"op is not a const op");
1208 llvm::DynamicAPInt fieldConst = TypeSwitch<Operation *, llvm::DynamicAPInt>(op)
1209 .Case<FeltConstantOp>([&](
auto feltConst) {
1210 llvm::APSInt constOpVal(feltConst.getValue());
1211 return field.get().reduce(constOpVal);
1213 .Case<arith::ConstantIndexOp>([&](
auto indexConst) {
1214 return DynamicAPInt(indexConst.value());
1216 .Case<arith::ConstantIntOp>([&](
auto intConst) {
1217 auto valAttr = dyn_cast<IntegerAttr>(intConst.getValue());
1218 ensure(valAttr !=
nullptr,
"arith::ConstantIntOp must have an IntegerAttr as its value");
1221 .Default([](
auto *illegalOp) {
1223 debug::Appender(err) <<
"unhandled getConst case: " << *illegalOp;
1224 llvm::report_fatal_error(Twine(err));
1225 return llvm::DynamicAPInt();
1232 Operation *op,
const LatticeValue &a,
const LatticeValue &b
1234 ensure(isArithmeticOp(op),
"is not arithmetic op");
1236 auto lhs = a.getScalarValue(), rhs = b.getScalarValue();
1237 ensure(lhs.getExpr(),
"cannot perform arithmetic over null lhs smt expr");
1238 ensure(rhs.getExpr(),
"cannot perform arithmetic over null rhs smt expr");
1241 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
1242 .Case<AddFeltOp>([&](
auto) {
return add(smtSolver, lhs, rhs); })
1243 .Case<SubFeltOp>([&](
auto) {
return sub(smtSolver, lhs, rhs); })
1244 .Case<MulFeltOp>([&](
auto) {
return mul(smtSolver, lhs, rhs); })
1245 .Case<DivFeltOp>([&](
auto) {
return div(smtSolver, op, lhs, rhs); })
1246 .Case<UnsignedIntDivFeltOp>([&](
auto) {
return uintDiv(smtSolver, op, lhs, rhs); })
1247 .Case<SignedIntDivFeltOp>([&](
auto) {
return sintDiv(smtSolver, op, lhs, rhs); })
1248 .Case<UnsignedModFeltOp>([&](
auto) {
return mod(smtSolver, lhs, rhs); })
1249 .Case<SignedModFeltOp>([&](
auto) {
return sintMod(smtSolver, lhs, rhs); })
1250 .Case<AndFeltOp>([&](
auto) {
return bitAnd(smtSolver, lhs, rhs); })
1251 .Case<OrFeltOp>([&](
auto) {
return bitOr(smtSolver, lhs, rhs); })
1252 .Case<XorFeltOp, arith::XOrIOp>([&](
auto) {
return bitXor(smtSolver, lhs, rhs); })
1253 .Case<ShlFeltOp>([&](
auto) {
return shiftLeft(smtSolver, lhs, rhs); })
1254 .Case<ShrFeltOp>([&](
auto) {
return shiftRight(smtSolver, lhs, rhs); })
1255 .Case<CmpOp>([&](
auto cmpOp) {
return cmp(smtSolver, cmpOp, lhs, rhs); })
1256 .Case<AndBoolOp>([&](
auto) {
return boolAnd(smtSolver, lhs, rhs); })
1257 .Case<OrBoolOp>([&](
auto) {
return boolOr(smtSolver, lhs, rhs); })
1258 .Case<XorBoolOp>([&](
auto) {
return boolXor(smtSolver, lhs, rhs); })
1259 .Default([&](
auto *unsupported) {
1262 "unsupported binary arithmetic operation"
1265 return ExpressionValue();
1269 ensure(res.getExpr(),
"arithmetic produced null smt expr");
1274IntervalDataFlowAnalysis::performUnaryArithmetic(Operation *op,
const LatticeValue &a) {
1275 ensure(isArithmeticOp(op),
"is not arithmetic op");
1277 auto val = a.getScalarValue();
1278 ensure(val.getExpr(),
"cannot perform arithmetic over null smt expr");
1280 auto res = TypeSwitch<Operation *, ExpressionValue>(op)
1281 .Case<NegFeltOp>([&](
auto) {
return neg(smtSolver, val); })
1282 .Case<NotFeltOp>([&](
auto) {
return notOp(smtSolver, val); })
1283 .Case<NotBoolOp>([&](
auto) {
return boolNot(smtSolver, val); })
1285 .Case<InvFeltOp>([&](
auto inv) {
1287 }).Default([&](
auto *unsupported) {
1290 "unsupported unary arithmetic operation, defaulting to over-approximated interval"
1296 ensure(res.getExpr(),
"arithmetic produced null smt expr");
1300void IntervalDataFlowAnalysis::applyInterval(Operation *valUser, Value val,
Interval newInterval) {
1302 ExpressionValue oldLatticeVal = valLattice->getValue().getScalarValue();
1305 ExpressionValue newLatticeVal = refineReducedInterval(oldLatticeVal,
intersection);
1306 ChangeResult changed = valLattice->setValue(newLatticeVal);
1308 if (
auto blockArg = llvm::dyn_cast<BlockArgument>(val)) {
1309 auto fnOp = dyn_cast<FuncDefOp>(blockArg.getOwner()->getParentOp());
1312 if (propagateInputConstraints && fnOp && fnOp.isStructConstrain() &&
1313 blockArg.getArgNumber() > 0 && !newInterval.isEntire()) {
1314 auto structOp = fnOp->getParentOfType<StructDefOp>();
1315 FuncDefOp computeFn = structOp.getComputeFuncOp();
1316 BlockArgument computeArg = computeFn.getArgument(blockArg.getArgNumber() - 1);
1319 SourceRef ref(computeArg);
1320 ExpressionValue newArgVal(
1322 trackUnreducedIntervals ? std::optional<UnreducedInterval>(newInterval.firstUnreduced())
1325 propagateIfChanged(computeEntryLattice, computeEntryLattice->setValue(newArgVal));
1330 Operation *definingOp = val.getDefiningOp();
1332 propagateIfChanged(valLattice, changed);
1336 const Field &f = field.get();
1344 auto cmpCase = [&](CmpOp cmpOp) {
1350 newInterval.isBoolean() || newInterval.isEmpty(),
1351 "new interval for CmpOp is not boolean or empty"
1353 if (!newInterval.isDegenerate()) {
1358 bool cmpTrue = newInterval.rhs() == f.one();
1360 Value lhs = cmpOp.getLhs(), rhs = cmpOp.getRhs();
1362 ExpressionValue lhsExpr = lhsLat->getValue().getScalarValue(),
1363 rhsExpr = rhsLat->getValue().getScalarValue();
1365 Interval newLhsInterval, newRhsInterval;
1366 const Interval &lhsInterval = lhsExpr.getInterval();
1367 const Interval &rhsInterval = rhsExpr.getInterval();
1371 auto eqCase = [&]() {
1372 return (pred == FeltCmpPredicate::EQ && cmpTrue) ||
1373 (pred == FeltCmpPredicate::NE && !cmpTrue);
1375 auto neCase = [&]() {
1376 return (pred == FeltCmpPredicate::NE && cmpTrue) ||
1377 (pred == FeltCmpPredicate::EQ && !cmpTrue);
1379 auto ltCase = [&]() {
1380 return (pred == FeltCmpPredicate::LT && cmpTrue) ||
1381 (pred == FeltCmpPredicate::GE && !cmpTrue);
1383 auto leCase = [&]() {
1384 return (pred == FeltCmpPredicate::LE && cmpTrue) ||
1385 (pred == FeltCmpPredicate::GT && !cmpTrue);
1387 auto gtCase = [&]() {
1388 return (pred == FeltCmpPredicate::GT && cmpTrue) ||
1389 (pred == FeltCmpPredicate::LE && !cmpTrue);
1391 auto geCase = [&]() {
1392 return (pred == FeltCmpPredicate::GE && cmpTrue) ||
1393 (pred == FeltCmpPredicate::LT && !cmpTrue);
1398 newLhsInterval = newRhsInterval = lhsInterval.intersect(rhsInterval);
1399 }
else if (neCase()) {
1400 if (lhsInterval.isDegenerate() && rhsInterval.isDegenerate() && lhsInterval == rhsInterval) {
1404 }
else if (lhsInterval.isDegenerate()) {
1406 newLhsInterval = lhsInterval;
1407 newRhsInterval = rhsInterval.difference(lhsInterval);
1408 }
else if (rhsInterval.isDegenerate()) {
1410 newLhsInterval = lhsInterval.difference(rhsInterval);
1411 newRhsInterval = rhsInterval;
1414 newLhsInterval = lhsInterval;
1415 newRhsInterval = rhsInterval;
1417 }
else if (ltCase()) {
1418 newLhsInterval = lhsInterval.toUnreduced().computeLTPart(rhsInterval.toUnreduced()).reduce(f);
1419 newRhsInterval = rhsInterval.toUnreduced().computeGEPart(lhsInterval.toUnreduced()).reduce(f);
1420 }
else if (leCase()) {
1421 newLhsInterval = lhsInterval.toUnreduced().computeLEPart(rhsInterval.toUnreduced()).reduce(f);
1422 newRhsInterval = rhsInterval.toUnreduced().computeGTPart(lhsInterval.toUnreduced()).reduce(f);
1423 }
else if (gtCase()) {
1424 newLhsInterval = lhsInterval.toUnreduced().computeGTPart(rhsInterval.toUnreduced()).reduce(f);
1425 newRhsInterval = rhsInterval.toUnreduced().computeLEPart(lhsInterval.toUnreduced()).reduce(f);
1426 }
else if (geCase()) {
1427 newLhsInterval = lhsInterval.toUnreduced().computeGEPart(rhsInterval.toUnreduced()).reduce(f);
1428 newRhsInterval = rhsInterval.toUnreduced().computeLTPart(lhsInterval.toUnreduced()).reduce(f);
1430 cmpOp->emitWarning(
"unhandled cmp predicate").report();
1435 applyInterval(cmpOp, lhs, newLhsInterval);
1436 applyInterval(cmpOp, rhs, newRhsInterval);
1444 auto mulCase = [&](MulFeltOp mulOp) {
1446 auto constCase = [&](FeltConstantOp constOperand, Value multiplicand) {
1448 APInt constVal = constOperand.getValue();
1449 if (constVal.isZero()) {
1454 applyInterval(mulOp, multiplicand, updatedInterval);
1457 Value lhs = mulOp.getLhs(), rhs = mulOp.getRhs();
1459 auto lhsConstOp = dyn_cast_if_present<FeltConstantOp>(lhs.getDefiningOp());
1460 auto rhsConstOp = dyn_cast_if_present<FeltConstantOp>(rhs.getDefiningOp());
1462 if (lhsConstOp && rhsConstOp) {
1464 }
else if (lhsConstOp) {
1465 constCase(lhsConstOp, rhs);
1467 }
else if (rhsConstOp) {
1468 constCase(rhsConstOp, lhs);
1474 if (newInterval.intersect(zeroInt).isNotEmpty()) {
1480 ExpressionValue lhsExpr = lhsLat->getValue().getScalarValue(),
1481 rhsExpr = rhsLat->getValue().getScalarValue();
1482 Interval newLhsInterval = lhsExpr.getInterval().difference(zeroInt);
1483 Interval newRhsInterval = rhsExpr.getInterval().difference(zeroInt);
1484 applyInterval(mulOp, lhs, newLhsInterval);
1485 applyInterval(mulOp, rhs, newRhsInterval);
1488 auto addCase = [&](AddFeltOp addOp) {
1489 Value lhs = addOp.getLhs(), rhs = addOp.getRhs();
1491 ExpressionValue lhsVal = lhsLat->getValue().getScalarValue();
1492 ExpressionValue rhsVal = rhsLat->getValue().getScalarValue();
1494 const Interval &currLhsInt = lhsVal.getInterval(), &currRhsInt = rhsVal.getInterval();
1496 Interval derivedLhsInt = newInterval - currRhsInt;
1497 Interval derivedRhsInt = newInterval - currLhsInt;
1499 Interval finalLhsInt = currLhsInt.intersect(derivedLhsInt);
1500 Interval finalRhsInt = currRhsInt.intersect(derivedRhsInt);
1502 applyInterval(addOp, lhs, finalLhsInt);
1503 applyInterval(addOp, rhs, finalRhsInt);
1506 auto subCase = [&](SubFeltOp subOp) {
1507 Value lhs = subOp.getLhs(), rhs = subOp.getRhs();
1509 ExpressionValue lhsVal = lhsLat->getValue().getScalarValue();
1510 ExpressionValue rhsVal = rhsLat->getValue().getScalarValue();
1512 const Interval &currLhsInt = lhsVal.getInterval(), &currRhsInt = rhsVal.getInterval();
1514 Interval derivedLhsInt = newInterval + currRhsInt;
1515 Interval derivedRhsInt = currLhsInt - newInterval;
1517 Interval finalLhsInt = currLhsInt.intersect(derivedLhsInt);
1518 Interval finalRhsInt = currRhsInt.intersect(derivedRhsInt);
1520 applyInterval(subOp, lhs, finalLhsInt);
1521 applyInterval(subOp, rhs, finalRhsInt);
1524 auto selectCase = [&](arith::SelectOp selectOp) {
1525 Value cond = selectOp.getCondition();
1526 Value trueVal = selectOp.getTrueValue();
1527 Value falseVal = selectOp.getFalseValue();
1533 const Interval &condInterval = condExpr.getInterval();
1534 if (condInterval.isDegenerate() && condInterval.rhs() == f.one()) {
1535 applyInterval(selectOp, trueVal, newInterval);
1538 if (condInterval.isDegenerate() && condInterval.rhs() == f.zero()) {
1539 applyInterval(selectOp, falseVal, newInterval);
1543 Interval trueOverlap = trueExpr.getInterval().intersect(newInterval);
1544 Interval falseOverlap = falseExpr.getInterval().intersect(newInterval);
1545 bool truePossible = trueOverlap.isNotEmpty();
1546 bool falsePossible = falseOverlap.isNotEmpty();
1548 if (truePossible && !falsePossible) {
1550 applyInterval(selectOp, trueVal, newInterval);
1553 if (!truePossible && falsePossible) {
1555 applyInterval(selectOp, falseVal, newInterval);
1558 if (!truePossible && !falsePossible) {
1563 auto readmCase = [&](MemberReadOp) {
1564 SourceRefLatticeValue sourceRefVal = getSourceRefState(val);
1566 if (sourceRefVal.isSingleValue()) {
1567 const SourceRef &ref = sourceRefVal.getSingleValue();
1568 readResults[ref].insert(valLattice);
1571 for (Lattice *l : readResults[ref]) {
1572 if (l != valLattice) {
1573 propagateIfChanged(l, l->setValue(newLatticeVal));
1579 auto readArrCase = [&](ReadArrayOp) {
1580 auto arrayRef = getArrayAccessRef(valUser, llvm::cast<ReadArrayOp>(definingOp));
1581 if (succeeded(arrayRef)) {
1582 readResults[*arrayRef].insert(valLattice);
1584 for (Lattice *l : readResults[*arrayRef]) {
1585 if (l != valLattice) {
1586 propagateIfChanged(l, l->setValue(newLatticeVal));
1591 SourceRefLatticeValue sourceRefVal = getSourceRefState(val);
1593 if (sourceRefVal.isSingleValue()) {
1594 const SourceRef &ref = sourceRefVal.getSingleValue();
1595 readResults[ref].insert(valLattice);
1598 for (Lattice *l : readResults[ref]) {
1599 if (l != valLattice) {
1600 propagateIfChanged(l, l->setValue(newLatticeVal));
1607 auto castCase = [&](Operation *op) { applyInterval(op, op->getOperand(0), newInterval); };
1613 TypeSwitch<Operation *>(definingOp)
1614 .Case<CmpOp>([&](
auto op) { cmpCase(op); })
1615 .Case<AddFeltOp>([&](
auto op) {
return addCase(op); })
1616 .Case<SubFeltOp>([&](
auto op) {
return subCase(op); })
1617 .Case<MulFeltOp>([&](
auto op) { mulCase(op); })
1618 .Case<arith::SelectOp>([&](
auto op) { selectCase(op); })
1619 .Case<MemberReadOp>([&](
auto op){ readmCase(op); })
1620 .Case<ReadArrayOp>([&](
auto op){ readArrCase(op); })
1621 .Case<IntToFeltOp, FeltToIndexOp>([&](
auto op) { castCase(op); })
1622 .Default([&](Operation *) { });
1626 propagateIfChanged(valLattice, changed);
1629FailureOr<std::pair<DenseSet<Value>,
Interval>>
1630IntervalDataFlowAnalysis::getGeneralizedDecompInterval(
1631 Operation * , Value lhs, Value rhs
1633 auto isZeroConst = [
this](Value v) {
1634 Operation *op = v.getDefiningOp();
1638 if (!isConstOp(op)) {
1641 return getConst(op) == field.get().zero();
1643 bool lhsIsZero = isZeroConst(lhs), rhsIsZero = isZeroConst(rhs);
1644 Value exprTree =
nullptr;
1645 if (lhsIsZero && !rhsIsZero) {
1647 }
else if (!lhsIsZero && rhsIsZero) {
1654 std::optional<SourceRef> signalRef = std::nullopt;
1655 DenseSet<Value> signalVals;
1656 SmallVector<DynamicAPInt> consts;
1657 SmallVector<Value> frontier {exprTree};
1658 while (!frontier.empty()) {
1659 Value v = frontier.back();
1660 frontier.pop_back();
1661 Operation *op = v.getDefiningOp();
1665 auto handleRefValue = [
this, &signalRef, &signalVal, &signalVals]() {
1666 SourceRefLatticeValue refSet = getSourceRefState(signalVal);
1667 if (!refSet.isScalar() || !refSet.isSingleValue()) {
1670 SourceRef r = refSet.getSingleValue();
1671 if (signalRef.has_value() && signalRef.value() != r) {
1673 }
else if (!signalRef.has_value()) {
1676 signalVals.insert(signalVal);
1681 if (op && matchPattern(op, subPattern)) {
1682 if (failed(handleRefValue())) {
1685 auto constInt = APSInt(c.getValue());
1686 consts.push_back(field.get().reduce(constInt));
1688 }
else if (
m_RefValue(&signalVal).match(v)) {
1689 if (failed(handleRefValue())) {
1692 consts.push_back(field.get().zero());
1698 if (op && matchPattern(op, mulPattern)) {
1699 frontier.push_back(a);
1700 frontier.push_back(b);
1709 std::sort(consts.begin(), consts.end());
1710 Interval iv = UnreducedInterval(consts.front(), consts.back()).reduce(field.get());
1711 return std::make_pair(std::move(signalVals), iv);
1719 SymbolTableCollection tables;
1721 auto computeIntervalsImpl =
1722 [&solver, &am, &ctx, &tables,
this](
1723 FuncDefOp fn, llvm::MapVector<SourceRef, Interval> &memberRanges,
1724 llvm::MapVector<SourceRef, UnreducedInterval> &memberUnreducedRanges,
1725 llvm::SetVector<ExpressionValue> &
1727 auto setUnreducedRange =
1729 memberUnreducedRanges.erase(ref);
1730 memberUnreducedRanges.insert({ref, interval});
1743 searchSet.insert(ref);
1747 auto mergeInterval = [&memberRanges, &memberUnreducedRanges](
1749 std::optional<UnreducedInterval> unreducedInterval = std::nullopt
1751 auto *existing = memberRanges.find(ref);
1752 if (existing != memberRanges.end()) {
1754 bool intervalChanged = mergedInterval != existing->second;
1755 existing->second = mergedInterval;
1757 if (unreducedInterval.has_value()) {
1758 auto *existingUnreduced = memberUnreducedRanges.find(ref);
1759 if (existingUnreduced != memberUnreducedRanges.end()) {
1760 existingUnreduced->second = existingUnreduced->second.intersect(*unreducedInterval);
1762 memberUnreducedRanges.insert({ref, *unreducedInterval});
1764 }
else if (intervalChanged) {
1765 memberUnreducedRanges.erase(ref);
1770 memberRanges[ref] = interval;
1771 if (unreducedInterval.has_value()) {
1772 memberUnreducedRanges.insert({ref, *unreducedInterval});
1777 for (BlockArgument arg : fn.getArguments()) {
1779 if (searchSet.erase(ref)) {
1783 if (!expr.getExpr()) {
1786 expr = expr.withUnreducedInterval(expr.getInterval().firstUnreduced());
1789 memberRanges[ref] = expr.getInterval();
1790 if (expr.hasUnreducedInterval()) {
1791 setUnreducedRange(ref, expr.getUnreducedInterval());
1793 assert(memberRanges[ref].getField() == ctx.
getField() &&
"bad interval defaults");
1801 if (!lattices.empty() && searchSet.erase(ref)) {
1803 std::optional<UnreducedInterval> joinedUnreduced = std::nullopt;
1804 bool sawFirst =
false;
1816 memberRanges[ref] = joinedInterval;
1817 if (joinedUnreduced.has_value()) {
1818 setUnreducedRange(ref, *joinedUnreduced);
1820 assert(memberRanges[ref].getField() == ctx.
getField() &&
"bad interval defaults");
1825 if (searchSet.erase(ref)) {
1826 memberRanges[ref] = val.getInterval();
1827 if (val.hasUnreducedInterval()) {
1828 setUnreducedRange(ref, val.getUnreducedInterval());
1830 assert(memberRanges[ref].getField() == ctx.
getField() &&
"bad interval defaults");
1837 if (fn.isStructConstrain()) {
1838 auto mergeChildConstrainIntervals = [&](
CallOp fnCall) {
1853 auto calledStruct = calledFn->getParentOfType<
StructDefOp>();
1854 if (calledStruct == structDef) {
1859 if (childAnalysis.inProgress(ctx)) {
1862 if (!childAnalysis.constructed(ctx)) {
1864 succeeded(childAnalysis.runAnalysis(solver, am, ctx)),
1865 "could not construct interval analysis for child struct"
1872 llvm::MapVector<SourceRef, Interval> callOperandIntervals;
1873 for (
unsigned i = 0; i < calledFn.getNumArguments(); i++) {
1874 SourceRef prefix(calledFn.getArgument(i));
1875 Value operand = fnCall.getOperand(i);
1876 std::optional<SourceRefLatticeValue> identityVal =
1877 getIdentitySourceRefState(solver, operand);
1878 if (identityVal.has_value()) {
1879 identityTranslations.push_back({prefix, *identityVal});
1882 if (!llvm::isa<ArrayType, StructType>(operand.getType())) {
1885 if (lattice !=
nullptr) {
1887 callOperandIntervals[prefix] = expr.
getInterval();
1892 const StructIntervals &childIntervals = childAnalysis.getResult(ctx);
1895 for (
const auto &[childRef, childInterval] : constrainIntervals) {
1896 auto translatedRefs = translateRef(childRef, identityTranslations);
1897 if (failed(translatedRefs)) {
1901 std::optional<UnreducedInterval> childUnreduced = std::nullopt;
1902 if (
auto *childUnreducedIt = constrainUnreducedIntervals.find(childRef);
1903 childUnreducedIt != constrainUnreducedIntervals.end()) {
1904 childUnreduced = childUnreducedIt->second;
1908 for (
const SourceRef &translatedRef : *translatedRefs) {
1909 uniqueTranslatedRefs.insert(translatedRef);
1911 if (uniqueTranslatedRefs.size() != 1) {
1915 const SourceRef &translatedRef = *uniqueTranslatedRefs.begin();
1916 if (functionRefs.contains(translatedRef)) {
1917 mergeInterval(translatedRef, childInterval, childUnreduced);
1918 searchSet.erase(translatedRef);
1924 llvm::EquivalenceClasses<SourceRef> directEqRefs =
1925 collectDirectEqualityRefs(solver, calledFn);
1926 for (
auto leaderIt = directEqRefs.begin(); leaderIt != directEqRefs.end(); ++leaderIt) {
1927 if (!leaderIt->isLeader()) {
1931 llvm::MapVector<SourceRef, Interval> translatedEqRefs;
1933 bool hasInterval =
false;
1934 bool ambiguousTranslation =
false;
1936 for (
auto memberIt = directEqRefs.member_begin(leaderIt);
1937 memberIt != directEqRefs.member_end(); ++memberIt) {
1939 if (
const auto *childIntervalIt = constrainIntervals.find(*memberIt);
1940 childIntervalIt != constrainIntervals.end()) {
1941 memberInterval = memberInterval.
intersect(childIntervalIt->second);
1943 if (
auto *callOperandIt = callOperandIntervals.find(*memberIt);
1944 callOperandIt != callOperandIntervals.end()) {
1945 memberInterval = memberInterval.
intersect(callOperandIt->second);
1946 contextualInterval = contextualInterval.
intersect(memberInterval);
1950 auto translatedRefs = translateRef(*memberIt, identityTranslations);
1951 if (failed(translatedRefs)) {
1956 for (
const SourceRef &translatedRef : *translatedRefs) {
1957 uniqueTranslatedRefs.insert(translatedRef);
1959 if (uniqueTranslatedRefs.size() != 1) {
1960 ambiguousTranslation =
true;
1964 const SourceRef &translatedRef = *uniqueTranslatedRefs.begin();
1965 if (!functionRefs.contains(translatedRef)) {
1969 if (
auto *parentIntervalIt = memberRanges.find(translatedRef);
1970 parentIntervalIt != memberRanges.end()) {
1971 memberInterval = memberInterval.
intersect(parentIntervalIt->second);
1974 translatedEqRefs[translatedRef] = memberInterval;
1975 contextualInterval = contextualInterval.
intersect(memberInterval);
1979 if (ambiguousTranslation || !hasInterval || translatedEqRefs.empty()) {
1983 for (
const auto &[translatedRef, _] : translatedEqRefs) {
1984 mergeInterval(translatedRef, contextualInterval);
1985 searchSet.erase(translatedRef);
1990 fn.walk(mergeChildConstrainIntervals);
1994 for (
const auto &ref : searchSet) {
1997 setUnreducedRange(ref, memberRanges[ref].firstUnreduced());
2006 llvm::SmallVector<std::pair<SourceRef, Interval>> sortedRanges;
2007 sortedRanges.reserve(memberRanges.size());
2008 for (
const auto &[ref, interval] : memberRanges) {
2009 sortedRanges.emplace_back(ref, interval);
2011 llvm::sort(sortedRanges, [](
const auto &a,
const auto &b) {
return a.first < b.first; });
2012 llvm::SmallVector<std::pair<SourceRef, UnreducedInterval>> sortedUnreducedRanges;
2013 sortedUnreducedRanges.reserve(memberUnreducedRanges.size());
2014 for (
const auto &[ref, interval] : memberUnreducedRanges) {
2015 sortedUnreducedRanges.emplace_back(ref, interval);
2017 llvm::sort(sortedUnreducedRanges, [](
const auto &a,
const auto &b) {
2018 return a.first < b.first;
2020 memberRanges.clear();
2021 memberUnreducedRanges.clear();
2022 for (
auto &[ref, interval] : sortedRanges) {
2023 memberRanges[ref] = interval;
2025 for (
auto &[ref, interval] : sortedUnreducedRanges) {
2026 memberUnreducedRanges.insert({ref, interval});
2030 if (
auto computeFn = structDef.getComputeFuncOp()) {
2031 computeIntervalsImpl(
2032 computeFn, computeMemberRanges, computeMemberUnreducedRanges, computeSolverConstraints
2035 if (
auto constrainFn = structDef.getConstrainFuncOp()) {
2036 computeIntervalsImpl(
2037 constrainFn, constrainMemberRanges, constrainMemberUnreducedRanges,
2038 constrainSolverConstraints
2046 mlir::raw_ostream &os,
bool withConstraints,
bool printCompute,
bool printUnreduced
2048 auto writeIntervals =
2049 [&os, &withConstraints, &printUnreduced](
2050 const char *fnName,
const llvm::MapVector<SourceRef, Interval> &memberRanges,
2051 const llvm::MapVector<SourceRef, UnreducedInterval> &memberUnreducedRanges,
2052 const llvm::SetVector<ExpressionValue> &solverConstraints,
bool printName
2057 os.indent(indent) << fnName <<
" {";
2061 if (memberRanges.empty()) {
2066 for (
const auto &[ref, interval] : memberRanges) {
2068 os.indent(indent) << ref <<
" in " << interval;
2069 if (printUnreduced) {
2070 const auto *unreducedIt = memberUnreducedRanges.find(ref);
2071 if (unreducedIt != memberUnreducedRanges.end()) {
2072 os <<
" ( " << unreducedIt->second <<
" )";
2077 if (withConstraints) {
2079 os.indent(indent) <<
"Solver Constraints { ";
2080 if (solverConstraints.empty()) {
2083 for (
const auto &e : solverConstraints) {
2085 os.indent(indent + 4);
2086 e.getExpr()->print(os);
2089 os.indent(indent) <<
'}';
2095 os.indent(indent - 4) <<
'}';
2099 os <<
"StructIntervals { ";
2100 if (constrainMemberRanges.empty() && (!printCompute || computeMemberRanges.empty())) {
2108 computeSolverConstraints, printCompute
2113 constrainSolverConstraints, printCompute
Tracks a solver expression and an interval range for that expression.
ExpressionValue withUnreducedInterval(const UnreducedInterval &newUnreducedInterval) const
ExpressionValue withExpression(const llvm::SMTExprRef &newExpr) const
Return the current expression with a new SMT expression.
const Interval & getInterval() const
const std::optional< UnreducedInterval > & getOptionalUnreducedInterval() const
ExpressionValue withOptionalUnreducedInterval(std::optional< UnreducedInterval > newUnreducedInterval) const
ExpressionValue withInterval(const Interval &newInterval) const
Return the current expression with a new interval.
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
llvm::SMTExprRef getExpr() const
bool isBoolSort(const llvm::SMTSolverRef &solver) const
bool hasUnreducedInterval() const
const Field & getField() const
const UnreducedInterval & getUnreducedInterval() const
Information about the prime finite field used for the interval analysis.
llvm::DynamicAPInt zero() const
Returns 0 at the bitwidth of the field.
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
llvm::DynamicAPInt one() const
Returns 1 at the bitwidth of the field.
llvm::DynamicAPInt inv(const llvm::DynamicAPInt &i) const
Returns the multiplicative inverse of i in prime field p.
unsigned bitWidth() const
llvm::SMTExprRef createSymbol(const llvm::SMTSolverRef &solver, const char *name) const
Create a SMT solver symbol with the current field's bitwidth.
llvm::DynamicAPInt maxVal() const
Returns p - 1, which is the max value possible in a prime field described by p.
const LatticeValue & getValue() const
mlir::ChangeResult setValue(const LatticeValue &val)
IntervalAnalysisLatticeValue LatticeValue
mlir::ChangeResult meet(const AbstractSparseLattice &other) override
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult join(const AbstractSparseLattice &other) override
mlir::ChangeResult addSolverConstraint(const ExpressionValue &e)
mlir::LogicalResult visitOperation(mlir::Operation *op, mlir::ArrayRef< const Lattice * > operands, mlir::ArrayRef< Lattice * > results) override
Visit an operation with the lattices of its operands.
llvm::SMTExprRef getOrCreateSymbol(const SourceRef &r)
Either return the existing SMT expression that corresponds to the SourceRef, or create one.
const llvm::DenseMap< SourceRef, ExpressionValue > & getWriteResults() const
const llvm::DenseMap< SourceRef, llvm::DenseSet< Lattice * > > & getReadResults() const
Intervals over a finite field.
static Interval True(const Field &f)
llvm::DynamicAPInt rhs() const
Interval intersect(const Interval &rhs) const
Intersect.
UnreducedInterval toUnreduced() const
Convert to an UnreducedInterval.
static Interval Boolean(const Field &f)
UnreducedInterval firstUnreduced() const
Get the first side of the interval for TypeF intervals, otherwise just get the full interval as an Un...
static Interval Entire(const Field &f)
bool isDegenerate() const
static Interval False(const Field &f)
llvm::DynamicAPInt lhs() const
Interval join(const Interval &rhs) const
Union.
static SourceRefLatticeValue getValueState(mlir::DataFlowSolver &solver, mlir::Value val)
Defines an index into an LLZK object.
A value at a given point of the SourceRefLattice.
const SourceRef & getSingleValue() const
mlir::FailureOr< std::pair< SourceRefLatticeValue, mlir::ChangeResult > > extract(const std::vector< SourceRefIndex > &indices) const
Perform an array.extract or array.read operation, depending on how many indices are provided.
A reference to a "source", which is the base value from which other SSA values are derived.
mlir::FailureOr< SourceRef > createChild(const SourceRefIndex &r) const
std::vector< SourceRefIndex > Path
static std::vector< SourceRef > getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, const SourceRef &root)
Produce all possible SourceRefs that are present starting from the given root.
mlir::Type getType() const
const llvm::MapVector< SourceRef, Interval > & getConstrainIntervals() const
const llvm::MapVector< SourceRef, UnreducedInterval > & getConstrainUnreducedIntervals() const
void print(mlir::raw_ostream &os, bool withConstraints=false, bool printCompute=false, bool printUnreduced=false) const
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, const IntervalAnalysisContext &ctx)
An inclusive interval [a, b] where a and b are arbitrary integers not necessarily bound to a given fi...
UnreducedInterval computeLTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is guaranteed to be less than the rhs's max value.
UnreducedInterval computeGEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than or equal to the rhs's lower bound.
UnreducedInterval computeGTPart(const UnreducedInterval &rhs) const
Return the part of the interval that is greater than the rhs's lower bound.
Interval reduce(const Field &field) const
Reduce the interval to an interval in the given field.
UnreducedInterval computeLEPart(const UnreducedInterval &rhs) const
Return the part of the interval that is less than or equal to the rhs's upper bound.
Helper for converting between linear and multi-dimensional indexing with checks to ensure indices are...
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
::mlir::Type getElementType() const
::llzk::boolean::FeltCmpPredicate getPredicate()
bool isSingleValue() const
const ScalarTy & getScalarValue() const
IntervalAnalysisLattice * getLatticeElement(mlir::Value value) override
bool isStructConstrain()
Return true iff the function is within a StructDefOp and named FUNC_NAME_CONSTRAIN.
bool isOperationLive(DataFlowSolver &solver, Operation *op)
ExpressionValue boolNot(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
ExpressionValue add(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
ExpressionValue sintMod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
RefValueCapture m_RefValue()
ExpressionValue intersection(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
FailureOr< Interval > signedIntDiv(const Interval &lhs, const Interval &rhs)
Computes signed integer division with possibly non-Degenerate divisors.
std::vector< std::pair< SourceRef, SourceRefLatticeValue > > SourceRefRemappings
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue shiftLeft(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue fallbackUnaryOp(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &val)
constexpr char FUNC_NAME_CONSTRAIN[]
Interval signedMod(const Interval &lhs, const Interval &rhs)
Computes signed integer remainder with possibly non-Degenerate divisors.
void ensure(bool condition, const llvm::Twine &errMsg)
ExpressionValue boolXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue cmp(const llvm::SMTSolverRef &solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue neg(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
DynamicAPInt toDynamicAPInt(StringRef str)
llvm::SMTExprRef createFieldInverseExpr(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &val, StringRef suffix="")
ExpressionValue sintDiv(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue boolAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
FailureOr< Interval > unsignedIntDiv(const Interval &lhs, const Interval &rhs)
Computes unsigned integer division with possibly non-Degenerate divisors.
ExpressionValue div(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue mul(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
ExpressionValue boolToFelt(const llvm::SMTSolverRef &solver, const ExpressionValue &expr, unsigned bitwidth)
mlir::FailureOr< SymbolLookupResult< T > > resolveCallableSilently(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Resolve a callable without emitting a diagnostic for missing top-level symbols.
ConstantCapture m_Constant()
std::string buildStringViaInsertionOp(Args &&...args)
Generate a string by using the insertion operator (<<) to append all args to a stream backed by the r...
ExpressionValue bitOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue uintDiv(const llvm::SMTSolverRef &solver, Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue bitAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue shiftRight(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
APSInt toAPSInt(const DynamicAPInt &i)
ExpressionValue sub(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
auto m_CommutativeOp(LhsMatcher lhs, RhsMatcher rhs)
ExpressionValue bitXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue notOp(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
FailureOr< Interval > feltDiv(const Interval &lhs, const Interval &rhs)
Computes finite-field division by multiplying the dividend by the multiplicative inverse of the divis...
ExpressionValue boolOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue selectValue(const llvm::SMTSolverRef &solver, const ExpressionValue &cond, const ExpressionValue &trueVal, const ExpressionValue &falseVal)
Parameters and shared objects to pass to child analyses.
const Field & getField() const
IntervalDataFlowAnalysis * intervalDFA
bool doTrackUnreducedIntervals() const