253static inline bool tableOffsetIsntSymbol(
MemberReadOp op) {
254 return !llvm::isa_and_present<SymbolRefAttr>(op.
getTableOffset().value_or(
nullptr));
260 ConversionTracker &tracker_;
262 SymbolTableCollection symTables;
263 bool reportMissing =
true;
265 class MappedTypeConverter :
public TypeConverter {
268 const DenseMap<Attribute, Attribute> ¶mNameToValue;
270 inline Attribute convertIfPossible(Attribute a)
const {
271 auto res = this->paramNameToValue.find(a);
272 return (res != this->paramNameToValue.end()) ? res->second : a;
279 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
281 : TypeConverter(), origTy(originalType), newTy(newType),
282 paramNameToValue(paramNameToInstantiatedValue) {
284 addConversion([](Type inputTy) {
return inputTy; });
287 LLVM_DEBUG(llvm::dbgs() <<
"[MappedTypeConverter] convert " << inputTy <<
'\n');
290 if (inputTy == this->origTy) {
294 if (ArrayAttr inputTyParams = inputTy.getParams()) {
295 SmallVector<Attribute> updated;
296 for (Attribute a : inputTyParams) {
297 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
298 updated.push_back(TypeAttr::get(this->convertType(ta.getValue())));
300 updated.push_back(convertIfPossible(a));
304 inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
311 addConversion([
this](
ArrayType inputTy) {
313 ArrayRef<Attribute> dimSizes = inputTy.getDimensionSizes();
314 if (!dimSizes.empty()) {
315 SmallVector<Attribute> updated;
316 for (Attribute a : dimSizes) {
317 updated.push_back(convertIfPossible(a));
319 return ArrayType::get(this->convertType(inputTy.getElementType()), updated);
325 addConversion([
this](
TypeVarType inputTy) -> Type {
327 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.getNameRef()))) {
328 Type convertedType = tyAttr.getValue();
333 return convertedType;
341 template <
typename Impl,
typename Op,
typename... HandledAttrs>
342 class SymbolUserHelper :
public OpConversionPattern<Op> {
344 const DenseMap<Attribute, Attribute> ¶mNameToValue;
347 TypeConverter &converter, MLIRContext *ctx,
unsigned Benefit,
348 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
350 : OpConversionPattern<Op>(converter, ctx, Benefit),
351 paramNameToValue(paramNameToInstantiatedValue) {}
354 using OpAdaptor =
typename mlir::OpConversionPattern<Op>::OpAdaptor;
356 virtual Attribute getNameAttr(Op)
const = 0;
358 virtual LogicalResult handleDefaultRewrite(
359 Attribute, Op op, OpAdaptor, ConversionPatternRewriter &, Attribute a
361 return op->emitOpError().append(
"expected value with type ", op.getType(),
" but found ", a);
365 matchAndRewrite(Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
366 LLVM_DEBUG(llvm::dbgs() <<
"[SymbolUserHelper] op: " << op <<
'\n');
367 auto res = this->paramNameToValue.find(getNameAttr(op));
368 if (res == this->paramNameToValue.end()) {
369 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] no instantiation for " << op <<
'\n');
372 llvm::TypeSwitch<Attribute, LogicalResult> TS(res->second);
373 llvm::TypeSwitch<Attribute, LogicalResult> *ptr = &TS;
375 ((ptr = &(ptr->template Case<HandledAttrs>([&](HandledAttrs a) {
376 return static_cast<const Impl *
>(
this)->handleRewrite(res->first, op, adaptor, rewriter, a);
380 return TS.Default([&](Attribute a) {
381 return handleDefaultRewrite(res->first, op, adaptor, rewriter, a);
387 class ClonedStructConstReadOpPattern
388 :
public SymbolUserHelper<
389 ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr> {
390 SmallVector<Diagnostic> &diagnostics;
393 SymbolUserHelper<ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr>;
396 ClonedStructConstReadOpPattern(
397 TypeConverter &converter, MLIRContext *ctx,
398 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue,
399 SmallVector<Diagnostic> &instantiationDiagnostics
403 : super(converter, ctx, 2, paramNameToInstantiatedValue),
404 diagnostics(instantiationDiagnostics) {}
408 LogicalResult handleRewrite(
409 Attribute sym,
ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, IntegerAttr a
411 APInt attrValue = a.getValue();
412 Type origResTy = op.getType();
413 if (llvm::isa<FeltType>(origResTy)) {
415 rewriter, op, FeltConstAttr::get(getContext(), attrValue)
420 if (llvm::isa<IndexType>(origResTy)) {
425 if (origResTy.isSignlessInteger(1)) {
427 if (attrValue.isZero()) {
431 if (!attrValue.isOne()) {
432 Location opLoc = op.getLoc();
433 Diagnostic diag(opLoc, DiagnosticSeverity::Warning);
435 if (getContext()->shouldPrintOpOnDiagnostic()) {
436 diag.attachNote(opLoc) <<
"see current operation: " << *op;
438 diag.attachNote(UnknownLoc::get(getContext()))
440 << sym <<
"\" for this call";
441 diagnostics.push_back(std::move(diag));
446 return op->emitOpError().append(
"unexpected result type ", origResTy);
449 LogicalResult handleRewrite(
450 Attribute,
ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, FeltConstAttr a
457 class ClonedStructMemberReadOpPattern
458 :
public SymbolUserHelper<
459 ClonedStructMemberReadOpPattern, MemberReadOp, IntegerAttr, FeltConstAttr> {
461 SymbolUserHelper<ClonedStructMemberReadOpPattern, MemberReadOp, IntegerAttr, FeltConstAttr>;
464 ClonedStructMemberReadOpPattern(
465 TypeConverter &converter, MLIRContext *ctx,
466 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
471 : super(converter, ctx, 2, paramNameToInstantiatedValue) {}
477 template <
typename Attr>
478 LogicalResult handleRewrite(
479 Attribute,
MemberReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, Attr a
481 rewriter.modifyOpInPlace(op, [&]() {
488 LogicalResult matchAndRewrite(
489 MemberReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
492 llvm::dbgs() <<
"[ClonedStructMemberReadOpPattern] MemberReadOp: " << op <<
'\n';
494 if (tableOffsetIsntSymbol(op)) {
498 return super::matchAndRewrite(op, adaptor, rewriter);
502 FailureOr<StructType> genClone(
StructType typeAtCaller, ArrayRef<Attribute> typeAtCallerParams) {
503 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] attempting clone of " << typeAtCaller <<
'\n');
505 FailureOr<SymbolLookupResult<StructDefOp>> r =
506 typeAtCaller.
getDefinition(symTables, rootMod, reportMissing);
508 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: cannot find StructDefOp \n");
511 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] found definition\n";);
515 MLIRContext *ctx = origStruct.getContext();
518 DenseMap<Attribute, Attribute> paramNameToConcrete;
522 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
525 ArrayAttr reducedParamNameList =
nullptr;
527 ArrayAttr reducedCallerParams =
nullptr;
529 ArrayAttr paramNames = typeAtDef.
getParams();
533 assert(paramNames.size() == typeAtCallerParams.size());
535 SmallVector<Attribute> remainingNames;
536 SmallVector<Attribute> nonConcreteParams;
537 for (
size_t i = 0, e = paramNames.size(); i < e; ++i) {
538 Attribute next = typeAtCallerParams[i];
539 if (isConcreteAttr<false>(next)) {
540 paramNameToConcrete[paramNames[i]] = next;
541 attrsForInstantiatedNameSuffix.push_back(next);
543 remainingNames.push_back(paramNames[i]);
544 nonConcreteParams.push_back(next);
545 attrsForInstantiatedNameSuffix.push_back(
nullptr);
549 assert(remainingNames.size() == nonConcreteParams.size());
550 assert(attrsForInstantiatedNameSuffix.size() == paramNames.size());
551 assert(remainingNames.size() + paramNameToConcrete.size() == paramNames.size());
553 if (paramNameToConcrete.empty()) {
554 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: no concrete params \n");
557 if (!remainingNames.empty()) {
558 reducedParamNameList = ArrayAttr::get(ctx, remainingNames);
559 reducedCallerParams = ArrayAttr::get(ctx, nonConcreteParams);
569 typeAtCaller.
getNameRef().getLeafReference().str(), attrsForInstantiatedNameSuffix
575 ModuleOp parentModule = origStruct.getParentOp<ModuleOp>();
576 symTables.getSymbolTable(parentModule).insert(newStruct, Block::iterator(origStruct));
580 auto typeAtCallerSym = typeAtCaller.
getNameRef();
582 SmallVector<FlatSymbolRefAttr> newLeafs(typeAtCallerSym.getNestedReferences());
583 auto rootSym = typeAtCallerSym.getRootReference();
584 if (!newLeafs.empty()) {
586 newLeafs.back() = FlatSymbolRefAttr::get(newLocalType.
getNameRef().getLeafReference());
589 rootSym = newLocalType.
getNameRef().getLeafReference();
594 llvm::dbgs() <<
"[StructCloner] original def type: " << typeAtDef <<
'\n';
595 llvm::dbgs() <<
"[StructCloner] cloned def type: " << newStruct.
getType() <<
'\n';
596 llvm::dbgs() <<
"[StructCloner] original remote type: " << typeAtCaller <<
'\n';
597 llvm::dbgs() <<
"[StructCloner] cloned local type: " << newLocalType <<
'\n';
598 llvm::dbgs() <<
"[StructCloner] cloned remote type: " << newRemoteType <<
'\n';
604 MappedTypeConverter tyConv(typeAtDef, newStruct.
getType(), paramNameToConcrete);
605 ConversionTarget target =
609 return paramNameToConcrete.find(op.getConstNameAttr()) == paramNameToConcrete.end();
613 patterns.add<ClonedStructConstReadOpPattern>(
614 tyConv, ctx, paramNameToConcrete, tracker_.delayedDiagnosticSet(newLocalType)
616 patterns.add<ClonedStructMemberReadOpPattern>(tyConv, ctx, paramNameToConcrete);
617 if (failed(applyFullConversion(newStruct, target, std::move(patterns)))) {
618 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] instantiating body of struct failed \n");
621 return newRemoteType;
625 StructCloner(ConversionTracker &tracker, ModuleOp root)
626 : tracker_(tracker), rootMod(root), symTables() {}
628 FailureOr<StructType> createInstantiatedClone(
StructType orig) {
629 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] orig: " << orig <<
'\n');
630 if (ArrayAttr params = orig.
getParams()) {
631 return genClone(orig, params.getValue());
633 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: nullptr for params \n");
637 void enableReportMissing() { reportMissing =
true; }
639 void disableReportMissing() { reportMissing =
false; }
642class DisableReportMissing;
644class ParameterizedStructUseTypeConverter :
public TypeConverter {
645 ConversionTracker &tracker_;
648 friend DisableReportMissing;
651 ParameterizedStructUseTypeConverter(ConversionTracker &tracker, ModuleOp root)
652 : TypeConverter(), tracker_(tracker), cloner(tracker, root) {
654 addConversion([](Type inputTy) {
return inputTy; });
658 llvm::dbgs() <<
"[ParameterizedStructUseTypeConverter] attempting conversion of "
662 if (
auto opt = tracker_.getInstantiation(inputTy)) {
668 FailureOr<StructType> cloneRes = cloner.createInstantiatedClone(inputTy);
669 if (failed(cloneRes)) {
674 llvm::dbgs() <<
"[ParameterizedStructUseTypeConverter] instantiating " << inputTy
675 <<
" as " << newTy <<
'\n'
677 tracker_.recordInstantiation(inputTy, newTy);
681 addConversion([
this](
ArrayType inputTy) {
682 return inputTy.cloneWith(convertType(inputTy.getElementType()));
687class CallStructFuncPattern :
public OpConversionPattern<CallOp> {
688 ConversionTracker &tracker_;
691 CallStructFuncPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &tracker)
694 : OpConversionPattern<CallOp>(converter, ctx, 2), tracker_(tracker) {}
696 LogicalResult matchAndRewrite(
697 CallOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
699 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] CallOp: " << op <<
'\n');
702 SmallVector<Type> newResultTypes;
703 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
704 return op->emitError(
"Could not convert Op result types.");
707 llvm::dbgs() <<
"[CallStructFuncPattern] newResultTypes: "
717 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] newStTy: " << newStTy <<
'\n');
718 calleeAttr =
appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
719 tracker_.reportDelayedDiagnostics(newStTy, op);
723 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] newStTy: " << newStTy <<
'\n');
724 calleeAttr =
appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
728 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] replaced " << op);
730 rewriter, op, newResultTypes, calleeAttr, adapter.
getMapOperands(),
734 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
740class MemberDefOpPattern :
public OpConversionPattern<MemberDefOp> {
742 MemberDefOpPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &)
745 : OpConversionPattern<MemberDefOp>(converter, ctx, 2) {}
747 LogicalResult matchAndRewrite(
748 MemberDefOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
750 LLVM_DEBUG(llvm::dbgs() <<
"[MemberDefOpPattern] MemberDefOp: " << op <<
'\n');
752 Type oldMemberType = op.
getType();
753 Type newMemberType = getTypeConverter()->convertType(oldMemberType);
754 if (oldMemberType == newMemberType) {
758 rewriter.modifyOpInPlace(op, [&op, &newMemberType]() { op.
setType(newMemberType); });
766 ParameterizedStructUseTypeConverter &tyConv;
769 explicit DisableReportMissing(ParameterizedStructUseTypeConverter &tc) : tyConv(tc) {}
771 void checkStarted()
override { tyConv.cloner.disableReportMissing(); }
773 void checkEnded(
bool)
override { tyConv.cloner.enableReportMissing(); }
776LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
777 MLIRContext *ctx = modOp.getContext();
778 ParameterizedStructUseTypeConverter tyConv(tracker, modOp);
779 DisableReportMissing drm(tyConv);
782 patterns.add<CallStructFuncPattern, MemberDefOpPattern>(tyConv, ctx, tracker);
783 return applyPartialConversion(modOp, target, std::move(patterns));
838std::optional<SmallVector<int64_t>> getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
839 SmallVector<int64_t> res;
840 for (OpFoldResult ofr : ofrs) {
841 std::optional<int64_t> cv = getConstantIntValue(ofr);
842 if (!cv.has_value()) {
845 res.push_back(cv.value());
850struct AffineMapFolder {
852 OperandRangeRange mapOpGroups;
853 DenseI32ArrayAttr dimsPerGroup;
854 ArrayRef<Attribute> paramsOfStructTy;
858 SmallVector<SmallVector<Value>> mapOpGroups;
859 SmallVector<int32_t> dimsPerGroup;
860 SmallVector<Attribute> paramsOfStructTy;
863 static inline SmallVector<ValueRange> getConvertedMapOpGroups(Output out) {
864 return llvm::map_to_vector(out.mapOpGroups, [](
const SmallVector<Value> &grp) {
865 return ValueRange(grp);
870 fold(PatternRewriter &rewriter,
const Input &in, Output &out, Operation *op,
const char *aspect) {
871 if (in.mapOpGroups.empty()) {
876 assert(in.mapOpGroups.size() <= in.paramsOfStructTy.size());
877 assert(std::cmp_equal(in.mapOpGroups.size(), in.dimsPerGroup.size()));
880 for (Attribute sizeAttr : in.paramsOfStructTy) {
881 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(sizeAttr)) {
882 ValueRange currMapOps = in.mapOpGroups[idx++];
887 SmallVector<OpFoldResult> currMapOpsCast = getAsOpFoldResult(currMapOps);
889 llvm::dbgs() <<
"[AffineMapFolder] currMapOps as fold results: "
892 if (
auto constOps = Step3_InstantiateAffineMaps::getConstantIntValues(currMapOpsCast)) {
893 SmallVector<Attribute> result;
894 bool hasPoison =
false;
895 auto constAttrs = llvm::map_to_vector(*constOps, [&rewriter](int64_t v) -> Attribute {
896 return rewriter.getIndexAttr(v);
898 LogicalResult foldResult = m.getAffineMap().constantFold(constAttrs, result, &hasPoison);
903 "Cannot fold affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
904 " due to divide by 0 or modulus with negative divisor"
909 if (failed(foldResult)) {
913 "Folding affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
" failed"
918 if (result.size() != 1) {
922 "Folding affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
923 " produced ", result.size(),
" results but expected 1"
928 assert(!llvm::isa<AffineMapAttr>(result[0]) &&
"not converted");
929 out.paramsOfStructTy.push_back(result[0]);
933 out.mapOpGroups.emplace_back(currMapOps);
934 out.dimsPerGroup.push_back(in.dimsPerGroup[idx - 1]);
937 out.paramsOfStructTy.push_back(sizeAttr);
939 assert(idx == in.mapOpGroups.size() &&
"all affine_map not processed");
941 in.paramsOfStructTy.size() == out.paramsOfStructTy.size() &&
942 "produced wrong number of dimensions"
950class InstantiateAtCreateArrayOp final :
public OpRewritePattern<CreateArrayOp> {
952 ConversionTracker &tracker_;
955 InstantiateAtCreateArrayOp(MLIRContext *ctx, ConversionTracker &tracker)
956 : OpRewritePattern(ctx), tracker_(tracker) {}
958 LogicalResult matchAndRewrite(
CreateArrayOp op, PatternRewriter &rewriter)
const override {
961 AffineMapFolder::Output out;
962 AffineMapFolder::Input in = {
967 if (failed(AffineMapFolder::fold(rewriter, in, out, op,
"array dimension"))) {
972 if (newResultType == oldResultType) {
977 assert(tracker_.isLegalConversion(oldResultType, newResultType,
"InstantiateAtCreateArrayOp"));
979 llvm::dbgs() <<
"[InstantiateAtCreateArrayOp] instantiating " << oldResultType <<
" as "
980 << newResultType <<
" in \"" << op <<
"\"\n"
983 rewriter, op, newResultType, AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup
990class InstantiateAtCallOpCompute final :
public OpRewritePattern<CallOp> {
991 ConversionTracker &tracker_;
994 InstantiateAtCallOpCompute(MLIRContext *ctx, ConversionTracker &tracker)
995 : OpRewritePattern(ctx), tracker_(tracker) {}
997 LogicalResult matchAndRewrite(
CallOp op, PatternRewriter &rewriter)
const override {
1002 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] target: " << op.
getCallee() <<
'\n');
1004 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] oldRetTy: " << oldRetTy <<
'\n');
1005 ArrayAttr params = oldRetTy.
getParams();
1011 AffineMapFolder::Output out;
1012 AffineMapFolder::Input in = {
1017 if (!in.mapOpGroups.empty()) {
1019 if (failed(AffineMapFolder::fold(rewriter, in, out, op,
"struct parameter"))) {
1023 llvm::dbgs() <<
"[InstantiateAtCallOpCompute] folded affine_map in result type params\n";
1029 if (callArgTypes.empty()) {
1033 SymbolTableCollection tables;
1035 if (failed(lookupRes)) {
1038 if (failed(instantiateViaTargetType(in, out, callArgTypes, lookupRes->get()))) {
1042 llvm::dbgs() <<
"[InstantiateAtCallOpCompute] propagated instantiations via symrefs in "
1043 "result type params: "
1049 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] newRetTy: " << newRetTy <<
'\n');
1050 if (newRetTy == oldRetTy) {
1058 if (!tracker_.isLegalConversion(oldRetTy, newRetTy,
"InstantiateAtCallOpCompute")) {
1059 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1061 "result type mismatch: due to struct instantiation, expected type ", newRetTy,
1062 ", but found ", oldRetTy
1066 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] replaced " << op);
1068 rewriter, op, TypeRange {newRetTy}, op.
getCallee(),
1069 AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup, op.
getArgOperands()
1072 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
1079 inline LogicalResult instantiateViaTargetType(
1080 const AffineMapFolder::Input &in, AffineMapFolder::Output &out,
1081 OperandRange::type_range callArgTypes,
FuncDefOp targetFunc
1086 assert(in.paramsOfStructTy.size() == targetResTyParams.size());
1088 if (llvm::all_of(in.paramsOfStructTy, isConcreteAttr<>)) {
1094 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1096 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']' <<
" target func arg types: "
1098 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1100 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1107 assert(unifies &&
"should have been checked by verifiers");
1110 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1119 SmallVector<Attribute> newReturnStructParams = llvm::map_to_vector(
1120 llvm::zip_equal(targetResTyParams.getValue(), in.paramsOfStructTy),
1121 [&unifications](std::tuple<Attribute, Attribute> p) {
1122 Attribute fromCall = std::get<1>(p);
1125 if (!isConcreteAttr<>(fromCall)) {
1126 Attribute fromTgt = std::get<0>(p);
1128 llvm::dbgs() <<
"[instantiateViaTargetType] fromCall = " << fromCall <<
'\n';
1129 llvm::dbgs() <<
"[instantiateViaTargetType] fromTgt = " << fromTgt <<
'\n';
1131 assert(llvm::isa<SymbolRefAttr>(fromTgt));
1132 auto it = unifications.find(std::make_pair(llvm::cast<SymbolRefAttr>(fromTgt), Side::LHS));
1133 if (it != unifications.end()) {
1134 Attribute unifiedAttr = it->second;
1136 llvm::dbgs() <<
"[instantiateViaTargetType] unifiedAttr = " << unifiedAttr <<
'\n';
1138 if (unifiedAttr && isConcreteAttr<false>(unifiedAttr)) {
1147 out.paramsOfStructTy = newReturnStructParams;
1148 assert(out.paramsOfStructTy.size() == in.paramsOfStructTy.size() &&
"post-condition");
1149 assert(out.mapOpGroups.empty() &&
"post-condition");
1150 assert(out.dimsPerGroup.empty() &&
"post-condition");
1155LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1156 MLIRContext *ctx = modOp.getContext();
1157 RewritePatternSet patterns(ctx);
1159 InstantiateAtCreateArrayOp,
1160 InstantiateAtCallOpCompute
1163 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1171class UpdateNewArrayElemFromWrite final :
public OpRewritePattern<CreateArrayOp> {
1172 ConversionTracker &tracker_;
1175 UpdateNewArrayElemFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1176 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1178 LogicalResult matchAndRewrite(
CreateArrayOp op, PatternRewriter &rewriter)
const override {
1180 ArrayType createResultType = dyn_cast<ArrayType>(createResult.getType());
1181 assert(createResultType &&
"CreateArrayOp must produce ArrayType");
1186 Type newResultElemType =
nullptr;
1187 for (Operation *user : createResult.getUsers()) {
1188 if (
WriteArrayOp writeOp = dyn_cast<WriteArrayOp>(user)) {
1189 if (writeOp.getArrRef() != createResult) {
1192 Type writeRValueType = writeOp.getRvalue().getType();
1193 if (writeRValueType == oldResultElemType) {
1196 if (newResultElemType && newResultElemType != writeRValueType) {
1199 <<
"[UpdateNewArrayElemFromWrite] multiple possible element types for CreateArrayOp "
1200 << newResultElemType <<
" vs " << writeRValueType <<
'\n'
1204 newResultElemType = writeRValueType;
1207 if (!newResultElemType) {
1211 if (!tracker_.isLegalConversion(
1212 oldResultElemType, newResultElemType,
"UpdateNewArrayElemFromWrite"
1217 rewriter.modifyOpInPlace(op, [&createResult, &newType]() { createResult.setType(newType); });
1219 llvm::dbgs() <<
"[UpdateNewArrayElemFromWrite] updated result type of " << op <<
'\n'
1227LogicalResult updateArrayElemFromArrAccessOp(
1229 PatternRewriter &rewriter
1236 if (oldArrType == newArrType ||
1237 !tracker.isLegalConversion(oldArrType, newArrType,
"updateArrayElemFromArrAccessOp")) {
1240 rewriter.modifyOpInPlace(op, [&op, &newArrType]() { op.
getArrRef().setType(newArrType); });
1242 llvm::dbgs() <<
"[updateArrayElemFromArrAccessOp] updated base array type in " << op <<
'\n'
1249class UpdateArrayElemFromArrWrite final :
public OpRewritePattern<WriteArrayOp> {
1250 ConversionTracker &tracker_;
1253 UpdateArrayElemFromArrWrite(MLIRContext *ctx, ConversionTracker &tracker)
1254 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1256 LogicalResult matchAndRewrite(
WriteArrayOp op, PatternRewriter &rewriter)
const override {
1257 return updateArrayElemFromArrAccessOp(op, op.
getRvalue().getType(), tracker_, rewriter);
1261class UpdateArrayElemFromArrRead final :
public OpRewritePattern<ReadArrayOp> {
1262 ConversionTracker &tracker_;
1265 UpdateArrayElemFromArrRead(MLIRContext *ctx, ConversionTracker &tracker)
1266 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1268 LogicalResult matchAndRewrite(
ReadArrayOp op, PatternRewriter &rewriter)
const override {
1269 return updateArrayElemFromArrAccessOp(op, op.
getResult().getType(), tracker_, rewriter);
1274class UpdateMemberDefTypeFromWrite final :
public OpRewritePattern<MemberDefOp> {
1275 ConversionTracker &tracker_;
1278 UpdateMemberDefTypeFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1279 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1281 LogicalResult matchAndRewrite(
MemberDefOp op, PatternRewriter &rewriter)
const override {
1284 assert(succeeded(parentRes) &&
"MemberDefOp parent is always StructDefOp");
1288 Type newType =
nullptr;
1290 std::optional<Location> newTypeLoc = std::nullopt;
1291 for (SymbolTable::SymbolUse symUse : memberUsers.value()) {
1292 if (
MemberWriteOp writeOp = llvm::dyn_cast<MemberWriteOp>(symUse.getUser())) {
1293 Type writeToType = writeOp.getVal().getType();
1294 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateMemberDefTypeFromWrite] checking " << writeOp <<
'\n');
1297 newType = writeToType;
1298 newTypeLoc = writeOp.getLoc();
1299 }
else if (writeToType != newType) {
1305 if (!tracker_.isLegalConversion(writeToType, newType,
"UpdateMemberDefTypeFromWrite")) {
1306 if (tracker_.isLegalConversion(
1307 newType, writeToType,
"UpdateMemberDefTypeFromWrite"
1310 newType = writeToType;
1311 newTypeLoc = writeOp.getLoc();
1314 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1318 "' with different value types"
1321 diag.attachNote(*newTypeLoc).append(
"type written here is ", newType);
1323 diag.attachNote(writeOp.getLoc()).append(
"type written here is ", writeToType);
1331 if (!newType || newType == op.
getType()) {
1335 if (!tracker_.isLegalConversion(op.
getType(), newType,
"UpdateMemberDefTypeFromWrite")) {
1338 rewriter.modifyOpInPlace(op, [&op, &newType]() { op.
setType(newType); });
1339 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateMemberDefTypeFromWrite] updated type of " << op <<
'\n');
1346SmallVector<std::unique_ptr<Region>> moveRegions(Operation *op) {
1347 SmallVector<std::unique_ptr<Region>> newRegions;
1348 for (Region ®ion : op->getRegions()) {
1349 auto newRegion = std::make_unique<Region>();
1350 newRegion->takeBody(region);
1351 newRegions.push_back(std::move(newRegion));
1360class UpdateInferredResultTypes final :
public OpTraitRewritePattern<OpTrait::InferTypeOpAdaptor> {
1361 ConversionTracker &tracker_;
1364 UpdateInferredResultTypes(MLIRContext *ctx, ConversionTracker &tracker)
1365 : OpTraitRewritePattern(ctx, 6), tracker_(tracker) {}
1367 LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter)
const override {
1368 SmallVector<Type, 1> inferredResultTypes;
1369 InferTypeOpInterface retTypeFn = llvm::cast<InferTypeOpInterface>(op);
1370 LogicalResult result = retTypeFn.inferReturnTypes(
1371 op->getContext(), op->getLoc(), op->getOperands(), op->getRawDictionaryAttrs(),
1372 op->getPropertiesStorage(), op->getRegions(), inferredResultTypes
1374 if (failed(result)) {
1377 if (op->getResultTypes() == inferredResultTypes) {
1381 if (!tracker_.areLegalConversions(
1382 op->getResultTypes(), inferredResultTypes,
"UpdateInferredResultTypes"
1388 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateInferredResultTypes] replaced " << *op);
1389 SmallVector<std::unique_ptr<Region>> newRegions = moveRegions(op);
1390 Operation *newOp = rewriter.create(
1391 op->getLoc(), op->getName().getIdentifier(), op->getOperands(), inferredResultTypes,
1392 op->getAttrs(), op->getSuccessors(), newRegions
1394 rewriter.replaceOp(op, newOp);
1395 LLVM_DEBUG(llvm::dbgs() <<
" with " << *newOp <<
'\n');
1401class UpdateFuncTypeFromReturn final :
public OpRewritePattern<FuncDefOp> {
1402 ConversionTracker &tracker_;
1405 UpdateFuncTypeFromReturn(MLIRContext *ctx, ConversionTracker &tracker)
1406 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1408 LogicalResult matchAndRewrite(
FuncDefOp op, PatternRewriter &rewriter)
const override {
1409 Region &body = op.getFunctionBody();
1413 ReturnOp retOp = llvm::dyn_cast<ReturnOp>(body.back().getTerminator());
1414 assert(retOp &&
"final op in body region must be return");
1415 OperandRange::type_range tyFromReturnOp = retOp.
getOperands().getTypes();
1418 if (oldFuncTy.getResults() == tyFromReturnOp) {
1422 if (!tracker_.areLegalConversions(
1423 oldFuncTy.getResults(), tyFromReturnOp,
"UpdateFuncTypeFromReturn"
1428 rewriter.modifyOpInPlace(op, [&]() {
1429 op.
setFunctionType(rewriter.getFunctionType(oldFuncTy.getInputs(), tyFromReturnOp));
1432 llvm::dbgs() <<
"[UpdateFuncTypeFromReturn] changed " << op.
getSymName() <<
" from "
1443class UpdateGlobalCallOpTypes final :
public OpRewritePattern<CallOp> {
1444 ConversionTracker &tracker_;
1447 UpdateGlobalCallOpTypes(MLIRContext *ctx, ConversionTracker &tracker)
1448 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1450 LogicalResult matchAndRewrite(
CallOp op, PatternRewriter &rewriter)
const override {
1451 SymbolTableCollection tables;
1453 if (failed(lookupRes)) {
1456 FuncDefOp targetFunc = lookupRes->get();
1461 if (op.getResultTypes() == targetFunc.
getFunctionType().getResults()) {
1465 if (!tracker_.areLegalConversions(
1467 "UpdateGlobalCallOpTypes"
1472 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateGlobalCallOpTypes] replaced " << op);
1475 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
1482LogicalResult updateMemberRefValFromMemberDef(
1485 SymbolTableCollection tables;
1490 Type oldResultType = op.
getVal().getType();
1491 Type newResultType = def->get().getType();
1492 if (oldResultType == newResultType ||
1493 !tracker.isLegalConversion(oldResultType, newResultType,
"updateMemberRefValFromMemberDef")) {
1496 rewriter.modifyOpInPlace(op, [&op, &newResultType]() { op.
getVal().setType(newResultType); });
1498 llvm::dbgs() <<
"[updateMemberRefValFromMemberDef] updated value type in " << op <<
'\n'
1506class UpdateMemberReadValFromDef final :
public OpRewritePattern<MemberReadOp> {
1507 ConversionTracker &tracker_;
1510 UpdateMemberReadValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1511 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1513 LogicalResult matchAndRewrite(
MemberReadOp op, PatternRewriter &rewriter)
const override {
1514 return updateMemberRefValFromMemberDef(op, tracker_, rewriter);
1519class UpdateMemberWriteValFromDef final :
public OpRewritePattern<MemberWriteOp> {
1520 ConversionTracker &tracker_;
1523 UpdateMemberWriteValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1524 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1526 LogicalResult matchAndRewrite(
MemberWriteOp op, PatternRewriter &rewriter)
const override {
1527 return updateMemberRefValFromMemberDef(op, tracker_, rewriter);
1531LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1532 MLIRContext *ctx = modOp.getContext();
1533 RewritePatternSet patterns(ctx);
1538 UpdateInferredResultTypes,
1540 UpdateGlobalCallOpTypes,
1541 UpdateFuncTypeFromReturn,
1542 UpdateNewArrayElemFromWrite,
1543 UpdateArrayElemFromArrRead,
1544 UpdateArrayElemFromArrWrite,
1545 UpdateMemberDefTypeFromWrite,
1546 UpdateMemberReadValFromDef,
1547 UpdateMemberWriteValFromDef
1550 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1558 SymbolTableCollection tables;
1561 : rootMod(root), defTree(symDefTree), useGraph(symUseGraph) {}
1569struct FromKeepSet :
public CleanupBase {
1570 using CleanupBase::CleanupBase;
1575 LogicalResult eraseUnreachableFrom(ArrayRef<StructDefOp> keep) {
1577 SetVector<SymbolOpInterface> roots(keep.begin(), keep.end());
1579 rootMod.walk([&roots](Operation *op) {
1583 if (!fdef.isInStruct()) {
1592 llvm::df_iterator_default_set<const SymbolUseGraphNode *> symbolsToKeep;
1593 for (
size_t i = 0; i < roots.size(); ++i) {
1594 SymbolOpInterface keepRoot = roots[i];
1595 LLVM_DEBUG({ llvm::dbgs() <<
"[EraseUnreachable] root: " << keepRoot <<
'\n'; });
1597 assert(keepRootNode &&
"every struct def must be in the def tree");
1598 for (
const SymbolDefTreeNode *reachableDefNode : llvm::depth_first(keepRootNode)) {
1600 llvm::dbgs() <<
"[EraseUnreachable] can reach: " << reachableDefNode->getOp() <<
'\n';
1602 if (SymbolOpInterface reachableDef = reachableDefNode->getOp()) {
1607 if (
const SymbolUseGraphNode *useGraphNodeForDef = useGraph.lookupNode(reachableDef)) {
1609 depth_first_ext(useGraphNodeForDef, symbolsToKeep)) {
1611 llvm::dbgs() <<
"[EraseUnreachable] uses symbol: "
1612 << usedSymbolNode->getSymbolPath() <<
'\n';
1616 if (usedSymbolNode->isStructParam()) {
1620 auto lookupRes = usedSymbolNode->lookupSymbol(tables);
1621 if (failed(lookupRes)) {
1622 LLVM_DEBUG(useGraph.dumpToDotFile());
1626 if (lookupRes->viaInclude()) {
1629 if (
StructDefOp asStruct = llvm::dyn_cast<StructDefOp>(lookupRes->get())) {
1630 bool insertRes = roots.insert(asStruct);
1634 llvm::dbgs() <<
"[EraseUnreachable] found another root: " << asStruct <<
'\n';
1644 rootMod.walk([
this, &symbolsToKeep](
StructDefOp op) {
1647 if (!symbolsToKeep.contains(n)) {
1648 LLVM_DEBUG(llvm::dbgs() <<
"[EraseUnreachable] removing: " << op.getSymName() <<
'\n');
1652 return WalkResult::skip();
1659struct FromEraseSet :
public CleanupBase {
1664 DenseSet<SymbolRefAttr> &&tryToErasePaths
1666 : CleanupBase(root, symDefTree, symUseGraph) {
1668 for (SymbolRefAttr path : tryToErasePaths) {
1669 LLVM_DEBUG(llvm::dbgs() <<
"[FromEraseSet] path to erase: " << path <<
'\n';);
1670 Operation *lookupFrom = rootMod.getOperation();
1672 assert(succeeded(res) &&
"inputs must be valid StructDefOp references");
1673 if (!res->viaInclude()) {
1674 auto op = res->get();
1675 LLVM_DEBUG(llvm::dbgs() <<
"[FromEraseSet] added op to the erase set: " << op <<
'\n';);
1676 tryToErase.insert(op);
1679 llvm::dbgs() <<
"[FromEraseSet] ignored op because it comes from an include: "
1680 << res->get() <<
'\n';
1686 LogicalResult eraseUnusedStructs() {
1689 collectSafeToErase(sd);
1694 for (
auto it = visitedPlusSafetyResult.begin(); it != visitedPlusSafetyResult.end(); ++it) {
1695 if (!it->second || !llvm::isa<StructDefOp>(it->first.getOperation())) {
1696 visitedPlusSafetyResult.erase(it);
1699 for (
auto &[sym, _] : visitedPlusSafetyResult) {
1700 LLVM_DEBUG(llvm::dbgs() <<
"[EraseIfUnused] removing: " << sym.getNameAttr() <<
'\n');
1706 const DenseSet<StructDefOp> &getTryToEraseSet()
const {
return tryToErase; }
1710 DenseSet<StructDefOp> tryToErase;
1714 DenseMap<SymbolOpInterface, bool> visitedPlusSafetyResult;
1716 DenseMap<const SymbolUseGraphNode *, SymbolOpInterface> lookupCache;
1720 bool collectSafeToErase(SymbolOpInterface
check) {
1724 auto visited = visitedPlusSafetyResult.find(
check);
1725 if (visited != visitedPlusSafetyResult.end()) {
1726 return visited->second;
1731 if (!tryToErase.contains(sd)) {
1732 visitedPlusSafetyResult[
check] =
false;
1739 visitedPlusSafetyResult[
check] =
true;
1743 if (collectSafeToErase(defTree.lookupNode(
check))) {
1744 auto useNode = useGraph.lookupNode(
check);
1745 assert(useNode || llvm::isa<ModuleOp>(
check.getOperation()));
1746 if (!useNode || collectSafeToErase(useNode)) {
1752 visitedPlusSafetyResult[
check] =
false;
1760 if (SymbolOpInterface checkOp = p->getOp()) {
1761 return collectSafeToErase(checkOp);
1771 if (SymbolOpInterface checkOp = cachedLookup(p)) {
1772 if (!collectSafeToErase(checkOp)) {
1785 assert(node &&
"must provide a node");
1787 auto fromCache = lookupCache.find(node);
1788 if (fromCache != lookupCache.end()) {
1789 return fromCache->second;
1793 assert(succeeded(lookupRes) &&
"graph contains node with invalid path");
1794 assert(lookupRes->get() !=
nullptr &&
"lookup must return an Operation");
1799 SymbolOpInterface actualRes =
1800 lookupRes->viaInclude() ? nullptr : llvm::cast<SymbolOpInterface>(lookupRes->get());
1802 lookupCache[node] = actualRes;
1803 assert((!actualRes == lookupRes->viaInclude()) &&
"not found iff included");