LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKConversionUtils.h
Go to the documentation of this file.
1//===-- LLZKConversionUtils.h -----------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9//
10// Shared utilities for dialect converting transformations.
11//
12//===----------------------------------------------------------------------===//
13
14#pragma once
15
18#include "llzk/Util/Concepts.h"
19
20#include <mlir/IR/PatternMatch.h>
21#include <mlir/IR/SymbolTable.h>
22#include <mlir/Transforms/DialectConversion.h>
23
24#include <llvm/ADT/DenseMap.h>
25#include <llvm/ADT/StringSet.h>
26#include <llvm/ADT/Twine.h>
27
28#include <optional>
29#include <string>
30
31namespace llzk {
32
35inline mlir::DictionaryAttr
36withFunctionNameAttr(mlir::DictionaryAttr attrs, llvm::StringRef attrName, llvm::StringRef name) {
37 mlir::NamedAttrList newAttrs(attrs);
38 newAttrs.set(attrName, mlir::StringAttr::get(attrs.getContext(), name));
39 return newAttrs.getDictionary(attrs.getContext());
40}
41
44inline mlir::DictionaryAttr
45withFunctionArgNameAttr(mlir::DictionaryAttr attrs, llvm::StringRef name) {
47}
48
50inline mlir::DictionaryAttr
51withFunctionResNameAttr(mlir::DictionaryAttr attrs, llvm::StringRef name) {
53}
54
56inline std::string
57reserveUniqueAttrName(llvm::StringSet<> &usedNames, llvm::StringRef desiredName) {
58 if (!usedNames.contains(desiredName)) {
59 usedNames.insert(desiredName);
60 return desiredName.str();
61 }
62
63 for (unsigned suffix = 1;; ++suffix) {
64 std::string candidate = (desiredName + "#" + llvm::Twine(suffix)).str();
65 if (!usedNames.contains(candidate)) {
66 usedNames.insert(candidate);
67 return candidate;
68 }
69 }
70}
71
73inline std::optional<mlir::StringAttr>
74getAttrAtIndexWithName(mlir::ArrayAttr attrs, unsigned index, llvm::StringRef attrName) {
75 if (!attrs || index >= attrs.size()) {
76 return std::nullopt;
77 }
78 if (auto dictAttr = llvm::dyn_cast<mlir::DictionaryAttr>(attrs[index])) {
79 if (auto nameAttr = llvm::dyn_cast<mlir::StringAttr>(dictAttr.get(attrName))) {
80 return nameAttr;
81 }
82 }
83 return std::nullopt;
84}
85
88 llvm::SmallVector<std::optional<llvm::StringRef>> originalNames;
89 llvm::SmallVector<llvm::StringRef> existingNames;
90 llvm::SmallVector<llvm::SmallVector<std::string>> splitNameSuffixes;
91};
92
94template <typename GetNameAttrFn, typename GetSplitSuffixesFn>
96 mlir::ArrayRef<mlir::Type> origTypes, GetNameAttrFn &&getNameAttr,
97 GetSplitSuffixesFn &&getSplitSuffixes
98) {
100 info.originalNames.reserve(origTypes.size());
101 info.splitNameSuffixes.reserve(origTypes.size());
102 for (auto [i, type] : llvm::enumerate(origTypes)) {
103 if (std::optional<mlir::StringAttr> nameAttr = getNameAttr(i)) {
104 info.originalNames.push_back(nameAttr->getValue());
105 info.existingNames.push_back(nameAttr->getValue());
106 } else {
107 info.originalNames.push_back(std::nullopt);
108 }
109 info.splitNameSuffixes.push_back(getSplitSuffixes(type));
110 }
111 return info;
112}
113
117 mlir::ArrayAttr origAttrs, const llvm::SmallVector<size_t> &originalIdxToSize,
118 const llvm::SmallVector<mlir::Type> &newTypes, llvm::StringRef functionNameAttrName,
119 llvm::ArrayRef<std::optional<llvm::StringRef>> origNames = {},
120 llvm::ArrayRef<llvm::StringRef> existingNames = {},
121 llvm::ArrayRef<llvm::SmallVector<std::string>> splitNameSuffixes = {}
122) {
123 if (!origAttrs) {
124 return nullptr;
125 }
126 assert(originalIdxToSize.size() == origAttrs.size());
127 if (originalIdxToSize.size() == newTypes.size()) {
128 return nullptr;
129 }
130
131 llvm::SmallVector<mlir::Attribute> newAttrs;
132 llvm::StringSet<> usedNames;
133 if (!origNames.empty()) {
134 for (llvm::StringRef name : existingNames) {
135 usedNames.insert(name);
136 }
137 }
138
139 for (auto [i, s] : llvm::enumerate(originalIdxToSize)) {
140 mlir::Attribute attr = origAttrs[i];
141 if (!origNames.empty() && !splitNameSuffixes.empty() && s != 1 && origNames[i]) {
142 assert(i < splitNameSuffixes.size());
143 assert(splitNameSuffixes[i].size() == s);
144 auto dictAttr = llvm::cast<mlir::DictionaryAttr>(attr);
145 llvm::StringRef name = *origNames[i];
146 for (llvm::StringRef suffix : splitNameSuffixes[i]) {
147 std::string desiredName = (llvm::Twine(name) + suffix).str();
148 newAttrs.push_back(withFunctionNameAttr(
149 dictAttr, functionNameAttrName, reserveUniqueAttrName(usedNames, desiredName)
150 ));
151 }
152 continue;
153 }
154 newAttrs.append(s, attr);
155 }
156 return mlir::ArrayAttr::get(origAttrs.getContext(), newAttrs);
157}
158
165 mlir::Location loc, mlir::TypeRange newResultTypes, function::CallOp oldCall,
166 llvm::ArrayRef<mlir::ValueRange> mapOperands, mlir::ValueRange argOperands,
167 mlir::ConversionPatternRewriter &rewriter
168) {
169 llvm::SmallVector<mlir::Attribute> templateParams;
170 if (mlir::ArrayAttr templateParamsAttr = oldCall.getTemplateParamsAttr()) {
171 templateParams.append(templateParamsAttr.begin(), templateParamsAttr.end());
172 }
173
174 if (oldCall.getMapOperands().empty()) {
175 return rewriter.create<function::CallOp>(
176 loc, newResultTypes, oldCall.getCalleeAttr(), argOperands, templateParams
177 );
178 }
179
180 return rewriter.create<function::CallOp>(
181 loc, newResultTypes, oldCall.getCalleeAttr(), mapOperands, oldCall.getNumDimsPerMapAttr(),
182 argOperands, templateParams
183 );
184}
185
189
190protected:
191 virtual llvm::SmallVector<mlir::Type> convertInputs(mlir::ArrayRef<mlir::Type> origTypes) = 0;
192 virtual llvm::SmallVector<mlir::Type> convertResults(mlir::ArrayRef<mlir::Type> origTypes) = 0;
193
194 virtual mlir::ArrayAttr
195 convertInputAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector<mlir::Type> newTypes) = 0;
196 virtual mlir::ArrayAttr
197 convertResultAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector<mlir::Type> newTypes) = 0;
198
199 virtual void processBlockArgs(mlir::Block &entryBlock, mlir::RewriterBase &rewriter) = 0;
200
201public:
202 virtual ~FunctionTypeConverter() = default;
203
204 void convert(function::FuncDefOp op, mlir::RewriterBase &rewriter) {
205 // Update in/out types of the function
206 mlir::FunctionType oldTy = op.getFunctionType();
207 llvm::SmallVector<mlir::Type> newInputs = convertInputs(oldTy.getInputs());
208 llvm::SmallVector<mlir::Type> newResults = convertResults(oldTy.getResults());
209 mlir::FunctionType newTy = mlir::FunctionType::get(
210 oldTy.getContext(), mlir::TypeRange(newInputs), mlir::TypeRange(newResults)
211 );
212 if (newTy == oldTy) {
213 return; // nothing to change
214 }
215
216 // Pre-condition: arg/result count equals corresponding attribute count
217 assert(!op.getResAttrsAttr() || op.getResAttrsAttr().size() == op.getNumResults());
218 assert(!op.getArgAttrsAttr() || op.getArgAttrsAttr().size() == op.getNumArguments());
219 rewriter.modifyOpInPlace(op, [&]() {
220 op.setFunctionType(newTy);
221
222 // If any input or result types were added, ensure the attributes are updated too.
223 if (mlir::ArrayAttr newArgAttrs = convertInputAttrs(op.getArgAttrsAttr(), newInputs)) {
224 op.setArgAttrsAttr(newArgAttrs);
225 }
226 if (mlir::ArrayAttr newResAttrs = convertResultAttrs(op.getResAttrsAttr(), newResults)) {
227 op.setResAttrsAttr(newResAttrs);
228 }
229 });
230 // Post-condition: arg/result count equals corresponding attribute count
231 assert(!op.getResAttrsAttr() || op.getResAttrsAttr().size() == op.getNumResults());
232 assert(!op.getArgAttrsAttr() || op.getArgAttrsAttr().size() == op.getNumArguments());
233
234 // If the function has a body, ensure the entry block arguments match the function inputs.
235 if (mlir::Region *body = op.getCallableRegion()) {
236 mlir::Block &entryBlock = body->front();
237 bool blockArgsNeedUpdate =
238 !std::cmp_equal(entryBlock.getNumArguments(), newInputs.size()) ||
239 llvm::any_of(llvm::zip_equal(entryBlock.getArgumentTypes(), newInputs), [](auto pair) {
240 return std::get<0>(pair) != std::get<1>(pair);
241 });
242 if (blockArgsNeedUpdate) {
243 processBlockArgs(entryBlock, rewriter);
244 // Post-condition: block args must match function inputs in both arity and type.
245 assert(std::cmp_equal(entryBlock.getNumArguments(), newInputs.size()));
246 for (unsigned i = 0, e = entryBlock.getNumArguments(); i < e; ++i) {
247 assert(entryBlock.getArgument(i).getType() == newInputs[i]);
248 }
249 }
250 }
251 }
252};
253
261template <
262 typename ImplClass, HasInterface<component::MemberRefOpInterface> MemberRefOpClass,
263 typename GenHeaderType, typename IdType>
264class SplitAggregateInMemberRefOp : public mlir::OpConversionPattern<MemberRefOpClass> {
265public:
267 using MemberInfo = std::pair<mlir::StringAttr, mlir::Type>;
269 using LocalMemberReplacementMap = llvm::DenseMap<IdType, MemberInfo>;
271 using MemberReplacementMap = llvm::DenseMap<
272 component::StructDefOp, llvm::DenseMap<mlir::StringAttr, LocalMemberReplacementMap>>;
273
274private:
275 mlir::SymbolTableCollection &tables;
276 const MemberReplacementMap &repMapRef;
277
278 // Static check to ensure the methods are implemented in all subclasses.
279 inline static void ensureImplementedAtCompile() {
280 static_assert(
281 sizeof(MemberRefOpClass) == 0,
282 "SplitAggregateInMemberRefOp not implemented for requested type."
283 );
284 }
285
286protected:
287 using OpAdaptor = typename MemberRefOpClass::Adaptor;
288
291 static GenHeaderType genHeader(MemberRefOpClass, mlir::ConversionPatternRewriter &) {
292 ensureImplementedAtCompile();
293 llvm_unreachable("must have concrete instantiation");
294 }
295
298 static void forId(
299 mlir::Location, GenHeaderType &, IdType, MemberInfo, OpAdaptor,
300 mlir::ConversionPatternRewriter &
301 ) {
302 ensureImplementedAtCompile();
303 llvm_unreachable("must have concrete instantiation");
304 }
305
306public:
307 // Suppress false positive from `clang-tidy`
308 // NOLINTNEXTLINE(bugprone-crtp-constructor-accessibility)
310 mlir::MLIRContext *ctx, mlir::SymbolTableCollection &symTables,
311 const MemberReplacementMap &memberRepMap
312 )
313 : mlir::OpConversionPattern<MemberRefOpClass>(ctx), tables(symTables),
314 repMapRef(memberRepMap) {}
315
316 static bool legal(MemberRefOpClass) {
317 ensureImplementedAtCompile();
318 llvm_unreachable("must have concrete instantiation");
319 return false;
320 }
321
322 mlir::LogicalResult match(MemberRefOpClass op) const override {
323 return mlir::failure(ImplClass::legal(op));
324 }
325
327 MemberRefOpClass op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter
328 ) const override {
329 component::StructType tgtStructTy =
330 llvm::cast<component::MemberRefOpInterface>(op.getOperation()).getStructType();
331 assert(tgtStructTy);
332 auto tgtStructDef = tgtStructTy.getDefinition(tables, op);
333 assert(mlir::succeeded(tgtStructDef));
334
335 GenHeaderType prefixResult = ImplClass::genHeader(op, rewriter);
336
337 const LocalMemberReplacementMap &idToName =
338 repMapRef.at(tgtStructDef->get()).at(op.getMemberNameAttr().getAttr());
339 // Split the aggregate member into a series of scalar member ops.
340 for (auto [id, newMember] : idToName) {
341 ImplClass::forId(op.getLoc(), prefixResult, id, newMember, adaptor, rewriter);
342 }
343 if constexpr (requires { ImplClass::finalize(op, prefixResult, adaptor, rewriter); }) {
344 ImplClass::finalize(op, prefixResult, adaptor, rewriter);
345 }
346 rewriter.eraseOp(op);
347 }
348};
349
350} // namespace llzk
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
virtual void processBlockArgs(mlir::Block &entryBlock, mlir::RewriterBase &rewriter)=0
virtual llvm::SmallVector< mlir::Type > convertResults(mlir::ArrayRef< mlir::Type > origTypes)=0
virtual mlir::ArrayAttr convertResultAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
void convert(function::FuncDefOp op, mlir::RewriterBase &rewriter)
virtual ~FunctionTypeConverter()=default
virtual mlir::ArrayAttr convertInputAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual llvm::SmallVector< mlir::Type > convertInputs(mlir::ArrayRef< mlir::Type > origTypes)=0
void rewrite(MemberRefOpClass op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override
llvm::DenseMap< component::StructDefOp, llvm::DenseMap< mlir::StringAttr, LocalMemberReplacementMap > > MemberReplacementMap
Maps struct -> original aggregate-type member name -> LocalMemberReplacementMap.
static bool legal(MemberRefOpClass)
std::pair< mlir::StringAttr, mlir::Type > MemberInfo
Scalar member name and type.
SplitAggregateInMemberRefOp(mlir::MLIRContext *ctx, mlir::SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap)
typename MemberRefOpClass::Adaptor OpAdaptor
static GenHeaderType genHeader(MemberRefOpClass, mlir::ConversionPatternRewriter &)
Executed at the start of rewrite() to (optionally) generate anything that should appear before the pe...
static void forId(mlir::Location, GenHeaderType &, IdType, MemberInfo, OpAdaptor, mlir::ConversionPatternRewriter &)
Executed for each scalar id in the aggregate type of the original member to generate the per-scalar o...
llvm::DenseMap< IdType, MemberInfo > LocalMemberReplacementMap
Maps a scalar element identifier within the aggregate to its new scalar member info.
mlir::LogicalResult match(MemberRefOpClass op) const override
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
Definition Types.cpp:26
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.h.inc:292
::mlir::ArrayAttr getTemplateParamsAttr()
Definition Ops.h.inc:297
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:270
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:302
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:984
void setArgAttrsAttr(::mlir::ArrayAttr attr)
Definition Ops.h.inc:733
void setResAttrsAttr(::mlir::ArrayAttr attr)
Definition Ops.h.inc:737
::mlir::ArrayAttr getArgAttrsAttr()
Definition Ops.h.inc:713
void setFunctionType(::mlir::FunctionType attrValue)
Definition Ops.cpp.inc:1003
::mlir::Region * getCallableRegion()
Required by FunctionOpInterface.
Definition Ops.h.inc:846
::mlir::ArrayAttr getResAttrsAttr()
Definition Ops.h.inc:718
Restricts a template parameter to Op classes that implement the given OpInterface.
Definition Concepts.h:20
constexpr char ARG_NAME_ATTR_NAME[]
Attribute name for source-level function argument names.
Definition Ops.h:34
constexpr char RES_NAME_ATTR_NAME[]
Attribute name for source-level function result names.
Definition Ops.h:37
mlir::DictionaryAttr withFunctionResNameAttr(mlir::DictionaryAttr attrs, llvm::StringRef name)
Return a copy of the given result attribute dictionary with function.res_name set to name.
mlir::DictionaryAttr withFunctionNameAttr(mlir::DictionaryAttr attrs, llvm::StringRef attrName, llvm::StringRef name)
Return a copy of the given function argument/result attribute dictionary with attrName set to name.
mlir::ArrayAttr replicateFunctionNameAttrsAsNeeded(mlir::ArrayAttr origAttrs, const llvm::SmallVector< size_t > &originalIdxToSize, const llvm::SmallVector< mlir::Type > &newTypes, llvm::StringRef functionNameAttrName, llvm::ArrayRef< std::optional< llvm::StringRef > > origNames={}, llvm::ArrayRef< llvm::StringRef > existingNames={}, llvm::ArrayRef< llvm::SmallVector< std::string > > splitNameSuffixes={})
Expand function arg/result attribute arrays to match a split signature, rewriting name attrs with the...
mlir::DictionaryAttr withFunctionArgNameAttr(mlir::DictionaryAttr attrs, llvm::StringRef name)
Return a copy of the given argument attribute dictionary with function.arg_name set to name.
function::CallOp createCallPreservingInstantiationOperands(mlir::Location loc, mlir::TypeRange newResultTypes, function::CallOp oldCall, llvm::ArrayRef< mlir::ValueRange > mapOperands, mlir::ValueRange argOperands, mlir::ConversionPatternRewriter &rewriter)
Rebuild a function.call while preserving explicit instantiation state from oldCall.
SplitFunctionNameInfo collectSplitFunctionNameInfo(mlir::ArrayRef< mlir::Type > origTypes, GetNameAttrFn &&getNameAttr, GetSplitSuffixesFn &&getSplitSuffixes)
Collect function arg/result names and split suffixes from a list of original types.
std::optional< mlir::StringAttr > getAttrAtIndexWithName(mlir::ArrayAttr attrs, unsigned index, llvm::StringRef attrName)
Return the function arg/result attribute at index for the given name, if present.
std::string reserveUniqueAttrName(llvm::StringSet<> &usedNames, llvm::StringRef desiredName)
Reserve and return a unique function argument/result name based on desiredName.
Cached function arg/result names and split suffixes used while rewriting a function signature.
llvm::SmallVector< std::optional< llvm::StringRef > > originalNames
llvm::SmallVector< llvm::StringRef > existingNames
llvm::SmallVector< llvm::SmallVector< std::string > > splitNameSuffixes