25#include <mlir/IR/Builders.h>
26#include <mlir/IR/SymbolTable.h>
27#include <mlir/Transforms/InliningUtils.h>
29#include <llvm/Support/Debug.h>
35#define GEN_PASS_DECL_COMPUTECONSTRAINTOPRODUCTPASS
36#define GEN_PASS_DEF_COMPUTECONSTRAINTOPRODUCTPASS
40#define DEBUG_TYPE "llzk-compute-constrain-to-product-pass"
46using std::make_unique;
54 if (!computeFunc || !constrainFunc) {
88 ModuleOp
mod = getOperation();
91 SymbolTableCollection tables;
93 getAnalysis<LightweightSignalEquivalenceAnalysis>()
115 OpBuilder funcBuilder(compute);
118 compute.walk([&funcBuilder](Operation *op) {
122 constrain.walk([&funcBuilder](Operation *op) {
131 productFunc->setAttr(
DERIVED_ATTR_NAME, UnitAttr::get(funcBuilder.getContext()));
132 Block *entryBlock = productFunc.addEntryBlock();
133 funcBuilder.setInsertionPointToStart(entryBlock);
140 llvm::SmallVector<Value> args {productFunc.getArguments()};
143 CallOp computeCall = funcBuilder.create<
CallOp>(funcBuilder.getUnknownLoc(), compute, args);
144 args.insert(args.begin(), computeCall->getResult(0));
146 funcBuilder.create<
ReturnOp>(funcBuilder.getUnknownLoc(), computeCall->getResult(0));
149 InlinerInterface inliner(productFunc.getContext());
150 if (failed(inlineCall(inliner, computeCall, compute, &compute.
getBody(),
true))) {
154 if (failed(inlineCall(inliner, constrainCall,
constrain, &
constrain.getBody(),
true))) {
158 computeCall->erase();
159 constrainCall->erase();
173 llvm::SetVector<CallOp> computeCalls, constrainCalls;
174 product.walk([&](
CallOp callOp) {
176 computeCalls.insert(callOp);
178 constrainCalls.insert(callOp);
182 llvm::SetVector<std::pair<CallOp, CallOp>> alignedCalls;
188 llvm::outs() <<
"Asking for equivalence between calls\n"
189 << compute <<
"\nand\n"
191 llvm::outs() <<
"In block:\n\n" << *compute->getBlock() <<
"\n";
196 if (computeStruct != constrainStruct) {
199 if (compute.getNumOperands() == 0) {
202 for (
unsigned i = 0, e = compute->getNumOperands() - 1; i < e; i++) {
203 if (!equivalence.areSignalsEquivalent(compute->getOperand(i),
constrain->getOperand(i + 1))) {
211 for (
auto compute : computeCalls) {
213 auto matches = llvm::filter_to_vector(constrainCalls, [&](
CallOp constrain) {
217 if (matches.size() == 1) {
218 alignedCalls.insert({compute, matches[0]});
219 computeCalls.remove(compute);
220 constrainCalls.remove(matches[0]);
224 if (!computeCalls.empty() && constrainCalls.empty()) {
225 product.emitWarning()
230 for (
auto [compute,
constrain] : alignedCalls) {
233 if (failed(calleeTgt)) {
236 auto newRoot = calleeTgt->get()->getParentOfType<
StructDefOp>();
239 alignFuncs(newRoot, newRoot.getComputeFuncOp(), newRoot.getConstrainFuncOp());
245 OpBuilder callBuilder(compute);
247 callBuilder.getFusedLoc({compute.getLoc(),
constrain.getLoc()}), newProduct,
248 compute.getOperands()
250 compute->replaceAllUsesWith(newCall.getResults());
257 return make_unique<ComputeConstrainToProductPass>();
void runOnOperation() override
std::vector< component::StructDefOp > alignedStructs
function::FuncDefOp alignFuncs(component::StructDefOp root, function::FuncDefOp compute, function::FuncDefOp constrain)
mlir::LogicalResult alignCalls(function::FuncDefOp product)
::llvm::StringRef getSymName()
::llzk::function::FuncDefOp getProductFuncOp()
Gets the FuncDefOp that defines the product function in this structure, if present,...
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
::mlir::SymbolRefAttr getCallee()
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
::mlir::FunctionType getFunctionType()
void setAllowNonNativeFieldOpsAttr(bool newValue=true)
Add (resp. remove) the allow_non_native_field_ops attribute to (resp. from) the function def.
bool hasAllowNonNativeFieldOpsAttr()
Return true iff the function def has the allow_non_native_field_ops attribute.
static FuncDefOp create(::mlir::Location location, ::llvm::StringRef name, ::mlir::FunctionType type, ::llvm::ArrayRef<::mlir::NamedAttribute > attrs={})
::mlir::Region & getBody()
::mlir::Pass::Option< std::string > rootStruct
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
constexpr char PRODUCT_SOURCE[]
Name of the attribute on aligned product program ops that specifies where they came from.
constexpr char FUNC_NAME_CONSTRAIN[]
LogicalResult alignStartingAt(component::StructDefOp root, SymbolTableCollection &tables, LightweightSignalEquivalenceAnalysis &equivalence)
std::unique_ptr< Pass > createComputeConstrainToProductPass()
constexpr char FUNC_NAME_PRODUCT[]
constexpr char DERIVED_ATTR_NAME[]
Name of the attribute on a @product func that has been automatically aligned from @compute + @constra...
bool isValidRoot(StructDefOp root)