34#include <mlir/Conversion/AffineToStandard/AffineToStandard.h>
35#include <mlir/Dialect/Arith/IR/Arith.h>
36#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
37#include <mlir/Dialect/Func/IR/FuncOps.h>
38#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
39#include <mlir/Dialect/MemRef/IR/MemRef.h>
40#include <mlir/Dialect/SCF/IR/SCF.h>
41#include <mlir/Dialect/Utils/IndexingUtils.h>
42#include <mlir/IR/Builders.h>
43#include <mlir/IR/BuiltinAttributes.h>
44#include <mlir/IR/BuiltinOps.h>
45#include <mlir/IR/SymbolTable.h>
46#include <mlir/Pass/PassManager.h>
47#include <mlir/Transforms/Passes.h>
49#include <llvm/ADT/APInt.h>
50#include <llvm/ADT/STLExtras.h>
51#include <llvm/ADT/SmallString.h>
52#include <llvm/ADT/StringMap.h>
53#include <llvm/ADT/TypeSwitch.h>
54#include <llvm/Support/MathExtras.h>
66 llvm::SmallVector<Value> leaves;
70static FailureOr<std::reference_wrapper<const Field>> getModuleField(ModuleOp moduleOp) {
72 if (failed(
collectFields(moduleOp.getOperation(), fields,
false))) {
73 moduleOp.emitError(
"failed to collect fields for llzk-witgen lowering");
76 if (fields.size() != 1) {
77 moduleOp.emitError(
"llzk-witgen execution-engine lowering requires exactly one field");
80 return *fields.begin();
84static std::string mangleFunctionName(function::FuncDefOp funcOp) {
85 auto symbolRef = funcOp.getFullyQualifiedName(
false);
86 llvm::SmallString<128> result(
"__llzk_witgen_");
87 for (StringRef piece :
getNames(symbolRef)) {
88 if (!result.empty() && result.back() !=
'_') {
91 for (
char c : piece) {
92 result += llvm::isAlnum(c) ? c :
'_';
95 return std::string(result);
99static Value makeIndexConstant(OpBuilder &builder, Location loc, int64_t value) {
100 return builder.create<arith::ConstantIndexOp>(loc, value).getResult();
104static Value makeOneFelt(OpBuilder &builder, Location loc,
const Field &field) {
105 return builder.create<arith::ConstantOp>(
106 loc, IntegerAttr::get(IntegerType::get(builder.getContext(), field.bitWidth()), 1)
111static FailureOr<Type> lowerScalarType(MLIRContext *context, Type type,
const Field &field) {
112 if (isa<felt::FeltType>(type)) {
113 return IntegerType::get(context, field.bitWidth());
115 if (isa<IndexType>(type)) {
118 if (
auto intType = dyn_cast<IntegerType>(type)) {
119 if (intType.getWidth() == 1) {
127static bool isScalarType(Type type) {
128 return isa<felt::FeltType, IndexType>(type) ||
129 (isa<IntegerType>(type) && mlir::cast<IntegerType>(type).getWidth() == 1);
133static LogicalResult flattenTypeLeaves(
134 Type type, SymbolTableCollection &tables, Operation *origin,
const Field &field,
135 SmallVectorImpl<Type> &out, llvm::ArrayRef<int64_t> prefixShape = {},
bool storage = false
137 auto emitScalarLeaf = [&](Type leafType) {
138 auto lowered = lowerScalarType(origin->getContext(), leafType, field);
139 if (failed(lowered)) {
142 if (!storage && prefixShape.empty()) {
143 out.push_back(*lowered);
146 llvm::SmallVector<int64_t> shape(prefixShape.begin(), prefixShape.end());
150 out.push_back(MemRefType::get(shape, *lowered));
154 if (isScalarType(type)) {
155 return emitScalarLeaf(type);
158 if (
auto arrayType = dyn_cast<array::ArrayType>(type)) {
159 llvm::SmallVector<int64_t> newPrefix(prefixShape.begin(), prefixShape.end());
160 newPrefix.append(arrayType.getShape().begin(), arrayType.getShape().end());
161 return flattenTypeLeaves(
162 arrayType.getElementType(), tables, origin, field, out, newPrefix,
true
166 if (
auto podType = dyn_cast<pod::PodType>(type)) {
167 for (pod::RecordAttr record : podType.getRecords()) {
169 flattenTypeLeaves(record.getType(), tables, origin, field, out, prefixShape,
true)
177 if (
auto structType = dyn_cast<component::StructType>(type)) {
178 auto def = structType.getDefinition(tables, origin);
180 origin->emitError(
"could not resolve struct type during witgen lowering");
183 for (component::MemberDefOp member : def->get().getMemberDefs()) {
185 flattenTypeLeaves(member.getType(), tables, origin, field, out, prefixShape,
true)
193 origin->emitError(
"unsupported type in llzk-witgen lowering: ") << type;
199getStridedMemRefType(MLIRContext *context, ArrayRef<int64_t> shape, Type elementType) {
200 SmallVector<int64_t> strides(shape.size(), ShapedType::kDynamic);
201 return MemRefType::get(
202 shape, elementType, StridedLayoutAttr::get(context, ShapedType::kDynamic, strides)
207static LogicalResult flattenABILeafTypes(
208 Type type, SymbolTableCollection &tables, Operation *origin,
const Field &field,
209 SmallVectorImpl<Type> &out,
size_t prefixRank = 0,
bool aggregateStorage =
false
211 auto emitScalarLeaf = [&](Type leafType) {
212 auto lowered = lowerScalarType(origin->getContext(), leafType, field);
213 if (failed(lowered)) {
216 if (!aggregateStorage && prefixRank == 0) {
217 out.push_back(*lowered);
220 SmallVector<int64_t> shape;
221 if (prefixRank == 0) {
224 shape.assign(prefixRank, ShapedType::kDynamic);
226 out.push_back(getStridedMemRefType(origin->getContext(), shape, *lowered));
230 if (isScalarType(type)) {
231 return emitScalarLeaf(type);
234 if (
auto arrayType = dyn_cast<array::ArrayType>(type)) {
235 return flattenABILeafTypes(
236 arrayType.getElementType(), tables, origin, field, out, prefixRank + arrayType.getRank(),
241 if (
auto podType = dyn_cast<pod::PodType>(type)) {
242 for (pod::RecordAttr record : podType.getRecords()) {
244 flattenABILeafTypes(record.getType(), tables, origin, field, out, prefixRank,
true)
252 if (
auto structType = dyn_cast<component::StructType>(type)) {
253 auto def = structType.getDefinition(tables, origin);
255 origin->emitError(
"could not resolve struct type during witgen lowering");
258 for (component::MemberDefOp member : def->get().getMemberDefs()) {
260 flattenABILeafTypes(member.getType(), tables, origin, field, out, prefixRank,
true)
268 origin->emitError(
"unsupported type in llzk-witgen lowering: ") << type;
273static FailureOr<size_t>
274getLeafCount(Type type, SymbolTableCollection &tables, Operation *origin,
const Field &field) {
275 SmallVector<Type> leaves;
276 if (failed(flattenTypeLeaves(type, tables, origin, field, leaves))) {
279 return leaves.size();
283static FailureOr<SmallVector<Type>>
284getLeafTypes(Type type, SymbolTableCollection &tables, Operation *origin,
const Field &field) {
285 SmallVector<Type> leaves;
286 if (failed(flattenTypeLeaves(type, tables, origin, field, leaves))) {
293static FailureOr<SmallVector<Type>>
294getABILeafTypes(Type type, SymbolTableCollection &tables, Operation *origin,
const Field &field) {
295 SmallVector<Type> leaves;
296 if (failed(flattenABILeafTypes(type, tables, origin, field, leaves))) {
303static FailureOr<std::pair<size_t, size_t>> getNamedLeafSpan(
304 Type ownerType, StringRef name, SymbolTableCollection &tables, Operation *origin,
307 if (
auto podType = dyn_cast<pod::PodType>(ownerType)) {
309 for (pod::RecordAttr record : podType.getRecords()) {
310 auto count = getLeafCount(record.getType(), tables, origin, field);
314 if (record.getName().getValue() == name) {
315 return std::pair<size_t, size_t> {running, *count};
321 if (
auto structType = dyn_cast<component::StructType>(ownerType)) {
322 auto def = structType.getDefinition(tables, origin);
324 origin->emitError(
"could not resolve struct type during witgen lowering");
328 for (component::MemberDefOp member : def->get().getMemberDefs()) {
329 auto count = getLeafCount(member.getType(), tables, origin, field);
333 if (member.getSymName() == name) {
334 return std::pair<size_t, size_t> {running, *count};
340 origin->emitError(
"could not resolve aggregate member/record @") << name;
345static FailureOr<Type>
346getNamedSubType(Type ownerType, StringRef name, SymbolTableCollection &tables, Operation *origin) {
347 if (
auto podType = dyn_cast<pod::PodType>(ownerType)) {
348 for (pod::RecordAttr record : podType.getRecords()) {
349 if (record.getName().getValue() == name) {
350 return record.getType();
354 if (
auto structType = dyn_cast<component::StructType>(ownerType)) {
355 auto def = structType.getDefinition(tables, origin);
357 origin->emitError(
"could not resolve struct type during witgen lowering");
360 for (component::MemberDefOp member : def->get().getMemberDefs()) {
361 if (member.getSymName() == name) {
362 return member.getType();
366 origin->emitError(
"could not resolve aggregate member/record @") << name;
371static FailureOr<Value> createZeroMemRef(OpBuilder &builder, Location loc, MemRefType memrefType) {
374 emitError(loc) << llvm::toString(elementCount.takeError());
377 Value alloc = builder.create<memref::AllocOp>(loc, memrefType);
378 auto elementType = memrefType.getElementType();
380 if (isa<IndexType>(elementType)) {
381 zero = builder.create<arith::ConstantIndexOp>(loc, 0);
383 zero = builder.create<arith::ConstantOp>(
384 loc, IntegerAttr::get(mlir::cast<IntegerType>(elementType), 0)
387 auto strides = mlir::computeStrides(memrefType.getShape());
388 for (
size_t flat = 0; flat < *elementCount; ++flat) {
391 emitError(loc) << llvm::toString(flatSigned.takeError());
394 SmallVector<Value> indices;
395 for (int64_t index : mlir::delinearize(*flatSigned, strides)) {
396 indices.push_back(makeIndexConstant(builder, loc, index));
398 builder.create<memref::StoreOp>(loc, zero, alloc, indices);
404static FailureOr<Value> createRandomMemRef(
405 OpBuilder &builder, Location loc, MemRefType memrefType,
const Field &field,
410 emitError(loc) << llvm::toString(elementCount.takeError());
413 Value alloc = builder.create<memref::AllocOp>(loc, memrefType);
414 auto elementType = memrefType.getElementType();
415 auto strides = mlir::computeStrides(memrefType.getShape());
416 for (
size_t flat = 0; flat < *elementCount; ++flat) {
419 emitError(loc) << llvm::toString(flatSigned.takeError());
422 SmallVector<Value> indices;
423 for (int64_t index : mlir::delinearize(*flatSigned, strides)) {
424 indices.push_back(makeIndexConstant(builder, loc, index));
426 if (isa<IndexType>(elementType)) {
428 builder.create<memref::StoreOp>(
429 loc, builder.create<arith::ConstantIndexOp>(loc, value), alloc, indices
433 auto intType = mlir::cast<IntegerType>(elementType);
434 if (intType.getWidth() == 1) {
435 builder.create<memref::StoreOp>(
437 builder.create<arith::ConstantOp>(
445 builder.create<memref::StoreOp>(
447 builder.create<arith::ConstantOp>(
457static FailureOr<LoweredValue> createDefaultValue(
458 OpBuilder &builder, Location loc, Type type, SymbolTableCollection &tables, Operation *origin,
461 LoweredValue lowered {type, {}};
462 auto leafTypes = getLeafTypes(type, tables, origin, field);
463 if (failed(leafTypes)) {
466 for (Type leafType : *leafTypes) {
469 "fail-mode default materialization is unsupported in witgen lowering because it would "
470 "hide uninitialized reads"
475 if (
auto memrefType = dyn_cast<MemRefType>(leafType)) {
476 auto randomMemRef = createRandomMemRef(builder, loc, memrefType, field, rng);
477 if (failed(randomMemRef)) {
480 lowered.leaves.push_back(*randomMemRef);
483 if (isa<IndexType>(leafType)) {
484 lowered.leaves.push_back(
489 auto intType = mlir::cast<IntegerType>(leafType);
490 if (intType.getWidth() == 1) {
491 lowered.leaves.push_back(builder.create<arith::ConstantOp>(
497 lowered.leaves.push_back(builder.create<arith::ConstantOp>(
502 if (
auto memrefType = dyn_cast<MemRefType>(leafType)) {
503 auto zeroMemRef = createZeroMemRef(builder, loc, memrefType);
504 if (failed(zeroMemRef)) {
507 lowered.leaves.push_back(*zeroMemRef);
510 if (isa<IndexType>(leafType)) {
511 lowered.leaves.push_back(builder.create<arith::ConstantIndexOp>(loc, 0));
514 lowered.leaves.push_back(builder.create<arith::ConstantOp>(
515 loc, IntegerAttr::get(mlir::cast<IntegerType>(leafType), 0)
522static Value normalizeWideValue(
523 OpBuilder &builder, Location loc, Value wideValue,
unsigned dstWidth,
const Field &field
525 auto wideType = mlir::cast<IntegerType>(wideValue.getType());
526 Value modulus = builder.create<arith::ConstantOp>(
527 loc, field.getPrimeAttr(builder.getContext(), wideType.getWidth())
529 Value reduced = builder.create<arith::RemUIOp>(loc, wideValue, modulus);
530 return builder.create<arith::TruncIOp>(
531 loc, IntegerType::get(builder.getContext(), dstWidth), reduced
536static Value normalizeSignedWideValue(
537 OpBuilder &builder, Location loc, Value wideValue,
unsigned dstWidth,
const Field &field
539 auto wideType = mlir::cast<IntegerType>(wideValue.getType());
540 Value modulus = builder.create<arith::ConstantOp>(
541 loc, field.getPrimeAttr(builder.getContext(), wideType.getWidth())
543 Value reduced = builder.create<arith::RemSIOp>(loc, wideValue, modulus);
544 Value zero = builder.create<arith::ConstantOp>(loc, IntegerAttr::get(wideType, 0));
545 Value isNegative = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, reduced, zero);
546 Value adjusted = builder.create<arith::AddIOp>(loc, reduced, modulus);
547 Value canonical = builder.create<arith::SelectOp>(loc, isNegative, adjusted, reduced);
548 return builder.create<arith::TruncIOp>(
549 loc, IntegerType::get(builder.getContext(), dstWidth), canonical
555lowerFeltToSignedWide(OpBuilder &builder, Location loc, Value operand,
const Field &field) {
556 unsigned width = field.bitWidth();
557 unsigned wideWidth = width + 1;
558 auto feltType = IntegerType::get(builder.getContext(), width);
559 auto wideType = IntegerType::get(builder.getContext(), wideWidth);
560 Value operandWide = builder.create<arith::ExtUIOp>(loc, wideType, operand);
562 builder.create<arith::ConstantOp>(loc, field.getPrimeAttr(builder.getContext(), wideWidth));
563 Value half = builder.create<arith::ConstantOp>(
566 Value isNegative = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, operand, half);
567 Value signedOperand = builder.create<arith::SubIOp>(loc, operandWide, prime);
568 return builder.create<arith::SelectOp>(loc, isNegative, signedOperand, operandWide);
572static void assertNonZeroFelt(OpBuilder &builder, Location loc, Value operand, StringRef message) {
573 auto operandType = mlir::cast<IntegerType>(operand.getType());
574 Value zero = builder.create<arith::ConstantOp>(loc, IntegerAttr::get(operandType, 0));
575 Value isNonZero = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, operand, zero);
576 builder.create<cf::AssertOp>(loc, isNonZero, message);
581lowerFeltAdd(OpBuilder &builder, Location loc, Value lhs, Value rhs,
const Field &field) {
582 unsigned width = field.bitWidth();
583 unsigned wideWidth = width + 1;
584 auto wideType = IntegerType::get(builder.getContext(), wideWidth);
585 Value lhsWide = builder.create<arith::ExtUIOp>(loc, wideType, lhs);
586 Value rhsWide = builder.create<arith::ExtUIOp>(loc, wideType, rhs);
587 Value sum = builder.create<arith::AddIOp>(loc, lhsWide, rhsWide);
588 return normalizeWideValue(builder, loc, sum, width, field);
593lowerFeltSub(OpBuilder &builder, Location loc, Value lhs, Value rhs,
const Field &field) {
594 unsigned width = field.bitWidth();
595 unsigned wideWidth = width + 1;
596 auto wideType = IntegerType::get(builder.getContext(), wideWidth);
597 Value lhsWide = builder.create<arith::ExtUIOp>(loc, wideType, lhs);
598 Value rhsWide = builder.create<arith::ExtUIOp>(loc, wideType, rhs);
600 builder.create<arith::ConstantOp>(loc, field.getPrimeAttr(builder.getContext(), wideWidth));
601 Value lhsPlusMod = builder.create<arith::AddIOp>(loc, lhsWide, modulus);
602 Value diff = builder.create<arith::SubIOp>(loc, lhsPlusMod, rhsWide);
603 return normalizeWideValue(builder, loc, diff, width, field);
607static Value lowerFeltNeg(OpBuilder &builder, Location loc, Value operand,
const Field &field) {
608 unsigned width = field.bitWidth();
609 unsigned wideWidth = width + 1;
610 auto wideType = IntegerType::get(builder.getContext(), wideWidth);
611 Value operandWide = builder.create<arith::ExtUIOp>(loc, wideType, operand);
613 builder.create<arith::ConstantOp>(loc, field.getPrimeAttr(builder.getContext(), wideWidth));
614 Value diff = builder.create<arith::SubIOp>(loc, modulus, operandWide);
615 return normalizeWideValue(builder, loc, diff, width, field);
620lowerFeltMul(OpBuilder &builder, Location loc, Value lhs, Value rhs,
const Field &field) {
621 unsigned width = field.bitWidth();
622 unsigned wideWidth = width * 2;
623 auto wideType = IntegerType::get(builder.getContext(), wideWidth);
624 Value lhsWide = builder.create<arith::ExtUIOp>(loc, wideType, lhs);
625 Value rhsWide = builder.create<arith::ExtUIOp>(loc, wideType, rhs);
626 Value product = builder.create<arith::MulIOp>(loc, lhsWide, rhsWide);
627 return normalizeWideValue(builder, loc, product, width, field);
631static Value lowerFeltInv(OpBuilder &builder, Location loc, Value operand,
const Field &field) {
633 Value result = makeOneFelt(builder, loc, field);
634 Value base = operand;
635 for (
unsigned bit = 0; bit < exponent.getBitWidth(); ++bit) {
637 result = lowerFeltMul(builder, loc, result, base, field);
639 if (bit + 1 < exponent.getBitWidth()) {
640 base = lowerFeltMul(builder, loc, base, base, field);
648lowerFeltDiv(OpBuilder &builder, Location loc, Value lhs, Value rhs,
const Field &field) {
649 return lowerFeltMul(builder, loc, lhs, lowerFeltInv(builder, loc, rhs, field), field);
654lowerFeltPow(OpBuilder &builder, Location loc, Value base, Value exponent,
const Field &field) {
655 auto feltType = IntegerType::get(builder.getContext(), field.bitWidth());
656 Value zero = builder.create<arith::ConstantOp>(loc, IntegerAttr::get(feltType, 0));
657 Value one = builder.create<arith::ConstantOp>(loc, IntegerAttr::get(feltType, 1));
658 Value result = makeOneFelt(builder, loc, field);
659 Value currentBase = base;
660 for (
unsigned bit = 0; bit < field.bitWidth(); ++bit) {
661 Value bitIndex = builder.create<arith::ConstantOp>(
662 loc, IntegerAttr::get(feltType, llvm::APInt(field.bitWidth(), bit))
664 Value shifted = builder.create<arith::ShRUIOp>(loc, exponent, bitIndex);
665 Value masked = builder.create<arith::AndIOp>(loc, shifted, one);
666 Value bitIsSet = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, masked, zero);
667 auto ifOp = builder.create<scf::IfOp>(loc, TypeRange {feltType}, bitIsSet,
true);
669 OpBuilder::InsertionGuard guard(builder);
670 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
671 Value multiplied = lowerFeltMul(builder, loc, result, currentBase, field);
672 builder.create<scf::YieldOp>(loc, multiplied);
675 OpBuilder::InsertionGuard guard(builder);
676 builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
677 builder.create<scf::YieldOp>(loc, result);
679 result = ifOp.getResult(0);
680 if (bit + 1 < field.bitWidth()) {
681 currentBase = lowerFeltMul(builder, loc, currentBase, currentBase, field);
689lowerFeltShl(OpBuilder &builder, Location loc, Value lhs, Value rhs,
const Field &field) {
690 auto feltType = IntegerType::get(builder.getContext(), field.bitWidth());
691 Value two = builder.create<arith::ConstantOp>(loc, IntegerAttr::get(feltType, 2));
692 return lowerFeltMul(builder, loc, lhs, lowerFeltPow(builder, loc, two, rhs, field), field);
697lowerFeltOr(OpBuilder &builder, Location loc, Value lhs, Value rhs,
const Field &field) {
698 unsigned width = field.bitWidth();
699 auto wideType = IntegerType::get(builder.getContext(), width + 1);
700 Value orValue = builder.create<arith::OrIOp>(loc, lhs, rhs);
701 Value orWide = builder.create<arith::ExtUIOp>(loc, wideType, orValue);
702 return normalizeWideValue(builder, loc, orWide, width, field);
707lowerFeltXor(OpBuilder &builder, Location loc, Value lhs, Value rhs,
const Field &field) {
708 unsigned width = field.bitWidth();
709 auto wideType = IntegerType::get(builder.getContext(), width + 1);
710 Value xorValue = builder.create<arith::XOrIOp>(loc, lhs, rhs);
711 Value xorWide = builder.create<arith::ExtUIOp>(loc, wideType, xorValue);
712 return normalizeWideValue(builder, loc, xorWide, width, field);
716static Value lowerFeltUnsignedDiv(OpBuilder &builder, Location loc, Value lhs, Value rhs) {
717 return builder.create<arith::DivUIOp>(loc, lhs, rhs);
722lowerFeltSignedDiv(OpBuilder &builder, Location loc, Value lhs, Value rhs,
const Field &field) {
723 unsigned width = field.bitWidth();
724 Value lhsSigned = lowerFeltToSignedWide(builder, loc, lhs, field);
725 Value rhsSigned = lowerFeltToSignedWide(builder, loc, rhs, field);
726 Value quotient = builder.create<arith::DivSIOp>(loc, lhsSigned, rhsSigned);
727 return normalizeSignedWideValue(builder, loc, quotient, width, field);
731static Value lowerFeltUnsignedMod(OpBuilder &builder, Location loc, Value lhs, Value rhs) {
732 return builder.create<arith::RemUIOp>(loc, lhs, rhs);
737lowerFeltSignedMod(OpBuilder &builder, Location loc, Value lhs, Value rhs,
const Field &field) {
738 unsigned width = field.bitWidth();
739 Value lhsSigned = lowerFeltToSignedWide(builder, loc, lhs, field);
740 Value rhsSigned = lowerFeltToSignedWide(builder, loc, rhs, field);
741 Value remainder = builder.create<arith::RemSIOp>(loc, lhsSigned, rhsSigned);
742 return normalizeSignedWideValue(builder, loc, remainder, width, field);
747lowerFeltShr(OpBuilder &builder, Location loc, Value lhs, Value rhs,
const Field &field) {
748 auto feltType = IntegerType::get(builder.getContext(), field.bitWidth());
749 Value width = builder.create<arith::ConstantOp>(
750 loc, IntegerAttr::get(feltType, llvm::APInt(field.bitWidth(), field.bitWidth()))
752 Value shiftTooLarge = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge, rhs, width);
753 Value zero = builder.create<arith::ConstantOp>(loc, IntegerAttr::get(feltType, 0));
754 Value maxValidShift = builder.create<arith::ConstantOp>(
755 loc, IntegerAttr::get(feltType, llvm::APInt(field.bitWidth(), field.bitWidth() - 1))
757 Value clampedShift = builder.create<arith::MinUIOp>(loc, rhs, maxValidShift);
758 Value shifted = builder.create<arith::ShRUIOp>(loc, lhs, clampedShift);
759 return builder.create<arith::SelectOp>(loc, shiftTooLarge, zero, shifted);
763static Value lowerFeltNot(OpBuilder &builder, Location loc, Value operand,
const Field &field) {
764 unsigned width = field.bitWidth();
765 auto feltType = IntegerType::get(builder.getContext(), width);
766 auto wideType = IntegerType::get(builder.getContext(), width + 1);
767 Value maxMask = builder.create<arith::ConstantOp>(
768 loc, IntegerAttr::get(feltType, llvm::APInt::getAllOnes(width))
770 Value complement = builder.create<arith::XOrIOp>(loc, operand, maxMask);
771 Value complementWide = builder.create<arith::ExtUIOp>(loc, wideType, complement);
772 return normalizeWideValue(builder, loc, complementWide, width, field);
776static Value loadStorageScalar(OpBuilder &builder, Location loc, Value storageLeaf) {
777 auto memrefType = mlir::cast<MemRefType>(storageLeaf.getType());
778 SmallVector<Value> indices;
779 indices.reserve(memrefType.getRank());
780 for (int64_t dim = 0; dim < memrefType.getRank(); ++dim) {
781 indices.push_back(makeIndexConstant(builder, loc, 0));
783 return builder.create<memref::LoadOp>(loc, storageLeaf, indices);
787static void storeStorageScalar(OpBuilder &builder, Location loc, Value scalar, Value storageLeaf) {
788 auto memrefType = mlir::cast<MemRefType>(storageLeaf.getType());
789 SmallVector<Value> indices;
790 indices.reserve(memrefType.getRank());
791 for (int64_t dim = 0; dim < memrefType.getRank(); ++dim) {
792 indices.push_back(makeIndexConstant(builder, loc, 0));
794 builder.create<memref::StoreOp>(loc, scalar, storageLeaf, indices);
798static LogicalResult copyIntoStorage(
799 OpBuilder &builder, Location loc, Type sourceType, ArrayRef<Value> destLeaves,
800 ArrayRef<Value> sourceLeaves, SymbolTableCollection &tables, Operation *origin,
803 auto leafTypes = getLeafTypes(sourceType, tables, origin, field);
804 if (failed(leafTypes)) {
807 if (destLeaves.size() != sourceLeaves.size() || destLeaves.size() != leafTypes->size()) {
808 origin->emitError(
"flattened leaf mismatch while copying aggregate storage");
811 for (
auto [leafType, destLeaf, srcLeaf] : llvm::zip(*leafTypes, destLeaves, sourceLeaves)) {
812 if (isa<MemRefType>(leafType)) {
813 builder.create<memref::CopyOp>(loc, srcLeaf, destLeaf);
816 storeStorageScalar(builder, loc, srcLeaf, destLeaf);
822static FailureOr<LoweredValue> readNamedAggregateValue(
823 OpBuilder &builder, Location loc, Type ownerType, StringRef name,
const LoweredValue &owner,
824 SymbolTableCollection &tables, Operation *origin,
const Field &field
826 auto subType = getNamedSubType(ownerType, name, tables, origin);
827 if (failed(subType)) {
830 auto span = getNamedLeafSpan(ownerType, name, tables, origin, field);
834 LoweredValue result {*subType, {}};
835 auto leafTypes = getLeafTypes(*subType, tables, origin, field);
836 if (failed(leafTypes)) {
839 auto leaves = ArrayRef<Value>(owner.leaves).slice(span->first, span->second);
840 for (
auto [leafType, leafValue] : llvm::zip(*leafTypes, leaves)) {
841 if (isa<MemRefType>(leafType)) {
842 result.leaves.push_back(leafValue);
844 result.leaves.push_back(loadStorageScalar(builder, loc, leafValue));
851static LogicalResult writeNamedAggregateValue(
852 OpBuilder &builder, Location loc, Type ownerType, StringRef name, LoweredValue &owner,
853 const LoweredValue &value, SymbolTableCollection &tables, Operation *origin,
const Field &field
855 auto subType = getNamedSubType(ownerType, name, tables, origin);
856 if (failed(subType)) {
859 auto span = getNamedLeafSpan(ownerType, name, tables, origin, field);
863 return copyIntoStorage(
864 builder, loc, *subType, ArrayRef<Value>(owner.leaves).slice(span->first, span->second),
865 value.leaves, tables, origin, field
870static FailureOr<Value>
871createElementSubview(OpBuilder &builder, Location loc, Value
source, ValueRange outerIndices) {
872 auto sourceType = mlir::cast<MemRefType>(
source.getType());
873 SmallVector<OpFoldResult> mixedOffsets;
874 SmallVector<OpFoldResult> mixedSizes;
875 SmallVector<OpFoldResult> mixedStrides;
878 emitError(loc) << llvm::toString(indexedRank.takeError());
881 mixedOffsets.reserve(sourceType.getRank());
882 mixedSizes.reserve(sourceType.getRank());
883 mixedStrides.reserve(sourceType.getRank());
884 for (Value index : outerIndices) {
885 mixedOffsets.push_back(index);
887 for (int64_t dim = *indexedRank; dim < sourceType.getRank(); ++dim) {
888 mixedOffsets.push_back(builder.getIndexAttr(0));
890 for (int64_t dim = 0; dim < *indexedRank; ++dim) {
891 mixedSizes.push_back(builder.getIndexAttr(1));
893 for (int64_t dim = *indexedRank; dim < sourceType.getRank(); ++dim) {
894 mixedSizes.push_back(memref::getMixedSize(builder, loc,
source, dim));
896 for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) {
897 mixedStrides.push_back(builder.getIndexAttr(1));
899 SmallVector<int64_t> desiredShape;
902 emitError(loc) << llvm::toString(reserveSize.takeError());
905 desiredShape.reserve(*reserveSize);
906 for (int64_t dim = *indexedRank; dim < sourceType.getRank(); ++dim) {
909 emitError(loc) << llvm::toString(dimIndex.takeError());
912 if (
auto attr = llvm::dyn_cast<Attribute>(mixedSizes[*dimIndex])) {
913 desiredShape.push_back(mlir::cast<IntegerAttr>(attr).getInt());
915 desiredShape.push_back(ShapedType::kDynamic);
918 if (desiredShape.empty()) {
919 desiredShape.push_back(1);
921 auto resultType = mlir::cast<MemRefType>(memref::SubViewOp::inferRankReducedResultType(
922 desiredShape, sourceType, mixedOffsets, mixedSizes, mixedStrides
924 auto op = builder.create<memref::SubViewOp>(
925 loc, resultType,
source, mixedOffsets, mixedSizes, mixedStrides
927 return success(op.getResult());
931static FailureOr<LoweredValue> readArrayElement(
932 OpBuilder &builder, Location loc, array::ArrayType arrayType,
const LoweredValue &arrayValue,
933 ArrayRef<Value> indices
935 Type elementType = arrayType.getElementType();
936 LoweredValue result {elementType, {}};
937 if (isScalarType(elementType)) {
938 result.leaves.push_back(
939 builder.create<memref::LoadOp>(loc, arrayValue.leaves.front(), indices)
944 for (Value sourceLeaf : arrayValue.leaves) {
945 auto subview = createElementSubview(builder, loc, sourceLeaf, indices);
946 if (failed(subview)) {
949 result.leaves.push_back(*subview);
955static LogicalResult writeArrayElement(
956 OpBuilder &builder, Location loc, array::ArrayType arrayType, LoweredValue &arrayValue,
957 ArrayRef<Value> indices,
const LoweredValue &elementValue
959 Type elementType = arrayType.getElementType();
960 if (isScalarType(elementType)) {
961 builder.create<memref::StoreOp>(
962 loc, elementValue.leaves.front(), arrayValue.leaves.front(), indices
967 for (
auto [destLeaf, srcLeaf] : llvm::zip(arrayValue.leaves, elementValue.leaves)) {
968 auto subview = createElementSubview(builder, loc, destLeaf, indices);
969 if (failed(subview)) {
972 builder.create<memref::CopyOp>(loc, srcLeaf, *subview);
978static LogicalResult appendFlatLeavesToTypes(
979 OpBuilder &builder, Location loc,
const LoweredValue &value, ArrayRef<Type> targetLeafTypes,
980 SmallVectorImpl<Value> &out, Operation *origin
982 if (targetLeafTypes.size() != value.leaves.size()) {
983 origin->emitError(
"flattened leaf mismatch during call lowering");
986 for (
auto [leafValue, leafType] : llvm::zip(value.leaves, targetLeafTypes)) {
987 if (leafValue.getType() == leafType) {
988 out.push_back(leafValue);
991 if (isa<MemRefType>(leafValue.getType()) && isa<MemRefType>(leafType)) {
992 out.push_back(builder.create<memref::CastOp>(loc, leafType, leafValue));
995 origin->emitError(
"lowered leaf type mismatch during call lowering");
1006 ModuleOp
mod, SymbolTableCollection &symbolTables,
const Field &moduleField,
1007 const WitgenOptions &options
1009 : moduleOp(
mod), tables(symbolTables), field(moduleField),
1010 uninitializedBehavior(options.uninitializedBehavior), rng(
makeDefaultValueRng(options)) {}
1013 FailureOr<func::FuncOp> lowerFunction(function::FuncDefOp funcOp) {
1014 if (funcOp.isExternal()) {
1015 funcOp.emitError(
"execution-engine backend does not lower extern functions");
1018 if (!funcOp.getBody().hasOneBlock()) {
1019 funcOp.emitError(
"execution-engine backend only supports single-block functions");
1023 SmallVector<Type> loweredArgTypes;
1024 for (Type argType : funcOp.getArgumentTypes()) {
1026 flattenABILeafTypes(argType, tables, funcOp.getOperation(), field, loweredArgTypes)
1031 SmallVector<Type> loweredResultTypes;
1032 for (Type resultType : funcOp.getResultTypes()) {
1033 if (failed(flattenABILeafTypes(
1034 resultType, tables, funcOp.getOperation(), field, loweredResultTypes
1040 OpBuilder moduleBuilder(moduleOp.getContext());
1041 moduleBuilder.setInsertionPointToEnd(moduleOp.getBody());
1042 auto loweredFunc = moduleBuilder.create<func::FuncOp>(
1043 funcOp.getLoc(), mangleFunctionName(funcOp),
1044 moduleBuilder.getFunctionType(loweredArgTypes, loweredResultTypes)
1046 Block *entry = loweredFunc.addEntryBlock();
1047 OpBuilder builder(entry, entry->begin());
1049 DenseMap<Value, LoweredValue> valueMap;
1050 unsigned cursor = 0;
1051 for (
auto [arg, argType] :
1052 llvm::zip(funcOp.getBody().front().getArguments(), funcOp.getArgumentTypes())) {
1053 auto leafCount = getLeafCount(argType, tables, funcOp.getOperation(), field);
1054 if (failed(leafCount)) {
1055 loweredFunc.erase();
1058 LoweredValue lowered {argType, {}};
1059 lowered.leaves.append(
1060 entry->getArguments().begin() + cursor,
1061 entry->getArguments().begin() + cursor + *leafCount
1063 cursor += *leafCount;
1064 valueMap[arg] = std::move(lowered);
1067 if (failed(lowerBlock(builder, funcOp.getBody().front(), valueMap))) {
1068 loweredFunc.erase();
1076 SymbolTableCollection &tables;
1079 std::mt19937_64 rng;
1082 FailureOr<LoweredValue>
1083 lookup(Value value, DenseMap<Value, LoweredValue> &valueMap, Operation *origin) {
1084 auto it = valueMap.find(value);
1085 if (it == valueMap.end()) {
1086 origin->emitError(
"failed to find lowered SSA value");
1094 lookupScalar(Value value, DenseMap<Value, LoweredValue> &valueMap, Operation *origin) {
1095 auto lowered = lookup(value, valueMap, origin);
1096 if (failed(lowered) || lowered->leaves.size() != 1 ||
1097 isa<MemRefType>(lowered->leaves.front().getType())) {
1098 origin->emitError(
"expected scalar lowered value");
1101 return lowered->leaves.front();
1106 lowerBlock(OpBuilder &builder, Block &block, DenseMap<Value, LoweredValue> &valueMap) {
1107 for (Operation &op : block) {
1108 if (failed(lowerOperation(builder, op, valueMap))) {
1117 lowerFeltCmp(OpBuilder &builder, Location loc, boolean::CmpOp cmpOp, Value lhs, Value rhs) {
1118 arith::CmpIPredicate predicate;
1119 switch (cmpOp.getPredicate()) {
1121 predicate = arith::CmpIPredicate::eq;
1124 predicate = arith::CmpIPredicate::ne;
1127 predicate = arith::CmpIPredicate::ult;
1130 predicate = arith::CmpIPredicate::ule;
1133 predicate = arith::CmpIPredicate::ugt;
1136 predicate = arith::CmpIPredicate::uge;
1139 return builder.create<arith::CmpIOp>(loc, predicate, lhs, rhs).getResult();
1144 lowerOperation(OpBuilder &builder, Operation &op, DenseMap<Value, LoweredValue> &valueMap) {
1145 Location loc = op.getLoc();
1147 auto bind = [&](Value result, LoweredValue lowered) {
1148 valueMap[result] = std::move(lowered);
1152 if (
auto returnOp = dyn_cast<function::ReturnOp>(op)) {
1153 SmallVector<Value> results;
1154 for (Value operand : returnOp.getOperands()) {
1155 auto lowered = lookup(operand, valueMap, returnOp.getOperation());
1156 auto leafTypes = getABILeafTypes(operand.getType(), tables, returnOp.getOperation(), field);
1157 if (failed(lowered) || failed(leafTypes) ||
1158 failed(appendFlatLeavesToTypes(
1159 builder, loc, *lowered, *leafTypes, results, returnOp.getOperation()
1164 builder.create<func::ReturnOp>(loc, results);
1168 if (
auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1169 SmallVector<Value> results;
1170 for (Value operand : yieldOp.getOperands()) {
1171 auto lowered = lookup(operand, valueMap, yieldOp.getOperation());
1172 auto leafTypes = getABILeafTypes(operand.getType(), tables, yieldOp.getOperation(), field);
1173 if (failed(lowered) || failed(leafTypes) ||
1174 failed(appendFlatLeavesToTypes(
1175 builder, loc, *lowered, *leafTypes, results, yieldOp.getOperation()
1180 builder.create<scf::YieldOp>(loc, results);
1183 if (
auto conditionOp = dyn_cast<scf::ConditionOp>(op)) {
1185 lookupScalar(conditionOp.getCondition(), valueMap, conditionOp.getOperation());
1186 if (failed(condition)) {
1189 SmallVector<Value> results;
1190 for (Value operand : conditionOp.getArgs()) {
1191 auto lowered = lookup(operand, valueMap, conditionOp.getOperation());
1193 getABILeafTypes(operand.getType(), tables, conditionOp.getOperation(), field);
1194 if (failed(lowered) || failed(leafTypes) ||
1195 failed(appendFlatLeavesToTypes(
1196 builder, loc, *lowered, *leafTypes, results, conditionOp.getOperation()
1201 builder.create<scf::ConditionOp>(loc, *condition, results);
1205 if (
auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1206 Operation *clone = builder.clone(op);
1208 constantOp.getResult(), LoweredValue {constantOp.getType(), {clone->getResult(0)}}
1212 if (
auto feltConst = dyn_cast<felt::FeltConstantOp>(op)) {
1213 auto intType = IntegerType::get(builder.getContext(), field.bitWidth());
1216 auto modVal = constVal % field.prime();
1218 Value lowered = builder.create<arith::ConstantOp>(loc, IntegerAttr::get(intType, intVal));
1219 return bind(feltConst.getResult(), LoweredValue {feltConst.getType(), {lowered}});
1222 if (
auto nondetOp = dyn_cast<llzk::NonDetOp>(op)) {
1223 auto lowered = createDefaultValue(
1224 builder, loc, nondetOp.getType(), tables, nondetOp.getOperation(), field,
1225 uninitializedBehavior, rng
1227 if (failed(lowered)) {
1230 return bind(nondetOp.getResult(), std::move(*lowered));
1233 if (
auto addOp = dyn_cast<felt::AddFeltOp>(op)) {
1234 auto lhs = lookupScalar(addOp.getLhs(), valueMap, addOp.getOperation());
1235 auto rhs = lookupScalar(addOp.getRhs(), valueMap, addOp.getOperation());
1236 if (failed(lhs) || failed(rhs)) {
1241 LoweredValue {addOp.getType(), {lowerFeltAdd(builder, loc, *lhs, *rhs, field)}}
1244 if (
auto powOp = dyn_cast<felt::PowFeltOp>(op)) {
1245 auto lhs = lookupScalar(powOp.getLhs(), valueMap, powOp.getOperation());
1246 auto rhs = lookupScalar(powOp.getRhs(), valueMap, powOp.getOperation());
1247 if (failed(lhs) || failed(rhs)) {
1252 LoweredValue {powOp.getType(), {lowerFeltPow(builder, loc, *lhs, *rhs, field)}}
1255 if (
auto andOp = dyn_cast<felt::AndFeltOp>(op)) {
1256 auto lhs = lookupScalar(andOp.getLhs(), valueMap, andOp.getOperation());
1257 auto rhs = lookupScalar(andOp.getRhs(), valueMap, andOp.getOperation());
1258 if (failed(lhs) || failed(rhs)) {
1263 LoweredValue {andOp.getType(), {builder.create<arith::AndIOp>(loc, *lhs, *rhs)}}
1266 if (
auto orOp = dyn_cast<felt::OrFeltOp>(op)) {
1267 auto lhs = lookupScalar(orOp.getLhs(), valueMap, orOp.getOperation());
1268 auto rhs = lookupScalar(orOp.getRhs(), valueMap, orOp.getOperation());
1269 if (failed(lhs) || failed(rhs)) {
1274 LoweredValue {orOp.getType(), {lowerFeltOr(builder, loc, *lhs, *rhs, field)}}
1277 if (
auto xorOp = dyn_cast<felt::XorFeltOp>(op)) {
1278 auto lhs = lookupScalar(xorOp.getLhs(), valueMap, xorOp.getOperation());
1279 auto rhs = lookupScalar(xorOp.getRhs(), valueMap, xorOp.getOperation());
1280 if (failed(lhs) || failed(rhs)) {
1285 LoweredValue {xorOp.getType(), {lowerFeltXor(builder, loc, *lhs, *rhs, field)}}
1288 if (
auto subOp = dyn_cast<felt::SubFeltOp>(op)) {
1289 auto lhs = lookupScalar(subOp.getLhs(), valueMap, subOp.getOperation());
1290 auto rhs = lookupScalar(subOp.getRhs(), valueMap, subOp.getOperation());
1291 if (failed(lhs) || failed(rhs)) {
1296 LoweredValue {subOp.getType(), {lowerFeltSub(builder, loc, *lhs, *rhs, field)}}
1299 if (
auto mulOp = dyn_cast<felt::MulFeltOp>(op)) {
1300 auto lhs = lookupScalar(mulOp.getLhs(), valueMap, mulOp.getOperation());
1301 auto rhs = lookupScalar(mulOp.getRhs(), valueMap, mulOp.getOperation());
1302 if (failed(lhs) || failed(rhs)) {
1307 LoweredValue {mulOp.getType(), {lowerFeltMul(builder, loc, *lhs, *rhs, field)}}
1310 if (
auto negOp = dyn_cast<felt::NegFeltOp>(op)) {
1311 auto operand = lookupScalar(negOp.getOperand(), valueMap, negOp.getOperation());
1312 if (failed(operand)) {
1317 LoweredValue {negOp.getType(), {lowerFeltNeg(builder, loc, *operand, field)}}
1320 if (
auto invOp = dyn_cast<felt::InvFeltOp>(op)) {
1321 auto operand = lookupScalar(invOp.getOperand(), valueMap, invOp.getOperation());
1322 if (failed(operand)) {
1327 LoweredValue {invOp.getType(), {lowerFeltInv(builder, loc, *operand, field)}}
1330 if (
auto divOp = dyn_cast<felt::DivFeltOp>(op)) {
1331 auto lhs = lookupScalar(divOp.getLhs(), valueMap, divOp.getOperation());
1332 auto rhs = lookupScalar(divOp.getRhs(), valueMap, divOp.getOperation());
1333 if (failed(lhs) || failed(rhs)) {
1338 LoweredValue {divOp.getType(), {lowerFeltDiv(builder, loc, *lhs, *rhs, field)}}
1341 if (
auto uintDivOp = dyn_cast<felt::UnsignedIntDivFeltOp>(op)) {
1342 auto lhs = lookupScalar(uintDivOp.getLhs(), valueMap, uintDivOp.getOperation());
1343 auto rhs = lookupScalar(uintDivOp.getRhs(), valueMap, uintDivOp.getOperation());
1344 if (failed(lhs) || failed(rhs)) {
1347 assertNonZeroFelt(builder, loc, *rhs,
"felt.uintdiv divisor must be non-zero");
1349 uintDivOp.getResult(),
1350 LoweredValue {uintDivOp.getType(), {lowerFeltUnsignedDiv(builder, loc, *lhs, *rhs)}}
1353 if (
auto sintDivOp = dyn_cast<felt::SignedIntDivFeltOp>(op)) {
1354 auto lhs = lookupScalar(sintDivOp.getLhs(), valueMap, sintDivOp.getOperation());
1355 auto rhs = lookupScalar(sintDivOp.getRhs(), valueMap, sintDivOp.getOperation());
1356 if (failed(lhs) || failed(rhs)) {
1359 assertNonZeroFelt(builder, loc, *rhs,
"felt.sintdiv divisor must be non-zero");
1361 sintDivOp.getResult(),
1362 LoweredValue {sintDivOp.getType(), {lowerFeltSignedDiv(builder, loc, *lhs, *rhs, field)}}
1365 if (
auto umodOp = dyn_cast<felt::UnsignedModFeltOp>(op)) {
1366 auto lhs = lookupScalar(umodOp.getLhs(), valueMap, umodOp.getOperation());
1367 auto rhs = lookupScalar(umodOp.getRhs(), valueMap, umodOp.getOperation());
1368 if (failed(lhs) || failed(rhs)) {
1371 assertNonZeroFelt(builder, loc, *rhs,
"felt.umod divisor must be non-zero");
1374 LoweredValue {umodOp.getType(), {lowerFeltUnsignedMod(builder, loc, *lhs, *rhs)}}
1377 if (
auto smodOp = dyn_cast<felt::SignedModFeltOp>(op)) {
1378 auto lhs = lookupScalar(smodOp.getLhs(), valueMap, smodOp.getOperation());
1379 auto rhs = lookupScalar(smodOp.getRhs(), valueMap, smodOp.getOperation());
1380 if (failed(lhs) || failed(rhs)) {
1383 assertNonZeroFelt(builder, loc, *rhs,
"felt.smod divisor must be non-zero");
1386 LoweredValue {smodOp.getType(), {lowerFeltSignedMod(builder, loc, *lhs, *rhs, field)}}
1389 if (
auto shrOp = dyn_cast<felt::ShrFeltOp>(op)) {
1390 auto lhs = lookupScalar(shrOp.getLhs(), valueMap, shrOp.getOperation());
1391 auto rhs = lookupScalar(shrOp.getRhs(), valueMap, shrOp.getOperation());
1392 if (failed(lhs) || failed(rhs)) {
1397 LoweredValue {shrOp.getType(), {lowerFeltShr(builder, loc, *lhs, *rhs, field)}}
1400 if (
auto shlOp = dyn_cast<felt::ShlFeltOp>(op)) {
1401 auto lhs = lookupScalar(shlOp.getLhs(), valueMap, shlOp.getOperation());
1402 auto rhs = lookupScalar(shlOp.getRhs(), valueMap, shlOp.getOperation());
1403 if (failed(lhs) || failed(rhs)) {
1408 LoweredValue {shlOp.getType(), {lowerFeltShl(builder, loc, *lhs, *rhs, field)}}
1411 if (
auto notOp = dyn_cast<felt::NotFeltOp>(op)) {
1412 auto operand = lookupScalar(
notOp.getOperand(), valueMap,
notOp.getOperation());
1413 if (failed(operand)) {
1418 LoweredValue {notOp.getType(), {lowerFeltNot(builder, loc, *operand, field)}}
1422 if (
auto cmpOp = dyn_cast<boolean::CmpOp>(op)) {
1423 auto lhs = lookupScalar(cmpOp.getLhs(), valueMap, cmpOp.getOperation());
1424 auto rhs = lookupScalar(cmpOp.getRhs(), valueMap, cmpOp.getOperation());
1425 if (failed(lhs) || failed(rhs)) {
1428 auto lowered = lowerFeltCmp(builder, loc, cmpOp, *lhs, *rhs);
1429 if (failed(lowered)) {
1432 return bind(cmpOp.getResult(), LoweredValue {cmpOp.getType(), {*lowered}});
1434 if (
auto assertOp = dyn_cast<boolean::AssertOp>(op)) {
1435 auto condition = lookupScalar(assertOp.getCondition(), valueMap, assertOp.getOperation());
1436 if (failed(condition)) {
1439 builder.create<cf::AssertOp>(
1440 loc, *condition, assertOp.getMsg() ? assertOp.getMsg()->str() :
"bool.assert failed"
1444 if (
auto andOp = dyn_cast<boolean::AndBoolOp>(op)) {
1445 auto lhs = lookupScalar(andOp.getLhs(), valueMap, andOp.getOperation());
1446 auto rhs = lookupScalar(andOp.getRhs(), valueMap, andOp.getOperation());
1447 if (failed(lhs) || failed(rhs)) {
1452 LoweredValue {andOp.getType(), {builder.create<arith::AndIOp>(loc, *lhs, *rhs)}}
1455 if (
auto orOp = dyn_cast<boolean::OrBoolOp>(op)) {
1456 auto lhs = lookupScalar(orOp.getLhs(), valueMap, orOp.getOperation());
1457 auto rhs = lookupScalar(orOp.getRhs(), valueMap, orOp.getOperation());
1458 if (failed(lhs) || failed(rhs)) {
1463 LoweredValue {orOp.getType(), {builder.create<arith::OrIOp>(loc, *lhs, *rhs)}}
1466 if (
auto xorOp = dyn_cast<boolean::XorBoolOp>(op)) {
1467 auto lhs = lookupScalar(xorOp.getLhs(), valueMap, xorOp.getOperation());
1468 auto rhs = lookupScalar(xorOp.getRhs(), valueMap, xorOp.getOperation());
1469 if (failed(lhs) || failed(rhs)) {
1474 LoweredValue {xorOp.getType(), {builder.create<arith::XOrIOp>(loc, *lhs, *rhs)}}
1477 if (
auto notOp = dyn_cast<boolean::NotBoolOp>(op)) {
1478 auto operand = lookupScalar(
notOp.getOperand(), valueMap,
notOp.getOperation());
1479 if (failed(operand)) {
1482 Value one = builder.create<arith::ConstantOp>(
1483 loc, IntegerAttr::get(IntegerType::get(builder.getContext(), 1), 1)
1487 LoweredValue {notOp.getType(), {builder.create<arith::XOrIOp>(loc, *operand, one)}}
1491 if (
auto intToFelt = dyn_cast<cast::IntToFeltOp>(op)) {
1492 auto operand = lookupScalar(intToFelt.getValue(), valueMap, intToFelt.getOperation());
1493 if (failed(operand)) {
1496 auto dstType = IntegerType::get(builder.getContext(), field.bitWidth());
1498 if (isa<IndexType>((*operand).getType())) {
1499 lowered = builder.create<arith::IndexCastUIOp>(loc, dstType, *operand);
1501 auto intType = mlir::cast<IntegerType>((*operand).getType());
1502 if (intType.getWidth() < dstType.getWidth()) {
1503 lowered = builder.create<arith::ExtUIOp>(loc, dstType, *operand);
1504 }
else if (intType.getWidth() > dstType.getWidth()) {
1505 lowered = normalizeWideValue(builder, loc, *operand, dstType.getWidth(), field);
1510 return bind(intToFelt.getResult(), LoweredValue {intToFelt.getType(), {lowered}});
1512 if (
auto feltToIndex = dyn_cast<cast::FeltToIndexOp>(op)) {
1513 auto operand = lookupScalar(feltToIndex.getValue(), valueMap, feltToIndex.getOperation());
1514 if (failed(operand)) {
1518 feltToIndex.getResult(),
1520 feltToIndex.getType(),
1521 {builder.create<arith::IndexCastUIOp>(loc, builder.getIndexType(), *operand)}
1526 if (
auto structNewOp = dyn_cast<component::CreateStructOp>(op)) {
1527 auto lowered = createDefaultValue(
1528 builder, loc, structNewOp.getType(), tables, structNewOp.getOperation(), field,
1529 uninitializedBehavior, rng
1531 if (failed(lowered)) {
1534 return bind(structNewOp.getResult(), std::move(*lowered));
1536 if (
auto readMemberOp = dyn_cast<component::MemberReadOp>(op)) {
1537 auto componentValue =
1538 lookup(readMemberOp.getComponent(), valueMap, readMemberOp.getOperation());
1539 if (failed(componentValue)) {
1542 auto lowered = readNamedAggregateValue(
1543 builder, loc, readMemberOp.getComponent().getType(), readMemberOp.getMemberName(),
1544 *componentValue, tables, readMemberOp.getOperation(), field
1546 if (failed(lowered)) {
1549 return bind(readMemberOp.getResult(), std::move(*lowered));
1551 if (
auto writeMemberOp = dyn_cast<component::MemberWriteOp>(op)) {
1552 auto componentValue =
1553 lookup(writeMemberOp.getComponent(), valueMap, writeMemberOp.getOperation());
1554 auto memberValue = lookup(writeMemberOp.getVal(), valueMap, writeMemberOp.getOperation());
1555 if (failed(componentValue) || failed(memberValue)) {
1558 return writeNamedAggregateValue(
1559 builder, loc, writeMemberOp.getComponent().getType(), writeMemberOp.getMemberName(),
1560 valueMap[writeMemberOp.getComponent()], *memberValue, tables,
1561 writeMemberOp.getOperation(), field
1565 if (
auto newPodOp = dyn_cast<pod::NewPodOp>(op)) {
1566 auto lowered = createDefaultValue(
1567 builder, loc, newPodOp.getType(), tables, newPodOp.getOperation(), field,
1568 uninitializedBehavior, rng
1570 if (failed(lowered)) {
1573 for (pod::RecordValue init : newPodOp.getInitializedRecordValues()) {
1574 auto value = lookup(init.value, valueMap, newPodOp.getOperation());
1575 if (failed(value) || failed(writeNamedAggregateValue(
1576 builder, loc, newPodOp.getType(), init.name, *lowered, *value,
1577 tables, newPodOp.getOperation(), field
1582 return bind(newPodOp.getResult(), std::move(*lowered));
1584 if (
auto readPodOp = dyn_cast<pod::ReadPodOp>(op)) {
1585 auto podValue = lookup(readPodOp.getPodRef(), valueMap, readPodOp.getOperation());
1586 if (failed(podValue)) {
1589 auto lowered = readNamedAggregateValue(
1590 builder, loc, readPodOp.getPodRef().getType(), readPodOp.getRecordName(), *podValue,
1591 tables, readPodOp.getOperation(), field
1593 if (failed(lowered)) {
1596 return bind(readPodOp.getResult(), std::move(*lowered));
1598 if (
auto writePodOp = dyn_cast<pod::WritePodOp>(op)) {
1599 auto recordValue = lookup(writePodOp.getValue(), valueMap, writePodOp.getOperation());
1600 if (failed(recordValue)) {
1603 return writeNamedAggregateValue(
1604 builder, loc, writePodOp.getPodRef().getType(), writePodOp.getRecordName(),
1605 valueMap[writePodOp.getPodRef()], *recordValue, tables, writePodOp.getOperation(), field
1609 if (
auto arrayNewOp = dyn_cast<array::CreateArrayOp>(op)) {
1610 auto lowered = createDefaultValue(
1611 builder, loc, arrayNewOp.getType(), tables, arrayNewOp.getOperation(), field,
1612 uninitializedBehavior, rng
1614 if (failed(lowered)) {
1617 if (!arrayNewOp.getElements().empty()) {
1618 auto elementCount = checkedCast<size_t>(arrayNewOp.getType().getNumElements());
1619 if (!elementCount) {
1620 arrayNewOp.emitError() << llvm::toString(elementCount.takeError());
1623 if (arrayNewOp.getElements().size() != *elementCount) {
1624 arrayNewOp.emitError(
"expected one explicit element per array slot in witgen lowering");
1627 auto shape = arrayNewOp.getType().getShape();
1628 for (
auto [flatIndex, operand] : llvm::enumerate(arrayNewOp.getElements())) {
1629 auto elementValue = lookup(operand, valueMap, arrayNewOp.getOperation());
1630 if (failed(elementValue)) {
1633 SmallVector<Value> indices;
1634 auto strides = mlir::computeStrides(shape);
1635 auto flatSigned = checkedCast<int64_t>(flatIndex);
1637 arrayNewOp.emitError() << llvm::toString(flatSigned.takeError());
1640 for (int64_t index : mlir::delinearize(*flatSigned, strides)) {
1641 indices.push_back(makeIndexConstant(builder, loc, index));
1643 if (failed(writeArrayElement(
1644 builder, loc, arrayNewOp.getType(), *lowered, indices, *elementValue
1650 return bind(arrayNewOp.getResult(), std::move(*lowered));
1652 if (
auto readArrayOp = dyn_cast<array::ReadArrayOp>(op)) {
1653 SmallVector<Value> indices;
1654 for (Value indexValue : readArrayOp.getIndices()) {
1655 auto loweredIndex = lookupScalar(indexValue, valueMap, readArrayOp.getOperation());
1656 if (failed(loweredIndex)) {
1659 indices.push_back(*loweredIndex);
1661 auto arrayValue = lookup(readArrayOp.getArrRef(), valueMap, readArrayOp.getOperation());
1662 if (failed(arrayValue)) {
1665 auto lowered = readArrayElement(
1666 builder, loc, mlir::cast<array::ArrayType>(readArrayOp.getArrRef().getType()),
1667 *arrayValue, indices
1669 if (failed(lowered)) {
1672 return bind(readArrayOp.getResult(), std::move(*lowered));
1674 if (
auto writeArrayOp = dyn_cast<array::WriteArrayOp>(op)) {
1675 SmallVector<Value> indices;
1676 for (Value indexValue : writeArrayOp.getIndices()) {
1677 auto loweredIndex = lookupScalar(indexValue, valueMap, writeArrayOp.getOperation());
1678 if (failed(loweredIndex)) {
1681 indices.push_back(*loweredIndex);
1683 auto elementValue = lookup(writeArrayOp.getRvalue(), valueMap, writeArrayOp.getOperation());
1684 if (failed(elementValue)) {
1687 return writeArrayElement(
1688 builder, loc, mlir::cast<array::ArrayType>(writeArrayOp.getArrRef().getType()),
1689 valueMap[writeArrayOp.getArrRef()], indices, *elementValue
1693 if (
auto cmpiOp = dyn_cast<arith::CmpIOp>(op)) {
1694 auto lhs = lookupScalar(cmpiOp.getLhs(), valueMap, cmpiOp.getOperation());
1695 auto rhs = lookupScalar(cmpiOp.getRhs(), valueMap, cmpiOp.getOperation());
1696 if (failed(lhs) || failed(rhs)) {
1703 {builder.create<arith::CmpIOp>(loc, cmpiOp.getPredicate(), *lhs, *rhs)}
1707 if (
auto selectOp = dyn_cast<arith::SelectOp>(op)) {
1708 auto cond = lookupScalar(selectOp.getCondition(), valueMap, selectOp.getOperation());
1709 auto trueValue = lookupScalar(selectOp.getTrueValue(), valueMap, selectOp.getOperation());
1710 auto falseValue = lookupScalar(selectOp.getFalseValue(), valueMap, selectOp.getOperation());
1711 if (failed(cond) || failed(trueValue) || failed(falseValue)) {
1715 selectOp.getResult(),
1718 {builder.create<arith::SelectOp>(loc, *cond, *trueValue, *falseValue)}
1722 if (
auto addiOp = dyn_cast<arith::AddIOp>(op)) {
1723 auto lhs = lookupScalar(addiOp.getLhs(), valueMap, addiOp.getOperation());
1724 auto rhs = lookupScalar(addiOp.getRhs(), valueMap, addiOp.getOperation());
1725 if (failed(lhs) || failed(rhs)) {
1730 LoweredValue {addiOp.getType(), {builder.create<arith::AddIOp>(loc, *lhs, *rhs)}}
1733 if (
auto subiOp = dyn_cast<arith::SubIOp>(op)) {
1734 auto lhs = lookupScalar(subiOp.getLhs(), valueMap, subiOp.getOperation());
1735 auto rhs = lookupScalar(subiOp.getRhs(), valueMap, subiOp.getOperation());
1736 if (failed(lhs) || failed(rhs)) {
1741 LoweredValue {subiOp.getType(), {builder.create<arith::SubIOp>(loc, *lhs, *rhs)}}
1745 if (
auto callOp = dyn_cast<function::CallOp>(op)) {
1746 if (callOp.getTemplateParams() || !callOp.getMapOperands().empty()) {
1747 callOp.emitError(
"execution-engine backend encountered an unflattened function.call");
1750 auto *callable = callOp.resolveCallableInTable(&tables);
1751 auto callee = dyn_cast_or_null<function::FuncDefOp>(callable);
1753 callOp.emitError(
"failed to resolve callee during execution-engine lowering");
1756 SmallVector<Type> resultTypes;
1757 for (Type resultType : callOp.getResultTypes()) {
1759 flattenABILeafTypes(resultType, tables, callOp.getOperation(), field, resultTypes)
1764 SmallVector<Value> flatArgs;
1765 for (Value operand : callOp.getArgOperands()) {
1766 auto lowered = lookup(operand, valueMap, callOp.getOperation());
1767 auto leafTypes = getABILeafTypes(operand.getType(), tables, callOp.getOperation(), field);
1768 if (failed(lowered) || failed(leafTypes) ||
1769 failed(appendFlatLeavesToTypes(
1770 builder, loc, *lowered, *leafTypes, flatArgs, callOp.getOperation()
1776 builder.create<func::CallOp>(loc, mangleFunctionName(callee), resultTypes, flatArgs);
1777 auto loweredCallResults = loweredCall.getResults();
1778 size_t totalResults = loweredCallResults.size();
1780 for (
auto [oldResult, oldType] : llvm::zip(callOp.getResults(), callOp.getResultTypes())) {
1781 auto leafCount = getLeafCount(oldType, tables, callOp.getOperation(), field);
1782 if (failed(leafCount)) {
1785 bool overflow =
false;
1786 size_t nextCursor = llvm::SaturatingAdd(cursor, *leafCount, &overflow);
1787 if (overflow || nextCursor > totalResults) {
1788 callOp.emitError(
"leaf count overflow while lowering function call results");
1791 LoweredValue lowered {oldType, {}};
1792 lowered.leaves.append(
1793 loweredCallResults.begin() +
static_cast<ptrdiff_t
>(cursor),
1794 loweredCallResults.begin() +
static_cast<ptrdiff_t
>(nextCursor)
1796 valueMap[oldResult] = std::move(lowered);
1797 cursor = nextCursor;
1802 if (
auto whileOp = dyn_cast<scf::WhileOp>(op)) {
1803 SmallVector<Value> initArgs;
1804 SmallVector<size_t> beforeLeafCounts;
1805 for (
auto [init, initType] : llvm::zip(whileOp.getInits(), whileOp.getOperandTypes())) {
1806 auto lowered = lookup(init, valueMap, whileOp.getOperation());
1807 auto leafTypes = getABILeafTypes(initType, tables, whileOp.getOperation(), field);
1808 if (failed(lowered) || failed(leafTypes) ||
1809 failed(appendFlatLeavesToTypes(
1810 builder, loc, *lowered, *leafTypes, initArgs, whileOp.getOperation()
1814 auto count = getLeafCount(initType, tables, whileOp.getOperation(), field);
1815 if (failed(count)) {
1818 beforeLeafCounts.push_back(*count);
1821 SmallVector<size_t> resultLeafCounts;
1822 SmallVector<Type> loweredResultTypes;
1823 for (Type resultType : whileOp.getResultTypes()) {
1824 auto leafTypes = getABILeafTypes(resultType, tables, whileOp.getOperation(), field);
1825 auto count = getLeafCount(resultType, tables, whileOp.getOperation(), field);
1826 if (failed(leafTypes) || failed(count)) {
1829 loweredResultTypes.append(leafTypes->begin(), leafTypes->end());
1830 resultLeafCounts.push_back(*count);
1833 auto mapRegionArguments = [&](
auto oldArgs,
auto oldTypes,
auto leafCounts,
auto newArgs,
1834 StringRef overflowMessage,
1835 DenseMap<Value, LoweredValue> ®ionMap) -> LogicalResult {
1836 size_t totalArgs = newArgs.size();
1838 for (
auto [oldArg, oldType, leafCount] : llvm::zip(oldArgs, oldTypes, leafCounts)) {
1839 bool overflow =
false;
1840 size_t nextCursor = llvm::SaturatingAdd(cursor, leafCount, &overflow);
1841 if (overflow || nextCursor > totalArgs) {
1842 whileOp.emitError(overflowMessage);
1845 LoweredValue lowered {oldType, {}};
1846 lowered.leaves.append(
1850 regionMap[oldArg] = std::move(lowered);
1851 cursor = nextCursor;
1856 LogicalResult whileLoweringStatus = success();
1857 auto newWhile = builder.create<scf::WhileOp>(
1858 loc, loweredResultTypes, initArgs,
1859 [&](OpBuilder ®ionBuilder, Location , ValueRange beforeArgs) {
1860 DenseMap<Value, LoweredValue> beforeMap(valueMap.begin(), valueMap.end());
1861 if (failed(mapRegionArguments(
1862 whileOp.getBeforeArguments(), whileOp.getOperandTypes(), beforeLeafCounts,
1863 beforeArgs,
"leaf count overflow while lowering while-loop before-region args",
1866 failed(lowerBlock(regionBuilder, whileOp.getBefore().front(), beforeMap))) {
1867 whileLoweringStatus = failure();
1869 }, [&](OpBuilder ®ionBuilder, Location , ValueRange afterArgs) {
1870 DenseMap<Value, LoweredValue> afterMap(valueMap.begin(), valueMap.end());
1871 if (failed(mapRegionArguments(
1872 whileOp.getAfterArguments(), whileOp.getResultTypes(), resultLeafCounts, afterArgs,
1873 "leaf count overflow while lowering while-loop after-region args", afterMap
1875 failed(lowerBlock(regionBuilder, whileOp.getAfter().front(), afterMap))) {
1876 whileLoweringStatus = failure();
1880 if (failed(whileLoweringStatus)) {
1885 auto newWhileResults = newWhile.getResults();
1886 size_t totalResults = newWhileResults.size();
1888 for (
auto [oldResult, oldType, leafCount] :
1889 llvm::zip(whileOp.getResults(), whileOp.getResultTypes(), resultLeafCounts)) {
1890 bool overflow =
false;
1891 size_t nextCursor = llvm::SaturatingAdd(cursor, leafCount, &overflow);
1892 if (overflow || nextCursor > totalResults) {
1893 whileOp.emitError(
"leaf count overflow while lowering while-loop results");
1896 LoweredValue lowered {oldType, {}};
1897 lowered.leaves.append(
1901 valueMap[oldResult] = std::move(lowered);
1902 cursor = nextCursor;
1907 if (
auto ifOp = dyn_cast<scf::IfOp>(op)) {
1908 auto condition = lookupScalar(ifOp.getCondition(), valueMap, ifOp.getOperation());
1909 if (failed(condition)) {
1913 SmallVector<size_t> resultLeafCounts;
1914 SmallVector<Type> loweredResultTypes;
1915 for (Type resultType : ifOp.getResultTypes()) {
1916 auto leafTypes = getABILeafTypes(resultType, tables, ifOp.getOperation(), field);
1917 auto count = getLeafCount(resultType, tables, ifOp.getOperation(), field);
1918 if (failed(leafTypes) || failed(count)) {
1921 loweredResultTypes.append(leafTypes->begin(), leafTypes->end());
1922 resultLeafCounts.push_back(*count);
1925 auto newIf = builder.create<scf::IfOp>(
1926 loc, loweredResultTypes, *condition,
true, !ifOp.getElseRegion().empty()
1930 OpBuilder thenBuilder = OpBuilder::atBlockBegin(&newIf.getThenRegion().front());
1931 DenseMap<Value, LoweredValue> thenMap(valueMap.begin(), valueMap.end());
1932 if (failed(lowerBlock(thenBuilder, ifOp.getThenRegion().front(), thenMap))) {
1937 if (!ifOp.getElseRegion().empty()) {
1938 OpBuilder elseBuilder = OpBuilder::atBlockBegin(&newIf.getElseRegion().front());
1939 DenseMap<Value, LoweredValue> elseMap(valueMap.begin(), valueMap.end());
1940 if (failed(lowerBlock(elseBuilder, ifOp.getElseRegion().front(), elseMap))) {
1945 auto newIfResults = newIf.getResults();
1946 size_t totalResults = newIfResults.size();
1948 for (
auto [oldResult, oldType, leafCount] :
1949 llvm::zip(ifOp.getResults(), ifOp.getResultTypes(), resultLeafCounts)) {
1950 bool overflow =
false;
1951 size_t nextCursor = llvm::SaturatingAdd(cursor, leafCount, &overflow);
1952 if (overflow || nextCursor > totalResults) {
1953 ifOp.emitError(
"leaf count overflow while lowering if-op results");
1956 LoweredValue lowered {oldType, {}};
1957 lowered.leaves.append(
1961 valueMap[oldResult] = std::move(lowered);
1962 cursor = nextCursor;
1967 if (
auto forOp = dyn_cast<scf::ForOp>(op)) {
1968 auto lb = lookupScalar(forOp.getLowerBound(), valueMap, forOp.getOperation());
1969 auto ub = lookupScalar(forOp.getUpperBound(), valueMap, forOp.getOperation());
1970 auto step = lookupScalar(forOp.getStep(), valueMap, forOp.getOperation());
1971 if (failed(lb) || failed(ub) || failed(step)) {
1975 SmallVector<Value> initArgs;
1976 SmallVector<size_t> initLeafCounts;
1977 for (
auto [init, resultType] : llvm::zip(forOp.getInitArgs(), forOp.getResultTypes())) {
1978 auto lowered = lookup(init, valueMap, forOp.getOperation());
1979 auto leafTypes = getABILeafTypes(resultType, tables, forOp.getOperation(), field);
1980 if (failed(lowered) || failed(leafTypes) ||
1981 failed(appendFlatLeavesToTypes(
1982 builder, loc, *lowered, *leafTypes, initArgs, forOp.getOperation()
1986 auto count = getLeafCount(resultType, tables, forOp.getOperation(), field);
1987 if (failed(count)) {
1990 initLeafCounts.push_back(*count);
1993 auto newFor = builder.create<scf::ForOp>(loc, *lb, *ub, *step, initArgs);
1994 if (Attribute unsignedCmpAttr = forOp->getAttr(
"unsignedCmp")) {
1995 newFor->setAttr(
"unsignedCmp", unsignedCmpAttr);
1997 DenseMap<Value, LoweredValue> bodyMap(valueMap.begin(), valueMap.end());
1998 bodyMap[forOp.getInductionVar()] =
1999 LoweredValue {forOp.getInductionVar().getType(), {newFor.getInductionVar()}};
2001 auto newForIterArgs = newFor.getRegionIterArgs();
2002 size_t totalIterArgs = newForIterArgs.size();
2004 for (
auto [oldIterArg, oldType, leafCount] :
2005 llvm::zip(forOp.getRegionIterArgs(), forOp.getResultTypes(), initLeafCounts)) {
2006 bool overflow =
false;
2007 size_t nextCursor = llvm::SaturatingAdd(cursor, leafCount, &overflow);
2008 if (overflow || nextCursor > totalIterArgs) {
2009 forOp.emitError(
"leaf count overflow while lowering for-loop region iter args");
2012 LoweredValue lowered {oldType, {}};
2013 lowered.leaves.append(
2014 newForIterArgs.begin() +
static_cast<ptrdiff_t
>(cursor),
2015 newForIterArgs.begin() +
static_cast<ptrdiff_t
>(nextCursor)
2017 bodyMap[oldIterArg] = std::move(lowered);
2018 cursor = nextCursor;
2022 newFor.getBody()->clear();
2023 OpBuilder bodyBuilder = OpBuilder::atBlockBegin(newFor.getBody());
2024 if (failed(lowerBlock(bodyBuilder, *forOp.getBody(), bodyMap))) {
2029 auto newForResults = newFor.getResults();
2030 size_t totalForResults = newForResults.size();
2032 for (
auto [oldResult, oldType, leafCount] :
2033 llvm::zip(forOp.getResults(), forOp.getResultTypes(), initLeafCounts)) {
2034 bool overflow =
false;
2035 size_t nextCursor = llvm::SaturatingAdd(cursor, leafCount, &overflow);
2036 if (overflow || nextCursor > totalForResults) {
2037 forOp.emitError(
"leaf count overflow while lowering for-loop results");
2040 LoweredValue lowered {oldType, {}};
2041 lowered.leaves.append(
2042 newForResults.begin() +
static_cast<ptrdiff_t
>(cursor),
2043 newForResults.begin() +
static_cast<ptrdiff_t
>(nextCursor)
2045 valueMap[oldResult] = std::move(lowered);
2046 cursor = nextCursor;
2052 op.emitError(
"unsupported operation in execution-engine lowering: ") << op.getName();
2058class LowerComputeToCorePass :
public PassWrapper<LowerComputeToCorePass, OperationPass<ModuleOp>> {
2060 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LowerComputeToCorePass)
2062 explicit LowerComputeToCorePass(
const WitgenOptions &opts) : options(opts) {}
2065 StringRef getArgument() const final {
return "llzk-lower-compute-to-core"; }
2068 StringRef getDescription() const final {
2069 return "Lower LLZK compute IR to func/arith/cf/scf/memref";
2073 StringRef getName()
const override {
return "LowerComputeToCorePass"; }
2076 void runOnOperation()
override {
2077 ModuleOp moduleOp = getOperation();
2078 auto field = getModuleField(moduleOp);
2079 if (failed(field)) {
2080 signalPassFailure();
2084 SymbolTableCollection tables;
2085 SmallVector<function::FuncDefOp> funcs;
2086 moduleOp.walk([&](function::FuncDefOp funcOp) {
2087 if (funcOp.nameIsConstrain()) {
2090 funcs.push_back(funcOp);
2093 BodyLowerer lowerer(moduleOp, tables, field->get(), options);
2094 for (function::FuncDefOp funcOp : funcs) {
2095 if (failed(lowerer.lowerFunction(funcOp))) {
2096 signalPassFailure();
2103 WitgenOptions options;
2107class CreateWitgenEntryPass :
public PassWrapper<CreateWitgenEntryPass, OperationPass<ModuleOp>> {
2109 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CreateWitgenEntryPass)
2112 explicit CreateWitgenEntryPass(
bool fullWitness =
false) : emitFullWitness(fullWitness) {}
2115 StringRef getArgument() const final {
return "llzk-create-witgen-entry"; }
2118 StringRef getDescription() const final {
2119 return "Create the llzk-witgen execution-engine entry wrapper";
2123 StringRef getName()
const override {
return "CreateWitgenEntryPass"; }
2126 void runOnOperation()
override {
2127 ModuleOp moduleOp = getOperation();
2128 auto field = getModuleField(moduleOp);
2129 if (failed(field)) {
2130 signalPassFailure();
2134 SymbolTableCollection tables;
2136 if (failed(mainDef) || !mainDef.value()) {
2137 moduleOp.emitError(
"module is missing a concrete llzk.main struct");
2138 signalPassFailure();
2141 function::FuncDefOp computeFunc = mainDef->get().getComputeFuncOp();
2143 moduleOp.emitError(
"main struct is missing @compute");
2144 signalPassFailure();
2149 mainDef->get(), tables, computeFunc.getOperation(),
2150 emitFullWitness ? OutputScope::FullWitness : OutputScope::Public
2152 if (failed(outputs)) {
2153 signalPassFailure();
2157 OpBuilder builder(moduleOp.getContext());
2158 builder.setInsertionPointToEnd(moduleOp.getBody());
2160 SmallVector<Type> wrapperArgs;
2161 for (Type argType : computeFunc.getArgumentTypes()) {
2162 SmallVector<Type> loweredLeafTypes;
2163 if (failed(flattenTypeLeaves(
2164 argType, tables, computeFunc.getOperation(), field->get(), loweredLeafTypes, {},
true
2166 signalPassFailure();
2169 if (loweredLeafTypes.size() != 1 || !isa<MemRefType>(loweredLeafTypes.front())) {
2170 computeFunc.emitError(
2171 "execution-engine wrapper only supports felt and array<...xfelt> inputs"
2173 signalPassFailure();
2176 wrapperArgs.push_back(loweredLeafTypes.front());
2178 for (
const OutputBinding &output : *outputs) {
2179 SmallVector<Type> loweredLeafTypes;
2180 if (failed(flattenTypeLeaves(
2181 output.type, tables, computeFunc.getOperation(), field->get(), loweredLeafTypes, {},
2184 signalPassFailure();
2187 if (loweredLeafTypes.size() != 1 || !isa<MemRefType>(loweredLeafTypes.front())) {
2188 computeFunc.emitError(
2189 "execution-engine wrapper only supports felt and array<...xfelt> outputs"
2191 signalPassFailure();
2194 wrapperArgs.push_back(loweredLeafTypes.front());
2197 auto wrapper = builder.create<func::FuncOp>(
2198 computeFunc.getLoc(),
"__llzk_witgen_main",
2199 builder.getFunctionType(wrapperArgs, TypeRange {})
2201 wrapper->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(), builder.getUnitAttr());
2202 Block *entry = wrapper.addEntryBlock();
2203 builder.setInsertionPointToStart(entry);
2205 SmallVector<Type> loweredMainResultTypes;
2206 for (Type resultType : computeFunc.getResultTypes()) {
2207 if (failed(flattenABILeafTypes(
2208 resultType, tables, computeFunc.getOperation(), field->get(), loweredMainResultTypes
2210 signalPassFailure();
2215 SmallVector<Value> mainArgs;
2216 for (
auto [argType, wrapperArg] : llvm::zip(
2217 computeFunc.getArgumentTypes(),
2218 entry->getArguments().take_front(computeFunc.getNumArguments())
2220 if (isScalarType(argType)) {
2221 mainArgs.push_back(loadStorageScalar(builder, computeFunc.getLoc(), wrapperArg));
2224 getABILeafTypes(argType, tables, computeFunc.getOperation(), field->get());
2225 if (failed(abiLeafTypes) || abiLeafTypes->size() != 1 ||
2226 !isa<MemRefType>(abiLeafTypes->front())) {
2227 computeFunc.emitError(
"failed to derive execution-engine ABI type for main input");
2228 signalPassFailure();
2231 if (wrapperArg.getType() == abiLeafTypes->front()) {
2232 mainArgs.push_back(wrapperArg);
2234 mainArgs.push_back(builder.create<memref::CastOp>(
2235 computeFunc.getLoc(), abiLeafTypes->front(), wrapperArg
2240 auto loweredMain = builder.create<func::CallOp>(
2241 computeFunc.getLoc(), mangleFunctionName(computeFunc), loweredMainResultTypes, mainArgs
2244 LoweredValue mainResultValue {
2245 computeFunc.getResultTypes().front(),
2246 llvm::SmallVector<Value>(loweredMain.getResults().begin(), loweredMain.getResults().end())
2249 auto extractOutputSlice = [&](ArrayRef<std::string> path, Type currentType,
2250 ArrayRef<Value> leaves,
2251 auto &self) -> FailureOr<SmallVector<Value>> {
2253 return SmallVector<Value>(leaves.begin(), leaves.end());
2255 if (
auto structType = dyn_cast<component::StructType>(currentType)) {
2256 auto defLookup = structType.getDefinition(tables, computeFunc.getOperation());
2257 if (failed(defLookup)) {
2260 unsigned localCursor = 0;
2261 for (component::MemberDefOp member : defLookup->get().getMemberDefs()) {
2263 getLeafCount(member.getType(), tables, member.getOperation(), field->get());
2264 if (failed(leafCount)) {
2267 ArrayRef<Value> slice = ArrayRef<Value>(leaves).slice(localCursor, *leafCount);
2268 localCursor += *leafCount;
2269 if (member.getSymName() == path.front()) {
2270 return self(path.drop_front(), member.getType(), slice, self);
2273 computeFunc.emitError(
"failed to find struct member while wiring witgen outputs");
2276 if (
auto podType = dyn_cast<pod::PodType>(currentType)) {
2277 unsigned localCursor = 0;
2278 for (pod::RecordAttr record : podType.getRecords()) {
2280 getLeafCount(record.getType(), tables, computeFunc.getOperation(), field->get());
2281 if (failed(leafCount)) {
2284 ArrayRef<Value> slice = ArrayRef<Value>(leaves).slice(localCursor, *leafCount);
2285 localCursor += *leafCount;
2286 if (record.getName().getValue() == path.front()) {
2287 return self(path.drop_front(), record.getType(), slice, self);
2290 computeFunc.emitError(
"failed to find POD record while wiring witgen outputs");
2293 computeFunc.emitError(
"extra witness path components for non-aggregate output");
2297 auto outputArgs = entry->getArguments().drop_front(computeFunc.getNumArguments());
2298 for (
auto [output, outputMemRef] : llvm::zip(*outputs, outputArgs)) {
2299 auto slice = extractOutputSlice(
2300 output.path, mainResultValue.sourceType, mainResultValue.leaves, extractOutputSlice
2302 if (failed(slice) || slice->empty()) {
2303 wrapper.emitError(
"missing selected witness output slice while building witgen entry");
2304 signalPassFailure();
2307 if (isScalarType(output.type)) {
2309 builder, computeFunc.getLoc(),
2310 loadStorageScalar(builder, computeFunc.getLoc(), slice->front()), outputMemRef
2313 builder.create<memref::CopyOp>(computeFunc.getLoc(), slice->front(), outputMemRef);
2316 builder.create<func::ReturnOp>(computeFunc.getLoc());
2319 moduleOp->removeAttr(MAIN_ATTR_NAME);
2321 SmallVector<Operation *> toErase;
2322 for (Operation &op : moduleOp.getBody()->getOperations()) {
2323 if (!isa<func::FuncOp>(op)) {
2324 toErase.push_back(&op);
2327 for (Operation *op : toErase) {
2333 bool emitFullWitness;
2343 pm.addPass(mlir::createLowerAffinePass());
2346 pm.addPass(mlir::createCanonicalizerPass());
2347 pm.addPass(mlir::createCSEPass());
2351 return std::make_unique<LowerComputeToCorePass>(options);
2355 return std::make_unique<CreateWitgenEntryPass>(emitFullWitness);
This file implements helper methods for constructing DynamicAPInts.
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation source
std::unique_ptr<::mlir::Pass > createFlatteningPass()
llvm::Expected< T > checkedCast(U u)
std::mt19937_64 makeDefaultValueRng(const WitgenOptions &options)
Seed an RNG for random/default witness value materialization.
FailureOr< llvm::SmallVector< OutputBinding > > collectOutputBindings(component::StructDefOp mainDef, SymbolTableCollection &tables, Operation *origin, OutputScope scope)
Collect the selected output bindings for the requested scope.
std::unique_ptr< Pass > createCreateWitgenEntryPass(bool emitFullWitness)
Create the pass that synthesizes the stable llzk-witgen JIT entry wrapper.
void addWitgenPreparePipeline(OpPassManager &pm, const WitgenOptions &)
UninitializedBehavior
Control how witgen materializes uninitialized/default values.
std::unique_ptr< Pass > createLowerComputeToCorePass(const WitgenOptions &options)
Create the pass that lowers supported LLZK compute IR into core MLIR dialects suitable for LLVM lower...
llvm::DynamicAPInt randomFieldElement(std::mt19937_64 &rng, const Field &field)
Draw a uniformly distributed field element in [0, prime).
bool randomBoolValue(std::mt19937_64 &rng)
Draw a uniformly distributed boolean value.
llvm::Expected< size_t > getStaticElementCount(ShapedType type, llvm::StringRef context)
int64_t randomIndexValue(std::mt19937_64 &rng)
Draw a uniformly distributed signed index value.
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
llvm::SmallVector< StringRef > getNames(SymbolRefAttr ref)
DynamicAPInt toDynamicAPInt(StringRef str)
constexpr T checkedCast(U u) noexcept
APInt toExactWidthAPInt(const DynamicAPInt &val, unsigned bitWidth)
FailureOr< SymbolLookupResult< StructDefOp > > getMainInstanceDef(SymbolTableCollection &symbolTable, Operation *lookupFrom)
llvm::SmallSet< FieldRef, 2 > FieldSet
Typealias for a set of Fields.
ExpressionValue notOp(const llvm::SMTSolverRef &solver, const ExpressionValue &val)
mlir::LogicalResult collectFields(mlir::Operation *root, FieldSet &fields, bool silent=true)
Collects all the fields used in a circuit.
Configure one llzk-witgen execution.