LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Field.cpp
Go to the documentation of this file.
1//===-- Field.cpp -----------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#include "llzk/Util/Field.h"
11
19#include "llzk/Util/Constants.h"
20#include "llzk/Util/Debug.h"
22
23#include <mlir/IR/Attributes.h>
24#include <mlir/IR/BuiltinAttributes.h>
25#include <mlir/IR/Operation.h>
26
27#include <llvm/ADT/APSInt.h>
28#include <llvm/ADT/SlowDynamicAPInt.h>
29#include <llvm/ADT/Twine.h>
30#include <llvm/ADT/TypeSwitch.h>
31#include <llvm/Support/LogicalResult.h>
32
33#include <algorithm>
34#include <mutex>
35
36using namespace mlir;
37
38namespace llzk {
39
40// Having `knownFields` as a static local object ensures it is initialized when
41// `getKnownFields` is called, rather than relying on non-local static initialization
42// order (https://en.cppreference.com/cpp/language/initialization).
43static DenseMap<StringRef, Field> &getKnownFields() {
44 static DenseMap<StringRef, Field> knownFields;
45 return knownFields;
46}
47
48Field::Field(std::string_view primeStr, StringRef name) : Field(APSInt(primeStr), name) {}
49
50Field::Field(const APInt &prime, StringRef name) : primeName(name) {
51 primeMod = toDynamicAPInt(prime);
52 halfPrime = (primeMod + felt(1)) / felt(2);
53 bitwidth = prime.getBitWidth();
54}
55
56FailureOr<std::reference_wrapper<const Field>> Field::tryGetField(StringRef fieldName) {
57 static std::once_flag fieldsInit;
58 std::call_once(fieldsInit, initKnownFields);
59
60 auto &knownFields = getKnownFields();
61 if (auto it = knownFields.find(fieldName); it != knownFields.end()) {
62 return {it->second};
63 }
64 return failure();
65}
66
67LogicalResult Field::verifyFieldDefined(StringRef fieldName, EmitErrorFn errFn) {
68 if (failed(Field::tryGetField(fieldName))) {
69 return errFn().append("field '", fieldName, "' is not defined");
70 }
71 return success();
72}
73
74const Field &Field::getField(StringRef fieldName, EmitErrorFn errFn) {
75 auto res = tryGetField(fieldName);
76 if (succeeded(res)) {
77 return res.value().get();
78 }
79 std::string msg = "field \"" + fieldName.str() + "\" is unsupported";
80 if (errFn) {
81 errFn().append(msg).report();
82 }
83 llvm::report_fatal_error(msg.c_str());
84}
85
86void Field::addField(Field &&f, EmitErrorFn errFn) {
87 // Use `tryGetField()` to ensure knownFields is initialized before checking for conflicts.
88 auto existing = Field::tryGetField(f.name());
89 if (succeeded(existing)) {
90 // Field exists and conflicts with existing definition.
91 std::string msg;
92 debug::Appender(msg) << "Definition of \"" << f.name()
93 << "\" conflicts with prior definition: prior="
94 << existing.value().get().prime() << ", new=" << f.prime();
95 if (errFn) {
96 errFn().append(msg).report();
97 } else {
98 llvm::report_fatal_error(msg.c_str());
99 }
100 return;
101 }
102 // Field does not exist, add it.
103 getKnownFields().try_emplace(f.name(), f);
104}
105
106void Field::initKnownFields() {
107 static constexpr const char BN128[] = "bn128", BN254[] = "bn254", GRUMPKIN[] = "grumpkin",
108 BABYBEAR[] = "babybear", GOLDILOCKS[] = "goldilocks",
109 MERSENNE31[] = "mersenne31", KOALABEAR[] = "koalabear";
110
111 auto insert = [](const char *name, const char *primeStr) {
112 getKnownFields().try_emplace(name, Field(primeStr, name));
113 };
114
115 // Reference: https://github.com/iden3/circom/blob/master/program_structure/src/utils/constants.rs
116 // bn128/254, default for circom
117 insert(BN128, "21888242871839275222246405745257275088548364400416034343698204186575808495617");
118 insert(BN254, "21888242871839275222246405745257275088548364400416034343698204186575808495617");
119 // Grumpkin scalar field
120 insert(GRUMPKIN, "21888242871839275222246405745257275088696311157297823662689037894645226208583");
121 // 15 * 2^27 + 1, default for zirgen
122 insert(BABYBEAR, "2013265921");
123 // 2^64 - 2^32 + 1, used for plonky2
124 insert(GOLDILOCKS, "18446744069414584321");
125 // 2^31 - 1, used for Plonky3
126 insert(MERSENNE31, "2147483647");
127 // 2^31 - 2^24 + 1, also for Plonky3
128 insert(KOALABEAR, "2130706433");
129}
130
131DynamicAPInt Field::reduce(const DynamicAPInt &i) const {
132 DynamicAPInt m = i % prime();
133 if (m < 0) {
134 return prime() + m;
135 }
136 return m;
137}
138
139DynamicAPInt Field::reduce(const APInt &i) const { return reduce(toDynamicAPInt(i)); }
140
141DynamicAPInt Field::toSigned(const DynamicAPInt &i) const { return i < half() ? i : i - prime(); }
142
143DynamicAPInt Field::inv(const DynamicAPInt &i) const { return modInversePrime(i, prime()); }
144
145DynamicAPInt Field::inv(const APInt &i) const {
146 return modInversePrime(toDynamicAPInt(i), prime());
147}
148
149IntegerAttr Field::getPrimeAttr(MLIRContext *context, unsigned bitWidth) const {
150 return IntegerAttr::get(
151 IntegerType::get(context, bitWidth), toExactWidthAPInt(prime(), bitWidth)
152 );
153}
154
155// Parses Fields from the given attribute, if able.
156static LogicalResult parseFields(Attribute a) {
157 // clang-format off
158 return TypeSwitch<
159 Attribute, FailureOr<SmallVector<std::reference_wrapper<const Field>>>>(a)
160 .Case<UnitAttr>(
161 [](auto) {
162 return success();
163 })
164 .Case<StringAttr>(
165 [](auto s) -> FailureOr<SmallVector<std::reference_wrapper<const Field>>> {
166 auto fieldRes = Field::tryGetField(s);
167 if (failed(fieldRes)) {
168 return failure();
169 }
170 return SmallVector<std::reference_wrapper<const Field>> {fieldRes.value()};
171 })
172 .Case<ArrayAttr>(
173 [](auto arr) -> FailureOr<SmallVector<std::reference_wrapper<const Field>>> {
174 // An ArrayAttr may only contain inner StringAttr
175 SmallVector<std::reference_wrapper<const Field>> res;
176 for (Attribute elem : arr) {
177 if (auto s = llvm::dyn_cast<StringAttr>(elem)) {
178 auto fieldRes = Field::tryGetField(s);
179 if (failed(fieldRes)) {
180 return failure();
181 }
182 res.push_back(fieldRes.value());
183 } else {
184 return failure();
185 }
186 }
187 return res;
188 })
189 .Default([](auto) { return failure(); });
190 // clang-format on
191}
192
193LogicalResult addSpecifiedFields(ModuleOp modOp) {
194 if (Attribute a = modOp->getAttr(FIELD_ATTR_NAME)) {
195 return parseFields(a);
196 }
197 // Always recurse.
198 if (ModuleOp parentMod = modOp->getParentOfType<ModuleOp>()) {
199 return addSpecifiedFields(parentMod);
200 }
201 return success();
202}
203
204} // namespace llzk
205
206namespace {
207
208struct FieldsCtx {
209 llzk::FieldSet &fields;
210 LogicalResult &status;
211 mlir::Operation *scope;
212};
213
214} // namespace
215
216static void handleAttribute(mlir::Attribute, FieldsCtx &);
217
218static void handleType(mlir::Type type, FieldsCtx &ctx) {
219 TypeSwitch<mlir::Type> ts(type);
220 ts.Case([&ctx](llzk::felt::FeltType felt) {
221 if (felt.hasField()) {
222 ctx.fields.insert(felt.getField());
223 } else {
224 ctx.status = failure();
225 if (ctx.scope) {
226 ctx.scope->emitWarning() << "felt type is unspecified, which may cause some passes to fail";
227 }
228 }
229 })
230 .Case([&ctx](llzk::array::ArrayType array) { handleType(array.getElementType(), ctx); })
231 .Case([&ctx](llzk::pod::PodType pod) {
232 for (auto record : pod.getRecords()) {
233 handleAttribute(record, ctx);
234 }
235 }).Case([&ctx](mlir::FunctionType funcType) {
236 for (auto i : funcType.getInputs()) {
237 handleType(i, ctx);
238 }
239 for (auto o : funcType.getResults()) {
240 handleType(o, ctx);
241 }
242 });
243 // Do nothing by default for any other type
244 ts.Default([](auto) {});
245}
246
247static void handleAttribute(mlir::Attribute attr, FieldsCtx &ctx) {
248 TypeSwitch<mlir::Attribute> ts(attr);
249 ts.Case([&ctx](mlir::TypeAttr typeAttr) { handleType(typeAttr.getValue(), ctx); })
250 .Case([&ctx](mlir::ArrayAttr arrayAttr) {
251 for (auto a : arrayAttr) {
252 handleAttribute(a, ctx);
253 }
254 })
255 .Case([&ctx](mlir::DictionaryAttr dictAttr) {
256 for (auto a : dictAttr.getValue()) {
257 handleAttribute(a.getValue(), ctx);
258 }
259 }).Case([&ctx](llzk::pod::RecordAttr recordAttr) {
260 handleType(recordAttr.getType(), ctx);
261 }).Default([](auto) {});
262}
263
264LogicalResult llzk::collectFields(mlir::Operation *root, llzk::FieldSet &fields, bool silent) {
265 if (!root) {
266 return success(); // Nothing to do
267 }
268 LogicalResult status = success();
269 root->walk([&fields, &status, silent](mlir::Operation *op) {
270 FieldsCtx ctx = {.fields = fields, .status = status, .scope = silent ? nullptr : op};
271 // Crawl for types in the results,
272 for (auto result : op->getOpResults()) {
273 handleType(result.getType(), ctx);
274 }
275 // the attributes,
276 for (auto attr : op->getAttrs()) {
277 handleAttribute(attr.getValue(), ctx);
278 }
279 // block arguments (if any)
280 for (auto &region : op->getRegions()) {
281 for (auto &block : region) {
282 for (auto &arg : block.getArguments()) {
283 handleType(arg.getType(), ctx);
284 }
285 }
286 }
287 });
288
289 return status;
290}
291
292std::optional<std::reference_wrapper<const llzk::Field>>
293llzk::tryDetectSpecifiedField(Operation *root) {
294 if (!root) {
295 return std::nullopt;
296 }
297
298 ModuleOp modOp = dyn_cast<ModuleOp>(root);
299 if (!modOp) {
300 modOp = root->getParentOfType<ModuleOp>();
301 }
302
303 if (!modOp) {
304 return std::nullopt;
305 }
306
307 FieldSet fields;
308 if (failed(collectFields(modOp, fields)) || fields.size() != 1) {
309 return std::nullopt;
310 }
311 return *fields.begin();
312}
This file implements helper methods for constructing DynamicAPInts.
Information about the prime finite field used for the interval analysis.
Definition Field.h:36
llvm::DynamicAPInt half() const
Returns p / 2.
Definition Field.h:75
mlir::IntegerAttr getPrimeAttr(mlir::MLIRContext *context, unsigned bitWidth) const
Return the field prime modulus materialized as an integer attribute at bitWidth.
Definition Field.cpp:149
static llvm::FailureOr< std::reference_wrapper< const Field > > tryGetField(llvm::StringRef fieldName)
Get a Field from a given field name string, or failure if the field is not defined.
Definition Field.cpp:56
Field()=delete
llvm::DynamicAPInt toSigned(const llvm::DynamicAPInt &i) const
Converts a canonical field element to its signed integer representation: toSigned(f) = f if f < field...
Definition Field.cpp:141
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
Definition Field.h:72
llvm::DynamicAPInt inv(const llvm::DynamicAPInt &i) const
Returns the multiplicative inverse of i in prime field p.
unsigned bitWidth() const
Definition Field.h:107
static llvm::LogicalResult verifyFieldDefined(llvm::StringRef fieldName, EmitErrorFn errFn)
Search for a field with the given name, reporting an error if the field is not found.
Definition Field.cpp:67
static const Field & getField(llvm::StringRef fieldName, EmitErrorFn errFn)
Get a Field from a given field name string.
::mlir::Type getElementType() const
const ::llzk::Field & getField() const
Definition Dialect.cpp:168
bool hasField() const
Definition Types.h.inc:26
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
DynamicAPInt toDynamicAPInt(StringRef str)
APInt toExactWidthAPInt(const DynamicAPInt &val, unsigned bitWidth)
LogicalResult addSpecifiedFields(ModuleOp modOp)
Definition Field.cpp:193
DynamicAPInt modInversePrime(const DynamicAPInt &f, const DynamicAPInt &p)
llvm::SmallSet< FieldRef, 2 > FieldSet
Typealias for a set of Fields.
Definition Field.h:159
std::optional< std::reference_wrapper< const Field > > tryDetectSpecifiedField(mlir::Operation *root)
Try to detect a uniquely used field from the enclosing LLZK module.
constexpr char FIELD_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that defines prime fields used in the circuit.
Definition Constants.h:32
mlir::LogicalResult collectFields(mlir::Operation *root, FieldSet &fields, bool silent=true)
Collects all the fields used in a circuit.
Definition Field.cpp:264