LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKConversionUtils.h
Go to the documentation of this file.
1//===-- LLZKConversionUtils.h -----------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9//
10// Shared utilities for dialect converting transformations.
11//
12//===----------------------------------------------------------------------===//
13
14#pragma once
15
17
18#include <mlir/IR/PatternMatch.h>
19
20namespace llzk {
21
25
26protected:
27 virtual llvm::SmallVector<mlir::Type> convertInputs(mlir::ArrayRef<mlir::Type> origTypes) = 0;
28 virtual llvm::SmallVector<mlir::Type> convertResults(mlir::ArrayRef<mlir::Type> origTypes) = 0;
29
30 virtual mlir::ArrayAttr
31 convertInputAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector<mlir::Type> newTypes) = 0;
32 virtual mlir::ArrayAttr
33 convertResultAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector<mlir::Type> newTypes) = 0;
34
35 virtual void processBlockArgs(mlir::Block &entryBlock, mlir::RewriterBase &rewriter) = 0;
36
37public:
38 virtual ~FunctionTypeConverter() = default;
39
40 void convert(function::FuncDefOp op, mlir::RewriterBase &rewriter) {
41 // Update in/out types of the function
42 mlir::FunctionType oldTy = op.getFunctionType();
43 llvm::SmallVector<mlir::Type> newInputs = convertInputs(oldTy.getInputs());
44 llvm::SmallVector<mlir::Type> newResults = convertResults(oldTy.getResults());
45 mlir::FunctionType newTy = mlir::FunctionType::get(
46 oldTy.getContext(), mlir::TypeRange(newInputs), mlir::TypeRange(newResults)
47 );
48 if (newTy == oldTy) {
49 return; // nothing to change
50 }
51
52 // Pre-condition: arg/result count equals corresponding attribute count
53 assert(!op.getResAttrsAttr() || op.getResAttrsAttr().size() == op.getNumResults());
54 assert(!op.getArgAttrsAttr() || op.getArgAttrsAttr().size() == op.getNumArguments());
55 rewriter.modifyOpInPlace(op, [&]() {
56 op.setFunctionType(newTy);
57
58 // If any input or result types were added, ensure the attributes are updated too.
59 if (mlir::ArrayAttr newArgAttrs = convertInputAttrs(op.getArgAttrsAttr(), newInputs)) {
60 op.setArgAttrsAttr(newArgAttrs);
61 }
62 if (mlir::ArrayAttr newResAttrs = convertResultAttrs(op.getResAttrsAttr(), newResults)) {
63 op.setResAttrsAttr(newResAttrs);
64 }
65 });
66 // Post-condition: arg/result count equals corresponding attribute count
67 assert(!op.getResAttrsAttr() || op.getResAttrsAttr().size() == op.getNumResults());
68 assert(!op.getArgAttrsAttr() || op.getArgAttrsAttr().size() == op.getNumArguments());
69
70 // If the function has a body, ensure the entry block arguments match the function inputs.
71 if (mlir::Region *body = op.getCallableRegion()) {
72 mlir::Block &entryBlock = body->front();
73 if (!std::cmp_equal(entryBlock.getNumArguments(), newInputs.size())) {
74 processBlockArgs(entryBlock, rewriter);
75 // Post-condition: block args must match function inputs
76 assert(std::cmp_equal(entryBlock.getNumArguments(), newInputs.size()));
77 }
78 }
79 }
80};
81
82} // namespace llzk
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
virtual void processBlockArgs(mlir::Block &entryBlock, mlir::RewriterBase &rewriter)=0
virtual llvm::SmallVector< mlir::Type > convertResults(mlir::ArrayRef< mlir::Type > origTypes)=0
virtual mlir::ArrayAttr convertResultAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
void convert(function::FuncDefOp op, mlir::RewriterBase &rewriter)
virtual ~FunctionTypeConverter()=default
virtual mlir::ArrayAttr convertInputAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual llvm::SmallVector< mlir::Type > convertInputs(mlir::ArrayRef< mlir::Type > origTypes)=0
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:984
void setArgAttrsAttr(::mlir::ArrayAttr attr)
Definition Ops.h.inc:733
void setResAttrsAttr(::mlir::ArrayAttr attr)
Definition Ops.h.inc:737
::mlir::ArrayAttr getArgAttrsAttr()
Definition Ops.h.inc:713
void setFunctionType(::mlir::FunctionType attrValue)
Definition Ops.cpp.inc:1003
::mlir::Region * getCallableRegion()
Required by FunctionOpInterface.
Definition Ops.h.inc:834
::mlir::ArrayAttr getResAttrsAttr()
Definition Ops.h.inc:718