LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
SymbolLookup.h
Go to the documentation of this file.
1//===-- SymbolLookup.h - Symbol Lookup Functions ----------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
14//===----------------------------------------------------------------------===//
15
16#pragma once
17
18#include "llzk/Util/Constants.h"
19
20#include <mlir/IR/BuiltinOps.h>
21#include <mlir/IR/Operation.h>
22#include <mlir/IR/OwningOpRef.h>
23
24#include <llvm/ADT/ArrayRef.h>
25#include <llvm/ADT/StringRef.h>
26
27#include <variant>
28#include <vector>
29
30namespace llzk {
31
32template <typename T> class SymbolLookupResult;
33
35 std::shared_ptr<std::pair<mlir::OwningOpRef<mlir::ModuleOp>, mlir::SymbolTableCollection>>;
36
38public:
39 SymbolLookupResultUntyped() : op(nullptr) {}
40 SymbolLookupResultUntyped(mlir::Operation *opPtr) : op(opPtr) {}
41
43 : op(other.op), managedResources(other.managedResources),
44 includeSymNameStack(other.includeSymNameStack), namespaceStack(other.namespaceStack) {}
45 template <typename T> SymbolLookupResultUntyped(const SymbolLookupResult<T> &other);
46
48 if (this == &other) {
49 return *this;
50 }
51 this->op = other.op;
52 this->managedResources = other.managedResources;
53 this->includeSymNameStack = other.includeSymNameStack;
54 this->namespaceStack = other.namespaceStack;
55 return *this;
56 }
57 template <typename T> SymbolLookupResultUntyped &operator=(const SymbolLookupResult<T> &other);
58
60 : op(other.op), managedResources(std::move(other.managedResources)),
61 includeSymNameStack(std::move(other.includeSymNameStack)),
62 namespaceStack(std::move(other.namespaceStack)) {
63 other.op = nullptr;
64 }
65 template <typename T> SymbolLookupResultUntyped(SymbolLookupResult<T> &&other);
66
68 if (this != &other) {
69 this->op = other.op;
70 other.op = nullptr;
71 this->managedResources = std::move(other.managedResources);
72 this->includeSymNameStack = std::move(other.includeSymNameStack);
73 this->namespaceStack = std::move(other.namespaceStack);
74 }
75 return *this;
76 }
77 template <typename T> SymbolLookupResultUntyped &operator=(SymbolLookupResult<T> &&other);
78
80 mlir::Operation *operator->();
81 mlir::Operation &operator*();
82 mlir::Operation &operator*() const;
83 mlir::Operation *get();
84 mlir::Operation *get() const;
85
87 operator bool() const;
88
90 std::vector<llvm::StringRef> getIncludeSymNames() const { return includeSymNameStack; }
91
94 llvm::ArrayRef<llvm::StringRef> getNamespace() const { return namespaceStack; }
95
97 bool viaInclude() const { return !includeSymNameStack.empty(); }
98
99 mlir::SymbolTableCollection *getSymbolTableCache() {
100 if (managedResources) {
101 return &managedResources->second;
102 } else {
103 return nullptr;
104 }
105 }
106
108 bool isManaged() const { return managedResources != nullptr; }
109
111 void manage(mlir::OwningOpRef<mlir::ModuleOp> &&ptr, mlir::SymbolTableCollection &&tables);
112
114 void trackIncludeAsName(llvm::StringRef includeOpSymName);
115
117 void pushNamespace(llvm::StringRef symName);
118
120 void prependNamespace(llvm::ArrayRef<llvm::StringRef> ns);
121
122 bool operator==(const SymbolLookupResultUntyped &rhs) const { return op == rhs.op; }
123
124private:
125 mlir::Operation *op;
128 ManagedResources managedResources;
130 std::vector<llvm::StringRef> includeSymNameStack;
133 std::vector<llvm::StringRef> namespaceStack;
134
135 friend class Within;
136};
137
138template <typename T> class SymbolLookupResult {
139public:
140 SymbolLookupResult(SymbolLookupResultUntyped &&innerRes) : inner(std::move(innerRes)) {}
141
144 T operator->() { return llvm::dyn_cast<T>(*inner); }
145 T operator*() { return llvm::dyn_cast<T>(*inner); }
146 const T operator*() const { return llvm::dyn_cast<T>(*inner); }
147 T get() { return llvm::dyn_cast<T>(inner.get()); }
148 T get() const { return llvm::dyn_cast<T>(inner.get()); }
149
150 operator bool() const { return inner && llvm::isa<T>(*inner); }
151
153 std::vector<llvm::StringRef> getIncludeSymNames() const { return inner.getIncludeSymNames(); }
154
157 llvm::ArrayRef<llvm::StringRef> getNamespace() const { return inner.getNamespace(); }
158
160 void prependNamespace(llvm::ArrayRef<llvm::StringRef> ns) { inner.prependNamespace(ns); }
161
163 bool viaInclude() const { return inner.viaInclude(); }
164
165 bool operator==(const SymbolLookupResult<T> &rhs) const { return inner == rhs.inner; }
166
168 bool isManaged() const { return inner.isManaged(); }
169
170private:
172
173 friend class Within;
175};
176
177// These methods' definitions need to be here, after the declaration of SymbolLookupResult<T>
178
179template <typename T>
182
183template <typename T>
186 *this = other.inner;
187 return *this;
188}
189
190template <typename T>
193
194template <typename T>
196 *this = std::move(other.inner);
197 return *this;
198}
199
200class Within {
201public:
203 Within() : from(nullptr) {}
205 Within(mlir::Operation *op) : from(op) { assert(op && "cannot lookup within nullptr"); }
207 Within(SymbolLookupResultUntyped &&res) : from(std::move(res)) {}
209 template <typename T> Within(SymbolLookupResult<T> &&res) : Within(std::move(res.inner)) {}
210
211 Within(const Within &) = delete;
212 Within(Within &&other) noexcept : from(std::move(other.from)) {}
213 Within &operator=(const Within &) = delete;
214 Within &operator=(Within &&) noexcept;
215
216 inline static Within root() { return Within(); }
217
218 mlir::FailureOr<SymbolLookupResultUntyped> lookup(
219 mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin,
220 bool reportMissing = true
221 ) &&;
222
223private:
224 std::variant<mlir::Operation *, SymbolLookupResultUntyped> from;
225};
226
227inline mlir::FailureOr<SymbolLookupResultUntyped> lookupSymbolIn(
228 mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin,
229 mlir::Operation *origin, bool reportMissing = true
230) {
231 return std::move(lookupWithin).lookup(tables, symbol, origin, reportMissing);
232}
233
234inline mlir::FailureOr<SymbolLookupResultUntyped> lookupTopLevelSymbol(
235 mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin,
236 bool reportMissing = true
237) {
238 return Within().lookup(tables, symbol, origin, reportMissing);
239}
240
241template <typename T>
242inline mlir::FailureOr<SymbolLookupResult<T>> lookupSymbolIn(
243 mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin,
244 mlir::Operation *origin, bool reportMissing = true
245) {
246 auto found = lookupSymbolIn(tables, symbol, std::move(lookupWithin), origin, reportMissing);
247 if (mlir::failed(found)) {
248 return mlir::failure(); // lookupSymbolIn() already emits a sufficient error message
249 }
250 // Keep a copy of the op ptr in case we need it for displaying diagnostics
251 mlir::Operation *op = found->get();
252 // ... since the untyped result gets moved here into a typed result.
253 SymbolLookupResult<T> ret(std::move(*found));
254 if (!ret) {
255 if (reportMissing) {
256 return origin->emitError() << "symbol \"" << symbol << "\" references a '" << op->getName()
257 << "' but expected a '" << T::getOperationName() << '\'';
258 } else {
259 return mlir::failure();
260 }
261 }
262 return ret;
263}
264
265template <typename T>
266inline mlir::FailureOr<SymbolLookupResult<T>> lookupTopLevelSymbol(
267 mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin,
268 bool reportMissing = true
269) {
270 return lookupSymbolIn<T>(tables, symbol, Within(), origin, reportMissing);
271}
272
273} // namespace llzk
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
Definition LICENSE.txt:45
mlir::SymbolTableCollection * getSymbolTableCache()
bool viaInclude() const
Return 'true' if at least one IncludeOp was traversed to load this result.
void manage(mlir::OwningOpRef< mlir::ModuleOp > &&ptr, mlir::SymbolTableCollection &&tables)
Adds a pointer to the set of resources the result has to manage the lifetime of.
std::vector< llvm::StringRef > getIncludeSymNames() const
Return the stack of symbol names from the IncludeOp that were traversed to load this result.
SymbolLookupResultUntyped(SymbolLookupResultUntyped &&other) noexcept
void prependNamespace(llvm::ArrayRef< llvm::StringRef > ns)
Adds the given namespace to the beginning of this result's namespace.
void trackIncludeAsName(llvm::StringRef includeOpSymName)
Adds the symbol name from the IncludeOp that caused the module to be loaded.
SymbolLookupResultUntyped(const SymbolLookupResultUntyped &other)
llvm::ArrayRef< llvm::StringRef > getNamespace() const
Return the stack of symbol names from either IncludeOp or ModuleOp that were traversed to load this r...
bool operator==(const SymbolLookupResultUntyped &rhs) const
bool isManaged() const
True iff the symbol is managed (i.e., loaded via an IncludeOp).
SymbolLookupResultUntyped(mlir::Operation *opPtr)
SymbolLookupResultUntyped & operator=(const SymbolLookupResultUntyped &other)
void pushNamespace(llvm::StringRef symName)
Adds the symbol name from an IncludeOp or ModuleOp where the op is contained.
mlir::Operation * operator->()
Access the internal operation.
SymbolLookupResultUntyped & operator=(SymbolLookupResultUntyped &&other) noexcept
void prependNamespace(llvm::ArrayRef< llvm::StringRef > ns)
Adds the given namespace to the beginning of this result's namespace.
llvm::ArrayRef< llvm::StringRef > getNamespace() const
Return the stack of symbol names from either IncludeOp or ModuleOp that were traversed to load this r...
std::vector< llvm::StringRef > getIncludeSymNames() const
Return the stack of symbol names from the IncludeOp that were traversed to load this result.
const T operator*() const
SymbolLookupResult(SymbolLookupResultUntyped &&innerRes)
bool viaInclude() const
Return 'true' if at least one IncludeOp was traversed to load this result.
friend class SymbolLookupResultUntyped
bool isManaged() const
Return 'true' if the inner resource is managed (i.e., loaded via an IncludeOp).
bool operator==(const SymbolLookupResult< T > &rhs) const
T operator->()
Access the internal operation as type T.
mlir::FailureOr< SymbolLookupResultUntyped > lookup(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true) &&
Within(const Within &)=delete
Within()
Lookup within the top-level (root) module.
Within(SymbolLookupResultUntyped &&res)
Lookup within the Operation of the given result and transfer managed resources.
Within(Within &&other) noexcept
Within & operator=(const Within &)=delete
Within(mlir::Operation *op)
Lookup within the given Operation (cannot be nullptr)
static Within root()
Within(SymbolLookupResult< T > &&res)
Lookup within the Operation of the given result and transfer managed resources.
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
std::shared_ptr< std::pair< mlir::OwningOpRef< mlir::ModuleOp >, mlir::SymbolTableCollection > > ManagedResources
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)