23#include <mlir/Dialect/Arith/IR/Arith.h>
24#include <mlir/Dialect/SCF/IR/SCF.h>
25#include <mlir/IR/BuiltinOps.h>
26#include <mlir/IR/Dominance.h>
28#include <llvm/ADT/DenseMap.h>
29#include <llvm/ADT/PostOrderIterator.h>
30#include <llvm/ADT/SmallVector.h>
36#define GEN_PASS_DEF_REDUNDANTOPERATIONELIMINATIONPASS
47#define DEBUG_TYPE "llzk-duplicate-op-elim"
51static auto EMPTY_OP_KEY =
reinterpret_cast<Operation *
>(1);
52static auto TOMBSTONE_OP_KEY =
reinterpret_cast<Operation *
>(2);
60class OperationComparator {
62 explicit OperationComparator(Operation *o) : op(o) {
63 if (op != EMPTY_OP_KEY && op != TOMBSTONE_OP_KEY) {
64 operands = SmallVector<Value>(op->getOperands());
68 OperationComparator(Operation *o,
const TranslationMap &m) : op(o) {
69 for (
auto operand : op->getOperands()) {
70 if (
auto it = m.find(operand); it != m.end()) {
71 operands.push_back(it->second);
73 operands.push_back(operand);
78 Operation *getOp()
const {
return op; }
80 const SmallVector<Value> &getOperands()
const {
return operands; }
82 bool isCommutative()
const {
return op->hasTrait<OpTrait::IsCommutative>(); }
84 friend bool operator==(
const OperationComparator &lhs,
const OperationComparator &rhs) {
85 if (lhs.op == EMPTY_OP_KEY || rhs.op == EMPTY_OP_KEY || lhs.op == TOMBSTONE_OP_KEY ||
86 rhs.op == TOMBSTONE_OP_KEY) {
87 return lhs.op == rhs.op;
90 if (lhs.op->getName() != rhs.op->getName()) {
95 auto dialectName = lhs.op->getDialect()->getNamespace();
96 if (dialectName == scf::SCFDialect::getDialectNamespace()) {
105 if (lhs.op->getAttrs() != rhs.op->getAttrs()) {
109 if (lhs.isCommutative()) {
111 lhs.operands.size() == 2 && rhs.operands.size() == 2,
112 "No known commutative ops have more than two arguments"
114 return (lhs.operands[0] == rhs.operands[0] && lhs.operands[1] == rhs.operands[1]) ||
115 (lhs.operands[0] == rhs.operands[1] && lhs.operands[1] == rhs.operands[0]);
119 return lhs.operands == rhs.operands;
124 SmallVector<Value> operands;
131template <>
struct DenseMapInfo<OperationComparator> {
132 static OperationComparator
getEmptyKey() {
return OperationComparator(EMPTY_OP_KEY); }
134 return OperationComparator(TOMBSTONE_OP_KEY);
137 if (oc.getOp() == EMPTY_OP_KEY || oc.getOp() == TOMBSTONE_OP_KEY) {
138 return hash_value(oc.getOp());
141 return hash_value(oc.getOp()->getName());
143 static bool isEqual(
const OperationComparator &lhs,
const OperationComparator &rhs) {
153 using Base = RedundantOperationEliminationPassBase<PassImpl>;
156 void runOnOperation()
override {
157 SymbolTableCollection symbolTables;
161 auto &cga = getAnalysis<CallGraphAnalysis>();
162 const llzk::CallGraph *callGraph = &cga.getCallGraph();
163 for (
auto it = llvm::po_begin(callGraph); it != llvm::po_end(callGraph); ++it) {
164 const llzk::CallGraphNode *node = *it;
171 bool isPurposelessConstrainFunc(SymbolTableCollection &symbolTables, FuncDefOp fn) {
177 fn.walk([&](Operation *op) {
178 if (isa<EmitEqualityOp, EmitContainmentOp, AssertOp>(op)) {
180 return WalkResult::interrupt();
181 }
else if (
auto callOp = dyn_cast<CallOp>(op);
182 callOp && !callsPurposelessConstrainFunc(symbolTables, callOp)) {
184 return WalkResult::interrupt();
186 return WalkResult::advance();
191 bool callsPurposelessConstrainFunc(SymbolTableCollection &symbolTables, CallOp call) {
193 return succeeded(callLookup) && isPurposelessConstrainFunc(symbolTables, callLookup->get());
196 void runOnFunc(SymbolTableCollection &symbolTables, FuncDefOp fn) {
198 SmallVector<Operation *> redundantOps;
199 DenseSet<OperationComparator> uniqueOps;
200 DominanceInfo domInfo(fn);
202 auto unnecessaryOpCheck = [&](Operation *op) ->
bool {
203 if (
auto emiteq = dyn_cast<EmitEqualityOp>(op);
204 emiteq && emiteq.getLhs() == emiteq.getRhs()) {
205 redundantOps.push_back(op);
209 if (
auto callOp = dyn_cast<CallOp>(op);
210 callOp && callsPurposelessConstrainFunc(symbolTables, callOp)) {
211 redundantOps.push_back(op);
217 fn.walk([&](Operation *op) {
219 if (unnecessaryOpCheck(op)) {
220 return WalkResult::advance();
225 if (!isa<NonDetOp>(op)) {
226 OperationComparator comp(op, map);
227 if (
auto it = uniqueOps.find(comp);
228 it != uniqueOps.end() && domInfo.dominates(it->getOp(), op)) {
229 redundantOps.push_back(op);
230 for (
unsigned opNum = 0; opNum < op->getNumResults(); opNum++) {
231 map[op->getResult(opNum)] = it->getOp()->getResult(opNum);
234 uniqueOps.insert(comp);
238 return WalkResult::advance();
242 std::deque<Value> operands;
244 for (
auto *op : redundantOps) {
245 LLVM_DEBUG(llvm::dbgs() <<
"Removing op: " << *op <<
'\n');
246 for (
auto result : op->getResults()) {
247 if (!result.getUsers().empty()) {
248 auto it = map.find(result);
250 it != map.end(),
"failed to find a replacement value for redundant operation result"
252 LLVM_DEBUG(llvm::dbgs() <<
"Replacing " << it->first <<
" with " << it->second <<
'\n');
253 result.replaceAllUsesWith(it->second);
256 for (Value operand : op->getOperands()) {
257 operands.push_back(operand);
266 DenseSet<Value> checkedOperands;
267 while (!operands.empty()) {
268 Value operand = operands.front();
269 operands.pop_front();
270 checkedOperands.insert(operand);
274 if (
auto *op = operand.getDefiningOp(); op && operand.getUsers().empty()) {
275 for (
auto parentOperand : op->getOperands()) {
276 if (checkedOperands.find(parentOperand) == checkedOperands.end()) {
277 operands.push_back(parentOperand);
280 LLVM_DEBUG(llvm::dbgs() <<
"Removing unused operand: " << operand <<
'\n');
bool isExternal() const
Returns true if this node is an external node.
llzk::function::FuncDefOp getCalledFunction() const
Returns the called function that the callable region represents.
bool isStructConstrain()
Return true iff the function is within a StructDefOp and named FUNC_NAME_CONSTRAIN.
void ensure(bool condition, const llvm::Twine &errMsg)
std::unordered_map< SourceRef, SourceRefLatticeValue, SourceRef::Hash > TranslationMap
mlir::FailureOr< SymbolLookupResult< T > > resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers.
static OperationComparator getTombstoneKey()
static unsigned getHashValue(const OperationComparator &oc)
static bool isEqual(const OperationComparator &lhs, const OperationComparator &rhs)
static OperationComparator getEmptyKey()