17#include <mlir/TableGen/GenInfo.h>
18#include <mlir/TableGen/Operator.h>
20#include <llvm/Support/CommandLine.h>
21#include <llvm/Support/FormatVariadic.h>
22#include <llvm/TableGen/Error.h>
23#include <llvm/TableGen/Record.h>
24#include <llvm/TableGen/TableGenBackend.h>
30using namespace mlir::tblgen;
38 name.empty() ? llvm::formatv(
"Result{0}", resultIndex).str() :
toPascalCase(name);
42 name.empty() ? llvm::formatv(
"Region{0}", regionIndex).str() :
toPascalCase(name);
54 using HeaderGenerator::HeaderGenerator;
58 static constexpr char fmt[] = R
"(
60MLIR_CAPI_EXPORTED MlirOperation {0}{1}_{2}Build(MlirOpBuilder builder, MlirLocation location{3});
62 assert(!
className.empty() &&
"className must be set");
74 static constexpr char fmt[] = R
"(
76MLIR_CAPI_EXPORTED MlirValue {0}{1}_{2}Get{3}(MlirOperation op);
78 assert(!
className.empty() &&
"className must be set");
91 static constexpr char fmt[] = R
"(
93MLIR_CAPI_EXPORTED void {0}{1}_{2}Set{3}(MlirOperation op, MlirValue value);
95 assert(!
className.empty() &&
"className must be set");
108 static constexpr char fmt[] = R
"(
110MLIR_CAPI_EXPORTED intptr_t {0}{1}_{2}Get{3}Count(MlirOperation op);
113MLIR_CAPI_EXPORTED MlirValue {0}{1}_{2}Get{3}At(MlirOperation op, intptr_t index);
115 assert(!
className.empty() &&
"className must be set");
129 static constexpr char fmt[] = R
"(
131MLIR_CAPI_EXPORTED void {0}{1}_{2}Set{3}(MlirOperation op, intptr_t count, MlirValue const *values);
133 assert(!
className.empty() &&
"className must be set");
147 static constexpr char fmt[] = R
"(
151MLIR_CAPI_EXPORTED void {0}{1}_{2}Set{3}(MlirOperation op, intptr_t groupCount, MlirValueRange const *groups);
153 assert(!
className.empty() &&
"className must be set");
166 static constexpr char fmt[] = R
"(
168MLIR_CAPI_EXPORTED MlirAttribute {0}{1}_{2}Get{3}(MlirOperation op);
183 static constexpr char fmt[] = R
"(
185MLIR_CAPI_EXPORTED void {0}{1}_{2}Set{3}(MlirOperation op, MlirAttribute attr);
187 assert(!
className.empty() &&
"className must be set");
200 static constexpr char fmt[] = R
"(
202MLIR_CAPI_EXPORTED MlirValue {0}{1}_{2}Get{3}(MlirOperation op);
204 assert(!
className.empty() &&
"className must be set");
217 static constexpr char fmt[] = R
"(
219MLIR_CAPI_EXPORTED intptr_t {0}{1}_{2}Get{3}Count(MlirOperation op);
222MLIR_CAPI_EXPORTED MlirValue {0}{1}_{2}Get{3}At(MlirOperation op, intptr_t index);
224 assert(!
className.empty() &&
"className must be set");
237 static constexpr char fmt[] = R
"(
239MLIR_CAPI_EXPORTED MlirRegion {0}{1}_{2}Get{3}(MlirOperation op);
241 assert(!
className.empty() &&
"className must be set");
254 static constexpr char fmt[] = R
"(
256MLIR_CAPI_EXPORTED intptr_t {0}{1}_{2}Get{3}Count(MlirOperation op);
259MLIR_CAPI_EXPORTED MlirRegion {0}{1}_{2}Get{3}At(MlirOperation op, intptr_t index);
261 assert(!
className.empty() &&
"className must be set");
279static std::string generateCAPIBuildParams(
const Operator &op) {
282 llvm::raw_ostream &os,
const NamedTypeConstraint &result,
const std::string &resultName
284 if (result.isVariadic()) {
285 os << llvm::formatv(
", intptr_t {0}Size, MlirType const *{0}Types", resultName);
287 os << llvm::formatv(
", MlirType {0}Type", resultName);
290 void genOperand(llvm::raw_ostream &os,
const NamedTypeConstraint &operand)
override {
291 if (operand.isVariadic()) {
292 os << llvm::formatv(
", intptr_t {0}Size, MlirValue const *{0}", operand.name);
294 os << llvm::formatv(
", MlirValue {0}", operand.name);
297 void genAttribute(llvm::raw_ostream &os,
const NamedAttribute &attr)
override {
299 os << llvm::formatv(
", {0} {1}", attrType.value_or(
"MlirAttribute"), attr.name);
301 void genRegion(llvm::raw_ostream &os,
const mlir::tblgen::NamedRegion ®ion)
override {
302 if (region.isVariadic()) {
303 os << llvm::formatv(
", unsigned {0}Count", region.name);
306 } paramStringGenerator;
307 return paramStringGenerator.gen(op);
311static bool emitOpCAPIHeader(
const llvm::RecordKeeper &records, raw_ostream &os) {
312 emitSourceFileHeader(
"Op C API Declarations", os, records);
315 generator.genPrologue();
317 for (
const auto *def : records.getAllDerivedDefinitions(
"Op")) {
318 const Operator op(def);
319 const Dialect &dialect = op.getDialect();
326 generator.setNamespaceAndClassName(dialect, op.getCppClassName());
329 if (
GenOpBuild && !op.skipDefaultBuilders()) {
330 generator.genOpBuildDecl(generateCAPIBuildParams(op));
335 generator.genIsADecl();
339 for (
int i = 0, e = op.getNumOperands(); i < e; ++i) {
340 const auto &operand = op.getOperand(i);
341 generator.setOperandName(operand.name);
342 if (operand.isVariadic()) {
344 generator.genVariadicOperandGetterDecl();
347 if (operand.isVariadicOfVariadic()) {
348 generator.genVariadicOfVariadicOperandSetterDecl();
350 generator.genVariadicOperandSetterDecl();
355 generator.genOperandGetterDecl();
358 generator.genOperandSetterDecl();
364 for (
const auto &namedAttr : op.getAttributes()) {
365 generator.setAttributeName(namedAttr.name);
367 generator.genAttributeGetterDecl();
370 generator.genAttributeSetterDecl();
376 for (
int i = 0, e = op.getNumResults(); i < e; ++i) {
377 const auto &result = op.getResult(i);
378 generator.setResultName(result.name, i);
379 if (result.isVariadic()) {
380 generator.genVariadicResultGetterDecl();
382 generator.genResultGetterDecl();
389 for (
unsigned i = 0, e = op.getNumRegions(); i < e; ++i) {
390 const auto ®ion = op.getRegion(i);
391 generator.setRegionName(region.name, i);
392 if (region.isVariadic()) {
393 generator.genVariadicRegionGetterDecl();
395 generator.genRegionGetterDecl();
402 generator.genExtraMethods(op.getExtraClassDeclaration());
406 generator.genEpilogue();
412 using ImplementationGenerator::ImplementationGenerator;
429 const std::string &operationName,
const std::string ¶ms,
const std::string &assignments
431 static constexpr char fmt[] = R
"(
432MlirOperation {0}{1}_{2}Build(MlirOpBuilder builder, MlirLocation location{3}) {{
433 MlirOperationState state = mlirOperationStateGet(mlirStringRefCreateFromCString("{4}"), location);
435 return mlirOpBuilderInsert(builder, mlirOperationCreate(&state));
438 assert(!className.empty() && "className must be set");
451 static constexpr char fmt[] = R
"(
452MlirValue {0}{1}_{2}Get{3}(MlirOperation op) {{
453 auto range = llvm::cast<{2}>(unwrap(op)).getODSOperandIndexAndLength({4});
454 assert(range.second == 1 && "expected fixed operand segment size");
456 static_cast<uintptr_t>(range.first) <= static_cast<uintptr_t>(std::numeric_limits<intptr_t>::max()) &&
457 "operand index exceeds intptr_t range"
459 return mlirOperationGetOperand(op, static_cast<intptr_t>(range.first));
462 assert(!className.empty() && "className must be set");
475 static constexpr char fmt[] = R
"(
476void {0}{1}_{2}Set{3}(MlirOperation op, MlirValue value) {{
477 auto range = llvm::cast<{2}>(unwrap(op)).getODSOperandIndexAndLength({4});
478 assert(range.second == 1 && "expected fixed operand segment size");
480 static_cast<uintptr_t>(range.first) <= static_cast<uintptr_t>(std::numeric_limits<intptr_t>::max()) &&
481 "operand index exceeds intptr_t range"
483 mlirOperationSetOperand(op, static_cast<intptr_t>(range.first), value);
486 assert(!className.empty() && "className must be set");
499 static constexpr char fmt[] = R
"(
500intptr_t {0}{1}_{2}Get{3}Count(MlirOperation op) {{
501 auto range = llvm::cast<{2}>(unwrap(op)).getODSOperandIndexAndLength({4});
505MlirValue {0}{1}_{2}Get{3}At(MlirOperation op, intptr_t index) {{
506 auto range = llvm::cast<{2}>(unwrap(op)).getODSOperandIndexAndLength({4});
507 assert(index >= 0 && index < range.second && "variadic operand index out of range");
509 static_cast<uintptr_t>(range.first) <= static_cast<uintptr_t>(std::numeric_limits<intptr_t>::max()) &&
510 "operand index exceeds intptr_t range"
512 return mlirOperationGetOperand(op, static_cast<intptr_t>(range.first) + index);
515 assert(!className.empty() && "className must be set");
530 static constexpr char fmt[] = R
"(
531void {0}{1}_{2}Set{3}(MlirOperation op, intptr_t count, MlirValue const *values) {{
534 ::llvm::SmallVector<::mlir::Value> vals;
535 vals.reserve(static_cast<size_t>(count));
536 for (intptr_t i = 0; i < count; ++i)
537 vals.push_back(unwrap(values[i]));
538 ::llvm::cast<{2}>(unwrap(op)).get{3}Mutable().assign(vals);
541 assert(!className.empty() && "className must be set");
557 static constexpr char fmt[] = R
"(
558void {0}{1}_{2}Set{3}(MlirOperation op, intptr_t groupCount, MlirValueRange const *groups) {{
562 ::llvm::SmallVector<::mlir::Value> vals;
563 for (intptr_t g = 0; g < groupCount; ++g) {{
564 assert(groups[g].size >= 0 && "group size must be non-negative");
565 for (intptr_t i = 0; i < groups[g].size; ++i) {{
566 vals.push_back(unwrap(groups[g].values[i]));
569 ::llvm::cast<{2}>(unwrap(op)).get{3}Mutable().join().assign(vals);
571 ::llvm::SmallVector<int32_t> newGroupSizes;
572 newGroupSizes.reserve(static_cast<size_t>(groupCount));
573 for (intptr_t g = 0; g < groupCount; ++g) {{
575 groups[g].size <= static_cast<intptr_t>(std::numeric_limits<int32_t>::max()) &&
576 "group size exceeds int32_t range"
578 newGroupSizes.push_back(static_cast<int32_t>(groups[g].size));
580 MlirContext ctx = mlirOperationGetContext(op);
582 newGroupSizes.size() <= static_cast<size_t>(std::numeric_limits<intptr_t>::max()) &&
583 "group count exceeds intptr_t range"
585 mlirOperationSetAttributeByName(
586 op, mlirStringRefCreateFromCString("{4}"),
587 mlirDenseI32ArrayGet(ctx, static_cast<intptr_t>(newGroupSizes.size()), newGroupSizes.data())
591 assert(!className.empty() && "className must be set");
604 static constexpr char fmt[] = R
"(
605MlirAttribute {0}{1}_{2}Get{3}(MlirOperation op) {{
606 return mlirOperationGetAttributeByName(op, mlirStringRefCreateFromCString("{4}"));
609 assert(!className.empty() && "className must be set");
622 static constexpr char fmt[] = R
"(
623void {0}{1}_{2}Set{3}(MlirOperation op, MlirAttribute attr) {{
624 mlirOperationSetAttributeByName(op, mlirStringRefCreateFromCString("{4}"), attr);
627 assert(!className.empty() && "className must be set");
640 static constexpr char fmt[] = R
"(
641MlirValue {0}{1}_{2}Get{3}(MlirOperation op) {{
642 return mlirOperationGetResult(op, {4});
645 assert(!className.empty() && "className must be set");
658 static constexpr char fmt[] = R
"(
659intptr_t {0}{1}_{2}Get{3}Count(MlirOperation op) {{
660 intptr_t count = mlirOperationGetNumResults(op);
661 assert(count >= {4} && "result count less than start index");
665MlirValue {0}{1}_{2}Get{3}At(MlirOperation op, intptr_t index) {{
666 return mlirOperationGetResult(op, {4} + index);
669 assert(!className.empty() && "className must be set");
682 static constexpr char fmt[] = R
"(
683MlirRegion {0}{1}_{2}Get{3}(MlirOperation op) {{
684 return mlirOperationGetRegion(op, {4});
687 assert(!className.empty() && "className must be set");
700 static constexpr char fmt[] = R
"(
701intptr_t {0}{1}_{2}Get{3}Count(MlirOperation op) {{
702 intptr_t count = mlirOperationGetNumRegions(op);
703 assert(count >= {4} && "region count less than start index");
707MlirRegion {0}{1}_{2}Get{3}At(MlirOperation op, intptr_t index) {{
708 return mlirOperationGetRegion(op, {4} + index);
711 assert(!className.empty() && "className must be set");
729static std::string generateCAPIAssignments(
const Operator &op) {
736 void genResultInferred(llvm::raw_ostream &os)
override {
737 os <<
" mlirOperationStateEnableResultTypeInference(&state);\n";
740 llvm::raw_ostream &os,
const NamedTypeConstraint &result,
const std::string &resultName
742 if (result.isVariadic()) {
744 " mlirOperationStateAddResults(&state, {0}Size, {0}Types);\n", resultName
747 os << llvm::formatv(
" mlirOperationStateAddResults(&state, 1, &{0}Type);\n", resultName);
750 void genOperand(llvm::raw_ostream &os,
const NamedTypeConstraint &operand)
override {
751 if (operand.isVariadic()) {
753 " mlirOperationStateAddOperands(&state, {0}Size, {0});\n", operand.name
756 os << llvm::formatv(
" mlirOperationStateAddOperands(&state, 1, &{0});\n", operand.name);
759 void genAttributesPrefix(llvm::raw_ostream &os,
const mlir::tblgen::Operator &op)
override {
760 os <<
" MlirContext ctx = mlirOpBuilderGetContext(builder);\n";
761 os <<
" llvm::SmallVector<MlirNamedAttribute, " << op.getNumAttributes()
762 <<
"> attributes;\n";
764 void genAttribute(llvm::raw_ostream &os,
const NamedAttribute &attr)
override {
769 std::string attrValue;
770 if (attrType.has_value() && attrType.value() ==
"MlirIdentifier") {
771 attrValue =
"reinterpret_cast<MlirAttribute&>(" + attr.name.str() +
")";
773 attrValue = attr.name.str();
776 os <<
" if (!mlirAttributeIsNull(" << attrValue <<
")) {\n";
777 os <<
" attributes.push_back(mlirNamedAttributeGet(mlirIdentifierGet(ctx, "
778 <<
"mlirStringRefCreateFromCString(\"" << attr.name <<
"\")), " << attrValue <<
"));\n";
782 genAttributesSuffix(llvm::raw_ostream &os,
const mlir::tblgen::Operator & )
override {
783 os <<
" mlirOperationStateAddAttributes(&state, attributes.size(), attributes.data());\n";
785 void genRegionsPrefix(llvm::raw_ostream &os,
const mlir::tblgen::Operator &op)
override {
786 os <<
" llvm::SmallVector<MlirRegion, " << op.getNumRegions() <<
"> regions;\n";
788 void genRegion(llvm::raw_ostream &os,
const mlir::tblgen::NamedRegion ®ion)
override {
789 if (region.isVariadic()) {
790 os << llvm::formatv(
" for (unsigned i = 0; i < {0}Count; ++i)\n ", region.name);
792 os <<
" regions.push_back(mlirRegionCreate());\n";
794 void genRegionsSuffix(llvm::raw_ostream &os,
const mlir::tblgen::Operator & )
override {
795 os <<
" mlirOperationStateAddOwnedRegions(&state, regions.size(), regions.data());\n";
797 } paramStringGenerator;
798 return paramStringGenerator.gen(op);
802static bool emitOpCAPIImpl(
const llvm::RecordKeeper &records, raw_ostream &os) {
803 emitSourceFileHeader(
"Op C API Definitions", os, records);
806 generator.genPrologue();
808 for (
const auto *def : records.getAllDerivedDefinitions(
"Op")) {
809 const Operator op(def);
810 const Dialect &dialect = op.getDialect();
817 generator.setNamespaceAndClassName(dialect, op.getCppClassName());
820 if (
GenOpBuild && !op.skipDefaultBuilders()) {
821 std::string assignments = generateCAPIAssignments(op);
822 generator.genOpBuildImpl(op.getOperationName(), generateCAPIBuildParams(op), assignments);
827 generator.genIsAImpl();
831 for (
int i = 0, e = op.getNumOperands(); i < e; ++i) {
832 const auto &operand = op.getOperand(i);
833 generator.setOperandName(operand.name);
834 if (operand.isVariadic()) {
836 generator.genVariadicOperandGetterImpl(i);
839 if (operand.isVariadicOfVariadic()) {
840 generator.genVariadicOfVariadicOperandSetterImpl(
841 operand.constraint.getVariadicOfVariadicSegmentSizeAttr()
844 generator.genVariadicOperandSetterImpl();
849 generator.genOperandGetterImpl(i);
852 generator.genOperandSetterImpl(i);
858 for (
const auto &namedAttr : op.getAttributes()) {
859 generator.setAttributeName(namedAttr.name);
861 generator.genAttributeGetterImpl(namedAttr.name);
864 generator.genAttributeSetterImpl(namedAttr.name);
870 for (
int i = 0, e = op.getNumResults(); i < e; ++i) {
871 const auto &result = op.getResult(i);
872 generator.setResultName(result.name, i);
873 if (result.isVariadic()) {
874 generator.genVariadicResultGetterImpl(i);
876 generator.genResultGetterImpl(i);
883 for (
unsigned i = 0, e = op.getNumRegions(); i < e; ++i) {
884 const auto ®ion = op.getRegion(i);
885 generator.setRegionName(region.name, i);
886 if (region.isVariadic()) {
887 generator.genVariadicRegionGetterImpl(i);
889 generator.genRegionGetterImpl(i);
896 generator.genExtraMethods(op.getExtraClassDeclaration());
903static mlir::GenRegistration
904 genOpCAPIHeader(
"gen-op-capi-header",
"Generate operation C API header", &emitOpCAPIHeader);
906static mlir::GenRegistration
907 genOpCAPIImpl(
"gen-op-capi-impl",
"Generate operation C API implementation", &emitOpCAPIImpl);
std::optional< std::string > tryCppTypeToCapiType(StringRef cppType)
Convert C++ type to MLIR C API type.
llvm::cl::opt< bool > GenOpOperandSetters
llvm::cl::opt< bool > GenIsA
llvm::cl::opt< bool > GenOpBuild
llvm::cl::opt< std::string > DialectName
llvm::cl::opt< std::string > FunctionPrefix
llvm::cl::opt< bool > GenOpRegionGetters
llvm::cl::opt< bool > GenOpResultGetters
llvm::cl::opt< bool > GenOpAttributeGetters
llvm::cl::opt< bool > GenOpAttributeSetters
llvm::cl::opt< bool > GenOpOperandGetters
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
Helper struct to generate a string from operation operand, attribute, and result pieces.
mlir::StringRef className
mlir::StringRef dialectNamespace
std::string dialectNameCapitalized
Generator for common C implementation file elements.
Common between header and implementation generators for operations.
void setAttributeName(mlir::StringRef name)
void setResultName(mlir::StringRef name, int resultIndex)
void setOperandName(mlir::StringRef name)
std::string resultNameCapitalized
std::string attrNameCapitalized
std::string operandNameCapitalized
void setRegionName(mlir::StringRef name, unsigned regionIndex)
std::string regionNameCapitalized
Generator for operation C implementation files.
void genOperandGetterImpl(int index) const
void genRegionGetterImpl(unsigned index) const
void genAttributeGetterImpl(mlir::StringRef attrName) const
void genVariadicRegionGetterImpl(unsigned startIdx) const
void genVariadicResultGetterImpl(int startIdx) const
void genVariadicOfVariadicOperandSetterImpl(mlir::StringRef segSizeAttrName) const
void genOpBuildImpl(const std::string &operationName, const std::string ¶ms, const std::string &assignments) const
Generate operation "Build" function implementation.
void genResultGetterImpl(int index) const
~OpImplementationGenerator() override=default
void genOperandSetterImpl(int index) const
void genVariadicOperandGetterImpl(int index) const
void genPrologue() const override
void genAttributeSetterImpl(mlir::StringRef attrName) const
void genVariadicOperandSetterImpl() const