LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKInliningExtensions.cpp
Go to the documentation of this file.
1//===-- LLZKInliningExtensions.cpp ------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
24
25#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
26#include <mlir/Transforms/InliningUtils.h>
27
28using namespace mlir;
29using namespace llzk;
30
31namespace {
32
33template <typename InlinerImpl, typename DialectImpl, typename... RequiredDialects>
34struct BaseInlinerInterface : public DialectInlinerInterface {
35 using DialectInlinerInterface::DialectInlinerInterface;
36
37 static void registrationHook(MLIRContext *ctx, DialectImpl *dialect) {
38 dialect->template addInterfaces<InlinerImpl>();
39 if constexpr (sizeof...(RequiredDialects) != 0) {
40 ctx->loadDialect<RequiredDialects...>();
41 }
42 }
43};
44
45// Adapted from `mlir/lib/Dialect/Func/Extensions/InlinerExtension.cpp`
46struct FuncInlinerInterface
47 : public BaseInlinerInterface<
48 FuncInlinerInterface, function::FunctionDialect, cf::ControlFlowDialect> {
49 using BaseInlinerInterface::BaseInlinerInterface;
50
52 bool isLegalToInline(Operation *, Operation *, bool) const final { return true; }
53 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { return true; }
54 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final { return true; }
55
56 void handleTerminator(Operation *op, Block *newDest) const final {
57 // Only return needs to be handled here. Replace the return with a branch to the dest.
58 // Note: This function is only called when there are multiple blocks in the region being
59 // inlined. In LLZK IR, that would only occur when the `cf` dialect is already used (since no
60 // LLZK dialect defines any kind of cross-block branching ops) so it's fine to add a
61 // `cf::BranchOp` here.
62 if (auto returnOp = llvm::dyn_cast<function::ReturnOp>(op)) {
63 OpBuilder builder(op);
64 builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
65 op->erase();
66 }
67 }
68
69 void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
70 // ASSERT: when region contains a single block, terminator must be ReturnOp
71 assert(llvm::isa<function::ReturnOp>(op));
72
73 // Replace the values directly with the return operands.
74 auto returnOp = llvm::cast<function::ReturnOp>(op);
75 assert(returnOp.getNumOperands() == valuesToRepl.size());
76 for (const auto &it : llvm::enumerate(returnOp.getOperands())) {
77 valuesToRepl[it.index()].replaceAllUsesWith(it.value());
78 }
79 }
80};
81
82template <typename DialectImpl>
83struct FullyLegalForInlining
84 : public BaseInlinerInterface<FullyLegalForInlining<DialectImpl>, DialectImpl> {
85 using BaseInlinerInterface<FullyLegalForInlining<DialectImpl>, DialectImpl>::BaseInlinerInterface;
86
87 bool isLegalToInline(Operation *, Operation *, bool) const override { return true; }
88 bool isLegalToInline(Region *, Region *, bool, IRMapping &) const override { return true; }
89 bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const override { return true; }
90};
91
92} // namespace
93
94namespace llzk {
95
96void registerInliningExtensions(DialectRegistry &registry) {
97 registry.addExtension(FuncInlinerInterface::registrationHook);
98 registry.addExtension(FullyLegalForInlining<component::StructDialect>::registrationHook);
99 registry.addExtension(FullyLegalForInlining<constrain::ConstrainDialect>::registrationHook);
100 registry.addExtension(FullyLegalForInlining<string::StringDialect>::registrationHook);
101 registry.addExtension(FullyLegalForInlining<polymorphic::PolymorphicDialect>::registrationHook);
102 registry.addExtension(FullyLegalForInlining<felt::FeltDialect>::registrationHook);
103 registry.addExtension(FullyLegalForInlining<global::GlobalDialect>::registrationHook);
104 registry.addExtension(FullyLegalForInlining<boolean::BoolDialect>::registrationHook);
105 registry.addExtension(FullyLegalForInlining<array::ArrayDialect>::registrationHook);
106 registry.addExtension(FullyLegalForInlining<cast::CastDialect>::registrationHook);
107 registry.addExtension(FullyLegalForInlining<include::IncludeDialect>::registrationHook);
108 registry.addExtension(FullyLegalForInlining<llzk::LLZKDialect>::registrationHook);
109 registry.addExtension(FullyLegalForInlining<pod::PODDialect>::registrationHook);
110}
111
112} // namespace llzk
void registerInliningExtensions(DialectRegistry &registry)