22#include <mlir/Dialect/SCF/Transforms/Patterns.h>
23#include <mlir/Transforms/DialectConversion.h>
27#define GEN_PASS_DEF_EMPTYTEMPLATEREMOVALPASS
33#define DEBUG_TYPE "llzk-drop-empty-templates"
45 if (ArrayAttr paramList = t.
getParams()) {
46 return paramList.empty();
52class EmptyParamListStructTypeConverter :
public TypeConverter {
54 EmptyParamListStructTypeConverter() : TypeConverter() {
56 addConversion([](Type inputTy) {
return inputTy; });
72class DeleteNoDefTemplatePattern :
public OpConversionPattern<TemplateOp> {
74 using OpConversionPattern<
TemplateOp>::OpConversionPattern;
77 return llvm::any_of(op.
getBodyRegion().getOps(), [](Operation &p) {
78 return llvm::isa<StructDefOp, FuncDefOp>(p);
82 LogicalResult match(TemplateOp op)
const override {
return failure(legal(op)); }
85 rewrite(TemplateOp op, TemplateOpAdaptor, ConversionPatternRewriter &rewriter)
const override {
87 llvm::dbgs() <<
"found template with no struct or function definitions: " << op <<
'\n';
94class ReplaceNoParamTemplatePattern :
public OpConversionPattern<TemplateOp> {
96 using OpConversionPattern<TemplateOp>::OpConversionPattern;
98 static inline bool legal(TemplateOp op) {
99 return op.
hasConstOps<TemplateSymbolBindingOpInterface>();
102 LogicalResult matchAndRewrite(
103 TemplateOp op, TemplateOpAdaptor adaptor, ConversionPatternRewriter &rewriter
109 llvm::dbgs() <<
"found template with no constant parameters or expressions: " << op <<
'\n';
113 if (failed(rewriter.convertRegionTypes(¤tBody, *getTypeConverter()))) {
114 LLVM_DEBUG(llvm::dbgs() <<
"convertRegionTypes(currentBody) failed!\n");
118 ModuleOp newOp = rewriter.create<ModuleOp>(op.getLoc(), adaptor.
getSymName());
121 Region &newOpBody = newOp.getBodyRegion();
122 if (!newOpBody.empty()) {
123 rewriter.eraseBlock(&newOpBody.front());
125 rewriter.inlineRegionBefore(currentBody, newOpBody, newOpBody.end());
126 rewriter.eraseOp(op);
131class EmptyTemplateRemovalPass
134 void runOnOperation()
override {
135 ModuleOp modOp = getOperation();
136 MLIRContext *ctx = modOp.getContext();
137 EmptyParamListStructTypeConverter tyConv;
140 target.addDynamicallyLegalOp<TemplateOp>([](TemplateOp op) {
141 return DeleteNoDefTemplatePattern::legal(op) && ReplaceNoParamTemplatePattern::legal(op);
145 patterns.add<DeleteNoDefTemplatePattern>(tyConv, ctx);
146 patterns.add<ReplaceNoParamTemplatePattern>(tyConv, ctx);
147 if (failed(applyFullConversion(modOp, target, std::move(patterns)))) {
156 return std::make_unique<EmptyTemplateRemovalPass>();
Common private implementation for poly dialect passes.
::mlir::Type getElementType() const
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
::mlir::ArrayAttr getParams() const
::mlir::Region & getBodyRegion()
bool hasConstOps()
Return true if there are ops of type OpT within the body region.
::mlir::Region & getBodyRegion()
::llvm::StringRef getSymName()
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
std::unique_ptr< mlir::Pass > createEmptyTemplateRemoval()
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet covering all LLZK op types that may contain a StructType.