22#include <mlir/IR/BuiltinOps.h>
24#include <llvm/ADT/DenseMap.h>
25#include <llvm/ADT/DenseMapInfo.h>
26#include <llvm/ADT/SmallVector.h>
27#include <llvm/Support/Debug.h>
34#define GEN_PASS_DEF_REDUNDANTREADANDWRITEELIMINATIONPASS
45#define DEBUG_TYPE "llzk-redundant-read-write-pass"
53 explicit ReferenceID(Value v) {
55 if (v.getImpl() ==
reinterpret_cast<mlir::detail::ValueImpl *
>(1) ||
56 v.getImpl() ==
reinterpret_cast<mlir::detail::ValueImpl *
>(2)) {
58 }
else if (
auto constVal = dyn_cast_if_present<FeltConstantOp>(v.getDefiningOp())) {
59 identifier = constVal.getValue();
60 }
else if (
auto constIdxVal = dyn_cast_if_present<arith::ConstantIndexOp>(v.getDefiningOp())) {
61 identifier = llvm::cast<IntegerAttr>(constIdxVal.getValue()).getValue();
66 explicit ReferenceID(FlatSymbolRefAttr s) : identifier(s) {}
67 explicit ReferenceID(
const APInt &i) : identifier(i) {}
68 explicit ReferenceID(
unsigned i) : identifier(APInt(64, i)) {}
70 bool isValue()
const {
return std::holds_alternative<Value>(identifier); }
71 bool isSymbol()
const {
return std::holds_alternative<FlatSymbolRefAttr>(identifier); }
72 bool isConst()
const {
return std::holds_alternative<APInt>(identifier); }
74 Value getValue()
const {
75 ensure(isValue(),
"does not hold Value");
76 return std::get<Value>(identifier);
79 FlatSymbolRefAttr getSymbol()
const {
80 ensure(isSymbol(),
"does not hold symbol");
81 return std::get<FlatSymbolRefAttr>(identifier);
84 APInt getConst()
const {
85 ensure(isConst(),
"does not hold const");
86 return std::get<APInt>(identifier);
89 void print(raw_ostream &os)
const {
90 if (
const auto *v = std::get_if<Value>(&identifier)) {
91 if (
auto opres = dyn_cast<OpResult>(*v)) {
92 os <<
'%' << opres.getResultNumber();
96 }
else if (
const auto *s = std::get_if<FlatSymbolRefAttr>(&identifier)) {
99 os << std::get<APInt>(identifier);
103 friend bool operator==(
const ReferenceID &lhs,
const ReferenceID &rhs) {
104 return lhs.identifier == rhs.identifier;
107 friend raw_ostream &
operator<<(raw_ostream &os,
const ReferenceID &
id) {
117 std::variant<FlatSymbolRefAttr, APInt, Value> identifier;
125template <>
struct DenseMapInfo<ReferenceID> {
127 return ReferenceID(mlir::Value(
reinterpret_cast<mlir::detail::ValueImpl *
>(1)));
130 return ReferenceID(mlir::Value(
reinterpret_cast<mlir::detail::ValueImpl *
>(2)));
134 return hash_value(r.getValue());
135 }
else if (r.isSymbol()) {
136 return hash_value(r.getSymbol());
138 return hash_value(r.getConst());
140 static bool isEqual(
const ReferenceID &lhs,
const ReferenceID &rhs) {
return lhs == rhs; }
167 template <
typename IdType>
static std::shared_ptr<ReferenceNode>
create(IdType
id, Value v) {
168 ReferenceNode n(
id, v);
170 return std::make_shared<ReferenceNode>(std::move(n));
175 std::shared_ptr<ReferenceNode> clone(
bool withChildren =
true)
const {
176 ReferenceNode copy(identifier, storedValue);
177 copy.updateLastWrite(lastWrite);
179 for (
const auto &[
id, child] : children) {
180 copy.children[id] = child->clone(withChildren);
183 return std::make_shared<ReferenceNode>(std::move(copy));
186 template <
typename IdType>
187 std::shared_ptr<ReferenceNode>
188 createChild(IdType
id, Value storedVal,
const std::shared_ptr<ReferenceNode> &valTree =
nullptr) {
189 std::shared_ptr<ReferenceNode> child =
create(
id, storedVal);
190 child->setCurrentValue(storedVal, valTree);
191 children[child->identifier] = child;
197 template <
typename IdType> std::shared_ptr<ReferenceNode> getChild(IdType
id)
const {
198 auto it = children.find(ReferenceID(
id));
199 if (it != children.end()) {
208 template <
typename IdType>
209 std::shared_ptr<ReferenceNode> getOrCreateChild(IdType
id, Value storedVal =
nullptr) {
210 auto it = children.find(ReferenceID(
id));
211 if (it != children.end()) {
214 return createChild(
id, storedVal);
219 Operation *updateLastWrite(Operation *writeOp) {
220 Operation *old = lastWrite;
225 void setCurrentValue(Value v,
const std::shared_ptr<ReferenceNode> &valTree =
nullptr) {
227 if (valTree !=
nullptr) {
230 children = valTree->children;
234 void invalidateChildren() { children.clear(); }
236 bool isLeaf()
const {
return children.empty(); }
238 Value getStoredValue()
const {
return storedValue; }
240 bool hasStoredValue()
const {
return storedValue !=
nullptr; }
242 void print(raw_ostream &os,
int indent = 0)
const {
243 os.indent(indent) <<
'[' << identifier;
244 if (storedValue !=
nullptr) {
245 os <<
" => " << storedValue;
248 if (!children.empty()) {
250 for (
const auto &[_, child] : children) {
251 child->print(os, indent + 4);
254 os.indent(indent) <<
'}';
259 friend raw_ostream &
operator<<(raw_ostream &os,
const ReferenceNode &r) {
266 topLevelEq(
const std::shared_ptr<ReferenceNode> &lhs,
const std::shared_ptr<ReferenceNode> &rhs) {
267 return lhs->identifier == rhs->identifier && lhs->storedValue == rhs->storedValue &&
268 lhs->lastWrite == rhs->lastWrite;
271 friend std::shared_ptr<ReferenceNode> greatestCommonSubtree(
272 const std::shared_ptr<ReferenceNode> &lhs,
const std::shared_ptr<ReferenceNode> &rhs
274 if (!topLevelEq(lhs, rhs)) {
277 auto res = lhs->clone(
false);
279 for (
auto &[
id, lhsChild] : lhs->children) {
280 if (
auto it = rhs->children.find(
id); it != rhs->children.end()) {
281 auto &rhsChild = it->second;
282 if (
auto gcs = greatestCommonSubtree(lhsChild, rhsChild)) {
283 res->children[id] = gcs;
291 ReferenceID identifier;
292 mlir::Value storedValue;
293 Operation *lastWrite;
294 DenseMap<ReferenceID, std::shared_ptr<ReferenceNode>> children;
296 template <
typename IdType>
297 ReferenceNode(IdType
id, Value initialVal)
298 : identifier(std::move(id)), storedValue(initialVal), lastWrite(nullptr), children() {}
301using ValueMap = DenseMap<mlir::Value, std::shared_ptr<ReferenceNode>>;
303ValueMap intersect(
const ValueMap &lhs,
const ValueMap &rhs) {
305 for (
const auto &[
id, lhsValTree] : lhs) {
306 if (
auto it = rhs.find(
id); it != rhs.end()) {
307 const auto &rhsValTree = it->second;
308 res[id] = greatestCommonSubtree(lhsValTree, rhsValTree);
316ValueMap cloneValueMap(
const ValueMap &orig) {
318 for (
const auto &[
id, tree] : orig) {
319 res[id] = tree->clone();
325 using Base = RedundantReadAndWriteEliminationPassBase<PassImpl>;
333 void runOnOperation()
override {
334 getOperation().walk([&](FuncDefOp fn) { runOnFunc(fn); });
339 void runOnFunc(FuncDefOp fn) {
345 LLVM_DEBUG(llvm::dbgs() <<
"Running on " << fn.getName() <<
'\n');
348 DenseMap<Value, Value> replacementMap;
350 SmallVector<Value> readVals;
353 SmallVector<Operation *> redundantWrites;
357 for (
auto arg : fn.getArguments()) {
358 initState[arg] = ReferenceNode::create(arg, arg);
362 *fn.
getCallableRegion(), std::move(initState), replacementMap, readVals, redundantWrites
367 for (
auto &[orig, replace] : replacementMap) {
368 LLVM_DEBUG(llvm::dbgs() <<
"replacing " << orig <<
" with " << orig <<
'\n');
369 orig.replaceAllUsesWith(replace);
373 for (
auto *writeOp : redundantWrites) {
374 LLVM_DEBUG(llvm::dbgs() <<
"erase write: " << *writeOp <<
'\n');
380 for (
auto it = readVals.rbegin(); it != readVals.rend(); it++) {
382 if (readVal.use_empty()) {
383 LLVM_DEBUG(llvm::dbgs() <<
"erase read: " << readVal <<
'\n');
384 readVal.getDefiningOp()->erase();
389 ValueMap runOnRegion(
390 Region &r, ValueMap &&initState, DenseMap<Value, Value> &replacementMap,
391 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
394 DenseMap<Block *, ValueMap> endStates;
396 endStates[
nullptr] = initState;
397 auto getBlockState = [&endStates](Block *blockPtr) {
398 auto it = endStates.find(blockPtr);
399 ensure(it != endStates.end(),
"unknown end state means we have an unsupported backedge");
400 return cloneValueMap(it->second);
402 std::deque<Block *> frontier;
403 frontier.push_back(&r.front());
404 DenseSet<Block *> visited;
406 SmallVector<std::reference_wrapper<const ValueMap>> terminalStates;
408 while (!frontier.empty()) {
409 Block *currentBlock = frontier.front();
410 frontier.pop_front();
411 visited.insert(currentBlock);
414 ValueMap currentState;
415 auto it = currentBlock->pred_begin();
416 auto itEnd = currentBlock->pred_end();
419 currentState = getBlockState(
nullptr);
421 currentState = getBlockState(*it);
425 for (it++; it != itEnd; it++) {
426 currentState = intersect(currentState, getBlockState(*it));
431 auto endState = runOnBlock(
432 *currentBlock, std::move(currentState), replacementMap, readVals, redundantWrites
438 ensure(endStates.find(currentBlock) == endStates.end(),
"backedge");
439 endStates[currentBlock] = std::move(endState);
442 if (currentBlock->hasNoSuccessors()) {
443 terminalStates.push_back(endStates[currentBlock]);
445 for (Block *succ : currentBlock->getSuccessors()) {
446 if (visited.find(succ) == visited.end()) {
447 frontier.push_back(succ);
454 ensure(!terminalStates.empty(),
"computed no states");
455 auto finalState = terminalStates.front().get();
456 for (
const auto *it = terminalStates.begin() + 1; it != terminalStates.end(); it++) {
457 finalState = intersect(finalState, it->get());
463 Block &b, ValueMap &&state, DenseMap<Value, Value> &replacementMap,
464 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
466 for (Operation &op : b) {
467 runOperation(&op, state, replacementMap, readVals, redundantWrites);
471 if (!op.getRegions().empty()) {
472 SmallVector<ValueMap> regionStates;
473 for (Region ®ion : op.getRegions()) {
475 runOnRegion(region, cloneValueMap(state), replacementMap, readVals, redundantWrites);
476 regionStates.push_back(regionState);
479 ValueMap finalState = regionStates.front();
480 for (
const auto *it = regionStates.begin() + 1; it != regionStates.end(); it++) {
481 finalState = intersect(finalState, *it);
483 state = std::move(finalState);
486 return std::move(state);
497 Operation *op, ValueMap &state, DenseMap<Value, Value> &replacementMap,
498 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
503 auto translate = [&replacementMap](Value v) {
504 if (
auto it = replacementMap.find(v); it != replacementMap.end()) {
511 auto tryGetValTree = [&state](Value v) -> std::shared_ptr<ReferenceNode> {
512 if (
auto it = state.find(v); it != state.end()) {
520 auto doArrayReadLike = [&]<HasInterface<ArrayAccessOpInterface> OpClass>(OpClass readarr) {
521 std::shared_ptr<ReferenceNode> currValTree = state.at(translate(readarr.getArrRef()));
523 for (Value origIdx : readarr.getIndices()) {
524 Value idxVal = translate(origIdx);
525 currValTree = currValTree->getOrCreateChild(idxVal);
528 Value resVal = readarr.getResult();
529 if (!currValTree->hasStoredValue()) {
530 currValTree->setCurrentValue(resVal);
533 if (currValTree->getStoredValue() != resVal) {
535 llvm::dbgs() << readarr.getOperationName() <<
": replace " << resVal <<
" with "
536 << currValTree->getStoredValue() <<
'\n'
538 replacementMap[resVal] = currValTree->getStoredValue();
540 state[resVal] = currValTree;
542 llvm::dbgs() << readarr.getOperationName() <<
": " << resVal <<
" => " << *currValTree
547 readVals.push_back(resVal);
556 auto doArrayWriteLike = [&]<HasInterface<ArrayAccessOpInterface> OpClass>(OpClass writearr) {
557 std::shared_ptr<ReferenceNode> currValTree = state.at(translate(writearr.getArrRef()));
558 Value newVal = translate(writearr.getRvalue());
559 std::shared_ptr<ReferenceNode> valTree = tryGetValTree(newVal);
561 for (Value origIdx : writearr.getIndices()) {
562 Value idxVal = translate(origIdx);
565 if (ReferenceID(idxVal).isValue()) {
566 LLVM_DEBUG(llvm::dbgs() << writearr.getOperationName() <<
": invalidate alias\n");
567 currValTree->invalidateChildren();
569 currValTree = currValTree->getOrCreateChild(idxVal);
572 if (currValTree->getStoredValue() == newVal) {
574 llvm::dbgs() << writearr.getOperationName() <<
": subsequent " << writearr
577 redundantWrites.push_back(writearr);
579 if (Operation *lastWrite = currValTree->updateLastWrite(writearr)) {
581 llvm::dbgs() << writearr.getOperationName() <<
"writearr: replacing " << lastWrite
582 <<
" with prior write " << *lastWrite <<
'\n'
584 redundantWrites.push_back(lastWrite);
586 currValTree->setCurrentValue(newVal, valTree);
591 if (
auto newStruct = dyn_cast<CreateStructOp>(op)) {
593 auto structVal = ReferenceNode::create(newStruct, newStruct);
594 state[newStruct] = structVal;
595 LLVM_DEBUG(llvm::dbgs() << newStruct.getOperationName() <<
": " << *state[newStruct] <<
'\n');
597 readVals.push_back(newStruct);
598 }
else if (
auto readm = dyn_cast<MemberReadOp>(op)) {
599 auto structVal = state.at(translate(readm.getComponent()));
600 FlatSymbolRefAttr symbol = readm.getMemberNameAttr();
601 Value resVal = translate(readm.getVal());
603 if (
auto child = structVal->getChild(symbol)) {
605 llvm::dbgs() << readm.getOperationName() <<
": adding replacement map entry { "
606 << resVal <<
" => " << child->getStoredValue() <<
" }\n"
608 replacementMap[resVal] = child->getStoredValue();
612 state[readm] = structVal->createChild(symbol, resVal);
613 LLVM_DEBUG(llvm::dbgs() << readm.getOperationName() <<
": " << *state[readm] <<
'\n');
616 readVals.push_back(readm.getVal());
617 }
else if (
auto writem = dyn_cast<MemberWriteOp>(op)) {
618 auto structVal = state.at(translate(writem.getComponent()));
619 Value writeVal = translate(writem.getVal());
620 FlatSymbolRefAttr symbol = writem.getMemberNameAttr();
621 auto valTree = tryGetValTree(writeVal);
623 auto child = structVal->getOrCreateChild(symbol);
624 if (child->getStoredValue() == writeVal) {
626 llvm::dbgs() << writem.getOperationName() <<
": recording redundant write " << writem
629 redundantWrites.push_back(writem);
631 if (
auto *lastWrite = child->updateLastWrite(writem)) {
633 llvm::dbgs() << writem.getOperationName() <<
": recording overwritten write "
634 << *lastWrite <<
'\n'
636 redundantWrites.push_back(lastWrite);
638 child->setCurrentValue(writeVal, valTree);
640 llvm::dbgs() << writem.getOperationName() <<
": " << *child <<
" set to " << writeVal
646 else if (
auto newArray = dyn_cast<CreateArrayOp>(op)) {
647 auto arrayVal = ReferenceNode::create(newArray, newArray);
648 state[newArray] = arrayVal;
653 for (
auto elem : newArray.getElements()) {
654 Value elemVal = translate(elem);
655 auto valTree = tryGetValTree(elemVal);
656 auto elemChild = arrayVal->createChild(idx, elemVal, valTree);
658 llvm::dbgs() << newArray.getOperationName() <<
": element " << idx <<
" initialized to "
659 << *elemChild <<
'\n'
664 readVals.push_back(newArray);
665 }
else if (
auto readarr = dyn_cast<ReadArrayOp>(op)) {
666 doArrayReadLike(readarr);
667 }
else if (
auto writearr = dyn_cast<WriteArrayOp>(op)) {
668 doArrayWriteLike(writearr);
669 }
else if (
auto extractarr = dyn_cast<ExtractArrayOp>(op)) {
671 doArrayReadLike(extractarr);
672 }
else if (
auto insertarr = dyn_cast<InsertArrayOp>(op)) {
674 doArrayWriteLike(insertarr);
void print(llvm::raw_ostream &os) const
::mlir::Region * getCallableRegion()
Required by FunctionOpInterface.
void ensure(bool condition, const llvm::Twine &errMsg)
Interval operator<<(const Interval &lhs, const Interval &rhs)
mlir::Operation * create(MlirOpBuilder cBuilder, MlirLocation cLocation, Args &&...args)
Creates a new operation using an ODS build method.
static bool isEqual(const ReferenceID &lhs, const ReferenceID &rhs)
static ReferenceID getEmptyKey()
static ReferenceID getTombstoneKey()
static unsigned getHashValue(const ReferenceID &r)