LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
EmptyTemplateRemovalPass.cpp
Go to the documentation of this file.
1//===-- EmptyTemplateRemovalPass.cpp ----------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
20#include "llzk/Util/Debug.h"
22#include <mlir/Dialect/SCF/Transforms/Patterns.h>
23#include <mlir/Transforms/DialectConversion.h>
24
25// Include the generated base pass class definitions.
27#define GEN_PASS_DEF_EMPTYTEMPLATEREMOVALPASS
29} // namespace llzk::polymorphic
31#include "SharedImpl.h"
32
33#define DEBUG_TYPE "llzk-drop-empty-templates"
34
35using namespace mlir;
36using namespace llzk::array;
37using namespace llzk::component;
38using namespace llzk::function;
39using namespace llzk::polymorphic;
40using namespace llzk::polymorphic::detail;
42namespace {
43
44static inline bool hasEmptyParamList(StructType t) {
45 if (ArrayAttr paramList = t.getParams()) {
46 return paramList.empty();
47 }
48 return false;
49}
50
52class EmptyParamListStructTypeConverter : public TypeConverter {
53public:
54 EmptyParamListStructTypeConverter() : TypeConverter() {
55
56 addConversion([](Type inputTy) { return inputTy; });
58 addConversion([](StructType inputTy) -> StructType {
59 return hasEmptyParamList(inputTy) ? StructType::get(inputTy.getNameRef()) : inputTy;
60 });
61
62 addConversion([this](ArrayType inputTy) {
63 // Recursively convert element type
64 return ArrayType::get(
65 this->convertType(inputTy.getElementType()), inputTy.getDimensionSizes()
66 );
67 });
68 }
69};
70
72class DeleteNoDefTemplatePattern : public OpConversionPattern<TemplateOp> {
73public:
74 using OpConversionPattern<TemplateOp>::OpConversionPattern;
75
76 static inline bool legal(TemplateOp op) {
77 return llvm::any_of(op.getBodyRegion().getOps(), [](Operation &p) {
78 return llvm::isa<StructDefOp, FuncDefOp>(p);
79 });
80 }
81
82 LogicalResult match(TemplateOp op) const override { return failure(legal(op)); }
83
84 void
85 rewrite(TemplateOp op, TemplateOpAdaptor, ConversionPatternRewriter &rewriter) const override {
86 LLVM_DEBUG({
87 llvm::dbgs() << "found template with no struct or function definitions: " << op << '\n';
88 });
89 rewriter.eraseOp(op);
90 }
91};
92
94class ReplaceNoParamTemplatePattern : public OpConversionPattern<TemplateOp> {
95public:
96 using OpConversionPattern<TemplateOp>::OpConversionPattern;
97
98 static inline bool legal(TemplateOp op) {
99 return op.hasConstOps<TemplateSymbolBindingOpInterface>();
100 }
101
102 LogicalResult matchAndRewrite(
103 TemplateOp op, TemplateOpAdaptor adaptor, ConversionPatternRewriter &rewriter
104 ) const override {
105 if (legal(op)) {
106 return failure();
107 }
108 LLVM_DEBUG({
109 llvm::dbgs() << "found template with no constant parameters or expressions: " << op << '\n';
110 });
111 // Convert types within the current body.
112 Region &currentBody = adaptor.getBodyRegion();
113 if (failed(rewriter.convertRegionTypes(&currentBody, *getTypeConverter()))) {
114 LLVM_DEBUG(llvm::dbgs() << "convertRegionTypes(currentBody) failed!\n");
115 return failure();
116 }
117 // Insert new ModuleOp at location of the current template.
118 ModuleOp newOp = rewriter.create<ModuleOp>(op.getLoc(), adaptor.getSymName());
119 // Move the current body into the module and erase the now-empty template op.
120 // First, clear body region of the new module to prepare for `inlineRegionBefore`.
121 Region &newOpBody = newOp.getBodyRegion();
122 if (!newOpBody.empty()) {
123 rewriter.eraseBlock(&newOpBody.front());
124 }
125 rewriter.inlineRegionBefore(currentBody, newOpBody, newOpBody.end());
126 rewriter.eraseOp(op);
127 return success();
128 }
129};
130
131class EmptyTemplateRemovalPass
132 : public llzk::polymorphic::impl::EmptyTemplateRemovalPassBase<EmptyTemplateRemovalPass> {
133
134 void runOnOperation() override {
135 ModuleOp modOp = getOperation();
136 MLIRContext *ctx = modOp.getContext();
137 EmptyParamListStructTypeConverter tyConv;
138 ConversionTarget target = newConverterDefinedTarget<>(tyConv, ctx);
139 // Mark TemplateOp legal only if legal according to both patterns.
140 target.addDynamicallyLegalOp<TemplateOp>([](TemplateOp op) {
141 return DeleteNoDefTemplatePattern::legal(op) && ReplaceNoParamTemplatePattern::legal(op);
142 });
143 RewritePatternSet patterns = llzk::newGeneralRewritePatternSet(tyConv, ctx, target);
144 // Try `DeleteNoDefTemplatePattern` first since full removal is better that replacement.
145 patterns.add<DeleteNoDefTemplatePattern>(tyConv, ctx);
146 patterns.add<ReplaceNoParamTemplatePattern>(tyConv, ctx);
147 if (failed(applyFullConversion(modOp, target, std::move(patterns)))) {
148 signalPassFailure();
149 }
150 }
151};
152
153} // namespace
154
156 return std::make_unique<EmptyTemplateRemovalPass>();
157};
Common private implementation for poly dialect passes.
::mlir::Type getElementType() const
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
Definition Types.cpp.inc:83
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:79
::mlir::ArrayAttr getParams() const
::mlir::Region & getBodyRegion()
Definition Ops.h.inc:872
bool hasConstOps()
Return true if there are ops of type OpT within the body region.
Definition Ops.h.inc:927
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:81
std::unique_ptr< mlir::Pass > createEmptyTemplateRemoval()
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet covering all LLZK op types that may contain a StructType.