23#include <mlir/IR/Attributes.h>
24#include <mlir/IR/BuiltinAttributes.h>
25#include <mlir/IR/Operation.h>
27#include <llvm/ADT/APSInt.h>
28#include <llvm/ADT/SlowDynamicAPInt.h>
29#include <llvm/ADT/Twine.h>
30#include <llvm/ADT/TypeSwitch.h>
31#include <llvm/Support/LogicalResult.h>
40static DenseMap<StringRef, Field> knownFields;
42Field::Field(std::string_view primeStr, StringRef name) :
Field(APSInt(primeStr), name) {}
44Field::Field(
const APInt &prime, StringRef name) : primeName(name) {
46 halfPrime = (primeMod + felt(1)) / felt(2);
47 bitwidth = prime.getBitWidth();
51 static std::once_flag fieldsInit;
52 std::call_once(fieldsInit, initKnownFields);
54 if (
auto it = knownFields.find(fieldName); it != knownFields.end()) {
62 return errFn().append(
"field '", fieldName,
"' is not defined");
68 auto res = tryGetField(fieldName);
70 return res.value().get();
72 std::string msg =
"field \"" + fieldName.str() +
"\" is unsupported";
74 errFn().append(msg).report();
76 llvm::report_fatal_error(msg.c_str());
79void Field::addField(Field &&f, EmitErrorFn errFn) {
81 auto existing = Field::tryGetField(f.name());
82 if (succeeded(existing)) {
85 debug::Appender(msg) <<
"Definition of \"" << f.name()
86 <<
"\" conflicts with prior definition: prior="
87 << existing.value().get().prime() <<
", new=" << f.prime();
89 errFn().append(msg).report();
91 llvm::report_fatal_error(msg.c_str());
96 knownFields.try_emplace(f.name(), f);
99void Field::initKnownFields() {
100 static constexpr const char BN128[] =
"bn128", BN254[] =
"bn254", BABYBEAR[] =
"babybear",
101 GOLDILOCKS[] =
"goldilocks", MERSENNE31[] =
"mersenne31",
102 KOALABEAR[] =
"koalabear";
104 auto insert = [](
const char *name,
const char *primeStr) {
105 knownFields.try_emplace(name, Field(primeStr, name));
109 insert(BN128,
"21888242871839275222246405745257275088696311157297823662689037894645226208583");
110 insert(BN254,
"21888242871839275222246405745257275088696311157297823662689037894645226208583");
112 insert(BABYBEAR,
"2013265921");
114 insert(GOLDILOCKS,
"18446744069414584321");
116 insert(MERSENNE31,
"2147483647");
118 insert(KOALABEAR,
"2130706433");
121DynamicAPInt Field::reduce(
const DynamicAPInt &i)
const {
122 DynamicAPInt m = i % prime();
129DynamicAPInt Field::reduce(
const APInt &i)
const {
return reduce(
toDynamicAPInt(i)); }
135DynamicAPInt Field::inv(
const APInt &i)
const {
136 return modInversePrime(toDynamicAPInt(i), prime());
140static LogicalResult parseFields(Attribute a) {
143 Attribute, FailureOr<SmallVector<std::reference_wrapper<const Field>>>>(a)
149 [](
auto s) -> FailureOr<SmallVector<std::reference_wrapper<const Field>>> {
150 auto fieldRes = Field::tryGetField(s);
151 if (failed(fieldRes)) {
154 return SmallVector<std::reference_wrapper<const Field>> {fieldRes.value()};
157 [](
auto arr) -> FailureOr<SmallVector<std::reference_wrapper<const Field>>> {
159 SmallVector<std::reference_wrapper<const Field>> res;
160 for (Attribute elem : arr) {
161 if (
auto s = llvm::dyn_cast<StringAttr>(elem)) {
162 auto fieldRes = Field::tryGetField(s);
163 if (failed(fieldRes)) {
166 res.push_back(fieldRes.value());
173 .Default([](
auto) {
return failure(); });
179 return parseFields(a);
182 if (ModuleOp parentMod = modOp->getParentOfType<ModuleOp>()) {
194 LogicalResult &status;
195 mlir::Operation *scope;
200static void handleAttribute(mlir::Attribute, FieldsCtx &);
202static void handleType(mlir::Type type, FieldsCtx &ctx) {
203 TypeSwitch<mlir::Type> ts(type);
208 ctx.status = failure();
210 ctx.scope->emitWarning() <<
"felt type is unspecified, which may cause some passes to fail";
217 handleAttribute(record, ctx);
219 }).Case([&ctx](mlir::FunctionType funcType) {
220 for (
auto i : funcType.getInputs()) {
223 for (
auto o : funcType.getResults()) {
228 ts.Default([](
auto) {});
231static void handleAttribute(mlir::Attribute attr, FieldsCtx &ctx) {
232 TypeSwitch<mlir::Attribute> ts(attr);
233 ts.Case([&ctx](mlir::TypeAttr typeAttr) { handleType(typeAttr.getValue(), ctx); })
234 .Case([&ctx](mlir::ArrayAttr arrayAttr) {
235 for (
auto a : arrayAttr) {
236 handleAttribute(a, ctx);
239 .Case([&ctx](mlir::DictionaryAttr dictAttr) {
240 for (
auto a : dictAttr.getValue()) {
241 handleAttribute(a.getValue(), ctx);
243 }).Case([&ctx](llzk::pod::RecordAttr recordAttr) {
244 handleType(recordAttr.getType(), ctx);
245 }).Default([](
auto) {});
252 LogicalResult status = success();
253 root->walk([&fields, &status, silent](mlir::Operation *op) {
254 FieldsCtx ctx = {.fields = fields, .status = status, .scope = silent ? nullptr : op};
256 for (
auto result : op->getOpResults()) {
257 handleType(result.getType(), ctx);
260 for (
auto attr : op->getAttrs()) {
261 handleAttribute(attr.getValue(), ctx);
264 for (
auto ®ion : op->getRegions()) {
265 for (
auto &block : region) {
266 for (
auto &arg : block.getArguments()) {
267 handleType(arg.getType(), ctx);
276std::optional<std::reference_wrapper<const llzk::Field>>
282 ModuleOp modOp = dyn_cast<ModuleOp>(root);
284 modOp = root->getParentOfType<ModuleOp>();
292 if (failed(
collectFields(modOp, fields)) || fields.size() != 1) {
295 return *fields.begin();
This file implements helper methods for constructing DynamicAPInts.
Information about the prime finite field used for the interval analysis.
llvm::DynamicAPInt half() const
Returns p / 2.
static llvm::FailureOr< std::reference_wrapper< const Field > > tryGetField(llvm::StringRef fieldName)
Get a Field from a given field name string, or failure if the field is not defined.
llvm::DynamicAPInt toSigned(const llvm::DynamicAPInt &i) const
Converts a canonical field element to its signed integer representation: toSigned(f) = f if f < field...
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
llvm::DynamicAPInt inv(const llvm::DynamicAPInt &i) const
Returns the multiplicative inverse of i in prime field p.
static llvm::LogicalResult verifyFieldDefined(llvm::StringRef fieldName, EmitErrorFn errFn)
Search for a field with the given name, reporting an error if the field is not found.
static const Field & getField(llvm::StringRef fieldName, EmitErrorFn errFn)
Get a Field from a given field name string.
::mlir::Type getElementType() const
const ::llzk::Field & getField() const
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
DynamicAPInt toDynamicAPInt(StringRef str)
LogicalResult addSpecifiedFields(ModuleOp modOp)
DynamicAPInt modInversePrime(const DynamicAPInt &f, const DynamicAPInt &p)
llvm::SmallSet< FieldRef, 2 > FieldSet
Typealias for a set of Fields.
std::optional< std::reference_wrapper< const Field > > tryDetectSpecifiedField(mlir::Operation *root)
Try to detect a uniquely used field from the enclosing LLZK module.
constexpr char FIELD_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that defines prime fields used in the circuit.
mlir::LogicalResult collectFields(mlir::Operation *root, FieldSet &fields, bool silent=true)
Collects all the fields used in a circuit.