LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
POD.cpp
Go to the documentation of this file.
1//===-- POD.cpp - POD dialect C API implementation --------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2026 Project LLZK
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#include "llzk/CAPI/Builder.h"
11#include "llzk/CAPI/Support.h"
14
15#include "llzk-c/Dialect/POD.h"
16#include "llzk-c/Support.h"
17
18#include <mlir/CAPI/IR.h>
19#include <mlir/CAPI/Registration.h>
20#include <mlir/CAPI/Support.h>
21#include <mlir/CAPI/Wrap.h>
22#include <mlir/IR/Attributes.h>
23#include <mlir/IR/Diagnostics.h>
24#include <mlir/Support/LLVM.h>
25
26#include <mlir-c/IR.h>
27
28#include <llvm/ADT/STLExtras.h>
29#include <llvm/ADT/SmallVectorExtras.h>
30
31#include <cstdint>
32
33using namespace mlir;
34using namespace llzk;
35using namespace llzk::pod;
36
37MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(POD, llzk__pod, PODDialect)
38
39namespace {
40
41static SmallVector<RecordValue>
42fromRawRecordValues(intptr_t nValues, LlzkRecordValue const *values) {
43 return llvm::map_to_vector(ArrayRef(values, nValues), [](const auto &record) {
44 return RecordValue {.name = unwrap(record.name), .value = unwrap(record.value)};
45 });
46}
47
48} // namespace
49
50//===----------------------------------------------------------------------===//
51// RecordAttr
52//===----------------------------------------------------------------------===//
53
54MlirAttribute llzkRecordAttrGet(MlirStringRef name, MlirType type) {
55 auto t = unwrap(type);
56 return wrap(RecordAttr::get(t.getContext(), StringAttr::get(t.getContext(), unwrap(name)), t));
57}
58
59bool llzkAttributeIsARecordAttr(MlirAttribute attr) { return mlir::isa<RecordAttr>(unwrap(attr)); }
60
61MlirStringRef llzkRecordAttrGetName(MlirAttribute attr) {
62 return wrap(unwrap_cast<RecordAttr>(attr).getName().getValue());
63}
64
65MlirAttribute llzkRecordAttrGetNameSym(MlirAttribute attr) {
66 return wrap(unwrap_cast<RecordAttr>(attr).getNameSym());
67}
68
69MlirType llzkRecordAttrGetType(MlirAttribute attr) {
70 return wrap(unwrap_cast<RecordAttr>(attr).getType());
71}
72
73//===----------------------------------------------------------------------===//
74// PodType
75//===----------------------------------------------------------------------===//
76
77MlirType llzkPodTypeGet(MlirContext context, intptr_t nRecords, MlirAttribute const *records) {
78 SmallVector<Attribute> recordsSto;
79 auto recordAttrs = llvm::map_to_vector(unwrapList(nRecords, records, recordsSto), [](auto attr) {
80 return mlir::cast<RecordAttr>(attr);
81 });
82 return wrap(PodType::get(unwrap(context), recordAttrs));
83}
84
86 MlirContext context, intptr_t nRecords, LlzkRecordValue const *records
87) {
88 auto initialValues = fromRawRecordValues(nRecords, records);
89 return wrap(PodType::fromInitialValues(unwrap(context), initialValues));
90}
91
92bool llzkTypeIsAPodType(MlirType type) { return mlir::isa<PodType>(unwrap(type)); }
93
94intptr_t llzkPodTypeGetNumRecords(MlirType type) {
95 return static_cast<intptr_t>(unwrap_cast<PodType>(type).getRecords().size());
96}
97
98void llzkPodTypeGetRecords(MlirType type, MlirAttribute *dst) {
99 auto records = unwrap_cast<PodType>(type).getRecords();
100 MutableArrayRef<MlirAttribute> dstRef(dst, records.size());
101 llvm::transform(records, dstRef.begin(), [](auto record) { return wrap(record); });
102}
103
104MlirAttribute llzkPodTypeGetNthRecord(MlirType type, intptr_t n) {
105 return wrap(unwrap_cast<PodType>(type).getRecords()[n]);
106}
107
108namespace {
109static MlirType
110lookupRecordImpl(PodType type, StringRef name, llvm::function_ref<InFlightDiagnostic()> emitError) {
111 auto attr = type.getRecord(name, emitError);
112 if (failed(attr)) {
113 return MlirType {.ptr = nullptr};
114 }
115 return wrap(*attr);
116}
117} // namespace
118
119MlirType llzkPodTypeLookupRecord(MlirType type, MlirStringRef name) {
120 auto pod = unwrap_cast<PodType>(type);
121 return lookupRecordImpl(pod, unwrap(name), [pod] {
122 auto *ctx = pod.getContext();
123 return ctx->getDiagEngine().emit(Builder(ctx).getUnknownLoc(), DiagnosticSeverity::Error);
124 });
125}
126
127MlirType
128llzkPodTypeLookupRecordWithinLocation(MlirType type, MlirStringRef name, MlirLocation loc) {
129 auto pod = unwrap_cast<PodType>(type);
130 return lookupRecordImpl(pod, unwrap(name), [pod, loc] {
131 return pod.getContext()->getDiagEngine().emit(unwrap(loc), DiagnosticSeverity::Error);
132 });
133}
134
135MlirType
136llzkPodTypeLookupRecordWithinOperation(MlirType type, MlirStringRef name, MlirOperation op) {
137 return lookupRecordImpl(unwrap_cast<PodType>(type), unwrap(name), [op] {
138 return unwrap(op)->emitError();
139 });
140}
141
142//===----------------------------------------------------------------------===//
143// NewPodOp
144//===----------------------------------------------------------------------===//
145
147 NewPodOp, InferredFromInitialValues, intptr_t nValues, LlzkRecordValue const *values
148) {
149 auto recordValues = fromRawRecordValues(nValues, values);
150 return wrap(create<NewPodOp>(builder, location, recordValues));
151}
152
154 NewPodOp, MlirType type, intptr_t nValues, LlzkRecordValue const *values
155) {
156 auto recordValues = fromRawRecordValues(nValues, values);
157 return wrap(create<NewPodOp>(builder, location, unwrap_cast<PodType>(type), recordValues));
158}
159
161 NewPodOp, WithMapOperands, MlirType type, intptr_t nValues, LlzkRecordValue const *values,
163) {
164 auto recordValues = fromRawRecordValues(nValues, values);
165 MapOperandsHelper<> mapOps(mapOperands.nMapOperands, mapOperands.mapOperands);
166 auto numDimsPerMap =
167 llzkAffineMapOperandsBuilderGetDimsPerMapAttr(mapOperands, mlirLocationGetContext(location));
168 return wrap(
170 builder, location, unwrap_cast<PodType>(type), *mapOps,
171 unwrap_cast<DenseI32ArrayAttr>(numDimsPerMap)
172 )
173 );
174}
175
176bool llzkOperationIsANewPodOp(MlirOperation op) { return mlir::isa<NewPodOp>(unwrap(op)); }
177
178//===----------------------------------------------------------------------===//
179// ReadPodOp
180//===----------------------------------------------------------------------===//
181
182bool llzkOperationIsAReadPodOp(MlirOperation op) { return mlir::isa<ReadPodOp>(unwrap(op)); }
183
184//===----------------------------------------------------------------------===//
185// WritePodOp
186//===----------------------------------------------------------------------===//
187
188bool llzkOperationIsAWritePodOp(MlirOperation op) { return mlir::isa<WritePodOp>(unwrap(op)); }
bool llzkOperationIsANewPodOp(MlirOperation op)
Definition POD.cpp:176
bool llzkTypeIsAPodType(MlirType type)
Definition POD.cpp:92
MlirAttribute llzkRecordAttrGetNameSym(MlirAttribute attr)
Returns the name of the record as a flat symbol attribute.
Definition POD.cpp:65
MlirType llzkPodTypeLookupRecordWithinLocation(MlirType type, MlirStringRef name, MlirLocation loc)
Lookups a record type by name.
Definition POD.cpp:128
MlirAttribute llzkPodTypeGetNthRecord(MlirType type, intptr_t n)
Returns the n-th record in the struct.
Definition POD.cpp:104
MlirStringRef llzkRecordAttrGetName(MlirAttribute attr)
Returns the name of the record.
Definition POD.cpp:61
MlirType llzkPodTypeGetFromInitialValues(MlirContext context, intptr_t nRecords, LlzkRecordValue const *records)
Creates an llzk::pod::PodType using a list of values for inferring the records.
Definition POD.cpp:85
MlirType llzkPodTypeGet(MlirContext context, intptr_t nRecords, MlirAttribute const *records)
Creates an llzk::pod::PodType using a list of attributes as records.
Definition POD.cpp:77
void llzkPodTypeGetRecords(MlirType type, MlirAttribute *dst)
Writes the records into the given array that must have been previously allocated with enough space.
Definition POD.cpp:98
MlirType llzkPodTypeLookupRecordWithinOperation(MlirType type, MlirStringRef name, MlirOperation op)
Lookups a record type by name.
Definition POD.cpp:136
MlirAttribute llzkRecordAttrGet(MlirStringRef name, MlirType type)
Creates a new llzk::pod::RecordAttr.
Definition POD.cpp:54
bool llzkOperationIsAReadPodOp(MlirOperation op)
Definition POD.cpp:182
MlirType llzkRecordAttrGetType(MlirAttribute attr)
Returns the type of the record.
Definition POD.cpp:69
bool llzkOperationIsAWritePodOp(MlirOperation op)
Definition POD.cpp:188
intptr_t llzkPodTypeGetNumRecords(MlirType type)
Returns the number of records in the struct.
Definition POD.cpp:94
MlirType llzkPodTypeLookupRecord(MlirType type, MlirStringRef name)
Lookups a record type by name.
Definition POD.cpp:119
bool llzkAttributeIsARecordAttr(MlirAttribute attr)
Definition POD.cpp:59
MlirAttribute llzkAffineMapOperandsBuilderGetDimsPerMapAttr(LlzkAffineMapOperandsBuilder builder, MlirContext context)
Returns the number of dimensions per map represented as an attribute.
Definition Support.cpp:186
Helper for unwrapping the C arguments for the map operands.
Definition Support.h:36
::llvm::FailureOr<::mlir::Type > getRecord(::llvm::StringRef name, ::llvm::function_ref<::mlir::InFlightDiagnostic()>) const
Searches a record by name.
Definition Types.cpp:50
static PodType fromInitialValues(::mlir::MLIRContext *ctx, InitializedRecords init)
Creates a new type from a set of initialized records.
Definition Types.cpp:42
static PodType get(::mlir::MLIRContext *context, ::llvm::ArrayRef<::llzk::pod::RecordAttr > records)
Definition Types.cpp.inc:68
#define LLZK_DEFINE_OP_BUILD_METHOD(op,...)
Definition Support.h:27
#define LLZK_DEFINE_SUFFIX_OP_BUILD_METHOD(op, suffix,...)
Definition Support.h:25
mlir::Operation * create(MlirOpBuilder cBuilder, MlirLocation cLocation, Args &&...args)
Creates a new operation using an ODS build method.
Definition Builder.h:41
mlir::Location getUnknownLoc(mlir::MLIRContext *context)
Definition Builders.h:25
auto unwrap_cast(auto &from)
Definition Support.h:30
Encapsulates the arguments related to affine maps that are common in operation constructors that supp...
Definition Support.h:105
MlirValueRange * mapOperands
Definition Support.h:109