28#include <mlir/Analysis/DataFlow/DenseAnalysis.h>
29#include <mlir/IR/BuiltinOps.h>
30#include <mlir/Pass/AnalysisManager.h>
31#include <mlir/Support/LLVM.h>
33#include <llvm/ADT/DynamicAPInt.h>
34#include <llvm/ADT/MapVector.h>
35#include <llvm/ADT/ScopeExit.h>
36#include <llvm/Support/SMTAPI.h>
41#include <unordered_set>
55 : i(
Interval::Entire(f)), expr(nullptr), unreduced(std::nullopt) {}
58 : i(
Interval::Entire(f)), expr(exprRef), unreduced(std::nullopt) {}
61 : i(
Interval::Degenerate(f, singleVal)), expr(exprRef), unreduced(std::nullopt) {}
64 llvm::SMTExprRef exprRef,
const Interval &interval,
65 std::optional<UnreducedInterval> unreducedInterval = std::nullopt
67 : i(interval), expr(exprRef), unreduced(std::move(unreducedInterval)) {}
69 llvm::SMTExprRef
getExpr()
const {
return expr; }
78 ensure(unreduced.has_value(),
"unreduced interval not set");
111 unreduced = std::nullopt;
118 return solver->getBoolSort() == solver->getSort(expr);
156 const llvm::SMTSolverRef &solver, mlir::Operation *op,
const ExpressionValue &lhs,
161 const llvm::SMTSolverRef &solver, mlir::Operation *op,
const ExpressionValue &lhs,
205 const llvm::SMTSolverRef &solver, mlir::Operation *op,
const ExpressionValue &val
210 void print(mlir::raw_ostream &os)
const;
220 std::hash<bool> {}(e.unreduced.has_value()) ^
227 llvm::SMTExprRef expr;
228 std::optional<UnreducedInterval> unreduced;
256 using ValueMap = mlir::DenseMap<mlir::Value, LatticeValue>;
259 using MemberMap = mlir::DenseMap<mlir::Value, mlir::DenseMap<mlir::StringAttr, LatticeValue>>;
265 using AbstractSparseLattice::AbstractSparseLattice;
267 mlir::ChangeResult
join(
const AbstractSparseLattice &other)
override;
269 mlir::ChangeResult
meet(
const AbstractSparseLattice &other)
override;
271 void print(mlir::raw_ostream &os)
const override;
304 using SymbolMap = mlir::DenseMap<SourceRef, llvm::SMTExprRef>;
308 mlir::DataFlowSolver &dataflowSolver, llvm::SMTSolverRef
smt,
const Field &f,
309 bool propInputConstraints,
bool shouldTrackUnreducedIntervals
312 smtSolver(std::move(
smt)), field(f), propagateInputConstraints(propInputConstraints),
313 trackUnreducedIntervals(shouldTrackUnreducedIntervals) {}
316 mlir::Operation *op, mlir::ArrayRef<const Lattice *> operands,
317 mlir::ArrayRef<Lattice *> results
326 const llvm::DenseMap<SourceRef, llvm::DenseSet<Lattice *>> &
getReadResults()
const {
330 const llvm::DenseMap<SourceRef, ExpressionValue> &
getWriteResults()
const {
return writeResults; }
333 mlir::DataFlowSolver &_dataflowSolver;
334 llvm::SMTSolverRef smtSolver;
335 SymbolMap refSymbols;
336 std::reference_wrapper<const Field> field;
337 bool propagateInputConstraints;
338 bool trackUnreducedIntervals;
339 mlir::SymbolTableCollection
tables;
342 llvm::DenseMap<SourceRef, llvm::DenseSet<Lattice *>> readResults;
344 llvm::DenseMap<SourceRef, ExpressionValue> writeResults;
346 void setToEntryState(Lattice *lattice)
override {
351 static bool isBooleanType(mlir::Type ty) {
352 if (
auto intTy = llvm::dyn_cast<mlir::IntegerType>(ty)) {
353 return intTy.getWidth() == 1;
358 Interval getDefaultIntervalForType(mlir::Type ty)
const {
359 return isBooleanType(ty) ?
Interval::Boolean(field.get()) : Interval::Entire(field.get());
362 std::optional<UnreducedInterval> getDefaultUnreducedIntervalForType(mlir::Type ty)
const;
364 std::optional<UnreducedInterval> getRefUnreducedInterval(
const SourceRef &ref);
366 llvm::SMTExprRef createSymbol(mlir::Type ty,
const char *name)
const;
368 llvm::SMTExprRef createSymbol(
const SourceRef &r)
const;
370 llvm::SMTExprRef createSymbol(mlir::Value val)
const;
372 ExpressionValue createUnknownValue(mlir::Value val)
const {
373 return ExpressionValue(
374 createSymbol(val), getDefaultIntervalForType(val.getType()),
375 getDefaultUnreducedIntervalForType(val.getType())
379 inline bool isConstOp(mlir::Operation *op)
const {
381 felt::FeltConstantOp, mlir::arith::ConstantIndexOp, mlir::arith::ConstantIntOp>(op);
384 inline bool isBoolConstOp(mlir::Operation *op)
const {
385 if (
auto constIntOp = llvm::dyn_cast<mlir::arith::ConstantIntOp>(op)) {
386 auto valAttr = dyn_cast<mlir::IntegerAttr>(constIntOp.getValue());
387 ensure(valAttr !=
nullptr,
"arith::ConstantIntOp must have an IntegerAttr as its value");
388 return valAttr.getValue().getBitWidth() == 1;
393 llvm::DynamicAPInt getConst(mlir::Operation *op)
const;
395 inline llvm::SMTExprRef createConstBitvectorExpr(
const llvm::DynamicAPInt &v)
const {
396 return createConstBitvectorExpr(
toAPSInt(v));
399 inline llvm::SMTExprRef createConstBitvectorExpr(
const llvm::APSInt &v)
const {
400 return smtSolver->mkBitvector(v, field.get().bitWidth());
403 llvm::SMTExprRef createConstBoolExpr(
bool v)
const {
return smtSolver->mkBoolean(v); }
405 bool isArithmeticOp(mlir::Operation *op)
const {
407 felt::AddFeltOp, felt::SubFeltOp, felt::MulFeltOp, felt::DivFeltOp, felt::UnsignedModFeltOp,
408 felt::SignedModFeltOp, felt::SignedIntDivFeltOp, felt::UnsignedIntDivFeltOp,
409 mlir::arith::XOrIOp, felt::NegFeltOp, felt::InvFeltOp, felt::AndFeltOp, felt::OrFeltOp,
410 felt::XorFeltOp, felt::NotFeltOp, felt::ShlFeltOp, felt::ShrFeltOp, boolean::CmpOp,
411 boolean::AndBoolOp, boolean::OrBoolOp, boolean::XorBoolOp, boolean::NotBoolOp>(op);
415 performBinaryArithmetic(mlir::Operation *op,
const LatticeValue &a,
const LatticeValue &b);
417 ExpressionValue performUnaryArithmetic(mlir::Operation *op,
const LatticeValue &a);
425 void applyInterval(mlir::Operation *originalOp, mlir::Value val, Interval newInterval);
428 mlir::FailureOr<std::pair<llvm::DenseSet<mlir::Value>, Interval>>
429 getGeneralizedDecompInterval(mlir::Operation *baseOp, mlir::Value lhs, mlir::Value rhs);
431 bool isReadOp(mlir::Operation *op)
const {
432 return llvm::isa<component::MemberReadOp, polymorphic::ConstReadOp, array::ReadArrayOp>(op);
435 bool isDefinitionOp(mlir::Operation *op)
const {
437 component::StructDefOp, function::FuncDefOp, component::MemberDefOp, global::GlobalDefOp,
441 bool isReturnOp(mlir::Operation *op)
const {
return llvm::isa<function::ReturnOp>(op); }
446 std::vector<SourceRefIndex>
447 getArrayAccessIndices(mlir::Operation *baseOp, array::ArrayAccessOpInterface arrayAccessOp);
451 mlir::FailureOr<SourceRef>
452 getArrayAccessRef(mlir::Operation *baseOp, array::ArrayAccessOpInterface arrayAccessOp);
456 Interval getRefInterval(
const SourceRef &ref);
461 ExpressionValue getRefValue(
const SourceRef &ref, mlir::Value val);
469 const SourceRef &writtenRef,
const ExpressionValue &writeVal,
bool mayBeSkipped =
false
473 SourceRefLatticeValue getSourceRefState(mlir::Value val);
482 std::optional<std::reference_wrapper<const Field>>
field;
489 ensure(
field.has_value(),
"field not set within context");
501template <>
struct std::hash<
llzk::IntervalAnalysisContext> {
503 return llvm::hash_combine(
504 std::hash<const llzk::IntervalDataFlowAnalysis *> {}(c.
intervalDFA),
505 std::hash<const llvm::SMTSolver *> {}(c.
smtSolver.get()),
506 std::hash<const llzk::Field *> {}(&c.
getField()),
517class StructIntervals {
528 static mlir::FailureOr<StructIntervals>
compute(
532 StructIntervals si(mod, s);
534 return mlir::failure();
544 mlir::raw_ostream &os,
bool withConstraints =
false,
bool printCompute =
false,
545 bool printUnreduced =
false
549 return constrainMemberRanges;
553 return constrainMemberUnreducedRanges;
557 return constrainSolverConstraints;
561 return computeMemberRanges;
565 return computeMemberUnreducedRanges;
569 return computeSolverConstraints;
572 friend mlir::raw_ostream &
operator<<(mlir::raw_ostream &os,
const StructIntervals &si) {
580 llvm::SMTSolverRef smtSolver;
582 llvm::MapVector<SourceRef, Interval> constrainMemberRanges, computeMemberRanges;
583 llvm::MapVector<SourceRef, UnreducedInterval> constrainMemberUnreducedRanges,
584 computeMemberUnreducedRanges;
586 llvm::SetVector<ExpressionValue> constrainSolverConstraints, computeSolverConstraints;
601 return inProgressContexts.contains(ctx);
608 return mlir::failure();
610 inProgressContexts.insert(ctx);
611 auto cleanup = llvm::make_scope_exit([
this, &ctx] { inProgressContexts.erase(ctx); });
614 if (mlir::failed(computeRes)) {
615 return mlir::failure();
618 return mlir::success();
622 std::unordered_set<IntervalAnalysisContext> inProgressContexts;
628 :
public ModuleAnalysis<StructIntervals, IntervalAnalysisContext, StructIntervalAnalysis> {
634 ctx.smtSolver = llvm::CreateZ3Solver();
644 ensure(ctx.hasField(),
"field not set, could not generate analysis context");
646 auto smtSolverRef = ctx.smtSolver;
647 bool prop = ctx.propagateInputConstraints;
648 bool track = ctx.trackUnreducedIntervals;
651 std::move(smtSolverRef), ctx.getField(),
658 ensure(ctx.field.has_value(),
"field not set, could not generate analysis context");
670template <>
struct DenseMapInfo<
llzk::ExpressionValue> {
673 static const auto *emptyPtr =
reinterpret_cast<SMTExprRef
>(1);
677 static const auto *tombstonePtr =
reinterpret_cast<SMTExprRef
>(2);
Convenience classes for a frequent pattern of dataflow analysis used in LLZK, where an analysis is ru...
This file implements sparse data-flow analysis using the data-flow analysis framework.
Tracks a solver expression and an interval range for that expression.
friend ExpressionValue boolAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue withUnreducedInterval(const UnreducedInterval &newUnreducedInterval) const
friend ExpressionValue sintDiv(const llvm::SMTSolverRef &solver, mlir::Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue withExpression(const llvm::SMTExprRef &newExpr) const
Return the current expression with a new SMT expression.
friend ExpressionValue fallbackUnaryOp(const llvm::SMTSolverRef &solver, mlir::Operation *op, const ExpressionValue &val)
friend ExpressionValue bitOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue notOp(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
friend ExpressionValue shiftRight(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue intersection(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Compute the intersection of the lhs and rhs intervals, and create a solver expression that constrains...
ExpressionValue(llvm::SMTExprRef exprRef, const Interval &interval, std::optional< UnreducedInterval > unreducedInterval=std::nullopt)
friend ExpressionValue uintDiv(const llvm::SMTSolverRef &solver, mlir::Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
const Interval & getInterval() const
friend ExpressionValue mul(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue div(const llvm::SMTSolverRef &solver, mlir::Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue(const Field &f, llvm::SMTExprRef exprRef)
const std::optional< UnreducedInterval > & getOptionalUnreducedInterval() const
ExpressionValue(const Field &f)
ExpressionValue withOptionalUnreducedInterval(std::optional< UnreducedInterval > newUnreducedInterval) const
friend ExpressionValue sub(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue withInterval(const Interval &newInterval) const
Return the current expression with a new interval.
friend ExpressionValue boolXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
friend ExpressionValue shiftLeft(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue bitXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue bitAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::SMTExprRef getExpr() const
bool isBoolSort(const llvm::SMTSolverRef &solver) const
friend ExpressionValue cmp(const llvm::SMTSolverRef &solver, boolean::CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue join(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Compute the union of the lhs and rhs intervals, and create a solver expression that constrains both s...
bool hasUnreducedInterval() const
ExpressionValue(const Field &f, llvm::SMTExprRef exprRef, const llvm::DynamicAPInt &singleVal)
friend ExpressionValue boolOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue boolNot(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ExpressionValue &e)
const Field & getField() const
const UnreducedInterval & getUnreducedInterval() const
ExpressionValue & join(const ExpressionValue &)
Fold two expressions together when overapproximating array elements.
friend ExpressionValue neg(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
friend ExpressionValue add(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue dropUnreducedInterval() const
Information about the prime finite field used for the interval analysis.
static const Field & getField(llvm::StringRef fieldName, EmitErrorFn errFn)
Get a Field from a given field name string.
IntervalAnalysisLatticeValue & operator=(const IntervalAnalysisLatticeValue &)=default
IntervalAnalysisLatticeValue(IntervalAnalysisLatticeValue &&)=default
IntervalAnalysisLatticeValue(const IntervalAnalysisLatticeValue &)=default
IntervalAnalysisLatticeValue(mlir::ArrayRef< int64_t > shape)
IntervalAnalysisLatticeValue & operator=(IntervalAnalysisLatticeValue &&)=default
IntervalAnalysisLatticeValue(ExpressionValue e)
IntervalAnalysisLatticeValue()
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const IntervalAnalysisLattice &l)
const LatticeValue & getValue() const
llvm::SetVector< ExpressionValue > ConstraintSet
mlir::ChangeResult setValue(const LatticeValue &val)
IntervalAnalysisLatticeValue LatticeValue
mlir::DenseMap< mlir::Value, LatticeValue > ValueMap
mlir::ChangeResult meet(const AbstractSparseLattice &other) override
const ConstraintSet & getConstraints() const
mlir::DenseMap< mlir::Value, mlir::DenseMap< mlir::StringAttr, LatticeValue > > MemberMap
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult join(const AbstractSparseLattice &other) override
mlir::ChangeResult setInterval(llvm::SMTExprRef expr, const Interval &i)
mlir::DenseMap< llvm::SMTExprRef, Interval > ExpressionIntervals
mlir::ChangeResult addSolverConstraint(const ExpressionValue &e)
mlir::FailureOr< Interval > findInterval(llvm::SMTExprRef expr) const
mlir::LogicalResult visitOperation(mlir::Operation *op, mlir::ArrayRef< const Lattice * > operands, mlir::ArrayRef< Lattice * > results) override
Visit an operation with the lattices of its operands.
llvm::SMTExprRef getOrCreateSymbol(const SourceRef &r)
Either return the existing SMT expression that corresponds to the SourceRef, or create one.
const llvm::DenseMap< SourceRef, ExpressionValue > & getWriteResults() const
const llvm::DenseMap< SourceRef, llvm::DenseSet< Lattice * > > & getReadResults() const
IntervalDataFlowAnalysis(mlir::DataFlowSolver &dataflowSolver, llvm::SMTSolverRef smt, const Field &f, bool propInputConstraints, bool shouldTrackUnreducedIntervals)
Intervals over a finite field.
static Interval Boolean(const Field &f)
mlir::DataFlowSolver solver
ModuleAnalysis(mlir::Operation *op, const mlir::DataFlowConfig &config=mlir::DataFlowConfig())
ModuleIntervalAnalysis(mlir::Operation *op)
~ModuleIntervalAnalysis() override=default
void setPropagateInputConstraints(bool prop)
void initializeSolver() override
Initialize the shared dataflow solver with any common analyses required by the contained struct analy...
void setTrackUnreducedIntervals(bool track)
const IntervalAnalysisContext & getContext() const override
Return the current Context object.
void setField(const Field &f)
The dataflow analysis that computes the set of references that LLZK operations use and produce.
A reference to a "source", which is the base value from which other SSA values are derived.
StructAnalysis(mlir::Operation *op)
Assert that this analysis is being run on a StructDefOp and initializes the analysis with the current...
component::StructDefOp getStruct() const
void setResult(const IntervalAnalysisContext &ctx, StructIntervals &&r)
mlir::ModuleOp getModule() const
mlir::LogicalResult runAnalysis(mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, const IntervalAnalysisContext &ctx) override
Perform the analysis and construct the Result output.
~StructIntervalAnalysis() override=default
StructAnalysis(mlir::Operation *op)
Assert that this analysis is being run on a StructDefOp and initializes the analysis with the current...
bool inProgress(const IntervalAnalysisContext &ctx) const
const llvm::MapVector< SourceRef, Interval > & getConstrainIntervals() const
static mlir::FailureOr< StructIntervals > compute(mlir::ModuleOp mod, component::StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, const IntervalAnalysisContext &ctx)
Compute the struct intervals.
const llvm::SetVector< ExpressionValue > getConstrainSolverConstraints() const
const llvm::MapVector< SourceRef, UnreducedInterval > & getConstrainUnreducedIntervals() const
const llvm::SetVector< ExpressionValue > getComputeSolverConstraints() const
const llvm::MapVector< SourceRef, Interval > & getComputeIntervals() const
void print(mlir::raw_ostream &os, bool withConstraints=false, bool printCompute=false, bool printUnreduced=false) const
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const StructIntervals &si)
const llvm::MapVector< SourceRef, UnreducedInterval > & getComputeUnreducedIntervals() const
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, const IntervalAnalysisContext &ctx)
An inclusive interval [a, b] where a and b are arbitrary integers not necessarily bound to a given fi...
mlir::SymbolTableCollection tables
LLZK: Added for use of symbol helper caching.
A sparse forward data-flow analysis for propagating SSA value lattices across the IR by implementing ...
SparseForwardDataFlowAnalysis(mlir::DataFlowSolver &s)
mlir::dataflow::AbstractSparseLattice AbstractSparseLattice
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
void ensure(bool condition, const llvm::Twine &errMsg)
APSInt toAPSInt(const DynamicAPInt &i)
static unsigned getHashValue(const llzk::ExpressionValue &e)
static SMTExprRef getTombstoneExpr()
static bool isEqual(const llzk::ExpressionValue &lhs, const llzk::ExpressionValue &rhs)
static llzk::ExpressionValue getTombstoneKey()
static llzk::ExpressionValue getEmptyKey()
static SMTExprRef getEmptyExpr()
unsigned operator()(const ExpressionValue &e) const
Parameters and shared objects to pass to child analyses.
bool trackUnreducedIntervals
bool doInputConstraintPropagation() const
const Field & getField() const
friend bool operator==(const IntervalAnalysisContext &a, const IntervalAnalysisContext &b)=default
std::optional< std::reference_wrapper< const Field > > field
IntervalDataFlowAnalysis * intervalDFA
bool doTrackUnreducedIntervals() const
bool propagateInputConstraints
llvm::SMTSolverRef smtSolver
llvm::SMTExprRef getSymbol(const SourceRef &r) const
size_t operator()(const llzk::IntervalAnalysisContext &c) const