40#include <mlir/TableGen/GenInfo.h>
41#include <mlir/TableGen/Operator.h>
43#include <llvm/ADT/StringExtras.h>
44#include <llvm/Support/CommandLine.h>
45#include <llvm/Support/FormatVariadic.h>
46#include <llvm/TableGen/Record.h>
47#include <llvm/TableGen/TableGenBackend.h>
50using namespace mlir::tblgen;
62 OpTestGenerator(llvm::raw_ostream &outputStream) : TestGenerator(
"Operation", outputStream) {}
69 std::string genCleanup()
const override {
return "mlirOperationDestroy"; };
73 void genBuildOpTests(
const Operator &op)
const {
74 static constexpr char fmt[] = R
"(
76TEST_F({1}OperationLinkTests, {0}_{2}_Build) {{
77 // Returns an `arith.constant` op, which will never match the {2} dialect check.
78 auto testOp = createIndexOperation();
80 // This condition is always false, so the function is never actually called.
81 // We only verify it compiles and links correctly.
82 if ({0}OperationIsA_{1}_{2}(testOp)) {{
83 MlirOpBuilder builder = mlirOpBuilderCreate(context);
84 MlirLocation location = mlirLocationUnknownGet(context);
86 (void){0}{1}_{2}Build(builder, location{4});
87 // No need to destroy builder or op since this code never runs.
90 mlirOperationDestroy(testOp);
93struct {2}BuildFuncHelper : public TestAnyBuildFuncHelper<CAPITest> {
94 virtual bool callIsA(MlirOperation op) override { return {0}OperationIsA_{1}_{2}(op); }
98 static std::unique_ptr<{2}BuildFuncHelper> get();
101 {2}BuildFuncHelper() = default;
106TEST_F(CAPITest, {2}_build_pass) { {2}BuildFuncHelper::get()->run(*this); }
109 assert(!className.empty() &&
"className must be set");
113 dialectNameCapitalized,
115 generateBuildDummyParams(op),
116 generateBuildParamList(op)
122 void genOperandTests(
const Operator &op)
const {
123 static constexpr char OperandGetterTest[] = R
"(
124TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}) {{
125 auto testOp = createIndexOperation();
127 if ({0}OperationIsA_{1}_{2}(testOp)) {{
128 (void){0}{1}_{2}Get{3}(testOp);
131 mlirOperationDestroy(testOp);
135 static constexpr char OperandSetterTest[] = R
"(
136TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}) {{
137 auto testOp = createIndexOperation();
139 if ({0}OperationIsA_{1}_{2}(testOp)) {{
140 auto dummyValue = mlirOperationGetResult(testOp, 0);
141 {0}{1}_{2}Set{3}(testOp, dummyValue);
144 mlirOperationDestroy(testOp);
148 static constexpr char VariadicOperandGetterTest[] = R
"(
149TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
150 auto testOp = createIndexOperation();
152 if ({0}OperationIsA_{1}_{2}(testOp)) {{
153 (void){0}{1}_{2}Get{3}Count(testOp);
156 mlirOperationDestroy(testOp);
159TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
160 auto testOp = createIndexOperation();
162 if ({0}OperationIsA_{1}_{2}(testOp)) {{
163 (void){0}{1}_{2}Get{3}At(testOp, 0);
166 mlirOperationDestroy(testOp);
170 static constexpr char VariadicOperandSetterTest[] = R
"(
171TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}_Variadic) {{
172 auto testOp = createIndexOperation();
174 if ({0}OperationIsA_{1}_{2}(testOp)) {{
175 auto dummyValue = mlirOperationGetResult(testOp, 0);
176 MlirValue values[] = {{dummyValue};
177 {0}{1}_{2}Set{3}(testOp, 1, values);
180 mlirOperationDestroy(testOp);
184 static constexpr char VariadicOfVariadicOperandSetterTest[] = R
"(
185TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}_VariadicOfVariadic) {{
186 auto testOp = createIndexOperation();
188 if ({0}OperationIsA_{1}_{2}(testOp)) {{
189 auto dummyValue = mlirOperationGetResult(testOp, 0);
190 MlirValueRange groups[1];
191 groups[0].values = &dummyValue;
193 {0}{1}_{2}Set{3}(testOp, 1, groups);
196 mlirOperationDestroy(testOp);
199 assert(!className.empty() && "className must be set");
201 for (
int i = 0, e = op.getNumOperands(); i < e; ++i) {
202 const auto &operand = op.getOperand(i);
204 if (operand.isVariadic()) {
207 VariadicOperandGetterTest,
209 dialectNameCapitalized,
216 operand.isVariadicOfVariadic() ? VariadicOfVariadicOperandSetterTest
217 : VariadicOperandSetterTest,
219 dialectNameCapitalized,
229 dialectNameCapitalized,
238 dialectNameCapitalized,
249 void genAttributeTests(
const Operator &op)
const {
250 static constexpr char AttributeGetterTest[] = R
"(
251TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Attr) {{
252 auto testOp = createIndexOperation();
254 if ({0}OperationIsA_{1}_{2}(testOp)) {{
255 (void){0}{1}_{2}Get{3}(testOp);
258 mlirOperationDestroy(testOp);
262 static constexpr char AttributeSetterTest[] = R
"(
263TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}Attr) {{
264 auto testOp = createIndexOperation();
266 if ({0}OperationIsA_{1}_{2}(testOp)) {{
267 {0}{1}_{2}Set{3}(testOp, createIndexAttribute());
270 mlirOperationDestroy(testOp);
273 assert(!className.empty() && "className must be set");
275 for (
const auto &namedAttr : op.getAttributes()) {
281 dialectNameCapitalized,
290 dialectNameCapitalized,
300 void genResultTests(
const Operator &op)
const {
301 static constexpr char ResultGetterTest[] = R
"(
302TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}) {{
303 auto testOp = createIndexOperation();
305 if ({0}OperationIsA_{1}_{2}(testOp)) {{
306 (void){0}{1}_{2}Get{3}(testOp);
309 mlirOperationDestroy(testOp);
313 static constexpr char VariadicResultGetterTest[] = R
"(
314TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
315 auto testOp = createIndexOperation();
317 if ({0}OperationIsA_{1}_{2}(testOp)) {{
318 (void){0}{1}_{2}Get{3}Count(testOp);
321 mlirOperationDestroy(testOp);
324TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
325 auto testOp = createIndexOperation();
327 if ({0}OperationIsA_{1}_{2}(testOp)) {{
328 (void){0}{1}_{2}Get{3}At(testOp, 0);
331 mlirOperationDestroy(testOp);
334 assert(!className.empty() && "className must be set");
336 for (
int i = 0, e = op.getNumResults(); i < e; ++i) {
337 const auto &result = op.getResult(i);
338 llvm::StringRef name = result.name;
339 std::string capName = name.empty() ? llvm::formatv(
"Result{0}", i).str() :
toPascalCase(name);
341 if (result.isVariadic()) {
343 VariadicResultGetterTest,
345 dialectNameCapitalized,
353 dialectNameCapitalized,
363 void genRegionTests(
const Operator &op)
const {
364 static constexpr char RegionGetterTest[] = R
"(
365TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Region) {{
366 auto testOp = createIndexOperation();
368 if ({0}OperationIsA_{1}_{2}(testOp)) {{
369 (void){0}{1}_{2}Get{3}(testOp);
372 mlirOperationDestroy(testOp);
376 static constexpr char VariadicRegionGetterTest[] = R
"(
377TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
378 auto testOp = createIndexOperation();
380 if ({0}OperationIsA_{1}_{2}(testOp)) {{
381 (void){0}{1}_{2}Get{3}Count(testOp);
384 mlirOperationDestroy(testOp);
387TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
388 auto testOp = createIndexOperation();
390 if ({0}OperationIsA_{1}_{2}(testOp)) {{
391 (void){0}{1}_{2}Get{3}At(testOp, 0);
394 mlirOperationDestroy(testOp);
397 assert(!className.empty() && "className must be set");
399 for (
unsigned int i = 0, e = op.getNumRegions(); i < e; ++i) {
400 const auto ®ion = op.getRegion(i);
401 llvm::StringRef name = region.name;
402 std::string capName = name.empty() ? llvm::formatv(
"Region{0}", i).str() :
toPascalCase(name);
404 if (region.isVariadic()) {
406 VariadicRegionGetterTest,
408 dialectNameCapitalized,
416 dialectNameCapitalized,
426 void genCompleteRecord(
const Operator &op) {
427 const Dialect &defDialect = op.getDialect();
434 this->setNamespaceAndClassName(defDialect, op.getCppClassName());
439 if (
GenOpBuild && !op.skipDefaultBuilders()) {
440 this->genBuildOpTests(op);
443 this->genOperandTests(op);
446 this->genAttributeTests(op);
449 this->genRegionTests(op);
452 this->genResultTests(op);
455 this->genExtraMethods(op.getExtraClassDeclaration());
463 static std::string generateBuildDummyParams(
const Operator &op) {
464 struct : GenStringFromOpPieces {
465 void genHeader(llvm::raw_ostream &os)
override {
467 os <<
" auto dummyValue = mlirOperationGetResult(testOp, 0);\n";
470 llvm::raw_ostream &os,
const NamedTypeConstraint &result,
const std::string &resultName
472 if (result.isVariadic()) {
474 " auto {0}TypeArray = createIndexType();\n"
475 " MlirType {0}Types[] = {{{0}TypeArray};\n"
476 " intptr_t {0}Size = 0;\n",
480 os << llvm::formatv(
" auto {0}Type = createIndexType();\n", resultName);
483 void genOperand(llvm::raw_ostream &os,
const NamedTypeConstraint &operand)
override {
486 if (operand.isVariadic()) {
488 " MlirValue {0}Values[] = {{dummyValue};\n"
489 " intptr_t {0}Size = 0;\n",
494 void genAttribute(llvm::raw_ostream &os,
const NamedAttribute &attr)
override {
497 if (attrType.has_value() && attrType.value() ==
"MlirIdentifier") {
498 rhs =
"mlirOperationGetName(testOp)";
500 rhs =
"createIndexAttribute()";
502 os << llvm::formatv(
" auto {0}Attr = {1};\n", attr.name, rhs);
504 void genRegion(llvm::raw_ostream &os,
const mlir::tblgen::NamedRegion ®ion)
override {
505 if (region.isVariadic()) {
506 os << llvm::formatv(
" unsigned {0}Count = 0;\n", region.name);
509 } paramsStringGenerator;
510 return paramsStringGenerator.gen(op);
516 static std::string generateBuildParamList(
const Operator &op) {
517 struct : GenStringFromOpPieces {
519 llvm::raw_ostream &os,
const NamedTypeConstraint &result,
const std::string &resultName
521 if (result.isVariadic()) {
522 os << llvm::formatv(
", {0}Size, {0}Types", resultName);
524 os << llvm::formatv(
", {0}Type", resultName);
527 void genOperand(llvm::raw_ostream &os,
const NamedTypeConstraint &operand)
override {
528 if (operand.isVariadic()) {
529 os << llvm::formatv(
", {0}Size, {0}Values", operand.name);
531 os <<
", dummyValue";
534 void genAttribute(llvm::raw_ostream &os,
const NamedAttribute &attr)
override {
535 os << llvm::formatv(
", {0}Attr", attr.name);
537 void genRegion(llvm::raw_ostream &os,
const mlir::tblgen::NamedRegion ®ion)
override {
538 if (region.isVariadic()) {
539 os << llvm::formatv(
", {0}Count", region.name);
542 } paramsStringGenerator;
543 return paramsStringGenerator.gen(op);
550static bool emitOpCAPITests(
const llvm::RecordKeeper &records, raw_ostream &os) {
552 emitSourceFileHeader(
"Op C API Tests", os, records);
555 OpTestGenerator generator(os);
558 generator.genTestClassPrologue();
561 for (
const auto *def : records.getAllDerivedDefinitions(
"Op")) {
563 generator.genCompleteRecord(op);
569static mlir::GenRegistration
570 genOpCAPITests(
"gen-op-capi-tests",
"Generate operation C API unit tests", &emitOpCAPITests);
std::optional< std::string > tryCppTypeToCapiType(StringRef cppType)
Convert C++ type to MLIR C API type.
llvm::cl::opt< bool > GenOpOperandSetters
llvm::cl::opt< bool > GenIsA
llvm::cl::opt< bool > GenOpBuild
llvm::cl::opt< std::string > DialectName
llvm::cl::opt< std::string > FunctionPrefix
llvm::cl::opt< bool > GenOpRegionGetters
llvm::cl::opt< bool > GenOpResultGetters
llvm::cl::opt< bool > GenOpAttributeGetters
llvm::cl::opt< bool > GenOpAttributeSetters
llvm::cl::opt< bool > GenOpOperandGetters
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
Generator for common test implementation file elements.