40#include <mlir/TableGen/GenInfo.h>
41#include <mlir/TableGen/Operator.h>
43#include <llvm/ADT/StringExtras.h>
44#include <llvm/Support/CommandLine.h>
45#include <llvm/Support/FormatVariadic.h>
46#include <llvm/TableGen/Record.h>
47#include <llvm/TableGen/TableGenBackend.h>
50using namespace mlir::tblgen;
62 OpTestGenerator(llvm::raw_ostream &outputStream) : TestGenerator(
"Operation", outputStream) {}
69 std::string genCleanup()
const override {
return "mlirOperationDestroy"; };
73 void genBuildOpTests(
const Operator &op)
const {
74 static constexpr char fmt[] = R
"(
76TEST_F({1}OperationLinkTests, {0}_{2}_Build) {{
77 // Returns an `arith.constant` op, which will never match the {2} dialect check.
78 auto testOp = createIndexOperation();
80 // This condition is always false, so the function is never actually called.
81 // We only verify it compiles and links correctly.
82 if ({0}OperationIsA_{1}_{2}(testOp)) {{
83 MlirOpBuilder builder = mlirOpBuilderCreate(context);
84 MlirLocation location = mlirLocationUnknownGet(context);
86 (void){0}{1}_{2}Build(builder, location{4});
87 // No need to destroy builder or op since this code never runs.
90 mlirOperationDestroy(testOp);
93struct {2}BuildFuncHelper : public TestAnyBuildFuncHelper<CAPITest> {
94 virtual bool callIsA(MlirOperation op) override { return {0}OperationIsA_{1}_{2}(op); }
98 static std::unique_ptr<{2}BuildFuncHelper> get();
101 {2}BuildFuncHelper() = default;
106TEST_F(CAPITest, {2}_build_pass) { {2}BuildFuncHelper::get()->run(*this); }
109 assert(!className.empty() &&
"className must be set");
113 dialectNameCapitalized,
115 generateBuildDummyParams(op),
116 generateBuildParamList(op)
122 void genOperandTests(
const Operator &op)
const {
123 static constexpr char OperandGetterTest[] = R
"(
124TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}) {{
125 auto testOp = createIndexOperation();
127 if ({0}OperationIsA_{1}_{2}(testOp)) {{
128 (void){0}{1}_{2}Get{3}(testOp);
131 mlirOperationDestroy(testOp);
135 static constexpr char OperandSetterTest[] = R
"(
136TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}) {{
137 auto testOp = createIndexOperation();
139 if ({0}OperationIsA_{1}_{2}(testOp)) {{
140 auto dummyValue = mlirOperationGetResult(testOp, 0);
141 {0}{1}_{2}Set{3}(testOp, dummyValue);
144 mlirOperationDestroy(testOp);
148 static constexpr char VariadicOperandGetterTest[] = R
"(
149TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
150 auto testOp = createIndexOperation();
152 if ({0}OperationIsA_{1}_{2}(testOp)) {{
153 (void){0}{1}_{2}Get{3}Count(testOp);
156 mlirOperationDestroy(testOp);
159TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
160 auto testOp = createIndexOperation();
162 if ({0}OperationIsA_{1}_{2}(testOp)) {{
163 (void){0}{1}_{2}Get{3}At(testOp, 0);
166 mlirOperationDestroy(testOp);
170 static constexpr char VariadicOperandSetterTest[] = R
"(
171TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}_Variadic) {{
172 auto testOp = createIndexOperation();
174 if ({0}OperationIsA_{1}_{2}(testOp)) {{
175 auto dummyValue = mlirOperationGetResult(testOp, 0);
176 MlirValue values[] = {{dummyValue};
177 {0}{1}_{2}Set{3}(testOp, 1, values);
180 mlirOperationDestroy(testOp);
183 assert(!className.empty() && "className must be set");
185 for (
int i = 0, e = op.getNumOperands(); i < e; ++i) {
186 const auto &operand = op.getOperand(i);
188 if (operand.isVariadic()) {
191 VariadicOperandGetterTest,
193 dialectNameCapitalized,
200 VariadicOperandSetterTest,
202 dialectNameCapitalized,
212 dialectNameCapitalized,
221 dialectNameCapitalized,
232 void genAttributeTests(
const Operator &op)
const {
233 static constexpr char AttributeGetterTest[] = R
"(
234TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Attr) {{
235 auto testOp = createIndexOperation();
237 if ({0}OperationIsA_{1}_{2}(testOp)) {{
238 (void){0}{1}_{2}Get{3}(testOp);
241 mlirOperationDestroy(testOp);
245 static constexpr char AttributeSetterTest[] = R
"(
246TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}Attr) {{
247 auto testOp = createIndexOperation();
249 if ({0}OperationIsA_{1}_{2}(testOp)) {{
250 {0}{1}_{2}Set{3}(testOp, createIndexAttribute());
253 mlirOperationDestroy(testOp);
256 assert(!className.empty() && "className must be set");
258 for (
const auto &namedAttr : op.getAttributes()) {
264 dialectNameCapitalized,
273 dialectNameCapitalized,
283 void genResultTests(
const Operator &op)
const {
284 static constexpr char ResultGetterTest[] = R
"(
285TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}) {{
286 auto testOp = createIndexOperation();
288 if ({0}OperationIsA_{1}_{2}(testOp)) {{
289 (void){0}{1}_{2}Get{3}(testOp);
292 mlirOperationDestroy(testOp);
296 static constexpr char VariadicResultGetterTest[] = R
"(
297TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
298 auto testOp = createIndexOperation();
300 if ({0}OperationIsA_{1}_{2}(testOp)) {{
301 (void){0}{1}_{2}Get{3}Count(testOp);
304 mlirOperationDestroy(testOp);
307TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
308 auto testOp = createIndexOperation();
310 if ({0}OperationIsA_{1}_{2}(testOp)) {{
311 (void){0}{1}_{2}Get{3}At(testOp, 0);
314 mlirOperationDestroy(testOp);
317 assert(!className.empty() && "className must be set");
319 for (
int i = 0, e = op.getNumResults(); i < e; ++i) {
320 const auto &result = op.getResult(i);
321 llvm::StringRef name = result.name;
322 std::string capName = name.empty() ? llvm::formatv(
"Result{0}", i).str() :
toPascalCase(name);
324 if (result.isVariadic()) {
326 VariadicResultGetterTest,
328 dialectNameCapitalized,
336 dialectNameCapitalized,
346 void genRegionTests(
const Operator &op)
const {
347 static constexpr char RegionGetterTest[] = R
"(
348TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Region) {{
349 auto testOp = createIndexOperation();
351 if ({0}OperationIsA_{1}_{2}(testOp)) {{
352 (void){0}{1}_{2}Get{3}(testOp);
355 mlirOperationDestroy(testOp);
359 static constexpr char VariadicRegionGetterTest[] = R
"(
360TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
361 auto testOp = createIndexOperation();
363 if ({0}OperationIsA_{1}_{2}(testOp)) {{
364 (void){0}{1}_{2}Get{3}Count(testOp);
367 mlirOperationDestroy(testOp);
370TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
371 auto testOp = createIndexOperation();
373 if ({0}OperationIsA_{1}_{2}(testOp)) {{
374 (void){0}{1}_{2}Get{3}At(testOp, 0);
377 mlirOperationDestroy(testOp);
380 assert(!className.empty() && "className must be set");
382 for (
unsigned int i = 0, e = op.getNumRegions(); i < e; ++i) {
383 const auto ®ion = op.getRegion(i);
384 llvm::StringRef name = region.name;
385 std::string capName = name.empty() ? llvm::formatv(
"Region{0}", i).str() :
toPascalCase(name);
387 if (region.isVariadic()) {
389 VariadicRegionGetterTest,
391 dialectNameCapitalized,
399 dialectNameCapitalized,
409 void genCompleteRecord(
const Operator &op) {
410 const Dialect &defDialect = op.getDialect();
417 this->setNamespaceAndClassName(defDialect, op.getCppClassName());
422 if (
GenOpBuild && !op.skipDefaultBuilders()) {
423 this->genBuildOpTests(op);
426 this->genOperandTests(op);
429 this->genAttributeTests(op);
432 this->genRegionTests(op);
435 this->genResultTests(op);
438 this->genExtraMethods(op.getExtraClassDeclaration());
446 static std::string generateBuildDummyParams(
const Operator &op) {
447 struct : GenStringFromOpPieces {
448 void genHeader(llvm::raw_ostream &os)
override {
450 os <<
" auto dummyValue = mlirOperationGetResult(testOp, 0);\n";
453 llvm::raw_ostream &os,
const NamedTypeConstraint &result,
const std::string &resultName
455 if (result.isVariadic()) {
457 " auto {0}TypeArray = createIndexType();\n"
458 " MlirType {0}Types[] = {{{0}TypeArray};\n"
459 " intptr_t {0}Size = 0;\n",
463 os << llvm::formatv(
" auto {0}Type = createIndexType();\n", resultName);
466 void genOperand(llvm::raw_ostream &os,
const NamedTypeConstraint &operand)
override {
469 if (operand.isVariadic()) {
471 " MlirValue {0}Values[] = {{dummyValue};\n"
472 " intptr_t {0}Size = 0;\n",
477 void genAttribute(llvm::raw_ostream &os,
const NamedAttribute &attr)
override {
480 if (attrType.has_value() && attrType.value() ==
"MlirIdentifier") {
481 rhs =
"mlirOperationGetName(testOp)";
483 rhs =
"createIndexAttribute()";
485 os << llvm::formatv(
" auto {0}Attr = {1};\n", attr.name, rhs);
487 void genRegion(llvm::raw_ostream &os,
const mlir::tblgen::NamedRegion ®ion)
override {
488 if (region.isVariadic()) {
489 os << llvm::formatv(
" unsigned {0}Count = 0;\n", region.name);
492 } paramsStringGenerator;
493 return paramsStringGenerator.gen(op);
499 static std::string generateBuildParamList(
const Operator &op) {
500 struct : GenStringFromOpPieces {
502 llvm::raw_ostream &os,
const NamedTypeConstraint &result,
const std::string &resultName
504 if (result.isVariadic()) {
505 os << llvm::formatv(
", {0}Size, {0}Types", resultName);
507 os << llvm::formatv(
", {0}Type", resultName);
510 void genOperand(llvm::raw_ostream &os,
const NamedTypeConstraint &operand)
override {
511 if (operand.isVariadic()) {
512 os << llvm::formatv(
", {0}Size, {0}Values", operand.name);
514 os <<
", dummyValue";
517 void genAttribute(llvm::raw_ostream &os,
const NamedAttribute &attr)
override {
518 os << llvm::formatv(
", {0}Attr", attr.name);
520 void genRegion(llvm::raw_ostream &os,
const mlir::tblgen::NamedRegion ®ion)
override {
521 if (region.isVariadic()) {
522 os << llvm::formatv(
", {0}Count", region.name);
525 } paramsStringGenerator;
526 return paramsStringGenerator.gen(op);
533static bool emitOpCAPITests(
const llvm::RecordKeeper &records, raw_ostream &os) {
535 emitSourceFileHeader(
"Op C API Tests", os, records);
538 OpTestGenerator generator(os);
541 generator.genTestClassPrologue();
544 for (
const auto *def : records.getAllDerivedDefinitions(
"Op")) {
546 generator.genCompleteRecord(op);
552static mlir::GenRegistration
553 genOpCAPITests(
"gen-op-capi-tests",
"Generate operation C API unit tests", &emitOpCAPITests);
std::optional< std::string > tryCppTypeToCapiType(StringRef cppType)
Convert C++ type to MLIR C API type.
llvm::cl::opt< bool > GenOpOperandSetters
llvm::cl::opt< bool > GenIsA
llvm::cl::opt< bool > GenOpBuild
llvm::cl::opt< std::string > DialectName
llvm::cl::opt< std::string > FunctionPrefix
llvm::cl::opt< bool > GenOpRegionGetters
llvm::cl::opt< bool > GenOpResultGetters
llvm::cl::opt< bool > GenOpAttributeGetters
llvm::cl::opt< bool > GenOpAttributeSetters
llvm::cl::opt< bool > GenOpOperandGetters
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
Generator for common test implementation file elements.