LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
CommonAttrOrTypeCAPIGen.h
Go to the documentation of this file.
1//===- CommonAttrOrTypeCAPIGen.h ------------------------------------------===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9//
10// Common utilities shared between Attr and Type CAPI generators
11//
12//===----------------------------------------------------------------------===//
13
14#pragma once
15
16#include "CommonCAPIGen.h"
17
18#include <mlir/TableGen/AttrOrTypeDef.h>
19
25 using HeaderGenerator::HeaderGenerator;
26 ~AttrOrTypeHeaderGenerator() override = default;
27
30 void setParamName(mlir::StringRef name) {
31 this->paramName = name;
33 }
34
36 virtual void genParameterGetterDecl(mlir::StringRef cppType) const {
37 static constexpr char fmt[] = R"(
39MLIR_CAPI_EXPORTED {7} {0}{2}_{3}Get{4}(Mlir{1});
40)";
41 assert(!dialectNamespace.empty() && "Dialect must be set");
42 assert(!paramName.empty() && "paramName must be set");
43 os << llvm::formatv(
44 fmt,
45 FunctionPrefix, // {0}
46 kind, // {1}
48 className, // {3}
50 paramName, // {5}
51 dialectNamespace, // {6}
52 mapCppTypeToCapiType(cppType) // {7}
53 );
54 }
55
57 virtual void genArrayRefParameterGetterDecls(mlir::StringRef cppType) const {
58 static constexpr char fmt[] = R"(
60MLIR_CAPI_EXPORTED intptr_t {0}{2}_{3}Get{4}Count(Mlir{1});
61
63MLIR_CAPI_EXPORTED {7} {0}{2}_{3}Get{4}At(Mlir{1}, intptr_t pos);
64)";
65 assert(!dialectNamespace.empty() && "Dialect must be set");
66 assert(!paramName.empty() && "paramName must be set");
67 mlir::StringRef cppElemType = extractArrayRefElementType(cppType);
68 os << llvm::formatv(
69 fmt,
70 FunctionPrefix, // {0}
71 kind, // {1}
73 className, // {3}
75 paramName, // {5}
76 dialectNamespace, // {6}
77 mapCppTypeToCapiType(cppElemType) // {7}
78 );
79 }
80
82 virtual void genDefaultGetBuilderDecl(const mlir::tblgen::AttrOrTypeDef &def) const {
83 static constexpr char fmt[] = R"(
85MLIR_CAPI_EXPORTED Mlir{1} {0}{2}_{3}Get(MlirContext ctx{4});
86)";
87 assert(!dialectNamespace.empty() && "Dialect must be set");
88
89 // Use raw_string_ostream for efficient string building of parameter list
90 std::string paramListBuffer;
91 llvm::raw_string_ostream paramListStream(paramListBuffer);
92 for (const auto &param : def.getParameters()) {
93 mlir::StringRef cppType = param.getCppType();
94 if (isArrayRefType(cppType)) {
95 // For ArrayRef parameters, use intptr_t for count and pointer to element type
96 mlir::StringRef cppElemType = extractArrayRefElementType(cppType);
97 paramListStream << ", intptr_t " << param.getName() << "Count, "
98 << mapCppTypeToCapiType(cppElemType) << " const *" << param.getName();
99 } else {
100 paramListStream << ", " << mapCppTypeToCapiType(cppType) << ' ' << param.getName();
101 }
102 }
104 os << llvm::formatv(
105 fmt,
106 FunctionPrefix, // {0}
107 kind, // {1}
109 className, // {3}
110 paramListBuffer, // {4}
111 dialectNamespace // {5}
112 );
113 }
114
115 void genCompleteRecord(const mlir::tblgen::AttrOrTypeDef &def) {
116 mlir::tblgen::Dialect defDialect = def.getDialect();
117
118 // Generate for the selected dialect only
119 if (defDialect.getName() != DialectName) {
120 return;
121 }
122
123 this->setNamespaceAndClassName(defDialect, def.getCppClassName());
124
125 // Generate IsA check
126 if (GenIsA) {
127 this->genIsADecl();
128 }
129
130 // Generate default Get builder if not skipped
131 if (GenTypeOrAttrGet && !def.skipDefaultBuilders()) {
132 this->genDefaultGetBuilderDecl(def);
133 }
134
135 // Generate parameter getters
137 for (const auto &param : def.getParameters()) {
138 this->setParamName(param.getName());
139 mlir::StringRef cppType = param.getCppType();
140 if (isArrayRefType(cppType)) {
141 this->genArrayRefParameterGetterDecls(cppType);
142 } else {
143 this->genParameterGetterDecl(cppType);
144 }
145 }
148 // Generate extra class method declarations
150 std::optional<mlir::StringRef> extraDecls = def.getExtraDecls();
151 if (extraDecls.has_value()) {
152 this->genExtraMethods(extraDecls.value());
153 }
154 }
155 }
156
157protected:
158 mlir::StringRef paramName;
159 std::string paramNameCapitalized;
160};
161
167 using ImplementationGenerator::ImplementationGenerator;
169
172 void setParamName(mlir::StringRef name) {
173 this->paramName = name;
175 }
176
177 virtual void genPrologue() const {
178 os << R"(
179#include <mlir/CAPI/IR.h>
180#include <mlir/CAPI/Support.h>
181#include <llvm/ADT/TypeSwitch.h>
182#include <utility>
183
184using namespace mlir;
185using namespace llvm;
186)";
187 }
188
189 virtual void genArrayRefParameterImpls(mlir::StringRef cppType) const {
190 static constexpr char fmt[] = R"(
191intptr_t {0}{2}_{3}Get{4}Count(Mlir{1} inp) {{
192 auto size = llvm::cast<{3}>(unwrap(inp)).get{4}().size();
193 assert(std::in_range<intptr_t>(size) && "lossy conversion");
194 return static_cast<intptr_t>(size);
195}
196
197{5} {0}{2}_{3}Get{4}At(Mlir{1} inp, intptr_t pos) {{
198 return {6}(llvm::cast<{3}>(unwrap(inp)).get{4}()[pos]);
199}
200 )";
201 assert(!className.empty() && "className must be set");
202 assert(!paramName.empty() && "paramName must be set");
203 mlir::StringRef cppElemType = extractArrayRefElementType(cppType);
204 os << llvm::formatv(
205 fmt,
206 FunctionPrefix, // {0}
207 kind, // {1}
209 className, // {3}
211 mapCppTypeToCapiType(cppElemType), // {5}
212 isPrimitiveType(cppElemType) ? "" : "wrap" // {6}
213 );
214 }
216 virtual void genParameterGetterImpl(mlir::StringRef cppType) const {
217 static constexpr char fmt[] = R"(
218{5} {0}{2}_{3}Get{4}(Mlir{1} inp) {{
219 return {6}(llvm::cast<{3}>(unwrap(inp)).get{4}());
220}
221 )";
222 assert(!className.empty() && "className must be set");
223 assert(!paramName.empty() && "paramName must be set");
224 os << llvm::formatv(
225 fmt,
226 FunctionPrefix, // {0}
227 kind, // {1}
229 className, // {3}
231 mapCppTypeToCapiType(cppType), // {5}
232 isPrimitiveType(cppType) ? "" : "wrap" // {6}
233 );
234 }
235
237 virtual void genDefaultGetBuilderImpl(const mlir::tblgen::AttrOrTypeDef &def) const {
238 static constexpr char fmt[] = R"(
239Mlir{1} {0}{2}_{3}Get(MlirContext ctx{4}) {{
240 {6}
241 return wrap({3}::get(unwrap(ctx){5}));
242}
243 )";
244 assert(!className.empty() && "className must be set");
245
246 // Use raw_string_ostream for efficient string building
247 std::string paramListBuffer;
248 std::string argListBuffer;
249 std::string prefixBuffer;
250 llvm::raw_string_ostream paramListStream(paramListBuffer);
251 llvm::raw_string_ostream argListStream(argListBuffer);
252 llvm::raw_string_ostream prefixStream(prefixBuffer);
253
254 for (const auto &param : def.getParameters()) {
255 mlir::StringRef pName = param.getName();
256 mlir::StringRef cppType = param.getCppType();
257 if (isArrayRefType(cppType)) {
258 // For ArrayRef parameters, convert from pointer + count to ArrayRef
259 mlir::StringRef cppElemType = extractArrayRefElementType(cppType);
260 std::string capiElemType = mapCppTypeToCapiType(cppElemType);
261 paramListStream << ", intptr_t " << pName << "Count, " << capiElemType << " const *"
262 << pName;
263
264 // In the call, we need to convert back to ArrayRef. Check if elements need unwrapping.
265 if (isPrimitiveType(cppElemType)) {
266 argListStream << ", ::llvm::ArrayRef<" << cppElemType << ">(" << pName << ", " << pName
267 << "Count)";
268 } else {
269 prefixStream << "SmallVector<Attribute> storage;";
270 argListStream << ", llvm::map_to_vector(unwrapList(" << pName << "Count, " << pName
271 << ", storage), [](auto a) {return llvm::cast<RecordAttr>(a);})";
272 }
273 } else {
274 std::string capiType = mapCppTypeToCapiType(cppType);
275 paramListStream << ", " << capiType << ' ' << pName;
277 // Add unwrapping if needed
278 argListStream << ", ";
279 if (isPrimitiveType(cppType)) {
280 argListStream << pName;
281 } else if (capiType == "MlirAttribute" || capiType == "MlirType") {
282 // Needs additional cast to the specific attribute/type class
283 argListStream << "::llvm::cast<" << cppType << ">(unwrap(" << pName << "))";
284 } else {
285 // Any other cases, just use an "unwrap" function
286 argListStream << "unwrap(" << pName << ")";
287 }
288 }
289 }
290
291 os << llvm::formatv(
292 fmt,
293 FunctionPrefix, // {0}
294 kind, // {1}
296 className, // {3}
297 paramListBuffer, // {4}
298 argListBuffer, // {5}
299 prefixBuffer // {6}
300 );
301 }
302
303 void genCompleteRecord(const mlir::tblgen::AttrOrTypeDef &def) {
304 mlir::tblgen::Dialect defDialect = def.getDialect();
305
306 // Generate for the selected dialect only
307 if (defDialect.getName() != DialectName) {
308 return;
309 }
310
311 this->setNamespaceAndClassName(defDialect, def.getCppClassName());
312
313 // Generate IsA check implementation
314 if (GenIsA) {
315 this->genIsAImpl();
316 }
317
318 // Generate default Get builder implementation if not skipped
319 if (GenTypeOrAttrGet && !def.skipDefaultBuilders()) {
321 }
322
323 // Generate parameter getter implementations
325 for (const auto &param : def.getParameters()) {
326 this->setParamName(param.getName());
327 mlir::StringRef cppType = param.getCppType();
328 if (isArrayRefType(cppType)) {
329 this->genArrayRefParameterImpls(cppType);
330 } else {
331 this->genParameterGetterImpl(cppType);
332 }
333 }
334 }
335
336 // Generate extra class method implementations
338 std::optional<mlir::StringRef> extraDecls = def.getExtraDecls();
339 if (extraDecls.has_value()) {
340 this->genExtraMethods(extraDecls.value());
341 }
342 }
343 }
344
345protected:
346 mlir::StringRef paramName;
347 std::string paramNameCapitalized;
348};
std::string mapCppTypeToCapiType(StringRef cppType)
mlir::StringRef extractArrayRefElementType(mlir::StringRef cppType)
Extract element type from ArrayRef<...>
llvm::cl::opt< bool > GenTypeOrAttrParamGetters
bool isPrimitiveType(mlir::StringRef cppType)
Check if a C++ type is a known primitive type.
llvm::cl::opt< bool > GenTypeOrAttrGet
llvm::cl::opt< bool > GenIsA
llvm::cl::opt< std::string > DialectName
llvm::cl::opt< std::string > FunctionPrefix
bool isArrayRefType(mlir::StringRef cppType)
Check if a C++ type is an ArrayRef type.
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
Generator for attribute/type C header files.
virtual void genArrayRefParameterGetterDecls(mlir::StringRef cppType) const
Generate accessor function for ArrayRef parameter elements.
void setParamName(mlir::StringRef name)
Set the parameter name for code generation.
void genCompleteRecord(const mlir::tblgen::AttrOrTypeDef &def)
virtual void genDefaultGetBuilderDecl(const mlir::tblgen::AttrOrTypeDef &def) const
Generate default Get builder declaration.
~AttrOrTypeHeaderGenerator() override=default
virtual void genParameterGetterDecl(mlir::StringRef cppType) const
Generate regular getter for non-ArrayRef type parameter.
Generator for attribute/type C implementation files.
void genCompleteRecord(const mlir::tblgen::AttrOrTypeDef &def)
void setParamName(mlir::StringRef name)
Set the parameter name for code generation.
virtual void genDefaultGetBuilderImpl(const mlir::tblgen::AttrOrTypeDef &def) const
Generate default Get builder implementation.
virtual void genArrayRefParameterImpls(mlir::StringRef cppType) const
~AttrOrTypeImplementationGenerator() override=default
virtual void genParameterGetterImpl(mlir::StringRef cppType) const
mlir::StringRef className
mlir::StringRef dialectNamespace
virtual void genExtraMethods(mlir::StringRef extraDecl) const
Generate code for extra methods from an extraClassDeclaration
virtual void setNamespaceAndClassName(const mlir::tblgen::Dialect &d, mlir::StringRef cppClassName)
Set the dialect and class name for code generation.
std::string dialectNameCapitalized
llvm::raw_ostream & os
std::string kind
Generator for common C header file elements.
virtual void genIsADecl() const
Generator for common C implementation file elements.
virtual void genIsAImpl() const