LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
SMTAttributes.cpp
Go to the documentation of this file.
1//===- SMTAttributes.cpp - Implement SMT attributes -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
13
14#include <mlir/IR/Builders.h>
15#include <mlir/IR/DialectImplementation.h>
16
17#include <llvm/ADT/TypeSwitch.h>
18
19using namespace mlir;
20using namespace llzk::smt;
21
22//===----------------------------------------------------------------------===//
23// BitVectorAttr
24//===----------------------------------------------------------------------===//
25
26LogicalResult BitVectorAttr::verify(
27 function_ref<InFlightDiagnostic()> emitError,
28 APInt value // NOLINT(performance-unnecessary-value-param)
29) {
30 if (value.getBitWidth() < 1) {
31 return emitError() << "bit-width must be at least 1, but got " << value.getBitWidth();
32 }
33 return success();
34}
35
36std::string BitVectorAttr::getValueAsString(bool prefix) const {
37 unsigned width = getValue().getBitWidth();
38 SmallVector<char> toPrint;
39 StringRef pref = prefix ? "#" : "";
40 if (width % 4 == 0) {
41 getValue().toString(toPrint, 16, false, false, false);
42 // APInt's 'toString' omits leading zeros. However, those are critical here
43 // because they determine the bit-width of the bit-vector.
44 SmallVector<char> leadingZeros(width / 4 - toPrint.size(), '0');
45 return (pref + "x" + Twine(leadingZeros) + toPrint).str();
46 }
47
48 getValue().toString(toPrint, 2, false, false, false);
49 // APInt's 'toString' omits leading zeros
50 SmallVector<char> leadingZeros(width - toPrint.size(), '0');
51 return (pref + "b" + Twine(leadingZeros) + toPrint).str();
52}
53
55static FailureOr<APInt>
56parseBitVectorString(function_ref<InFlightDiagnostic()> emitError, StringRef value) {
57 auto reportError = [&](StringRef msg) -> FailureOr<APInt> {
58 if (emitError) {
59 return emitError() << msg;
60 }
61 return failure();
62 };
63
64 if (value[0] != '#') {
65 return reportError("expected '#'");
66 }
67
68 if (value.size() < 3) {
69 return reportError("expected at least one digit");
70 }
71
72 if (value[1] == 'b') {
73 return APInt(value.size() - 2, std::string(value.begin() + 2, value.end()), 2);
74 }
75
76 if (value[1] == 'x') {
77 return APInt((value.size() - 2) * 4, std::string(value.begin() + 2, value.end()), 16);
78 }
79
80 return reportError("expected either 'b' or 'x'");
81}
82
83BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) {
84 auto maybeValue = parseBitVectorString(nullptr, value);
85
86 assert(succeeded(maybeValue) && "string must have SMT-LIB format");
87 return Base::get(context, *maybeValue);
88}
89
90BitVectorAttr BitVectorAttr::getChecked(
91 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context, StringRef value
92) {
93 auto maybeValue = parseBitVectorString(emitError, value);
94 if (failed(maybeValue)) {
95 return {};
96 }
97
98 return Base::getChecked(emitError, context, *maybeValue);
99}
100
101BitVectorAttr BitVectorAttr::get(MLIRContext *context, uint64_t value, unsigned width) {
102 return Base::get(context, APInt(width, value));
103}
104
105BitVectorAttr BitVectorAttr::getChecked(
106 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context, uint64_t value,
107 unsigned width
108) {
109 if (width < 64 && value >= (UINT64_C(1) << width)) {
110 emitError() << "value does not fit in a bit-vector of desired width";
111 return {};
112 }
113 return Base::getChecked(emitError, context, APInt(width, value));
114}
115
116Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) {
117 llvm::SMLoc loc = odsParser.getCurrentLocation();
118
119 APInt val;
120 if (odsParser.parseLess() || odsParser.parseInteger(val) || odsParser.parseGreater()) {
121 return {};
122 }
123
124 // Requires the use of `quantified(<attr>)` in operation assembly formats.
125 if (!odsType || !llvm::isa<BitVectorType>(odsType)) {
126 odsParser.emitError(loc) << "explicit bit-vector type required";
127 return {};
128 }
129
130 unsigned width = llvm::cast<BitVectorType>(odsType).getWidth();
131
132 if (width > val.getBitWidth()) {
133 // sext is always safe here, even for unsigned values, because the
134 // parseOptionalInteger method will return something with a zero in the
135 // top bits if it is a positive number.
136 val = val.sext(width);
137 } else if (width < val.getBitWidth()) {
138 // The parser can return an unnecessarily wide result.
139 // This isn't a problem, but truncating off bits is bad.
140 unsigned neededBits = val.isNegative() ? val.getSignificantBits() : val.getActiveBits();
141 if (width < neededBits) {
142 odsParser.emitError(loc) << "integer value out of range for given bit-vector type "
143 << odsType;
144 return {};
145 }
146 val = val.trunc(width);
147 }
148
149 return BitVectorAttr::get(odsParser.getContext(), val);
150}
151
152void BitVectorAttr::print(AsmPrinter &odsPrinter) const {
153 // This printer only works for the extended format where the MLIR
154 // infrastructure prints the type for us. This means, the attribute should
155 // never be used without `quantified` in an assembly format.
156 odsPrinter << "<" << getValue() << ">";
157}
158
159Type BitVectorAttr::getType() const {
160 return BitVectorType::get(getContext(), getValue().getBitWidth());
161}
162
163//===----------------------------------------------------------------------===//
164// ODS Boilerplate
165//===----------------------------------------------------------------------===//
166
167#define GET_ATTRDEF_CLASSES
169
171 // clang-format off
172 // Suppress false positive from `clang-tidy`
173 // NOLINTNEXTLINE(clang-analyzer-core.StackAddressEscape)
174 addAttributes<
175 #define GET_ATTRDEF_LIST
177 >();
178 // clang-format on
179}
static BitVectorType get(::mlir::MLIRContext *context, int64_t width)