LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKRedundantOperationEliminationPass.cpp
Go to the documentation of this file.
1//===-- LLZKRedundantOperationEliminationPass.cpp ---------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
22
23#include <mlir/Dialect/Arith/IR/Arith.h>
24#include <mlir/Dialect/SCF/IR/SCF.h>
25#include <mlir/IR/BuiltinOps.h>
26#include <mlir/IR/Dominance.h>
27
28#include <llvm/ADT/DenseMap.h>
29#include <llvm/ADT/PostOrderIterator.h>
30#include <llvm/ADT/SmallVector.h>
31
32#include <deque>
33
34// Include the generated base pass class definitions.
35namespace llzk {
36#define GEN_PASS_DEF_REDUNDANTOPERATIONELIMINATIONPASS
38} // namespace llzk
39
40using namespace mlir;
41using namespace llzk;
42using namespace llzk::boolean;
43using namespace llzk::component;
44using namespace llzk::constrain;
45using namespace llzk::function;
46
47#define DEBUG_TYPE "llzk-duplicate-op-elim"
48
49namespace {
50
51static auto EMPTY_OP_KEY = reinterpret_cast<Operation *>(1);
52static auto TOMBSTONE_OP_KEY = reinterpret_cast<Operation *>(2);
53
54// Maps original -> replacement value
55using TranslationMap = DenseMap<Value, Value>;
56
60class OperationComparator {
61public:
62 explicit OperationComparator(Operation *o) : op(o) {
63 if (op != EMPTY_OP_KEY && op != TOMBSTONE_OP_KEY) {
64 operands = SmallVector<Value>(op->getOperands());
65 }
66 }
67
68 OperationComparator(Operation *o, const TranslationMap &m) : op(o) {
69 for (auto operand : op->getOperands()) {
70 if (auto it = m.find(operand); it != m.end()) {
71 operands.push_back(it->second);
72 } else {
73 operands.push_back(operand);
74 }
75 }
76 }
77
78 Operation *getOp() const { return op; }
79
80 const SmallVector<Value> &getOperands() const { return operands; }
81
82 bool isCommutative() const { return op->hasTrait<OpTrait::IsCommutative>(); }
83
84 friend bool operator==(const OperationComparator &lhs, const OperationComparator &rhs) {
85 if (lhs.op == EMPTY_OP_KEY || rhs.op == EMPTY_OP_KEY || lhs.op == TOMBSTONE_OP_KEY ||
86 rhs.op == TOMBSTONE_OP_KEY) {
87 return lhs.op == rhs.op;
88 }
89
90 if (lhs.op->getName() != rhs.op->getName()) {
91 return false;
92 }
93
94 // uninterested in operating over control-flow ops
95 auto dialectName = lhs.op->getDialect()->getNamespace();
96 if (dialectName == scf::SCFDialect::getDialectNamespace()) {
97 return false;
98 }
99
100 // This may be overly restrictive in some cases, but without knowing what
101 // potential future attributes we may have, it's safer to assume that
102 // unequal attributes => unequal operations.
103 // This covers constant operations too, as the constant is an attribute,
104 // not an operand.
105 if (lhs.op->getAttrs() != rhs.op->getAttrs()) {
106 return false;
107 }
108 // For commutative operations, just check if the operands contain the same set in any order
109 if (lhs.isCommutative()) {
110 ensure(
111 lhs.operands.size() == 2 && rhs.operands.size() == 2,
112 "No known commutative ops have more than two arguments"
113 );
114 return (lhs.operands[0] == rhs.operands[0] && lhs.operands[1] == rhs.operands[1]) ||
115 (lhs.operands[0] == rhs.operands[1] && lhs.operands[1] == rhs.operands[0]);
116 }
117
118 // The default case requires an exact match per argument
119 return lhs.operands == rhs.operands;
120 }
121
122private:
123 Operation *op;
124 SmallVector<Value> operands;
125};
126
127} // namespace
128
129namespace llvm {
130
131template <> struct DenseMapInfo<OperationComparator> {
132 static OperationComparator getEmptyKey() { return OperationComparator(EMPTY_OP_KEY); }
133 static inline OperationComparator getTombstoneKey() {
134 return OperationComparator(TOMBSTONE_OP_KEY);
135 }
136 static unsigned getHashValue(const OperationComparator &oc) {
137 if (oc.getOp() == EMPTY_OP_KEY || oc.getOp() == TOMBSTONE_OP_KEY) {
138 return hash_value(oc.getOp());
139 }
140 // Just hash on name to force more thorough equality checks by operation type.
141 return hash_value(oc.getOp()->getName());
142 }
143 static bool isEqual(const OperationComparator &lhs, const OperationComparator &rhs) {
144 return lhs == rhs;
145 }
146};
147
148} // namespace llvm
149
150namespace {
151
152class PassImpl : public llzk::impl::RedundantOperationEliminationPassBase<PassImpl> {
153 using Base = RedundantOperationEliminationPassBase<PassImpl>;
154 using Base::Base;
155
156 void runOnOperation() override {
157 SymbolTableCollection symbolTables;
158 // Traverse functions from the bottom of the call graph up.
159 // This way, we may create empty constrain functions to which we can eliminate
160 // calls.
161 auto &cga = getAnalysis<CallGraphAnalysis>();
162 const llzk::CallGraph *callGraph = &cga.getCallGraph();
163 for (auto it = llvm::po_begin(callGraph); it != llvm::po_end(callGraph); ++it) {
164 const llzk::CallGraphNode *node = *it;
165 if (!node->isExternal()) {
166 runOnFunc(symbolTables, node->getCalledFunction());
167 }
168 }
169 }
170
171 bool isPurposelessConstrainFunc(SymbolTableCollection &symbolTables, FuncDefOp fn) {
172 if (!fn.isStructConstrain()) {
173 return false;
174 }
175
176 bool res = true;
177 fn.walk([&](Operation *op) {
178 if (isa<EmitEqualityOp, EmitContainmentOp, AssertOp>(op)) {
179 res = false;
180 return WalkResult::interrupt();
181 } else if (auto callOp = dyn_cast<CallOp>(op);
182 callOp && !callsPurposelessConstrainFunc(symbolTables, callOp)) {
183 res = false;
184 return WalkResult::interrupt();
185 }
186 return WalkResult::advance();
187 });
188 return res;
189 }
190
191 bool callsPurposelessConstrainFunc(SymbolTableCollection &symbolTables, CallOp call) {
192 auto callLookup = resolveCallable<FuncDefOp>(symbolTables, call);
193 return succeeded(callLookup) && isPurposelessConstrainFunc(symbolTables, callLookup->get());
194 }
195
196 void runOnFunc(SymbolTableCollection &symbolTables, FuncDefOp fn) {
197 TranslationMap map;
198 SmallVector<Operation *> redundantOps;
199 DenseSet<OperationComparator> uniqueOps;
200 DominanceInfo domInfo(fn);
201
202 auto unnecessaryOpCheck = [&](Operation *op) -> bool {
203 if (auto emiteq = dyn_cast<EmitEqualityOp>(op);
204 emiteq && emiteq.getLhs() == emiteq.getRhs()) {
205 redundantOps.push_back(op);
206 return true;
207 }
208
209 if (auto callOp = dyn_cast<CallOp>(op);
210 callOp && callsPurposelessConstrainFunc(symbolTables, callOp)) {
211 redundantOps.push_back(op);
212 return true;
213 }
214 return false;
215 };
216
217 fn.walk([&](Operation *op) {
218 // Case 1: The operation itself is unnecessary.
219 if (unnecessaryOpCheck(op)) {
220 return WalkResult::advance();
221 }
222
223 // Case 2: An equivalent operation A has already been performed before
224 // the current operation B and A dominates B.
225 if (!isa<NonDetOp>(op)) {
226 OperationComparator comp(op, map);
227 if (auto it = uniqueOps.find(comp);
228 it != uniqueOps.end() && domInfo.dominates(it->getOp(), op)) {
229 redundantOps.push_back(op);
230 for (unsigned opNum = 0; opNum < op->getNumResults(); opNum++) {
231 map[op->getResult(opNum)] = it->getOp()->getResult(opNum);
232 }
233 } else {
234 uniqueOps.insert(comp);
235 }
236 }
237
238 return WalkResult::advance();
239 });
240
241 // Track the operands of removed ops.
242 std::deque<Value> operands;
243
244 for (auto *op : redundantOps) {
245 LLVM_DEBUG(llvm::dbgs() << "Removing op: " << *op << '\n');
246 for (auto result : op->getResults()) {
247 if (!result.getUsers().empty()) {
248 auto it = map.find(result);
249 ensure(
250 it != map.end(), "failed to find a replacement value for redundant operation result"
251 );
252 LLVM_DEBUG(llvm::dbgs() << "Replacing " << it->first << " with " << it->second << '\n');
253 result.replaceAllUsesWith(it->second);
254 }
255 }
256 for (Value operand : op->getOperands()) {
257 operands.push_back(operand);
258 }
259 op->erase();
260 }
261
262 // Check if any of the operands are unused. If so, remove them, and check
263 // their operands until all operands have been checked.
264
265 // Make sure operands aren't freed multiple times
266 DenseSet<Value> checkedOperands;
267 while (!operands.empty()) {
268 Value operand = operands.front();
269 operands.pop_front();
270 checkedOperands.insert(operand);
271
272 // We only want to remove operands that are defined by an operation and
273 // are not block arguments.
274 if (auto *op = operand.getDefiningOp(); op && operand.getUsers().empty()) {
275 for (auto parentOperand : op->getOperands()) {
276 if (checkedOperands.find(parentOperand) == checkedOperands.end()) {
277 operands.push_back(parentOperand);
278 }
279 }
280 LLVM_DEBUG(llvm::dbgs() << "Removing unused operand: " << operand << '\n');
281 op->erase();
282 }
283 }
284 }
285};
286
287} // namespace
bool isExternal() const
Returns true if this node is an external node.
Definition CallGraph.cpp:39
llzk::function::FuncDefOp getCalledFunction() const
Returns the called function that the callable region represents.
Definition CallGraph.cpp:48
bool isStructConstrain()
Return true iff the function is within a StructDefOp and named FUNC_NAME_CONSTRAIN.
Definition Ops.h.inc:882
void ensure(bool condition, const llvm::Twine &errMsg)
std::unordered_map< SourceRef, SourceRefLatticeValue, SourceRef::Hash > TranslationMap
mlir::FailureOr< SymbolLookupResult< T > > resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers.
static unsigned getHashValue(const OperationComparator &oc)
static bool isEqual(const OperationComparator &lhs, const OperationComparator &rhs)