LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
EmptyTemplateRemovalPass.cpp
Go to the documentation of this file.
1//===-- EmptyTemplateRemovalPass.cpp ----------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
23
24#include <mlir/Dialect/SCF/Transforms/Patterns.h>
25#include <mlir/Transforms/DialectConversion.h>
27// Include the generated base pass class definitions.
29#define GEN_PASS_DEF_EMPTYTEMPLATEREMOVALPASS
31} // namespace llzk::polymorphic
33#include "SharedImpl.h"
35#define DEBUG_TYPE "llzk-drop-empty-templates"
36
37using namespace mlir;
38using namespace llzk::array;
39using namespace llzk::component;
40using namespace llzk::function;
41using namespace llzk::pod;
42using namespace llzk::polymorphic;
44
45namespace {
47static inline bool hasEmptyParamList(StructType t) {
48 if (ArrayAttr paramList = t.getParams()) {
49 return paramList.empty();
50 }
51 return false;
53
55class EmptyParamListStructTypeConverter : public TypeConverter {
56public:
57 EmptyParamListStructTypeConverter() : TypeConverter() {
58
59 addConversion([](Type inputTy) { return inputTy; });
60
61 addConversion([](StructType inputTy) -> StructType {
62 return hasEmptyParamList(inputTy) ? StructType::get(inputTy.getNameRef()) : inputTy;
63 });
64
65 addConversion([this](ArrayType inputTy) {
66 // Recursively convert element type
67 return ArrayType::get(
68 this->convertType(inputTy.getElementType()), inputTy.getDimensionSizes()
69 );
70 });
71
72 addConversion([this](PodType inputTy) {
73 // Recursively convert record types
74 llvm::ArrayRef<RecordAttr> records = inputTy.getRecords();
75 if (records.empty()) {
76 return inputTy;
77 }
78 llvm::SmallVector<RecordAttr> newRecords;
79 newRecords.reserve(records.size());
80 MLIRContext *ctx = inputTy.getContext();
81 for (RecordAttr attr : records) {
82 newRecords.push_back(
83 RecordAttr::get(ctx, attr.getName(), this->convertType(attr.getType()))
84 );
85 }
86 return PodType::get(ctx, newRecords);
87 });
88 }
89};
90
92class DeleteNoDefTemplatePattern : public OpConversionPattern<TemplateOp> {
93public:
94 using OpConversionPattern<TemplateOp>::OpConversionPattern;
95
96 static inline bool legal(TemplateOp op) {
97 return llvm::any_of(op.getBodyRegion().getOps(), [](Operation &p) {
98 return llvm::isa<StructDefOp, FuncDefOp>(p);
99 });
100 }
101
102 LogicalResult match(TemplateOp op) const override { return failure(legal(op)); }
103
104 void
105 rewrite(TemplateOp op, TemplateOpAdaptor, ConversionPatternRewriter &rewriter) const override {
106 LLVM_DEBUG({
107 llvm::dbgs() << "found template with no struct or function definitions: " << op << '\n';
108 });
109 rewriter.eraseOp(op);
110 }
111};
112
114class ReplaceNoParamTemplatePattern : public OpConversionPattern<TemplateOp> {
115public:
116 using OpConversionPattern<TemplateOp>::OpConversionPattern;
117
118 static inline bool legal(TemplateOp op) {
120 }
121
122 LogicalResult matchAndRewrite(
123 TemplateOp op, TemplateOpAdaptor adaptor, ConversionPatternRewriter &rewriter
124 ) const override {
125 if (legal(op)) {
126 return failure();
127 }
128 LLVM_DEBUG({
129 llvm::dbgs() << "found template with no constant parameters or expressions: " << op << '\n';
130 });
131 // Convert types within the current body.
132 Region &currentBody = adaptor.getBodyRegion();
133 if (failed(rewriter.convertRegionTypes(&currentBody, *getTypeConverter()))) {
134 LLVM_DEBUG(llvm::dbgs() << "convertRegionTypes(currentBody) failed!\n");
135 return failure();
136 }
137 // Insert new ModuleOp at location of the current template.
138 ModuleOp newOp = rewriter.create<ModuleOp>(op.getLoc(), adaptor.getSymName());
139 // Move the current body into the module and erase the now-empty template op.
140 // First, clear body region of the new module to prepare for `inlineRegionBefore`.
141 Region &newOpBody = newOp.getBodyRegion();
142 if (!newOpBody.empty()) {
143 rewriter.eraseBlock(&newOpBody.front());
144 }
145 rewriter.inlineRegionBefore(currentBody, newOpBody, newOpBody.end());
146 rewriter.eraseOp(op);
147 return success();
148 }
149};
150
151class PassImpl : public llzk::polymorphic::impl::EmptyTemplateRemovalPassBase<PassImpl> {
152 using Base = EmptyTemplateRemovalPassBase<PassImpl>;
153 using Base::Base;
154
155 void runOnOperation() override {
156 ModuleOp modOp = getOperation();
157 MLIRContext *ctx = modOp.getContext();
158 EmptyParamListStructTypeConverter tyConv;
159 ConversionTarget target = newConverterDefinedTarget<>(tyConv, ctx);
160 // Mark TemplateOp legal only if legal according to both patterns.
161 target.addDynamicallyLegalOp<TemplateOp>([](TemplateOp op) {
162 return DeleteNoDefTemplatePattern::legal(op) && ReplaceNoParamTemplatePattern::legal(op);
163 });
164 RewritePatternSet patterns = llzk::newGeneralRewritePatternSet(tyConv, ctx, target);
165 // Try `DeleteNoDefTemplatePattern` first since full removal is better that replacement.
166 patterns.add<DeleteNoDefTemplatePattern>(tyConv, ctx);
167 patterns.add<ReplaceNoParamTemplatePattern>(tyConv, ctx);
168 if (failed(applyFullConversion(modOp, target, std::move(patterns)))) {
169 signalPassFailure();
170 }
171 }
172};
173
174} // namespace
Common private implementation for poly dialect passes.
::mlir::Type getElementType() const
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
Definition Types.cpp.inc:83
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:79
::mlir::ArrayAttr getParams() const
static PodType get(::mlir::MLIRContext *context, ::llvm::ArrayRef<::llzk::pod::RecordAttr > records)
Definition Types.cpp.inc:68
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
::mlir::Region & getBodyRegion()
Definition Ops.h.inc:872
bool hasConstOps()
Return true if there are ops of type OpT within the body region.
Definition Ops.h.inc:927
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:81
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet covering all LLZK op types that may contain a StructType.