25#include <mlir/IR/Builders.h>
26#include <mlir/IR/SymbolTable.h>
27#include <mlir/Transforms/InliningUtils.h>
29#include <llvm/Support/Debug.h>
35#define GEN_PASS_DEF_COMPUTECONSTRAINTOPRODUCTPASS
39#define DEBUG_TYPE "llzk-compute-constrain-to-product-pass"
52 OpBuilder funcBuilder(compute);
55 compute.walk([&funcBuilder](Operation *op) {
69 Block *entryBlock = productFunc.addEntryBlock();
70 funcBuilder.setInsertionPointToStart(entryBlock);
77 llvm::SmallVector<Value> args {productFunc.getArguments()};
80 CallOp computeCall = funcBuilder.create<
CallOp>(funcBuilder.getUnknownLoc(), compute, args);
81 args.insert(args.begin(), computeCall->getResult(0));
83 funcBuilder.create<
ReturnOp>(funcBuilder.getUnknownLoc(), computeCall->getResult(0));
86 InlinerInterface inliner(productFunc.getContext());
87 if (failed(inlineCall(inliner, computeCall, compute, &compute.
getBody(),
true))) {
91 if (failed(inlineCall(inliner, constrainCall,
constrain, &
constrain.getBody(),
true))) {
96 constrainCall->erase();
110 llvm::SetVector<CallOp> computeCalls, constrainCalls;
111 product.walk([&](
CallOp callOp) {
113 computeCalls.insert(callOp);
115 constrainCalls.insert(callOp);
119 llvm::SetVector<std::pair<CallOp, CallOp>> alignedCalls;
125 llvm::outs() <<
"Asking for equivalence between calls\n"
126 << compute <<
"\nand\n"
128 llvm::outs() <<
"In block:\n\n" << *compute->getBlock() <<
'\n';
133 if (computeStruct != constrainStruct) {
136 if (compute.getNumOperands() == 0) {
139 for (
unsigned i = 0, e = compute->getNumOperands() - 1; i < e; i++) {
148 for (
auto compute : computeCalls) {
150 auto matches = llvm::filter_to_vector(constrainCalls, [&](
CallOp constrain) {
154 if (matches.size() == 1) {
155 alignedCalls.insert({compute, matches[0]});
156 computeCalls.remove(compute);
157 constrainCalls.remove(matches[0]);
161 if (!computeCalls.empty() && constrainCalls.empty()) {
162 product.emitWarning()
167 for (
auto [compute, constrain] : alignedCalls) {
170 if (failed(calleeTgt)) {
173 auto newRoot = calleeTgt->get()->getParentOfType<StructDefOp>();
175 FuncDefOp newProduct =
176 alignFuncs(newRoot, newRoot.getComputeFuncOp(), newRoot.getConstrainFuncOp());
182 OpBuilder callBuilder(compute);
183 CallOp newCall = callBuilder.create<CallOp>(
184 callBuilder.getFusedLoc({compute.getLoc(), constrain.getLoc()}), newProduct,
185 compute.getOperands()
187 compute->replaceAllUsesWith(newCall.getResults());
199 if (!computeFunc || !constrainFunc) {
212LogicalResult alignStartingAt(
216 if (!isValidRoot(root)) {
228class PassImpl :
public llzk::impl::ComputeConstrainToProductPassBase<PassImpl> {
229 using Base = ComputeConstrainToProductPassBase<PassImpl>;
232 void runOnOperation()
override {
233 ModuleOp
mod = getOperation();
236 SymbolTableCollection tables;
237 LightweightSignalEquivalenceAnalysis equivalence {
238 getAnalysis<LightweightSignalEquivalenceAnalysis>()
242 mod.walk([&root,
this](StructDefOp structDef) {
248 if (failed(alignStartingAt(root, tables, equivalence))) {
bool areSignalsEquivalent(mlir::Value v1, mlir::Value v2)
std::vector< component::StructDefOp > alignedStructs
function::FuncDefOp alignFuncs(component::StructDefOp root, function::FuncDefOp compute, function::FuncDefOp constrain)
mlir::LogicalResult alignCalls(function::FuncDefOp product)
::llvm::StringRef getSymName()
::llzk::function::FuncDefOp getProductFuncOp()
Gets the FuncDefOp that defines the product function in this structure, if present,...
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
::mlir::SymbolRefAttr getCallee()
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
::mlir::FunctionType getFunctionType()
void setAllowNonNativeFieldOpsAttr(bool newValue=true)
Add (resp. remove) the allow_non_native_field_ops attribute to (resp. from) the function def.
bool hasAllowNonNativeFieldOpsAttr()
Return true iff the function def has the allow_non_native_field_ops attribute.
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
::mlir::Region & getBody()
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char PRODUCT_SOURCE[]
Name of the attribute on aligned product program ops that specifies where they came from.
constexpr char FUNC_NAME_CONSTRAIN[]
constexpr char FUNC_NAME_PRODUCT[]
constexpr char DERIVED_ATTR_NAME[]
Name of the attribute on a @product func that has been automatically aligned from @compute + @constra...