16#include <mlir/IR/BuiltinOps.h>
18#include <llvm/ADT/DepthFirstIterator.h>
19#include <llvm/ADT/SmallSet.h>
20#include <llvm/Support/GraphWriter.h>
31 assert(!node->parent &&
"def cannot be in more than one symbol table");
33 children.insert(node);
42void assertProperBuild(SymbolOpInterface root,
const SymbolDefTree *tree) {
44 llvm::SmallSet<SymbolOpInterface, 16> fromGraph;
46 if (SymbolOpInterface s = r->getOp()) {
52 root.walk([&fromGraph](SymbolOpInterface s) { assert(fromGraph.contains(s)); });
59 assert(rootSymbol->hasTrait<OpTrait::SymbolTable>());
60 buildTree(rootSymbol,
nullptr);
61 assertProperBuild(rootSymbol,
this);
64void SymbolDefTree::buildTree(SymbolOpInterface symbolOp,
SymbolDefTreeNode *parentNode) {
66 parentNode = getOrAddNode(symbolOp, parentNode);
68 if (symbolOp->hasTrait<OpTrait::SymbolTable>()) {
69 for (Operation &op : symbolOp->getRegion(0).front()) {
70 if (SymbolOpInterface childSym = llvm::dyn_cast<SymbolOpInterface>(&op)) {
71 buildTree(childSym, parentNode);
78SymbolDefTree::getOrAddNode(SymbolOpInterface symbolDef,
SymbolDefTreeNode *parentNode) {
79 std::unique_ptr<SymbolDefTreeNode> &node = nodes[symbolDef];
81 node.reset(
new SymbolDefTreeNode(symbolDef));
84 parentNode->addChild(node.get());
86 root.addChild(node.get());
93 const auto *it = nodes.find(symbolDef);
94 return it == nodes.
end() ? nullptr : it->second.get();
104 os <<
'\'' << symbolDef->getName() <<
"' ";
106 os <<
"named " << name <<
'\n';
108 os <<
"without a name\n";
115 os <<
"// - Node : [" << node <<
"] ";
118 os <<
"// --- Children : [";
119 llvm::interleaveComma(node->children, os);
127 os <<
"// ---- SymbolDefTree ----\n";
131 os <<
"// -----------------------\n";
136 llvm::WriteGraph(
this,
"SymbolDefTree",
false, title, std::move(filename));
child_iterator end() const
std::string toString() const
Print the node in a human readable format.
void print(llvm::raw_ostream &os) const
Builds a tree structure representing the symbol table structure.
const SymbolDefTreeNode * lookupNode(mlir::SymbolOpInterface symbolOp) const
Lookup the node for the given symbol Op, or nullptr if none exists.
SymbolDefTree(mlir::SymbolOpInterface root)
void dumpToDotFile(std::string filename="") const
Dump the tree to file in dot graph format.
void print(llvm::raw_ostream &os) const
mlir::StringAttr getSymbolName(mlir::Operation *symbol)
Returns the name of the given symbol operation, or nullptr if no symbol is present.
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...
static std::string getGraphName(GraphType)