LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
CommonCAPIGen.h
Go to the documentation of this file.
1//===- CommonCAPIGen.h - Common utilities for C API generation ------------===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9//
10// Common utilities shared between all CAPI generators (ops, attrs, types)
11//
12//===----------------------------------------------------------------------===//
13
14#pragma once
15
16#include <mlir/TableGen/Dialect.h>
17
18#include <llvm/ADT/StringExtras.h>
19#include <llvm/ADT/StringRef.h>
20#include <llvm/ADT/StringSwitch.h>
21#include <llvm/Support/CommandLine.h>
22#include <llvm/Support/FormatVariadic.h>
23
24#include <memory>
25#include <string>
26
27constexpr bool WARN_SKIPPED_METHODS = false;
28
30template <typename S> inline void warnSkipped(const S &methodName, const std::string &message) {
32 llvm::errs() << "Warning: Skipping method '" << methodName << "' - " << message << '\n';
33 }
34}
35
37template <typename S>
38inline void warnSkippedNoConversion(const S &methodName, const std::string &cppType) {
40 warnSkipped(methodName, "no conversion to C API type for '" + cppType + '\'');
41 }
42}
43
44// Forward declarations for Clang classes
45namespace clang {
46class Lexer;
47class SourceManager;
48} // namespace clang
49
50// Shared command-line options used by all CAPI generators
51extern llvm::cl::OptionCategory OpGenCat;
52extern llvm::cl::opt<std::string> DialectName;
53extern llvm::cl::opt<std::string> FunctionPrefix;
54
55// Shared flags for controlling code generation
56extern llvm::cl::opt<bool> GenIsA;
57extern llvm::cl::opt<bool> GenOpBuild;
58extern llvm::cl::opt<bool> GenOpOperandGetters;
59extern llvm::cl::opt<bool> GenOpOperandSetters;
60extern llvm::cl::opt<bool> GenOpAttributeGetters;
61extern llvm::cl::opt<bool> GenOpAttributeSetters;
62extern llvm::cl::opt<bool> GenOpRegionGetters;
63extern llvm::cl::opt<bool> GenOpResultGetters;
64extern llvm::cl::opt<bool> GenTypeOrAttrGet;
65extern llvm::cl::opt<bool> GenTypeOrAttrParamGetters;
66extern llvm::cl::opt<bool> GenExtraClassMethods;
67
75inline std::string toPascalCase(mlir::StringRef str) {
76 if (str.empty()) {
77 return "";
78 }
79
80 std::string result;
81 result.reserve(str.size());
82 llvm::raw_string_ostream resultStream(result);
83 bool capitalizeNext = true;
84
85 for (char c : str) {
86 if (c == '_' || c == ':') {
87 capitalizeNext = true;
88 } else {
89 resultStream << (capitalizeNext ? llvm::toUpper(c) : c);
90 capitalizeNext = false;
91 }
92 }
93
94 return result;
95}
96
100inline bool isIntegerType(mlir::StringRef type) {
101 // Consume optional root namespace token
102 type.consume_front("::");
103 // Handle special names first
104 if (type == "signed" || type == "unsigned" || type == "size_t" || type == "char32_t" ||
105 type == "char16_t" || type == "char8_t" || type == "wchar_t") {
106 return true;
107 }
108 // Handle standard integer types with optional signed/unsigned prefix
109 type.consume_front("signed ") || type.consume_front("unsigned ");
110 if (type == "char" || type == "int" || type == "short" || type == "short int" || type == "long" ||
111 type == "long int" || type == "long long" || type == "long long int") {
112 return true;
113 }
114 // Handle fixed-width integer types (https://cppreference.com/w/cpp/types/integer.html)
115 type.consume_front("std::"); // optional
116 if (type.consume_back("_t") && (type.consume_front("int") || type.consume_front("uint"))) {
117 // intmax_t, intptr_t, uintmax_t, uintptr_t
118 if (type == "max" || type == "ptr") {
119 return true;
120 }
121 // Optional "_fast" or "_least" and finally bit width to cover the rest
122 type.consume_back("_fast") || type.consume_back("_least");
123 if (type == "8" || type == "16" || type == "32" || type == "64") {
124 return true;
125 }
126 }
127 return false;
128}
129
136inline bool isPrimitiveType(mlir::StringRef cppType) {
137 cppType.consume_front("::");
138 return cppType == "void" || cppType == "bool" || cppType == "float" || cppType == "double" ||
139 cppType == "long double" || isIntegerType(cppType);
140}
141
145inline bool isCppModifierKeyword(mlir::StringRef tokenText) {
146 return llvm::StringSwitch<bool>(tokenText)
147 .Case("inline", true)
148 .Case("static", true)
149 .Case("virtual", true)
150 .Case("explicit", true)
151 .Case("constexpr", true)
152 .Case("consteval", true)
153 .Case("extern", true)
154 .Case("mutable", true)
155 .Case("friend", true)
156 .Default(false);
157}
158
162inline bool isCppLanguageConstruct(mlir::StringRef methodName) {
163 return llvm::StringSwitch<bool>(methodName)
164 .Case("if", true)
165 .Case("for", true)
166 .Case("while", true)
167 .Case("switch", true)
168 .Case("return", true)
169 .Case("sizeof", true)
170 .Case("decltype", true)
171 .Case("alignof", true)
172 .Case("typeid", true)
173 .Case("static_assert", true)
174 .Case("noexcept", true)
175 .Default(false);
176}
177
181inline bool isAPIntType(mlir::StringRef cppType) {
182 cppType.consume_front("::");
183 cppType.consume_front("llvm::") || cppType.consume_front("mlir::");
184 return cppType == "APInt";
185}
186
190inline bool isArrayRefType(mlir::StringRef cppType) {
191 cppType.consume_front("::");
192 cppType.consume_front("llvm::") || cppType.consume_front("mlir::");
193 return cppType.starts_with("ArrayRef<");
194}
195
197inline mlir::StringRef extractArrayRefElementType(mlir::StringRef cppType) {
198 assert(isArrayRefType(cppType) && "must check `isArrayRefType()` outside");
199
200 // Remove "ArrayRef<" prefix and ">" suffix
201 cppType.consume_front("::");
202 cppType.consume_front("llvm::") || cppType.consume_front("mlir::");
203 cppType.consume_front("ArrayRef<") && cppType.consume_back(">");
204 return cppType;
205}
206
216public:
220 explicit ClangLexerContext(mlir::StringRef source, mlir::StringRef bufferName = "input");
221
224 clang::Lexer &getLexer() const;
225
228 clang::SourceManager &getSourceManager() const;
229
232 bool isValid() const { return lexer != nullptr; }
233
234private:
235 struct Impl;
236 std::unique_ptr<Impl> impl;
237 clang::Lexer *lexer = nullptr;
238};
239
244 std::string type;
246 std::string name;
247
251 MethodParameter(const std::string &paramType, const std::string &paramName)
252 : type(mlir::StringRef(paramType).trim().str()),
253 name(mlir::StringRef(paramName).trim().str()) {}
254};
255
262 std::string returnType;
264 std::string methodName;
266 std::string documentation;
268 bool isConst = false;
270 bool hasParameters = false;
272 std::vector<MethodParameter> parameters;
273};
274
301llvm::SmallVector<ExtraMethod> parseExtraMethods(mlir::StringRef extraDecl);
302
307bool matchesMLIRClass(mlir::StringRef cppType, mlir::StringRef typeName);
308
312std::optional<std::string> tryCppTypeToCapiType(mlir::StringRef cppType);
313
320std::string mapCppTypeToCapiType(mlir::StringRef cppType);
321
323struct Generator {
324 Generator(std::string_view recordKind, llvm::raw_ostream &outputStream)
325 : kind(recordKind), os(outputStream), dialectNameCapitalized(toPascalCase(DialectName)) {}
326 virtual ~Generator() = default;
327
331 virtual void
332 setNamespaceAndClassName(const mlir::tblgen::Dialect &d, mlir::StringRef cppClassName) {
333 this->dialectNamespace = d.getCppNamespace();
334 this->className = cppClassName;
335 }
336
339 virtual void genExtraMethods(mlir::StringRef extraDecl) const {
340 if (extraDecl.empty()) {
341 return;
342 }
343 for (const ExtraMethod &method : parseExtraMethods(extraDecl)) {
344 genExtraMethod(method);
345 }
346 }
347
350 virtual void genExtraMethod(const ExtraMethod &method) const = 0;
351
352protected:
353 std::string kind;
354 llvm::raw_ostream &os;
356 mlir::StringRef dialectNamespace;
357 mlir::StringRef className;
358};
359
361struct HeaderGenerator : public Generator {
363 ~HeaderGenerator() override = default;
364
365 virtual void genPrologue() const {
366 os << R"(
367#include "llzk-c/Builder.h"
368#include <mlir-c/IR.h>
370#ifdef __cplusplus
371extern "C" {
372#endif
373)";
374 }
375
376 virtual void genEpilogue() const {
377 os << R"(
378#ifdef __cplusplus
379}
380#endif
381)";
382 }
383
384 virtual void genIsADecl() const {
385 static constexpr char fmt[] = R"(
387MLIR_CAPI_EXPORTED bool {0}{1}IsA_{2}_{3}(Mlir{1});
388)";
389 assert(!dialectNamespace.empty() && "Dialect must be set");
390 os << llvm::formatv(
391 fmt,
392 FunctionPrefix, // {0}
393 kind, // {1}
395 className, // {3}
396 dialectNamespace // {4}
397 );
398 }
399
401 void genExtraMethod(const ExtraMethod &method) const override {
402 // Convert return type to C API type, skip if it can't be converted
403 std::optional<std::string> capiReturnTypeOpt = tryCppTypeToCapiType(method.returnType);
404 if (!capiReturnTypeOpt.has_value()) {
406 return;
407 }
408 std::string capiReturnType = capiReturnTypeOpt.value();
409
410 // Build parameter list
411 std::string paramList;
412 llvm::raw_string_ostream paramListStream(paramList);
413 paramListStream << llvm::formatv("Mlir{0} inp", kind);
414 for (const auto &param : method.parameters) {
415 // Convert C++ type to C API type for parameter, skip if it can't be converted
416 std::optional<std::string> capiParamTypeOpt = tryCppTypeToCapiType(param.type);
417 if (!capiParamTypeOpt.has_value()) {
418 warnSkippedNoConversion(method.methodName, param.type);
419 return;
420 }
421 const std::string &capiParamType = capiParamTypeOpt.value();
422 paramListStream << ", " << capiParamType << ' ' << param.name;
423 }
424
425 // Generate declaration
426 if (method.documentation.empty()) {
427 os << llvm::formatv("\n/// {0}\n", method.methodName);
428 } else {
429 os << llvm::formatv("\n{0}\n", method.documentation);
430 }
431 os << llvm::formatv(
432 "MLIR_CAPI_EXPORTED {0} {1}{2}_{3}{4}({5});\n",
433 capiReturnType, // {0}
434 FunctionPrefix, // {1}
436 className, // {3}
437 toPascalCase(method.methodName), // {4}
438 paramList // {5}
439 );
440 }
441};
442
446 ~ImplementationGenerator() override = default;
447
448 virtual void genIsAImpl() const {
449 static constexpr char fmt[] = R"(
450bool {0}{1}IsA_{2}_{3}(Mlir{1} inp) {{
451 return llvm::isa<{3}>(unwrap(inp));
452}
453)";
454 assert(!className.empty() && "className must be set");
456 }
457
459 void genExtraMethod(const ExtraMethod &method) const override {
460 // Convert return type to C API type, skip if it can't be converted
461 std::optional<std::string> capiReturnTypeOpt = tryCppTypeToCapiType(method.returnType);
462 if (!capiReturnTypeOpt.has_value()) {
464 return;
465 }
466 std::string capiReturnType = capiReturnTypeOpt.value();
467
468 // Build the return statement prefix and suffix
469 std::string returnPrefix;
470 std::string returnSuffix;
471 mlir::StringRef cppReturnType = method.returnType;
472
473 if (cppReturnType == "void") {
474 // "void" type doesn't even need "return"
475 returnPrefix = "";
476 returnSuffix = "";
477 } else {
478 // Check if return needs wrapping
479 if (isPrimitiveType(cppReturnType)) {
480 // Primitive types don't need wrapping
481 returnPrefix = "return ";
482 returnSuffix = "";
483 } else if (capiReturnType.starts_with("Mlir") || isAPIntType(cppReturnType)) {
484 // MLIR C API types and APInt type need wrapping
485 returnPrefix = "return wrap(";
486 returnSuffix = ")";
487 } else {
488 return;
489 }
490 }
491
492 // Build parameter list for C API function signature
493 std::string paramList;
494 llvm::raw_string_ostream paramListStream(paramList);
495 paramListStream << llvm::formatv("Mlir{0} inp", kind);
496 for (const auto &param : method.parameters) {
497 // Convert C++ type to C API type for parameter, skip if it can't be converted
498 std::optional<std::string> capiParamTypeOpt = tryCppTypeToCapiType(param.type);
499 if (!capiParamTypeOpt.has_value()) {
500 warnSkippedNoConversion(method.methodName, param.type);
501 return;
502 }
503 const std::string &capiParamType = capiParamTypeOpt.value();
504 paramListStream << ", " << capiParamType << ' ' << param.name;
505 }
506
507 // Build argument list for C++ method call
508 std::string argList;
509 llvm::raw_string_ostream argListStream(argList);
510 for (size_t i = 0; i < method.parameters.size(); ++i) {
511 if (i > 0) {
512 argListStream << ", ";
513 }
514 const auto &param = method.parameters[i];
515
516 // Check if parameter needs unwrapping
517 mlir::StringRef cppParamType = param.type;
518 if (isPrimitiveType(cppParamType)) {
519 // Primitive types don't need unwrapping
520 argListStream << param.name;
521 } else if (isAPIntType(cppParamType)) {
522 // APInt needs unwrapping
523 argListStream << "unwrap(" << param.name << ')';
524 } else {
525 // Convert C++ type to C API type for parameter, skip if it can't be converted
526 std::optional<std::string> capiParamTypeOpt = tryCppTypeToCapiType(cppParamType);
527 if (capiParamTypeOpt.has_value() && capiParamTypeOpt->starts_with("Mlir")) {
528 // MLIR C API types need unwrapping
529 argListStream << "unwrap(" << param.name << ')';
530 } else {
531 warnSkippedNoConversion(method.methodName, cppParamType.str());
532 return;
533 }
534 }
535 }
536
537 // Generate implementation
538 os << '\n';
539 os << llvm::formatv(
540 "{0} {1}{2}_{3}{4}({5}) {{\n",
541 capiReturnType, // {0}
542 FunctionPrefix, // {1}
544 className, // {3}
545 toPascalCase(method.methodName), // {4}
546 paramList // {5}
547 );
548 os << llvm::formatv(
549 " {0}llvm::cast<{1}>(unwrap(inp)).{2}({3}){4};\n",
550 returnPrefix, // {0}
551 className, // {1}
552 method.methodName, // {2}
553 argList, // {3}
554 returnSuffix // {4}
555 );
556 os << "}\n";
557 }
558};
559
561struct TestGenerator : public Generator {
563 ~TestGenerator() override = default;
564
566 virtual void genTestClassPrologue() const {
567 static constexpr char fmt[] = "class {0}{1}LinkTests : public CAPITest {{};\n";
568 os << llvm::formatv(fmt, dialectNameCapitalized, kind);
569 }
570
572 virtual void genIsATest() const {
573 static constexpr char fmt[] = R"(
575TEST_F({2}{1}LinkTests, IsA_{2}_{3}) {{
576 auto test{1} = createIndex{1}();
577
578 // This will always return false since `createIndex*` returns an MLIR builtin
579 EXPECT_FALSE({0}{1}IsA_{2}_{3}(test{1}));
580
581 {4}(test{1});
582}
583)";
584 assert(!className.empty() && "className must be set");
585 os << llvm::formatv(
586 fmt,
587 FunctionPrefix, // {0}
588 kind, // {1}
590 className, // {3}
591 genCleanup() // {4}
592 );
593 }
594
596 void genExtraMethod(const ExtraMethod &method) const override {
597 // Convert return type to C API type, skip if it can't be converted
598 std::optional<std::string> capiReturnTypeOpt = tryCppTypeToCapiType(method.returnType);
599 if (!capiReturnTypeOpt.has_value()) {
601 return;
602 }
603
604 // Build parameter list for dummy values
605 std::string dummyParams;
606 llvm::raw_string_ostream dummyParamsStream(dummyParams);
607 std::string paramList;
608 llvm::raw_string_ostream paramListStream(paramList);
609
610 for (const auto &param : method.parameters) {
611 // Convert C++ type to C API type for parameter, skip if it can't be converted
612 std::optional<std::string> capiParamTypeOpt = tryCppTypeToCapiType(param.type);
613 if (!capiParamTypeOpt.has_value()) {
614 warnSkippedNoConversion(method.methodName, param.type);
615 return;
616 }
617 const std::string &capiParamType = capiParamTypeOpt.value();
618 std::string name = param.name;
619
620 // Generate dummy value creation for each parameter
621 if (capiParamType == "bool") {
622 dummyParamsStream << " bool " << name << " = false;\n";
623 } else if (capiParamType == "MlirValue") {
624 dummyParamsStream << " auto " << name << " = mlirOperationGetResult(testOp, 0);\n";
625 } else if (capiParamType == "MlirType") {
626 dummyParamsStream << " auto " << name << " = createIndexType();\n";
627 } else if (capiParamType == "MlirAttribute") {
628 dummyParamsStream << " auto " << name << " = createIndexAttribute();\n";
629 } else if (capiParamType == "MlirStringRef") {
630 dummyParamsStream << " auto " << name << " = mlirStringRefCreateFromCString(\"\");\n";
631 } else if (isIntegerType(capiParamType)) {
632 dummyParamsStream << " " << capiParamType << ' ' << name << " = 0;\n";
633 } else {
634 // For unknown types, create a default-initialized variable
635 dummyParamsStream << " " << capiParamType << ' ' << name << " = {};\n";
636 }
637
638 paramListStream << ", " << name;
639 }
640
641 static constexpr char fmt[] = R"(
643TEST_F({2}{1}LinkTests, {0}_{3}_{4}) {{
644 auto test{1} = createIndex{1}();
645
646 if ({0}{1}IsA_{2}_{3}(test{1})) {{
647{5}
648 (void){0}{2}_{3}{4}(test{1}{6});
649 }
650
651 {7}(test{1});
652}
653)";
654 assert(!className.empty() && "className must be set");
655 os << llvm::formatv(
656 fmt,
657 FunctionPrefix, // {0}
658 kind, // {1}
660 className, // {3}
661 toPascalCase(method.methodName), // {4}
662 dummyParams, // {5}
663 paramList, // {6}
664 genCleanup() // {7}
665 );
666 }
667
677 virtual std::string genCleanup() const {
678 // The default case is to just comment out the rest of the cleanup line
679 return "//";
680 }
681};
mlir::StringRef extractArrayRefElementType(mlir::StringRef cppType)
Extract element type from ArrayRef<...>
llvm::cl::OptionCategory OpGenCat
llvm::cl::opt< bool > GenOpOperandSetters
llvm::cl::opt< bool > GenTypeOrAttrParamGetters
bool isPrimitiveType(mlir::StringRef cppType)
Check if a C++ type is a known primitive type.
llvm::cl::opt< bool > GenTypeOrAttrGet
void warnSkippedNoConversion(const S &methodName, const std::string &cppType)
Print warning about skipping a function due to no conversion of C++ type to C API type.
llvm::cl::opt< bool > GenIsA
std::string mapCppTypeToCapiType(mlir::StringRef cppType)
Map C++ type to corresponding C API type.
llvm::cl::opt< bool > GenOpBuild
llvm::cl::opt< std::string > DialectName
bool isCppModifierKeyword(mlir::StringRef tokenText)
Check if a token text represents a C++ modifier/specifier keyword.
bool isAPIntType(mlir::StringRef cppType)
Check if a C++ type is APInt.
llvm::cl::opt< std::string > FunctionPrefix
std::optional< std::string > tryCppTypeToCapiType(mlir::StringRef cppType)
Convert C++ type to MLIR C API type.
bool isArrayRefType(mlir::StringRef cppType)
Check if a C++ type is an ArrayRef type.
llvm::cl::opt< bool > GenOpRegionGetters
constexpr bool WARN_SKIPPED_METHODS
bool isIntegerType(mlir::StringRef type)
Check if a C++ type is a known integer type.
llvm::cl::opt< bool > GenOpResultGetters
llvm::cl::opt< bool > GenOpAttributeGetters
bool matchesMLIRClass(mlir::StringRef cppType, mlir::StringRef typeName)
Check if a C++ type matches an MLIR type pattern.
llvm::cl::opt< bool > GenOpAttributeSetters
llvm::cl::opt< bool > GenOpOperandGetters
llvm::SmallVector< ExtraMethod > parseExtraMethods(mlir::StringRef extraDecl)
Parse method declarations from an extraClassDeclaration using Clang's Lexer.
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
void warnSkipped(const S &methodName, const std::string &message)
Print warning about skipping a function.
bool isCppLanguageConstruct(mlir::StringRef methodName)
Check if a method name represents a C++ control flow keyword or language construct.
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation source
Definition LICENSE.txt:28
clang::SourceManager & getSourceManager() const
Get the source manager instance.
bool isValid() const
Check if the lexer was successfully created.
ClangLexerContext(mlir::StringRef source, mlir::StringRef bufferName="input")
Construct a lexer context for the given source code.
clang::Lexer & getLexer() const
Get the lexer instance.
Structure to represent a parsed method signature from an extraClassDeclaration
bool isConst
Whether the method is const-qualified.
bool hasParameters
Whether the method has parameters (unsupported for now)
std::vector< MethodParameter > parameters
The parameters of the method.
std::string returnType
The C++ return type of the method.
std::string methodName
The name of the method.
std::string documentation
Properly escaped documentation comment (if any)
virtual ~Generator()=default
Generator(std::string_view recordKind, llvm::raw_ostream &outputStream)
mlir::StringRef className
mlir::StringRef dialectNamespace
virtual void genExtraMethods(mlir::StringRef extraDecl) const
Generate code for extra methods from an extraClassDeclaration
virtual void setNamespaceAndClassName(const mlir::tblgen::Dialect &d, mlir::StringRef cppClassName)
Set the dialect and class name for code generation.
virtual void genExtraMethod(const ExtraMethod &method) const =0
Generate code for an extra method.
std::string dialectNameCapitalized
llvm::raw_ostream & os
std::string kind
Generator for common C header file elements.
Generator(std::string_view recordKind, llvm::raw_ostream &outputStream)
virtual void genPrologue() const
void genExtraMethod(const ExtraMethod &method) const override
Generate declaration for an extra method from an extraClassDeclaration
~HeaderGenerator() override=default
virtual void genEpilogue() const
virtual void genIsADecl() const
Generator for common C implementation file elements.
Generator(std::string_view recordKind, llvm::raw_ostream &outputStream)
virtual void genIsAImpl() const
void genExtraMethod(const ExtraMethod &method) const override
Generate implementation for an extra method from an extraClassDeclaration
~ImplementationGenerator() override=default
std::string name
The name of the parameter.
std::string type
The C++ type of the parameter.
MethodParameter(const std::string &paramType, const std::string &paramName)
Construct a new Method Parameter object.
Generator for common test implementation file elements.
virtual void genTestClassPrologue() const
Generate the test class prologue.
Generator(std::string_view recordKind, llvm::raw_ostream &outputStream)
~TestGenerator() override=default
void genExtraMethod(const ExtraMethod &method) const override
Generate test for an extra method from extraClassDeclaration.
virtual void genIsATest() const
Generate IsA test for a class.
virtual std::string genCleanup() const
Generate cleanup code for test methods.