LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKRedundantReadAndWriteEliminationPass.cpp
Go to the documentation of this file.
1//===-- LLZKRedundantReadAndWriteEliminationPass.cpp ------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
19#include "llzk/Util/Concepts.h"
21
22#include <mlir/IR/BuiltinOps.h>
23
24#include <llvm/ADT/DenseMap.h>
25#include <llvm/ADT/DenseMapInfo.h>
26#include <llvm/ADT/SmallVector.h>
27#include <llvm/Support/Debug.h>
28
29#include <deque>
30#include <memory>
31
32// Include the generated base pass class definitions.
33namespace llzk {
34#define GEN_PASS_DEF_REDUNDANTREADANDWRITEELIMINATIONPASS
36} // namespace llzk
37
38using namespace mlir;
39using namespace llzk;
40using namespace llzk::array;
41using namespace llzk::felt;
42using namespace llzk::function;
43using namespace llzk::component;
44
45#define DEBUG_TYPE "llzk-redundant-read-write-pass"
46
47namespace {
48
51class ReferenceID {
52public:
53 explicit ReferenceID(Value v) {
54 // reserved special pointer values for DenseMapInfo
55 if (v.getImpl() == reinterpret_cast<mlir::detail::ValueImpl *>(1) ||
56 v.getImpl() == reinterpret_cast<mlir::detail::ValueImpl *>(2)) {
57 identifier = v;
58 } else if (auto constVal = dyn_cast_if_present<FeltConstantOp>(v.getDefiningOp())) {
59 identifier = constVal.getValue();
60 } else if (auto constIdxVal = dyn_cast_if_present<arith::ConstantIndexOp>(v.getDefiningOp())) {
61 identifier = llvm::cast<IntegerAttr>(constIdxVal.getValue()).getValue();
62 } else {
63 identifier = v;
64 }
65 }
66 explicit ReferenceID(FlatSymbolRefAttr s) : identifier(s) {}
67 explicit ReferenceID(const APInt &i) : identifier(i) {}
68 explicit ReferenceID(unsigned i) : identifier(APInt(64, i)) {}
69
70 bool isValue() const { return std::holds_alternative<Value>(identifier); }
71 bool isSymbol() const { return std::holds_alternative<FlatSymbolRefAttr>(identifier); }
72 bool isConst() const { return std::holds_alternative<APInt>(identifier); }
73
74 Value getValue() const {
75 ensure(isValue(), "does not hold Value");
76 return std::get<Value>(identifier);
77 }
78
79 FlatSymbolRefAttr getSymbol() const {
80 ensure(isSymbol(), "does not hold symbol");
81 return std::get<FlatSymbolRefAttr>(identifier);
82 }
83
84 APInt getConst() const {
85 ensure(isConst(), "does not hold const");
86 return std::get<APInt>(identifier);
87 }
88
89 void print(raw_ostream &os) const {
90 if (const auto *v = std::get_if<Value>(&identifier)) {
91 if (auto opres = dyn_cast<OpResult>(*v)) {
92 os << '%' << opres.getResultNumber();
93 } else {
94 os << *v;
95 }
96 } else if (const auto *s = std::get_if<FlatSymbolRefAttr>(&identifier)) {
97 os << *s;
98 } else {
99 os << std::get<APInt>(identifier);
100 }
101 }
102
103 friend bool operator==(const ReferenceID &lhs, const ReferenceID &rhs) {
104 return lhs.identifier == rhs.identifier;
105 }
106
107 friend raw_ostream &operator<<(raw_ostream &os, const ReferenceID &id) {
108 id.print(os);
109 return os;
110 }
111
112private:
117 std::variant<FlatSymbolRefAttr, APInt, Value> identifier;
118};
119
120} // namespace
121
122namespace llvm {
123
125template <> struct DenseMapInfo<ReferenceID> {
126 static ReferenceID getEmptyKey() {
127 return ReferenceID(mlir::Value(reinterpret_cast<mlir::detail::ValueImpl *>(1)));
128 }
129 static inline ReferenceID getTombstoneKey() {
130 return ReferenceID(mlir::Value(reinterpret_cast<mlir::detail::ValueImpl *>(2)));
131 }
132 static unsigned getHashValue(const ReferenceID &r) {
133 if (r.isValue()) {
134 return hash_value(r.getValue());
135 } else if (r.isSymbol()) {
136 return hash_value(r.getSymbol());
137 }
138 return hash_value(r.getConst());
139 }
140 static bool isEqual(const ReferenceID &lhs, const ReferenceID &rhs) { return lhs == rhs; }
141};
142
143} // namespace llvm
144
145namespace {
146
165class ReferenceNode {
166public:
167 template <typename IdType> static std::shared_ptr<ReferenceNode> create(IdType id, Value v) {
168 ReferenceNode n(id, v);
169 // Need the move constructor version since constructor is private
170 return std::make_shared<ReferenceNode>(std::move(n));
171 }
172
175 std::shared_ptr<ReferenceNode> clone(bool withChildren = true) const {
176 ReferenceNode copy(identifier, storedValue);
177 copy.updateLastWrite(lastWrite);
178 if (withChildren) {
179 for (const auto &[id, child] : children) {
180 copy.children[id] = child->clone(withChildren);
181 }
182 }
183 return std::make_shared<ReferenceNode>(std::move(copy));
184 }
185
186 template <typename IdType>
187 std::shared_ptr<ReferenceNode>
188 createChild(IdType id, Value storedVal, const std::shared_ptr<ReferenceNode> &valTree = nullptr) {
189 std::shared_ptr<ReferenceNode> child = create(id, storedVal);
190 child->setCurrentValue(storedVal, valTree);
191 children[child->identifier] = child;
192 return child;
193 }
194
197 template <typename IdType> std::shared_ptr<ReferenceNode> getChild(IdType id) const {
198 auto it = children.find(ReferenceID(id));
199 if (it != children.end()) {
200 return it->second;
201 }
202 return nullptr;
203 }
204
208 template <typename IdType>
209 std::shared_ptr<ReferenceNode> getOrCreateChild(IdType id, Value storedVal = nullptr) {
210 auto it = children.find(ReferenceID(id));
211 if (it != children.end()) {
212 return it->second;
213 }
214 return createChild(id, storedVal);
215 }
216
219 Operation *updateLastWrite(Operation *writeOp) {
220 Operation *old = lastWrite;
221 lastWrite = writeOp;
222 return old;
223 }
224
225 void setCurrentValue(Value v, const std::shared_ptr<ReferenceNode> &valTree = nullptr) {
226 storedValue = v;
227 if (valTree != nullptr) {
228 // Overwrite our current set of children with new children, since we overwrote
229 // the stored value.
230 children = valTree->children;
231 }
232 }
233
234 void invalidateChildren() { children.clear(); }
235
236 bool isLeaf() const { return children.empty(); }
237
238 Value getStoredValue() const { return storedValue; }
239
240 bool hasStoredValue() const { return storedValue != nullptr; }
241
242 void print(raw_ostream &os, int indent = 0) const {
243 os.indent(indent) << '[' << identifier;
244 if (storedValue != nullptr) {
245 os << " => " << storedValue;
246 }
247 os << ']';
248 if (!children.empty()) {
249 os << "{\n";
250 for (const auto &[_, child] : children) {
251 child->print(os, indent + 4);
252 os << '\n';
253 }
254 os.indent(indent) << '}';
255 }
256 }
257
258 [[maybe_unused]]
259 friend raw_ostream &operator<<(raw_ostream &os, const ReferenceNode &r) {
260 r.print(os);
261 return os;
262 }
263
265 friend bool
266 topLevelEq(const std::shared_ptr<ReferenceNode> &lhs, const std::shared_ptr<ReferenceNode> &rhs) {
267 return lhs->identifier == rhs->identifier && lhs->storedValue == rhs->storedValue &&
268 lhs->lastWrite == rhs->lastWrite;
269 }
270
271 friend std::shared_ptr<ReferenceNode> greatestCommonSubtree(
272 const std::shared_ptr<ReferenceNode> &lhs, const std::shared_ptr<ReferenceNode> &rhs
273 ) {
274 if (!topLevelEq(lhs, rhs)) {
275 return nullptr;
276 }
277 auto res = lhs->clone(false); // childless clone
278 // Find common children and recurse
279 for (auto &[id, lhsChild] : lhs->children) {
280 if (auto it = rhs->children.find(id); it != rhs->children.end()) {
281 auto &rhsChild = it->second;
282 if (auto gcs = greatestCommonSubtree(lhsChild, rhsChild)) {
283 res->children[id] = gcs;
284 }
285 }
286 }
287 return res;
288 }
289
290private:
291 ReferenceID identifier;
292 mlir::Value storedValue;
293 Operation *lastWrite;
294 DenseMap<ReferenceID, std::shared_ptr<ReferenceNode>> children;
295
296 template <typename IdType>
297 ReferenceNode(IdType id, Value initialVal)
298 : identifier(std::move(id)), storedValue(initialVal), lastWrite(nullptr), children() {}
299};
300
301using ValueMap = DenseMap<mlir::Value, std::shared_ptr<ReferenceNode>>;
302
303ValueMap intersect(const ValueMap &lhs, const ValueMap &rhs) {
304 ValueMap res;
305 for (const auto &[id, lhsValTree] : lhs) {
306 if (auto it = rhs.find(id); it != rhs.end()) {
307 const auto &rhsValTree = it->second;
308 res[id] = greatestCommonSubtree(lhsValTree, rhsValTree);
309 }
310 }
311 return res;
312}
313
316ValueMap cloneValueMap(const ValueMap &orig) {
317 ValueMap res;
318 for (const auto &[id, tree] : orig) {
319 res[id] = tree->clone();
320 }
321 return res;
322}
323
324class PassImpl : public llzk::impl::RedundantReadAndWriteEliminationPassBase<PassImpl> {
325 using Base = RedundantReadAndWriteEliminationPassBase<PassImpl>;
326 using Base::Base;
327
333 void runOnOperation() override {
334 getOperation().walk([&](FuncDefOp fn) { runOnFunc(fn); });
335 }
336
339 void runOnFunc(FuncDefOp fn) {
340 // Nothing to do for body-less functions.
341 if (fn.getCallableRegion() == nullptr) {
342 return;
343 }
344
345 LLVM_DEBUG(llvm::dbgs() << "Running on " << fn.getName() << '\n');
346
347 // Maps redundant value -> necessary value.
348 DenseMap<Value, Value> replacementMap;
349 // All values created by a new_* operation or from a read*/extract* operation.
350 SmallVector<Value> readVals;
351 // All writes that are either (1) overwritten by subsequent writes or (2)
352 // write a value that is already written.
353 SmallVector<Operation *> redundantWrites;
354
355 ValueMap initState;
356 // Initialize the state to the function arguments.
357 for (auto arg : fn.getArguments()) {
358 initState[arg] = ReferenceNode::create(arg, arg);
359 }
360 // Functions only have a single region
361 (void)runOnRegion(
362 *fn.getCallableRegion(), std::move(initState), replacementMap, readVals, redundantWrites
363 );
364
365 // Now that we have accumulated all necessary state, we perform the optimizations:
366 // - Replace all redundant values.
367 for (auto &[orig, replace] : replacementMap) {
368 LLVM_DEBUG(llvm::dbgs() << "replacing " << orig << " with " << orig << '\n');
369 orig.replaceAllUsesWith(replace);
370 // We save the deletion to the readVals loop to prevent double-free.
371 }
372 // -Remove redundant writes now that it is safe to do so.
373 for (auto *writeOp : redundantWrites) {
374 LLVM_DEBUG(llvm::dbgs() << "erase write: " << *writeOp << '\n');
375 writeOp->erase();
376 }
377 // - Now we do a pass over read values to see if any are now unused.
378 // We do this in reverse order to free up early reads if their users would
379 // be removed.
380 for (auto it = readVals.rbegin(); it != readVals.rend(); it++) {
381 Value readVal = *it;
382 if (readVal.use_empty()) {
383 LLVM_DEBUG(llvm::dbgs() << "erase read: " << readVal << '\n');
384 readVal.getDefiningOp()->erase();
385 }
386 }
387 }
388
389 ValueMap runOnRegion(
390 Region &r, ValueMap &&initState, DenseMap<Value, Value> &replacementMap,
391 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
392 ) {
393 // maps block -> state at the end of the block
394 DenseMap<Block *, ValueMap> endStates;
395 // The first block has no predecessors, so nullptr contains the init state
396 endStates[nullptr] = initState;
397 auto getBlockState = [&endStates](Block *blockPtr) {
398 auto it = endStates.find(blockPtr);
399 ensure(it != endStates.end(), "unknown end state means we have an unsupported backedge");
400 return cloneValueMap(it->second);
401 };
402 std::deque<Block *> frontier;
403 frontier.push_back(&r.front());
404 DenseSet<Block *> visited;
405
406 SmallVector<std::reference_wrapper<const ValueMap>> terminalStates;
407
408 while (!frontier.empty()) {
409 Block *currentBlock = frontier.front();
410 frontier.pop_front();
411 visited.insert(currentBlock);
412
413 // get predecessors
414 ValueMap currentState;
415 auto it = currentBlock->pred_begin();
416 auto itEnd = currentBlock->pred_end();
417 if (it == itEnd) {
418 // get the state for the entry block.
419 currentState = getBlockState(nullptr);
420 } else {
421 currentState = getBlockState(*it);
422 // If we have multiple predecessors, we take a pessimistic view and
423 // set the state as only the intersection of all predecessor states
424 // (e.g., only the common state from an if branch).
425 for (it++; it != itEnd; it++) {
426 currentState = intersect(currentState, getBlockState(*it));
427 }
428 }
429
430 // Run this block, consuming currentState and producing the endState
431 auto endState = runOnBlock(
432 *currentBlock, std::move(currentState), replacementMap, readVals, redundantWrites
433 );
434
435 // Update the end states.
436 // Since we only support the scf dialect, we should never have any
437 // backedges, so we should never already have state for this block.
438 ensure(endStates.find(currentBlock) == endStates.end(), "backedge");
439 endStates[currentBlock] = std::move(endState);
440
441 // add successors to frontier
442 if (currentBlock->hasNoSuccessors()) {
443 terminalStates.push_back(endStates[currentBlock]);
444 } else {
445 for (Block *succ : currentBlock->getSuccessors()) {
446 if (visited.find(succ) == visited.end()) {
447 frontier.push_back(succ);
448 }
449 }
450 }
451 }
452
453 // The final state is the intersection of all possible terminal states.
454 ensure(!terminalStates.empty(), "computed no states");
455 auto finalState = terminalStates.front().get();
456 for (const auto *it = terminalStates.begin() + 1; it != terminalStates.end(); it++) {
457 finalState = intersect(finalState, it->get());
458 }
459 return finalState;
460 }
461
462 ValueMap runOnBlock(
463 Block &b, ValueMap &&state, DenseMap<Value, Value> &replacementMap,
464 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
465 ) {
466 for (Operation &op : b) {
467 runOperation(&op, state, replacementMap, readVals, redundantWrites);
468 // Some operations have regions (e.g., scf.if). These regions must be
469 // traversed and the resulting state(s) are intersected for the final
470 // state of this operation.
471 if (!op.getRegions().empty()) {
472 SmallVector<ValueMap> regionStates;
473 for (Region &region : op.getRegions()) {
474 auto regionState =
475 runOnRegion(region, cloneValueMap(state), replacementMap, readVals, redundantWrites);
476 regionStates.push_back(regionState);
477 }
478
479 ValueMap finalState = regionStates.front();
480 for (const auto *it = regionStates.begin() + 1; it != regionStates.end(); it++) {
481 finalState = intersect(finalState, *it);
482 }
483 state = std::move(finalState);
484 }
485 }
486 return std::move(state);
487 }
488
496 void runOperation(
497 Operation *op, ValueMap &state, DenseMap<Value, Value> &replacementMap,
498 SmallVector<Value> &readVals, SmallVector<Operation *> &redundantWrites
499 ) {
500 // Uses the replacement map to look up values to simplify later replacement.
501 // This avoids having a daisy chain of "replace B with A", "replace C with B",
502 // etc.
503 auto translate = [&replacementMap](Value v) {
504 if (auto it = replacementMap.find(v); it != replacementMap.end()) {
505 return it->second;
506 }
507 return v;
508 };
509
510 // Lookup the value tree in the current state or return nullptr.
511 auto tryGetValTree = [&state](Value v) -> std::shared_ptr<ReferenceNode> {
512 if (auto it = state.find(v); it != state.end()) {
513 return it->second;
514 }
515 return nullptr;
516 };
517
518 // Read a value from an array. This works on both readarr operations (which
519 // return a scalar value) and extractarr operations (which return a subarray).
520 auto doArrayReadLike = [&]<HasInterface<ArrayAccessOpInterface> OpClass>(OpClass readarr) {
521 std::shared_ptr<ReferenceNode> currValTree = state.at(translate(readarr.getArrRef()));
522
523 for (Value origIdx : readarr.getIndices()) {
524 Value idxVal = translate(origIdx);
525 currValTree = currValTree->getOrCreateChild(idxVal);
526 }
527
528 Value resVal = readarr.getResult();
529 if (!currValTree->hasStoredValue()) {
530 currValTree->setCurrentValue(resVal);
531 }
532
533 if (currValTree->getStoredValue() != resVal) {
534 LLVM_DEBUG(
535 llvm::dbgs() << readarr.getOperationName() << ": replace " << resVal << " with "
536 << currValTree->getStoredValue() << '\n'
537 );
538 replacementMap[resVal] = currValTree->getStoredValue();
539 } else {
540 state[resVal] = currValTree;
541 LLVM_DEBUG(
542 llvm::dbgs() << readarr.getOperationName() << ": " << resVal << " => " << *currValTree
543 << '\n'
544 );
545 }
546
547 readVals.push_back(resVal);
548 };
549
550 // Write a scalar value (for writearr) or a subarray value (for insertarr)
551 // to an array. The unique part of this operation relative to others is that
552 // we may receive a variable index (i.e., not a constant). In this case, we
553 // invalidate ajoining parts of the subtree, since it is possible that
554 // the variable index aliases one of the other elements and may or may not
555 // override that value.
556 auto doArrayWriteLike = [&]<HasInterface<ArrayAccessOpInterface> OpClass>(OpClass writearr) {
557 std::shared_ptr<ReferenceNode> currValTree = state.at(translate(writearr.getArrRef()));
558 Value newVal = translate(writearr.getRvalue());
559 std::shared_ptr<ReferenceNode> valTree = tryGetValTree(newVal);
560
561 for (Value origIdx : writearr.getIndices()) {
562 Value idxVal = translate(origIdx);
563 // This write will invalidate all children, since it may reference
564 // any number of them.
565 if (ReferenceID(idxVal).isValue()) {
566 LLVM_DEBUG(llvm::dbgs() << writearr.getOperationName() << ": invalidate alias\n");
567 currValTree->invalidateChildren();
568 }
569 currValTree = currValTree->getOrCreateChild(idxVal);
570 }
571
572 if (currValTree->getStoredValue() == newVal) {
573 LLVM_DEBUG(
574 llvm::dbgs() << writearr.getOperationName() << ": subsequent " << writearr
575 << " is redundant\n"
576 );
577 redundantWrites.push_back(writearr);
578 } else {
579 if (Operation *lastWrite = currValTree->updateLastWrite(writearr)) {
580 LLVM_DEBUG(
581 llvm::dbgs() << writearr.getOperationName() << "writearr: replacing " << lastWrite
582 << " with prior write " << *lastWrite << '\n'
583 );
584 redundantWrites.push_back(lastWrite);
585 }
586 currValTree->setCurrentValue(newVal, valTree);
587 }
588 };
589
590 // struct ops
591 if (auto newStruct = dyn_cast<CreateStructOp>(op)) {
592 // For new values, the "stored value" of the reference is the creation site.
593 auto structVal = ReferenceNode::create(newStruct, newStruct);
594 state[newStruct] = structVal;
595 LLVM_DEBUG(llvm::dbgs() << newStruct.getOperationName() << ": " << *state[newStruct] << '\n');
596 // adding this to readVals
597 readVals.push_back(newStruct);
598 } else if (auto readm = dyn_cast<MemberReadOp>(op)) {
599 auto structVal = state.at(translate(readm.getComponent()));
600 FlatSymbolRefAttr symbol = readm.getMemberNameAttr();
601 Value resVal = translate(readm.getVal());
602 // Check if such a child already exists.
603 if (auto child = structVal->getChild(symbol)) {
604 LLVM_DEBUG(
605 llvm::dbgs() << readm.getOperationName() << ": adding replacement map entry { "
606 << resVal << " => " << child->getStoredValue() << " }\n"
607 );
608 replacementMap[resVal] = child->getStoredValue();
609 } else {
610 // If we have no previous store, we create a new symbolic value for
611 // this location.
612 state[readm] = structVal->createChild(symbol, resVal);
613 LLVM_DEBUG(llvm::dbgs() << readm.getOperationName() << ": " << *state[readm] << '\n');
614 }
615 // specifically add the untranslated value back for removal checks
616 readVals.push_back(readm.getVal());
617 } else if (auto writem = dyn_cast<MemberWriteOp>(op)) {
618 auto structVal = state.at(translate(writem.getComponent()));
619 Value writeVal = translate(writem.getVal());
620 FlatSymbolRefAttr symbol = writem.getMemberNameAttr();
621 auto valTree = tryGetValTree(writeVal);
623 auto child = structVal->getOrCreateChild(symbol);
624 if (child->getStoredValue() == writeVal) {
625 LLVM_DEBUG(
626 llvm::dbgs() << writem.getOperationName() << ": recording redundant write " << writem
627 << '\n'
628 );
629 redundantWrites.push_back(writem);
630 } else {
631 if (auto *lastWrite = child->updateLastWrite(writem)) {
632 LLVM_DEBUG(
633 llvm::dbgs() << writem.getOperationName() << ": recording overwritten write "
634 << *lastWrite << '\n'
635 );
636 redundantWrites.push_back(lastWrite);
637 }
638 child->setCurrentValue(writeVal, valTree);
639 LLVM_DEBUG(
640 llvm::dbgs() << writem.getOperationName() << ": " << *child << " set to " << writeVal
641 << '\n'
642 );
643 }
644 }
645 // array ops
646 else if (auto newArray = dyn_cast<CreateArrayOp>(op)) {
647 auto arrayVal = ReferenceNode::create(newArray, newArray);
648 state[newArray] = arrayVal;
649
650 // If we're given a constructor, we can instantiate elements using
651 // constant indices.
652 unsigned idx = 0;
653 for (auto elem : newArray.getElements()) {
654 Value elemVal = translate(elem);
655 auto valTree = tryGetValTree(elemVal);
656 auto elemChild = arrayVal->createChild(idx, elemVal, valTree);
657 LLVM_DEBUG(
658 llvm::dbgs() << newArray.getOperationName() << ": element " << idx << " initialized to "
659 << *elemChild << '\n'
660 );
661 idx++;
662 }
663
664 readVals.push_back(newArray);
665 } else if (auto readarr = dyn_cast<ReadArrayOp>(op)) {
666 doArrayReadLike(readarr);
667 } else if (auto writearr = dyn_cast<WriteArrayOp>(op)) {
668 doArrayWriteLike(writearr);
669 } else if (auto extractarr = dyn_cast<ExtractArrayOp>(op)) {
670 // Logic is essentially the same as readarr
671 doArrayReadLike(extractarr);
672 } else if (auto insertarr = dyn_cast<InsertArrayOp>(op)) {
673 // Logic is essentially the same as writearr
674 doArrayWriteLike(insertarr);
675 }
677};
678
679} // namespace
void print(llvm::raw_ostream &os) const
::mlir::Region * getCallableRegion()
Required by FunctionOpInterface.
Definition Ops.h.inc:846
void ensure(bool condition, const llvm::Twine &errMsg)
Interval operator<<(const Interval &lhs, const Interval &rhs)
mlir::Operation * create(MlirOpBuilder cBuilder, MlirLocation cLocation, Args &&...args)
Creates a new operation using an ODS build method.
Definition Builder.h:41
static bool isEqual(const ReferenceID &lhs, const ReferenceID &rhs)