23#include <mlir/IR/Builders.h>
24#include <mlir/Transforms/InliningUtils.h>
26#include <llvm/Support/Debug.h>
32#define GEN_PASS_DECL_COMPUTECONSTRAINTOPRODUCTPASS
33#define GEN_PASS_DEF_COMPUTECONSTRAINTOPRODUCTPASS
37#define DEBUG_TYPE "llzk-compute-constrain-to-product-pass"
43using std::make_unique;
51 if (!computeFunc || !constrainFunc) {
76 s.getComputeFuncOp()->erase();
77 s.getConstrainFuncOp()->erase();
88 ModuleOp
mod = getOperation();
91 SymbolTableCollection tables;
93 getAnalysis<LightweightSignalEquivalenceAnalysis>()
110 OpBuilder funcBuilder(compute);
113 compute.walk([&funcBuilder](Operation *op) {
117 constrain.walk([&funcBuilder](Operation *op) {
126 Block *entryBlock = productFunc.addEntryBlock();
127 funcBuilder.setInsertionPointToStart(entryBlock);
130 llvm::SmallVector<Value> args {productFunc.getArguments()};
133 CallOp computeCall = funcBuilder.create<
CallOp>(funcBuilder.getUnknownLoc(), compute, args);
134 args.insert(args.begin(), computeCall->getResult(0));
136 funcBuilder.create<
ReturnOp>(funcBuilder.getUnknownLoc(), computeCall->getResult(0));
139 InlinerInterface inliner(productFunc.getContext());
140 if (failed(inlineCall(inliner, computeCall, compute, &compute.
getBody(),
true))) {
144 if (failed(inlineCall(inliner, constrainCall,
constrain, &
constrain.getBody(),
true))) {
148 computeCall->erase();
149 constrainCall->erase();
163 llvm::SetVector<CallOp> computeCalls, constrainCalls;
164 product.walk([&](
CallOp callOp) {
166 computeCalls.insert(callOp);
168 constrainCalls.insert(callOp);
172 llvm::SetVector<std::pair<CallOp, CallOp>> alignedCalls;
178 llvm::outs() <<
"Asking for equivalence between calls\n"
179 << compute <<
"\nand\n"
181 llvm::outs() <<
"In block:\n\n" << *compute->getBlock() <<
"\n";
186 if (computeStruct != constrainStruct) {
189 for (
unsigned i = 0, e = compute->getNumOperands() - 1; i < e; i++) {
190 if (!equivalence.areSignalsEquivalent(compute->getOperand(i),
constrain->getOperand(i + 1))) {
198 for (
auto compute : computeCalls) {
200 auto matches = llvm::filter_to_vector(constrainCalls, [&](
CallOp constrain) {
204 if (matches.size() == 1) {
205 alignedCalls.insert({compute, matches[0]});
206 computeCalls.remove(compute);
207 constrainCalls.remove(matches[0]);
212 if (!computeCalls.empty() && constrainCalls.empty()) {
213 product->emitError() <<
"failed to align some @" <<
FUNC_NAME_COMPUTE <<
" and @"
218 for (
auto [compute,
constrain] : alignedCalls) {
223 alignFuncs(newRoot, newRoot.getComputeFuncOp(), newRoot.getConstrainFuncOp());
229 OpBuilder callBuilder(compute);
231 callBuilder.getFusedLoc({compute.getLoc(),
constrain.getLoc()}), newProduct,
232 compute.getOperands()
234 compute->replaceAllUsesWith(newCall.getResults());
243 return make_unique<ComputeConstrainToProductPass>();
void runOnOperation() override
std::vector< component::StructDefOp > alignedStructs
function::FuncDefOp alignFuncs(component::StructDefOp root, function::FuncDefOp compute, function::FuncDefOp constrain)
mlir::LogicalResult alignCalls(function::FuncDefOp product)
::llvm::StringRef getSymName()
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
::mlir::SymbolRefAttr getCallee()
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
::mlir::FunctionType getFunctionType()
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
::mlir::Region & getBody()
::mlir::Pass::Option< std::string > rootStruct
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
constexpr char PRODUCT_SOURCE[]
Name of the attribute on aligned product program ops that specifies where they came from.
constexpr char FUNC_NAME_CONSTRAIN[]
LogicalResult alignStartingAt(component::StructDefOp root, SymbolTableCollection &tables, LightweightSignalEquivalenceAnalysis &equivalence)
std::unique_ptr< Pass > createComputeConstrainToProductPass()
constexpr char FUNC_NAME_PRODUCT[]
bool isValidRoot(StructDefOp root)
ExpressionValue mod(llvm::SMTSolverRef solver, const ExpressionValue &lhs, const ExpressionValue &rhs)