LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
OpCAPITestGen.cpp
Go to the documentation of this file.
1//===- OpCAPITestGen.cpp - C API test generator for operations ------------===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9//
10// OpCAPITestGen generates unit tests for the C API operations generated by
11// OpCAPIGen. These are link-time tests that ensure all generated functions
12// compile and link properly, using a pattern where the C API function is
13// wrapped in a conditional that is always false but the compiler still ensures
14// the function within will link correctly.
15//
16// Test Strategy:
17// - Each test creates a dummy operation from a different dialect (arith.constant)
18// - Tests then call the generated C API functions inside an if statement that
19// checks if the dummy op is of the target type (always false)
20// - The compiler still verifies the function signatures and the linker ensures
21// the symbols are defined, even though the code never executes at runtime
22//
23// These tests will catch the following kinds of issues:
24// - Functions declared but not defined (link errors)
25// - Signature mismatches between header and implementation
26// - Missing functions in the build system
27// - ABI compatibility issues
28// - Refactoring breaks
29//
30// However, the following issues will NOT be caught:
31// - Generator logic bugs (if generator is wrong, tests will be wrong too)
32// - Runtime behavior
33// - Semantic correctness
34//
35//===----------------------------------------------------------------------===//
36
37#include "CommonCAPIGen.h"
38#include "OpCAPIParamHelper.h"
39
40#include <mlir/TableGen/GenInfo.h>
41#include <mlir/TableGen/Operator.h>
42
43#include <llvm/ADT/StringExtras.h>
44#include <llvm/Support/CommandLine.h>
45#include <llvm/Support/FormatVariadic.h>
46#include <llvm/TableGen/Record.h>
47#include <llvm/TableGen/TableGenBackend.h>
48
49using namespace mlir;
50using namespace mlir::tblgen;
51
52namespace {
53
59struct OpTestGenerator : public TestGenerator {
62 OpTestGenerator(llvm::raw_ostream &outputStream) : TestGenerator("Operation", outputStream) {}
63
69 std::string genCleanup() const override { return "mlirOperationDestroy"; };
70
73 void genBuildOpTests(const Operator &op) const {
74 static constexpr char fmt[] = R"(
76TEST_F({1}OperationLinkTests, {0}_{2}_Build) {{
77 // Returns an `arith.constant` op, which will never match the {2} dialect check.
78 auto testOp = createIndexOperation();
79
80 // This condition is always false, so the function is never actually called.
81 // We only verify it compiles and links correctly.
82 if ({0}OperationIsA_{1}_{2}(testOp)) {{
83 MlirOpBuilder builder = mlirOpBuilderCreate(context);
84 MlirLocation location = mlirLocationUnknownGet(context);
85{3}
86 (void){0}{1}_{2}Build(builder, location{4});
87 // No need to destroy builder or op since this code never runs.
88 }
89
90 mlirOperationDestroy(testOp);
91}
92
93struct {2}BuildFuncHelper : public TestAnyBuildFuncHelper<CAPITest> {
94 virtual bool callIsA(MlirOperation op) override { return {0}OperationIsA_{1}_{2}(op); }
98 static std::unique_ptr<{2}BuildFuncHelper> get();
99
100protected:
101 {2}BuildFuncHelper() = default;
102};
106TEST_F(CAPITest, {2}_build_pass) { {2}BuildFuncHelper::get()->run(*this); }
107)";
108
109 assert(!className.empty() && "className must be set");
110 os << llvm::formatv(
111 fmt,
112 FunctionPrefix, // {0}
113 dialectNameCapitalized, // {1}
114 className, // {2}
115 generateBuildDummyParams(op), // {3}
116 generateBuildParamList(op) // {4}
117 );
118 }
119
122 void genOperandTests(const Operator &op) const {
123 static constexpr char OperandGetterTest[] = R"(
124TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}) {{
125 auto testOp = createIndexOperation();
126
127 if ({0}OperationIsA_{1}_{2}(testOp)) {{
128 (void){0}{1}_{2}Get{3}(testOp);
129 }
130
131 mlirOperationDestroy(testOp);
132}
133)";
134
135 static constexpr char OperandSetterTest[] = R"(
136TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}) {{
137 auto testOp = createIndexOperation();
138
139 if ({0}OperationIsA_{1}_{2}(testOp)) {{
140 auto dummyValue = mlirOperationGetResult(testOp, 0);
141 {0}{1}_{2}Set{3}(testOp, dummyValue);
142 }
143
144 mlirOperationDestroy(testOp);
145}
146)";
147
148 static constexpr char VariadicOperandGetterTest[] = R"(
149TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
150 auto testOp = createIndexOperation();
151
152 if ({0}OperationIsA_{1}_{2}(testOp)) {{
153 (void){0}{1}_{2}Get{3}Count(testOp);
154 }
155
156 mlirOperationDestroy(testOp);
157}
158
159TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
160 auto testOp = createIndexOperation();
161
162 if ({0}OperationIsA_{1}_{2}(testOp)) {{
163 (void){0}{1}_{2}Get{3}At(testOp, 0);
164 }
165
166 mlirOperationDestroy(testOp);
167}
168)";
169
170 static constexpr char VariadicOperandSetterTest[] = R"(
171TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}_Variadic) {{
172 auto testOp = createIndexOperation();
173
174 if ({0}OperationIsA_{1}_{2}(testOp)) {{
175 auto dummyValue = mlirOperationGetResult(testOp, 0);
176 MlirValue values[] = {{dummyValue};
177 {0}{1}_{2}Set{3}(testOp, 1, values);
178 }
179
180 mlirOperationDestroy(testOp);
181}
182)";
183
184 static constexpr char VariadicOfVariadicOperandSetterTest[] = R"(
185TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}_VariadicOfVariadic) {{
186 auto testOp = createIndexOperation();
187
188 if ({0}OperationIsA_{1}_{2}(testOp)) {{
189 auto dummyValue = mlirOperationGetResult(testOp, 0);
190 MlirValueRange groups[1];
191 groups[0].values = &dummyValue;
192 groups[0].size = 1;
193 {0}{1}_{2}Set{3}(testOp, 1, groups);
194 }
195
196 mlirOperationDestroy(testOp);
197}
198)";
199 assert(!className.empty() && "className must be set");
200
201 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
202 const auto &operand = op.getOperand(i);
203 std::string capName = toPascalCase(operand.name);
204 if (operand.isVariadic()) {
206 os << llvm::formatv(
207 VariadicOperandGetterTest,
208 FunctionPrefix, // {0}
209 dialectNameCapitalized, // {1}
210 className, // {2}
211 capName // {3}
212 );
213 }
215 os << llvm::formatv(
216 operand.isVariadicOfVariadic() ? VariadicOfVariadicOperandSetterTest
217 : VariadicOperandSetterTest,
218 FunctionPrefix, // {0}
219 dialectNameCapitalized, // {1}
220 className, // {2}
221 capName // {3}
222 );
223 }
224 } else {
226 os << llvm::formatv(
227 OperandGetterTest,
228 FunctionPrefix, // {0}
229 dialectNameCapitalized, // {1}
230 className, // {2}
231 capName // {3}
232 );
233 }
235 os << llvm::formatv(
236 OperandSetterTest,
237 FunctionPrefix, // {0}
238 dialectNameCapitalized, // {1}
239 className, // {2}
240 capName // {3}
241 );
242 }
243 }
244 }
245 }
246
249 void genAttributeTests(const Operator &op) const {
250 static constexpr char AttributeGetterTest[] = R"(
251TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Attr) {{
252 auto testOp = createIndexOperation();
253
254 if ({0}OperationIsA_{1}_{2}(testOp)) {{
255 (void){0}{1}_{2}Get{3}(testOp);
256 }
257
258 mlirOperationDestroy(testOp);
259}
260)";
261
262 static constexpr char AttributeSetterTest[] = R"(
263TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}Attr) {{
264 auto testOp = createIndexOperation();
265
266 if ({0}OperationIsA_{1}_{2}(testOp)) {{
267 {0}{1}_{2}Set{3}(testOp, createIndexAttribute());
268 }
269
270 mlirOperationDestroy(testOp);
271}
272)";
273 assert(!className.empty() && "className must be set");
274
275 for (const auto &namedAttr : op.getAttributes()) {
276 std::string capName = toPascalCase(namedAttr.name);
278 os << llvm::formatv(
279 AttributeGetterTest,
280 FunctionPrefix, // {0}
281 dialectNameCapitalized, // {1}
282 className, // {2}
283 capName // {3}
284 );
285 }
287 os << llvm::formatv(
288 AttributeSetterTest,
289 FunctionPrefix, // {0}
290 dialectNameCapitalized, // {1}
291 className, // {2}
292 capName // {3}
293 );
294 }
295 }
296 }
297
300 void genResultTests(const Operator &op) const {
301 static constexpr char ResultGetterTest[] = R"(
302TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}) {{
303 auto testOp = createIndexOperation();
304
305 if ({0}OperationIsA_{1}_{2}(testOp)) {{
306 (void){0}{1}_{2}Get{3}(testOp);
307 }
308
309 mlirOperationDestroy(testOp);
310}
311)";
312
313 static constexpr char VariadicResultGetterTest[] = R"(
314TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
315 auto testOp = createIndexOperation();
316
317 if ({0}OperationIsA_{1}_{2}(testOp)) {{
318 (void){0}{1}_{2}Get{3}Count(testOp);
319 }
320
321 mlirOperationDestroy(testOp);
322}
323
324TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
325 auto testOp = createIndexOperation();
326
327 if ({0}OperationIsA_{1}_{2}(testOp)) {{
328 (void){0}{1}_{2}Get{3}At(testOp, 0);
329 }
330
331 mlirOperationDestroy(testOp);
332}
333)";
334 assert(!className.empty() && "className must be set");
335
336 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
337 const auto &result = op.getResult(i);
338 llvm::StringRef name = result.name;
339 std::string capName = name.empty() ? llvm::formatv("Result{0}", i).str() : toPascalCase(name);
340
341 if (result.isVariadic()) {
342 os << llvm::formatv(
343 VariadicResultGetterTest,
344 FunctionPrefix, // {0}
345 dialectNameCapitalized, // {1}
346 className, // {2}
347 capName // {3}
348 );
349 } else {
350 os << llvm::formatv(
351 ResultGetterTest,
352 FunctionPrefix, // {0}
353 dialectNameCapitalized, // {1}
354 className, // {2}
355 capName // {3}
356 );
357 }
358 }
359 }
360
363 void genRegionTests(const Operator &op) const {
364 static constexpr char RegionGetterTest[] = R"(
365TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Region) {{
366 auto testOp = createIndexOperation();
367
368 if ({0}OperationIsA_{1}_{2}(testOp)) {{
369 (void){0}{1}_{2}Get{3}(testOp);
370 }
371
372 mlirOperationDestroy(testOp);
373}
374)";
375
376 static constexpr char VariadicRegionGetterTest[] = R"(
377TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
378 auto testOp = createIndexOperation();
379
380 if ({0}OperationIsA_{1}_{2}(testOp)) {{
381 (void){0}{1}_{2}Get{3}Count(testOp);
382 }
383
384 mlirOperationDestroy(testOp);
385}
386
387TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
388 auto testOp = createIndexOperation();
389
390 if ({0}OperationIsA_{1}_{2}(testOp)) {{
391 (void){0}{1}_{2}Get{3}At(testOp, 0);
392 }
393
394 mlirOperationDestroy(testOp);
395}
396)";
397 assert(!className.empty() && "className must be set");
398
399 for (unsigned int i = 0, e = op.getNumRegions(); i < e; ++i) {
400 const auto &region = op.getRegion(i);
401 llvm::StringRef name = region.name;
402 std::string capName = name.empty() ? llvm::formatv("Region{0}", i).str() : toPascalCase(name);
403
404 if (region.isVariadic()) {
405 os << llvm::formatv(
406 VariadicRegionGetterTest,
407 FunctionPrefix, // {0}
408 dialectNameCapitalized, // {1}
409 className, // {2}
410 capName // {3}
411 );
412 } else {
413 os << llvm::formatv(
414 RegionGetterTest,
415 FunctionPrefix, // {0}
416 dialectNameCapitalized, // {1}
417 className, // {2}
418 capName // {3}
419 );
420 }
421 }
422 }
423
426 void genCompleteRecord(const Operator &op) {
427 const Dialect &defDialect = op.getDialect();
428
429 // Generate for the selected dialect only
430 if (defDialect.getName() != DialectName) {
431 return;
432 }
433
434 this->setNamespaceAndClassName(defDialect, op.getCppClassName());
435
436 if (GenIsA) {
437 this->genIsATest();
438 }
439 if (GenOpBuild && !op.skipDefaultBuilders()) {
440 this->genBuildOpTests(op);
441 }
443 this->genOperandTests(op);
444 }
446 this->genAttributeTests(op);
447 }
448 if (GenOpRegionGetters) {
449 this->genRegionTests(op);
450 }
451 if (GenOpResultGetters) {
452 this->genResultTests(op);
453 }
455 this->genExtraMethods(op.getExtraClassDeclaration());
456 }
457 }
458
459private:
463 static std::string generateBuildDummyParams(const Operator &op) {
464 struct : GenStringFromOpPieces {
465 void genHeader(llvm::raw_ostream &os) override {
466 // Declare dummyValue first
467 os << " auto dummyValue = mlirOperationGetResult(testOp, 0);\n";
468 }
469 void genResult(
470 llvm::raw_ostream &os, const NamedTypeConstraint &result, const std::string &resultName
471 ) override {
472 if (result.isVariadic()) {
473 os << llvm::formatv(
474 " auto {0}TypeArray = createIndexType();\n"
475 " MlirType {0}Types[] = {{{0}TypeArray};\n"
476 " intptr_t {0}Size = 0;\n",
477 resultName
478 );
479 } else {
480 os << llvm::formatv(" auto {0}Type = createIndexType();\n", resultName);
481 }
482 }
483 void genOperand(llvm::raw_ostream &os, const NamedTypeConstraint &operand) override {
484 // per `generateParamList()` only need to create something additional in case
485 // of variadic operand, otherwise `dummyValue` is used directly.
486 if (operand.isVariadic()) {
487 os << llvm::formatv(
488 " MlirValue {0}Values[] = {{dummyValue};\n"
489 " intptr_t {0}Size = 0;\n",
490 operand.name
491 );
492 }
493 }
494 void genAttribute(llvm::raw_ostream &os, const NamedAttribute &attr) override {
495 std::string rhs;
496 std::optional<std::string> attrType = tryCppTypeToCapiType(attr.attr.getStorageType());
497 if (attrType.has_value() && attrType.value() == "MlirIdentifier") {
498 rhs = "mlirOperationGetName(testOp)";
499 } else {
500 rhs = "createIndexAttribute()";
501 }
502 os << llvm::formatv(" auto {0}Attr = {1};\n", attr.name, rhs);
503 }
504 void genRegion(llvm::raw_ostream &os, const mlir::tblgen::NamedRegion &region) override {
505 if (region.isVariadic()) {
506 os << llvm::formatv(" unsigned {0}Count = 0;\n", region.name);
507 }
508 }
509 } paramsStringGenerator;
510 return paramsStringGenerator.gen(op);
511 }
512
516 static std::string generateBuildParamList(const Operator &op) {
517 struct : GenStringFromOpPieces {
518 void genResult(
519 llvm::raw_ostream &os, const NamedTypeConstraint &result, const std::string &resultName
520 ) override {
521 if (result.isVariadic()) {
522 os << llvm::formatv(", {0}Size, {0}Types", resultName);
523 } else {
524 os << llvm::formatv(", {0}Type", resultName);
525 }
526 }
527 void genOperand(llvm::raw_ostream &os, const NamedTypeConstraint &operand) override {
528 if (operand.isVariadic()) {
529 os << llvm::formatv(", {0}Size, {0}Values", operand.name);
530 } else {
531 os << ", dummyValue";
532 }
533 }
534 void genAttribute(llvm::raw_ostream &os, const NamedAttribute &attr) override {
535 os << llvm::formatv(", {0}Attr", attr.name);
536 }
537 void genRegion(llvm::raw_ostream &os, const mlir::tblgen::NamedRegion &region) override {
538 if (region.isVariadic()) {
539 os << llvm::formatv(", {0}Count", region.name);
540 }
541 }
542 } paramsStringGenerator;
543 return paramsStringGenerator.gen(op);
544 }
545};
546
547} // namespace
548
550static bool emitOpCAPITests(const llvm::RecordKeeper &records, raw_ostream &os) {
551 // Generate file header
552 emitSourceFileHeader("Op C API Tests", os, records);
553
554 // Create generator
555 OpTestGenerator generator(os);
556
557 // Generate test class prologue
558 generator.genTestClassPrologue();
559
560 // Generate tests for each operation
561 for (const auto *def : records.getAllDerivedDefinitions("Op")) {
562 Operator op(def);
563 generator.genCompleteRecord(op);
564 }
565
566 return false;
567}
568
569static mlir::GenRegistration
570 genOpCAPITests("gen-op-capi-tests", "Generate operation C API unit tests", &emitOpCAPITests);
std::optional< std::string > tryCppTypeToCapiType(StringRef cppType)
Convert C++ type to MLIR C API type.
llvm::cl::opt< bool > GenOpOperandSetters
llvm::cl::opt< bool > GenIsA
llvm::cl::opt< bool > GenOpBuild
llvm::cl::opt< std::string > DialectName
llvm::cl::opt< std::string > FunctionPrefix
llvm::cl::opt< bool > GenOpRegionGetters
llvm::cl::opt< bool > GenOpResultGetters
llvm::cl::opt< bool > GenOpAttributeGetters
llvm::cl::opt< bool > GenOpAttributeSetters
llvm::cl::opt< bool > GenOpOperandGetters
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
Generator for common test implementation file elements.