LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
OpCAPIGen.cpp
Go to the documentation of this file.
1//===- OpCAPIGen.cpp - C API generator for operations ---------------------===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9//
10// OpCAPIGen uses the description of operations to generate C API for the ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "CommonCAPIGen.h"
15#include "OpCAPIParamHelper.h"
16
17#include <mlir/TableGen/GenInfo.h>
18#include <mlir/TableGen/Operator.h>
19
20#include <llvm/Support/CommandLine.h>
21#include <llvm/Support/FormatVariadic.h>
22#include <llvm/TableGen/Error.h>
23#include <llvm/TableGen/Record.h>
24#include <llvm/TableGen/TableGenBackend.h>
25
26#include <string>
27#include <vector>
28
29using namespace mlir;
30using namespace mlir::tblgen;
31
34 void setOperandName(mlir::StringRef name) { this->operandNameCapitalized = toPascalCase(name); }
35 void setAttributeName(mlir::StringRef name) { this->attrNameCapitalized = toPascalCase(name); }
36 void setResultName(mlir::StringRef name, int resultIndex) {
38 name.empty() ? llvm::formatv("Result{0}", resultIndex).str() : toPascalCase(name);
39 }
40 void setRegionName(mlir::StringRef name, unsigned regionIndex) {
42 name.empty() ? llvm::formatv("Region{0}", regionIndex).str() : toPascalCase(name);
43 }
44
45protected:
50};
51
54 using HeaderGenerator::HeaderGenerator;
55 ~OpHeaderGenerator() override = default;
56
57 void genOpBuildDecl(const std::string &params) const {
58 static constexpr char fmt[] = R"(
60MLIR_CAPI_EXPORTED MlirOperation {0}{1}_{2}Build(MlirOpBuilder builder, MlirLocation location{3});
61)";
62 assert(!className.empty() && "className must be set");
63 os << llvm::formatv(
64 fmt,
65 FunctionPrefix, // {0}
67 className, // {2}
68 params, // {3}
69 dialectNamespace // {4}
70 );
71 }
72
73 void genOperandGetterDecl() const {
74 static constexpr char fmt[] = R"(
76MLIR_CAPI_EXPORTED MlirValue {0}{1}_{2}Get{3}(MlirOperation op);
77)";
78 assert(!className.empty() && "className must be set");
79 assert(!operandNameCapitalized.empty() && "operandName must be set");
80 os << llvm::formatv(
81 fmt,
82 FunctionPrefix, // {0}
84 className, // {2}
86 dialectNamespace // {4}
87 );
88 }
89
90 void genOperandSetterDecl() const {
91 static constexpr char fmt[] = R"(
93MLIR_CAPI_EXPORTED void {0}{1}_{2}Set{3}(MlirOperation op, MlirValue value);
94)";
95 assert(!className.empty() && "className must be set");
96 assert(!operandNameCapitalized.empty() && "operandName must be set");
97 os << llvm::formatv(
98 fmt,
99 FunctionPrefix, // {0}
101 className, // {2}
103 dialectNamespace // {4}
104 );
105 }
106
107 void genVariadicOperandGetterDecl() const {
108 static constexpr char fmt[] = R"(
110MLIR_CAPI_EXPORTED intptr_t {0}{1}_{2}Get{3}Count(MlirOperation op);
111
112
113MLIR_CAPI_EXPORTED MlirValue {0}{1}_{2}Get{3}At(MlirOperation op, intptr_t index);
114)";
115 assert(!className.empty() && "className must be set");
116 assert(!operandNameCapitalized.empty() && "operandName must be set");
117 os << llvm::formatv(
118 fmt,
119 FunctionPrefix, // {0}
121 className, // {2}
123 dialectNamespace // {4}
124
125 );
126 }
127
129 static constexpr char fmt[] = R"(
131MLIR_CAPI_EXPORTED void {0}{1}_{2}Set{3}(MlirOperation op, intptr_t count, MlirValue const *values);
132)";
133 assert(!className.empty() && "className must be set");
134 assert(!operandNameCapitalized.empty() && "operandName must be set");
135 os << llvm::formatv(
136 fmt,
137 FunctionPrefix, // {0}
139 className, // {2}
141 dialectNamespace // {4}
143 );
144 }
145
147 static constexpr char fmt[] = R"(
151MLIR_CAPI_EXPORTED void {0}{1}_{2}Set{3}(MlirOperation op, intptr_t groupCount, MlirValueRange const *groups);
152)";
153 assert(!className.empty() && "className must be set");
154 assert(!operandNameCapitalized.empty() && "operandName must be set");
155 os << llvm::formatv(
156 fmt,
157 FunctionPrefix, // {0}
159 className, // {2}
161 dialectNamespace // {4}
162 );
163 }
164
165 void genAttributeGetterDecl() const {
166 static constexpr char fmt[] = R"(
168MLIR_CAPI_EXPORTED MlirAttribute {0}{1}_{2}Get{3}(MlirOperation op);
169)";
170 assert(!className.empty() && "className must be set");
171 assert(!attrNameCapitalized.empty() && "attrName must be set");
172 os << llvm::formatv(
173 fmt,
174 FunctionPrefix, // {0}
176 className, // {2}
177 attrNameCapitalized, // {3}
178 dialectNamespace // {4}
179 );
180 }
181
182 void genAttributeSetterDecl() const {
183 static constexpr char fmt[] = R"(
184/// Set {3} attribute of {4}::{2} Operation.
185MLIR_CAPI_EXPORTED void {0}{1}_{2}Set{3}(MlirOperation op, MlirAttribute attr);
186)";
187 assert(!className.empty() && "className must be set");
188 assert(!attrNameCapitalized.empty() && "attrName must be set");
189 os << llvm::formatv(
190 fmt,
191 FunctionPrefix, // {0}
193 className, // {2}
194 attrNameCapitalized, // {3}
195 dialectNamespace // {4}
196 );
197 }
199 void genResultGetterDecl() const {
200 static constexpr char fmt[] = R"(
202MLIR_CAPI_EXPORTED MlirValue {0}{1}_{2}Get{3}(MlirOperation op);
203)";
204 assert(!className.empty() && "className must be set");
205 assert(!resultNameCapitalized.empty() && "resultName must be set");
206 os << llvm::formatv(
207 fmt,
208 FunctionPrefix, // {0}
210 className, // {2}
213 );
214 }
215
216 void genVariadicResultGetterDecl() const {
217 static constexpr char fmt[] = R"(
219MLIR_CAPI_EXPORTED intptr_t {0}{1}_{2}Get{3}Count(MlirOperation op);
220
222MLIR_CAPI_EXPORTED MlirValue {0}{1}_{2}Get{3}At(MlirOperation op, intptr_t index);
223)";
224 assert(!className.empty() && "className must be set");
225 assert(!resultNameCapitalized.empty() && "resultName must be set");
226 os << llvm::formatv(
227 fmt,
228 FunctionPrefix, // {0}
230 className, // {2}
232 dialectNamespace // {4}
233 );
234 }
235
236 void genRegionGetterDecl() const {
237 static constexpr char fmt[] = R"(
239MLIR_CAPI_EXPORTED MlirRegion {0}{1}_{2}Get{3}(MlirOperation op);
240)";
241 assert(!className.empty() && "className must be set");
242 assert(!regionNameCapitalized.empty() && "regionName must be set");
243 os << llvm::formatv(
244 fmt,
245 FunctionPrefix, // {0}
247 className, // {2}
249 dialectNamespace // {4}
250 );
251 }
252
253 void genVariadicRegionGetterDecl() const {
254 static constexpr char fmt[] = R"(
256MLIR_CAPI_EXPORTED intptr_t {0}{1}_{2}Get{3}Count(MlirOperation op);
257
259MLIR_CAPI_EXPORTED MlirRegion {0}{1}_{2}Get{3}At(MlirOperation op, intptr_t index);
260)";
261 assert(!className.empty() && "className must be set");
262 assert(!regionNameCapitalized.empty() && "regionName must be set");
263 os << llvm::formatv(
264 fmt,
265 FunctionPrefix, // {0}
267 className, // {2}
269 dialectNamespace // {4}
270 );
271 }
272};
273
279static std::string generateCAPIBuildParams(const Operator &op) {
280 struct : GenStringFromOpPieces {
281 void genResult(
282 llvm::raw_ostream &os, const NamedTypeConstraint &result, const std::string &resultName
283 ) override {
284 if (result.isVariadic()) {
285 os << llvm::formatv(", intptr_t {0}Size, MlirType const *{0}Types", resultName);
286 } else {
287 os << llvm::formatv(", MlirType {0}Type", resultName);
288 }
289 }
290 void genOperand(llvm::raw_ostream &os, const NamedTypeConstraint &operand) override {
291 if (operand.isVariadic()) {
292 os << llvm::formatv(", intptr_t {0}Size, MlirValue const *{0}", operand.name);
293 } else {
294 os << llvm::formatv(", MlirValue {0}", operand.name);
295 }
296 }
297 void genAttribute(llvm::raw_ostream &os, const NamedAttribute &attr) override {
298 std::optional<std::string> attrType = tryCppTypeToCapiType(attr.attr.getStorageType());
299 os << llvm::formatv(", {0} {1}", attrType.value_or("MlirAttribute"), attr.name);
300 }
301 void genRegion(llvm::raw_ostream &os, const mlir::tblgen::NamedRegion &region) override {
302 if (region.isVariadic()) {
303 os << llvm::formatv(", unsigned {0}Count", region.name);
304 }
305 }
306 } paramStringGenerator;
307 return paramStringGenerator.gen(op);
308}
309
311static bool emitOpCAPIHeader(const llvm::RecordKeeper &records, raw_ostream &os) {
312 emitSourceFileHeader("Op C API Declarations", os, records);
313
314 OpHeaderGenerator generator("Operation", os);
315 generator.genPrologue();
316
317 for (const auto *def : records.getAllDerivedDefinitions("Op")) {
318 const Operator op(def);
319 const Dialect &dialect = op.getDialect();
320
321 // Generate for the selected dialect only (specified via -dialect command-line option)
322 if (dialect.getName() != DialectName) {
323 continue;
324 }
325
326 generator.setNamespaceAndClassName(dialect, op.getCppClassName());
327
328 // Generate "Build" function
329 if (GenOpBuild && !op.skipDefaultBuilders()) {
330 generator.genOpBuildDecl(generateCAPIBuildParams(op));
331 }
332
333 // Generate IsA check
334 if (GenIsA) {
335 generator.genIsADecl();
336 }
337
338 // Generate operand getters and setters
339 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
340 const auto &operand = op.getOperand(i);
341 generator.setOperandName(operand.name);
342 if (operand.isVariadic()) {
344 generator.genVariadicOperandGetterDecl();
345 }
347 if (operand.isVariadicOfVariadic()) {
348 generator.genVariadicOfVariadicOperandSetterDecl();
349 } else {
350 generator.genVariadicOperandSetterDecl();
351 }
352 }
353 } else {
355 generator.genOperandGetterDecl();
356 }
358 generator.genOperandSetterDecl();
359 }
360 }
361 }
362
363 // Generate attribute getters and setters
364 for (const auto &namedAttr : op.getAttributes()) {
365 generator.setAttributeName(namedAttr.name);
367 generator.genAttributeGetterDecl();
368 }
370 generator.genAttributeSetterDecl();
371 }
372 }
373
374 // Generate result getters
375 if (GenOpResultGetters) {
376 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
377 const auto &result = op.getResult(i);
378 generator.setResultName(result.name, i);
379 if (result.isVariadic()) {
380 generator.genVariadicResultGetterDecl();
381 } else {
382 generator.genResultGetterDecl();
383 }
384 }
385 }
386
387 // Generate region getters
388 if (GenOpRegionGetters) {
389 for (unsigned i = 0, e = op.getNumRegions(); i < e; ++i) {
390 const auto &region = op.getRegion(i);
391 generator.setRegionName(region.name, i);
392 if (region.isVariadic()) {
393 generator.genVariadicRegionGetterDecl();
394 } else {
395 generator.genRegionGetterDecl();
396 }
397 }
398 }
399
400 // Generate extra class method wrappers
402 generator.genExtraMethods(op.getExtraClassDeclaration());
403 }
404 }
405
406 generator.genEpilogue();
407 return false;
408}
409
412 using ImplementationGenerator::ImplementationGenerator;
413 ~OpImplementationGenerator() override = default;
414
415 void genPrologue() const override {
416 os << R"(
417#include <limits>
418
419using namespace mlir;
420using namespace llvm;
421)";
422 }
428 void genOpBuildImpl(
429 const std::string &operationName, const std::string &params, const std::string &assignments
430 ) const {
431 static constexpr char fmt[] = R"(
432MlirOperation {0}{1}_{2}Build(MlirOpBuilder builder, MlirLocation location{3}) {{
433 MlirOperationState state = mlirOperationStateGet(mlirStringRefCreateFromCString("{4}"), location);
434{5}
435 return mlirOpBuilderInsert(builder, mlirOperationCreate(&state));
436}
437)";
438 assert(!className.empty() && "className must be set");
439 os << llvm::formatv(
440 fmt,
441 FunctionPrefix, // {0}
443 className, // {2}
444 params, // {3}
445 operationName, // {4}
446 assignments // {5}
447 );
448 }
449
450 void genOperandGetterImpl(int index) const {
451 static constexpr char fmt[] = R"(
452MlirValue {0}{1}_{2}Get{3}(MlirOperation op) {{
453 auto range = llvm::cast<{2}>(unwrap(op)).getODSOperandIndexAndLength({4});
454 assert(range.second == 1 && "expected fixed operand segment size");
455 assert(
456 static_cast<uintptr_t>(range.first) <= static_cast<uintptr_t>(std::numeric_limits<intptr_t>::max()) &&
457 "operand index exceeds intptr_t range"
458 );
459 return mlirOperationGetOperand(op, static_cast<intptr_t>(range.first));
460}
461)";
462 assert(!className.empty() && "className must be set");
463 assert(!operandNameCapitalized.empty() && "operandName must be set");
464 os << llvm::formatv(
465 fmt,
466 FunctionPrefix, // {0}
468 className, // {2}
470 index // {4}
471 );
472 }
473
474 void genOperandSetterImpl(int index) const {
475 static constexpr char fmt[] = R"(
476void {0}{1}_{2}Set{3}(MlirOperation op, MlirValue value) {{
477 auto range = llvm::cast<{2}>(unwrap(op)).getODSOperandIndexAndLength({4});
478 assert(range.second == 1 && "expected fixed operand segment size");
479 assert(
480 static_cast<uintptr_t>(range.first) <= static_cast<uintptr_t>(std::numeric_limits<intptr_t>::max()) &&
481 "operand index exceeds intptr_t range"
482 );
483 mlirOperationSetOperand(op, static_cast<intptr_t>(range.first), value);
484}
485)";
486 assert(!className.empty() && "className must be set");
487 assert(!operandNameCapitalized.empty() && "operandName must be set");
488 os << llvm::formatv(
489 fmt,
490 FunctionPrefix, // {0}
492 className, // {2}
494 index // {4}
495 );
496 }
497
498 void genVariadicOperandGetterImpl(int index) const {
499 static constexpr char fmt[] = R"(
500intptr_t {0}{1}_{2}Get{3}Count(MlirOperation op) {{
501 auto range = llvm::cast<{2}>(unwrap(op)).getODSOperandIndexAndLength({4});
502 return range.second;
503}
504
505MlirValue {0}{1}_{2}Get{3}At(MlirOperation op, intptr_t index) {{
506 auto range = llvm::cast<{2}>(unwrap(op)).getODSOperandIndexAndLength({4});
507 assert(index >= 0 && index < range.second && "variadic operand index out of range");
508 assert(
509 static_cast<uintptr_t>(range.first) <= static_cast<uintptr_t>(std::numeric_limits<intptr_t>::max()) &&
510 "operand index exceeds intptr_t range"
511 );
512 return mlirOperationGetOperand(op, static_cast<intptr_t>(range.first) + index);
513}
514)";
515 assert(!className.empty() && "className must be set");
516 assert(!operandNameCapitalized.empty() && "operandName must be set");
517 os << llvm::formatv(
518 fmt,
519 FunctionPrefix, // {0}
521 className, // {2}
523 index // {4}
524 );
525 }
526
527 // Delegate to the ODS-generated mutable accessor. assign() keeps operandSegmentSizes in sync
528 // automatically via MutableOperandRange::updateLength.
529 void genVariadicOperandSetterImpl() const {
530 static constexpr char fmt[] = R"(
531void {0}{1}_{2}Set{3}(MlirOperation op, intptr_t count, MlirValue const *values) {{
532 if (count < 0)
533 return;
534 ::llvm::SmallVector<::mlir::Value> vals;
535 vals.reserve(static_cast<size_t>(count));
536 for (intptr_t i = 0; i < count; ++i)
537 vals.push_back(unwrap(values[i]));
538 ::llvm::cast<{2}>(unwrap(op)).get{3}Mutable().assign(vals);
539}
540)";
541 assert(!className.empty() && "className must be set");
542 assert(!operandNameCapitalized.empty() && "operandName must be set");
543 os << llvm::formatv(
544 fmt,
545 FunctionPrefix, // {0}
547 className, // {2}
549 );
550 }
551
552 // For VariadicOfVariadic operands the caller passes one MlirValueRange per group. The flat
553 // operand list is rebuilt from all groups, operandSegmentSizes is updated automatically by
554 // join().assign(), and the per-group segment-size attribute (whose name is read from the
555 // TableGen constraint via getVariadicOfVariadicSegmentSizeAttr()) is updated explicitly.
556 void genVariadicOfVariadicOperandSetterImpl(mlir::StringRef segSizeAttrName) const {
557 static constexpr char fmt[] = R"(
558void {0}{1}_{2}Set{3}(MlirOperation op, intptr_t groupCount, MlirValueRange const *groups) {{
559 if (groupCount < 0)
560 return;
561
562 ::llvm::SmallVector<::mlir::Value> vals;
563 for (intptr_t g = 0; g < groupCount; ++g) {{
564 assert(groups[g].size >= 0 && "group size must be non-negative");
565 for (intptr_t i = 0; i < groups[g].size; ++i) {{
566 vals.push_back(unwrap(groups[g].values[i]));
567 }
568 }
569 ::llvm::cast<{2}>(unwrap(op)).get{3}Mutable().join().assign(vals);
571 ::llvm::SmallVector<int32_t> newGroupSizes;
572 newGroupSizes.reserve(static_cast<size_t>(groupCount));
573 for (intptr_t g = 0; g < groupCount; ++g) {{
574 assert(
575 groups[g].size <= static_cast<intptr_t>(std::numeric_limits<int32_t>::max()) &&
576 "group size exceeds int32_t range"
577 );
578 newGroupSizes.push_back(static_cast<int32_t>(groups[g].size));
579 }
580 MlirContext ctx = mlirOperationGetContext(op);
581 assert(
582 newGroupSizes.size() <= static_cast<size_t>(std::numeric_limits<intptr_t>::max()) &&
583 "group count exceeds intptr_t range"
584 );
585 mlirOperationSetAttributeByName(
586 op, mlirStringRefCreateFromCString("{4}"),
587 mlirDenseI32ArrayGet(ctx, static_cast<intptr_t>(newGroupSizes.size()), newGroupSizes.data())
588 );
589}
590)";
591 assert(!className.empty() && "className must be set");
592 assert(!operandNameCapitalized.empty() && "operandName must be set");
593 os << llvm::formatv(
594 fmt,
595 FunctionPrefix, // {0}
597 className, // {2}
599 segSizeAttrName // {4}
600 );
601 }
602
603 void genAttributeGetterImpl(mlir::StringRef attrName) const {
604 static constexpr char fmt[] = R"(
605MlirAttribute {0}{1}_{2}Get{3}(MlirOperation op) {{
606 return mlirOperationGetAttributeByName(op, mlirStringRefCreateFromCString("{4}"));
607}
608)";
609 assert(!className.empty() && "className must be set");
610 assert(!attrNameCapitalized.empty() && "attrName must be set");
611 os << llvm::formatv(
612 fmt,
613 FunctionPrefix, // {0}
615 className, // {2}
616 attrNameCapitalized, // {3}
617 attrName // {4}
618 );
619 }
620
621 void genAttributeSetterImpl(mlir::StringRef attrName) const {
622 static constexpr char fmt[] = R"(
623void {0}{1}_{2}Set{3}(MlirOperation op, MlirAttribute attr) {{
624 mlirOperationSetAttributeByName(op, mlirStringRefCreateFromCString("{4}"), attr);
625}
626)";
627 assert(!className.empty() && "className must be set");
628 assert(!attrNameCapitalized.empty() && "attrName must be set");
629 os << llvm::formatv(
630 fmt,
631 FunctionPrefix, // {0}
633 className, // {2}
634 attrNameCapitalized, // {3}
635 attrName // {4}
636 );
637 }
638
639 void genResultGetterImpl(int index) const {
640 static constexpr char fmt[] = R"(
641MlirValue {0}{1}_{2}Get{3}(MlirOperation op) {{
642 return mlirOperationGetResult(op, {4});
643}
644)";
645 assert(!className.empty() && "className must be set");
646 assert(!resultNameCapitalized.empty() && "resultName must be set");
647 os << llvm::formatv(
648 fmt,
649 FunctionPrefix, // {0}
651 className, // {2}
653 index // {4}
654 );
655 }
656
657 void genVariadicResultGetterImpl(int startIdx) const {
658 static constexpr char fmt[] = R"(
659intptr_t {0}{1}_{2}Get{3}Count(MlirOperation op) {{
660 intptr_t count = mlirOperationGetNumResults(op);
661 assert(count >= {4} && "result count less than start index");
662 return count - {4};
663}
664
665MlirValue {0}{1}_{2}Get{3}At(MlirOperation op, intptr_t index) {{
666 return mlirOperationGetResult(op, {4} + index);
667}
668)";
669 assert(!className.empty() && "className must be set");
670 assert(!resultNameCapitalized.empty() && "resultName must be set");
671 os << llvm::formatv(
672 fmt,
673 FunctionPrefix, // {0}
675 className, // {2}
677 startIdx // {4}
678 );
679 }
680
681 void genRegionGetterImpl(unsigned index) const {
682 static constexpr char fmt[] = R"(
683MlirRegion {0}{1}_{2}Get{3}(MlirOperation op) {{
684 return mlirOperationGetRegion(op, {4});
685}
686)";
687 assert(!className.empty() && "className must be set");
688 assert(!regionNameCapitalized.empty() && "regionName must be set");
689 os << llvm::formatv(
690 fmt,
691 FunctionPrefix, // {0}
693 className, // {2}
695 index // {4}
696 );
697 }
698
699 void genVariadicRegionGetterImpl(unsigned startIdx) const {
700 static constexpr char fmt[] = R"(
701intptr_t {0}{1}_{2}Get{3}Count(MlirOperation op) {{
702 intptr_t count = mlirOperationGetNumRegions(op);
703 assert(count >= {4} && "region count less than start index");
704 return count - {4};
705}
706
707MlirRegion {0}{1}_{2}Get{3}At(MlirOperation op, intptr_t index) {{
708 return mlirOperationGetRegion(op, {4} + index);
709}
710)";
711 assert(!className.empty() && "className must be set");
712 assert(!regionNameCapitalized.empty() && "regionName must be set");
713 os << llvm::formatv(
714 fmt,
715 FunctionPrefix, // {0}
717 className, // {2}
719 startIdx // {4}
720 );
721 }
722};
723
729static std::string generateCAPIAssignments(const Operator &op) {
730 // Code generated here can use the following variables:
731 // - MlirOpBuilder builder
732 // - MlirLocation location
733 // - MlirOperationState state
734 // - Operand/Attribute/Result parameters per `generateCAPIBuildParams()`
735 struct : GenStringFromOpPieces {
736 void genResultInferred(llvm::raw_ostream &os) override {
737 os << " mlirOperationStateEnableResultTypeInference(&state);\n";
738 }
739 void genResult(
740 llvm::raw_ostream &os, const NamedTypeConstraint &result, const std::string &resultName
741 ) override {
742 if (result.isVariadic()) {
743 os << llvm::formatv(
744 " mlirOperationStateAddResults(&state, {0}Size, {0}Types);\n", resultName
745 );
746 } else {
747 os << llvm::formatv(" mlirOperationStateAddResults(&state, 1, &{0}Type);\n", resultName);
748 }
749 }
750 void genOperand(llvm::raw_ostream &os, const NamedTypeConstraint &operand) override {
751 if (operand.isVariadic()) {
752 os << llvm::formatv(
753 " mlirOperationStateAddOperands(&state, {0}Size, {0});\n", operand.name
754 );
755 } else {
756 os << llvm::formatv(" mlirOperationStateAddOperands(&state, 1, &{0});\n", operand.name);
757 }
758 }
759 void genAttributesPrefix(llvm::raw_ostream &os, const mlir::tblgen::Operator &op) override {
760 os << " MlirContext ctx = mlirOpBuilderGetContext(builder);\n";
761 os << " llvm::SmallVector<MlirNamedAttribute, " << op.getNumAttributes()
762 << "> attributes;\n";
763 }
764 void genAttribute(llvm::raw_ostream &os, const NamedAttribute &attr) override {
765 // The second parameter to `mlirNamedAttributeGet()` must be an "MlirAttribute". However, if
766 // it ends up as "MlirIdentifier", a reinterpret cast is needed. These C structs have the same
767 // layout and the C++ mlir::StringAttr is a subclass of mlir::Attribute so the cast is safe.
768 std::optional<std::string> attrType = tryCppTypeToCapiType(attr.attr.getStorageType());
769 std::string attrValue;
770 if (attrType.has_value() && attrType.value() == "MlirIdentifier") {
771 attrValue = "reinterpret_cast<MlirAttribute&>(" + attr.name.str() + ")";
772 } else {
773 attrValue = attr.name.str();
774 }
775
776 os << " if (!mlirAttributeIsNull(" << attrValue << ")) {\n";
777 os << " attributes.push_back(mlirNamedAttributeGet(mlirIdentifierGet(ctx, "
778 << "mlirStringRefCreateFromCString(\"" << attr.name << "\")), " << attrValue << "));\n";
779 os << " }\n";
780 }
781 void
782 genAttributesSuffix(llvm::raw_ostream &os, const mlir::tblgen::Operator & /*op*/) override {
783 os << " mlirOperationStateAddAttributes(&state, attributes.size(), attributes.data());\n";
784 }
785 void genRegionsPrefix(llvm::raw_ostream &os, const mlir::tblgen::Operator &op) override {
786 os << " llvm::SmallVector<MlirRegion, " << op.getNumRegions() << "> regions;\n";
787 }
788 void genRegion(llvm::raw_ostream &os, const mlir::tblgen::NamedRegion &region) override {
789 if (region.isVariadic()) {
790 os << llvm::formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ", region.name);
791 }
792 os << " regions.push_back(mlirRegionCreate());\n";
793 }
794 void genRegionsSuffix(llvm::raw_ostream &os, const mlir::tblgen::Operator & /*op*/) override {
795 os << " mlirOperationStateAddOwnedRegions(&state, regions.size(), regions.data());\n";
796 }
797 } paramStringGenerator;
798 return paramStringGenerator.gen(op);
799}
800
802static bool emitOpCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) {
803 emitSourceFileHeader("Op C API Definitions", os, records);
804
805 OpImplementationGenerator generator("Operation", os);
806 generator.genPrologue();
807
808 for (const auto *def : records.getAllDerivedDefinitions("Op")) {
809 const Operator op(def);
810 const Dialect &dialect = op.getDialect();
811
812 // Generate for the selected dialect only (specified via -dialect command-line option)
813 if (dialect.getName() != DialectName) {
814 continue;
815 }
816
817 generator.setNamespaceAndClassName(dialect, op.getCppClassName());
818
819 // Generate "Build" function
820 if (GenOpBuild && !op.skipDefaultBuilders()) {
821 std::string assignments = generateCAPIAssignments(op);
822 generator.genOpBuildImpl(op.getOperationName(), generateCAPIBuildParams(op), assignments);
823 }
824
825 // Generate IsA check implementation
826 if (GenIsA) {
827 generator.genIsAImpl();
828 }
829
830 // Generate operand getters and setters
831 for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
832 const auto &operand = op.getOperand(i);
833 generator.setOperandName(operand.name);
834 if (operand.isVariadic()) {
836 generator.genVariadicOperandGetterImpl(i);
837 }
839 if (operand.isVariadicOfVariadic()) {
840 generator.genVariadicOfVariadicOperandSetterImpl(
841 operand.constraint.getVariadicOfVariadicSegmentSizeAttr()
842 );
843 } else {
844 generator.genVariadicOperandSetterImpl();
845 }
846 }
847 } else {
849 generator.genOperandGetterImpl(i);
850 }
852 generator.genOperandSetterImpl(i);
853 }
854 }
855 }
856
857 // Generate attribute getters and setters
858 for (const auto &namedAttr : op.getAttributes()) {
859 generator.setAttributeName(namedAttr.name);
861 generator.genAttributeGetterImpl(namedAttr.name);
862 }
864 generator.genAttributeSetterImpl(namedAttr.name);
865 }
866 }
867
868 // Generate result getters
869 if (GenOpResultGetters) {
870 for (int i = 0, e = op.getNumResults(); i < e; ++i) {
871 const auto &result = op.getResult(i);
872 generator.setResultName(result.name, i);
873 if (result.isVariadic()) {
874 generator.genVariadicResultGetterImpl(i);
875 } else {
876 generator.genResultGetterImpl(i);
877 }
878 }
879 }
880
881 // Generate region getters
882 if (GenOpRegionGetters) {
883 for (unsigned i = 0, e = op.getNumRegions(); i < e; ++i) {
884 const auto &region = op.getRegion(i);
885 generator.setRegionName(region.name, i);
886 if (region.isVariadic()) {
887 generator.genVariadicRegionGetterImpl(i);
888 } else {
889 generator.genRegionGetterImpl(i);
890 }
891 }
892 }
893
894 // Generate extra class method implementations
896 generator.genExtraMethods(op.getExtraClassDeclaration());
897 }
898 }
899
900 return false;
901}
902
903static mlir::GenRegistration
904 genOpCAPIHeader("gen-op-capi-header", "Generate operation C API header", &emitOpCAPIHeader);
905
906static mlir::GenRegistration
907 genOpCAPIImpl("gen-op-capi-impl", "Generate operation C API implementation", &emitOpCAPIImpl);
std::optional< std::string > tryCppTypeToCapiType(StringRef cppType)
Convert C++ type to MLIR C API type.
llvm::cl::opt< bool > GenOpOperandSetters
llvm::cl::opt< bool > GenIsA
llvm::cl::opt< bool > GenOpBuild
llvm::cl::opt< std::string > DialectName
llvm::cl::opt< std::string > FunctionPrefix
llvm::cl::opt< bool > GenOpRegionGetters
llvm::cl::opt< bool > GenOpResultGetters
llvm::cl::opt< bool > GenOpAttributeGetters
llvm::cl::opt< bool > GenOpAttributeSetters
llvm::cl::opt< bool > GenOpOperandGetters
std::string toPascalCase(mlir::StringRef str)
Convert names separated by underscore or colon to PascalCase.
llvm::cl::opt< bool > GenExtraClassMethods
Helper struct to generate a string from operation operand, attribute, and result pieces.
mlir::StringRef className
mlir::StringRef dialectNamespace
std::string dialectNameCapitalized
llvm::raw_ostream & os
Generator for common C header file elements.
Generator for common C implementation file elements.
Common between header and implementation generators for operations.
Definition OpCAPIGen.cpp:33
void setAttributeName(mlir::StringRef name)
Definition OpCAPIGen.cpp:35
void setResultName(mlir::StringRef name, int resultIndex)
Definition OpCAPIGen.cpp:36
void setOperandName(mlir::StringRef name)
Definition OpCAPIGen.cpp:34
std::string resultNameCapitalized
Definition OpCAPIGen.cpp:48
std::string attrNameCapitalized
Definition OpCAPIGen.cpp:47
std::string operandNameCapitalized
Definition OpCAPIGen.cpp:46
void setRegionName(mlir::StringRef name, unsigned regionIndex)
Definition OpCAPIGen.cpp:40
std::string regionNameCapitalized
Definition OpCAPIGen.cpp:49
Generator for operation C header files.
Definition OpCAPIGen.cpp:53
void genVariadicRegionGetterDecl() const
void genVariadicOfVariadicOperandSetterDecl() const
void genVariadicResultGetterDecl() const
~OpHeaderGenerator() override=default
void genOperandSetterDecl() const
Definition OpCAPIGen.cpp:84
void genVariadicOperandSetterDecl() const
void genAttributeGetterDecl() const
void genRegionGetterDecl() const
void genAttributeSetterDecl() const
void genVariadicOperandGetterDecl() const
Definition OpCAPIGen.cpp:98
void genResultGetterDecl() const
void genOpBuildDecl(const std::string &params) const
Definition OpCAPIGen.cpp:57
void genOperandGetterDecl() const
Definition OpCAPIGen.cpp:70
Generator for operation C implementation files.
void genOperandGetterImpl(int index) const
void genRegionGetterImpl(unsigned index) const
void genAttributeGetterImpl(mlir::StringRef attrName) const
void genVariadicRegionGetterImpl(unsigned startIdx) const
void genVariadicResultGetterImpl(int startIdx) const
void genVariadicOfVariadicOperandSetterImpl(mlir::StringRef segSizeAttrName) const
void genOpBuildImpl(const std::string &operationName, const std::string &params, const std::string &assignments) const
Generate operation "Build" function implementation.
void genResultGetterImpl(int index) const
~OpImplementationGenerator() override=default
void genOperandSetterImpl(int index) const
void genVariadicOperandGetterImpl(int index) const
void genPrologue() const override
void genAttributeSetterImpl(mlir::StringRef attrName) const
void genVariadicOperandSetterImpl() const