20#include <mlir/IR/Operation.h>
22#include <llvm/ADT/SmallString.h>
23#include <llvm/ADT/TypeSwitch.h>
24#include <llvm/Support/FormatVariadic.h>
25#include <llvm/Support/MathExtras.h>
26#include <llvm/Support/raw_ostream.h>
33static std::string renderJSON(
const llvm::json::Value &value) {
34 return llvm::formatv(
"{0:2}", value).str();
37static llvm::StringRef jsonKind(
const llvm::json::Value &value) {
38 if (value.getAsNull()) {
41 if (value.getAsBoolean().has_value()) {
44 if (value.getAsNumber().has_value()) {
47 if (value.getAsString()) {
50 if (value.getAsArray()) {
53 if (value.getAsObject()) {
59static std::string appendObjectPath(llvm::StringRef path, llvm::StringRef key) {
60 llvm::SmallString<64> out(path);
63 return std::string(out);
66static std::string appendIndexPath(llvm::StringRef path,
size_t index) {
67 llvm::SmallString<64> out(path);
68 llvm::raw_svector_ostream os(out);
69 os <<
'[' << index <<
']';
70 return std::string(out);
73static void pushMismatch(
74 llvm::SmallVectorImpl<JSONMismatch> &out, llvm::StringRef path,
const llvm::Twine &message
79static void diffObjects(
80 const llvm::json::Object &expected,
const llvm::json::Object &actual,
81 llvm::SmallVectorImpl<JSONMismatch> &out, llvm::StringRef path
83 for (
const auto &kv : expected) {
84 if (
const llvm::json::Value *actualValue = actual.get(kv.first)) {
85 diffJSON(kv.second, *actualValue, out, appendObjectPath(path, kv.first));
88 pushMismatch(out, appendObjectPath(path, kv.first),
"missing key");
90 for (
const auto &kv : actual) {
91 if (!expected.get(kv.first)) {
92 pushMismatch(out, appendObjectPath(path, kv.first),
"unexpected key");
97static void diffArrays(
98 const llvm::json::Array &expected,
const llvm::json::Array &actual,
99 llvm::SmallVectorImpl<JSONMismatch> &out, llvm::StringRef path
101 if (expected.size() != actual.size()) {
104 llvm::Twine(
"array length mismatch: expected ") + llvm::Twine(expected.size()) +
", got " +
105 llvm::Twine(actual.size())
108 size_t shared = std::min(expected.size(), actual.size());
109 for (
size_t i = 0; i < shared; ++i) {
110 diffJSON(expected[i], actual[i], out, appendIndexPath(path, i));
117static llvm::Expected<int64_t> jsonToInt(
const llvm::json::Value *json) {
118 if (std::optional<int64_t> integer = json->getAsInteger()) {
121 if (std::optional<llvm::StringRef> str = json->getAsString()) {
123 if (!str->getAsInteger(10, value)) {
127 return makeError(
"expected integer-compatible JSON value");
131static llvm::Expected<llvm::DynamicAPInt>
132jsonToFelt(
const llvm::json::Value *json,
const Field &field) {
133 if (std::optional<llvm::StringRef> str = json->getAsString()) {
136 if (std::optional<int64_t> integer = json->getAsInteger()) {
137 return field.reduce(*integer);
139 return makeError(
"expected felt value as JSON integer or decimal string");
143static llvm::Expected<WitnessVal> parseJSONArray(
144 const llvm::json::Value *json, array::ArrayType type,
const Field &field, Operation *origin,
147 const auto *jsonArray = json->getAsArray();
152 llvm::ArrayRef<int64_t> shape = type.getShape();
153 if (dimIndex >= shape.size()) {
158 return expectedSize.takeError();
160 if (jsonArray->size() != *expectedSize) {
161 return makeError(
"JSON array length does not match LLZK array dimension");
164 auto arrayValue = std::make_shared<ArrayValue>();
165 arrayValue->type = type;
166 if (dimIndex == shape.size() - 1) {
167 arrayValue->elements.reserve(jsonArray->size());
168 for (
const llvm::json::Value &elem : *jsonArray) {
169 auto parsed =
parseJSONValue(&elem, type.getElementType(), field, origin);
171 return parsed.takeError();
173 arrayValue->elements.push_back(*parsed);
178 arrayValue->elements.reserve(jsonArray->size());
179 for (
const llvm::json::Value &elem : *jsonArray) {
180 auto parsed = parseJSONArray(&elem, type, field, origin, dimIndex + 1);
182 return parsed.takeError();
184 auto subArray =
asArray(*parsed);
186 return subArray.takeError();
188 for (
const WitnessVal &subElem : (*subArray)->elements) {
189 arrayValue->elements.push_back(subElem);
196static llvm::Expected<llvm::json::Value> feltToJSON(
const llvm::DynamicAPInt &value) {
197 std::string rendered;
198 llvm::raw_string_ostream os(rendered);
200 return llvm::json::Value(os.str());
204static llvm::Expected<llvm::json::Value> serializeJSONArray(
205 const ArrayValueRef &arrayValue, array::ArrayType type, SymbolTableCollection &tables,
206 Operation *origin,
SerializationMode mode,
size_t dimIndex = 0,
size_t flatOffset = 0
208 llvm::json::Array jsonArray;
209 llvm::ArrayRef<int64_t> shape = type.getShape();
212 return dimSize.takeError();
214 if (dimIndex == shape.size() - 1) {
215 for (
size_t i = 0; i < *dimSize; ++i) {
216 bool overflow =
false;
217 size_t elementOffset = llvm::SaturatingAdd(flatOffset, i, &overflow);
219 return makeError(
"JSON array output flat index would overflow size_t");
222 arrayValue->elements[elementOffset], type.getElementType(), tables, origin, mode
225 return elem.takeError();
227 jsonArray.push_back(*elem);
229 return llvm::json::Value(std::move(jsonArray));
232 size_t subArraySize = 1;
233 for (
size_t i = dimIndex + 1; i < shape.size(); ++i) {
236 return nextDimSize.takeError();
238 bool overflow =
false;
239 subArraySize = llvm::SaturatingMultiply(subArraySize, *nextDimSize, &overflow);
241 return makeError(
"JSON array output sub-array size would overflow size_t");
245 for (
size_t i = 0; i < *dimSize; ++i) {
246 bool overflow =
false;
247 size_t nextOffset = llvm::SaturatingMultiplyAdd(i, subArraySize, flatOffset, &overflow);
249 return makeError(
"JSON array output flat offset would overflow size_t");
252 serializeJSONArray(arrayValue, type, tables, origin, mode, dimIndex + 1, nextOffset);
254 return subArray.takeError();
256 jsonArray.push_back(*subArray);
258 return llvm::json::Value(std::move(jsonArray));
262llvm::Expected<WitnessVal>
264 return llvm::TypeSwitch<Type, llvm::Expected<WitnessVal>>(type)
265 .Case([&](
felt::FeltType) -> llvm::Expected<WitnessVal> {
return jsonToFelt(json, field); })
267 return parseJSONArray(json, arrayType, field, origin);
270 return makeError(
"pod JSON inputs are not supported in llzk-witgen v1");
273 return makeError(
"struct JSON inputs are not supported in llzk-witgen v1");
275 .Case([&](IndexType) -> llvm::Expected<WitnessVal> {
276 auto integer = jsonToInt(json);
278 return integer.takeError();
282 .Case([&](IntegerType intType) -> llvm::Expected<WitnessVal> {
283 if (intType.getWidth() == 1) {
284 if (std::optional<bool> boolValue = json->getAsBoolean()) {
287 auto integer = jsonToInt(json);
289 return integer.takeError();
291 return *integer != 0;
293 return makeError(
"only i1 integer JSON inputs are supported");
294 }).Default([&](Type) -> llvm::Expected<WitnessVal> {
295 return makeError(
"unsupported input type in llzk-witgen");
301 const WitnessVal &value, Type type, SymbolTableCollection &tables, Operation *origin,
304 return llvm::TypeSwitch<Type, llvm::Expected<llvm::json::Value>>(type)
306 auto feltValue =
asFelt(value);
308 return feltValue.takeError();
310 return feltToJSON(*feltValue);
312 .Case([&](
array::ArrayType arrayType) -> llvm::Expected<llvm::json::Value> {
313 auto arrayValue =
asArray(value);
315 return arrayValue.takeError();
317 return serializeJSONArray(*arrayValue, arrayType, tables, origin, mode);
319 .Case([&](
pod::PodType podType) -> llvm::Expected<llvm::json::Value> {
320 auto podValue =
asPod(value);
322 return podValue.takeError();
324 llvm::json::Object result;
325 for (pod::RecordAttr record : podType.
getRecords()) {
326 auto it = (*podValue)->records.find(record.getName().getValue());
327 if (it == (*podValue)->records.end()) {
328 return makeError(
"missing POD record during JSON serialization");
330 auto serialized =
serializeJSONValue(it->second, record.getType(), tables, origin, mode);
332 return serialized.takeError();
334 result[record.getName().getValue()] = *serialized;
336 return llvm::json::Value(std::move(result));
341 return structValue.takeError();
344 if (failed(defLookup)) {
345 return makeError(
"could not resolve struct type during JSON serialization");
347 llvm::json::Object result;
349 auto it = (*structValue)->members.find(member.getSymName());
350 if (it == (*structValue)->members.end()) {
351 return makeError(
"missing struct member during JSON serialization");
355 if (!member.hasPublicAttr()) {
360 !isa<component::StructType>(member.getType())) {
365 auto serialized =
serializeJSONValue(it->second, member.getType(), tables, origin, mode);
367 return serialized.takeError();
370 auto *
object = serialized->getAsObject();
371 if (!
object || object->empty()) {
375 result[member.getSymName()] = *serialized;
377 return llvm::json::Value(std::move(result));
379 .Case([&](IndexType) -> llvm::Expected<llvm::json::Value> {
380 auto indexValue =
asIndex(value);
382 return indexValue.takeError();
384 return llvm::json::Value(*indexValue);
386 .Case([&](IntegerType intType) -> llvm::Expected<llvm::json::Value> {
387 if (intType.getWidth() != 1) {
388 return makeError(
"only i1 integer JSON serialization is supported");
390 auto boolValue =
asBool(value);
392 return boolValue.takeError();
394 return llvm::json::Value(*boolValue);
395 }).Default([&](Type) -> llvm::Expected<llvm::json::Value> {
396 return makeError(
"unsupported output type in llzk-witgen");
402 ArrayRef<InputBinding> bindings, ArrayRef<WitnessVal> values, SymbolTableCollection &tables,
405 if (bindings.size() != values.size()) {
406 return makeError(
"input binding count mismatch during witness JSON assembly");
409 llvm::json::Object result;
410 for (
auto [binding, value] : llvm::zip(bindings, values)) {
414 return serialized.takeError();
416 result[binding.name] = *serialized;
423 const WitnessVal &root, Type rootType, ArrayRef<std::string> path,
424 SymbolTableCollection &tables, Operation *origin
430 if (
auto structType = dyn_cast<component::StructType>(rootType)) {
433 return structValue.takeError();
435 auto defLookup = structType.getDefinition(tables, origin);
436 if (failed(defLookup)) {
437 return makeError(
"could not resolve struct type while extracting witness value");
440 if (member.getSymName() != path.front()) {
443 auto it = (*structValue)->members.find(member.getSymName());
444 if (it == (*structValue)->members.end()) {
445 return makeError(
"missing struct member while extracting witness value");
447 return extractValueAtPath(it->second, member.getType(), path.drop_front(), tables, origin);
449 return makeError(
"unknown struct member while extracting witness value");
452 if (
auto podType = dyn_cast<pod::PodType>(rootType)) {
453 auto podValue =
asPod(root);
455 return podValue.takeError();
457 for (pod::RecordAttr record : podType.getRecords()) {
458 if (record.getName().getValue() != path.front()) {
461 auto it = (*podValue)->records.find(record.getName().getValue());
462 if (it == (*podValue)->records.end()) {
463 return makeError(
"missing POD record while extracting witness value");
465 return extractValueAtPath(it->second, record.getType(), path.drop_front(), tables, origin);
467 return makeError(
"unknown POD record while extracting witness value");
470 return makeError(
"extra witness path components for non-aggregate value");
474 const llvm::json::Value &expected,
const llvm::json::Value &actual,
475 llvm::SmallVectorImpl<JSONMismatch> &out, llvm::StringRef path
477 if (expected.kind() != actual.kind()) {
480 llvm::Twine(
"type mismatch: expected ") + jsonKind(expected) +
", got " + jsonKind(actual)
485 if (
const auto *expectedObject = expected.getAsObject()) {
486 diffObjects(*expectedObject, *actual.getAsObject(), out, path);
489 if (
const auto *expectedArray = expected.getAsArray()) {
490 diffArrays(*expectedArray, *actual.getAsArray(), out, path);
493 if (expected == actual) {
499 llvm::Twine(
"value mismatch: expected ") + renderJSON(expected) +
", got " +
506 os << mismatch.path <<
": " << mismatch.message <<
'\n';
Information about the prime finite field used for the interval analysis.
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
llvm::Expected< llvm::json::Object > buildInputsJSONObject(ArrayRef< InputBinding > bindings, ArrayRef< WitnessVal > values, SymbolTableCollection &tables, Operation *origin)
Serialize named input values into a JSON object.
llvm::Expected< PodValueRef > asPod(const WitnessVal &value)
Require a POD value from the runtime variant.
llvm::Expected< llvm::json::Value > serializeJSONValue(const WitnessVal &value, Type type, SymbolTableCollection &tables, Operation *origin, SerializationMode mode)
Serialize a supported LLZK runtime value into JSON.
llvm::Expected< bool > asBool(const WitnessVal &value)
Require a boolean value from the runtime variant.
llvm::Expected< WitnessVal > extractValueAtPath(const WitnessVal &root, Type rootType, ArrayRef< std::string > path, SymbolTableCollection &tables, Operation *origin)
Extract one nested runtime leaf by path.
void diffJSON(const llvm::json::Value &expected, const llvm::json::Value &actual, llvm::SmallVectorImpl< JSONMismatch > &out, llvm::StringRef path)
Compare two JSON values structurally and append any mismatches to out.
SerializationMode
Select how struct values are filtered during JSON serialization.
std::shared_ptr< ArrayValue > ArrayValueRef
Shared runtime storage for LLZK array values.
llvm::Expected< int64_t > asIndex(const WitnessVal &value)
Require an index value from the runtime variant.
bool memberIsSignal(component::StructDefOp owner, component::MemberDefOp member)
Return true iff the member is considered a witness signal.
llvm::Expected< size_t > checkedShapeDimToSize(int64_t dim, llvm::StringRef context)
Convert one static dimension to size_t, rejecting dynamic or invalid sizes.
llvm::Expected< WitnessVal > parseJSONValue(const llvm::json::Value *json, Type type, const Field &field, Operation *origin)
Parse a supported LLZK input type from JSON.
void printJSONMismatches(llvm::raw_ostream &os, llvm::ArrayRef< JSONMismatch > mismatches)
Render one human-readable mismatch report.
std::variant< std::monostate, bool, int64_t, llvm::DynamicAPInt, ArrayValueRef, PodValueRef, StructValueRef > WitnessVal
Runtime value representation used by the tool-local interpreter.
llvm::Expected< llvm::DynamicAPInt > asFelt(const WitnessVal &value)
Require a felt value from the runtime variant.
llvm::Error makeError(const llvm::Twine &msg)
Build a string-backed error for user-facing witgen failures.
llvm::Expected< StructValueRef > asStruct(const WitnessVal &value)
Require a struct value from the runtime variant.
llvm::Expected< ArrayValueRef > asArray(const WitnessVal &value)
Require an array value from the runtime variant.
DynamicAPInt toDynamicAPInt(StringRef str)
One structured JSON mismatch between expected and actual witgen output.