LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
SourceRef.h
Go to the documentation of this file.
1//===-- SourceRef.h ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
21#include "llzk/Util/Hash.h"
22
23#include <mlir/Analysis/DataFlowFramework.h>
24#include <mlir/Dialect/Arith/IR/Arith.h>
25#include <mlir/Pass/AnalysisManager.h>
26
27#include <llvm/ADT/ArrayRef.h>
28#include <llvm/ADT/DynamicAPInt.h>
29#include <llvm/ADT/EquivalenceClasses.h>
30#include <llvm/ADT/TypeSwitch.h>
31
32#include <compare>
33#include <unordered_set>
34#include <variant>
35#include <vector>
36
37namespace llzk {
38
43 using IndexRange = std::pair<llvm::DynamicAPInt, llvm::DynamicAPInt>;
44
45public:
46 explicit SourceRefIndex(component::MemberDefOp f) : index(f) {}
48 explicit SourceRefIndex(const llvm::DynamicAPInt &i) : index(i) {}
49 explicit SourceRefIndex(const llvm::APInt &i) : index(toDynamicAPInt(i)) {}
50 explicit SourceRefIndex(int64_t i) : index(llvm::DynamicAPInt(i)) {}
51 SourceRefIndex(const llvm::APInt &low, const llvm::APInt &high)
52 : index(IndexRange {toDynamicAPInt(low), toDynamicAPInt(high)}) {}
53 explicit SourceRefIndex(IndexRange r) : index(r) {}
54
55 bool isMember() const {
56 return std::holds_alternative<SymbolLookupResult<component::MemberDefOp>>(index) ||
57 std::holds_alternative<component::MemberDefOp>(index);
58 }
60 ensure(isMember(), "SourceRefIndex: member requested but not contained");
61 if (std::holds_alternative<component::MemberDefOp>(index)) {
62 return std::get<component::MemberDefOp>(index);
63 }
64 return std::get<SymbolLookupResult<component::MemberDefOp>>(index).get();
65 }
66
67 bool isIndex() const { return std::holds_alternative<llvm::DynamicAPInt>(index); }
68 llvm::DynamicAPInt getIndex() const {
69 ensure(isIndex(), "SourceRefIndex: index requested but not contained");
70 return std::get<llvm::DynamicAPInt>(index);
71 }
72
73 bool isIndexRange() const { return std::holds_alternative<IndexRange>(index); }
74 IndexRange getIndexRange() const {
75 ensure(isIndexRange(), "SourceRefIndex: index range requested but not contained");
76 return std::get<IndexRange>(index);
77 }
78
79 inline void dump() const { print(llvm::errs()); }
80 void print(mlir::raw_ostream &os) const;
81
82 inline bool operator==(const SourceRefIndex &rhs) const {
83 if (isMember() && rhs.isMember()) {
84 // We compare the underlying members, since the member could be in a symbol
85 // lookup or not.
86 return getMember() == rhs.getMember();
87 }
88 if (isIndex() && rhs.isIndex()) {
89 return getIndex() == rhs.getIndex();
90 }
91 return index == rhs.index;
92 }
93
94 std::strong_ordering operator<=>(const SourceRefIndex &rhs) const;
95
96 struct Hash {
97 size_t operator()(const SourceRefIndex &c) const;
98 };
99
100 inline size_t getHash() const { return Hash {}(*this); }
101
102private:
109 std::variant<
111 IndexRange>
112 index;
113};
114
115static inline mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const SourceRefIndex &rhs) {
116 rhs.print(os);
117 return os;
118}
119
132class SourceRef {
133public:
134 using Path = std::vector<SourceRefIndex>;
135
136private:
137 // Sort in the following order:
138 // block arg < struct.new < nondet < other rooted result < template const < const index <
139 // const felt.
140 enum class SortCategory {
141 BlockArgument,
142 CreateStruct,
143 NonDet,
144 RootResult,
145 TemplateConstant,
146 ConstantIndex,
147 ConstantFelt,
148 };
149
150 template <typename OpT> static mlir::Value getSingleResultValue(OpT op) {
151 ensure(op, "SourceRef requires a non-null operation");
152 ensure(op->getNumResults() == 1, "SourceRef expects a single-result operation");
153 return op->getResult(0);
154 }
155
156 static mlir::Value getRootResultValue(mlir::OpResult result) {
157 ensure(
158 !llvm::isa<
159 felt::FeltConstantOp, mlir::arith::ConstantIndexOp, mlir::arith::ConstantIntOp,
160 polymorphic::ConstReadOp>(result.getOwner()),
161 "SourceRef rooted OpResult constructors must not be used for constant values"
162 );
163 return result;
164 }
165
166 template <typename OpT> mlir::FailureOr<OpT> getDefiningOp() const {
167 if (auto op = llvm::dyn_cast_if_present<OpT>(value.getDefiningOp())) {
168 return op;
169 }
170 return mlir::failure();
171 }
172
173 SourceRef(mlir::Value sourceValue, bool isConstantStorage, Path sourcePath = {})
174 : value(sourceValue), path(std::move(sourcePath)), constant(isConstantStorage) {
175 ensure(value != nullptr, "SourceRef requires a non-null value");
176 ensure(!constant || this->path.empty(), "constant SourceRef cannot have a path");
177 }
178
179 Path &getPathMut() { return path; }
180 const void *getAsOpaquePointer() const { return value.getAsOpaquePointer(); }
181 SortCategory getSortCategory() const;
182 llvm::StringRef getTemplateConstantName() const;
183 std::strong_ordering compareWithinCategory(const SourceRef &rhs, SortCategory category) const;
184
185public:
187 static std::vector<SourceRef>
188 getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, const SourceRef &root);
189
191 static std::vector<SourceRef>
193
196 static std::vector<SourceRef>
198
199 /* Rooted constructors */
200
201 SourceRef(mlir::BlockArgument b, Path p = {}) : SourceRef(b, /*constant=*/false, p) {}
203 : SourceRef(getSingleResultValue(createOp), /*constant=*/false, p) {}
204 SourceRef(NonDetOp nondet, Path p = {})
205 : SourceRef(getSingleResultValue(nondet), /*constant=*/false, p) {}
206 SourceRef(mlir::OpResult rootResult, Path p = {})
207 : SourceRef(getRootResultValue(rootResult), /*constant=*/false, p) {}
208
209 /* Constant constructors */
210
212 : SourceRef(getSingleResultValue(c), /*constant=*/true) {}
213 explicit SourceRef(mlir::arith::ConstantIndexOp c)
214 : SourceRef(getSingleResultValue(c), /*constant=*/true) {}
216 : SourceRef(getSingleResultValue(c), /*constant=*/true) {}
217
218 mlir::Type getType() const;
219
220 bool isConstantFelt() const {
221 return isConstant() && llvm::isa_and_present<felt::FeltConstantOp>(value.getDefiningOp());
222 }
223 bool isConstantIndex() const {
224 return isConstant() &&
225 llvm::isa_and_present<mlir::arith::ConstantIndexOp>(value.getDefiningOp());
226 }
227 bool isTemplateConstant() const {
228 return isConstant() && llvm::isa_and_present<polymorphic::ConstReadOp>(value.getDefiningOp());
229 }
230
231 bool isConstant() const { return constant; }
232 bool isConstantInt() const { return isConstantFelt() || isConstantIndex(); }
233
234 bool isFeltVal() const { return llvm::isa<felt::FeltType>(getType()); }
235 bool isIndexVal() const { return llvm::isa<mlir::IndexType>(getType()); }
236 bool isIntegerVal() const { return llvm::isa<mlir::IntegerType>(getType()); }
237 bool isTypeVarVal() const { return llvm::isa<polymorphic::TypeVarType>(getType()); }
238 bool isScalar() const {
239 return isConstant() || isFeltVal() || isIndexVal() || isIntegerVal() || isTypeVarVal();
240 }
241
242 bool isRooted() const { return !constant; }
243 bool isBlockArgument() const { return isRooted() && llvm::isa<mlir::BlockArgument>(value); }
244 mlir::FailureOr<mlir::Value> getRoot() const {
245 if (isRooted()) {
246 return value;
247 }
248 return mlir::failure();
249 }
250 mlir::FailureOr<mlir::Value> getConstant() const {
251 if (isConstant()) {
252 return value;
253 }
254 return mlir::failure();
255 }
256 mlir::FailureOr<mlir::BlockArgument> getBlockArgument() const {
257 if (auto blockArg = llvm::dyn_cast<mlir::BlockArgument>(value)) {
258 return blockArg;
259 }
260 return mlir::failure();
261 }
262 mlir::FailureOr<unsigned> getInputNum() const {
263 auto blockArg = getBlockArgument();
264 if (succeeded(blockArg)) {
265 return blockArg->getArgNumber();
266 }
267 return mlir::failure();
268 }
269
270 bool isCreateStructOp() const { return succeeded(getCreateStructOp()); }
271 mlir::FailureOr<component::CreateStructOp> getCreateStructOp() const {
272 return getDefiningOp<component::CreateStructOp>();
273 }
274
275 bool isNonDetOp() const { return succeeded(getNonDetOp()); }
276 mlir::FailureOr<NonDetOp> getNonDetOp() const { return getDefiningOp<NonDetOp>(); }
277
278 bool isCallResult() const { return succeeded(getCallOp()); }
279 mlir::FailureOr<function::CallOp> getCallOp() const { return getDefiningOp<function::CallOp>(); }
280
281 mlir::FailureOr<llvm::DynamicAPInt> getConstantFeltValue() const {
282 auto feltConst = getDefiningOp<felt::FeltConstantOp>();
283 if (succeeded(feltConst)) {
284 llvm::APInt i = feltConst->getValue();
285 return toDynamicAPInt(i);
286 }
287 return mlir::failure();
288 }
289 mlir::FailureOr<llvm::DynamicAPInt> getConstantIndexValue() const {
290 auto indexConst = getDefiningOp<mlir::arith::ConstantIndexOp>();
291 if (succeeded(indexConst)) {
292 return llvm::DynamicAPInt(indexConst->value());
293 }
294 return mlir::failure();
295 }
296 mlir::FailureOr<llvm::DynamicAPInt> getConstantValue() const {
297 auto feltVal = getConstantFeltValue();
298 if (succeeded(feltVal)) {
299 return *feltVal;
300 }
301 auto indexVal = getConstantIndexValue();
302 if (succeeded(indexVal)) {
303 return *indexVal;
304 }
305 return mlir::failure();
306 }
307
309 bool isValidPrefix(const SourceRef &prefix) const;
310
315 mlir::FailureOr<std::vector<SourceRefIndex>> getSuffix(const SourceRef &prefix) const;
316
323 mlir::FailureOr<SourceRef> translate(const SourceRef &prefix, const SourceRef &other) const;
324
326 mlir::FailureOr<SourceRef> getParentPrefix() const {
327 if (!isRooted() || getPath().empty()) {
328 return mlir::failure();
329 }
330 auto copy = *this;
331 copy.getPathMut().pop_back();
332 return copy;
333 }
334
336 std::vector<SourceRef>
337 getAllChildren(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod) const;
338
339 mlir::FailureOr<SourceRef> createChild(const SourceRefIndex &r) const {
340 if (!isRooted()) {
341 return mlir::failure();
342 }
343 auto copy = *this;
344 copy.getPathMut().push_back(r);
345 return copy;
346 }
347
348 mlir::FailureOr<SourceRef> createChild(const SourceRef &other) const {
349 auto idxVal = other.getConstantIndexValue();
350 if (failed(idxVal)) {
351 return mlir::failure();
352 }
353 return createChild(SourceRefIndex(*idxVal));
354 }
355
356 [[deprecated("Use getPath() instead")]]
357 // NOTE: When this function is removed, do not delete it, rewrite as `... = delete`.
358 llvm::ArrayRef<SourceRefIndex> getPieces() const {
359 return path;
360 }
361 llvm::ArrayRef<SourceRefIndex> getPath() const { return path; }
362
363 void print(mlir::raw_ostream &os) const;
364 void dump() const { print(llvm::errs()); }
365
366 bool operator==(const SourceRef &rhs) const;
367
368 bool operator!=(const SourceRef &rhs) const { return !(*this == rhs); }
369
370 // required for EquivalenceClasses usage
371 std::strong_ordering operator<=>(const SourceRef &rhs) const;
372
373 struct Hash {
374 size_t operator()(const SourceRef &val) const;
375 };
376
377 friend struct llvm::DenseMapInfo<SourceRef>;
378
379private:
380 mlir::Value value;
381 Path path;
382 bool constant;
383};
384
385mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const SourceRef &rhs);
386
387/* SourceRefSet */
388
389class SourceRefSet : public std::unordered_set<SourceRef, SourceRef::Hash> {
390 using Base = std::unordered_set<SourceRef, SourceRef::Hash>;
391
392public:
393 using Base::Base;
394
395 SourceRefSet &join(const SourceRefSet &rhs);
396
397 friend mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const SourceRefSet &rhs);
398};
399
400static_assert(
402 "SourceRefSet must satisfy the ScalarLatticeValue requirements"
403);
404
405} // namespace llzk
406
407namespace llvm {
408
409template <> struct DenseMapInfo<llzk::SourceRef> {
411 return llzk::SourceRef(mlir::BlockArgument(reinterpret_cast<mlir::detail::ValueImpl *>(1)));
412 }
414 return llzk::SourceRef(mlir::BlockArgument(reinterpret_cast<mlir::detail::ValueImpl *>(2)));
415 }
416 static unsigned getHashValue(const llzk::SourceRef &ref) {
417 if (ref == getEmptyKey() || ref == getTombstoneKey()) {
418 return llvm::hash_value(ref.getAsOpaquePointer());
419 }
420 return llzk::SourceRef::Hash {}(ref);
421 }
422 static bool isEqual(const llzk::SourceRef &lhs, const llzk::SourceRef &rhs) { return lhs == rhs; }
423};
424
425} // namespace llvm
This file implements helper methods for constructing DynamicAPInts.
void print(llvm::raw_ostream &os) const
Defines an index into an LLZK object.
Definition SourceRef.h:42
std::strong_ordering operator<=>(const SourceRefIndex &rhs) const
Definition SourceRef.cpp:88
bool operator==(const SourceRefIndex &rhs) const
Definition SourceRef.h:82
bool isIndexRange() const
Definition SourceRef.h:73
size_t getHash() const
Definition SourceRef.h:100
bool isIndex() const
Definition SourceRef.h:67
bool isMember() const
Definition SourceRef.h:55
SourceRefIndex(const llvm::DynamicAPInt &i)
Definition SourceRef.h:48
SourceRefIndex(const llvm::APInt &low, const llvm::APInt &high)
Definition SourceRef.h:51
SourceRefIndex(const llvm::APInt &i)
Definition SourceRef.h:49
void dump() const
Definition SourceRef.h:79
llvm::DynamicAPInt getIndex() const
Definition SourceRef.h:68
void print(mlir::raw_ostream &os) const
Definition SourceRef.cpp:73
IndexRange getIndexRange() const
Definition SourceRef.h:74
component::MemberDefOp getMember() const
Definition SourceRef.h:59
SourceRefIndex(SymbolLookupResult< component::MemberDefOp > f)
Definition SourceRef.h:47
SourceRefIndex(IndexRange r)
Definition SourceRef.h:53
SourceRefIndex(int64_t i)
Definition SourceRef.h:50
SourceRefIndex(component::MemberDefOp f)
Definition SourceRef.h:46
SourceRefSet & join(const SourceRefSet &rhs)
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const SourceRefSet &rhs)
A reference to a "source", which is the base value from which other SSA values are derived.
Definition SourceRef.h:132
bool isIntegerVal() const
Definition SourceRef.h:236
bool isBlockArgument() const
Definition SourceRef.h:243
mlir::FailureOr< SourceRef > createChild(const SourceRefIndex &r) const
Definition SourceRef.h:339
std::vector< SourceRef > getAllChildren(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod) const
Get all direct children of this SourceRef, assuming this ref is not a scalar.
mlir::FailureOr< std::vector< SourceRefIndex > > getSuffix(const SourceRef &prefix) const
If prefix is a valid prefix of this reference, return the suffix that remains after removing the pref...
mlir::FailureOr< SourceRef > getParentPrefix() const
Create a new reference that is the immediate prefix of this reference if possible.
Definition SourceRef.h:326
mlir::FailureOr< function::CallOp > getCallOp() const
Definition SourceRef.h:279
void print(mlir::raw_ostream &os) const
bool isCallResult() const
Definition SourceRef.h:278
bool isScalar() const
Definition SourceRef.h:238
bool operator==(const SourceRef &rhs) const
mlir::FailureOr< component::CreateStructOp > getCreateStructOp() const
Definition SourceRef.h:271
bool isConstantFelt() const
Definition SourceRef.h:220
bool isRooted() const
Definition SourceRef.h:242
SourceRef(felt::FeltConstantOp c)
Definition SourceRef.h:211
SourceRef(component::CreateStructOp createOp, Path p={})
Definition SourceRef.h:202
llvm::ArrayRef< SourceRefIndex > getPath() const
Definition SourceRef.h:361
bool isValidPrefix(const SourceRef &prefix) const
Returns true iff prefix is a valid prefix of this reference.
std::strong_ordering operator<=>(const SourceRef &rhs) const
SourceRef(mlir::BlockArgument b, Path p={})
Definition SourceRef.h:201
mlir::FailureOr< llvm::DynamicAPInt > getConstantFeltValue() const
Definition SourceRef.h:281
bool isConstantIndex() const
Definition SourceRef.h:223
std::vector< SourceRefIndex > Path
Definition SourceRef.h:134
static std::vector< SourceRef > getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, const SourceRef &root)
Produce all possible SourceRefs that are present starting from the given root.
mlir::FailureOr< llvm::DynamicAPInt > getConstantValue() const
Definition SourceRef.h:296
mlir::FailureOr< unsigned > getInputNum() const
Definition SourceRef.h:262
mlir::FailureOr< NonDetOp > getNonDetOp() const
Definition SourceRef.h:276
void dump() const
Definition SourceRef.h:364
llvm::ArrayRef< SourceRefIndex > getPieces() const
Definition SourceRef.h:358
SourceRef(mlir::arith::ConstantIndexOp c)
Definition SourceRef.h:213
mlir::FailureOr< mlir::BlockArgument > getBlockArgument() const
Definition SourceRef.h:256
bool isIndexVal() const
Definition SourceRef.h:235
SourceRef(NonDetOp nondet, Path p={})
Definition SourceRef.h:204
SourceRef(mlir::OpResult rootResult, Path p={})
Definition SourceRef.h:206
mlir::FailureOr< SourceRef > createChild(const SourceRef &other) const
Definition SourceRef.h:348
mlir::FailureOr< llvm::DynamicAPInt > getConstantIndexValue() const
Definition SourceRef.h:289
mlir::FailureOr< SourceRef > translate(const SourceRef &prefix, const SourceRef &other) const
Create a new reference with prefix replaced with other iff prefix is a valid prefix for this referenc...
bool isNonDetOp() const
Definition SourceRef.h:275
mlir::FailureOr< mlir::Value > getConstant() const
Definition SourceRef.h:250
SourceRef(polymorphic::ConstReadOp c)
Definition SourceRef.h:215
bool isTemplateConstant() const
Definition SourceRef.h:227
bool isTypeVarVal() const
Definition SourceRef.h:237
bool isConstant() const
Definition SourceRef.h:231
bool operator!=(const SourceRef &rhs) const
Definition SourceRef.h:368
mlir::FailureOr< mlir::Value > getRoot() const
Definition SourceRef.h:244
bool isFeltVal() const
Definition SourceRef.h:234
bool isConstantInt() const
Definition SourceRef.h:232
bool isCreateStructOp() const
Definition SourceRef.h:270
mlir::Type getType() const
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
void ensure(bool condition, const llvm::Twine &errMsg)
DynamicAPInt toDynamicAPInt(StringRef str)
Interval operator<<(const Interval &lhs, const Interval &rhs)
static bool isEqual(const llzk::SourceRef &lhs, const llzk::SourceRef &rhs)
Definition SourceRef.h:422
static unsigned getHashValue(const llzk::SourceRef &ref)
Definition SourceRef.h:416
static llzk::SourceRef getTombstoneKey()
Definition SourceRef.h:413
static llzk::SourceRef getEmptyKey()
Definition SourceRef.h:410
size_t operator()(const SourceRefIndex &c) const
size_t operator()(const SourceRef &val) const