LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
OpHelpers.h
Go to the documentation of this file.
1//===-- OpHelpers.h ---------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
14#include "llzk/Util/Constants.h"
17
18#include <mlir/IR/OpImplementation.h>
19#include <mlir/IR/Operation.h>
20#include <mlir/IR/SymbolTable.h>
21#include <mlir/Support/LogicalResult.h>
22
23#include <llvm/ADT/SmallString.h>
24#include <llvm/ADT/StringRef.h>
25
26namespace llzk {
27
29template <typename TypeClass>
30// Suppress false positive from `clang-tidy`
31// NOLINTNEXTLINE(bugprone-crtp-constructor-accessibility)
33 : public mlir::OpTrait::TraitBase<TypeClass, LLZKSymbolTableImplTrait> {
34public:
35 static mlir::LogicalResult verifyRegionTrait(mlir::Operation *op) {
36 // Note: the current op will be checked by the normal `SymbolTable` trait that is
37 // included in `LLZKSymbolTable`. Checking it here would cause the same error described
38 // in `LLZKSymbolTable`.
39 while ((op = op->getParentWithTrait<mlir::OpTrait::SymbolTable>())) {
40 if (mlir::failed(mlir::detail::verifySymbolTable(op))) {
41 return mlir::failure();
42 }
43 }
44 return mlir::success();
45 }
46};
47
51template <typename OpClass> inline llvm::StringLiteral getOperationName() {
52 return OpClass::getOperationName();
53}
54
56template <typename OpClass> inline OpClass getSelfOrParentOfType(mlir::Operation *op) {
57 if (op) {
58 if (OpClass self = llvm::dyn_cast<OpClass>(op)) {
59 return self;
60 }
61 if (OpClass parent = op->getParentOfType<OpClass>()) {
62 return parent;
63 }
64 }
65 return {};
66}
67
69template <typename OpClass> inline OpClass getParentOfType(mlir::Operation *op) {
70 if (op) {
71 if (OpClass p = op->getParentOfType<OpClass>()) {
72 return p;
73 }
74 }
75 return {};
76}
77
80template <int OperandSegmentIndex> struct VerifySizesForMultiAffineOps {
81 template <typename TypeClass> class Impl : public mlir::OpTrait::TraitBase<TypeClass, Impl> {
82 inline static mlir::LogicalResult verifyHelper(mlir::Operation *op, int32_t segmentSize) {
83 TypeClass c = llvm::cast<TypeClass>(op);
85 op, segmentSize, c.getMapOpGroupSizesAttr(), c.getMapOperands(), c.getNumDimsPerMapAttr()
86 );
87 }
88
89 public:
90 static mlir::LogicalResult verifyTrait(mlir::Operation *op) {
91 if (TypeClass::template hasTrait<mlir::OpTrait::AttrSizedOperandSegments>()) {
92 // If the AttrSizedOperandSegments trait is present, must have `OperandSegmentIndex`.
93 static_assert(
94 OperandSegmentIndex >= 0,
95 "When the `AttrSizedOperandSegments` trait is present, the index of `$mapOperands` "
96 "within the `operandSegmentSizes` attribute must be specified."
97 );
98 mlir::DenseI32ArrayAttr segmentSizes = op->getAttrOfType<mlir::DenseI32ArrayAttr>(
99 mlir::OpTrait::AttrSizedOperandSegments<TypeClass>::getOperandSegmentSizeAttr()
100 );
101 assert(
102 OperandSegmentIndex < segmentSizes.size() &&
103 "Parameter of `VerifySizesForMultiAffineOps` exceeds the number of ODS-declared "
104 "operands"
105 );
106 return verifyHelper(op, segmentSizes[OperandSegmentIndex]);
107 } else {
108 // If the trait is not present, the `OperandSegmentIndex` is ignored. Pass `-1` to indicate
109 // that the checks against `operandSegmentSizes` should be skipped.
110 return verifyHelper(op, -1);
111 }
112 }
113 };
114};
115
116template <unsigned N>
117inline mlir::ParseResult parseDimAndSymbolList(
118 mlir::OpAsmParser &parser,
119 mlir::SmallVector<mlir::OpAsmParser::UnresolvedOperand, N> &mapOperands,
120 mlir::IntegerAttr &numDims
121) {
122 return affineMapHelpers::parseDimAndSymbolList(parser, mapOperands, numDims);
123}
124
126 mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::OperandRange mapOperands,
127 mlir::IntegerAttr numDims
128) {
129 return affineMapHelpers::printDimAndSymbolList(printer, op, mapOperands, numDims);
130}
131
132inline mlir::ParseResult parseMultiDimAndSymbolList(
133 mlir::OpAsmParser &parser,
134 mlir::SmallVector<mlir::SmallVector<mlir::OpAsmParser::UnresolvedOperand>> &multiMapOperands,
135 mlir::DenseI32ArrayAttr &numDimsPerMap
136) {
137 return affineMapHelpers::parseMultiDimAndSymbolList(parser, multiMapOperands, numDimsPerMap);
138}
139
141 mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::OperandRangeRange multiMapOperands,
142 mlir::DenseI32ArrayAttr numDimsPerMap
143) {
144 return affineMapHelpers::printMultiDimAndSymbolList(printer, op, multiMapOperands, numDimsPerMap);
145}
146
147inline mlir::ParseResult parseAttrDictWithWarnings(
148 mlir::OpAsmParser &parser, mlir::NamedAttrList &extraAttrs, mlir::OperationState &state
149) {
150 return affineMapHelpers::parseAttrDictWithWarnings(parser, extraAttrs, state);
151}
152
153template <typename ConcreteOp>
155 mlir::OpAsmPrinter &printer, ConcreteOp op, mlir::DictionaryAttr extraAttrs,
156 typename mlir::PropertiesSelector<ConcreteOp>::type state
157) {
158 return affineMapHelpers::printAttrDictWithWarnings(printer, op, extraAttrs, state);
159}
160
161inline mlir::ParseResult parseTemplateParams(mlir::AsmParser &parser, mlir::ArrayAttr &value) {
162 auto parseResult = mlir::FieldParser<mlir::ArrayAttr>::parse(parser);
163 if (mlir::failed(parseResult)) {
164 return parser.emitError(parser.getCurrentLocation(), "failed to parse template parameters");
165 }
166 auto emitError = [&parser] {
167 return llzk::InFlightDiagnosticWrapper(parser.emitError(parser.getCurrentLocation()));
168 };
169 mlir::FailureOr<mlir::SmallVector<mlir::Attribute>> res =
170 forceIntAttrTypes(parseResult->getValue(), emitError);
171 if (mlir::failed(res)) {
172 return mlir::failure();
173 }
174 value = parser.getBuilder().getArrayAttr(*res);
175 return mlir::success();
176}
177
178// 2 parameter version used by types
179inline void printTemplateParams(mlir::AsmPrinter &printer, mlir::ArrayAttr value) {
180 printer << '[';
181 printAttrs(printer, value.getValue(), ", ");
182 printer << ']';
183}
184
185// 3 parameter version used by ops
186inline void printTemplateParams(mlir::AsmPrinter &printer, void *, mlir::ArrayAttr value) {
187 printTemplateParams(printer, value);
188}
189
190} // namespace llzk
Wrapper around InFlightDiagnostic that can either be a regular InFlightDiagnostic or a special versio...
Definition ErrorHelper.h:26
See LLZKSymbolTable ODS documentation for details.
Definition OpHelpers.h:33
static mlir::LogicalResult verifyRegionTrait(mlir::Operation *op)
Definition OpHelpers.h:35
static mlir::LogicalResult verifyTrait(mlir::Operation *op)
Definition OpHelpers.h:90
LogicalResult verifySizesForMultiAffineOps(Operation *op, int32_t segmentSize, ArrayRef< int32_t > mapOpGroupSizes, OperandRangeRange mapOperands, ArrayRef< int32_t > numDimsPerMap)
ParseResult parseMultiDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &multiMapOperands, DenseI32ArrayAttr &numDimsPerMap)
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapOperands, IntegerAttr &numDims)
ParseResult parseAttrDictWithWarnings(OpAsmParser &parser, NamedAttrList &extraAttrs, OperationState &state)
void printMultiDimAndSymbolList(OpAsmPrinter &printer, Operation *, OperandRangeRange multiMapOperands, DenseI32ArrayAttr numDimsPerMap)
void printDimAndSymbolList(OpAsmPrinter &printer, Operation *, OperandRange mapOperands, IntegerAttr numDims)
void printAttrDictWithWarnings(mlir::OpAsmPrinter &printer, ConcreteOp, mlir::DictionaryAttr extraAttrs, typename ConcreteOp::Properties)
void printTemplateParams(mlir::AsmPrinter &printer, mlir::ArrayAttr value)
Definition OpHelpers.h:179
void printDimAndSymbolList(mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::OperandRange mapOperands, mlir::IntegerAttr numDims)
Definition OpHelpers.h:125
void printAttrs(AsmPrinter &printer, ArrayRef< Attribute > attrs, const StringRef &separator)
FailureOr< SmallVector< Attribute > > forceIntAttrTypes(ArrayRef< Attribute > attrList, EmitErrorFn emitError)
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:69
mlir::ParseResult parseAttrDictWithWarnings(mlir::OpAsmParser &parser, mlir::NamedAttrList &extraAttrs, mlir::OperationState &state)
Definition OpHelpers.h:147
void printMultiDimAndSymbolList(mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::OperandRangeRange multiMapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
Definition OpHelpers.h:140
void printAttrDictWithWarnings(mlir::OpAsmPrinter &printer, ConcreteOp op, mlir::DictionaryAttr extraAttrs, typename mlir::PropertiesSelector< ConcreteOp >::type state)
Definition OpHelpers.h:154
mlir::ParseResult parseTemplateParams(mlir::AsmParser &parser, mlir::ArrayAttr &value)
Definition OpHelpers.h:161
mlir::ParseResult parseMultiDimAndSymbolList(mlir::OpAsmParser &parser, mlir::SmallVector< mlir::SmallVector< mlir::OpAsmParser::UnresolvedOperand > > &multiMapOperands, mlir::DenseI32ArrayAttr &numDimsPerMap)
Definition OpHelpers.h:132
llvm::StringLiteral getOperationName()
Get the operation name, like "constrain.eq" for the given OpClass.
Definition OpHelpers.h:51
OpClass getSelfOrParentOfType(mlir::Operation *op)
Return the closest operation that is of type 'OpClass', either the op itself or an ancestor.
Definition OpHelpers.h:56
mlir::ParseResult parseDimAndSymbolList(mlir::OpAsmParser &parser, mlir::SmallVector< mlir::OpAsmParser::UnresolvedOperand, N > &mapOperands, mlir::IntegerAttr &numDims)
Definition OpHelpers.h:117
Produces errors if there is an inconsistency in the various attributes/values that are used to suppor...
Definition OpHelpers.h:80