22#include <mlir/Dialect/SCF/Utils/Utils.h>
24#include <llvm/Support/Debug.h>
25#include <llvm/Support/SMTAPI.h>
31#define GEN_PASS_DEF_FUSEPRODUCTLOOPSPASS
41constexpr int INDEX_WIDTH = 64;
43static inline bool isConstOrStructParam(Value val) {
45 return val.getDefiningOp<arith::ConstantIndexOp>() ||
49static llvm::SMTExprRef mkExpr(Value value, llvm::SMTSolver *solver) {
50 if (
auto constOp = value.getDefiningOp<arith::ConstantIndexOp>()) {
51 return solver->mkBitvector(llvm::APSInt::get(constOp.value()), INDEX_WIDTH);
54 return solver->mkSymbol(
55 std::string {polyReadOp.getConstName()}.c_str(), solver->getBitvectorSort(INDEX_WIDTH)
58 assert(
false &&
"unsupported: checking non-constant trip counts");
62static llvm::SMTExprRef tripCount(scf::ForOp op, llvm::SMTSolver *solver) {
63 const auto *one = solver->mkBitvector(llvm::APSInt::get(1), INDEX_WIDTH);
64 return solver->mkBVSDiv(
67 solver->mkBVSub(mkExpr(op.getUpperBound(), solver), mkExpr(op.getLowerBound(), solver))
69 mkExpr(op.getStep(), solver)
73static inline bool canLoopsBeFused(scf::ForOp a, scf::ForOp b) {
80 if (a->getParentRegion() != b->getParentRegion()) {
101 auto tripCountA = constantTripCount(a.getLowerBound(), a.getUpperBound(), a.getStep());
102 auto tripCountB = constantTripCount(b.getLowerBound(), b.getUpperBound(), b.getStep());
103 if (tripCountA.has_value() && tripCountB.has_value() && *tripCountA == *tripCountB) {
107 if (!isConstOrStructParam(a.getLowerBound()) || !isConstOrStructParam(a.getUpperBound()) ||
108 !isConstOrStructParam(a.getStep()) || !isConstOrStructParam(b.getLowerBound()) ||
109 !isConstOrStructParam(b.getUpperBound()) || !isConstOrStructParam(b.getStep())) {
113 llvm::SMTSolverRef solver = llvm::CreateZ3Solver();
114 solver->addConstraint( solver->mkNot(
115 solver->mkEqual(tripCount(a, solver.get()), tripCount(b, solver.get()))
118 return !*solver->check();
121static LogicalResult fuseMatchingLoopPairs(Region &body, MLIRContext *context) {
123 llvm::SmallVector<scf::ForOp> witnessLoops, constraintLoops;
124 body.walk<WalkOrder::PreOrder>([&witnessLoops, &constraintLoops](scf::ForOp forOp) {
126 return WalkResult::skip();
128 auto productSource = forOp->getAttrOfType<StringAttr>(
PRODUCT_SOURCE);
130 witnessLoops.push_back(forOp);
132 constraintLoops.push_back(forOp);
135 return WalkResult::skip();
141 witnessLoops, constraintLoops, canLoopsBeFused
145 if (failed(fusionCandidates)) {
150 IRRewriter rewriter {context};
151 for (
auto [w, c] : *fusionCandidates) {
152 auto fusedLoop = fuseIndependentSiblingForLoops(w, c, rewriter);
153 fusedLoop->setAttr(
PRODUCT_SOURCE, rewriter.getAttr<StringAttr>(
"fused"));
155 if (failed(fuseMatchingLoopPairs(fusedLoop.getBodyRegion(), context))) {
163 using Base = FuseProductLoopsPassBase<PassImpl>;
166 void runOnOperation()
override {
167 ModuleOp
mod = getOperation();
168 mod.walk([
this](function::FuncDefOp funcDef) {
170 if (failed(fuseMatchingLoopPairs(funcDef.getFunctionBody(), &getContext()))) {
bool isStructProduct()
Return true iff the function is within a StructDefOp and named FUNC_NAME_PRODUCT.
llvm::FailureOr< llvm::SetVector< std::pair< ValueT, ValueT > > > getMatchingPairs(llvm::ArrayRef< ValueT > as, llvm::ArrayRef< ValueT > bs, FnT doesMatch, bool allowPartial=true)
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char PRODUCT_SOURCE[]
Name of the attribute on aligned product program ops that specifies where they came from.
constexpr char FUNC_NAME_CONSTRAIN[]