27#include <mlir/Dialect/SCF/Transforms/Patterns.h>
28#include <mlir/IR/Attributes.h>
29#include <mlir/IR/BuiltinAttributes.h>
30#include <mlir/IR/BuiltinTypes.h>
31#include <mlir/IR/MLIRContext.h>
32#include <mlir/IR/Operation.h>
33#include <mlir/IR/PatternMatch.h>
34#include <mlir/Transforms/DialectConversion.h>
36#include <llvm/ADT/SmallVector.h>
46 if (!tyConv.isLegal(op)) {
51 mlir::DictionaryAttr dictAttr = op->getAttrDictionary();
52 for (mlir::NamedAttribute n : dictAttr.getValue()) {
53 if (mlir::TypeAttr tyAttr = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
54 mlir::Type t = tyAttr.getValue();
55 if (mlir::FunctionType funcTy = llvm::dyn_cast<mlir::FunctionType>(t)) {
56 if (!tyConv.isSignatureLegal(funcTy)) {
60 if (!tyConv.isLegal(t)) {
71template <
typename OpClass,
typename Rewriter,
typename... Args>
73 mlir::DictionaryAttr attrs = op->getDiscardableAttrDictionary();
75 newOp->setDiscardableAttrs(attrs);
80static struct OpClassesWithStructTypes {
90 array::ExtractArrayOp,
91 constrain::EmitEqualityOp,
92 constrain::EmitContainmentOp,
93 component::MemberDefOp,
94 component::MemberReadOp,
95 component::MemberWriteOp,
96 component::CreateStructOp,
100 global::GlobalReadOp,
101 global::GlobalWriteOp,
104 polymorphic::UnifiableCastOp,
105 polymorphic::ConstReadOp
108 WithGeneralBuilder {};
115 const std::tuple<function::CallOp, array::CreateArrayOp, pod::NewPodOp> NoGeneralBuilder {};
117} OpClassesWithStructTypes;
127template <
typename OpClass>
128class GeneralTypeReplacePattern :
public mlir::OpConversionPattern<OpClass> {
132 GeneralTypeReplacePattern(mlir::TypeConverter &converter, mlir::MLIRContext *ctx)
133 : mlir::OpConversionPattern<OpClass>(converter, ctx, 0) {}
135 mlir::LogicalResult matchAndRewrite(
136 OpClass op,
typename OpClass::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter
138 const mlir::TypeConverter *converter = mlir::OpConversionPattern<OpClass>::getTypeConverter();
141 mlir::SmallVector<mlir::Type> newResultTypes;
142 if (mlir::failed(converter->convertTypes(op->getResultTypes(), newResultTypes))) {
143 return op->emitError(
"Could not convert Op result types.");
148 adaptor.getAttributes().empty() ||
150 adaptor.getAttributes(), [d = op->getAttrDictionary()](mlir::NamedAttribute a) {
151 return d.contains(a.getName());
156 mlir::SmallVector<mlir::NamedAttribute> newAttrs(op->getAttrDictionary().getValue());
157 for (mlir::NamedAttribute &n : newAttrs) {
158 if (mlir::TypeAttr t = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
159 if (mlir::Type newType = converter->convertType(t.getValue())) {
160 n.setValue(mlir::TypeAttr::get(newType));
162 return op->emitError().append(
"Could not convert type in attribute: ", t);
168 rewriter, op, mlir::TypeRange(newResultTypes), adaptor.getOperands(),
169 mlir::ArrayRef(newAttrs)
171 return mlir::success();
176class CreateArrayOpClassReplacePattern :
public mlir::OpConversionPattern<array::CreateArrayOp> {
180 CreateArrayOpClassReplacePattern(mlir::TypeConverter &converter, mlir::MLIRContext *ctx)
181 : mlir::OpConversionPattern<array::CreateArrayOp>(converter, ctx, 0) {}
183 mlir::LogicalResult match(array::CreateArrayOp op)
const override {
184 if (getTypeConverter()->convertType(op.getType())) {
185 return mlir::success();
187 return op->emitError(
"Could not convert Op result type.");
191 array::CreateArrayOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter
193 mlir::Type newType = getTypeConverter()->convertType(op.getType());
194 assert(llvm::isa<array::ArrayType>(newType) &&
"CreateArrayOp must produce ArrayType result");
195 mlir::DenseI32ArrayAttr numDimsPerMap = op.getNumDimsPerMapAttr();
198 rewriter, op, llvm::cast<array::ArrayType>(newType), adapter.getElements()
202 rewriter, op, llvm::cast<array::ArrayType>(newType), adapter.getMapOperands(),
210class CallOpClassReplacePattern :
public mlir::OpConversionPattern<function::CallOp> {
212 CallOpClassReplacePattern(mlir::TypeConverter &converter, mlir::MLIRContext *ctx)
213 : mlir::OpConversionPattern<function::CallOp>(converter, ctx, 0) {}
215 mlir::LogicalResult matchAndRewrite(
216 function::CallOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter
218 mlir::SmallVector<mlir::Type> newResultTypes;
219 if (mlir::failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
220 return op->emitError(
"Could not convert Op result types.");
223 rewriter, op, newResultTypes, op.getCalleeAttr(), adapter.getMapOperands(),
224 op.getNumDimsPerMapAttr(), adapter.getArgOperands()
226 return mlir::success();
231class NewPodOpClassReplacePattern :
public mlir::OpConversionPattern<pod::NewPodOp> {
233 NewPodOpClassReplacePattern(mlir::TypeConverter &converter, mlir::MLIRContext *ctx)
234 : mlir::OpConversionPattern<pod::NewPodOp>(converter, ctx, 0) {}
236 mlir::LogicalResult matchAndRewrite(
237 pod::NewPodOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter
239 auto newResultType = dyn_cast_if_present<pod::PodType>(
240 getTypeConverter()->convertType(op.getResult().getType())
242 if (!newResultType) {
243 return op->emitError(
"Could not convert Op result types.");
246 rewriter, op, newResultType, adaptor.getMapOperands(), op.getNumDimsPerMapAttr(),
249 return mlir::success();
253template <
typename I,
typename NextOpClass,
typename... OtherOpClasses>
254inline void applyToMoreTypes(I inserter) {
255 std::apply(inserter, std::tuple<NextOpClass, OtherOpClasses...> {});
257template <
typename I>
inline void applyToMoreTypes(I) {}
265template <
typename... AdditionalOpClasses>
267 mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target
269 mlir::RewritePatternSet patterns(ctx);
270 auto inserter = [&](
auto... opClasses) {
271 patterns.add<GeneralTypeReplacePattern<
decltype(opClasses)>...>(tyConv, ctx);
273 std::apply(inserter, OpClassesWithStructTypes.WithGeneralBuilder);
274 applyToMoreTypes<
decltype(inserter), AdditionalOpClasses...>(inserter);
277 CreateArrayOpClassReplacePattern, CallOpClassReplacePattern, NewPodOpClassReplacePattern>(
281 mlir::populateFunctionOpInterfaceTypeConversionPattern<function::FuncDefOp>(patterns, tyConv);
282 mlir::scf::populateSCFStructuralTypeConversionsAndLegality(tyConv, patterns, target);
SmallVector< RecordValue > getInitializedRecordValues(ValueRange initialValues, ArrayAttr initializedRecords)
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter::replaceOpWithNewOp() that automatically copies discardable attributes (i...
bool isNullOrEmpty(mlir::ArrayAttr a)
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet covering all LLZK op types that may contain a StructType.
bool defaultLegalityCheck(const mlir::TypeConverter &tyConv, mlir::Operation *op)
Check whether an op is legal with respect to the given type converter, including TypeAttr attributes ...