LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
IntervalAnalysis.h
Go to the documentation of this file.
1//===-- IntervalAnalysis.h --------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
25#include "llzk/Util/Compare.h"
26#include "llzk/Util/Field.h"
27
28#include <mlir/Analysis/DataFlow/DenseAnalysis.h>
29#include <mlir/IR/BuiltinOps.h>
30#include <mlir/Pass/AnalysisManager.h>
31#include <mlir/Support/LLVM.h>
32
33#include <llvm/ADT/DynamicAPInt.h>
34#include <llvm/ADT/MapVector.h>
35#include <llvm/ADT/ScopeExit.h>
36#include <llvm/Support/SMTAPI.h>
37
38#include <array>
39#include <mutex>
40#include <optional>
41#include <unordered_set>
42
43namespace llzk {
44
45/* ExpressionValue */
46
50public:
51 /* Must be default initializable to be a ScalarLatticeValue. */
52 ExpressionValue() : i(), expr(nullptr), unreduced(std::nullopt) {}
53
54 explicit ExpressionValue(const Field &f)
55 : i(Interval::Entire(f)), expr(nullptr), unreduced(std::nullopt) {}
56
57 ExpressionValue(const Field &f, llvm::SMTExprRef exprRef)
58 : i(Interval::Entire(f)), expr(exprRef), unreduced(std::nullopt) {}
59
60 ExpressionValue(const Field &f, llvm::SMTExprRef exprRef, const llvm::DynamicAPInt &singleVal)
61 : i(Interval::Degenerate(f, singleVal)), expr(exprRef), unreduced(std::nullopt) {}
62
64 llvm::SMTExprRef exprRef, const Interval &interval,
65 std::optional<UnreducedInterval> unreducedInterval = std::nullopt
66 )
67 : i(interval), expr(exprRef), unreduced(std::move(unreducedInterval)) {}
68
69 llvm::SMTExprRef getExpr() const { return expr; }
70
71 const Interval &getInterval() const { return i; }
72
73 bool hasUnreducedInterval() const { return unreduced.has_value(); }
74
75 const std::optional<UnreducedInterval> &getOptionalUnreducedInterval() const { return unreduced; }
76
78 ensure(unreduced.has_value(), "unreduced interval not set");
79 return *unreduced;
80 }
81
82 const Field &getField() const { return i.getField(); }
83
87 ExpressionValue withInterval(const Interval &newInterval) const {
88 return ExpressionValue(expr, newInterval, unreduced);
89 }
90
92 ExpressionValue withExpression(const llvm::SMTExprRef &newExpr) const {
93 return ExpressionValue(newExpr, i, unreduced);
94 }
95
96 ExpressionValue withUnreducedInterval(const UnreducedInterval &newUnreducedInterval) const {
97 return ExpressionValue(expr, i, newUnreducedInterval);
98 }
99
101 withOptionalUnreducedInterval(std::optional<UnreducedInterval> newUnreducedInterval) const {
102 return ExpressionValue(expr, i, std::move(newUnreducedInterval));
103 }
104
105 ExpressionValue dropUnreducedInterval() const { return ExpressionValue(expr, i, std::nullopt); }
106
107 /* Required to be a ScalarLatticeValue. */
111 unreduced = std::nullopt;
112 return *this;
113 }
114
115 bool operator==(const ExpressionValue &rhs) const;
116
117 bool isBoolSort(const llvm::SMTSolverRef &solver) const {
118 return solver->getBoolSort() == solver->getSort(expr);
119 }
120
128 const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs
129 );
130
137 friend ExpressionValue
138 join(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
139
140 // arithmetic ops
141
142 friend ExpressionValue
143 add(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
144
145 friend ExpressionValue
146 sub(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
147
148 friend ExpressionValue
149 mul(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
150
151 friend ExpressionValue
152 div(const llvm::SMTSolverRef &solver, mlir::Operation *op, const ExpressionValue &lhs,
153 const ExpressionValue &rhs);
154
156 const llvm::SMTSolverRef &solver, mlir::Operation *op, const ExpressionValue &lhs,
157 const ExpressionValue &rhs
158 );
159
161 const llvm::SMTSolverRef &solver, mlir::Operation *op, const ExpressionValue &lhs,
162 const ExpressionValue &rhs
163 );
164
165 friend ExpressionValue
166 mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
167
168 friend ExpressionValue
169 bitAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
170
171 friend ExpressionValue
172 bitOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
173
174 friend ExpressionValue
175 bitXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
176
178 const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs
179 );
180
182 const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs
183 );
184
185 friend ExpressionValue
186 cmp(const llvm::SMTSolverRef &solver, boolean::CmpOp op, const ExpressionValue &lhs,
187 const ExpressionValue &rhs);
188
189 friend ExpressionValue
190 boolAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
191
192 friend ExpressionValue
193 boolOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
194
195 friend ExpressionValue
196 boolXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs);
197
198 friend ExpressionValue neg(const llvm::SMTSolverRef &solver, const ExpressionValue &val);
199
200 friend ExpressionValue notOp(const llvm::SMTSolverRef &solver, const ExpressionValue &val);
201
202 friend ExpressionValue boolNot(const llvm::SMTSolverRef &solver, const ExpressionValue &val);
203
205 const llvm::SMTSolverRef &solver, mlir::Operation *op, const ExpressionValue &val
206 );
207
208 /* Utility */
209
210 void print(mlir::raw_ostream &os) const;
211
212 friend mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const ExpressionValue &e) {
213 e.print(os);
214 return os;
215 }
216
217 struct Hash {
218 unsigned operator()(const ExpressionValue &e) const {
219 return Interval::Hash {}(e.i) ^ llvm::hash_value(e.expr) ^
220 std::hash<bool> {}(e.unreduced.has_value()) ^
221 (e.unreduced.has_value() ? UnreducedInterval::Hash {}(*e.unreduced) : 0U);
222 }
223 };
224
225private:
226 Interval i;
227 llvm::SMTExprRef expr;
228 std::optional<UnreducedInterval> unreduced;
229};
230
231/* IntervalAnalysisLatticeValue */
232
233// NOLINTNEXTLINE(bugprone-exception-escape)
247
248/* IntervalAnalysisLattice */
249
251
253public:
255 // Map mlir::Values to LatticeValues
256 using ValueMap = mlir::DenseMap<mlir::Value, LatticeValue>;
257 // Map member references to LatticeValues. Used for member reads and writes.
258 // Structure is component value -> member attribute -> latticeValue
259 using MemberMap = mlir::DenseMap<mlir::Value, mlir::DenseMap<mlir::StringAttr, LatticeValue>>;
260 // Expression to interval map for convenience.
261 using ExpressionIntervals = mlir::DenseMap<llvm::SMTExprRef, Interval>;
262 // Tracks all constraints and assignments in insertion order
263 using ConstraintSet = llvm::SetVector<ExpressionValue>;
264
265 using AbstractSparseLattice::AbstractSparseLattice;
266
267 mlir::ChangeResult join(const AbstractSparseLattice &other) override;
268
269 mlir::ChangeResult meet(const AbstractSparseLattice &other) override;
270
271 void print(mlir::raw_ostream &os) const override;
272
273 const LatticeValue &getValue() const { return val; }
274
275 mlir::ChangeResult setValue(const LatticeValue &val);
276 mlir::ChangeResult setValue(const ExpressionValue &e);
277
278 mlir::ChangeResult addSolverConstraint(const ExpressionValue &e);
279
280 friend mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const IntervalAnalysisLattice &l) {
281 l.print(os);
282 return os;
283 }
284
285 const ConstraintSet &getConstraints() const { return constraints; }
286
287 mlir::FailureOr<Interval> findInterval(llvm::SMTExprRef expr) const;
288 mlir::ChangeResult setInterval(llvm::SMTExprRef expr, const Interval &i);
289
290private:
291 LatticeValue val;
292 ConstraintSet constraints;
293};
294
295/* IntervalDataFlowAnalysis */
296
298 : public dataflow::SparseForwardDataFlowAnalysis<IntervalAnalysisLattice> {
300 using Lattice = IntervalAnalysisLattice;
301 using LatticeValue = IntervalAnalysisLattice::LatticeValue;
302
303 // Map SourceRefs to their symbols.
304 using SymbolMap = mlir::DenseMap<SourceRef, llvm::SMTExprRef>;
305
306public:
308 mlir::DataFlowSolver &dataflowSolver, llvm::SMTSolverRef smt, const Field &f,
309 bool propInputConstraints, bool shouldTrackUnreducedIntervals
310 )
311 : Base::SparseForwardDataFlowAnalysis(dataflowSolver), _dataflowSolver(dataflowSolver),
312 smtSolver(std::move(smt)), field(f), propagateInputConstraints(propInputConstraints),
313 trackUnreducedIntervals(shouldTrackUnreducedIntervals) {}
314
315 mlir::LogicalResult visitOperation(
316 mlir::Operation *op, mlir::ArrayRef<const Lattice *> operands,
317 mlir::ArrayRef<Lattice *> results
318 ) override;
319
324 llvm::SMTExprRef getOrCreateSymbol(const SourceRef &r);
325
326 const llvm::DenseMap<SourceRef, llvm::DenseSet<Lattice *>> &getReadResults() const {
327 return readResults;
328 }
329
330 const llvm::DenseMap<SourceRef, ExpressionValue> &getWriteResults() const { return writeResults; }
331
332private:
333 mlir::DataFlowSolver &_dataflowSolver;
334 llvm::SMTSolverRef smtSolver;
335 SymbolMap refSymbols;
336 std::reference_wrapper<const Field> field;
337 bool propagateInputConstraints;
338 bool trackUnreducedIntervals;
339 mlir::SymbolTableCollection tables;
340
341 // Track SourceRef-indexed reads so writes to rooted storage can update existing readers.
342 llvm::DenseMap<SourceRef, llvm::DenseSet<Lattice *>> readResults;
343 // Track SourceRef-indexed writes. For now, we'll overapproximate repeated writes.
344 llvm::DenseMap<SourceRef, ExpressionValue> writeResults;
345
346 void setToEntryState(Lattice *lattice) override {
347 // Initialize the value with an interval in our specified field.
348 (void)lattice->setValue(ExpressionValue(field.get()));
349 }
350
351 static bool isBooleanType(mlir::Type ty) {
352 if (auto intTy = llvm::dyn_cast<mlir::IntegerType>(ty)) {
353 return intTy.getWidth() == 1;
354 }
355 return false;
356 }
357
358 Interval getDefaultIntervalForType(mlir::Type ty) const {
359 return isBooleanType(ty) ? Interval::Boolean(field.get()) : Interval::Entire(field.get());
360 }
361
362 std::optional<UnreducedInterval> getDefaultUnreducedIntervalForType(mlir::Type ty) const;
363
364 std::optional<UnreducedInterval> getRefUnreducedInterval(const SourceRef &ref);
365
366 llvm::SMTExprRef createSymbol(mlir::Type ty, const char *name) const;
367
368 llvm::SMTExprRef createSymbol(const SourceRef &r) const;
369
370 llvm::SMTExprRef createSymbol(mlir::Value val) const;
371
372 ExpressionValue createUnknownValue(mlir::Value val) const {
373 return ExpressionValue(
374 createSymbol(val), getDefaultIntervalForType(val.getType()),
375 getDefaultUnreducedIntervalForType(val.getType())
376 );
377 }
378
379 inline bool isConstOp(mlir::Operation *op) const {
380 return llvm::isa<
381 felt::FeltConstantOp, mlir::arith::ConstantIndexOp, mlir::arith::ConstantIntOp>(op);
382 }
383
384 inline bool isBoolConstOp(mlir::Operation *op) const {
385 if (auto constIntOp = llvm::dyn_cast<mlir::arith::ConstantIntOp>(op)) {
386 auto valAttr = dyn_cast<mlir::IntegerAttr>(constIntOp.getValue());
387 ensure(valAttr != nullptr, "arith::ConstantIntOp must have an IntegerAttr as its value");
388 return valAttr.getValue().getBitWidth() == 1;
389 }
390 return false;
391 }
392
393 llvm::DynamicAPInt getConst(mlir::Operation *op) const;
394
395 inline llvm::SMTExprRef createConstBitvectorExpr(const llvm::DynamicAPInt &v) const {
396 return createConstBitvectorExpr(toAPSInt(v));
397 }
398
399 inline llvm::SMTExprRef createConstBitvectorExpr(const llvm::APSInt &v) const {
400 return smtSolver->mkBitvector(v, field.get().bitWidth());
401 }
402
403 llvm::SMTExprRef createConstBoolExpr(bool v) const { return smtSolver->mkBoolean(v); }
404
405 bool isArithmeticOp(mlir::Operation *op) const {
406 return llvm::isa<
407 felt::AddFeltOp, felt::SubFeltOp, felt::MulFeltOp, felt::DivFeltOp, felt::UnsignedModFeltOp,
408 felt::SignedModFeltOp, felt::SignedIntDivFeltOp, felt::UnsignedIntDivFeltOp,
409 mlir::arith::XOrIOp, felt::NegFeltOp, felt::InvFeltOp, felt::AndFeltOp, felt::OrFeltOp,
410 felt::XorFeltOp, felt::NotFeltOp, felt::ShlFeltOp, felt::ShrFeltOp, boolean::CmpOp,
411 boolean::AndBoolOp, boolean::OrBoolOp, boolean::XorBoolOp, boolean::NotBoolOp>(op);
412 }
413
414 ExpressionValue
415 performBinaryArithmetic(mlir::Operation *op, const LatticeValue &a, const LatticeValue &b);
416
417 ExpressionValue performUnaryArithmetic(mlir::Operation *op, const LatticeValue &a);
418
425 void applyInterval(mlir::Operation *originalOp, mlir::Value val, Interval newInterval);
426
428 mlir::FailureOr<std::pair<llvm::DenseSet<mlir::Value>, Interval>>
429 getGeneralizedDecompInterval(mlir::Operation *baseOp, mlir::Value lhs, mlir::Value rhs);
430
431 bool isReadOp(mlir::Operation *op) const {
432 return llvm::isa<component::MemberReadOp, polymorphic::ConstReadOp, array::ReadArrayOp>(op);
433 }
434
435 bool isDefinitionOp(mlir::Operation *op) const {
436 return llvm::isa<
437 component::StructDefOp, function::FuncDefOp, component::MemberDefOp, global::GlobalDefOp,
438 mlir::ModuleOp>(op);
439 }
440
441 bool isReturnOp(mlir::Operation *op) const { return llvm::isa<function::ReturnOp>(op); }
442
446 std::vector<SourceRefIndex>
447 getArrayAccessIndices(mlir::Operation *baseOp, array::ArrayAccessOpInterface arrayAccessOp);
448
451 mlir::FailureOr<SourceRef>
452 getArrayAccessRef(mlir::Operation *baseOp, array::ArrayAccessOpInterface arrayAccessOp);
453
456 Interval getRefInterval(const SourceRef &ref);
457
461 ExpressionValue getRefValue(const SourceRef &ref, mlir::Value val);
462
468 void recordRefWrite(
469 const SourceRef &writtenRef, const ExpressionValue &writeVal, bool mayBeSkipped = false
470 );
471
473 SourceRefLatticeValue getSourceRefState(mlir::Value val);
474};
475
476/* StructIntervals */
477
481 llvm::SMTSolverRef smtSolver;
482 std::optional<std::reference_wrapper<const Field>> field;
485
486 llvm::SMTExprRef getSymbol(const SourceRef &r) const { return intervalDFA->getOrCreateSymbol(r); }
487 bool hasField() const { return field.has_value(); }
488 const Field &getField() const {
489 ensure(field.has_value(), "field not set within context");
490 return field->get();
491 }
494
495 friend bool
497};
498
499} // namespace llzk
500
501template <> struct std::hash<llzk::IntervalAnalysisContext> {
503 return llvm::hash_combine(
504 std::hash<const llzk::IntervalDataFlowAnalysis *> {}(c.intervalDFA),
505 std::hash<const llvm::SMTSolver *> {}(c.smtSolver.get()),
506 std::hash<const llzk::Field *> {}(&c.getField()),
507 std::hash<bool> {}(c.propagateInputConstraints),
508 std::hash<bool> {}(c.trackUnreducedIntervals)
509 );
510 }
511};
512
513namespace llzk {
514
515// Suppress false positive from `clang-tidy`
516// NOLINTNEXTLINE(bugprone-exception-escape)
517class StructIntervals {
518public:
528 static mlir::FailureOr<StructIntervals> compute(
529 mlir::ModuleOp mod, component::StructDefOp s, mlir::DataFlowSolver &solver,
530 mlir::AnalysisManager &am, const IntervalAnalysisContext &ctx
531 ) {
532 StructIntervals si(mod, s);
533 if (si.computeIntervals(solver, am, ctx).failed()) {
534 return mlir::failure();
535 }
536 return si;
537 }
538
539 mlir::LogicalResult computeIntervals(
540 mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, const IntervalAnalysisContext &ctx
541 );
542
543 void print(
544 mlir::raw_ostream &os, bool withConstraints = false, bool printCompute = false,
545 bool printUnreduced = false
546 ) const;
547
548 const llvm::MapVector<SourceRef, Interval> &getConstrainIntervals() const {
549 return constrainMemberRanges;
550 }
551
552 const llvm::MapVector<SourceRef, UnreducedInterval> &getConstrainUnreducedIntervals() const {
553 return constrainMemberUnreducedRanges;
554 }
555
556 const llvm::SetVector<ExpressionValue> getConstrainSolverConstraints() const {
557 return constrainSolverConstraints;
558 }
559
560 const llvm::MapVector<SourceRef, Interval> &getComputeIntervals() const {
561 return computeMemberRanges;
562 }
563
564 const llvm::MapVector<SourceRef, UnreducedInterval> &getComputeUnreducedIntervals() const {
565 return computeMemberUnreducedRanges;
566 }
567
568 const llvm::SetVector<ExpressionValue> getComputeSolverConstraints() const {
569 return computeSolverConstraints;
570 }
571
572 friend mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const StructIntervals &si) {
573 si.print(os);
574 return os;
575 }
576
577private:
578 mlir::ModuleOp mod;
579 component::StructDefOp structDef;
580 llvm::SMTSolverRef smtSolver;
581 // llvm::MapVector keeps insertion order for consistent iteration
582 llvm::MapVector<SourceRef, Interval> constrainMemberRanges, computeMemberRanges;
583 llvm::MapVector<SourceRef, UnreducedInterval> constrainMemberUnreducedRanges,
584 computeMemberUnreducedRanges;
585 // llvm::SetVector for the same reasons as above
586 llvm::SetVector<ExpressionValue> constrainSolverConstraints, computeSolverConstraints;
587
588 StructIntervals(mlir::ModuleOp m, component::StructDefOp s) : mod(m), structDef(s) {}
589};
590
591/* StructIntervalAnalysis */
592
594
595class StructIntervalAnalysis : public StructAnalysis<StructIntervals, IntervalAnalysisContext> {
596public:
598 ~StructIntervalAnalysis() override = default;
599
600 bool inProgress(const IntervalAnalysisContext &ctx) const {
601 return inProgressContexts.contains(ctx);
602 }
603
604 mlir::LogicalResult runAnalysis(
605 mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, const IntervalAnalysisContext &ctx
606 ) override {
607 if (inProgress(ctx)) {
608 return mlir::failure();
609 }
610 inProgressContexts.insert(ctx);
611 auto cleanup = llvm::make_scope_exit([this, &ctx] { inProgressContexts.erase(ctx); });
612
613 auto computeRes = StructIntervals::compute(getModule(), getStruct(), solver, am, ctx);
614 if (mlir::failed(computeRes)) {
615 return mlir::failure();
616 }
617 setResult(ctx, std::move(*computeRes));
618 return mlir::success();
619 }
620
621private:
622 std::unordered_set<IntervalAnalysisContext> inProgressContexts;
623};
624
625/* ModuleIntervalAnalysis */
626
628 : public ModuleAnalysis<StructIntervals, IntervalAnalysisContext, StructIntervalAnalysis> {
629
630public:
631 // We set intraprocedural to false for the sake of the SourceRefAnalysis
632 ModuleIntervalAnalysis(mlir::Operation *op)
633 : ModuleAnalysis(op, mlir::DataFlowConfig().setInterprocedural(false)), ctx {} {
634 ctx.smtSolver = llvm::CreateZ3Solver();
635 }
636 ~ModuleIntervalAnalysis() override = default;
637
638 void setField(const Field &f) { ctx.field = f; }
639 void setPropagateInputConstraints(bool prop) { ctx.propagateInputConstraints = prop; }
640 void setTrackUnreducedIntervals(bool track) { ctx.trackUnreducedIntervals = track; }
641
642protected:
643 void initializeSolver() override {
644 ensure(ctx.hasField(), "field not set, could not generate analysis context");
645 (void)solver.load<SourceRefAnalysis>();
646 auto smtSolverRef = ctx.smtSolver;
647 bool prop = ctx.propagateInputConstraints;
648 bool track = ctx.trackUnreducedIntervals;
649 ctx.intervalDFA =
650 solver.load<IntervalDataFlowAnalysis, llvm::SMTSolverRef, const Field &, bool, bool>(
651 std::move(smtSolverRef), ctx.getField(),
652 std::move(prop), // NOLINT(performance-move-const-arg)
653 std::move(track) // NOLINT(performance-move-const-arg)
654 );
655 }
656
657 const IntervalAnalysisContext &getContext() const override {
658 ensure(ctx.field.has_value(), "field not set, could not generate analysis context");
659 return ctx;
660 }
661
662private:
664};
665
666} // namespace llzk
667
668namespace llvm {
669
670template <> struct DenseMapInfo<llzk::ExpressionValue> {
671
672 static SMTExprRef getEmptyExpr() {
673 static const auto *emptyPtr = reinterpret_cast<SMTExprRef>(1);
674 return emptyPtr;
675 }
676 static SMTExprRef getTombstoneExpr() {
677 static const auto *tombstonePtr = reinterpret_cast<SMTExprRef>(2);
678 return tombstonePtr;
679 }
680
687 static unsigned getHashValue(const llzk::ExpressionValue &e) {
688 return llzk::ExpressionValue::Hash {}(e);
689 }
690 static bool isEqual(const llzk::ExpressionValue &lhs, const llzk::ExpressionValue &rhs) {
691 if (lhs.getExpr() == getEmptyExpr() || lhs.getExpr() == getTombstoneExpr() ||
692 rhs.getExpr() == getEmptyExpr() || rhs.getExpr() == getTombstoneExpr()) {
693 return lhs.getExpr() == rhs.getExpr();
694 }
695 return lhs == rhs;
696 }
697};
698
699} // namespace llvm
Convenience classes for a frequent pattern of dataflow analysis used in LLZK, where an analysis is ru...
This file implements sparse data-flow analysis using the data-flow analysis framework.
Tracks a solver expression and an interval range for that expression.
friend ExpressionValue boolAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue withUnreducedInterval(const UnreducedInterval &newUnreducedInterval) const
friend ExpressionValue sintDiv(const llvm::SMTSolverRef &solver, mlir::Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue withExpression(const llvm::SMTExprRef &newExpr) const
Return the current expression with a new SMT expression.
friend ExpressionValue fallbackUnaryOp(const llvm::SMTSolverRef &solver, mlir::Operation *op, const ExpressionValue &val)
friend ExpressionValue bitOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue notOp(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
friend ExpressionValue shiftRight(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue intersection(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Compute the intersection of the lhs and rhs intervals, and create a solver expression that constrains...
ExpressionValue(llvm::SMTExprRef exprRef, const Interval &interval, std::optional< UnreducedInterval > unreducedInterval=std::nullopt)
friend ExpressionValue uintDiv(const llvm::SMTSolverRef &solver, mlir::Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
const Interval & getInterval() const
friend ExpressionValue mul(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue div(const llvm::SMTSolverRef &solver, mlir::Operation *op, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue(const Field &f, llvm::SMTExprRef exprRef)
const std::optional< UnreducedInterval > & getOptionalUnreducedInterval() const
ExpressionValue(const Field &f)
ExpressionValue withOptionalUnreducedInterval(std::optional< UnreducedInterval > newUnreducedInterval) const
friend ExpressionValue sub(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue withInterval(const Interval &newInterval) const
Return the current expression with a new interval.
friend ExpressionValue boolXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
void print(mlir::raw_ostream &os) const
bool operator==(const ExpressionValue &rhs) const
friend ExpressionValue shiftLeft(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue bitXor(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue bitAnd(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::SMTExprRef getExpr() const
bool isBoolSort(const llvm::SMTSolverRef &solver) const
friend ExpressionValue cmp(const llvm::SMTSolverRef &solver, boolean::CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue join(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
Compute the union of the lhs and rhs intervals, and create a solver expression that constrains both s...
bool hasUnreducedInterval() const
ExpressionValue(const Field &f, llvm::SMTExprRef exprRef, const llvm::DynamicAPInt &singleVal)
friend ExpressionValue boolOr(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
friend ExpressionValue boolNot(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const ExpressionValue &e)
const Field & getField() const
const UnreducedInterval & getUnreducedInterval() const
ExpressionValue & join(const ExpressionValue &)
Fold two expressions together when overapproximating array elements.
friend ExpressionValue neg(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
friend ExpressionValue add(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
ExpressionValue dropUnreducedInterval() const
Information about the prime finite field used for the interval analysis.
Definition Field.h:36
static const Field & getField(llvm::StringRef fieldName, EmitErrorFn errFn)
Get a Field from a given field name string.
IntervalAnalysisLatticeValue & operator=(const IntervalAnalysisLatticeValue &)=default
IntervalAnalysisLatticeValue(IntervalAnalysisLatticeValue &&)=default
IntervalAnalysisLatticeValue(const IntervalAnalysisLatticeValue &)=default
IntervalAnalysisLatticeValue(mlir::ArrayRef< int64_t > shape)
IntervalAnalysisLatticeValue & operator=(IntervalAnalysisLatticeValue &&)=default
IntervalAnalysisLatticeValue(ExpressionValue e)
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const IntervalAnalysisLattice &l)
const LatticeValue & getValue() const
llvm::SetVector< ExpressionValue > ConstraintSet
mlir::ChangeResult setValue(const LatticeValue &val)
IntervalAnalysisLatticeValue LatticeValue
mlir::DenseMap< mlir::Value, LatticeValue > ValueMap
mlir::ChangeResult meet(const AbstractSparseLattice &other) override
const ConstraintSet & getConstraints() const
mlir::DenseMap< mlir::Value, mlir::DenseMap< mlir::StringAttr, LatticeValue > > MemberMap
void print(mlir::raw_ostream &os) const override
mlir::ChangeResult join(const AbstractSparseLattice &other) override
mlir::ChangeResult setInterval(llvm::SMTExprRef expr, const Interval &i)
mlir::DenseMap< llvm::SMTExprRef, Interval > ExpressionIntervals
mlir::ChangeResult addSolverConstraint(const ExpressionValue &e)
mlir::FailureOr< Interval > findInterval(llvm::SMTExprRef expr) const
mlir::LogicalResult visitOperation(mlir::Operation *op, mlir::ArrayRef< const Lattice * > operands, mlir::ArrayRef< Lattice * > results) override
Visit an operation with the lattices of its operands.
llvm::SMTExprRef getOrCreateSymbol(const SourceRef &r)
Either return the existing SMT expression that corresponds to the SourceRef, or create one.
const llvm::DenseMap< SourceRef, ExpressionValue > & getWriteResults() const
const llvm::DenseMap< SourceRef, llvm::DenseSet< Lattice * > > & getReadResults() const
IntervalDataFlowAnalysis(mlir::DataFlowSolver &dataflowSolver, llvm::SMTSolverRef smt, const Field &f, bool propInputConstraints, bool shouldTrackUnreducedIntervals)
Intervals over a finite field.
Definition Intervals.h:206
static Interval Boolean(const Field &f)
Definition Intervals.h:227
ModuleAnalysis(mlir::Operation *op, const mlir::DataFlowConfig &config=mlir::DataFlowConfig())
ModuleIntervalAnalysis(mlir::Operation *op)
~ModuleIntervalAnalysis() override=default
void setPropagateInputConstraints(bool prop)
void initializeSolver() override
Initialize the shared dataflow solver with any common analyses required by the contained struct analy...
void setTrackUnreducedIntervals(bool track)
const IntervalAnalysisContext & getContext() const override
Return the current Context object.
void setField(const Field &f)
The dataflow analysis that computes the set of references that LLZK operations use and produce.
A reference to a "source", which is the base value from which other SSA values are derived.
Definition SourceRef.h:132
StructAnalysis(mlir::Operation *op)
Assert that this analysis is being run on a StructDefOp and initializes the analysis with the current...
void setResult(const IntervalAnalysisContext &ctx, StructIntervals &&r)
mlir::LogicalResult runAnalysis(mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, const IntervalAnalysisContext &ctx) override
Perform the analysis and construct the Result output.
~StructIntervalAnalysis() override=default
StructAnalysis(mlir::Operation *op)
Assert that this analysis is being run on a StructDefOp and initializes the analysis with the current...
bool inProgress(const IntervalAnalysisContext &ctx) const
const llvm::MapVector< SourceRef, Interval > & getConstrainIntervals() const
static mlir::FailureOr< StructIntervals > compute(mlir::ModuleOp mod, component::StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, const IntervalAnalysisContext &ctx)
Compute the struct intervals.
const llvm::SetVector< ExpressionValue > getConstrainSolverConstraints() const
const llvm::MapVector< SourceRef, UnreducedInterval > & getConstrainUnreducedIntervals() const
const llvm::SetVector< ExpressionValue > getComputeSolverConstraints() const
const llvm::MapVector< SourceRef, Interval > & getComputeIntervals() const
void print(mlir::raw_ostream &os, bool withConstraints=false, bool printCompute=false, bool printUnreduced=false) const
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const StructIntervals &si)
const llvm::MapVector< SourceRef, UnreducedInterval > & getComputeUnreducedIntervals() const
mlir::LogicalResult computeIntervals(mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, const IntervalAnalysisContext &ctx)
An inclusive interval [a, b] where a and b are arbitrary integers not necessarily bound to a given fi...
Definition Intervals.h:26
mlir::SymbolTableCollection tables
LLZK: Added for use of symbol helper caching.
A sparse forward data-flow analysis for propagating SSA value lattices across the IR by implementing ...
mlir::dataflow::AbstractSparseLattice AbstractSparseLattice
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
void ensure(bool condition, const llvm::Twine &errMsg)
APSInt toAPSInt(const DynamicAPInt &i)
static unsigned getHashValue(const llzk::ExpressionValue &e)
static bool isEqual(const llzk::ExpressionValue &lhs, const llzk::ExpressionValue &rhs)
static llzk::ExpressionValue getTombstoneKey()
static llzk::ExpressionValue getEmptyKey()
unsigned operator()(const ExpressionValue &e) const
Parameters and shared objects to pass to child analyses.
const Field & getField() const
friend bool operator==(const IntervalAnalysisContext &a, const IntervalAnalysisContext &b)=default
std::optional< std::reference_wrapper< const Field > > field
IntervalDataFlowAnalysis * intervalDFA
llvm::SMTExprRef getSymbol(const SourceRef &r) const
size_t operator()(const llzk::IntervalAnalysisContext &c) const