LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
SymbolUseGraph.cpp
Go to the documentation of this file.
1//===-- SymbolUseGraph.cpp --------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
15#include "llzk/Util/Compare.h"
16#include "llzk/Util/Constants.h"
20
21#include <mlir/IR/BuiltinOps.h>
22
23#include <llvm/ADT/SmallPtrSet.h>
24#include <llvm/ADT/SmallSet.h>
25#include <llvm/Support/GraphWriter.h>
26
27using namespace mlir;
28
29namespace llzk {
30
31//===----------------------------------------------------------------------===//
32// SymbolUseGraphNode
33//===----------------------------------------------------------------------===//
34
35void SymbolUseGraphNode::addSuccessor(SymbolUseGraphNode *node) {
36 if (this->successors.insert(node)) {
37 node->predecessors.insert(this);
38 }
39}
40
41void SymbolUseGraphNode::removeSuccessor(SymbolUseGraphNode *node) {
42 if (this->successors.remove(node)) {
43 node->predecessors.remove(this);
44 }
45}
46
47FailureOr<SymbolLookupResultUntyped>
48SymbolUseGraphNode::lookupSymbol(SymbolTableCollection &tables, bool reportMissing) const {
49 if (!isRealNode()) {
50 return failure();
51 }
52 Operation *lookupFrom = getSymbolPathRoot().getOperation();
53 auto res = lookupSymbolIn(tables, getSymbolPath(), lookupFrom, lookupFrom, reportMissing);
54 if (succeeded(res) || !reportMissing) {
55 return res;
56 }
57 // This is likely an error in the use graph and not a case that should ever happen.
58 return lookupFrom->emitError().append(
59 "Could not find symbol referenced in UseGraph: ", getSymbolPath()
60 );
61}
62
63//===----------------------------------------------------------------------===//
64// SymbolUseGraph
65//===----------------------------------------------------------------------===//
66
67namespace {
68
69template <typename R>
70R getPathAndCall(SymbolOpInterface defOp, llvm::function_ref<R(ModuleOp, SymbolRefAttr)> callback) {
71 assert(defOp); // pre-condition
72
73 ModuleOp foundRoot;
74 FailureOr<SymbolRefAttr> path = llzk::getPathFromRoot(defOp, &foundRoot);
75 if (failed(path)) {
76 // This occurs if there is no root module with LANG_ATTR_NAME attribute
77 // or there is an unnamed module between the root module and the symbol.
78 auto diag = defOp.emitError("in SymbolUseGraph, failed to build symbol path");
79 diag.attachNote(defOp.getLoc()).append("for this SymbolOp");
80 diag.report();
81 return nullptr;
82 }
83 return callback(foundRoot, path.value());
84}
85
86} // namespace
87
88SymbolUseGraph::SymbolUseGraph(SymbolOpInterface rootSymbolOp) {
89 assert(rootSymbolOp->hasTrait<OpTrait::SymbolTable>());
90 buildGraph(rootSymbolOp);
91}
92
94SymbolUseGraphNode *SymbolUseGraph::getSymbolUserNode(const SymbolTable::SymbolUse &u) {
95 SymbolOpInterface userSymbol = getSelfOrParentOfType<SymbolOpInterface>(u.getUser());
96 return getPathAndCall<SymbolUseGraphNode *>(
97 userSymbol, [this, &userSymbol](ModuleOp r, SymbolRefAttr p) {
98 auto *n = this->getOrAddNode(r, p, nullptr);
99 n->opsThatUseTheSymbol.insert(userSymbol);
100 return n;
101 }
102 );
103}
104
105void SymbolUseGraph::buildGraph(SymbolOpInterface symbolOp) {
106 auto walkFn = [this](Operation *op, bool) {
107 assert(op->hasTrait<OpTrait::SymbolTable>());
108 FailureOr<ModuleOp> opRootModule = llzk::getRootModule(op);
109 if (failed(opRootModule)) {
110 return;
111 }
112
113 SymbolTableCollection tables;
114 if (auto usesOpt = llzk::getSymbolUses(&op->getRegion(0))) {
115 // Create child node for each Symbol use, as successor of the user Symbol op.
116 for (SymbolTable::SymbolUse u : usesOpt.value()) {
117 bool isTemplateSymbol = false;
118 Operation *user = u.getUser();
119 SymbolRefAttr symRef = u.getSymbolRef();
120 // Pending [LLZK-272] only a heuristic approach is possible. Check for FlatSymbolRefAttr
121 // where the user is a MemberRefOpInterface or the user is located within a TemplateOp and
122 // append the TemplateOp path with the FlatSymbolRefAttr.
123 if (FlatSymbolRefAttr flatSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(symRef)) {
124 if (auto fref = llvm::dyn_cast<component::MemberRefOpInterface>(user);
125 fref && fref.getMemberNameAttr() == flatSymRef) {
126 symRef = llzk::appendLeaf(fref.getStructType().getNameRef(), flatSymRef);
127 } else if (auto userTemplate = getSelfOrParentOfType<polymorphic::TemplateOp>(user)) {
128 StringAttr localName = flatSymRef.getAttr();
129 isTemplateSymbol =
130 userTemplate.hasConstNamed<polymorphic::TemplateSymbolBindingOpInterface>(
131 localName
132 );
133 if (isTemplateSymbol || tables.getSymbolTable(userTemplate).lookup(localName)) {
134 // If 'flatSymRef' is defined in the SymbolTable for 'userTemplate' then it's
135 // a local symbol so prepend the full path of the template itself.
136 auto parentPath = llzk::getPathFromRoot(userTemplate);
137 assert(succeeded(parentPath));
138 symRef = llzk::appendLeaf(parentPath.value(), flatSymRef);
139 }
140 }
141 }
142 auto *node = this->getOrAddNode(opRootModule.value(), symRef, getSymbolUserNode(u));
143 node->isTemplateSymBinding = isTemplateSymbol;
144 node->opsThatUseTheSymbol.insert(user);
145 }
146 }
147 };
148 SymbolTable::walkSymbolTables(symbolOp.getOperation(), true, walkFn);
149
150 // Find all nodes with no successors and add the tail node as successor.
151 for (SymbolUseGraphNode *n : nodesIter()) {
152 if (!n->hasSuccessor()) {
153 n->addSuccessor(&tail);
154 }
155 }
156}
157
158SymbolUseGraphNode *SymbolUseGraph::getOrAddNode(
159 ModuleOp pathRoot, SymbolRefAttr path, SymbolUseGraphNode *predecessorNode
160) {
161 NodeMapKeyT key = std::make_pair(pathRoot, path);
162 std::unique_ptr<SymbolUseGraphNode> &nodeRef = nodes[key];
163 if (!nodeRef) {
164 nodeRef.reset(new SymbolUseGraphNode(pathRoot, path));
165 // When creating a new node, ensure it's attached to the graph, either as successor
166 // to the predecessor node (if given) else as successor to the root node.
167 if (predecessorNode) {
168 predecessorNode->addSuccessor(nodeRef.get());
169 } else {
170 root.addSuccessor(nodeRef.get());
171 }
172 } else if (predecessorNode) {
173 // When the node already exists and an additional predecessor node is given, add the node as a
174 // successor to the given predecessor node and detach from the 'root' (unless it's a self edge).
175 SymbolUseGraphNode *node = nodeRef.get();
176 predecessorNode->addSuccessor(node);
177 if (node != predecessorNode) {
178 root.removeSuccessor(node);
179 }
180 }
181 return nodeRef.get();
182}
183
184const SymbolUseGraphNode *SymbolUseGraph::lookupNode(ModuleOp pathRoot, SymbolRefAttr path) const {
185 NodeMapKeyT key = std::make_pair(pathRoot, path);
186 const auto *it = nodes.find(key);
187 return it == nodes.end() ? nullptr : it->second.get();
188}
189
190const SymbolUseGraphNode *SymbolUseGraph::lookupNode(SymbolOpInterface symbolDef) const {
191 return getPathAndCall<const SymbolUseGraphNode *>(symbolDef, [this](ModuleOp r, SymbolRefAttr p) {
192 return this->lookupNode(r, p);
193 });
194}
195
196//===----------------------------------------------------------------------===//
197// Printing
198//===----------------------------------------------------------------------===//
199
200std::string SymbolUseGraphNode::toString(bool showLocations) const {
201 return buildStringViaPrint(*this, showLocations);
202}
203
204namespace {
205
206inline void safeAppendPathRoot(llvm::raw_ostream &os, ModuleOp root) {
207 if (root) {
208 FailureOr<SymbolRefAttr> unambiguousRoot = getPathFromTopRoot(root);
209 if (succeeded(unambiguousRoot)) {
210 os << unambiguousRoot.value() << '\n';
211 } else {
212 os << "<<unknown path>>\n";
213 }
214 } else {
215 os << "<<NULL MODULE>>\n";
216 }
217}
218
219} // namespace
220
222 llvm::raw_ostream &os, bool showLocations, const std::string &locationLinePrefix
223) const {
224 os << '\'' << symbolPath << '\'';
225 if (isTemplateSymBinding) {
226 os << " (struct param)";
227 }
228 os << " with root module ";
229 safeAppendPathRoot(os, symbolPathRoot);
230 if (showLocations) {
231 // Print the user op locations (sorted for stable output). Printing only the location rather
232 // than the full Operation gives a short (single-line) format that's still useful for human
233 // debugging.
234 llvm::SmallSet<mlir::Location, 3, LocationComparator> locations;
235 for (Operation *user : getUserOps()) {
236 locations.insert(user->getLoc());
237 }
238 for (Location loc : locations) {
239 os << locationLinePrefix << loc << '\n';
240 }
241 }
242}
243
244void SymbolUseGraph::print(llvm::raw_ostream &os) const {
245 const SymbolUseGraphNode *rootPtr = &this->root;
246
247 // Tracks nodes that have been printed to ensure they are only printed once.
248 SmallPtrSet<SymbolUseGraphNode *, 16> done;
249
250 std::function<void(SymbolUseGraphNode *)> printNode = [rootPtr, &printNode, &done,
251 &os](SymbolUseGraphNode *node) {
252 // Skip if the node has been printed before
253 if (!done.insert(node).second) {
254 return;
255 }
256 // Print the current node
257 os << "// - Node : [" << node << "] ";
258 node->print(os, true, "// --- ");
259 // Print list of IDs for the predecessors (excluding root) and successors
260 os << "// --- Predecessors : [";
261 llvm::interleaveComma(
262 llvm::make_filter_range(
263 node->predecessorIter(), [rootPtr](SymbolUseGraphNode *n) { return n != rootPtr; }
264 ),
265 os
266 );
267 os << "]\n";
268 os << "// --- Successors : [";
269 llvm::interleaveComma(node->successorIter(), os);
270 os << "]\n";
271 // Recursively print the successors
272 for (SymbolUseGraphNode *c : node->successorIter()) {
273 printNode(c);
274 }
275 };
276
277 os << "// ---- SymbolUseGraph ----\n";
278 for (SymbolUseGraphNode *r : rootPtr->successorIter()) {
279 printNode(r);
280 }
281 os << "// ------------------------\n";
282 assert(done.size() == this->size() && "All nodes were not printed!");
283}
284
285void SymbolUseGraph::dumpToDotFile(std::string filename) const {
287 llvm::WriteGraph(this, "SymbolUseGraph", /*ShortNames*/ false, title, std::move(filename));
288}
289
290} // namespace llzk
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
bool isRealNode() const
Return 'false' iff this node is an artificial node created for the graph head/tail.
std::string toString(bool showLocations=false) const
Print the node in a human readable format.
mlir::SymbolRefAttr getSymbolPath() const
The symbol path+name relative to the closest root ModuleOp.
const OpSet & getUserOps() const
The set of operations that use the symbol.
mlir::ModuleOp getSymbolPathRoot() const
Return the root ModuleOp for the path.
llvm::iterator_range< iterator > successorIter() const
Range over successor nodes.
void print(llvm::raw_ostream &os, bool showLocations=false, const std::string &locationLinePrefix="") const
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
SymbolUseGraph(mlir::SymbolOpInterface rootSymbolOp)
llvm::iterator_range< iterator > nodesIter() const
Range over all nodes in the graph.
void print(llvm::raw_ostream &os) const
const SymbolUseGraphNode * lookupNode(mlir::ModuleOp pathRoot, mlir::SymbolRefAttr path) const
Return the existing node for the symbol reference relative to the given module, else nullptr.
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
FailureOr< ModuleOp > getRootModule(Operation *from)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
FailureOr< SymbolRefAttr > getPathFromTopRoot(SymbolOpInterface to, ModuleOp *foundRoot)
OpClass getSelfOrParentOfType(mlir::Operation *op)
Return the closest operation that is of type 'OpClass', either the op itself or an ancestor.
Definition OpHelpers.h:56
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)