LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKInliningExtensions.cpp
Go to the documentation of this file.
1//===-- LLZKInliningExtensions.cpp ------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
24
25#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
26#include <mlir/Transforms/InliningUtils.h>
27
28using namespace mlir;
29using namespace llzk;
30
31namespace {
32
33template <typename InlinerImpl, typename DialectImpl, typename... RequiredDialects>
34// Suppress false positive from `clang-tidy`
35// NOLINTNEXTLINE(bugprone-crtp-constructor-accessibility)
36struct BaseInlinerInterface : public DialectInlinerInterface {
37protected:
38 using DialectInlinerInterface::DialectInlinerInterface;
39
40public:
41 static void registrationHook(MLIRContext *ctx, DialectImpl *dialect) {
42 dialect->template addInterfaces<InlinerImpl>();
43 if constexpr (sizeof...(RequiredDialects) != 0) {
44 ctx->loadDialect<RequiredDialects...>();
45 }
46 }
47};
48
49// Adapted from `mlir/lib/Dialect/Func/Extensions/InlinerExtension.cpp`
50struct FuncInlinerInterface
51 : public BaseInlinerInterface<
52 FuncInlinerInterface, function::FunctionDialect, cf::ControlFlowDialect> {
53 using BaseInlinerInterface::BaseInlinerInterface;
54
56 bool isLegalToInline(Operation *, Operation *, bool) const final { return true; }
57 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { return true; }
58 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { return true; }
59
60 void handleTerminator(Operation *op, Block *newDest) const final {
61 // Only return needs to be handled here. Replace the return with a branch to the dest.
62 // Note: This function is only called when there are multiple blocks in the region being
63 // inlined. In LLZK IR, that would only occur when the `cf` dialect is already used (since no
64 // LLZK dialect defines any kind of cross-block branching ops) so it's fine to add a
65 // `cf::BranchOp` here.
66 if (auto returnOp = llvm::dyn_cast<function::ReturnOp>(op)) {
67 OpBuilder builder(op);
68 builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
69 op->erase();
70 }
71 }
72
73 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
74 // ASSERT: when region contains a single block, terminator must be ReturnOp
75 assert(llvm::isa<function::ReturnOp>(op));
76
77 // Replace the values directly with the return operands.
78 auto returnOp = llvm::cast<function::ReturnOp>(op);
79 assert(returnOp.getNumOperands() == valuesToRepl.size());
80 for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
81 valuesToRepl[it.index()].replaceAllUsesWith(it.value());
82 }
83 }
84};
85
86template <typename DialectImpl>
87struct FullyLegalForInlining
88 : public BaseInlinerInterface<FullyLegalForInlining<DialectImpl>, DialectImpl> {
89 using BaseInlinerInterface<FullyLegalForInlining<DialectImpl>, DialectImpl>::BaseInlinerInterface;
90
91 bool isLegalToInline(Operation *, Operation *, bool) const override { return true; }
92 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const override { return true; }
93 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const override { return true; }
94};
95
96} // namespace
97
98namespace llzk {
99
100void registerInliningExtensions(DialectRegistry &registry) {
101 registry.addExtension(FuncInlinerInterface::registrationHook);
102 registry.addExtension(FullyLegalForInlining<component::StructDialect>::registrationHook);
103 registry.addExtension(FullyLegalForInlining<constrain::ConstrainDialect>::registrationHook);
104 registry.addExtension(FullyLegalForInlining<string::StringDialect>::registrationHook);
105 registry.addExtension(FullyLegalForInlining<polymorphic::PolymorphicDialect>::registrationHook);
106 registry.addExtension(FullyLegalForInlining<felt::FeltDialect>::registrationHook);
107 registry.addExtension(FullyLegalForInlining<global::GlobalDialect>::registrationHook);
108 registry.addExtension(FullyLegalForInlining<boolean::BoolDialect>::registrationHook);
109 registry.addExtension(FullyLegalForInlining<array::ArrayDialect>::registrationHook);
110 registry.addExtension(FullyLegalForInlining<cast::CastDialect>::registrationHook);
111 registry.addExtension(FullyLegalForInlining<include::IncludeDialect>::registrationHook);
112 registry.addExtension(FullyLegalForInlining<llzk::LLZKDialect>::registrationHook);
113 registry.addExtension(FullyLegalForInlining<pod::PODDialect>::registrationHook);
114}
115
116} // namespace llzk
void registerInliningExtensions(DialectRegistry &registry)