24#include <mlir/Dialect/SCF/Transforms/Patterns.h>
25#include <mlir/Transforms/DialectConversion.h>
29#define GEN_PASS_DEF_EMPTYTEMPLATEREMOVALPASS
35#define DEBUG_TYPE "llzk-drop-empty-templates"
47static inline bool hasEmptyParamList(
StructType t) {
48 if (ArrayAttr paramList = t.
getParams()) {
49 return paramList.empty();
55class EmptyParamListStructTypeConverter :
public TypeConverter {
57 EmptyParamListStructTypeConverter() : TypeConverter() {
59 addConversion([](Type inputTy) {
return inputTy; });
62 return hasEmptyParamList(inputTy) ?
StructType::get(inputTy.getNameRef()) : inputTy;
72 addConversion([
this](
PodType inputTy) {
74 llvm::ArrayRef<RecordAttr> records = inputTy.
getRecords();
75 if (records.empty()) {
78 llvm::SmallVector<RecordAttr> newRecords;
79 newRecords.reserve(records.size());
80 MLIRContext *ctx = inputTy.getContext();
81 for (RecordAttr attr : records) {
83 RecordAttr::get(ctx, attr.getName(), this->convertType(attr.getType()))
92class DeleteNoDefTemplatePattern :
public OpConversionPattern<TemplateOp> {
94 using OpConversionPattern<
TemplateOp>::OpConversionPattern;
97 return llvm::any_of(op.
getBodyRegion().getOps(), [](Operation &p) {
98 return llvm::isa<StructDefOp, FuncDefOp>(p);
102 LogicalResult match(
TemplateOp op)
const override {
return failure(legal(op)); }
107 llvm::dbgs() <<
"found template with no struct or function definitions: " << op <<
'\n';
109 rewriter.eraseOp(op);
114class ReplaceNoParamTemplatePattern :
public OpConversionPattern<TemplateOp> {
116 using OpConversionPattern<
TemplateOp>::OpConversionPattern;
122 LogicalResult matchAndRewrite(
129 llvm::dbgs() <<
"found template with no constant parameters or expressions: " << op <<
'\n';
133 if (failed(rewriter.convertRegionTypes(¤tBody, *getTypeConverter()))) {
134 LLVM_DEBUG(llvm::dbgs() <<
"convertRegionTypes(currentBody) failed!\n");
138 ModuleOp newOp = rewriter.create<ModuleOp>(op.getLoc(), adaptor.
getSymName());
141 Region &newOpBody = newOp.getBodyRegion();
142 if (!newOpBody.empty()) {
143 rewriter.eraseBlock(&newOpBody.front());
145 rewriter.inlineRegionBefore(currentBody, newOpBody, newOpBody.end());
146 rewriter.eraseOp(op);
152 using Base = EmptyTemplateRemovalPassBase<PassImpl>;
155 void runOnOperation()
override {
156 ModuleOp modOp = getOperation();
157 MLIRContext *ctx = modOp.getContext();
158 EmptyParamListStructTypeConverter tyConv;
161 target.addDynamicallyLegalOp<TemplateOp>([](TemplateOp op) {
162 return DeleteNoDefTemplatePattern::legal(op) && ReplaceNoParamTemplatePattern::legal(op);
166 patterns.add<DeleteNoDefTemplatePattern>(tyConv, ctx);
167 patterns.add<ReplaceNoParamTemplatePattern>(tyConv, ctx);
168 if (failed(applyFullConversion(modOp, target, std::move(patterns)))) {
Common private implementation for poly dialect passes.
::mlir::Type getElementType() const
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
static StructType get(::mlir::SymbolRefAttr structName)
::mlir::ArrayAttr getParams() const
static PodType get(::mlir::MLIRContext *context, ::llvm::ArrayRef<::llzk::pod::RecordAttr > records)
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
::mlir::Region & getBodyRegion()
bool hasConstOps()
Return true if there are ops of type OpT within the body region.
::mlir::Region & getBodyRegion()
::llvm::StringRef getSymName()
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet covering all LLZK op types that may contain a StructType.