17#include <mlir/IR/Builders.h>
19#include <llvm/ADT/DynamicAPInt.h>
20#include <llvm/ADT/SmallString.h>
40struct BinaryFoldData {
41 DynamicAPInt lhsVal, rhsVal;
57static std::optional<BinaryFoldData> tryGetBinaryFoldData(Attribute lhsAttr, Attribute rhsAttr) {
58 auto lhs = llvm::dyn_cast_or_null<FeltConstAttr>(lhsAttr);
59 auto rhs = llvm::dyn_cast_or_null<FeltConstAttr>(rhsAttr);
64 StringAttr lhsFieldName = lhs.getFieldName();
65 StringAttr rhsFieldName = rhs.getFieldName();
66 if (!lhsFieldName || !rhsFieldName || lhsFieldName != rhsFieldName) {
71 if (failed(fieldRes)) {
75 return BinaryFoldData {
77 &fieldRes.value().get()
82static std::optional<UnaryFoldData> tryGetUnaryFoldData(Attribute operandAttr) {
83 auto operand = llvm::dyn_cast_or_null<FeltConstAttr>(operandAttr);
88 StringAttr fieldNameAttr = operand.getFieldName();
94 if (failed(fieldRes)) {
98 return UnaryFoldData {
99 toDynamicAPInt(operand.getValue()), fieldNameAttr.getValue(), &fieldRes.value().get()
104static FeltConstAttr buildFoldResult(
105 MLIRContext *ctx,
const DynamicAPInt &val,
const Field &field, StringRef fieldName
107 return FeltConstAttr::get(ctx,
toAPInt(val, field.
bitWidth()), fieldName);
118 llvm::raw_svector_ostream(buf) <<
"felt_const_";
119 getValue().getValue().toStringUnsigned(buf);
126 MLIRContext *context, std::optional<Location> , Adaptor adaptor,
127 SmallVectorImpl<Type> &inferred
130 auto value = adaptor.getValue();
131 inferred[0] = value ? value.getType() :
FeltType::get(context, StringAttr());
142 auto data = tryGetBinaryFoldData(adaptor.
getLhs(), adaptor.
getRhs());
146 return buildFoldResult(
147 getContext(), data->field->reduce(data->lhsVal + data->rhsVal), *data->field, data->fieldName
152 auto data = tryGetBinaryFoldData(adaptor.
getLhs(), adaptor.
getRhs());
156 return buildFoldResult(
157 getContext(), data->field->reduce(data->lhsVal - data->rhsVal), *data->field, data->fieldName
162 auto data = tryGetBinaryFoldData(adaptor.
getLhs(), adaptor.
getRhs());
166 return buildFoldResult(
167 getContext(), data->field->reduce(data->lhsVal * data->rhsVal), *data->field, data->fieldName
172 auto data = tryGetBinaryFoldData(adaptor.
getLhs(), adaptor.
getRhs());
176 return buildFoldResult(
177 getContext(),
modExp(data->lhsVal, data->rhsVal, data->field->prime()), *data->field,
183 auto data = tryGetBinaryFoldData(adaptor.
getLhs(), adaptor.
getRhs());
184 if (!data || data->rhsVal == 0) {
187 return buildFoldResult(
188 getContext(), data->field->reduce(data->lhsVal * data->field->inv(data->rhsVal)),
189 *data->field, data->fieldName
194 auto data = tryGetBinaryFoldData(adaptor.
getLhs(), adaptor.
getRhs());
195 if (!data || data->rhsVal == 0) {
200 return buildFoldResult(getContext(), data->lhsVal / data->rhsVal, *data->field, data->fieldName);
204 auto data = tryGetBinaryFoldData(adaptor.
getLhs(), adaptor.
getRhs());
208 const Field *field = data->field;
209 DynamicAPInt rhs = data->rhsVal;
210 if (rhs == 0 || rhs == field->
prime()) {
213 DynamicAPInt sRhs = field->
toSigned(rhs);
214 DynamicAPInt sLhs = field->
toSigned(data->lhsVal);
216 return buildFoldResult(getContext(), field->
reduce(sLhs / sRhs), *field, data->fieldName);
220 auto data = tryGetBinaryFoldData(adaptor.
getLhs(), adaptor.
getRhs());
221 if (!data || data->rhsVal == 0) {
225 return buildFoldResult(getContext(), data->lhsVal % data->rhsVal, *data->field, data->fieldName);
229 auto data = tryGetBinaryFoldData(adaptor.
getLhs(), adaptor.
getRhs());
233 const Field *field = data->field;
234 DynamicAPInt rhs = data->rhsVal;
235 if (rhs == 0 || rhs == field->
prime()) {
238 DynamicAPInt sRhs = field->
toSigned(rhs);
239 DynamicAPInt sLhs = field->
toSigned(data->lhsVal);
240 return buildFoldResult(getContext(), field->
reduce(sLhs % sRhs), *field, data->fieldName);
244 auto data = tryGetBinaryFoldData(adaptor.
getLhs(), adaptor.
getRhs());
248 return buildFoldResult(
249 getContext(), data->field->reduce(data->lhsVal & data->rhsVal), *data->field, data->fieldName
254 auto data = tryGetBinaryFoldData(adaptor.
getLhs(), adaptor.
getRhs());
258 return buildFoldResult(
259 getContext(), data->field->reduce(data->lhsVal | data->rhsVal), *data->field, data->fieldName
264 auto data = tryGetBinaryFoldData(adaptor.
getLhs(), adaptor.
getRhs());
268 return buildFoldResult(
269 getContext(), data->field->reduce(data->lhsVal ^ data->rhsVal), *data->field, data->fieldName
274 auto data = tryGetBinaryFoldData(adaptor.
getLhs(), adaptor.
getRhs());
278 return buildFoldResult(
279 getContext(), data->field->reduce(data->lhsVal << data->rhsVal), *data->field, data->fieldName
284 auto data = tryGetBinaryFoldData(adaptor.
getLhs(), adaptor.
getRhs());
289 const Field *field = data->field;
290 if (data->rhsVal >= DynamicAPInt(field->
bitWidth())) {
291 return buildFoldResult(getContext(), DynamicAPInt(0), *field, data->fieldName);
295 return buildFoldResult(getContext(), data->lhsVal >> data->rhsVal, *field, data->fieldName);
303 auto data = tryGetUnaryFoldData(adaptor.
getOperand());
307 return buildFoldResult(
308 getContext(), data->field->reduce(-data->val), *data->field, data->fieldName
313 auto data = tryGetUnaryFoldData(adaptor.
getOperand());
314 if (!data || data->val == 0) {
317 return buildFoldResult(getContext(), data->field->inv(data->val), *data->field, data->fieldName);
321 auto data = tryGetUnaryFoldData(adaptor.
getOperand());
328 DynamicAPInt maxMask =
329 (DynamicAPInt(1) << DynamicAPInt(data->field->bitWidth())) - DynamicAPInt(1);
330 return buildFoldResult(
331 getContext(), data->field->reduce(maxMask ^ data->val), *data->field, data->fieldName
This file implements helper methods for constructing DynamicAPInts.
Information about the prime finite field used for the interval analysis.
static llvm::FailureOr< std::reference_wrapper< const Field > > tryGetField(llvm::StringRef fieldName)
Get a Field from a given field name string, or failure if the field is not defined.
llvm::DynamicAPInt toSigned(const llvm::DynamicAPInt &i) const
Converts a canonical field element to its signed integer representation: toSigned(f) = f if f < field...
llvm::DynamicAPInt prime() const
For the prime field p, returns p.
llvm::DynamicAPInt reduce(const llvm::DynamicAPInt &i) const
Returns i mod p and reduces the result into the appropriate bitwidth.
unsigned bitWidth() const
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location > location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type > &inferredReturnTypes)
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
::llzk::felt::FeltConstAttr getValueAttr()
::mlir::TypedValue<::mlir::Type > getResult()
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
static bool isCompatibleReturnTypes(::mlir::TypeRange l, ::mlir::TypeRange r)
::llzk::felt::FeltConstAttr getValue()
static FeltType get(::mlir::MLIRContext *context, ::mlir::StringAttr fieldName)
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
DynamicAPInt toDynamicAPInt(StringRef str)
APInt toAPInt(const DynamicAPInt &val, unsigned bitWidth)
DynamicAPInt modExp(const DynamicAPInt &base, const DynamicAPInt &exp, const DynamicAPInt &mod)