LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
TypeConversionPatterns.h
Go to the documentation of this file.
1//===-- TypeConversionPatterns.h --------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// Copyright 2026 Project LLZK
7// SPDX-License-Identifier: Apache-2.0
8//
9//===----------------------------------------------------------------------===//
14//===----------------------------------------------------------------------===//
15
16#pragma once
17
26
27#include <mlir/Dialect/SCF/Transforms/Patterns.h>
28#include <mlir/IR/Attributes.h>
29#include <mlir/IR/BuiltinAttributes.h>
30#include <mlir/IR/BuiltinTypes.h>
31#include <mlir/IR/MLIRContext.h>
32#include <mlir/IR/Operation.h>
33#include <mlir/IR/PatternMatch.h>
34#include <mlir/Transforms/DialectConversion.h>
35
36#include <llvm/ADT/SmallVector.h>
37
38#include <tuple>
39
40namespace llzk {
41
44inline bool defaultLegalityCheck(const mlir::TypeConverter &tyConv, mlir::Operation *op) {
45 // Check operand types and result types
46 if (!tyConv.isLegal(op)) {
47 return false;
48 }
49 // Check type attributes
50 // Extend lifetime of temporary to suppress warnings.
51 mlir::DictionaryAttr dictAttr = op->getAttrDictionary();
52 for (mlir::NamedAttribute n : dictAttr.getValue()) {
53 if (mlir::TypeAttr tyAttr = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
54 mlir::Type t = tyAttr.getValue();
55 if (mlir::FunctionType funcTy = llvm::dyn_cast<mlir::FunctionType>(t)) {
56 if (!tyConv.isSignatureLegal(funcTy)) {
57 return false;
58 }
59 } else {
60 if (!tyConv.isLegal(t)) {
61 return false;
62 }
63 }
64 }
65 }
66 return true;
67}
68
71template <typename OpClass, typename Rewriter, typename... Args>
72inline OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args) {
73 mlir::DictionaryAttr attrs = op->getDiscardableAttrDictionary();
74 OpClass newOp = rewriter.template replaceOpWithNewOp<OpClass>(op, std::forward<Args>(args)...);
75 newOp->setDiscardableAttrs(attrs);
76 return newOp;
77}
78
80static struct OpClassesWithStructTypes {
81
84 const std::tuple<
85 // clang-format off
86 array::ArrayLengthOp,
87 array::ReadArrayOp,
88 array::WriteArrayOp,
89 array::InsertArrayOp,
90 array::ExtractArrayOp,
91 constrain::EmitEqualityOp,
92 constrain::EmitContainmentOp,
93 component::MemberDefOp,
94 component::MemberReadOp,
95 component::MemberWriteOp,
96 component::CreateStructOp,
97 function::FuncDefOp,
98 function::ReturnOp,
99 global::GlobalDefOp,
100 global::GlobalReadOp,
101 global::GlobalWriteOp,
102 pod::ReadPodOp,
103 pod::WritePodOp,
104 polymorphic::UnifiableCastOp,
105 polymorphic::ConstReadOp
106 // clang-format on
107 >
108 WithGeneralBuilder {};
109
115 const std::tuple<function::CallOp, array::CreateArrayOp, pod::NewPodOp> NoGeneralBuilder {};
116
117} OpClassesWithStructTypes;
118
119namespace {
120
127template <typename OpClass>
128class GeneralTypeReplacePattern : public mlir::OpConversionPattern<OpClass> {
129public:
132 GeneralTypeReplacePattern(mlir::TypeConverter &converter, mlir::MLIRContext *ctx)
133 : mlir::OpConversionPattern<OpClass>(converter, ctx, 0) {}
134
135 mlir::LogicalResult matchAndRewrite(
136 OpClass op, typename OpClass::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter
137 ) const override {
138 const mlir::TypeConverter *converter = mlir::OpConversionPattern<OpClass>::getTypeConverter();
139 assert(converter);
140 // Convert result types
141 mlir::SmallVector<mlir::Type> newResultTypes;
142 if (mlir::failed(converter->convertTypes(op->getResultTypes(), newResultTypes))) {
143 return op->emitError("Could not convert Op result types.");
144 }
145 // ASSERT: 'adaptor.getAttributes()' is empty or a subset of 'op->getAttrDictionary()' so the
146 // former can be ignored without losing anything.
147 assert(
148 adaptor.getAttributes().empty() ||
149 llvm::all_of(
150 adaptor.getAttributes(), [d = op->getAttrDictionary()](mlir::NamedAttribute a) {
151 return d.contains(a.getName());
152 }
153 )
154 );
155 // Convert any TypeAttr in the attribute list.
156 mlir::SmallVector<mlir::NamedAttribute> newAttrs(op->getAttrDictionary().getValue());
157 for (mlir::NamedAttribute &n : newAttrs) {
158 if (mlir::TypeAttr t = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
159 if (mlir::Type newType = converter->convertType(t.getValue())) {
160 n.setValue(mlir::TypeAttr::get(newType));
161 } else {
162 return op->emitError().append("Could not convert type in attribute: ", t);
163 }
164 }
165 }
166 // Build a new Op in place of the current one
168 rewriter, op, mlir::TypeRange(newResultTypes), adaptor.getOperands(),
169 mlir::ArrayRef(newAttrs)
170 );
171 return mlir::success();
172 }
173};
174
176class CreateArrayOpClassReplacePattern : public mlir::OpConversionPattern<array::CreateArrayOp> {
177public:
180 CreateArrayOpClassReplacePattern(mlir::TypeConverter &converter, mlir::MLIRContext *ctx)
181 : mlir::OpConversionPattern<array::CreateArrayOp>(converter, ctx, 0) {}
182
183 mlir::LogicalResult match(array::CreateArrayOp op) const override {
184 if (getTypeConverter()->convertType(op.getType())) {
185 return mlir::success();
186 }
187 return op->emitError("Could not convert Op result type.");
188 }
189
190 void rewrite(
191 array::CreateArrayOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter
192 ) const override {
193 mlir::Type newType = getTypeConverter()->convertType(op.getType());
194 assert(llvm::isa<array::ArrayType>(newType) && "CreateArrayOp must produce ArrayType result");
195 mlir::DenseI32ArrayAttr numDimsPerMap = op.getNumDimsPerMapAttr();
196 if (isNullOrEmpty(numDimsPerMap)) {
198 rewriter, op, llvm::cast<array::ArrayType>(newType), adapter.getElements()
199 );
200 } else {
202 rewriter, op, llvm::cast<array::ArrayType>(newType), adapter.getMapOperands(),
203 numDimsPerMap
204 );
205 }
206 }
207};
208
210class CallOpClassReplacePattern : public mlir::OpConversionPattern<function::CallOp> {
211public:
212 CallOpClassReplacePattern(mlir::TypeConverter &converter, mlir::MLIRContext *ctx)
213 : mlir::OpConversionPattern<function::CallOp>(converter, ctx, 0) {}
214
215 mlir::LogicalResult matchAndRewrite(
216 function::CallOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter
217 ) const override {
218 mlir::SmallVector<mlir::Type> newResultTypes;
219 if (mlir::failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
220 return op->emitError("Could not convert Op result types.");
221 }
223 rewriter, op, newResultTypes, op.getCalleeAttr(), adapter.getMapOperands(),
224 op.getNumDimsPerMapAttr(), adapter.getArgOperands()
225 );
226 return mlir::success();
227 }
228};
229
231class NewPodOpClassReplacePattern : public mlir::OpConversionPattern<pod::NewPodOp> {
232public:
233 NewPodOpClassReplacePattern(mlir::TypeConverter &converter, mlir::MLIRContext *ctx)
234 : mlir::OpConversionPattern<pod::NewPodOp>(converter, ctx, 0) {}
235
236 mlir::LogicalResult matchAndRewrite(
237 pod::NewPodOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter
238 ) const override {
239 auto newResultType = dyn_cast_if_present<pod::PodType>(
240 getTypeConverter()->convertType(op.getResult().getType())
241 );
242 if (!newResultType) {
243 return op->emitError("Could not convert Op result types.");
244 }
246 rewriter, op, newResultType, adaptor.getMapOperands(), op.getNumDimsPerMapAttr(),
247 pod::getInitializedRecordValues(adaptor.getInitialValues(), op.getInitializedRecords())
248 );
249 return mlir::success();
250 }
251};
252
253template <typename I, typename NextOpClass, typename... OtherOpClasses>
254inline void applyToMoreTypes(I inserter) {
255 std::apply(inserter, std::tuple<NextOpClass, OtherOpClasses...> {});
256}
257template <typename I> inline void applyToMoreTypes(I) {}
258
259} // namespace
260
265template <typename... AdditionalOpClasses>
266inline mlir::RewritePatternSet newGeneralRewritePatternSet(
267 mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target
268) {
269 mlir::RewritePatternSet patterns(ctx);
270 auto inserter = [&](auto... opClasses) {
271 patterns.add<GeneralTypeReplacePattern<decltype(opClasses)>...>(tyConv, ctx);
272 };
273 std::apply(inserter, OpClassesWithStructTypes.WithGeneralBuilder);
274 applyToMoreTypes<decltype(inserter), AdditionalOpClasses...>(inserter);
275 // Special cases for ops where GeneralTypeReplacePattern doesn't work
276 patterns.add<
277 CreateArrayOpClassReplacePattern, CallOpClassReplacePattern, NewPodOpClassReplacePattern>(
278 tyConv, ctx
279 );
280 // Add builtin FunctionType and SCF op converters
281 mlir::populateFunctionOpInterfaceTypeConversionPattern<function::FuncDefOp>(patterns, tyConv);
282 mlir::scf::populateSCFStructuralTypeConversionsAndLegality(tyConv, patterns, target);
283 return patterns;
284}
285
286} // namespace llzk
SmallVector< RecordValue > getInitializedRecordValues(ValueRange initialValues, ArrayAttr initializedRecords)
Definition Ops.cpp:449
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter::replaceOpWithNewOp() that automatically copies discardable attributes (i...
bool isNullOrEmpty(mlir::ArrayAttr a)
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet covering all LLZK op types that may contain a StructType.
bool defaultLegalityCheck(const mlir::TypeConverter &tyConv, mlir::Operation *op)
Check whether an op is legal with respect to the given type converter, including TypeAttr attributes ...