26#include <mlir/Conversion/AffineToStandard/AffineToStandard.h>
27#include <mlir/Conversion/ArithToLLVM/ArithToLLVM.h>
28#include <mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h>
29#include <mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h>
30#include <mlir/Conversion/IndexToLLVM/IndexToLLVM.h>
31#include <mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h>
32#include <mlir/Conversion/Passes.h>
33#include <mlir/Conversion/UBToLLVM/UBToLLVM.h>
34#include <mlir/Dialect/MemRef/Transforms/Passes.h>
35#include <mlir/Dialect/Utils/IndexingUtils.h>
36#include <mlir/ExecutionEngine/CRunnerUtils.h>
37#include <mlir/ExecutionEngine/ExecutionEngine.h>
38#include <mlir/Pass/PassManager.h>
39#include <mlir/Target/LLVMIR/Dialect/All.h>
40#include <mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h>
41#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
42#include <mlir/Transforms/Passes.h>
44#include <llvm/ADT/APInt.h>
45#include <llvm/Support/Endian.h>
46#include <llvm/Support/MathExtras.h>
47#include <llvm/Support/TargetSelect.h>
48#include <llvm/Support/raw_ostream.h>
61 unsigned feltBitWidth = 0;
63 std::vector<int64_t> shape;
64 std::vector<int64_t> strides;
65 std::vector<uint8_t> descriptor;
66 std::vector<uint8_t> storage;
70static size_t getElementBytes(
unsigned bitWidth) {
return (bitWidth + 7U) / 8U; }
73static llvm::Expected<std::vector<int64_t>> getBoundaryShape(Type type) {
74 if (isa<felt::FeltType>(type)) {
75 return std::vector<int64_t> {1};
77 if (
auto arrayType = dyn_cast<array::ArrayType>(type)) {
78 if (!isa<felt::FeltType>(arrayType.getElementType())) {
80 "execution-engine backend only supports arrays of felt values at the main boundary"
83 if (!arrayType.hasStaticShape()) {
85 "execution-engine backend only supports statically shaped arrays at the main boundary"
88 return std::vector<int64_t>(arrayType.getShape().begin(), arrayType.getShape().end());
91 "execution-engine backend only supports felt and array<...xfelt> main boundaries"
96static llvm::Expected<std::vector<int64_t>> computeStaticStrides(ArrayRef<int64_t> shape) {
97 for (int64_t dim : shape) {
100 return checkedDim.takeError();
104 auto strides = mlir::computeStrides(shape);
105 return std::vector<int64_t>(strides.begin(), strides.end());
109static llvm::Error buildDescriptor(BufferPack &buffer) {
112 return rank.takeError();
114 auto descriptorSize = llvm::DynamicAPInt(
sizeof(
void *)) * 2;
115 auto shapeAndStrideCount = llvm::DynamicAPInt(1) + *rank + *rank;
116 auto dynamicPart = llvm::DynamicAPInt(
sizeof(int64_t)) * shapeAndStrideCount;
117 auto totalSize = descriptorSize + dynamicPart;
118 auto checkedTotalSize =
120 if (!checkedTotalSize) {
121 return checkedTotalSize.takeError();
123 buffer.descriptor.resize(*checkedTotalSize);
124 uint8_t *cursor = buffer.descriptor.data();
125 uint8_t *base = buffer.storage.data();
126 std::memcpy(cursor,
static_cast<const void *
>(&base),
sizeof(
void *));
127 cursor +=
sizeof(
void *);
128 std::memcpy(cursor,
static_cast<const void *
>(&base),
sizeof(
void *));
129 cursor +=
sizeof(
void *);
130 const int64_t offset = 0;
131 std::memcpy(cursor, &offset,
sizeof(int64_t));
132 cursor +=
sizeof(int64_t);
133 for (int64_t size : buffer.shape) {
134 std::memcpy(cursor, &size,
sizeof(int64_t));
135 cursor +=
sizeof(int64_t);
137 for (int64_t stride : buffer.strides) {
138 std::memcpy(cursor, &stride,
sizeof(int64_t));
139 cursor +=
sizeof(int64_t);
141 return llvm::Error::success();
145static llvm::Expected<BufferPack> createBufferPack(Type type,
const Field &field) {
146 auto shape = getBoundaryShape(type);
148 return shape.takeError();
151 buffer.originalType = type;
152 buffer.feltBitWidth = field.bitWidth();
153 buffer.elemBytes = getElementBytes(buffer.feltBitWidth);
154 buffer.shape = std::move(*shape);
155 auto strides = computeStaticStrides(buffer.shape);
157 return strides.takeError();
159 buffer.strides = std::move(*strides);
162 return elementCount.takeError();
164 bool overflow =
false;
165 size_t storageBytes = llvm::SaturatingMultiply(*elementCount, buffer.elemBytes, &overflow);
167 return makeError(
"execution-engine buffer storage would overflow size_t");
169 buffer.storage.resize(storageBytes);
170 if (
auto error = buildDescriptor(buffer)) {
171 return std::move(error);
178storeElement(BufferPack &buffer,
size_t flatIndex,
const llvm::DynamicAPInt &value) {
179 bool overflow =
false;
180 size_t byteOffset = llvm::SaturatingMultiply(flatIndex, buffer.elemBytes, &overflow);
182 return makeError(
"execution-engine buffer store offset would overflow size_t");
186 return elemBytesU.takeError();
188 llvm::APInt raw =
toAPInt(value, buffer.feltBitWidth);
189 llvm::StoreIntToMemory(raw, buffer.storage.data() + byteOffset, *elemBytesU);
190 return llvm::Error::success();
194static llvm::Expected<llvm::DynamicAPInt> loadElement(
const BufferPack &buffer,
size_t flatIndex) {
195 bool overflow =
false;
196 size_t byteOffset = llvm::SaturatingMultiply(flatIndex, buffer.elemBytes, &overflow);
198 return makeError(
"execution-engine buffer load offset would overflow size_t");
202 return elemBytesU.takeError();
204 llvm::APInt raw(buffer.feltBitWidth, 0);
205 llvm::LoadIntFromMemory(raw, buffer.storage.data() + byteOffset, *elemBytesU);
210static llvm::Error fillInputBuffer(BufferPack &buffer,
const WitnessVal &value) {
211 if (isa<felt::FeltType>(buffer.originalType)) {
212 auto feltValue =
asFelt(value);
214 return feltValue.takeError();
216 return storeElement(buffer, 0, *feltValue);
219 auto arrayValue =
asArray(value);
221 return arrayValue.takeError();
225 return elementCount.takeError();
227 if ((*arrayValue)->elements.size() != *elementCount) {
228 return makeError(
"input array element count mismatch");
230 for (
size_t i = 0; i < (*arrayValue)->elements.size(); ++i) {
231 auto feltValue =
asFelt((*arrayValue)->elements[i]);
233 return feltValue.takeError();
235 if (
auto err = storeElement(buffer, i, *feltValue)) {
239 return llvm::Error::success();
243static llvm::Expected<llvm::json::Value>
244feltElementToJSON(
const BufferPack &buffer,
size_t flatIndex) {
245 auto element = loadElement(buffer, flatIndex);
247 return element.takeError();
249 std::string rendered;
250 llvm::raw_string_ostream(rendered) << *element;
251 return llvm::json::Value(rendered);
255static llvm::Expected<llvm::json::Value>
256bufferToJSONArray(
const BufferPack &buffer,
size_t dimIndex,
size_t flatOffset) {
257 if (dimIndex == SIZE_MAX) {
258 return makeError(
"execution-engine JSON output would overflow size_t");
262 return dimSize.takeError();
264 if (dimIndex + 1 == buffer.shape.size()) {
265 llvm::json::Array result;
266 for (
size_t i = 0; i < *dimSize; ++i) {
267 bool overflow =
false;
268 size_t elementOffset = llvm::SaturatingAdd(i, flatOffset, &overflow);
270 return makeError(
"execution-engine JSON output would overflow size_t");
272 auto element = feltElementToJSON(buffer, elementOffset);
274 return element.takeError();
276 result.push_back(*element);
278 return llvm::json::Value(std::move(result));
282 llvm::ArrayRef<int64_t>(buffer.shape).drop_front(dimIndex + 1),
"execution-engine JSON output"
285 return subArraySize.takeError();
288 llvm::json::Array result;
289 for (
size_t i = 0; i < *dimSize; ++i) {
290 bool overflow =
false;
291 size_t nextOffset = llvm::SaturatingMultiplyAdd(i, *subArraySize, flatOffset, &overflow);
293 return makeError(
"execution-engine JSON output would overflow size_t");
295 auto subArray = bufferToJSONArray(buffer, dimIndex + 1, nextOffset);
297 return subArray.takeError();
299 result.push_back(*subArray);
301 return llvm::json::Value(std::move(result));
305static llvm::Expected<llvm::json::Value> bufferToJSON(
const BufferPack &buffer) {
306 if (isa<felt::FeltType>(buffer.originalType)) {
307 return feltElementToJSON(buffer, 0);
309 return bufferToJSONArray(buffer, 0, 0);
313static llvm::Expected<OwningOpRef<ModuleOp>> buildExecutionEngineModule(
316 OwningOpRef<ModuleOp> cloned = cast<ModuleOp>(moduleOp->clone());
317 PassManager pm(cloned->getContext());
320 if (failed(pm.run(*cloned))) {
321 return makeError(
"failed to lower LLZK compute IR to execution-engine core dialects");
327static llvm::Error finalizeExecutionEngineModule(ModuleOp moduleOp) {
328 PassManager pm(moduleOp.getContext());
329 pm.addPass(mlir::createCanonicalizerPass());
330 pm.addPass(mlir::createCSEPass());
331 pm.addPass(mlir::memref::createExpandStridedMetadataPass());
332 pm.addPass(mlir::createLowerAffinePass());
333 pm.addPass(mlir::createConvertSCFToCFPass());
334 pm.addPass(mlir::createCanonicalizerPass());
335 pm.addPass(mlir::createCSEPass());
336 pm.addPass(mlir::createConvertToLLVMPass());
337 pm.addPass(mlir::createReconcileUnrealizedCastsPass());
338 if (failed(pm.run(moduleOp))) {
339 return makeError(
"failed to lower execution-engine module to LLVM dialect");
341 return llvm::Error::success();
345static void maybeDumpModule(ModuleOp moduleOp,
bool enabled, llvm::StringRef title) {
349 llvm::errs() << title <<
":\n";
350 moduleOp.print(llvm::errs());
351 llvm::errs() <<
'\n';
358 ModuleOp moduleOp, SymbolTableCollection &tables,
const Field &field,
362 if (failed(mainDef) || !mainDef.value()) {
363 return makeError(
"module is missing a concrete llzk.main struct");
367 return makeError(
"main struct is missing @compute");
378 auto parsedArgs = [&]() -> llvm::Expected<llvm::SmallVector<WitnessVal>> {
379 llvm::SmallVector<WitnessVal> args;
380 const auto *jsonObject = input.getAsObject();
381 const auto *jsonArray = input.getAsArray();
382 if (!jsonObject && !jsonArray) {
383 return makeError(
"inputs JSON must be either an object or an array");
386 for (
unsigned i = 0; i < computeFunc.getNumArguments(); ++i) {
387 std::optional<StringAttr> argName = computeFunc.
getArgNameAttr(i);
389 return makeError(
"JSON object input requires function.arg_name on every main argument");
391 const llvm::json::Value *value = jsonObject->get(argName->getValue());
393 return makeError(llvm::Twine(
"missing JSON input field: ") + argName->getValue());
399 return parsed.takeError();
401 args.push_back(*parsed);
405 if (jsonArray->size() != computeFunc.getNumArguments()) {
406 return makeError(
"JSON positional input length does not match main compute arity");
408 for (
unsigned i = 0; i < computeFunc.getNumArguments(); ++i) {
410 &(*jsonArray)[i], computeFunc.
getArgumentTypes()[i], field, computeFunc.getOperation()
413 return parsed.takeError();
415 args.push_back(*parsed);
420 return parsedArgs.takeError();
425 mainDef->get(), tables, computeFunc.getOperation(), options.
outputScope
427 if (failed(outputs)) {
428 return makeError(
"failed to select witness outputs for execution-engine mode");
431 llvm::SmallVector<BufferPack> inputBuffers;
432 for (
auto [argType, parsed] : llvm::zip(computeFunc.
getArgumentTypes(), *parsedArgs)) {
433 auto buffer = createBufferPack(argType, field);
435 return buffer.takeError();
437 if (
auto err = fillInputBuffer(*buffer, parsed)) {
438 return std::move(err);
440 inputBuffers.push_back(std::move(*buffer));
443 llvm::SmallVector<BufferPack> outputBuffers;
445 auto buffer = createBufferPack(output.type, field);
447 return buffer.takeError();
449 outputBuffers.push_back(std::move(*buffer));
452 auto loweredModule = buildExecutionEngineModule(moduleOp, options.
outputScope, options);
453 if (!loweredModule) {
454 return loweredModule.takeError();
457 DialectRegistry registry;
458 mlir::arith::registerConvertArithToLLVMInterface(registry);
459 mlir::cf::registerConvertControlFlowToLLVMInterface(registry);
460 mlir::registerConvertFuncToLLVMInterface(registry);
461 mlir::index::registerConvertIndexToLLVMInterface(registry);
462 mlir::registerConvertMemRefToLLVMInterface(registry);
463 mlir::ub::registerConvertUBToLLVMInterface(registry);
464 (*loweredModule)->getContext()->appendDialectRegistry(registry);
465 (*loweredModule)->getContext()->loadAllAvailableDialects();
467 maybeDumpModule(**loweredModule, options.
dumpJITCore,
"llzk-witgen JIT core");
468 if (
auto err = finalizeExecutionEngineModule(**loweredModule)) {
469 return std::move(err);
471 maybeDumpModule(**loweredModule, options.
dumpJITLLVM,
"llzk-witgen JIT LLVM");
473 llvm::InitializeNativeTarget();
474 llvm::InitializeNativeTargetAsmPrinter();
475 mlir::registerBuiltinDialectTranslation(*(*loweredModule)->getContext());
476 mlir::registerLLVMDialectTranslation(*(*loweredModule)->getContext());
478 auto maybeEngine = mlir::ExecutionEngine::create(loweredModule->get());
480 return maybeEngine.takeError();
482 (*maybeEngine)->registerSymbols([](llvm::orc::MangleAndInterner interner) {
483 llvm::orc::SymbolMap symbolMap;
484 symbolMap[interner(
"memrefCopy")] = {
485 llvm::orc::ExecutorAddr::fromPtr(&memrefCopy), llvm::JITSymbolFlags::Exported
490 llvm::SmallVector<void *> descriptorPtrs;
491 descriptorPtrs.reserve(inputBuffers.size() + outputBuffers.size());
492 for (BufferPack &buffer : inputBuffers) {
493 descriptorPtrs.push_back(buffer.descriptor.data());
495 for (BufferPack &buffer : outputBuffers) {
496 descriptorPtrs.push_back(buffer.descriptor.data());
499 llvm::SmallVector<void *> packedArgs;
500 packedArgs.reserve(descriptorPtrs.size());
501 for (
void *&descriptorPtr : descriptorPtrs) {
502 packedArgs.push_back(
static_cast<void *
>(&descriptorPtr));
505 if (
auto err = (*maybeEngine)->invokePacked(
"_mlir_ciface___llzk_witgen_main", packedArgs)) {
506 return std::move(err);
509 llvm::SmallVector<llvm::json::Value> serializedOutputs;
510 serializedOutputs.reserve(outputBuffers.size());
511 for (
const BufferPack &buffer : outputBuffers) {
512 auto serialized = bufferToJSON(buffer);
514 return serialized.takeError();
516 serializedOutputs.push_back(*serialized);
526 return inputsJSON.takeError();
528 llvm::json::Object result;
529 result[
"inputs"] = llvm::json::Value(std::move(*inputsJSON));
531 return llvm::json::Value(std::move(result));
This file implements helper methods for constructing DynamicAPInts.
Information about the prime finite field used for the interval analysis.
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Required by FunctionOpInterface.
::std::optional<::mlir::StringAttr > getArgNameAttr(unsigned index)
Return the function.arg_name attribute for the argument at the given index.
Drive witness generation for the concrete llzk.main instance.
void setOutputScope(OutputScope newOutputScope)
Select which witness JSON scope this interpreter emits.
llvm::Expected< llvm::json::Value > runMainFromJSON(const llvm::json::Value &input)
Execute the main compute() function using JSON inputs.
llvm::Expected< llvm::json::Value > runWithExecutionEngine(ModuleOp moduleOp, SymbolTableCollection &tables, const Field &field, const llvm::json::Value &input, const WitgenOptions &options)
Execute witness generation through MLIR lowering and the LLVM execution engine.
llvm::Expected< T > checkedCast(U u)
llvm::Expected< llvm::json::Object > buildInputsJSONObject(ArrayRef< InputBinding > bindings, ArrayRef< WitnessVal > values, SymbolTableCollection &tables, Operation *origin)
Serialize named input values into a JSON object.
llvm::SmallVector< InputBinding > collectInputBindings(function::FuncDefOp computeFunc)
Collect stable JSON bindings for the main compute inputs.
std::mt19937_64 makeDefaultValueRng(const WitgenOptions &options)
Seed an RNG for random/default witness value materialization.
OutputScope
Select the JSON scope emitted by llzk-witgen.
FailureOr< llvm::SmallVector< OutputBinding > > collectOutputBindings(component::StructDefOp mainDef, SymbolTableCollection &tables, Operation *origin, OutputScope scope)
Collect the selected output bindings for the requested scope.
std::unique_ptr< Pass > createCreateWitgenEntryPass(bool emitFullWitness)
Create the pass that synthesizes the stable llzk-witgen JIT entry wrapper.
llvm::json::Value buildSignalsJSONObject(ArrayRef< OutputBinding > bindings, ArrayRef< llvm::json::Value > serializedLeaves)
Assemble a nested JSON object from selected witness leaves.
std::unique_ptr< Pass > createLowerComputeToCorePass(const WitgenOptions &options)
Create the pass that lowers supported LLZK compute IR into core MLIR dialects suitable for LLVM lower...
llvm::Expected< size_t > getStaticShapeElementCount(llvm::ArrayRef< int64_t > shape, llvm::StringRef context)
Return the static element count for one shape, rejecting dynamic sizes.
llvm::Expected< size_t > checkedDynamicAPIntToSize(const llvm::DynamicAPInt &value, llvm::StringRef context)
Convert a DynamicAPInt into size_t after validating its range.
llvm::Expected< size_t > checkedShapeDimToSize(int64_t dim, llvm::StringRef context)
Convert one static dimension to size_t, rejecting dynamic or invalid sizes.
llvm::Expected< WitnessVal > parseJSONValue(const llvm::json::Value *json, Type type, const Field &field, Operation *origin)
Parse a supported LLZK input type from JSON.
std::variant< std::monostate, bool, int64_t, llvm::DynamicAPInt, ArrayValueRef, PodValueRef, StructValueRef > WitnessVal
Runtime value representation used by the tool-local interpreter.
llvm::Expected< llvm::DynamicAPInt > asFelt(const WitnessVal &value)
Require a felt value from the runtime variant.
llvm::Error makeError(const llvm::Twine &msg)
Build a string-backed error for user-facing witgen failures.
llvm::Expected< ArrayValueRef > asArray(const WitnessVal &value)
Require an array value from the runtime variant.
DynamicAPInt toDynamicAPInt(StringRef str)
APInt toAPInt(const DynamicAPInt &val, unsigned bitWidth)
FailureOr< SymbolLookupResult< StructDefOp > > getMainInstanceDef(SymbolTableCollection &symbolTable, Operation *lookupFrom)
Describe one selected witness output leaf.
Configure one llzk-witgen execution.
UninitializedBehavior uninitializedBehavior