20#include <mlir/IR/PatternMatch.h>
21#include <mlir/IR/SymbolTable.h>
22#include <mlir/Transforms/DialectConversion.h>
24#include <llvm/ADT/DenseMap.h>
25#include <llvm/ADT/StringSet.h>
26#include <llvm/ADT/Twine.h>
35inline mlir::DictionaryAttr
37 mlir::NamedAttrList newAttrs(attrs);
38 newAttrs.set(attrName, mlir::StringAttr::get(attrs.getContext(), name));
39 return newAttrs.getDictionary(attrs.getContext());
44inline mlir::DictionaryAttr
50inline mlir::DictionaryAttr
58 if (!usedNames.contains(desiredName)) {
59 usedNames.insert(desiredName);
60 return desiredName.str();
63 for (
unsigned suffix = 1;; ++suffix) {
64 std::string candidate = (desiredName +
"#" + llvm::Twine(suffix)).str();
65 if (!usedNames.contains(candidate)) {
66 usedNames.insert(candidate);
73inline std::optional<mlir::StringAttr>
75 if (!attrs || index >= attrs.size()) {
78 if (
auto dictAttr = llvm::dyn_cast<mlir::DictionaryAttr>(attrs[index])) {
79 if (
auto nameAttr = llvm::dyn_cast<mlir::StringAttr>(dictAttr.get(attrName))) {
94template <
typename GetNameAttrFn,
typename GetSplitSuffixesFn>
96 mlir::ArrayRef<mlir::Type> origTypes, GetNameAttrFn &&getNameAttr,
97 GetSplitSuffixesFn &&getSplitSuffixes
102 for (
auto [i, type] : llvm::enumerate(origTypes)) {
103 if (std::optional<mlir::StringAttr> nameAttr = getNameAttr(i)) {
117 mlir::ArrayAttr origAttrs,
const llvm::SmallVector<size_t> &originalIdxToSize,
118 const llvm::SmallVector<mlir::Type> &newTypes, llvm::StringRef functionNameAttrName,
119 llvm::ArrayRef<std::optional<llvm::StringRef>> origNames = {},
120 llvm::ArrayRef<llvm::StringRef> existingNames = {},
121 llvm::ArrayRef<llvm::SmallVector<std::string>> splitNameSuffixes = {}
126 assert(originalIdxToSize.size() == origAttrs.size());
127 if (originalIdxToSize.size() == newTypes.size()) {
131 llvm::SmallVector<mlir::Attribute> newAttrs;
132 llvm::StringSet<> usedNames;
133 if (!origNames.empty()) {
134 for (llvm::StringRef name : existingNames) {
135 usedNames.insert(name);
139 for (
auto [i, s] : llvm::enumerate(originalIdxToSize)) {
140 mlir::Attribute attr = origAttrs[i];
141 if (!origNames.empty() && !splitNameSuffixes.empty() && s != 1 && origNames[i]) {
142 assert(i < splitNameSuffixes.size());
143 assert(splitNameSuffixes[i].size() == s);
144 auto dictAttr = llvm::cast<mlir::DictionaryAttr>(attr);
145 llvm::StringRef name = *origNames[i];
146 for (llvm::StringRef suffix : splitNameSuffixes[i]) {
147 std::string desiredName = (llvm::Twine(name) + suffix).str();
154 newAttrs.append(s, attr);
156 return mlir::ArrayAttr::get(origAttrs.getContext(), newAttrs);
165 mlir::Location loc, mlir::TypeRange newResultTypes,
function::CallOp oldCall,
166 llvm::ArrayRef<mlir::ValueRange> mapOperands, mlir::ValueRange argOperands,
167 mlir::ConversionPatternRewriter &rewriter
169 llvm::SmallVector<mlir::Attribute> templateParams;
171 templateParams.append(templateParamsAttr.begin(), templateParamsAttr.end());
176 loc, newResultTypes, oldCall.
getCalleeAttr(), argOperands, templateParams
182 argOperands, templateParams
191 virtual llvm::SmallVector<mlir::Type>
convertInputs(mlir::ArrayRef<mlir::Type> origTypes) = 0;
192 virtual llvm::SmallVector<mlir::Type>
convertResults(mlir::ArrayRef<mlir::Type> origTypes) = 0;
194 virtual mlir::ArrayAttr
196 virtual mlir::ArrayAttr
207 llvm::SmallVector<mlir::Type> newInputs =
convertInputs(oldTy.getInputs());
208 llvm::SmallVector<mlir::Type> newResults =
convertResults(oldTy.getResults());
209 mlir::FunctionType newTy = mlir::FunctionType::get(
210 oldTy.getContext(), mlir::TypeRange(newInputs), mlir::TypeRange(newResults)
212 if (newTy == oldTy) {
219 rewriter.modifyOpInPlace(op, [&]() {
236 mlir::Block &entryBlock = body->front();
237 bool blockArgsNeedUpdate =
238 !std::cmp_equal(entryBlock.getNumArguments(), newInputs.size()) ||
239 llvm::any_of(llvm::zip_equal(entryBlock.getArgumentTypes(), newInputs), [](
auto pair) {
240 return std::get<0>(pair) != std::get<1>(pair);
242 if (blockArgsNeedUpdate) {
245 assert(std::cmp_equal(entryBlock.getNumArguments(), newInputs.size()));
246 for (
unsigned i = 0, e = entryBlock.getNumArguments(); i < e; ++i) {
247 assert(entryBlock.getArgument(i).getType() == newInputs[i]);
263 typename GenHeaderType,
typename IdType>
275 mlir::SymbolTableCollection &tables;
279 inline static void ensureImplementedAtCompile() {
281 sizeof(MemberRefOpClass) == 0,
282 "SplitAggregateInMemberRefOp not implemented for requested type."
291 static GenHeaderType
genHeader(MemberRefOpClass, mlir::ConversionPatternRewriter &) {
292 ensureImplementedAtCompile();
293 llvm_unreachable(
"must have concrete instantiation");
300 mlir::ConversionPatternRewriter &
302 ensureImplementedAtCompile();
303 llvm_unreachable(
"must have concrete instantiation");
310 mlir::MLIRContext *ctx, mlir::SymbolTableCollection &symTables,
313 :
mlir::OpConversionPattern<MemberRefOpClass>(ctx), tables(symTables),
314 repMapRef(memberRepMap) {}
316 static bool legal(MemberRefOpClass) {
317 ensureImplementedAtCompile();
318 llvm_unreachable(
"must have concrete instantiation");
322 mlir::LogicalResult
match(MemberRefOpClass op)
const override {
323 return mlir::failure(ImplClass::legal(op));
327 MemberRefOpClass op,
OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter
330 llvm::cast<component::MemberRefOpInterface>(op.getOperation()).getStructType();
333 assert(mlir::succeeded(tgtStructDef));
335 GenHeaderType prefixResult = ImplClass::genHeader(op, rewriter);
338 repMapRef.at(tgtStructDef->get()).at(op.getMemberNameAttr().getAttr());
340 for (
auto [
id, newMember] : idToName) {
341 ImplClass::forId(op.getLoc(), prefixResult,
id, newMember, adaptor, rewriter);
343 if constexpr (
requires { ImplClass::finalize(op, prefixResult, adaptor, rewriter); }) {
344 ImplClass::finalize(op, prefixResult, adaptor, rewriter);
346 rewriter.eraseOp(op);
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
virtual void processBlockArgs(mlir::Block &entryBlock, mlir::RewriterBase &rewriter)=0
virtual llvm::SmallVector< mlir::Type > convertResults(mlir::ArrayRef< mlir::Type > origTypes)=0
virtual mlir::ArrayAttr convertResultAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
void convert(function::FuncDefOp op, mlir::RewriterBase &rewriter)
virtual ~FunctionTypeConverter()=default
virtual mlir::ArrayAttr convertInputAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual llvm::SmallVector< mlir::Type > convertInputs(mlir::ArrayRef< mlir::Type > origTypes)=0
void rewrite(MemberRefOpClass op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override
llvm::DenseMap< component::StructDefOp, llvm::DenseMap< mlir::StringAttr, LocalMemberReplacementMap > > MemberReplacementMap
Maps struct -> original aggregate-type member name -> LocalMemberReplacementMap.
static bool legal(MemberRefOpClass)
std::pair< mlir::StringAttr, mlir::Type > MemberInfo
Scalar member name and type.
SplitAggregateInMemberRefOp(mlir::MLIRContext *ctx, mlir::SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap)
typename MemberRefOpClass::Adaptor OpAdaptor
static GenHeaderType genHeader(MemberRefOpClass, mlir::ConversionPatternRewriter &)
Executed at the start of rewrite() to (optionally) generate anything that should appear before the pe...
static void forId(mlir::Location, GenHeaderType &, IdType, MemberInfo, OpAdaptor, mlir::ConversionPatternRewriter &)
Executed for each scalar id in the aggregate type of the original member to generate the per-scalar o...
llvm::DenseMap< IdType, MemberInfo > LocalMemberReplacementMap
Maps a scalar element identifier within the aggregate to its new scalar member info.
mlir::LogicalResult match(MemberRefOpClass op) const override
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
::mlir::SymbolRefAttr getCalleeAttr()
::mlir::ArrayAttr getTemplateParamsAttr()
::mlir::OperandRangeRange getMapOperands()
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
::mlir::FunctionType getFunctionType()
void setArgAttrsAttr(::mlir::ArrayAttr attr)
void setResAttrsAttr(::mlir::ArrayAttr attr)
::mlir::ArrayAttr getArgAttrsAttr()
void setFunctionType(::mlir::FunctionType attrValue)
::mlir::Region * getCallableRegion()
Required by FunctionOpInterface.
::mlir::ArrayAttr getResAttrsAttr()
Restricts a template parameter to Op classes that implement the given OpInterface.
constexpr char ARG_NAME_ATTR_NAME[]
Attribute name for source-level function argument names.
constexpr char RES_NAME_ATTR_NAME[]
Attribute name for source-level function result names.
mlir::DictionaryAttr withFunctionResNameAttr(mlir::DictionaryAttr attrs, llvm::StringRef name)
Return a copy of the given result attribute dictionary with function.res_name set to name.
mlir::DictionaryAttr withFunctionNameAttr(mlir::DictionaryAttr attrs, llvm::StringRef attrName, llvm::StringRef name)
Return a copy of the given function argument/result attribute dictionary with attrName set to name.
mlir::ArrayAttr replicateFunctionNameAttrsAsNeeded(mlir::ArrayAttr origAttrs, const llvm::SmallVector< size_t > &originalIdxToSize, const llvm::SmallVector< mlir::Type > &newTypes, llvm::StringRef functionNameAttrName, llvm::ArrayRef< std::optional< llvm::StringRef > > origNames={}, llvm::ArrayRef< llvm::StringRef > existingNames={}, llvm::ArrayRef< llvm::SmallVector< std::string > > splitNameSuffixes={})
Expand function arg/result attribute arrays to match a split signature, rewriting name attrs with the...
mlir::DictionaryAttr withFunctionArgNameAttr(mlir::DictionaryAttr attrs, llvm::StringRef name)
Return a copy of the given argument attribute dictionary with function.arg_name set to name.
function::CallOp createCallPreservingInstantiationOperands(mlir::Location loc, mlir::TypeRange newResultTypes, function::CallOp oldCall, llvm::ArrayRef< mlir::ValueRange > mapOperands, mlir::ValueRange argOperands, mlir::ConversionPatternRewriter &rewriter)
Rebuild a function.call while preserving explicit instantiation state from oldCall.
SplitFunctionNameInfo collectSplitFunctionNameInfo(mlir::ArrayRef< mlir::Type > origTypes, GetNameAttrFn &&getNameAttr, GetSplitSuffixesFn &&getSplitSuffixes)
Collect function arg/result names and split suffixes from a list of original types.
std::optional< mlir::StringAttr > getAttrAtIndexWithName(mlir::ArrayAttr attrs, unsigned index, llvm::StringRef attrName)
Return the function arg/result attribute at index for the given name, if present.
std::string reserveUniqueAttrName(llvm::StringSet<> &usedNames, llvm::StringRef desiredName)
Reserve and return a unique function argument/result name based on desiredName.
Cached function arg/result names and split suffixes used while rewriting a function signature.
llvm::SmallVector< std::optional< llvm::StringRef > > originalNames
llvm::SmallVector< llvm::StringRef > existingNames
llvm::SmallVector< llvm::SmallVector< std::string > > splitNameSuffixes