26#include <mlir/Dialect/Arith/IR/Arith.h>
27#include <mlir/Dialect/SCF/IR/SCF.h>
28#include <mlir/Dialect/Utils/IndexingUtils.h>
29#include <mlir/IR/Operation.h>
31#include <llvm/ADT/STLExtras.h>
32#include <llvm/ADT/SmallVector.h>
33#include <llvm/Support/MathExtras.h>
45static bool usesUnsignedCmp(scf::ForOp forOp) {
46 if (
auto boolAttr = forOp->getAttrOfType<BoolAttr>(
"unsignedCmp")) {
47 return boolAttr.getValue();
49 return forOp->hasAttr(
"unsignedCmp");
54 bool terminated =
false;
55 llvm::SmallVector<WitnessVal> values;
59llvm::Expected<size_t> checkedLinearize(
60 llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> indices, llvm::StringRef context
62 if (shape.size() != indices.size()) {
63 return makeError(
"wrong number of array indices");
65 for (
auto [idx, dim] : llvm::zip_equal(indices, shape)) {
66 if (idx < 0 || dim < 0 || idx >= dim) {
70 auto strides = mlir::computeStrides(shape);
78 ModuleOp module, SymbolTableCollection &symbolTables,
const Field &moduleField,
81 : moduleOp(module), tables(symbolTables), field(moduleField), uninitializedBehavior(behavior),
87class InvocationInterpreter {
90 InvocationInterpreter(
91 ModuleOp module, SymbolTableCollection &symbolTables,
const Field &moduleField,
94 : moduleOp(module), tables(symbolTables), field(moduleField), uninitializedBehavior(behavior),
98 llvm::Expected<llvm::SmallVector<WitnessVal>>
100 if (funcOp.isExternal()) {
101 return makeError(
"extern functions are not supported in llzk-witgen");
103 if (!funcOp.
getBody().hasOneBlock()) {
104 return makeError(
"multi-block functions are not supported in llzk-witgen");
106 if (funcOp.getNumArguments() != args.size()) {
107 return makeError(
"wrong number of arguments passed to function");
110 llvm::DenseMap<mlir::Value, WitnessVal> scope;
111 Block &entry = funcOp.
getBody().front();
112 for (
auto [arg, value] : llvm::zip(entry.getArguments(), args)) {
116 auto result = runBlock(entry, scope);
118 return result.takeError();
120 return result->values;
125 SymbolTableCollection &tables;
128 std::mt19937_64 &rng;
131 llvm::Expected<BlockResult>
132 runBlock(Block &block, llvm::DenseMap<mlir::Value, WitnessVal> &scope) {
133 for (Operation &op : block) {
134 auto handled = runOperation(op, scope);
136 return handled.takeError();
138 if (handled->terminated) {
142 return BlockResult {};
146 llvm::Expected<BlockResult> runRegion(
147 Region ®ion, ArrayRef<WitnessVal> args, llvm::DenseMap<mlir::Value, WitnessVal> scope
149 if (!region.hasOneBlock()) {
150 return makeError(
"multi-block regions are not supported in llzk-witgen");
152 Block &block = region.front();
153 if (block.getNumArguments() != args.size()) {
154 return makeError(
"region argument count mismatch");
156 for (
auto [arg, value] : llvm::zip(block.getArguments(), args)) {
159 return runBlock(block, scope);
163 llvm::Expected<WitnessVal>
164 lookup(mlir::Value value, llvm::DenseMap<mlir::Value, WitnessVal> &scope) {
165 auto it = scope.find(value);
166 if (it == scope.end()) {
167 return makeError(
"failed to find SSA value during interpretation");
173 llvm::Expected<llvm::SmallVector<WitnessVal>>
174 collectOperands(OperandRange operands, llvm::DenseMap<mlir::Value, WitnessVal> &scope) {
175 llvm::SmallVector<WitnessVal> values;
176 values.reserve(operands.size());
177 for (mlir::Value operand : operands) {
178 auto value = lookup(operand, scope);
180 return value.takeError();
182 values.push_back(*value);
188 llvm::Expected<BlockResult>
189 runOperation(Operation &op, llvm::DenseMap<mlir::Value, WitnessVal> &scope) {
190 if (
auto returnOp = dyn_cast<function::ReturnOp>(op)) {
191 auto values = collectOperands(returnOp.getOperands(), scope);
193 return values.takeError();
195 return BlockResult {
true, std::move(*values)};
197 if (
auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
198 auto values = collectOperands(yieldOp.getOperands(), scope);
200 return values.takeError();
202 return BlockResult {
true, std::move(*values)};
204 if (
auto conditionOp = dyn_cast<scf::ConditionOp>(op)) {
205 auto values = collectOperands(conditionOp.getOperands(), scope);
207 return values.takeError();
209 return BlockResult {
true, std::move(*values)};
212 auto bind = [&](ArrayRef<WitnessVal> results) -> llvm::Expected<BlockResult> {
213 if (results.size() != op.getNumResults()) {
214 return makeError(
"internal result count mismatch");
216 for (
auto [result, value] : llvm::zip(op.getResults(), results)) {
217 scope[result] = value;
219 return BlockResult {};
222 if (
auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
223 Attribute valueAttr = constantOp.getValue();
224 if (
auto integerAttr = dyn_cast<IntegerAttr>(valueAttr)) {
225 if (integerAttr.getType().isInteger(1)) {
226 return bind({
WitnessVal(integerAttr.getValue().getBoolValue())});
228 return bind({
WitnessVal(integerAttr.getValue().getSExtValue())});
230 return makeError(
"unsupported arith.constant value");
233 if (
auto nondetOp = dyn_cast<llzk::NonDetOp>(op)) {
235 nondetOp.getType(), tables, nondetOp.getOperation(), field, uninitializedBehavior, &rng
238 return value.takeError();
240 return bind({*value});
243 if (
auto assertOp = dyn_cast<boolean::AssertOp>(op)) {
244 auto condition = lookup(assertOp.getCondition(), scope);
246 return condition.takeError();
248 auto boolValue =
asBool(*condition);
250 return boolValue.takeError();
253 std::string msg =
"bool.assert failed";
254 if (
auto attr = assertOp.getMsg()) {
259 return BlockResult {};
262 if (
auto andOp = dyn_cast<boolean::AndBoolOp>(op)) {
263 auto lhsValue = lookup(andOp.getLhs(), scope);
264 auto rhsValue = lookup(andOp.getRhs(), scope);
266 return lhsValue.takeError();
269 return rhsValue.takeError();
271 auto lhs =
asBool(*lhsValue);
273 return lhs.takeError();
275 auto rhs =
asBool(*rhsValue);
277 return rhs.takeError();
281 if (
auto orOp = dyn_cast<boolean::OrBoolOp>(op)) {
282 auto lhsValue = lookup(orOp.getLhs(), scope);
283 auto rhsValue = lookup(orOp.getRhs(), scope);
285 return lhsValue.takeError();
288 return rhsValue.takeError();
290 auto lhs =
asBool(*lhsValue);
292 return lhs.takeError();
294 auto rhs =
asBool(*rhsValue);
296 return rhs.takeError();
300 if (
auto xorOp = dyn_cast<boolean::XorBoolOp>(op)) {
301 auto lhsValue = lookup(xorOp.getLhs(), scope);
302 auto rhsValue = lookup(xorOp.getRhs(), scope);
304 return lhsValue.takeError();
307 return rhsValue.takeError();
309 auto lhs =
asBool(*lhsValue);
311 return lhs.takeError();
313 auto rhs =
asBool(*rhsValue);
315 return rhs.takeError();
319 if (
auto notOp = dyn_cast<boolean::NotBoolOp>(op)) {
320 auto operand = lookup(
notOp.getOperand(), scope);
322 return operand.takeError();
324 auto boolValue =
asBool(*operand);
326 return boolValue.takeError();
330 if (
auto cmpOp = dyn_cast<boolean::CmpOp>(op)) {
331 auto lhs = lookup(cmpOp.getLhs(), scope);
332 auto rhs = lookup(cmpOp.getRhs(), scope);
334 return lhs.takeError();
337 return rhs.takeError();
339 auto lhsValue =
asFelt(*lhs);
341 return lhsValue.takeError();
343 auto rhsValue =
asFelt(*rhs);
345 return rhsValue.takeError();
348 switch (cmpOp.getPredicate()) {
350 result = *lhsValue == *rhsValue;
353 result = *lhsValue != *rhsValue;
356 result = *lhsValue < *rhsValue;
359 result = *lhsValue <= *rhsValue;
362 result = *lhsValue > *rhsValue;
365 result = *lhsValue >= *rhsValue;
371 if (
auto feltConst = dyn_cast<felt::FeltConstantOp>(op)) {
372 return bind({
WitnessVal(field.reduce(feltConst.getValue().getValue()))});
375 auto handleBinaryFelt = [&](
auto feltOp,
auto fn) -> llvm::Expected<BlockResult> {
376 auto lhsValue = lookup(feltOp.getLhs(), scope);
377 auto rhsValue = lookup(feltOp.getRhs(), scope);
379 return lhsValue.takeError();
382 return rhsValue.takeError();
384 auto lhs =
asFelt(*lhsValue);
386 return lhs.takeError();
388 auto rhs =
asFelt(*rhsValue);
390 return rhs.takeError();
392 return bind({
WitnessVal(field.reduce(fn(*lhs, *rhs)))});
395 if (
auto addOp = dyn_cast<felt::AddFeltOp>(op)) {
396 return handleBinaryFelt(addOp, [](
const auto &lhs,
const auto &rhs) {
return lhs + rhs; });
398 if (
auto powOp = dyn_cast<felt::PowFeltOp>(op)) {
399 return handleBinaryFelt(powOp, [&](
const auto &lhs,
const auto &rhs) {
400 return modExp(lhs, rhs, field.prime());
403 if (
auto andOp = dyn_cast<felt::AndFeltOp>(op)) {
404 return handleBinaryFelt(andOp, [](
const auto &lhs,
const auto &rhs) {
return lhs & rhs; });
406 if (
auto orOp = dyn_cast<felt::OrFeltOp>(op)) {
407 return handleBinaryFelt(orOp, [](
const auto &lhs,
const auto &rhs) {
return lhs | rhs; });
409 if (
auto xorOp = dyn_cast<felt::XorFeltOp>(op)) {
410 return handleBinaryFelt(xorOp, [](
const auto &lhs,
const auto &rhs) {
return lhs ^ rhs; });
412 if (
auto subOp = dyn_cast<felt::SubFeltOp>(op)) {
413 return handleBinaryFelt(subOp, [](
const auto &lhs,
const auto &rhs) {
return lhs - rhs; });
415 if (
auto mulOp = dyn_cast<felt::MulFeltOp>(op)) {
416 return handleBinaryFelt(mulOp, [](
const auto &lhs,
const auto &rhs) {
return lhs * rhs; });
418 if (
auto divOp = dyn_cast<felt::DivFeltOp>(op)) {
419 return handleBinaryFelt(divOp, [&](
const auto &lhs,
const auto &rhs) {
420 return lhs * field.inv(rhs);
423 if (
auto uintDivOp = dyn_cast<felt::UnsignedIntDivFeltOp>(op)) {
424 auto lhsValue = lookup(uintDivOp.getLhs(), scope);
425 auto rhsValue = lookup(uintDivOp.getRhs(), scope);
427 return lhsValue.takeError();
430 return rhsValue.takeError();
432 auto lhs =
asFelt(*lhsValue);
434 return lhs.takeError();
436 auto rhs =
asFelt(*rhsValue);
438 return rhs.takeError();
441 return makeError(
"felt.uintdiv divisor must be non-zero");
445 if (
auto sintDivOp = dyn_cast<felt::SignedIntDivFeltOp>(op)) {
446 auto lhsValue = lookup(sintDivOp.getLhs(), scope);
447 auto rhsValue = lookup(sintDivOp.getRhs(), scope);
449 return lhsValue.takeError();
452 return rhsValue.takeError();
454 auto lhs =
asFelt(*lhsValue);
456 return lhs.takeError();
458 auto rhs =
asFelt(*rhsValue);
460 return rhs.takeError();
463 return makeError(
"felt.sintdiv divisor must be non-zero");
465 return bind({
WitnessVal(field.reduce(field.toSigned(*lhs) / field.toSigned(*rhs)))});
467 if (
auto umodOp = dyn_cast<felt::UnsignedModFeltOp>(op)) {
468 auto lhsValue = lookup(umodOp.getLhs(), scope);
469 auto rhsValue = lookup(umodOp.getRhs(), scope);
471 return lhsValue.takeError();
474 return rhsValue.takeError();
476 auto lhs =
asFelt(*lhsValue);
478 return lhs.takeError();
480 auto rhs =
asFelt(*rhsValue);
482 return rhs.takeError();
485 return makeError(
"felt.umod divisor must be non-zero");
489 if (
auto smodOp = dyn_cast<felt::SignedModFeltOp>(op)) {
490 auto lhsValue = lookup(smodOp.getLhs(), scope);
491 auto rhsValue = lookup(smodOp.getRhs(), scope);
493 return lhsValue.takeError();
496 return rhsValue.takeError();
498 auto lhs =
asFelt(*lhsValue);
500 return lhs.takeError();
502 auto rhs =
asFelt(*rhsValue);
504 return rhs.takeError();
507 return makeError(
"felt.smod divisor must be non-zero");
509 return bind({
WitnessVal(field.reduce(field.toSigned(*lhs) % field.toSigned(*rhs)))});
511 if (
auto shrOp = dyn_cast<felt::ShrFeltOp>(op)) {
512 auto lhsValue = lookup(shrOp.getLhs(), scope);
513 auto rhsValue = lookup(shrOp.getRhs(), scope);
515 return lhsValue.takeError();
518 return rhsValue.takeError();
520 auto lhs =
asFelt(*lhsValue);
522 return lhs.takeError();
524 auto rhs =
asFelt(*rhsValue);
526 return rhs.takeError();
528 llvm::DynamicAPInt result(0);
529 if (*rhs < llvm::DynamicAPInt(field.bitWidth())) {
530 result = *lhs >> *rhs;
534 if (
auto shlOp = dyn_cast<felt::ShlFeltOp>(op)) {
535 auto lhsValue = lookup(shlOp.getLhs(), scope);
536 auto rhsValue = lookup(shlOp.getRhs(), scope);
538 return lhsValue.takeError();
541 return rhsValue.takeError();
543 auto lhs =
asFelt(*lhsValue);
545 return lhs.takeError();
547 auto rhs =
asFelt(*rhsValue);
549 return rhs.takeError();
551 llvm::DynamicAPInt two(2);
555 return bind({
WitnessVal(field.reduce(*lhs *
modExp(two, *rhs, field.prime())))});
557 if (
auto negOp = dyn_cast<felt::NegFeltOp>(op)) {
558 auto operand = lookup(negOp.getOperand(), scope);
560 return operand.takeError();
562 auto feltValue =
asFelt(*operand);
564 return feltValue.takeError();
566 return bind({
WitnessVal(field.reduce(-*feltValue))});
568 if (
auto invOp = dyn_cast<felt::InvFeltOp>(op)) {
569 auto operand = lookup(invOp.getOperand(), scope);
571 return operand.takeError();
573 auto feltValue =
asFelt(*operand);
575 return feltValue.takeError();
577 return bind({
WitnessVal(field.inv(*feltValue))});
579 if (
auto notOp = dyn_cast<felt::NotFeltOp>(op)) {
580 auto operand = lookup(
notOp.getOperand(), scope);
582 return operand.takeError();
584 auto feltValue =
asFelt(*operand);
586 return feltValue.takeError();
588 llvm::DynamicAPInt maxMask =
589 (llvm::DynamicAPInt(1) << llvm::DynamicAPInt(field.bitWidth())) - llvm::DynamicAPInt(1);
590 return bind({
WitnessVal(field.reduce(maxMask ^ *feltValue))});
594 if (
auto intToFeltOp = dyn_cast<cast::IntToFeltOp>(op)) {
595 auto operand = lookup(intToFeltOp.getValue(), scope);
597 return operand.takeError();
599 if (std::holds_alternative<bool>(*operand)) {
600 return bind({
WitnessVal(field.reduce(std::get<bool>(*operand) ? 1 : 0))});
602 auto integer =
asIndex(*operand);
604 return integer.takeError();
606 return bind({
WitnessVal(field.reduce(*integer))});
610 if (
auto feltToIndexOp = dyn_cast<cast::FeltToIndexOp>(op)) {
611 auto operand = lookup(feltToIndexOp.getValue(), scope);
613 return operand.takeError();
615 auto feltValue =
asFelt(*operand);
617 return feltValue.takeError();
619 auto &felt = *feltValue;
620 if (felt < 0 || felt > std::numeric_limits<int64_t>::max()) {
621 return makeError(
"felt value does not fit in index");
626 if (
auto structNewOp = dyn_cast<component::CreateStructOp>(op)) {
628 structNewOp.getType(), tables, structNewOp.getOperation(), field, uninitializedBehavior,
632 return value.takeError();
634 return bind({*value});
636 if (
auto readMemberOp = dyn_cast<component::MemberReadOp>(op)) {
637 auto componentValue = lookup(readMemberOp.getComponent(), scope);
638 if (!componentValue) {
639 return componentValue.takeError();
641 auto structValue =
asStruct(*componentValue);
643 return structValue.takeError();
645 auto it = (*structValue)->members.find(readMemberOp.getMemberName());
646 if (it == (*structValue)->members.end()) {
647 return makeError(
"missing struct member");
649 return bind({it->second});
651 if (
auto writeMemberOp = dyn_cast<component::MemberWriteOp>(op)) {
652 auto componentValue = lookup(writeMemberOp.getComponent(), scope);
653 auto memberValue = lookup(writeMemberOp.getVal(), scope);
654 if (!componentValue) {
655 return componentValue.takeError();
658 return memberValue.takeError();
660 auto structValue =
asStruct(*componentValue);
662 return structValue.takeError();
664 (*structValue)->members[writeMemberOp.getMemberName()] = *memberValue;
665 return BlockResult {};
668 if (
auto newPodOp = dyn_cast<pod::NewPodOp>(op)) {
670 newPodOp.getType(), tables, newPodOp.getOperation(), field, uninitializedBehavior, &rng
673 return podValue.takeError();
675 auto podRef =
asPod(*podValue);
677 return podRef.takeError();
679 auto initValues = newPodOp.getInitializedRecordValues();
680 for (pod::RecordValue init : initValues) {
681 auto value = lookup(init.value, scope);
683 return value.takeError();
685 (*podRef)->records[init.name] = *value;
687 return bind({*podRef});
689 if (
auto readPodOp = dyn_cast<pod::ReadPodOp>(op)) {
690 auto podValue = lookup(readPodOp.getPodRef(), scope);
692 return podValue.takeError();
694 auto podRef =
asPod(*podValue);
696 return podRef.takeError();
698 auto it = (*podRef)->records.find(readPodOp.getRecordName());
699 if (it == (*podRef)->records.end()) {
702 return bind({it->second});
704 if (
auto writePodOp = dyn_cast<pod::WritePodOp>(op)) {
705 auto podValue = lookup(writePodOp.getPodRef(), scope);
706 auto recordValue = lookup(writePodOp.getValue(), scope);
708 return podValue.takeError();
711 return recordValue.takeError();
713 auto podRef =
asPod(*podValue);
715 return podRef.takeError();
717 (*podRef)->records[writePodOp.getRecordName()] = *recordValue;
718 return BlockResult {};
721 if (
auto arrayNewOp = dyn_cast<array::CreateArrayOp>(op)) {
722 auto arrayValue = std::make_shared<ArrayValue>();
723 arrayValue->type = arrayNewOp.getType();
724 if (arrayNewOp.getElements().empty()) {
728 return elementCount.takeError();
730 arrayValue->elements.reserve(*elementCount);
731 for (
size_t i = 0; i < *elementCount; ++i) {
733 arrayValue->type.getElementType(), tables, arrayNewOp.getOperation(), field,
734 uninitializedBehavior, &rng
737 return elem.takeError();
739 arrayValue->elements.push_back(*elem);
742 auto values = collectOperands(arrayNewOp.getElements(), scope);
744 return values.takeError();
746 arrayValue->elements.assign(values->begin(), values->end());
748 return bind({arrayValue});
750 if (
auto readArrayOp = dyn_cast<array::ReadArrayOp>(op)) {
751 auto arrayValue = lookup(readArrayOp.getArrRef(), scope);
753 return arrayValue.takeError();
755 auto arrayRef =
asArray(*arrayValue);
757 return arrayRef.takeError();
759 llvm::SmallVector<int64_t> indices;
760 for (mlir::Value indexVal : readArrayOp.getIndices()) {
761 auto value = lookup(indexVal, scope);
763 return value.takeError();
767 return index.takeError();
769 indices.push_back(*index);
772 checkedLinearize((*arrayRef)->type.getShape(), indices,
"array index out of bounds");
774 return offset.takeError();
776 return bind({(*arrayRef)->elements[*offset]});
778 if (
auto writeArrayOp = dyn_cast<array::WriteArrayOp>(op)) {
779 auto arrayValue = lookup(writeArrayOp.getArrRef(), scope);
780 auto rvalue = lookup(writeArrayOp.getRvalue(), scope);
782 return arrayValue.takeError();
785 return rvalue.takeError();
787 auto arrayRef =
asArray(*arrayValue);
789 return arrayRef.takeError();
791 llvm::SmallVector<int64_t> indices;
792 for (mlir::Value indexVal : writeArrayOp.getIndices()) {
793 auto value = lookup(indexVal, scope);
795 return value.takeError();
799 return index.takeError();
801 indices.push_back(*index);
804 checkedLinearize((*arrayRef)->type.getShape(), indices,
"array index out of bounds");
806 return offset.takeError();
808 (*arrayRef)->elements[*offset] = *rvalue;
809 return BlockResult {};
811 if (
auto extractArrayOp = dyn_cast<array::ExtractArrayOp>(op)) {
812 auto arrayValue = lookup(extractArrayOp.getArrRef(), scope);
814 return arrayValue.takeError();
816 auto arrayRef =
asArray(*arrayValue);
818 return arrayRef.takeError();
820 llvm::SmallVector<int64_t> indices;
821 for (mlir::Value indexVal : extractArrayOp.getIndices()) {
822 auto value = lookup(indexVal, scope);
824 return value.takeError();
828 return index.takeError();
830 indices.push_back(*index);
832 llvm::ArrayRef<int64_t> shape = (*arrayRef)->type.getShape();
833 if (indices.size() >= shape.size()) {
834 return makeError(
"array.extract indices exceed array rank");
839 return subArraySize.takeError();
842 checkedLinearize(shape.take_front(indices.size()), indices,
"array index out of bounds");
844 return prefixOffset.takeError();
846 bool baseOverflow =
false;
847 size_t base = llvm::SaturatingMultiply(*prefixOffset, *subArraySize, &baseOverflow);
849 return makeError(
"array.extract element offset would overflow size_t");
851 auto subArray = std::make_shared<ArrayValue>();
852 subArray->type = extractArrayOp.getType();
853 subArray->elements.reserve(*subArraySize);
854 for (
size_t i = 0; i < *subArraySize; ++i) {
855 bool overflow =
false;
856 size_t elementOffset = llvm::SaturatingAdd(base, i, &overflow);
858 return makeError(
"array.extract element offset would overflow size_t");
860 subArray->elements.push_back((*arrayRef)->elements[elementOffset]);
862 return bind({subArray});
864 if (
auto insertArrayOp = dyn_cast<array::InsertArrayOp>(op)) {
865 auto arrayValue = lookup(insertArrayOp.getArrRef(), scope);
866 auto subArrayValue = lookup(insertArrayOp.getRvalue(), scope);
868 return arrayValue.takeError();
870 if (!subArrayValue) {
871 return subArrayValue.takeError();
873 auto arrayRef =
asArray(*arrayValue);
874 auto subArrayRef =
asArray(*subArrayValue);
876 return arrayRef.takeError();
879 return subArrayRef.takeError();
881 llvm::SmallVector<int64_t> indices;
882 for (mlir::Value indexVal : insertArrayOp.getIndices()) {
883 auto value = lookup(indexVal, scope);
885 return value.takeError();
889 return index.takeError();
891 indices.push_back(*index);
893 llvm::ArrayRef<int64_t> shape = (*arrayRef)->type.getShape();
894 size_t subArraySize = (*subArrayRef)->elements.size();
896 checkedLinearize(shape.take_front(indices.size()), indices,
"array index out of bounds");
898 return prefixOffset.takeError();
900 bool baseOverflow =
false;
901 size_t base = llvm::SaturatingMultiply(*prefixOffset, subArraySize, &baseOverflow);
903 return makeError(
"array.insert element offset would overflow size_t");
905 for (
size_t i = 0; i < subArraySize; ++i) {
906 bool overflow =
false;
907 size_t elementOffset = llvm::SaturatingAdd(base, i, &overflow);
909 return makeError(
"array.insert element offset would overflow size_t");
911 (*arrayRef)->elements[elementOffset] = (*subArrayRef)->elements[i];
913 return BlockResult {};
915 if (
auto arrayLenOp = dyn_cast<array::ArrayLengthOp>(op)) {
916 auto dimValue = lookup(arrayLenOp.getDim(), scope);
918 return dimValue.takeError();
922 return dim.takeError();
924 llvm::ArrayRef<int64_t> shape = arrayLenOp.getArrRefType().getShape();
927 return dimIndex.takeError();
929 if (*dimIndex >= shape.size()) {
930 return makeError(
"array.len dimension out of bounds");
935 if (
auto callOp = dyn_cast<function::CallOp>(op)) {
936 if (callOp.getTemplateParams() || !callOp.getMapOperands().empty()) {
937 return makeError(
"templated or affine-instantiated calls are not supported in llzk-witgen");
940 if (failed(callee)) {
941 return makeError(
"could not resolve called function");
943 auto args = collectOperands(callOp.getArgOperands(), scope);
945 return args.takeError();
947 auto results = run(callee->get(), *args);
949 return results.takeError();
951 return bind(*results);
954 auto handleBinaryIndex = [&](
auto arithOp,
auto fn) -> llvm::Expected<BlockResult> {
955 auto lhs = lookup(arithOp.getLhs(), scope);
956 auto rhs = lookup(arithOp.getRhs(), scope);
958 return lhs.takeError();
961 return rhs.takeError();
965 return lhsValue.takeError();
969 return rhsValue.takeError();
971 return bind({
WitnessVal(fn(*lhsValue, *rhsValue))});
974 if (
auto addIOp = dyn_cast<arith::AddIOp>(op)) {
975 return handleBinaryIndex(addIOp, [](int64_t lhs, int64_t rhs) {
return lhs + rhs; });
977 if (
auto subIOp = dyn_cast<arith::SubIOp>(op)) {
978 return handleBinaryIndex(subIOp, [](int64_t lhs, int64_t rhs) {
return lhs - rhs; });
980 if (
auto mulIOp = dyn_cast<arith::MulIOp>(op)) {
981 return handleBinaryIndex(mulIOp, [](int64_t lhs, int64_t rhs) {
return lhs * rhs; });
983 if (
auto divUIOp = dyn_cast<arith::DivUIOp>(op)) {
984 return handleBinaryIndex(divUIOp, [](int64_t lhs, int64_t rhs) {
986 auto divRes =
static_cast<uint64_t
>(lhs) /
static_cast<uint64_t
>(rhs);
987 return static_cast<int64_t
>(divRes);
990 if (
auto cmpIOp = dyn_cast<arith::CmpIOp>(op)) {
991 return handleBinaryIndex(cmpIOp, [&cmpIOp](int64_t lhs, int64_t rhs) ->
bool {
992 switch (cmpIOp.getPredicate()) {
993 case arith::CmpIPredicate::eq:
995 case arith::CmpIPredicate::ne:
997 case arith::CmpIPredicate::slt:
999 case arith::CmpIPredicate::sle:
1001 case arith::CmpIPredicate::sgt:
1003 case arith::CmpIPredicate::sge:
1006 case arith::CmpIPredicate::ult:
1007 return static_cast<uint64_t
>(lhs) <
static_cast<uint64_t
>(rhs);
1008 case arith::CmpIPredicate::ule:
1009 return static_cast<uint64_t
>(lhs) <=
static_cast<uint64_t
>(rhs);
1010 case arith::CmpIPredicate::ugt:
1011 return static_cast<uint64_t
>(lhs) >
static_cast<uint64_t
>(rhs);
1012 case arith::CmpIPredicate::uge:
1013 return static_cast<uint64_t
>(lhs) >=
static_cast<uint64_t
>(rhs);
1015 llvm_unreachable(
"unknown comparison predicate");
1019 if (
auto selectOp = dyn_cast<arith::SelectOp>(op)) {
1020 auto cond = lookup(selectOp.getCondition(), scope);
1021 auto trueValue = lookup(selectOp.getTrueValue(), scope);
1022 auto falseValue = lookup(selectOp.getFalseValue(), scope);
1024 return cond.takeError();
1027 return trueValue.takeError();
1030 return falseValue.takeError();
1032 auto condition =
asBool(*cond);
1034 return condition.takeError();
1036 return bind({*condition ? *trueValue : *falseValue});
1039 if (
auto ifOp = dyn_cast<scf::IfOp>(op)) {
1040 auto cond = lookup(ifOp.getCondition(), scope);
1042 return cond.takeError();
1044 auto condition =
asBool(*cond);
1046 return condition.takeError();
1048 if (!*condition && ifOp.getNumResults() == 0 && ifOp.getElseRegion().empty()) {
1051 Region ®ion = *condition ? ifOp.getThenRegion() : ifOp.getElseRegion();
1052 auto result = runRegion(region, {}, scope);
1054 return result.takeError();
1056 return bind(result->values);
1059 if (
auto forOp = dyn_cast<scf::ForOp>(op)) {
1060 auto lowerBoundValue = lookup(forOp.getLowerBound(), scope);
1061 auto upperBoundValue = lookup(forOp.getUpperBound(), scope);
1062 auto stepValue = lookup(forOp.getStep(), scope);
1063 if (!lowerBoundValue) {
1064 return lowerBoundValue.takeError();
1066 if (!upperBoundValue) {
1067 return upperBoundValue.takeError();
1070 return stepValue.takeError();
1072 auto lowerBound =
asIndex(*lowerBoundValue);
1074 return lowerBound.takeError();
1076 auto upperBound =
asIndex(*upperBoundValue);
1078 return upperBound.takeError();
1080 auto step =
asIndex(*stepValue);
1082 return step.takeError();
1084 auto iterValuesOrErr = collectOperands(forOp.getInitArgs(), scope);
1085 if (!iterValuesOrErr) {
1086 return iterValuesOrErr.takeError();
1088 llvm::SmallVector<WitnessVal> iterValues = std::move(*iterValuesOrErr);
1090 if (usesUnsignedCmp(forOp)) {
1092 auto lowerBoundUIntValue =
static_cast<uint64_t
>(*lowerBound);
1093 auto upperBoundUIntValue =
static_cast<uint64_t
>(*upperBound);
1094 auto stepUInt =
static_cast<uint64_t
>(*step);
1095 for (uint64_t iv = lowerBoundUIntValue, ub = upperBoundUIntValue, unsignedStep = stepUInt;
1096 iv < ub; iv += unsignedStep) {
1099 return signedIV.takeError();
1101 llvm::SmallVector<WitnessVal> regionArgs;
1103 regionArgs.append(iterValues.begin(), iterValues.end());
1104 auto result = runRegion(forOp.getRegion(), regionArgs, scope);
1106 return result.takeError();
1108 iterValues = std::move(result->values);
1111 for (int64_t iv = *lowerBound; iv < *upperBound; iv += *step) {
1112 llvm::SmallVector<WitnessVal> regionArgs;
1114 regionArgs.append(iterValues.begin(), iterValues.end());
1115 auto result = runRegion(forOp.getRegion(), regionArgs, scope);
1117 return result.takeError();
1119 iterValues = std::move(result->values);
1122 return bind(iterValues);
1125 if (
auto whileOp = dyn_cast<scf::WhileOp>(op)) {
1126 auto iterValuesOrErr = collectOperands(whileOp.getInits(), scope);
1127 if (!iterValuesOrErr) {
1128 return iterValuesOrErr.takeError();
1130 llvm::SmallVector<WitnessVal> iterValues = std::move(*iterValuesOrErr);
1132 auto beforeResult = runRegion(whileOp.getBefore(), iterValues, scope);
1133 if (!beforeResult) {
1134 return beforeResult.takeError();
1136 if (!beforeResult->terminated) {
1137 return makeError(
"scf.while before region must terminate with scf.condition");
1139 if (beforeResult->values.empty()) {
1140 return makeError(
"scf.while before region did not produce a condition");
1143 auto condition =
asBool(beforeResult->values.front());
1145 return condition.takeError();
1148 llvm::SmallVector<WitnessVal> nextValues;
1149 nextValues.append(beforeResult->values.begin() + 1, beforeResult->values.end());
1151 return bind(nextValues);
1154 auto afterResult = runRegion(whileOp.getAfter(), nextValues, scope);
1156 return afterResult.takeError();
1158 if (!afterResult->terminated) {
1159 return makeError(
"scf.while after region must terminate with scf.yield");
1161 iterValues = std::move(afterResult->values);
1165 return makeError(llvm::Twine(
"unsupported op in llzk-witgen: ") + op.getName().getStringRef());
1172llvm::Expected<llvm::SmallVector<WitnessVal>>
1174 return InvocationInterpreter(moduleOp, tables, field, uninitializedBehavior, rng)
This file implements helper methods for constructing DynamicAPInts.
This file defines methods symbol lookup across LLZK operations and included files.
Information about the prime finite field used for the interval analysis.
::mlir::Region & getBody()
llvm::Expected< llvm::SmallVector< WitnessVal > > run(llzk::function::FuncDefOp funcOp, mlir::ArrayRef< WitnessVal > args)
Run a function with concrete arguments and return its result values.
FunctionInterpreter(mlir::ModuleOp moduleOp, mlir::SymbolTableCollection &tables, const llzk::Field &field, UninitializedBehavior uninitializedBehavior, std::mt19937_64 rng)
Build an interpreter for one module and field configuration.
llvm::Expected< T > checkedCast(U u)
llvm::Expected< PodValueRef > asPod(const WitnessVal &value)
Require a POD value from the runtime variant.
llvm::Expected< bool > asBool(const WitnessVal &value)
Require a boolean value from the runtime variant.
UninitializedBehavior
Control how witgen materializes uninitialized/default values.
llvm::Expected< size_t > getStaticShapeElementCount(llvm::ArrayRef< int64_t > shape, llvm::StringRef context)
Return the static element count for one shape, rejecting dynamic sizes.
llvm::Expected< WitnessVal > defaultValue(Type type, SymbolTableCollection &tables, Operation *origin, const Field &field, UninitializedBehavior behavior, std::mt19937_64 *rng)
Build a default value for a supported LLZK type.
llvm::Expected< int64_t > asIndex(const WitnessVal &value)
Require an index value from the runtime variant.
llvm::Expected< size_t > checkedShapeDimToSize(int64_t dim, llvm::StringRef context)
Convert one static dimension to size_t, rejecting dynamic or invalid sizes.
std::variant< std::monostate, bool, int64_t, llvm::DynamicAPInt, ArrayValueRef, PodValueRef, StructValueRef > WitnessVal
Runtime value representation used by the tool-local interpreter.
llvm::Expected< llvm::DynamicAPInt > asFelt(const WitnessVal &value)
Require a felt value from the runtime variant.
llvm::Error makeError(const llvm::Twine &msg)
Build a string-backed error for user-facing witgen failures.
llvm::Expected< StructValueRef > asStruct(const WitnessVal &value)
Require a struct value from the runtime variant.
llvm::Expected< ArrayValueRef > asArray(const WitnessVal &value)
Require an array value from the runtime variant.
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
DynamicAPInt modExp(const DynamicAPInt &base, const DynamicAPInt &exp, const DynamicAPInt &mod)
ExpressionValue notOp(const llvm::SMTSolverRef &solver, const ExpressionValue &val)