LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
SourceRefLattice.cpp
Go to the documentation of this file.
1//===-- SourceRefLattice.cpp - SourceRef lattice & utils --*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
16#include "llzk/Util/Hash.h"
18
19#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
20#include <mlir/IR/Value.h>
21
22#include <llvm/Support/Debug.h>
23
24#include <numeric>
25#include <unordered_set>
26
27#define DEBUG_TYPE "llzk-constrain-ref-lattice"
28
29using namespace mlir;
30
31namespace llzk {
32
33using namespace array;
34using namespace component;
35using namespace felt;
36using namespace polymorphic;
37
38/* SourceRefLatticeValue */
39
40mlir::ChangeResult SourceRefLatticeValue::insert(const SourceRef &rhs) {
41 auto rhsVal = SourceRefLatticeValue(rhs);
42 if (isScalar()) {
43 return updateScalar(rhsVal.getScalarValue());
44 } else {
45 return foldAndUpdate(rhsVal);
46 }
47}
48
49std::pair<SourceRefLatticeValue, mlir::ChangeResult>
51 auto newVal = *this;
52 auto res = mlir::ChangeResult::NoChange;
53 if (newVal.isScalar()) {
54 res = newVal.translateScalar(translation);
55 } else {
56 for (auto &elem : newVal.getArrayValue()) {
57 auto [newElem, elemRes] = elem->translate(translation);
58 (*elem) = newElem;
59 res |= elemRes;
60 }
61 }
62 return {newVal, res};
63}
64
65mlir::FailureOr<std::pair<SourceRefLatticeValue, mlir::ChangeResult>>
67 SourceRefIndex idx(std::move(memberRef));
68 auto transform = [&idx](const SourceRef &r) -> mlir::FailureOr<SourceRef> {
69 return r.createChild(idx);
70 };
71 return elementwiseTransform(transform);
72}
73
74mlir::FailureOr<std::pair<SourceRefLatticeValue, mlir::ChangeResult>>
75SourceRefLatticeValue::extract(const std::vector<SourceRefIndex> &indices) const {
76 if (isArray()) {
77 ensure(indices.size() <= getNumArrayDims(), "invalid extract array operands");
78
79 // First, compute what chunk(s) to index
80 std::vector<size_t> currIdxs {0};
81 for (unsigned i = 0; i < indices.size(); i++) {
82 const auto &idx = indices[i];
83 auto currDim = getArrayDim(i);
84
85 std::vector<size_t> newIdxs;
86 ensure(idx.isIndex() || idx.isIndexRange(), "wrong type of index for array");
87 if (idx.isIndex()) {
88 int64_t idxVal(idx.getIndex());
89 std::transform(
90 currIdxs.begin(), currIdxs.end(), std::back_inserter(newIdxs),
91 [&currDim, &idxVal](size_t j) { return j * currDim + idxVal; }
92 );
93 } else {
94 auto [low, high] = idx.getIndexRange();
95 int64_t lowInt(low), highInt(high);
96 for (int64_t idxVal = lowInt; idxVal < highInt; idxVal++) {
97 std::transform(
98 currIdxs.begin(), currIdxs.end(), std::back_inserter(newIdxs),
99 [&currDim, &idxVal](size_t j) { return j * currDim + idxVal; }
100 );
101 }
102 }
103
104 currIdxs = newIdxs;
105 }
106 std::vector<int64_t> newArrayDims;
107 size_t chunkSz = 1;
108 for (size_t i = indices.size(); i < getNumArrayDims(); i++) {
109 auto dim = getArrayDim(i);
110 newArrayDims.push_back(dim);
111 chunkSz *= dim;
112 }
113 if (newArrayDims.empty()) {
114 // read case, where the return value is a scalar (single element)
115 SourceRefLatticeValue extractedVal;
116 for (auto idx : currIdxs) {
117 (void)extractedVal.update(getElemFlatIdx(idx));
118 }
119 return std::make_pair(extractedVal, mlir::ChangeResult::Change);
120 } else {
121 // extract case, where the return value is an array of fewer dimensions.
122 SourceRefLatticeValue extractedVal(newArrayDims);
123 for (auto chunkStart : currIdxs) {
124 for (size_t i = 0; i < chunkSz; i++) {
125 (void)extractedVal.getElemFlatIdx(i).update(getElemFlatIdx(chunkStart + i));
126 }
127 }
128 return std::make_pair(extractedVal, mlir::ChangeResult::Change);
129 }
130 } else {
131 auto currVal = *this;
132 auto res = mlir::ChangeResult::NoChange;
133 for (const auto &idx : indices) {
134 auto transform = [&idx](const SourceRef &r) -> mlir::FailureOr<SourceRef> {
135 return r.createChild(idx);
136 };
137 auto transformedVal = currVal.elementwiseTransform(transform);
138 if (failed(transformedVal)) {
139 return mlir::failure();
140 }
141 auto [newVal, transformRes] = *transformedVal;
142 currVal = std::move(newVal);
143 res |= transformRes;
144 }
145 return std::make_pair(currVal, res);
146 }
147}
148
149mlir::ChangeResult SourceRefLatticeValue::translateScalar(const TranslationMap &translation) {
150 auto res = mlir::ChangeResult::NoChange;
151 // copy the current value
152 auto currVal = getScalarValue();
153 // reset this value
154 getValue() = ScalarTy();
155 // For each current element, see if the translation map contains a valid prefix.
156 // If so, translate the current element with all replacement prefixes indicated
157 // by the translation value.
158 for (const SourceRef &currRef : currVal) {
159 for (const auto &[prefix, replacementVal] : translation) {
160 if (currRef.isValidPrefix(prefix)) {
161 for (const SourceRef &replacementPrefix : replacementVal.foldToScalar()) {
162 auto translatedRefRes = currRef.translate(prefix, replacementPrefix);
163 if (succeeded(translatedRefRes)) {
164 res |= insert(*translatedRefRes);
165 }
166 }
167 }
168 }
169 }
170 return res;
171}
172
173mlir::FailureOr<std::pair<SourceRefLatticeValue, mlir::ChangeResult>>
175 llvm::function_ref<mlir::FailureOr<SourceRef>(const SourceRef &)> transform
176) const {
177 auto newVal = *this;
178 auto res = mlir::ChangeResult::NoChange;
179 if (newVal.isScalar()) {
180 ScalarTy indexed;
181 for (const auto &ref : newVal.getScalarValue()) {
182 auto transformedRef = transform(ref);
183 if (failed(transformedRef)) {
184 return mlir::failure();
185 }
186 auto [_, inserted] = indexed.insert(*transformedRef);
187 if (inserted) {
188 res |= mlir::ChangeResult::Change;
189 }
190 }
191 newVal.getScalarValue() = indexed;
192 } else {
193 for (auto &elem : newVal.getArrayValue()) {
194 auto transformedElem = elem->elementwiseTransform(transform);
195 if (failed(transformedElem)) {
196 return mlir::failure();
197 }
198 auto [newElem, elemRes] = *transformedElem;
199 (*elem) = std::move(newElem);
200 res |= elemRes;
201 }
202 }
203 return std::make_pair(newVal, res);
204}
205
206mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const SourceRefLatticeValue &v) {
207 v.print(os);
208 return os;
209}
210
211/* SourceRefLattice */
212
213mlir::FailureOr<SourceRef> SourceRefLattice::getSourceRef(mlir::Value val) {
214 if (auto blockArg = llvm::dyn_cast<mlir::BlockArgument>(val)) {
215 return SourceRef(blockArg);
216 } else if (auto *defOp = val.getDefiningOp()) {
217 if (auto feltConst = llvm::dyn_cast<FeltConstantOp>(defOp)) {
218 return SourceRef(feltConst);
219 } else if (auto constIdx = llvm::dyn_cast<mlir::arith::ConstantIndexOp>(defOp)) {
220 return SourceRef(constIdx);
221 } else if (auto readConst = llvm::dyn_cast<ConstReadOp>(defOp)) {
222 return SourceRef(readConst);
223 } else if (auto structNew = llvm::dyn_cast<CreateStructOp>(defOp)) {
224 return SourceRef(structNew);
225 } else if (auto nonDet = llvm::dyn_cast<NonDetOp>(defOp)) {
226 return SourceRef(nonDet);
227 } else if (auto createArray = llvm::dyn_cast<CreateArrayOp>(defOp)) {
228 return SourceRef(createArray->getResult(0));
229 } else if (llvm::isa<function::CallOp>(defOp)) {
230 auto callResult = llvm::dyn_cast<mlir::OpResult>(val);
231 ensure(callResult != nullptr, "function.call value should be an OpResult");
232 return SourceRef(callResult);
233 }
234 }
235 return mlir::failure();
236}
237
239 if (auto asVal = llvm::dyn_cast_if_present<Value>(v)) {
240 auto sourceRef = getSourceRef(asVal);
241 if (mlir::succeeded(sourceRef)) {
242 return SourceRefLatticeValue(*sourceRef);
243 }
244 }
245 return SourceRefLatticeValue();
246}
247
248ChangeResult SourceRefLattice::join(const AbstractSparseLattice &rhs) {
249 return value.update(static_cast<const SourceRefLattice &>(rhs).value);
250}
251
252ChangeResult SourceRefLattice::meet(const AbstractSparseLattice & /*rhs*/) {
253 llvm::report_fatal_error("meet operation is not supported for SourceRefLattice");
254 return ChangeResult::NoChange;
255}
256
257void SourceRefLattice::print(mlir::raw_ostream &os) const {
258 os << "SourceRefLattice { " << value << " }";
259}
260
261ChangeResult SourceRefLattice::setValue(const LatticeValue &newValue) {
262 return value.setValue(newValue);
263}
264
265ChangeResult SourceRefLattice::setValue(const SourceRef &ref) {
266 return value.setValue(LatticeValue(ref));
267}
268
269} // namespace llzk
270
271namespace llvm {
272
273raw_ostream &operator<<(raw_ostream &os, llvm::PointerUnion<mlir::Value, mlir::Operation *> ptr) {
274 if (auto asVal = llvm::dyn_cast_if_present<Value>(ptr)) {
275 os << asVal;
276 } else if (auto *asOp = llvm::dyn_cast_if_present<Operation *>(ptr)) {
277 os << *asOp;
278 } else {
279 os << "<<null PointerUnion>>";
280 }
281 return os;
282}
283} // namespace llvm
Defines an index into an LLZK object.
Definition SourceRef.h:42
A value at a given point of the SourceRefLattice.
mlir::FailureOr< std::pair< SourceRefLatticeValue, mlir::ChangeResult > > referenceMember(SymbolLookupResult< component::MemberDefOp > memberRef) const
Add the given memberRef to the SourceRefs contained within this value.
virtual mlir::FailureOr< std::pair< SourceRefLatticeValue, mlir::ChangeResult > > elementwiseTransform(llvm::function_ref< mlir::FailureOr< SourceRef >(const SourceRef &)> transform) const
Perform a recursive transformation over all elements of this value and return a new value with the mo...
mlir::ChangeResult insert(const SourceRef &rhs)
Directly insert the ref into this value.
std::pair< SourceRefLatticeValue, mlir::ChangeResult > translate(const TranslationMap &translation) const
For the refs contained in this value, translate them given the translation map and return the transfo...
mlir::FailureOr< std::pair< SourceRefLatticeValue, mlir::ChangeResult > > extract(const std::vector< SourceRefIndex > &indices) const
Perform an array.extract or array.read operation, depending on how many indices are provided.
mlir::ChangeResult translateScalar(const TranslationMap &translation)
Translate this value using the translation map, assuming this value is a scalar.
Sparse SSA-value lattice for SourceRef propagation.
mlir::ChangeResult join(const AbstractSparseLattice &rhs) override
mlir::ChangeResult setValue(const LatticeValue &newValue)
mlir::ChangeResult meet(const AbstractSparseLattice &rhs) override
static SourceRefLatticeValue getDefaultValue(ValueTy v)
void print(mlir::raw_ostream &os) const override
static mlir::FailureOr< SourceRef > getSourceRef(mlir::Value val)
If val is the source of other values (i.e., a block argument, an allocation-like op result,...
llvm::PointerUnion< mlir::Value, mlir::Operation * > ValueTy
SourceRefLatticeValue LatticeValue
A reference to a "source", which is the base value from which other SSA values are derived.
Definition SourceRef.h:132
mlir::ChangeResult update(const Derived &rhs)
Union this value with that of rhs.
void print(mlir::raw_ostream &os) const
raw_ostream & operator<<(raw_ostream &os, llvm::PointerUnion< mlir::Value, mlir::Operation * > ptr)
void ensure(bool condition, const llvm::Twine &errMsg)
Interval operator<<(const Interval &lhs, const Interval &rhs)
std::unordered_map< SourceRef, SourceRefLatticeValue, SourceRef::Hash > TranslationMap