66#include <mlir/IR/BuiltinOps.h>
67#include <mlir/Pass/PassManager.h>
68#include <mlir/Transforms/DialectConversion.h>
69#include <mlir/Transforms/Passes.h>
71#include <llvm/Support/Debug.h>
75#define GEN_PASS_DEF_ARRAYTOSCALARPASS
85#define DEBUG_TYPE "llzk-array-to-scalar"
90inline ArrayType splittableArray(
ArrayType at) {
return at.hasStaticShape() ? at :
nullptr; }
94 if (
ArrayType at = dyn_cast<ArrayType>(t)) {
95 return splittableArray(at);
102inline bool containsSplittableArrayType(Type t) {
105 return splittableArray(a) ? WalkResult::interrupt() : WalkResult::skip();
110template <
typename T>
bool containsSplittableArrayType(ValueTypeRange<T> types) {
111 for (Type t : types) {
112 if (containsSplittableArrayType(t)) {
121size_t splitArrayTypeTo(Type t, SmallVector<Type> &collect) {
127 collect.push_back(t);
133template <
typename TypeCollection>
134inline void splitArrayTypeTo(
135 TypeCollection types, SmallVector<Type> &collect, SmallVector<size_t> *originalIdxToSize
137 for (Type t : types) {
138 size_t count = splitArrayTypeTo(t, collect);
139 if (originalIdxToSize) {
140 originalIdxToSize->push_back(count);
147template <
typename TypeCollection>
148inline SmallVector<Type>
149splitArrayType(TypeCollection types, SmallVector<size_t> *originalIdxToSize =
nullptr) {
150 SmallVector<Type> collect;
151 splitArrayTypeTo(types, collect, originalIdxToSize);
157SmallVector<Value> genIndexConstants(ArrayAttr index, Location loc, RewriterBase &rewriter) {
158 SmallVector<Value> operands;
159 for (Attribute a : index) {
161 IntegerAttr ia = llvm::dyn_cast<IntegerAttr>(a);
162 assert(ia && ia.getType().isIndex());
163 operands.push_back(rewriter.create<arith::ConstantOp>(loc, ia));
169genWrite(Location loc, Value baseArrayOp, ArrayAttr index, Value init, RewriterBase &rewriter) {
170 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
171 return rewriter.create<
WriteArrayOp>(loc, baseArrayOp, ValueRange(readOperands), init);
177CallOp newCallOpWithSplitResults(
180 OpBuilder::InsertionGuard guard(rewriter);
181 rewriter.setInsertionPointAfter(oldCall);
183 Operation::result_range oldResults = oldCall.getResults();
185 oldCall.getLoc(), splitArrayType(oldResults.getTypes()), oldCall.
getCallee(),
189 auto newResults = newCall.getResults().begin();
190 for (Value oldVal : oldResults) {
191 if (
ArrayType at = splittableArray(oldVal.getType())) {
192 Location loc = oldVal.getLoc();
195 rewriter.replaceAllUsesWith(oldVal, newArray);
201 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
202 for (ArrayAttr subIdx : allIndices.value()) {
203 genWrite(loc, newArray, subIdx, *newResults, rewriter);
211 rewriter.eraseOp(oldCall);
217genRead(Location loc, Value baseArrayOp, ArrayAttr index, ConversionPatternRewriter &rewriter) {
218 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
219 return rewriter.create<
ReadArrayOp>(loc, baseArrayOp, ValueRange(readOperands));
224void processInputOperand(
225 Location loc, Value operand, SmallVector<Value> &newOperands,
226 ConversionPatternRewriter &rewriter
228 if (
ArrayType at = splittableArray(operand.getType())) {
230 assert(indices.has_value() &&
"passed earlier hasStaticShape() check");
231 for (ArrayAttr index : indices.value()) {
232 newOperands.push_back(genRead(loc, operand, index, rewriter));
235 newOperands.push_back(operand);
240void processInputOperands(
241 ValueRange operands, MutableOperandRange outputOpRef, Operation *op,
242 ConversionPatternRewriter &rewriter
244 SmallVector<Value> newOperands;
245 for (Value v : operands) {
246 processInputOperand(op->getLoc(), v, newOperands, rewriter);
248 rewriter.modifyOpInPlace(op, [&outputOpRef, &newOperands]() {
249 outputOpRef.assign(ValueRange(newOperands));
255enum Direction : std::uint8_t {
264template <Direction dir>
265inline void rewriteImpl(
267 ConversionPatternRewriter &rewriter
270 Location loc = op.getLoc();
271 MLIRContext *ctx = op.getContext();
280 assert(std::cmp_equal(subIndices->size(), smallType.getNumElements()));
281 for (ArrayAttr indexingTail : subIndices.value()) {
282 SmallVector<Attribute> joined;
283 joined.append(indexAsAttr.begin(), indexAsAttr.end());
284 joined.append(indexingTail.begin(), indexingTail.end());
285 ArrayAttr fullIndex = ArrayAttr::get(ctx, joined);
287 if constexpr (dir == Direction::SMALL_TO_LARGE) {
288 auto init = genRead(loc, smallArr, indexingTail, rewriter);
289 genWrite(loc, largeArr, fullIndex, init, rewriter);
290 }
else if constexpr (dir == Direction::LARGE_TO_SMALL) {
291 auto init = genRead(loc, largeArr, fullIndex, rewriter);
292 genWrite(loc, smallArr, indexingTail, init, rewriter);
299class SplitInsertArrayOp :
public OpConversionPattern<InsertArrayOp> {
301 using OpConversionPattern<
InsertArrayOp>::OpConversionPattern;
304 return !containsSplittableArrayType(op.
getRvalue().getType());
307 LogicalResult match(
InsertArrayOp op)
const override {
return failure(legal(op)); }
310 rewrite(
InsertArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
312 rewriteImpl<SMALL_TO_LARGE>(
313 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, adaptor.getRvalue(),
314 adaptor.getArrRef(), rewriter
316 rewriter.eraseOp(op);
320class SplitExtractArrayOp :
public OpConversionPattern<ExtractArrayOp> {
325 return !containsSplittableArrayType(op.
getResult().getType());
328 LogicalResult match(
ExtractArrayOp op)
const override {
return failure(legal(op)); }
331 ExtractArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
335 auto newArray = rewriter.replaceOpWithNewOp<
CreateArrayOp>(op, at);
336 rewriteImpl<LARGE_TO_SMALL>(
337 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, newArray, adaptor.getArrRef(),
343class SplitInitFromCreateArrayOp :
public OpConversionPattern<CreateArrayOp> {
345 using OpConversionPattern<
CreateArrayOp>::OpConversionPattern;
349 LogicalResult match(
CreateArrayOp op)
const override {
return failure(legal(op)); }
352 rewrite(
CreateArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
356 rewriter.setInsertionPointAfter(op);
357 Location loc = op.getLoc();
359 for (
auto [i, init] : llvm::enumerate(adaptor.getElements())) {
361 std::optional<SmallVector<Value>> multiDimIdxVals =
365 assert(multiDimIdxVals.has_value());
372class SplitArrayInFuncDefOp :
public OpConversionPattern<FuncDefOp> {
374 using OpConversionPattern<
FuncDefOp>::OpConversionPattern;
383 static ArrayAttr replicateAttributesAsNeeded(
384 ArrayAttr origAttrs,
const SmallVector<size_t> &originalIdxToSize,
385 const SmallVector<Type> &newTypes
388 assert(originalIdxToSize.size() == origAttrs.size());
389 if (originalIdxToSize.size() != newTypes.size()) {
390 SmallVector<Attribute> newArgAttrs;
391 for (
auto [i, s] : llvm::enumerate(originalIdxToSize)) {
392 newArgAttrs.append(s, origAttrs[i]);
394 return ArrayAttr::get(origAttrs.getContext(), newArgAttrs);
400 LogicalResult match(
FuncDefOp op)
const override {
return failure(legal(op)); }
402 void rewrite(
FuncDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter)
const override {
405 SmallVector<size_t> originalInputIdxToSize, originalResultIdxToSize;
408 SmallVector<Type> convertInputs(ArrayRef<Type> origTypes)
override {
409 return splitArrayType(origTypes, &originalInputIdxToSize);
411 SmallVector<Type> convertResults(ArrayRef<Type> origTypes)
override {
412 return splitArrayType(origTypes, &originalResultIdxToSize);
414 ArrayAttr convertInputAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes)
override {
415 return replicateAttributesAsNeeded(origAttrs, originalInputIdxToSize, newTypes);
417 ArrayAttr convertResultAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes)
override {
418 return replicateAttributesAsNeeded(origAttrs, originalResultIdxToSize, newTypes);
425 void processBlockArgs(Block &entryBlock, RewriterBase &rewriter)
override {
426 OpBuilder::InsertionGuard guard(rewriter);
427 rewriter.setInsertionPointToStart(&entryBlock);
429 for (
unsigned i = 0; i < entryBlock.getNumArguments();) {
430 Value oldV = entryBlock.getArgument(i);
431 if (
ArrayType at = splittableArray(oldV.getType())) {
432 Location loc = oldV.getLoc();
435 rewriter.replaceAllUsesWith(oldV, newArray);
437 entryBlock.eraseArgument(i);
442 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
443 for (ArrayAttr subIdx : allIndices.value()) {
444 BlockArgument newArg = entryBlock.insertArgument(i, at.
getElementType(), loc);
445 genWrite(loc, newArray, subIdx, newArg, rewriter);
454 Impl().convert(op, rewriter);
458class SplitArrayInReturnOp :
public OpConversionPattern<ReturnOp> {
460 using OpConversionPattern<
ReturnOp>::OpConversionPattern;
462 inline static bool legal(
ReturnOp op) {
463 return !containsSplittableArrayType(op.
getOperands().getTypes());
466 LogicalResult match(
ReturnOp op)
const override {
return failure(legal(op)); }
468 void rewrite(
ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
473class SplitArrayInCallOp :
public OpConversionPattern<CallOp> {
475 using OpConversionPattern<
CallOp>::OpConversionPattern;
477 inline static bool legal(
CallOp op) {
478 return !containsSplittableArrayType(op.
getArgOperands().getTypes()) &&
479 !containsSplittableArrayType(op.getResultTypes());
482 LogicalResult match(
CallOp op)
const override {
return failure(legal(op)); }
484 void rewrite(
CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
488 CallOp newCall = newCallOpWithSplitResults(op, adaptor, rewriter);
489 processInputOperands(
495class ReplaceKnownArrayLengthOp :
public OpConversionPattern<ArrayLengthOp> {
497 using OpConversionPattern<
ArrayLengthOp>::OpConversionPattern;
500 static std::optional<llvm::APInt> getDimSizeIfKnown(Value dimIdx,
ArrayType baseArrType) {
501 if (splittableArray(baseArrType)) {
503 if (mlir::matchPattern(dimIdx, mlir::m_ConstantInt(&idxAP))) {
506 if (mlir::matchPattern(dimSizeAttr, mlir::m_ConstantInt(&idxAP))) {
519 LogicalResult match(
ArrayLengthOp op)
const override {
return failure(legal(op)); }
522 rewrite(
ArrayLengthOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
523 ArrayType arrTy = dyn_cast<ArrayType>(adaptor.getArrRef().getType());
525 std::optional<llvm::APInt> len = getDimSizeIfKnown(adaptor.getDim(), arrTy);
526 assert(len.has_value());
527 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op,
llzk::fromAPInt(len.value()));
532using MemberInfo = std::pair<StringAttr, Type>;
534using LocalMemberReplacementMap = DenseMap<ArrayAttr, MemberInfo>;
536using MemberReplacementMap = DenseMap<StructDefOp, DenseMap<StringAttr, LocalMemberReplacementMap>>;
538class SplitArrayInMemberDefOp :
public OpConversionPattern<MemberDefOp> {
539 SymbolTableCollection &tables;
540 MemberReplacementMap &repMapRef;
543 SplitArrayInMemberDefOp(
544 MLIRContext *ctx, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap
546 : OpConversionPattern<MemberDefOp>(ctx), tables(symTables), repMapRef(memberRepMap) {}
548 inline static bool legal(
MemberDefOp op) {
return !containsSplittableArrayType(op.
getType()); }
550 LogicalResult match(
MemberDefOp op)
const override {
return failure(legal(op)); }
552 void rewrite(
MemberDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter)
const override {
555 LocalMemberReplacementMap &localRepMapRef = repMapRef[inStruct][op.
getSymNameAttr()];
560 assert(subIdxs.has_value());
563 SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct);
564 for (ArrayAttr idx : subIdxs.value()) {
571 localRepMapRef[idx] = std::make_pair(structSymbolTable.insert(newMember), elemTy);
573 rewriter.eraseOp(op);
584class SplitArrayInMemberRefOp :
public OpConversionPattern<MemberRefOpClass> {
585 SymbolTableCollection &tables;
586 const MemberReplacementMap &repMapRef;
589 inline static void ensureImplementedAtCompile() {
591 sizeof(MemberRefOpClass) == 0,
"SplitArrayInMemberRefOp not implemented for requested type."
596 using OpAdaptor =
typename MemberRefOpClass::Adaptor;
600 static GenHeaderType genHeader(MemberRefOpClass, ConversionPatternRewriter &) {
601 ensureImplementedAtCompile();
602 llvm_unreachable(
"must have concrete instantiation");
608 forIndex(Location, GenHeaderType, ArrayAttr, MemberInfo, OpAdaptor, ConversionPatternRewriter &) {
609 ensureImplementedAtCompile();
610 llvm_unreachable(
"must have concrete instantiation");
616 SplitArrayInMemberRefOp(
617 MLIRContext *ctx, SymbolTableCollection &symTables,
const MemberReplacementMap &memberRepMap
619 : OpConversionPattern<MemberRefOpClass>(ctx), tables(symTables), repMapRef(memberRepMap) {}
621 static bool legal(MemberRefOpClass) {
622 ensureImplementedAtCompile();
623 llvm_unreachable(
"must have concrete instantiation");
627 LogicalResult match(MemberRefOpClass op)
const override {
return failure(ImplClass::legal(op)); }
630 MemberRefOpClass op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
632 StructType tgtStructTy = llvm::cast<MemberRefOpInterface>(op.getOperation()).getStructType();
635 assert(succeeded(tgtStructDef));
637 GenHeaderType prefixResult = ImplClass::genHeader(op, rewriter);
639 const LocalMemberReplacementMap &idxToName =
640 repMapRef.at(tgtStructDef->get()).at(op.getMemberNameAttr().getAttr());
642 for (
auto [idx, newMember] : idxToName) {
643 ImplClass::forIndex(op.getLoc(), prefixResult, idx, newMember, adaptor, rewriter);
645 rewriter.eraseOp(op);
649class SplitArrayInMemberWriteOp
650 :
public SplitArrayInMemberRefOp<SplitArrayInMemberWriteOp, MemberWriteOp, void *> {
652 using SplitArrayInMemberRefOp<
653 SplitArrayInMemberWriteOp,
MemberWriteOp,
void *>::SplitArrayInMemberRefOp;
656 return !containsSplittableArrayType(op.
getVal().getType());
659 static void *genHeader(
MemberWriteOp, ConversionPatternRewriter &) {
return nullptr; }
661 static void forIndex(
662 Location loc,
void *, ArrayAttr idx, MemberInfo newMember, OpAdaptor adaptor,
663 ConversionPatternRewriter &rewriter
665 ReadArrayOp scalarRead = genRead(loc, adaptor.getVal(), idx, rewriter);
667 loc, adaptor.getComponent(), FlatSymbolRefAttr::get(newMember.first), scalarRead
672class SplitArrayInMemberReadOp
673 :
public SplitArrayInMemberRefOp<SplitArrayInMemberReadOp, MemberReadOp, CreateArrayOp> {
675 using SplitArrayInMemberRefOp<
679 return !containsSplittableArrayType(op.getResult().getType());
684 rewriter.create<
CreateArrayOp>(op.getLoc(), llvm::cast<ArrayType>(op.getType()));
685 rewriter.replaceAllUsesWith(op, newArray);
689 static void forIndex(
690 Location loc,
CreateArrayOp newArray, ArrayAttr idx, MemberInfo newMember, OpAdaptor adaptor,
691 ConversionPatternRewriter &rewriter
694 loc, newMember.second, adaptor.getComponent(), newMember.first
696 genWrite(loc, newArray, idx, scalarRead, rewriter);
700static void baseTargetSetup(ConversionTarget &target) {
701 target.addLegalDialect<
706 target.addLegalOp<ModuleOp>();
710step1(ModuleOp modOp, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap) {
711 MLIRContext *ctx = modOp.getContext();
713 RewritePatternSet patterns(ctx);
715 patterns.add<SplitArrayInMemberDefOp>(ctx, symTables, memberRepMap);
717 ConversionTarget target(*ctx);
718 baseTargetSetup(target);
719 target.addDynamicallyLegalOp<
MemberDefOp>(SplitArrayInMemberDefOp::legal);
721 LLVM_DEBUG(llvm::dbgs() <<
"Begin step 1: split array members\n";);
722 return applyFullConversion(modOp, target, std::move(patterns));
726step2(ModuleOp modOp, SymbolTableCollection &symTables,
const MemberReplacementMap &memberRepMap) {
727 MLIRContext *ctx = modOp.getContext();
729 RewritePatternSet patterns(ctx);
732 SplitInitFromCreateArrayOp,
735 SplitArrayInFuncDefOp,
736 SplitArrayInReturnOp,
738 ReplaceKnownArrayLengthOp
744 SplitArrayInMemberWriteOp,
745 SplitArrayInMemberReadOp
747 >(ctx, symTables, memberRepMap);
749 ConversionTarget target(*ctx);
750 baseTargetSetup(target);
751 target.addDynamicallyLegalOp<
CreateArrayOp>(SplitInitFromCreateArrayOp::legal);
752 target.addDynamicallyLegalOp<
InsertArrayOp>(SplitInsertArrayOp::legal);
753 target.addDynamicallyLegalOp<
ExtractArrayOp>(SplitExtractArrayOp::legal);
754 target.addDynamicallyLegalOp<
FuncDefOp>(SplitArrayInFuncDefOp::legal);
755 target.addDynamicallyLegalOp<
ReturnOp>(SplitArrayInReturnOp::legal);
756 target.addDynamicallyLegalOp<
CallOp>(SplitArrayInCallOp::legal);
757 target.addDynamicallyLegalOp<
ArrayLengthOp>(ReplaceKnownArrayLengthOp::legal);
758 target.addDynamicallyLegalOp<
MemberWriteOp>(SplitArrayInMemberWriteOp::legal);
759 target.addDynamicallyLegalOp<
MemberReadOp>(SplitArrayInMemberReadOp::legal);
761 LLVM_DEBUG(llvm::dbgs() <<
"Begin step 2: update/split other array ops\n";);
762 return applyFullConversion(modOp, target, std::move(patterns));
765LogicalResult splitArrayCreateInit(ModuleOp modOp) {
766 SymbolTableCollection symTables;
767 MemberReplacementMap memberRepMap;
773 if (failed(step1(modOp, symTables, memberRepMap))) {
777 llvm::dbgs() <<
"After step 1:\n";
780 if (failed(step2(modOp, symTables, memberRepMap))) {
784 llvm::dbgs() <<
"After step 2:\n";
791 void runOnOperation()
override {
792 ModuleOp module = getOperation();
795 if (failed(splitArrayCreateInit(module))) {
799 OpPassManager nestedPM(ModuleOp::getOperationName());
803 nestedPM.addPass(createSROA());
805 nestedPM.addPass(createMem2Reg());
807 nestedPM.addPass(createRemoveDeadValuesPass());
808 if (failed(runPipeline(nestedPM, module))) {
813 llvm::dbgs() <<
"After SROA+Mem2Reg pipeline:\n";
822 return std::make_unique<ArrayToScalarPass>();
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
::mlir::ArrayAttr indexOperandsToAttributeArray()
Returns the multi-dimensional indices of the array access as an Attribute array or a null pointer if ...
Helper for converting between linear and multi-dimensional indexing with checks to ensure indices are...
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced base array.
::mlir::TypedValue<::mlir::IndexType > getDim()
std::optional<::llvm::SmallVector<::mlir::ArrayAttr > > getSubelementIndices() const
Return a list of all valid indices for this ArrayType.
::mlir::Type getElementType() const
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
::mlir::MutableOperandRange getElementsMutable()
::mlir::Operation::operand_range getElements()
::mlir::TypedValue<::llzk::array::ArrayType > getRvalue()
void setPublicAttr(bool newValue=true)
::mlir::StringAttr getSymNameAttr()
::mlir::TypedValue<::mlir::Type > getVal()
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
::mlir::SymbolRefAttr getCallee()
::mlir::MutableOperandRange getArgOperandsMutable()
::mlir::Operation::operand_range getArgOperands()
::mlir::DenseI32ArrayAttr getMapOpGroupSizesAttr()
::mlir::FunctionType getFunctionType()
::mlir::MutableOperandRange getOperandsMutable()
::mlir::Operation::operand_range getOperands()
Restricts a template parameter to Op classes that implement the given OpInterface.
std::unique_ptr< mlir::Pass > createArrayToScalarPass()
bool isNullOrEmpty(mlir::ArrayAttr a)
constexpr T checkedCast(U u) noexcept
int64_t fromAPInt(const llvm::APInt &i)