LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
SymbolHelper.h
Go to the documentation of this file.
1//===-- SymbolHelper.h ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
13
14#include <mlir/Interfaces/CallInterfaces.h>
15
16#include <ranges>
17
18namespace llzk {
19
20namespace component {
21class StructType;
22class StructDefOp;
23class MemberDefOp;
24} // namespace component
25
26namespace function {
27class FuncDefOp;
28} // namespace function
29namespace polymorphic {
30class TemplateOp;
31} // namespace polymorphic
32
33llvm::SmallVector<mlir::StringRef> getNames(mlir::SymbolRefAttr ref);
34llvm::SmallVector<mlir::FlatSymbolRefAttr> getPieces(mlir::SymbolRefAttr ref);
35
37inline mlir::FlatSymbolRefAttr
38getFlatSymbolRefAttr(mlir::MLIRContext *context, const mlir::Twine &twine) {
39 return mlir::FlatSymbolRefAttr::get(mlir::StringAttr::get(context, twine));
40}
41
43inline mlir::SymbolRefAttr asSymbolRefAttr(mlir::StringAttr root, mlir::SymbolRefAttr tail) {
44 return mlir::SymbolRefAttr::get(root, getPieces(tail));
45}
46
48inline mlir::SymbolRefAttr asSymbolRefAttr(llvm::ArrayRef<mlir::FlatSymbolRefAttr> path) {
49 return mlir::SymbolRefAttr::get(path.front().getAttr(), path.drop_front());
50}
51
53inline mlir::SymbolRefAttr asSymbolRefAttr(const std::vector<mlir::FlatSymbolRefAttr> &path) {
54 return asSymbolRefAttr(llvm::ArrayRef<mlir::FlatSymbolRefAttr>(path));
55}
56
58inline mlir::SymbolRefAttr getTailAsSymbolRefAttr(mlir::SymbolRefAttr symbol) {
59 return asSymbolRefAttr(symbol.getNestedReferences());
60}
61
63inline mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol) {
64 return mlir::SymbolRefAttr::get(
65 symbol.getRootReference(), symbol.getNestedReferences().drop_back()
66 );
67}
68
70mlir::SymbolRefAttr replaceLeaf(mlir::SymbolRefAttr orig, mlir::FlatSymbolRefAttr newLeaf);
71inline mlir::SymbolRefAttr replaceLeaf(mlir::SymbolRefAttr orig, mlir::StringAttr newLeaf) {
72 return replaceLeaf(orig, mlir::FlatSymbolRefAttr::get(newLeaf));
73}
74inline mlir::SymbolRefAttr replaceLeaf(mlir::SymbolRefAttr orig, const mlir::Twine &newLeaf) {
75 return replaceLeaf(orig, mlir::StringAttr::get(orig.getContext(), newLeaf));
76}
77
79mlir::SymbolRefAttr appendLeaf(mlir::SymbolRefAttr orig, mlir::FlatSymbolRefAttr newLeaf);
80inline mlir::SymbolRefAttr appendLeaf(mlir::SymbolRefAttr orig, mlir::StringAttr newLeaf) {
81 return appendLeaf(orig, mlir::FlatSymbolRefAttr::get(newLeaf));
82}
83inline mlir::SymbolRefAttr appendLeaf(mlir::SymbolRefAttr orig, const mlir::Twine &newLeaf) {
84 return appendLeaf(orig, mlir::StringAttr::get(orig.getContext(), newLeaf));
85}
86
89mlir::SymbolRefAttr appendLeafName(mlir::SymbolRefAttr orig, const mlir::Twine &newLeafSuffix);
90
93mlir::FailureOr<mlir::ModuleOp> getRootModule(mlir::Operation *from);
94mlir::FailureOr<mlir::SymbolRefAttr>
95getPathFromRoot(mlir::SymbolOpInterface to, mlir::ModuleOp *foundRoot = nullptr);
96mlir::FailureOr<mlir::SymbolRefAttr>
97getPathFromRoot(component::StructDefOp &to, mlir::ModuleOp *foundRoot = nullptr);
98mlir::FailureOr<mlir::SymbolRefAttr>
99getPathFromRoot(component::MemberDefOp &to, mlir::ModuleOp *foundRoot = nullptr);
100mlir::FailureOr<mlir::SymbolRefAttr>
101getPathFromRoot(function::FuncDefOp &to, mlir::ModuleOp *foundRoot = nullptr);
102
105mlir::FailureOr<mlir::ModuleOp> getTopRootModule(mlir::Operation *from);
106mlir::FailureOr<mlir::SymbolRefAttr>
107getPathFromTopRoot(mlir::SymbolOpInterface to, mlir::ModuleOp *foundRoot = nullptr);
108mlir::FailureOr<mlir::SymbolRefAttr>
109getPathFromTopRoot(component::StructDefOp &to, mlir::ModuleOp *foundRoot = nullptr);
110mlir::FailureOr<mlir::SymbolRefAttr>
111getPathFromTopRoot(component::MemberDefOp &to, mlir::ModuleOp *foundRoot = nullptr);
112mlir::FailureOr<mlir::SymbolRefAttr>
113getPathFromTopRoot(function::FuncDefOp &to, mlir::ModuleOp *foundRoot = nullptr);
114
119mlir::FailureOr<llzk::component::StructType> getMainInstanceType(mlir::Operation *lookupFrom);
120
125mlir::FailureOr<SymbolLookupResult<llzk::component::StructDefOp>>
126getMainInstanceDef(mlir::SymbolTableCollection &symbolTable, mlir::Operation *lookupFrom);
127
133template <typename T>
134inline mlir::FailureOr<SymbolLookupResult<T>>
135resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call) {
136 mlir::CallInterfaceCallable callable = call.getCallableForCallee();
137 if (auto symbolVal = llvm::dyn_cast<mlir::Value>(callable)) {
138 return SymbolLookupResult<T>(symbolVal.getDefiningOp());
139 }
140
141 // If the callable isn't a value, lookup the symbol reference.
142 // We first try to resolve in the nearest symbol table, as per the default
143 // MLIR behavior. If the resulting operation is not found, we will then
144 // use the LLZK lookup helpers.
145 auto symbolRef = llvm::cast<mlir::SymbolRefAttr>(callable);
146 mlir::Operation *op = symbolTable.lookupNearestSymbolFrom(call.getOperation(), symbolRef);
147
148 if (op) {
149 return SymbolLookupResult<T>(std::move(op));
150 }
151 // Otherwise, use the top-level lookup.
152 return lookupTopLevelSymbol<T>(symbolTable, symbolRef, call.getOperation());
153}
154
159template <typename T>
160inline mlir::FailureOr<SymbolLookupResult<T>>
161resolveCallableSilently(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call) {
162 mlir::CallInterfaceCallable callable = call.getCallableForCallee();
163 if (auto symbolVal = llvm::dyn_cast<mlir::Value>(callable)) {
164 SymbolLookupResult<T> result(symbolVal.getDefiningOp());
165 if (!result) {
166 return mlir::failure();
167 }
168 return result;
169 }
170
171 auto symbolRef = llvm::cast<mlir::SymbolRefAttr>(callable);
172 if (mlir::Operation *op = symbolTable.lookupNearestSymbolFrom(call.getOperation(), symbolRef)) {
173 SymbolLookupResult<T> result(op);
174 if (!result) {
175 return mlir::failure();
176 }
177 return result;
178 }
180 symbolTable, symbolRef, call.getOperation(), /*reportMissing=*/false
181 );
182}
183
184template <typename T>
185inline mlir::FailureOr<SymbolLookupResult<T>> resolveCallable(mlir::CallOpInterface call) {
186 mlir::SymbolTableCollection symbolTable;
187 return resolveCallable<T>(symbolTable, call);
188}
189
195mlir::FailureOr<polymorphic::TemplateOp>
196getConstResolutionTemplate(mlir::SymbolTableCollection &tables, mlir::Operation *origin);
197
199mlir::LogicalResult verifyParamOfType(
200 mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr param, mlir::Type structOrArrayType,
201 mlir::Operation *origin
202);
203
206mlir::LogicalResult verifyParamsOfType(
207 mlir::SymbolTableCollection &tables, mlir::ArrayRef<mlir::Attribute> tyParams,
208 mlir::Type structOrArrayType, mlir::Operation *origin
209);
210
212mlir::FailureOr<component::StructDefOp> verifyStructTypeResolution(
213 mlir::SymbolTableCollection &tables, component::StructType ty, mlir::Operation *origin
214);
215
217mlir::LogicalResult
218verifyTypeResolution(mlir::SymbolTableCollection &tables, mlir::Operation *origin, mlir::Type type);
219
221template <std::ranges::input_range Range>
222mlir::LogicalResult verifyTypeResolution(
223 mlir::SymbolTableCollection &tables, mlir::Operation *origin, const Range &types
224) {
225 // Check all before returning to present all applicable type errors in one compilation.
226 bool failed = false;
227 for (const auto &t : types) {
228 failed |= mlir::failed(verifyTypeResolution(tables, origin, t));
229 }
230 return mlir::LogicalResult::failure(failed);
231}
232
233} // namespace llzk
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
Definition LICENSE.txt:45
This file defines methods symbol lookup across LLZK operations and included files.
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
SymbolRefAttr appendLeafName(SymbolRefAttr orig, const Twine &newLeafSuffix)
mlir::FlatSymbolRefAttr getFlatSymbolRefAttr(mlir::MLIRContext *context, const mlir::Twine &twine)
Construct a FlatSymbolRefAttr with the given content.
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
FailureOr< StructType > getMainInstanceType(Operation *lookupFrom)
llvm::SmallVector< StringRef > getNames(SymbolRefAttr ref)
FailureOr< ModuleOp > getRootModule(Operation *from)
FailureOr< TemplateOp > getConstResolutionTemplate(SymbolTableCollection &tables, Operation *origin)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
SymbolRefAttr replaceLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
FailureOr< StructDefOp > verifyStructTypeResolution(SymbolTableCollection &tables, StructType ty, Operation *origin)
mlir::FailureOr< SymbolLookupResult< T > > resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers.
FailureOr< ModuleOp > getTopRootModule(Operation *from)
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
mlir::FailureOr< SymbolLookupResult< T > > resolveCallableSilently(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Resolve a callable without emitting a diagnostic for missing top-level symbols.
LogicalResult verifyParamsOfType(SymbolTableCollection &tables, ArrayRef< Attribute > tyParams, Type parameterizedType, Operation *origin)
mlir::SymbolRefAttr asSymbolRefAttr(mlir::StringAttr root, mlir::SymbolRefAttr tail)
Build a SymbolRefAttr that prepends tail with root, i.e., root::tail.
mlir::SymbolRefAttr getTailAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the root/head element removed.
FailureOr< SymbolLookupResult< StructDefOp > > getMainInstanceDef(SymbolTableCollection &symbolTable, Operation *lookupFrom)
FailureOr< SymbolRefAttr > getPathFromTopRoot(SymbolOpInterface to, ModuleOp *foundRoot)
LogicalResult verifyParamOfType(SymbolTableCollection &tables, SymbolRefAttr param, Type parameterizedType, Operation *origin)
llvm::SmallVector< FlatSymbolRefAttr > getPieces(SymbolRefAttr ref)
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)