89#include <mlir/Dialect/SCF/IR/SCF.h>
90#include <mlir/IR/BuiltinOps.h>
91#include <mlir/Pass/PassManager.h>
92#include <mlir/Transforms/DialectConversion.h>
93#include <mlir/Transforms/Passes.h>
95#include <llvm/Support/Debug.h>
101#define GEN_PASS_DEF_ARRAYTOSCALARPASS
111#define DEBUG_TYPE "llzk-array-to-scalar"
116inline ArrayType splittableArray(
ArrayType at) {
return at.hasStaticShape() ? at :
nullptr; }
119inline ArrayType splittableArray(Type t) {
120 if (
ArrayType at = dyn_cast<ArrayType>(t)) {
121 return splittableArray(at);
129inline bool containsSplittableArrayType(ArrayRef<Type> types) {
130 for (Type t : types) {
131 if (splittableArray(t)) {
139template <
typename T>
bool containsSplittableArrayType(ValueTypeRange<T> types) {
140 for (Type t : types) {
141 if (splittableArray(t)) {
150size_t splitArrayTypeTo(Type t, SmallVector<Type> &collect) {
156 collect.push_back(t);
162template <
typename TypeCollection>
163inline void splitArrayTypeTo(
164 TypeCollection types, SmallVector<Type> &collect, SmallVector<size_t> *originalIdxToSize
166 for (Type t : types) {
167 size_t count = splitArrayTypeTo(t, collect);
168 if (originalIdxToSize) {
169 originalIdxToSize->push_back(count);
176template <
typename TypeCollection>
177inline SmallVector<Type>
178splitArrayType(TypeCollection types, SmallVector<size_t> *originalIdxToSize =
nullptr) {
179 SmallVector<Type> collect;
180 splitArrayTypeTo(types, collect, originalIdxToSize);
186SmallVector<Value> genIndexConstants(ArrayAttr index, Location loc, RewriterBase &rewriter) {
187 SmallVector<Value> operands;
188 for (Attribute a : index) {
190 IntegerAttr ia = llvm::dyn_cast<IntegerAttr>(a);
191 assert(ia && ia.getType().isIndex());
192 operands.push_back(rewriter.create<arith::ConstantOp>(loc, ia));
199genWrite(Location loc, Value baseArrayOp, ArrayAttr index, Value init, RewriterBase &rewriter) {
200 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
201 return rewriter.create<
WriteArrayOp>(loc, baseArrayOp, ValueRange(readOperands), init);
205static std::string formatSplitArrayIndexSuffix(ArrayAttr index) {
207 llvm::raw_string_ostream os(suffix);
208 for (Attribute attr : index) {
210 attr.print(os,
true);
217static SmallVector<std::string> getSplitArrayIndexSuffixes(Type type) {
218 SmallVector<std::string> suffixes;
219 if (
ArrayType at = splittableArray(type)) {
221 assert(indices.has_value() &&
"static-shape arrays must provide subelement indices");
222 suffixes.reserve(indices->size());
223 for (ArrayAttr index : *indices) {
224 suffixes.push_back(formatSplitArrayIndexSuffix(index));
231CallOp newCallOpWithSplitResults(
234 OpBuilder::InsertionGuard guard(rewriter);
235 rewriter.setInsertionPointAfter(oldCall);
237 Operation::result_range oldResults = oldCall.getResults();
239 oldCall.getLoc(), splitArrayType(oldResults.getTypes()), oldCall, adaptor.
getMapOperands(),
243 auto newResults = newCall.getResults().begin();
244 for (Value oldVal : oldResults) {
245 if (
ArrayType at = splittableArray(oldVal.getType())) {
246 Location loc = oldVal.getLoc();
249 rewriter.replaceAllUsesWith(oldVal, newArray);
255 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
256 for (ArrayAttr subIdx : allIndices.value()) {
257 genWrite(loc, newArray, subIdx, *newResults, rewriter);
261 rewriter.replaceAllUsesWith(oldVal, *newResults);
266 rewriter.eraseOp(oldCall);
273genRead(Location loc, Value baseArrayOp, ArrayAttr index, ConversionPatternRewriter &rewriter) {
274 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
275 return rewriter.create<
ReadArrayOp>(loc, baseArrayOp, ValueRange(readOperands));
280void processInputOperand(
281 Location loc, Value operand, SmallVector<Value> &newOperands,
282 ConversionPatternRewriter &rewriter
284 if (
ArrayType at = splittableArray(operand.getType())) {
286 assert(indices.has_value() &&
"passed earlier hasStaticShape() check");
287 for (ArrayAttr index : indices.value()) {
288 newOperands.push_back(genRead(loc, operand, index, rewriter));
291 newOperands.push_back(operand);
296void processInputOperands(
297 ValueRange operands, MutableOperandRange outputOpRef, Operation *op,
298 ConversionPatternRewriter &rewriter
300 SmallVector<Value> newOperands;
301 for (Value v : operands) {
302 processInputOperand(op->getLoc(), v, newOperands, rewriter);
304 rewriter.modifyOpInPlace(op, [&outputOpRef, &newOperands]() {
305 outputOpRef.assign(ValueRange(newOperands));
311enum Direction : std::uint8_t {
320template <Direction dir>
321inline void rewriteImpl(
323 ConversionPatternRewriter &rewriter
326 Location loc = op.getLoc();
327 MLIRContext *ctx = op.getContext();
336 assert(std::cmp_equal(subIndices->size(), smallType.getNumElements()));
337 for (ArrayAttr indexingTail : subIndices.value()) {
338 SmallVector<Attribute> joined;
339 joined.append(indexAsAttr.begin(), indexAsAttr.end());
340 joined.append(indexingTail.begin(), indexingTail.end());
341 ArrayAttr fullIndex = ArrayAttr::get(ctx, joined);
343 if constexpr (dir == Direction::SMALL_TO_LARGE) {
344 auto init = genRead(loc, smallArr, indexingTail, rewriter);
345 genWrite(loc, largeArr, fullIndex, init, rewriter);
346 }
else if constexpr (dir == Direction::LARGE_TO_SMALL) {
347 auto init = genRead(loc, largeArr, fullIndex, rewriter);
348 genWrite(loc, smallArr, indexingTail, init, rewriter);
356class SplitInsertArrayOp :
public OpConversionPattern<InsertArrayOp> {
358 using OpConversionPattern<
InsertArrayOp>::OpConversionPattern;
361 return !containsSplittableArrayType(op.
getRvalue().getType());
364 LogicalResult match(
InsertArrayOp op)
const override {
return failure(legal(op)); }
367 rewrite(
InsertArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
369 rewriteImpl<SMALL_TO_LARGE>(
370 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, adaptor.getRvalue(),
371 adaptor.getArrRef(), rewriter
373 rewriter.eraseOp(op);
378class SplitExtractArrayOp :
public OpConversionPattern<ExtractArrayOp> {
383 return !containsSplittableArrayType(op.
getResult().getType());
386 LogicalResult match(
ExtractArrayOp op)
const override {
return failure(legal(op)); }
389 ExtractArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
393 auto newArray = rewriter.replaceOpWithNewOp<
CreateArrayOp>(op, at);
394 rewriteImpl<LARGE_TO_SMALL>(
395 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, newArray, adaptor.getArrRef(),
402class SplitInitFromCreateArrayOp :
public OpConversionPattern<CreateArrayOp> {
404 using OpConversionPattern<
CreateArrayOp>::OpConversionPattern;
408 LogicalResult match(
CreateArrayOp op)
const override {
return failure(legal(op)); }
411 rewrite(
CreateArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
415 rewriter.setInsertionPointAfter(op);
416 Location loc = op.getLoc();
418 for (
auto [i, init] : llvm::enumerate(adaptor.getElements())) {
420 std::optional<SmallVector<Value>> multiDimIdxVals =
424 assert(multiDimIdxVals.has_value());
432class SplitArrayInFuncDefOp :
public OpConversionPattern<FuncDefOp> {
434 using OpConversionPattern<
FuncDefOp>::OpConversionPattern;
441 LogicalResult match(
FuncDefOp op)
const override {
return failure(legal(op)); }
443 void rewrite(
FuncDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter)
const override {
446 SmallVector<size_t> originalInputIdxToSize, originalResultIdxToSize;
451 SmallVector<Type> convertInputs(ArrayRef<Type> origTypes)
override {
452 return splitArrayType(origTypes, &originalInputIdxToSize);
454 SmallVector<Type> convertResults(ArrayRef<Type> origTypes)
override {
455 return splitArrayType(origTypes, &originalResultIdxToSize);
457 ArrayAttr convertInputAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes)
override {
464 ArrayAttr convertResultAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes)
override {
476 void processBlockArgs(Block &entryBlock, RewriterBase &rewriter)
override {
477 OpBuilder::InsertionGuard guard(rewriter);
478 rewriter.setInsertionPointToStart(&entryBlock);
480 for (
unsigned i = 0; i < entryBlock.getNumArguments();) {
481 Value oldV = entryBlock.getArgument(i);
482 if (
ArrayType at = splittableArray(oldV.getType())) {
483 Location loc = oldV.getLoc();
486 rewriter.replaceAllUsesWith(oldV, newArray);
488 entryBlock.eraseArgument(i);
493 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
494 for (ArrayAttr subIdx : allIndices.value()) {
495 BlockArgument newArg = entryBlock.insertArgument(i, at.
getElementType(), loc);
496 genWrite(loc, newArray, subIdx, newArg, rewriter);
507 ArrayAttr resultAttrs = op.getAllResultAttrs();
509 return op.getArgNameAttr(i);
510 }, getSplitArrayIndexSuffixes);
512 return getAttrAtIndexWithName(resultAttrs, i, RES_NAME_ATTR_NAME);
513 }, getSplitArrayIndexSuffixes);
516 Impl(op).convert(op, rewriter);
521class SplitArrayInReturnOp :
public OpConversionPattern<ReturnOp> {
523 using OpConversionPattern<
ReturnOp>::OpConversionPattern;
525 inline static bool legal(
ReturnOp op) {
526 return !containsSplittableArrayType(op.
getOperands().getTypes());
529 LogicalResult match(
ReturnOp op)
const override {
return failure(legal(op)); }
531 void rewrite(
ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
537class SplitArrayInCallOp :
public OpConversionPattern<CallOp> {
539 using OpConversionPattern<
CallOp>::OpConversionPattern;
541 inline static bool legal(
CallOp op) {
542 return !containsSplittableArrayType(op.
getArgOperands().getTypes()) &&
543 !containsSplittableArrayType(op.getResultTypes());
546 LogicalResult match(
CallOp op)
const override {
return failure(legal(op)); }
548 void rewrite(
CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
550 CallOp newCall = newCallOpWithSplitResults(op, adaptor, rewriter);
551 processInputOperands(
558class ReplaceKnownArrayLengthOp :
public OpConversionPattern<ArrayLengthOp> {
560 using OpConversionPattern<
ArrayLengthOp>::OpConversionPattern;
563 static std::optional<llvm::APInt> getDimSizeIfKnown(Value dimIdx,
ArrayType baseArrType) {
564 if (splittableArray(baseArrType)) {
566 if (mlir::matchPattern(dimIdx, mlir::m_ConstantInt(&idxAP))) {
567 std::optional<int64_t> signedIdx = idxAP.trySExtValue();
568 if (!signedIdx || *signedIdx < 0) {
573 if (idx >= dimSizes.size()) {
576 Attribute dimSizeAttr = dimSizes[idx];
577 if (mlir::matchPattern(dimSizeAttr, mlir::m_ConstantInt(&idxAP))) {
590 LogicalResult match(
ArrayLengthOp op)
const override {
return failure(legal(op)); }
593 rewrite(
ArrayLengthOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
594 ArrayType arrTy = dyn_cast<ArrayType>(adaptor.getArrRef().getType());
596 std::optional<llvm::APInt> len = getDimSizeIfKnown(adaptor.getDim(), arrTy);
597 assert(len.has_value());
598 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op,
llzk::fromAPInt(len.value()));
603using MemberInfo = std::pair<StringAttr, Type>;
605using LocalMemberReplacementMap = DenseMap<ArrayAttr, MemberInfo>;
607using MemberReplacementMap = DenseMap<StructDefOp, DenseMap<StringAttr, LocalMemberReplacementMap>>;
610class SplitArrayInMemberDefOp :
public OpConversionPattern<MemberDefOp> {
611 SymbolTableCollection &tables;
612 MemberReplacementMap &repMapRef;
615 SplitArrayInMemberDefOp(
616 MLIRContext *ctx, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap
618 : OpConversionPattern<MemberDefOp>(ctx), tables(symTables), repMapRef(memberRepMap) {}
620 inline static bool legal(
MemberDefOp op) {
return !containsSplittableArrayType(op.
getType()); }
622 LogicalResult match(
MemberDefOp op)
const override {
return failure(legal(op)); }
624 void rewrite(
MemberDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter)
const override {
627 LocalMemberReplacementMap &localRepMapRef = repMapRef[inStruct][op.
getSymNameAttr()];
632 assert(subIdxs.has_value());
635 SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct);
636 for (ArrayAttr idx : subIdxs.value()) {
643 localRepMapRef[idx] = std::make_pair(structSymbolTable.insert(newMember), elemTy);
645 rewriter.eraseOp(op);
651 SplitArrayInMemberWriteOp, MemberWriteOp, void *, ArrayAttr> {
657 return !containsSplittableArrayType(op.
getVal().getType());
660 static void *genHeader(
MemberWriteOp, ConversionPatternRewriter &) {
return nullptr; }
663 Location loc,
void *, ArrayAttr idx, MemberInfo newMember, OpAdaptor adaptor,
664 ConversionPatternRewriter &rewriter
666 ReadArrayOp scalarRead = genRead(loc, adaptor.getVal(), idx, rewriter);
668 loc, adaptor.getComponent(), FlatSymbolRefAttr::get(newMember.first), scalarRead
675class SplitArrayInMemberReadOp
677 SplitArrayInMemberReadOp, MemberReadOp, CreateArrayOp, ArrayAttr> {
684 return !containsSplittableArrayType(op.getResult().getType());
689 rewriter.create<
CreateArrayOp>(op.getLoc(), llvm::cast<ArrayType>(op.getType()));
690 rewriter.replaceAllUsesWith(op, newArray);
695 Location loc,
CreateArrayOp newArray, ArrayAttr idx, MemberInfo newMember, OpAdaptor adaptor,
696 ConversionPatternRewriter &rewriter
699 loc, newMember.second, adaptor.getComponent(), newMember.first
701 genWrite(loc, newArray, idx, scalarRead, rewriter);
706static void baseTargetSetup(ConversionTarget &target) {
707 target.addLegalDialect<
713 target.addLegalOp<ModuleOp>();
717class NondetToNewArray :
public OpConversionPattern<NonDetOp> {
718 using OpConversionPattern<
NonDetOp>::OpConversionPattern;
719 LogicalResult matchAndRewrite(
720 NonDetOp nondetOp, OpAdaptor, ConversionPatternRewriter &rewriter
722 if (
auto at = dyn_cast<ArrayType>(nondetOp.getType())) {
731static LogicalResult step0(ModuleOp modOp) {
732 MLIRContext *ctx = modOp.getContext();
733 RewritePatternSet patterns {ctx};
734 patterns.add<NondetToNewArray>(ctx);
735 ConversionTarget target {*ctx};
737 baseTargetSetup(target);
738 target.addDynamicallyLegalOp<
NonDetOp>([](
NonDetOp op) {
return !isa<ArrayType>(op.getType()); });
740 return applyFullConversion(modOp, target, std::move(patterns));
745step1(ModuleOp modOp, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap) {
746 MLIRContext *ctx = modOp.getContext();
748 RewritePatternSet patterns(ctx);
750 patterns.add<SplitArrayInMemberDefOp>(ctx, symTables, memberRepMap);
752 ConversionTarget target(*ctx);
753 baseTargetSetup(target);
754 target.addDynamicallyLegalOp<
MemberDefOp>(SplitArrayInMemberDefOp::legal);
756 LLVM_DEBUG(llvm::dbgs() <<
"Begin step 1: split array-type members\n";);
757 return applyFullConversion(modOp, target, std::move(patterns));
763step2(ModuleOp modOp, SymbolTableCollection &symTables,
const MemberReplacementMap &memberRepMap) {
764 MLIRContext *ctx = modOp.getContext();
766 RewritePatternSet patterns(ctx);
769 SplitInitFromCreateArrayOp,
772 SplitArrayInFuncDefOp,
773 SplitArrayInReturnOp,
775 ReplaceKnownArrayLengthOp
781 SplitArrayInMemberWriteOp,
782 SplitArrayInMemberReadOp
784 >(ctx, symTables, memberRepMap);
786 ConversionTarget target(*ctx);
787 baseTargetSetup(target);
788 target.addDynamicallyLegalOp<
CreateArrayOp>(SplitInitFromCreateArrayOp::legal);
789 target.addDynamicallyLegalOp<
InsertArrayOp>(SplitInsertArrayOp::legal);
790 target.addDynamicallyLegalOp<
ExtractArrayOp>(SplitExtractArrayOp::legal);
791 target.addDynamicallyLegalOp<
FuncDefOp>(SplitArrayInFuncDefOp::legal);
792 target.addDynamicallyLegalOp<
ReturnOp>(SplitArrayInReturnOp::legal);
793 target.addDynamicallyLegalOp<
CallOp>(SplitArrayInCallOp::legal);
794 target.addDynamicallyLegalOp<
ArrayLengthOp>(ReplaceKnownArrayLengthOp::legal);
795 target.addDynamicallyLegalOp<
MemberWriteOp>(SplitArrayInMemberWriteOp::legal);
796 target.addDynamicallyLegalOp<
MemberReadOp>(SplitArrayInMemberReadOp::legal);
798 LLVM_DEBUG(llvm::dbgs() <<
"Begin step 2: update/split other array ops\n";);
799 return applyFullConversion(modOp, target, std::move(patterns));
808static bool mayWriteToIndex(
WriteArrayOp writeOp, ArrayAttr index) {
809 ArrayAttr writeIndex = getIndexAsAttr(writeOp);
810 return !writeIndex || writeIndex == index;
814static bool hasEarlierWriteInBlock(
ReadArrayOp readOp, ArrayAttr readIndex) {
816 for (Operation &op : *readOp->getBlock()) {
817 if (&op == readOp.getOperation()) {
821 if (
auto writeOp = dyn_cast<WriteArrayOp>(&op)) {
822 if (writeOp.
getArrRef() == arrRef && mayWriteToIndex(writeOp, readIndex)) {
830 if (writeOp.getArrRef() == arrRef && mayWriteToIndex(writeOp, readIndex)) {
831 return WalkResult::interrupt();
833 return WalkResult::advance();
834 }).wasInterrupted()) {
842static std::optional<WriteArrayOp> findPrecedingWriteForIfRead(
ReadArrayOp readOp) {
843 ArrayAttr readIndex = getIndexAsAttr(readOp);
849 auto ifOp = readOp->getParentOfType<scf::IfOp>();
850 if (!ifOp || readOp->getBlock()->getParentOp() != ifOp.getOperation()) {
853 if (hasEarlierWriteInBlock(readOp, readIndex)) {
857 Block *ifBlock = ifOp->getBlock();
864 for (Operation &op : *ifBlock) {
865 if (&op == ifOp.getOperation()) {
869 if (
auto writeOp = dyn_cast<WriteArrayOp>(&op)) {
870 if (writeOp.getArrRef() != arrRef) {
874 if (mayWriteToIndex(writeOp, readIndex)) {
875 ArrayAttr writeIndex = getIndexAsAttr(writeOp);
876 replacement = writeIndex == readIndex ? writeOp :
WriteArrayOp();
882 if (op.walk([arrRef, readIndex, &replacement](
WriteArrayOp writeOp) {
883 if (writeOp.getArrRef() == arrRef && mayWriteToIndex(writeOp, readIndex)) {
884 replacement = WriteArrayOp();
885 return WalkResult::interrupt();
887 return WalkResult::advance();
888 }).wasInterrupted()) {
893 return replacement ? std::make_optional(replacement) : std::nullopt;
898static void step3(ModuleOp modOp) {
899 SmallVector<std::pair<ReadArrayOp, Value>> replacements;
901 if (std::optional<WriteArrayOp> writeOp = findPrecedingWriteForIfRead(readOp)) {
902 replacements.emplace_back(readOp, writeOp->getRvalue());
906 for (
auto [readOp, value] : replacements) {
907 readOp.
getResult().replaceAllUsesWith(value);
914 using Base = ArrayToScalarPassBase<PassImpl>;
917 void runOnOperation()
override {
918 ModuleOp module = getOperation();
920 if (failed(step0(module))) {
921 return signalPassFailure();
924 llvm::dbgs() <<
"After step 0:\n";
933 SymbolTableCollection symTables;
934 MemberReplacementMap memberRepMap;
935 if (failed(step1(module, symTables, memberRepMap))) {
936 return signalPassFailure();
939 llvm::dbgs() <<
"After step 1:\n";
943 if (failed(step2(module, symTables, memberRepMap))) {
944 return signalPassFailure();
947 llvm::dbgs() <<
"After step 2:\n";
954 llvm::dbgs() <<
"After step 3:\n";
958 OpPassManager nestedPM(ModuleOp::getOperationName());
967 RemoveUnusedDiscardableAllocationsPassOptions {
972 nestedPM.addPass(createRemoveDeadValuesPass());
973 if (failed(runPipeline(nestedPM, module))) {
978 llvm::dbgs() <<
"After SROA+Mem2Reg pipeline:\n";
Provides SpecializedSROA<AllocOpTy> and SpecializedMem2Reg<AllocOpTy>: pass templates that replicate ...
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
Common implementation for handling MemberWriteOp and MemberReadOp while destructuring an aggregate ty...
::mlir::ArrayAttr indexOperandsToAttributeArray()
Returns the multi-dimensional indices of the array access as an Attribute array or a null pointer if ...
Helper for converting between linear and multi-dimensional indexing with checks to ensure indices are...
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced base array.
::mlir::TypedValue<::mlir::IndexType > getDim()
std::optional<::llvm::SmallVector<::mlir::ArrayAttr > > getSubelementIndices() const
Return a list of all valid indices for this ArrayType.
::mlir::Type getElementType() const
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
static constexpr ::llvm::StringLiteral getOperationName()
::mlir::MutableOperandRange getElementsMutable()
::mlir::Operation::operand_range getElements()
::mlir::TypedValue<::llzk::array::ArrayType > getRvalue()
::mlir::TypedValue<::mlir::Type > getResult()
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
void setPublicAttr(bool newValue=true)
::mlir::StringAttr getSymNameAttr()
::mlir::TypedValue<::mlir::Type > getVal()
::llvm::SmallVector< RangeT > getMapOperands()
::mlir::MutableOperandRange getArgOperandsMutable()
::mlir::Operation::operand_range getArgOperands()
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Required by FunctionOpInterface.
::llvm::ArrayRef<::mlir::Type > getResultTypes()
Required by FunctionOpInterface.
::mlir::MutableOperandRange getOperandsMutable()
::mlir::Operation::operand_range getOperands()
constexpr char ARG_NAME_ATTR_NAME[]
Attribute name for source-level function argument names.
constexpr char RES_NAME_ATTR_NAME[]
Attribute name for source-level function result names.
std::unique_ptr<::mlir::Pass > createRemoveUnusedDiscardableAllocationsPass()
mlir::ArrayAttr replicateFunctionNameAttrsAsNeeded(mlir::ArrayAttr origAttrs, const llvm::SmallVector< size_t > &originalIdxToSize, const llvm::SmallVector< mlir::Type > &newTypes, llvm::StringRef functionNameAttrName, llvm::ArrayRef< std::optional< llvm::StringRef > > origNames={}, llvm::ArrayRef< llvm::StringRef > existingNames={}, llvm::ArrayRef< llvm::SmallVector< std::string > > splitNameSuffixes={})
Expand function arg/result attribute arrays to match a split signature, rewriting name attrs with the...
constexpr T checkedCast(U u) noexcept
std::unique_ptr< SpecializedMem2Reg< AllocOpTy > > createSpecializedMem2RegPass()
int64_t fromAPInt(const llvm::APInt &i)
function::CallOp createCallPreservingInstantiationOperands(mlir::Location loc, mlir::TypeRange newResultTypes, function::CallOp oldCall, llvm::ArrayRef< mlir::ValueRange > mapOperands, mlir::ValueRange argOperands, mlir::ConversionPatternRewriter &rewriter)
Rebuild a function.call while preserving explicit instantiation state from oldCall.
SplitFunctionNameInfo collectSplitFunctionNameInfo(mlir::ArrayRef< mlir::Type > origTypes, GetNameAttrFn &&getNameAttr, GetSplitSuffixesFn &&getSplitSuffixes)
Collect function arg/result names and split suffixes from a list of original types.
std::unique_ptr< SpecializedSROA< AllocOpTy > > createSpecializedSROAPass()
Cached function arg/result names and split suffixes used while rewriting a function signature.
llvm::SmallVector< std::optional< llvm::StringRef > > originalNames
llvm::SmallVector< llvm::StringRef > existingNames
llvm::SmallVector< llvm::SmallVector< std::string > > splitNameSuffixes