25#include <mlir/Dialect/Arith/IR/Arith.h>
26#include <mlir/Dialect/SCF/IR/SCF.h>
27#include <mlir/Dialect/SCF/Transforms/Patterns.h>
28#include <mlir/IR/Attributes.h>
29#include <mlir/IR/BuiltinAttributes.h>
30#include <mlir/IR/MLIRContext.h>
31#include <mlir/IR/Operation.h>
32#include <mlir/IR/PatternMatch.h>
33#include <mlir/Transforms/DialectConversion.h>
35#include <llvm/ADT/STLExtras.h>
36#include <llvm/ADT/SmallVector.h>
37#include <llvm/Support/Debug.h>
41#define DEBUG_TYPE "poly-dialect-shared"
48static struct OpClassesWithStructTypes {
74 WithGeneralBuilder {};
80 const std::tuple<llzk::function::CallOp, llzk::array::CreateArrayOp> NoGeneralBuilder {};
82} OpClassesWithStructTypes;
84template <
typename I,
typename NextOpClass,
typename... OtherOpClasses>
85inline void applyToMoreTypes(I inserter) {
86 std::apply(inserter, std::tuple<NextOpClass, OtherOpClasses...> {});
88template <
typename I>
inline void applyToMoreTypes(I inserter) {}
90inline bool defaultLegalityCheck(
const mlir::TypeConverter &tyConv, mlir::Operation *op) {
92 if (!tyConv.isLegal(op)) {
97 mlir::DictionaryAttr dictAttr = op->getAttrDictionary();
98 for (mlir::NamedAttribute n : dictAttr.getValue()) {
99 if (mlir::TypeAttr tyAttr = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
100 mlir::Type t = tyAttr.getValue();
101 if (mlir::FunctionType funcTy = llvm::dyn_cast<mlir::FunctionType>(t)) {
102 if (!tyConv.isSignatureLegal(funcTy)) {
106 if (!tyConv.isLegal(t)) {
116template <
typename Check>
inline bool runCheck(mlir::Operation *op, Check
check) {
117 if (
auto specificOp =
118 llvm::dyn_cast_if_present<
typename llvm::function_traits<Check>::template arg_t<0>>(op)) {
119 return check(specificOp);
128template <
typename OpClass,
typename Rewriter,
typename... Args>
130 mlir::DictionaryAttr attrs = op->getDiscardableAttrDictionary();
132 newOp->setDiscardableAttrs(attrs);
139template <
typename OpClass>
142 using mlir::OpConversionPattern<OpClass>::OpConversionPattern;
145 OpClass op, OpClass::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter
147 const mlir::TypeConverter *converter = mlir::OpConversionPattern<OpClass>::getTypeConverter();
150 mlir::SmallVector<mlir::Type> newResultTypes;
151 if (mlir::failed(converter->convertTypes(op->getResultTypes(), newResultTypes))) {
152 return op->emitError(
"Could not convert Op result types.");
157 adaptor.getAttributes().empty() ||
159 adaptor.getAttributes(), [d = op->getAttrDictionary()](mlir::NamedAttribute a) {
160 return d.contains(a.getName());
165 mlir::SmallVector<mlir::NamedAttribute> newAttrs(op->getAttrDictionary().getValue());
166 for (mlir::NamedAttribute &n : newAttrs) {
167 if (mlir::TypeAttr t = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
168 if (mlir::Type newType = converter->convertType(t.getValue())) {
169 n.setValue(mlir::TypeAttr::get(newType));
171 return op->emitError().append(
"Could not convert type in attribute: ", t);
177 rewriter, op, mlir::TypeRange(newResultTypes), adaptor.getOperands(),
178 mlir::ArrayRef(newAttrs)
180 return mlir::success();
185 :
public mlir::OpConversionPattern<llzk::array::CreateArrayOp> {
190 if (mlir::Type newType = getTypeConverter()->convertType(op.getType())) {
191 return mlir::success();
193 return op->emitError(
"Could not convert Op result type.");
200 mlir::Type newType = getTypeConverter()->convertType(op.getType());
202 llvm::isa<llzk::array::ArrayType>(newType) &&
"CreateArrayOp must produce ArrayType result"
207 rewriter, op, llvm::cast<llzk::array::ArrayType>(newType), adapter.
getElements()
211 rewriter, op, llvm::cast<llzk::array::ArrayType>(newType), adapter.
getMapOperands(),
226 mlir::SmallVector<mlir::Type> newResultTypes;
227 if (mlir::failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
228 return op->emitError(
"Could not convert Op result types.");
231 rewriter, op, newResultTypes, op.
getCalleeAttr(), adapter.getMapOperands(),
234 return mlir::success();
242template <
typename... AdditionalOpClasses>
244 mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target
246 mlir::RewritePatternSet patterns(ctx);
247 auto inserter = [&](
auto... opClasses) {
250 std::apply(inserter, OpClassesWithStructTypes.WithGeneralBuilder);
251 applyToMoreTypes<
decltype(inserter), AdditionalOpClasses...>(inserter);
255 mlir::populateFunctionOpInterfaceTypeConversionPattern<llzk::function::FuncDefOp>(
258 mlir::scf::populateSCFStructuralTypeConversionsAndLegality(tyConv, patterns, target);
284template <
typename... AdditionalOpClasses,
typename... AdditionalChecks>
286 mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks
290 tyConv, ctx, empty, (std::forward<AdditionalChecks>(checks))...
300template <
typename... AdditionalOpClasses,
typename... AdditionalChecks>
303 AdditionalChecks &&...checks
306 auto inserter = [&](
auto... opClasses) {
307 target.addDynamicallyLegalOp<
decltype(opClasses)...>([&cb, &tyConv,
308 &checks...](mlir::Operation *op) {
310 llvm::dbgs() <<
"[newConverterDefinedTarget] checking legality of ";
315 defaultLegalityCheck(tyConv, op) && (runCheck<AdditionalChecks>(op, checks) && ...);
318 LLVM_DEBUG(
if (legality) { llvm::dbgs() <<
"[newConverterDefinedTarget] is legal\n"; }
else {
319 llvm::dbgs() <<
"[newConverterDefinedTarget] is not legal\n";
324 std::apply(inserter, OpClassesWithStructTypes.NoGeneralBuilder);
325 std::apply(inserter, OpClassesWithStructTypes.WithGeneralBuilder);
326 applyToMoreTypes<
decltype(inserter), AdditionalOpClasses...>(inserter);
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
::mlir::OperandRangeRange getMapOperands()
::mlir::Operation::operand_range getElements()
::mlir::SymbolRefAttr getCalleeAttr()
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
mlir::LogicalResult matchAndRewrite(llzk::function::CallOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter) const override
mlir::LogicalResult match(llzk::array::CreateArrayOp op) const override
void rewrite(llzk::array::CreateArrayOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter) const override
void checkEnded(bool) override
void checkStarted() override
mlir::LogicalResult matchAndRewrite(OpClass op, OpClass::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override
virtual void checkStarted()=0
virtual void checkEnded(bool outcome)=0
virtual ~LegalityCheckCallback()=default
mlir::ConversionTarget newConverterDefinedTargetWithCallback(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, LegalityCheckCallback &cb, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet that includes a GeneralTypeReplacePattern for all of OpClassesWithStru...
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter.replaceOpWithNewOp() that automatically copies discardable attributes (i....
mlir::ConversionTarget newBaseTarget(mlir::MLIRContext *ctx)
Return a new ConversionTarget allowing all LLZK-required dialects.
bool isNullOrEmpty(mlir::ArrayAttr a)