LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
WitgenDriver.cpp
Go to the documentation of this file.
1//===-- WitgenDriver.cpp - llzk-witgen driver entrypoints -------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2026 Project LLZK
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#include "WitgenDriver.h"
11
12#include "Errors.h"
14#include "Interpreter.h"
15#include "JSON.h"
16#include "WitgenLowering.h"
17#include "WitgenUtils.h"
18#include "WitnessSelection.h"
19
24
25#include <mlir/IR/BuiltinAttributes.h>
26#include <mlir/Pass/PassManager.h>
27
28#include <llvm/ADT/SmallVector.h>
29
30using namespace mlir;
31
32namespace llzk::witgen {
33
35static bool requiresFlattening(ModuleOp moduleOp) {
36 return moduleOp
37 ->walk([&](Operation *op) {
38 if (isa<function::CallOp>(op)) {
39 auto callOp = cast<function::CallOp>(op);
40 if (callOp.getTemplateParams() || !callOp.getMapOperands().empty()) {
41 return WalkResult::interrupt();
42 }
43 }
44 if (op->getDialect()->getNamespace() ==
46 return WalkResult::interrupt();
47 }
48 return WalkResult::advance();
49 }).wasInterrupted();
50}
51
54 ModuleOp mod, SymbolTableCollection &symbolTables, const Field &moduleField,
55 UninitializedBehavior behavior, std::mt19937_64 r
56)
57 : moduleOp(mod), tables(symbolTables), field(moduleField), uninitializedBehavior(behavior),
58 rng(r) {}
59
61static llvm::Expected<llvm::SmallVector<WitnessVal>> parseArgumentsFromJSON(
62 function::FuncDefOp computeFunc, const llvm::json::Value &input, const Field &field
63) {
64 llvm::SmallVector<WitnessVal> args;
65 const auto *jsonObject = input.getAsObject();
66 const auto *jsonArray = input.getAsArray();
67 if (!jsonObject && !jsonArray) {
68 return makeError("inputs JSON must be either an object or an array");
69 }
70
71 if (jsonObject) {
72 for (unsigned i = 0; i < computeFunc.getNumArguments(); ++i) {
73 llvm::StringRef argName;
74 if (std::optional<StringAttr> attr = computeFunc.getArgNameAttr(i)) {
75 argName = attr->getValue();
76 } else {
77 return makeError("JSON object input requires function.arg_name on every main argument");
78 }
79 const llvm::json::Value *value = jsonObject->get(argName);
80 if (!value) {
81 return makeError(llvm::Twine("missing JSON input field: ") + argName);
82 }
83 auto parsed = parseJSONValue(
84 value, computeFunc.getArgumentTypes()[i], field, computeFunc.getOperation()
85 );
86 if (!parsed) {
87 return parsed.takeError();
88 }
89 args.push_back(*parsed);
90 }
91 return args;
92 }
93
94 if (jsonArray->size() != computeFunc.getNumArguments()) {
95 return makeError("JSON positional input length does not match main compute arity");
96 }
97 for (unsigned i = 0; i < computeFunc.getNumArguments(); ++i) {
98 auto parsed = parseJSONValue(
99 &(*jsonArray)[i], computeFunc.getArgumentTypes()[i], field, computeFunc.getOperation()
100 );
101 if (!parsed) {
102 return parsed.takeError();
103 }
104 args.push_back(*parsed);
105 }
106 return args;
107}
108
110llvm::Expected<llvm::json::Value> Interpreter::runMainFromJSON(const llvm::json::Value &input) {
111 auto mainDef = getMainInstanceDef(tables, moduleOp.getOperation());
112 if (failed(mainDef) || !mainDef.value()) {
113 return makeError("module is missing a concrete llzk.main struct");
114 }
115
116 auto computeFunc = mainDef->get().getComputeFuncOp();
117 if (!computeFunc) {
118 return makeError("main struct is missing @compute");
119 }
120 if (computeFunc.getNumResults() != 1) {
121 return makeError("main compute must return exactly one value");
122 }
123
124 auto args = parseArgumentsFromJSON(computeFunc, input, field);
125 if (!args) {
126 return args.takeError();
127 }
128
129 FunctionInterpreter interpreter(moduleOp, tables, field, uninitializedBehavior, rng);
130 auto results = interpreter.run(computeFunc, *args);
131 if (!results) {
132 return results.takeError();
133 }
134 if (results->size() != 1) {
135 return makeError("main compute returned unexpected result count");
136 }
137 if (outputScope == OutputScope::Public) {
138 return serializeJSONValue(
139 results->front(), computeFunc.getResultTypes().front(), tables, computeFunc.getOperation(),
141 );
142 }
143
144 auto inputBindings = collectInputBindings(computeFunc);
145 auto inputsJSON = buildInputsJSONObject(inputBindings, *args, tables, computeFunc.getOperation());
146 if (!inputsJSON) {
147 return inputsJSON.takeError();
148 }
149
150 auto outputBindings = collectOutputBindings(
151 mainDef->get(), tables, computeFunc.getOperation(), OutputScope::FullWitness
152 );
153 if (failed(outputBindings)) {
154 return makeError("failed to select full witness signals");
155 }
156
157 llvm::SmallVector<llvm::json::Value> serializedSignals;
158 serializedSignals.reserve(outputBindings->size());
159 for (const OutputBinding &binding : *outputBindings) {
160 auto leafValue = extractValueAtPath(
161 results->front(), computeFunc.getResultTypes().front(), binding.path, tables,
162 computeFunc.getOperation()
163 );
164 if (!leafValue) {
165 return leafValue.takeError();
166 }
167 auto serialized = serializeJSONValue(
168 *leafValue, binding.type, tables, computeFunc.getOperation(), SerializationMode::AllSignals
169 );
170 if (!serialized) {
171 return serialized.takeError();
172 }
173 serializedSignals.push_back(*serialized);
174 }
175
176 llvm::json::Object result;
177 result["inputs"] = llvm::json::Value(std::move(*inputsJSON));
178 result["signals"] = buildSignalsJSONObject(*outputBindings, serializedSignals);
179 return llvm::json::Value(std::move(result));
180}
181
183static llvm::Error preprocessModule(ModuleOp moduleOp, const WitgenOptions &options) {
184 // normalizeCallOpProperties(moduleOp);
185 PassManager pm(moduleOp.getContext());
186 if (options.inlineIncludes) {
188 }
189 if (options.backend == Backend::ExecutionEngine) {
190 addWitgenPreparePipeline(pm, options);
191 } else if (requiresFlattening(moduleOp)) {
193 }
194 if (failed(pm.run(moduleOp))) {
195 return makeError("failed to preprocess LLZK module for llzk-witgen");
196 }
197 return llvm::Error::success();
198}
199
201llvm::Expected<llvm::json::Value>
202runWitgen(ModuleOp moduleOp, const llvm::json::Value &input, const WitgenOptions &options) {
203 if (auto err = preprocessModule(moduleOp, options)) {
204 return std::move(err);
205 }
206
207 FieldSet fields;
208 if (failed(collectFields(moduleOp.getOperation(), fields))) {
209 return makeError("failed to collect fields for llzk-witgen");
210 }
211 if (fields.size() != 1) {
212 return makeError("llzk-witgen v1 requires exactly one field in the module");
213 }
214
215 SymbolTableCollection tables;
216 if (options.backend == Backend::ExecutionEngine) {
217 return runWithExecutionEngine(moduleOp, tables, *fields.begin(), input, options);
218 }
219 Interpreter interpreter(
220 moduleOp, tables, *fields.begin(), options.uninitializedBehavior, makeDefaultValueRng(options)
221 );
222 interpreter.setOutputScope(options.outputScope);
223 return interpreter.runMainFromJSON(input);
224}
225
226} // namespace llzk::witgen
Information about the prime finite field used for the interval analysis.
Definition Field.h:36
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Required by FunctionOpInterface.
Definition Ops.h.inc:850
::std::optional<::mlir::StringAttr > getArgNameAttr(unsigned index)
Return the function.arg_name attribute for the argument at the given index.
Definition Ops.cpp:297
::llvm::ArrayRef<::mlir::Type > getResultTypes()
Required by FunctionOpInterface.
Definition Ops.h.inc:854
static constexpr ::llvm::StringLiteral getDialectNamespace()
Definition Dialect.h.inc:20
Execute one flattened LLZK function body over runtime values.
Definition Interpreter.h:25
llvm::Expected< llvm::SmallVector< WitnessVal > > run(llzk::function::FuncDefOp funcOp, mlir::ArrayRef< WitnessVal > args)
Run a function with concrete arguments and return its result values.
Drive witness generation for the concrete llzk.main instance.
Interpreter(mlir::ModuleOp moduleOp, mlir::SymbolTableCollection &tables, const llzk::Field &field, UninitializedBehavior uninitializedBehavior, std::mt19937_64 rng)
Build a driver for one parsed module and validated field.
void setOutputScope(OutputScope newOutputScope)
Select which witness JSON scope this interpreter emits.
llvm::Expected< llvm::json::Value > runMainFromJSON(const llvm::json::Value &input)
Execute the main compute() function using JSON inputs.
std::unique_ptr<::mlir::Pass > createInlineIncludesPass()
std::unique_ptr<::mlir::Pass > createFlatteningPass()
llvm::Expected< llvm::json::Value > runWithExecutionEngine(ModuleOp moduleOp, SymbolTableCollection &tables, const Field &field, const llvm::json::Value &input, const WitgenOptions &options)
Execute witness generation through MLIR lowering and the LLVM execution engine.
llvm::Expected< llvm::json::Object > buildInputsJSONObject(ArrayRef< InputBinding > bindings, ArrayRef< WitnessVal > values, SymbolTableCollection &tables, Operation *origin)
Serialize named input values into a JSON object.
Definition JSON.cpp:401
llvm::SmallVector< InputBinding > collectInputBindings(function::FuncDefOp computeFunc)
Collect stable JSON bindings for the main compute inputs.
std::mt19937_64 makeDefaultValueRng(const WitgenOptions &options)
Seed an RNG for random/default witness value materialization.
FailureOr< llvm::SmallVector< OutputBinding > > collectOutputBindings(component::StructDefOp mainDef, SymbolTableCollection &tables, Operation *origin, OutputScope scope)
Collect the selected output bindings for the requested scope.
llvm::Expected< llvm::json::Value > serializeJSONValue(const WitnessVal &value, Type type, SymbolTableCollection &tables, Operation *origin, SerializationMode mode)
Serialize a supported LLZK runtime value into JSON.
Definition JSON.cpp:300
llvm::json::Value buildSignalsJSONObject(ArrayRef< OutputBinding > bindings, ArrayRef< llvm::json::Value > serializedLeaves)
Assemble a nested JSON object from selected witness leaves.
llvm::Expected< WitnessVal > extractValueAtPath(const WitnessVal &root, Type rootType, ArrayRef< std::string > path, SymbolTableCollection &tables, Operation *origin)
Extract one nested runtime leaf by path.
Definition JSON.cpp:422
void addWitgenPreparePipeline(OpPassManager &pm, const WitgenOptions &)
llvm::Expected< llvm::json::Value > runWitgen(ModuleOp moduleOp, const llvm::json::Value &input, const WitgenOptions &options)
Run include preprocessing, field validation, and backend execution.
UninitializedBehavior
Control how witgen materializes uninitialized/default values.
Definition ValueModel.h:55
llvm::Expected< WitnessVal > parseJSONValue(const llvm::json::Value *json, Type type, const Field &field, Operation *origin)
Parse a supported LLZK input type from JSON.
Definition JSON.cpp:263
llvm::Error makeError(const llvm::Twine &msg)
Build a string-backed error for user-facing witgen failures.
Definition Errors.h:18
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
FailureOr< SymbolLookupResult< StructDefOp > > getMainInstanceDef(SymbolTableCollection &symbolTable, Operation *lookupFrom)
llvm::SmallSet< FieldRef, 2 > FieldSet
Typealias for a set of Fields.
Definition Field.h:159
mlir::LogicalResult collectFields(mlir::Operation *root, FieldSet &fields, bool silent=true)
Collects all the fields used in a circuit.
Definition Field.cpp:264
Describe one selected witness output leaf.
Configure one llzk-witgen execution.
UninitializedBehavior uninitializedBehavior