LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
OpCAPITestGen.cpp
Go to the documentation of this file.
1//===- OpCAPITestGen.cpp - C API test generator for operations ------------===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9//
10// OpCAPITestGen generates unit tests for the C API operations generated by
11// OpCAPIGen. These are link-time tests that ensure all generated functions
12// compile and link properly, using a pattern where the C API function is
13// wrapped in a conditional that is always false but the compiler still ensures
14// the function within will link correctly.
15//
16// Test Strategy:
17// - Each test creates a dummy operation from a different dialect (arith.constant)
18// - Tests then call the generated C API functions inside an if statement that
19// checks if the dummy op is of the target type (always false)
20// - The compiler still verifies the function signatures and the linker ensures
21// the symbols are defined, even though the code never executes at runtime
22//
23// These tests will catch the following kinds of issues:
24// - Functions declared but not defined (link errors)
25// - Signature mismatches between header and implementation
26// - Missing functions in the build system
27// - ABI compatibility issues
28// - Refactoring breaks
29//
30// However, the following issues will NOT be caught:
31// - Generator logic bugs (if generator is wrong, tests will be wrong too)
32// - Runtime behavior
33// - Semantic correctness
34//
35//===----------------------------------------------------------------------===//
36
37#include "CommonCAPIGen.h"
38#include "OpCAPIParamHelper.h"
39
40#include <mlir/TableGen/GenInfo.h>
41#include <mlir/TableGen/Operator.h>
42
43#include <llvm/ADT/StringExtras.h>
44#include <llvm/Support/CommandLine.h>
45#include <llvm/Support/FormatVariadic.h>
46#include <llvm/TableGen/Record.h>
47#include <llvm/TableGen/TableGenBackend.h>
48
49using namespace mlir;
50using namespace mlir::tblgen;
51
52namespace {
53
59struct OpTestGenerator : public TestGenerator {
62 OpTestGenerator(llvm::raw_ostream &outputStream) : TestGenerator("Operation", outputStream) {}
63
69 std::string genCleanup() const override { return "mlirOperationDestroy"; };
70
73 void genBuildOpTests(const Operator &op) const {
74 static constexpr char fmt[] = R"(
76TEST_F({1}OperationLinkTests, {0}_{2}_Build) {{
77 // Returns an `arith.constant` op, which will never match the {2} dialect check.
78 auto testOp = createIndexOperation();
79
80 // This condition is always false, so the function is never actually called.
81 // We only verify it compiles and links correctly.
82 if ({0}OperationIsA_{1}_{2}(testOp)) {{
83 MlirOpBuilder builder = mlirOpBuilderCreate(context);
84 MlirLocation location = mlirLocationUnknownGet(context);
85{3}
86 (void){0}{1}_{2}Build(builder, location{4});
87 // No need to destroy builder or op since this code never runs.
88 }
89
90 mlirOperationDestroy(testOp);
91}
92
93struct {2}BuildFuncHelper : public TestAnyBuildFuncHelper<CAPITest> {
94 virtual bool callIsA(MlirOperation op) override { return {0}OperationIsA_{1}_{2}(op); }
98 static std::unique_ptr<{2}BuildFuncHelper> get();
99
100protected:
101 {2}BuildFuncHelper() = default;
102};
106TEST_F(CAPITest, {2}_build_pass) { {2}BuildFuncHelper::get()->run(*this); }
107)";
108
109 assert(!className.empty() && "className must be set");
110 os << llvm::formatv(
111 fmt,
112 FunctionPrefix, // {0}
113 dialectNameCapitalized, // {1}
114 className, // {2}
115 generateBuildDummyParams(op), // {3}
116 generateBuildParamList(op) // {4}
117 );
118 }
119
122 void genOperandTests(const Operator &op) const {
123 static constexpr char OperandGetterTest[] = R"(
124TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}) {{
125 auto testOp = createIndexOperation();
126
127 if ({0}OperationIsA_{1}_{2}(testOp)) {{
128 (void){0}{1}_{2}Get{3}(testOp);
129 }
130
131 mlirOperationDestroy(testOp);
132}
133)";
134
135 static constexpr char OperandSetterTest[] = R"(
136TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}) {{
137 auto testOp = createIndexOperation();
138
139 if ({0}OperationIsA_{1}_{2}(testOp)) {{
140 auto dummyValue = mlirOperationGetResult(testOp, 0);
141 {0}{1}_{2}Set{3}(testOp, dummyValue);
142 }
143
144 mlirOperationDestroy(testOp);
145}
146)";
147
148 static constexpr char VariadicOperandGetterTest[] = R"(
149TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
150 auto testOp = createIndexOperation();
151
152 if ({0}OperationIsA_{1}_{2}(testOp)) {{
153 (void){0}{1}_{2}Get{3}Count(testOp);
154 }
155
156 mlirOperationDestroy(testOp);
157}
158
159TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
160 auto testOp = createIndexOperation();
161
162 if ({0}OperationIsA_{1}_{2}(testOp)) {{
163 (void){0}{1}_{2}Get{3}At(testOp, 0);
164 }
165
166 mlirOperationDestroy(testOp);
167}
168)";
169
170 static constexpr char VariadicOperandSetterTest[] = R"(
171TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}_Variadic) {{
172 auto testOp = createIndexOperation();
173
174 if ({0}OperationIsA_{1}_{2}(testOp)) {{
175 auto dummyValue = mlirOperationGetResult(testOp, 0);
176 MlirValue values[] = {{dummyValue};
177 {0}{1}_{2}Set{3}(testOp, 1, values);
178 }
179
180 mlirOperationDestroy(testOp);
181}
182)";
183 assert(!className.empty() && "className must be set");
184
185 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
186 const auto &operand = op.getOperand(i);
187 std::string capName = toPascalCase(operand.name);
188 if (operand.isVariadic()) {
190 os << llvm::formatv(
191 VariadicOperandGetterTest,
192 FunctionPrefix, // {0}
193 dialectNameCapitalized, // {1}
194 className, // {2}
195 capName // {3}
196 );
197 }
199 os << llvm::formatv(
200 VariadicOperandSetterTest,
201 FunctionPrefix, // {0}
202 dialectNameCapitalized, // {1}
203 className, // {2}
204 capName // {3}
205 );
206 }
207 } else {
209 os << llvm::formatv(
210 OperandGetterTest,
211 FunctionPrefix, // {0}
212 dialectNameCapitalized, // {1}
213 className, // {2}
214 capName // {3}
215 );
216 }
218 os << llvm::formatv(
219 OperandSetterTest,
220 FunctionPrefix, // {0}
221 dialectNameCapitalized, // {1}
222 className, // {2}
223 capName // {3}
224 );
225 }
226 }
227 }
228 }
229
232 void genAttributeTests(const Operator &op) const {
233 static constexpr char AttributeGetterTest[] = R"(
234TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Attr) {{
235 auto testOp = createIndexOperation();
236
237 if ({0}OperationIsA_{1}_{2}(testOp)) {{
238 (void){0}{1}_{2}Get{3}(testOp);
239 }
240
241 mlirOperationDestroy(testOp);
242}
243)";
244
245 static constexpr char AttributeSetterTest[] = R"(
246TEST_F({1}OperationLinkTests, {0}_{2}_Set{3}Attr) {{
247 auto testOp = createIndexOperation();
248
249 if ({0}OperationIsA_{1}_{2}(testOp)) {{
250 {0}{1}_{2}Set{3}(testOp, createIndexAttribute());
251 }
252
253 mlirOperationDestroy(testOp);
254}
255)";
256 assert(!className.empty() && "className must be set");
257
258 for (const auto &namedAttr : op.getAttributes()) {
259 std::string capName = toPascalCase(namedAttr.name);
261 os << llvm::formatv(
262 AttributeGetterTest,
263 FunctionPrefix, // {0}
264 dialectNameCapitalized, // {1}
265 className, // {2}
266 capName // {3}
267 );
268 }
270 os << llvm::formatv(
271 AttributeSetterTest,
272 FunctionPrefix, // {0}
273 dialectNameCapitalized, // {1}
274 className, // {2}
275 capName // {3}
276 );
277 }
278 }
279 }
280
283 void genResultTests(const Operator &op) const {
284 static constexpr char ResultGetterTest[] = R"(
285TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}) {{
286 auto testOp = createIndexOperation();
287
288 if ({0}OperationIsA_{1}_{2}(testOp)) {{
289 (void){0}{1}_{2}Get{3}(testOp);
290 }
291
292 mlirOperationDestroy(testOp);
293}
294)";
295
296 static constexpr char VariadicResultGetterTest[] = R"(
297TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
298 auto testOp = createIndexOperation();
299
300 if ({0}OperationIsA_{1}_{2}(testOp)) {{
301 (void){0}{1}_{2}Get{3}Count(testOp);
302 }
303
304 mlirOperationDestroy(testOp);
305}
306
307TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
308 auto testOp = createIndexOperation();
309
310 if ({0}OperationIsA_{1}_{2}(testOp)) {{
311 (void){0}{1}_{2}Get{3}At(testOp, 0);
312 }
313
314 mlirOperationDestroy(testOp);
315}
316)";
317 assert(!className.empty() && "className must be set");
318
319 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
320 const auto &result = op.getResult(i);
321 llvm::StringRef name = result.name;
322 std::string capName = name.empty() ? llvm::formatv("Result{0}", i).str() : toPascalCase(name);
323
324 if (result.isVariadic()) {
325 os << llvm::formatv(
326 VariadicResultGetterTest,
327 FunctionPrefix, // {0}
328 dialectNameCapitalized, // {1}
329 className, // {2}
330 capName // {3}
331 );
332 } else {
333 os << llvm::formatv(
334 ResultGetterTest,
335 FunctionPrefix, // {0}
336 dialectNameCapitalized, // {1}
337 className, // {2}
338 capName // {3}
339 );
340 }
341 }
342 }
343
346 void genRegionTests(const Operator &op) const {
347 static constexpr char RegionGetterTest[] = R"(
348TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Region) {{
349 auto testOp = createIndexOperation();
350
351 if ({0}OperationIsA_{1}_{2}(testOp)) {{
352 (void){0}{1}_{2}Get{3}(testOp);
353 }
354
355 mlirOperationDestroy(testOp);
356}
357)";
358
359 static constexpr char VariadicRegionGetterTest[] = R"(
360TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}Count) {{
361 auto testOp = createIndexOperation();
362
363 if ({0}OperationIsA_{1}_{2}(testOp)) {{
364 (void){0}{1}_{2}Get{3}Count(testOp);
365 }
366
367 mlirOperationDestroy(testOp);
368}
369
370TEST_F({1}OperationLinkTests, {0}_{2}_Get{3}At) {{
371 auto testOp = createIndexOperation();
372
373 if ({0}OperationIsA_{1}_{2}(testOp)) {{
374 (void){0}{1}_{2}Get{3}At(testOp, 0);
375 }
376
377 mlirOperationDestroy(testOp);
378}
379)";
380 assert(!className.empty() && "className must be set");
381
382 for (unsigned int i = 0, e = op.getNumRegions(); i < e; ++i) {
383 const auto &region = op.getRegion(i);
384 llvm::StringRef name = region.name;
385 std::string capName = name.empty() ? llvm::formatv("Region{0}", i).str() : toPascalCase(name);
386
387 if (region.isVariadic()) {
388 os << llvm::formatv(
389 VariadicRegionGetterTest,
390 FunctionPrefix, // {0}
391 dialectNameCapitalized, // {1}
392 className, // {2}
393 capName // {3}
394 );
395 } else {
396 os << llvm::formatv(
397 RegionGetterTest,
398 FunctionPrefix, // {0}
399 dialectNameCapitalized, // {1}
400 className, // {2}
401 capName // {3}
402 );
403 }
404 }
405 }
406
409 void genCompleteRecord(const Operator &op) {
410 const Dialect &defDialect = op.getDialect();
411
412 // Generate for the selected dialect only
413 if (defDialect.getName() != DialectName) {
414 return;
415 }
416
417 this->setNamespaceAndClassName(defDialect, op.getCppClassName());
418
419 if (GenIsA) {
420 this->genIsATest();
421 }
422 if (GenOpBuild && !op.skipDefaultBuilders()) {
423 this->genBuildOpTests(op);
424 }
426 this->genOperandTests(op);
427 }
429 this->genAttributeTests(op);
430 }
431 if (GenOpRegionGetters) {
432 this->genRegionTests(op);
433 }
434 if (GenOpResultGetters) {
435 this->genResultTests(op);
436 }
438 this->genExtraMethods(op.getExtraClassDeclaration());
439 }
440 }
441
442private:
446 static std::string generateBuildDummyParams(const Operator &op) {
447 struct : GenStringFromOpPieces {
448 void genHeader(llvm::raw_ostream &os) override {
449 // Declare dummyValue first
450 os << " auto dummyValue = mlirOperationGetResult(testOp, 0);\n";
451 }
452 void genResult(
453 llvm::raw_ostream &os, const NamedTypeConstraint &result, const std::string &resultName
454 ) override {
455 if (result.isVariadic()) {
456 os << llvm::formatv(
457 " auto {0}TypeArray = createIndexType();\n"
458 " MlirType {0}Types[] = {{{0}TypeArray};\n"
459 " intptr_t {0}Size = 0;\n",
460 resultName
461 );
462 } else {
463 os << llvm::formatv(" auto {0}Type = createIndexType();\n", resultName);
464 }
465 }
466 void genOperand(llvm::raw_ostream &os, const NamedTypeConstraint &operand) override {
467 // per `generateParamList()` only need to create something additional in case
468 // of variadic operand, otherwise `dummyValue` is used directly.
469 if (operand.isVariadic()) {
470 os << llvm::formatv(
471 " MlirValue {0}Values[] = {{dummyValue};\n"
472 " intptr_t {0}Size = 0;\n",
473 operand.name
474 );
475 }
476 }
477 void genAttribute(llvm::raw_ostream &os, const NamedAttribute &attr) override {
478 std::string rhs;
479 std::optional<std::string> attrType = tryCppTypeToCapiType(attr.attr.getStorageType());
480 if (attrType.has_value() && attrType.value() == "MlirIdentifier") {
481 rhs = "mlirOperationGetName(testOp)";
482 } else {
483 rhs = "createIndexAttribute()";
484 }
485 os << llvm::formatv(" auto {0}Attr = {1};\n", attr.name, rhs);
486 }
487 void genRegion(llvm::raw_ostream &os, const mlir::tblgen::NamedRegion &region) override {
488 if (region.isVariadic()) {
489 os << llvm::formatv(" unsigned {0}Count = 0;\n", region.name);
490 }
491 }
492 } paramsStringGenerator;
493 return paramsStringGenerator.gen(op);
494 }
495
499 static std::string generateBuildParamList(const Operator &op) {
500 struct : GenStringFromOpPieces {
501 void genResult(
502 llvm::raw_ostream &os, const NamedTypeConstraint &result, const std::string &resultName
503 ) override {
504 if (result.isVariadic()) {
505 os << llvm::formatv(", {0}Size, {0}Types", resultName);
506 } else {
507 os << llvm::formatv(", {0}Type", resultName);
508 }
509 }
510 void genOperand(llvm::raw_ostream &os, const NamedTypeConstraint &operand) override {
511 if (operand.isVariadic()) {
512 os << llvm::formatv(", {0}Size, {0}Values", operand.name);
513 } else {
514 os << ", dummyValue";
515 }
516 }
517 void genAttribute(llvm::raw_ostream &os, const NamedAttribute &attr) override {
518 os << llvm::formatv(", {0}Attr", attr.name);
519 }
520 void genRegion(llvm::raw_ostream &os, const mlir::tblgen::NamedRegion &region) override {
521 if (region.isVariadic()) {
522 os << llvm::formatv(", {0}Count", region.name);
523 }
524 }
525 } paramsStringGenerator;
526 return paramsStringGenerator.gen(op);
527 }
528};
529
530} // namespace
531
533static bool emitOpCAPITests(const llvm::RecordKeeper &records, raw_ostream &os) {
534 // Generate file header
535 emitSourceFileHeader("Op C API Tests", os, records);
536
537 // Create generator
538 OpTestGenerator generator(os);
539
540 // Generate test class prologue
541 generator.genTestClassPrologue();
542
543 // Generate tests for each operation
544 for (const auto *def : records.getAllDerivedDefinitions("Op")) {
545 Operator op(def);
546 generator.genCompleteRecord(op);
547 }
548
549 return false;
550}
551
552static mlir::GenRegistration
553 genOpCAPITests("gen-op-capi-tests", "Generate operation C API unit tests", &emitOpCAPITests);
std::optional< std::string > tryCppTypeToCapiType(StringRef cppType)
Convert C++ type to MLIR C API type.
llvm::cl::opt< bool > GenOpOperandSetters
llvm::cl::opt< bool > GenIsA
llvm::cl::opt< bool > GenOpBuild
llvm::cl::opt< std::string > DialectName
llvm::cl::opt< std::string > FunctionPrefix
llvm::cl::opt< bool > GenOpRegionGetters
llvm::cl::opt< bool > GenOpResultGetters
llvm::cl::opt< bool > GenOpAttributeGetters
llvm::cl::opt< bool > GenOpAttributeSetters
llvm::cl::opt< bool > GenOpOperandGetters
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
Generator for common test implementation file elements.