LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
SpecializedMemoryPasses.h
Go to the documentation of this file.
1//===-- SpecializedMemoryPasses.h - Targeted SROA / mem2reg -----*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2026 Project LLZK
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
16//===----------------------------------------------------------------------===//
17
18#pragma once
19
20#include <mlir/Analysis/DataLayoutAnalysis.h>
21#include <mlir/IR/Builders.h>
22#include <mlir/IR/Dominance.h>
23#include <mlir/Interfaces/MemorySlotInterfaces.h>
24#include <mlir/Pass/Pass.h>
25#include <mlir/Transforms/Mem2Reg.h>
26#include <mlir/Transforms/SROA.h>
27
28#include <llvm/ADT/SmallVector.h>
29
30namespace llzk {
31
35template <typename AllocOpTy>
36struct SpecializedSROA : mlir::PassWrapper<SpecializedSROA<AllocOpTy>, mlir::OperationPass<>> {
37
38 mlir::StringRef getArgument() const override { return "llzk-specialized-sroa"; }
39
40 mlir::StringRef getDescription() const override {
41 return "Scalar replacement of aggregates for a specific allocator op type";
42 }
43
44 void runOnOperation() override {
45 mlir::Operation *scopeOp = this->getOperation();
46
47 auto &dataLayoutAnalysis = this->template getAnalysis<mlir::DataLayoutAnalysis>();
48 const mlir::DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
49
50 bool changed = false;
51
52 for (mlir::Region &region : scopeOp->getRegions()) {
53 if (region.getBlocks().empty()) {
54 continue;
55 }
56
57 mlir::OpBuilder builder(&region.front(), region.front().begin());
58
59 mlir::SmallVector<mlir::DestructurableAllocationOpInterface> allocators;
60 region.walk([&](AllocOpTy allocator) { allocators.emplace_back(allocator); });
61
62 if (mlir::succeeded(mlir::tryToDestructureMemorySlots(allocators, builder, dataLayout))) {
63 changed = true;
64 }
65 }
66
67 if (!changed) {
68 this->markAllAnalysesPreserved();
69 }
70 }
71};
72
73// Pass factory for `SpecializedSROA`.
74template <typename AllocOpTy>
75std::unique_ptr<SpecializedSROA<AllocOpTy>> createSpecializedSROAPass() {
76 return std::make_unique<SpecializedSROA<AllocOpTy>>();
77}
78
82template <typename AllocOpTy>
84 : mlir::PassWrapper<SpecializedMem2Reg<AllocOpTy>, mlir::OperationPass<>> {
85
86 mlir::StringRef getArgument() const override { return "llzk-specialized-mem2reg"; }
87
88 mlir::StringRef getDescription() const override {
89 return "Promotes memory slots of a specific allocator op type into values";
90 }
91
92 void runOnOperation() override {
93 mlir::Operation *scopeOp = this->getOperation();
94
95 auto &dataLayoutAnalysis = this->template getAnalysis<mlir::DataLayoutAnalysis>();
96 const mlir::DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
97 auto &dominance = this->template getAnalysis<mlir::DominanceInfo>();
98
99 bool changed = false;
100
101 for (mlir::Region &region : scopeOp->getRegions()) {
102 if (region.getBlocks().empty()) {
103 continue;
104 }
105
106 mlir::OpBuilder builder(&region.front(), region.front().begin());
107
108 mlir::SmallVector<mlir::PromotableAllocationOpInterface> allocators;
109 region.walk([&](AllocOpTy allocator) { allocators.emplace_back(allocator); });
110
111 if (mlir::succeeded(
112 mlir::tryToPromoteMemorySlots(allocators, builder, dataLayout, dominance)
113 )) {
114 changed = true;
115 }
116 }
117
118 if (!changed) {
119 this->markAllAnalysesPreserved();
120 }
121 }
122};
123
124// Pass factory for `SpecializedMem2Reg`.
125template <typename AllocOpTy>
126std::unique_ptr<SpecializedMem2Reg<AllocOpTy>> createSpecializedMem2RegPass() {
127 return std::make_unique<SpecializedMem2Reg<AllocOpTy>>();
128}
129
130} // namespace llzk
std::unique_ptr< SpecializedMem2Reg< AllocOpTy > > createSpecializedMem2RegPass()
std::unique_ptr< SpecializedSROA< AllocOpTy > > createSpecializedSROAPass()
A variant of the MLIR mem2reg pass that only promotes memory slots belonging to allocators of type Al...
mlir::StringRef getArgument() const override
mlir::StringRef getDescription() const override
A variant of the MLIR sroa pass that only destructures memory slots belonging to allocators of type A...
mlir::StringRef getDescription() const override
mlir::StringRef getArgument() const override