LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
CommonCAPIGen.cpp
Go to the documentation of this file.
1//===- CommonCAPIGen.cpp - Common utilities for C API generation ----------===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9//
10// Shared command-line options for all CAPI generators (ops, attrs, types)
11//
12//===----------------------------------------------------------------------===//
13
14#include "CommonCAPIGen.h"
15
16#include <llvm/ADT/StringMap.h>
17
18#include <clang/Basic/FileManager.h>
19#include <clang/Basic/LangOptions.h>
20#include <clang/Basic/SourceManager.h>
21#include <clang/Lex/Lexer.h>
22#include <optional>
23
24using namespace mlir;
25using namespace clang;
26
27llvm::cl::OptionCategory
28 OpGenCat("Options for -gen-op-capi-header, -gen-op-capi-impl, and -gen-op-capi-tests");
29
30llvm::cl::opt<std::string> DialectName(
31 "dialect",
32 llvm::cl::desc(
33 "The dialect name to use for this group of ops. "
34 "Must match across header, implementation, and test generation."
35 ),
36 llvm::cl::cat(OpGenCat)
37);
38
39llvm::cl::opt<std::string> FunctionPrefix(
40 "prefix",
41 llvm::cl::desc(
42 "The prefix to use for generated C API function names. "
43 "Default is 'mlir'. Must match across header, implementation, and test generation."
44 ),
45 llvm::cl::init("mlir"), llvm::cl::cat(OpGenCat)
46);
47
48llvm::cl::opt<bool> GenIsA(
49 "gen-isa", llvm::cl::desc("Generate IsA checks"), llvm::cl::init(true), llvm::cl::cat(OpGenCat)
50);
51
52llvm::cl::opt<bool> GenOpBuild(
53 "gen-op-build", llvm::cl::desc("Generate operation build(..) functions"), llvm::cl::init(true),
54 llvm::cl::cat(OpGenCat)
55);
56
57llvm::cl::opt<bool> GenOpOperandGetters(
58 "gen-operand-getters", llvm::cl::desc("Generate operand getters for operations"),
59 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
60);
61
62llvm::cl::opt<bool> GenOpOperandSetters(
63 "gen-operand-setters", llvm::cl::desc("Generate operand setters for operations"),
64 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
65);
66
67llvm::cl::opt<bool> GenOpAttributeGetters(
68 "gen-attribute-getters", llvm::cl::desc("Generate attribute getters for operations"),
69 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
70);
71
72llvm::cl::opt<bool> GenOpAttributeSetters(
73 "gen-attribute-setters", llvm::cl::desc("Generate attribute setters for operations"),
74 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
75);
76
77llvm::cl::opt<bool> GenOpRegionGetters(
78 "gen-region-getters", llvm::cl::desc("Generate region getters for operations"),
79 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
80);
81
82llvm::cl::opt<bool> GenOpResultGetters(
83 "gen-result-getters", llvm::cl::desc("Generate result getters for operations"),
84 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
85);
86
87llvm::cl::opt<bool> GenTypeOrAttrGet(
88 "gen-type-attr-get", llvm::cl::desc("Generate get functions for types and attributes"),
89 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
90);
91
92llvm::cl::opt<bool> GenTypeOrAttrParamGetters(
93 "gen-parameter-getters", llvm::cl::desc("Generate parameter getters for types and attributes"),
94 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
95);
96
97llvm::cl::opt<bool> GenExtraClassMethods(
98 "gen-extra-class-methods",
99 llvm::cl::desc("Generate C API wrappers for methods in `extraClassDeclaration`"),
100 llvm::cl::init(true), llvm::cl::cat(OpGenCat)
101);
102
103//===----------------------------------------------------------------------===//
104// ClangLexerContext Implementation
105//===----------------------------------------------------------------------===//
108 LangOptions langOpts;
110 IntrusiveRefCntPtr<FileManager> fileMgr;
112 IntrusiveRefCntPtr<DiagnosticIDs> diagIDs;
114 IntrusiveRefCntPtr<DiagnosticOptions> diagOpts;
116 std::unique_ptr<DiagnosticsEngine> diags;
118 std::unique_ptr<SourceManager> sourceMgr;
120 std::unique_ptr<Lexer> lexer;
121
122 Impl() : diagIDs(new DiagnosticIDs()), diagOpts(new DiagnosticOptions()) {
123 // Enable C++ language features for lexing
124 langOpts.CPlusPlus = true;
125 langOpts.CPlusPlus11 = true;
126
127 FileSystemOptions fileSystemOpts;
128 fileMgr = new FileManager(fileSystemOpts);
129 diags = std::make_unique<DiagnosticsEngine>(diagIDs, diagOpts);
130 sourceMgr = std::make_unique<SourceManager>(*diags, *fileMgr);
131 }
132};
133
134ClangLexerContext::ClangLexerContext(StringRef source, StringRef bufferName)
135 : impl(std::make_unique<Impl>()) {
136 if (source.empty()) {
137 llvm::errs() << "Warning: ClangLexerContext created with empty source\n";
138 return;
139 }
140
141 // Create a memory buffer for the input
142 std::unique_ptr<llvm::MemoryBuffer> buffer = llvm::MemoryBuffer::getMemBuffer(source, bufferName);
143 if (!buffer) {
144 llvm::errs() << "Error: Failed to create memory buffer for ClangLexerContext\n";
145 return;
146 }
147
148 FileID fileID = impl->sourceMgr->createFileID(std::move(buffer), SrcMgr::C_User);
149 llvm::MemoryBufferRef bufferRef = impl->sourceMgr->getBufferOrFake(fileID);
150
151 if (bufferRef.getBufferSize() == 0 && !source.empty()) {
152 llvm::errs() << "Error: Failed to get buffer from source manager in ClangLexerContext\n";
153 return;
154 }
155
156 // Create the lexer
157 impl->lexer = std::make_unique<Lexer>(fileID, bufferRef, *impl->sourceMgr, impl->langOpts);
158 // Enable comment parsing for extraClassDeclaration method extraction
159 impl->lexer->SetCommentRetentionState(true);
160 lexer = impl->lexer.get();
161}
162
164 assert(lexer && "Lexer not initialized - check isValid() before calling getLexer()");
165 return *lexer;
166}
167
169 assert(
170 impl && impl->sourceMgr &&
171 "SourceManager not initialized - check isValid() before calling getSourceManager()"
172 );
173 return *impl->sourceMgr;
174}
175
176namespace {
177
178static inline bool isAccessModifier(StringRef tokenText) {
179 return tokenText == "private" || tokenText == "public" || tokenText == "protected";
180}
181
185static inline std::vector<Token> tokenize(const ClangLexerContext &lexerCtx) {
186 Lexer &lexer = lexerCtx.getLexer();
187 std::vector<Token> tokens;
188 tokens.reserve(128); // Reasonable default
189 for (Token tok; !lexer.LexFromRawLexer(tok);) {
190 if (tok.is(tok::eof)) {
191 break;
192 }
193 tokens.push_back(tok);
194 }
195 return tokens;
196}
197
199static inline std::string getDocumentation(
200 size_t returnTypeStart, const std::vector<Token> &tokens, const SourceManager &sourceMgr
201) {
202 std::string documentation;
203 for (size_t j = returnTypeStart; j > 0; --j) {
204 Token curr = tokens[j - 1];
205 if (curr.is(tok::comment)) {
206 StringRef comment(sourceMgr.getCharacterData(curr.getLocation()), curr.getLength());
207 comment.consume_front("///");
208 comment.consume_front("//");
209 if (comment.consume_front("/*")) {
210 comment.consume_back("*/");
211 }
212 comment = comment.trim();
213
214 // Trim whitespace
215 if (!comment.empty()) {
216 std::string newDoc("/// ");
217 newDoc += comment;
218 if (!documentation.empty()) {
219 newDoc += '\n';
220 newDoc += documentation;
221 }
222 documentation = std::move(newDoc);
223 }
224 } else if (!curr.is(tok::unknown)) {
225 // Stop looking backwards when we hit a non-comment, non-whitespace token
226 // that could be part of another declaration
227 if (curr.isOneOf(tok::semi, tok::r_brace, tok::l_brace)) {
228 break;
229 }
230 if (curr.is(tok::raw_identifier) && isAccessModifier(curr.getRawIdentifier())) {
231 break;
232 }
233 }
234 }
235 return documentation;
236}
237
238} // namespace
239
240//===----------------------------------------------------------------------===//
241// Method Parsing Implementation
242//===----------------------------------------------------------------------===//
243
245enum class AccessLevel : std::uint8_t { Public, Private, Protected };
246
254static bool
255updateAccessLevel(size_t &i, const std::vector<Token> &tokens, AccessLevel &currentAccess) {
256 if (i + 1 < tokens.size() && tokens[i + 1].is(tok::colon)) {
257 if (tokens[i].is(tok::raw_identifier)) {
258 StringRef name = tokens[i].getRawIdentifier();
259 if (name == "private") {
260 currentAccess = AccessLevel::Private;
261 i++; // extra skip for the colon
262 return true;
263 } else if (name == "public") {
264 currentAccess = AccessLevel::Public;
265 i++; // extra skip for the colon
266 return true;
267 } else if (name == "protected") {
268 currentAccess = AccessLevel::Protected;
269 i++; // extra skip for the colon
270 return true;
271 }
272 }
273 }
274 return false;
275}
276
287static std::string extractReturnType(
288 size_t i, const std::vector<Token> &tokens, const SourceManager &sourceMgr,
289 size_t &returnTypeStart
290) {
291 std::string returnType;
292 returnTypeStart = 0;
293 bool isStaticMethod = false;
294
295 // Look backwards for return type start, stopping at declaration boundaries
296 for (size_t j = i; j > 0; --j) {
297 Token curr = tokens[j - 1];
298 // Semicolon or right brace indicates lookback has reached the end of a prior declaration.
299 if (curr.isOneOf(tok::semi, tok::r_brace)) {
300 returnTypeStart = j;
301 break;
302 }
303 // Check for "static" or access modifiers in the return type lookback (both appear as
304 // `raw_identifier` in raw token stream).
305 if (curr.is(tok::raw_identifier)) {
306 StringRef text = curr.getRawIdentifier();
307 if (text == "static") {
308 isStaticMethod = true;
309 break;
310 }
311 if (tokens[j].is(tok::colon) && isAccessModifier(text)) {
312 // In this case, `returnTypeStart` must be after the colon.
313 returnTypeStart = j + 1;
314 assert(returnTypeStart < tokens.size());
315 break;
316 }
317 }
318 }
319
320 // Skip static methods (for now)
321 if (isStaticMethod) {
322 return "";
323 }
324
325 // Adjust `returnTypeStart` for potential comment tokens. Skip as many
326 // sequential comments as needed.
327 while (tokens[returnTypeStart].is(tok::comment)) {
328 returnTypeStart++;
329 assert(returnTypeStart <= i);
330 }
331
332 // Build return type from tokens, skipping modifiers and comments.
333 returnType.reserve(32); // Reasonable default for most type names
334 llvm::raw_string_ostream returnTypeStream(returnType);
335
336 for (size_t j = returnTypeStart; j < i; ++j) {
337 // Skip comments - they should be extracted as documentation, not part of the return type
338 if (tokens[j].is(tok::comment)) {
339 continue;
340 }
341
342 StringRef tokenText(sourceMgr.getCharacterData(tokens[j].getLocation()), tokens[j].getLength());
343
344 // Skip access specifiers (e.g., "private", "public", "protected").
345 if (tokens[j].is(tok::raw_identifier) && isAccessModifier(tokenText)) {
346 // If followed by a colon, skip that too
347 if (j + 1 < i && tokens[j + 1].is(tok::colon)) {
348 j++; // Skip the colon too
349 }
350 continue;
351 }
352
353 // Skip common implementation keywords that indicate we're in code, not a declaration.
354 if (tokens[j].is(tok::raw_identifier) && tokenText == "return") {
355 // This indicates we've hit implementation code, stop parsing
356 returnType.clear();
357 break;
358 }
359
360 // Skip modifiers and language keywords that shouldn't be in the return type
361 if (tokens[j].is(tok::raw_identifier) && isCppModifierKeyword(tokenText)) {
362 continue;
363 }
364
365 // Skip standalone colons (from lookback to access specifiers)
366 if (tokens[j].is(tok::colon)) {
367 // Only skip if it's not part of ::
368 if (j == 0 || j + 1 >= i || !tokens[j - 1].is(tok::colon)) {
369 continue;
370 }
371 }
372
373 // Add spacing between tokens (but not around ::)
374 if (!returnType.empty() && !returnType.ends_with("::") && tokenText != "::" &&
375 !tokenText.starts_with("::")) {
376 returnTypeStream << ' ';
377 }
378 returnTypeStream << tokenText;
379 }
380
381 // Trim possible whitespace
382 return StringRef(returnType).trim().str();
383}
384
395static bool parseMethodParameters(
396 size_t i, const std::vector<Token> &tokens, const SourceManager &sourceMgr,
397 size_t &closeParenIdx, bool &hasParameters, std::vector<MethodParameter> &parameters
398) {
399 const size_t tokenCount = tokens.size();
400 closeParenIdx = tokenCount;
401 hasParameters = false;
402 parameters.clear();
403
404 // Initialize parenDepth to 1 to account for the opening '(' at tokens[i+1]
405 // Start scanning from i+2 (the first token after the opening paren)
406 size_t parenDepth = 1;
407 for (size_t j = i + 2; j < tokenCount; ++j) {
408 if (tokens[j].is(tok::l_paren)) {
409 parenDepth++;
410 } else if (tokens[j].is(tok::r_paren)) {
411 parenDepth--;
412 if (parenDepth == 0) {
413 closeParenIdx = j;
414
415 // Parse parameters between '(' and ')'
416 // Parameters follow the pattern: type name [, type name ...]
417 std::vector<Token> paramTokens;
418 for (size_t k = i + 2; k < j; ++k) {
419 if (k >= tokenCount) {
420 break;
421 }
422 if (!tokens[k].is(tok::comment)) {
423 paramTokens.push_back(tokens[k]);
424 }
425 }
426 const size_t paramTokenCount = paramTokens.size();
427
428 // Check if we have actual parameters (excluding just "void")
429 if (paramTokenCount == 1) {
430 StringRef paramToken(
431 sourceMgr.getCharacterData(paramTokens[0].getLocation()), paramTokens[0].getLength()
432 );
433 if (paramToken != "void") {
434 hasParameters = true;
435 }
436 } else if (paramTokenCount > 1) {
437 hasParameters = true;
438 }
439
440 // Parse individual parameters
441 if (hasParameters) {
442 std::string currentParamType;
443 std::string currentParamName;
444 bool inDefaultValue = false;
445
446 for (size_t k = 0; k < paramTokenCount; ++k) {
447 // Check for end of current parameter
448 if (paramTokens[k].is(tok::comma)) {
449 // Add the current parameter if valid
450 if (!currentParamType.empty() && !currentParamName.empty()) {
451 parameters.push_back(MethodParameter(currentParamType, currentParamName));
452 }
453 currentParamType.clear();
454 currentParamName.clear();
455 inDefaultValue = false;
456 continue;
457 }
458 // Skip tokens that are part of the default value
459 if (inDefaultValue) {
460 continue;
461 }
462 // Check for '=' which indicates start of default value
463 if (paramTokens[k].is(tok::equal)) {
464 inDefaultValue = true;
465 continue;
466 }
467
468 StringRef tokenText(
469 sourceMgr.getCharacterData(paramTokens[k].getLocation()), paramTokens[k].getLength()
470 );
471
472 // Identifier token could be part of the type or the parameter name.
473 // Simple heuristic: last identifier before comma, equal, or end is the name
474 if (paramTokens[k].is(tok::raw_identifier)) {
475 if (k + 1 == paramTokenCount ||
476 (k + 1 < paramTokenCount && paramTokens[k + 1].isOneOf(tok::comma, tok::equal))) {
477 currentParamName = tokenText.str();
478 continue;
479 }
480 }
481
482 // Other identifiers and other tokens (keywords, ::, *, &, etc.) are part of type.
483 llvm::raw_string_ostream paramTypeStream(currentParamType);
484 if (!currentParamType.empty() && tokenText != "*" && tokenText != "&" &&
485 tokenText != "::" && !tokenText.starts_with("::") &&
486 !StringRef(currentParamType).ends_with("::")) {
487 paramTypeStream << ' ';
488 }
489 paramTypeStream << tokenText;
490 currentParamType = paramTypeStream.str();
491 }
492
493 // Add the last parameter if valid
494 if (!currentParamType.empty() && !currentParamName.empty()) {
495 parameters.push_back(MethodParameter(currentParamType, currentParamName));
496 }
497 }
498
499 return true;
500 }
501 }
502 }
503
504 // Couldn't find closing paren
505 return false;
506}
507
515static bool
516checkConstAndFindEnd(size_t closeParenIdx, const std::vector<Token> &tokens, size_t &endIdx) {
517 bool isConst = false;
518 endIdx = closeParenIdx + 1;
519
520 while (endIdx < tokens.size()) {
521 Token curr = tokens[endIdx];
522 if (curr.isOneOf(tok::semi, tok::l_brace)) {
523 break;
524 }
525 if (curr.is(tok::raw_identifier) && curr.getRawIdentifier() == "const") {
526 isConst = true;
527 }
528 endIdx++;
529 }
530
531 return isConst;
532}
533
552SmallVector<ExtraMethod> parseExtraMethods(StringRef extraDecl) {
553 if (extraDecl.empty()) {
554 return {};
555 }
556
557 // Use ClangLexerContext for simplified setup
558 const ClangLexerContext lexerCtx(extraDecl, "extraClassDecl");
559 if (!lexerCtx.isValid()) {
560 llvm::errs() << "Error: Failed to create lexer context for parseExtraMethods\n";
561 return {};
562 }
563
564 // Store methods uniqued by name to detect and skip overloads (duplicate method names).
565 llvm::StringMap<std::optional<ExtraMethod>> methods;
566
567 // Parse tokens to find method declarations
568 const std::vector<Token> tokens = tokenize(lexerCtx);
569 const size_t tokenCount = tokens.size();
570 const SourceManager &sourceMgr = lexerCtx.getSourceManager();
571
572 // Track current access level to avoid generating C API wrappers for private functions. Code
573 // generated by `mlir-tblgen` puts the extra declarations in the public section by default.
574 AccessLevel currentAccess = AccessLevel::Public;
575
576 for (size_t i = 0; i < tokenCount; ++i) {
577 // Skip comments (they'll be extracted separately)
578 if (tokens[i].is(tok::comment)) {
579 continue;
580 }
581
582 // Check for access specifier changes (e.g., "private:", "public:", "protected:").
583 if (updateAccessLevel(i, tokens, currentAccess)) {
584 continue;
585 }
586
587 // Skip private and protected methods - no need to generate C API wrappers
588 if (currentAccess != AccessLevel::Public) {
589 continue;
590 }
591
592 // Look for pattern: [modifiers] <return_type> <identifier> '(' [params] ')' [const] ';'
593 // Look for an identifier followed by '('
594 if (i + 1 < tokenCount && tokens[i + 1].is(tok::l_paren) && tokens[i].is(tok::raw_identifier)) {
595 StringRef methodName = tokens[i].getRawIdentifier();
596
597 // Skip control flow keywords and other language constructs that use parentheses
598 if (isCppLanguageConstruct(methodName)) {
599 continue;
600 }
601
602 // Extract return type (everything before method name)
603 size_t returnTypeStart = 0;
604 std::string returnType = extractReturnType(i, tokens, sourceMgr, returnTypeStart);
605
606 // Skip static methods (return type is empty if static)
607 if (returnType.empty()) {
608 continue;
609 }
610
611 // Parse method parameters
612 size_t closeParenIdx = tokenCount;
613 bool hasParameters = false;
614 std::vector<MethodParameter> parameters;
615 if (!parseMethodParameters(i, tokens, sourceMgr, closeParenIdx, hasParameters, parameters)) {
616 // Couldn't find closing paren, skip this method
617 continue;
618 }
619
620 // Check for 'const' and find declaration end
621 size_t endIdx;
622 bool isConst = checkConstAndFindEnd(closeParenIdx, tokens, endIdx);
623
624 // Create method struct
625 if (!returnType.empty() && !methodName.empty()) {
626 if (methods.contains(methodName)) {
627 warnSkipped(methodName, "C API does not support method overloading");
628 methods[methodName] = std::nullopt;
629 } else {
630 ExtraMethod method;
631 method.returnType = returnType;
632 method.methodName = methodName;
633 method.documentation = getDocumentation(returnTypeStart, tokens, sourceMgr);
634 method.isConst = isConst;
635 method.hasParameters = hasParameters;
636 method.parameters = parameters;
637 methods[methodName] = std::make_optional(method);
638 }
639 }
640
641 // Skip to end of this declaration for the next iteration.
642 i = endIdx;
643 }
644 }
645
646 // Return valid methods, skipping overloaded names (nullopt entries).
647 return llvm::to_vector(
648 llvm::map_range(
649 llvm::make_filter_range(methods, [](const auto &p) { return p.second.has_value(); }),
650 [](const auto &p) { return p.second.value(); }
651 )
652 );
653}
654
656bool matchesMLIRClass(StringRef cppType, StringRef typeName) {
657 if (cppType == typeName) {
658 return true;
659 }
660
661 // Check for "::mlir::" or "mlir::" prefix
662 StringRef prefix = cppType;
663 prefix.consume_front("::");
664 if (prefix.consume_front("mlir::")) {
665 return prefix == typeName;
666 }
667
668 return false;
669}
670
672std::optional<std::string> tryCppTypeToCapiType(StringRef cppType) {
673 cppType = cppType.trim();
674
675 // Primitive types are unchanged
676 if (isPrimitiveType(cppType)) {
677 return std::make_optional(cppType.str());
678 }
679
680 // APInt type is converted via llzk::fromAPInt()
681 if (isAPIntType(cppType)) {
682 return std::make_optional("int64_t");
683 }
684
685 // Pointer type conversions happen via the `unwrap()` function generated
686 // by `DEFINE_C_API_PTR_METHODS()` in `mlir/CAPI/IR.h`
687 if (cppType.ends_with(" *") || cppType.ends_with("*")) {
688 size_t starPos = cppType.rfind('*');
689 if (starPos != StringRef::npos) {
690 StringRef baseType = cppType.substr(0, starPos).trim();
691 if (matchesMLIRClass(baseType, "AsmState")) {
692 return std::make_optional("MlirAsmState");
693 }
694 if (matchesMLIRClass(baseType, "BytecodeWriterConfig")) {
695 return std::make_optional("MlirBytecodeWriterConfig");
696 }
697 if (matchesMLIRClass(baseType, "MLIRContext")) {
698 return std::make_optional("MlirContext");
699 }
700 if (matchesMLIRClass(baseType, "Dialect")) {
701 return std::make_optional("MlirDialect");
702 }
703 if (matchesMLIRClass(baseType, "DialectRegistry")) {
704 return std::make_optional("MlirDialectRegistry");
705 }
706 if (matchesMLIRClass(baseType, "Operation")) {
707 return std::make_optional("MlirOperation");
708 }
709 if (matchesMLIRClass(baseType, "Block")) {
710 return std::make_optional("MlirBlock");
711 }
712 if (matchesMLIRClass(baseType, "OpOperand")) {
713 return std::make_optional("MlirOpOperand");
714 }
715 if (matchesMLIRClass(baseType, "OpPrintingFlags")) {
716 return std::make_optional("MlirOpPrintingFlags");
717 }
718 if (matchesMLIRClass(baseType, "Region")) {
719 return std::make_optional("MlirRegion");
720 }
721 if (matchesMLIRClass(baseType, "SymbolTable")) {
722 return std::make_optional("MlirSymbolTable");
723 }
724 } else {
725 llvm::errs() << "Error: Failed to parse pointer type: " << cppType << '\n';
726 }
727 }
728
729 // These have `wrap()`/`unwrap()` generated by `DEFINE_C_API_METHODS()` in...
730 // ... `mlir/CAPI/IR.h`
731 if (matchesMLIRClass(cppType, "Attribute")) {
732 return std::make_optional("MlirAttribute");
733 }
734 if (matchesMLIRClass(cppType, "StringAttr")) {
735 return std::make_optional("MlirIdentifier");
736 }
737 if (matchesMLIRClass(cppType, "Location")) {
738 return std::make_optional("MlirLocation");
739 }
740 if (matchesMLIRClass(cppType, "ModuleOp")) {
741 return std::make_optional("MlirModule");
742 }
743 if (matchesMLIRClass(cppType, "Type")) {
744 return std::make_optional("MlirType");
745 }
746 if (matchesMLIRClass(cppType, "Value")) {
747 return std::make_optional("MlirValue");
748 }
749 // ... `mlir/CAPI/AffineExpr.h`
750 if (matchesMLIRClass(cppType, "AffineExpr")) {
751 return std::make_optional("MlirAffineExpr");
752 }
753 // ... `mlir/CAPI/AffineMap.h`
754 if (matchesMLIRClass(cppType, "AffineMap")) {
755 return std::make_optional("MlirAffineMap");
756 }
757 // ... `mlir/CAPI/IntegerSet.h`
758 if (matchesMLIRClass(cppType, "IntegerSet")) {
759 return std::make_optional("MlirIntegerSet");
760 }
761 // ... `mlir/CAPI/Support.h`
762 if (matchesMLIRClass(cppType, "TypeID")) {
763 return std::make_optional("MlirTypeID");
764 }
765
766 // These have `wrap()`/`unwrap()` manually defined in `mlir/CAPI/Support.h`
767 if (matchesMLIRClass(cppType, "StringRef")) {
768 return std::make_optional("MlirStringRef");
769 }
770 if (matchesMLIRClass(cppType, "LogicalResult")) {
771 return std::make_optional("MlirLogicalResult");
772 }
773
774 // Heuristically map custom dialect classes to their C API equivalents
775 if (cppType.ends_with("Type")) {
776 return std::make_optional("MlirType");
777 }
778 if (cppType.ends_with("Attr")) {
779 return std::make_optional("MlirAttribute");
780 }
781 if (cppType.ends_with("Op")) {
782 return std::make_optional("MlirOperation");
783 }
784
785 // Otherwise, not sure how to convert it
786 return std::nullopt;
787}
788
789// Map C++ type to corresponding C API type
790std::string mapCppTypeToCapiType(StringRef cppType) {
791 assert(!isArrayRefType(cppType) && "must check `isArrayRefType()` outside");
792
793 std::optional<std::string> capiTypeOpt = tryCppTypeToCapiType(cppType);
794 if (capiTypeOpt.has_value()) {
795 return capiTypeOpt.value();
796 }
797
798 // Otherwise assume it's a type where the C name is a direct translation from the C++ name.
799 return toPascalCase(cppType);
800}
std::optional< std::string > tryCppTypeToCapiType(StringRef cppType)
Convert C++ type to MLIR C API type.
std::string mapCppTypeToCapiType(StringRef cppType)
AccessLevel
Access level tracking for C++ class declarations.
SmallVector< ExtraMethod > parseExtraMethods(StringRef extraDecl)
Parse method declarations from extraClassDeclaration using Clang's Lexer.
bool matchesMLIRClass(StringRef cppType, StringRef typeName)
Check if a C++ type matches an MLIR type pattern.
llvm::cl::OptionCategory OpGenCat
llvm::cl::opt< bool > GenOpOperandSetters
llvm::cl::opt< bool > GenTypeOrAttrParamGetters
bool isPrimitiveType(mlir::StringRef cppType)
Check if a C++ type is a known primitive type.
llvm::cl::opt< bool > GenTypeOrAttrGet
llvm::cl::opt< bool > GenIsA
llvm::cl::opt< bool > GenOpBuild
llvm::cl::opt< std::string > DialectName
bool isCppModifierKeyword(mlir::StringRef tokenText)
Check if a token text represents a C++ modifier/specifier keyword.
bool isAPIntType(mlir::StringRef cppType)
Check if a C++ type is APInt.
llvm::cl::opt< std::string > FunctionPrefix
bool isArrayRefType(mlir::StringRef cppType)
Check if a C++ type is an ArrayRef type.
llvm::cl::opt< bool > GenOpRegionGetters
llvm::cl::opt< bool > GenOpResultGetters
llvm::cl::opt< bool > GenOpAttributeGetters
llvm::cl::opt< bool > GenOpAttributeSetters
llvm::cl::opt< bool > GenOpOperandGetters
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
void warnSkipped(const S &methodName, const std::string &message)
Print warning about skipping a function.
bool isCppLanguageConstruct(mlir::StringRef methodName)
Check if a method name represents a C++ control flow keyword or language construct.
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation source
Definition LICENSE.txt:28
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated documentation
Definition LICENSE.txt:32
RAII wrapper for Clang lexer infrastructure.
clang::SourceManager & getSourceManager() const
Get the source manager instance.
bool isValid() const
Check if the lexer was successfully created.
ClangLexerContext(mlir::StringRef source, mlir::StringRef bufferName="input")
Construct a lexer context for the given source code.
clang::Lexer & getLexer() const
Get the lexer instance.
IntrusiveRefCntPtr< DiagnosticOptions > diagOpts
Diagnostic options for configuring diagnostics.
IntrusiveRefCntPtr< FileManager > fileMgr
File manager for handling virtual files.
IntrusiveRefCntPtr< DiagnosticIDs > diagIDs
Diagnostic IDs for error reporting.
LangOptions langOpts
C++ language options for lexer configuration.
std::unique_ptr< SourceManager > sourceMgr
Source manager for tracking file locations.
std::unique_ptr< Lexer > lexer
The actual lexer instance.
std::unique_ptr< DiagnosticsEngine > diags
Diagnostics engine for handling errors and warnings.
Structure to represent a parsed method signature from an extraClassDeclaration
bool isConst
Whether the method is const-qualified.
bool hasParameters
Whether the method has parameters (unsupported for now)
std::vector< MethodParameter > parameters
The parameters of the method.
std::string returnType
The C++ return type of the method.
std::string methodName
The name of the method.
std::string documentation
Properly escaped documentation comment (if any)
Structure to represent a parameter in a parsed method signature from an extraClassDeclaration