23#include <mlir/IR/Attributes.h>
24#include <mlir/IR/BuiltinAttributes.h>
25#include <mlir/IR/Operation.h>
27#include <llvm/ADT/APSInt.h>
28#include <llvm/ADT/SlowDynamicAPInt.h>
29#include <llvm/ADT/Twine.h>
30#include <llvm/ADT/TypeSwitch.h>
31#include <llvm/Support/LogicalResult.h>
43static DenseMap<StringRef, Field> &getKnownFields() {
44 static DenseMap<StringRef, Field> knownFields;
48Field::Field(std::string_view primeStr, StringRef name) :
Field(APSInt(primeStr), name) {}
50Field::Field(
const APInt &prime, StringRef name) : primeName(name) {
52 halfPrime = (primeMod + felt(1)) / felt(2);
53 bitwidth = prime.getBitWidth();
57 static std::once_flag fieldsInit;
58 std::call_once(fieldsInit, initKnownFields);
60 auto &knownFields = getKnownFields();
61 if (
auto it = knownFields.find(fieldName); it != knownFields.end()) {
69 return errFn().append(
"field '", fieldName,
"' is not defined");
75 auto res = tryGetField(fieldName);
77 return res.value().get();
79 std::string msg =
"field \"" + fieldName.str() +
"\" is unsupported";
81 errFn().append(msg).report();
83 llvm::report_fatal_error(msg.c_str());
86void Field::addField(Field &&f, EmitErrorFn errFn) {
88 auto existing = Field::tryGetField(f.name());
89 if (succeeded(existing)) {
92 debug::Appender(msg) <<
"Definition of \"" << f.name()
93 <<
"\" conflicts with prior definition: prior="
94 << existing.value().get().prime() <<
", new=" << f.prime();
96 errFn().append(msg).report();
98 llvm::report_fatal_error(msg.c_str());
103 getKnownFields().try_emplace(f.name(), f);
106void Field::initKnownFields() {
107 static constexpr const char BN128[] =
"bn128", BN254[] =
"bn254", GRUMPKIN[] =
"grumpkin",
108 BABYBEAR[] =
"babybear", GOLDILOCKS[] =
"goldilocks",
109 MERSENNE31[] =
"mersenne31", KOALABEAR[] =
"koalabear";
111 auto insert = [](
const char *name,
const char *primeStr) {
112 getKnownFields().try_emplace(name, Field(primeStr, name));
117 insert(BN128,
"21888242871839275222246405745257275088548364400416034343698204186575808495617");
118 insert(BN254,
"21888242871839275222246405745257275088548364400416034343698204186575808495617");
120 insert(GRUMPKIN,
"21888242871839275222246405745257275088696311157297823662689037894645226208583");
122 insert(BABYBEAR,
"2013265921");
124 insert(GOLDILOCKS,
"18446744069414584321");
126 insert(MERSENNE31,
"2147483647");
128 insert(KOALABEAR,
"2130706433");
131DynamicAPInt Field::reduce(
const DynamicAPInt &i)
const {
132 DynamicAPInt m = i % prime();
139DynamicAPInt Field::reduce(
const APInt &i)
const {
return reduce(
toDynamicAPInt(i)); }
145DynamicAPInt Field::inv(
const APInt &i)
const {
146 return modInversePrime(toDynamicAPInt(i), prime());
150 return IntegerAttr::get(
156static LogicalResult parseFields(Attribute a) {
159 Attribute, FailureOr<SmallVector<std::reference_wrapper<const Field>>>>(a)
165 [](
auto s) -> FailureOr<SmallVector<std::reference_wrapper<const Field>>> {
166 auto fieldRes = Field::tryGetField(s);
167 if (failed(fieldRes)) {
170 return SmallVector<std::reference_wrapper<const Field>> {fieldRes.value()};
173 [](
auto arr) -> FailureOr<SmallVector<std::reference_wrapper<const Field>>> {
175 SmallVector<std::reference_wrapper<const Field>> res;
176 for (Attribute elem : arr) {
177 if (
auto s = llvm::dyn_cast<StringAttr>(elem)) {
178 auto fieldRes = Field::tryGetField(s);
179 if (failed(fieldRes)) {
182 res.push_back(fieldRes.value());
189 .Default([](
auto) {
return failure(); });
195 return parseFields(a);
198 if (ModuleOp parentMod = modOp->getParentOfType<ModuleOp>()) {
210 LogicalResult &status;
211 mlir::Operation *scope;
216static void handleAttribute(mlir::Attribute, FieldsCtx &);
218static void handleType(mlir::Type type, FieldsCtx &ctx) {
219 TypeSwitch<mlir::Type> ts(type);
224 ctx.status = failure();
226 ctx.scope->emitWarning() <<
"felt type is unspecified, which may cause some passes to fail";
233 handleAttribute(record, ctx);
235 }).Case([&ctx](mlir::FunctionType funcType) {
236 for (
auto i : funcType.getInputs()) {
239 for (
auto o : funcType.getResults()) {
244 ts.Default([](
auto) {});
247static void handleAttribute(mlir::Attribute attr, FieldsCtx &ctx) {
248 TypeSwitch<mlir::Attribute> ts(attr);
249 ts.Case([&ctx](mlir::TypeAttr typeAttr) { handleType(typeAttr.getValue(), ctx); })
250 .Case([&ctx](mlir::ArrayAttr arrayAttr) {
251 for (
auto a : arrayAttr) {
252 handleAttribute(a, ctx);
255 .Case([&ctx](mlir::DictionaryAttr dictAttr) {
256 for (
auto a : dictAttr.getValue()) {
257 handleAttribute(a.getValue(), ctx);
259 }).Case([&ctx](llzk::pod::RecordAttr recordAttr) {
260 handleType(recordAttr.getType(), ctx);
261 }).Default([](
auto) {});
268 LogicalResult status = success();
269 root->walk([&fields, &status, silent](mlir::Operation *op) {
270 FieldsCtx ctx = {.fields = fields, .status = status, .scope = silent ? nullptr : op};
272 for (
auto result : op->getOpResults()) {
273 handleType(result.getType(), ctx);
276 for (
auto attr : op->getAttrs()) {
277 handleAttribute(attr.getValue(), ctx);
280 for (
auto ®ion : op->getRegions()) {
281 for (
auto &block : region) {
282 for (
auto &arg : block.getArguments()) {
283 handleType(arg.getType(), ctx);
292std::optional<std::reference_wrapper<const llzk::Field>>
298 ModuleOp modOp = dyn_cast<ModuleOp>(root);
300 modOp = root->getParentOfType<ModuleOp>();
308 if (failed(
collectFields(modOp, fields)) || fields.size() != 1) {
311 return *fields.begin();
This file implements helper methods for constructing DynamicAPInts.
Information about the prime finite field used for the interval analysis.
llvm::DynamicAPInt half() const
Returns p / 2.
mlir::IntegerAttr getPrimeAttr(mlir::MLIRContext *context, unsigned bitWidth) const
Return the field prime modulus materialized as an integer attribute at bitWidth.
static llvm::FailureOr< std::reference_wrapper< const Field > > tryGetField(llvm::StringRef fieldName)
Get a Field from a given field name string, or failure if the field is not defined.
llvm::DynamicAPInt toSigned(const llvm::DynamicAPInt &i) const
Converts a canonical field element to its signed integer representation: toSigned(f) = f if f < field...
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
llvm::DynamicAPInt inv(const llvm::DynamicAPInt &i) const
Returns the multiplicative inverse of i in prime field p.
unsigned bitWidth() const
static llvm::LogicalResult verifyFieldDefined(llvm::StringRef fieldName, EmitErrorFn errFn)
Search for a field with the given name, reporting an error if the field is not found.
static const Field & getField(llvm::StringRef fieldName, EmitErrorFn errFn)
Get a Field from a given field name string.
::mlir::Type getElementType() const
const ::llzk::Field & getField() const
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
DynamicAPInt toDynamicAPInt(StringRef str)
APInt toExactWidthAPInt(const DynamicAPInt &val, unsigned bitWidth)
LogicalResult addSpecifiedFields(ModuleOp modOp)
DynamicAPInt modInversePrime(const DynamicAPInt &f, const DynamicAPInt &p)
llvm::SmallSet< FieldRef, 2 > FieldSet
Typealias for a set of Fields.
std::optional< std::reference_wrapper< const Field > > tryDetectSpecifiedField(mlir::Operation *root)
Try to detect a uniquely used field from the enclosing LLZK module.
constexpr char FIELD_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that defines prime fields used in the circuit.
mlir::LogicalResult collectFields(mlir::Operation *root, FieldSet &fields, bool silent=true)
Collects all the fields used in a circuit.