40#include <llvm/ADT/SmallPtrSet.h>
52static bool isPotentiallyUnknownSymbolTable(Operation *op) {
53 return op->getNumRegions() == 1 && !op->getDialect();
57static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
58 return op->getAttrOfType<StringAttr>(symbolAttrNameId);
65static LogicalResult collectValidReferencesFor(
66 Operation *symbol, StringAttr symbolName, Operation *within,
67 SmallVectorImpl<SymbolRefAttr> &results
69 assert(within->isAncestor(symbol) &&
"expected 'within' to be an ancestor");
70 MLIRContext *ctx = symbol->getContext();
72 auto leafRef = FlatSymbolRefAttr::get(symbolName);
73 results.push_back(leafRef);
76 Operation *symbolTableOp = symbol->getParentOp();
77 if (within == symbolTableOp) {
82 SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
83 StringAttr symbolNameId = StringAttr::get(ctx, SymbolTable::getSymbolAttrName());
86 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
90 StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
91 if (!symbolTableName) {
94 results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
96 symbolTableOp = symbolTableOp->getParentOp();
97 if (symbolTableOp == within) {
100 nestedRefs.insert(nestedRefs.begin(), FlatSymbolRefAttr::get(symbolTableName));
108static std::optional<WalkResult> walkSymbolTable(
109 MutableArrayRef<Region> regions, function_ref<std::optional<WalkResult>(Operation *)> callback
111 SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
112 while (!worklist.empty()) {
113 for (Operation &op : worklist.pop_back_val()->getOps()) {
114 std::optional<WalkResult> result = callback(&op);
115 if (result != WalkResult::advance()) {
121 if (!op.hasTrait<OpTrait::SymbolTable>()) {
122 for (Region ®ion : op.getRegions()) {
123 worklist.push_back(®ion);
128 return WalkResult::advance();
134walkSymbolRefs(Operation *op, function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
136 auto walkFn = [&op, &callback](SymbolRefAttr symbolRef) {
137 if (callback({op, symbolRef}).wasInterrupted()) {
138 return WalkResult::interrupt();
140 return WalkResult::skip();
142 for (Type t : op->getOperandTypes()) {
143 if (t.walk<WalkOrder::PreOrder>(walkFn).wasInterrupted()) {
144 return WalkResult::interrupt();
147 for (Type t : op->getResultTypes()) {
148 if (t.walk<WalkOrder::PreOrder>(walkFn).wasInterrupted()) {
149 return WalkResult::interrupt();
157 auto shouldSkipAttr = [op](NamedAttribute attr) {
158 return attr.getName() ==
"record_name" &&
159 (isa<llzk::pod::ReadPodOp>(op) || isa<llzk::pod::WritePodOp>(op));
161 for (NamedAttribute attr : op->getAttrs()) {
162 if (shouldSkipAttr(attr)) {
165 if (attr.getValue().walk<WalkOrder::PreOrder>(walkFn).wasInterrupted()) {
166 return WalkResult::interrupt();
169 return WalkResult::advance();
175static std::optional<WalkResult> walkSymbolUses(
176 MutableArrayRef<Region> regions, function_ref<WalkResult(SymbolTable::SymbolUse)> callback
178 return walkSymbolTable(regions, [&](Operation *op) -> std::optional<WalkResult> {
180 if (isPotentiallyUnknownSymbolTable(op)) {
183 return walkSymbolRefs(op, callback);
190static std::optional<WalkResult>
191walkSymbolUses(Operation *
from, function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
196 if (isPotentiallyUnknownSymbolTable(
from)) {
201 if (walkSymbolRefs(
from, callback).wasInterrupted()) {
202 return WalkResult::interrupt();
208 if (!
from->hasTrait<OpTrait::SymbolTable>()) {
209 return walkSymbolUses(
from->getRegions(), callback);
211 return WalkResult::advance();
225 std::enable_if_t<!std::is_same<
226 typename llvm::function_traits<CallbackT>::result_t,
void>::value> * =
nullptr>
227 std::optional<WalkResult> walk(CallbackT cback) {
228 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit)) {
229 return walkSymbolUses(*region, cback);
231 return walkSymbolUses(llvm::cast<Operation *>(limit), cback);
237 std::enable_if_t<std::is_same<
238 typename llvm::function_traits<CallbackT>::result_t,
void>::value> * =
nullptr>
239 std::optional<WalkResult> walk(CallbackT cback) {
240 return walk([=](SymbolTable::SymbolUse
use) {
return cback(
use), WalkResult::advance(); });
245 template <
typename CallbackT> std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
246 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit)) {
247 return ::walkSymbolTable(*region, cback);
249 return ::walkSymbolTable(llvm::cast<Operation *>(limit), cback);
253 SymbolRefAttr symbol;
256 llvm::PointerUnion<Operation *, Region *> limit;
260static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol, Operation *limit) {
261 StringAttr symName = SymbolTable::getSymbolName(symbol);
262 assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
265 SetVector<Operation *, SmallVector<Operation *, 4>, SmallPtrSet<Operation *, 4>> limitAncestors;
266 Operation *limitAncestor = limit;
269 if (limitAncestor == symbol) {
272 if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) == symbol->getParentOp()) {
273 return {{SymbolRefAttr::get(symName), limit}};
278 limitAncestors.insert(limitAncestor);
279 }
while ((limitAncestor = limitAncestor->getParentOp()));
282 Operation *commonAncestor = symbol->getParentOp();
284 if (limitAncestors.count(commonAncestor)) {
287 }
while ((commonAncestor = commonAncestor->getParentOp()));
288 assert(commonAncestor &&
"'limit' and 'symbol' have no common ancestor");
292 SmallVector<SymbolRefAttr, 2> references;
293 bool collectedAllReferences =
294 succeeded(collectValidReferencesFor(symbol, symName, commonAncestor, references));
297 if (commonAncestor == limit) {
298 SmallVector<SymbolScope, 2> scopes;
302 Operation *limitIt = symbol->getParentOp();
303 for (
size_t i = 0, e = references.size(); i != e; ++i, limitIt = limitIt->getParentOp()) {
304 assert(limitIt->hasTrait<OpTrait::SymbolTable>());
305 scopes.push_back({references[i], &limitIt->getRegion(0)});
313 if (!collectedAllReferences) {
316 return {{references.back(), limit}};
319static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol, Region *limit) {
320 auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
324 if (!scopes.empty()) {
325 scopes.back().limit = limit;
330static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol, Region *limit) {
331 return {{SymbolRefAttr::get(symbol), limit}};
334static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol, Operation *limit) {
335 SmallVector<SymbolScope, 1> scopes;
336 auto symbolRef = SymbolRefAttr::get(symbol);
337 for (
auto ®ion : limit->getRegions()) {
338 scopes.push_back({symbolRef, ®ion});
345static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
352 if (llvm::isa<FlatSymbolRefAttr>(ref) || ref.getRootReference() != subRef.getRootReference()) {
356 auto refLeafs = ref.getNestedReferences();
357 auto subRefLeafs = subRef.getNestedReferences();
358 return subRefLeafs.size() < refLeafs.size() &&
359 subRefLeafs == refLeafs.take_front(subRefLeafs.size());
370template <
typename FromT>
371static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT
from) {
372 std::vector<SymbolTable::SymbolUse> uses;
373 auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
374 uses.push_back(symbolUse);
375 return WalkResult::advance();
377 auto result = walkSymbolUses(
from, walkFn);
378 return result ? std::optional<SymbolTable::UseRange>(std::move(uses)) : std::nullopt;
391 return getSymbolUsesImpl(
from);
394 return getSymbolUsesImpl(MutableArrayRef<Region>(*
from));
403template <
typename SymbolT,
typename IRUnitT>
404static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol, IRUnitT *limit) {
405 std::vector<SymbolTable::SymbolUse> uses;
406 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
407 if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
408 if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())) {
409 uses.push_back(symbolUse);
415 return SymbolTable::UseRange(std::move(uses));
425 return getSymbolUsesImpl(symbol,
from);
428 return getSymbolUsesImpl(symbol,
from);
431 return getSymbolUsesImpl(symbol,
from);
434 return getSymbolUsesImpl(symbol,
from);
441 t.walk([&symbolsUsed](SymbolRefAttr symbolRef) { symbolsUsed.insert(symbolRef); });
445 for (Type t : types) {
451 llvm::SmallDenseSet<SymbolRefAttr> symbolsUsed;
457 llvm::SmallDenseSet<SymbolRefAttr> symbolsUsed;
468template <
typename SymbolT,
typename IRUnitT>
469static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
470 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
472 if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
473 return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()) ? WalkResult::interrupt()
474 : WalkResult::advance();
475 }) != WalkResult::advance()) {
489 return symbolKnownUseEmptyImpl(symbol,
from);
492 return symbolKnownUseEmptyImpl(symbol,
from);
495 return symbolKnownUseEmptyImpl(symbol,
from);
498 return symbolKnownUseEmptyImpl(symbol,
from);
509 return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
mlir::StringAttr getSymbolName(mlir::Operation *symbol)
Returns the name of the given symbol operation, or nullptr if no symbol is present.
void getSymbolsUsedIn(mlir::Type t, llvm::SmallDenseSet< mlir::SymbolRefAttr > &symbolsUsed)
Add all symbols used within the given Type to the provided set.
bool symbolKnownUseEmpty(mlir::StringAttr symbol, mlir::Operation *from)
Return if the given symbol is known to have no uses that are nested within the given operation 'from'...