24#include <mlir/Dialect/SCF/Utils/Utils.h>
26#include <llvm/Support/Debug.h>
27#include <llvm/Support/SMTAPI.h>
33#define GEN_PASS_DECL_FUSEPRODUCTLOOPSPASS
34#define GEN_PASS_DEF_FUSEPRODUCTLOOPSPASS
46 mlir::ModuleOp
mod = getOperation();
57static inline bool isConstOrStructParam(mlir::Value val) {
59 return val.getDefiningOp<mlir::arith::ConstantIndexOp>() ||
63llvm::SMTExprRef
mkExpr(mlir::Value value, llvm::SMTSolver *solver) {
64 if (
auto constOp = value.getDefiningOp<mlir::arith::ConstantIndexOp>()) {
65 return solver->mkBitvector(llvm::APSInt::get(constOp.value()),
INDEX_WIDTH);
68 return solver->mkSymbol(
69 std::string {polyReadOp.getConstName()}.c_str(), solver->getBitvectorSort(
INDEX_WIDTH)
72 assert(
false &&
"unsupported: checking non-constant trip counts");
76llvm::SMTExprRef
tripCount(mlir::scf::ForOp op, llvm::SMTSolver *solver) {
77 const auto *one = solver->mkBitvector(llvm::APSInt::get(1),
INDEX_WIDTH);
78 return solver->mkBVSDiv(
81 solver->mkBVSub(
mkExpr(op.getUpperBound(), solver),
mkExpr(op.getLowerBound(), solver))
83 mkExpr(op.getStep(), solver)
87static inline bool canLoopsBeFused(mlir::scf::ForOp a, mlir::scf::ForOp b) {
94 if (a->getParentRegion() != b->getParentRegion()) {
115 auto tripCountA = mlir::constantTripCount(a.getLowerBound(), a.getUpperBound(), a.getStep());
116 auto tripCountB = mlir::constantTripCount(b.getLowerBound(), b.getUpperBound(), b.getStep());
117 if (tripCountA.has_value() && tripCountB.has_value() && *tripCountA == *tripCountB) {
121 if (!isConstOrStructParam(a.getLowerBound()) || !isConstOrStructParam(a.getUpperBound()) ||
122 !isConstOrStructParam(a.getStep()) || !isConstOrStructParam(b.getLowerBound()) ||
123 !isConstOrStructParam(b.getUpperBound()) || !isConstOrStructParam(b.getStep())) {
127 llvm::SMTSolverRef solver = llvm::CreateZ3Solver();
128 solver->addConstraint( solver->mkNot(
132 return !*solver->check();
137 llvm::SmallVector<mlir::scf::ForOp> witnessLoops, constraintLoops;
138 body.walk<mlir::WalkOrder::PreOrder>([&witnessLoops, &constraintLoops](mlir::scf::ForOp forOp) {
140 return mlir::WalkResult::skip();
142 auto productSource = forOp->getAttrOfType<mlir::StringAttr>(
PRODUCT_SOURCE);
144 witnessLoops.push_back(forOp);
146 constraintLoops.push_back(forOp);
149 return mlir::WalkResult::skip();
155 witnessLoops, constraintLoops, canLoopsBeFused
159 if (mlir::failed(fusionCandidates)) {
160 return mlir::failure();
164 mlir::IRRewriter rewriter {context};
165 for (
auto [w, c] : *fusionCandidates) {
166 auto fusedLoop = mlir::fuseIndependentSiblingForLoops(w, c, rewriter);
167 fusedLoop->setAttr(
PRODUCT_SOURCE, rewriter.getAttr<mlir::StringAttr>(
"fused"));
170 return mlir::failure();
173 return mlir::success();
177 return std::make_unique<FuseProductLoopsPass>();
void runOnOperation() override
bool isStructProduct()
Return true iff the function is within a StructDefOp and named FUNC_NAME_PRODUCT.
llvm::FailureOr< llvm::SetVector< std::pair< ValueT, ValueT > > > getMatchingPairs(llvm::ArrayRef< ValueT > as, llvm::ArrayRef< ValueT > bs, FnT doesMatch, bool allowPartial=true)
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
llvm::SMTExprRef tripCount(mlir::scf::ForOp op, llvm::SMTSolver *solver)
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char PRODUCT_SOURCE[]
Name of the attribute on aligned product program ops that specifies where they came from.
constexpr char FUNC_NAME_CONSTRAIN[]
mlir::LogicalResult fuseMatchingLoopPairs(mlir::Region &body, mlir::MLIRContext *context)
Identify pairs of scf.for loops that can be fused, fuse them, and then recurse to fuse nested loops.
llvm::SMTExprRef mkExpr(mlir::Value value, llvm::SMTSolver *solver)
constexpr int INDEX_WIDTH
std::unique_ptr< mlir::Pass > createFuseProductLoopsPass()