84#include <mlir/Dialect/SCF/IR/SCF.h>
85#include <mlir/Pass/PassManager.h>
86#include <mlir/Transforms/DialectConversion.h>
87#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
88#include <mlir/Transforms/Passes.h>
90#include <llvm/ADT/DenseMapInfo.h>
91#include <llvm/ADT/STLExtras.h>
92#include <llvm/Support/Debug.h>
96#define GEN_PASS_DEF_PODTOSCALARPASS
106#define DEBUG_TYPE "llzk-pod-to-scalar"
112 SmallVector<StringAttr> nameList;
114 RecordChain() =
default;
116 explicit RecordChain(ArrayRef<StringAttr>
names) : nameList(
names.begin(),
names.end()) {}
118 bool operator==(
const RecordChain &other)
const {
return nameList == other.nameList; }
124template <>
struct DenseMapInfo<RecordChain> {
126 return RecordChain {{DenseMapInfo<StringAttr>::getEmptyKey()}};
130 return RecordChain {{DenseMapInfo<StringAttr>::getTombstoneKey()}};
134 return llvm::hash_combine_range(chain.nameList.begin(), chain.nameList.end());
137 static bool isEqual(
const RecordChain &lhs,
const RecordChain &rhs) {
return lhs == rhs; }
147inline static PodType splittablePod(Type t) {
148 if (
PodType pt = dyn_cast<PodType>(t)) {
149 return splittablePod(pt);
157inline static bool containsSplittablePodType(ArrayRef<Type> types) {
158 for (Type t : types) {
159 if (splittablePod(t)) {
168template <
typename T>
static bool containsSplittablePodType(ValueTypeRange<T> types) {
169 for (Type t : types) {
170 if (splittablePod(t)) {
179size_t splitPodTypeTo(Type t, SmallVector<Type> &collect) {
180 if (
PodType pt = splittablePod(t)) {
181 auto records = pt.getRecords();
182 for (RecordAttr record : records) {
183 collect.push_back(record.getType());
185 return records.size();
187 collect.push_back(t);
193template <
typename TypeCollection>
194inline void splitPodTypeTo(
195 TypeCollection types, SmallVector<Type> &collect, SmallVector<size_t> *originalIdxToSize
197 for (Type t : types) {
198 size_t count = splitPodTypeTo(t, collect);
199 if (originalIdxToSize) {
200 originalIdxToSize->push_back(count);
207template <
typename TypeCollection>
208inline SmallVector<Type>
209splitPodType(TypeCollection types, SmallVector<size_t> *originalIdxToSize =
nullptr) {
210 SmallVector<Type> collect;
211 splitPodTypeTo(types, collect, originalIdxToSize);
217genRead(Location loc, Value podRef, StringAttr recordName, OpBuilder &rewriter) {
219 llvm::cast<PodType>(podRef.getType()).getRecordMap().lookup(recordName.getValue());
220 return rewriter.create<
ReadPodOp>(loc, resultType, podRef, recordName);
225genWrite(Location loc, Value podRef, StringAttr recordName, Value value, OpBuilder &rewriter) {
226 return rewriter.create<
WritePodOp>(loc, podRef, recordName, value);
230static SmallVector<std::string> getSplitRecordNameSuffixes(Type type) {
231 SmallVector<std::string> suffixes;
232 if (
PodType pt = splittablePod(type)) {
233 suffixes.reserve(pt.getRecords().size());
234 for (RecordAttr record : pt.getRecords()) {
235 StringRef name = record.getName().getValue();
237 result.reserve(name.size() + 1);
238 result.push_back(
'.');
239 result.append(name.data(), name.size());
240 suffixes.push_back(result);
248static void processInputOperand(
249 Location loc, Value operand, SmallVector<Value> &newOperands,
250 ConversionPatternRewriter &rewriter
252 if (
PodType pt = splittablePod(operand.getType())) {
253 for (RecordAttr record : pt.getRecords()) {
254 newOperands.push_back(genRead(loc, operand, record.getName(), rewriter));
257 newOperands.push_back(operand);
263static void processInputOperands(
264 ValueRange operands, MutableOperandRange outputOpRef, Operation *op,
265 ConversionPatternRewriter &rewriter
267 SmallVector<Value> newOperands;
268 for (Value v : operands) {
269 processInputOperand(op->getLoc(), v, newOperands, rewriter);
271 rewriter.modifyOpInPlace(op, [&outputOpRef, &newOperands]() {
272 outputOpRef.assign(ValueRange(newOperands));
277inline static void baseTargetSetup(ConversionTarget &target) {
278 target.addLegalDialect<
284 target.addLegalOp<ModuleOp>();
289class NondetToNewPod :
public OpConversionPattern<NonDetOp> {
290 using OpConversionPattern<NonDetOp>::OpConversionPattern;
291 LogicalResult matchAndRewrite(
292 NonDetOp nondetOp, OpAdaptor, ConversionPatternRewriter &rewriter
294 if (
auto pt = dyn_cast<PodType>(nondetOp.getType())) {
295 rewriter.replaceOpWithNewOp<NewPodOp>(nondetOp, pt);
303static LogicalResult step0(ModuleOp modOp) {
304 MLIRContext *ctx = modOp.getContext();
305 RewritePatternSet patterns {ctx};
306 patterns.add<NondetToNewPod>(ctx);
307 ConversionTarget target {*ctx};
309 baseTargetSetup(target);
310 target.addDynamicallyLegalOp<
NonDetOp>([](
NonDetOp op) {
return !isa<PodType>(op.getType()); });
312 return applyFullConversion(modOp, target, std::move(patterns));
316using MemberInfo = std::pair<StringAttr, Type>;
318using LocalMemberReplacementMap = DenseMap<RecordChain, MemberInfo>;
320using MemberReplacementMap = DenseMap<StructDefOp, DenseMap<StringAttr, LocalMemberReplacementMap>>;
324getFlattenedMemberName(MLIRContext *ctx, StringAttr memberName, ArrayRef<StringAttr> recordChain) {
325 std::string flatName;
326 llvm::raw_string_ostream os(flatName);
327 os << memberName.getValue();
328 for (StringAttr recordName : recordChain) {
329 os <<
'_' << recordName.getValue();
331 return StringAttr::get(ctx, flatName);
335static void flattenPodMemberIntoLeaves(
337 LocalMemberReplacementMap &localRepMapRef, SymbolTable &structSymbolTable,
338 ConversionPatternRewriter &rewriter
340 for (RecordAttr record : podTy.
getRecords()) {
341 recordChain.push_back(record.getName());
342 if (
PodType nestedPodTy = dyn_cast<PodType>(record.getType())) {
343 flattenPodMemberIntoLeaves(
344 originalMember, nestedPodTy, recordChain, localRepMapRef, structSymbolTable, rewriter
346 recordChain.pop_back();
350 StringAttr name = getFlattenedMemberName(
351 originalMember.getContext(), originalMember.
getSymNameAttr(), recordChain
353 Type ty = record.getType();
355 originalMember.getLoc(), name, ty, originalMember.
getSignal(), originalMember.
getColumn()
358 localRepMapRef[RecordChain(recordChain)] =
359 std::make_pair(structSymbolTable.insert(newMember), ty);
360 recordChain.pop_back();
368class SplitPodInMemberDefOp :
public OpConversionPattern<MemberDefOp> {
369 SymbolTableCollection &tables;
370 MemberReplacementMap &repMapRef;
373 SplitPodInMemberDefOp(
374 MLIRContext *ctx, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap
376 : OpConversionPattern<MemberDefOp>(ctx), tables(symTables), repMapRef(memberRepMap) {}
378 inline static bool legal(MemberDefOp op) {
return !splittablePod(op.
getType()); }
380 LogicalResult match(MemberDefOp op)
const override {
return failure(legal(op)); }
383 rewrite(MemberDefOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
384 StructDefOp inStruct = op->getParentOfType<StructDefOp>();
386 LocalMemberReplacementMap &localRepMapRef = repMapRef[inStruct][op.
getSymNameAttr()];
388 PodType podTy = llvm::cast<PodType>(adaptor.getType());
390 SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct);
391 SmallVector<StringAttr> recordChain;
392 flattenPodMemberIntoLeaves(op, podTy, recordChain, localRepMapRef, structSymbolTable, rewriter);
393 rewriter.eraseOp(op);
399step1(ModuleOp modOp, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap) {
400 MLIRContext *ctx = modOp.getContext();
402 RewritePatternSet patterns(ctx);
404 patterns.add<SplitPodInMemberDefOp>(ctx, symTables, memberRepMap);
406 ConversionTarget target(*ctx);
407 baseTargetSetup(target);
408 target.addDynamicallyLegalOp<
MemberDefOp>(SplitPodInMemberDefOp::legal);
410 LLVM_DEBUG(llvm::dbgs() <<
"Begin step 1: split pod-type members\n";);
411 return applyFullConversion(modOp, target, std::move(patterns));
415class SplitInitFromNewPodOp :
public OpConversionPattern<NewPodOp> {
417 using OpConversionPattern<NewPodOp>::OpConversionPattern;
421 LogicalResult match(NewPodOp op)
const override {
return failure(legal(op)); }
423 void rewrite(NewPodOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
425 rewriter.setInsertionPointAfter(op);
426 Location loc = op.getLoc();
427 for (
auto [name, init] :
428 llvm::zip_equal(adaptor.getInitializedRecords(), adaptor.getInitialValues())) {
430 rewriter.create<WritePodOp>(loc, op.
getResult(), llvm::cast<StringAttr>(name), init);
433 rewriter.modifyOpInPlace(op, [&op]() {
447class SplitPodInFuncDefOp :
public OpConversionPattern<FuncDefOp> {
449 using OpConversionPattern<FuncDefOp>::OpConversionPattern;
451 inline static bool legal(FuncDefOp op) {
456 LogicalResult match(FuncDefOp op)
const override {
return failure(legal(op)); }
458 void rewrite(FuncDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter)
const override {
460 class Impl :
public FunctionTypeConverter {
461 SmallVector<size_t> originalInputIdxToSize, originalResultIdxToSize;
462 SplitFunctionNameInfo inputNameInfo;
463 SplitFunctionNameInfo resultNameInfo;
466 SmallVector<Type> convertInputs(ArrayRef<Type> origTypes)
override {
467 return splitPodType(origTypes, &originalInputIdxToSize);
469 SmallVector<Type> convertResults(ArrayRef<Type> origTypes)
override {
470 return splitPodType(origTypes, &originalResultIdxToSize);
472 ArrayAttr convertInputAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes)
override {
479 ArrayAttr convertResultAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes)
override {
491 void processBlockArgs(Block &entryBlock, RewriterBase &rewriter)
override {
492 OpBuilder::InsertionGuard guard(rewriter);
493 rewriter.setInsertionPointToStart(&entryBlock);
495 for (
unsigned i = 0; i < entryBlock.getNumArguments();) {
496 Value oldV = entryBlock.getArgument(i);
497 if (PodType pt = splittablePod(oldV.getType())) {
498 Location loc = oldV.getLoc();
500 auto newPod = rewriter.create<NewPodOp>(loc, pt);
501 rewriter.replaceAllUsesWith(oldV, newPod);
503 entryBlock.eraseArgument(i);
506 for (RecordAttr record : pt.getRecords()) {
507 BlockArgument newArg = entryBlock.insertArgument(i, record.getType(), loc);
508 genWrite(loc, newPod, record.getName(), newArg, rewriter);
520 return op.getArgNameAttr(i);
521 }, getSplitRecordNameSuffixes);
523 op.
getResultTypes(), [resultAttrs = op.getAllResultAttrs()](
unsigned i) {
524 return getAttrAtIndexWithName(resultAttrs, i, RES_NAME_ATTR_NAME);
525 }, getSplitRecordNameSuffixes
529 Impl(op).convert(op, rewriter);
538class SplitPodInReturnOp :
public OpConversionPattern<ReturnOp> {
540 using OpConversionPattern<ReturnOp>::OpConversionPattern;
542 inline static bool legal(ReturnOp op) {
543 return !containsSplittablePodType(op.
getOperands().getTypes());
546 LogicalResult match(ReturnOp op)
const override {
return failure(legal(op)); }
548 void rewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
554static CallOp newCallOpWithSplitResults(
557 OpBuilder::InsertionGuard guard(rewriter);
558 rewriter.setInsertionPointAfter(oldCall);
560 Operation::result_range oldResults = oldCall.getResults();
562 oldCall.getLoc(), splitPodType(oldResults.getTypes()), oldCall, adaptor.
getMapOperands(),
566 auto newResults = newCall.getResults().begin();
567 for (Value oldVal : oldResults) {
568 if (
PodType pt = splittablePod(oldVal.getType())) {
569 Location loc = oldVal.getLoc();
571 auto newPod = rewriter.create<
NewPodOp>(loc, pt);
572 rewriter.replaceAllUsesWith(oldVal, newPod);
575 for (RecordAttr record : pt.getRecords()) {
576 genWrite(loc, newPod, record.getName(), *newResults, rewriter);
580 rewriter.replaceAllUsesWith(oldVal, *newResults);
585 rewriter.eraseOp(oldCall);
596class SplitPodInCallOp :
public OpConversionPattern<CallOp> {
598 using OpConversionPattern<CallOp>::OpConversionPattern;
600 inline static bool legal(CallOp op) {
601 return !containsSplittablePodType(op.
getArgOperands().getTypes()) &&
602 !containsSplittablePodType(op.getResultTypes());
605 LogicalResult match(CallOp op)
const override {
return failure(legal(op)); }
607 void rewrite(CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
609 CallOp newCall = newCallOpWithSplitResults(op, adaptor, rewriter);
610 processInputOperands(
618genReadAlongPath(Location loc, Value podRef, RecordChain recordChain, OpBuilder &rewriter) {
619 Value value = podRef;
620 for (StringAttr attr : recordChain.nameList) {
621 value = genRead(loc, value, attr, rewriter);
627struct RebuildPodReadState {
629 DenseMap<RecordChain, Value> leafValues;
633static Value rebuildFlattenedPodRecord(
634 Location loc, Type recordType, SmallVectorImpl<StringAttr> &recordChain,
635 const DenseMap<RecordChain, Value> &leafValues, ConversionPatternRewriter &rewriter
637 if (
PodType nestedPodTy = dyn_cast<PodType>(recordType)) {
639 for (RecordAttr record : nestedPodTy.getRecords()) {
640 recordChain.push_back(record.getName());
642 rebuildFlattenedPodRecord(loc, record.getType(), recordChain, leafValues, rewriter);
643 genWrite(loc, nestedPod, record.getName(), recordValue, rewriter);
644 recordChain.pop_back();
649 auto it = leafValues.find(RecordChain(recordChain));
650 assert(it != leafValues.end() &&
"missing flattened POD leaf value");
656 SplitPodInMemberWriteOp, MemberWriteOp, void *, RecordChain> {
658 using SplitAggregateInMemberRefOp<
659 SplitPodInMemberWriteOp, MemberWriteOp,
void *, RecordChain>::SplitAggregateInMemberRefOp;
661 static bool legal(MemberWriteOp op) {
return !containsSplittablePodType(op.
getVal().getType()); }
663 static void *genHeader(MemberWriteOp, ConversionPatternRewriter &) {
return nullptr; }
666 Location loc,
void *&, RecordChain
id, MemberInfo newMember, OpAdaptor adaptor,
667 ConversionPatternRewriter &rewriter
669 Value scalarRead = genReadAlongPath(loc, adaptor.getVal(),
id, rewriter);
670 rewriter.create<MemberWriteOp>(
671 loc, adaptor.getComponent(), FlatSymbolRefAttr::get(newMember.first), scalarRead
677class SplitPodInMemberReadOp
679 SplitPodInMemberReadOp, MemberReadOp, RebuildPodReadState, RecordChain> {
681 using SplitAggregateInMemberRefOp<
682 SplitPodInMemberReadOp, MemberReadOp, RebuildPodReadState,
683 RecordChain>::SplitAggregateInMemberRefOp;
685 static bool legal(MemberReadOp op) {
686 return !containsSplittablePodType(op.getResult().getType());
689 static RebuildPodReadState genHeader(MemberReadOp op, ConversionPatternRewriter &rewriter) {
690 RebuildPodReadState state;
691 state.pod = rewriter.create<NewPodOp>(op.getLoc(), llvm::cast<PodType>(op.getType()));
692 rewriter.replaceAllUsesWith(op, state.pod);
697 Location loc, RebuildPodReadState &state, RecordChain
id, MemberInfo newMember,
698 OpAdaptor adaptor, ConversionPatternRewriter &rewriter
700 Value scalarRead = rewriter.create<MemberReadOp>(
701 loc, newMember.second, adaptor.getComponent(), newMember.first
703 state.leafValues[id] = scalarRead;
706 static void finalize(
707 MemberReadOp op, RebuildPodReadState &state, OpAdaptor, ConversionPatternRewriter &rewriter
709 auto podTy = llvm::cast<PodType>(op.getType());
710 SmallVector<StringAttr> recordChain;
711 for (RecordAttr record : podTy.
getRecords()) {
712 recordChain.push_back(record.getName());
713 Value recordValue = rebuildFlattenedPodRecord(
714 op.getLoc(), record.getType(), recordChain, state.leafValues, rewriter
716 genWrite(op.getLoc(), state.pod, record.getName(), recordValue, rewriter);
717 recordChain.pop_back();
725step2(ModuleOp modOp, SymbolTableCollection &symTables,
const MemberReplacementMap &memberRepMap) {
726 MLIRContext *ctx = modOp.getContext();
728 RewritePatternSet patterns(ctx);
731 SplitInitFromNewPodOp,
740 SplitPodInMemberWriteOp,
741 SplitPodInMemberReadOp
743 >(ctx, symTables, memberRepMap);
745 ConversionTarget target(*ctx);
746 baseTargetSetup(target);
747 target.addDynamicallyLegalOp<
NewPodOp>(SplitInitFromNewPodOp::legal);
748 target.addDynamicallyLegalOp<
FuncDefOp>(SplitPodInFuncDefOp::legal);
749 target.addDynamicallyLegalOp<
ReturnOp>(SplitPodInReturnOp::legal);
750 target.addDynamicallyLegalOp<
CallOp>(SplitPodInCallOp::legal);
751 target.addDynamicallyLegalOp<
MemberWriteOp>(SplitPodInMemberWriteOp::legal);
752 target.addDynamicallyLegalOp<
MemberReadOp>(SplitPodInMemberReadOp::legal);
754 LLVM_DEBUG(llvm::dbgs() <<
"Begin step 2: update/split other pod ops\n";);
755 return applyFullConversion(modOp, target, std::move(patterns));
759inline static StringAttr getRecordNameAsStringAttr(
ReadPodOp readOp) {
764inline static StringAttr getRecordNameAsStringAttr(
WritePodOp writeOp) {
769inline static bool isSamePodRecord(
ReadPodOp readOp, Value podRef, StringAttr recordName) {
770 return readOp.
getPodRef() == podRef && getRecordNameAsStringAttr(readOp) == recordName;
774inline static bool isSamePodRecord(
WritePodOp writeOp, Value podRef, StringAttr recordName) {
775 return writeOp.
getPodRef() == podRef && getRecordNameAsStringAttr(writeOp) == recordName;
779static bool hasNestedWriteToRecord(Operation &op, Value podRef, StringAttr recordName) {
780 return walkContainsMatch<WritePodOp>(op, [&](
WritePodOp writeOp) {
781 return writeOp.getOperation() != &op && isSamePodRecord(writeOp, podRef, recordName);
786static bool hasReadFromRecord(Operation &op, Value podRef, StringAttr recordName) {
787 return walkContainsMatch<ReadPodOp>(op, [&](
ReadPodOp readOp) {
788 return isSamePodRecord(readOp, podRef, recordName);
793static bool hasValueUse(Operation &op, Value value) {
794 return walkContainsMatch<Operation *>(op, [&value](Operation *nestedOp) {
795 return llvm::is_contained(nestedOp->getOperands(), value);
800static bool hasEarlierWriteInBlock(
ReadPodOp readOp) {
802 StringAttr recordName = getRecordNameAsStringAttr(readOp);
804 for (Operation &op : *readOp->getBlock()) {
805 if (&op == readOp.getOperation()) {
809 if (
auto writeOp = dyn_cast<WritePodOp>(&op)) {
810 if (isSamePodRecord(writeOp, podRef, recordName)) {
816 if (hasNestedWriteToRecord(op, podRef, recordName)) {
827static bool isValueDefinedInside(Operation *ancestor, Value value) {
828 if (Operation *defOp = value.getDefiningOp()) {
829 return ancestor->isAncestor(defOp);
832 auto blockArg = llvm::dyn_cast<BlockArgument>(value);
833 Operation *parentOp = blockArg.getOwner()->getParentOp();
834 return parentOp && ancestor->isAncestor(parentOp);
839 auto ifOp = readOp->getParentOfType<scf::IfOp>();
840 if (!ifOp || readOp->getBlock()->getParentOp() != ifOp.getOperation()) {
843 if (hasEarlierWriteInBlock(readOp)) {
847 Block *ifBlock = ifOp->getBlock();
853 StringAttr recordName = getRecordNameAsStringAttr(readOp);
855 for (Operation &op : *ifBlock) {
856 if (&op == ifOp.getOperation()) {
860 if (
auto writeOp = dyn_cast<WritePodOp>(&op)) {
861 if (isSamePodRecord(writeOp, podRef, recordName)) {
862 replacement = writeOp;
867 if (hasNestedWriteToRecord(op, podRef, recordName)) {
868 replacement =
nullptr;
876class ReplaceIfReadPattern final :
public OpRewritePattern<ReadPodOp> {
878 using OpRewritePattern<ReadPodOp>::OpRewritePattern;
880 LogicalResult matchAndRewrite(ReadPodOp readOp, PatternRewriter &rewriter)
const override {
881 auto ifOp = readOp->getParentOfType<scf::IfOp>();
882 if (!ifOp || readOp->getBlock()->getParentOp() != ifOp.getOperation()) {
885 if (isValueDefinedInside(ifOp, readOp.
getPodRef()) || hasEarlierWriteInBlock(readOp)) {
889 if (WritePodOp writeOp = findPrecedingWriteForIfRead(readOp)) {
890 rewriter.replaceOp(readOp, writeOp.
getValue());
894 rewriter.setInsertionPoint(ifOp);
897 genRead(readOp.getLoc(), readOp.
getPodRef(), getRecordNameAsStringAttr(readOp), rewriter)
917class FoldIfCarriedPodReadAfterWritePattern final :
public OpRewritePattern<ReadPodOp> {
919 using OpRewritePattern<ReadPodOp>::OpRewritePattern;
921 LogicalResult matchAndRewrite(ReadPodOp readOp, PatternRewriter &rewriter)
const override {
922 auto podRes = dyn_cast<OpResult>(readOp.
getPodRef());
927 auto ifOp = dyn_cast<scf::IfOp>(podRes.getOwner());
932 auto writeOp = dyn_cast_or_null<WritePodOp>(readOp->getPrevNode());
933 if (!writeOp || getRecordNameAsStringAttr(writeOp) != getRecordNameAsStringAttr(readOp)) {
937 auto valueRes = dyn_cast<OpResult>(writeOp.
getValue());
938 if (!valueRes || valueRes.getOwner() != ifOp.getOperation()) {
943 unsigned podResultIndex = podRes.getResultNumber();
945 auto thenYield = dyn_cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
946 if (!thenYield || thenYield.getOperand(podResultIndex) != carriedPod) {
950 Region &elseRegion = ifOp.getElseRegion();
951 if (Block *elseBlock = elseRegion.empty() ?
nullptr : &elseRegion.front()) {
952 auto elseYield = dyn_cast<scf::YieldOp>(elseBlock->getTerminator());
953 if (!elseYield || elseYield.getOperand(podResultIndex) != carriedPod) {
958 rewriter.replaceOp(readOp, valueRes);
969 StringAttr recordName;
971 WritePodOp thenWrite;
972 WritePodOp elseWrite;
978lookupSlot(SmallVectorImpl<IfWriteSlot> &slots, Value podRef, StringAttr recordName) {
979 for (IfWriteSlot &slot : slots) {
980 if (slot.podRef == podRef && slot.recordName == recordName) {
988static IfWriteSlot &getOrCreateSlot(
989 SmallVectorImpl<IfWriteSlot> &slots, Value podRef, StringAttr recordName, Type type
991 if (IfWriteSlot *slot = lookupSlot(slots, podRef, recordName)) {
994 slots.push_back(IfWriteSlot {podRef, recordName, type,
nullptr,
nullptr, Value()});
999static Block *getElseBlockOrNull(scf::IfOp ifOp) {
1000 return ifOp.getElseRegion().empty() ? nullptr : &ifOp.getElseRegion().front();
1005collectDirectWrites(Block *block,
bool isThenBlock, SmallVectorImpl<IfWriteSlot> &slots) {
1010 for (Operation &op : *block) {
1011 if (op.hasTrait<OpTrait::IsTerminator>()) {
1015 auto writeOp = dyn_cast<WritePodOp>(&op);
1020 IfWriteSlot &slot = getOrCreateSlot(
1021 slots, writeOp.
getPodRef(), getRecordNameAsStringAttr(writeOp), writeOp.
getValue().getType()
1024 slot.thenWrite = writeOp;
1026 slot.elseWrite = writeOp;
1035static bool branchSlotCanBeLifted(Block *block, Value podRef, StringAttr recordName) {
1040 bool seenDirectWrite =
false;
1041 for (Operation &op : *block) {
1042 if (op.hasTrait<OpTrait::IsTerminator>()) {
1046 if (
auto writeOp = dyn_cast<WritePodOp>(&op)) {
1047 if (isSamePodRecord(writeOp, podRef, recordName)) {
1048 seenDirectWrite =
true;
1053 if (hasNestedWriteToRecord(op, podRef, recordName)) {
1056 if (seenDirectWrite && (hasReadFromRecord(op, podRef, recordName) || hasValueUse(op, podRef))) {
1064static bool isLiftedWrite(Operation &op, ArrayRef<IfWriteSlot> slots) {
1065 auto writeOp = dyn_cast<WritePodOp>(&op);
1066 return writeOp && llvm::any_of(slots, [&writeOp](
const IfWriteSlot &slot) {
1067 return isSamePodRecord(writeOp, slot.podRef, slot.recordName);
1072static scf::YieldOp getYieldOp(Block &block) {
1073 auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator());
1074 assert(yieldOp &&
"expected scf.if branch to terminate with scf.yield");
1079static void dropTerminatorIfPresent(Block &block) {
1080 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) {
1081 block.back().erase();
1087moveBranchWithoutLiftedWrites(Block *srcBlock, Block &destBlock, ArrayRef<IfWriteSlot> slots) {
1089 for (
auto it = srcBlock->begin(), end = srcBlock->end(); it != end;) {
1090 Operation &op = *it++;
1091 if (op.hasTrait<OpTrait::IsTerminator>() || isLiftedWrite(op, slots)) {
1094 op.moveBefore(&destBlock, destBlock.end());
1101static void appendYield(
1102 Location loc, Block &block, ValueRange priorYieldValues, ArrayRef<IfWriteSlot> slots,
1103 bool isThenBlock, OpBuilder &builder
1105 SmallVector<Value> yieldValues = llvm::to_vector(priorYieldValues);
1106 llvm::append_range(yieldValues, llvm::map_range(slots, [isThenBlock](
const IfWriteSlot &slot) {
1107 WritePodOp writeOp = isThenBlock ? slot.thenWrite : slot.elseWrite;
1108 return writeOp ? writeOp.
getValue() : slot.incomingValue;
1111 builder.setInsertionPointToEnd(&block);
1112 builder.create<scf::YieldOp>(loc, yieldValues);
1122 StringAttr recordName;
1126 bool matches(Value findPodRef, StringAttr findRecordName)
const {
1127 return this->podRef == findPodRef && this->recordName == findRecordName;
1133lookupLoopSlot(SmallVectorImpl<LoopPodSlot> &slots, Value podRef, StringAttr recordName) {
1134 auto it = llvm::find_if(slots, [&podRef, &recordName](
const LoopPodSlot &slot) {
1135 return slot.matches(podRef, recordName);
1137 return it == slots.end() ? nullptr : &*it;
1141static bool hasLoopSlot(ArrayRef<LoopPodSlot> slots, Value podRef, StringAttr recordName) {
1142 auto it = llvm::find_if(slots, [&podRef, &recordName](
const LoopPodSlot &slot) {
1143 return slot.matches(podRef, recordName);
1145 return it != slots.end();
1149static LoopPodSlot &getOrCreateLoopSlot(
1150 SmallVectorImpl<LoopPodSlot> &slots, Value podRef, StringAttr recordName, Type type
1152 if (LoopPodSlot *slot = lookupLoopSlot(slots, podRef, recordName)) {
1155 slots.push_back(LoopPodSlot {podRef, recordName, type});
1156 return slots.back();
1160static std::optional<size_t>
1161findLoopSlotIndex(ArrayRef<LoopPodSlot> slots, Value podRef, StringAttr recordName) {
1162 for (
auto [idx, slot] : llvm::enumerate(slots)) {
1163 if (slot.podRef == podRef && slot.recordName == recordName) {
1167 return std::nullopt;
1173collectDirectLoopPodSlots(Block &block, Operation *ancestor, SmallVectorImpl<LoopPodSlot> &slots) {
1174 for (Operation &op : block) {
1175 if (
auto readOp = dyn_cast<ReadPodOp>(&op)) {
1176 if (!isValueDefinedInside(ancestor, readOp.
getPodRef())) {
1177 getOrCreateLoopSlot(
1178 slots, readOp.
getPodRef(), getRecordNameAsStringAttr(readOp), readOp.getType()
1184 if (
auto writeOp = dyn_cast<WritePodOp>(&op)) {
1185 if (!isValueDefinedInside(ancestor, writeOp.
getPodRef())) {
1186 getOrCreateLoopSlot(
1187 slots, writeOp.
getPodRef(), getRecordNameAsStringAttr(writeOp),
1196static bool opUsesTrackedPodRefDirectly(Operation &op, ArrayRef<LoopPodSlot> slots) {
1197 return llvm::any_of(op.getOperands(), [&slots](Value operand) {
1198 return llvm::any_of(slots, [&operand](const LoopPodSlot &slot) {
1199 return slot.podRef == operand;
1205static bool hasNestedTrackedPodAccess(Operation &op, ArrayRef<LoopPodSlot> slots) {
1207 .walk([&op, &slots](Operation *nestedOp) {
1208 if (nestedOp == &op) {
1209 return WalkResult::advance();
1212 if (
auto readOp = dyn_cast<ReadPodOp>(nestedOp)) {
1213 if (hasLoopSlot(slots, readOp.
getPodRef(), getRecordNameAsStringAttr(readOp))) {
1214 return WalkResult::interrupt();
1216 return WalkResult::advance();
1219 if (
auto writeOp = dyn_cast<WritePodOp>(nestedOp)) {
1220 if (hasLoopSlot(slots, writeOp.
getPodRef(), getRecordNameAsStringAttr(writeOp))) {
1221 return WalkResult::interrupt();
1224 return WalkResult::advance();
1225 }).wasInterrupted();
1230static bool hasUnliftableLoopPodUses(Block &block, ArrayRef<LoopPodSlot> slots) {
1231 for (Operation &op : block) {
1232 if (isa<ReadPodOp, WritePodOp>(op)) {
1235 if (opUsesTrackedPodRefDirectly(op, slots) || hasNestedTrackedPodAccess(op, slots)) {
1245class LiftPodWritesFromIfBlocksPattern final :
public OpRewritePattern<scf::IfOp> {
1247 using OpRewritePattern<scf::IfOp>::OpRewritePattern;
1249 LogicalResult matchAndRewrite(scf::IfOp ifOp, PatternRewriter &rewriter)
const override {
1250 SmallVector<IfWriteSlot> slots;
1251 Block &thenBlock = *ifOp.thenBlock();
1252 Block *elseBlock = getElseBlockOrNull(ifOp);
1253 collectDirectWrites(&thenBlock,
true, slots);
1254 collectDirectWrites(elseBlock,
false, slots);
1255 if (slots.empty()) {
1259 llvm::erase_if(slots, [&](
const IfWriteSlot &slot) {
1260 return isValueDefinedInside(ifOp, slot.podRef) ||
1261 !branchSlotCanBeLifted(&thenBlock, slot.podRef, slot.recordName) ||
1262 !branchSlotCanBeLifted(elseBlock, slot.podRef, slot.recordName);
1264 if (slots.empty()) {
1268 for (IfWriteSlot &slot : slots) {
1269 if (slot.thenWrite && slot.elseWrite) {
1272 rewriter.setInsertionPoint(ifOp);
1273 slot.incomingValue =
1274 genRead(ifOp.getLoc(), slot.podRef, slot.recordName, rewriter).getResult();
1277 SmallVector<Type> resultTypes = llvm::to_vector(ifOp.getResultTypes());
1278 llvm::append_range(resultTypes, llvm::map_range(slots, [](
auto slot) {
return slot.type; }));
1280 SmallVector<Value> originalThenYields;
1281 if (!ifOp.getResults().empty()) {
1282 scf::YieldOp thenYieldOp = getYieldOp(thenBlock);
1283 originalThenYields.append(thenYieldOp.getOperands().begin(), thenYieldOp.getOperands().end());
1286 SmallVector<Value> originalElseYields;
1287 if (elseBlock && !ifOp.getResults().empty()) {
1288 scf::YieldOp elseYieldOp = getYieldOp(*elseBlock);
1289 originalElseYields.append(elseYieldOp.getOperands().begin(), elseYieldOp.getOperands().end());
1292 rewriter.setInsertionPoint(ifOp);
1293 auto newIf = rewriter.create<scf::IfOp>(ifOp.getLoc(), resultTypes, ifOp.getCondition(),
true);
1294 Block &newThenBlock = *newIf.thenBlock();
1295 Block &newElseBlock = *newIf.elseBlock();
1296 dropTerminatorIfPresent(newThenBlock);
1297 dropTerminatorIfPresent(newElseBlock);
1299 moveBranchWithoutLiftedWrites(&thenBlock, newThenBlock, slots);
1300 moveBranchWithoutLiftedWrites(elseBlock, newElseBlock, slots);
1301 appendYield(ifOp.getLoc(), newThenBlock, originalThenYields, slots,
true, rewriter);
1302 appendYield(ifOp.getLoc(), newElseBlock, originalElseYields, slots,
false, rewriter);
1304 rewriter.setInsertionPointAfter(newIf);
1305 unsigned originalResultCount = ifOp.getNumResults();
1306 for (
auto [idx, slot] : llvm::enumerate(slots)) {
1308 ifOp.getLoc(), slot.podRef, slot.recordName, newIf.getResult(originalResultCount + idx),
1313 rewriter.replaceOp(ifOp, newIf.getResults().take_front(originalResultCount));
1320class LiftPodAccessesFromForLoopPattern final :
public OpRewritePattern<scf::ForOp> {
1322 using OpRewritePattern<scf::ForOp>::OpRewritePattern;
1324 LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter)
const override {
1325 Block &body = *forOp.getBody();
1326 SmallVector<LoopPodSlot> slots;
1327 collectDirectLoopPodSlots(body, forOp.getOperation(), slots);
1328 if (slots.empty() || hasUnliftableLoopPodUses(body, slots)) {
1332 Location loc = forOp.getLoc();
1334 SmallVector<Value> newInitArgs = llvm::to_vector(forOp.getInitArgs());
1335 rewriter.setInsertionPoint(forOp);
1336 for (
const LoopPodSlot &slot : slots) {
1337 newInitArgs.push_back(genRead(loc, slot.podRef, slot.recordName, rewriter).getResult());
1340 auto newFor = rewriter.create<scf::ForOp>(
1341 loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newInitArgs
1343 newFor->setAttrs(forOp->getAttrs());
1345 Block &newBody = *newFor.getBody();
1346 dropTerminatorIfPresent(newBody);
1349 mapping.map(forOp.getInductionVar(), newFor.getInductionVar());
1350 for (
auto [idx, oldArg] : llvm::enumerate(forOp.getRegionIterArgs())) {
1351 mapping.map(oldArg, newFor.getRegionIterArg(idx));
1354 SmallVector<Value> slotValues = llvm::map_to_vector(
1355 llvm::seq<size_t>(0, slots.size()),
1356 [base =
static_cast<size_t>(forOp.getNumRegionIterArgs()), &newFor](
size_t idx) -> Value {
1357 return newFor.getRegionIterArg(llzk::checkedCast<unsigned>(base + idx));
1361 rewriter.setInsertionPointToEnd(&newBody);
1362 for (Operation &op : body) {
1363 if (
auto yieldOp = dyn_cast<scf::YieldOp>(&op)) {
1364 auto yieldValues = llvm::map_to_vector(yieldOp.getOperands(), [&mapping](Value operand) {
1365 return mapping.lookupOrDefault(operand);
1367 llvm::append_range(yieldValues, slotValues);
1368 rewriter.create<scf::YieldOp>(yieldOp.getLoc(), yieldValues);
1372 if (
auto readOp = dyn_cast<ReadPodOp>(&op)) {
1373 if (std::optional<size_t> slotIdx =
1374 findLoopSlotIndex(slots, readOp.
getPodRef(), getRecordNameAsStringAttr(readOp))) {
1375 mapping.map(readOp.
getResult(), slotValues[*slotIdx]);
1380 if (
auto writeOp = dyn_cast<WritePodOp>(&op)) {
1381 if (std::optional<size_t> slotIdx =
1382 findLoopSlotIndex(slots, writeOp.
getPodRef(), getRecordNameAsStringAttr(writeOp))) {
1383 slotValues[*slotIdx] = mapping.lookupOrDefault(writeOp.
getValue());
1388 rewriter.clone(op, mapping);
1391 rewriter.setInsertionPointAfter(newFor);
1392 for (
auto [idx, slot] : llvm::enumerate(slots)) {
1394 loc, slot.podRef, slot.recordName, newFor.getResult(forOp.getNumResults() + idx), rewriter
1398 rewriter.replaceOp(forOp, newFor.getResults().take_front(forOp.getNumResults()));
1405class LiftPodAccessesFromWhileLoopPattern final :
public OpRewritePattern<scf::WhileOp> {
1407 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
1409 LogicalResult matchAndRewrite(scf::WhileOp whileOp, PatternRewriter &rewriter)
const override {
1410 Block &beforeBody = *whileOp.getBeforeBody();
1411 Block &afterBody = *whileOp.getAfterBody();
1413 SmallVector<LoopPodSlot> slots;
1414 collectDirectLoopPodSlots(beforeBody, whileOp.getOperation(), slots);
1415 collectDirectLoopPodSlots(afterBody, whileOp.getOperation(), slots);
1416 if (slots.empty() || hasUnliftableLoopPodUses(beforeBody, slots) ||
1417 hasUnliftableLoopPodUses(afterBody, slots)) {
1421 Location loc = whileOp.getLoc();
1423 SmallVector<Value> newInits = llvm::to_vector(whileOp.getInits());
1424 SmallVector<Type> newResultTypes = llvm::to_vector(whileOp.getResultTypes());
1425 rewriter.setInsertionPoint(whileOp);
1426 for (
const LoopPodSlot &slot : slots) {
1427 newInits.push_back(genRead(loc, slot.podRef, slot.recordName, rewriter).getResult());
1428 newResultTypes.push_back(slot.type);
1431 auto newWhile = rewriter.create<scf::WhileOp>(loc, newResultTypes, newInits,
nullptr,
nullptr);
1432 newWhile->setAttrs(whileOp->getAttrs());
1434 Block &newBeforeBody = *newWhile.getBeforeBody();
1435 Block &newAfterBody = *newWhile.getAfterBody();
1436 dropTerminatorIfPresent(newBeforeBody);
1437 dropTerminatorIfPresent(newAfterBody);
1439 IRMapping beforeMapping;
1440 for (
auto [oldArg, newArg] : llvm::zip_equal(
1441 whileOp.getBeforeArguments(),
1442 newWhile.getBeforeArguments().take_front(whileOp.getBeforeArguments().size())
1444 beforeMapping.map(oldArg, newArg);
1447 SmallVector<Value> beforeSlotValues = llvm::map_to_vector(
1448 llvm::seq<size_t>(0, slots.size()),
1449 [base = whileOp.getBeforeArguments().size(), &newWhile](
size_t idx) -> Value {
1450 return newWhile.getBeforeArguments()[llzk::checkedCast<unsigned>(base + idx)];
1454 rewriter.setInsertionPointToEnd(&newBeforeBody);
1455 for (Operation &op : beforeBody) {
1456 if (
auto conditionOp = dyn_cast<scf::ConditionOp>(&op)) {
1457 SmallVector<Value> conditionArgs =
1458 llvm::map_to_vector(conditionOp.getArgs(), [&beforeMapping](Value a) {
1459 return beforeMapping.lookupOrDefault(a);
1461 llvm::append_range(conditionArgs, beforeSlotValues);
1462 rewriter.create<scf::ConditionOp>(
1463 conditionOp.getLoc(), beforeMapping.lookupOrDefault(conditionOp.getCondition()),
1469 if (
auto readOp = dyn_cast<ReadPodOp>(&op)) {
1470 if (std::optional<size_t> slotIdx =
1471 findLoopSlotIndex(slots, readOp.
getPodRef(), getRecordNameAsStringAttr(readOp))) {
1472 beforeMapping.map(readOp.
getResult(), beforeSlotValues[*slotIdx]);
1477 if (
auto writeOp = dyn_cast<WritePodOp>(&op)) {
1478 if (std::optional<size_t> slotIdx =
1479 findLoopSlotIndex(slots, writeOp.
getPodRef(), getRecordNameAsStringAttr(writeOp))) {
1480 beforeSlotValues[*slotIdx] = beforeMapping.lookupOrDefault(writeOp.
getValue());
1485 rewriter.clone(op, beforeMapping);
1488 IRMapping afterMapping;
1489 for (
auto [oldArg, newArg] : llvm::zip_equal(
1490 whileOp.getAfterArguments(),
1491 newWhile.getAfterArguments().take_front(whileOp.getAfterArguments().size())
1493 afterMapping.map(oldArg, newArg);
1496 SmallVector<Value> afterSlotValues = llvm::map_to_vector(
1497 llvm::seq<size_t>(0, slots.size()),
1498 [base = whileOp.getAfterArguments().size(), &newWhile](
size_t idx) -> Value {
1499 return newWhile.getAfterArguments()[llzk::checkedCast<unsigned>(base + idx)];
1503 rewriter.setInsertionPointToEnd(&newAfterBody);
1504 for (Operation &op : afterBody) {
1505 if (
auto yieldOp = dyn_cast<scf::YieldOp>(&op)) {
1506 SmallVector<Value> yieldValues =
1507 llvm::map_to_vector(yieldOp.getOperands(), [&afterMapping](Value v) {
1508 return afterMapping.lookupOrDefault(v);
1510 llvm::append_range(yieldValues, afterSlotValues);
1511 rewriter.create<scf::YieldOp>(yieldOp.getLoc(), yieldValues);
1515 if (
auto readOp = dyn_cast<ReadPodOp>(&op)) {
1516 if (std::optional<size_t> slotIdx =
1517 findLoopSlotIndex(slots, readOp.
getPodRef(), getRecordNameAsStringAttr(readOp))) {
1518 afterMapping.map(readOp.
getResult(), afterSlotValues[*slotIdx]);
1523 if (
auto writeOp = dyn_cast<WritePodOp>(&op)) {
1524 if (std::optional<size_t> slotIdx =
1525 findLoopSlotIndex(slots, writeOp.
getPodRef(), getRecordNameAsStringAttr(writeOp))) {
1526 afterSlotValues[*slotIdx] = afterMapping.lookupOrDefault(writeOp.
getValue());
1531 rewriter.clone(op, afterMapping);
1534 rewriter.setInsertionPointAfter(newWhile);
1535 for (
auto [idx, slot] : llvm::enumerate(slots)) {
1537 loc, slot.podRef, slot.recordName, newWhile.getResult(whileOp.getNumResults() + idx),
1542 rewriter.replaceOp(whileOp, newWhile.getResults().take_front(whileOp.getNumResults()));
1549applyGreedily(ModuleOp modOp, RewritePatternSet &&patterns,
bool *changed =
nullptr) {
1550 return applyPatternsGreedily(
1551 modOp->getRegion(0), std::move(patterns),
1552 GreedyRewriteConfig {.fold = false, .cseConstants = false}, changed
1558static LogicalResult step3(ModuleOp modOp) {
1559 RewritePatternSet patterns(modOp.getContext());
1561 ReplaceIfReadPattern, LiftPodWritesFromIfBlocksPattern, LiftPodAccessesFromForLoopPattern,
1562 LiftPodAccessesFromWhileLoopPattern, FoldIfCarriedPodReadAfterWritePattern>(
1563 patterns.getContext()
1566 LLVM_DEBUG(llvm::dbgs() <<
"Begin step 3: refactor pod ops within SCF regions\n";);
1567 return applyGreedily(modOp, std::move(patterns));
1572static bool applyIfCarriedPodReadAfterWritePatterns(ModuleOp modOp) {
1573 RewritePatternSet patterns(modOp.getContext());
1574 patterns.add<FoldIfCarriedPodReadAfterWritePattern>(patterns.getContext());
1576 bool changed =
false;
1577 if (failed(applyGreedily(modOp, std::move(patterns), &changed))) {
1585static size_t podTypeScalarizationWeight(Type type) {
1586 auto podTy = dyn_cast<PodType>(type);
1592 for (RecordAttr record : podTy.
getRecords()) {
1593 weight += podTypeScalarizationWeight(record.getType());
1601static size_t podAllocScalarizationWeight(ModuleOp modOp) {
1603 modOp.walk([&weight](
NewPodOp newPodOp) {
1604 weight += podTypeScalarizationWeight(newPodOp.getType());
1611 using Base = PodToScalarPassBase<PassImpl>;
1614 void runOnOperation()
override {
1615 ModuleOp module = getOperation();
1617 if (failed(step0(module))) {
1618 return signalPassFailure();
1621 llvm::dbgs() <<
"After step 0:\n";
1630 SymbolTableCollection symTables;
1631 MemberReplacementMap memberRepMap;
1632 if (failed(step1(module, symTables, memberRepMap))) {
1633 return signalPassFailure();
1636 llvm::dbgs() <<
"After step 1:\n";
1640 if (failed(step2(module, symTables, memberRepMap))) {
1641 return signalPassFailure();
1644 llvm::dbgs() <<
"After step 2:\n";
1649 if (failed(step3(module))) {
1650 return signalPassFailure();
1653 llvm::dbgs() <<
"After step 3:\n";
1661 OpPassManager scalarizePM(ModuleOp::getOperationName());
1666 OpPassManager cleanupPM(ModuleOp::getOperationName());
1667 cleanupPM.addPass(createRemoveDeadValuesPass());
1669 size_t podAllocWeight = podAllocScalarizationWeight(module);
1670 while (podAllocWeight != 0) {
1671 if (failed(runPipeline(scalarizePM, module))) {
1672 signalPassFailure();
1679 bool foldedIfCarriedRead = applyIfCarriedPodReadAfterWritePatterns(module);
1680 if (failed(runPipeline(cleanupPM, module))) {
1681 signalPassFailure();
1688 size_t nextPodAllocWeight = podAllocScalarizationWeight(module);
1689 if (!foldedIfCarriedRead && nextPodAllocWeight == podAllocWeight) {
1692 podAllocWeight = nextPodAllocWeight;
1695 llvm::dbgs() <<
"After SROA+Mem2Reg pipeline:\n";
within a display generated by the Derivative if and wherever such third party notices normally appear The contents of the NOTICE file are for informational purposes only and do not modify the License You may add Your own attribution notices within Derivative Works that You alongside or as an addendum to the NOTICE text from the provided that such additional attribution notices cannot be construed as modifying the License You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for or distribution of Your or for any such Derivative Works as a provided Your and distribution of the Work otherwise complies with the conditions stated in this License Submission of Contributions Unless You explicitly state any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this without any additional terms or conditions Notwithstanding the nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions Trademarks This License does not grant permission to use the trade names
Provides SpecializedSROA<AllocOpTy> and SpecializedMem2Reg<AllocOpTy>: pass templates that replicate ...
Common implementation for handling MemberWriteOp and MemberReadOp while destructuring an aggregate ty...
void setPublicAttr(bool newValue=true)
::mlir::StringAttr getSymNameAttr()
::mlir::TypedValue<::mlir::Type > getVal()
::llvm::SmallVector< RangeT > getMapOperands()
::mlir::MutableOperandRange getArgOperandsMutable()
::mlir::Operation::operand_range getArgOperands()
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Required by FunctionOpInterface.
::llvm::ArrayRef<::mlir::Type > getResultTypes()
Required by FunctionOpInterface.
::mlir::MutableOperandRange getOperandsMutable()
::mlir::Operation::operand_range getOperands()
::mlir::Operation::operand_range getInitialValues()
::mlir::TypedValue<::llzk::pod::PodType > getResult()
void setInitializedRecordsAttr(::mlir::ArrayAttr attr)
::mlir::MutableOperandRange getInitialValuesMutable()
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
::mlir::TypedValue<::mlir::Type > getResult()
::mlir::FlatSymbolRefAttr getRecordNameAttr()
::mlir::TypedValue<::llzk::pod::PodType > getPodRef()
::mlir::FlatSymbolRefAttr getRecordNameAttr()
::mlir::TypedValue<::mlir::Type > getValue()
::mlir::TypedValue<::llzk::pod::PodType > getPodRef()
constexpr char ARG_NAME_ATTR_NAME[]
Attribute name for source-level function argument names.
constexpr char RES_NAME_ATTR_NAME[]
Attribute name for source-level function result names.
mlir::ArrayAttr replicateFunctionNameAttrsAsNeeded(mlir::ArrayAttr origAttrs, const llvm::SmallVector< size_t > &originalIdxToSize, const llvm::SmallVector< mlir::Type > &newTypes, llvm::StringRef functionNameAttrName, llvm::ArrayRef< std::optional< llvm::StringRef > > origNames={}, llvm::ArrayRef< llvm::StringRef > existingNames={}, llvm::ArrayRef< llvm::SmallVector< std::string > > splitNameSuffixes={})
Expand function arg/result attribute arrays to match a split signature, rewriting name attrs with the...
std::unique_ptr< SpecializedMem2Reg< AllocOpTy > > createSpecializedMem2RegPass()
function::CallOp createCallPreservingInstantiationOperands(mlir::Location loc, mlir::TypeRange newResultTypes, function::CallOp oldCall, llvm::ArrayRef< mlir::ValueRange > mapOperands, mlir::ValueRange argOperands, mlir::ConversionPatternRewriter &rewriter)
Rebuild a function.call while preserving explicit instantiation state from oldCall.
SplitFunctionNameInfo collectSplitFunctionNameInfo(mlir::ArrayRef< mlir::Type > origTypes, GetNameAttrFn &&getNameAttr, GetSplitSuffixesFn &&getSplitSuffixes)
Collect function arg/result names and split suffixes from a list of original types.
std::unique_ptr< SpecializedSROA< AllocOpTy > > createSpecializedSROAPass()
static RecordChain getTombstoneKey()
static unsigned getHashValue(const RecordChain &chain)
static RecordChain getEmptyKey()
static bool isEqual(const RecordChain &lhs, const RecordChain &rhs)
llvm::SmallVector< std::optional< llvm::StringRef > > originalNames
llvm::SmallVector< llvm::StringRef > existingNames
llvm::SmallVector< llvm::SmallVector< std::string > > splitNameSuffixes