LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
SourceRef.h
Go to the documentation of this file.
1//===-- SourceRef.h ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
19#include "llzk/Util/Hash.h"
20
21#include <mlir/Analysis/DataFlowFramework.h>
22#include <mlir/Dialect/Arith/IR/Arith.h>
23#include <mlir/Pass/AnalysisManager.h>
24
25#include <llvm/ADT/DynamicAPInt.h>
26#include <llvm/ADT/EquivalenceClasses.h>
27
28#include <unordered_set>
29#include <vector>
30
31namespace llzk {
32
37 using IndexRange = std::pair<llvm::DynamicAPInt, llvm::DynamicAPInt>;
38
39public:
40 explicit SourceRefIndex(component::MemberDefOp f) : index(f) {}
42 explicit SourceRefIndex(const llvm::DynamicAPInt &i) : index(i) {}
43 explicit SourceRefIndex(const llvm::APInt &i) : index(toDynamicAPInt(i)) {}
44 explicit SourceRefIndex(int64_t i) : index(llvm::DynamicAPInt(i)) {}
45 SourceRefIndex(const llvm::APInt &low, const llvm::APInt &high)
46 : index(IndexRange {toDynamicAPInt(low), toDynamicAPInt(high)}) {}
47 explicit SourceRefIndex(IndexRange r) : index(r) {}
48
49 bool isMember() const {
50 return std::holds_alternative<SymbolLookupResult<component::MemberDefOp>>(index) ||
51 std::holds_alternative<component::MemberDefOp>(index);
52 }
54 ensure(isMember(), "SourceRefIndex: member requested but not contained");
55 if (std::holds_alternative<component::MemberDefOp>(index)) {
56 return std::get<component::MemberDefOp>(index);
57 }
58 return std::get<SymbolLookupResult<component::MemberDefOp>>(index).get();
59 }
60
61 bool isIndex() const { return std::holds_alternative<llvm::DynamicAPInt>(index); }
62 llvm::DynamicAPInt getIndex() const {
63 ensure(isIndex(), "SourceRefIndex: index requested but not contained");
64 return std::get<llvm::DynamicAPInt>(index);
65 }
66
67 bool isIndexRange() const { return std::holds_alternative<IndexRange>(index); }
68 IndexRange getIndexRange() const {
69 ensure(isIndexRange(), "SourceRefIndex: index range requested but not contained");
70 return std::get<IndexRange>(index);
71 }
72
73 inline void dump() const { print(llvm::errs()); }
74 void print(mlir::raw_ostream &os) const;
75
76 inline bool operator==(const SourceRefIndex &rhs) const {
77 if (isMember() && rhs.isMember()) {
78 // We compare the underlying members, since the member could be in a symbol
79 // lookup or not.
80 return getMember() == rhs.getMember();
81 }
82 if (isIndex() && rhs.isIndex()) {
83 return getIndex() == rhs.getIndex();
84 }
85 return index == rhs.index;
86 }
87
88 bool operator<(const SourceRefIndex &rhs) const;
89
90 bool operator>(const SourceRefIndex &rhs) const { return rhs < *this; }
91
92 struct Hash {
93 size_t operator()(const SourceRefIndex &c) const;
94 };
95
96 inline size_t getHash() const { return Hash {}(*this); }
97
98private:
105 std::variant<
107 IndexRange>
108 index;
109};
110
111static inline mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const SourceRefIndex &rhs) {
112 rhs.print(os);
113 return os;
114}
115
128
129public:
131 static std::vector<SourceRef>
132 getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, SourceRef root);
133
135 static std::vector<SourceRef>
137
140 static std::vector<SourceRef>
142
143 explicit SourceRef(mlir::BlockArgument b) : root(b), memberRefs(), constantVal(std::nullopt) {}
144 SourceRef(mlir::BlockArgument b, std::vector<SourceRefIndex> f)
145 : root(b), memberRefs(std::move(f)), constantVal(std::nullopt) {}
146
148 : root(createOp), memberRefs(), constantVal(std::nullopt) {}
149 SourceRef(component::CreateStructOp createOp, std::vector<SourceRefIndex> f)
150 : root(createOp), memberRefs(std::move(f)), constantVal(std::nullopt) {}
151
152 explicit SourceRef(felt::FeltConstantOp c) : root(std::nullopt), memberRefs(), constantVal(c) {}
153 explicit SourceRef(mlir::arith::ConstantIndexOp c)
154 : root(std::nullopt), memberRefs(), constantVal(c) {}
156 : root(std::nullopt), memberRefs(), constantVal(c) {}
157
158 mlir::Type getType() const;
159
160 bool isConstantFelt() const {
161 return constantVal.has_value() && std::holds_alternative<felt::FeltConstantOp>(*constantVal);
162 }
163 bool isConstantIndex() const {
164 return constantVal.has_value() &&
165 std::holds_alternative<mlir::arith::ConstantIndexOp>(*constantVal);
166 }
167 bool isTemplateConstant() const {
168 return constantVal.has_value() &&
169 std::holds_alternative<polymorphic::ConstReadOp>(*constantVal);
170 }
171 bool isConstant() const { return constantVal.has_value(); }
172 bool isConstantInt() const { return isConstantFelt() || isConstantIndex(); }
173
174 bool isFeltVal() const { return llvm::isa<felt::FeltType>(getType()); }
175 bool isIndexVal() const { return llvm::isa<mlir::IndexType>(getType()); }
176 bool isIntegerVal() const { return llvm::isa<mlir::IntegerType>(getType()); }
177 bool isTypeVarVal() const { return llvm::isa<polymorphic::TypeVarType>(getType()); }
178 bool isScalar() const {
179 return isConstant() || isFeltVal() || isIndexVal() || isIntegerVal() || isTypeVarVal();
180 }
181
182 bool isBlockArgument() const {
183 return root.has_value() && std::holds_alternative<mlir::BlockArgument>(*root);
184 }
185 mlir::BlockArgument getBlockArgument() const {
186 ensure(isBlockArgument(), "is not a block argument");
187 return std::get<mlir::BlockArgument>(*root);
188 }
189 unsigned getInputNum() const { return getBlockArgument().getArgNumber(); }
190
191 bool isCreateStructOp() const {
192 return root.has_value() && std::holds_alternative<component::CreateStructOp>(*root);
193 }
195 ensure(isCreateStructOp(), "is not a create struct op");
196 return std::get<component::CreateStructOp>(*root);
197 }
198
199 llvm::DynamicAPInt getConstantFeltValue() const {
200 ensure(
201 isConstantFelt(), mlir::Twine(mlir::StringRef(__FUNCTION__), " requires a constant felt!")
202 );
203 llvm::APInt i = std::get<felt::FeltConstantOp>(*constantVal).getValue();
204 return toDynamicAPInt(i);
205 }
206 llvm::DynamicAPInt getConstantIndexValue() const {
207 ensure(
208 isConstantIndex(), mlir::Twine(mlir::StringRef(__FUNCTION__), " requires a constant index!")
209 );
210 return llvm::DynamicAPInt(std::get<mlir::arith::ConstantIndexOp>(*constantVal).value());
211 }
212 llvm::DynamicAPInt getConstantValue() const {
213 ensure(
215 mlir::Twine(mlir::StringRef(__FUNCTION__), " requires a constant int type!")
216 );
218 }
219
221 bool isValidPrefix(const SourceRef &prefix) const;
222
227 mlir::FailureOr<std::vector<SourceRefIndex>> getSuffix(const SourceRef &prefix) const;
228
235 mlir::FailureOr<SourceRef> translate(const SourceRef &prefix, const SourceRef &other) const;
236
238 mlir::FailureOr<SourceRef> getParentPrefix() const {
239 if (isConstantFelt() || memberRefs.empty()) {
240 return mlir::failure();
241 }
242 auto copy = *this;
243 copy.memberRefs.pop_back();
244 return copy;
245 }
246
248 std::vector<SourceRef>
249 getAllChildren(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod) const;
250
252 auto copy = *this;
253 copy.memberRefs.push_back(r);
254 return copy;
255 }
256
258 assert(other.isConstantIndex());
260 }
261
262 const std::vector<SourceRefIndex> &getPieces() const { return memberRefs; }
263
264 void print(mlir::raw_ostream &os) const;
265 void dump() const { print(llvm::errs()); }
266
267 bool operator==(const SourceRef &rhs) const;
268
269 bool operator!=(const SourceRef &rhs) const { return !(*this == rhs); }
270
271 // required for EquivalenceClasses usage
272 bool operator<(const SourceRef &rhs) const;
273
274 bool operator>(const SourceRef &rhs) const { return rhs < *this; }
275
276 struct Hash {
277 size_t operator()(const SourceRef &val) const;
278 };
279
280private:
290 std::optional<std::variant<mlir::BlockArgument, component::CreateStructOp>> root;
291
292 std::vector<SourceRefIndex> memberRefs;
293 // using mutable to reduce constant casts for certain get* functions.
294 mutable std::optional<
295 std::variant<felt::FeltConstantOp, mlir::arith::ConstantIndexOp, polymorphic::ConstReadOp>>
296 constantVal;
297};
298
299mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const SourceRef &rhs);
300
301/* SourceRefSet */
302
303class SourceRefSet : public std::unordered_set<SourceRef, SourceRef::Hash> {
304 using Base = std::unordered_set<SourceRef, SourceRef::Hash>;
305
306public:
307 using Base::Base;
308
309 SourceRefSet &join(const SourceRefSet &rhs);
310
311 friend mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const SourceRefSet &rhs);
312};
313
314static_assert(
316 "SourceRefSet must satisfy the ScalarLatticeValue requirements"
317);
318
319} // namespace llzk
320
321namespace llvm {
322
323template <> struct DenseMapInfo<llzk::SourceRef> {
325 return llzk::SourceRef(mlir::BlockArgument(reinterpret_cast<mlir::detail::ValueImpl *>(1)));
326 }
328 return llzk::SourceRef(mlir::BlockArgument(reinterpret_cast<mlir::detail::ValueImpl *>(2)));
329 }
330 static unsigned getHashValue(const llzk::SourceRef &ref) {
331 if (ref == getEmptyKey() || ref == getTombstoneKey()) {
332 return llvm::hash_value(ref.getBlockArgument().getAsOpaquePointer());
333 }
334 return llzk::SourceRef::Hash {}(ref);
335 }
336 static bool isEqual(const llzk::SourceRef &lhs, const llzk::SourceRef &rhs) { return lhs == rhs; }
337};
338
339} // namespace llvm
This file implements helper methods for constructing DynamicAPInts.
void print(llvm::raw_ostream &os) const
Defines an index into an LLZK object.
Definition SourceRef.h:36
bool operator==(const SourceRefIndex &rhs) const
Definition SourceRef.h:76
bool isIndexRange() const
Definition SourceRef.h:67
bool operator<(const SourceRefIndex &rhs) const
Definition SourceRef.cpp:49
bool operator>(const SourceRefIndex &rhs) const
Definition SourceRef.h:90
size_t getHash() const
Definition SourceRef.h:96
bool isIndex() const
Definition SourceRef.h:61
bool isMember() const
Definition SourceRef.h:49
SourceRefIndex(const llvm::DynamicAPInt &i)
Definition SourceRef.h:42
SourceRefIndex(const llvm::APInt &low, const llvm::APInt &high)
Definition SourceRef.h:45
SourceRefIndex(const llvm::APInt &i)
Definition SourceRef.h:43
void dump() const
Definition SourceRef.h:73
llvm::DynamicAPInt getIndex() const
Definition SourceRef.h:62
void print(mlir::raw_ostream &os) const
Definition SourceRef.cpp:34
IndexRange getIndexRange() const
Definition SourceRef.h:68
component::MemberDefOp getMember() const
Definition SourceRef.h:53
SourceRefIndex(SymbolLookupResult< component::MemberDefOp > f)
Definition SourceRef.h:41
SourceRefIndex(IndexRange r)
Definition SourceRef.h:47
SourceRefIndex(int64_t i)
Definition SourceRef.h:44
SourceRefIndex(component::MemberDefOp f)
Definition SourceRef.h:40
SourceRefSet & join(const SourceRefSet &rhs)
friend mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const SourceRefSet &rhs)
A reference to a "source", which is the base value from which other SSA values are derived.
Definition SourceRef.h:127
SourceRef(component::CreateStructOp createOp, std::vector< SourceRefIndex > f)
Definition SourceRef.h:149
bool isIntegerVal() const
Definition SourceRef.h:176
bool isBlockArgument() const
Definition SourceRef.h:182
llvm::DynamicAPInt getConstantIndexValue() const
Definition SourceRef.h:206
std::vector< SourceRef > getAllChildren(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod) const
Get all direct children of this SourceRef, assuming this ref is not a scalar.
SourceRef createChild(SourceRef other) const
Definition SourceRef.h:257
mlir::FailureOr< SourceRef > getParentPrefix() const
Create a new reference that is the immediate prefix of this reference if possible.
Definition SourceRef.h:238
void print(mlir::raw_ostream &os) const
mlir::FailureOr< std::vector< SourceRefIndex > > getSuffix(const SourceRef &prefix) const
If prefix is a valid prefix of this reference, return the suffix that remains after removing the pref...
bool isScalar() const
Definition SourceRef.h:178
bool operator==(const SourceRef &rhs) const
bool isConstantFelt() const
Definition SourceRef.h:160
component::CreateStructOp getCreateStructOp() const
Definition SourceRef.h:194
SourceRef(felt::FeltConstantOp c)
Definition SourceRef.h:152
bool isValidPrefix(const SourceRef &prefix) const
Returns true iff prefix is a valid prefix of this reference.
const std::vector< SourceRefIndex > & getPieces() const
Definition SourceRef.h:262
bool isConstantIndex() const
Definition SourceRef.h:163
SourceRef createChild(SourceRefIndex r) const
Definition SourceRef.h:251
mlir::BlockArgument getBlockArgument() const
Definition SourceRef.h:185
llvm::DynamicAPInt getConstantValue() const
Definition SourceRef.h:212
static std::vector< SourceRef > getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, SourceRef root)
Produce all possible SourceRefs that are present starting from the given root.
void dump() const
Definition SourceRef.h:265
SourceRef(mlir::arith::ConstantIndexOp c)
Definition SourceRef.h:153
bool isIndexVal() const
Definition SourceRef.h:175
SourceRef(mlir::BlockArgument b, std::vector< SourceRefIndex > f)
Definition SourceRef.h:144
bool operator<(const SourceRef &rhs) const
SourceRef(mlir::BlockArgument b)
Definition SourceRef.h:143
mlir::FailureOr< SourceRef > translate(const SourceRef &prefix, const SourceRef &other) const
Create a new reference with prefix replaced with other iff prefix is a valid prefix for this referenc...
SourceRef(polymorphic::ConstReadOp c)
Definition SourceRef.h:155
bool isTemplateConstant() const
Definition SourceRef.h:167
bool isTypeVarVal() const
Definition SourceRef.h:177
bool isConstant() const
Definition SourceRef.h:171
bool operator!=(const SourceRef &rhs) const
Definition SourceRef.h:269
SourceRef(component::CreateStructOp createOp)
Definition SourceRef.h:147
llvm::DynamicAPInt getConstantFeltValue() const
Definition SourceRef.h:199
bool operator>(const SourceRef &rhs) const
Definition SourceRef.h:274
bool isFeltVal() const
Definition SourceRef.h:174
unsigned getInputNum() const
Definition SourceRef.h:189
bool isConstantInt() const
Definition SourceRef.h:172
bool isCreateStructOp() const
Definition SourceRef.h:191
mlir::Type getType() const
void ensure(bool condition, const llvm::Twine &errMsg)
DynamicAPInt toDynamicAPInt(StringRef str)
Interval operator<<(const Interval &lhs, const Interval &rhs)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
static bool isEqual(const llzk::SourceRef &lhs, const llzk::SourceRef &rhs)
Definition SourceRef.h:336
static unsigned getHashValue(const llzk::SourceRef &ref)
Definition SourceRef.h:330
static llzk::SourceRef getTombstoneKey()
Definition SourceRef.h:327
static llzk::SourceRef getEmptyKey()
Definition SourceRef.h:324
size_t operator()(const SourceRefIndex &c) const
Definition SourceRef.cpp:72
size_t operator()(const SourceRef &val) const