LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
OpHelpers.h
Go to the documentation of this file.
1//===-- OpHelpers.h ---------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
14#include "llzk/Util/Constants.h"
17
18#include <mlir/IR/BuiltinTypes.h>
19#include <mlir/IR/OpImplementation.h>
20#include <mlir/IR/Operation.h>
21#include <mlir/IR/SymbolTable.h>
22#include <mlir/Support/LogicalResult.h>
23
24#include <llvm/ADT/SmallString.h>
25#include <llvm/ADT/StringRef.h>
26
27namespace llzk {
28
32template <typename OpClass> inline llvm::StringLiteral getOperationName() {
33 return OpClass::getOperationName();
34}
35
38template <typename OpClass> inline OpClass getSelfOrParentOfType(mlir::Operation *op) {
39 if (op) {
40 if (OpClass self = llvm::dyn_cast<OpClass>(op)) {
41 return self;
42 }
43 if (OpClass parent = op->getParentOfType<OpClass>()) {
44 return parent;
45 }
46 }
47 return {};
48}
49
51template <typename OpClass> inline OpClass getParentOfType(mlir::Operation *op) {
52 if (op) {
53 if (OpClass p = op->getParentOfType<OpClass>()) {
54 return p;
55 }
56 }
57 return {};
58}
59
62template <typename... OpTys> bool hasParentThatIsa(mlir::Operation *op) {
63 while ((op = op->getParentOp())) {
64 if (llvm::isa<OpTys...>(op)) {
65 return true;
66 }
67 }
68 return false;
69}
70
72template <typename TypeClass>
73// Suppress false positive from `clang-tidy`
74// NOLINTNEXTLINE(bugprone-crtp-constructor-accessibility)
76 : public mlir::OpTrait::TraitBase<TypeClass, LLZKSymbolTableImplTrait> {
77public:
78 static mlir::LogicalResult verifyRegionTrait(mlir::Operation *op) {
79 // Note: the current op will be checked by the normal `SymbolTable` trait that is
80 // included in `LLZKSymbolTable`. Checking it here would cause the same error described
81 // in `LLZKSymbolTable`.
82 while ((op = op->getParentWithTrait<mlir::OpTrait::SymbolTable>())) {
83 if (mlir::failed(mlir::detail::verifySymbolTable(op))) {
84 return mlir::failure();
85 }
86 }
87 return mlir::success();
88 }
89};
90
92template <typename... Ancestors> struct HasAncestor {
93 template <typename ConcreteType>
94 static void appendTypeName(mlir::InFlightDiagnostic &diag, bool &first) {
95 if (!first) {
96 diag << ", ";
97 }
98 first = false;
99 diag << '\'' << ConcreteType::getOperationName() << '\'';
100 }
101
102 template <typename ConcreteType>
103 // Suppress false positive from `clang-tidy`
104 // NOLINTNEXTLINE(bugprone-crtp-constructor-accessibility)
105 struct Impl : public mlir::OpTrait::TraitBase<ConcreteType, Impl> {
106 static mlir::LogicalResult verifyRegionTrait(mlir::Operation *op) {
108 return mlir::success();
109 }
110 auto diag = op->emitOpError();
111
112 if constexpr (sizeof...(Ancestors) == 1) {
113 diag << "must have an ancestor of type ";
114 } else {
115 diag << "must have an ancestor of one of the following types: ";
116 }
117
118 bool first = true;
119 (HasAncestor::template appendTypeName<Ancestors>(diag, first), ...);
120
121 return diag;
122 }
123 };
124};
125
128template <int OperandSegmentIndex> struct VerifySizesForMultiAffineOps {
129 template <typename TypeClass> class Impl : public mlir::OpTrait::TraitBase<TypeClass, Impl> {
130 inline static mlir::LogicalResult verifyHelper(mlir::Operation *op, int32_t segmentSize) {
131 TypeClass c = llvm::cast<TypeClass>(op);
133 op, segmentSize, c.getMapOpGroupSizesAttr(), c.getMapOperands(), c.getNumDimsPerMapAttr()
134 );
135 }
136
137 public:
138 static mlir::LogicalResult verifyTrait(mlir::Operation *op) {
139 if (TypeClass::template hasTrait<mlir::OpTrait::AttrSizedOperandSegments>()) {
140 // If the AttrSizedOperandSegments trait is present, must have `OperandSegmentIndex`.
141 static_assert(
142 OperandSegmentIndex >= 0,
143 "When the `AttrSizedOperandSegments` trait is present, the index of `$mapOperands` "
144 "within the `operandSegmentSizes` attribute must be specified."
145 );
146 mlir::DenseI32ArrayAttr segmentSizes = op->getAttrOfType<mlir::DenseI32ArrayAttr>(
147 mlir::OpTrait::AttrSizedOperandSegments<TypeClass>::getOperandSegmentSizeAttr()
148 );
149 assert(
150 OperandSegmentIndex < segmentSizes.size() &&
151 "Parameter of `VerifySizesForMultiAffineOps` exceeds the number of ODS-declared "
152 "operands"
153 );
154 return verifyHelper(op, segmentSizes[OperandSegmentIndex]);
155 } else {
156 // If the trait is not present, the `OperandSegmentIndex` is ignored. Pass `-1` to indicate
157 // that the checks against `operandSegmentSizes` should be skipped.
158 return verifyHelper(op, -1);
159 }
160 }
161 };
162};
163
164template <unsigned N>
165inline mlir::ParseResult parseDimAndSymbolList(
166 mlir::OpAsmParser &parser,
167 mlir::SmallVector<mlir::OpAsmParser::UnresolvedOperand, N> &mapOperands,
168 mlir::IntegerAttr &numDims
169) {
170 return affineMapHelpers::parseDimAndSymbolList(parser, mapOperands, numDims);
171}
172
174 mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::OperandRange mapOperands,
175 mlir::IntegerAttr numDims
176) {
177 return affineMapHelpers::printDimAndSymbolList(printer, op, mapOperands, numDims);
178}
179
180inline mlir::ParseResult parseMultiDimAndSymbolList(
181 mlir::OpAsmParser &parser,
182 mlir::SmallVector<mlir::SmallVector<mlir::OpAsmParser::UnresolvedOperand>> &multiMapOperands,
183 mlir::DenseI32ArrayAttr &numDimsPerMap
184) {
185 return affineMapHelpers::parseMultiDimAndSymbolList(parser, multiMapOperands, numDimsPerMap);
186}
187
189 mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::OperandRangeRange multiMapOperands,
190 mlir::DenseI32ArrayAttr numDimsPerMap
191) {
192 return affineMapHelpers::printMultiDimAndSymbolList(printer, op, multiMapOperands, numDimsPerMap);
193}
194
195inline mlir::ParseResult parseAttrDictWithWarnings(
196 mlir::OpAsmParser &parser, mlir::NamedAttrList &extraAttrs, mlir::OperationState &state
197) {
198 return affineMapHelpers::parseAttrDictWithWarnings(parser, extraAttrs, state);
199}
200
201template <typename ConcreteOp>
203 mlir::OpAsmPrinter &printer, ConcreteOp op, mlir::DictionaryAttr extraAttrs,
204 typename mlir::PropertiesSelector<ConcreteOp>::type state
205) {
206 return affineMapHelpers::printAttrDictWithWarnings(printer, op, extraAttrs, state);
207}
208
209inline mlir::ParseResult parseTemplateParams(mlir::AsmParser &parser, mlir::ArrayAttr &value) {
210 mlir::SmallVector<mlir::Attribute> elements;
211 auto parseElement = [&]() -> mlir::ParseResult {
212 // `?` is a wildcard meaning "infer this parameter"; only valid for tvar-restricted params.
213 if (mlir::succeeded(parser.parseOptionalQuestion())) {
214 elements.push_back(parser.getBuilder().getIndexAttr(mlir::ShapedType::kDynamic));
215 return mlir::success();
216 }
217 auto attrParseResult = mlir::FieldParser<mlir::Attribute>::parse(parser);
218 if (mlir::failed(attrParseResult)) {
219 return parser.emitError(
220 parser.getCurrentLocation(), "failed to parse template parameter attribute"
221 );
222 }
223 auto emitError = [&parser] {
224 return llzk::InFlightDiagnosticWrapper(parser.emitError(parser.getCurrentLocation()));
225 };
226 mlir::FailureOr<mlir::Attribute> forced = forceIntAttrType(*attrParseResult, emitError);
227 if (mlir::failed(forced)) {
228 return mlir::failure();
229 }
230 elements.push_back(*forced);
231 return mlir::success();
232 };
233 auto res = parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::Square, parseElement);
234 if (mlir::failed(res)) {
235 return res; // parseElement() already emits a sufficient error message
236 }
237 value = parser.getBuilder().getArrayAttr(elements);
238 return mlir::success();
239}
240
241// 2 parameter version used by types
242inline void printTemplateParams(mlir::AsmPrinter &printer, mlir::ArrayAttr value) {
243 printer << '[';
244 printAttrs(printer, value.getValue(), ", ");
245 printer << ']';
246}
247
248// 3 parameter version used by ops
249inline void printTemplateParams(mlir::AsmPrinter &printer, void *, mlir::ArrayAttr value) {
250 printTemplateParams(printer, value);
251}
252
253} // namespace llzk
Wrapper around InFlightDiagnostic that can either be a regular InFlightDiagnostic or a special versio...
Definition ErrorHelper.h:26
See LLZKSymbolTable ODS documentation for details.
Definition OpHelpers.h:76
static mlir::LogicalResult verifyRegionTrait(mlir::Operation *op)
Definition OpHelpers.h:78
static mlir::LogicalResult verifyTrait(mlir::Operation *op)
Definition OpHelpers.h:138
LogicalResult verifySizesForMultiAffineOps(Operation *op, int32_t segmentSize, ArrayRef< int32_t > mapOpGroupSizes, OperandRangeRange mapOperands, ArrayRef< int32_t > numDimsPerMap)
ParseResult parseMultiDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &multiMapOperands, DenseI32ArrayAttr &numDimsPerMap)
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapOperands, IntegerAttr &numDims)
ParseResult parseAttrDictWithWarnings(OpAsmParser &parser, NamedAttrList &extraAttrs, OperationState &state)
void printMultiDimAndSymbolList(OpAsmPrinter &printer, Operation *, OperandRangeRange multiMapOperands, DenseI32ArrayAttr numDimsPerMap)
void printDimAndSymbolList(OpAsmPrinter &printer, Operation *, OperandRange mapOperands, IntegerAttr numDims)
void printAttrDictWithWarnings(mlir::OpAsmPrinter &printer, ConcreteOp, mlir::DictionaryAttr extraAttrs, typename ConcreteOp::Properties)
FailureOr< Attribute > forceIntAttrType(Attribute attr, EmitErrorFn emitError)
void printTemplateParams(mlir::AsmPrinter &printer, mlir::ArrayAttr value)
Definition OpHelpers.h:242
void printDimAndSymbolList(mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::OperandRange mapOperands, mlir::IntegerAttr numDims)
Definition OpHelpers.h:173
void printAttrs(AsmPrinter &printer, ArrayRef< Attribute > attrs, const StringRef &separator)
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent/ancestor operation that is of type 'OpClass'.
Definition OpHelpers.h:51
mlir::ParseResult parseAttrDictWithWarnings(mlir::OpAsmParser &parser, mlir::NamedAttrList &extraAttrs, mlir::OperationState &state)
Definition OpHelpers.h:195
void printMultiDimAndSymbolList(mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::OperandRangeRange multiMapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
Definition OpHelpers.h:188
void printAttrDictWithWarnings(mlir::OpAsmPrinter &printer, ConcreteOp op, mlir::DictionaryAttr extraAttrs, typename mlir::PropertiesSelector< ConcreteOp >::type state)
Definition OpHelpers.h:202
mlir::ParseResult parseTemplateParams(mlir::AsmParser &parser, mlir::ArrayAttr &value)
Definition OpHelpers.h:209
mlir::ParseResult parseMultiDimAndSymbolList(mlir::OpAsmParser &parser, mlir::SmallVector< mlir::SmallVector< mlir::OpAsmParser::UnresolvedOperand > > &multiMapOperands, mlir::DenseI32ArrayAttr &numDimsPerMap)
Definition OpHelpers.h:180
llvm::StringLiteral getOperationName()
Get the operation name, like "constrain.eq" for the given OpClass.
Definition OpHelpers.h:32
bool hasParentThatIsa(mlir::Operation *op)
Return true if the parameter has a parent/ancestor op that is an instance of one of the template type...
Definition OpHelpers.h:62
OpClass getSelfOrParentOfType(mlir::Operation *op)
Return the closest surrounding parent/ancestor operation that is of type 'OpClass',...
Definition OpHelpers.h:38
mlir::ParseResult parseDimAndSymbolList(mlir::OpAsmParser &parser, mlir::SmallVector< mlir::OpAsmParser::UnresolvedOperand, N > &mapOperands, mlir::IntegerAttr &numDims)
Definition OpHelpers.h:165
static mlir::LogicalResult verifyRegionTrait(mlir::Operation *op)
Definition OpHelpers.h:106
See HasAncestor ODS documentation for details.
Definition OpHelpers.h:92
static void appendTypeName(mlir::InFlightDiagnostic &diag, bool &first)
Definition OpHelpers.h:94
Produces errors if there is an inconsistency in the various attributes/values that are used to suppor...
Definition OpHelpers.h:128