LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Field.cpp
Go to the documentation of this file.
1//===-- Field.cpp -----------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#include "llzk/Util/Field.h"
11
19#include "llzk/Util/Constants.h"
20#include "llzk/Util/Debug.h"
22
23#include <mlir/IR/Attributes.h>
24#include <mlir/IR/BuiltinAttributes.h>
25#include <mlir/IR/Operation.h>
26
27#include <llvm/ADT/APSInt.h>
28#include <llvm/ADT/SlowDynamicAPInt.h>
29#include <llvm/ADT/Twine.h>
30#include <llvm/ADT/TypeSwitch.h>
31#include <llvm/Support/LogicalResult.h>
32
33#include <algorithm>
34#include <mutex>
35
36using namespace mlir;
37
38namespace llzk {
39
40static DenseMap<StringRef, Field> knownFields;
41
42Field::Field(std::string_view primeStr, StringRef name) : Field(APSInt(primeStr), name) {}
43
44Field::Field(const APInt &prime, StringRef name) : primeName(name) {
45 primeMod = toDynamicAPInt(prime);
46 halfPrime = (primeMod + felt(1)) / felt(2);
47 bitwidth = prime.getBitWidth();
48}
49
50FailureOr<std::reference_wrapper<const Field>> Field::tryGetField(StringRef fieldName) {
51 static std::once_flag fieldsInit;
52 std::call_once(fieldsInit, initKnownFields);
53
54 if (auto it = knownFields.find(fieldName); it != knownFields.end()) {
55 return {it->second};
56 }
57 return failure();
58}
59
60LogicalResult Field::verifyFieldDefined(StringRef fieldName, EmitErrorFn errFn) {
61 if (failed(Field::tryGetField(fieldName))) {
62 return errFn().append("field '", fieldName, "' is not defined");
63 }
64 return success();
65}
66
67const Field &Field::getField(StringRef fieldName, EmitErrorFn errFn) {
68 auto res = tryGetField(fieldName);
69 if (succeeded(res)) {
70 return res.value().get();
71 }
72 std::string msg = "field \"" + fieldName.str() + "\" is unsupported";
73 if (errFn) {
74 errFn().append(msg).report();
75 }
76 llvm::report_fatal_error(msg.c_str());
77}
78
79void Field::addField(Field &&f, EmitErrorFn errFn) {
80 // Use `tryGetField()` to ensure knownFields is initialized before checking for conflicts.
81 auto existing = Field::tryGetField(f.name());
82 if (succeeded(existing)) {
83 // Field exists and conflicts with existing definition.
84 std::string msg;
85 debug::Appender(msg) << "Definition of \"" << f.name()
86 << "\" conflicts with prior definition: prior="
87 << existing.value().get().prime() << ", new=" << f.prime();
88 if (errFn) {
89 errFn().append(msg).report();
90 } else {
91 llvm::report_fatal_error(msg.c_str());
92 }
93 return;
94 }
95 // Field does not exist, add it.
96 knownFields.try_emplace(f.name(), f);
97}
98
99void Field::initKnownFields() {
100 static constexpr const char BN128[] = "bn128", BN254[] = "bn254", BABYBEAR[] = "babybear",
101 GOLDILOCKS[] = "goldilocks", MERSENNE31[] = "mersenne31",
102 KOALABEAR[] = "koalabear";
103
104 auto insert = [](const char *name, const char *primeStr) {
105 knownFields.try_emplace(name, Field(primeStr, name));
106 };
107
108 // bn128/254, default for circom
109 insert(BN128, "21888242871839275222246405745257275088696311157297823662689037894645226208583");
110 insert(BN254, "21888242871839275222246405745257275088696311157297823662689037894645226208583");
111 // 15 * 2^27 + 1, default for zirgen
112 insert(BABYBEAR, "2013265921");
113 // 2^64 - 2^32 + 1, used for plonky2
114 insert(GOLDILOCKS, "18446744069414584321");
115 // 2^31 - 1, used for Plonky3
116 insert(MERSENNE31, "2147483647");
117 // 2^31 - 2^24 + 1, also for Plonky3
118 insert(KOALABEAR, "2130706433");
119}
120
121DynamicAPInt Field::reduce(const DynamicAPInt &i) const {
122 DynamicAPInt m = i % prime();
123 if (m < 0) {
124 return prime() + m;
125 }
126 return m;
127}
128
129DynamicAPInt Field::reduce(const APInt &i) const { return reduce(toDynamicAPInt(i)); }
130
131DynamicAPInt Field::toSigned(const DynamicAPInt &i) const { return i < half() ? i : i - prime(); }
132
133DynamicAPInt Field::inv(const DynamicAPInt &i) const { return modInversePrime(i, prime()); }
134
135DynamicAPInt Field::inv(const APInt &i) const {
136 return modInversePrime(toDynamicAPInt(i), prime());
137}
138
139// Parses Fields from the given attribute, if able.
140static LogicalResult parseFields(Attribute a) {
141 // clang-format off
142 return TypeSwitch<
143 Attribute, FailureOr<SmallVector<std::reference_wrapper<const Field>>>>(a)
144 .Case<UnitAttr>(
145 [](auto) {
146 return success();
147 })
148 .Case<StringAttr>(
149 [](auto s) -> FailureOr<SmallVector<std::reference_wrapper<const Field>>> {
150 auto fieldRes = Field::tryGetField(s);
151 if (failed(fieldRes)) {
152 return failure();
153 }
154 return SmallVector<std::reference_wrapper<const Field>> {fieldRes.value()};
155 })
156 .Case<ArrayAttr>(
157 [](auto arr) -> FailureOr<SmallVector<std::reference_wrapper<const Field>>> {
158 // An ArrayAttr may only contain inner StringAttr
159 SmallVector<std::reference_wrapper<const Field>> res;
160 for (Attribute elem : arr) {
161 if (auto s = llvm::dyn_cast<StringAttr>(elem)) {
162 auto fieldRes = Field::tryGetField(s);
163 if (failed(fieldRes)) {
164 return failure();
165 }
166 res.push_back(fieldRes.value());
167 } else {
168 return failure();
169 }
170 }
171 return res;
172 })
173 .Default([](auto) { return failure(); });
174 // clang-format on
175}
176
177LogicalResult addSpecifiedFields(ModuleOp modOp) {
178 if (Attribute a = modOp->getAttr(FIELD_ATTR_NAME)) {
179 return parseFields(a);
180 }
181 // Always recurse.
182 if (ModuleOp parentMod = modOp->getParentOfType<ModuleOp>()) {
183 return addSpecifiedFields(parentMod);
184 }
185 return success();
186}
187
188} // namespace llzk
189
190namespace {
191
192struct FieldsCtx {
193 llzk::FieldSet &fields;
194 LogicalResult &status;
195 mlir::Operation *scope;
196};
197
198} // namespace
199
200static void handleAttribute(mlir::Attribute, FieldsCtx &);
201
202static void handleType(mlir::Type type, FieldsCtx &ctx) {
203 TypeSwitch<mlir::Type> ts(type);
204 ts.Case([&ctx](llzk::felt::FeltType felt) {
205 if (felt.hasField()) {
206 ctx.fields.insert(felt.getField());
207 } else {
208 ctx.status = failure();
209 if (ctx.scope) {
210 ctx.scope->emitWarning() << "felt type is unspecified, which may cause some passes to fail";
211 }
212 }
213 })
214 .Case([&ctx](llzk::array::ArrayType array) { handleType(array.getElementType(), ctx); })
215 .Case([&ctx](llzk::pod::PodType pod) {
216 for (auto record : pod.getRecords()) {
217 handleAttribute(record, ctx);
218 }
219 }).Case([&ctx](mlir::FunctionType funcType) {
220 for (auto i : funcType.getInputs()) {
221 handleType(i, ctx);
222 }
223 for (auto o : funcType.getResults()) {
224 handleType(o, ctx);
225 }
226 });
227 // Do nothing by default for any other type
228 ts.Default([](auto) {});
229}
230
231static void handleAttribute(mlir::Attribute attr, FieldsCtx &ctx) {
232 TypeSwitch<mlir::Attribute> ts(attr);
233 ts.Case([&ctx](mlir::TypeAttr typeAttr) { handleType(typeAttr.getValue(), ctx); })
234 .Case([&ctx](mlir::ArrayAttr arrayAttr) {
235 for (auto a : arrayAttr) {
236 handleAttribute(a, ctx);
237 }
238 })
239 .Case([&ctx](mlir::DictionaryAttr dictAttr) {
240 for (auto a : dictAttr.getValue()) {
241 handleAttribute(a.getValue(), ctx);
242 }
243 }).Case([&ctx](llzk::pod::RecordAttr recordAttr) {
244 handleType(recordAttr.getType(), ctx);
245 }).Default([](auto) {});
246}
247
248LogicalResult llzk::collectFields(mlir::Operation *root, llzk::FieldSet &fields, bool silent) {
249 if (!root) {
250 return success(); // Nothing to do
251 }
252 LogicalResult status = success();
253 root->walk([&fields, &status, silent](mlir::Operation *op) {
254 FieldsCtx ctx = {.fields = fields, .status = status, .scope = silent ? nullptr : op};
255 // Crawl for types in the results,
256 for (auto result : op->getOpResults()) {
257 handleType(result.getType(), ctx);
258 }
259 // the attributes,
260 for (auto attr : op->getAttrs()) {
261 handleAttribute(attr.getValue(), ctx);
262 }
263 // block arguments (if any)
264 for (auto &region : op->getRegions()) {
265 for (auto &block : region) {
266 for (auto &arg : block.getArguments()) {
267 handleType(arg.getType(), ctx);
268 }
269 }
270 }
271 });
272
273 return status;
274}
275
276std::optional<std::reference_wrapper<const llzk::Field>>
277llzk::tryDetectSpecifiedField(Operation *root) {
278 if (!root) {
279 return std::nullopt;
280 }
281
282 ModuleOp modOp = dyn_cast<ModuleOp>(root);
283 if (!modOp) {
284 modOp = root->getParentOfType<ModuleOp>();
285 }
286
287 if (!modOp) {
288 return std::nullopt;
289 }
290
291 FieldSet fields;
292 if (failed(collectFields(modOp, fields)) || fields.size() != 1) {
293 return std::nullopt;
294 }
295 return *fields.begin();
296}
This file implements helper methods for constructing DynamicAPInts.
Information about the prime finite field used for the interval analysis.
Definition Field.h:35
llvm::DynamicAPInt half() const
Returns p / 2.
Definition Field.h:74
static llvm::FailureOr< std::reference_wrapper< const Field > > tryGetField(llvm::StringRef fieldName)
Get a Field from a given field name string, or failure if the field is not defined.
Definition Field.cpp:50
Field()=delete
llvm::DynamicAPInt toSigned(const llvm::DynamicAPInt &i) const
Converts a canonical field element to its signed integer representation: toSigned(f) = f if f < field...
Definition Field.cpp:131
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
Definition Field.h:71
llvm::DynamicAPInt inv(const llvm::DynamicAPInt &i) const
Returns the multiplicative inverse of i in prime field p.
static llvm::LogicalResult verifyFieldDefined(llvm::StringRef fieldName, EmitErrorFn errFn)
Search for a field with the given name, reporting an error if the field is not found.
Definition Field.cpp:60
static const Field & getField(llvm::StringRef fieldName, EmitErrorFn errFn)
Get a Field from a given field name string.
::mlir::Type getElementType() const
const ::llzk::Field & getField() const
Definition Dialect.cpp:168
bool hasField() const
Definition Types.h.inc:26
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
DynamicAPInt toDynamicAPInt(StringRef str)
LogicalResult addSpecifiedFields(ModuleOp modOp)
Definition Field.cpp:177
DynamicAPInt modInversePrime(const DynamicAPInt &f, const DynamicAPInt &p)
llvm::SmallSet< FieldRef, 2 > FieldSet
Typealias for a set of Fields.
Definition Field.h:155
std::optional< std::reference_wrapper< const Field > > tryDetectSpecifiedField(mlir::Operation *root)
Try to detect a uniquely used field from the enclosing LLZK module.
constexpr char FIELD_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that defines prime fields used in the circuit.
Definition Constants.h:32
mlir::LogicalResult collectFields(mlir::Operation *root, FieldSet &fields, bool silent=true)
Collects all the fields used in a circuit.
Definition Field.cpp:248