LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
SymbolDefTree.cpp
Go to the documentation of this file.
1//===-- SymbolDefTree.cpp ---------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
12#include "llzk/Util/Constants.h"
15
16#include <mlir/IR/BuiltinOps.h>
17
18#include <llvm/ADT/DepthFirstIterator.h>
19#include <llvm/ADT/SmallSet.h>
20#include <llvm/Support/GraphWriter.h>
21
22using namespace mlir;
23
24namespace llzk {
25
26//===----------------------------------------------------------------------===//
27// SymbolDefTreeNode
28//===----------------------------------------------------------------------===//
29
30void SymbolDefTreeNode::addChild(SymbolDefTreeNode *node) {
31 assert(!node->parent && "def cannot be in more than one symbol table");
32 node->parent = this;
33 children.insert(node);
34}
35
36//===----------------------------------------------------------------------===//
37// SymbolDefTree
38//===----------------------------------------------------------------------===//
39
40namespace {
41
42void assertProperBuild(SymbolOpInterface root, const SymbolDefTree *tree) {
43 // Collect all Symbols in the graph
44 llvm::SmallSet<SymbolOpInterface, 16> fromGraph;
45 for (const SymbolDefTreeNode *r : llvm::depth_first(tree)) {
46 if (SymbolOpInterface s = r->getOp()) {
47 fromGraph.insert(s);
48 }
49 }
50 // Ensure every symbol reachable from the 'root' is represented in the graph
51#ifndef NDEBUG
52 root.walk([&fromGraph](SymbolOpInterface s) { assert(fromGraph.contains(s)); });
53#endif
54}
55
56} // namespace
57
58SymbolDefTree::SymbolDefTree(SymbolOpInterface rootSymbol) {
59 assert(rootSymbol->hasTrait<OpTrait::SymbolTable>());
60 buildTree(rootSymbol, /*parentNode=*/nullptr);
61 assertProperBuild(rootSymbol, this);
62}
63
64void SymbolDefTree::buildTree(SymbolOpInterface symbolOp, SymbolDefTreeNode *parentNode) {
65 // Add node for the current symbol
66 parentNode = getOrAddNode(symbolOp, parentNode);
67 // If this symbol is also its own SymbolTable, recursively add child symbols
68 if (symbolOp->hasTrait<OpTrait::SymbolTable>()) {
69 for (Operation &op : symbolOp->getRegion(0).front()) {
70 if (SymbolOpInterface childSym = llvm::dyn_cast<SymbolOpInterface>(&op)) {
71 buildTree(childSym, parentNode);
72 }
73 }
74 }
75}
76
78SymbolDefTree::getOrAddNode(SymbolOpInterface symbolDef, SymbolDefTreeNode *parentNode) {
79 std::unique_ptr<SymbolDefTreeNode> &node = nodes[symbolDef];
80 if (!node) {
81 node.reset(new SymbolDefTreeNode(symbolDef));
82 // Add this node to the given parent node if given, else the root node.
83 if (parentNode) {
84 parentNode->addChild(node.get());
85 } else {
86 root.addChild(node.get());
87 }
88 }
89 return node.get();
90}
91
92const SymbolDefTreeNode *SymbolDefTree::lookupNode(SymbolOpInterface symbolDef) const {
93 const auto *it = nodes.find(symbolDef);
94 return it == nodes.end() ? nullptr : it->second.get();
95}
96
97//===----------------------------------------------------------------------===//
98// Printing
99//===----------------------------------------------------------------------===//
100
101std::string SymbolDefTreeNode::toString() const { return buildStringViaPrint(*this); }
102
103void SymbolDefTreeNode::print(llvm::raw_ostream &os) const {
104 os << '\'' << symbolDef->getName() << "' ";
105 if (StringAttr name = llzk::getSymbolName(symbolDef)) {
106 os << "named " << name << '\n';
107 } else {
108 os << "without a name\n";
109 }
110}
111
112void SymbolDefTree::print(llvm::raw_ostream &os) const {
113 std::function<void(SymbolDefTreeNode *)> printNode = [&os, &printNode](SymbolDefTreeNode *node) {
114 // Print the current node
115 os << "// - Node : [" << node << "] ";
116 node->print(os);
117 // Print list of IDs for the children
118 os << "// --- Children : [";
119 llvm::interleaveComma(node->children, os);
120 os << "]\n";
121 // Recursively print the children
122 for (SymbolDefTreeNode *c : node->children) {
123 printNode(c);
124 }
125 };
126
127 os << "// ---- SymbolDefTree ----\n";
128 for (SymbolDefTreeNode *r : root.children) {
129 printNode(r);
130 }
131 os << "// -----------------------\n";
132}
133
134void SymbolDefTree::dumpToDotFile(std::string filename) const {
136 llvm::WriteGraph(this, "SymbolDefTree", /*ShortNames*/ false, title, std::move(filename));
137}
138
139} // namespace llzk
child_iterator end() const
std::string toString() const
Print the node in a human readable format.
void print(llvm::raw_ostream &os) const
Builds a tree structure representing the symbol table structure.
const SymbolDefTreeNode * lookupNode(mlir::SymbolOpInterface symbolOp) const
Lookup the node for the given symbol Op, or nullptr if none exists.
SymbolDefTree(mlir::SymbolOpInterface root)
void dumpToDotFile(std::string filename="") const
Dump the tree to file in dot graph format.
void print(llvm::raw_ostream &os) const
mlir::StringAttr getSymbolName(mlir::Operation *symbol)
Returns the name of the given symbol operation, or nullptr if no symbol is present.
std::string buildStringViaPrint(const T &base, Args &&...args)
Generate a string by calling base.print(llvm::raw_ostream &) on a stream backed by the returned strin...