21#include <mlir/IR/BuiltinOps.h>
23#include <llvm/ADT/SmallPtrSet.h>
24#include <llvm/ADT/SmallSet.h>
25#include <llvm/Support/GraphWriter.h>
36 if (this->successors.insert(node)) {
37 node->predecessors.insert(
this);
42 if (this->successors.remove(node)) {
43 node->predecessors.remove(
this);
47FailureOr<SymbolLookupResultUntyped>
54 if (succeeded(res) || !reportMissing) {
58 return lookupFrom->emitError().append(
59 "Could not find symbol referenced in UseGraph: ",
getSymbolPath()
70R getPathAndCall(SymbolOpInterface defOp, llvm::function_ref<R(ModuleOp, SymbolRefAttr)> callback) {
78 auto diag = defOp.emitError(
"in SymbolUseGraph, failed to build symbol path");
79 diag.attachNote(defOp.getLoc()).append(
"for this SymbolOp");
83 return callback(foundRoot, path.value());
89 assert(rootSymbolOp->hasTrait<OpTrait::SymbolTable>());
90 buildGraph(rootSymbolOp);
94SymbolUseGraphNode *SymbolUseGraph::getSymbolUserNode(
const SymbolTable::SymbolUse &u) {
96 return getPathAndCall<SymbolUseGraphNode *>(
97 userSymbol, [
this, &userSymbol](ModuleOp r, SymbolRefAttr p) {
98 auto *n = this->getOrAddNode(r, p,
nullptr);
99 n->opsThatUseTheSymbol.insert(userSymbol);
105void SymbolUseGraph::buildGraph(SymbolOpInterface symbolOp) {
106 auto walkFn = [
this](Operation *op, bool) {
107 assert(op->hasTrait<OpTrait::SymbolTable>());
109 if (failed(opRootModule)) {
113 SymbolTableCollection tables;
116 for (SymbolTable::SymbolUse u : usesOpt.value()) {
117 bool isTemplateSymbol =
false;
118 Operation *user = u.getUser();
119 SymbolRefAttr symRef = u.getSymbolRef();
123 if (FlatSymbolRefAttr flatSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(symRef)) {
124 if (
auto fref = llvm::dyn_cast<component::MemberRefOpInterface>(user);
125 fref && fref.getMemberNameAttr() == flatSymRef) {
128 StringAttr localName = flatSymRef.getAttr();
130 userTemplate.hasConstNamed<polymorphic::TemplateSymbolBindingOpInterface>(
133 if (isTemplateSymbol || tables.getSymbolTable(userTemplate).lookup(localName)) {
137 assert(succeeded(parentPath));
142 auto *node = this->getOrAddNode(opRootModule.value(), symRef, getSymbolUserNode(u));
143 node->isTemplateSymBinding = isTemplateSymbol;
144 node->opsThatUseTheSymbol.insert(user);
148 SymbolTable::walkSymbolTables(symbolOp.getOperation(),
true, walkFn);
151 for (SymbolUseGraphNode *n :
nodesIter()) {
152 if (!n->hasSuccessor()) {
153 n->addSuccessor(&tail);
161 NodeMapKeyT key = std::make_pair(pathRoot, path);
162 std::unique_ptr<SymbolUseGraphNode> &nodeRef = nodes[key];
164 nodeRef.reset(
new SymbolUseGraphNode(pathRoot, path));
167 if (predecessorNode) {
168 predecessorNode->addSuccessor(nodeRef.get());
170 root.addSuccessor(nodeRef.get());
172 }
else if (predecessorNode) {
175 SymbolUseGraphNode *node = nodeRef.get();
176 predecessorNode->addSuccessor(node);
177 if (node != predecessorNode) {
178 root.removeSuccessor(node);
181 return nodeRef.get();
185 NodeMapKeyT key = std::make_pair(pathRoot, path);
186 const auto *it = nodes.find(key);
187 return it == nodes.end() ? nullptr : it->second.get();
191 return getPathAndCall<const SymbolUseGraphNode *>(symbolDef, [
this](ModuleOp r, SymbolRefAttr p) {
206inline void safeAppendPathRoot(llvm::raw_ostream &os, ModuleOp root) {
209 if (succeeded(unambiguousRoot)) {
210 os << unambiguousRoot.value() <<
'\n';
212 os <<
"<<unknown path>>\n";
215 os <<
"<<NULL MODULE>>\n";
222 llvm::raw_ostream &os,
bool showLocations,
const std::string &locationLinePrefix
224 os <<
'\'' << symbolPath <<
'\'';
225 if (isTemplateSymBinding) {
226 os <<
" (struct param)";
228 os <<
" with root module ";
229 safeAppendPathRoot(os, symbolPathRoot);
234 llvm::SmallSet<mlir::Location, 3, LocationComparator> locations;
236 locations.insert(user->getLoc());
238 for (Location loc : locations) {
239 os << locationLinePrefix << loc <<
'\n';
248 SmallPtrSet<SymbolUseGraphNode *, 16> done;
253 if (!done.insert(node).second) {
257 os <<
"// - Node : [" << node <<
"] ";
258 node->print(os,
true,
"// --- ");
260 os <<
"// --- Predecessors : [";
261 llvm::interleaveComma(
262 llvm::make_filter_range(
268 os <<
"// --- Successors : [";
269 llvm::interleaveComma(node->successorIter(), os);
277 os <<
"// ---- SymbolUseGraph ----\n";
281 os <<
"// ------------------------\n";
282 assert(done.size() == this->size() &&
"All nodes were not printed!");
287 llvm::WriteGraph(
this,
"SymbolUseGraph",
false, title, std::move(filename));
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
bool isRealNode() const
Return 'false' iff this node is an artificial node created for the graph head/tail.
std::string toString(bool showLocations=false) const
Print the node in a human readable format.
mlir::SymbolRefAttr getSymbolPath() const
The symbol path+name relative to the closest root ModuleOp.
const OpSet & getUserOps() const
The set of operations that use the symbol.
mlir::ModuleOp getSymbolPathRoot() const
Return the root ModuleOp for the path.
llvm::iterator_range< iterator > successorIter() const
Range over successor nodes.
void print(llvm::raw_ostream &os, bool showLocations=false, const std::string &locationLinePrefix="") const
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
SymbolUseGraph(mlir::SymbolOpInterface rootSymbolOp)
llvm::iterator_range< iterator > nodesIter() const
Range over all nodes in the graph.
void print(llvm::raw_ostream &os) const
const SymbolUseGraphNode * lookupNode(mlir::ModuleOp pathRoot, mlir::SymbolRefAttr path) const
Return the existing node for the symbol reference relative to the given module, else nullptr.
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
FailureOr< ModuleOp > getRootModule(Operation *from)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
FailureOr< SymbolRefAttr > getPathFromTopRoot(SymbolOpInterface to, ModuleOp *foundRoot)
OpClass getSelfOrParentOfType(mlir::Operation *op)
Return the closest operation that is of type 'OpClass', either the op itself or an ancestor.
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)
static std::string getGraphName(GraphType)