LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - POD operation implementations -----------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2026 Project LLZK
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
16
17#include <mlir/IR/Builders.h>
18#include <mlir/IR/BuiltinAttributes.h>
19#include <mlir/IR/Diagnostics.h>
20#include <mlir/IR/OpImplementation.h>
21#include <mlir/IR/OperationSupport.h>
22#include <mlir/Support/LLVM.h>
23
24#include <llvm/ADT/STLExtras.h>
25#include <llvm/ADT/SmallString.h>
26#include <llvm/ADT/SmallVectorExtras.h>
27#include <llvm/ADT/StringSet.h>
28#include <llvm/ADT/TypeSwitch.h>
29#include <llvm/Support/Debug.h>
30
31#include <cstdint>
32
33// TableGen'd implementation files
34#define GET_OP_CLASSES
36
37using namespace mlir;
38
39namespace llzk::pod {
40
41//===----------------------------------------------------------------------===//
42// NewPodOp
43//===----------------------------------------------------------------------===//
44
45namespace {
46static void buildCommon(
47 OpBuilder &builder, OperationState &state, PodType result, InitializedRecords initialValues
48) {
49 SmallVector<Value, 4> values;
50 SmallVector<StringRef, 4> names;
51
52 for (const auto &record : initialValues) {
53 names.push_back(record.name);
54 values.push_back(record.value);
55 }
56
57 auto &props = state.getOrAddProperties<NewPodOp::Properties>();
58 state.addTypes(result);
59 state.addOperands(values);
60 props.setInitializedRecords(builder.getStrArrayAttr(names));
61}
62} // namespace
63
65 OpBuilder &builder, OperationState &state, PodType result, ArrayRef<ValueRange> mapOperands,
66 DenseI32ArrayAttr numDimsPerMap, InitializedRecords initialValues
67) {
68 buildCommon(builder, state, result, initialValues);
69 affineMapHelpers::buildInstantiationAttrs<NewPodOp>(builder, state, mapOperands, numDimsPerMap);
70}
71
73 OpBuilder &builder, OperationState &state, PodType result, InitializedRecords initialValues
74) {
75 buildCommon(builder, state, result, initialValues);
77 builder, state, llzk::checkedCast<int32_t>(initialValues.size())
78 );
79}
80
81void NewPodOp::getAsmResultNames(llvm::function_ref<void(Value, StringRef)> setNameFn) {
82 setNameFn(getResult(), "pod");
83}
84
85namespace {
86
87static void collectMapAttrs(Type type, SmallVector<AffineMapAttr> &mapAttrs) {
88 // clang-format off
89 llvm::TypeSwitch<Type, void>(type)
90 .Case([&mapAttrs](PodType t) {
91 for (auto record : t.getRecords()) {
92 collectMapAttrs(record.getType(), mapAttrs);
93 }
94 })
95 .Case([&mapAttrs](array::ArrayType t) {
96 for (auto a : t.getDimensionSizes()) {
97 if (auto m = llvm::dyn_cast<AffineMapAttr>(a)) {
98 mapAttrs.push_back(m);
99 }
100 }
101 })
102 .Case([&mapAttrs](component::StructType t) {
103 if (ArrayAttr params = t.getParams()) {
104 for (auto param : params) {
105 if (auto m = llvm::dyn_cast<AffineMapAttr>(param)) {
106 mapAttrs.push_back(m);
107 }
108 }
109 }
110 }).Default([](Type) {});
111 // clang-format on
112}
113
121static LogicalResult verifyInitialValues(
122 ValueRange values, ArrayRef<Attribute> names, PodType retTy,
123 llvm::function_ref<InFlightDiagnostic()> emitError
124) {
125 bool failed = false;
126 if (names.size() != values.size()) {
127 emitError() << "number of initialized records and initial values does not match ("
128 << names.size() << " != " << values.size() << ")";
129 failed = true;
130 }
131
132 llvm::StringMap<Type> records = retTy.getRecordMap();
133 llvm::StringSet<> seenNames;
134 for (auto [nameAttr, value] : llvm::zip_equal(names, values)) {
135 auto name = llvm::cast<StringAttr>(nameAttr).getValue(); // Per the ODS spec.
136 if (seenNames.contains(name)) {
137 emitError() << "found duplicated record name '" << name << '\'';
138 failed = true;
139 }
140 seenNames.insert(name);
141
142 if (!records.contains(name)) {
143 emitError() << "record '" << name << "' is not part of the struct";
144 failed = true;
145 continue;
146 }
147
148 auto valueTy = value.getType();
149 auto recordTy = records.at(name);
150 if (valueTy != recordTy) {
151 auto err = emitError();
152 err << "record '" << name << "' expected type " << recordTy << " but got " << valueTy;
153 if (typesUnify(valueTy, recordTy)) {
154 err.attachNote()
155 << "types " << valueTy << " and " << recordTy
156 << " can be unified. Perhaps you can add a 'poly.unifiable_cast' operation?";
157 }
158 failed = true;
159 }
160 }
161
162 return failure(failed);
163}
164
165static LogicalResult verifyAffineMapOperands(NewPodOp *op, Type retTy) {
166 SmallVector<AffineMapAttr> mapAttrs;
167 collectMapAttrs(retTy, mapAttrs);
169 op->getMapOperands(), op->getNumDimsPerMap(), mapAttrs, *op
170 );
171}
172
173} // namespace
174
175#define check(x) \
176 { \
177 failed = failed || mlir::failed(x); \
178 }
179
180LogicalResult NewPodOp::verify() {
181 auto retTy = llvm::dyn_cast<PodType>(getResult().getType());
182 assert(retTy); // per ODS spec of NewPodOp
183
184 bool failed = false;
185 check(
186 verifyInitialValues(getInitialValues(), getInitializedRecords().getValue(), retTy, [this]() {
187 return this->emitError();
188 })
189 );
190 check(verifyAffineMapOperands(this, retTy));
191
192 return failure(failed);
193}
194
195#undef check
196
197using UnresolvedOp = OpAsmParser::UnresolvedOperand;
198
199ParseResult
200parseRecordInitialization(OpAsmParser &parser, StringAttr &name, UnresolvedOp &operand) {
201 if (failed(parser.parseSymbolName(name))) {
202 return failure();
203 }
204
205 if (parser.parseEqual()) {
206 return failure();
207 }
208 return parser.parseOperand(operand);
209}
210
211ParseResult NewPodOp::parse(OpAsmParser &parser, OperationState &result) {
212 /* Grammar
213 * op : record_init map_operands `:` type($result) attr-dict
214 * record_init : `{` record_inits `}`| `{` `}` | $
215 * map_operands : custom<MapOperands> | $
216 * record_inits : symbol `=` operand `,` record_inits | symbol `=` operand
217 */
218
219 auto &props = result.getOrAddProperties<NewPodOp::Properties>();
220
221 SmallVector<Attribute> initializedRecords;
222 // The map may not preserve the order of the operands so it needs to be iterated using
223 // `initializedRecords` that preserves the original order.
224 llvm::StringMap<UnresolvedOp> initialValuesOperands;
225 auto parseElementFn = [&parser, &initializedRecords, &initialValuesOperands] {
226 StringAttr name;
227 UnresolvedOp operand;
228 if (failed(parseRecordInitialization(parser, name, operand))) {
229 return failure();
230 }
231 initializedRecords.push_back(name);
232 initialValuesOperands.insert({name.getValue(), operand});
233 return success();
234 };
235 auto initialValuesLoc = parser.getCurrentLocation();
236 if (parser.parseCommaSeparatedList(AsmParser::Delimiter::OptionalBraces, parseElementFn)) {
237 return failure();
238 }
239 SmallVector<int32_t> mapOperandsGroupSizes;
240 SmallVector<UnresolvedOp> allMapOperands;
241 Type indexTy = parser.getBuilder().getIndexType();
242 bool colonAlreadyParsed = true;
243 auto mapOperandsLoc = parser.getCurrentLocation();
244 // Peek to see if we have affine map operands.
245 // If we don't then the next token must be `:`
246 if (failed(parser.parseOptionalColon())) {
247 colonAlreadyParsed = false;
248 SmallVector<SmallVector<UnresolvedOp>> mapOperands {};
249 if (parseMultiDimAndSymbolList(parser, mapOperands, props.numDimsPerMap)) {
250 return failure();
251 }
252
253 mapOperandsGroupSizes.reserve(mapOperands.size());
254 for (const auto &subRange : mapOperands) {
255 allMapOperands.append(subRange.begin(), subRange.end());
256 mapOperandsGroupSizes.push_back(llzk::checkedCast<int32_t>(subRange.size()));
257 }
258 }
259
260 if (!colonAlreadyParsed && parser.parseColon()) {
261 return failure();
262 }
263
264 PodType resultType;
265 if (parser.parseCustomTypeWithFallback(resultType)) {
266 return failure();
267 }
268 // Now that we have the struct type we can resolve the operands
269 // using the types of the struct.
270 for (auto attr : initializedRecords) {
271 auto name = llvm::cast<StringAttr>(attr); // Per ODS spec of RecordAttr
272 auto lookup = resultType.getRecord(name.getValue(), [&parser, initialValuesLoc] {
273 return parser.emitError(initialValuesLoc);
274 });
275 if (failed(lookup)) {
276 return failure();
277 }
278 const auto &operand = initialValuesOperands.at(name.getValue());
279 if (failed(parser.resolveOperands({operand}, *lookup, initialValuesLoc, result.operands))) {
280 return failure();
281 }
282 }
283 props.operandSegmentSizes = {
284 llzk::checkedCast<int32_t>(initializedRecords.size()),
285 llzk::checkedCast<int32_t>(allMapOperands.size())
286 };
287 props.mapOpGroupSizes = parser.getBuilder().getDenseI32ArrayAttr(mapOperandsGroupSizes);
288 props.initializedRecords = parser.getBuilder().getArrayAttr(initializedRecords);
289 result.addTypes({resultType});
290
291 if (failed(parser.resolveOperands(allMapOperands, indexTy, mapOperandsLoc, result.operands))) {
292 return failure();
293 }
294 {
295 auto loc = parser.getCurrentLocation();
296 if (parser.parseOptionalAttrDict(result.attributes)) {
297 return failure();
298 }
299 if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
300 return parser.emitError(loc) << '\'' << result.name.getStringRef() << "' op ";
301 }))) {
302 return failure();
303 }
304 }
305
306 return success();
307}
308
309void NewPodOp::print(OpAsmPrinter &printer) {
310 auto &os = printer.getStream();
311 auto initializedRecords = getInitializedRecordValues();
312 if (!initializedRecords.empty()) {
313 os << " { ";
314 llvm::interleaveComma(initializedRecords, os, [&os, &printer](auto record) {
315 printer.printSymbolName(record.name);
316 os << " = ";
317 printer.printOperand(record.value);
318 });
319 os << " } ";
320 }
322
323 os << " : ";
324
325 auto type = getResult().getType();
326 if (auto validType = llvm::dyn_cast<PodType>(type)) {
327 printer.printStrippedAttrOrType(validType);
328 } else {
329 printer.printType(type);
330 }
331
332 printer.printOptionalAttrDict(
333 (*this)->getAttrs(),
334 {"initializedRecords", "mapOpGroupSizes", "numDimsPerMap", "operandSegmentSizes"}
335 );
336}
337
338SmallVector<RecordValue> NewPodOp::getInitializedRecordValues() {
339 return llvm::map_to_vector(
340 llvm::zip_equal(getInitialValues(), getInitializedRecords()), [](auto pair) {
341 auto [value, name] = pair;
342 return RecordValue {.name = llvm::cast<StringAttr>(name).getValue(), .value = value};
343 }
344 );
345}
346
347//===----------------------------------------------------------------------===//
348// ReadPodOp
349//===----------------------------------------------------------------------===//
350
351LogicalResult ReadPodOp::verify() {
352 auto podTy = llvm::dyn_cast<PodType>(getPodRef().getType());
353 if (!podTy) {
354 return emitError() << "reference operand expected a plain-old-data struct but got "
355 << getPodRef().getType();
356 }
357
358 auto lookup = podTy.getRecord(getRecordName(), [this]() { return this->emitError(); });
359 if (failed(lookup)) {
360 return lookup;
361 }
362
363 if (getResult().getType() != *lookup) {
364 return emitError() << "operation result type and type of record do not match ("
365 << getResult().getType() << " != " << *lookup << ")";
366 }
367
368 return success();
369}
370
371//===----------------------------------------------------------------------===//
372// WritePodOp
373//===----------------------------------------------------------------------===//
374
375LogicalResult WritePodOp::verify() {
376 auto podTy = llvm::dyn_cast<PodType>(getPodRef().getType());
377 if (!podTy) {
378 return emitError() << "reference operand expected a plain-old-data struct but got "
379 << getPodRef().getType();
380 }
381
382 auto lookup = podTy.getRecord(getRecordName(), [this]() { return this->emitError(); });
383 if (failed(lookup)) {
384 return lookup;
385 }
386
387 if (getValue().getType() != *lookup) {
388 return emitError() << "type of source value and type of record do not match ("
389 << getValue().getType() << " != " << *lookup << ")";
390 }
391
392 return success();
393}
394
395//===----------------------------------------------------------------------===//
396// Parsing/Printing helpers
397//===----------------------------------------------------------------------===//
398
399ParseResult parseRecordName(AsmParser &parser, FlatSymbolRefAttr &name) {
400 return parser.parseCustomAttributeWithFallback(name);
401}
402
403void printRecordName(AsmPrinter &printer, Operation *, FlatSymbolRefAttr name) {
404 printer.printSymbolName(name.getValue());
405}
406
407} // namespace llzk::pod
within a display generated by the Derivative if and wherever such third party notices normally appear The contents of the NOTICE file are for informational purposes only and do not modify the License You may add Your own attribution notices within Derivative Works that You alongside or as an addendum to the NOTICE text from the provided that such additional attribution notices cannot be construed as modifying the License You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for or distribution of Your or for any such Derivative Works as a provided Your and distribution of the Work otherwise complies with the conditions stated in this License Submission of Contributions Unless You explicitly state any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this without any additional terms or conditions Notwithstanding the nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions Trademarks This License does not grant permission to use the trade names
Definition LICENSE.txt:139
#define check(x)
Definition Ops.cpp:175
void print(::mlir::OpAsmPrinter &p)
Definition Ops.cpp:309
::mlir::Operation::operand_range getInitialValues()
Definition Ops.h.inc:237
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:241
::mlir::SmallVector<::llzk::pod::RecordValue > getInitializedRecordValues()
Definition Ops.cpp:338
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llzk::pod::InitializedRecords initialValues={})
Definition Ops.cpp.inc:458
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:275
::mlir::TypedValue<::llzk::pod::PodType > getResult()
Definition Ops.h.inc:257
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
Definition Ops.cpp:81
FoldAdaptor::Properties Properties
Definition Ops.h.inc:188
::llvm::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError)
Definition Ops.cpp.inc:355
::llvm::LogicalResult verify()
Definition Ops.cpp:180
::mlir::ArrayAttr getInitializedRecords()
Definition Ops.cpp.inc:435
::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result)
Definition Ops.cpp:211
::llvm::FailureOr<::mlir::Type > getRecord(::llvm::StringRef name, ::llvm::function_ref<::mlir::InFlightDiagnostic()>) const
Searches a record by name.
Definition Types.cpp:50
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
::llvm::StringRef getRecordName()
Definition Ops.cpp.inc:643
::mlir::TypedValue<::mlir::Type > getResult()
Definition Ops.h.inc:492
::llvm::LogicalResult verify()
Definition Ops.cpp:351
::mlir::TypedValue<::llzk::pod::PodType > getPodRef()
Definition Ops.h.inc:473
::llvm::LogicalResult verify()
Definition Ops.cpp:375
::mlir::TypedValue<::mlir::Type > getValue()
Definition Ops.h.inc:694
::mlir::TypedValue<::llzk::pod::PodType > getPodRef()
Definition Ops.h.inc:690
::llvm::StringRef getRecordName()
Definition Ops.cpp.inc:926
OpClass::Properties & buildInstantiationAttrs(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
OpClass::Properties & buildInstantiationAttrsEmpty(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
mlir::ArrayRef< RecordValue > InitializedRecords
Definition Types.h:23
OpAsmParser::UnresolvedOperand UnresolvedOp
Definition Ops.cpp:197
ParseResult parseRecordName(AsmParser &parser, FlatSymbolRefAttr &name)
Definition Ops.cpp:399
ParseResult parseRecordInitialization(OpAsmParser &parser, StringAttr &name, UnresolvedOp &operand)
Definition Ops.cpp:200
void printRecordName(AsmPrinter &printer, Operation *, FlatSymbolRefAttr name)
Definition Ops.cpp:403
constexpr T checkedCast(U u) noexcept
Definition Compare.h:81
void printMultiDimAndSymbolList(mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::OperandRangeRange multiMapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
Definition OpHelpers.h:140
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
mlir::ParseResult parseMultiDimAndSymbolList(mlir::OpAsmParser &parser, mlir::SmallVector< mlir::SmallVector< mlir::OpAsmParser::UnresolvedOperand > > &multiMapOperands, mlir::DenseI32ArrayAttr &numDimsPerMap)
Definition OpHelpers.h:132
void setInitializedRecords(const ::mlir::ArrayAttr &propValue)
Definition Ops.h.inc:46