36#include <mlir/Dialect/Affine/IR/AffineOps.h>
37#include <mlir/Dialect/Affine/LoopUtils.h>
38#include <mlir/Dialect/Arith/IR/Arith.h>
39#include <mlir/Dialect/SCF/IR/SCF.h>
40#include <mlir/Dialect/SCF/Utils/Utils.h>
41#include <mlir/Dialect/Utils/StaticValueUtils.h>
42#include <mlir/IR/Attributes.h>
43#include <mlir/IR/BuiltinAttributes.h>
44#include <mlir/IR/BuiltinOps.h>
45#include <mlir/IR/BuiltinTypes.h>
46#include <mlir/Interfaces/InferTypeOpInterface.h>
47#include <mlir/Pass/PassManager.h>
48#include <mlir/Support/LLVM.h>
49#include <mlir/Support/LogicalResult.h>
50#include <mlir/Transforms/DialectConversion.h>
51#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
52#include <mlir/Transforms/WalkPatternRewriteDriver.h>
54#include <llvm/ADT/APInt.h>
55#include <llvm/ADT/DenseMap.h>
56#include <llvm/ADT/DepthFirstIterator.h>
57#include <llvm/ADT/STLExtras.h>
58#include <llvm/ADT/SmallVector.h>
59#include <llvm/ADT/TypeSwitch.h>
60#include <llvm/Support/Debug.h>
64#define GEN_PASS_DEF_FLATTENINGPASS
70#define DEBUG_TYPE "llzk-flatten"
84static void reportDelayedDiagnostics(
CallOp caller, SmallVector<Diagnostic> &&diagnostics) {
85 DiagnosticEngine &engine = caller.getContext()->getDiagEngine();
86 for (Diagnostic &diag : diagnostics) {
88 for (Diagnostic ¬e : diag.getNotes()) {
89 assert(note.getNotes().empty() &&
"notes cannot have notes attached");
90 if (llvm::isa<UnknownLoc>(note.getLocation())) {
91 note = std::move(Diagnostic(caller.getLoc(), note.getSeverity()).append(note.str()));
95 engine.emit(std::move(diag));
99class ConversionTracker {
104 DenseMap<StructType, StructType> structInstantiations;
106 DenseMap<StructType, StructType> reverseInstantiations;
108 DenseSet<SymbolRefAttr> funcInstantiations;
111 DenseMap<StructType, SmallVector<Diagnostic>> delayedDiagnostics;
114 bool isModified()
const {
return modified; }
115 void resetModifiedFlag() { modified =
false; }
116 void updateModifiedFlag(
bool currStepModified) { modified |= currStepModified; }
121 auto forwardResult = structInstantiations.try_emplace(oldType, newType);
122 if (forwardResult.second) {
125 assert(!reverseInstantiations.contains(newType));
126 reverseInstantiations[newType] = oldType;
131 assert(forwardResult.first->getSecond() == newType);
133 assert(reverseInstantiations.lookup(newType) == oldType);
135 assert(structInstantiations.size() == reverseInstantiations.size());
139 std::optional<StructType> getInstantiation(
StructType oldType)
const {
140 auto cachedResult = structInstantiations.find(oldType);
141 if (cachedResult != structInstantiations.end()) {
142 return cachedResult->second;
148 void recordInstantiation(SymbolRefAttr funcName) {
149 funcInstantiations.insert(funcName);
154 DenseSet<SymbolRefAttr> getInstantiatedDefinitionNames()
const {
155 DenseSet<SymbolRefAttr> instantiatedNames = funcInstantiations;
156 for (
const auto &[origRemoteTy, _] : structInstantiations) {
157 instantiatedNames.insert(origRemoteTy.getNameRef());
159 return instantiatedNames;
163 auto res = delayedDiagnostics.find(newType);
164 if (res != delayedDiagnostics.end()) {
165 ::reportDelayedDiagnostics(caller, std::move(res->second));
170 delayedDiagnostics.erase(newType);
174 SmallVector<Diagnostic> &delayedDiagnosticSet(
StructType newType) {
175 return delayedDiagnostics[newType];
180 bool isLegalConversion(Type oldType, Type newType,
const char *patName)
const {
181 std::function<bool(Type, Type)> checkInstantiations = [&](Type oTy, Type nTy) {
183 if (
StructType oldStructType = llvm::dyn_cast<StructType>(oTy)) {
186 if (this->structInstantiations.lookup(oldStructType) == nTy) {
193 if (
StructType newStructType = llvm::dyn_cast<StructType>(nTy)) {
194 if (
auto preImage = this->reverseInstantiations.lookup(newStructType)) {
207 llvm::dbgs() <<
"[" << patName <<
"] Cannot replace old type " << oldType
208 <<
" with new type " << newType
209 <<
" because it does not define a compatible and more concrete type.\n";
214 template <
typename T,
typename U>
215 inline bool areLegalConversions(T oldTypes, U newTypes,
const char *patName)
const {
217 llvm::zip_equal(oldTypes, newTypes), [
this, &patName](std::tuple<Type, Type> oldThenNew) {
218 return this->isLegalConversion(std::get<0>(oldThenNew), std::get<1>(oldThenNew), patName);
224template <
typename Impl,
typename Op,
typename... HandledAttrs>
225class SymbolUserHelper :
public OpConversionPattern<Op> {
227 const DenseMap<Attribute, Attribute> ¶mNameToValue;
230 TypeConverter &converter, MLIRContext *ctx,
unsigned patternBenefit,
231 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
233 : OpConversionPattern<Op>(converter, ctx, patternBenefit),
234 paramNameToValue(paramNameToInstantiatedValue) {}
237 using OpAdaptor =
typename mlir::OpConversionPattern<Op>::OpAdaptor;
239 virtual Attribute getNameAttr(Op)
const = 0;
241 virtual LogicalResult handleDefaultRewrite(
242 Attribute, Op op, OpAdaptor, ConversionPatternRewriter &, Attribute a
244 return op->emitOpError().append(
"expected value with type ", op.getType(),
" but found ", a);
248 matchAndRewrite(Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
249 LLVM_DEBUG(llvm::dbgs() <<
"[SymbolUserHelper] op: " << op <<
'\n');
250 auto res = this->paramNameToValue.find(getNameAttr(op));
251 if (res == this->paramNameToValue.end()) {
252 LLVM_DEBUG(llvm::dbgs() <<
"[SymbolUserHelper] no instantiation for " << op <<
'\n');
255 llvm::TypeSwitch<Attribute, LogicalResult> TS(res->second);
256 llvm::TypeSwitch<Attribute, LogicalResult> *ptr = &TS;
258 ((ptr = &(ptr->template Case<HandledAttrs>([&](HandledAttrs a) {
259 return static_cast<const Impl *
>(
this)->handleRewrite(res->first, op, adaptor, rewriter, a);
263 return TS.Default([&](Attribute a) {
264 return handleDefaultRewrite(res->first, op, adaptor, rewriter, a);
270class ClonedBodyConstReadOpPattern
271 :
public SymbolUserHelper<
272 ClonedBodyConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr> {
273 SmallVector<Diagnostic> &diagnostics;
276 SymbolUserHelper<ClonedBodyConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr>;
279 ClonedBodyConstReadOpPattern(
280 TypeConverter &converter, MLIRContext *ctx,
281 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue,
282 SmallVector<Diagnostic> &instantiationDiagnostics
285 : super(converter, ctx, 1, paramNameToInstantiatedValue),
286 diagnostics(instantiationDiagnostics) {}
288 Attribute getNameAttr(ConstReadOp op)
const override {
return op.
getConstNameAttr(); }
290 LogicalResult handleRewrite(
291 Attribute sym, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, IntegerAttr a
293 APInt attrValue = a.getValue();
294 Type origResTy = op.getType();
295 if (FeltType ty = llvm::dyn_cast<FeltType>(origResTy)) {
297 rewriter, op, FeltConstAttr::get(getContext(), attrValue, ty)
302 if (llvm::isa<IndexType>(origResTy)) {
307 if (origResTy.isSignlessInteger(1)) {
309 if (attrValue.isZero()) {
313 if (!attrValue.isOne()) {
314 Location opLoc = op.getLoc();
315 Diagnostic diag(opLoc, DiagnosticSeverity::Warning);
317 if (getContext()->shouldPrintOpOnDiagnostic()) {
318 diag.attachNote(opLoc) <<
"see current operation: " << *op;
320 diag.attachNote(UnknownLoc::get(getContext()))
322 <<
"\" for this call";
323 diagnostics.push_back(std::move(diag));
328 return op->emitOpError().append(
"unexpected result type ", origResTy);
331 LogicalResult handleRewrite(
332 Attribute, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, FeltConstAttr a
341struct MatchFailureListener :
public RewriterBase::Listener {
342 bool hadFailure =
false;
344 ~MatchFailureListener()
override {}
346 void notifyMatchFailure(Location loc, function_ref<
void(Diagnostic &)> reasonCallback)
override {
349 InFlightDiagnostic diag = emitError(loc);
350 reasonCallback(*diag.getUnderlyingDiagnostic());
356applyAndFoldGreedily(ModuleOp modOp, ConversionTracker &tracker, RewritePatternSet &&patterns) {
357 bool currStepModified =
false;
358 MatchFailureListener failureListener;
359 LogicalResult result = applyPatternsGreedily(
360 modOp->getRegion(0), std::move(patterns),
361 GreedyRewriteConfig {.maxIterations = 20, .listener = &failureListener, .fold = true},
364 tracker.updateModifiedFlag(currStepModified);
365 return failure(result.failed() || failureListener.hadFailure);
368template <
bool AllowStructParams = true>
bool isConcreteAttr(Attribute a) {
369 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(a)) {
372 if (IntegerAttr intAttr = dyn_cast<IntegerAttr>(a)) {
379convertCalleeSymRefs(SymbolRefAttr callee,
const DenseMap<Attribute, Attribute> ¶mNameToValue) {
380 auto it = paramNameToValue.find(FlatSymbolRefAttr::get(callee.getRootReference()));
381 if (it == paramNameToValue.end()) {
385 auto tyAttr = llvm::dyn_cast<TypeAttr>(it->second);
390 auto structTy = llvm::dyn_cast<StructType>(tyAttr.getValue());
395 SmallVector<FlatSymbolRefAttr> newPieces =
getPieces(structTy.getNameRef());
396 llvm::append_range(newPieces, callee.getNestedReferences());
401convertCalleesInPlace(Operation *op,
const DenseMap<Attribute, Attribute> ¶mNameToValue) {
402 op->walk([¶mNameToValue](
CallOp callOp) {
407static bool calleeReferencesTemplateParam(
CallOp op) {
409 if (!callee || callee.getNestedReferences().size() != 1) {
413 if (!parentTemplate) {
423static std::optional<Attribute>
424evaluateExpr(
TemplateExprOp exprOp,
const DenseMap<Attribute, Attribute> ¶mNameToConcrete) {
426 DenseMap<Value, Attribute> valueMap;
428 if (
auto yieldOp = llvm::dyn_cast<YieldOp>(bodyOp)) {
429 auto it = valueMap.find(yieldOp.getVal());
430 return it != valueMap.end() ? std::make_optional(it->second) : std::nullopt;
433 if (
auto constReadOp = llvm::dyn_cast<ConstReadOp>(bodyOp)) {
434 auto it = paramNameToConcrete.find(constReadOp.getConstNameAttr());
435 if (it == paramNameToConcrete.end()) {
440 Attribute val = it->second;
441 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(val)) {
442 if (
auto feltTy = llvm::dyn_cast<FeltType>(constReadOp.getResult().getType())) {
443 val = FeltConstAttr::get(bodyOp.getContext(), intAttr.getValue(), feltTy);
446 valueMap[constReadOp.getResult()] = val;
451 SmallVector<Attribute> operandAttrs;
452 operandAttrs.reserve(bodyOp.getNumOperands());
453 for (Value operand : bodyOp.getOperands()) {
454 auto it = valueMap.find(operand);
455 if (it == valueMap.end()) {
458 operandAttrs.push_back(it->second);
462 SmallVector<OpFoldResult> foldResults;
463 if (succeeded(bodyOp.fold(operandAttrs, foldResults)) &&
464 foldResults.size() == bodyOp.getNumResults()) {
465 for (
auto [result, fr] : llvm::zip_equal(bodyOp.getResults(), foldResults)) {
466 if (Attribute a = llvm::dyn_cast<Attribute>(fr)) {
467 valueMap[result] = a;
481evaluateTemplateExprs(
TemplateOp templateOp, DenseMap<Attribute, Attribute> ¶mNameToConcrete) {
483 llvm::dbgs() <<
"[evaluateTemplateExprs] before: " <<
debug::toStringList(paramNameToConcrete)
487 std::optional<Attribute> result = evaluateExpr(exprOp, paramNameToConcrete);
488 if (result.has_value()) {
489 auto exprNameAttr = FlatSymbolRefAttr::get(exprOp.
getSymNameAttr());
490 paramNameToConcrete.try_emplace(exprNameAttr, *result);
492 llvm::dbgs() <<
"[evaluateTemplateExprs] expr @" << exprOp.
getSymName()
493 <<
" evaluated to " << *result <<
'\n'
498 llvm::dbgs() <<
"[evaluateTemplateExprs] after: " <<
debug::toStringList(paramNameToConcrete)
505static inline bool tableOffsetIsntSymbol(
MemberReadOp op) {
506 return !llvm::isa_and_present<SymbolRefAttr>(op.
getTableOffset().value_or(
nullptr));
512 ConversionTracker &tracker_;
514 SymbolTableCollection symTables;
515 bool reportMissing =
true;
517 class MappedTypeConverter :
public TypeConverter {
520 const DenseMap<Attribute, Attribute> ¶mNameToValue;
522 inline Attribute convertIfPossible(Attribute a)
const {
523 auto res = this->paramNameToValue.find(a);
524 return (res != this->paramNameToValue.end()) ? res->second : a;
531 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
533 : TypeConverter(), origTy(originalType), newTy(newType),
534 paramNameToValue(paramNameToInstantiatedValue) {
536 addConversion([](Type inputTy) {
return inputTy; });
539 LLVM_DEBUG(llvm::dbgs() <<
"[MappedTypeConverter] convert " << inputTy <<
'\n');
542 if (inputTy == this->origTy) {
546 if (ArrayAttr inputTyParams = inputTy.getParams()) {
547 SmallVector<Attribute> updated;
548 for (Attribute a : inputTyParams) {
549 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
550 updated.push_back(TypeAttr::get(this->convertType(ta.getValue())));
552 updated.push_back(convertIfPossible(a));
556 inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
563 addConversion([
this](
ArrayType inputTy) {
565 ArrayRef<Attribute> dimSizes = inputTy.getDimensionSizes();
566 if (!dimSizes.empty()) {
567 SmallVector<Attribute> updated;
568 for (Attribute a : dimSizes) {
569 updated.push_back(convertIfPossible(a));
571 return ArrayType::get(this->convertType(inputTy.getElementType()), updated);
577 addConversion([
this](
TypeVarType inputTy) -> Type {
579 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.getNameRef()))) {
580 Type convertedType = tyAttr.getValue();
585 return convertedType;
593 class ClonedStructMemberReadOpPattern
594 :
public SymbolUserHelper<
595 ClonedStructMemberReadOpPattern, MemberReadOp, IntegerAttr, FeltConstAttr> {
597 SymbolUserHelper<ClonedStructMemberReadOpPattern, MemberReadOp, IntegerAttr, FeltConstAttr>;
600 ClonedStructMemberReadOpPattern(
601 TypeConverter &converter, MLIRContext *ctx,
602 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
605 : super(converter, ctx, 1, paramNameToInstantiatedValue) {}
611 template <
typename Attr>
612 LogicalResult handleRewrite(
613 Attribute,
MemberReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, Attr a
615 rewriter.modifyOpInPlace(op, [&]() {
622 LogicalResult matchAndRewrite(
623 MemberReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
626 llvm::dbgs() <<
"[ClonedStructMemberReadOpPattern] MemberReadOp: " << op <<
'\n';
628 if (tableOffsetIsntSymbol(op)) {
632 return super::matchAndRewrite(op, adaptor, rewriter);
636 FailureOr<StructType> genClone(
StructType typeAtCaller, ArrayRef<Attribute> typeAtCallerParams) {
637 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] attempting clone of " << typeAtCaller <<
'\n');
639 FailureOr<SymbolLookupResult<StructDefOp>> r =
640 typeAtCaller.
getDefinition(symTables, rootMod, reportMissing);
642 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: cannot find StructDefOp \n");
645 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] found definition\n";);
649 MLIRContext *ctx = origStruct.getContext();
652 DenseMap<Attribute, Attribute> paramNameToConcrete;
656 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
659 SmallVector<Attribute> remainingNames;
661 ArrayAttr reducedCallerParams =
nullptr;
663 ArrayAttr paramNames = typeAtDef.
getParams();
667 assert(paramNames.size() == typeAtCallerParams.size());
669 SmallVector<Attribute> nonConcreteParams;
670 for (
size_t i = 0, e = paramNames.size(); i < e; ++i) {
671 Attribute next = typeAtCallerParams[i];
672 if (isConcreteAttr<false>(next)) {
673 paramNameToConcrete[paramNames[i]] = next;
674 attrsForInstantiatedNameSuffix.push_back(next);
676 remainingNames.push_back(paramNames[i]);
677 nonConcreteParams.push_back(next);
678 attrsForInstantiatedNameSuffix.push_back(
nullptr);
682 assert(remainingNames.size() == nonConcreteParams.size());
683 assert(attrsForInstantiatedNameSuffix.size() == paramNames.size());
684 assert(remainingNames.size() + paramNameToConcrete.size() == paramNames.size());
686 if (paramNameToConcrete.empty()) {
687 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: no concrete params \n");
690 if (!remainingNames.empty()) {
691 reducedCallerParams = ArrayAttr::get(ctx, nonConcreteParams);
696 SmallVector<FlatSymbolRefAttr> typeAtCallerSymPieces =
getPieces(typeAtCaller.
getNameRef());
697 typeAtCallerSymPieces.pop_back();
700 typeAtCallerSymPieces.back().getValue().str(), attrsForInstantiatedNameSuffix
705 assert(parentTemplate &&
"parameterized struct must be nested in a TemplateOp");
707 assert(parentModule &&
"TemplateOp must be nested in a ModuleOp");
711 evaluateTemplateExprs(parentTemplate, paramNameToConcrete);
715 convertCalleesInPlace(newStruct, paramNameToConcrete);
716 if (remainingNames.empty()) {
719 (templateNameWithAttrs + mlir::Twine(
'_') + newStruct.
getSymName()).str()
723 symTables.getSymbolTable(parentModule).insert(newStruct, Block::iterator(parentTemplate));
725 typeAtCallerSymPieces.pop_back();
728 TemplateOp newTemplate = parentTemplate.cloneWithoutRegions();
729 newTemplate.
setSymName(templateNameWithAttrs);
730 assert(newTemplate->getNumRegions() > 0 &&
"region exists");
734 for (Attribute name : remainingNames) {
735 FlatSymbolRefAttr nameSym = llvm::dyn_cast<FlatSymbolRefAttr>(name);
736 assert(nameSym &&
"expected FlatSymbolRefAttr");
738 Operation *symOp = symTables.getSymbolTable(parentTemplate).lookup(nameSym.getAttr());
739 assert(symOp &&
"symbol must exist");
740 newTemplate.insert(newTemplate.begin(), symOp->clone());
745 symTables.getSymbolTable(newTemplate).insert(newStruct);
746 symTables.getSymbolTable(parentModule).insert(newTemplate, Block::iterator(parentTemplate));
750 typeAtCallerSymPieces.back() = FlatSymbolRefAttr::get(newTemplate.
getSymNameAttr());
756 typeAtCallerSymPieces.push_back(
757 FlatSymbolRefAttr::get(newLocalType.
getNameRef().getLeafReference())
762 llvm::dbgs() <<
"[StructCloner] original def type: " << typeAtDef <<
'\n';
763 llvm::dbgs() <<
"[StructCloner] cloned def type: " << newStruct.
getType() <<
'\n';
764 llvm::dbgs() <<
"[StructCloner] original remote type: " << typeAtCaller <<
'\n';
765 llvm::dbgs() <<
"[StructCloner] cloned local type: " << newLocalType <<
'\n';
766 llvm::dbgs() <<
"[StructCloner] cloned remote type: " << newRemoteType <<
'\n';
772 MappedTypeConverter tyConv(typeAtDef, newStruct.
getType(), paramNameToConcrete);
773 ConversionTarget target =
777 return !paramNameToConcrete.contains(op.getConstNameAttr());
781 patterns.add<ClonedBodyConstReadOpPattern>(
782 tyConv, ctx, paramNameToConcrete, tracker_.delayedDiagnosticSet(newLocalType)
784 patterns.add<ClonedStructMemberReadOpPattern>(tyConv, ctx, paramNameToConcrete);
785 if (failed(applyFullConversion(newStruct, target, std::move(patterns)))) {
786 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] instantiating body of struct failed \n");
789 return newRemoteType;
793 StructCloner(ConversionTracker &tracker, ModuleOp root)
794 : tracker_(tracker), rootMod(root), symTables() {}
796 FailureOr<StructType> createInstantiatedClone(
StructType orig) {
797 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] orig: " << orig <<
'\n');
798 if (ArrayAttr params = orig.
getParams()) {
799 return genClone(orig, params.getValue());
801 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: nullptr for params \n");
805 void enableReportMissing() { reportMissing =
true; }
807 void disableReportMissing() { reportMissing =
false; }
810class DisableReportMissing;
812class ParameterizedStructUseTypeConverter :
public TypeConverter {
813 ConversionTracker &tracker_;
816 friend DisableReportMissing;
819 ParameterizedStructUseTypeConverter(ConversionTracker &tracker, ModuleOp root)
820 : TypeConverter(), tracker_(tracker), cloner(tracker, root) {
822 addConversion([](Type inputTy) {
return inputTy; });
826 llvm::dbgs() <<
"[ParameterizedStructUseTypeConverter] attempting conversion of "
830 if (
auto opt = tracker_.getInstantiation(inputTy)) {
836 FailureOr<StructType> cloneRes = cloner.createInstantiatedClone(inputTy);
837 if (failed(cloneRes)) {
842 llvm::dbgs() <<
"[ParameterizedStructUseTypeConverter] instantiating " << inputTy
843 <<
" as " << newTy <<
'\n'
845 tracker_.recordInstantiation(inputTy, newTy);
849 addConversion([
this](
ArrayType inputTy) {
850 return inputTy.cloneWith(convertType(inputTy.getElementType()));
855class CallStructFuncPattern :
public OpConversionPattern<CallOp> {
856 ConversionTracker &tracker_;
859 CallStructFuncPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &tracker)
861 : OpConversionPattern<CallOp>(converter, ctx, 1), tracker_(tracker) {}
863 LogicalResult matchAndRewrite(
864 CallOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
866 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] CallOp: " << op <<
'\n');
869 SmallVector<Type> newResultTypes;
870 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
871 return op->emitError(
"Could not convert Op result types.");
874 llvm::dbgs() <<
"[CallStructFuncPattern] newResultTypes: "
884 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] newStTy: " << newStTy <<
'\n');
885 calleeAttr =
appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
886 tracker_.reportDelayedDiagnostics(newStTy, op);
890 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] newStTy: " << newStTy <<
'\n');
891 calleeAttr =
appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
895 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] replaced " << op);
897 rewriter, op, newResultTypes, calleeAttr, adapter.
getMapOperands(),
901 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
907class MemberDefOpPattern :
public OpConversionPattern<MemberDefOp> {
909 MemberDefOpPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &)
911 : OpConversionPattern<MemberDefOp>(converter, ctx, 1) {}
913 LogicalResult matchAndRewrite(
914 MemberDefOp op, OpAdaptor , ConversionPatternRewriter &rewriter
916 LLVM_DEBUG(llvm::dbgs() <<
"[MemberDefOpPattern] MemberDefOp: " << op <<
'\n');
918 Type oldMemberType = op.
getType();
919 Type newMemberType = getTypeConverter()->convertType(oldMemberType);
920 if (oldMemberType == newMemberType) {
923 rewriter.modifyOpInPlace(op, [&op, &newMemberType]() { op.
setType(newMemberType); });
931 ParameterizedStructUseTypeConverter &tyConv;
934 explicit DisableReportMissing(ParameterizedStructUseTypeConverter &tc) : tyConv(tc) {}
936 void checkStarted()
override { tyConv.cloner.disableReportMissing(); }
938 void checkEnded(
bool)
override { tyConv.cloner.enableReportMissing(); }
941LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
942 MLIRContext *ctx = modOp.getContext();
943 ParameterizedStructUseTypeConverter tyConv(tracker, modOp);
944 DisableReportMissing drm(tyConv);
947 patterns.add<CallStructFuncPattern, MemberDefOpPattern>(tyConv, ctx, tracker);
948 return applyPartialConversion(modOp, target, std::move(patterns));
957class FuncInstTypeConverter :
public TypeConverter {
958 DenseMap<Attribute, Attribute> paramNameToValue;
960 Attribute convertIfPossible(Attribute a)
const {
961 auto res = paramNameToValue.find(a);
962 return (res != paramNameToValue.end()) ? res->second : a;
966 explicit FuncInstTypeConverter(DenseMap<Attribute, Attribute> paramNameToConcrete)
967 : TypeConverter(), paramNameToValue(std::move(paramNameToConcrete)) {
968 addConversion([](Type t) {
return t; });
970 addConversion([
this](
TypeVarType inputTy) -> Type {
971 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.
getNameRef()))) {
972 Type convertedType = tyAttr.getValue();
974 return convertedType;
980 addConversion([
this](
ArrayType inputTy) {
981 SmallVector<Attribute> updated;
982 bool changed =
false;
983 for (Attribute a : inputTy.getDimensionSizes()) {
984 Attribute converted = convertIfPossible(a);
985 updated.push_back(converted);
986 if (converted != a) {
990 Type newElemTy = this->convertType(inputTy.getElementType());
991 if (!changed && newElemTy == inputTy.getElementType()) {
998 if (ArrayAttr params = inputTy.getParams()) {
999 SmallVector<Attribute> updated;
1000 bool changed =
false;
1001 for (Attribute a : params) {
1002 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
1003 Type newTy = this->convertType(ta.getValue());
1004 if (newTy != ta.getValue()) {
1005 updated.push_back(TypeAttr::get(newTy));
1010 Attribute converted = convertIfPossible(a);
1011 if (converted != a) {
1012 updated.push_back(converted);
1017 updated.push_back(a);
1021 inputTy.
getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
1029 bool containsParam(Attribute nameAttr)
const {
return paramNameToValue.contains(nameAttr); }
1030 const DenseMap<Attribute, Attribute> &getParamMap()
const {
return paramNameToValue; }
1033class InstantiateFuncAtCallOp final :
public OpRewritePattern<CallOp> {
1034 ConversionTracker &tracker_;
1037 InstantiateFuncAtCallOp(MLIRContext *ctx, ConversionTracker &tracker)
1038 : OpRewritePattern<CallOp>(ctx), tracker_(tracker) {}
1040 LogicalResult matchAndRewrite(
CallOp op, PatternRewriter &rewriter)
const override {
1041 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateFuncAtCallOp] op: " << op <<
'\n');
1043 if (calleeReferencesTemplateParam(op)) {
1048 SymbolTableCollection symTables;
1049 FailureOr<SymbolLookupResult<FuncDefOp>> callTgtOpt = op.
getCalleeTarget(symTables);
1050 if (failed(callTgtOpt)) {
1051 return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
1052 diag <<
"could not find target function for call";
1058 TemplateOp parentTemplate = llvm::dyn_cast<TemplateOp>(callTgt->getParentOp());
1059 if (!parentTemplate) {
1063 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] target function in template "
1075 if (failed(unifyResult)) {
1076 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1077 diag.append(
"target function type does not unify with call type ")
1079 .attachNote(callTgt.getLoc())
1080 .append(
"target function declared here");
1084 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] unifications of types: "
1089 DenseMap<Attribute, Attribute> paramNameToConcrete;
1094 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] TemplateParamsAttr: " << callParams <<
'\n'
1097 for (
auto paramOp : realParams) {
1098 auto paramName = FlatSymbolRefAttr::get(paramOp.getSymNameAttr());
1099 auto it = unifyResult->find({paramName,
Side::RHS});
1100 if (it == unifyResult->end()) {
1102 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] unification for param '" << paramName
1107 Attribute inferredVal = it->second;
1108 if (!isConcreteAttr(inferredVal)) {
1110 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] unification for param '" << paramName
1111 <<
"': not concrete, " << inferredVal <<
'\n'
1118 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] unification for param '" << paramName
1119 <<
"': incompatible with specified param type. MUST FAIL!\n"
1121 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1122 diag.append(
"inferred value for parameter '")
1124 .append(
"' is incompatible with specified param type")
1125 .attachNote(paramOp.getLoc())
1126 .append(
"template parameter declared here");
1129 paramNameToConcrete[paramName] = inferredVal;
1134 assert((callParams.size() == llvm::range_size(realParams)) &&
"per CallOpVerifier");
1136 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1137 diag.append(
"incompatible with specified param type(s)");
1141 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1142 diag.append(
"incompatible with inferred param value(s)");
1146 for (
auto [paramOp, attr] : llvm::zip_equal(realParams, callParams.getValue())) {
1147 auto paramName = FlatSymbolRefAttr::get(paramOp.getSymNameAttr());
1148 if (!isConcreteAttr(attr)) {
1150 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] unification for param '" << paramName
1151 <<
"': not concrete, " << attr <<
'\n'
1155 paramNameToConcrete[paramName] = attr;
1159 if (paramNameToConcrete.empty()) {
1160 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateFuncAtCallOp] skip: no concrete params\n");
1166 evaluateTemplateExprs(parentTemplate, paramNameToConcrete);
1169 SmallVector<Attribute> remainingNames;
1170 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
1172 auto it = paramNameToConcrete.find(paramName);
1173 if (it != paramNameToConcrete.end()) {
1174 attrsForInstantiatedNameSuffix.push_back(it->second);
1176 attrsForInstantiatedNameSuffix.push_back(
nullptr);
1177 remainingNames.push_back(paramName);
1181 MLIRContext *ctx = op.getContext();
1183 assert(parentModule &&
"TemplateOp must be nested in a ModuleOp");
1188 parentTemplate.
getSymName().str(), attrsForInstantiatedNameSuffix
1196 auto applyBodyConversions = [&](
FuncDefOp newFunc) -> LogicalResult {
1197 FuncInstTypeConverter tyConv(paramNameToConcrete);
1201 return !tyConv.containsParam(p.getConstNameAttr());
1203 SmallVector<Diagnostic> delayedDiagnostics;
1205 bodyPatterns.add<ClonedBodyConstReadOpPattern>(
1206 tyConv, ctx, tyConv.getParamMap(), delayedDiagnostics
1208 if (failed(applyFullConversion(newFunc, target, std::move(bodyPatterns)))) {
1212 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] instantiated clone: " << newFunc <<
'\n'
1214 ::reportDelayedDiagnostics(op, std::move(delayedDiagnostics));
1217 SymbolTableCollection tables;
1218 WalkResult res = newFunc.walk([&tables](
CallOp nestedCall) {
1221 return failure(res.wasInterrupted());
1225 assert(symPieces.size() >= 2 &&
"callee must include at least template and function names");
1227 if (remainingNames.empty()) {
1230 std::string newFuncName =
1231 (mlir::Twine(templateNameWithAttrs) +
"_" + callTgt.
getSymName()).str();
1232 StringRef actualNewFuncName = newFuncName;
1233 if (!symTables.getSymbolTable(parentModule).lookup(newFuncName)) {
1236 convertCalleesInPlace(newFunc, paramNameToConcrete);
1238 symTables.getSymbolTable(parentModule).insert(newFunc, Block::iterator(parentTemplate));
1241 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] created full instantiation function: "
1242 << actualNewFuncName <<
'\n'
1244 if (failed(applyBodyConversions(newFunc))) {
1246 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] body conversion failed for "
1247 << actualNewFuncName <<
'\n'
1250 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1251 diag.append(
"failure while creating instantiated function '", actualNewFuncName,
'\'');
1256 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] reusing full instantiation function: "
1257 << actualNewFuncName <<
'\n'
1263 symPieces.pop_back();
1264 symPieces.pop_back();
1265 symPieces.push_back(FlatSymbolRefAttr::get(StringAttr::get(ctx, actualNewFuncName)));
1272 if (Operation *existing =
1273 symTables.getSymbolTable(parentModule).lookup(templateNameWithAttrs)) {
1274 newTemplate = llvm::dyn_cast<TemplateOp>(existing);
1278 newTemplate = parentTemplate.cloneWithoutRegions();
1279 newTemplate.
setSymName(templateNameWithAttrs);
1280 assert(newTemplate->getNumRegions() > 0 &&
"region exists");
1284 Block &newTemplateBody = newTemplate.
getBodyRegion().front();
1285 for (Attribute name : remainingNames) {
1286 FlatSymbolRefAttr nameSym = llvm::cast<FlatSymbolRefAttr>(name);
1287 Operation *paramOp = symTables.getSymbolTable(parentTemplate).lookup(nameSym.getAttr());
1288 assert(paramOp &&
"symbol must exist");
1289 newTemplateBody.push_back(paramOp->clone());
1294 convertCalleesInPlace(newFunc, paramNameToConcrete);
1298 symTables.getSymbolTable(newTemplate).insert(newFunc);
1299 symTables.getSymbolTable(parentModule).insert(newTemplate, Block::iterator(parentTemplate));
1300 if (failed(applyBodyConversions(newFunc))) {
1301 StringRef newFuncName = newFunc.
getSymName();
1303 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] body conversion failed for "
1304 << newFuncName <<
'\n'
1306 newTemplate->erase();
1307 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1308 diag.append(
"failure while creating instantiated function '", newFuncName,
'\'');
1313 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] created partial instantiation template: "
1318 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] reusing partial instantiation template: "
1325 symPieces.pop_back();
1326 symPieces.pop_back();
1327 symPieces.push_back(FlatSymbolRefAttr::get(newTemplate.
getSymNameAttr()));
1328 symPieces.push_back(FlatSymbolRefAttr::get(callTgt.
getSymNameAttr()));
1331 tracker_.recordInstantiation(originalCalleeAttr);
1334 rewriter.modifyOpInPlace(op, [&op, &symPieces]() {
1338 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] updating callee from " << op.
getCalleeAttr()
1339 <<
" to " << newCalleeAttr <<
'\n';
1345 tracker_.updateModifiedFlag(
true);
1350LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1351 MLIRContext *ctx = modOp.getContext();
1352 RewritePatternSet patterns(ctx);
1353 patterns.add<InstantiateFuncAtCallOp>(ctx, tracker);
1354 MatchFailureListener failureListener;
1355 walkAndApplyPatterns(modOp, std::move(patterns), &failureListener);
1356 return failure(failureListener.hadFailure);
1364template <HasInterface<LoopLikeOpInterface> OpClass>
1365class LoopUnrollPattern :
public OpRewritePattern<OpClass> {
1367 using OpRewritePattern<OpClass>::OpRewritePattern;
1369 LogicalResult matchAndRewrite(OpClass loopOp, PatternRewriter &rewriter)
const override {
1370 if (
auto maybeConstant = getConstantTripCount(loopOp)) {
1371 uint64_t tripCount = *maybeConstant;
1372 if (tripCount == 0) {
1373 rewriter.eraseOp(loopOp);
1375 }
else if (tripCount == 1) {
1376 return loopOp.promoteIfSingleIteration(rewriter);
1378 return loopUnrollByFactor(loopOp, tripCount);
1386 static std::optional<int64_t> getConstantTripCount(LoopLikeOpInterface loopOp) {
1387 std::optional<OpFoldResult> lbVal = loopOp.getSingleLowerBound();
1388 std::optional<OpFoldResult> ubVal = loopOp.getSingleUpperBound();
1389 std::optional<OpFoldResult> stepVal = loopOp.getSingleStep();
1390 if (!lbVal.has_value() || !ubVal.has_value() || !stepVal.has_value()) {
1391 return std::nullopt;
1393 return constantTripCount(lbVal.value(), ubVal.value(), stepVal.value());
1397LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1398 MLIRContext *ctx = modOp.getContext();
1399 RewritePatternSet patterns(ctx);
1400 patterns.add<LoopUnrollPattern<scf::ForOp>>(ctx);
1401 patterns.add<LoopUnrollPattern<affine::AffineForOp>>(ctx);
1403 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1411std::optional<SmallVector<int64_t>> getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
1412 SmallVector<int64_t> res;
1413 for (OpFoldResult ofr : ofrs) {
1414 std::optional<int64_t> cv = getConstantIntValue(ofr);
1415 if (!cv.has_value()) {
1416 return std::nullopt;
1418 res.push_back(cv.value());
1423struct AffineMapFolder {
1425 OperandRangeRange mapOpGroups;
1426 DenseI32ArrayAttr dimsPerGroup;
1427 ArrayRef<Attribute> paramsOfStructTy;
1431 SmallVector<SmallVector<Value>> mapOpGroups;
1432 SmallVector<int32_t> dimsPerGroup;
1433 SmallVector<Attribute> paramsOfStructTy;
1436 static inline SmallVector<ValueRange> getConvertedMapOpGroups(Output out) {
1437 return llvm::map_to_vector(out.mapOpGroups, [](
const SmallVector<Value> &grp) {
1438 return ValueRange(grp);
1442 static LogicalResult
1443 fold(PatternRewriter &rewriter,
const Input &in, Output &out, Operation *op,
const char *aspect) {
1444 if (in.mapOpGroups.empty()) {
1449 assert(in.mapOpGroups.size() <= in.paramsOfStructTy.size());
1450 assert(std::cmp_equal(in.mapOpGroups.size(), in.dimsPerGroup.size()));
1453 for (Attribute sizeAttr : in.paramsOfStructTy) {
1454 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(sizeAttr)) {
1455 ValueRange currMapOps = in.mapOpGroups[idx++];
1460 SmallVector<OpFoldResult> currMapOpsCast = getAsOpFoldResult(currMapOps);
1462 llvm::dbgs() <<
"[AffineMapFolder] currMapOps as fold results: "
1465 if (
auto constOps = Step3_InstantiateAffineMaps::getConstantIntValues(currMapOpsCast)) {
1466 SmallVector<Attribute> result;
1467 bool hasPoison =
false;
1468 auto constAttrs = llvm::map_to_vector(*constOps, [&rewriter](int64_t v) -> Attribute {
1469 return rewriter.getIndexAttr(v);
1471 LogicalResult foldResult = m.getAffineMap().constantFold(constAttrs, result, &hasPoison);
1476 "Cannot fold affine_map for ", aspect,
' ', out.paramsOfStructTy.size(),
1477 " due to divide by 0 or modulus with negative divisor"
1482 if (failed(foldResult)) {
1486 "Folding affine_map for ", aspect,
' ', out.paramsOfStructTy.size(),
" failed"
1491 if (result.size() != 1) {
1495 "Folding affine_map for ", aspect,
' ', out.paramsOfStructTy.size(),
1496 " produced ", result.size(),
" results but expected 1"
1501 assert(!llvm::isa<AffineMapAttr>(result[0]) &&
"not converted");
1502 out.paramsOfStructTy.push_back(result[0]);
1506 out.mapOpGroups.emplace_back(currMapOps);
1507 out.dimsPerGroup.push_back(in.dimsPerGroup[idx - 1]);
1510 out.paramsOfStructTy.push_back(sizeAttr);
1512 assert(idx == in.mapOpGroups.size() &&
"all affine_map not processed");
1514 in.paramsOfStructTy.size() == out.paramsOfStructTy.size() &&
1515 "produced wrong number of dimensions"
1523class InstantiateAtCreateArrayOp final :
public OpRewritePattern<CreateArrayOp> {
1525 ConversionTracker &tracker_;
1528 InstantiateAtCreateArrayOp(MLIRContext *ctx, ConversionTracker &tracker)
1529 : OpRewritePattern(ctx), tracker_(tracker) {}
1531 LogicalResult matchAndRewrite(
CreateArrayOp op, PatternRewriter &rewriter)
const override {
1534 AffineMapFolder::Output out;
1535 AffineMapFolder::Input in = {
1540 if (failed(AffineMapFolder::fold(rewriter, in, out, op,
"array dimension"))) {
1545 if (newResultType == oldResultType) {
1549 assert(tracker_.isLegalConversion(oldResultType, newResultType,
"InstantiateAtCreateArrayOp"));
1551 llvm::dbgs() <<
"[InstantiateAtCreateArrayOp] instantiating " << oldResultType <<
" as "
1552 << newResultType <<
" in \"" << op <<
"\"\n"
1555 rewriter, op, newResultType, AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup
1562class InstantiateAtCallOpCompute final :
public OpRewritePattern<CallOp> {
1563 ConversionTracker &tracker_;
1566 InstantiateAtCallOpCompute(MLIRContext *ctx, ConversionTracker &tracker)
1567 : OpRewritePattern(ctx), tracker_(tracker) {}
1569 LogicalResult matchAndRewrite(
CallOp op, PatternRewriter &rewriter)
const override {
1574 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] target: " << op.
getCallee() <<
'\n');
1576 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] oldRetTy: " << oldRetTy <<
'\n');
1577 ArrayAttr params = oldRetTy.
getParams();
1583 AffineMapFolder::Output out;
1584 AffineMapFolder::Input in = {
1589 if (!in.mapOpGroups.empty()) {
1591 if (failed(AffineMapFolder::fold(rewriter, in, out, op,
"struct parameter"))) {
1595 llvm::dbgs() <<
"[InstantiateAtCallOpCompute] folded affine_map in result type params\n";
1601 if (callArgTypes.empty()) {
1605 if (calleeReferencesTemplateParam(op)) {
1608 SymbolTableCollection tables;
1610 if (failed(lookupRes)) {
1613 if (failed(instantiateViaTargetType(in, out, callArgTypes, lookupRes->get()))) {
1617 llvm::dbgs() <<
"[InstantiateAtCallOpCompute] propagated instantiations via symrefs in "
1618 "result type params: "
1624 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] newRetTy: " << newRetTy <<
'\n');
1625 if (newRetTy == oldRetTy) {
1632 if (!tracker_.isLegalConversion(oldRetTy, newRetTy,
"InstantiateAtCallOpCompute")) {
1633 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1635 "result type mismatch: due to struct instantiation, expected type ", newRetTy,
1636 ", but found ", oldRetTy
1640 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] replaced " << op);
1642 rewriter, op, TypeRange {newRetTy}, op.
getCallee(),
1643 AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup, op.
getArgOperands()
1646 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
1653 inline LogicalResult instantiateViaTargetType(
1654 const AffineMapFolder::Input &in, AffineMapFolder::Output &out,
1655 OperandRange::type_range callArgTypes,
FuncDefOp targetFunc
1660 assert(in.paramsOfStructTy.size() == targetResTyParams.size());
1662 if (llvm::all_of(in.paramsOfStructTy, isConcreteAttr<>)) {
1668 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1670 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']' <<
" target func arg types: "
1672 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1674 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1681 assert(unifies &&
"should have been checked by verifiers");
1684 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1693 SmallVector<Attribute> newReturnStructParams = llvm::map_to_vector(
1694 llvm::zip_equal(targetResTyParams.getValue(), in.paramsOfStructTy),
1695 [&unifications](std::tuple<Attribute, Attribute> p) {
1696 Attribute fromCall = std::get<1>(p);
1699 if (!isConcreteAttr<>(fromCall)) {
1700 Attribute fromTgt = std::get<0>(p);
1702 llvm::dbgs() <<
"[instantiateViaTargetType] fromCall = " << fromCall <<
'\n';
1703 llvm::dbgs() <<
"[instantiateViaTargetType] fromTgt = " << fromTgt <<
'\n';
1705 assert(llvm::isa<SymbolRefAttr>(fromTgt));
1706 auto it = unifications.find(std::make_pair(llvm::cast<SymbolRefAttr>(fromTgt), Side::LHS));
1707 if (it != unifications.end()) {
1708 Attribute unifiedAttr = it->second;
1710 llvm::dbgs() <<
"[instantiateViaTargetType] unifiedAttr = " << unifiedAttr <<
'\n';
1712 if (unifiedAttr && isConcreteAttr<false>(unifiedAttr)) {
1721 out.paramsOfStructTy = newReturnStructParams;
1722 assert(out.paramsOfStructTy.size() == in.paramsOfStructTy.size() &&
"post-condition");
1723 assert(out.mapOpGroups.empty() &&
"post-condition");
1724 assert(out.dimsPerGroup.empty() &&
"post-condition");
1729LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1730 MLIRContext *ctx = modOp.getContext();
1731 RewritePatternSet patterns(ctx);
1733 InstantiateAtCreateArrayOp,
1734 InstantiateAtCallOpCompute
1737 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1745class UpdateNewArrayElemFromWrite final :
public OpRewritePattern<CreateArrayOp> {
1746 ConversionTracker &tracker_;
1749 UpdateNewArrayElemFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1750 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1752 LogicalResult matchAndRewrite(
CreateArrayOp op, PatternRewriter &rewriter)
const override {
1754 ArrayType createResultType = dyn_cast<ArrayType>(createResult.getType());
1755 assert(createResultType &&
"CreateArrayOp must produce ArrayType");
1760 Type newResultElemType =
nullptr;
1761 for (Operation *user : createResult.getUsers()) {
1762 if (
WriteArrayOp writeOp = dyn_cast<WriteArrayOp>(user)) {
1763 if (writeOp.getArrRef() != createResult) {
1766 Type writeRValueType = writeOp.getRvalue().getType();
1767 if (writeRValueType == oldResultElemType) {
1770 if (newResultElemType && newResultElemType != writeRValueType) {
1773 <<
"[UpdateNewArrayElemFromWrite] multiple possible element types for CreateArrayOp "
1774 << newResultElemType <<
" vs " << writeRValueType <<
'\n'
1778 newResultElemType = writeRValueType;
1781 if (!newResultElemType) {
1785 if (!tracker_.isLegalConversion(
1786 oldResultElemType, newResultElemType,
"UpdateNewArrayElemFromWrite"
1791 rewriter.modifyOpInPlace(op, [&createResult, &newType]() { createResult.setType(newType); });
1793 llvm::dbgs() <<
"[UpdateNewArrayElemFromWrite] updated result type of " << op <<
'\n'
1801LogicalResult updateArrayElemFromArrAccessOp(
1803 PatternRewriter &rewriter
1810 if (oldArrType == newArrType ||
1811 !tracker.isLegalConversion(oldArrType, newArrType,
"updateArrayElemFromArrAccessOp")) {
1814 rewriter.modifyOpInPlace(op, [&op, &newArrType]() { op.
getArrRef().setType(newArrType); });
1816 llvm::dbgs() <<
"[updateArrayElemFromArrAccessOp] updated base array type in " << op <<
'\n'
1823class UpdateArrayElemFromArrWrite final :
public OpRewritePattern<WriteArrayOp> {
1824 ConversionTracker &tracker_;
1827 UpdateArrayElemFromArrWrite(MLIRContext *ctx, ConversionTracker &tracker)
1828 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1830 LogicalResult matchAndRewrite(
WriteArrayOp op, PatternRewriter &rewriter)
const override {
1831 return updateArrayElemFromArrAccessOp(op, op.
getRvalue().getType(), tracker_, rewriter);
1835class UpdateArrayElemFromArrRead final :
public OpRewritePattern<ReadArrayOp> {
1836 ConversionTracker &tracker_;
1839 UpdateArrayElemFromArrRead(MLIRContext *ctx, ConversionTracker &tracker)
1840 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1842 LogicalResult matchAndRewrite(
ReadArrayOp op, PatternRewriter &rewriter)
const override {
1843 return updateArrayElemFromArrAccessOp(op, op.
getResult().getType(), tracker_, rewriter);
1848class UpdateMemberDefTypeFromWrite final :
public OpRewritePattern<MemberDefOp> {
1849 ConversionTracker &tracker_;
1852 UpdateMemberDefTypeFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1853 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1855 LogicalResult matchAndRewrite(
MemberDefOp op, PatternRewriter &rewriter)
const override {
1858 assert(parentRes &&
"MemberDefOp parent is always StructDefOp");
1862 Type newType =
nullptr;
1864 std::optional<Location> newTypeLoc = std::nullopt;
1865 for (SymbolTable::SymbolUse symUse : memberUsers.value()) {
1866 if (
MemberWriteOp writeOp = llvm::dyn_cast<MemberWriteOp>(symUse.getUser())) {
1867 Type writeToType = writeOp.getVal().getType();
1868 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateMemberDefTypeFromWrite] checking " << writeOp <<
'\n');
1871 newType = writeToType;
1872 newTypeLoc = writeOp.getLoc();
1873 }
else if (writeToType != newType) {
1879 if (!tracker_.isLegalConversion(writeToType, newType,
"UpdateMemberDefTypeFromWrite")) {
1880 if (tracker_.isLegalConversion(
1881 newType, writeToType,
"UpdateMemberDefTypeFromWrite"
1884 newType = writeToType;
1885 newTypeLoc = writeOp.getLoc();
1888 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1892 "' with different value types"
1895 diag.attachNote(newTypeLoc).append(
"type written here is ", newType);
1897 diag.attachNote(writeOp.getLoc()).append(
"type written here is ", writeToType);
1905 if (!newType || newType == op.
getType()) {
1908 if (!tracker_.isLegalConversion(op.
getType(), newType,
"UpdateMemberDefTypeFromWrite")) {
1911 rewriter.modifyOpInPlace(op, [&op, &newType]() { op.
setType(newType); });
1912 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateMemberDefTypeFromWrite] updated type of " << op <<
'\n');
1919SmallVector<std::unique_ptr<Region>> moveRegions(Operation *op) {
1920 SmallVector<std::unique_ptr<Region>> newRegions;
1921 for (Region ®ion : op->getRegions()) {
1922 auto newRegion = std::make_unique<Region>();
1923 newRegion->takeBody(region);
1924 newRegions.push_back(std::move(newRegion));
1933class UpdateInferredResultTypes final :
public OpTraitRewritePattern<OpTrait::InferTypeOpAdaptor> {
1934 ConversionTracker &tracker_;
1937 UpdateInferredResultTypes(MLIRContext *ctx, ConversionTracker &tracker)
1938 : OpTraitRewritePattern(ctx, 6), tracker_(tracker) {}
1940 LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter)
const override {
1941 SmallVector<Type, 1> inferredResultTypes;
1942 InferTypeOpInterface retTypeFn = llvm::cast<InferTypeOpInterface>(op);
1943 LogicalResult result = retTypeFn.inferReturnTypes(
1944 op->getContext(), op->getLoc(), op->getOperands(), op->getRawDictionaryAttrs(),
1945 op->getPropertiesStorage(), op->getRegions(), inferredResultTypes
1947 if (failed(result)) {
1950 if (op->getResultTypes() == inferredResultTypes) {
1953 if (!tracker_.areLegalConversions(
1954 op->getResultTypes(), inferredResultTypes,
"UpdateInferredResultTypes"
1960 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateInferredResultTypes] replaced " << *op);
1961 SmallVector<std::unique_ptr<Region>> newRegions = moveRegions(op);
1962 Operation *newOp = rewriter.create(
1963 op->getLoc(), op->getName().getIdentifier(), op->getOperands(), inferredResultTypes,
1964 op->getAttrs(), op->getSuccessors(), newRegions
1966 rewriter.replaceOp(op, newOp);
1967 LLVM_DEBUG(llvm::dbgs() <<
" with " << *newOp <<
'\n');
1973class UpdateFuncTypeFromReturn final :
public OpRewritePattern<FuncDefOp> {
1974 ConversionTracker &tracker_;
1977 UpdateFuncTypeFromReturn(MLIRContext *ctx, ConversionTracker &tracker)
1978 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1980 LogicalResult matchAndRewrite(
FuncDefOp op, PatternRewriter &rewriter)
const override {
1981 Region &body = op.getFunctionBody();
1985 ReturnOp retOp = llvm::dyn_cast<ReturnOp>(body.back().getTerminator());
1986 assert(retOp &&
"final op in body region must be return");
1987 OperandRange::type_range tyFromReturnOp = retOp.
getOperands().getTypes();
1990 if (oldFuncTy.getResults() == tyFromReturnOp) {
1993 if (!tracker_.areLegalConversions(
1994 oldFuncTy.getResults(), tyFromReturnOp,
"UpdateFuncTypeFromReturn"
1999 rewriter.modifyOpInPlace(op, [&]() {
2000 op.
setFunctionType(rewriter.getFunctionType(oldFuncTy.getInputs(), tyFromReturnOp));
2003 llvm::dbgs() <<
"[UpdateFuncTypeFromReturn] changed " << op.
getSymName() <<
" from "
2014class UpdateFreeFuncCallOpTypes final :
public OpRewritePattern<CallOp> {
2015 ConversionTracker &tracker_;
2018 UpdateFreeFuncCallOpTypes(MLIRContext *ctx, ConversionTracker &tracker)
2019 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
2021 LogicalResult matchAndRewrite(
CallOp op, PatternRewriter &rewriter)
const override {
2022 if (calleeReferencesTemplateParam(op)) {
2025 SymbolTableCollection tables;
2027 if (failed(lookupRes)) {
2030 FuncDefOp targetFunc = lookupRes->get();
2035 if (op.getResultTypes() == targetFunc.
getFunctionType().getResults()) {
2038 if (!tracker_.areLegalConversions(
2040 "UpdateFreeFuncCallOpTypes"
2045 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateFreeFuncCallOpTypes] replaced " << op);
2048 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
2055LogicalResult updateMemberRefValFromMemberDef(
2058 SymbolTableCollection tables;
2063 Type oldResultType = op.
getVal().getType();
2064 Type newResultType = def->get().getType();
2065 if (oldResultType == newResultType ||
2066 !tracker.isLegalConversion(oldResultType, newResultType,
"updateMemberRefValFromMemberDef")) {
2069 rewriter.modifyOpInPlace(op, [&op, &newResultType]() { op.
getVal().setType(newResultType); });
2071 llvm::dbgs() <<
"[updateMemberRefValFromMemberDef] updated value type in " << op <<
'\n'
2079class UpdateMemberReadValFromDef final :
public OpRewritePattern<MemberReadOp> {
2080 ConversionTracker &tracker_;
2083 UpdateMemberReadValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
2084 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
2086 LogicalResult matchAndRewrite(
MemberReadOp op, PatternRewriter &rewriter)
const override {
2087 return updateMemberRefValFromMemberDef(op, tracker_, rewriter);
2092class UpdateMemberWriteValFromDef final :
public OpRewritePattern<MemberWriteOp> {
2093 ConversionTracker &tracker_;
2096 UpdateMemberWriteValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
2097 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
2099 LogicalResult matchAndRewrite(
MemberWriteOp op, PatternRewriter &rewriter)
const override {
2100 return updateMemberRefValFromMemberDef(op, tracker_, rewriter);
2104LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
2105 MLIRContext *ctx = modOp.getContext();
2106 RewritePatternSet patterns(ctx);
2111 UpdateInferredResultTypes,
2113 UpdateFreeFuncCallOpTypes,
2114 UpdateFuncTypeFromReturn,
2115 UpdateNewArrayElemFromWrite,
2116 UpdateArrayElemFromArrRead,
2117 UpdateArrayElemFromArrWrite,
2118 UpdateMemberDefTypeFromWrite,
2119 UpdateMemberReadValFromDef,
2120 UpdateMemberWriteValFromDef
2123 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
2131 SymbolTableCollection tables;
2134 : rootMod(root), defTree(symDefTree), useGraph(symUseGraph) {}
2142struct FromKeepSet :
public CleanupBase {
2143 using CleanupBase::CleanupBase;
2147 static bool hasTemplateSymbolBindings(Operation *op) {
2148 if (
StructDefOp sdef = llvm::dyn_cast<StructDefOp>(op)) {
2149 return sdef.hasTemplateSymbolBindings();
2151 if (llvm::isa<function::FuncDefOp>(op)) {
2160 static bool isErasableDefinition(Operation *op) {
2161 if (llvm::isa<StructDefOp>(op)) {
2165 return !fdef.isInStruct();
2173 LogicalResult eraseUnreachableFrom(ArrayRef<SymbolOpInterface> keep) {
2175 SetVector<SymbolOpInterface> roots(keep.begin(), keep.end());
2182 DenseSet<Operation *> defsToKeep;
2183 llvm::df_iterator_default_set<const SymbolUseGraphNode *> symbolsToKeep;
2184 for (
size_t i = 0; i < roots.size(); ++i) {
2185 SymbolOpInterface keepRoot = roots[i];
2186 LLVM_DEBUG({ llvm::dbgs() <<
"[EraseUnreachable] root: " << keepRoot <<
'\n'; });
2188 assert(keepRootNode &&
"every symbol def must be in the def tree");
2189 for (
const SymbolDefTreeNode *reachableDefNode : llvm::depth_first(keepRootNode)) {
2191 llvm::dbgs() <<
"[EraseUnreachable] can reach: " << reachableDefNode->getOp() <<
'\n';
2193 if (SymbolOpInterface reachableDef = reachableDefNode->getOp()) {
2194 if (isErasableDefinition(reachableDef.getOperation())) {
2195 defsToKeep.insert(reachableDef.getOperation());
2201 if (
const SymbolUseGraphNode *useGraphNodeForDef = useGraph.lookupNode(reachableDef)) {
2203 depth_first_ext(useGraphNodeForDef, symbolsToKeep)) {
2205 llvm::dbgs() <<
"[EraseUnreachable] uses symbol: "
2206 << usedSymbolNode->getSymbolPath() <<
'\n';
2210 if (usedSymbolNode->isTemplateSymbolBinding()) {
2215 auto lookupRes = usedSymbolNode->lookupSymbol(tables);
2216 if (failed(lookupRes)) {
2217 LLVM_DEBUG(useGraph.dumpToDotFile());
2221 if (lookupRes->viaInclude()) {
2224 Operation *usedOp = lookupRes->get();
2225 if (isErasableDefinition(usedOp)) {
2226 SymbolOpInterface asSymbol = llvm::cast<SymbolOpInterface>(usedOp);
2227 bool insertRes = roots.insert(asSymbol);
2231 llvm::dbgs() <<
"[EraseUnreachable] found another root: " << asSymbol <<
'\n';
2241 SmallVector<SymbolOpInterface> toErase;
2242 rootMod.walk([
this, &defsToKeep, &symbolsToKeep, &toErase](Operation *op) {
2243 if (!isErasableDefinition(op) || defsToKeep.contains(op)) {
2246 SymbolOpInterface symOp = llvm::cast<SymbolOpInterface>(op);
2248 if (!n || !symbolsToKeep.contains(n)) {
2249 LLVM_DEBUG(llvm::dbgs() <<
"[EraseUnreachable] removing: " << symOp.getNameAttr() <<
'\n');
2250 toErase.push_back(symOp);
2253 for (SymbolOpInterface symOp : toErase) {
2261struct FromEraseSet :
public CleanupBase {
2266 DenseSet<SymbolRefAttr> &&tryToErasePaths
2268 : CleanupBase(root, symDefTree, symUseGraph) {
2270 for (SymbolRefAttr path : tryToErasePaths) {
2271 LLVM_DEBUG(llvm::dbgs() <<
"[FromEraseSet] path to erase: " << path <<
'\n';);
2272 Operation *lookupFrom = rootMod.getOperation();
2274 assert(succeeded(res) &&
"inputs must be valid symbol references");
2275 assert(FromKeepSet::isErasableDefinition(res->get()) &&
"inputs must be cleanup candidates");
2276 if (!res->viaInclude()) {
2277 SymbolOpInterface op = llvm::cast<SymbolOpInterface>(res->get());
2278 LLVM_DEBUG(llvm::dbgs() <<
"[FromEraseSet] added op to the erase set: " << op <<
'\n';);
2279 tryToErase.insert(op);
2282 llvm::dbgs() <<
"[FromEraseSet] ignored op because it comes from an include: "
2283 << res->get() <<
'\n';
2289 LogicalResult eraseUnusedDefinitions() {
2291 for (SymbolOpInterface sym : tryToErase) {
2292 collectSafeToErase(sym);
2296 for (
auto &it : llvm::make_early_inc_range(visitedPlusSafetyResult)) {
2297 if (!it.second || !tryToErase.contains(it.first)) {
2298 visitedPlusSafetyResult.erase(it.first);
2301 for (
auto &[sym, _] : visitedPlusSafetyResult) {
2302 LLVM_DEBUG(llvm::dbgs() <<
"[EraseIfUnused] removing: " << sym.getNameAttr() <<
'\n');
2308 const DenseSet<SymbolOpInterface> &getTryToEraseSet()
const {
return tryToErase; }
2312 DenseSet<SymbolOpInterface> tryToErase;
2316 DenseMap<SymbolOpInterface, bool> visitedPlusSafetyResult;
2318 DenseMap<const SymbolUseGraphNode *, SymbolOpInterface> lookupCache;
2322 bool collectSafeToErase(SymbolOpInterface
check) {
2326 auto visited = visitedPlusSafetyResult.find(
check);
2327 if (visited != visitedPlusSafetyResult.end()) {
2328 return visited->second;
2332 if (FromKeepSet::isErasableDefinition(
check.getOperation()) && !tryToErase.contains(
check)) {
2333 visitedPlusSafetyResult[
check] =
false;
2339 visitedPlusSafetyResult[
check] =
true;
2344 if (collectSafeToErase(defTree.lookupNode(
check))) {
2345 const auto *useNode = useGraph.lookupNode(
check);
2346 assert(useNode || (llvm::isa<ModuleOp, TemplateOp>(
check.getOperation())));
2347 if (!useNode || collectSafeToErase(useNode)) {
2353 visitedPlusSafetyResult[
check] =
false;
2361 if (SymbolOpInterface checkOp = p->getOp()) {
2362 return collectSafeToErase(checkOp);
2372 if (SymbolOpInterface checkOp = cachedLookup(p)) {
2373 if (!collectSafeToErase(checkOp)) {
2386 assert(node &&
"must provide a node");
2388 auto fromCache = lookupCache.find(node);
2389 if (fromCache != lookupCache.end()) {
2390 return fromCache->second;
2394 assert(succeeded(lookupRes) &&
"graph contains node with invalid path");
2395 assert(lookupRes->get() !=
nullptr &&
"lookup must return an Operation");
2400 SymbolOpInterface actualRes =
2401 lookupRes->viaInclude() ? nullptr : llvm::cast<SymbolOpInterface>(lookupRes->get());
2403 lookupCache[node] = actualRes;
2404 assert((!actualRes == lookupRes->viaInclude()) &&
"not found iff included");
2412 using Base = FlatteningPassBase<PassImpl>;
2415 void runOnOperation()
override {
2416 ModuleOp modOp = getOperation();
2417 if (failed(runOn(modOp))) {
2420 llvm::dbgs() <<
"=====================================================================\n";
2421 llvm::dbgs() <<
" Dumping module after failure of pass " <<
DEBUG_TYPE <<
'\n';
2422 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
2423 llvm::dbgs() <<
"=====================================================================\n";
2425 signalPassFailure();
2429 inline LogicalResult runOn(ModuleOp modOp) {
2434 if (cleanupMode == FlatteningCleanupMode::MainAsRoot) {
2435 if (failed(eraseUnreachableFromMainStruct(modOp))) {
2443 OpPassManager universalCleanup(ModuleOp::getOperationName());
2448 if (failed(runPipeline(universalCleanup, modOp))) {
2452 ConversionTracker tracker;
2453 unsigned loopCount = 0;
2456 if (loopCount > iterationLimit) {
2457 llvm::errs() <<
DEBUG_TYPE <<
" exceeded the limit of " << iterationLimit
2458 <<
" iterations!\n";
2461 tracker.resetModifiedFlag();
2464 llvm::dbgs() <<
"[FlatteningPass(count=" << loopCount
2465 <<
")] Running step 1: struct instantiation\n";
2470 if (failed(Step1A_InstantiateStructs::run(modOp, tracker))) {
2471 llvm::errs() <<
DEBUG_TYPE <<
" failed while instantiating structs in templates\n";
2475 if (failed(Step1B_InstantiateFunctions::run(modOp, tracker))) {
2476 llvm::errs() <<
DEBUG_TYPE <<
" failed while instantiating functions in templates\n";
2481 llvm::dbgs() <<
"[FlatteningPass(count=" << loopCount
2482 <<
")] Running step 2: loop unrolling\n";
2485 if (failed(Step2_Unroll::run(modOp, tracker))) {
2486 llvm::errs() <<
DEBUG_TYPE <<
" failed while unrolling loops\n";
2491 llvm::dbgs() <<
"[FlatteningPass(count=" << loopCount
2492 <<
")] Running step 3: affine maps instantiation\n";
2495 if (failed(Step3_InstantiateAffineMaps::run(modOp, tracker))) {
2496 llvm::errs() <<
DEBUG_TYPE <<
" failed while instantiating `affine_map` parameters\n";
2501 llvm::dbgs() <<
"[FlatteningPass(count=" << loopCount
2502 <<
")] Running step 4: type propagation\n";
2505 if (failed(Step4_PropagateTypes::run(modOp, tracker))) {
2506 llvm::errs() <<
DEBUG_TYPE <<
" failed while propagating instantiated types\n";
2510 LLVM_DEBUG(
if (tracker.isModified()) {
2511 llvm::dbgs() <<
"=====================================================================\n";
2512 llvm::dbgs() <<
" Dumping module between iterations of " << DEBUG_TYPE <<
'\n';
2513 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
2514 llvm::dbgs() <<
"=====================================================================\n";
2516 }
while (tracker.isModified());
2519 if (failed(cleanupSwitch(modOp, tracker))) {
2523 if (failed(runPipeline(universalCleanup, modOp))) {
2527 OpPassManager allocationCleanup(ModuleOp::getOperationName());
2529 RemoveUnusedDiscardableAllocationsPassOptions {
2533 return runPipeline(allocationCleanup, modOp);
2537 LogicalResult cleanupSwitch(ModuleOp modOp,
const ConversionTracker &tracker) {
2538 LLVM_DEBUG({ llvm::dbgs() <<
"[FlatteningPass] Running step 5: cleanup "; });
2539 switch (cleanupMode) {
2540 case FlatteningCleanupMode::MainAsRoot:
2541 LLVM_DEBUG(llvm::dbgs() <<
"(main as root mode)\n");
2542 return eraseUnreachableFromMainStruct(modOp,
false);
2543 case FlatteningCleanupMode::ConcreteAsRoot:
2544 LLVM_DEBUG(llvm::dbgs() <<
"(concrete definitions mode)\n");
2545 return eraseUnreachableFromConcreteDefinitions(modOp);
2546 case FlatteningCleanupMode::Preimage:
2547 LLVM_DEBUG(llvm::dbgs() <<
"(preimage mode)\n");
2548 return erasePreimageOfInstantiations(modOp, tracker);
2550 LLVM_DEBUG(llvm::dbgs() <<
"(disabled)\n");
2556 LogicalResult erasePreimageOfInstantiations(ModuleOp rootMod,
const ConversionTracker &tracker) {
2561 Step5_Cleanup::FromEraseSet cleaner(
2562 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>(),
2563 tracker.getInstantiatedDefinitionNames()
2565 LogicalResult res = cleaner.eraseUnusedDefinitions();
2566 if (succeeded(res)) {
2567 LLVM_DEBUG(llvm::dbgs() <<
"[Cleanup(preimage)] success\n";);
2569 const SymbolUseGraph *useGraph =
nullptr;
2570 rootMod->walk([
this, &cleaner, &useGraph](Operation *walkedOp) {
2571 SymbolOpInterface op = llvm::dyn_cast<SymbolOpInterface>(walkedOp);
2572 if (!op || !cleaner.getTryToEraseSet().contains(op)) {
2577 useGraph = &getAnalysis<SymbolUseGraph>();
2580 if (useGraph->lookupNode(op)->hasPredecessor()) {
2581 op.emitWarning(
"Parameterized definition still has uses!").report();
2585 LLVM_DEBUG(llvm::dbgs() <<
"[Cleanup(preimage)] failed\n";);
2590 LogicalResult eraseUnreachableFromConcreteDefinitions(ModuleOp rootMod) {
2591 SmallVector<SymbolOpInterface> roots;
2592 rootMod.walk([&roots](Operation *op) {
2593 if (Step5_Cleanup::FromKeepSet::isErasableDefinition(op) &&
2594 !Step5_Cleanup::FromKeepSet::hasTemplateSymbolBindings(op)) {
2595 roots.push_back(llvm::cast<SymbolOpInterface>(op));
2599 Step5_Cleanup::FromKeepSet cleaner(
2600 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
2602 return cleaner.eraseUnreachableFrom(roots);
2605 LogicalResult eraseUnreachableFromMainStruct(ModuleOp rootMod,
bool emitWarning =
true) {
2606 Step5_Cleanup::FromKeepSet cleaner(
2607 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
2609 FailureOr<SymbolLookupResult<StructDefOp>> mainOpt =
2611 if (failed(mainOpt)) {
2614 SymbolLookupResult<StructDefOp>
main = mainOpt.value();
2615 if (emitWarning && !
main) {
2618 rootMod.emitWarning()
2620 "using option '", cleanupMode.getArgStr(),
'=',
2623 "\" attribute on the top-level module may remove all cleanup-candidate definitions!"
2627 SmallVector<SymbolOpInterface> roots;
2629 roots.push_back(*
main);
2631 return cleaner.eraseUnreachableFrom(roots);
Common private implementation for poly dialect passes.
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Builds a tree structure representing the symbol table structure.
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
Builds a graph structure representing the relationships between symbols and their uses.
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
Gets the SSA Value for the referenced array.
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced array.
ArrayType cloneWith(std::optional<::llvm::ArrayRef< int64_t > > shape, ::mlir::Type elementType) const
Clone this type with the given shape and element type.
::mlir::Type getElementType() const
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
static constexpr ::llvm::StringLiteral getOperationName()
::mlir::OperandRangeRange getMapOperands()
::mlir::TypedValue<::mlir::Type > getResult()
::mlir::TypedValue<::mlir::Type > getRvalue()
static constexpr ::llvm::StringLiteral getOperationName()
void setType(::mlir::Type attrValue)
::std::optional<::mlir::Attribute > getTableOffset()
void setTableOffsetAttr(::mlir::Attribute attr)
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the MemberRefOp.
::mlir::FailureOr< SymbolLookupResult< MemberDefOp > > getMemberDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the member referenced in this op.
static constexpr ::llvm::StringLiteral getOperationName()
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
static constexpr ::llvm::StringLiteral getOperationName()
::llvm::StringRef getSymName()
void setSymName(::llvm::StringRef attrValue)
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
::mlir::ArrayAttr getParams() const
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the callee is FUNC_NAME_COMPUTE, return the single StructType result.
::mlir::SymbolRefAttr getCalleeAttr()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
::mlir::SymbolRefAttr getCallee()
::mlir::FunctionType getTypeSignature()
Return the FunctionType inferred from the arg operands and result types of this CallOp.
void setTemplateParamsAttr(::mlir::ArrayAttr attr)
::mlir::Operation::operand_range getArgOperands()
::mlir::ArrayAttr getTemplateParamsAttr()
::mlir::OperandRangeRange getMapOperands()
void setCalleeAttr(::mlir::SymbolRefAttr attr)
::mlir::FailureOr< UnificationMap > unifyTypeSignature(::mlir::FunctionType other)
Attempt type unfication between the inferred FunctionType from this CallOp (as LHS) and the given Fun...
::mlir::LogicalResult verifyTemplateParamsMatchInferred(::llvm::iterator_range<::mlir::Region::op_iterator<::llzk::polymorphic::TemplateParamOp > > targetParamDefs, const UnificationMap &unifications)
Verify that each template parameter value provided in this CallOp is consistent with the value inferr...
::mlir::LogicalResult verifyTemplateParamCompatibility(::mlir::Attribute paramFromCallOp, ::llzk::polymorphic::TemplateParamOp targetParam)
Check type compatibility of the given template parameter value from this CallOp against the declared ...
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
FuncDefOp clone(::mlir::IRMapping &mapper)
Create a deep copy of this function and all of its blocks, remapping any operands that use values out...
::mlir::FunctionType getFunctionType()
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Required by FunctionOpInterface.
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the name is FUNC_NAME_COMPUTE, return the single StructType result.
::llvm::StringRef getSymName()
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
bool isInStruct()
Return true iff the function is within a StructDefOp.
void setFunctionType(::mlir::FunctionType attrValue)
void setSymName(::llvm::StringRef attrValue)
::mlir::StringAttr getSymNameAttr()
::mlir::Operation::operand_range getOperands()
::mlir::FlatSymbolRefAttr getConstNameAttr()
::mlir::StringAttr getSymNameAttr()
::mlir::Region & getInitializerRegion()
::llvm::StringRef getSymName()
::mlir::Region & getBodyRegion()
::llvm::SmallVector<::mlir::Attribute > getConstNames()
Return the names of all ops of type OpT within the body region in the order they are defined.
bool hasConstNamed(::mlir::StringRef find)
Return true if there is an op of type OpT with the given name within the body region.
::mlir::StringAttr getSymNameAttr()
void setSymName(::llvm::StringRef attrValue)
::llvm::StringRef getSymName()
inline ::llvm::iterator_range<::mlir::Region::op_iterator< OpT > > getConstOps()
Return ops of type OpT within the body region.
::mlir::FlatSymbolRefAttr getNameRef() const
int main(int argc, char **argv)
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
mlir::ConversionTarget newConverterDefinedTargetWithCallback(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, LegalityCheckCallback &cb, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
std::unique_ptr<::mlir::Pass > createEmptyTemplateRemovalPass()
::llvm::StringRef stringifyFlatteningCleanupMode(FlatteningCleanupMode val)
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter::replaceOpWithNewOp() that automatically copies discardable attributes (i...
std::unique_ptr<::mlir::Pass > createRemoveUnusedDiscardableAllocationsPass()
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
bool isConcreteType(Type type, bool allowStructParams)
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
TypeClass getIfSingleton(mlir::TypeRange types)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
std::string stringWithoutType(mlir::Attribute a)
bool isNullOrEmpty(mlir::ArrayAttr a)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent/ancestor operation that is of type 'OpClass'.
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet covering all LLZK op types that may contain a StructType.
bool isDynamic(IntegerAttr intAttr)
mlir::SymbolRefAttr asSymbolRefAttr(mlir::StringAttr root, mlir::SymbolRefAttr tail)
Build a SymbolRefAttr that prepends tail with root, i.e., root::tail.
int64_t fromAPInt(const llvm::APInt &i)
FailureOr< SymbolLookupResult< StructDefOp > > getMainInstanceDef(SymbolTableCollection &symbolTable, Operation *lookupFrom)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
llvm::SmallVector< FlatSymbolRefAttr > getPieces(SymbolRefAttr ref)
constexpr char MAIN_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that specifies the type of the main struct.