LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
SymbolLookup.h
Go to the documentation of this file.
1//===-- SymbolLookup.h - Symbol Lookup Functions ----------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
14//===----------------------------------------------------------------------===//
15
16#pragma once
17
18#include "llzk/Util/Constants.h"
19
20#include <mlir/IR/BuiltinOps.h>
21#include <mlir/IR/Operation.h>
22#include <mlir/IR/OwningOpRef.h>
23
24#include <llvm/ADT/ArrayRef.h>
25#include <llvm/ADT/StringRef.h>
26
27#include <variant>
28#include <vector>
29
30namespace llzk {
31
32template <typename T> class SymbolLookupResult;
33
35 std::shared_ptr<std::pair<mlir::OwningOpRef<mlir::ModuleOp>, mlir::SymbolTableCollection>>;
36
38public:
39 SymbolLookupResultUntyped() : op(nullptr) {}
40 SymbolLookupResultUntyped(mlir::Operation *opPtr) : op(opPtr) {}
41
43 : op(other.op), managedResources(other.managedResources),
44 includeSymNameStack(other.includeSymNameStack), namespaceStack(other.namespaceStack) {}
45 template <typename T> SymbolLookupResultUntyped(const SymbolLookupResult<T> &other);
46
48 this->op = other.op;
49 this->managedResources = other.managedResources;
50 this->includeSymNameStack = other.includeSymNameStack;
51 this->namespaceStack = other.namespaceStack;
52 return *this;
53 }
54 template <typename T> SymbolLookupResultUntyped &operator=(const SymbolLookupResult<T> &other);
55
57 : op(other.op), managedResources(std::move(other.managedResources)),
58 includeSymNameStack(std::move(other.includeSymNameStack)),
59 namespaceStack(std::move(other.namespaceStack)) {
60 other.op = nullptr;
61 }
62 template <typename T> SymbolLookupResultUntyped(SymbolLookupResult<T> &&other);
63
65 if (this != &other) {
66 this->op = other.op;
67 other.op = nullptr;
68 this->managedResources = std::move(other.managedResources);
69 this->includeSymNameStack = std::move(other.includeSymNameStack);
70 this->namespaceStack = std::move(other.namespaceStack);
71 }
72 return *this;
73 }
74 template <typename T> SymbolLookupResultUntyped &operator=(SymbolLookupResult<T> &&other);
75
77 mlir::Operation *operator->();
78 mlir::Operation &operator*();
79 mlir::Operation &operator*() const;
80 mlir::Operation *get();
81 mlir::Operation *get() const;
82
84 operator bool() const;
85
87 std::vector<llvm::StringRef> getIncludeSymNames() const { return includeSymNameStack; }
88
91 llvm::ArrayRef<llvm::StringRef> getNamespace() const { return namespaceStack; }
92
94 bool viaInclude() const { return !includeSymNameStack.empty(); }
95
96 mlir::SymbolTableCollection *getSymbolTableCache() {
97 if (managedResources) {
98 return &managedResources->second;
99 } else {
100 return nullptr;
101 }
102 }
103
105 void manage(mlir::OwningOpRef<mlir::ModuleOp> &&ptr, mlir::SymbolTableCollection &&tables);
106
108 void trackIncludeAsName(llvm::StringRef includeOpSymName);
109
111 void pushNamespace(llvm::StringRef symName);
112
114 void prependNamespace(llvm::ArrayRef<llvm::StringRef> ns);
115
116 bool operator==(const SymbolLookupResultUntyped &rhs) const { return op == rhs.op; }
117
118private:
119 mlir::Operation *op;
122 ManagedResources managedResources;
124 std::vector<llvm::StringRef> includeSymNameStack;
127 std::vector<llvm::StringRef> namespaceStack;
128
129 friend class Within;
130};
131
132template <typename T> class SymbolLookupResult {
133public:
134 SymbolLookupResult(SymbolLookupResultUntyped &&innerRes) : inner(std::move(innerRes)) {}
135
138 T operator->() { return llvm::dyn_cast<T>(*inner); }
139 T operator*() { return llvm::dyn_cast<T>(*inner); }
140 const T operator*() const { return llvm::dyn_cast<T>(*inner); }
141 T get() { return llvm::dyn_cast<T>(inner.get()); }
142 T get() const { return llvm::dyn_cast<T>(inner.get()); }
143
144 operator bool() const { return inner && llvm::isa<T>(*inner); }
145
147 std::vector<llvm::StringRef> getIncludeSymNames() const { return inner.getIncludeSymNames(); }
148
151 llvm::ArrayRef<llvm::StringRef> getNamespace() const { return inner.getNamespace(); }
152
154 void prependNamespace(llvm::ArrayRef<llvm::StringRef> ns) { inner.prependNamespace(ns); }
155
157 bool viaInclude() const { return inner.viaInclude(); }
158
159 bool operator==(const SymbolLookupResult<T> &rhs) const { return inner == rhs.inner; }
160
161private:
163
164 friend class Within;
166};
167
168// These methods' definitions need to be here, after the declaration of SymbolLookupResult<T>
169
170template <typename T>
173
174template <typename T>
177 *this = other.inner;
178 return *this;
179}
180
181template <typename T>
184
185template <typename T>
187 *this = std::move(other.inner);
188 return *this;
189}
190
191class Within {
192public:
194 Within() : from(nullptr) {}
196 Within(mlir::Operation *op) : from(op) { assert(op && "cannot lookup within nullptr"); }
198 Within(SymbolLookupResultUntyped &&res) : from(std::move(res)) {}
200 template <typename T> Within(SymbolLookupResult<T> &&res) : Within(std::move(res.inner)) {}
201
202 Within(const Within &) = delete;
203 Within(Within &&other) noexcept : from(std::move(other.from)) {}
204 Within &operator=(const Within &) = delete;
205 Within &operator=(Within &&) noexcept;
206
207 inline static Within root() { return Within(); }
208
209 mlir::FailureOr<SymbolLookupResultUntyped> lookup(
210 mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin,
211 bool reportMissing = true
212 ) &&;
213
214private:
215 std::variant<mlir::Operation *, SymbolLookupResultUntyped> from;
216};
217
218inline mlir::FailureOr<SymbolLookupResultUntyped> lookupSymbolIn(
219 mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin,
220 mlir::Operation *origin, bool reportMissing = true
221) {
222 return std::move(lookupWithin).lookup(tables, symbol, origin, reportMissing);
223}
224
225inline mlir::FailureOr<SymbolLookupResultUntyped> lookupTopLevelSymbol(
226 mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin,
227 bool reportMissing = true
228) {
229 return Within().lookup(tables, symbol, origin, reportMissing);
230}
231
232template <typename T>
233inline mlir::FailureOr<SymbolLookupResult<T>> lookupSymbolIn(
234 mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin,
235 mlir::Operation *origin, bool reportMissing = true
236) {
237 auto found = lookupSymbolIn(tables, symbol, std::move(lookupWithin), origin, reportMissing);
238 if (mlir::failed(found)) {
239 return mlir::failure(); // lookupSymbolIn() already emits a sufficient error message
240 }
241 // Keep a copy of the op ptr in case we need it for displaying diagnostics
242 mlir::Operation *op = found->get();
243 // ... since the untyped result gets moved here into a typed result.
244 SymbolLookupResult<T> ret(std::move(*found));
245 if (!ret) {
246 if (reportMissing) {
247 return origin->emitError() << "symbol \"" << symbol << "\" references a '" << op->getName()
248 << "' but expected a '" << T::getOperationName() << '\'';
249 } else {
250 return mlir::failure();
251 }
252 }
253 return ret;
254}
255
256template <typename T>
257inline mlir::FailureOr<SymbolLookupResult<T>> lookupTopLevelSymbol(
258 mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin,
259 bool reportMissing = true
260) {
261 return lookupSymbolIn<T>(tables, symbol, Within(), origin, reportMissing);
262}
263
264} // namespace llzk
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
Definition LICENSE.txt:45
mlir::SymbolTableCollection * getSymbolTableCache()
bool viaInclude() const
Return 'true' if at least one IncludeOp was traversed to load this result.
void manage(mlir::OwningOpRef< mlir::ModuleOp > &&ptr, mlir::SymbolTableCollection &&tables)
Adds a pointer to the set of resources the result has to manage the lifetime of.
std::vector< llvm::StringRef > getIncludeSymNames() const
Return the stack of symbol names from the IncludeOp that were traversed to load this result.
void prependNamespace(llvm::ArrayRef< llvm::StringRef > ns)
Adds the given namespace to the beginning of this result's namespace.
void trackIncludeAsName(llvm::StringRef includeOpSymName)
Adds the symbol name from the IncludeOp that caused the module to be loaded.
SymbolLookupResultUntyped(const SymbolLookupResultUntyped &other)
llvm::ArrayRef< llvm::StringRef > getNamespace() const
Return the stack of symbol names from either IncludeOp or ModuleOp that were traversed to load this r...
bool operator==(const SymbolLookupResultUntyped &rhs) const
SymbolLookupResultUntyped(mlir::Operation *opPtr)
SymbolLookupResultUntyped & operator=(const SymbolLookupResultUntyped &other)
SymbolLookupResultUntyped & operator=(SymbolLookupResultUntyped &&other)
void pushNamespace(llvm::StringRef symName)
Adds the symbol name from an IncludeOp or ModuleOp where the op is contained.
mlir::Operation * operator->()
Access the internal operation.
SymbolLookupResultUntyped(SymbolLookupResultUntyped &&other)
void prependNamespace(llvm::ArrayRef< llvm::StringRef > ns)
Adds the given namespace to the beginning of this result's namespace.
llvm::ArrayRef< llvm::StringRef > getNamespace() const
Return the stack of symbol names from either IncludeOp or ModuleOp that were traversed to load this r...
std::vector< llvm::StringRef > getIncludeSymNames() const
Return the stack of symbol names from the IncludeOp that were traversed to load this result.
const T operator*() const
SymbolLookupResult(SymbolLookupResultUntyped &&innerRes)
bool viaInclude() const
Return 'true' if at least one IncludeOp was traversed to load this result.
friend class SymbolLookupResultUntyped
bool operator==(const SymbolLookupResult< T > &rhs) const
T operator->()
Access the internal operation as type T.
mlir::FailureOr< SymbolLookupResultUntyped > lookup(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true) &&
Within(const Within &)=delete
Within()
Lookup within the top-level (root) module.
Within(SymbolLookupResultUntyped &&res)
Lookup within the Operation of the given result and transfer managed resources.
Within(Within &&other) noexcept
Within & operator=(const Within &)=delete
Within(mlir::Operation *op)
Lookup within the given Operation (cannot be nullptr)
static Within root()
Within(SymbolLookupResult< T > &&res)
Lookup within the Operation of the given result and transfer managed resources.
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
std::shared_ptr< std::pair< mlir::OwningOpRef< mlir::ModuleOp >, mlir::SymbolTableCollection > > ManagedResources
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)