35#include <mlir/Dialect/Affine/IR/AffineOps.h>
36#include <mlir/Dialect/Affine/LoopUtils.h>
37#include <mlir/Dialect/Arith/IR/Arith.h>
38#include <mlir/Dialect/SCF/IR/SCF.h>
39#include <mlir/Dialect/SCF/Utils/Utils.h>
40#include <mlir/Dialect/Utils/StaticValueUtils.h>
41#include <mlir/IR/Attributes.h>
42#include <mlir/IR/BuiltinAttributes.h>
43#include <mlir/IR/BuiltinOps.h>
44#include <mlir/IR/BuiltinTypes.h>
45#include <mlir/Interfaces/InferTypeOpInterface.h>
46#include <mlir/Pass/PassManager.h>
47#include <mlir/Support/LLVM.h>
48#include <mlir/Support/LogicalResult.h>
49#include <mlir/Transforms/DialectConversion.h>
50#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
51#include <mlir/Transforms/WalkPatternRewriteDriver.h>
53#include <llvm/ADT/APInt.h>
54#include <llvm/ADT/DenseMap.h>
55#include <llvm/ADT/DepthFirstIterator.h>
56#include <llvm/ADT/STLExtras.h>
57#include <llvm/ADT/SmallVector.h>
58#include <llvm/ADT/TypeSwitch.h>
59#include <llvm/Support/Debug.h>
63#define GEN_PASS_DECL_FLATTENINGPASS
64#define GEN_PASS_DEF_FLATTENINGPASS
70#define DEBUG_TYPE "llzk-flatten"
84static void reportDelayedDiagnostics(
CallOp caller, SmallVector<Diagnostic> &&diagnostics) {
85 DiagnosticEngine &engine = caller.getContext()->getDiagEngine();
86 for (Diagnostic &diag : diagnostics) {
88 for (Diagnostic ¬e : diag.getNotes()) {
89 assert(note.getNotes().empty() &&
"notes cannot have notes attached");
90 if (llvm::isa<UnknownLoc>(note.getLocation())) {
91 note = std::move(Diagnostic(caller.getLoc(), note.getSeverity()).append(note.str()));
95 engine.emit(std::move(diag));
99class ConversionTracker {
104 DenseMap<StructType, StructType> structInstantiations;
106 DenseMap<StructType, StructType> reverseInstantiations;
109 DenseMap<StructType, SmallVector<Diagnostic>> delayedDiagnostics;
112 bool isModified()
const {
return modified; }
113 void resetModifiedFlag() { modified =
false; }
114 void updateModifiedFlag(
bool currStepModified) { modified |= currStepModified; }
119 auto forwardResult = structInstantiations.try_emplace(oldType, newType);
120 if (forwardResult.second) {
123 assert(!reverseInstantiations.contains(newType));
124 reverseInstantiations[newType] = oldType;
129 assert(forwardResult.first->getSecond() == newType);
131 assert(reverseInstantiations.lookup(newType) == oldType);
133 assert(structInstantiations.size() == reverseInstantiations.size());
137 std::optional<StructType> getInstantiation(
StructType oldType)
const {
138 auto cachedResult = structInstantiations.find(oldType);
139 if (cachedResult != structInstantiations.end()) {
140 return cachedResult->second;
146 DenseSet<SymbolRefAttr> getInstantiatedStructNames()
const {
147 DenseSet<SymbolRefAttr> instantiatedNames;
148 for (
const auto &[origRemoteTy, _] : structInstantiations) {
149 instantiatedNames.insert(origRemoteTy.getNameRef());
151 return instantiatedNames;
155 auto res = delayedDiagnostics.find(newType);
156 if (res != delayedDiagnostics.end()) {
157 ::reportDelayedDiagnostics(caller, std::move(res->second));
162 delayedDiagnostics.erase(newType);
166 SmallVector<Diagnostic> &delayedDiagnosticSet(
StructType newType) {
167 return delayedDiagnostics[newType];
172 bool isLegalConversion(Type oldType, Type newType,
const char *patName)
const {
173 std::function<bool(Type, Type)> checkInstantiations = [&](Type oTy, Type nTy) {
175 if (
StructType oldStructType = llvm::dyn_cast<StructType>(oTy)) {
178 if (this->structInstantiations.lookup(oldStructType) == nTy) {
185 if (StructType newStructType = llvm::dyn_cast<StructType>(nTy)) {
186 if (
auto preImage = this->reverseInstantiations.lookup(newStructType)) {
199 llvm::dbgs() <<
"[" << patName <<
"] Cannot replace old type " << oldType
200 <<
" with new type " << newType
201 <<
" because it does not define a compatible and more concrete type.\n";
206 template <
typename T,
typename U>
207 inline bool areLegalConversions(T oldTypes, U newTypes,
const char *patName)
const {
209 llvm::zip_equal(oldTypes, newTypes), [
this, &patName](std::tuple<Type, Type> oldThenNew) {
210 return this->isLegalConversion(std::get<0>(oldThenNew), std::get<1>(oldThenNew), patName);
216template <
typename Impl,
typename Op,
typename... HandledAttrs>
217class SymbolUserHelper :
public OpConversionPattern<Op> {
219 const DenseMap<Attribute, Attribute> ¶mNameToValue;
222 TypeConverter &converter, MLIRContext *ctx,
unsigned patternBenefit,
223 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
225 : OpConversionPattern<Op>(converter, ctx, patternBenefit),
226 paramNameToValue(paramNameToInstantiatedValue) {}
229 using OpAdaptor =
typename mlir::OpConversionPattern<Op>::OpAdaptor;
231 virtual Attribute getNameAttr(Op)
const = 0;
233 virtual LogicalResult handleDefaultRewrite(
234 Attribute, Op op, OpAdaptor, ConversionPatternRewriter &, Attribute a
236 return op->emitOpError().append(
"expected value with type ", op.getType(),
" but found ", a);
240 matchAndRewrite(Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter)
const override {
241 LLVM_DEBUG(llvm::dbgs() <<
"[SymbolUserHelper] op: " << op <<
'\n');
242 auto res = this->paramNameToValue.find(getNameAttr(op));
243 if (res == this->paramNameToValue.end()) {
244 LLVM_DEBUG(llvm::dbgs() <<
"[SymbolUserHelper] no instantiation for " << op <<
'\n');
247 llvm::TypeSwitch<Attribute, LogicalResult> TS(res->second);
248 llvm::TypeSwitch<Attribute, LogicalResult> *ptr = &TS;
250 ((ptr = &(ptr->template Case<HandledAttrs>([&](HandledAttrs a) {
251 return static_cast<const Impl *
>(
this)->handleRewrite(res->first, op, adaptor, rewriter, a);
255 return TS.Default([&](Attribute a) {
256 return handleDefaultRewrite(res->first, op, adaptor, rewriter, a);
262class ClonedBodyConstReadOpPattern
263 :
public SymbolUserHelper<
264 ClonedBodyConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr> {
265 SmallVector<Diagnostic> &diagnostics;
268 SymbolUserHelper<ClonedBodyConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr>;
271 ClonedBodyConstReadOpPattern(
272 TypeConverter &converter, MLIRContext *ctx,
273 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue,
274 SmallVector<Diagnostic> &instantiationDiagnostics
277 : super(converter, ctx, 1, paramNameToInstantiatedValue),
278 diagnostics(instantiationDiagnostics) {}
280 Attribute getNameAttr(ConstReadOp op)
const override {
return op.
getConstNameAttr(); }
282 LogicalResult handleRewrite(
283 Attribute sym, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, IntegerAttr a
285 APInt attrValue = a.getValue();
286 Type origResTy = op.getType();
287 if (FeltType ty = llvm::dyn_cast<FeltType>(origResTy)) {
289 rewriter, op, FeltConstAttr::get(getContext(), attrValue, ty)
294 if (llvm::isa<IndexType>(origResTy)) {
299 if (origResTy.isSignlessInteger(1)) {
301 if (attrValue.isZero()) {
305 if (!attrValue.isOne()) {
306 Location opLoc = op.getLoc();
307 Diagnostic diag(opLoc, DiagnosticSeverity::Warning);
309 if (getContext()->shouldPrintOpOnDiagnostic()) {
310 diag.attachNote(opLoc) <<
"see current operation: " << *op;
312 diag.attachNote(UnknownLoc::get(getContext()))
314 <<
"\" for this call";
315 diagnostics.push_back(std::move(diag));
320 return op->emitOpError().append(
"unexpected result type ", origResTy);
323 LogicalResult handleRewrite(
324 Attribute, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, FeltConstAttr a
333struct MatchFailureListener :
public RewriterBase::Listener {
334 bool hadFailure =
false;
336 ~MatchFailureListener()
override {}
338 void notifyMatchFailure(Location loc, function_ref<
void(Diagnostic &)> reasonCallback)
override {
341 InFlightDiagnostic diag = emitError(loc);
342 reasonCallback(*diag.getUnderlyingDiagnostic());
348applyAndFoldGreedily(ModuleOp modOp, ConversionTracker &tracker, RewritePatternSet &&patterns) {
349 bool currStepModified =
false;
350 MatchFailureListener failureListener;
351 LogicalResult result = applyPatternsGreedily(
352 modOp->getRegion(0), std::move(patterns),
353 GreedyRewriteConfig {.maxIterations = 20, .listener = &failureListener, .fold = true},
356 tracker.updateModifiedFlag(currStepModified);
357 return failure(result.failed() || failureListener.hadFailure);
360template <
bool AllowStructParams = true>
bool isConcreteAttr(Attribute a) {
361 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(a)) {
364 if (IntegerAttr intAttr = dyn_cast<IntegerAttr>(a)) {
374static std::optional<Attribute>
375evaluateExpr(
TemplateExprOp exprOp,
const DenseMap<Attribute, Attribute> ¶mNameToConcrete) {
377 DenseMap<Value, Attribute> valueMap;
379 if (
auto yieldOp = llvm::dyn_cast<YieldOp>(bodyOp)) {
380 auto it = valueMap.find(yieldOp.getVal());
381 return it != valueMap.end() ? std::make_optional(it->second) : std::nullopt;
384 if (
auto constReadOp = llvm::dyn_cast<ConstReadOp>(bodyOp)) {
385 auto it = paramNameToConcrete.find(constReadOp.getConstNameAttr());
386 if (it == paramNameToConcrete.end()) {
391 Attribute val = it->second;
392 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(val)) {
393 if (
auto feltTy = llvm::dyn_cast<FeltType>(constReadOp.getResult().getType())) {
394 val = FeltConstAttr::get(bodyOp.getContext(), intAttr.getValue(), feltTy);
397 valueMap[constReadOp.getResult()] = val;
402 SmallVector<Attribute> operandAttrs;
403 operandAttrs.reserve(bodyOp.getNumOperands());
404 for (Value operand : bodyOp.getOperands()) {
405 auto it = valueMap.find(operand);
406 if (it == valueMap.end()) {
409 operandAttrs.push_back(it->second);
413 SmallVector<OpFoldResult> foldResults;
414 if (succeeded(bodyOp.fold(operandAttrs, foldResults)) &&
415 foldResults.size() == bodyOp.getNumResults()) {
416 for (
auto [result, fr] : llvm::zip_equal(bodyOp.getResults(), foldResults)) {
417 if (Attribute a = llvm::dyn_cast<Attribute>(fr)) {
418 valueMap[result] = a;
432evaluateTemplateExprs(
TemplateOp templateOp, DenseMap<Attribute, Attribute> ¶mNameToConcrete) {
434 llvm::dbgs() <<
"[evaluateTemplateExprs] before: " <<
debug::toStringList(paramNameToConcrete)
438 std::optional<Attribute> result = evaluateExpr(exprOp, paramNameToConcrete);
439 if (result.has_value()) {
440 auto exprNameAttr = FlatSymbolRefAttr::get(exprOp.
getSymNameAttr());
441 paramNameToConcrete.try_emplace(exprNameAttr, *result);
443 llvm::dbgs() <<
"[evaluateTemplateExprs] expr @" << exprOp.
getSymName()
444 <<
" evaluated to " << *result <<
'\n'
449 llvm::dbgs() <<
"[evaluateTemplateExprs] after: " <<
debug::toStringList(paramNameToConcrete)
456static inline bool tableOffsetIsntSymbol(
MemberReadOp op) {
457 return !llvm::isa_and_present<SymbolRefAttr>(op.
getTableOffset().value_or(
nullptr));
463 ConversionTracker &tracker_;
465 SymbolTableCollection symTables;
466 bool reportMissing =
true;
468 class MappedTypeConverter :
public TypeConverter {
471 const DenseMap<Attribute, Attribute> ¶mNameToValue;
473 inline Attribute convertIfPossible(Attribute a)
const {
474 auto res = this->paramNameToValue.find(a);
475 return (res != this->paramNameToValue.end()) ? res->second : a;
482 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
484 : TypeConverter(), origTy(originalType), newTy(newType),
485 paramNameToValue(paramNameToInstantiatedValue) {
487 addConversion([](Type inputTy) {
return inputTy; });
490 LLVM_DEBUG(llvm::dbgs() <<
"[MappedTypeConverter] convert " << inputTy <<
'\n');
493 if (inputTy == this->origTy) {
497 if (ArrayAttr inputTyParams = inputTy.getParams()) {
498 SmallVector<Attribute> updated;
499 for (Attribute a : inputTyParams) {
500 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
501 updated.push_back(TypeAttr::get(this->convertType(ta.getValue())));
503 updated.push_back(convertIfPossible(a));
507 inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
514 addConversion([
this](
ArrayType inputTy) {
516 ArrayRef<Attribute> dimSizes = inputTy.getDimensionSizes();
517 if (!dimSizes.empty()) {
518 SmallVector<Attribute> updated;
519 for (Attribute a : dimSizes) {
520 updated.push_back(convertIfPossible(a));
522 return ArrayType::get(this->convertType(inputTy.getElementType()), updated);
528 addConversion([
this](
TypeVarType inputTy) -> Type {
530 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.getNameRef()))) {
531 Type convertedType = tyAttr.getValue();
536 return convertedType;
544 class ClonedStructMemberReadOpPattern
545 :
public SymbolUserHelper<
546 ClonedStructMemberReadOpPattern, MemberReadOp, IntegerAttr, FeltConstAttr> {
548 SymbolUserHelper<ClonedStructMemberReadOpPattern, MemberReadOp, IntegerAttr, FeltConstAttr>;
551 ClonedStructMemberReadOpPattern(
552 TypeConverter &converter, MLIRContext *ctx,
553 const DenseMap<Attribute, Attribute> ¶mNameToInstantiatedValue
556 : super(converter, ctx, 1, paramNameToInstantiatedValue) {}
562 template <
typename Attr>
563 LogicalResult handleRewrite(
564 Attribute,
MemberReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, Attr a
566 rewriter.modifyOpInPlace(op, [&]() {
573 LogicalResult matchAndRewrite(
574 MemberReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
577 llvm::dbgs() <<
"[ClonedStructMemberReadOpPattern] MemberReadOp: " << op <<
'\n';
579 if (tableOffsetIsntSymbol(op)) {
583 return super::matchAndRewrite(op, adaptor, rewriter);
587 FailureOr<StructType> genClone(
StructType typeAtCaller, ArrayRef<Attribute> typeAtCallerParams) {
588 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] attempting clone of " << typeAtCaller <<
'\n');
590 FailureOr<SymbolLookupResult<StructDefOp>> r =
591 typeAtCaller.
getDefinition(symTables, rootMod, reportMissing);
593 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: cannot find StructDefOp \n");
596 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] found definition\n";);
600 MLIRContext *ctx = origStruct.getContext();
603 DenseMap<Attribute, Attribute> paramNameToConcrete;
607 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
610 SmallVector<Attribute> remainingNames;
612 ArrayAttr reducedCallerParams =
nullptr;
614 ArrayAttr paramNames = typeAtDef.
getParams();
618 assert(paramNames.size() == typeAtCallerParams.size());
620 SmallVector<Attribute> nonConcreteParams;
621 for (
size_t i = 0, e = paramNames.size(); i < e; ++i) {
622 Attribute next = typeAtCallerParams[i];
623 if (isConcreteAttr<false>(next)) {
624 paramNameToConcrete[paramNames[i]] = next;
625 attrsForInstantiatedNameSuffix.push_back(next);
627 remainingNames.push_back(paramNames[i]);
628 nonConcreteParams.push_back(next);
629 attrsForInstantiatedNameSuffix.push_back(
nullptr);
633 assert(remainingNames.size() == nonConcreteParams.size());
634 assert(attrsForInstantiatedNameSuffix.size() == paramNames.size());
635 assert(remainingNames.size() + paramNameToConcrete.size() == paramNames.size());
637 if (paramNameToConcrete.empty()) {
638 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: no concrete params \n");
641 if (!remainingNames.empty()) {
642 reducedCallerParams = ArrayAttr::get(ctx, nonConcreteParams);
647 SmallVector<FlatSymbolRefAttr> typeAtCallerSymPieces =
getPieces(typeAtCaller.
getNameRef());
648 typeAtCallerSymPieces.pop_back();
651 typeAtCallerSymPieces.back().getValue().str(), attrsForInstantiatedNameSuffix
656 assert(parentTemplate &&
"parameterized struct must be nested in a TemplateOp");
658 assert(parentModule &&
"TemplateOp must be nested in a ModuleOp");
662 evaluateTemplateExprs(parentTemplate, paramNameToConcrete);
666 if (remainingNames.empty()) {
669 (templateNameWithAttrs + mlir::Twine(
'_') + newStruct.
getSymName()).str()
673 symTables.getSymbolTable(parentModule).insert(newStruct, Block::iterator(parentTemplate));
675 typeAtCallerSymPieces.pop_back();
678 TemplateOp newTemplate = parentTemplate.cloneWithoutRegions();
679 newTemplate.
setSymName(templateNameWithAttrs);
680 assert(newTemplate->getNumRegions() > 0 &&
"region exists");
684 for (Attribute name : remainingNames) {
685 FlatSymbolRefAttr nameSym = llvm::dyn_cast<FlatSymbolRefAttr>(name);
686 assert(nameSym &&
"expected FlatSymbolRefAttr");
688 Operation *symOp = symTables.getSymbolTable(parentTemplate).lookup(nameSym.getAttr());
689 assert(symOp &&
"symbol must exist");
690 newTemplate.insert(newTemplate.begin(), symOp->clone());
695 symTables.getSymbolTable(newTemplate).insert(newStruct);
696 symTables.getSymbolTable(parentModule).insert(newTemplate, Block::iterator(parentTemplate));
700 typeAtCallerSymPieces.back() = FlatSymbolRefAttr::get(newTemplate.
getSymNameAttr());
706 typeAtCallerSymPieces.push_back(
707 FlatSymbolRefAttr::get(newLocalType.
getNameRef().getLeafReference())
712 llvm::dbgs() <<
"[StructCloner] original def type: " << typeAtDef <<
'\n';
713 llvm::dbgs() <<
"[StructCloner] cloned def type: " << newStruct.
getType() <<
'\n';
714 llvm::dbgs() <<
"[StructCloner] original remote type: " << typeAtCaller <<
'\n';
715 llvm::dbgs() <<
"[StructCloner] cloned local type: " << newLocalType <<
'\n';
716 llvm::dbgs() <<
"[StructCloner] cloned remote type: " << newRemoteType <<
'\n';
722 MappedTypeConverter tyConv(typeAtDef, newStruct.
getType(), paramNameToConcrete);
723 ConversionTarget target =
727 return !paramNameToConcrete.contains(op.getConstNameAttr());
731 patterns.add<ClonedBodyConstReadOpPattern>(
732 tyConv, ctx, paramNameToConcrete, tracker_.delayedDiagnosticSet(newLocalType)
734 patterns.add<ClonedStructMemberReadOpPattern>(tyConv, ctx, paramNameToConcrete);
735 if (failed(applyFullConversion(newStruct, target, std::move(patterns)))) {
736 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] instantiating body of struct failed \n");
739 return newRemoteType;
743 StructCloner(ConversionTracker &tracker, ModuleOp root)
744 : tracker_(tracker), rootMod(root), symTables() {}
746 FailureOr<StructType> createInstantiatedClone(
StructType orig) {
747 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] orig: " << orig <<
'\n');
748 if (ArrayAttr params = orig.
getParams()) {
749 return genClone(orig, params.getValue());
751 LLVM_DEBUG(llvm::dbgs() <<
"[StructCloner] skip: nullptr for params \n");
755 void enableReportMissing() { reportMissing =
true; }
757 void disableReportMissing() { reportMissing =
false; }
760class DisableReportMissing;
762class ParameterizedStructUseTypeConverter :
public TypeConverter {
763 ConversionTracker &tracker_;
766 friend DisableReportMissing;
769 ParameterizedStructUseTypeConverter(ConversionTracker &tracker, ModuleOp root)
770 : TypeConverter(), tracker_(tracker), cloner(tracker, root) {
772 addConversion([](Type inputTy) {
return inputTy; });
776 llvm::dbgs() <<
"[ParameterizedStructUseTypeConverter] attempting conversion of "
780 if (
auto opt = tracker_.getInstantiation(inputTy)) {
786 FailureOr<StructType> cloneRes = cloner.createInstantiatedClone(inputTy);
787 if (failed(cloneRes)) {
792 llvm::dbgs() <<
"[ParameterizedStructUseTypeConverter] instantiating " << inputTy
793 <<
" as " << newTy <<
'\n'
795 tracker_.recordInstantiation(inputTy, newTy);
799 addConversion([
this](
ArrayType inputTy) {
800 return inputTy.cloneWith(convertType(inputTy.getElementType()));
805class CallStructFuncPattern :
public OpConversionPattern<CallOp> {
806 ConversionTracker &tracker_;
809 CallStructFuncPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &tracker)
811 : OpConversionPattern<CallOp>(converter, ctx, 1), tracker_(tracker) {}
813 LogicalResult matchAndRewrite(
814 CallOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
816 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] CallOp: " << op <<
'\n');
819 SmallVector<Type> newResultTypes;
820 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
821 return op->emitError(
"Could not convert Op result types.");
824 llvm::dbgs() <<
"[CallStructFuncPattern] newResultTypes: "
834 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] newStTy: " << newStTy <<
'\n');
835 calleeAttr =
appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
836 tracker_.reportDelayedDiagnostics(newStTy, op);
840 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] newStTy: " << newStTy <<
'\n');
841 calleeAttr =
appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
845 LLVM_DEBUG(llvm::dbgs() <<
"[CallStructFuncPattern] replaced " << op);
847 rewriter, op, newResultTypes, calleeAttr, adapter.
getMapOperands(),
851 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
857class MemberDefOpPattern :
public OpConversionPattern<MemberDefOp> {
859 MemberDefOpPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &)
861 : OpConversionPattern<MemberDefOp>(converter, ctx, 1) {}
863 LogicalResult matchAndRewrite(
864 MemberDefOp op, OpAdaptor , ConversionPatternRewriter &rewriter
866 LLVM_DEBUG(llvm::dbgs() <<
"[MemberDefOpPattern] MemberDefOp: " << op <<
'\n');
868 Type oldMemberType = op.
getType();
869 Type newMemberType = getTypeConverter()->convertType(oldMemberType);
870 if (oldMemberType == newMemberType) {
873 rewriter.modifyOpInPlace(op, [&op, &newMemberType]() { op.
setType(newMemberType); });
881 ParameterizedStructUseTypeConverter &tyConv;
884 explicit DisableReportMissing(ParameterizedStructUseTypeConverter &tc) : tyConv(tc) {}
886 void checkStarted()
override { tyConv.cloner.disableReportMissing(); }
888 void checkEnded(
bool)
override { tyConv.cloner.enableReportMissing(); }
891LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
892 MLIRContext *ctx = modOp.getContext();
893 ParameterizedStructUseTypeConverter tyConv(tracker, modOp);
894 DisableReportMissing drm(tyConv);
897 patterns.add<CallStructFuncPattern, MemberDefOpPattern>(tyConv, ctx, tracker);
898 return applyPartialConversion(modOp, target, std::move(patterns));
907class FuncInstTypeConverter :
public TypeConverter {
908 DenseMap<Attribute, Attribute> paramNameToValue;
910 Attribute convertIfPossible(Attribute a)
const {
911 auto res = paramNameToValue.find(a);
912 return (res != paramNameToValue.end()) ? res->second : a;
916 explicit FuncInstTypeConverter(DenseMap<Attribute, Attribute> paramNameToConcrete)
917 : TypeConverter(), paramNameToValue(std::move(paramNameToConcrete)) {
918 addConversion([](Type t) {
return t; });
920 addConversion([
this](
TypeVarType inputTy) -> Type {
921 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.
getNameRef()))) {
922 Type convertedType = tyAttr.getValue();
924 return convertedType;
930 addConversion([
this](
ArrayType inputTy) {
931 SmallVector<Attribute> updated;
932 bool changed =
false;
933 for (Attribute a : inputTy.getDimensionSizes()) {
934 Attribute converted = convertIfPossible(a);
935 updated.push_back(converted);
936 if (converted != a) {
940 Type newElemTy = this->convertType(inputTy.getElementType());
941 if (!changed && newElemTy == inputTy.getElementType()) {
948 if (ArrayAttr params = inputTy.getParams()) {
949 SmallVector<Attribute> updated;
950 bool changed =
false;
951 for (Attribute a : params) {
952 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
953 Type newTy = this->convertType(ta.getValue());
954 if (newTy != ta.getValue()) {
955 updated.push_back(TypeAttr::get(newTy));
960 Attribute converted = convertIfPossible(a);
961 if (converted != a) {
962 updated.push_back(converted);
967 updated.push_back(a);
971 inputTy.
getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
979 bool containsParam(Attribute nameAttr)
const {
return paramNameToValue.contains(nameAttr); }
980 const DenseMap<Attribute, Attribute> &getParamMap()
const {
return paramNameToValue; }
983class InstantiateFuncAtCallOp final :
public OpRewritePattern<CallOp> {
984 ConversionTracker &tracker_;
987 InstantiateFuncAtCallOp(MLIRContext *ctx, ConversionTracker &tracker)
988 : OpRewritePattern<CallOp>(ctx), tracker_(tracker) {}
990 LogicalResult matchAndRewrite(
CallOp op, PatternRewriter &rewriter)
const override {
991 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateFuncAtCallOp] op: " << op <<
'\n');
994 SymbolTableCollection symTables;
995 FailureOr<SymbolLookupResult<FuncDefOp>> callTgtOpt = op.
getCalleeTarget(symTables);
996 if (failed(callTgtOpt)) {
997 return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
998 diag <<
"could not find target function for call";
1004 TemplateOp parentTemplate = llvm::dyn_cast<TemplateOp>(callTgt->getParentOp());
1005 if (!parentTemplate) {
1009 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] target function in template "
1021 if (failed(unifyResult)) {
1022 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1023 diag.append(
"target function type does not unify with call type ")
1025 .attachNote(callTgt.getLoc())
1026 .append(
"target function declared here");
1030 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] unifications of types: "
1035 DenseMap<Attribute, Attribute> paramNameToConcrete;
1040 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] TemplateParamsAttr: " << callParams <<
'\n'
1043 for (
auto paramOp : realParams) {
1044 auto paramName = FlatSymbolRefAttr::get(paramOp.getSymNameAttr());
1045 auto it = unifyResult->find({paramName,
Side::RHS});
1046 if (it == unifyResult->end()) {
1048 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] unification for param '" << paramName
1053 Attribute inferredVal = it->second;
1054 if (!isConcreteAttr(inferredVal)) {
1056 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] unification for param '" << paramName
1057 <<
"': not concrete, " << inferredVal <<
'\n'
1064 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] unification for param '" << paramName
1065 <<
"': incompatible with specified param type. MUST FAIL!\n"
1067 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1068 diag.append(
"inferred value for parameter '")
1070 .append(
"' is incompatible with specified param type")
1071 .attachNote(paramOp.getLoc())
1072 .append(
"template parameter declared here");
1075 paramNameToConcrete[paramName] = inferredVal;
1080 assert((callParams.size() == llvm::range_size(realParams)) &&
"per CallOpVerifier");
1082 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1083 diag.append(
"incompatible with specified param type(s)");
1087 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1088 diag.append(
"incompatible with inferred param value(s)");
1092 for (
auto [paramOp, attr] : llvm::zip_equal(realParams, callParams.getValue())) {
1093 auto paramName = FlatSymbolRefAttr::get(paramOp.getSymNameAttr());
1094 if (!isConcreteAttr(attr)) {
1096 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] unification for param '" << paramName
1097 <<
"': not concrete, " << attr <<
'\n'
1101 paramNameToConcrete[paramName] = attr;
1105 if (paramNameToConcrete.empty()) {
1106 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateFuncAtCallOp] skip: no concrete params\n");
1112 evaluateTemplateExprs(parentTemplate, paramNameToConcrete);
1115 SmallVector<Attribute> remainingNames;
1116 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
1118 auto it = paramNameToConcrete.find(paramName);
1119 if (it != paramNameToConcrete.end()) {
1120 attrsForInstantiatedNameSuffix.push_back(it->second);
1122 attrsForInstantiatedNameSuffix.push_back(
nullptr);
1123 remainingNames.push_back(paramName);
1127 MLIRContext *ctx = op.getContext();
1129 assert(parentModule &&
"TemplateOp must be nested in a ModuleOp");
1134 parentTemplate.
getSymName().str(), attrsForInstantiatedNameSuffix
1142 auto applyBodyConversions = [&](
FuncDefOp newFunc) -> LogicalResult {
1143 FuncInstTypeConverter tyConv(paramNameToConcrete);
1147 return !tyConv.containsParam(p.getConstNameAttr());
1149 SmallVector<Diagnostic> delayedDiagnostics;
1151 bodyPatterns.add<ClonedBodyConstReadOpPattern>(
1152 tyConv, ctx, tyConv.getParamMap(), delayedDiagnostics
1154 if (failed(applyFullConversion(newFunc, target, std::move(bodyPatterns)))) {
1158 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] instantiated clone: " << newFunc <<
'\n'
1160 ::reportDelayedDiagnostics(op, std::move(delayedDiagnostics));
1163 SymbolTableCollection tables;
1164 WalkResult res = newFunc.walk([&tables](
CallOp nestedCall) {
1167 return failure(res.wasInterrupted());
1171 assert(symPieces.size() >= 2 &&
"callee must include at least template and function names");
1172 if (remainingNames.empty()) {
1175 std::string newFuncName =
1176 (mlir::Twine(templateNameWithAttrs) +
"_" + callTgt.
getSymName()).str();
1177 StringRef actualNewFuncName = newFuncName;
1178 if (!symTables.getSymbolTable(parentModule).lookup(newFuncName)) {
1182 symTables.getSymbolTable(parentModule).insert(newFunc, Block::iterator(parentTemplate));
1185 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] created full instantiation function: "
1186 << actualNewFuncName <<
'\n'
1188 if (failed(applyBodyConversions(newFunc))) {
1190 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] body conversion failed for "
1191 << actualNewFuncName <<
'\n'
1194 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1195 diag.append(
"failure while creating instantiated function '", actualNewFuncName,
'\'');
1200 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] reusing full instantiation function: "
1201 << actualNewFuncName <<
'\n'
1207 symPieces.pop_back();
1208 symPieces.pop_back();
1209 symPieces.push_back(FlatSymbolRefAttr::get(StringAttr::get(ctx, actualNewFuncName)));
1216 if (Operation *existing =
1217 symTables.getSymbolTable(parentModule).lookup(templateNameWithAttrs)) {
1218 newTemplate = llvm::dyn_cast<TemplateOp>(existing);
1222 newTemplate = parentTemplate.cloneWithoutRegions();
1223 newTemplate.
setSymName(templateNameWithAttrs);
1224 assert(newTemplate->getNumRegions() > 0 &&
"region exists");
1228 Block &newTemplateBody = newTemplate.
getBodyRegion().front();
1229 for (Attribute name : remainingNames) {
1230 FlatSymbolRefAttr nameSym = llvm::cast<FlatSymbolRefAttr>(name);
1231 Operation *paramOp = symTables.getSymbolTable(parentTemplate).lookup(nameSym.getAttr());
1232 assert(paramOp &&
"symbol must exist");
1233 newTemplateBody.push_back(paramOp->clone());
1238 if (failed(applyBodyConversions(newFunc))) {
1239 StringRef newFuncName = newFunc.
getSymName();
1241 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] body conversion failed for "
1242 << newFuncName <<
'\n'
1244 newTemplate->erase();
1245 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1246 diag.append(
"failure while creating instantiated function '", newFuncName,
'\'');
1252 symTables.getSymbolTable(newTemplate).insert(newFunc);
1253 symTables.getSymbolTable(parentModule).insert(newTemplate, Block::iterator(parentTemplate));
1255 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] created partial instantiation template: "
1260 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] reusing partial instantiation template: "
1267 symPieces.pop_back();
1268 symPieces.pop_back();
1269 symPieces.push_back(FlatSymbolRefAttr::get(newTemplate.
getSymNameAttr()));
1270 symPieces.push_back(FlatSymbolRefAttr::get(callTgt.
getSymNameAttr()));
1274 rewriter.modifyOpInPlace(op, [&op, &symPieces]() {
1278 llvm::dbgs() <<
"[InstantiateFuncAtCallOp] updating callee from " << op.
getCalleeAttr()
1279 <<
" to " << newCalleeAttr <<
'\n';
1285 tracker_.updateModifiedFlag(
true);
1290LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1291 MLIRContext *ctx = modOp.getContext();
1292 RewritePatternSet patterns(ctx);
1293 patterns.add<InstantiateFuncAtCallOp>(ctx, tracker);
1294 MatchFailureListener failureListener;
1295 walkAndApplyPatterns(modOp, std::move(patterns), &failureListener);
1296 return failure(failureListener.hadFailure);
1304template <HasInterface<LoopLikeOpInterface> OpClass>
1305class LoopUnrollPattern :
public OpRewritePattern<OpClass> {
1307 using OpRewritePattern<OpClass>::OpRewritePattern;
1309 LogicalResult matchAndRewrite(OpClass loopOp, PatternRewriter &rewriter)
const override {
1310 if (
auto maybeConstant = getConstantTripCount(loopOp)) {
1313 rewriter.eraseOp(loopOp);
1316 return loopOp.promoteIfSingleIteration(rewriter);
1318 return loopUnrollByFactor(loopOp,
tripCount);
1326 static std::optional<int64_t> getConstantTripCount(LoopLikeOpInterface loopOp) {
1327 std::optional<OpFoldResult> lbVal = loopOp.getSingleLowerBound();
1328 std::optional<OpFoldResult> ubVal = loopOp.getSingleUpperBound();
1329 std::optional<OpFoldResult> stepVal = loopOp.getSingleStep();
1330 if (!lbVal.has_value() || !ubVal.has_value() || !stepVal.has_value()) {
1331 return std::nullopt;
1333 return constantTripCount(lbVal.value(), ubVal.value(), stepVal.value());
1337LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1338 MLIRContext *ctx = modOp.getContext();
1339 RewritePatternSet patterns(ctx);
1340 patterns.add<LoopUnrollPattern<scf::ForOp>>(ctx);
1341 patterns.add<LoopUnrollPattern<affine::AffineForOp>>(ctx);
1343 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1351std::optional<SmallVector<int64_t>> getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
1352 SmallVector<int64_t> res;
1353 for (OpFoldResult ofr : ofrs) {
1354 std::optional<int64_t> cv = getConstantIntValue(ofr);
1355 if (!cv.has_value()) {
1356 return std::nullopt;
1358 res.push_back(cv.value());
1363struct AffineMapFolder {
1365 OperandRangeRange mapOpGroups;
1366 DenseI32ArrayAttr dimsPerGroup;
1367 ArrayRef<Attribute> paramsOfStructTy;
1371 SmallVector<SmallVector<Value>> mapOpGroups;
1372 SmallVector<int32_t> dimsPerGroup;
1373 SmallVector<Attribute> paramsOfStructTy;
1376 static inline SmallVector<ValueRange> getConvertedMapOpGroups(Output out) {
1377 return llvm::map_to_vector(out.mapOpGroups, [](
const SmallVector<Value> &grp) {
1378 return ValueRange(grp);
1382 static LogicalResult
1383 fold(PatternRewriter &rewriter,
const Input &in, Output &out, Operation *op,
const char *aspect) {
1384 if (in.mapOpGroups.empty()) {
1389 assert(in.mapOpGroups.size() <= in.paramsOfStructTy.size());
1390 assert(std::cmp_equal(in.mapOpGroups.size(), in.dimsPerGroup.size()));
1393 for (Attribute sizeAttr : in.paramsOfStructTy) {
1394 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(sizeAttr)) {
1395 ValueRange currMapOps = in.mapOpGroups[idx++];
1400 SmallVector<OpFoldResult> currMapOpsCast = getAsOpFoldResult(currMapOps);
1402 llvm::dbgs() <<
"[AffineMapFolder] currMapOps as fold results: "
1405 if (
auto constOps = Step3_InstantiateAffineMaps::getConstantIntValues(currMapOpsCast)) {
1406 SmallVector<Attribute> result;
1407 bool hasPoison =
false;
1408 auto constAttrs = llvm::map_to_vector(*constOps, [&rewriter](int64_t v) -> Attribute {
1409 return rewriter.getIndexAttr(v);
1411 LogicalResult foldResult = m.getAffineMap().constantFold(constAttrs, result, &hasPoison);
1416 "Cannot fold affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
1417 " due to divide by 0 or modulus with negative divisor"
1422 if (failed(foldResult)) {
1426 "Folding affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
" failed"
1431 if (result.size() != 1) {
1435 "Folding affine_map for ", aspect,
" ", out.paramsOfStructTy.size(),
1436 " produced ", result.size(),
" results but expected 1"
1441 assert(!llvm::isa<AffineMapAttr>(result[0]) &&
"not converted");
1442 out.paramsOfStructTy.push_back(result[0]);
1446 out.mapOpGroups.emplace_back(currMapOps);
1447 out.dimsPerGroup.push_back(in.dimsPerGroup[idx - 1]);
1450 out.paramsOfStructTy.push_back(sizeAttr);
1452 assert(idx == in.mapOpGroups.size() &&
"all affine_map not processed");
1454 in.paramsOfStructTy.size() == out.paramsOfStructTy.size() &&
1455 "produced wrong number of dimensions"
1463class InstantiateAtCreateArrayOp final :
public OpRewritePattern<CreateArrayOp> {
1465 ConversionTracker &tracker_;
1468 InstantiateAtCreateArrayOp(MLIRContext *ctx, ConversionTracker &tracker)
1469 : OpRewritePattern(ctx), tracker_(tracker) {}
1471 LogicalResult matchAndRewrite(
CreateArrayOp op, PatternRewriter &rewriter)
const override {
1474 AffineMapFolder::Output out;
1475 AffineMapFolder::Input in = {
1480 if (failed(AffineMapFolder::fold(rewriter, in, out, op,
"array dimension"))) {
1485 if (newResultType == oldResultType) {
1489 assert(tracker_.isLegalConversion(oldResultType, newResultType,
"InstantiateAtCreateArrayOp"));
1491 llvm::dbgs() <<
"[InstantiateAtCreateArrayOp] instantiating " << oldResultType <<
" as "
1492 << newResultType <<
" in \"" << op <<
"\"\n"
1495 rewriter, op, newResultType, AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup
1502class InstantiateAtCallOpCompute final :
public OpRewritePattern<CallOp> {
1503 ConversionTracker &tracker_;
1506 InstantiateAtCallOpCompute(MLIRContext *ctx, ConversionTracker &tracker)
1507 : OpRewritePattern(ctx), tracker_(tracker) {}
1509 LogicalResult matchAndRewrite(
CallOp op, PatternRewriter &rewriter)
const override {
1514 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] target: " << op.
getCallee() <<
'\n');
1516 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] oldRetTy: " << oldRetTy <<
'\n');
1517 ArrayAttr params = oldRetTy.
getParams();
1523 AffineMapFolder::Output out;
1524 AffineMapFolder::Input in = {
1529 if (!in.mapOpGroups.empty()) {
1531 if (failed(AffineMapFolder::fold(rewriter, in, out, op,
"struct parameter"))) {
1535 llvm::dbgs() <<
"[InstantiateAtCallOpCompute] folded affine_map in result type params\n";
1541 if (callArgTypes.empty()) {
1545 SymbolTableCollection tables;
1547 if (failed(lookupRes)) {
1550 if (failed(instantiateViaTargetType(in, out, callArgTypes, lookupRes->get()))) {
1554 llvm::dbgs() <<
"[InstantiateAtCallOpCompute] propagated instantiations via symrefs in "
1555 "result type params: "
1561 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] newRetTy: " << newRetTy <<
'\n');
1562 if (newRetTy == oldRetTy) {
1569 if (!tracker_.isLegalConversion(oldRetTy, newRetTy,
"InstantiateAtCallOpCompute")) {
1570 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1572 "result type mismatch: due to struct instantiation, expected type ", newRetTy,
1573 ", but found ", oldRetTy
1577 LLVM_DEBUG(llvm::dbgs() <<
"[InstantiateAtCallOpCompute] replaced " << op);
1579 rewriter, op, TypeRange {newRetTy}, op.
getCallee(),
1580 AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup, op.
getArgOperands()
1583 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
1590 inline LogicalResult instantiateViaTargetType(
1591 const AffineMapFolder::Input &in, AffineMapFolder::Output &out,
1592 OperandRange::type_range callArgTypes,
FuncDefOp targetFunc
1597 assert(in.paramsOfStructTy.size() == targetResTyParams.size());
1599 if (llvm::all_of(in.paramsOfStructTy, isConcreteAttr<>)) {
1605 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1607 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']' <<
" target func arg types: "
1609 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1611 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1618 assert(unifies &&
"should have been checked by verifiers");
1621 llvm::dbgs() <<
'[' << __FUNCTION__ <<
']'
1630 SmallVector<Attribute> newReturnStructParams = llvm::map_to_vector(
1631 llvm::zip_equal(targetResTyParams.getValue(), in.paramsOfStructTy),
1632 [&unifications](std::tuple<Attribute, Attribute> p) {
1633 Attribute fromCall = std::get<1>(p);
1636 if (!isConcreteAttr<>(fromCall)) {
1637 Attribute fromTgt = std::get<0>(p);
1639 llvm::dbgs() <<
"[instantiateViaTargetType] fromCall = " << fromCall <<
'\n';
1640 llvm::dbgs() <<
"[instantiateViaTargetType] fromTgt = " << fromTgt <<
'\n';
1642 assert(llvm::isa<SymbolRefAttr>(fromTgt));
1643 auto it = unifications.find(std::make_pair(llvm::cast<SymbolRefAttr>(fromTgt), Side::LHS));
1644 if (it != unifications.end()) {
1645 Attribute unifiedAttr = it->second;
1647 llvm::dbgs() <<
"[instantiateViaTargetType] unifiedAttr = " << unifiedAttr <<
'\n';
1649 if (unifiedAttr && isConcreteAttr<false>(unifiedAttr)) {
1658 out.paramsOfStructTy = newReturnStructParams;
1659 assert(out.paramsOfStructTy.size() == in.paramsOfStructTy.size() &&
"post-condition");
1660 assert(out.mapOpGroups.empty() &&
"post-condition");
1661 assert(out.dimsPerGroup.empty() &&
"post-condition");
1666LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1667 MLIRContext *ctx = modOp.getContext();
1668 RewritePatternSet patterns(ctx);
1670 InstantiateAtCreateArrayOp,
1671 InstantiateAtCallOpCompute
1674 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1682class UpdateNewArrayElemFromWrite final :
public OpRewritePattern<CreateArrayOp> {
1683 ConversionTracker &tracker_;
1686 UpdateNewArrayElemFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1687 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1689 LogicalResult matchAndRewrite(
CreateArrayOp op, PatternRewriter &rewriter)
const override {
1691 ArrayType createResultType = dyn_cast<ArrayType>(createResult.getType());
1692 assert(createResultType &&
"CreateArrayOp must produce ArrayType");
1697 Type newResultElemType =
nullptr;
1698 for (Operation *user : createResult.getUsers()) {
1699 if (
WriteArrayOp writeOp = dyn_cast<WriteArrayOp>(user)) {
1700 if (writeOp.getArrRef() != createResult) {
1703 Type writeRValueType = writeOp.getRvalue().getType();
1704 if (writeRValueType == oldResultElemType) {
1707 if (newResultElemType && newResultElemType != writeRValueType) {
1710 <<
"[UpdateNewArrayElemFromWrite] multiple possible element types for CreateArrayOp "
1711 << newResultElemType <<
" vs " << writeRValueType <<
'\n'
1715 newResultElemType = writeRValueType;
1718 if (!newResultElemType) {
1722 if (!tracker_.isLegalConversion(
1723 oldResultElemType, newResultElemType,
"UpdateNewArrayElemFromWrite"
1728 rewriter.modifyOpInPlace(op, [&createResult, &newType]() { createResult.setType(newType); });
1730 llvm::dbgs() <<
"[UpdateNewArrayElemFromWrite] updated result type of " << op <<
'\n'
1738LogicalResult updateArrayElemFromArrAccessOp(
1740 PatternRewriter &rewriter
1747 if (oldArrType == newArrType ||
1748 !tracker.isLegalConversion(oldArrType, newArrType,
"updateArrayElemFromArrAccessOp")) {
1751 rewriter.modifyOpInPlace(op, [&op, &newArrType]() { op.
getArrRef().setType(newArrType); });
1753 llvm::dbgs() <<
"[updateArrayElemFromArrAccessOp] updated base array type in " << op <<
'\n'
1760class UpdateArrayElemFromArrWrite final :
public OpRewritePattern<WriteArrayOp> {
1761 ConversionTracker &tracker_;
1764 UpdateArrayElemFromArrWrite(MLIRContext *ctx, ConversionTracker &tracker)
1765 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1767 LogicalResult matchAndRewrite(
WriteArrayOp op, PatternRewriter &rewriter)
const override {
1768 return updateArrayElemFromArrAccessOp(op, op.
getRvalue().getType(), tracker_, rewriter);
1772class UpdateArrayElemFromArrRead final :
public OpRewritePattern<ReadArrayOp> {
1773 ConversionTracker &tracker_;
1776 UpdateArrayElemFromArrRead(MLIRContext *ctx, ConversionTracker &tracker)
1777 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1779 LogicalResult matchAndRewrite(
ReadArrayOp op, PatternRewriter &rewriter)
const override {
1780 return updateArrayElemFromArrAccessOp(op, op.
getResult().getType(), tracker_, rewriter);
1785class UpdateMemberDefTypeFromWrite final :
public OpRewritePattern<MemberDefOp> {
1786 ConversionTracker &tracker_;
1789 UpdateMemberDefTypeFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1790 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1792 LogicalResult matchAndRewrite(
MemberDefOp op, PatternRewriter &rewriter)
const override {
1795 assert(parentRes &&
"MemberDefOp parent is always StructDefOp");
1799 Type newType =
nullptr;
1801 std::optional<Location> newTypeLoc = std::nullopt;
1802 for (SymbolTable::SymbolUse symUse : memberUsers.value()) {
1803 if (
MemberWriteOp writeOp = llvm::dyn_cast<MemberWriteOp>(symUse.getUser())) {
1804 Type writeToType = writeOp.getVal().getType();
1805 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateMemberDefTypeFromWrite] checking " << writeOp <<
'\n');
1808 newType = writeToType;
1809 newTypeLoc = writeOp.getLoc();
1810 }
else if (writeToType != newType) {
1816 if (!tracker_.isLegalConversion(writeToType, newType,
"UpdateMemberDefTypeFromWrite")) {
1817 if (tracker_.isLegalConversion(
1818 newType, writeToType,
"UpdateMemberDefTypeFromWrite"
1821 newType = writeToType;
1822 newTypeLoc = writeOp.getLoc();
1825 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1829 "' with different value types"
1832 diag.attachNote(newTypeLoc).append(
"type written here is ", newType);
1834 diag.attachNote(writeOp.getLoc()).append(
"type written here is ", writeToType);
1842 if (!newType || newType == op.
getType()) {
1845 if (!tracker_.isLegalConversion(op.
getType(), newType,
"UpdateMemberDefTypeFromWrite")) {
1848 rewriter.modifyOpInPlace(op, [&op, &newType]() { op.
setType(newType); });
1849 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateMemberDefTypeFromWrite] updated type of " << op <<
'\n');
1856SmallVector<std::unique_ptr<Region>> moveRegions(Operation *op) {
1857 SmallVector<std::unique_ptr<Region>> newRegions;
1858 for (Region ®ion : op->getRegions()) {
1859 auto newRegion = std::make_unique<Region>();
1860 newRegion->takeBody(region);
1861 newRegions.push_back(std::move(newRegion));
1870class UpdateInferredResultTypes final :
public OpTraitRewritePattern<OpTrait::InferTypeOpAdaptor> {
1871 ConversionTracker &tracker_;
1874 UpdateInferredResultTypes(MLIRContext *ctx, ConversionTracker &tracker)
1875 : OpTraitRewritePattern(ctx, 6), tracker_(tracker) {}
1877 LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter)
const override {
1878 SmallVector<Type, 1> inferredResultTypes;
1879 InferTypeOpInterface retTypeFn = llvm::cast<InferTypeOpInterface>(op);
1880 LogicalResult result = retTypeFn.inferReturnTypes(
1881 op->getContext(), op->getLoc(), op->getOperands(), op->getRawDictionaryAttrs(),
1882 op->getPropertiesStorage(), op->getRegions(), inferredResultTypes
1884 if (failed(result)) {
1887 if (op->getResultTypes() == inferredResultTypes) {
1890 if (!tracker_.areLegalConversions(
1891 op->getResultTypes(), inferredResultTypes,
"UpdateInferredResultTypes"
1897 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateInferredResultTypes] replaced " << *op);
1898 SmallVector<std::unique_ptr<Region>> newRegions = moveRegions(op);
1899 Operation *newOp = rewriter.create(
1900 op->getLoc(), op->getName().getIdentifier(), op->getOperands(), inferredResultTypes,
1901 op->getAttrs(), op->getSuccessors(), newRegions
1903 rewriter.replaceOp(op, newOp);
1904 LLVM_DEBUG(llvm::dbgs() <<
" with " << *newOp <<
'\n');
1910class UpdateFuncTypeFromReturn final :
public OpRewritePattern<FuncDefOp> {
1911 ConversionTracker &tracker_;
1914 UpdateFuncTypeFromReturn(MLIRContext *ctx, ConversionTracker &tracker)
1915 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1917 LogicalResult matchAndRewrite(
FuncDefOp op, PatternRewriter &rewriter)
const override {
1918 Region &body = op.getFunctionBody();
1922 ReturnOp retOp = llvm::dyn_cast<ReturnOp>(body.back().getTerminator());
1923 assert(retOp &&
"final op in body region must be return");
1924 OperandRange::type_range tyFromReturnOp = retOp.
getOperands().getTypes();
1927 if (oldFuncTy.getResults() == tyFromReturnOp) {
1930 if (!tracker_.areLegalConversions(
1931 oldFuncTy.getResults(), tyFromReturnOp,
"UpdateFuncTypeFromReturn"
1936 rewriter.modifyOpInPlace(op, [&]() {
1937 op.
setFunctionType(rewriter.getFunctionType(oldFuncTy.getInputs(), tyFromReturnOp));
1940 llvm::dbgs() <<
"[UpdateFuncTypeFromReturn] changed " << op.
getSymName() <<
" from "
1951class UpdateFreeFuncCallOpTypes final :
public OpRewritePattern<CallOp> {
1952 ConversionTracker &tracker_;
1955 UpdateFreeFuncCallOpTypes(MLIRContext *ctx, ConversionTracker &tracker)
1956 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1958 LogicalResult matchAndRewrite(
CallOp op, PatternRewriter &rewriter)
const override {
1959 SymbolTableCollection tables;
1961 if (failed(lookupRes)) {
1964 FuncDefOp targetFunc = lookupRes->get();
1969 if (op.getResultTypes() == targetFunc.
getFunctionType().getResults()) {
1972 if (!tracker_.areLegalConversions(
1974 "UpdateFreeFuncCallOpTypes"
1979 LLVM_DEBUG(llvm::dbgs() <<
"[UpdateFreeFuncCallOpTypes] replaced " << op);
1982 LLVM_DEBUG(llvm::dbgs() <<
" with " << newOp <<
'\n');
1989LogicalResult updateMemberRefValFromMemberDef(
1992 SymbolTableCollection tables;
1997 Type oldResultType = op.
getVal().getType();
1998 Type newResultType = def->get().getType();
1999 if (oldResultType == newResultType ||
2000 !tracker.isLegalConversion(oldResultType, newResultType,
"updateMemberRefValFromMemberDef")) {
2003 rewriter.modifyOpInPlace(op, [&op, &newResultType]() { op.
getVal().setType(newResultType); });
2005 llvm::dbgs() <<
"[updateMemberRefValFromMemberDef] updated value type in " << op <<
'\n'
2013class UpdateMemberReadValFromDef final :
public OpRewritePattern<MemberReadOp> {
2014 ConversionTracker &tracker_;
2017 UpdateMemberReadValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
2018 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
2020 LogicalResult matchAndRewrite(
MemberReadOp op, PatternRewriter &rewriter)
const override {
2021 return updateMemberRefValFromMemberDef(op, tracker_, rewriter);
2026class UpdateMemberWriteValFromDef final :
public OpRewritePattern<MemberWriteOp> {
2027 ConversionTracker &tracker_;
2030 UpdateMemberWriteValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
2031 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
2033 LogicalResult matchAndRewrite(
MemberWriteOp op, PatternRewriter &rewriter)
const override {
2034 return updateMemberRefValFromMemberDef(op, tracker_, rewriter);
2038LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
2039 MLIRContext *ctx = modOp.getContext();
2040 RewritePatternSet patterns(ctx);
2045 UpdateInferredResultTypes,
2047 UpdateFreeFuncCallOpTypes,
2048 UpdateFuncTypeFromReturn,
2049 UpdateNewArrayElemFromWrite,
2050 UpdateArrayElemFromArrRead,
2051 UpdateArrayElemFromArrWrite,
2052 UpdateMemberDefTypeFromWrite,
2053 UpdateMemberReadValFromDef,
2054 UpdateMemberWriteValFromDef
2057 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
2065 SymbolTableCollection tables;
2068 : rootMod(root), defTree(symDefTree), useGraph(symUseGraph) {}
2076struct FromKeepSet :
public CleanupBase {
2077 using CleanupBase::CleanupBase;
2082 LogicalResult eraseUnreachableFrom(ArrayRef<StructDefOp> keep) {
2084 SetVector<SymbolOpInterface> roots(keep.begin(), keep.end());
2086 rootMod.walk([&roots](Operation *op) {
2090 if (!fdef.isInStruct()) {
2099 llvm::df_iterator_default_set<const SymbolUseGraphNode *> symbolsToKeep;
2100 for (
size_t i = 0; i < roots.size(); ++i) {
2101 SymbolOpInterface keepRoot = roots[i];
2102 LLVM_DEBUG({ llvm::dbgs() <<
"[EraseUnreachable] root: " << keepRoot <<
'\n'; });
2104 assert(keepRootNode &&
"every struct def must be in the def tree");
2105 for (
const SymbolDefTreeNode *reachableDefNode : llvm::depth_first(keepRootNode)) {
2107 llvm::dbgs() <<
"[EraseUnreachable] can reach: " << reachableDefNode->getOp() <<
'\n';
2109 if (SymbolOpInterface reachableDef = reachableDefNode->getOp()) {
2114 if (
const SymbolUseGraphNode *useGraphNodeForDef = useGraph.lookupNode(reachableDef)) {
2116 depth_first_ext(useGraphNodeForDef, symbolsToKeep)) {
2118 llvm::dbgs() <<
"[EraseUnreachable] uses symbol: "
2119 << usedSymbolNode->getSymbolPath() <<
'\n';
2123 if (usedSymbolNode->isTemplateSymbolBinding()) {
2127 auto lookupRes = usedSymbolNode->lookupSymbol(tables);
2128 if (failed(lookupRes)) {
2129 LLVM_DEBUG(useGraph.dumpToDotFile());
2133 if (lookupRes->viaInclude()) {
2136 if (
StructDefOp asStruct = llvm::dyn_cast<StructDefOp>(lookupRes->get())) {
2137 bool insertRes = roots.insert(asStruct);
2141 llvm::dbgs() <<
"[EraseUnreachable] found another root: " << asStruct <<
'\n';
2151 rootMod.walk([
this, &symbolsToKeep](
StructDefOp op) {
2154 if (!symbolsToKeep.contains(n)) {
2155 LLVM_DEBUG(llvm::dbgs() <<
"[EraseUnreachable] removing: " << op.getSymName() <<
'\n');
2159 return WalkResult::skip();
2166struct FromEraseSet :
public CleanupBase {
2171 DenseSet<SymbolRefAttr> &&tryToErasePaths
2173 : CleanupBase(root, symDefTree, symUseGraph) {
2175 for (SymbolRefAttr path : tryToErasePaths) {
2176 LLVM_DEBUG(llvm::dbgs() <<
"[FromEraseSet] path to erase: " << path <<
'\n';);
2177 Operation *lookupFrom = rootMod.getOperation();
2179 assert(succeeded(res) &&
"inputs must be valid StructDefOp references");
2180 if (!res->viaInclude()) {
2181 auto op = res->get();
2182 LLVM_DEBUG(llvm::dbgs() <<
"[FromEraseSet] added op to the erase set: " << op <<
'\n';);
2183 tryToErase.insert(op);
2186 llvm::dbgs() <<
"[FromEraseSet] ignored op because it comes from an include: "
2187 << res->get() <<
'\n';
2193 LogicalResult eraseUnusedStructs() {
2196 collectSafeToErase(sd);
2201 for (
auto it = visitedPlusSafetyResult.begin(); it != visitedPlusSafetyResult.end(); ++it) {
2202 if (!it->second || !llvm::isa<StructDefOp>(it->first.getOperation())) {
2203 visitedPlusSafetyResult.erase(it);
2206 for (
auto &[sym, _] : visitedPlusSafetyResult) {
2207 LLVM_DEBUG(llvm::dbgs() <<
"[EraseIfUnused] removing: " << sym.getNameAttr() <<
'\n');
2213 const DenseSet<StructDefOp> &getTryToEraseSet()
const {
return tryToErase; }
2217 DenseSet<StructDefOp> tryToErase;
2221 DenseMap<SymbolOpInterface, bool> visitedPlusSafetyResult;
2223 DenseMap<const SymbolUseGraphNode *, SymbolOpInterface> lookupCache;
2227 bool collectSafeToErase(SymbolOpInterface
check) {
2231 auto visited = visitedPlusSafetyResult.find(
check);
2232 if (visited != visitedPlusSafetyResult.end()) {
2233 return visited->second;
2238 if (!tryToErase.contains(sd)) {
2239 visitedPlusSafetyResult[
check] =
false;
2246 visitedPlusSafetyResult[
check] =
true;
2251 if (collectSafeToErase(defTree.lookupNode(
check))) {
2252 const auto *useNode = useGraph.lookupNode(
check);
2253 assert(useNode || (llvm::isa<ModuleOp, TemplateOp>(
check.getOperation())));
2254 if (!useNode || collectSafeToErase(useNode)) {
2260 visitedPlusSafetyResult[
check] =
false;
2268 if (SymbolOpInterface checkOp = p->getOp()) {
2269 return collectSafeToErase(checkOp);
2279 if (SymbolOpInterface checkOp = cachedLookup(p)) {
2280 if (!collectSafeToErase(checkOp)) {
2293 assert(node &&
"must provide a node");
2295 auto fromCache = lookupCache.find(node);
2296 if (fromCache != lookupCache.end()) {
2297 return fromCache->second;
2301 assert(succeeded(lookupRes) &&
"graph contains node with invalid path");
2302 assert(lookupRes->get() !=
nullptr &&
"lookup must return an Operation");
2307 SymbolOpInterface actualRes =
2308 lookupRes->viaInclude() ? nullptr : llvm::cast<SymbolOpInterface>(lookupRes->get());
2310 lookupCache[node] = actualRes;
2311 assert((!actualRes == lookupRes->viaInclude()) &&
"not found iff included");
2320 void runOnOperation()
override {
2321 ModuleOp modOp = getOperation();
2322 if (failed(runOn(modOp))) {
2325 llvm::dbgs() <<
"=====================================================================\n";
2326 llvm::dbgs() <<
" Dumping module after failure of pass " <<
DEBUG_TYPE <<
'\n';
2327 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
2328 llvm::dbgs() <<
"=====================================================================\n";
2330 signalPassFailure();
2334 inline LogicalResult runOn(ModuleOp modOp) {
2339 if (cleanupMode == StructCleanupMode::MainAsRoot) {
2340 if (failed(eraseUnreachableFromMainStruct(modOp))) {
2348 OpPassManager universalCleanup(ModuleOp::getOperationName());
2353 if (failed(runPipeline(universalCleanup, modOp))) {
2357 ConversionTracker tracker;
2358 unsigned loopCount = 0;
2361 if (loopCount > iterationLimit) {
2362 llvm::errs() <<
DEBUG_TYPE <<
" exceeded the limit of " << iterationLimit
2363 <<
" iterations!\n";
2366 tracker.resetModifiedFlag();
2369 llvm::dbgs() <<
"[FlatteningPass(count=" << loopCount
2370 <<
")] Running step 1: struct instantiation\n";
2375 if (failed(Step1A_InstantiateStructs::run(modOp, tracker))) {
2376 llvm::errs() <<
DEBUG_TYPE <<
" failed while instantiating structs in templates\n";
2380 if (failed(Step1B_InstantiateFunctions::run(modOp, tracker))) {
2381 llvm::errs() <<
DEBUG_TYPE <<
" failed while instantiating functions in templates\n";
2386 llvm::dbgs() <<
"[FlatteningPass(count=" << loopCount
2387 <<
")] Running step 2: loop unrolling\n";
2390 if (failed(Step2_Unroll::run(modOp, tracker))) {
2391 llvm::errs() <<
DEBUG_TYPE <<
" failed while unrolling loops\n";
2396 llvm::dbgs() <<
"[FlatteningPass(count=" << loopCount
2397 <<
")] Running step 3: affine maps instantiation\n";
2400 if (failed(Step3_InstantiateAffineMaps::run(modOp, tracker))) {
2401 llvm::errs() <<
DEBUG_TYPE <<
" failed while instantiating `affine_map` parameters\n";
2406 llvm::dbgs() <<
"[FlatteningPass(count=" << loopCount
2407 <<
")] Running step 4: type propagation\n";
2410 if (failed(Step4_PropagateTypes::run(modOp, tracker))) {
2411 llvm::errs() <<
DEBUG_TYPE <<
" failed while propagating instantiated types\n";
2415 LLVM_DEBUG(
if (tracker.isModified()) {
2416 llvm::dbgs() <<
"=====================================================================\n";
2417 llvm::dbgs() <<
" Dumping module between iterations of " << DEBUG_TYPE <<
'\n';
2418 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
2419 llvm::dbgs() <<
"=====================================================================\n";
2421 }
while (tracker.isModified());
2424 if (failed(cleanupSwitch(modOp, tracker))) {
2428 if (failed(runPipeline(universalCleanup, modOp))) {
2435 LogicalResult cleanupSwitch(ModuleOp modOp,
const ConversionTracker &tracker) {
2436 LLVM_DEBUG({ llvm::dbgs() <<
"[FlatteningPass] Running step 5: cleanup "; });
2437 switch (cleanupMode) {
2438 case StructCleanupMode::MainAsRoot:
2439 LLVM_DEBUG(llvm::dbgs() <<
"(main as root mode)\n");
2440 return eraseUnreachableFromMainStruct(modOp,
false);
2441 case StructCleanupMode::ConcreteAsRoot:
2442 LLVM_DEBUG(llvm::dbgs() <<
"(concrete structs mode)\n");
2443 return eraseUnreachableFromConcreteStructs(modOp);
2444 case StructCleanupMode::Preimage:
2445 LLVM_DEBUG(llvm::dbgs() <<
"(preimage mode)\n");
2446 return erasePreimageOfInstantiations(modOp, tracker);
2448 LLVM_DEBUG(llvm::dbgs() <<
"(disabled)\n");
2454 LogicalResult erasePreimageOfInstantiations(ModuleOp rootMod,
const ConversionTracker &tracker) {
2459 Step5_Cleanup::FromEraseSet cleaner(
2460 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>(),
2461 tracker.getInstantiatedStructNames()
2463 LogicalResult res = cleaner.eraseUnusedStructs();
2464 if (succeeded(res)) {
2465 LLVM_DEBUG(llvm::dbgs() <<
"[Cleanup(preimage)] success\n";);
2467 const SymbolUseGraph *useGraph =
nullptr;
2468 rootMod->walk([
this, &cleaner, &useGraph](StructDefOp op) {
2469 if (cleaner.getTryToEraseSet().contains(op)) {
2472 useGraph = &getAnalysis<SymbolUseGraph>();
2475 if (useGraph->lookupNode(op)->hasPredecessor()) {
2476 op.emitWarning(
"Parameterized struct still has uses!").report();
2479 return WalkResult::skip();
2482 LLVM_DEBUG(llvm::dbgs() <<
"[Cleanup(preimage)] failed\n";);
2487 LogicalResult eraseUnreachableFromConcreteStructs(ModuleOp rootMod) {
2488 SmallVector<StructDefOp> roots;
2489 rootMod.walk([&roots](StructDefOp op) {
2491 roots.push_back(op);
2493 return WalkResult::skip();
2496 Step5_Cleanup::FromKeepSet cleaner(
2497 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
2499 return cleaner.eraseUnreachableFrom(roots);
2502 LogicalResult eraseUnreachableFromMainStruct(ModuleOp rootMod,
bool emitWarning =
true) {
2503 Step5_Cleanup::FromKeepSet cleaner(
2504 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
2506 FailureOr<SymbolLookupResult<StructDefOp>> mainOpt =
2508 if (failed(mainOpt)) {
2511 SymbolLookupResult<StructDefOp>
main = mainOpt.value();
2512 if (emitWarning && !
main) {
2516 rootMod.emitWarning()
2518 "using option '", cleanupMode.getArgStr(),
'=',
2520 MAIN_ATTR_NAME,
"\" attribute on the top-level module may remove all structs!"
2524 return cleaner.eraseUnreachableFrom(
2525 main ? ArrayRef<StructDefOp> {*
main} : ArrayRef<StructDefOp> {}
2533 return std::make_unique<FlatteningPass>();
Common private implementation for poly dialect passes.
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Builds a tree structure representing the symbol table structure.
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
Builds a graph structure representing the relationships between symbols and their uses.
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
Gets the SSA Value for the referenced array.
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced array.
ArrayType cloneWith(std::optional<::llvm::ArrayRef< int64_t > > shape, ::mlir::Type elementType) const
Clone this type with the given shape and element type.
::mlir::Type getElementType() const
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
::mlir::OperandRangeRange getMapOperands()
::mlir::TypedValue<::mlir::Type > getResult()
::mlir::TypedValue<::mlir::Type > getRvalue()
static constexpr ::llvm::StringLiteral getOperationName()
void setType(::mlir::Type attrValue)
::std::optional<::mlir::Attribute > getTableOffset()
void setTableOffsetAttr(::mlir::Attribute attr)
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the MemberRefOp.
::mlir::FailureOr< SymbolLookupResult< MemberDefOp > > getMemberDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the member referenced in this op.
static constexpr ::llvm::StringLiteral getOperationName()
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
static constexpr ::llvm::StringLiteral getOperationName()
::llvm::StringRef getSymName()
void setSymName(::llvm::StringRef attrValue)
bool hasTemplateSymbolBindings()
Return true iff the struct.def appears within a poly.template that defines constant parameters and/or...
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
::mlir::ArrayAttr getParams() const
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the callee is FUNC_NAME_COMPUTE, return the single StructType result.
::mlir::SymbolRefAttr getCalleeAttr()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
::mlir::SymbolRefAttr getCallee()
::mlir::FunctionType getTypeSignature()
Return the FunctionType inferred from the arg operands and result types of this CallOp.
void setTemplateParamsAttr(::mlir::ArrayAttr attr)
::mlir::Operation::operand_range getArgOperands()
::mlir::ArrayAttr getTemplateParamsAttr()
::mlir::OperandRangeRange getMapOperands()
void setCalleeAttr(::mlir::SymbolRefAttr attr)
::mlir::FailureOr< UnificationMap > unifyTypeSignature(::mlir::FunctionType other)
Attempt type unfication between the inferred FunctionType from this CallOp (as LHS) and the given Fun...
::mlir::LogicalResult verifyTemplateParamsMatchInferred(::llvm::iterator_range<::mlir::Region::op_iterator<::llzk::polymorphic::TemplateParamOp > > targetParamDefs, const UnificationMap &unifications)
Verify that each template parameter value provided in this CallOp is consistent with the value inferr...
::mlir::LogicalResult verifyTemplateParamCompatibility(::mlir::Attribute paramFromCallOp, ::llzk::polymorphic::TemplateParamOp targetParam)
Check type compatibility of the given template parameter value from this CallOp against the declared ...
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
FuncDefOp clone(::mlir::IRMapping &mapper)
Create a deep copy of this function and all of its blocks, remapping any operands that use values out...
::mlir::FunctionType getFunctionType()
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Required by FunctionOpInterface.
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the name is FUNC_NAME_COMPUTE, return the single StructType result.
::llvm::StringRef getSymName()
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
bool isInStruct()
Return true iff the function is within a StructDefOp.
void setFunctionType(::mlir::FunctionType attrValue)
void setSymName(::llvm::StringRef attrValue)
::mlir::StringAttr getSymNameAttr()
::mlir::Operation::operand_range getOperands()
::mlir::FlatSymbolRefAttr getConstNameAttr()
::mlir::StringAttr getSymNameAttr()
::mlir::Region & getInitializerRegion()
::llvm::StringRef getSymName()
::mlir::Region & getBodyRegion()
::llvm::SmallVector<::mlir::Attribute > getConstNames()
Return the names of all ops of type OpT within the body region in the order they are defined.
::mlir::StringAttr getSymNameAttr()
void setSymName(::llvm::StringRef attrValue)
::llvm::StringRef getSymName()
inline ::llvm::iterator_range<::mlir::Region::op_iterator< OpT > > getConstOps()
Return ops of type OpT within the body region.
::mlir::FlatSymbolRefAttr getNameRef() const
int main(int argc, char **argv)
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
mlir::ConversionTarget newConverterDefinedTargetWithCallback(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, LegalityCheckCallback &cb, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
std::unique_ptr< mlir::Pass > createFlatteningPass()
::llvm::StringRef stringifyStructCleanupMode(StructCleanupMode val)
std::unique_ptr< mlir::Pass > createEmptyTemplateRemoval()
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter::replaceOpWithNewOp() that automatically copies discardable attributes (i...
llvm::SMTExprRef tripCount(mlir::scf::ForOp op, llvm::SMTSolver *solver)
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
bool isConcreteType(Type type, bool allowStructParams)
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
TypeClass getIfSingleton(mlir::TypeRange types)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
std::string stringWithoutType(mlir::Attribute a)
bool isNullOrEmpty(mlir::ArrayAttr a)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet covering all LLZK op types that may contain a StructType.
bool isDynamic(IntegerAttr intAttr)
mlir::SymbolRefAttr asSymbolRefAttr(mlir::StringAttr root, mlir::SymbolRefAttr tail)
Build a SymbolRefAttr that prepends tail with root, i.e., root::tail.
int64_t fromAPInt(const llvm::APInt &i)
FailureOr< SymbolLookupResult< StructDefOp > > getMainInstanceDef(SymbolTableCollection &symbolTable, Operation *lookupFrom)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
llvm::SmallVector< FlatSymbolRefAttr > getPieces(SymbolRefAttr ref)
constexpr char MAIN_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that specifies the type of the main struct.