65#include <mlir/IR/BuiltinOps.h>
66#include <mlir/Pass/PassManager.h>
67#include <mlir/Transforms/DialectConversion.h>
68#include <mlir/Transforms/Passes.h>
70#include <llvm/Support/Debug.h>
74#define GEN_PASS_DEF_ARRAYTOSCALARPASS
84#define DEBUG_TYPE "llzk-array-to-scalar"
89inline ArrayType splittableArray(
ArrayType at) {
return at.hasStaticShape() ? at :
nullptr; }
93 if (
ArrayType at = dyn_cast<ArrayType>(t)) {
94 return splittableArray(at);
101inline bool containsSplittableArrayType(Type t) {
104 return splittableArray(a) ? WalkResult::interrupt() : WalkResult::skip();
109template <
typename T>
bool containsSplittableArrayType(ValueTypeRange<T> types) {
110 for (Type t : types) {
111 if (containsSplittableArrayType(t)) {
120size_t splitArrayTypeTo(Type t, SmallVector<Type> &collect) {
122 int64_t n = at.getNumElements();
124 assert(std::cmp_less_equal(n, std::numeric_limits<size_t>::max()));
129 collect.push_back(t);
135template <
typename TypeCollection>
136inline void splitArrayTypeTo(
137 TypeCollection types, SmallVector<Type> &collect, SmallVector<size_t> *originalIdxToSize
139 for (Type t : types) {
140 size_t count = splitArrayTypeTo(t, collect);
141 if (originalIdxToSize) {
142 originalIdxToSize->push_back(count);
149template <
typename TypeCollection>
150inline SmallVector<Type>
151splitArrayType(TypeCollection types, SmallVector<size_t> *originalIdxToSize =
nullptr) {
152 SmallVector<Type> collect;
153 splitArrayTypeTo(types, collect, originalIdxToSize);
159SmallVector<Value> genIndexConstants(ArrayAttr index, Location loc, RewriterBase &rewriter) {
160 SmallVector<Value> operands;
161 for (Attribute a : index) {
163 IntegerAttr ia = llvm::dyn_cast<IntegerAttr>(a);
164 assert(ia && ia.getType().isIndex());
165 operands.push_back(rewriter.create<arith::ConstantOp>(loc, ia));
171genWrite(Location loc, Value baseArrayOp, ArrayAttr index, Value init, RewriterBase &rewriter) {
172 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
173 return rewriter.create<
WriteArrayOp>(loc, baseArrayOp, ValueRange(readOperands), init);
179CallOp newCallOpWithSplitResults(
182 OpBuilder::InsertionGuard guard(rewriter);
183 rewriter.setInsertionPointAfter(oldCall);
185 Operation::result_range oldResults = oldCall.getResults();
187 oldCall.getLoc(), splitArrayType(oldResults.getTypes()), oldCall.
getCallee(),
191 auto newResults = newCall.getResults().begin();
192 for (Value oldVal : oldResults) {
193 if (
ArrayType at = splittableArray(oldVal.getType())) {
194 Location loc = oldVal.getLoc();
197 rewriter.replaceAllUsesWith(oldVal, newArray);
203 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
204 for (ArrayAttr subIdx : allIndices.value()) {
205 genWrite(loc, newArray, subIdx, *newResults, rewriter);
213 rewriter.eraseOp(oldCall);
219genRead(Location loc, Value baseArrayOp, ArrayAttr index, ConversionPatternRewriter &rewriter) {
220 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
221 return rewriter.create<
ReadArrayOp>(loc, baseArrayOp, ValueRange(readOperands));
226void processInputOperand(
227 Location loc, Value operand, SmallVector<Value> &newOperands,
228 ConversionPatternRewriter &rewriter
230 if (
ArrayType at = splittableArray(operand.getType())) {
232 assert(indices.has_value() &&
"passed earlier hasStaticShape() check");
233 for (ArrayAttr index : indices.value()) {
234 newOperands.push_back(genRead(loc, operand, index, rewriter));
237 newOperands.push_back(operand);
242void processInputOperands(
243 ValueRange operands, MutableOperandRange outputOpRef, Operation *op,
244 ConversionPatternRewriter &rewriter
246 SmallVector<Value> newOperands;
247 for (Value v : operands) {
248 processInputOperand(op->getLoc(), v, newOperands, rewriter);
250 rewriter.modifyOpInPlace(op, [&outputOpRef, &newOperands]() {
251 outputOpRef.assign(ValueRange(newOperands));
257enum Direction : std::uint8_t {
266template <Direction dir>
267inline void rewriteImpl(
269 ConversionPatternRewriter &rewriter
272 Location loc = op.getLoc();
273 MLIRContext *ctx = op.getContext();
282 assert(std::cmp_equal(subIndices->size(), smallType.getNumElements()));
283 for (ArrayAttr indexingTail : subIndices.value()) {
284 SmallVector<Attribute> joined;
285 joined.append(indexAsAttr.begin(), indexAsAttr.end());
286 joined.append(indexingTail.begin(), indexingTail.end());
287 ArrayAttr fullIndex = ArrayAttr::get(ctx, joined);
289 if constexpr (dir == Direction::SMALL_TO_LARGE) {
290 auto init = genRead(loc, smallArr, indexingTail, rewriter);
291 genWrite(loc, largeArr, fullIndex, init, rewriter);
292 }
else if constexpr (dir == Direction::LARGE_TO_SMALL) {
293 auto init = genRead(loc, largeArr, fullIndex, rewriter);
294 genWrite(loc, smallArr, indexingTail, init, rewriter);
301class SplitInsertArrayOp :
public OpConversionPattern<InsertArrayOp> {
303 using OpConversionPattern<
InsertArrayOp>::OpConversionPattern;
306 return !containsSplittableArrayType(op.
getRvalue().getType());
309 LogicalResult match(
InsertArrayOp op)
const override {
return failure(legal(op)); }
312 rewrite(
InsertArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
314 rewriteImpl<SMALL_TO_LARGE>(
315 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, adaptor.getRvalue(),
316 adaptor.getArrRef(), rewriter
318 rewriter.eraseOp(op);
322class SplitExtractArrayOp :
public OpConversionPattern<ExtractArrayOp> {
327 return !containsSplittableArrayType(op.
getResult().getType());
330 LogicalResult match(
ExtractArrayOp op)
const override {
return failure(legal(op)); }
333 ExtractArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
337 auto newArray = rewriter.replaceOpWithNewOp<
CreateArrayOp>(op, at);
338 rewriteImpl<LARGE_TO_SMALL>(
339 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, newArray, adaptor.getArrRef(),
345class SplitInitFromCreateArrayOp :
public OpConversionPattern<CreateArrayOp> {
347 using OpConversionPattern<
CreateArrayOp>::OpConversionPattern;
351 LogicalResult match(
CreateArrayOp op)
const override {
return failure(legal(op)); }
354 rewrite(
CreateArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
358 rewriter.setInsertionPointAfter(op);
359 Location loc = op.getLoc();
361 for (
auto [i, init] : llvm::enumerate(adaptor.getElements())) {
363 assert(std::cmp_less_equal(i, std::numeric_limits<int64_t>::max()));
364 std::optional<SmallVector<Value>> multiDimIdxVals =
365 idxGen.
delinearize(
static_cast<int64_t
>(i), loc, rewriter);
368 assert(multiDimIdxVals.has_value());
375class SplitArrayInFuncDefOp :
public OpConversionPattern<FuncDefOp> {
377 using OpConversionPattern<
FuncDefOp>::OpConversionPattern;
386 static ArrayAttr replicateAttributesAsNeeded(
387 ArrayAttr origAttrs,
const SmallVector<size_t> &originalIdxToSize,
388 const SmallVector<Type> &newTypes
391 assert(originalIdxToSize.size() == origAttrs.size());
392 if (originalIdxToSize.size() != newTypes.size()) {
393 SmallVector<Attribute> newArgAttrs;
394 for (
auto [i, s] : llvm::enumerate(originalIdxToSize)) {
395 newArgAttrs.append(s, origAttrs[i]);
397 return ArrayAttr::get(origAttrs.getContext(), newArgAttrs);
403 LogicalResult match(
FuncDefOp op)
const override {
return failure(legal(op)); }
405 void rewrite(
FuncDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter)
const override {
408 SmallVector<size_t> originalInputIdxToSize, originalResultIdxToSize;
411 SmallVector<Type> convertInputs(ArrayRef<Type> origTypes)
override {
412 return splitArrayType(origTypes, &originalInputIdxToSize);
414 SmallVector<Type> convertResults(ArrayRef<Type> origTypes)
override {
415 return splitArrayType(origTypes, &originalResultIdxToSize);
417 ArrayAttr convertInputAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes)
override {
418 return replicateAttributesAsNeeded(origAttrs, originalInputIdxToSize, newTypes);
420 ArrayAttr convertResultAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes)
override {
421 return replicateAttributesAsNeeded(origAttrs, originalResultIdxToSize, newTypes);
428 void processBlockArgs(Block &entryBlock, RewriterBase &rewriter)
override {
429 OpBuilder::InsertionGuard guard(rewriter);
430 rewriter.setInsertionPointToStart(&entryBlock);
432 for (
unsigned i = 0; i < entryBlock.getNumArguments();) {
433 Value oldV = entryBlock.getArgument(i);
434 if (
ArrayType at = splittableArray(oldV.getType())) {
435 Location loc = oldV.getLoc();
438 rewriter.replaceAllUsesWith(oldV, newArray);
440 entryBlock.eraseArgument(i);
445 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
446 for (ArrayAttr subIdx : allIndices.value()) {
447 BlockArgument newArg = entryBlock.insertArgument(i, at.
getElementType(), loc);
448 genWrite(loc, newArray, subIdx, newArg, rewriter);
457 Impl().convert(op, rewriter);
461class SplitArrayInReturnOp :
public OpConversionPattern<ReturnOp> {
463 using OpConversionPattern<
ReturnOp>::OpConversionPattern;
465 inline static bool legal(
ReturnOp op) {
466 return !containsSplittableArrayType(op.
getOperands().getTypes());
469 LogicalResult match(
ReturnOp op)
const override {
return failure(legal(op)); }
471 void rewrite(
ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
476class SplitArrayInCallOp :
public OpConversionPattern<CallOp> {
478 using OpConversionPattern<
CallOp>::OpConversionPattern;
480 inline static bool legal(
CallOp op) {
481 return !containsSplittableArrayType(op.
getArgOperands().getTypes()) &&
482 !containsSplittableArrayType(op.getResultTypes());
485 LogicalResult match(
CallOp op)
const override {
return failure(legal(op)); }
487 void rewrite(
CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
491 CallOp newCall = newCallOpWithSplitResults(op, adaptor, rewriter);
492 processInputOperands(
498class ReplaceKnownArrayLengthOp :
public OpConversionPattern<ArrayLengthOp> {
500 using OpConversionPattern<
ArrayLengthOp>::OpConversionPattern;
503 static std::optional<llvm::APInt> getDimSizeIfKnown(Value dimIdx,
ArrayType baseArrType) {
504 if (splittableArray(baseArrType)) {
506 if (mlir::matchPattern(dimIdx, mlir::m_ConstantInt(&idxAP))) {
507 uint64_t idx64 = idxAP.getZExtValue();
508 assert(std::cmp_less_equal(idx64, std::numeric_limits<size_t>::max()));
509 Attribute dimSizeAttr = baseArrType.
getDimensionSizes()[
static_cast<size_t>(idx64)];
510 if (mlir::matchPattern(dimSizeAttr, mlir::m_ConstantInt(&idxAP))) {
523 LogicalResult match(
ArrayLengthOp op)
const override {
return failure(legal(op)); }
526 rewrite(
ArrayLengthOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
527 ArrayType arrTy = dyn_cast<ArrayType>(adaptor.getArrRef().getType());
529 std::optional<llvm::APInt> len = getDimSizeIfKnown(adaptor.getDim(), arrTy);
530 assert(len.has_value());
531 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op,
llzk::fromAPInt(len.value()));
536using MemberInfo = std::pair<StringAttr, Type>;
538using LocalMemberReplacementMap = DenseMap<ArrayAttr, MemberInfo>;
540using MemberReplacementMap = DenseMap<StructDefOp, DenseMap<StringAttr, LocalMemberReplacementMap>>;
542class SplitArrayInMemberDefOp :
public OpConversionPattern<MemberDefOp> {
543 SymbolTableCollection &tables;
544 MemberReplacementMap &repMapRef;
547 SplitArrayInMemberDefOp(
548 MLIRContext *ctx, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap
550 : OpConversionPattern<MemberDefOp>(ctx), tables(symTables), repMapRef(memberRepMap) {}
552 inline static bool legal(
MemberDefOp op) {
return !containsSplittableArrayType(op.
getType()); }
554 LogicalResult match(
MemberDefOp op)
const override {
return failure(legal(op)); }
556 void rewrite(
MemberDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter)
const override {
559 LocalMemberReplacementMap &localRepMapRef = repMapRef[inStruct][op.
getSymNameAttr()];
564 assert(subIdxs.has_value());
567 SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct);
568 for (ArrayAttr idx : subIdxs.value()) {
575 localRepMapRef[idx] = std::make_pair(structSymbolTable.insert(newMember), elemTy);
577 rewriter.eraseOp(op);
588class SplitArrayInMemberRefOp :
public OpConversionPattern<MemberRefOpClass> {
589 SymbolTableCollection &tables;
590 const MemberReplacementMap &repMapRef;
593 inline static void ensureImplementedAtCompile() {
595 sizeof(MemberRefOpClass) == 0,
"SplitArrayInMemberRefOp not implemented for requested type."
600 using OpAdaptor =
typename MemberRefOpClass::Adaptor;
604 static GenHeaderType genHeader(MemberRefOpClass, ConversionPatternRewriter &) {
605 ensureImplementedAtCompile();
606 llvm_unreachable(
"must have concrete instantiation");
612 forIndex(Location, GenHeaderType, ArrayAttr, MemberInfo, OpAdaptor, ConversionPatternRewriter &) {
613 ensureImplementedAtCompile();
614 llvm_unreachable(
"must have concrete instantiation");
618 SplitArrayInMemberRefOp(
619 MLIRContext *ctx, SymbolTableCollection &symTables,
const MemberReplacementMap &memberRepMap
621 : OpConversionPattern<MemberRefOpClass>(ctx), tables(symTables), repMapRef(memberRepMap) {}
623 static bool legal(MemberRefOpClass) {
624 ensureImplementedAtCompile();
625 llvm_unreachable(
"must have concrete instantiation");
629 LogicalResult match(MemberRefOpClass op)
const override {
return failure(ImplClass::legal(op)); }
632 MemberRefOpClass op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
634 StructType tgtStructTy = llvm::cast<MemberRefOpInterface>(op.getOperation()).getStructType();
637 assert(succeeded(tgtStructDef));
639 GenHeaderType prefixResult = ImplClass::genHeader(op, rewriter);
641 const LocalMemberReplacementMap &idxToName =
642 repMapRef.at(tgtStructDef->get()).at(op.getMemberNameAttr().getAttr());
644 for (
auto [idx, newMember] : idxToName) {
645 ImplClass::forIndex(op.getLoc(), prefixResult, idx, newMember, adaptor, rewriter);
647 rewriter.eraseOp(op);
651class SplitArrayInMemberWriteOp
652 :
public SplitArrayInMemberRefOp<SplitArrayInMemberWriteOp, MemberWriteOp, void *> {
654 using SplitArrayInMemberRefOp<
655 SplitArrayInMemberWriteOp,
MemberWriteOp,
void *>::SplitArrayInMemberRefOp;
658 return !containsSplittableArrayType(op.
getVal().getType());
661 static void *genHeader(
MemberWriteOp, ConversionPatternRewriter &) {
return nullptr; }
663 static void forIndex(
664 Location loc,
void *, ArrayAttr idx, MemberInfo newMember, OpAdaptor adaptor,
665 ConversionPatternRewriter &rewriter
667 ReadArrayOp scalarRead = genRead(loc, adaptor.getVal(), idx, rewriter);
669 loc, adaptor.getComponent(), FlatSymbolRefAttr::get(newMember.first), scalarRead
674class SplitArrayInMemberReadOp
675 :
public SplitArrayInMemberRefOp<SplitArrayInMemberReadOp, MemberReadOp, CreateArrayOp> {
677 using SplitArrayInMemberRefOp<
681 return !containsSplittableArrayType(op.getResult().getType());
686 rewriter.create<
CreateArrayOp>(op.getLoc(), llvm::cast<ArrayType>(op.getType()));
687 rewriter.replaceAllUsesWith(op, newArray);
691 static void forIndex(
692 Location loc,
CreateArrayOp newArray, ArrayAttr idx, MemberInfo newMember, OpAdaptor adaptor,
693 ConversionPatternRewriter &rewriter
696 loc, newMember.second, adaptor.getComponent(), newMember.first
698 genWrite(loc, newArray, idx, scalarRead, rewriter);
703step1(ModuleOp modOp, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap) {
704 MLIRContext *ctx = modOp.getContext();
706 RewritePatternSet patterns(ctx);
708 patterns.add<SplitArrayInMemberDefOp>(ctx, symTables, memberRepMap);
710 ConversionTarget target(*ctx);
711 target.addLegalDialect<
716 target.addLegalOp<ModuleOp>();
717 target.addDynamicallyLegalOp<
MemberDefOp>(SplitArrayInMemberDefOp::legal);
719 LLVM_DEBUG(llvm::dbgs() <<
"Begin step 1: split array members\n";);
720 return applyFullConversion(modOp, target, std::move(patterns));
724step2(ModuleOp modOp, SymbolTableCollection &symTables,
const MemberReplacementMap &memberRepMap) {
725 MLIRContext *ctx = modOp.getContext();
727 RewritePatternSet patterns(ctx);
730 SplitInitFromCreateArrayOp,
733 SplitArrayInFuncDefOp,
734 SplitArrayInReturnOp,
736 ReplaceKnownArrayLengthOp
742 SplitArrayInMemberWriteOp,
743 SplitArrayInMemberReadOp
745 >(ctx, symTables, memberRepMap);
747 ConversionTarget target(*ctx);
748 target.addLegalDialect<
752 target.addLegalOp<ModuleOp>();
753 target.addDynamicallyLegalOp<
CreateArrayOp>(SplitInitFromCreateArrayOp::legal);
754 target.addDynamicallyLegalOp<
InsertArrayOp>(SplitInsertArrayOp::legal);
755 target.addDynamicallyLegalOp<
ExtractArrayOp>(SplitExtractArrayOp::legal);
756 target.addDynamicallyLegalOp<
FuncDefOp>(SplitArrayInFuncDefOp::legal);
757 target.addDynamicallyLegalOp<
ReturnOp>(SplitArrayInReturnOp::legal);
758 target.addDynamicallyLegalOp<
CallOp>(SplitArrayInCallOp::legal);
759 target.addDynamicallyLegalOp<
ArrayLengthOp>(ReplaceKnownArrayLengthOp::legal);
760 target.addDynamicallyLegalOp<
MemberWriteOp>(SplitArrayInMemberWriteOp::legal);
761 target.addDynamicallyLegalOp<
MemberReadOp>(SplitArrayInMemberReadOp::legal);
763 LLVM_DEBUG(llvm::dbgs() <<
"Begin step 2: update/split other array ops\n";);
764 return applyFullConversion(modOp, target, std::move(patterns));
767LogicalResult splitArrayCreateInit(ModuleOp modOp) {
768 SymbolTableCollection symTables;
769 MemberReplacementMap memberRepMap;
775 if (failed(step1(modOp, symTables, memberRepMap))) {
779 llvm::dbgs() <<
"After step 1:\n";
782 if (failed(step2(modOp, symTables, memberRepMap))) {
786 llvm::dbgs() <<
"After step 2:\n";
793 void runOnOperation()
override {
794 ModuleOp module = getOperation();
797 if (failed(splitArrayCreateInit(module))) {
801 OpPassManager nestedPM(ModuleOp::getOperationName());
805 nestedPM.addPass(createSROA());
807 nestedPM.addPass(createMem2Reg());
809 nestedPM.addPass(createRemoveDeadValuesPass());
810 if (failed(runPipeline(nestedPM, module))) {
815 llvm::dbgs() <<
"After SROA+Mem2Reg pipeline:\n";
824 return std::make_unique<ArrayToScalarPass>();
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
::mlir::ArrayAttr indexOperandsToAttributeArray()
Returns the multi-dimensional indices of the array access as an Attribute array or a null pointer if ...
Helper for converting between linear and multi-dimensional indexing with checks to ensure indices are...
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced base array.
::mlir::TypedValue<::mlir::IndexType > getDim()
std::optional<::llvm::SmallVector<::mlir::ArrayAttr > > getSubelementIndices() const
Return a list of all valid indices for this ArrayType.
::mlir::Type getElementType() const
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
::mlir::MutableOperandRange getElementsMutable()
::mlir::Operation::operand_range getElements()
::mlir::TypedValue<::llzk::array::ArrayType > getRvalue()
void setPublicAttr(bool newValue=true)
::mlir::StringAttr getSymNameAttr()
::mlir::TypedValue<::mlir::Type > getVal()
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
::mlir::SymbolRefAttr getCallee()
::mlir::MutableOperandRange getArgOperandsMutable()
::mlir::Operation::operand_range getArgOperands()
::mlir::DenseI32ArrayAttr getMapOpGroupSizesAttr()
::mlir::FunctionType getFunctionType()
::mlir::MutableOperandRange getOperandsMutable()
::mlir::Operation::operand_range getOperands()
Restricts a template parameter to Op classes that implement the given OpInterface.
std::unique_ptr< mlir::Pass > createArrayToScalarPass()
bool isNullOrEmpty(mlir::ArrayAttr a)
int64_t fromAPInt(const llvm::APInt &i)