LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
WitgenLowering.cpp
Go to the documentation of this file.
1//===-- LLZKWitgenLoweringPass.cpp -----------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2026 Project LLZK
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#include "WitgenLowering.h"
11
12#include "WitgenDriver.h"
13#include "WitgenUtils.h"
14#include "WitnessSelection.h"
15
28#include "llzk/Util/Compare.h"
29#include "llzk/Util/Constants.h"
31#include "llzk/Util/Field.h"
33
34#include <mlir/Conversion/AffineToStandard/AffineToStandard.h>
35#include <mlir/Dialect/Arith/IR/Arith.h>
36#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
37#include <mlir/Dialect/Func/IR/FuncOps.h>
38#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
39#include <mlir/Dialect/MemRef/IR/MemRef.h>
40#include <mlir/Dialect/SCF/IR/SCF.h>
41#include <mlir/Dialect/Utils/IndexingUtils.h>
42#include <mlir/IR/Builders.h>
43#include <mlir/IR/BuiltinAttributes.h>
44#include <mlir/IR/BuiltinOps.h>
45#include <mlir/IR/SymbolTable.h>
46#include <mlir/Pass/PassManager.h>
47#include <mlir/Transforms/Passes.h>
48
49#include <llvm/ADT/APInt.h>
50#include <llvm/ADT/STLExtras.h>
51#include <llvm/ADT/SmallString.h>
52#include <llvm/ADT/StringMap.h>
53#include <llvm/ADT/TypeSwitch.h>
54#include <llvm/Support/MathExtras.h>
55
56#include <limits>
57
58using namespace mlir;
59
60namespace llzk::witgen {
61namespace {
62
64struct LoweredValue {
65 Type sourceType;
66 llvm::SmallVector<Value> leaves;
67};
68
70static FailureOr<std::reference_wrapper<const Field>> getModuleField(ModuleOp moduleOp) {
71 FieldSet fields;
72 if (failed(collectFields(moduleOp.getOperation(), fields, false))) {
73 moduleOp.emitError("failed to collect fields for llzk-witgen lowering");
74 return failure();
75 }
76 if (fields.size() != 1) {
77 moduleOp.emitError("llzk-witgen execution-engine lowering requires exactly one field");
78 return failure();
79 }
80 return *fields.begin();
81}
82
84static std::string mangleFunctionName(function::FuncDefOp funcOp) {
85 auto symbolRef = funcOp.getFullyQualifiedName(false);
86 llvm::SmallString<128> result("__llzk_witgen_");
87 for (StringRef piece : getNames(symbolRef)) {
88 if (!result.empty() && result.back() != '_') {
89 result += "__";
90 }
91 for (char c : piece) {
92 result += llvm::isAlnum(c) ? c : '_';
93 }
94 }
95 return std::string(result);
96}
97
99static Value makeIndexConstant(OpBuilder &builder, Location loc, int64_t value) {
100 return builder.create<arith::ConstantIndexOp>(loc, value).getResult();
101}
102
104static Value makeOneFelt(OpBuilder &builder, Location loc, const Field &field) {
105 return builder.create<arith::ConstantOp>(
106 loc, IntegerAttr::get(IntegerType::get(builder.getContext(), field.bitWidth()), 1)
107 );
108}
109
111static FailureOr<Type> lowerScalarType(MLIRContext *context, Type type, const Field &field) {
112 if (isa<felt::FeltType>(type)) {
113 return IntegerType::get(context, field.bitWidth());
114 }
115 if (isa<IndexType>(type)) {
116 return type;
117 }
118 if (auto intType = dyn_cast<IntegerType>(type)) {
119 if (intType.getWidth() == 1) {
120 return intType;
121 }
122 }
123 return failure();
124}
125
127static bool isScalarType(Type type) {
128 return isa<felt::FeltType, IndexType>(type) ||
129 (isa<IntegerType>(type) && mlir::cast<IntegerType>(type).getWidth() == 1);
130}
131
133static LogicalResult flattenTypeLeaves(
134 Type type, SymbolTableCollection &tables, Operation *origin, const Field &field,
135 SmallVectorImpl<Type> &out, llvm::ArrayRef<int64_t> prefixShape = {}, bool storage = false
136) {
137 auto emitScalarLeaf = [&](Type leafType) {
138 auto lowered = lowerScalarType(origin->getContext(), leafType, field);
139 if (failed(lowered)) {
140 return failure();
141 }
142 if (!storage && prefixShape.empty()) {
143 out.push_back(*lowered);
144 return success();
145 }
146 llvm::SmallVector<int64_t> shape(prefixShape.begin(), prefixShape.end());
147 if (shape.empty()) {
148 shape.push_back(1);
149 }
150 out.push_back(MemRefType::get(shape, *lowered));
151 return success();
152 };
153
154 if (isScalarType(type)) {
155 return emitScalarLeaf(type);
156 }
157
158 if (auto arrayType = dyn_cast<array::ArrayType>(type)) {
159 llvm::SmallVector<int64_t> newPrefix(prefixShape.begin(), prefixShape.end());
160 newPrefix.append(arrayType.getShape().begin(), arrayType.getShape().end());
161 return flattenTypeLeaves(
162 arrayType.getElementType(), tables, origin, field, out, newPrefix, true
163 );
164 }
165
166 if (auto podType = dyn_cast<pod::PodType>(type)) {
167 for (pod::RecordAttr record : podType.getRecords()) {
168 if (failed(
169 flattenTypeLeaves(record.getType(), tables, origin, field, out, prefixShape, true)
170 )) {
171 return failure();
172 }
173 }
174 return success();
175 }
176
177 if (auto structType = dyn_cast<component::StructType>(type)) {
178 auto def = structType.getDefinition(tables, origin);
179 if (failed(def)) {
180 origin->emitError("could not resolve struct type during witgen lowering");
181 return failure();
182 }
183 for (component::MemberDefOp member : def->get().getMemberDefs()) {
184 if (failed(
185 flattenTypeLeaves(member.getType(), tables, origin, field, out, prefixShape, true)
186 )) {
187 return failure();
188 }
189 }
190 return success();
191 }
192
193 origin->emitError("unsupported type in llzk-witgen lowering: ") << type;
194 return failure();
195}
196
198static MemRefType
199getStridedMemRefType(MLIRContext *context, ArrayRef<int64_t> shape, Type elementType) {
200 SmallVector<int64_t> strides(shape.size(), ShapedType::kDynamic);
201 return MemRefType::get(
202 shape, elementType, StridedLayoutAttr::get(context, ShapedType::kDynamic, strides)
203 );
204}
205
207static LogicalResult flattenABILeafTypes(
208 Type type, SymbolTableCollection &tables, Operation *origin, const Field &field,
209 SmallVectorImpl<Type> &out, size_t prefixRank = 0, bool aggregateStorage = false
210) {
211 auto emitScalarLeaf = [&](Type leafType) {
212 auto lowered = lowerScalarType(origin->getContext(), leafType, field);
213 if (failed(lowered)) {
214 return failure();
215 }
216 if (!aggregateStorage && prefixRank == 0) {
217 out.push_back(*lowered);
218 return success();
219 }
220 SmallVector<int64_t> shape;
221 if (prefixRank == 0) {
222 shape.push_back(1);
223 } else {
224 shape.assign(prefixRank, ShapedType::kDynamic);
225 }
226 out.push_back(getStridedMemRefType(origin->getContext(), shape, *lowered));
227 return success();
228 };
229
230 if (isScalarType(type)) {
231 return emitScalarLeaf(type);
232 }
233
234 if (auto arrayType = dyn_cast<array::ArrayType>(type)) {
235 return flattenABILeafTypes(
236 arrayType.getElementType(), tables, origin, field, out, prefixRank + arrayType.getRank(),
237 true
238 );
239 }
240
241 if (auto podType = dyn_cast<pod::PodType>(type)) {
242 for (pod::RecordAttr record : podType.getRecords()) {
243 if (failed(
244 flattenABILeafTypes(record.getType(), tables, origin, field, out, prefixRank, true)
245 )) {
246 return failure();
247 }
248 }
249 return success();
250 }
251
252 if (auto structType = dyn_cast<component::StructType>(type)) {
253 auto def = structType.getDefinition(tables, origin);
254 if (failed(def)) {
255 origin->emitError("could not resolve struct type during witgen lowering");
256 return failure();
257 }
258 for (component::MemberDefOp member : def->get().getMemberDefs()) {
259 if (failed(
260 flattenABILeafTypes(member.getType(), tables, origin, field, out, prefixRank, true)
261 )) {
262 return failure();
263 }
264 }
265 return success();
266 }
267
268 origin->emitError("unsupported type in llzk-witgen lowering: ") << type;
269 return failure();
270}
271
273static FailureOr<size_t>
274getLeafCount(Type type, SymbolTableCollection &tables, Operation *origin, const Field &field) {
275 SmallVector<Type> leaves;
276 if (failed(flattenTypeLeaves(type, tables, origin, field, leaves))) {
277 return failure();
278 }
279 return leaves.size();
280}
281
283static FailureOr<SmallVector<Type>>
284getLeafTypes(Type type, SymbolTableCollection &tables, Operation *origin, const Field &field) {
285 SmallVector<Type> leaves;
286 if (failed(flattenTypeLeaves(type, tables, origin, field, leaves))) {
287 return failure();
288 }
289 return leaves;
290}
291
293static FailureOr<SmallVector<Type>>
294getABILeafTypes(Type type, SymbolTableCollection &tables, Operation *origin, const Field &field) {
295 SmallVector<Type> leaves;
296 if (failed(flattenABILeafTypes(type, tables, origin, field, leaves))) {
297 return failure();
298 }
299 return leaves;
300}
301
303static FailureOr<std::pair<size_t, size_t>> getNamedLeafSpan(
304 Type ownerType, StringRef name, SymbolTableCollection &tables, Operation *origin,
305 const Field &field
306) {
307 if (auto podType = dyn_cast<pod::PodType>(ownerType)) {
308 size_t running = 0;
309 for (pod::RecordAttr record : podType.getRecords()) {
310 auto count = getLeafCount(record.getType(), tables, origin, field);
311 if (failed(count)) {
312 return failure();
313 }
314 if (record.getName().getValue() == name) {
315 return std::pair<size_t, size_t> {running, *count};
316 }
317 running += *count;
318 }
319 }
320
321 if (auto structType = dyn_cast<component::StructType>(ownerType)) {
322 auto def = structType.getDefinition(tables, origin);
323 if (failed(def)) {
324 origin->emitError("could not resolve struct type during witgen lowering");
325 return failure();
326 }
327 size_t running = 0;
328 for (component::MemberDefOp member : def->get().getMemberDefs()) {
329 auto count = getLeafCount(member.getType(), tables, origin, field);
330 if (failed(count)) {
331 return failure();
332 }
333 if (member.getSymName() == name) {
334 return std::pair<size_t, size_t> {running, *count};
335 }
336 running += *count;
337 }
338 }
339
340 origin->emitError("could not resolve aggregate member/record @") << name;
341 return failure();
342}
343
345static FailureOr<Type>
346getNamedSubType(Type ownerType, StringRef name, SymbolTableCollection &tables, Operation *origin) {
347 if (auto podType = dyn_cast<pod::PodType>(ownerType)) {
348 for (pod::RecordAttr record : podType.getRecords()) {
349 if (record.getName().getValue() == name) {
350 return record.getType();
351 }
352 }
353 }
354 if (auto structType = dyn_cast<component::StructType>(ownerType)) {
355 auto def = structType.getDefinition(tables, origin);
356 if (failed(def)) {
357 origin->emitError("could not resolve struct type during witgen lowering");
358 return failure();
359 }
360 for (component::MemberDefOp member : def->get().getMemberDefs()) {
361 if (member.getSymName() == name) {
362 return member.getType();
363 }
364 }
365 }
366 origin->emitError("could not resolve aggregate member/record @") << name;
367 return failure();
368}
369
371static FailureOr<Value> createZeroMemRef(OpBuilder &builder, Location loc, MemRefType memrefType) {
372 auto elementCount = getStaticElementCount(memrefType, "witgen zero memref");
373 if (!elementCount) {
374 emitError(loc) << llvm::toString(elementCount.takeError());
375 return failure();
376 }
377 Value alloc = builder.create<memref::AllocOp>(loc, memrefType);
378 auto elementType = memrefType.getElementType();
379 Value zero;
380 if (isa<IndexType>(elementType)) {
381 zero = builder.create<arith::ConstantIndexOp>(loc, 0);
382 } else {
383 zero = builder.create<arith::ConstantOp>(
384 loc, IntegerAttr::get(mlir::cast<IntegerType>(elementType), 0)
385 );
386 }
387 auto strides = mlir::computeStrides(memrefType.getShape());
388 for (size_t flat = 0; flat < *elementCount; ++flat) {
389 auto flatSigned = checkedCast<int64_t>(flat);
390 if (!flatSigned) {
391 emitError(loc) << llvm::toString(flatSigned.takeError());
392 return failure();
393 }
394 SmallVector<Value> indices;
395 for (int64_t index : mlir::delinearize(*flatSigned, strides)) {
396 indices.push_back(makeIndexConstant(builder, loc, index));
397 }
398 builder.create<memref::StoreOp>(loc, zero, alloc, indices);
399 }
400 return alloc;
401}
402
404static FailureOr<Value> createRandomMemRef(
405 OpBuilder &builder, Location loc, MemRefType memrefType, const Field &field,
406 std::mt19937_64 &rng
407) {
408 auto elementCount = getStaticElementCount(memrefType, "witgen random memref");
409 if (!elementCount) {
410 emitError(loc) << llvm::toString(elementCount.takeError());
411 return failure();
412 }
413 Value alloc = builder.create<memref::AllocOp>(loc, memrefType);
414 auto elementType = memrefType.getElementType();
415 auto strides = mlir::computeStrides(memrefType.getShape());
416 for (size_t flat = 0; flat < *elementCount; ++flat) {
417 auto flatSigned = checkedCast<int64_t>(flat);
418 if (!flatSigned) {
419 emitError(loc) << llvm::toString(flatSigned.takeError());
420 return failure();
421 }
422 SmallVector<Value> indices;
423 for (int64_t index : mlir::delinearize(*flatSigned, strides)) {
424 indices.push_back(makeIndexConstant(builder, loc, index));
425 }
426 if (isa<IndexType>(elementType)) {
427 auto value = randomIndexValue(rng);
428 builder.create<memref::StoreOp>(
429 loc, builder.create<arith::ConstantIndexOp>(loc, value), alloc, indices
430 );
431 continue;
432 }
433 auto intType = mlir::cast<IntegerType>(elementType);
434 if (intType.getWidth() == 1) {
435 builder.create<memref::StoreOp>(
436 loc,
437 builder.create<arith::ConstantOp>(
438 loc, IntegerAttr::get(intType, APInt(1, randomBoolValue(rng)))
439 ),
440 alloc, indices
441 );
442 continue;
443 }
444 auto candidate = randomFieldElement(rng, field);
445 builder.create<memref::StoreOp>(
446 loc,
447 builder.create<arith::ConstantOp>(
448 loc, IntegerAttr::get(intType, llzk::toExactWidthAPInt(candidate, intType.getWidth()))
449 ),
450 alloc, indices
451 );
452 }
453 return alloc;
454}
455
457static FailureOr<LoweredValue> createDefaultValue(
458 OpBuilder &builder, Location loc, Type type, SymbolTableCollection &tables, Operation *origin,
459 const Field &field, UninitializedBehavior behavior, std::mt19937_64 &rng
460) {
461 LoweredValue lowered {type, {}};
462 auto leafTypes = getLeafTypes(type, tables, origin, field);
463 if (failed(leafTypes)) {
464 return failure();
465 }
466 for (Type leafType : *leafTypes) {
467 if (behavior == UninitializedBehavior::Fail) {
468 origin->emitError(
469 "fail-mode default materialization is unsupported in witgen lowering because it would "
470 "hide uninitialized reads"
471 );
472 return failure();
473 }
474 if (behavior == UninitializedBehavior::Random) {
475 if (auto memrefType = dyn_cast<MemRefType>(leafType)) {
476 auto randomMemRef = createRandomMemRef(builder, loc, memrefType, field, rng);
477 if (failed(randomMemRef)) {
478 return failure();
479 }
480 lowered.leaves.push_back(*randomMemRef);
481 continue;
482 }
483 if (isa<IndexType>(leafType)) {
484 lowered.leaves.push_back(
485 builder.create<arith::ConstantIndexOp>(loc, randomIndexValue(rng))
486 );
487 continue;
488 }
489 auto intType = mlir::cast<IntegerType>(leafType);
490 if (intType.getWidth() == 1) {
491 lowered.leaves.push_back(builder.create<arith::ConstantOp>(
492 loc, IntegerAttr::get(intType, APInt(1, randomBoolValue(rng)))
493 ));
494 continue;
495 }
496 auto candidate = randomFieldElement(rng, field);
497 lowered.leaves.push_back(builder.create<arith::ConstantOp>(
498 loc, IntegerAttr::get(intType, llzk::toExactWidthAPInt(candidate, intType.getWidth()))
499 ));
500 continue;
501 }
502 if (auto memrefType = dyn_cast<MemRefType>(leafType)) {
503 auto zeroMemRef = createZeroMemRef(builder, loc, memrefType);
504 if (failed(zeroMemRef)) {
505 return failure();
506 }
507 lowered.leaves.push_back(*zeroMemRef);
508 continue;
509 }
510 if (isa<IndexType>(leafType)) {
511 lowered.leaves.push_back(builder.create<arith::ConstantIndexOp>(loc, 0));
512 continue;
513 }
514 lowered.leaves.push_back(builder.create<arith::ConstantOp>(
515 loc, IntegerAttr::get(mlir::cast<IntegerType>(leafType), 0)
516 ));
517 }
518 return lowered;
519}
520
522static Value normalizeWideValue(
523 OpBuilder &builder, Location loc, Value wideValue, unsigned dstWidth, const Field &field
524) {
525 auto wideType = mlir::cast<IntegerType>(wideValue.getType());
526 Value modulus = builder.create<arith::ConstantOp>(
527 loc, field.getPrimeAttr(builder.getContext(), wideType.getWidth())
528 );
529 Value reduced = builder.create<arith::RemUIOp>(loc, wideValue, modulus);
530 return builder.create<arith::TruncIOp>(
531 loc, IntegerType::get(builder.getContext(), dstWidth), reduced
532 );
533}
534
536static Value normalizeSignedWideValue(
537 OpBuilder &builder, Location loc, Value wideValue, unsigned dstWidth, const Field &field
538) {
539 auto wideType = mlir::cast<IntegerType>(wideValue.getType());
540 Value modulus = builder.create<arith::ConstantOp>(
541 loc, field.getPrimeAttr(builder.getContext(), wideType.getWidth())
542 );
543 Value reduced = builder.create<arith::RemSIOp>(loc, wideValue, modulus);
544 Value zero = builder.create<arith::ConstantOp>(loc, IntegerAttr::get(wideType, 0));
545 Value isNegative = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, reduced, zero);
546 Value adjusted = builder.create<arith::AddIOp>(loc, reduced, modulus);
547 Value canonical = builder.create<arith::SelectOp>(loc, isNegative, adjusted, reduced);
548 return builder.create<arith::TruncIOp>(
549 loc, IntegerType::get(builder.getContext(), dstWidth), canonical
550 );
551}
552
554static Value
555lowerFeltToSignedWide(OpBuilder &builder, Location loc, Value operand, const Field &field) {
556 unsigned width = field.bitWidth();
557 unsigned wideWidth = width + 1;
558 auto feltType = IntegerType::get(builder.getContext(), width);
559 auto wideType = IntegerType::get(builder.getContext(), wideWidth);
560 Value operandWide = builder.create<arith::ExtUIOp>(loc, wideType, operand);
561 Value prime =
562 builder.create<arith::ConstantOp>(loc, field.getPrimeAttr(builder.getContext(), wideWidth));
563 Value half = builder.create<arith::ConstantOp>(
564 loc, IntegerAttr::get(feltType, toExactWidthAPInt(field.half(), width))
565 );
566 Value isNegative = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, operand, half);
567 Value signedOperand = builder.create<arith::SubIOp>(loc, operandWide, prime);
568 return builder.create<arith::SelectOp>(loc, isNegative, signedOperand, operandWide);
569}
570
572static void assertNonZeroFelt(OpBuilder &builder, Location loc, Value operand, StringRef message) {
573 auto operandType = mlir::cast<IntegerType>(operand.getType());
574 Value zero = builder.create<arith::ConstantOp>(loc, IntegerAttr::get(operandType, 0));
575 Value isNonZero = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, operand, zero);
576 builder.create<cf::AssertOp>(loc, isNonZero, message);
577}
578
580static Value
581lowerFeltAdd(OpBuilder &builder, Location loc, Value lhs, Value rhs, const Field &field) {
582 unsigned width = field.bitWidth();
583 unsigned wideWidth = width + 1;
584 auto wideType = IntegerType::get(builder.getContext(), wideWidth);
585 Value lhsWide = builder.create<arith::ExtUIOp>(loc, wideType, lhs);
586 Value rhsWide = builder.create<arith::ExtUIOp>(loc, wideType, rhs);
587 Value sum = builder.create<arith::AddIOp>(loc, lhsWide, rhsWide);
588 return normalizeWideValue(builder, loc, sum, width, field);
589}
590
592static Value
593lowerFeltSub(OpBuilder &builder, Location loc, Value lhs, Value rhs, const Field &field) {
594 unsigned width = field.bitWidth();
595 unsigned wideWidth = width + 1;
596 auto wideType = IntegerType::get(builder.getContext(), wideWidth);
597 Value lhsWide = builder.create<arith::ExtUIOp>(loc, wideType, lhs);
598 Value rhsWide = builder.create<arith::ExtUIOp>(loc, wideType, rhs);
599 Value modulus =
600 builder.create<arith::ConstantOp>(loc, field.getPrimeAttr(builder.getContext(), wideWidth));
601 Value lhsPlusMod = builder.create<arith::AddIOp>(loc, lhsWide, modulus);
602 Value diff = builder.create<arith::SubIOp>(loc, lhsPlusMod, rhsWide);
603 return normalizeWideValue(builder, loc, diff, width, field);
604}
605
607static Value lowerFeltNeg(OpBuilder &builder, Location loc, Value operand, const Field &field) {
608 unsigned width = field.bitWidth();
609 unsigned wideWidth = width + 1;
610 auto wideType = IntegerType::get(builder.getContext(), wideWidth);
611 Value operandWide = builder.create<arith::ExtUIOp>(loc, wideType, operand);
612 Value modulus =
613 builder.create<arith::ConstantOp>(loc, field.getPrimeAttr(builder.getContext(), wideWidth));
614 Value diff = builder.create<arith::SubIOp>(loc, modulus, operandWide);
615 return normalizeWideValue(builder, loc, diff, width, field);
616}
617
619static Value
620lowerFeltMul(OpBuilder &builder, Location loc, Value lhs, Value rhs, const Field &field) {
621 unsigned width = field.bitWidth();
622 unsigned wideWidth = width * 2;
623 auto wideType = IntegerType::get(builder.getContext(), wideWidth);
624 Value lhsWide = builder.create<arith::ExtUIOp>(loc, wideType, lhs);
625 Value rhsWide = builder.create<arith::ExtUIOp>(loc, wideType, rhs);
626 Value product = builder.create<arith::MulIOp>(loc, lhsWide, rhsWide);
627 return normalizeWideValue(builder, loc, product, width, field);
628}
629
631static Value lowerFeltInv(OpBuilder &builder, Location loc, Value operand, const Field &field) {
632 llvm::APInt exponent = toExactWidthAPInt(field.prime() - 2, field.bitWidth());
633 Value result = makeOneFelt(builder, loc, field);
634 Value base = operand;
635 for (unsigned bit = 0; bit < exponent.getBitWidth(); ++bit) {
636 if (exponent[bit]) {
637 result = lowerFeltMul(builder, loc, result, base, field);
638 }
639 if (bit + 1 < exponent.getBitWidth()) {
640 base = lowerFeltMul(builder, loc, base, base, field);
641 }
642 }
643 return result;
644}
645
647static Value
648lowerFeltDiv(OpBuilder &builder, Location loc, Value lhs, Value rhs, const Field &field) {
649 return lowerFeltMul(builder, loc, lhs, lowerFeltInv(builder, loc, rhs, field), field);
650}
651
653static Value
654lowerFeltPow(OpBuilder &builder, Location loc, Value base, Value exponent, const Field &field) {
655 auto feltType = IntegerType::get(builder.getContext(), field.bitWidth());
656 Value zero = builder.create<arith::ConstantOp>(loc, IntegerAttr::get(feltType, 0));
657 Value one = builder.create<arith::ConstantOp>(loc, IntegerAttr::get(feltType, 1));
658 Value result = makeOneFelt(builder, loc, field);
659 Value currentBase = base;
660 for (unsigned bit = 0; bit < field.bitWidth(); ++bit) {
661 Value bitIndex = builder.create<arith::ConstantOp>(
662 loc, IntegerAttr::get(feltType, llvm::APInt(field.bitWidth(), bit))
663 );
664 Value shifted = builder.create<arith::ShRUIOp>(loc, exponent, bitIndex);
665 Value masked = builder.create<arith::AndIOp>(loc, shifted, one);
666 Value bitIsSet = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, masked, zero);
667 auto ifOp = builder.create<scf::IfOp>(loc, TypeRange {feltType}, bitIsSet, true);
668 {
669 OpBuilder::InsertionGuard guard(builder);
670 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
671 Value multiplied = lowerFeltMul(builder, loc, result, currentBase, field);
672 builder.create<scf::YieldOp>(loc, multiplied);
673 }
674 {
675 OpBuilder::InsertionGuard guard(builder);
676 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
677 builder.create<scf::YieldOp>(loc, result);
678 }
679 result = ifOp.getResult(0);
680 if (bit + 1 < field.bitWidth()) {
681 currentBase = lowerFeltMul(builder, loc, currentBase, currentBase, field);
682 }
683 }
684 return result;
685}
686
688static Value
689lowerFeltShl(OpBuilder &builder, Location loc, Value lhs, Value rhs, const Field &field) {
690 auto feltType = IntegerType::get(builder.getContext(), field.bitWidth());
691 Value two = builder.create<arith::ConstantOp>(loc, IntegerAttr::get(feltType, 2));
692 return lowerFeltMul(builder, loc, lhs, lowerFeltPow(builder, loc, two, rhs, field), field);
693}
694
696static Value
697lowerFeltOr(OpBuilder &builder, Location loc, Value lhs, Value rhs, const Field &field) {
698 unsigned width = field.bitWidth();
699 auto wideType = IntegerType::get(builder.getContext(), width + 1);
700 Value orValue = builder.create<arith::OrIOp>(loc, lhs, rhs);
701 Value orWide = builder.create<arith::ExtUIOp>(loc, wideType, orValue);
702 return normalizeWideValue(builder, loc, orWide, width, field);
703}
704
706static Value
707lowerFeltXor(OpBuilder &builder, Location loc, Value lhs, Value rhs, const Field &field) {
708 unsigned width = field.bitWidth();
709 auto wideType = IntegerType::get(builder.getContext(), width + 1);
710 Value xorValue = builder.create<arith::XOrIOp>(loc, lhs, rhs);
711 Value xorWide = builder.create<arith::ExtUIOp>(loc, wideType, xorValue);
712 return normalizeWideValue(builder, loc, xorWide, width, field);
713}
714
716static Value lowerFeltUnsignedDiv(OpBuilder &builder, Location loc, Value lhs, Value rhs) {
717 return builder.create<arith::DivUIOp>(loc, lhs, rhs);
718}
719
721static Value
722lowerFeltSignedDiv(OpBuilder &builder, Location loc, Value lhs, Value rhs, const Field &field) {
723 unsigned width = field.bitWidth();
724 Value lhsSigned = lowerFeltToSignedWide(builder, loc, lhs, field);
725 Value rhsSigned = lowerFeltToSignedWide(builder, loc, rhs, field);
726 Value quotient = builder.create<arith::DivSIOp>(loc, lhsSigned, rhsSigned);
727 return normalizeSignedWideValue(builder, loc, quotient, width, field);
728}
729
731static Value lowerFeltUnsignedMod(OpBuilder &builder, Location loc, Value lhs, Value rhs) {
732 return builder.create<arith::RemUIOp>(loc, lhs, rhs);
733}
734
736static Value
737lowerFeltSignedMod(OpBuilder &builder, Location loc, Value lhs, Value rhs, const Field &field) {
738 unsigned width = field.bitWidth();
739 Value lhsSigned = lowerFeltToSignedWide(builder, loc, lhs, field);
740 Value rhsSigned = lowerFeltToSignedWide(builder, loc, rhs, field);
741 Value remainder = builder.create<arith::RemSIOp>(loc, lhsSigned, rhsSigned);
742 return normalizeSignedWideValue(builder, loc, remainder, width, field);
743}
744
746static Value
747lowerFeltShr(OpBuilder &builder, Location loc, Value lhs, Value rhs, const Field &field) {
748 auto feltType = IntegerType::get(builder.getContext(), field.bitWidth());
749 Value width = builder.create<arith::ConstantOp>(
750 loc, IntegerAttr::get(feltType, llvm::APInt(field.bitWidth(), field.bitWidth()))
751 );
752 Value shiftTooLarge = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, rhs, width);
753 Value zero = builder.create<arith::ConstantOp>(loc, IntegerAttr::get(feltType, 0));
754 Value maxValidShift = builder.create<arith::ConstantOp>(
755 loc, IntegerAttr::get(feltType, llvm::APInt(field.bitWidth(), field.bitWidth() - 1))
756 );
757 Value clampedShift = builder.create<arith::MinUIOp>(loc, rhs, maxValidShift);
758 Value shifted = builder.create<arith::ShRUIOp>(loc, lhs, clampedShift);
759 return builder.create<arith::SelectOp>(loc, shiftTooLarge, zero, shifted);
760}
761
763static Value lowerFeltNot(OpBuilder &builder, Location loc, Value operand, const Field &field) {
764 unsigned width = field.bitWidth();
765 auto feltType = IntegerType::get(builder.getContext(), width);
766 auto wideType = IntegerType::get(builder.getContext(), width + 1);
767 Value maxMask = builder.create<arith::ConstantOp>(
768 loc, IntegerAttr::get(feltType, llvm::APInt::getAllOnes(width))
769 );
770 Value complement = builder.create<arith::XOrIOp>(loc, operand, maxMask);
771 Value complementWide = builder.create<arith::ExtUIOp>(loc, wideType, complement);
772 return normalizeWideValue(builder, loc, complementWide, width, field);
773}
774
776static Value loadStorageScalar(OpBuilder &builder, Location loc, Value storageLeaf) {
777 auto memrefType = mlir::cast<MemRefType>(storageLeaf.getType());
778 SmallVector<Value> indices;
779 indices.reserve(memrefType.getRank());
780 for (int64_t dim = 0; dim < memrefType.getRank(); ++dim) {
781 indices.push_back(makeIndexConstant(builder, loc, 0));
782 }
783 return builder.create<memref::LoadOp>(loc, storageLeaf, indices);
784}
785
787static void storeStorageScalar(OpBuilder &builder, Location loc, Value scalar, Value storageLeaf) {
788 auto memrefType = mlir::cast<MemRefType>(storageLeaf.getType());
789 SmallVector<Value> indices;
790 indices.reserve(memrefType.getRank());
791 for (int64_t dim = 0; dim < memrefType.getRank(); ++dim) {
792 indices.push_back(makeIndexConstant(builder, loc, 0));
793 }
794 builder.create<memref::StoreOp>(loc, scalar, storageLeaf, indices);
795}
796
798static LogicalResult copyIntoStorage(
799 OpBuilder &builder, Location loc, Type sourceType, ArrayRef<Value> destLeaves,
800 ArrayRef<Value> sourceLeaves, SymbolTableCollection &tables, Operation *origin,
801 const Field &field
802) {
803 auto leafTypes = getLeafTypes(sourceType, tables, origin, field);
804 if (failed(leafTypes)) {
805 return failure();
806 }
807 if (destLeaves.size() != sourceLeaves.size() || destLeaves.size() != leafTypes->size()) {
808 origin->emitError("flattened leaf mismatch while copying aggregate storage");
809 return failure();
810 }
811 for (auto [leafType, destLeaf, srcLeaf] : llvm::zip(*leafTypes, destLeaves, sourceLeaves)) {
812 if (isa<MemRefType>(leafType)) {
813 builder.create<memref::CopyOp>(loc, srcLeaf, destLeaf);
814 continue;
815 }
816 storeStorageScalar(builder, loc, srcLeaf, destLeaf);
817 }
818 return success();
819}
820
822static FailureOr<LoweredValue> readNamedAggregateValue(
823 OpBuilder &builder, Location loc, Type ownerType, StringRef name, const LoweredValue &owner,
824 SymbolTableCollection &tables, Operation *origin, const Field &field
825) {
826 auto subType = getNamedSubType(ownerType, name, tables, origin);
827 if (failed(subType)) {
828 return failure();
829 }
830 auto span = getNamedLeafSpan(ownerType, name, tables, origin, field);
831 if (failed(span)) {
832 return failure();
833 }
834 LoweredValue result {*subType, {}};
835 auto leafTypes = getLeafTypes(*subType, tables, origin, field);
836 if (failed(leafTypes)) {
837 return failure();
838 }
839 auto leaves = ArrayRef<Value>(owner.leaves).slice(span->first, span->second);
840 for (auto [leafType, leafValue] : llvm::zip(*leafTypes, leaves)) {
841 if (isa<MemRefType>(leafType)) {
842 result.leaves.push_back(leafValue);
843 } else {
844 result.leaves.push_back(loadStorageScalar(builder, loc, leafValue));
845 }
846 }
847 return result;
848}
849
851static LogicalResult writeNamedAggregateValue(
852 OpBuilder &builder, Location loc, Type ownerType, StringRef name, LoweredValue &owner,
853 const LoweredValue &value, SymbolTableCollection &tables, Operation *origin, const Field &field
854) {
855 auto subType = getNamedSubType(ownerType, name, tables, origin);
856 if (failed(subType)) {
857 return failure();
858 }
859 auto span = getNamedLeafSpan(ownerType, name, tables, origin, field);
860 if (failed(span)) {
861 return failure();
862 }
863 return copyIntoStorage(
864 builder, loc, *subType, ArrayRef<Value>(owner.leaves).slice(span->first, span->second),
865 value.leaves, tables, origin, field
866 );
867}
868
870static FailureOr<Value>
871createElementSubview(OpBuilder &builder, Location loc, Value source, ValueRange outerIndices) {
872 auto sourceType = mlir::cast<MemRefType>(source.getType());
873 SmallVector<OpFoldResult> mixedOffsets;
874 SmallVector<OpFoldResult> mixedSizes;
875 SmallVector<OpFoldResult> mixedStrides;
876 auto indexedRank = checkedCast<int64_t>(outerIndices.size());
877 if (!indexedRank) {
878 emitError(loc) << llvm::toString(indexedRank.takeError());
879 return failure();
880 }
881 mixedOffsets.reserve(sourceType.getRank());
882 mixedSizes.reserve(sourceType.getRank());
883 mixedStrides.reserve(sourceType.getRank());
884 for (Value index : outerIndices) {
885 mixedOffsets.push_back(index);
886 }
887 for (int64_t dim = *indexedRank; dim < sourceType.getRank(); ++dim) {
888 mixedOffsets.push_back(builder.getIndexAttr(0));
889 }
890 for (int64_t dim = 0; dim < *indexedRank; ++dim) {
891 mixedSizes.push_back(builder.getIndexAttr(1));
892 }
893 for (int64_t dim = *indexedRank; dim < sourceType.getRank(); ++dim) {
894 mixedSizes.push_back(memref::getMixedSize(builder, loc, source, dim));
895 }
896 for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
897 mixedStrides.push_back(builder.getIndexAttr(1));
898 }
899 SmallVector<int64_t> desiredShape;
900 auto reserveSize = checkedCast<size_t>(sourceType.getRank() - *indexedRank);
901 if (!reserveSize) {
902 emitError(loc) << llvm::toString(reserveSize.takeError());
903 return failure();
904 }
905 desiredShape.reserve(*reserveSize);
906 for (int64_t dim = *indexedRank; dim < sourceType.getRank(); ++dim) {
907 auto dimIndex = checkedCast<size_t>(dim);
908 if (!dimIndex) {
909 emitError(loc) << llvm::toString(dimIndex.takeError());
910 return failure();
911 }
912 if (auto attr = llvm::dyn_cast<Attribute>(mixedSizes[*dimIndex])) {
913 desiredShape.push_back(mlir::cast<IntegerAttr>(attr).getInt());
914 } else {
915 desiredShape.push_back(ShapedType::kDynamic);
916 }
917 }
918 if (desiredShape.empty()) {
919 desiredShape.push_back(1);
920 }
921 auto resultType = mlir::cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
922 desiredShape, sourceType, mixedOffsets, mixedSizes, mixedStrides
923 ));
924 auto op = builder.create<memref::SubViewOp>(
925 loc, resultType, source, mixedOffsets, mixedSizes, mixedStrides
926 );
927 return success(op.getResult());
928}
929
931static FailureOr<LoweredValue> readArrayElement(
932 OpBuilder &builder, Location loc, array::ArrayType arrayType, const LoweredValue &arrayValue,
933 ArrayRef<Value> indices
934) {
935 Type elementType = arrayType.getElementType();
936 LoweredValue result {elementType, {}};
937 if (isScalarType(elementType)) {
938 result.leaves.push_back(
939 builder.create<memref::LoadOp>(loc, arrayValue.leaves.front(), indices)
940 );
941 return result;
942 }
943
944 for (Value sourceLeaf : arrayValue.leaves) {
945 auto subview = createElementSubview(builder, loc, sourceLeaf, indices);
946 if (failed(subview)) {
947 return failure();
948 }
949 result.leaves.push_back(*subview);
950 }
951 return result;
952}
953
955static LogicalResult writeArrayElement(
956 OpBuilder &builder, Location loc, array::ArrayType arrayType, LoweredValue &arrayValue,
957 ArrayRef<Value> indices, const LoweredValue &elementValue
958) {
959 Type elementType = arrayType.getElementType();
960 if (isScalarType(elementType)) {
961 builder.create<memref::StoreOp>(
962 loc, elementValue.leaves.front(), arrayValue.leaves.front(), indices
963 );
964 return success();
965 }
966
967 for (auto [destLeaf, srcLeaf] : llvm::zip(arrayValue.leaves, elementValue.leaves)) {
968 auto subview = createElementSubview(builder, loc, destLeaf, indices);
969 if (failed(subview)) {
970 return failure();
971 }
972 builder.create<memref::CopyOp>(loc, srcLeaf, *subview);
973 }
974 return success();
975}
976
978static LogicalResult appendFlatLeavesToTypes(
979 OpBuilder &builder, Location loc, const LoweredValue &value, ArrayRef<Type> targetLeafTypes,
980 SmallVectorImpl<Value> &out, Operation *origin
981) {
982 if (targetLeafTypes.size() != value.leaves.size()) {
983 origin->emitError("flattened leaf mismatch during call lowering");
984 return failure();
985 }
986 for (auto [leafValue, leafType] : llvm::zip(value.leaves, targetLeafTypes)) {
987 if (leafValue.getType() == leafType) {
988 out.push_back(leafValue);
989 continue;
990 }
991 if (isa<MemRefType>(leafValue.getType()) && isa<MemRefType>(leafType)) {
992 out.push_back(builder.create<memref::CastOp>(loc, leafType, leafValue));
993 continue;
994 }
995 origin->emitError("lowered leaf type mismatch during call lowering");
996 return failure();
997 }
998 return success();
999}
1000
1002class BodyLowerer {
1003public:
1005 BodyLowerer(
1006 ModuleOp mod, SymbolTableCollection &symbolTables, const Field &moduleField,
1007 const WitgenOptions &options
1008 )
1009 : moduleOp(mod), tables(symbolTables), field(moduleField),
1010 uninitializedBehavior(options.uninitializedBehavior), rng(makeDefaultValueRng(options)) {}
1011
1013 FailureOr<func::FuncOp> lowerFunction(function::FuncDefOp funcOp) {
1014 if (funcOp.isExternal()) {
1015 funcOp.emitError("execution-engine backend does not lower extern functions");
1016 return failure();
1017 }
1018 if (!funcOp.getBody().hasOneBlock()) {
1019 funcOp.emitError("execution-engine backend only supports single-block functions");
1020 return failure();
1021 }
1022
1023 SmallVector<Type> loweredArgTypes;
1024 for (Type argType : funcOp.getArgumentTypes()) {
1025 if (failed(
1026 flattenABILeafTypes(argType, tables, funcOp.getOperation(), field, loweredArgTypes)
1027 )) {
1028 return failure();
1029 }
1030 }
1031 SmallVector<Type> loweredResultTypes;
1032 for (Type resultType : funcOp.getResultTypes()) {
1033 if (failed(flattenABILeafTypes(
1034 resultType, tables, funcOp.getOperation(), field, loweredResultTypes
1035 ))) {
1036 return failure();
1037 }
1038 }
1039
1040 OpBuilder moduleBuilder(moduleOp.getContext());
1041 moduleBuilder.setInsertionPointToEnd(moduleOp.getBody());
1042 auto loweredFunc = moduleBuilder.create<func::FuncOp>(
1043 funcOp.getLoc(), mangleFunctionName(funcOp),
1044 moduleBuilder.getFunctionType(loweredArgTypes, loweredResultTypes)
1045 );
1046 Block *entry = loweredFunc.addEntryBlock();
1047 OpBuilder builder(entry, entry->begin());
1048
1049 DenseMap<Value, LoweredValue> valueMap;
1050 unsigned cursor = 0;
1051 for (auto [arg, argType] :
1052 llvm::zip(funcOp.getBody().front().getArguments(), funcOp.getArgumentTypes())) {
1053 auto leafCount = getLeafCount(argType, tables, funcOp.getOperation(), field);
1054 if (failed(leafCount)) {
1055 loweredFunc.erase();
1056 return failure();
1057 }
1058 LoweredValue lowered {argType, {}};
1059 lowered.leaves.append(
1060 entry->getArguments().begin() + cursor,
1061 entry->getArguments().begin() + cursor + *leafCount
1062 );
1063 cursor += *leafCount;
1064 valueMap[arg] = std::move(lowered);
1065 }
1066
1067 if (failed(lowerBlock(builder, funcOp.getBody().front(), valueMap))) {
1068 loweredFunc.erase();
1069 return failure();
1070 }
1071 return loweredFunc;
1072 }
1073
1074private:
1075 ModuleOp moduleOp;
1076 SymbolTableCollection &tables;
1077 const Field &field;
1078 UninitializedBehavior uninitializedBehavior;
1079 std::mt19937_64 rng;
1080
1082 FailureOr<LoweredValue>
1083 lookup(Value value, DenseMap<Value, LoweredValue> &valueMap, Operation *origin) {
1084 auto it = valueMap.find(value);
1085 if (it == valueMap.end()) {
1086 origin->emitError("failed to find lowered SSA value");
1087 return failure();
1088 }
1089 return it->second;
1090 }
1091
1093 FailureOr<Value>
1094 lookupScalar(Value value, DenseMap<Value, LoweredValue> &valueMap, Operation *origin) {
1095 auto lowered = lookup(value, valueMap, origin);
1096 if (failed(lowered) || lowered->leaves.size() != 1 ||
1097 isa<MemRefType>(lowered->leaves.front().getType())) {
1098 origin->emitError("expected scalar lowered value");
1099 return failure();
1100 }
1101 return lowered->leaves.front();
1102 }
1103
1105 LogicalResult
1106 lowerBlock(OpBuilder &builder, Block &block, DenseMap<Value, LoweredValue> &valueMap) {
1107 for (Operation &op : block) {
1108 if (failed(lowerOperation(builder, op, valueMap))) {
1109 return failure();
1110 }
1111 }
1112 return success();
1113 }
1114
1116 FailureOr<Value>
1117 lowerFeltCmp(OpBuilder &builder, Location loc, boolean::CmpOp cmpOp, Value lhs, Value rhs) {
1118 arith::CmpIPredicate predicate;
1119 switch (cmpOp.getPredicate()) {
1121 predicate = arith::CmpIPredicate::eq;
1122 break;
1124 predicate = arith::CmpIPredicate::ne;
1125 break;
1127 predicate = arith::CmpIPredicate::ult;
1128 break;
1130 predicate = arith::CmpIPredicate::ule;
1131 break;
1133 predicate = arith::CmpIPredicate::ugt;
1134 break;
1136 predicate = arith::CmpIPredicate::uge;
1137 break;
1138 }
1139 return builder.create<arith::CmpIOp>(loc, predicate, lhs, rhs).getResult();
1140 }
1141
1143 LogicalResult
1144 lowerOperation(OpBuilder &builder, Operation &op, DenseMap<Value, LoweredValue> &valueMap) {
1145 Location loc = op.getLoc();
1146
1147 auto bind = [&](Value result, LoweredValue lowered) {
1148 valueMap[result] = std::move(lowered);
1149 return success();
1150 };
1151
1152 if (auto returnOp = dyn_cast<function::ReturnOp>(op)) {
1153 SmallVector<Value> results;
1154 for (Value operand : returnOp.getOperands()) {
1155 auto lowered = lookup(operand, valueMap, returnOp.getOperation());
1156 auto leafTypes = getABILeafTypes(operand.getType(), tables, returnOp.getOperation(), field);
1157 if (failed(lowered) || failed(leafTypes) ||
1158 failed(appendFlatLeavesToTypes(
1159 builder, loc, *lowered, *leafTypes, results, returnOp.getOperation()
1160 ))) {
1161 return failure();
1162 }
1163 }
1164 builder.create<func::ReturnOp>(loc, results);
1165 return success();
1166 }
1167
1168 if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1169 SmallVector<Value> results;
1170 for (Value operand : yieldOp.getOperands()) {
1171 auto lowered = lookup(operand, valueMap, yieldOp.getOperation());
1172 auto leafTypes = getABILeafTypes(operand.getType(), tables, yieldOp.getOperation(), field);
1173 if (failed(lowered) || failed(leafTypes) ||
1174 failed(appendFlatLeavesToTypes(
1175 builder, loc, *lowered, *leafTypes, results, yieldOp.getOperation()
1176 ))) {
1177 return failure();
1178 }
1179 }
1180 builder.create<scf::YieldOp>(loc, results);
1181 return success();
1182 }
1183 if (auto conditionOp = dyn_cast<scf::ConditionOp>(op)) {
1184 auto condition =
1185 lookupScalar(conditionOp.getCondition(), valueMap, conditionOp.getOperation());
1186 if (failed(condition)) {
1187 return failure();
1188 }
1189 SmallVector<Value> results;
1190 for (Value operand : conditionOp.getArgs()) {
1191 auto lowered = lookup(operand, valueMap, conditionOp.getOperation());
1192 auto leafTypes =
1193 getABILeafTypes(operand.getType(), tables, conditionOp.getOperation(), field);
1194 if (failed(lowered) || failed(leafTypes) ||
1195 failed(appendFlatLeavesToTypes(
1196 builder, loc, *lowered, *leafTypes, results, conditionOp.getOperation()
1197 ))) {
1198 return failure();
1199 }
1200 }
1201 builder.create<scf::ConditionOp>(loc, *condition, results);
1202 return success();
1203 }
1204
1205 if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1206 Operation *clone = builder.clone(op);
1207 return bind(
1208 constantOp.getResult(), LoweredValue {constantOp.getType(), {clone->getResult(0)}}
1209 );
1210 }
1211
1212 if (auto feltConst = dyn_cast<felt::FeltConstantOp>(op)) {
1213 auto intType = IntegerType::get(builder.getContext(), field.bitWidth());
1214 // Reduce into the field first, then build an APInt with the exact storage width.
1215 auto constVal = toDynamicAPInt(feltConst.getValue().getValue());
1216 auto modVal = constVal % field.prime();
1217 auto intVal = llzk::toExactWidthAPInt(modVal, field.bitWidth());
1218 Value lowered = builder.create<arith::ConstantOp>(loc, IntegerAttr::get(intType, intVal));
1219 return bind(feltConst.getResult(), LoweredValue {feltConst.getType(), {lowered}});
1220 }
1221
1222 if (auto nondetOp = dyn_cast<llzk::NonDetOp>(op)) {
1223 auto lowered = createDefaultValue(
1224 builder, loc, nondetOp.getType(), tables, nondetOp.getOperation(), field,
1225 uninitializedBehavior, rng
1226 );
1227 if (failed(lowered)) {
1228 return failure();
1229 }
1230 return bind(nondetOp.getResult(), std::move(*lowered));
1231 }
1232
1233 if (auto addOp = dyn_cast<felt::AddFeltOp>(op)) {
1234 auto lhs = lookupScalar(addOp.getLhs(), valueMap, addOp.getOperation());
1235 auto rhs = lookupScalar(addOp.getRhs(), valueMap, addOp.getOperation());
1236 if (failed(lhs) || failed(rhs)) {
1237 return failure();
1238 }
1239 return bind(
1240 addOp.getResult(),
1241 LoweredValue {addOp.getType(), {lowerFeltAdd(builder, loc, *lhs, *rhs, field)}}
1242 );
1243 }
1244 if (auto powOp = dyn_cast<felt::PowFeltOp>(op)) {
1245 auto lhs = lookupScalar(powOp.getLhs(), valueMap, powOp.getOperation());
1246 auto rhs = lookupScalar(powOp.getRhs(), valueMap, powOp.getOperation());
1247 if (failed(lhs) || failed(rhs)) {
1248 return failure();
1249 }
1250 return bind(
1251 powOp.getResult(),
1252 LoweredValue {powOp.getType(), {lowerFeltPow(builder, loc, *lhs, *rhs, field)}}
1253 );
1254 }
1255 if (auto andOp = dyn_cast<felt::AndFeltOp>(op)) {
1256 auto lhs = lookupScalar(andOp.getLhs(), valueMap, andOp.getOperation());
1257 auto rhs = lookupScalar(andOp.getRhs(), valueMap, andOp.getOperation());
1258 if (failed(lhs) || failed(rhs)) {
1259 return failure();
1260 }
1261 return bind(
1262 andOp.getResult(),
1263 LoweredValue {andOp.getType(), {builder.create<arith::AndIOp>(loc, *lhs, *rhs)}}
1264 );
1265 }
1266 if (auto orOp = dyn_cast<felt::OrFeltOp>(op)) {
1267 auto lhs = lookupScalar(orOp.getLhs(), valueMap, orOp.getOperation());
1268 auto rhs = lookupScalar(orOp.getRhs(), valueMap, orOp.getOperation());
1269 if (failed(lhs) || failed(rhs)) {
1270 return failure();
1271 }
1272 return bind(
1273 orOp.getResult(),
1274 LoweredValue {orOp.getType(), {lowerFeltOr(builder, loc, *lhs, *rhs, field)}}
1275 );
1276 }
1277 if (auto xorOp = dyn_cast<felt::XorFeltOp>(op)) {
1278 auto lhs = lookupScalar(xorOp.getLhs(), valueMap, xorOp.getOperation());
1279 auto rhs = lookupScalar(xorOp.getRhs(), valueMap, xorOp.getOperation());
1280 if (failed(lhs) || failed(rhs)) {
1281 return failure();
1282 }
1283 return bind(
1284 xorOp.getResult(),
1285 LoweredValue {xorOp.getType(), {lowerFeltXor(builder, loc, *lhs, *rhs, field)}}
1286 );
1287 }
1288 if (auto subOp = dyn_cast<felt::SubFeltOp>(op)) {
1289 auto lhs = lookupScalar(subOp.getLhs(), valueMap, subOp.getOperation());
1290 auto rhs = lookupScalar(subOp.getRhs(), valueMap, subOp.getOperation());
1291 if (failed(lhs) || failed(rhs)) {
1292 return failure();
1293 }
1294 return bind(
1295 subOp.getResult(),
1296 LoweredValue {subOp.getType(), {lowerFeltSub(builder, loc, *lhs, *rhs, field)}}
1297 );
1298 }
1299 if (auto mulOp = dyn_cast<felt::MulFeltOp>(op)) {
1300 auto lhs = lookupScalar(mulOp.getLhs(), valueMap, mulOp.getOperation());
1301 auto rhs = lookupScalar(mulOp.getRhs(), valueMap, mulOp.getOperation());
1302 if (failed(lhs) || failed(rhs)) {
1303 return failure();
1304 }
1305 return bind(
1306 mulOp.getResult(),
1307 LoweredValue {mulOp.getType(), {lowerFeltMul(builder, loc, *lhs, *rhs, field)}}
1308 );
1309 }
1310 if (auto negOp = dyn_cast<felt::NegFeltOp>(op)) {
1311 auto operand = lookupScalar(negOp.getOperand(), valueMap, negOp.getOperation());
1312 if (failed(operand)) {
1313 return failure();
1314 }
1315 return bind(
1316 negOp.getResult(),
1317 LoweredValue {negOp.getType(), {lowerFeltNeg(builder, loc, *operand, field)}}
1318 );
1319 }
1320 if (auto invOp = dyn_cast<felt::InvFeltOp>(op)) {
1321 auto operand = lookupScalar(invOp.getOperand(), valueMap, invOp.getOperation());
1322 if (failed(operand)) {
1323 return failure();
1324 }
1325 return bind(
1326 invOp.getResult(),
1327 LoweredValue {invOp.getType(), {lowerFeltInv(builder, loc, *operand, field)}}
1328 );
1329 }
1330 if (auto divOp = dyn_cast<felt::DivFeltOp>(op)) {
1331 auto lhs = lookupScalar(divOp.getLhs(), valueMap, divOp.getOperation());
1332 auto rhs = lookupScalar(divOp.getRhs(), valueMap, divOp.getOperation());
1333 if (failed(lhs) || failed(rhs)) {
1334 return failure();
1335 }
1336 return bind(
1337 divOp.getResult(),
1338 LoweredValue {divOp.getType(), {lowerFeltDiv(builder, loc, *lhs, *rhs, field)}}
1339 );
1340 }
1341 if (auto uintDivOp = dyn_cast<felt::UnsignedIntDivFeltOp>(op)) {
1342 auto lhs = lookupScalar(uintDivOp.getLhs(), valueMap, uintDivOp.getOperation());
1343 auto rhs = lookupScalar(uintDivOp.getRhs(), valueMap, uintDivOp.getOperation());
1344 if (failed(lhs) || failed(rhs)) {
1345 return failure();
1346 }
1347 assertNonZeroFelt(builder, loc, *rhs, "felt.uintdiv divisor must be non-zero");
1348 return bind(
1349 uintDivOp.getResult(),
1350 LoweredValue {uintDivOp.getType(), {lowerFeltUnsignedDiv(builder, loc, *lhs, *rhs)}}
1351 );
1352 }
1353 if (auto sintDivOp = dyn_cast<felt::SignedIntDivFeltOp>(op)) {
1354 auto lhs = lookupScalar(sintDivOp.getLhs(), valueMap, sintDivOp.getOperation());
1355 auto rhs = lookupScalar(sintDivOp.getRhs(), valueMap, sintDivOp.getOperation());
1356 if (failed(lhs) || failed(rhs)) {
1357 return failure();
1358 }
1359 assertNonZeroFelt(builder, loc, *rhs, "felt.sintdiv divisor must be non-zero");
1360 return bind(
1361 sintDivOp.getResult(),
1362 LoweredValue {sintDivOp.getType(), {lowerFeltSignedDiv(builder, loc, *lhs, *rhs, field)}}
1363 );
1364 }
1365 if (auto umodOp = dyn_cast<felt::UnsignedModFeltOp>(op)) {
1366 auto lhs = lookupScalar(umodOp.getLhs(), valueMap, umodOp.getOperation());
1367 auto rhs = lookupScalar(umodOp.getRhs(), valueMap, umodOp.getOperation());
1368 if (failed(lhs) || failed(rhs)) {
1369 return failure();
1370 }
1371 assertNonZeroFelt(builder, loc, *rhs, "felt.umod divisor must be non-zero");
1372 return bind(
1373 umodOp.getResult(),
1374 LoweredValue {umodOp.getType(), {lowerFeltUnsignedMod(builder, loc, *lhs, *rhs)}}
1375 );
1376 }
1377 if (auto smodOp = dyn_cast<felt::SignedModFeltOp>(op)) {
1378 auto lhs = lookupScalar(smodOp.getLhs(), valueMap, smodOp.getOperation());
1379 auto rhs = lookupScalar(smodOp.getRhs(), valueMap, smodOp.getOperation());
1380 if (failed(lhs) || failed(rhs)) {
1381 return failure();
1382 }
1383 assertNonZeroFelt(builder, loc, *rhs, "felt.smod divisor must be non-zero");
1384 return bind(
1385 smodOp.getResult(),
1386 LoweredValue {smodOp.getType(), {lowerFeltSignedMod(builder, loc, *lhs, *rhs, field)}}
1387 );
1388 }
1389 if (auto shrOp = dyn_cast<felt::ShrFeltOp>(op)) {
1390 auto lhs = lookupScalar(shrOp.getLhs(), valueMap, shrOp.getOperation());
1391 auto rhs = lookupScalar(shrOp.getRhs(), valueMap, shrOp.getOperation());
1392 if (failed(lhs) || failed(rhs)) {
1393 return failure();
1394 }
1395 return bind(
1396 shrOp.getResult(),
1397 LoweredValue {shrOp.getType(), {lowerFeltShr(builder, loc, *lhs, *rhs, field)}}
1398 );
1399 }
1400 if (auto shlOp = dyn_cast<felt::ShlFeltOp>(op)) {
1401 auto lhs = lookupScalar(shlOp.getLhs(), valueMap, shlOp.getOperation());
1402 auto rhs = lookupScalar(shlOp.getRhs(), valueMap, shlOp.getOperation());
1403 if (failed(lhs) || failed(rhs)) {
1404 return failure();
1405 }
1406 return bind(
1407 shlOp.getResult(),
1408 LoweredValue {shlOp.getType(), {lowerFeltShl(builder, loc, *lhs, *rhs, field)}}
1409 );
1410 }
1411 if (auto notOp = dyn_cast<felt::NotFeltOp>(op)) {
1412 auto operand = lookupScalar(notOp.getOperand(), valueMap, notOp.getOperation());
1413 if (failed(operand)) {
1414 return failure();
1415 }
1416 return bind(
1417 notOp.getResult(),
1418 LoweredValue {notOp.getType(), {lowerFeltNot(builder, loc, *operand, field)}}
1419 );
1420 }
1421
1422 if (auto cmpOp = dyn_cast<boolean::CmpOp>(op)) {
1423 auto lhs = lookupScalar(cmpOp.getLhs(), valueMap, cmpOp.getOperation());
1424 auto rhs = lookupScalar(cmpOp.getRhs(), valueMap, cmpOp.getOperation());
1425 if (failed(lhs) || failed(rhs)) {
1426 return failure();
1427 }
1428 auto lowered = lowerFeltCmp(builder, loc, cmpOp, *lhs, *rhs);
1429 if (failed(lowered)) {
1430 return failure();
1431 }
1432 return bind(cmpOp.getResult(), LoweredValue {cmpOp.getType(), {*lowered}});
1433 }
1434 if (auto assertOp = dyn_cast<boolean::AssertOp>(op)) {
1435 auto condition = lookupScalar(assertOp.getCondition(), valueMap, assertOp.getOperation());
1436 if (failed(condition)) {
1437 return failure();
1438 }
1439 builder.create<cf::AssertOp>(
1440 loc, *condition, assertOp.getMsg() ? assertOp.getMsg()->str() : "bool.assert failed"
1441 );
1442 return success();
1443 }
1444 if (auto andOp = dyn_cast<boolean::AndBoolOp>(op)) {
1445 auto lhs = lookupScalar(andOp.getLhs(), valueMap, andOp.getOperation());
1446 auto rhs = lookupScalar(andOp.getRhs(), valueMap, andOp.getOperation());
1447 if (failed(lhs) || failed(rhs)) {
1448 return failure();
1449 }
1450 return bind(
1451 andOp.getResult(),
1452 LoweredValue {andOp.getType(), {builder.create<arith::AndIOp>(loc, *lhs, *rhs)}}
1453 );
1454 }
1455 if (auto orOp = dyn_cast<boolean::OrBoolOp>(op)) {
1456 auto lhs = lookupScalar(orOp.getLhs(), valueMap, orOp.getOperation());
1457 auto rhs = lookupScalar(orOp.getRhs(), valueMap, orOp.getOperation());
1458 if (failed(lhs) || failed(rhs)) {
1459 return failure();
1460 }
1461 return bind(
1462 orOp.getResult(),
1463 LoweredValue {orOp.getType(), {builder.create<arith::OrIOp>(loc, *lhs, *rhs)}}
1464 );
1465 }
1466 if (auto xorOp = dyn_cast<boolean::XorBoolOp>(op)) {
1467 auto lhs = lookupScalar(xorOp.getLhs(), valueMap, xorOp.getOperation());
1468 auto rhs = lookupScalar(xorOp.getRhs(), valueMap, xorOp.getOperation());
1469 if (failed(lhs) || failed(rhs)) {
1470 return failure();
1471 }
1472 return bind(
1473 xorOp.getResult(),
1474 LoweredValue {xorOp.getType(), {builder.create<arith::XOrIOp>(loc, *lhs, *rhs)}}
1475 );
1476 }
1477 if (auto notOp = dyn_cast<boolean::NotBoolOp>(op)) {
1478 auto operand = lookupScalar(notOp.getOperand(), valueMap, notOp.getOperation());
1479 if (failed(operand)) {
1480 return failure();
1481 }
1482 Value one = builder.create<arith::ConstantOp>(
1483 loc, IntegerAttr::get(IntegerType::get(builder.getContext(), 1), 1)
1484 );
1485 return bind(
1486 notOp.getResult(),
1487 LoweredValue {notOp.getType(), {builder.create<arith::XOrIOp>(loc, *operand, one)}}
1488 );
1489 }
1490
1491 if (auto intToFelt = dyn_cast<cast::IntToFeltOp>(op)) {
1492 auto operand = lookupScalar(intToFelt.getValue(), valueMap, intToFelt.getOperation());
1493 if (failed(operand)) {
1494 return failure();
1495 }
1496 auto dstType = IntegerType::get(builder.getContext(), field.bitWidth());
1497 Value lowered;
1498 if (isa<IndexType>((*operand).getType())) {
1499 lowered = builder.create<arith::IndexCastUIOp>(loc, dstType, *operand);
1500 } else {
1501 auto intType = mlir::cast<IntegerType>((*operand).getType());
1502 if (intType.getWidth() < dstType.getWidth()) {
1503 lowered = builder.create<arith::ExtUIOp>(loc, dstType, *operand);
1504 } else if (intType.getWidth() > dstType.getWidth()) {
1505 lowered = normalizeWideValue(builder, loc, *operand, dstType.getWidth(), field);
1506 } else {
1507 lowered = *operand;
1508 }
1509 }
1510 return bind(intToFelt.getResult(), LoweredValue {intToFelt.getType(), {lowered}});
1511 }
1512 if (auto feltToIndex = dyn_cast<cast::FeltToIndexOp>(op)) {
1513 auto operand = lookupScalar(feltToIndex.getValue(), valueMap, feltToIndex.getOperation());
1514 if (failed(operand)) {
1515 return failure();
1516 }
1517 return bind(
1518 feltToIndex.getResult(),
1519 LoweredValue {
1520 feltToIndex.getType(),
1521 {builder.create<arith::IndexCastUIOp>(loc, builder.getIndexType(), *operand)}
1522 }
1523 );
1524 }
1525
1526 if (auto structNewOp = dyn_cast<component::CreateStructOp>(op)) {
1527 auto lowered = createDefaultValue(
1528 builder, loc, structNewOp.getType(), tables, structNewOp.getOperation(), field,
1529 uninitializedBehavior, rng
1530 );
1531 if (failed(lowered)) {
1532 return failure();
1533 }
1534 return bind(structNewOp.getResult(), std::move(*lowered));
1535 }
1536 if (auto readMemberOp = dyn_cast<component::MemberReadOp>(op)) {
1537 auto componentValue =
1538 lookup(readMemberOp.getComponent(), valueMap, readMemberOp.getOperation());
1539 if (failed(componentValue)) {
1540 return failure();
1541 }
1542 auto lowered = readNamedAggregateValue(
1543 builder, loc, readMemberOp.getComponent().getType(), readMemberOp.getMemberName(),
1544 *componentValue, tables, readMemberOp.getOperation(), field
1545 );
1546 if (failed(lowered)) {
1547 return failure();
1548 }
1549 return bind(readMemberOp.getResult(), std::move(*lowered));
1550 }
1551 if (auto writeMemberOp = dyn_cast<component::MemberWriteOp>(op)) {
1552 auto componentValue =
1553 lookup(writeMemberOp.getComponent(), valueMap, writeMemberOp.getOperation());
1554 auto memberValue = lookup(writeMemberOp.getVal(), valueMap, writeMemberOp.getOperation());
1555 if (failed(componentValue) || failed(memberValue)) {
1556 return failure();
1557 }
1558 return writeNamedAggregateValue(
1559 builder, loc, writeMemberOp.getComponent().getType(), writeMemberOp.getMemberName(),
1560 valueMap[writeMemberOp.getComponent()], *memberValue, tables,
1561 writeMemberOp.getOperation(), field
1562 );
1563 }
1564
1565 if (auto newPodOp = dyn_cast<pod::NewPodOp>(op)) {
1566 auto lowered = createDefaultValue(
1567 builder, loc, newPodOp.getType(), tables, newPodOp.getOperation(), field,
1568 uninitializedBehavior, rng
1569 );
1570 if (failed(lowered)) {
1571 return failure();
1572 }
1573 for (pod::RecordValue init : newPodOp.getInitializedRecordValues()) {
1574 auto value = lookup(init.value, valueMap, newPodOp.getOperation());
1575 if (failed(value) || failed(writeNamedAggregateValue(
1576 builder, loc, newPodOp.getType(), init.name, *lowered, *value,
1577 tables, newPodOp.getOperation(), field
1578 ))) {
1579 return failure();
1580 }
1581 }
1582 return bind(newPodOp.getResult(), std::move(*lowered));
1583 }
1584 if (auto readPodOp = dyn_cast<pod::ReadPodOp>(op)) {
1585 auto podValue = lookup(readPodOp.getPodRef(), valueMap, readPodOp.getOperation());
1586 if (failed(podValue)) {
1587 return failure();
1588 }
1589 auto lowered = readNamedAggregateValue(
1590 builder, loc, readPodOp.getPodRef().getType(), readPodOp.getRecordName(), *podValue,
1591 tables, readPodOp.getOperation(), field
1592 );
1593 if (failed(lowered)) {
1594 return failure();
1595 }
1596 return bind(readPodOp.getResult(), std::move(*lowered));
1597 }
1598 if (auto writePodOp = dyn_cast<pod::WritePodOp>(op)) {
1599 auto recordValue = lookup(writePodOp.getValue(), valueMap, writePodOp.getOperation());
1600 if (failed(recordValue)) {
1601 return failure();
1602 }
1603 return writeNamedAggregateValue(
1604 builder, loc, writePodOp.getPodRef().getType(), writePodOp.getRecordName(),
1605 valueMap[writePodOp.getPodRef()], *recordValue, tables, writePodOp.getOperation(), field
1606 );
1607 }
1608
1609 if (auto arrayNewOp = dyn_cast<array::CreateArrayOp>(op)) {
1610 auto lowered = createDefaultValue(
1611 builder, loc, arrayNewOp.getType(), tables, arrayNewOp.getOperation(), field,
1612 uninitializedBehavior, rng
1613 );
1614 if (failed(lowered)) {
1615 return failure();
1616 }
1617 if (!arrayNewOp.getElements().empty()) {
1618 auto elementCount = checkedCast<size_t>(arrayNewOp.getType().getNumElements());
1619 if (!elementCount) {
1620 arrayNewOp.emitError() << llvm::toString(elementCount.takeError());
1621 return failure();
1622 }
1623 if (arrayNewOp.getElements().size() != *elementCount) {
1624 arrayNewOp.emitError("expected one explicit element per array slot in witgen lowering");
1625 return failure();
1626 }
1627 auto shape = arrayNewOp.getType().getShape();
1628 for (auto [flatIndex, operand] : llvm::enumerate(arrayNewOp.getElements())) {
1629 auto elementValue = lookup(operand, valueMap, arrayNewOp.getOperation());
1630 if (failed(elementValue)) {
1631 return failure();
1632 }
1633 SmallVector<Value> indices;
1634 auto strides = mlir::computeStrides(shape);
1635 auto flatSigned = checkedCast<int64_t>(flatIndex);
1636 if (!flatSigned) {
1637 arrayNewOp.emitError() << llvm::toString(flatSigned.takeError());
1638 return failure();
1639 }
1640 for (int64_t index : mlir::delinearize(*flatSigned, strides)) {
1641 indices.push_back(makeIndexConstant(builder, loc, index));
1642 }
1643 if (failed(writeArrayElement(
1644 builder, loc, arrayNewOp.getType(), *lowered, indices, *elementValue
1645 ))) {
1646 return failure();
1647 }
1648 }
1649 }
1650 return bind(arrayNewOp.getResult(), std::move(*lowered));
1651 }
1652 if (auto readArrayOp = dyn_cast<array::ReadArrayOp>(op)) {
1653 SmallVector<Value> indices;
1654 for (Value indexValue : readArrayOp.getIndices()) {
1655 auto loweredIndex = lookupScalar(indexValue, valueMap, readArrayOp.getOperation());
1656 if (failed(loweredIndex)) {
1657 return failure();
1658 }
1659 indices.push_back(*loweredIndex);
1660 }
1661 auto arrayValue = lookup(readArrayOp.getArrRef(), valueMap, readArrayOp.getOperation());
1662 if (failed(arrayValue)) {
1663 return failure();
1664 }
1665 auto lowered = readArrayElement(
1666 builder, loc, mlir::cast<array::ArrayType>(readArrayOp.getArrRef().getType()),
1667 *arrayValue, indices
1668 );
1669 if (failed(lowered)) {
1670 return failure();
1671 }
1672 return bind(readArrayOp.getResult(), std::move(*lowered));
1673 }
1674 if (auto writeArrayOp = dyn_cast<array::WriteArrayOp>(op)) {
1675 SmallVector<Value> indices;
1676 for (Value indexValue : writeArrayOp.getIndices()) {
1677 auto loweredIndex = lookupScalar(indexValue, valueMap, writeArrayOp.getOperation());
1678 if (failed(loweredIndex)) {
1679 return failure();
1680 }
1681 indices.push_back(*loweredIndex);
1682 }
1683 auto elementValue = lookup(writeArrayOp.getRvalue(), valueMap, writeArrayOp.getOperation());
1684 if (failed(elementValue)) {
1685 return failure();
1686 }
1687 return writeArrayElement(
1688 builder, loc, mlir::cast<array::ArrayType>(writeArrayOp.getArrRef().getType()),
1689 valueMap[writeArrayOp.getArrRef()], indices, *elementValue
1690 );
1691 }
1692
1693 if (auto cmpiOp = dyn_cast<arith::CmpIOp>(op)) {
1694 auto lhs = lookupScalar(cmpiOp.getLhs(), valueMap, cmpiOp.getOperation());
1695 auto rhs = lookupScalar(cmpiOp.getRhs(), valueMap, cmpiOp.getOperation());
1696 if (failed(lhs) || failed(rhs)) {
1697 return failure();
1698 }
1699 return bind(
1700 cmpiOp.getResult(),
1701 LoweredValue {
1702 cmpiOp.getType(),
1703 {builder.create<arith::CmpIOp>(loc, cmpiOp.getPredicate(), *lhs, *rhs)}
1704 }
1705 );
1706 }
1707 if (auto selectOp = dyn_cast<arith::SelectOp>(op)) {
1708 auto cond = lookupScalar(selectOp.getCondition(), valueMap, selectOp.getOperation());
1709 auto trueValue = lookupScalar(selectOp.getTrueValue(), valueMap, selectOp.getOperation());
1710 auto falseValue = lookupScalar(selectOp.getFalseValue(), valueMap, selectOp.getOperation());
1711 if (failed(cond) || failed(trueValue) || failed(falseValue)) {
1712 return failure();
1713 }
1714 return bind(
1715 selectOp.getResult(),
1716 LoweredValue {
1717 selectOp.getType(),
1718 {builder.create<arith::SelectOp>(loc, *cond, *trueValue, *falseValue)}
1719 }
1720 );
1721 }
1722 if (auto addiOp = dyn_cast<arith::AddIOp>(op)) {
1723 auto lhs = lookupScalar(addiOp.getLhs(), valueMap, addiOp.getOperation());
1724 auto rhs = lookupScalar(addiOp.getRhs(), valueMap, addiOp.getOperation());
1725 if (failed(lhs) || failed(rhs)) {
1726 return failure();
1727 }
1728 return bind(
1729 addiOp.getResult(),
1730 LoweredValue {addiOp.getType(), {builder.create<arith::AddIOp>(loc, *lhs, *rhs)}}
1731 );
1732 }
1733 if (auto subiOp = dyn_cast<arith::SubIOp>(op)) {
1734 auto lhs = lookupScalar(subiOp.getLhs(), valueMap, subiOp.getOperation());
1735 auto rhs = lookupScalar(subiOp.getRhs(), valueMap, subiOp.getOperation());
1736 if (failed(lhs) || failed(rhs)) {
1737 return failure();
1738 }
1739 return bind(
1740 subiOp.getResult(),
1741 LoweredValue {subiOp.getType(), {builder.create<arith::SubIOp>(loc, *lhs, *rhs)}}
1742 );
1743 }
1744
1745 if (auto callOp = dyn_cast<function::CallOp>(op)) {
1746 if (callOp.getTemplateParams() || !callOp.getMapOperands().empty()) {
1747 callOp.emitError("execution-engine backend encountered an unflattened function.call");
1748 return failure();
1749 }
1750 auto *callable = callOp.resolveCallableInTable(&tables);
1751 auto callee = dyn_cast_or_null<function::FuncDefOp>(callable);
1752 if (!callee) {
1753 callOp.emitError("failed to resolve callee during execution-engine lowering");
1754 return failure();
1755 }
1756 SmallVector<Type> resultTypes;
1757 for (Type resultType : callOp.getResultTypes()) {
1758 if (failed(
1759 flattenABILeafTypes(resultType, tables, callOp.getOperation(), field, resultTypes)
1760 )) {
1761 return failure();
1762 }
1763 }
1764 SmallVector<Value> flatArgs;
1765 for (Value operand : callOp.getArgOperands()) {
1766 auto lowered = lookup(operand, valueMap, callOp.getOperation());
1767 auto leafTypes = getABILeafTypes(operand.getType(), tables, callOp.getOperation(), field);
1768 if (failed(lowered) || failed(leafTypes) ||
1769 failed(appendFlatLeavesToTypes(
1770 builder, loc, *lowered, *leafTypes, flatArgs, callOp.getOperation()
1771 ))) {
1772 return failure();
1773 }
1774 }
1775 auto loweredCall =
1776 builder.create<func::CallOp>(loc, mangleFunctionName(callee), resultTypes, flatArgs);
1777 auto loweredCallResults = loweredCall.getResults();
1778 size_t totalResults = loweredCallResults.size();
1779 size_t cursor = 0;
1780 for (auto [oldResult, oldType] : llvm::zip(callOp.getResults(), callOp.getResultTypes())) {
1781 auto leafCount = getLeafCount(oldType, tables, callOp.getOperation(), field);
1782 if (failed(leafCount)) {
1783 return failure();
1784 }
1785 bool overflow = false;
1786 size_t nextCursor = llvm::SaturatingAdd(cursor, *leafCount, &overflow);
1787 if (overflow || nextCursor > totalResults) {
1788 callOp.emitError("leaf count overflow while lowering function call results");
1789 return failure();
1790 }
1791 LoweredValue lowered {oldType, {}};
1792 lowered.leaves.append(
1793 loweredCallResults.begin() + static_cast<ptrdiff_t>(cursor),
1794 loweredCallResults.begin() + static_cast<ptrdiff_t>(nextCursor)
1795 );
1796 valueMap[oldResult] = std::move(lowered);
1797 cursor = nextCursor;
1798 }
1799 return success();
1800 }
1801
1802 if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
1803 SmallVector<Value> initArgs;
1804 SmallVector<size_t> beforeLeafCounts;
1805 for (auto [init, initType] : llvm::zip(whileOp.getInits(), whileOp.getOperandTypes())) {
1806 auto lowered = lookup(init, valueMap, whileOp.getOperation());
1807 auto leafTypes = getABILeafTypes(initType, tables, whileOp.getOperation(), field);
1808 if (failed(lowered) || failed(leafTypes) ||
1809 failed(appendFlatLeavesToTypes(
1810 builder, loc, *lowered, *leafTypes, initArgs, whileOp.getOperation()
1811 ))) {
1812 return failure();
1813 }
1814 auto count = getLeafCount(initType, tables, whileOp.getOperation(), field);
1815 if (failed(count)) {
1816 return failure();
1817 }
1818 beforeLeafCounts.push_back(*count);
1819 }
1820
1821 SmallVector<size_t> resultLeafCounts;
1822 SmallVector<Type> loweredResultTypes;
1823 for (Type resultType : whileOp.getResultTypes()) {
1824 auto leafTypes = getABILeafTypes(resultType, tables, whileOp.getOperation(), field);
1825 auto count = getLeafCount(resultType, tables, whileOp.getOperation(), field);
1826 if (failed(leafTypes) || failed(count)) {
1827 return failure();
1828 }
1829 loweredResultTypes.append(leafTypes->begin(), leafTypes->end());
1830 resultLeafCounts.push_back(*count);
1831 }
1832
1833 auto mapRegionArguments = [&](auto oldArgs, auto oldTypes, auto leafCounts, auto newArgs,
1834 StringRef overflowMessage,
1835 DenseMap<Value, LoweredValue> &regionMap) -> LogicalResult {
1836 size_t totalArgs = newArgs.size();
1837 size_t cursor = 0;
1838 for (auto [oldArg, oldType, leafCount] : llvm::zip(oldArgs, oldTypes, leafCounts)) {
1839 bool overflow = false;
1840 size_t nextCursor = llvm::SaturatingAdd(cursor, leafCount, &overflow);
1841 if (overflow || nextCursor > totalArgs) {
1842 whileOp.emitError(overflowMessage);
1843 return failure();
1844 }
1845 LoweredValue lowered {oldType, {}};
1846 lowered.leaves.append(
1847 newArgs.begin() + llzk::checkedCast<ptrdiff_t>(cursor),
1848 newArgs.begin() + llzk::checkedCast<ptrdiff_t>(nextCursor)
1849 );
1850 regionMap[oldArg] = std::move(lowered);
1851 cursor = nextCursor;
1852 }
1853 return success();
1854 };
1855
1856 LogicalResult whileLoweringStatus = success();
1857 auto newWhile = builder.create<scf::WhileOp>(
1858 loc, loweredResultTypes, initArgs,
1859 [&](OpBuilder &regionBuilder, Location /*regionLoc*/, ValueRange beforeArgs) {
1860 DenseMap<Value, LoweredValue> beforeMap(valueMap.begin(), valueMap.end());
1861 if (failed(mapRegionArguments(
1862 whileOp.getBeforeArguments(), whileOp.getOperandTypes(), beforeLeafCounts,
1863 beforeArgs, "leaf count overflow while lowering while-loop before-region args",
1864 beforeMap
1865 )) ||
1866 failed(lowerBlock(regionBuilder, whileOp.getBefore().front(), beforeMap))) {
1867 whileLoweringStatus = failure();
1868 }
1869 }, [&](OpBuilder &regionBuilder, Location /*regionLoc*/, ValueRange afterArgs) {
1870 DenseMap<Value, LoweredValue> afterMap(valueMap.begin(), valueMap.end());
1871 if (failed(mapRegionArguments(
1872 whileOp.getAfterArguments(), whileOp.getResultTypes(), resultLeafCounts, afterArgs,
1873 "leaf count overflow while lowering while-loop after-region args", afterMap
1874 )) ||
1875 failed(lowerBlock(regionBuilder, whileOp.getAfter().front(), afterMap))) {
1876 whileLoweringStatus = failure();
1877 }
1878 }
1879 );
1880 if (failed(whileLoweringStatus)) {
1881 newWhile.erase();
1882 return failure();
1883 }
1884
1885 auto newWhileResults = newWhile.getResults();
1886 size_t totalResults = newWhileResults.size();
1887 size_t cursor = 0;
1888 for (auto [oldResult, oldType, leafCount] :
1889 llvm::zip(whileOp.getResults(), whileOp.getResultTypes(), resultLeafCounts)) {
1890 bool overflow = false;
1891 size_t nextCursor = llvm::SaturatingAdd(cursor, leafCount, &overflow);
1892 if (overflow || nextCursor > totalResults) {
1893 whileOp.emitError("leaf count overflow while lowering while-loop results");
1894 return failure();
1895 }
1896 LoweredValue lowered {oldType, {}};
1897 lowered.leaves.append(
1898 newWhileResults.begin() + llzk::checkedCast<ptrdiff_t>(cursor),
1899 newWhileResults.begin() + llzk::checkedCast<ptrdiff_t>(nextCursor)
1900 );
1901 valueMap[oldResult] = std::move(lowered);
1902 cursor = nextCursor;
1903 }
1904 return success();
1905 }
1906
1907 if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
1908 auto condition = lookupScalar(ifOp.getCondition(), valueMap, ifOp.getOperation());
1909 if (failed(condition)) {
1910 return failure();
1911 }
1912
1913 SmallVector<size_t> resultLeafCounts;
1914 SmallVector<Type> loweredResultTypes;
1915 for (Type resultType : ifOp.getResultTypes()) {
1916 auto leafTypes = getABILeafTypes(resultType, tables, ifOp.getOperation(), field);
1917 auto count = getLeafCount(resultType, tables, ifOp.getOperation(), field);
1918 if (failed(leafTypes) || failed(count)) {
1919 return failure();
1920 }
1921 loweredResultTypes.append(leafTypes->begin(), leafTypes->end());
1922 resultLeafCounts.push_back(*count);
1923 }
1924
1925 auto newIf = builder.create<scf::IfOp>(
1926 loc, loweredResultTypes, *condition, true, !ifOp.getElseRegion().empty()
1927 );
1928
1929 {
1930 OpBuilder thenBuilder = OpBuilder::atBlockBegin(&newIf.getThenRegion().front());
1931 DenseMap<Value, LoweredValue> thenMap(valueMap.begin(), valueMap.end());
1932 if (failed(lowerBlock(thenBuilder, ifOp.getThenRegion().front(), thenMap))) {
1933 newIf.erase();
1934 return failure();
1935 }
1936 }
1937 if (!ifOp.getElseRegion().empty()) {
1938 OpBuilder elseBuilder = OpBuilder::atBlockBegin(&newIf.getElseRegion().front());
1939 DenseMap<Value, LoweredValue> elseMap(valueMap.begin(), valueMap.end());
1940 if (failed(lowerBlock(elseBuilder, ifOp.getElseRegion().front(), elseMap))) {
1941 newIf.erase();
1942 return failure();
1943 }
1944 }
1945 auto newIfResults = newIf.getResults();
1946 size_t totalResults = newIfResults.size();
1947 size_t cursor = 0;
1948 for (auto [oldResult, oldType, leafCount] :
1949 llvm::zip(ifOp.getResults(), ifOp.getResultTypes(), resultLeafCounts)) {
1950 bool overflow = false;
1951 size_t nextCursor = llvm::SaturatingAdd(cursor, leafCount, &overflow);
1952 if (overflow || nextCursor > totalResults) {
1953 ifOp.emitError("leaf count overflow while lowering if-op results");
1954 return failure();
1955 }
1956 LoweredValue lowered {oldType, {}};
1957 lowered.leaves.append(
1958 newIfResults.begin() + llzk::checkedCast<ptrdiff_t>(cursor),
1959 newIfResults.begin() + llzk::checkedCast<ptrdiff_t>(nextCursor)
1960 );
1961 valueMap[oldResult] = std::move(lowered);
1962 cursor = nextCursor;
1963 }
1964 return success();
1965 }
1966
1967 if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1968 auto lb = lookupScalar(forOp.getLowerBound(), valueMap, forOp.getOperation());
1969 auto ub = lookupScalar(forOp.getUpperBound(), valueMap, forOp.getOperation());
1970 auto step = lookupScalar(forOp.getStep(), valueMap, forOp.getOperation());
1971 if (failed(lb) || failed(ub) || failed(step)) {
1972 return failure();
1973 }
1974
1975 SmallVector<Value> initArgs;
1976 SmallVector<size_t> initLeafCounts;
1977 for (auto [init, resultType] : llvm::zip(forOp.getInitArgs(), forOp.getResultTypes())) {
1978 auto lowered = lookup(init, valueMap, forOp.getOperation());
1979 auto leafTypes = getABILeafTypes(resultType, tables, forOp.getOperation(), field);
1980 if (failed(lowered) || failed(leafTypes) ||
1981 failed(appendFlatLeavesToTypes(
1982 builder, loc, *lowered, *leafTypes, initArgs, forOp.getOperation()
1983 ))) {
1984 return failure();
1985 }
1986 auto count = getLeafCount(resultType, tables, forOp.getOperation(), field);
1987 if (failed(count)) {
1988 return failure();
1989 }
1990 initLeafCounts.push_back(*count);
1991 }
1992
1993 auto newFor = builder.create<scf::ForOp>(loc, *lb, *ub, *step, initArgs);
1994 if (Attribute unsignedCmpAttr = forOp->getAttr("unsignedCmp")) {
1995 newFor->setAttr("unsignedCmp", unsignedCmpAttr);
1996 }
1997 DenseMap<Value, LoweredValue> bodyMap(valueMap.begin(), valueMap.end());
1998 bodyMap[forOp.getInductionVar()] =
1999 LoweredValue {forOp.getInductionVar().getType(), {newFor.getInductionVar()}};
2000 {
2001 auto newForIterArgs = newFor.getRegionIterArgs();
2002 size_t totalIterArgs = newForIterArgs.size();
2003 size_t cursor = 0;
2004 for (auto [oldIterArg, oldType, leafCount] :
2005 llvm::zip(forOp.getRegionIterArgs(), forOp.getResultTypes(), initLeafCounts)) {
2006 bool overflow = false;
2007 size_t nextCursor = llvm::SaturatingAdd(cursor, leafCount, &overflow);
2008 if (overflow || nextCursor > totalIterArgs) {
2009 forOp.emitError("leaf count overflow while lowering for-loop region iter args");
2010 return failure();
2011 }
2012 LoweredValue lowered {oldType, {}};
2013 lowered.leaves.append(
2014 newForIterArgs.begin() + static_cast<ptrdiff_t>(cursor),
2015 newForIterArgs.begin() + static_cast<ptrdiff_t>(nextCursor)
2016 );
2017 bodyMap[oldIterArg] = std::move(lowered);
2018 cursor = nextCursor;
2019 }
2020 }
2021
2022 newFor.getBody()->clear();
2023 OpBuilder bodyBuilder = OpBuilder::atBlockBegin(newFor.getBody());
2024 if (failed(lowerBlock(bodyBuilder, *forOp.getBody(), bodyMap))) {
2025 return failure();
2026 }
2027
2028 {
2029 auto newForResults = newFor.getResults();
2030 size_t totalForResults = newForResults.size();
2031 size_t cursor = 0;
2032 for (auto [oldResult, oldType, leafCount] :
2033 llvm::zip(forOp.getResults(), forOp.getResultTypes(), initLeafCounts)) {
2034 bool overflow = false;
2035 size_t nextCursor = llvm::SaturatingAdd(cursor, leafCount, &overflow);
2036 if (overflow || nextCursor > totalForResults) {
2037 forOp.emitError("leaf count overflow while lowering for-loop results");
2038 return failure();
2039 }
2040 LoweredValue lowered {oldType, {}};
2041 lowered.leaves.append(
2042 newForResults.begin() + static_cast<ptrdiff_t>(cursor),
2043 newForResults.begin() + static_cast<ptrdiff_t>(nextCursor)
2044 );
2045 valueMap[oldResult] = std::move(lowered);
2046 cursor = nextCursor;
2047 }
2048 }
2049 return success();
2050 }
2051
2052 op.emitError("unsupported operation in execution-engine lowering: ") << op.getName();
2053 return failure();
2054 }
2055};
2056
2058class LowerComputeToCorePass : public PassWrapper<LowerComputeToCorePass, OperationPass<ModuleOp>> {
2059public:
2060 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerComputeToCorePass)
2061
2062 explicit LowerComputeToCorePass(const WitgenOptions &opts) : options(opts) {}
2063
2065 StringRef getArgument() const final { return "llzk-lower-compute-to-core"; }
2066
2068 StringRef getDescription() const final {
2069 return "Lower LLZK compute IR to func/arith/cf/scf/memref";
2070 }
2071
2073 StringRef getName() const override { return "LowerComputeToCorePass"; }
2074
2076 void runOnOperation() override {
2077 ModuleOp moduleOp = getOperation();
2078 auto field = getModuleField(moduleOp);
2079 if (failed(field)) {
2080 signalPassFailure();
2081 return;
2082 }
2083
2084 SymbolTableCollection tables;
2085 SmallVector<function::FuncDefOp> funcs;
2086 moduleOp.walk([&](function::FuncDefOp funcOp) {
2087 if (funcOp.nameIsConstrain()) {
2088 return;
2089 }
2090 funcs.push_back(funcOp);
2091 });
2092
2093 BodyLowerer lowerer(moduleOp, tables, field->get(), options);
2094 for (function::FuncDefOp funcOp : funcs) {
2095 if (failed(lowerer.lowerFunction(funcOp))) {
2096 signalPassFailure();
2097 return;
2098 }
2099 }
2100 }
2101
2102private:
2103 WitgenOptions options;
2104};
2105
2107class CreateWitgenEntryPass : public PassWrapper<CreateWitgenEntryPass, OperationPass<ModuleOp>> {
2108public:
2109 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CreateWitgenEntryPass)
2110
2111
2112 explicit CreateWitgenEntryPass(bool fullWitness = false) : emitFullWitness(fullWitness) {}
2113
2115 StringRef getArgument() const final { return "llzk-create-witgen-entry"; }
2116
2118 StringRef getDescription() const final {
2119 return "Create the llzk-witgen execution-engine entry wrapper";
2120 }
2121
2123 StringRef getName() const override { return "CreateWitgenEntryPass"; }
2124
2126 void runOnOperation() override {
2127 ModuleOp moduleOp = getOperation();
2128 auto field = getModuleField(moduleOp);
2129 if (failed(field)) {
2130 signalPassFailure();
2131 return;
2132 }
2133
2134 SymbolTableCollection tables;
2135 auto mainDef = getMainInstanceDef(tables, moduleOp.getOperation());
2136 if (failed(mainDef) || !mainDef.value()) {
2137 moduleOp.emitError("module is missing a concrete llzk.main struct");
2138 signalPassFailure();
2139 return;
2140 }
2141 function::FuncDefOp computeFunc = mainDef->get().getComputeFuncOp();
2142 if (!computeFunc) {
2143 moduleOp.emitError("main struct is missing @compute");
2144 signalPassFailure();
2145 return;
2146 }
2147
2148 auto outputs = collectOutputBindings(
2149 mainDef->get(), tables, computeFunc.getOperation(),
2150 emitFullWitness ? OutputScope::FullWitness : OutputScope::Public
2151 );
2152 if (failed(outputs)) {
2153 signalPassFailure();
2154 return;
2155 }
2156
2157 OpBuilder builder(moduleOp.getContext());
2158 builder.setInsertionPointToEnd(moduleOp.getBody());
2159
2160 SmallVector<Type> wrapperArgs;
2161 for (Type argType : computeFunc.getArgumentTypes()) {
2162 SmallVector<Type> loweredLeafTypes;
2163 if (failed(flattenTypeLeaves(
2164 argType, tables, computeFunc.getOperation(), field->get(), loweredLeafTypes, {}, true
2165 ))) {
2166 signalPassFailure();
2167 return;
2168 }
2169 if (loweredLeafTypes.size() != 1 || !isa<MemRefType>(loweredLeafTypes.front())) {
2170 computeFunc.emitError(
2171 "execution-engine wrapper only supports felt and array<...xfelt> inputs"
2172 );
2173 signalPassFailure();
2174 return;
2175 }
2176 wrapperArgs.push_back(loweredLeafTypes.front());
2177 }
2178 for (const OutputBinding &output : *outputs) {
2179 SmallVector<Type> loweredLeafTypes;
2180 if (failed(flattenTypeLeaves(
2181 output.type, tables, computeFunc.getOperation(), field->get(), loweredLeafTypes, {},
2182 true
2183 ))) {
2184 signalPassFailure();
2185 return;
2186 }
2187 if (loweredLeafTypes.size() != 1 || !isa<MemRefType>(loweredLeafTypes.front())) {
2188 computeFunc.emitError(
2189 "execution-engine wrapper only supports felt and array<...xfelt> outputs"
2190 );
2191 signalPassFailure();
2192 return;
2193 }
2194 wrapperArgs.push_back(loweredLeafTypes.front());
2195 }
2196
2197 auto wrapper = builder.create<func::FuncOp>(
2198 computeFunc.getLoc(), "__llzk_witgen_main",
2199 builder.getFunctionType(wrapperArgs, TypeRange {})
2200 );
2201 wrapper->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), builder.getUnitAttr());
2202 Block *entry = wrapper.addEntryBlock();
2203 builder.setInsertionPointToStart(entry);
2204
2205 SmallVector<Type> loweredMainResultTypes;
2206 for (Type resultType : computeFunc.getResultTypes()) {
2207 if (failed(flattenABILeafTypes(
2208 resultType, tables, computeFunc.getOperation(), field->get(), loweredMainResultTypes
2209 ))) {
2210 signalPassFailure();
2211 return;
2212 }
2213 }
2214
2215 SmallVector<Value> mainArgs;
2216 for (auto [argType, wrapperArg] : llvm::zip(
2217 computeFunc.getArgumentTypes(),
2218 entry->getArguments().take_front(computeFunc.getNumArguments())
2219 )) {
2220 if (isScalarType(argType)) {
2221 mainArgs.push_back(loadStorageScalar(builder, computeFunc.getLoc(), wrapperArg));
2222 } else {
2223 auto abiLeafTypes =
2224 getABILeafTypes(argType, tables, computeFunc.getOperation(), field->get());
2225 if (failed(abiLeafTypes) || abiLeafTypes->size() != 1 ||
2226 !isa<MemRefType>(abiLeafTypes->front())) {
2227 computeFunc.emitError("failed to derive execution-engine ABI type for main input");
2228 signalPassFailure();
2229 return;
2230 }
2231 if (wrapperArg.getType() == abiLeafTypes->front()) {
2232 mainArgs.push_back(wrapperArg);
2233 } else {
2234 mainArgs.push_back(builder.create<memref::CastOp>(
2235 computeFunc.getLoc(), abiLeafTypes->front(), wrapperArg
2236 ));
2237 }
2238 }
2239 }
2240 auto loweredMain = builder.create<func::CallOp>(
2241 computeFunc.getLoc(), mangleFunctionName(computeFunc), loweredMainResultTypes, mainArgs
2242 );
2243
2244 LoweredValue mainResultValue {
2245 computeFunc.getResultTypes().front(),
2246 llvm::SmallVector<Value>(loweredMain.getResults().begin(), loweredMain.getResults().end())
2247 };
2248
2249 auto extractOutputSlice = [&](ArrayRef<std::string> path, Type currentType,
2250 ArrayRef<Value> leaves,
2251 auto &self) -> FailureOr<SmallVector<Value>> {
2252 if (path.empty()) {
2253 return SmallVector<Value>(leaves.begin(), leaves.end());
2254 }
2255 if (auto structType = dyn_cast<component::StructType>(currentType)) {
2256 auto defLookup = structType.getDefinition(tables, computeFunc.getOperation());
2257 if (failed(defLookup)) {
2258 return failure();
2259 }
2260 unsigned localCursor = 0;
2261 for (component::MemberDefOp member : defLookup->get().getMemberDefs()) {
2262 auto leafCount =
2263 getLeafCount(member.getType(), tables, member.getOperation(), field->get());
2264 if (failed(leafCount)) {
2265 return failure();
2266 }
2267 ArrayRef<Value> slice = ArrayRef<Value>(leaves).slice(localCursor, *leafCount);
2268 localCursor += *leafCount;
2269 if (member.getSymName() == path.front()) {
2270 return self(path.drop_front(), member.getType(), slice, self);
2271 }
2272 }
2273 computeFunc.emitError("failed to find struct member while wiring witgen outputs");
2274 return failure();
2275 }
2276 if (auto podType = dyn_cast<pod::PodType>(currentType)) {
2277 unsigned localCursor = 0;
2278 for (pod::RecordAttr record : podType.getRecords()) {
2279 auto leafCount =
2280 getLeafCount(record.getType(), tables, computeFunc.getOperation(), field->get());
2281 if (failed(leafCount)) {
2282 return failure();
2283 }
2284 ArrayRef<Value> slice = ArrayRef<Value>(leaves).slice(localCursor, *leafCount);
2285 localCursor += *leafCount;
2286 if (record.getName().getValue() == path.front()) {
2287 return self(path.drop_front(), record.getType(), slice, self);
2288 }
2289 }
2290 computeFunc.emitError("failed to find POD record while wiring witgen outputs");
2291 return failure();
2292 }
2293 computeFunc.emitError("extra witness path components for non-aggregate output");
2294 return failure();
2295 };
2296
2297 auto outputArgs = entry->getArguments().drop_front(computeFunc.getNumArguments());
2298 for (auto [output, outputMemRef] : llvm::zip(*outputs, outputArgs)) {
2299 auto slice = extractOutputSlice(
2300 output.path, mainResultValue.sourceType, mainResultValue.leaves, extractOutputSlice
2301 );
2302 if (failed(slice) || slice->empty()) {
2303 wrapper.emitError("missing selected witness output slice while building witgen entry");
2304 signalPassFailure();
2305 return;
2306 }
2307 if (isScalarType(output.type)) {
2308 storeStorageScalar(
2309 builder, computeFunc.getLoc(),
2310 loadStorageScalar(builder, computeFunc.getLoc(), slice->front()), outputMemRef
2311 );
2312 } else {
2313 builder.create<memref::CopyOp>(computeFunc.getLoc(), slice->front(), outputMemRef);
2314 }
2315 }
2316 builder.create<func::ReturnOp>(computeFunc.getLoc());
2317
2318 // Remove `llzk.main` attribute because the main struct is deleted below.
2319 moduleOp->removeAttr(MAIN_ATTR_NAME);
2320
2321 SmallVector<Operation *> toErase;
2322 for (Operation &op : moduleOp.getBody()->getOperations()) {
2323 if (!isa<func::FuncOp>(op)) {
2324 toErase.push_back(&op);
2325 }
2326 }
2327 for (Operation *op : toErase) {
2328 op->erase();
2329 }
2330 }
2331
2332private:
2333 bool emitFullWitness;
2334};
2335
2336} // namespace
2337
2338void addWitgenPreparePipeline(OpPassManager &pm, const WitgenOptions &) {
2339 using namespace llzk::polymorphic;
2340 pm.addPass(createFlatteningPass(
2342 ));
2343 pm.addPass(mlir::createLowerAffinePass());
2344 // TODO: simplify lowering with `llzk-inline-structs` and `llzk-pod-to-scalar` when both are
2345 // available and support PODs.
2346 pm.addPass(mlir::createCanonicalizerPass());
2347 pm.addPass(mlir::createCSEPass());
2348}
2349
2350std::unique_ptr<Pass> createLowerComputeToCorePass(const WitgenOptions &options) {
2351 return std::make_unique<LowerComputeToCorePass>(options);
2352}
2353
2354std::unique_ptr<Pass> createCreateWitgenEntryPass(bool emitFullWitness) {
2355 return std::make_unique<CreateWitgenEntryPass>(emitFullWitness);
2356}
2357
2358} // namespace llzk::witgen
This file implements helper methods for constructing DynamicAPInts.
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation source
Definition LICENSE.txt:28
std::unique_ptr<::mlir::Pass > createFlatteningPass()
llvm::Expected< T > checkedCast(U u)
Definition WitgenUtils.h:28
std::mt19937_64 makeDefaultValueRng(const WitgenOptions &options)
Seed an RNG for random/default witness value materialization.
FailureOr< llvm::SmallVector< OutputBinding > > collectOutputBindings(component::StructDefOp mainDef, SymbolTableCollection &tables, Operation *origin, OutputScope scope)
Collect the selected output bindings for the requested scope.
std::unique_ptr< Pass > createCreateWitgenEntryPass(bool emitFullWitness)
Create the pass that synthesizes the stable llzk-witgen JIT entry wrapper.
void addWitgenPreparePipeline(OpPassManager &pm, const WitgenOptions &)
UninitializedBehavior
Control how witgen materializes uninitialized/default values.
Definition ValueModel.h:55
std::unique_ptr< Pass > createLowerComputeToCorePass(const WitgenOptions &options)
Create the pass that lowers supported LLZK compute IR into core MLIR dialects suitable for LLVM lower...
llvm::DynamicAPInt randomFieldElement(std::mt19937_64 &rng, const Field &field)
Draw a uniformly distributed field element in [0, prime).
bool randomBoolValue(std::mt19937_64 &rng)
Draw a uniformly distributed boolean value.
llvm::Expected< size_t > getStaticElementCount(ShapedType type, llvm::StringRef context)
int64_t randomIndexValue(std::mt19937_64 &rng)
Draw a uniformly distributed signed index value.
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::SmallVector< StringRef > getNames(SymbolRefAttr ref)
DynamicAPInt toDynamicAPInt(StringRef str)
constexpr T checkedCast(U u) noexcept
Definition Compare.h:81
APInt toExactWidthAPInt(const DynamicAPInt &val, unsigned bitWidth)
FailureOr< SymbolLookupResult< StructDefOp > > getMainInstanceDef(SymbolTableCollection &symbolTable, Operation *lookupFrom)
llvm::SmallSet< FieldRef, 2 > FieldSet
Typealias for a set of Fields.
Definition Field.h:159
ExpressionValue notOp(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
mlir::LogicalResult collectFields(mlir::Operation *root, FieldSet &fields, bool silent=true)
Collects all the fields used in a circuit.
Definition Field.cpp:264
Configure one llzk-witgen execution.