LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Poly.cpp
Go to the documentation of this file.
1//===-- Poly.cpp - Polymorphic dialect C API impl ---------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#include "llzk-c/Dialect/Poly.h"
11
12#include "llzk/CAPI/Builder.h"
13#include "llzk/CAPI/Support.h"
18#include "llzk/Util/Compare.h"
19
20#include <mlir-c/Pass.h>
21
22#include <mlir/CAPI/AffineExpr.h>
23#include <mlir/CAPI/AffineMap.h>
24#include <mlir/CAPI/Pass.h>
25#include <mlir/CAPI/Registration.h>
26#include <mlir/CAPI/Wrap.h>
27#include <mlir/IR/BuiltinAttributes.h>
28#include <mlir/Support/LLVM.h>
29
30using namespace mlir;
31using namespace llzk;
32using namespace llzk::polymorphic;
33
34static void registerLLZKPolymorphicTransformationPasses() { registerTransformationPasses(); }
35
36// Include the generated CAPI
40
41MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Polymorphic, llzk__polymorphic, PolymorphicDialect)
42
43//===----------------------------------------------------------------------===//
44// TypeVarType
45//===----------------------------------------------------------------------===//
46
47MlirType llzkPoly_TypeVarTypeGetFromStringRef(MlirContext ctx, MlirStringRef name) {
48 return wrap(TypeVarType::get(FlatSymbolRefAttr::get(StringAttr::get(unwrap(ctx), unwrap(name)))));
49}
50
51MlirType llzkPoly_TypeVarTypeGetFromAttr(MlirAttribute attrWrapper) {
52 auto attr = unwrap(attrWrapper);
53 if (auto sym = llvm::dyn_cast<FlatSymbolRefAttr>(attr)) {
54 return wrap(TypeVarType::get(sym));
55 }
56 return wrap(TypeVarType::get(FlatSymbolRefAttr::get(llvm::cast<StringAttr>(attr))));
57}
58
59//===----------------------------------------------------------------------===//
60// TemplateOp
61//===----------------------------------------------------------------------===//
62
63static inline TemplateOp asTemplateOp(MlirOperation op) { return unwrap_cast<TemplateOp>(op); }
64
65static inline void copyAttrs(SmallVector<Attribute> attrs, MlirAttribute *dst) {
66 for (auto [n, attr] : llvm::enumerate(attrs)) {
67 dst[n] = wrap(attr);
68 }
69}
70
71MlirBlock llzkPoly_TemplateOpGetBody(MlirOperation op) { return wrap(asTemplateOp(op).getBody()); }
72
74 return asTemplateOp(op).hasConstOps<TemplateParamOp>();
75}
76
77intptr_t llzkPoly_TemplateOpNumConstParamOps(MlirOperation op) {
78 return llzk::checkedCast<intptr_t>(asTemplateOp(op).numConstOps<TemplateParamOp>());
79}
80
81void llzkPoly_TemplateOpGetConstParamNames(MlirOperation op, MlirAttribute *dst) {
82 copyAttrs(asTemplateOp(op).getConstNames<TemplateParamOp>(), dst);
83}
84
85bool llzkPoly_TemplateOpHasConstParamNamed(MlirOperation op, MlirStringRef find) {
86 return asTemplateOp(op).hasConstNamed<TemplateParamOp>(unwrap(find));
87}
88
89bool llzkPoly_TemplateOpHasConstExprOps(MlirOperation op) {
90 return asTemplateOp(op).hasConstOps<TemplateExprOp>();
91}
92
93intptr_t llzkPoly_TemplateOpNumConstExprOps(MlirOperation op) {
94 return llzk::checkedCast<intptr_t>(asTemplateOp(op).numConstOps<TemplateExprOp>());
95}
96
97void llzkPoly_TemplateOpGetConstExprNames(MlirOperation op, MlirAttribute *dst) {
98 copyAttrs(asTemplateOp(op).getConstNames<TemplateExprOp>(), dst);
99}
100
101bool llzkPoly_TemplateOpHasConstExprNamed(MlirOperation op, MlirStringRef find) {
102 return asTemplateOp(op).hasConstNamed<TemplateExprOp>(unwrap(find));
103}
104
105//===----------------------------------------------------------------------===//
106// ApplyMapOp
107//===----------------------------------------------------------------------===//
108
109LLZK_DEFINE_OP_BUILD_METHOD(Poly, ApplyMapOp, MlirAttribute map, MlirValueRange mapOperands) {
110 SmallVector<Value> mapOperandsSto;
111 return wrap(
113 builder, location, llvm::cast<AffineMapAttr>(unwrap(map)),
114 ValueRange(unwrapList(mapOperands.size, mapOperands.values, mapOperandsSto))
115 )
116 );
117}
118
120 Poly, ApplyMapOp, WithAffineMap, MlirAffineMap map, MlirValueRange mapOperands
121) {
122 SmallVector<Value> mapOperandsSto;
123 return wrap(
125 builder, location, unwrap(map),
126 ValueRange(unwrapList(mapOperands.size, mapOperands.values, mapOperandsSto))
127 )
128 );
129}
130
132 Poly, ApplyMapOp, WithAffineExpr, MlirAffineExpr expr, MlirValueRange mapOperands
133) {
134 SmallVector<Value> mapOperandsSto;
135 return wrap(
137 builder, location, unwrap(expr),
138 ValueRange(unwrapList(mapOperands.size, mapOperands.values, mapOperandsSto))
139 )
140 );
141}
142
143static inline ValueRange dimOperands(MlirOperation op) {
144 return unwrap_cast<ApplyMapOp>(op).getDimOperands();
145}
146
147static inline ValueRange symbolOperands(MlirOperation op) {
148 return unwrap_cast<ApplyMapOp>(op).getSymbolOperands();
149}
150
151static inline void copyValues(ValueRange in, MlirValue *out) {
152 for (auto [n, value] : llvm::enumerate(in)) {
153 out[n] = wrap(value);
154 }
155}
156
158intptr_t llzkPoly_ApplyMapOpGetNumDimOperands(MlirOperation op) {
159 return llzk::checkedCast<intptr_t>(dimOperands(op).size());
160}
161
165void llzkPoly_ApplyMapOpGetDimOperands(MlirOperation op, MlirValue *dst) {
166 copyValues(dimOperands(op), dst);
167}
168
170intptr_t llzkPoly_ApplyMapOpGetNumSymbolOperands(MlirOperation op) {
171 return llzk::checkedCast<intptr_t>(symbolOperands(op).size());
172}
173
177void llzkPoly_ApplyMapOpGetSymbolOperands(MlirOperation op, MlirValue *dst) {
178 copyValues(symbolOperands(op), dst);
179}
void llzkPoly_ApplyMapOpGetSymbolOperands(MlirOperation op, MlirValue *dst)
Writes into the destination buffer the operands that correspond to symbols in the affine map.
Definition Poly.cpp:177
void llzkPoly_TemplateOpGetConstExprNames(MlirOperation op, MlirAttribute *dst)
Writes into the destination buffer the names of all TemplateExprOp children as FlatSymbolRefAttr attr...
Definition Poly.cpp:97
intptr_t llzkPoly_TemplateOpNumConstExprOps(MlirOperation op)
Returns the number of TemplateExprOp children in the TemplateOp.
Definition Poly.cpp:93
void llzkPoly_TemplateOpGetConstParamNames(MlirOperation op, MlirAttribute *dst)
Writes into the destination buffer the names of all TemplateParamOp children as FlatSymbolRefAttr att...
Definition Poly.cpp:81
bool llzkPoly_TemplateOpHasConstExprOps(MlirOperation op)
Returns true if the TemplateOp has any TemplateExprOp children.
Definition Poly.cpp:89
bool llzkPoly_TemplateOpHasConstExprNamed(MlirOperation op, MlirStringRef find)
Returns true if the TemplateOp has a TemplateExprOp with the given name.
Definition Poly.cpp:101
MlirBlock llzkPoly_TemplateOpGetBody(MlirOperation op)
Returns the single body Block within the TemplateOp's Region.
Definition Poly.cpp:71
bool llzkPoly_TemplateOpHasConstParamNamed(MlirOperation op, MlirStringRef find)
Returns true if the TemplateOp has a TemplateParamOp with the given name.
Definition Poly.cpp:85
intptr_t llzkPoly_ApplyMapOpGetNumSymbolOperands(MlirOperation op)
Returns the number of operands that correspond to symbols in the affine map.
Definition Poly.cpp:170
void llzkPoly_ApplyMapOpGetDimOperands(MlirOperation op, MlirValue *dst)
Writes into the destination buffer the operands that correspond to dimensions in the affine map.
Definition Poly.cpp:165
bool llzkPoly_TemplateOpHasConstParamOps(MlirOperation op)
Returns true if the TemplateOp has any TemplateParamOp children.
Definition Poly.cpp:73
MlirType llzkPoly_TypeVarTypeGetFromAttr(MlirAttribute attrWrapper)
Creates a llzk::polymorphic::TypeVarType from either a StringAttr or a FlatSymbolRefAttr.
Definition Poly.cpp:51
MlirType llzkPoly_TypeVarTypeGetFromStringRef(MlirContext ctx, MlirStringRef name)
Creates a llzk::polymorphic::TypeVarType.
Definition Poly.cpp:47
intptr_t llzkPoly_ApplyMapOpGetNumDimOperands(MlirOperation op)
Returns the number of operands that correspond to dimensions in the affine map.
Definition Poly.cpp:158
intptr_t llzkPoly_TemplateOpNumConstParamOps(MlirOperation op)
Returns the number of TemplateParamOp children in the TemplateOp.
Definition Poly.cpp:77
static TypeVarType get(::mlir::MLIRContext *context, ::mlir::FlatSymbolRefAttr nameRef)
Definition Types.cpp.inc:67
#define LLZK_DEFINE_OP_BUILD_METHOD(dialect, op,...)
Definition Support.h:31
#define LLZK_DEFINE_SUFFIX_OP_BUILD_METHOD(dialect, op, suffix,...)
Definition Support.h:27
void registerTransformationPasses()
constexpr T checkedCast(U u) noexcept
Definition Compare.h:81
mlir::Operation * create(MlirOpBuilder cBuilder, MlirLocation cLocation, Args &&...args)
Creates a new operation using an ODS build method.
Definition Builder.h:41
auto unwrap_cast(auto &from)
Definition Support.h:51
Representation of an mlir::ValueRange
Definition Support.h:47
MlirValue const * values
Pointer to the first value in the range.
Definition Support.h:49
intptr_t size
Number of values in the range.
Definition Support.h:51