18#include <mlir/Dialect/Arith/IR/Arith.h>
19#include <mlir/Dialect/Utils/IndexingUtils.h>
20#include <mlir/IR/Attributes.h>
21#include <mlir/IR/BuiltinOps.h>
22#include <mlir/IR/Diagnostics.h>
23#include <mlir/IR/Matchers.h>
24#include <mlir/IR/OwningOpRef.h>
25#include <mlir/IR/SymbolTable.h>
26#include <mlir/IR/ValueRange.h>
27#include <mlir/Support/LogicalResult.h>
29#include <llvm/ADT/ArrayRef.h>
30#include <llvm/ADT/Twine.h>
50 OpBuilder &odsBuilder, OperationState &odsState,
ArrayType result, ValueRange elements
52 odsState.addTypes(result);
53 odsState.addOperands(elements);
62 OpBuilder &odsBuilder, OperationState &odsState,
ArrayType result,
63 ArrayRef<ValueRange> mapOperands, DenseI32ArrayAttr numDimsPerMap
65 odsState.addTypes(result);
67 odsBuilder, odsState, mapOperands, numDimsPerMap
80llvm::SmallVector<Type> CreateArrayOp::resultTypeToElementsTypes(Type resultType) {
82 ArrayType a = llvm::cast<ArrayType>(resultType);
83 return llvm::SmallVector<Type>(a.getNumElements(), a.
getElementType());
86ParseResult CreateArrayOp::parseInferredArrayType(
87 OpAsmParser & , llvm::SmallVector<Type, 1> &elementsTypes,
88 ArrayRef<OpAsmParser::UnresolvedOperand> elements, Type resultType
90 assert(elementsTypes.size() == 0);
93 if (elements.size() > 0) {
94 elementsTypes.append(resultTypeToElementsTypes(resultType));
99void CreateArrayOp::printInferredArrayType(
100 OpAsmPrinter &printer,
CreateArrayOp, TypeRange, OperandRange, Type
107 assert(llvm::isa<ArrayType>(retTy));
110 SmallVector<AffineMapAttr> mapAttrs;
112 ArrayType arrTy = llvm::cast<ArrayType>(retTy);
114 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(a)) {
115 mapAttrs.push_back(m);
125 assert(
getElements().empty() &&
"must run after initialization is split from allocation");
127 if (!arrType.hasStaticShape() || arrType.getNumElements() == 1) {
131 return {DestructurableMemorySlot {{
getResult(), arrType}, std::move(*destructured)}};
138 const DestructurableMemorySlot &slot,
const SmallPtrSetImpl<Attribute> &usedIndices,
139 OpBuilder &builder, SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators
142 assert(slot.elemType == getType());
144 builder.setInsertionPointAfter(*
this);
146 DenseMap<Attribute, MemorySlot> slotMap;
147 for (Attribute index : usedIndices) {
149 ArrayAttr indexAsArray = llvm::dyn_cast<ArrayAttr>(index);
150 assert(indexAsArray &&
"expected ArrayAttr");
152 Type destructAs = getType().getTypeAtIndex(indexAsArray);
153 assert(destructAs == slot.subelementTypes.lookup(indexAsArray));
155 ArrayType destructAsArrayTy = llvm::dyn_cast<ArrayType>(destructAs);
156 assert(destructAsArrayTy &&
"expected ArrayType");
158 auto subCreate = builder.create<
CreateArrayOp>(getLoc(), destructAsArrayTy);
159 newAllocators.push_back(subCreate);
160 slotMap.try_emplace<MemorySlot>(index, {subCreate.getResult(), destructAs});
168 const DestructurableMemorySlot &slot, OpBuilder &
178 if (!arrType.hasStaticShape()) {
183 if (arrType.getNumElements() != 1) {
199 const MemorySlot & , Value defaultValue, OpBuilder &
201 if (defaultValue.use_empty()) {
202 defaultValue.getDefiningOp()->erase();
218 if (arrTy.hasStaticShape()) {
220 return ArrayAttr::get(getContext(), *converted);
228 const DestructurableMemorySlot &slot, SmallPtrSetImpl<Attribute> &usedIndices,
229 SmallVectorImpl<MemorySlot> & ,
const DataLayout &
246 usedIndices.insert(indexAsAttr);
252 const DestructurableMemorySlot &slot, DenseMap<Attribute, MemorySlot> &subslots,
253 OpBuilder &builder,
const DataLayout &
261 assert(indexAsAttr &&
"canRewire() should have returned false");
262 const MemorySlot &memorySlot = subslots.at(indexAsAttr);
265 OpBuilder::InsertionGuard guard(builder);
266 builder.setInsertionPoint(this->getOperation());
273 return DeletionKind::Keep;
286 auto compare = numIndices <=> dims.size();
288 return errFn().append(
289 "has ", (compare < 0 ?
"insufficient" :
"too many"),
" indexed dimensions: expected ",
290 dims.size(),
" but found ", numIndices
305 llvm::SmallVectorImpl<Type> &inferredReturnTypes
307 inferredReturnTypes.resize(1);
308 Type lvalType = adaptor.
getArrRef().getType();
309 assert(llvm::isa<ArrayType>(lvalType));
310 inferredReturnTypes[0] = llvm::cast<ArrayType>(lvalType).getElementType();
325 const MemorySlot &slot,
const SmallPtrSetImpl<OpOperand *> &blockingUses,
326 SmallVectorImpl<OpOperand *> & ,
const DataLayout &
328 if (blockingUses.size() != 1) {
331 Value blockingUse = (*blockingUses.begin())->get();
332 return blockingUse == slot.ptr &&
getArrRef() == slot.ptr &&
338 const MemorySlot & ,
const SmallPtrSetImpl<OpOperand *> & ,
339 OpBuilder & , Value reachingDefinition,
const DataLayout &
342 getResult().replaceAllUsesWith(reachingDefinition);
343 return DeletionKind::Delete;
364 const MemorySlot &slot,
const SmallPtrSetImpl<OpOperand *> &blockingUses,
365 SmallVectorImpl<OpOperand *> & ,
const DataLayout &
367 if (blockingUses.size() != 1) {
370 Value blockingUse = (*blockingUses.begin())->get();
371 return blockingUse == slot.ptr &&
getArrRef() == slot.ptr &&
getRvalue() != slot.ptr &&
377 const MemorySlot &,
const SmallPtrSetImpl<OpOperand *> &, OpBuilder &, Value,
const DataLayout &
379 return DeletionKind::Delete;
393 llvm::SmallVectorImpl<Type> &inferredReturnTypes
395 size_t numToSkip = adaptor.
getIndices().size();
396 Type arrRefType = adaptor.
getArrRef().getType();
397 assert(llvm::isa<ArrayType>(arrRefType));
398 ArrayType arrRefArrType = llvm::cast<ArrayType>(arrRefType);
402 auto compare = numToSkip <=> arrRefDimSizes.size();
404 return mlir::emitOptionalError(
409 }
else if (compare > 0) {
410 return mlir::emitOptionalError(
412 "' op cannot select more dimensions than exist in the source array"
417 inferredReturnTypes.resize(1);
418 inferredReturnTypes[0] =
441 assert(llvm::isa<ArrayType>(rValueType));
442 ArrayType rValueArrType = llvm::cast<ArrayType>(rValueType);
450 if (numIndices > lhsDims) {
451 return emitOpError(
"cannot select more dimensions than exist in the source array");
455 auto compare = (numIndices + rhsDims) <=> lhsDims;
457 return emitOpError().append(
458 "has ", (compare < 0 ?
"insufficient" :
"too many"),
" indexed dimensions: expected ",
459 (lhsDims - rhsDims),
" but found ", numIndices
481 if (!matchPattern(dimValue, m_ConstantInt(&dim))) {
485 std::optional<int64_t> idxOpt = dim.trySExtValue();
486 if (!idxOpt || *idxOpt < 0) {
487 auto diag = emitOpError(
"dimension must be a non-negative 64-bit integer");
488 if (!llvm::isa<UnknownLoc>(dimValue.getLoc())) {
489 diag.attachNote(dimValue.getLoc()).append(
"dimension defined here");
496 InFlightDiagnostic diag = emitOpError().append(
497 "dimension index ", idx,
" is not valid for array with ", rank,
" dimensions"
499 if (!llvm::isa<UnknownLoc>(
getArrRef().getLoc())) {
500 diag.attachNote(
getArrRef().getLoc()).append(
"array defined here");
502 if (!llvm::isa<UnknownLoc>(dimValue.getLoc())) {
503 diag.attachNote(dimValue.getLoc()).append(
"dimension defined here");
::mlir::DeletionKind rewire(const ::mlir::DestructurableMemorySlot &slot, ::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot > &subslots, ::mlir::OpBuilder &builder, const ::mlir::DataLayout &dataLayout)
Required by companion interface DestructurableAccessorOpInterface / SROA pass.
::mlir::Operation::operand_range getIndices()
Gets the operand range containing the index for each dimension.
::mlir::OpOperand & getArrRefMutable()
Gets the mutable operand slot holding the SSA Value for the referenced array.
inline ::mlir::ArrayRef<::mlir::Attribute > getValueOperandDims()
Compute the dimensions of the read/write value.
::mlir::ArrayAttr indexOperandsToAttributeArray()
Returns the multi-dimensional indices of the array access as an Attribute array or a null pointer if ...
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
Gets the SSA Value for the referenced array.
bool canRewire(const ::mlir::DestructurableMemorySlot &slot, ::llvm::SmallPtrSetImpl<::mlir::Attribute > &usedIndices, ::mlir::SmallVectorImpl<::mlir::MemorySlot > &mustBeSafelyUsed, const ::mlir::DataLayout &dataLayout)
Required by companion interface DestructurableAccessorOpInterface / SROA pass.
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced array.
::mlir::MutableOperandRange getIndicesMutable()
Gets the mutable operand range containing the index for each dimension.
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced base array.
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
::mlir::TypedValue<::mlir::IndexType > getDim()
::mlir::Type getElementType() const
::std::optional<::llvm::DenseMap<::mlir::Attribute, ::mlir::Type > > getSubelementIndexMap() const
Required by DestructurableTypeInterface / SROA pass.
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llzk::array::ArrayType result, ::mlir::ValueRange elements={})
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
::llvm::SmallVector<::mlir::DestructurableMemorySlot > getDestructurableSlots()
Required by DestructurableAllocationOpInterface / SROA pass.
::std::optional<::mlir::PromotableAllocationOpInterface > handlePromotionComplete(const ::mlir::MemorySlot &slot, ::mlir::Value defaultValue, ::mlir::OpBuilder &builder)
Required by PromotableAllocationOpInterface / mem2reg pass.
::std::optional<::mlir::DestructurableAllocationOpInterface > handleDestructuringComplete(const ::mlir::DestructurableMemorySlot &slot, ::mlir::OpBuilder &builder)
Required by DestructurableAllocationOpInterface / SROA pass.
::llvm::LogicalResult verify()
::mlir::Value getDefaultValue(const ::mlir::MemorySlot &slot, ::mlir::OpBuilder &builder)
Required by PromotableAllocationOpInterface / mem2reg pass.
::llvm::SmallVector<::mlir::MemorySlot > getPromotableSlots()
Required by PromotableAllocationOpInterface / mem2reg pass.
::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot > destructure(const ::mlir::DestructurableMemorySlot &slot, const ::llvm::SmallPtrSetImpl<::mlir::Attribute > &usedIndices, ::mlir::OpBuilder &builder, ::mlir::SmallVectorImpl<::mlir::DestructurableAllocationOpInterface > &newAllocators)
Required by DestructurableAllocationOpInterface / SROA pass.
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
void handleBlockArgument(const ::mlir::MemorySlot &slot, ::mlir::BlockArgument argument, ::mlir::OpBuilder &builder)
Required by PromotableAllocationOpInterface / mem2reg pass.
::mlir::OperandRangeRange getMapOperands()
::mlir::Operation::operand_range getElements()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced base array.
::mlir::Operation::operand_range getIndices()
::llvm::LogicalResult verify()
::mlir::TypedValue<::llzk::array::ArrayType > getRvalue()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
::mlir::DeletionKind removeBlockingUses(const ::mlir::MemorySlot &slot, const ::llvm::SmallPtrSetImpl< mlir::OpOperand * > &blockingUses, ::mlir::OpBuilder &builder, ::mlir::Value reachingDefinition, const ::mlir::DataLayout &dataLayout)
Required by PromotableMemOpInterface / mem2reg pass.
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced base array.
::mlir::TypedValue<::mlir::Type > getResult()
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
::mlir::Operation::operand_range getIndices()
::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location > location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type > &inferredReturnTypes)
::llvm::LogicalResult verify()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
static constexpr ::llvm::StringLiteral getOperationName()
static bool isCompatibleReturnTypes(::mlir::TypeRange l, ::mlir::TypeRange r)
bool canUsesBeRemoved(const ::mlir::MemorySlot &slot, const ::llvm::SmallPtrSetImpl<::mlir::OpOperand * > &blockingUses, ::llvm::SmallVectorImpl<::mlir::OpOperand * > &newBlockingUses, const ::mlir::DataLayout &datalayout)
Required by PromotableMemOpInterface / mem2reg pass.
::llvm::LogicalResult verify()
::mlir::Operation::operand_range getIndices()
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced base array.
bool canUsesBeRemoved(const ::mlir::MemorySlot &slot, const ::llvm::SmallPtrSetImpl<::mlir::OpOperand * > &blockingUses, ::llvm::SmallVectorImpl<::mlir::OpOperand * > &newBlockingUses, const ::mlir::DataLayout &datalayout)
Required by PromotableMemOpInterface / mem2reg pass.
::mlir::DeletionKind removeBlockingUses(const ::mlir::MemorySlot &slot, const ::llvm::SmallPtrSetImpl< mlir::OpOperand * > &blockingUses, ::mlir::OpBuilder &builder, ::mlir::Value reachingDefinition, const ::mlir::DataLayout &dataLayout)
Required by PromotableMemOpInterface / mem2reg pass.
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
::mlir::TypedValue<::mlir::Type > getRvalue()
OpClass::Properties & buildInstantiationAttrs(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
OpClass::Properties & buildInstantiationAttrsEmpty(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
LogicalResult verifySubArrayType(EmitErrorFn emitError, ArrayType arrayType, ArrayType subArrayType)
Determine if the subArrayType is a valid subarray of arrayType.
bool singletonTypeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
constexpr T checkedCast(U u) noexcept
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...