25#include <mlir/IR/BuiltinAttributes.h>
26#include <mlir/Pass/PassManager.h>
28#include <llvm/ADT/SmallVector.h>
35static bool requiresFlattening(ModuleOp moduleOp) {
37 ->walk([&](Operation *op) {
38 if (isa<function::CallOp>(op)) {
39 auto callOp = cast<function::CallOp>(op);
40 if (callOp.getTemplateParams() || !callOp.getMapOperands().empty()) {
41 return WalkResult::interrupt();
44 if (op->getDialect()->getNamespace() ==
46 return WalkResult::interrupt();
48 return WalkResult::advance();
54 ModuleOp
mod, SymbolTableCollection &symbolTables,
const Field &moduleField,
57 : moduleOp(
mod), tables(symbolTables), field(moduleField), uninitializedBehavior(behavior),
61static llvm::Expected<llvm::SmallVector<WitnessVal>> parseArgumentsFromJSON(
64 llvm::SmallVector<WitnessVal> args;
65 const auto *jsonObject = input.getAsObject();
66 const auto *jsonArray = input.getAsArray();
67 if (!jsonObject && !jsonArray) {
68 return makeError(
"inputs JSON must be either an object or an array");
72 for (
unsigned i = 0; i < computeFunc.getNumArguments(); ++i) {
73 llvm::StringRef argName;
74 if (std::optional<StringAttr> attr = computeFunc.
getArgNameAttr(i)) {
75 argName = attr->getValue();
77 return makeError(
"JSON object input requires function.arg_name on every main argument");
79 const llvm::json::Value *value = jsonObject->get(argName);
81 return makeError(llvm::Twine(
"missing JSON input field: ") + argName);
87 return parsed.takeError();
89 args.push_back(*parsed);
94 if (jsonArray->size() != computeFunc.getNumArguments()) {
95 return makeError(
"JSON positional input length does not match main compute arity");
97 for (
unsigned i = 0; i < computeFunc.getNumArguments(); ++i) {
99 &(*jsonArray)[i], computeFunc.
getArgumentTypes()[i], field, computeFunc.getOperation()
102 return parsed.takeError();
104 args.push_back(*parsed);
112 if (failed(mainDef) || !mainDef.value()) {
113 return makeError(
"module is missing a concrete llzk.main struct");
116 auto computeFunc = mainDef->get().getComputeFuncOp();
118 return makeError(
"main struct is missing @compute");
120 if (computeFunc.getNumResults() != 1) {
121 return makeError(
"main compute must return exactly one value");
124 auto args = parseArgumentsFromJSON(computeFunc, input, field);
126 return args.takeError();
130 auto results = interpreter.
run(computeFunc, *args);
132 return results.takeError();
134 if (results->size() != 1) {
135 return makeError(
"main compute returned unexpected result count");
139 results->front(), computeFunc.
getResultTypes().front(), tables, computeFunc.getOperation(),
147 return inputsJSON.takeError();
153 if (failed(outputBindings)) {
154 return makeError(
"failed to select full witness signals");
157 llvm::SmallVector<llvm::json::Value> serializedSignals;
158 serializedSignals.reserve(outputBindings->size());
161 results->front(), computeFunc.
getResultTypes().front(), binding.path, tables,
162 computeFunc.getOperation()
165 return leafValue.takeError();
171 return serialized.takeError();
173 serializedSignals.push_back(*serialized);
176 llvm::json::Object result;
177 result[
"inputs"] = llvm::json::Value(std::move(*inputsJSON));
179 return llvm::json::Value(std::move(result));
183static llvm::Error preprocessModule(ModuleOp moduleOp,
const WitgenOptions &options) {
185 PassManager pm(moduleOp.getContext());
191 }
else if (requiresFlattening(moduleOp)) {
194 if (failed(pm.run(moduleOp))) {
195 return makeError(
"failed to preprocess LLZK module for llzk-witgen");
197 return llvm::Error::success();
201llvm::Expected<llvm::json::Value>
203 if (
auto err = preprocessModule(moduleOp, options)) {
204 return std::move(err);
208 if (failed(
collectFields(moduleOp.getOperation(), fields))) {
209 return makeError(
"failed to collect fields for llzk-witgen");
211 if (fields.size() != 1) {
212 return makeError(
"llzk-witgen v1 requires exactly one field in the module");
215 SymbolTableCollection tables;
Information about the prime finite field used for the interval analysis.
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Required by FunctionOpInterface.
::std::optional<::mlir::StringAttr > getArgNameAttr(unsigned index)
Return the function.arg_name attribute for the argument at the given index.
::llvm::ArrayRef<::mlir::Type > getResultTypes()
Required by FunctionOpInterface.
static constexpr ::llvm::StringLiteral getDialectNamespace()
Execute one flattened LLZK function body over runtime values.
llvm::Expected< llvm::SmallVector< WitnessVal > > run(llzk::function::FuncDefOp funcOp, mlir::ArrayRef< WitnessVal > args)
Run a function with concrete arguments and return its result values.
Drive witness generation for the concrete llzk.main instance.
Interpreter(mlir::ModuleOp moduleOp, mlir::SymbolTableCollection &tables, const llzk::Field &field, UninitializedBehavior uninitializedBehavior, std::mt19937_64 rng)
Build a driver for one parsed module and validated field.
void setOutputScope(OutputScope newOutputScope)
Select which witness JSON scope this interpreter emits.
llvm::Expected< llvm::json::Value > runMainFromJSON(const llvm::json::Value &input)
Execute the main compute() function using JSON inputs.
std::unique_ptr<::mlir::Pass > createInlineIncludesPass()
std::unique_ptr<::mlir::Pass > createFlatteningPass()
llvm::Expected< llvm::json::Value > runWithExecutionEngine(ModuleOp moduleOp, SymbolTableCollection &tables, const Field &field, const llvm::json::Value &input, const WitgenOptions &options)
Execute witness generation through MLIR lowering and the LLVM execution engine.
llvm::Expected< llvm::json::Object > buildInputsJSONObject(ArrayRef< InputBinding > bindings, ArrayRef< WitnessVal > values, SymbolTableCollection &tables, Operation *origin)
Serialize named input values into a JSON object.
llvm::SmallVector< InputBinding > collectInputBindings(function::FuncDefOp computeFunc)
Collect stable JSON bindings for the main compute inputs.
std::mt19937_64 makeDefaultValueRng(const WitgenOptions &options)
Seed an RNG for random/default witness value materialization.
FailureOr< llvm::SmallVector< OutputBinding > > collectOutputBindings(component::StructDefOp mainDef, SymbolTableCollection &tables, Operation *origin, OutputScope scope)
Collect the selected output bindings for the requested scope.
llvm::Expected< llvm::json::Value > serializeJSONValue(const WitnessVal &value, Type type, SymbolTableCollection &tables, Operation *origin, SerializationMode mode)
Serialize a supported LLZK runtime value into JSON.
llvm::json::Value buildSignalsJSONObject(ArrayRef< OutputBinding > bindings, ArrayRef< llvm::json::Value > serializedLeaves)
Assemble a nested JSON object from selected witness leaves.
llvm::Expected< WitnessVal > extractValueAtPath(const WitnessVal &root, Type rootType, ArrayRef< std::string > path, SymbolTableCollection &tables, Operation *origin)
Extract one nested runtime leaf by path.
void addWitgenPreparePipeline(OpPassManager &pm, const WitgenOptions &)
llvm::Expected< llvm::json::Value > runWitgen(ModuleOp moduleOp, const llvm::json::Value &input, const WitgenOptions &options)
Run include preprocessing, field validation, and backend execution.
UninitializedBehavior
Control how witgen materializes uninitialized/default values.
llvm::Expected< WitnessVal > parseJSONValue(const llvm::json::Value *json, Type type, const Field &field, Operation *origin)
Parse a supported LLZK input type from JSON.
llvm::Error makeError(const llvm::Twine &msg)
Build a string-backed error for user-facing witgen failures.
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
FailureOr< SymbolLookupResult< StructDefOp > > getMainInstanceDef(SymbolTableCollection &symbolTable, Operation *lookupFrom)
llvm::SmallSet< FieldRef, 2 > FieldSet
Typealias for a set of Fields.
mlir::LogicalResult collectFields(mlir::Operation *root, FieldSet &fields, bool silent=true)
Collects all the fields used in a circuit.
Describe one selected witness output leaf.
Configure one llzk-witgen execution.
UninitializedBehavior uninitializedBehavior