LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
SymbolTableLLZK.cpp
Go to the documentation of this file.
1//===-- SymbolTableLLZK.cpp -------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8// Adapted from the LLVM Project's mlir/lib/IR/SymbolTable.cpp
9// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
10// See https://llvm.org/LICENSE.txt for license information.
11// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
12//
13//===----------------------------------------------------------------------===//
34//===----------------------------------------------------------------------===//
35
37
39
40#include <llvm/ADT/SmallPtrSet.h>
41
42using namespace mlir;
43
44//===----------------------------------------------------------------------===//
45// Symbol Use Lists
46//===----------------------------------------------------------------------===//
47
48namespace {
49
52static bool isPotentiallyUnknownSymbolTable(Operation *op) {
53 return op->getNumRegions() == 1 && !op->getDialect();
54}
55
57static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
58 return op->getAttrOfType<StringAttr>(symbolAttrNameId);
59}
60
65static LogicalResult collectValidReferencesFor(
66 Operation *symbol, StringAttr symbolName, Operation *within,
67 SmallVectorImpl<SymbolRefAttr> &results
68) {
69 assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
70 MLIRContext *ctx = symbol->getContext();
71
72 auto leafRef = FlatSymbolRefAttr::get(symbolName);
73 results.push_back(leafRef);
74
75 // Early exit for when 'within' is the parent of 'symbol'.
76 Operation *symbolTableOp = symbol->getParentOp();
77 if (within == symbolTableOp) {
78 return success();
79 }
80
81 // Collect references until 'symbolTableOp' reaches 'within'.
82 SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
83 StringAttr symbolNameId = StringAttr::get(ctx, SymbolTable::getSymbolAttrName());
84 do {
85 // Each parent of 'symbol' should define a symbol table.
86 if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
87 return failure();
88 }
89 // Each parent of 'symbol' should also be a symbol.
90 StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
91 if (!symbolTableName) {
92 return failure();
93 }
94 results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
95
96 symbolTableOp = symbolTableOp->getParentOp();
97 if (symbolTableOp == within) {
98 break;
99 }
100 nestedRefs.insert(nestedRefs.begin(), FlatSymbolRefAttr::get(symbolTableName));
101 } while (true);
102 return success();
103}
104
108static std::optional<WalkResult> walkSymbolTable(
109 MutableArrayRef<Region> regions, function_ref<std::optional<WalkResult>(Operation *)> callback
110) {
111 SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
112 while (!worklist.empty()) {
113 for (Operation &op : worklist.pop_back_val()->getOps()) {
114 std::optional<WalkResult> result = callback(&op);
115 if (result != WalkResult::advance()) {
116 return result;
117 }
118
119 // If this op defines a new symbol table scope, we can't traverse. Any
120 // symbol references nested within 'op' are different semantically.
121 if (!op.hasTrait<OpTrait::SymbolTable>()) {
122 for (Region &region : op.getRegions()) {
123 worklist.push_back(&region);
124 }
125 }
126 }
127 }
128 return WalkResult::advance();
129}
130
133static WalkResult
134walkSymbolRefs(Operation *op, function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
135 // This is modified for LLZK.
136 auto walkFn = [&op, &callback](SymbolRefAttr symbolRef) {
137 if (callback({op, symbolRef}).wasInterrupted()) {
138 return WalkResult::interrupt();
139 }
140 return WalkResult::skip(); // Don't walk nested references.
141 };
142 for (Type t : op->getOperandTypes()) {
143 if (t.walk<WalkOrder::PreOrder>(walkFn).wasInterrupted()) {
144 return WalkResult::interrupt();
145 }
146 }
147 for (Type t : op->getResultTypes()) {
148 if (t.walk<WalkOrder::PreOrder>(walkFn).wasInterrupted()) {
149 return WalkResult::interrupt();
150 }
151 }
152
153 // TODO: Remove this when POD types are updated to use StringAttr.
154 // POD record names are encoded as FlatSymbolRefAttr for parsing/printing
155 // convenience, but they are not real symbol references and must not be
156 // surfaced as symbol uses.
157 auto shouldSkipAttr = [op](NamedAttribute attr) {
158 return attr.getName() == "record_name" &&
159 (isa<llzk::pod::ReadPodOp>(op) || isa<llzk::pod::WritePodOp>(op));
160 };
161 for (NamedAttribute attr : op->getAttrs()) {
162 if (shouldSkipAttr(attr)) {
163 continue;
164 }
165 if (attr.getValue().walk<WalkOrder::PreOrder>(walkFn).wasInterrupted()) {
166 return WalkResult::interrupt();
167 }
168 }
169 return WalkResult::advance();
170}
171
175static std::optional<WalkResult> walkSymbolUses(
176 MutableArrayRef<Region> regions, function_ref<WalkResult(SymbolTable::SymbolUse)> callback
177) {
178 return walkSymbolTable(regions, [&](Operation *op) -> std::optional<WalkResult> {
179 // Check that this isn't a potentially unknown symbol table.
180 if (isPotentiallyUnknownSymbolTable(op)) {
181 return std::nullopt;
182 }
183 return walkSymbolRefs(op, callback);
184 });
185}
186
190static std::optional<WalkResult>
191walkSymbolUses(Operation *from, function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
192 // If this operation has regions, and it, as well as its dialect, isn't
193 // registered then conservatively fail. The operation may define a
194 // symbol table, so we can't opaquely know if we should traverse to find
195 // nested uses.
196 if (isPotentiallyUnknownSymbolTable(from)) {
197 return std::nullopt;
198 }
199
200 // Walk the uses on this operation.
201 if (walkSymbolRefs(from, callback).wasInterrupted()) {
202 return WalkResult::interrupt();
203 }
204
205 // Only recurse if this operation is not a symbol table. A symbol table
206 // defines a new scope, so we can't walk the attributes from within the symbol
207 // table op.
208 if (!from->hasTrait<OpTrait::SymbolTable>()) {
209 return walkSymbolUses(from->getRegions(), callback);
210 }
211 return WalkResult::advance();
212}
213
219struct SymbolScope {
223 template <
224 typename CallbackT,
225 std::enable_if_t<!std::is_same<
226 typename llvm::function_traits<CallbackT>::result_t, void>::value> * = nullptr>
227 std::optional<WalkResult> walk(CallbackT cback) {
228 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit)) {
229 return walkSymbolUses(*region, cback);
230 }
231 return walkSymbolUses(llvm::cast<Operation *>(limit), cback);
232 }
235 template <
236 typename CallbackT,
237 std::enable_if_t<std::is_same<
238 typename llvm::function_traits<CallbackT>::result_t, void>::value> * = nullptr>
239 std::optional<WalkResult> walk(CallbackT cback) {
240 return walk([=](SymbolTable::SymbolUse use) { return cback(use), WalkResult::advance(); });
241 }
242
245 template <typename CallbackT> std::optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
246 if (Region *region = llvm::dyn_cast_if_present<Region *>(limit)) {
247 return ::walkSymbolTable(*region, cback);
248 }
249 return ::walkSymbolTable(llvm::cast<Operation *>(limit), cback);
250 }
251
253 SymbolRefAttr symbol;
254
256 llvm::PointerUnion<Operation *, Region *> limit;
257};
258
260static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol, Operation *limit) {
261 StringAttr symName = SymbolTable::getSymbolName(symbol);
262 assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
263
264 // Compute the ancestors of 'limit'.
265 SetVector<Operation *, SmallVector<Operation *, 4>, SmallPtrSet<Operation *, 4>> limitAncestors;
266 Operation *limitAncestor = limit;
267 do {
268 // Check to see if 'symbol' is an ancestor of 'limit'.
269 if (limitAncestor == symbol) {
270 // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
271 // doesn't support parent references.
272 if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) == symbol->getParentOp()) {
273 return {{SymbolRefAttr::get(symName), limit}};
274 }
275 return {};
276 }
277
278 limitAncestors.insert(limitAncestor);
279 } while ((limitAncestor = limitAncestor->getParentOp()));
280
281 // Try to find the first ancestor of 'symbol' that is an ancestor of 'limit'.
282 Operation *commonAncestor = symbol->getParentOp();
283 do {
284 if (limitAncestors.count(commonAncestor)) {
285 break;
286 }
287 } while ((commonAncestor = commonAncestor->getParentOp()));
288 assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
289
290 // Compute the set of valid nested references for 'symbol' as far up to the
291 // common ancestor as possible.
292 SmallVector<SymbolRefAttr, 2> references;
293 bool collectedAllReferences =
294 succeeded(collectValidReferencesFor(symbol, symName, commonAncestor, references));
295
296 // Handle the case where the common ancestor is 'limit'.
297 if (commonAncestor == limit) {
298 SmallVector<SymbolScope, 2> scopes;
299
300 // Walk each of the ancestors of 'symbol', calling the compute function for
301 // each one.
302 Operation *limitIt = symbol->getParentOp();
303 for (size_t i = 0, e = references.size(); i != e; ++i, limitIt = limitIt->getParentOp()) {
304 assert(limitIt->hasTrait<OpTrait::SymbolTable>());
305 scopes.push_back({references[i], &limitIt->getRegion(0)});
306 }
307 return scopes;
308 }
309
310 // Otherwise, we just need the symbol reference for 'symbol' that will be
311 // used within 'limit'. This is the last reference in the list we computed
312 // above if we were able to collect all references.
313 if (!collectedAllReferences) {
314 return {};
315 }
316 return {{references.back(), limit}};
317}
318
319static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol, Region *limit) {
320 auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
321
322 // If we collected some scopes to walk, make sure to constrain the one for
323 // limit to the specific region requested.
324 if (!scopes.empty()) {
325 scopes.back().limit = limit;
326 }
327 return scopes;
328}
329
330static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol, Region *limit) {
331 return {{SymbolRefAttr::get(symbol), limit}};
332}
333
334static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol, Operation *limit) {
335 SmallVector<SymbolScope, 1> scopes;
336 auto symbolRef = SymbolRefAttr::get(symbol);
337 for (auto &region : limit->getRegions()) {
338 scopes.push_back({symbolRef, &region});
339 }
340 return scopes;
341}
342
345static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
346 if (ref == subRef) {
347 return true;
348 }
349
350 // If the references are not pointer equal, check to see if `subRef` is a
351 // prefix of `ref`.
352 if (llvm::isa<FlatSymbolRefAttr>(ref) || ref.getRootReference() != subRef.getRootReference()) {
353 return false;
354 }
355
356 auto refLeafs = ref.getNestedReferences();
357 auto subRefLeafs = subRef.getNestedReferences();
358 return subRefLeafs.size() < refLeafs.size() &&
359 subRefLeafs == refLeafs.take_front(subRefLeafs.size());
360}
361
362} // namespace
363
364//===----------------------------------------------------------------------===//
365// llzk::getSymbolUses
366
367namespace {
368
370template <typename FromT>
371static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
372 std::vector<SymbolTable::SymbolUse> uses;
373 auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
374 uses.push_back(symbolUse);
375 return WalkResult::advance();
376 };
377 auto result = walkSymbolUses(from, walkFn);
378 return result ? std::optional<SymbolTable::UseRange>(std::move(uses)) : std::nullopt;
379}
380
381} // namespace
382
390std::optional<SymbolTable::UseRange> llzk::getSymbolUses(Operation *from) {
391 return getSymbolUsesImpl(from);
392}
393std::optional<SymbolTable::UseRange> llzk::getSymbolUses(Region *from) {
394 return getSymbolUsesImpl(MutableArrayRef<Region>(*from));
395}
396
397//===----------------------------------------------------------------------===//
398// llzk::getSymbolUses
399
400namespace {
401
403template <typename SymbolT, typename IRUnitT>
404static std::optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol, IRUnitT *limit) {
405 std::vector<SymbolTable::SymbolUse> uses;
406 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
407 if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
408 if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())) {
409 uses.push_back(symbolUse);
410 }
411 })) {
412 return std::nullopt;
413 }
414 }
415 return SymbolTable::UseRange(std::move(uses));
416}
417
418} // namespace
419
424std::optional<SymbolTable::UseRange> llzk::getSymbolUses(StringAttr symbol, Operation *from) {
425 return getSymbolUsesImpl(symbol, from);
426}
427std::optional<SymbolTable::UseRange> llzk::getSymbolUses(Operation *symbol, Operation *from) {
428 return getSymbolUsesImpl(symbol, from);
429}
430std::optional<SymbolTable::UseRange> llzk::getSymbolUses(StringAttr symbol, Region *from) {
431 return getSymbolUsesImpl(symbol, from);
432}
433std::optional<SymbolTable::UseRange> llzk::getSymbolUses(Operation *symbol, Region *from) {
434 return getSymbolUsesImpl(symbol, from);
435}
436
437//===----------------------------------------------------------------------===//
438// llzk::getSymbolsUsedIn
439
440void llzk::getSymbolsUsedIn(Type t, llvm::SmallDenseSet<SymbolRefAttr> &symbolsUsed) {
441 t.walk([&symbolsUsed](SymbolRefAttr symbolRef) { symbolsUsed.insert(symbolRef); });
442}
443
444void llzk::getSymbolsUsedIn(ArrayRef<Type> types, llvm::SmallDenseSet<SymbolRefAttr> &symbolsUsed) {
445 for (Type t : types) {
446 getSymbolsUsedIn(t, symbolsUsed);
447 }
448}
449
450llvm::SmallDenseSet<SymbolRefAttr> llzk::getSymbolsUsedIn(Type t) {
451 llvm::SmallDenseSet<SymbolRefAttr> symbolsUsed;
452 getSymbolsUsedIn(t, symbolsUsed);
453 return symbolsUsed;
454}
455
456llvm::SmallDenseSet<SymbolRefAttr> llzk::getSymbolsUsedIn(ArrayRef<Type> types) {
457 llvm::SmallDenseSet<SymbolRefAttr> symbolsUsed;
458 getSymbolsUsedIn(types, symbolsUsed);
459 return symbolsUsed;
460}
461
462//===----------------------------------------------------------------------===//
463// llzk::symbolKnownUseEmpty
464
465namespace {
466
468template <typename SymbolT, typename IRUnitT>
469static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
470 for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
471 // Walk all of the symbol uses looking for a reference to 'symbol'.
472 if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
473 return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()) ? WalkResult::interrupt()
474 : WalkResult::advance();
475 }) != WalkResult::advance()) {
476 return false;
477 }
478 }
479 return true;
480}
481
482} // namespace
483
488bool llzk::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
489 return symbolKnownUseEmptyImpl(symbol, from);
490}
491bool llzk::symbolKnownUseEmpty(Operation *symbol, Operation *from) {
492 return symbolKnownUseEmptyImpl(symbol, from);
493}
494bool llzk::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
495 return symbolKnownUseEmptyImpl(symbol, from);
496}
497bool llzk::symbolKnownUseEmpty(Operation *symbol, Region *from) {
498 return symbolKnownUseEmptyImpl(symbol, from);
499}
500
501//===----------------------------------------------------------------------===//
502// llzk::getSymbolName
503
504StringAttr llzk::getSymbolName(Operation *op) {
505 // This is modified for LLZK.
506 // `SymbolTable::getSymbolName(Operation*)` asserts if there is no name (ex: in the case of
507 // ModuleOp where the symbol name is optional) and there's no other way to check if the name
508 // exists so this fully involved retrieval method must be used to return `nullptr` if no name.
509 return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
510}
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Definition LICENSE.txt:9
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
Definition LICENSE.txt:45
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
mlir::StringAttr getSymbolName(mlir::Operation *symbol)
Returns the name of the given symbol operation, or nullptr if no symbol is present.
void getSymbolsUsedIn(mlir::Type t, llvm::SmallDenseSet< mlir::SymbolRefAttr > &symbolsUsed)
Add all symbols used within the given Type to the provided set.
bool symbolKnownUseEmpty(mlir::StringAttr symbol, mlir::Operation *from)
Return if the given symbol is known to have no uses that are nested within the given operation 'from'...