LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Interpreter.cpp
Go to the documentation of this file.
1//===-- Interpreter.cpp - llzk-witgen compute interpreter -------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2026 Project LLZK
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#include "Interpreter.h"
11
12#include "Errors.h"
13#include "WitgenUtils.h"
14
22#include "llzk/Util/Compare.h"
25
26#include <mlir/Dialect/Arith/IR/Arith.h>
27#include <mlir/Dialect/SCF/IR/SCF.h>
28#include <mlir/Dialect/Utils/IndexingUtils.h>
29#include <mlir/IR/Operation.h>
30
31#include <llvm/ADT/STLExtras.h>
32#include <llvm/ADT/SmallVector.h>
33#include <llvm/Support/MathExtras.h>
34
35#include <limits>
36#include <random>
37
38using namespace mlir;
39
40namespace llzk::witgen {
41
42namespace {
43
45static bool usesUnsignedCmp(scf::ForOp forOp) {
46 if (auto boolAttr = forOp->getAttrOfType<BoolAttr>("unsignedCmp")) {
47 return boolAttr.getValue();
48 }
49 return forOp->hasAttr("unsignedCmp");
50}
51
53struct BlockResult {
54 bool terminated = false;
55 llvm::SmallVector<WitnessVal> values;
56};
57
59llvm::Expected<size_t> checkedLinearize(
60 llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> indices, llvm::StringRef context
61) {
62 if (shape.size() != indices.size()) {
63 return makeError("wrong number of array indices");
64 }
65 for (auto [idx, dim] : llvm::zip_equal(indices, shape)) {
66 if (idx < 0 || dim < 0 || idx >= dim) {
67 return makeError(context);
68 }
69 }
70 auto strides = mlir::computeStrides(shape);
71 return checkedCast<size_t>(mlir::linearize(indices, strides));
72}
73
74} // namespace
75
78 ModuleOp module, SymbolTableCollection &symbolTables, const Field &moduleField,
79 UninitializedBehavior behavior, std::mt19937_64 r
80)
81 : moduleOp(module), tables(symbolTables), field(moduleField), uninitializedBehavior(behavior),
82 rng(r) {}
83
84namespace {
85
87class InvocationInterpreter {
88public:
90 InvocationInterpreter(
91 ModuleOp module, SymbolTableCollection &symbolTables, const Field &moduleField,
92 UninitializedBehavior behavior, std::mt19937_64 &r
93 )
94 : moduleOp(module), tables(symbolTables), field(moduleField), uninitializedBehavior(behavior),
95 rng(r) {}
96
98 llvm::Expected<llvm::SmallVector<WitnessVal>>
99 run(function::FuncDefOp funcOp, ArrayRef<WitnessVal> args) {
100 if (funcOp.isExternal()) {
101 return makeError("extern functions are not supported in llzk-witgen");
102 }
103 if (!funcOp.getBody().hasOneBlock()) {
104 return makeError("multi-block functions are not supported in llzk-witgen");
105 }
106 if (funcOp.getNumArguments() != args.size()) {
107 return makeError("wrong number of arguments passed to function");
108 }
109
110 llvm::DenseMap<mlir::Value, WitnessVal> scope;
111 Block &entry = funcOp.getBody().front();
112 for (auto [arg, value] : llvm::zip(entry.getArguments(), args)) {
113 scope[arg] = value;
114 }
115
116 auto result = runBlock(entry, scope);
117 if (!result) {
118 return result.takeError();
119 }
120 return result->values;
121 }
122
123private:
124 ModuleOp moduleOp;
125 SymbolTableCollection &tables;
126 const Field &field;
127 UninitializedBehavior uninitializedBehavior;
128 std::mt19937_64 &rng;
129
131 llvm::Expected<BlockResult>
132 runBlock(Block &block, llvm::DenseMap<mlir::Value, WitnessVal> &scope) {
133 for (Operation &op : block) {
134 auto handled = runOperation(op, scope);
135 if (!handled) {
136 return handled.takeError();
137 }
138 if (handled->terminated) {
139 return *handled;
140 }
141 }
142 return BlockResult {};
143 }
144
146 llvm::Expected<BlockResult> runRegion(
147 Region &region, ArrayRef<WitnessVal> args, llvm::DenseMap<mlir::Value, WitnessVal> scope
148 ) {
149 if (!region.hasOneBlock()) {
150 return makeError("multi-block regions are not supported in llzk-witgen");
151 }
152 Block &block = region.front();
153 if (block.getNumArguments() != args.size()) {
154 return makeError("region argument count mismatch");
155 }
156 for (auto [arg, value] : llvm::zip(block.getArguments(), args)) {
157 scope[arg] = value;
158 }
159 return runBlock(block, scope);
160 }
161
163 llvm::Expected<WitnessVal>
164 lookup(mlir::Value value, llvm::DenseMap<mlir::Value, WitnessVal> &scope) {
165 auto it = scope.find(value);
166 if (it == scope.end()) {
167 return makeError("failed to find SSA value during interpretation");
168 }
169 return it->second;
170 }
171
173 llvm::Expected<llvm::SmallVector<WitnessVal>>
174 collectOperands(OperandRange operands, llvm::DenseMap<mlir::Value, WitnessVal> &scope) {
175 llvm::SmallVector<WitnessVal> values;
176 values.reserve(operands.size());
177 for (mlir::Value operand : operands) {
178 auto value = lookup(operand, scope);
179 if (!value) {
180 return value.takeError();
181 }
182 values.push_back(*value);
183 }
184 return values;
185 }
186
188 llvm::Expected<BlockResult>
189 runOperation(Operation &op, llvm::DenseMap<mlir::Value, WitnessVal> &scope) {
190 if (auto returnOp = dyn_cast<function::ReturnOp>(op)) {
191 auto values = collectOperands(returnOp.getOperands(), scope);
192 if (!values) {
193 return values.takeError();
194 }
195 return BlockResult {true, std::move(*values)};
196 }
197 if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
198 auto values = collectOperands(yieldOp.getOperands(), scope);
199 if (!values) {
200 return values.takeError();
201 }
202 return BlockResult {true, std::move(*values)};
203 }
204 if (auto conditionOp = dyn_cast<scf::ConditionOp>(op)) {
205 auto values = collectOperands(conditionOp.getOperands(), scope);
206 if (!values) {
207 return values.takeError();
208 }
209 return BlockResult {true, std::move(*values)};
210 }
211
212 auto bind = [&](ArrayRef<WitnessVal> results) -> llvm::Expected<BlockResult> {
213 if (results.size() != op.getNumResults()) {
214 return makeError("internal result count mismatch");
215 }
216 for (auto [result, value] : llvm::zip(op.getResults(), results)) {
217 scope[result] = value;
218 }
219 return BlockResult {};
220 };
221
222 if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
223 Attribute valueAttr = constantOp.getValue();
224 if (auto integerAttr = dyn_cast<IntegerAttr>(valueAttr)) {
225 if (integerAttr.getType().isInteger(1)) {
226 return bind({WitnessVal(integerAttr.getValue().getBoolValue())});
227 }
228 return bind({WitnessVal(integerAttr.getValue().getSExtValue())});
229 }
230 return makeError("unsupported arith.constant value");
231 }
232
233 if (auto nondetOp = dyn_cast<llzk::NonDetOp>(op)) {
234 auto value = defaultValue(
235 nondetOp.getType(), tables, nondetOp.getOperation(), field, uninitializedBehavior, &rng
236 );
237 if (!value) {
238 return value.takeError();
239 }
240 return bind({*value});
241 }
242
243 if (auto assertOp = dyn_cast<boolean::AssertOp>(op)) {
244 auto condition = lookup(assertOp.getCondition(), scope);
245 if (!condition) {
246 return condition.takeError();
247 }
248 auto boolValue = asBool(*condition);
249 if (!boolValue) {
250 return boolValue.takeError();
251 }
252 if (!*boolValue) {
253 std::string msg = "bool.assert failed";
254 if (auto attr = assertOp.getMsg()) {
255 msg = attr->str();
256 }
257 return makeError(msg);
258 }
259 return BlockResult {};
260 }
261
262 if (auto andOp = dyn_cast<boolean::AndBoolOp>(op)) {
263 auto lhsValue = lookup(andOp.getLhs(), scope);
264 auto rhsValue = lookup(andOp.getRhs(), scope);
265 if (!lhsValue) {
266 return lhsValue.takeError();
267 }
268 if (!rhsValue) {
269 return rhsValue.takeError();
270 }
271 auto lhs = asBool(*lhsValue);
272 if (!lhs) {
273 return lhs.takeError();
274 }
275 auto rhs = asBool(*rhsValue);
276 if (!rhs) {
277 return rhs.takeError();
278 }
279 return bind({WitnessVal(*lhs && *rhs)});
280 }
281 if (auto orOp = dyn_cast<boolean::OrBoolOp>(op)) {
282 auto lhsValue = lookup(orOp.getLhs(), scope);
283 auto rhsValue = lookup(orOp.getRhs(), scope);
284 if (!lhsValue) {
285 return lhsValue.takeError();
286 }
287 if (!rhsValue) {
288 return rhsValue.takeError();
289 }
290 auto lhs = asBool(*lhsValue);
291 if (!lhs) {
292 return lhs.takeError();
293 }
294 auto rhs = asBool(*rhsValue);
295 if (!rhs) {
296 return rhs.takeError();
297 }
298 return bind({WitnessVal(*lhs || *rhs)});
299 }
300 if (auto xorOp = dyn_cast<boolean::XorBoolOp>(op)) {
301 auto lhsValue = lookup(xorOp.getLhs(), scope);
302 auto rhsValue = lookup(xorOp.getRhs(), scope);
303 if (!lhsValue) {
304 return lhsValue.takeError();
305 }
306 if (!rhsValue) {
307 return rhsValue.takeError();
308 }
309 auto lhs = asBool(*lhsValue);
310 if (!lhs) {
311 return lhs.takeError();
312 }
313 auto rhs = asBool(*rhsValue);
314 if (!rhs) {
315 return rhs.takeError();
316 }
317 return bind({WitnessVal(*lhs != *rhs)});
318 }
319 if (auto notOp = dyn_cast<boolean::NotBoolOp>(op)) {
320 auto operand = lookup(notOp.getOperand(), scope);
321 if (!operand) {
322 return operand.takeError();
323 }
324 auto boolValue = asBool(*operand);
325 if (!boolValue) {
326 return boolValue.takeError();
327 }
328 return bind({WitnessVal(!*boolValue)});
329 }
330 if (auto cmpOp = dyn_cast<boolean::CmpOp>(op)) {
331 auto lhs = lookup(cmpOp.getLhs(), scope);
332 auto rhs = lookup(cmpOp.getRhs(), scope);
333 if (!lhs) {
334 return lhs.takeError();
335 }
336 if (!rhs) {
337 return rhs.takeError();
338 }
339 auto lhsValue = asFelt(*lhs);
340 if (!lhsValue) {
341 return lhsValue.takeError();
342 }
343 auto rhsValue = asFelt(*rhs);
344 if (!rhsValue) {
345 return rhsValue.takeError();
346 }
347 bool result = false;
348 switch (cmpOp.getPredicate()) {
350 result = *lhsValue == *rhsValue;
351 break;
353 result = *lhsValue != *rhsValue;
354 break;
356 result = *lhsValue < *rhsValue;
357 break;
359 result = *lhsValue <= *rhsValue;
360 break;
362 result = *lhsValue > *rhsValue;
363 break;
365 result = *lhsValue >= *rhsValue;
366 break;
367 }
368 return bind({WitnessVal(result)});
369 }
370
371 if (auto feltConst = dyn_cast<felt::FeltConstantOp>(op)) {
372 return bind({WitnessVal(field.reduce(feltConst.getValue().getValue()))});
373 }
374
375 auto handleBinaryFelt = [&](auto feltOp, auto fn) -> llvm::Expected<BlockResult> {
376 auto lhsValue = lookup(feltOp.getLhs(), scope);
377 auto rhsValue = lookup(feltOp.getRhs(), scope);
378 if (!lhsValue) {
379 return lhsValue.takeError();
380 }
381 if (!rhsValue) {
382 return rhsValue.takeError();
383 }
384 auto lhs = asFelt(*lhsValue);
385 if (!lhs) {
386 return lhs.takeError();
387 }
388 auto rhs = asFelt(*rhsValue);
389 if (!rhs) {
390 return rhs.takeError();
391 }
392 return bind({WitnessVal(field.reduce(fn(*lhs, *rhs)))});
393 };
394
395 if (auto addOp = dyn_cast<felt::AddFeltOp>(op)) {
396 return handleBinaryFelt(addOp, [](const auto &lhs, const auto &rhs) { return lhs + rhs; });
397 }
398 if (auto powOp = dyn_cast<felt::PowFeltOp>(op)) {
399 return handleBinaryFelt(powOp, [&](const auto &lhs, const auto &rhs) {
400 return modExp(lhs, rhs, field.prime());
401 });
402 }
403 if (auto andOp = dyn_cast<felt::AndFeltOp>(op)) {
404 return handleBinaryFelt(andOp, [](const auto &lhs, const auto &rhs) { return lhs & rhs; });
405 }
406 if (auto orOp = dyn_cast<felt::OrFeltOp>(op)) {
407 return handleBinaryFelt(orOp, [](const auto &lhs, const auto &rhs) { return lhs | rhs; });
408 }
409 if (auto xorOp = dyn_cast<felt::XorFeltOp>(op)) {
410 return handleBinaryFelt(xorOp, [](const auto &lhs, const auto &rhs) { return lhs ^ rhs; });
411 }
412 if (auto subOp = dyn_cast<felt::SubFeltOp>(op)) {
413 return handleBinaryFelt(subOp, [](const auto &lhs, const auto &rhs) { return lhs - rhs; });
414 }
415 if (auto mulOp = dyn_cast<felt::MulFeltOp>(op)) {
416 return handleBinaryFelt(mulOp, [](const auto &lhs, const auto &rhs) { return lhs * rhs; });
417 }
418 if (auto divOp = dyn_cast<felt::DivFeltOp>(op)) {
419 return handleBinaryFelt(divOp, [&](const auto &lhs, const auto &rhs) {
420 return lhs * field.inv(rhs);
421 });
422 }
423 if (auto uintDivOp = dyn_cast<felt::UnsignedIntDivFeltOp>(op)) {
424 auto lhsValue = lookup(uintDivOp.getLhs(), scope);
425 auto rhsValue = lookup(uintDivOp.getRhs(), scope);
426 if (!lhsValue) {
427 return lhsValue.takeError();
428 }
429 if (!rhsValue) {
430 return rhsValue.takeError();
431 }
432 auto lhs = asFelt(*lhsValue);
433 if (!lhs) {
434 return lhs.takeError();
435 }
436 auto rhs = asFelt(*rhsValue);
437 if (!rhs) {
438 return rhs.takeError();
439 }
440 if (*rhs == 0) {
441 return makeError("felt.uintdiv divisor must be non-zero");
442 }
443 return bind({WitnessVal(*lhs / *rhs)});
444 }
445 if (auto sintDivOp = dyn_cast<felt::SignedIntDivFeltOp>(op)) {
446 auto lhsValue = lookup(sintDivOp.getLhs(), scope);
447 auto rhsValue = lookup(sintDivOp.getRhs(), scope);
448 if (!lhsValue) {
449 return lhsValue.takeError();
450 }
451 if (!rhsValue) {
452 return rhsValue.takeError();
453 }
454 auto lhs = asFelt(*lhsValue);
455 if (!lhs) {
456 return lhs.takeError();
457 }
458 auto rhs = asFelt(*rhsValue);
459 if (!rhs) {
460 return rhs.takeError();
461 }
462 if (*rhs == 0) {
463 return makeError("felt.sintdiv divisor must be non-zero");
464 }
465 return bind({WitnessVal(field.reduce(field.toSigned(*lhs) / field.toSigned(*rhs)))});
466 }
467 if (auto umodOp = dyn_cast<felt::UnsignedModFeltOp>(op)) {
468 auto lhsValue = lookup(umodOp.getLhs(), scope);
469 auto rhsValue = lookup(umodOp.getRhs(), scope);
470 if (!lhsValue) {
471 return lhsValue.takeError();
472 }
473 if (!rhsValue) {
474 return rhsValue.takeError();
475 }
476 auto lhs = asFelt(*lhsValue);
477 if (!lhs) {
478 return lhs.takeError();
479 }
480 auto rhs = asFelt(*rhsValue);
481 if (!rhs) {
482 return rhs.takeError();
483 }
484 if (*rhs == 0) {
485 return makeError("felt.umod divisor must be non-zero");
486 }
487 return bind({WitnessVal(*lhs % *rhs)});
488 }
489 if (auto smodOp = dyn_cast<felt::SignedModFeltOp>(op)) {
490 auto lhsValue = lookup(smodOp.getLhs(), scope);
491 auto rhsValue = lookup(smodOp.getRhs(), scope);
492 if (!lhsValue) {
493 return lhsValue.takeError();
494 }
495 if (!rhsValue) {
496 return rhsValue.takeError();
497 }
498 auto lhs = asFelt(*lhsValue);
499 if (!lhs) {
500 return lhs.takeError();
501 }
502 auto rhs = asFelt(*rhsValue);
503 if (!rhs) {
504 return rhs.takeError();
505 }
506 if (*rhs == 0) {
507 return makeError("felt.smod divisor must be non-zero");
508 }
509 return bind({WitnessVal(field.reduce(field.toSigned(*lhs) % field.toSigned(*rhs)))});
510 }
511 if (auto shrOp = dyn_cast<felt::ShrFeltOp>(op)) {
512 auto lhsValue = lookup(shrOp.getLhs(), scope);
513 auto rhsValue = lookup(shrOp.getRhs(), scope);
514 if (!lhsValue) {
515 return lhsValue.takeError();
516 }
517 if (!rhsValue) {
518 return rhsValue.takeError();
519 }
520 auto lhs = asFelt(*lhsValue);
521 if (!lhs) {
522 return lhs.takeError();
523 }
524 auto rhs = asFelt(*rhsValue);
525 if (!rhs) {
526 return rhs.takeError();
527 }
528 llvm::DynamicAPInt result(0);
529 if (*rhs < llvm::DynamicAPInt(field.bitWidth())) {
530 result = *lhs >> *rhs;
531 }
532 return bind({WitnessVal(result)});
533 }
534 if (auto shlOp = dyn_cast<felt::ShlFeltOp>(op)) {
535 auto lhsValue = lookup(shlOp.getLhs(), scope);
536 auto rhsValue = lookup(shlOp.getRhs(), scope);
537 if (!lhsValue) {
538 return lhsValue.takeError();
539 }
540 if (!rhsValue) {
541 return rhsValue.takeError();
542 }
543 auto lhs = asFelt(*lhsValue);
544 if (!lhs) {
545 return lhs.takeError();
546 }
547 auto rhs = asFelt(*rhsValue);
548 if (!rhs) {
549 return rhs.takeError();
550 }
551 llvm::DynamicAPInt two(2);
552 // It's more efficient to use modExp than native shift left, as for large
553 // exponents, << could allocate large temporaries, whereas modExp will be bounded
554 // by the field prime.
555 return bind({WitnessVal(field.reduce(*lhs * modExp(two, *rhs, field.prime())))});
556 }
557 if (auto negOp = dyn_cast<felt::NegFeltOp>(op)) {
558 auto operand = lookup(negOp.getOperand(), scope);
559 if (!operand) {
560 return operand.takeError();
561 }
562 auto feltValue = asFelt(*operand);
563 if (!feltValue) {
564 return feltValue.takeError();
565 }
566 return bind({WitnessVal(field.reduce(-*feltValue))});
567 }
568 if (auto invOp = dyn_cast<felt::InvFeltOp>(op)) {
569 auto operand = lookup(invOp.getOperand(), scope);
570 if (!operand) {
571 return operand.takeError();
572 }
573 auto feltValue = asFelt(*operand);
574 if (!feltValue) {
575 return feltValue.takeError();
576 }
577 return bind({WitnessVal(field.inv(*feltValue))});
578 }
579 if (auto notOp = dyn_cast<felt::NotFeltOp>(op)) {
580 auto operand = lookup(notOp.getOperand(), scope);
581 if (!operand) {
582 return operand.takeError();
583 }
584 auto feltValue = asFelt(*operand);
585 if (!feltValue) {
586 return feltValue.takeError();
587 }
588 llvm::DynamicAPInt maxMask =
589 (llvm::DynamicAPInt(1) << llvm::DynamicAPInt(field.bitWidth())) - llvm::DynamicAPInt(1);
590 return bind({WitnessVal(field.reduce(maxMask ^ *feltValue))});
591 }
592 // Reduces signed integers to unsigned field elements using Field::reduce.
593 // Negative results are reduced by subtracting from the prime (e.g., -1 -> p - 1).
594 if (auto intToFeltOp = dyn_cast<cast::IntToFeltOp>(op)) {
595 auto operand = lookup(intToFeltOp.getValue(), scope);
596 if (!operand) {
597 return operand.takeError();
598 }
599 if (std::holds_alternative<bool>(*operand)) {
600 return bind({WitnessVal(field.reduce(std::get<bool>(*operand) ? 1 : 0))});
601 }
602 auto integer = asIndex(*operand);
603 if (!integer) {
604 return integer.takeError();
605 }
606 return bind({WitnessVal(field.reduce(*integer))});
607 }
608 // Field elements are unsigned. If the field element would overflow the 64-bit
609 // index, an error is reported.
610 if (auto feltToIndexOp = dyn_cast<cast::FeltToIndexOp>(op)) {
611 auto operand = lookup(feltToIndexOp.getValue(), scope);
612 if (!operand) {
613 return operand.takeError();
614 }
615 auto feltValue = asFelt(*operand);
616 if (!feltValue) {
617 return feltValue.takeError();
618 }
619 auto &felt = *feltValue;
620 if (felt < 0 || felt > std::numeric_limits<int64_t>::max()) {
621 return makeError("felt value does not fit in index");
622 }
623 return bind({WitnessVal(int64_t(felt))});
624 }
625
626 if (auto structNewOp = dyn_cast<component::CreateStructOp>(op)) {
627 auto value = defaultValue(
628 structNewOp.getType(), tables, structNewOp.getOperation(), field, uninitializedBehavior,
629 &rng
630 );
631 if (!value) {
632 return value.takeError();
633 }
634 return bind({*value});
635 }
636 if (auto readMemberOp = dyn_cast<component::MemberReadOp>(op)) {
637 auto componentValue = lookup(readMemberOp.getComponent(), scope);
638 if (!componentValue) {
639 return componentValue.takeError();
640 }
641 auto structValue = asStruct(*componentValue);
642 if (!structValue) {
643 return structValue.takeError();
644 }
645 auto it = (*structValue)->members.find(readMemberOp.getMemberName());
646 if (it == (*structValue)->members.end()) {
647 return makeError("missing struct member");
648 }
649 return bind({it->second});
650 }
651 if (auto writeMemberOp = dyn_cast<component::MemberWriteOp>(op)) {
652 auto componentValue = lookup(writeMemberOp.getComponent(), scope);
653 auto memberValue = lookup(writeMemberOp.getVal(), scope);
654 if (!componentValue) {
655 return componentValue.takeError();
656 }
657 if (!memberValue) {
658 return memberValue.takeError();
659 }
660 auto structValue = asStruct(*componentValue);
661 if (!structValue) {
662 return structValue.takeError();
663 }
664 (*structValue)->members[writeMemberOp.getMemberName()] = *memberValue;
665 return BlockResult {};
666 }
667
668 if (auto newPodOp = dyn_cast<pod::NewPodOp>(op)) {
669 auto podValue = defaultValue(
670 newPodOp.getType(), tables, newPodOp.getOperation(), field, uninitializedBehavior, &rng
671 );
672 if (!podValue) {
673 return podValue.takeError();
674 }
675 auto podRef = asPod(*podValue);
676 if (!podRef) {
677 return podRef.takeError();
678 }
679 auto initValues = newPodOp.getInitializedRecordValues();
680 for (pod::RecordValue init : initValues) {
681 auto value = lookup(init.value, scope);
682 if (!value) {
683 return value.takeError();
684 }
685 (*podRef)->records[init.name] = *value;
686 }
687 return bind({*podRef});
688 }
689 if (auto readPodOp = dyn_cast<pod::ReadPodOp>(op)) {
690 auto podValue = lookup(readPodOp.getPodRef(), scope);
691 if (!podValue) {
692 return podValue.takeError();
693 }
694 auto podRef = asPod(*podValue);
695 if (!podRef) {
696 return podRef.takeError();
697 }
698 auto it = (*podRef)->records.find(readPodOp.getRecordName());
699 if (it == (*podRef)->records.end()) {
700 return makeError("missing pod record");
701 }
702 return bind({it->second});
703 }
704 if (auto writePodOp = dyn_cast<pod::WritePodOp>(op)) {
705 auto podValue = lookup(writePodOp.getPodRef(), scope);
706 auto recordValue = lookup(writePodOp.getValue(), scope);
707 if (!podValue) {
708 return podValue.takeError();
709 }
710 if (!recordValue) {
711 return recordValue.takeError();
712 }
713 auto podRef = asPod(*podValue);
714 if (!podRef) {
715 return podRef.takeError();
716 }
717 (*podRef)->records[writePodOp.getRecordName()] = *recordValue;
718 return BlockResult {};
719 }
720
721 if (auto arrayNewOp = dyn_cast<array::CreateArrayOp>(op)) {
722 auto arrayValue = std::make_shared<ArrayValue>();
723 arrayValue->type = arrayNewOp.getType();
724 if (arrayNewOp.getElements().empty()) {
725 auto elementCount =
726 getStaticShapeElementCount(arrayValue->type.getShape(), "array.create default value");
727 if (!elementCount) {
728 return elementCount.takeError();
729 }
730 arrayValue->elements.reserve(*elementCount);
731 for (size_t i = 0; i < *elementCount; ++i) {
732 auto elem = defaultValue(
733 arrayValue->type.getElementType(), tables, arrayNewOp.getOperation(), field,
734 uninitializedBehavior, &rng
735 );
736 if (!elem) {
737 return elem.takeError();
738 }
739 arrayValue->elements.push_back(*elem);
740 }
741 } else {
742 auto values = collectOperands(arrayNewOp.getElements(), scope);
743 if (!values) {
744 return values.takeError();
745 }
746 arrayValue->elements.assign(values->begin(), values->end());
747 }
748 return bind({arrayValue});
749 }
750 if (auto readArrayOp = dyn_cast<array::ReadArrayOp>(op)) {
751 auto arrayValue = lookup(readArrayOp.getArrRef(), scope);
752 if (!arrayValue) {
753 return arrayValue.takeError();
754 }
755 auto arrayRef = asArray(*arrayValue);
756 if (!arrayRef) {
757 return arrayRef.takeError();
758 }
759 llvm::SmallVector<int64_t> indices;
760 for (mlir::Value indexVal : readArrayOp.getIndices()) {
761 auto value = lookup(indexVal, scope);
762 if (!value) {
763 return value.takeError();
764 }
765 auto index = asIndex(*value);
766 if (!index) {
767 return index.takeError();
768 }
769 indices.push_back(*index);
770 }
771 auto offset =
772 checkedLinearize((*arrayRef)->type.getShape(), indices, "array index out of bounds");
773 if (!offset) {
774 return offset.takeError();
775 }
776 return bind({(*arrayRef)->elements[*offset]});
777 }
778 if (auto writeArrayOp = dyn_cast<array::WriteArrayOp>(op)) {
779 auto arrayValue = lookup(writeArrayOp.getArrRef(), scope);
780 auto rvalue = lookup(writeArrayOp.getRvalue(), scope);
781 if (!arrayValue) {
782 return arrayValue.takeError();
783 }
784 if (!rvalue) {
785 return rvalue.takeError();
786 }
787 auto arrayRef = asArray(*arrayValue);
788 if (!arrayRef) {
789 return arrayRef.takeError();
790 }
791 llvm::SmallVector<int64_t> indices;
792 for (mlir::Value indexVal : writeArrayOp.getIndices()) {
793 auto value = lookup(indexVal, scope);
794 if (!value) {
795 return value.takeError();
796 }
797 auto index = asIndex(*value);
798 if (!index) {
799 return index.takeError();
800 }
801 indices.push_back(*index);
802 }
803 auto offset =
804 checkedLinearize((*arrayRef)->type.getShape(), indices, "array index out of bounds");
805 if (!offset) {
806 return offset.takeError();
807 }
808 (*arrayRef)->elements[*offset] = *rvalue;
809 return BlockResult {};
810 }
811 if (auto extractArrayOp = dyn_cast<array::ExtractArrayOp>(op)) {
812 auto arrayValue = lookup(extractArrayOp.getArrRef(), scope);
813 if (!arrayValue) {
814 return arrayValue.takeError();
815 }
816 auto arrayRef = asArray(*arrayValue);
817 if (!arrayRef) {
818 return arrayRef.takeError();
819 }
820 llvm::SmallVector<int64_t> indices;
821 for (mlir::Value indexVal : extractArrayOp.getIndices()) {
822 auto value = lookup(indexVal, scope);
823 if (!value) {
824 return value.takeError();
825 }
826 auto index = asIndex(*value);
827 if (!index) {
828 return index.takeError();
829 }
830 indices.push_back(*index);
831 }
832 llvm::ArrayRef<int64_t> shape = (*arrayRef)->type.getShape();
833 if (indices.size() >= shape.size()) {
834 return makeError("array.extract indices exceed array rank");
835 }
836 auto subArraySize =
837 getStaticShapeElementCount(shape.drop_front(indices.size()), "array.extract shape");
838 if (!subArraySize) {
839 return subArraySize.takeError();
840 }
841 auto prefixOffset =
842 checkedLinearize(shape.take_front(indices.size()), indices, "array index out of bounds");
843 if (!prefixOffset) {
844 return prefixOffset.takeError();
845 }
846 bool baseOverflow = false;
847 size_t base = llvm::SaturatingMultiply(*prefixOffset, *subArraySize, &baseOverflow);
848 if (baseOverflow) {
849 return makeError("array.extract element offset would overflow size_t");
850 }
851 auto subArray = std::make_shared<ArrayValue>();
852 subArray->type = extractArrayOp.getType();
853 subArray->elements.reserve(*subArraySize);
854 for (size_t i = 0; i < *subArraySize; ++i) {
855 bool overflow = false;
856 size_t elementOffset = llvm::SaturatingAdd(base, i, &overflow);
857 if (overflow) {
858 return makeError("array.extract element offset would overflow size_t");
859 }
860 subArray->elements.push_back((*arrayRef)->elements[elementOffset]);
861 }
862 return bind({subArray});
863 }
864 if (auto insertArrayOp = dyn_cast<array::InsertArrayOp>(op)) {
865 auto arrayValue = lookup(insertArrayOp.getArrRef(), scope);
866 auto subArrayValue = lookup(insertArrayOp.getRvalue(), scope);
867 if (!arrayValue) {
868 return arrayValue.takeError();
869 }
870 if (!subArrayValue) {
871 return subArrayValue.takeError();
872 }
873 auto arrayRef = asArray(*arrayValue);
874 auto subArrayRef = asArray(*subArrayValue);
875 if (!arrayRef) {
876 return arrayRef.takeError();
877 }
878 if (!subArrayRef) {
879 return subArrayRef.takeError();
880 }
881 llvm::SmallVector<int64_t> indices;
882 for (mlir::Value indexVal : insertArrayOp.getIndices()) {
883 auto value = lookup(indexVal, scope);
884 if (!value) {
885 return value.takeError();
886 }
887 auto index = asIndex(*value);
888 if (!index) {
889 return index.takeError();
890 }
891 indices.push_back(*index);
892 }
893 llvm::ArrayRef<int64_t> shape = (*arrayRef)->type.getShape();
894 size_t subArraySize = (*subArrayRef)->elements.size();
895 auto prefixOffset =
896 checkedLinearize(shape.take_front(indices.size()), indices, "array index out of bounds");
897 if (!prefixOffset) {
898 return prefixOffset.takeError();
899 }
900 bool baseOverflow = false;
901 size_t base = llvm::SaturatingMultiply(*prefixOffset, subArraySize, &baseOverflow);
902 if (baseOverflow) {
903 return makeError("array.insert element offset would overflow size_t");
904 }
905 for (size_t i = 0; i < subArraySize; ++i) {
906 bool overflow = false;
907 size_t elementOffset = llvm::SaturatingAdd(base, i, &overflow);
908 if (overflow) {
909 return makeError("array.insert element offset would overflow size_t");
910 }
911 (*arrayRef)->elements[elementOffset] = (*subArrayRef)->elements[i];
912 }
913 return BlockResult {};
914 }
915 if (auto arrayLenOp = dyn_cast<array::ArrayLengthOp>(op)) {
916 auto dimValue = lookup(arrayLenOp.getDim(), scope);
917 if (!dimValue) {
918 return dimValue.takeError();
919 }
920 auto dim = asIndex(*dimValue);
921 if (!dim) {
922 return dim.takeError();
923 }
924 llvm::ArrayRef<int64_t> shape = arrayLenOp.getArrRefType().getShape();
925 auto dimIndex = checkedShapeDimToSize(*dim, "array.len dimension");
926 if (!dimIndex) {
927 return dimIndex.takeError();
928 }
929 if (*dimIndex >= shape.size()) {
930 return makeError("array.len dimension out of bounds");
931 }
932 return bind({WitnessVal(shape[*dimIndex])});
933 }
934
935 if (auto callOp = dyn_cast<function::CallOp>(op)) {
936 if (callOp.getTemplateParams() || !callOp.getMapOperands().empty()) {
937 return makeError("templated or affine-instantiated calls are not supported in llzk-witgen");
938 }
939 auto callee = lookupTopLevelSymbol<function::FuncDefOp>(tables, callOp.getCalleeAttr(), &op);
940 if (failed(callee)) {
941 return makeError("could not resolve called function");
942 }
943 auto args = collectOperands(callOp.getArgOperands(), scope);
944 if (!args) {
945 return args.takeError();
946 }
947 auto results = run(callee->get(), *args);
948 if (!results) {
949 return results.takeError();
950 }
951 return bind(*results);
952 }
953
954 auto handleBinaryIndex = [&](auto arithOp, auto fn) -> llvm::Expected<BlockResult> {
955 auto lhs = lookup(arithOp.getLhs(), scope);
956 auto rhs = lookup(arithOp.getRhs(), scope);
957 if (!lhs) {
958 return lhs.takeError();
959 }
960 if (!rhs) {
961 return rhs.takeError();
962 }
963 auto lhsValue = asIndex(*lhs);
964 if (!lhsValue) {
965 return lhsValue.takeError();
966 }
967 auto rhsValue = asIndex(*rhs);
968 if (!rhsValue) {
969 return rhsValue.takeError();
970 }
971 return bind({WitnessVal(fn(*lhsValue, *rhsValue))});
972 };
973
974 if (auto addIOp = dyn_cast<arith::AddIOp>(op)) {
975 return handleBinaryIndex(addIOp, [](int64_t lhs, int64_t rhs) { return lhs + rhs; });
976 }
977 if (auto subIOp = dyn_cast<arith::SubIOp>(op)) {
978 return handleBinaryIndex(subIOp, [](int64_t lhs, int64_t rhs) { return lhs - rhs; });
979 }
980 if (auto mulIOp = dyn_cast<arith::MulIOp>(op)) {
981 return handleBinaryIndex(mulIOp, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
982 }
983 if (auto divUIOp = dyn_cast<arith::DivUIOp>(op)) {
984 return handleBinaryIndex(divUIOp, [](int64_t lhs, int64_t rhs) {
985 // Unsigned division directly interprets int64 as unsigned value.
986 auto divRes = static_cast<uint64_t>(lhs) / static_cast<uint64_t>(rhs);
987 return static_cast<int64_t>(divRes);
988 });
989 }
990 if (auto cmpIOp = dyn_cast<arith::CmpIOp>(op)) {
991 return handleBinaryIndex(cmpIOp, [&cmpIOp](int64_t lhs, int64_t rhs) -> bool {
992 switch (cmpIOp.getPredicate()) {
993 case arith::CmpIPredicate::eq:
994 return lhs == rhs;
995 case arith::CmpIPredicate::ne:
996 return lhs != rhs;
997 case arith::CmpIPredicate::slt:
998 return lhs < rhs;
999 case arith::CmpIPredicate::sle:
1000 return lhs <= rhs;
1001 case arith::CmpIPredicate::sgt:
1002 return lhs > rhs;
1003 case arith::CmpIPredicate::sge:
1004 return lhs >= rhs;
1005 // Unsigned comparisons directly interprets int64 as unsigned value.
1006 case arith::CmpIPredicate::ult:
1007 return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
1008 case arith::CmpIPredicate::ule:
1009 return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
1010 case arith::CmpIPredicate::ugt:
1011 return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
1012 case arith::CmpIPredicate::uge:
1013 return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
1014 }
1015 llvm_unreachable("unknown comparison predicate");
1016 });
1017 }
1018
1019 if (auto selectOp = dyn_cast<arith::SelectOp>(op)) {
1020 auto cond = lookup(selectOp.getCondition(), scope);
1021 auto trueValue = lookup(selectOp.getTrueValue(), scope);
1022 auto falseValue = lookup(selectOp.getFalseValue(), scope);
1023 if (!cond) {
1024 return cond.takeError();
1025 }
1026 if (!trueValue) {
1027 return trueValue.takeError();
1028 }
1029 if (!falseValue) {
1030 return falseValue.takeError();
1031 }
1032 auto condition = asBool(*cond);
1033 if (!condition) {
1034 return condition.takeError();
1035 }
1036 return bind({*condition ? *trueValue : *falseValue});
1037 }
1038
1039 if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
1040 auto cond = lookup(ifOp.getCondition(), scope);
1041 if (!cond) {
1042 return cond.takeError();
1043 }
1044 auto condition = asBool(*cond);
1045 if (!condition) {
1046 return condition.takeError();
1047 }
1048 if (!*condition && ifOp.getNumResults() == 0 && ifOp.getElseRegion().empty()) {
1049 return bind({});
1050 }
1051 Region &region = *condition ? ifOp.getThenRegion() : ifOp.getElseRegion();
1052 auto result = runRegion(region, {}, scope);
1053 if (!result) {
1054 return result.takeError();
1055 }
1056 return bind(result->values);
1057 }
1058
1059 if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1060 auto lowerBoundValue = lookup(forOp.getLowerBound(), scope);
1061 auto upperBoundValue = lookup(forOp.getUpperBound(), scope);
1062 auto stepValue = lookup(forOp.getStep(), scope);
1063 if (!lowerBoundValue) {
1064 return lowerBoundValue.takeError();
1065 }
1066 if (!upperBoundValue) {
1067 return upperBoundValue.takeError();
1068 }
1069 if (!stepValue) {
1070 return stepValue.takeError();
1071 }
1072 auto lowerBound = asIndex(*lowerBoundValue);
1073 if (!lowerBound) {
1074 return lowerBound.takeError();
1075 }
1076 auto upperBound = asIndex(*upperBoundValue);
1077 if (!upperBound) {
1078 return upperBound.takeError();
1079 }
1080 auto step = asIndex(*stepValue);
1081 if (!step) {
1082 return step.takeError();
1083 }
1084 auto iterValuesOrErr = collectOperands(forOp.getInitArgs(), scope);
1085 if (!iterValuesOrErr) {
1086 return iterValuesOrErr.takeError();
1087 }
1088 llvm::SmallVector<WitnessVal> iterValues = std::move(*iterValuesOrErr);
1089
1090 if (usesUnsignedCmp(forOp)) {
1091 // Unsigned comparison directly interprets int64 as unsigned value.
1092 auto lowerBoundUIntValue = static_cast<uint64_t>(*lowerBound);
1093 auto upperBoundUIntValue = static_cast<uint64_t>(*upperBound);
1094 auto stepUInt = static_cast<uint64_t>(*step);
1095 for (uint64_t iv = lowerBoundUIntValue, ub = upperBoundUIntValue, unsignedStep = stepUInt;
1096 iv < ub; iv += unsignedStep) {
1097 auto signedIV = checkedCast<int64_t>(iv);
1098 if (!signedIV) {
1099 return signedIV.takeError();
1100 }
1101 llvm::SmallVector<WitnessVal> regionArgs;
1102 regionArgs.push_back(WitnessVal(*signedIV));
1103 regionArgs.append(iterValues.begin(), iterValues.end());
1104 auto result = runRegion(forOp.getRegion(), regionArgs, scope);
1105 if (!result) {
1106 return result.takeError();
1107 }
1108 iterValues = std::move(result->values);
1109 }
1110 } else {
1111 for (int64_t iv = *lowerBound; iv < *upperBound; iv += *step) {
1112 llvm::SmallVector<WitnessVal> regionArgs;
1113 regionArgs.push_back(WitnessVal(iv));
1114 regionArgs.append(iterValues.begin(), iterValues.end());
1115 auto result = runRegion(forOp.getRegion(), regionArgs, scope);
1116 if (!result) {
1117 return result.takeError();
1118 }
1119 iterValues = std::move(result->values);
1120 }
1121 }
1122 return bind(iterValues);
1123 }
1124
1125 if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
1126 auto iterValuesOrErr = collectOperands(whileOp.getInits(), scope);
1127 if (!iterValuesOrErr) {
1128 return iterValuesOrErr.takeError();
1129 }
1130 llvm::SmallVector<WitnessVal> iterValues = std::move(*iterValuesOrErr);
1131 while (true) {
1132 auto beforeResult = runRegion(whileOp.getBefore(), iterValues, scope);
1133 if (!beforeResult) {
1134 return beforeResult.takeError();
1135 }
1136 if (!beforeResult->terminated) {
1137 return makeError("scf.while before region must terminate with scf.condition");
1138 }
1139 if (beforeResult->values.empty()) {
1140 return makeError("scf.while before region did not produce a condition");
1141 }
1142
1143 auto condition = asBool(beforeResult->values.front());
1144 if (!condition) {
1145 return condition.takeError();
1146 }
1147
1148 llvm::SmallVector<WitnessVal> nextValues;
1149 nextValues.append(beforeResult->values.begin() + 1, beforeResult->values.end());
1150 if (!*condition) {
1151 return bind(nextValues);
1152 }
1153
1154 auto afterResult = runRegion(whileOp.getAfter(), nextValues, scope);
1155 if (!afterResult) {
1156 return afterResult.takeError();
1157 }
1158 if (!afterResult->terminated) {
1159 return makeError("scf.while after region must terminate with scf.yield");
1160 }
1161 iterValues = std::move(afterResult->values);
1162 }
1163 }
1164
1165 return makeError(llvm::Twine("unsupported op in llzk-witgen: ") + op.getName().getStringRef());
1166 }
1167};
1168
1169} // namespace
1170
1172llvm::Expected<llvm::SmallVector<WitnessVal>>
1173FunctionInterpreter::run(function::FuncDefOp funcOp, ArrayRef<WitnessVal> args) {
1174 return InvocationInterpreter(moduleOp, tables, field, uninitializedBehavior, rng)
1175 .run(funcOp, args);
1176}
1177
1178} // namespace llzk::witgen
This file implements helper methods for constructing DynamicAPInts.
This file defines methods symbol lookup across LLZK operations and included files.
Information about the prime finite field used for the interval analysis.
Definition Field.h:36
::mlir::Region & getBody()
Definition Ops.h.inc:690
llvm::Expected< llvm::SmallVector< WitnessVal > > run(llzk::function::FuncDefOp funcOp, mlir::ArrayRef< WitnessVal > args)
Run a function with concrete arguments and return its result values.
FunctionInterpreter(mlir::ModuleOp moduleOp, mlir::SymbolTableCollection &tables, const llzk::Field &field, UninitializedBehavior uninitializedBehavior, std::mt19937_64 rng)
Build an interpreter for one module and field configuration.
llvm::Expected< T > checkedCast(U u)
Definition WitgenUtils.h:28
llvm::Expected< PodValueRef > asPod(const WitnessVal &value)
Require a POD value from the runtime variant.
llvm::Expected< bool > asBool(const WitnessVal &value)
Require a boolean value from the runtime variant.
UninitializedBehavior
Control how witgen materializes uninitialized/default values.
Definition ValueModel.h:55
llvm::Expected< size_t > getStaticShapeElementCount(llvm::ArrayRef< int64_t > shape, llvm::StringRef context)
Return the static element count for one shape, rejecting dynamic sizes.
llvm::Expected< WitnessVal > defaultValue(Type type, SymbolTableCollection &tables, Operation *origin, const Field &field, UninitializedBehavior behavior, std::mt19937_64 *rng)
Build a default value for a supported LLZK type.
llvm::Expected< int64_t > asIndex(const WitnessVal &value)
Require an index value from the runtime variant.
llvm::Expected< size_t > checkedShapeDimToSize(int64_t dim, llvm::StringRef context)
Convert one static dimension to size_t, rejecting dynamic or invalid sizes.
std::variant< std::monostate, bool, int64_t, llvm::DynamicAPInt, ArrayValueRef, PodValueRef, StructValueRef > WitnessVal
Runtime value representation used by the tool-local interpreter.
Definition ValueModel.h:51
llvm::Expected< llvm::DynamicAPInt > asFelt(const WitnessVal &value)
Require a felt value from the runtime variant.
llvm::Error makeError(const llvm::Twine &msg)
Build a string-backed error for user-facing witgen failures.
Definition Errors.h:18
llvm::Expected< StructValueRef > asStruct(const WitnessVal &value)
Require a struct value from the runtime variant.
llvm::Expected< ArrayValueRef > asArray(const WitnessVal &value)
Require an array value from the runtime variant.
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
DynamicAPInt modExp(const DynamicAPInt &base, const DynamicAPInt &exp, const DynamicAPInt &mod)
ExpressionValue notOp(const llvm::SMTSolverRef &solver, const ExpressionValue &val)