LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
SymbolLookup.cpp
Go to the documentation of this file.
1//===-- SymbolLookup.cpp - LLZK Symbol lookup helpers -----------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
18
19#include <mlir/IR/BuiltinOps.h>
20#include <mlir/IR/Operation.h>
21#include <mlir/IR/OwningOpRef.h>
22
23#include <llvm/ADT/STLExtras.h>
24#include <llvm/ADT/StringRef.h>
25#include <llvm/Support/Debug.h>
26
27#define DEBUG_TYPE "llzk-symbol-lookup"
28
29namespace llzk {
30using namespace mlir;
31using namespace include;
32
33namespace {
34SymbolLookupResultUntyped
35lookupSymbolRec(SymbolTableCollection &tables, SymbolRefAttr symbol, Operation *symTableOp) {
36 // First try a direct lookup via the SymbolTableCollection. Must use a low-level lookup function
37 // in order to properly account for modules that were added due to inlining IncludeOp.
38 {
39 SmallVector<Operation *, 4> symbolsFound;
40 if (succeeded(tables.lookupSymbolIn(symTableOp, symbol, symbolsFound))) {
41 SymbolLookupResultUntyped ret(symbolsFound.back());
42 for (auto it = symbolsFound.rbegin(); it != symbolsFound.rend(); ++it) {
43 Operation *op = *it;
44 if (op->hasAttr(LANG_ATTR_NAME)) {
45 auto symName = op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
46 ret.pushNamespace(symName);
47 if (!llvm::isa<ModuleOp>(op)) {
48 LLVM_DEBUG({ llvm::dbgs() << "[lookupSymbolRec] tracking op as include\n"; });
49
50 ret.trackIncludeAsName(symName);
51 }
52 }
53 }
54 return ret;
55 }
56 }
57 // Otherwise, check if the reference can be found by manually doing a lookup for each part of
58 // the reference in turn, traversing through IncludeOp symbols by parsing the included file.
59 if (Operation *rootOp = tables.lookupSymbolIn(symTableOp, symbol.getRootReference())) {
60 if (IncludeOp rootOpInc = llvm::dyn_cast<IncludeOp>(rootOp)) {
61 FailureOr<OwningOpRef<ModuleOp>> otherMod = rootOpInc.openModule();
62 if (succeeded(otherMod)) {
63 // Create a temporary SymbolTableCollection for caching the external symbols from the
64 // included module rather than adding these symbols to the existing SymbolTableCollection
65 // because it has no means of removing entries from its internal map and it is not safe to
66 // leave the dangling pointers in that map after the external module has been freed.
67 SymbolTableCollection external;
68 auto result = lookupSymbolRec(external, getTailAsSymbolRefAttr(symbol), otherMod->get());
69 if (result) {
70 result.manage(std::move(*otherMod), std::move(external));
71 auto symName = rootOpInc.getSymName();
72 result.pushNamespace(symName);
73 result.trackIncludeAsName(symName);
74 }
75 return result;
76 }
77 } else if (ModuleOp rootOpMod = llvm::dyn_cast<ModuleOp>(rootOp)) {
78 return lookupSymbolRec(tables, getTailAsSymbolRefAttr(symbol), rootOpMod);
79 }
80 }
81 // Otherwise, return empty result
82 return SymbolLookupResultUntyped();
83}
84} // namespace
85
86//===------------------------------------------------------------------===//
87// SymbolLookupResultUntyped
88//===------------------------------------------------------------------===//
89
91Operation *SymbolLookupResultUntyped::operator->() { return op; }
92Operation &SymbolLookupResultUntyped::operator*() { return *op; }
93Operation &SymbolLookupResultUntyped::operator*() const { return *op; }
94Operation *SymbolLookupResultUntyped::get() { return op; }
95Operation *SymbolLookupResultUntyped::get() const { return op; }
96
98SymbolLookupResultUntyped::operator bool() const { return op != nullptr; }
99
102 OwningOpRef<ModuleOp> &&ptr, SymbolTableCollection &&tables
103) {
104 // This may be called multiple times for the same result Operation but we only need to store the
105 // resources from the first call because that call will contain the final ModuleOp loaded in a
106 // chain of IncludeOp and that is the one which contains the result Operation*.
107 if (!managedResources) {
108 managedResources = std::make_shared<std::pair<OwningOpRef<ModuleOp>, SymbolTableCollection>>(
109 std::make_pair(std::move(ptr), std::move(tables))
110 );
111 }
112}
113
115void SymbolLookupResultUntyped::trackIncludeAsName(llvm::StringRef includeOpSymName) {
116 includeSymNameStack.push_back(includeOpSymName);
117}
118
119void SymbolLookupResultUntyped::pushNamespace(llvm::StringRef symName) {
120 namespaceStack.push_back(symName);
121}
122
123void SymbolLookupResultUntyped::prependNamespace(llvm::ArrayRef<llvm::StringRef> ns) {
124 std::vector<llvm::StringRef> newNamespace = ns;
125 newNamespace.insert(newNamespace.end(), namespaceStack.begin(), namespaceStack.end());
126 namespaceStack = newNamespace;
127}
128
129//===------------------------------------------------------------------===//
130// Within
131//===------------------------------------------------------------------===//
132
133Within &Within::operator=(Within &&other) noexcept {
134 if (this != &other) {
135 from = std::move(other.from);
136 }
137 return *this;
138}
139
140FailureOr<SymbolLookupResultUntyped> Within::lookup(
141 SymbolTableCollection &tables, SymbolRefAttr symbol, Operation *origin, bool reportMissing
142) && {
143 if (SymbolLookupResultUntyped *priorRes = std::get_if<SymbolLookupResultUntyped>(&this->from)) {
144 //---- Lookup within an existing result ----//
145 // Use the symbol table from prior result if available, otherwise use the parameter.
146 SymbolTableCollection *cachedTablesForRes = priorRes->getSymbolTableCache();
147 if (!cachedTablesForRes) {
148 cachedTablesForRes = &tables;
149 }
150 if (auto found = lookupSymbolRec(*cachedTablesForRes, symbol, priorRes->op)) {
151 assert(!found.managedResources && "should not have loaded additional modules");
152 // TODO: not quite sure following is true. If not, the result should contain
153 // `priorRes.includeSymNameStack` followed by `found.includeSymNameStack`.
154 assert(found.includeSymNameStack.empty() && "should not have loaded additional modules");
155 // Move stuff from 'priorRes' to the new result
156 found.managedResources = std::move(priorRes->managedResources);
157 found.includeSymNameStack = std::move(priorRes->includeSymNameStack);
158 return found;
159 }
160 } else {
161 //---- Lookup from a given operation or root (if nullptr) ----//
162 Operation *lookupFrom = std::get<Operation *>(this->from);
163 if (!lookupFrom) {
164 FailureOr<ModuleOp> root = getRootModule(origin);
165 if (failed(root)) {
166 return failure(); // getRootModule() already emits a sufficient error message
167 }
168 lookupFrom = root.value();
169 }
170 if (auto found = lookupSymbolRec(tables, symbol, lookupFrom)) {
171 return found;
172 }
173 }
174 // Handle the case where it was not found
175 if (reportMissing) {
176 return origin->emitOpError() << "references unknown symbol \"" << symbol << '"';
177 } else {
178 return failure();
179 }
180}
181
182} // namespace llzk
This file defines methods symbol lookup across LLZK operations and included files.
void manage(mlir::OwningOpRef< mlir::ModuleOp > &&ptr, mlir::SymbolTableCollection &&tables)
Adds a pointer to the set of resources the result has to manage the lifetime of.
void prependNamespace(llvm::ArrayRef< llvm::StringRef > ns)
Adds the given namespace to the beginning of this result's namespace.
void trackIncludeAsName(llvm::StringRef includeOpSymName)
Adds the symbol name from the IncludeOp that caused the module to be loaded.
void pushNamespace(llvm::StringRef symName)
Adds the symbol name from an IncludeOp or ModuleOp where the op is contained.
mlir::Operation * operator->()
Access the internal operation.
mlir::FailureOr< SymbolLookupResultUntyped > lookup(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true) &&
Within()
Lookup within the top-level (root) module.
Within & operator=(const Within &)=delete
static Within root()
constexpr char LANG_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that identifies the ModuleOp as the root module and s...
Definition Constants.h:23
FailureOr< ModuleOp > getRootModule(Operation *from)
mlir::SymbolRefAttr getTailAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the root/head element removed.