LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Struct.cpp
Go to the documentation of this file.
1//===-- Struct.cpp - Struct dialect C API implementation --------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
12#include "llzk/CAPI/Builder.h"
13#include "llzk/CAPI/Support.h"
18#include "llzk/Util/Compare.h"
21
22#include <mlir-c/Support.h>
23
24#include <mlir/CAPI/AffineMap.h>
25#include <mlir/CAPI/Registration.h>
26#include <mlir/CAPI/Support.h>
27#include <mlir/CAPI/Wrap.h>
28#include <mlir/IR/BuiltinAttributes.h>
29#include <mlir/IR/SymbolTable.h>
30
31#include <llvm/ADT/STLExtras.h>
32
33using namespace mlir;
34using namespace llzk;
35using namespace llzk::component;
36
37// Include the generated CAPI
40
41MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Struct, llzk__component, StructDialect)
42
43//===----------------------------------------------------------------------===//
44// StructType
45//===----------------------------------------------------------------------===//
46
47MlirType llzkStruct_StructTypeGet(MlirAttribute name) {
48 return wrap(StructType::get(llvm::cast<SymbolRefAttr>(unwrap(name))));
49}
50
51MlirType llzkStruct_StructTypeGetWithArrayAttr(MlirAttribute name, MlirAttribute params) {
52 return wrap(
54 llvm::cast<SymbolRefAttr>(unwrap(name)), llvm::cast<ArrayAttr>(unwrap(params))
55 )
56 );
57}
58
60 MlirAttribute name, intptr_t numParams, MlirAttribute const *params
61) {
62 SmallVector<Attribute> paramsSto;
63 return wrap(
65 llvm::cast<SymbolRefAttr>(unwrap(name)), unwrapList(numParams, params, paramsSto)
66 )
67 );
68}
70 MlirType type, MlirOperation root, LlzkSymbolLookupResult *result
71) {
72 auto structType = mlir::unwrap_cast<StructType>(type);
73 auto *rootOp = unwrap(root);
74 SymbolTableCollection stc;
75 mlir::FailureOr<llzk::SymbolLookupResult<StructDefOp>> lookup =
76 structType.getDefinition(stc, rootOp);
77
78 if (succeeded(lookup)) {
79 // Allocate the result in the heap and store the pointer in the out var.
80 result->ptr = new llzk::SymbolLookupResultUntyped(std::move(*lookup));
81 }
82 return wrap(lookup);
83}
84
86 MlirType type, MlirModule root, LlzkSymbolLookupResult *result
87) {
88 return llzkStructStructTypeGetDefinition(type, mlirModuleGetOperation(root), result);
89}
90
91//===----------------------------------------------------------------------===//
92// StructDefOp
93//===----------------------------------------------------------------------===//
94
95MlirBlock llzkStruct_StructDefOpGetBody(MlirOperation op) {
96 return wrap(llvm::cast<StructDefOp>(unwrap(op)).getBody());
97}
98
99MlirType llzkStruct_StructDefOpGetType(MlirOperation op) {
100 return wrap(llvm::cast<StructDefOp>(unwrap(op)).getType());
101}
102
103MlirType llzkStruct_StructDefOpGetTypeWithParams(MlirOperation op, MlirAttribute attr) {
104 return wrap(llvm::cast<StructDefOp>(unwrap(op)).getType(llvm::cast<ArrayAttr>(unwrap(attr))));
105}
106
107void llzkStruct_StructDefOpGetMemberDefs(MlirOperation op, MlirOperation *dst) {
108 for (auto [offset, member] :
109 llvm::enumerate(llvm::cast<StructDefOp>(unwrap(op)).getMemberDefs())) {
110 dst[offset] = wrap(member);
111 }
112}
113
114intptr_t llzkStruct_StructDefOpGetNumMemberDefs(MlirOperation op) {
115 return llzk::checkedCast<intptr_t>(llvm::cast<StructDefOp>(unwrap(op)).getMemberDefs().size());
116}
117
119 MlirOperation op, intptr_t *strSize, char *(*alloc_string)(size_t)
120) {
121 auto header = llvm::cast<StructDefOp>(unwrap(op)).getHeaderString();
122 *strSize = llzk::checkedCast<intptr_t>(header.size()) + 1; // Plus one because it's a C string.
123 char *dst = alloc_string(*strSize);
124 dst[header.size()] = 0;
125 memcpy(dst, header.data(), header.size());
126 return dst;
127}
128
129void llzkStruct_StructDefOpGetTemplateParamOpNames(MlirOperation op, MlirAttribute *dst) {
130 for (auto [offset, attr] :
131 llvm::enumerate(llvm::cast<StructDefOp>(unwrap(op)).getTemplateParamOpNames())) {
132 dst[offset] = wrap(attr);
133 }
134}
135
138 llvm::cast<StructDefOp>(unwrap(op)).getTemplateParamOpNames().size()
139 );
140}
141
142void llzkStruct_StructDefOpGetTemplateExprOpNames(MlirOperation op, MlirAttribute *dst) {
143 for (auto [offset, attr] :
144 llvm::enumerate(llvm::cast<StructDefOp>(unwrap(op)).getTemplateExprOpNames())) {
145 dst[offset] = wrap(attr);
146 }
147}
148
151 llvm::cast<StructDefOp>(unwrap(op)).getTemplateExprOpNames().size()
152 );
153}
154
155//===----------------------------------------------------------------------===//
156// MemberReadOp
157//===----------------------------------------------------------------------===//
158
160 Struct, MemberReadOp, MlirType memberType, MlirValue component, MlirIdentifier memberName
161) {
162 return wrap(
164 builder, location, unwrap(memberType), unwrap(component), unwrap(memberName)
165 )
166 );
167}
168
170 Struct, MemberReadOp, WithAffineMapDistance, MlirType memberType, MlirValue component,
171 MlirIdentifier memberName, MlirAffineMap map, MlirValueRange mapOperands
172) {
173 SmallVector<Value> mapOperandsSto;
174 auto mapAttr = AffineMapAttr::get(unwrap(map));
175 return wrap(
177 builder, location, unwrap(memberType), unwrap(component), unwrap(memberName), mapAttr,
178 unwrapList(mapOperands.size, mapOperands.values, mapOperandsSto),
179 mapAttr.getAffineMap().getNumDims()
180 )
181 );
182}
183
185 Struct, MemberReadOp, WithTemplateSymbolDistance, MlirType memberType, MlirValue component,
186 MlirIdentifier memberName, MlirStringRef symbol
187) {
188 return wrap(
190 builder, location, unwrap(memberType), unwrap(component), unwrap(memberName),
191 FlatSymbolRefAttr::get(unwrap(builder)->getStringAttr(unwrap(symbol)))
192 )
193 );
194}
195
197 Struct, MemberReadOp, WithLiteralDistance, MlirType memberType, MlirValue component,
198 MlirIdentifier memberName, int64_t distance
199) {
200 return wrap(
202 builder, location, unwrap(memberType), unwrap(component), unwrap(memberName),
203 unwrap(builder)->getIndexAttr(distance)
204 )
205 );
206}
MlirType llzkStruct_StructTypeGet(MlirAttribute name)
Creates a llzk::component::StructType.
Definition Struct.cpp:47
intptr_t llzkStruct_StructDefOpGetNumMemberDefs(MlirOperation op)
Returns the number of MemberDefOp operations defined in this struct.
Definition Struct.cpp:114
MlirLogicalResult llzkStructStructTypeGetDefinitionFromModule(MlirType type, MlirModule root, LlzkSymbolLookupResult *result)
Lookups the definition Operation of the given StructType using the given Module as root for the looku...
Definition Struct.cpp:85
MlirType llzkStruct_StructTypeGetWithArrayAttr(MlirAttribute name, MlirAttribute params)
Creates a llzk::component::StructType with an ArrayAttr as parameters.
Definition Struct.cpp:51
void llzkStruct_StructDefOpGetTemplateExprOpNames(MlirOperation op, MlirAttribute *dst)
If this struct.def is within a poly.template, add names of all poly.expr within the poly....
Definition Struct.cpp:142
MlirType llzkStruct_StructTypeGetWithAttrs(MlirAttribute name, intptr_t numParams, MlirAttribute const *params)
Creates a llzk::component::StructType with an array of parameters.
Definition Struct.cpp:59
MlirType llzkStruct_StructDefOpGetTypeWithParams(MlirOperation op, MlirAttribute attr)
Returns the associated StructType to this op using the given const params instead of the parameters d...
Definition Struct.cpp:103
MlirLogicalResult llzkStructStructTypeGetDefinition(MlirType type, MlirOperation root, LlzkSymbolLookupResult *result)
Lookups the definition Operation of the given StructType using the given Operation as root for the lo...
Definition Struct.cpp:69
void llzkStruct_StructDefOpGetMemberDefs(MlirOperation op, MlirOperation *dst)
Fills the given array with the MemberDefOp operations inside this struct.
Definition Struct.cpp:107
void llzkStruct_StructDefOpGetTemplateParamOpNames(MlirOperation op, MlirAttribute *dst)
If this struct.def is within a poly.template, add names of all poly.param within the poly....
Definition Struct.cpp:129
MlirBlock llzkStruct_StructDefOpGetBody(MlirOperation op)
Returns the single body Block within the StructDefOp's Region.
Definition Struct.cpp:95
const char * llzkStruct_StructDefOpGetHeaderString(MlirOperation op, intptr_t *strSize, char *(*alloc_string)(size_t))
Returns the header string of the struct.
Definition Struct.cpp:118
MlirType llzkStruct_StructDefOpGetType(MlirOperation op)
Returns the associated StructType to this op using the const params defined by the op.
Definition Struct.cpp:99
intptr_t llzkStruct_StructDefOpGetNumTemplateParamOpNames(MlirOperation op)
Returns the number of poly.param operations defined within this template.
Definition Struct.cpp:136
intptr_t llzkStruct_StructDefOpGetNumTemplateExprOpNames(MlirOperation op)
Returns the number of poly.expr operations defined within this template.
Definition Struct.cpp:149
This file defines methods symbol lookup across LLZK operations and included files.
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:79
#define LLZK_DEFINE_OP_BUILD_METHOD(dialect, op,...)
Definition Support.h:31
#define LLZK_DEFINE_SUFFIX_OP_BUILD_METHOD(dialect, op, suffix,...)
Definition Support.h:27
constexpr T checkedCast(U u) noexcept
Definition Compare.h:81
mlir::Operation * create(MlirOpBuilder cBuilder, MlirLocation cLocation, Args &&...args)
Creates a new operation using an ODS build method.
Definition Builder.h:41
auto unwrap_cast(auto &from)
Definition Support.h:51
Owned result of an LLZK symbol lookup.
Definition Support.h:56
void * ptr
raw pointer to the result
Definition Support.h:58
Representation of an mlir::ValueRange
Definition Support.h:47
MlirValue const * values
Pointer to the first value in the range.
Definition Support.h:49
intptr_t size
Number of values in the range.
Definition Support.h:51