LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
SymbolLookup.cpp
Go to the documentation of this file.
1//===-- SymbolLookup.cpp - LLZK Symbol lookup helpers -----------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
16
19
20#include <mlir/IR/BuiltinOps.h>
21#include <mlir/IR/Operation.h>
22#include <mlir/IR/OwningOpRef.h>
23
24#include <llvm/ADT/STLExtras.h>
25#include <llvm/ADT/StringRef.h>
26#include <llvm/Support/Debug.h>
27
28#define DEBUG_TYPE "llzk-symbol-lookup"
29
30namespace llzk {
31using namespace mlir;
32using namespace include;
33
34namespace {
35SymbolLookupResultUntyped
36lookupSymbolRec(SymbolTableCollection &tables, SymbolRefAttr symbol, Operation *symTableOp) {
37 // First try a direct lookup via the SymbolTableCollection. Must use a low-level lookup function
38 // in order to properly account for modules that were added due to inlining IncludeOp.
39 {
40 SmallVector<Operation *, 4> symbolsFound;
41 if (succeeded(tables.lookupSymbolIn(symTableOp, symbol, symbolsFound))) {
42 SymbolLookupResultUntyped ret(symbolsFound.back());
43 for (auto it = symbolsFound.rbegin(); it != symbolsFound.rend(); ++it) {
44 Operation *op = *it;
45 if (op->hasAttr(LANG_ATTR_NAME)) {
46 auto symName = op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
47 ret.pushNamespace(symName);
48 if (!llvm::isa<ModuleOp>(op)) {
49 LLVM_DEBUG({ llvm::dbgs() << "[lookupSymbolRec] tracking op as include\n"; });
50
51 ret.trackIncludeAsName(symName);
52 }
53 }
54 }
55 return ret;
56 }
57 }
58 // Otherwise, check if the reference can be found by manually doing a lookup for each part of
59 // the reference in turn, traversing through IncludeOp symbols by parsing the included file.
60 if (Operation *rootOp = tables.lookupSymbolIn(symTableOp, symbol.getRootReference())) {
61 if (IncludeOp rootOpInc = llvm::dyn_cast<IncludeOp>(rootOp)) {
62 FailureOr<OwningOpRef<ModuleOp>> otherMod = rootOpInc.openModule();
63 if (succeeded(otherMod)) {
64 // Create a temporary SymbolTableCollection for caching the external symbols from the
65 // included module rather than adding these symbols to the existing SymbolTableCollection
66 // because it has no means of removing entries from its internal map and it is not safe to
67 // leave the dangling pointers in that map after the external module has been freed.
68 SymbolTableCollection external;
69 auto result = lookupSymbolRec(external, getTailAsSymbolRefAttr(symbol), otherMod->get());
70 if (result) {
71 result.manage(std::move(*otherMod), std::move(external));
72 auto symName = rootOpInc.getSymName();
73 result.pushNamespace(symName);
74 result.trackIncludeAsName(symName);
75 }
76 return result;
77 }
78 } else if (ModuleOp rootOpMod = llvm::dyn_cast<ModuleOp>(rootOp)) {
79 return lookupSymbolRec(tables, getTailAsSymbolRefAttr(symbol), rootOpMod);
80 }
81 }
82 // Otherwise, return empty result
83 return SymbolLookupResultUntyped();
84}
85} // namespace
86
87//===------------------------------------------------------------------===//
88// SymbolLookupResultUntyped
89//===------------------------------------------------------------------===//
90
92Operation *SymbolLookupResultUntyped::operator->() { return op; }
93Operation &SymbolLookupResultUntyped::operator*() { return *op; }
94Operation &SymbolLookupResultUntyped::operator*() const { return *op; }
95Operation *SymbolLookupResultUntyped::get() { return op; }
96Operation *SymbolLookupResultUntyped::get() const { return op; }
97
99SymbolLookupResultUntyped::operator bool() const { return op != nullptr; }
100
103 OwningOpRef<ModuleOp> &&ptr, SymbolTableCollection &&tables
104) {
105 // This may be called multiple times for the same result Operation but we only need to store the
106 // resources from the first call because that call will contain the final ModuleOp loaded in a
107 // chain of IncludeOp and that is the one which contains the result Operation*.
108 if (!managedResources) {
109 managedResources = std::make_shared<std::pair<OwningOpRef<ModuleOp>, SymbolTableCollection>>(
110 std::make_pair(std::move(ptr), std::move(tables))
111 );
112 }
113}
114
116void SymbolLookupResultUntyped::trackIncludeAsName(llvm::StringRef includeOpSymName) {
117 includeSymNameStack.push_back(includeOpSymName);
118}
119
120void SymbolLookupResultUntyped::pushNamespace(llvm::StringRef symName) {
121 namespaceStack.push_back(symName);
122}
123
124void SymbolLookupResultUntyped::prependNamespace(llvm::ArrayRef<llvm::StringRef> ns) {
125 std::vector<llvm::StringRef> newNamespace = ns;
126 newNamespace.insert(newNamespace.end(), namespaceStack.begin(), namespaceStack.end());
127 namespaceStack = newNamespace;
128}
129
130//===------------------------------------------------------------------===//
131// Within
132//===------------------------------------------------------------------===//
133
134Within &Within::operator=(Within &&other) noexcept {
135 if (this != &other) {
136 from = std::move(other.from);
137 }
138 return *this;
139}
140
141FailureOr<SymbolLookupResultUntyped> Within::lookup(
142 SymbolTableCollection &tables, SymbolRefAttr symbol, Operation *origin, bool reportMissing
143) && {
144 if (SymbolLookupResultUntyped *priorRes = std::get_if<SymbolLookupResultUntyped>(&this->from)) {
145 //---- Lookup within an existing result ----//
146 // Use the symbol table from prior result if available, otherwise use the parameter.
147 SymbolTableCollection *cachedTablesForRes = priorRes->getSymbolTableCache();
148 if (!cachedTablesForRes) {
149 cachedTablesForRes = &tables;
150 }
151 if (auto found = lookupSymbolRec(*cachedTablesForRes, symbol, priorRes->op)) {
152 assert(!found.managedResources && "should not have loaded additional modules");
153 // TODO: not quite sure following is true. If not, the result should contain
154 // `priorRes.includeSymNameStack` followed by `found.includeSymNameStack`.
155 assert(found.includeSymNameStack.empty() && "should not have loaded additional modules");
156 // Move stuff from 'priorRes' to the new result
157 found.managedResources = std::move(priorRes->managedResources);
158 found.includeSymNameStack = std::move(priorRes->includeSymNameStack);
159 return found;
160 }
161 } else {
162 //---- Lookup from a given operation or root (if nullptr) ----//
163 Operation *lookupFrom = std::get<Operation *>(this->from);
164 if (!lookupFrom) {
165 FailureOr<ModuleOp> root = getRootModule(origin);
166 if (failed(root)) {
167 return failure(); // getRootModule() already emits a sufficient error message
168 }
169 lookupFrom = root.value();
170 }
171 if (auto found = lookupSymbolRec(tables, symbol, lookupFrom)) {
172 return found;
173 }
174 }
175 // Handle the case where it was not found
176 if (reportMissing) {
177 return origin->emitOpError() << "references unknown symbol \"" << symbol << '"';
178 } else {
179 return failure();
180 }
181}
182
183} // namespace llzk
This file defines methods symbol lookup across LLZK operations and included files.
void manage(mlir::OwningOpRef< mlir::ModuleOp > &&ptr, mlir::SymbolTableCollection &&tables)
Adds a pointer to the set of resources the result has to manage the lifetime of.
void prependNamespace(llvm::ArrayRef< llvm::StringRef > ns)
Adds the given namespace to the beginning of this result's namespace.
void trackIncludeAsName(llvm::StringRef includeOpSymName)
Adds the symbol name from the IncludeOp that caused the module to be loaded.
void pushNamespace(llvm::StringRef symName)
Adds the symbol name from an IncludeOp or ModuleOp where the op is contained.
mlir::Operation * operator->()
Access the internal operation.
mlir::FailureOr< SymbolLookupResultUntyped > lookup(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true) &&
Within()
Lookup within the top-level (root) module.
Within & operator=(const Within &)=delete
static Within root()
constexpr char LANG_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that identifies the ModuleOp as the root module and s...
Definition Constants.h:23
FailureOr< ModuleOp > getRootModule(Operation *from)
mlir::SymbolRefAttr getTailAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the root/head element removed.