LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - POD operation implementations -----------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2026 Project LLZK
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
15
16#include <mlir/IR/Builders.h>
17#include <mlir/IR/BuiltinAttributes.h>
18#include <mlir/IR/Diagnostics.h>
19#include <mlir/IR/OpImplementation.h>
20#include <mlir/IR/OperationSupport.h>
21#include <mlir/Support/LLVM.h>
22
23#include <llvm/ADT/STLExtras.h>
24#include <llvm/ADT/SmallString.h>
25#include <llvm/ADT/SmallVectorExtras.h>
26#include <llvm/ADT/StringSet.h>
27#include <llvm/ADT/TypeSwitch.h>
28#include <llvm/Support/Debug.h>
29
30#include <cstdint>
31
32// TableGen'd implementation files
33#define GET_OP_CLASSES
35
36using namespace mlir;
37
38namespace llzk::pod {
39
40//===----------------------------------------------------------------------===//
41// NewPodOp
42//===----------------------------------------------------------------------===//
43
44namespace {
45static void buildCommon(
46 OpBuilder &builder, OperationState &state, PodType result, InitializedRecords initialValues
47) {
48 SmallVector<Value, 4> values;
49 SmallVector<StringRef, 4> names;
50
51 for (const auto &record : initialValues) {
52 names.push_back(record.name);
53 values.push_back(record.value);
54 }
55
56 auto &props = state.getOrAddProperties<NewPodOp::Properties>();
57 state.addTypes(result);
58 state.addOperands(values);
59 props.setInitializedRecords(builder.getStrArrayAttr(names));
60}
61} // namespace
62
64 OpBuilder &builder, OperationState &state, PodType result, ArrayRef<ValueRange> mapOperands,
65 DenseI32ArrayAttr numDimsPerMap, InitializedRecords initialValues
66) {
67 buildCommon(builder, state, result, initialValues);
68 affineMapHelpers::buildInstantiationAttrs<NewPodOp>(builder, state, mapOperands, numDimsPerMap);
69}
70
72 OpBuilder &builder, OperationState &state, PodType result, InitializedRecords initialValues
73) {
74 buildCommon(builder, state, result, initialValues);
75 assert(std::cmp_less_equal(initialValues.size(), std::numeric_limits<int32_t>::max()));
77 builder, state, static_cast<int32_t>(initialValues.size())
78 );
79}
80
81void NewPodOp::getAsmResultNames(llvm::function_ref<void(Value, StringRef)> setNameFn) {
82 setNameFn(getResult(), "pod");
83}
84
85namespace {
86
87static void collectMapAttrs(Type type, SmallVector<AffineMapAttr> &mapAttrs) {
88 llvm::TypeSwitch<Type, void>(type)
89 .Case([&mapAttrs](PodType t) {
90 for (auto record : t.getRecords()) {
91 collectMapAttrs(record.getType(), mapAttrs);
92 }
93 })
94 .Case([&mapAttrs](array::ArrayType t) {
95 for (auto a : t.getDimensionSizes()) {
96 if (auto m = llvm::dyn_cast<AffineMapAttr>(a)) {
97 mapAttrs.push_back(m);
98 }
99 }
100 })
101 .Case([&mapAttrs](component::StructType t) {
102 for (auto param : t.getParams()) {
103 if (auto m = llvm::dyn_cast<AffineMapAttr>(param)) {
104 mapAttrs.push_back(m);
105 }
106 }
107 }).Default([](Type) {});
108}
109
117static LogicalResult verifyInitialValues(
118 ValueRange values, ArrayRef<Attribute> names, PodType retTy,
119 llvm::function_ref<InFlightDiagnostic()> emitError
120) {
121 bool failed = false;
122 if (names.size() != values.size()) {
123 emitError() << "number of initialized records and initial values does not match ("
124 << names.size() << " != " << values.size() << ")";
125 failed = true;
126 }
127
128 llvm::StringMap<Type> records = retTy.getRecordMap();
129 llvm::StringSet<> seenNames;
130 for (auto [nameAttr, value] : llvm::zip_equal(names, values)) {
131 auto name = llvm::cast<StringAttr>(nameAttr).getValue(); // Per the ODS spec.
132 if (seenNames.contains(name)) {
133 emitError() << "found duplicated record name '" << name << '\'';
134 failed = true;
135 }
136 seenNames.insert(name);
137
138 if (!records.contains(name)) {
139 emitError() << "record '" << name << "' is not part of the struct";
140 failed = true;
141 continue;
142 }
143
144 auto valueTy = value.getType();
145 auto recordTy = records.at(name);
146 if (valueTy != recordTy) {
147 auto err = emitError();
148 err << "record '" << name << "' expected type " << recordTy << " but got " << valueTy;
149 if (typesUnify(valueTy, recordTy)) {
150 err.attachNote()
151 << "types " << valueTy << " and " << recordTy
152 << " can be unified. Perhaps you can add a 'poly.unifiable_cast' operation?";
153 }
154 failed = true;
155 }
156 }
157
158 return failure(failed);
159}
160
161static LogicalResult verifyAffineMapOperands(NewPodOp *op, Type retTy) {
162 SmallVector<AffineMapAttr> mapAttrs;
163 collectMapAttrs(retTy, mapAttrs);
165 op->getMapOperands(), op->getNumDimsPerMap(), mapAttrs, *op
166 );
167}
168
169} // namespace
170
171#define check(x) \
172 { \
173 failed = failed || mlir::failed(x); \
174 }
175
176LogicalResult NewPodOp::verify() {
177 auto retTy = llvm::dyn_cast<PodType>(getResult().getType());
178 assert(retTy); // per ODS spec of NewPodOp
179
180 bool failed = false;
181 check(
182 verifyInitialValues(getInitialValues(), getInitializedRecords().getValue(), retTy, [this]() {
183 return this->emitError();
184 })
185 );
186 check(verifyAffineMapOperands(this, retTy));
187
188 return failure(failed);
189}
190
191#undef check
192
193using UnresolvedOp = OpAsmParser::UnresolvedOperand;
194
195ParseResult
196parseRecordInitialization(OpAsmParser &parser, StringAttr &name, UnresolvedOp &operand) {
197 if (failed(parser.parseSymbolName(name))) {
198 return failure();
199 }
200
201 if (parser.parseEqual()) {
202 return failure();
203 }
204 return parser.parseOperand(operand);
205}
206
207ParseResult NewPodOp::parse(OpAsmParser &parser, OperationState &result) {
208 /* Grammar
209 * op : record_init map_operands `:` type($result) attr-dict
210 * record_init : `{` record_inits `}`| `{` `}` | $
211 * map_operands : custom<MapOperands> | $
212 * record_inits : symbol `=` operand `,` record_inits | symbol `=` operand
213 */
214
215 auto &props = result.getOrAddProperties<NewPodOp::Properties>();
216
217 SmallVector<Attribute> initializedRecords;
218 // The map may not preserve the order of the operands so it needs to be iterated using
219 // `initializedRecords` that preserves the original order.
220 llvm::StringMap<UnresolvedOp> initialValuesOperands;
221 auto parseElementFn = [&parser, &initializedRecords, &initialValuesOperands] {
222 StringAttr name;
223 UnresolvedOp operand;
224 if (failed(parseRecordInitialization(parser, name, operand))) {
225 return failure();
226 }
227 initializedRecords.push_back(name);
228 initialValuesOperands.insert({name.getValue(), operand});
229 return success();
230 };
231 auto initialValuesLoc = parser.getCurrentLocation();
232 if (parser.parseCommaSeparatedList(AsmParser::Delimiter::OptionalBraces, parseElementFn)) {
233 return failure();
234 }
235 SmallVector<int32_t> mapOperandsGroupSizes;
236 SmallVector<UnresolvedOp> allMapOperands;
237 Type indexTy = parser.getBuilder().getIndexType();
238 bool colonAlreadyParsed = true;
239 auto mapOperandsLoc = parser.getCurrentLocation();
240 // Peek to see if we have affine map operands.
241 // If we don't then the next token must be `:`
242 if (failed(parser.parseOptionalColon())) {
243 colonAlreadyParsed = false;
244 SmallVector<SmallVector<UnresolvedOp>> mapOperands {};
245 if (parseMultiDimAndSymbolList(parser, mapOperands, props.numDimsPerMap)) {
246 return failure();
247 }
248
249 mapOperandsGroupSizes.reserve(mapOperands.size());
250 for (const auto &subRange : mapOperands) {
251 allMapOperands.append(subRange.begin(), subRange.end());
252 assert(std::cmp_less_equal(subRange.size(), std::numeric_limits<int32_t>::max()));
253 mapOperandsGroupSizes.push_back(static_cast<int32_t>(subRange.size()));
254 }
255 }
256
257 if (!colonAlreadyParsed && parser.parseColon()) {
258 return failure();
259 }
260
261 PodType resultType;
262 if (parser.parseCustomTypeWithFallback(resultType)) {
263 return failure();
264 }
265 // Now that we have the struct type we can resolve the operands
266 // using the types of the struct.
267 for (auto attr : initializedRecords) {
268 auto name = llvm::cast<StringAttr>(attr); // Per ODS spec of RecordAttr
269 auto lookup = resultType.getRecord(name.getValue(), [&parser, initialValuesLoc] {
270 return parser.emitError(initialValuesLoc);
271 });
272 if (failed(lookup)) {
273 return failure();
274 }
275 const auto &operand = initialValuesOperands.at(name.getValue());
276 if (failed(parser.resolveOperands({operand}, *lookup, initialValuesLoc, result.operands))) {
277 return failure();
278 }
279 }
280 assert(std::cmp_less_equal(initializedRecords.size(), std::numeric_limits<int32_t>::max()));
281 assert(std::cmp_less_equal(allMapOperands.size(), std::numeric_limits<int32_t>::max()));
282 props.operandSegmentSizes = {
283 static_cast<int32_t>(initializedRecords.size()), static_cast<int32_t>(allMapOperands.size())
284 };
285 props.mapOpGroupSizes = parser.getBuilder().getDenseI32ArrayAttr(mapOperandsGroupSizes);
286 props.initializedRecords = parser.getBuilder().getArrayAttr(initializedRecords);
287 result.addTypes({resultType});
288
289 if (failed(parser.resolveOperands(allMapOperands, indexTy, mapOperandsLoc, result.operands))) {
290 return failure();
291 }
292 {
293 auto loc = parser.getCurrentLocation();
294 if (parser.parseOptionalAttrDict(result.attributes)) {
295 return failure();
296 }
297 if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
298 return parser.emitError(loc) << '\'' << result.name.getStringRef() << "' op ";
299 }))) {
300 return failure();
301 }
302 }
303
304 return success();
305}
306
307void NewPodOp::print(OpAsmPrinter &printer) {
308 auto &os = printer.getStream();
309 auto initializedRecords = getInitializedRecordValues();
310 if (!initializedRecords.empty()) {
311 os << " { ";
312 llvm::interleaveComma(initializedRecords, os, [&os, &printer](auto record) {
313 printer.printSymbolName(record.name);
314 os << " = ";
315 printer.printOperand(record.value);
316 });
317 os << " } ";
318 }
320
321 os << " : ";
322
323 auto type = getResult().getType();
324 if (auto validType = llvm::dyn_cast<PodType>(type)) {
325 printer.printStrippedAttrOrType(validType);
326 } else {
327 printer.printType(type);
328 }
329
330 printer.printOptionalAttrDict(
331 (*this)->getAttrs(),
332 {"initializedRecords", "mapOpGroupSizes", "numDimsPerMap", "operandSegmentSizes"}
333 );
334}
335
336SmallVector<RecordValue> NewPodOp::getInitializedRecordValues() {
337 return llvm::map_to_vector(
338 llvm::zip_equal(getInitialValues(), getInitializedRecords()), [](auto pair) {
339 auto [value, name] = pair;
340 return RecordValue {.name = llvm::cast<StringAttr>(name).getValue(), .value = value};
341 }
342 );
343}
344
345//===----------------------------------------------------------------------===//
346// ReadPodOp
347//===----------------------------------------------------------------------===//
348
349LogicalResult ReadPodOp::verify() {
350 auto podTy = llvm::dyn_cast<PodType>(getPodRef().getType());
351 if (!podTy) {
352 return emitError() << "reference operand expected a plain-old-data struct but got "
353 << getPodRef().getType();
354 }
355
356 auto lookup = podTy.getRecord(getRecordName(), [this]() { return this->emitError(); });
357 if (failed(lookup)) {
358 return lookup;
359 }
360
361 if (getResult().getType() != *lookup) {
362 return emitError() << "operation result type and type of record do not match ("
363 << getResult().getType() << " != " << *lookup << ")";
364 }
365
366 return success();
367}
368
369//===----------------------------------------------------------------------===//
370// WritePodOp
371//===----------------------------------------------------------------------===//
372
373LogicalResult WritePodOp::verify() {
374 auto podTy = llvm::dyn_cast<PodType>(getPodRef().getType());
375 if (!podTy) {
376 return emitError() << "reference operand expected a plain-old-data struct but got "
377 << getPodRef().getType();
378 }
379
380 auto lookup = podTy.getRecord(getRecordName(), [this]() { return this->emitError(); });
381 if (failed(lookup)) {
382 return lookup;
383 }
384
385 if (getValue().getType() != *lookup) {
386 return emitError() << "type of source value and type of record do not match ("
387 << getValue().getType() << " != " << *lookup << ")";
388 }
389
390 return success();
391}
392
393//===----------------------------------------------------------------------===//
394// Parsing/Printing helpers
395//===----------------------------------------------------------------------===//
396
397ParseResult parseRecordName(AsmParser &parser, FlatSymbolRefAttr &name) {
398 return parser.parseCustomAttributeWithFallback(name);
399}
400
401void printRecordName(AsmPrinter &printer, Operation *, FlatSymbolRefAttr name) {
402 printer.printSymbolName(name.getValue());
403}
404
405} // namespace llzk::pod
within a display generated by the Derivative if and wherever such third party notices normally appear The contents of the NOTICE file are for informational purposes only and do not modify the License You may add Your own attribution notices within Derivative Works that You alongside or as an addendum to the NOTICE text from the provided that such additional attribution notices cannot be construed as modifying the License You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for or distribution of Your or for any such Derivative Works as a provided Your and distribution of the Work otherwise complies with the conditions stated in this License Submission of Contributions Unless You explicitly state any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this without any additional terms or conditions Notwithstanding the nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions Trademarks This License does not grant permission to use the trade names
Definition LICENSE.txt:139
#define check(x)
Definition Ops.cpp:171
void print(::mlir::OpAsmPrinter &p)
Definition Ops.cpp:307
::mlir::Operation::operand_range getInitialValues()
Definition Ops.h.inc:237
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:241
::mlir::SmallVector<::llzk::pod::RecordValue > getInitializedRecordValues()
Definition Ops.cpp:336
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llzk::pod::InitializedRecords initialValues={})
Definition Ops.cpp.inc:458
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:275
::mlir::TypedValue<::llzk::pod::PodType > getResult()
Definition Ops.h.inc:257
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
Definition Ops.cpp:81
FoldAdaptor::Properties Properties
Definition Ops.h.inc:188
::llvm::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError)
Definition Ops.cpp.inc:355
::llvm::LogicalResult verify()
Definition Ops.cpp:176
::mlir::ArrayAttr getInitializedRecords()
Definition Ops.cpp.inc:435
::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result)
Definition Ops.cpp:207
::llvm::FailureOr<::mlir::Type > getRecord(::llvm::StringRef name, ::llvm::function_ref<::mlir::InFlightDiagnostic()>) const
Searches a record by name.
Definition Types.cpp:50
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
::llvm::StringRef getRecordName()
Definition Ops.cpp.inc:643
::mlir::TypedValue<::mlir::Type > getResult()
Definition Ops.h.inc:492
::llvm::LogicalResult verify()
Definition Ops.cpp:349
::mlir::TypedValue<::llzk::pod::PodType > getPodRef()
Definition Ops.h.inc:473
::llvm::LogicalResult verify()
Definition Ops.cpp:373
::mlir::TypedValue<::mlir::Type > getValue()
Definition Ops.h.inc:694
::mlir::TypedValue<::llzk::pod::PodType > getPodRef()
Definition Ops.h.inc:690
::llvm::StringRef getRecordName()
Definition Ops.cpp.inc:926
OpClass::Properties & buildInstantiationAttrs(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
OpClass::Properties & buildInstantiationAttrsEmpty(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
mlir::ArrayRef< RecordValue > InitializedRecords
Definition Types.h:23
OpAsmParser::UnresolvedOperand UnresolvedOp
Definition Ops.cpp:193
ParseResult parseRecordName(AsmParser &parser, FlatSymbolRefAttr &name)
Definition Ops.cpp:397
ParseResult parseRecordInitialization(OpAsmParser &parser, StringAttr &name, UnresolvedOp &operand)
Definition Ops.cpp:196
void printRecordName(AsmPrinter &printer, Operation *, FlatSymbolRefAttr name)
Definition Ops.cpp:401
void printMultiDimAndSymbolList(mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::OperandRangeRange multiMapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
Definition OpHelpers.h:115
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
mlir::ParseResult parseMultiDimAndSymbolList(mlir::OpAsmParser &parser, mlir::SmallVector< mlir::SmallVector< mlir::OpAsmParser::UnresolvedOperand > > &multiMapOperands, mlir::DenseI32ArrayAttr &numDimsPerMap)
Definition OpHelpers.h:107
void setInitializedRecords(const ::mlir::ArrayAttr &propValue)
Definition Ops.h.inc:46