LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Felt operation implementations ----------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
14#include "llzk/Util/Field.h"
16
17#include <mlir/IR/Builders.h>
18
19#include <llvm/ADT/DynamicAPInt.h>
20#include <llvm/ADT/SmallString.h>
21
22// TableGen'd implementation files
24
25// TableGen'd implementation files
26#define GET_OP_CLASSES
28
29using namespace mlir;
30using namespace llzk;
31
32namespace llzk::felt {
33
34//===------------------------------------------------------------------===//
35// Constant folding helpers
36//===------------------------------------------------------------------===//
37
38namespace {
39
40struct BinaryFoldData {
41 DynamicAPInt lhsVal, rhsVal;
42 StringRef fieldName;
43 const Field *field;
44};
45
46struct UnaryFoldData {
47 DynamicAPInt val;
48 StringRef fieldName;
49 const Field *field;
50};
51
57static std::optional<BinaryFoldData> tryGetBinaryFoldData(Attribute lhsAttr, Attribute rhsAttr) {
58 auto lhs = llvm::dyn_cast_or_null<FeltConstAttr>(lhsAttr);
59 auto rhs = llvm::dyn_cast_or_null<FeltConstAttr>(rhsAttr);
60 if (!lhs || !rhs) {
61 return std::nullopt;
62 }
63
64 StringAttr lhsFieldName = lhs.getFieldName();
65 StringAttr rhsFieldName = rhs.getFieldName();
66 if (!lhsFieldName || !rhsFieldName || lhsFieldName != rhsFieldName) {
67 return std::nullopt;
68 }
69
70 auto fieldRes = Field::tryGetField(lhsFieldName.getValue());
71 if (failed(fieldRes)) {
72 return std::nullopt;
73 }
74
75 return BinaryFoldData {
76 toDynamicAPInt(lhs.getValue()), toDynamicAPInt(rhs.getValue()), lhsFieldName.getValue(),
77 &fieldRes.value().get()
78 };
79}
80
82static std::optional<UnaryFoldData> tryGetUnaryFoldData(Attribute operandAttr) {
83 auto operand = llvm::dyn_cast_or_null<FeltConstAttr>(operandAttr);
84 if (!operand) {
85 return std::nullopt;
86 }
87
88 StringAttr fieldNameAttr = operand.getFieldName();
89 if (!fieldNameAttr) {
90 return std::nullopt;
91 }
92
93 auto fieldRes = Field::tryGetField(fieldNameAttr.getValue());
94 if (failed(fieldRes)) {
95 return std::nullopt;
96 }
97
98 return UnaryFoldData {
99 toDynamicAPInt(operand.getValue()), fieldNameAttr.getValue(), &fieldRes.value().get()
100 };
101}
102
104static FeltConstAttr buildFoldResult(
105 MLIRContext *ctx, const DynamicAPInt &val, const Field &field, StringRef fieldName
106) {
107 return FeltConstAttr::get(ctx, toAPInt(val, field.bitWidth()), fieldName);
108}
109
110} // namespace
111
112//===------------------------------------------------------------------===//
113// FeltConstantOp
114//===------------------------------------------------------------------===//
115
116void FeltConstantOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
117 SmallString<32> buf;
118 llvm::raw_svector_ostream(buf) << "felt_const_";
119 getValue().getValue().toStringUnsigned(buf);
120 setNameFn(getResult(), buf);
121}
122
124
126 MLIRContext *context, std::optional<Location> /*loc*/, Adaptor adaptor,
127 SmallVectorImpl<Type> &inferred
128) {
129 inferred.resize(1);
130 auto value = adaptor.getValue(); // FeltConstAttr
131 inferred[0] = value ? value.getType() : FeltType::get(context, StringAttr());
132 return success();
133}
134
135bool FeltConstantOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { return l == r; }
136
137//===------------------------------------------------------------------===//
138// Binary op folds
139//===------------------------------------------------------------------===//
140
141OpFoldResult AddFeltOp::fold(FoldAdaptor adaptor) {
142 auto data = tryGetBinaryFoldData(adaptor.getLhs(), adaptor.getRhs());
143 if (!data) {
144 return {};
145 }
146 return buildFoldResult(
147 getContext(), data->field->reduce(data->lhsVal + data->rhsVal), *data->field, data->fieldName
148 );
149}
150
151OpFoldResult SubFeltOp::fold(FoldAdaptor adaptor) {
152 auto data = tryGetBinaryFoldData(adaptor.getLhs(), adaptor.getRhs());
153 if (!data) {
154 return {};
155 }
156 return buildFoldResult(
157 getContext(), data->field->reduce(data->lhsVal - data->rhsVal), *data->field, data->fieldName
158 );
159}
160
161OpFoldResult MulFeltOp::fold(FoldAdaptor adaptor) {
162 auto data = tryGetBinaryFoldData(adaptor.getLhs(), adaptor.getRhs());
163 if (!data) {
164 return {};
165 }
166 return buildFoldResult(
167 getContext(), data->field->reduce(data->lhsVal * data->rhsVal), *data->field, data->fieldName
168 );
169}
170
171OpFoldResult PowFeltOp::fold(FoldAdaptor adaptor) {
172 auto data = tryGetBinaryFoldData(adaptor.getLhs(), adaptor.getRhs());
173 if (!data) {
174 return {};
175 }
176 return buildFoldResult(
177 getContext(), modExp(data->lhsVal, data->rhsVal, data->field->prime()), *data->field,
178 data->fieldName
179 );
180}
181
182OpFoldResult DivFeltOp::fold(FoldAdaptor adaptor) {
183 auto data = tryGetBinaryFoldData(adaptor.getLhs(), adaptor.getRhs());
184 if (!data || data->rhsVal == 0) {
185 return {};
186 }
187 return buildFoldResult(
188 getContext(), data->field->reduce(data->lhsVal * data->field->inv(data->rhsVal)),
189 *data->field, data->fieldName
190 );
191}
192
194 auto data = tryGetBinaryFoldData(adaptor.getLhs(), adaptor.getRhs());
195 if (!data || data->rhsVal == 0) {
196 return {};
197 }
198 // Both values are non-negative field elements; standard integer division
199 // gives the correct unsigned quotient, already in [0, lhs] < prime.
200 return buildFoldResult(getContext(), data->lhsVal / data->rhsVal, *data->field, data->fieldName);
201}
202
204 auto data = tryGetBinaryFoldData(adaptor.getLhs(), adaptor.getRhs());
205 if (!data) {
206 return {};
207 }
208 const Field *field = data->field;
209 DynamicAPInt rhs = data->rhsVal;
210 if (rhs == 0 || rhs == field->prime()) {
211 return {};
212 }
213 DynamicAPInt sRhs = field->toSigned(rhs);
214 DynamicAPInt sLhs = field->toSigned(data->lhsVal);
215 // DynamicAPInt / truncates toward zero (same as C++ signed int division).
216 return buildFoldResult(getContext(), field->reduce(sLhs / sRhs), *field, data->fieldName);
217}
218
220 auto data = tryGetBinaryFoldData(adaptor.getLhs(), adaptor.getRhs());
221 if (!data || data->rhsVal == 0) {
222 return {};
223 }
224 // Both non-negative, so % gives the correct unsigned remainder in [0, rhs) < prime.
225 return buildFoldResult(getContext(), data->lhsVal % data->rhsVal, *data->field, data->fieldName);
226}
227
228OpFoldResult SignedModFeltOp::fold(FoldAdaptor adaptor) {
229 auto data = tryGetBinaryFoldData(adaptor.getLhs(), adaptor.getRhs());
230 if (!data) {
231 return {};
232 }
233 const Field *field = data->field;
234 DynamicAPInt rhs = data->rhsVal;
235 if (rhs == 0 || rhs == field->prime()) {
236 return {};
237 }
238 DynamicAPInt sRhs = field->toSigned(rhs);
239 DynamicAPInt sLhs = field->toSigned(data->lhsVal);
240 return buildFoldResult(getContext(), field->reduce(sLhs % sRhs), *field, data->fieldName);
241}
242
243OpFoldResult AndFeltOp::fold(FoldAdaptor adaptor) {
244 auto data = tryGetBinaryFoldData(adaptor.getLhs(), adaptor.getRhs());
245 if (!data) {
246 return {};
247 }
248 return buildFoldResult(
249 getContext(), data->field->reduce(data->lhsVal & data->rhsVal), *data->field, data->fieldName
250 );
251}
252
253OpFoldResult OrFeltOp::fold(FoldAdaptor adaptor) {
254 auto data = tryGetBinaryFoldData(adaptor.getLhs(), adaptor.getRhs());
255 if (!data) {
256 return {};
257 }
258 return buildFoldResult(
259 getContext(), data->field->reduce(data->lhsVal | data->rhsVal), *data->field, data->fieldName
260 );
261}
262
263OpFoldResult XorFeltOp::fold(FoldAdaptor adaptor) {
264 auto data = tryGetBinaryFoldData(adaptor.getLhs(), adaptor.getRhs());
265 if (!data) {
266 return {};
267 }
268 return buildFoldResult(
269 getContext(), data->field->reduce(data->lhsVal ^ data->rhsVal), *data->field, data->fieldName
270 );
271}
272
273OpFoldResult ShlFeltOp::fold(FoldAdaptor adaptor) {
274 auto data = tryGetBinaryFoldData(adaptor.getLhs(), adaptor.getRhs());
275 if (!data) {
276 return {};
277 }
278 return buildFoldResult(
279 getContext(), data->field->reduce(data->lhsVal << data->rhsVal), *data->field, data->fieldName
280 );
281}
282
283OpFoldResult ShrFeltOp::fold(FoldAdaptor adaptor) {
284 auto data = tryGetBinaryFoldData(adaptor.getLhs(), adaptor.getRhs());
285 if (!data) {
286 return {};
287 }
288 // Any shift `amount >= bitwidth` will yield zero.
289 const Field *field = data->field;
290 if (data->rhsVal >= DynamicAPInt(field->bitWidth())) {
291 return buildFoldResult(getContext(), DynamicAPInt(0), *field, data->fieldName);
292 }
293 // Right-shifting a non-negative value always yields a value in [0, lhs] < prime;
294 // no modular reduction required.
295 return buildFoldResult(getContext(), data->lhsVal >> data->rhsVal, *field, data->fieldName);
296}
297
298//===------------------------------------------------------------------===//
299// Unary op folds
300//===------------------------------------------------------------------===//
301
302OpFoldResult NegFeltOp::fold(FoldAdaptor adaptor) {
303 auto data = tryGetUnaryFoldData(adaptor.getOperand());
304 if (!data) {
305 return {};
306 }
307 return buildFoldResult(
308 getContext(), data->field->reduce(-data->val), *data->field, data->fieldName
309 );
310}
311
312OpFoldResult InvFeltOp::fold(FoldAdaptor adaptor) {
313 auto data = tryGetUnaryFoldData(adaptor.getOperand());
314 if (!data || data->val == 0) {
315 return {};
316 }
317 return buildFoldResult(getContext(), data->field->inv(data->val), *data->field, data->fieldName);
318}
319
320OpFoldResult NotFeltOp::fold(FoldAdaptor adaptor) {
321 auto data = tryGetUnaryFoldData(adaptor.getOperand());
322 if (!data) {
323 return {};
324 }
325 // One's complement at field.bitWidth() bits: maxMask = 2^bitWidth - 1,
326 // result = reduce(maxMask ^ val). The operator<< here is llzk::operator<<
327 // on DynamicAPInt (defined in DynamicAPIntHelper.h).
328 DynamicAPInt maxMask =
329 (DynamicAPInt(1) << DynamicAPInt(data->field->bitWidth())) - DynamicAPInt(1);
330 return buildFoldResult(
331 getContext(), data->field->reduce(maxMask ^ data->val), *data->field, data->fieldName
332 );
333}
334
335} // namespace llzk::felt
This file implements helper methods for constructing DynamicAPInts.
Information about the prime finite field used for the interval analysis.
Definition Field.h:35
static llvm::FailureOr< std::reference_wrapper< const Field > > tryGetField(llvm::StringRef fieldName)
Get a Field from a given field name string, or failure if the field is not defined.
Definition Field.cpp:50
llvm::DynamicAPInt toSigned(const llvm::DynamicAPInt &i) const
Converts a canonical field element to its signed integer representation: toSigned(f) = f if f < field...
Definition Field.cpp:131
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
Definition Field.h:71
llvm::DynamicAPInt reduce(const llvm::DynamicAPInt &i) const
Returns i mod p and reduces the result into the appropriate bitwidth.
unsigned bitWidth() const
Definition Field.h:106
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:141
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:187
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:243
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:374
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:182
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:561
::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location > location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type > &inferredReturnTypes)
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:773
::llzk::felt::FeltConstAttr getValueAttr()
Definition Ops.h.inc:825
::mlir::TypedValue<::mlir::Type > getResult()
Definition Ops.h.inc:812
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
Definition Ops.cpp:116
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:123
static bool isCompatibleReturnTypes(::mlir::TypeRange l, ::mlir::TypeRange r)
Definition Ops.cpp:135
::llzk::felt::FeltConstAttr getValue()
Definition Ops.cpp.inc:794
static FeltType get(::mlir::MLIRContext *context, ::mlir::StringAttr fieldName)
Definition Types.cpp.inc:67
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:312
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:954
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:1132
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:161
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:302
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:1315
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:1489
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:320
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:1667
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:253
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:171
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:1854
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:2041
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:273
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:283
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:2228
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:203
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:2415
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:228
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:2602
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:151
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:2789
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:2976
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:193
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:219
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:3163
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:3350
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:263
DynamicAPInt toDynamicAPInt(StringRef str)
APInt toAPInt(const DynamicAPInt &val, unsigned bitWidth)
DynamicAPInt modExp(const DynamicAPInt &base, const DynamicAPInt &exp, const DynamicAPInt &mod)