LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
IntervalAnalysisPass.cpp
Go to the documentation of this file.
1//===-- IntervalAnalysisPass.cpp --------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
19#include "llzk/Util/Constants.h"
21
22#include <mlir/IR/AsmState.h>
23
24#include <llvm/ADT/STLExtras.h>
25#include <llvm/ADT/SmallVector.h>
26#include <llvm/Support/Debug.h>
27#include <llvm/Support/ErrorHandling.h>
28
29namespace llzk {
30#define GEN_PASS_DEF_INTERVALANALYSISPRINTERPASS
32} // namespace llzk
33
34#define DEBUG_TYPE "llzk-interval-analysis-pass"
35
36using namespace mlir;
37using namespace llzk;
38using namespace llzk::component;
39using namespace llzk::function;
40
41namespace {
42
43class PassImpl : public llzk::impl::IntervalAnalysisPrinterPassBase<PassImpl> {
44 using Base = IntervalAnalysisPrinterPassBase<PassImpl>;
45 using Base::Base;
46
47 void runOnOperation() override {
48 markAllAnalysesPreserved();
49
50 // Suppress false positive from `clang-tidy`
51 // NOLINTNEXTLINE(clang-analyzer-core.NonNullParamChecker)
52 auto modOp = llvm::dyn_cast<ModuleOp>(getOperation());
53 if (!modOp) {
54 constexpr const char *msg = "IntervalAnalysisPrinterPass error: should be run on ModuleOp!";
55 getOperation()->emitError(msg).report();
56 return;
57 }
58
59 // Initialize to the fallback field value
60 FieldRef selectedField = Field::getField("bn128");
61 if (!fieldName.empty()) {
62 auto fieldLookupRes = Field::tryGetField(fieldName.c_str());
63 if (failed(fieldLookupRes)) {
64 modOp->emitError()
65 .append(
66 "IntervalAnalysisPrinterPass error: unknown field \"", fieldName, "\" specified"
67 )
68 .report();
69 return;
70 }
71 selectedField = fieldLookupRes.value();
72 LLVM_DEBUG(
73 llvm::dbgs() << "[IntervalAnalysisPrinterPass] using explicit -field override '"
74 << selectedField.get().name() << "'\n";
75 );
76 } else if (auto detectedField = tryDetectSpecifiedField(modOp)) {
77 selectedField = detectedField.value();
78 LLVM_DEBUG(
79 llvm::dbgs() << "[IntervalAnalysisPrinterPass] detected module field '"
80 << selectedField.get().name() << "' from module felt usage\n";
81 );
82 } else {
83 modOp->emitWarning() << "could not detect a unique module field; falling back to '"
84 << selectedField.get().name() << '\'';
85 LLVM_DEBUG(
86 llvm::dbgs() << "[IntervalAnalysisPrinterPass] no explicit or detectable module field; "
87 "falling back to '"
88 << selectedField.get().name() << "'\n";
89 );
90 }
91
92 auto &mia = getAnalysis<ModuleIntervalAnalysis>();
93 mia.setField(selectedField);
94 mia.setPropagateInputConstraints(propagateInputConstraints);
95 mia.setTrackUnreducedIntervals(printUnreducedIntervals);
96 auto am = getAnalysisManager();
97 mia.ensureAnalysisRun(am);
98 AsmState asmState(modOp);
99
100 auto printValueInterval = [this, &asmState, &mia](raw_ostream &out, int indent, Value value) {
101 if (llvm::isa<llzk::array::ArrayType, StructType>(value.getType())) {
102 return;
103 }
104 const auto *lattice = mia.getSolver().lookupState<IntervalAnalysisLattice>(value);
105 if (!lattice) {
106 return;
107 }
108 const ExpressionValue &expr = lattice->getValue().getScalarValue();
109 out << '\n';
110 out.indent(indent);
111 value.printAsOperand(out, asmState);
112 if (auto opResult = llvm::dyn_cast<OpResult>(value)) {
113 out << " [" << opResult.getOwner()->getName().getStringRef() << "]";
114 }
115 out << " in " << expr.getInterval();
116 if (printUnreducedIntervals && expr.hasUnreducedInterval()) {
117 out << " ( unreduced: " << expr.getUnreducedInterval() << " )";
118 }
119 };
120
121 auto printFunctionSSAIntervals =
122 [&printValueInterval](raw_ostream &out, FuncDefOp fn, llvm::StringRef fnName) {
123 if (!fn) {
124 return;
125 }
126
127 out << '\n';
128 out.indent(4) << fnName << " {";
129 for (BlockArgument arg : fn.getArguments()) {
130 printValueInterval(out, 8, arg);
131 }
132 fn.walk([&](Operation *op) {
133 if (op == fn.getOperation()) {
134 return;
135 }
136 for (Value result : op->getResults()) {
137 printValueInterval(out, 8, result);
138 }
139 });
140 out << '\n';
141 out.indent(4) << '}';
142 };
143
144 auto &os = llzk::toStream(outputStream);
145 for (const auto &[s, si] : mia.getCurrentResults()) {
146 auto &structDef = const_cast<StructDefOp &>(s);
147 auto fullName = getPathFromTopRoot(structDef);
148 ensure(
149 succeeded(fullName),
150 "could not resolve fully qualified name of struct " + Twine(structDef.getName())
151 );
152 os << fullName.value() << ' ';
153 si.get().print(os, printSolverConstraints, printComputeIntervals, printUnreducedIntervals);
154 if (printSSAIntervals) {
155 os << fullName.value() << " SSAIntervals {";
156 if (printComputeIntervals) {
157 printFunctionSSAIntervals(os, structDef.getComputeFuncOp(), FUNC_NAME_COMPUTE);
158 }
159 printFunctionSSAIntervals(os, structDef.getConstrainFuncOp(), FUNC_NAME_CONSTRAIN);
160 if (auto productFn = structDef.getProductFuncOp();
161 productFn && (!structDef.getConstrainFuncOp() || printComputeIntervals)) {
162 printFunctionSSAIntervals(os, productFn, FUNC_NAME_PRODUCT);
163 }
164 os << "\n}\n";
165 }
166 }
167 }
168};
169
170} // namespace
Tracks a solver expression and an interval range for that expression.
const Interval & getInterval() const
bool hasUnreducedInterval() const
const UnreducedInterval & getUnreducedInterval() const
static llvm::FailureOr< std::reference_wrapper< const Field > > tryGetField(llvm::StringRef fieldName)
Get a Field from a given field name string, or failure if the field is not defined.
Definition Field.cpp:56
static const Field & getField(llvm::StringRef fieldName, EmitErrorFn errFn)
Get a Field from a given field name string.
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
std::reference_wrapper< const Field > FieldRef
Typealias for a stable reference to a known Field.
Definition Field.h:156
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
llvm::raw_ostream & toStream(OutputStream val)
void ensure(bool condition, const llvm::Twine &errMsg)
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:18
FailureOr< SymbolRefAttr > getPathFromTopRoot(SymbolOpInterface to, ModuleOp *foundRoot)
std::optional< std::reference_wrapper< const Field > > tryDetectSpecifiedField(mlir::Operation *root)
Try to detect a uniquely used field from the enclosing LLZK module.