22#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
23#include <mlir/Analysis/DataFlow/SparseAnalysis.h>
24#include <mlir/Analysis/DataFlowFramework.h>
25#include <mlir/Dialect/SCF/IR/SCF.h>
26#include <mlir/IR/Attributes.h>
27#include <mlir/IR/Operation.h>
28#include <mlir/IR/Region.h>
29#include <mlir/IR/SymbolTable.h>
30#include <mlir/IR/Value.h>
31#include <mlir/IR/ValueRange.h>
32#include <mlir/Interfaces/CallInterfaces.h>
33#include <mlir/Interfaces/ControlFlowInterfaces.h>
34#include <mlir/Support/LLVM.h>
36#include <llvm/ADT/STLExtras.h>
37#include <llvm/Support/Casting.h>
43using namespace mlir::dataflow;
53 : DataFlowAnalysis(s) {
54 registerAnchorKind<CFGEdge>();
60 for (Region ®ion : top->getRegions()) {
64 for (Value argument : region.front().getArguments()) {
69 return initializeRecursively(top);
72LogicalResult AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
75 if (failed(visitOperation(op))) {
79 for (Region ®ion : op->getRegions()) {
80 for (Block &block : region) {
81 getOrCreate<Executable>(getProgramPointBefore(&block))->blockContentSubscribe(
this);
84 for (Operation &containedOp : block) {
85 if (failed(initializeRecursively(&containedOp))) {
96 if (!point->isBlockStart()) {
97 return visitOperation(point->getPrevOp());
99 visitBlock(point->getBlock());
103LogicalResult AbstractSparseForwardDataFlowAnalysis::visitOperation(Operation *op) {
109 if (op->getBlock() !=
nullptr &&
110 !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive()) {
115 SmallVector<AbstractSparseLattice *> resultLattices;
116 resultLattices.reserve(op->getNumResults());
117 for (Value result : op->getResults()) {
119 resultLattices.push_back(resultLattice);
123 if (
auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
124 visitRegionSuccessors(
125 getProgramPointAfter(branch), branch,
126 RegionBranchPoint::parent(), resultLattices
132 SmallVector<const AbstractSparseLattice *> operandLattices;
133 operandLattices.reserve(op->getNumOperands());
134 for (Value operand : op->getOperands()) {
136 operandLattice->useDefSubscribe(
this);
137 operandLattices.push_back(operandLattice);
140 if (
auto call = dyn_cast<CallOpInterface>(op)) {
143 auto callable = dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
144 if (!getSolverConfig().isInterprocedural() || (callable && !callable.getCallableRegion())) {
151 const auto *predecessors =
152 getOrCreateFor<PredecessorState>(getProgramPointAfter(op), getProgramPointAfter(call));
155 if (!predecessors->allPredecessorsKnown()) {
159 for (Operation *predecessor : predecessors->getKnownPredecessors()) {
160 for (
auto &&[operand, resLattice] : llvm::zip(predecessor->getOperands(), resultLattices)) {
171void AbstractSparseForwardDataFlowAnalysis::visitBlock(Block *block) {
173 if (block->getNumArguments() == 0) {
178 if (!getOrCreate<Executable>(getProgramPointBefore(block))->isLive()) {
183 SmallVector<AbstractSparseLattice *> argLattices;
184 argLattices.reserve(block->getNumArguments());
185 for (BlockArgument argument : block->getArguments()) {
187 argLattices.push_back(argLattice);
192 if (block->isEntryBlock()) {
194 auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
195 if (callable && callable.getCallableRegion() == block->getParent()) {
196 const auto *callsites = getOrCreateFor<PredecessorState>(
197 getProgramPointBefore(block), getProgramPointAfter(callable)
201 if (!callsites->allPredecessorsKnown() || !getSolverConfig().isInterprocedural()) {
204 for (Operation *callsite : callsites->getKnownPredecessors()) {
205 auto call = cast<CallOpInterface>(callsite);
206 for (
auto it : llvm::zip(call.getArgOperands(), argLattices)) {
216 if (
auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
217 return visitRegionSuccessors(
218 getProgramPointBefore(block), branch, block->getParent(), argLattices
224 block->getParentOp(), RegionSuccessor(block->getParent()), argLattices, 0
229 for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); it != e; ++it) {
230 Block *predecessor = *it;
234 auto *edgeExecutable = getOrCreate<Executable>(getLatticeAnchor<CFGEdge>(predecessor, block));
235 edgeExecutable->blockContentSubscribe(
this);
236 if (!edgeExecutable->isLive()) {
241 if (
auto branch = dyn_cast<BranchOpInterface>(predecessor->getTerminator())) {
242 SuccessorOperands operands = branch.getSuccessorOperands(it.getSuccessorIndex());
243 for (
auto [idx, lattice] : llvm::enumerate(argLattices)) {
244 if (Value operand = operands[idx]) {
258void AbstractSparseForwardDataFlowAnalysis::visitRegionSuccessors(
259 ProgramPoint *point, RegionBranchOpInterface branch, RegionBranchPoint successor,
260 ArrayRef<AbstractSparseLattice *> lattices
262 const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
263 assert(predecessors->allPredecessorsKnown() &&
"unexpected unresolved region successors");
265 for (Operation *op : predecessors->getKnownPredecessors()) {
267 std::optional<OperandRange> operands;
271 operands = branch.getEntrySuccessorOperands(successor);
273 }
else if (
auto regionTerminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
274 operands = regionTerminator.getSuccessorOperands(successor);
282 ValueRange inputs = predecessors->getSuccessorInputs(op);
284 inputs.size() == operands->size() &&
285 "expected the same number of successor inputs as operands"
288 unsigned firstIndex = 0;
289 if (inputs.size() != lattices.size()) {
290 if (!point->isBlockStart()) {
291 if (!inputs.empty()) {
292 firstIndex = cast<OpResult>(inputs.front()).getResultNumber();
295 branch, RegionSuccessor(branch->getResults().slice(firstIndex, inputs.size())),
299 if (!inputs.empty()) {
300 firstIndex = cast<BlockArgument>(inputs.front()).getArgNumber();
302 Region *region = point->getBlock()->getParent();
305 RegionSuccessor(region, region->getArguments().slice(firstIndex, inputs.size())),
311 for (
auto it : llvm::zip(*operands, lattices.drop_front(firstIndex))) {
320 addDependency(state, point);
325 ArrayRef<AbstractSparseLattice *> lattices
335 propagateIfChanged(lhs, lhs->join(rhs));
This file implements sparse data-flow analysis using the data-flow analysis framework.
mlir::LogicalResult visit(mlir::ProgramPoint *point) override
Visit a program point.
mlir::LogicalResult initialize(mlir::Operation *top) override
Initialize the analysis by visiting every owner of an SSA value: all operations and blocks.
virtual mlir::LogicalResult visitOperationImpl(mlir::Operation *op, mlir::ArrayRef< const AbstractSparseLattice * > operandLattices, mlir::ArrayRef< AbstractSparseLattice * > resultLattices)=0
The operation transfer function.
virtual AbstractSparseLattice * getLatticeElement(mlir::Value value)=0
Get the lattice element of a value.
AbstractSparseForwardDataFlowAnalysis(mlir::DataFlowSolver &solver)
void setAllToEntryStates(mlir::ArrayRef< AbstractSparseLattice * > lattices)
const AbstractSparseLattice * getLatticeElementFor(mlir::ProgramPoint *point, mlir::Value value)
Get a read-only lattice element for a value and add it as a dependency to a program point.
void join(AbstractSparseLattice *lhs, const AbstractSparseLattice &rhs)
Join the lattice element and propagate and update if it changed.
virtual void setToEntryState(AbstractSparseLattice *lattice)=0
Set the given lattice element(s) at control flow entry point(s).
virtual void visitNonControlFlowArgumentsImpl(mlir::Operation *op, const mlir::RegionSuccessor &successor, mlir::ArrayRef< AbstractSparseLattice * > argLattices, unsigned firstIndex)=0
Given an operation with region control-flow, the lattices of the operands, and a region successor,...
virtual void visitExternalCallImpl(mlir::CallOpInterface call, mlir::ArrayRef< const AbstractSparseLattice * > argumentLattices, mlir::ArrayRef< AbstractSparseLattice * > resultLattices)=0
The transfer function for calls to external functions.
mlir::dataflow::AbstractSparseLattice AbstractSparseLattice