LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
TypeConversionPatterns.h
Go to the documentation of this file.
1//===-- TypeConversionPatterns.h --------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// Copyright 2026 Project LLZK
7// SPDX-License-Identifier: Apache-2.0
8//
9//===----------------------------------------------------------------------===//
14//===----------------------------------------------------------------------===//
15
16#pragma once
17
25
26#include <mlir/Dialect/SCF/Transforms/Patterns.h>
27#include <mlir/IR/Attributes.h>
28#include <mlir/IR/BuiltinAttributes.h>
29#include <mlir/IR/BuiltinTypes.h>
30#include <mlir/IR/MLIRContext.h>
31#include <mlir/IR/Operation.h>
32#include <mlir/IR/PatternMatch.h>
33#include <mlir/Transforms/DialectConversion.h>
34
35#include <llvm/ADT/SmallVector.h>
36
37#include <tuple>
38
39namespace llzk {
40
43inline bool defaultLegalityCheck(const mlir::TypeConverter &tyConv, mlir::Operation *op) {
44 // Check operand types and result types
45 if (!tyConv.isLegal(op)) {
46 return false;
47 }
48 // Check type attributes
49 // Extend lifetime of temporary to suppress warnings.
50 mlir::DictionaryAttr dictAttr = op->getAttrDictionary();
51 for (mlir::NamedAttribute n : dictAttr.getValue()) {
52 if (mlir::TypeAttr tyAttr = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
53 mlir::Type t = tyAttr.getValue();
54 if (mlir::FunctionType funcTy = llvm::dyn_cast<mlir::FunctionType>(t)) {
55 if (!tyConv.isSignatureLegal(funcTy)) {
56 return false;
57 }
58 } else {
59 if (!tyConv.isLegal(t)) {
60 return false;
61 }
62 }
63 }
64 }
65 return true;
66}
67
70template <typename OpClass, typename Rewriter, typename... Args>
71inline OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args) {
72 mlir::DictionaryAttr attrs = op->getDiscardableAttrDictionary();
73 OpClass newOp = rewriter.template replaceOpWithNewOp<OpClass>(op, std::forward<Args>(args)...);
74 newOp->setDiscardableAttrs(attrs);
75 return newOp;
76}
77
79static struct OpClassesWithStructTypes {
80
83 const std::tuple<
84 // clang-format off
85 array::ArrayLengthOp,
86 array::ReadArrayOp,
87 array::WriteArrayOp,
88 array::InsertArrayOp,
89 array::ExtractArrayOp,
90 constrain::EmitEqualityOp,
91 constrain::EmitContainmentOp,
92 component::MemberDefOp,
93 component::MemberReadOp,
94 component::MemberWriteOp,
95 component::CreateStructOp,
96 function::FuncDefOp,
97 function::ReturnOp,
98 global::GlobalDefOp,
99 global::GlobalReadOp,
100 global::GlobalWriteOp,
101 polymorphic::UnifiableCastOp,
102 polymorphic::ConstReadOp
103 // clang-format on
104 >
105 WithGeneralBuilder {};
106
112 const std::tuple<function::CallOp, array::CreateArrayOp> NoGeneralBuilder {};
113
114} OpClassesWithStructTypes;
115
116namespace {
117
124template <typename OpClass>
125class GeneralTypeReplacePattern : public mlir::OpConversionPattern<OpClass> {
126public:
129 GeneralTypeReplacePattern(mlir::TypeConverter &converter, mlir::MLIRContext *ctx)
130 : mlir::OpConversionPattern<OpClass>(converter, ctx, 0) {}
131
132 mlir::LogicalResult matchAndRewrite(
133 OpClass op, typename OpClass::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter
134 ) const override {
135 const mlir::TypeConverter *converter = mlir::OpConversionPattern<OpClass>::getTypeConverter();
136 assert(converter);
137 // Convert result types
138 mlir::SmallVector<mlir::Type> newResultTypes;
139 if (mlir::failed(converter->convertTypes(op->getResultTypes(), newResultTypes))) {
140 return op->emitError("Could not convert Op result types.");
141 }
142 // ASSERT: 'adaptor.getAttributes()' is empty or a subset of 'op->getAttrDictionary()' so the
143 // former can be ignored without losing anything.
144 assert(
145 adaptor.getAttributes().empty() ||
146 llvm::all_of(
147 adaptor.getAttributes(), [d = op->getAttrDictionary()](mlir::NamedAttribute a) {
148 return d.contains(a.getName());
149 }
150 )
151 );
152 // Convert any TypeAttr in the attribute list.
153 mlir::SmallVector<mlir::NamedAttribute> newAttrs(op->getAttrDictionary().getValue());
154 for (mlir::NamedAttribute &n : newAttrs) {
155 if (mlir::TypeAttr t = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
156 if (mlir::Type newType = converter->convertType(t.getValue())) {
157 n.setValue(mlir::TypeAttr::get(newType));
158 } else {
159 return op->emitError().append("Could not convert type in attribute: ", t);
160 }
161 }
162 }
163 // Build a new Op in place of the current one
165 rewriter, op, mlir::TypeRange(newResultTypes), adaptor.getOperands(),
166 mlir::ArrayRef(newAttrs)
167 );
168 return mlir::success();
169 }
170};
171
173class CreateArrayOpClassReplacePattern : public mlir::OpConversionPattern<array::CreateArrayOp> {
174public:
177 CreateArrayOpClassReplacePattern(mlir::TypeConverter &converter, mlir::MLIRContext *ctx)
178 : mlir::OpConversionPattern<array::CreateArrayOp>(converter, ctx, 0) {}
179
180 mlir::LogicalResult match(array::CreateArrayOp op) const override {
181 if (getTypeConverter()->convertType(op.getType())) {
182 return mlir::success();
183 }
184 return op->emitError("Could not convert Op result type.");
185 }
186
187 void rewrite(
188 array::CreateArrayOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter
189 ) const override {
190 mlir::Type newType = getTypeConverter()->convertType(op.getType());
191 assert(llvm::isa<array::ArrayType>(newType) && "CreateArrayOp must produce ArrayType result");
192 mlir::DenseI32ArrayAttr numDimsPerMap = op.getNumDimsPerMapAttr();
193 if (isNullOrEmpty(numDimsPerMap)) {
195 rewriter, op, llvm::cast<array::ArrayType>(newType), adapter.getElements()
196 );
197 } else {
199 rewriter, op, llvm::cast<array::ArrayType>(newType), adapter.getMapOperands(),
200 numDimsPerMap
201 );
202 }
203 }
204};
205
207class CallOpClassReplacePattern : public mlir::OpConversionPattern<function::CallOp> {
208public:
209 CallOpClassReplacePattern(mlir::TypeConverter &converter, mlir::MLIRContext *ctx)
210 : mlir::OpConversionPattern<function::CallOp>(converter, ctx, 0) {}
211
212 mlir::LogicalResult matchAndRewrite(
213 function::CallOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter
214 ) const override {
215 mlir::SmallVector<mlir::Type> newResultTypes;
216 if (mlir::failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
217 return op->emitError("Could not convert Op result types.");
218 }
220 rewriter, op, newResultTypes, op.getCalleeAttr(), adapter.getMapOperands(),
221 op.getNumDimsPerMapAttr(), adapter.getArgOperands()
222 );
223 return mlir::success();
224 }
225};
226
227template <typename I, typename NextOpClass, typename... OtherOpClasses>
228inline void applyToMoreTypes(I inserter) {
229 std::apply(inserter, std::tuple<NextOpClass, OtherOpClasses...> {});
230}
231template <typename I> inline void applyToMoreTypes(I) {}
232
233} // namespace
234
239template <typename... AdditionalOpClasses>
240inline mlir::RewritePatternSet newGeneralRewritePatternSet(
241 mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target
242) {
243 mlir::RewritePatternSet patterns(ctx);
244 auto inserter = [&](auto... opClasses) {
245 patterns.add<GeneralTypeReplacePattern<decltype(opClasses)>...>(tyConv, ctx);
246 };
247 std::apply(inserter, OpClassesWithStructTypes.WithGeneralBuilder);
248 applyToMoreTypes<decltype(inserter), AdditionalOpClasses...>(inserter);
249 // Special cases for ops where GeneralTypeReplacePattern doesn't work
250 patterns.add<CreateArrayOpClassReplacePattern, CallOpClassReplacePattern>(tyConv, ctx);
251 // Add builtin FunctionType and SCF op converters
252 mlir::populateFunctionOpInterfaceTypeConversionPattern<function::FuncDefOp>(patterns, tyConv);
253 mlir::scf::populateSCFStructuralTypeConversionsAndLegality(tyConv, patterns, target);
254 return patterns;
255}
256
257} // namespace llzk
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter::replaceOpWithNewOp() that automatically copies discardable attributes (i...
bool isNullOrEmpty(mlir::ArrayAttr a)
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet covering all LLZK op types that may contain a StructType.
bool defaultLegalityCheck(const mlir::TypeConverter &tyConv, mlir::Operation *op)
Check whether an op is legal with respect to the given type converter, including TypeAttr attributes ...