16#include <llvm/ADT/StringMap.h>
18#include <clang/Basic/FileManager.h>
19#include <clang/Basic/LangOptions.h>
20#include <clang/Basic/SourceManager.h>
21#include <clang/Lex/Lexer.h>
27llvm::cl::OptionCategory
28 OpGenCat(
"Options for -gen-op-capi-header, -gen-op-capi-impl, and -gen-op-capi-tests");
33 "The dialect name to use for this group of ops. "
34 "Must match across header, implementation, and test generation."
42 "The prefix to use for generated C API function names. "
43 "Default is 'mlir'. Must match across header, implementation, and test generation."
45 llvm::cl::init(
"mlir"), llvm::cl::cat(
OpGenCat)
49 "gen-isa", llvm::cl::desc(
"Generate IsA checks"), llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
53 "gen-op-build", llvm::cl::desc(
"Generate operation build(..) functions"), llvm::cl::init(
true),
58 "gen-operand-getters", llvm::cl::desc(
"Generate operand getters for operations"),
59 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
63 "gen-operand-setters", llvm::cl::desc(
"Generate operand setters for operations"),
64 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
68 "gen-attribute-getters", llvm::cl::desc(
"Generate attribute getters for operations"),
69 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
73 "gen-attribute-setters", llvm::cl::desc(
"Generate attribute setters for operations"),
74 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
78 "gen-region-getters", llvm::cl::desc(
"Generate region getters for operations"),
79 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
83 "gen-result-getters", llvm::cl::desc(
"Generate result getters for operations"),
84 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
88 "gen-type-attr-get", llvm::cl::desc(
"Generate get functions for types and attributes"),
89 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
93 "gen-parameter-getters", llvm::cl::desc(
"Generate parameter getters for types and attributes"),
94 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
98 "gen-extra-class-methods",
99 llvm::cl::desc(
"Generate C API wrappers for methods in `extraClassDeclaration`"),
100 llvm::cl::init(
true), llvm::cl::cat(
OpGenCat)
116 std::unique_ptr<DiagnosticsEngine>
diags;
127 FileSystemOptions fileSystemOpts;
128 fileMgr =
new FileManager(fileSystemOpts);
135 : impl(std::make_unique<
Impl>()) {
137 llvm::errs() <<
"Warning: ClangLexerContext created with empty source\n";
142 std::unique_ptr<llvm::MemoryBuffer> buffer = llvm::MemoryBuffer::getMemBuffer(
source, bufferName);
144 llvm::errs() <<
"Error: Failed to create memory buffer for ClangLexerContext\n";
148 FileID fileID = impl->sourceMgr->createFileID(std::move(buffer), SrcMgr::C_User);
149 llvm::MemoryBufferRef bufferRef = impl->sourceMgr->getBufferOrFake(fileID);
151 if (bufferRef.getBufferSize() == 0 && !
source.empty()) {
152 llvm::errs() <<
"Error: Failed to get buffer from source manager in ClangLexerContext\n";
157 impl->lexer = std::make_unique<Lexer>(fileID, bufferRef, *impl->sourceMgr, impl->langOpts);
159 impl->lexer->SetCommentRetentionState(
true);
160 lexer = impl->lexer.get();
164 assert(lexer &&
"Lexer not initialized - check isValid() before calling getLexer()");
170 impl && impl->sourceMgr &&
171 "SourceManager not initialized - check isValid() before calling getSourceManager()"
173 return *impl->sourceMgr;
178static inline bool isAccessModifier(StringRef tokenText) {
179 return tokenText ==
"private" || tokenText ==
"public" || tokenText ==
"protected";
187 std::vector<Token> tokens;
189 for (Token tok; !lexer.LexFromRawLexer(tok);) {
190 if (tok.is(tok::eof)) {
193 tokens.push_back(tok);
199static inline std::string getDocumentation(
200 size_t returnTypeStart,
const std::vector<Token> &tokens,
const SourceManager &sourceMgr
203 for (
size_t j = returnTypeStart; j > 0; --j) {
204 Token curr = tokens[j - 1];
205 if (curr.is(tok::comment)) {
206 StringRef comment(sourceMgr.getCharacterData(curr.getLocation()), curr.getLength());
207 comment.consume_front(
"///");
208 comment.consume_front(
"//");
209 if (comment.consume_front(
"/*")) {
210 comment.consume_back(
"*/");
212 comment = comment.trim();
215 if (!comment.empty()) {
216 std::string newDoc(
"/// ");
224 }
else if (!curr.is(tok::unknown)) {
227 if (curr.isOneOf(tok::semi, tok::r_brace, tok::l_brace)) {
230 if (curr.is(tok::raw_identifier) && isAccessModifier(curr.getRawIdentifier())) {
255updateAccessLevel(
size_t &i,
const std::vector<Token> &tokens,
AccessLevel ¤tAccess) {
256 if (i + 1 < tokens.size() && tokens[i + 1].is(tok::colon)) {
257 if (tokens[i].is(tok::raw_identifier)) {
258 StringRef name = tokens[i].getRawIdentifier();
259 if (name ==
"private") {
263 }
else if (name ==
"public") {
267 }
else if (name ==
"protected") {
287static std::string extractReturnType(
288 size_t i,
const std::vector<Token> &tokens,
const SourceManager &sourceMgr,
289 size_t &returnTypeStart
291 std::string returnType;
293 bool isStaticMethod =
false;
296 for (
size_t j = i; j > 0; --j) {
297 Token curr = tokens[j - 1];
299 if (curr.isOneOf(tok::semi, tok::r_brace)) {
305 if (curr.is(tok::raw_identifier)) {
306 StringRef text = curr.getRawIdentifier();
307 if (text ==
"static") {
308 isStaticMethod =
true;
311 if (tokens[j].is(tok::colon) && isAccessModifier(text)) {
313 returnTypeStart = j + 1;
314 assert(returnTypeStart < tokens.size());
321 if (isStaticMethod) {
327 while (tokens[returnTypeStart].is(tok::comment)) {
329 assert(returnTypeStart <= i);
333 returnType.reserve(32);
334 llvm::raw_string_ostream returnTypeStream(returnType);
336 for (
size_t j = returnTypeStart; j < i; ++j) {
338 if (tokens[j].is(tok::comment)) {
342 StringRef tokenText(sourceMgr.getCharacterData(tokens[j].getLocation()), tokens[j].getLength());
345 if (tokens[j].is(tok::raw_identifier) && isAccessModifier(tokenText)) {
347 if (j + 1 < i && tokens[j + 1].is(tok::colon)) {
354 if (tokens[j].is(tok::raw_identifier) && tokenText ==
"return") {
366 if (tokens[j].is(tok::colon)) {
368 if (j == 0 || j + 1 >= i || !tokens[j - 1].is(tok::colon)) {
374 if (!returnType.empty() && !returnType.ends_with(
"::") && tokenText !=
"::" &&
375 !tokenText.starts_with(
"::")) {
376 returnTypeStream <<
' ';
378 returnTypeStream << tokenText;
382 return StringRef(returnType).trim().str();
395static bool parseMethodParameters(
396 size_t i,
const std::vector<Token> &tokens,
const SourceManager &sourceMgr,
397 size_t &closeParenIdx,
bool &hasParameters, std::vector<MethodParameter> ¶meters
399 const size_t tokenCount = tokens.size();
400 closeParenIdx = tokenCount;
401 hasParameters =
false;
406 size_t parenDepth = 1;
407 for (
size_t j = i + 2; j < tokenCount; ++j) {
408 if (tokens[j].is(tok::l_paren)) {
410 }
else if (tokens[j].is(tok::r_paren)) {
412 if (parenDepth == 0) {
417 std::vector<Token> paramTokens;
418 for (
size_t k = i + 2; k < j; ++k) {
419 if (k >= tokenCount) {
422 if (!tokens[k].is(tok::comment)) {
423 paramTokens.push_back(tokens[k]);
426 const size_t paramTokenCount = paramTokens.size();
429 if (paramTokenCount == 1) {
430 StringRef paramToken(
431 sourceMgr.getCharacterData(paramTokens[0].getLocation()), paramTokens[0].getLength()
433 if (paramToken !=
"void") {
434 hasParameters =
true;
436 }
else if (paramTokenCount > 1) {
437 hasParameters =
true;
442 std::string currentParamType;
443 std::string currentParamName;
444 bool inDefaultValue =
false;
446 for (
size_t k = 0; k < paramTokenCount; ++k) {
448 if (paramTokens[k].is(tok::comma)) {
450 if (!currentParamType.empty() && !currentParamName.empty()) {
451 parameters.push_back(
MethodParameter(currentParamType, currentParamName));
453 currentParamType.clear();
454 currentParamName.clear();
455 inDefaultValue =
false;
459 if (inDefaultValue) {
463 if (paramTokens[k].is(tok::equal)) {
464 inDefaultValue =
true;
469 sourceMgr.getCharacterData(paramTokens[k].getLocation()), paramTokens[k].getLength()
474 if (paramTokens[k].is(tok::raw_identifier)) {
475 if (k + 1 == paramTokenCount ||
476 (k + 1 < paramTokenCount && paramTokens[k + 1].isOneOf(tok::comma, tok::equal))) {
477 currentParamName = tokenText.str();
483 llvm::raw_string_ostream paramTypeStream(currentParamType);
484 if (!currentParamType.empty() && tokenText !=
"*" && tokenText !=
"&" &&
485 tokenText !=
"::" && !tokenText.starts_with(
"::") &&
486 !StringRef(currentParamType).ends_with(
"::")) {
487 paramTypeStream <<
' ';
489 paramTypeStream << tokenText;
490 currentParamType = paramTypeStream.str();
494 if (!currentParamType.empty() && !currentParamName.empty()) {
495 parameters.push_back(
MethodParameter(currentParamType, currentParamName));
516checkConstAndFindEnd(
size_t closeParenIdx,
const std::vector<Token> &tokens,
size_t &endIdx) {
517 bool isConst =
false;
518 endIdx = closeParenIdx + 1;
520 while (endIdx < tokens.size()) {
521 Token curr = tokens[endIdx];
522 if (curr.isOneOf(tok::semi, tok::l_brace)) {
525 if (curr.is(tok::raw_identifier) && curr.getRawIdentifier() ==
"const") {
553 if (extraDecl.empty()) {
560 llvm::errs() <<
"Error: Failed to create lexer context for parseExtraMethods\n";
565 llvm::StringMap<std::optional<ExtraMethod>> methods;
568 const std::vector<Token> tokens = tokenize(lexerCtx);
569 const size_t tokenCount = tokens.size();
576 for (
size_t i = 0; i < tokenCount; ++i) {
578 if (tokens[i].is(tok::comment)) {
583 if (updateAccessLevel(i, tokens, currentAccess)) {
594 if (i + 1 < tokenCount && tokens[i + 1].is(tok::l_paren) && tokens[i].is(tok::raw_identifier)) {
595 StringRef methodName = tokens[i].getRawIdentifier();
603 size_t returnTypeStart = 0;
604 std::string returnType = extractReturnType(i, tokens, sourceMgr, returnTypeStart);
607 if (returnType.empty()) {
612 size_t closeParenIdx = tokenCount;
613 bool hasParameters =
false;
614 std::vector<MethodParameter> parameters;
615 if (!parseMethodParameters(i, tokens, sourceMgr, closeParenIdx, hasParameters, parameters)) {
622 bool isConst = checkConstAndFindEnd(closeParenIdx, tokens, endIdx);
625 if (!returnType.empty() && !methodName.empty()) {
626 if (methods.contains(methodName)) {
627 warnSkipped(methodName,
"C API does not support method overloading");
628 methods[methodName] = std::nullopt;
633 method.
documentation = getDocumentation(returnTypeStart, tokens, sourceMgr);
637 methods[methodName] = std::make_optional(method);
647 return llvm::to_vector(
649 llvm::make_filter_range(methods, [](
const auto &p) {
return p.second.has_value(); }),
650 [](
const auto &p) {
return p.second.value(); }
657 if (cppType == typeName) {
662 StringRef prefix = cppType;
663 prefix.consume_front(
"::");
664 if (prefix.consume_front(
"mlir::")) {
665 return prefix == typeName;
673 cppType = cppType.trim();
677 return std::make_optional(cppType.str());
682 return std::make_optional(
"int64_t");
687 if (cppType.ends_with(
" *") || cppType.ends_with(
"*")) {
688 size_t starPos = cppType.rfind(
'*');
689 if (starPos != StringRef::npos) {
690 StringRef baseType = cppType.substr(0, starPos).trim();
692 return std::make_optional(
"MlirAsmState");
695 return std::make_optional(
"MlirBytecodeWriterConfig");
698 return std::make_optional(
"MlirContext");
701 return std::make_optional(
"MlirDialect");
704 return std::make_optional(
"MlirDialectRegistry");
707 return std::make_optional(
"MlirOperation");
710 return std::make_optional(
"MlirBlock");
713 return std::make_optional(
"MlirOpOperand");
716 return std::make_optional(
"MlirOpPrintingFlags");
719 return std::make_optional(
"MlirRegion");
722 return std::make_optional(
"MlirSymbolTable");
725 llvm::errs() <<
"Error: Failed to parse pointer type: " << cppType <<
'\n';
732 return std::make_optional(
"MlirAttribute");
735 return std::make_optional(
"MlirIdentifier");
738 return std::make_optional(
"MlirLocation");
741 return std::make_optional(
"MlirModule");
744 return std::make_optional(
"MlirType");
747 return std::make_optional(
"MlirValue");
751 return std::make_optional(
"MlirAffineExpr");
755 return std::make_optional(
"MlirAffineMap");
759 return std::make_optional(
"MlirIntegerSet");
763 return std::make_optional(
"MlirTypeID");
768 return std::make_optional(
"MlirStringRef");
771 return std::make_optional(
"MlirLogicalResult");
775 if (cppType.ends_with(
"Type")) {
776 return std::make_optional(
"MlirType");
778 if (cppType.ends_with(
"Attr")) {
779 return std::make_optional(
"MlirAttribute");
781 if (cppType.ends_with(
"Op")) {
782 return std::make_optional(
"MlirOperation");
791 assert(!
isArrayRefType(cppType) &&
"must check `isArrayRefType()` outside");
794 if (capiTypeOpt.has_value()) {
795 return capiTypeOpt.value();
std::optional< std::string > tryCppTypeToCapiType(StringRef cppType)
Convert C++ type to MLIR C API type.
std::string mapCppTypeToCapiType(StringRef cppType)
AccessLevel
Access level tracking for C++ class declarations.
SmallVector< ExtraMethod > parseExtraMethods(StringRef extraDecl)
Parse method declarations from extraClassDeclaration using Clang's Lexer.
bool matchesMLIRClass(StringRef cppType, StringRef typeName)
Check if a C++ type matches an MLIR type pattern.
llvm::cl::OptionCategory OpGenCat
llvm::cl::opt< bool > GenOpOperandSetters
llvm::cl::opt< bool > GenTypeOrAttrParamGetters
bool isPrimitiveType(mlir::StringRef cppType)
Check if a C++ type is a known primitive type.
llvm::cl::opt< bool > GenTypeOrAttrGet
llvm::cl::opt< bool > GenIsA
llvm::cl::opt< bool > GenOpBuild
llvm::cl::opt< std::string > DialectName
bool isCppModifierKeyword(mlir::StringRef tokenText)
Check if a token text represents a C++ modifier/specifier keyword.
bool isAPIntType(mlir::StringRef cppType)
Check if a C++ type is APInt.
llvm::cl::opt< std::string > FunctionPrefix
bool isArrayRefType(mlir::StringRef cppType)
Check if a C++ type is an ArrayRef type.
llvm::cl::opt< bool > GenOpRegionGetters
llvm::cl::opt< bool > GenOpResultGetters
llvm::cl::opt< bool > GenOpAttributeGetters
llvm::cl::opt< bool > GenOpAttributeSetters
llvm::cl::opt< bool > GenOpOperandGetters
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
void warnSkipped(const S &methodName, const std::string &message)
Print warning about skipping a function.
bool isCppLanguageConstruct(mlir::StringRef methodName)
Check if a method name represents a C++ control flow keyword or language construct.
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation source
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated documentation
RAII wrapper for Clang lexer infrastructure.
clang::SourceManager & getSourceManager() const
Get the source manager instance.
bool isValid() const
Check if the lexer was successfully created.
ClangLexerContext(mlir::StringRef source, mlir::StringRef bufferName="input")
Construct a lexer context for the given source code.
clang::Lexer & getLexer() const
Get the lexer instance.
IntrusiveRefCntPtr< DiagnosticOptions > diagOpts
Diagnostic options for configuring diagnostics.
IntrusiveRefCntPtr< FileManager > fileMgr
File manager for handling virtual files.
IntrusiveRefCntPtr< DiagnosticIDs > diagIDs
Diagnostic IDs for error reporting.
LangOptions langOpts
C++ language options for lexer configuration.
std::unique_ptr< SourceManager > sourceMgr
Source manager for tracking file locations.
std::unique_ptr< Lexer > lexer
The actual lexer instance.
std::unique_ptr< DiagnosticsEngine > diags
Diagnostics engine for handling errors and warnings.
Structure to represent a parameter in a parsed method signature from an extraClassDeclaration