LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
SMTAttributes.cpp
Go to the documentation of this file.
1//===- SMTAttributes.cpp - Implement SMT attributes -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
11#include "llvm/ADT/TypeSwitch.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/DialectImplementation.h"
14
17
18using namespace mlir;
19using namespace llzk::smt;
20
21//===----------------------------------------------------------------------===//
22// BitVectorAttr
23//===----------------------------------------------------------------------===//
24
25LogicalResult BitVectorAttr::verify(
26 function_ref<InFlightDiagnostic()> emitError,
27 APInt value
28) { // NOLINT(performance-unnecessary-value-param)
29 if (value.getBitWidth() < 1) {
30 return emitError() << "bit-width must be at least 1, but got " << value.getBitWidth();
31 }
32 return success();
33}
34
35std::string BitVectorAttr::getValueAsString(bool prefix) const {
36 unsigned width = getValue().getBitWidth();
37 SmallVector<char> toPrint;
38 StringRef pref = prefix ? "#" : "";
39 if (width % 4 == 0) {
40 getValue().toString(toPrint, 16, false, false, false);
41 // APInt's 'toString' omits leading zeros. However, those are critical here
42 // because they determine the bit-width of the bit-vector.
43 SmallVector<char> leadingZeros(width / 4 - toPrint.size(), '0');
44 return (pref + "x" + Twine(leadingZeros) + toPrint).str();
45 }
46
47 getValue().toString(toPrint, 2, false, false, false);
48 // APInt's 'toString' omits leading zeros
49 SmallVector<char> leadingZeros(width - toPrint.size(), '0');
50 return (pref + "b" + Twine(leadingZeros) + toPrint).str();
51}
52
54static FailureOr<APInt>
55parseBitVectorString(function_ref<InFlightDiagnostic()> emitError, StringRef value) {
56 if (value[0] != '#') {
57 return emitError() << "expected '#'";
58 }
59
60 if (value.size() < 3) {
61 return emitError() << "expected at least one digit";
62 }
63
64 if (value[1] == 'b') {
65 return APInt(value.size() - 2, std::string(value.begin() + 2, value.end()), 2);
66 }
67
68 if (value[1] == 'x') {
69 return APInt((value.size() - 2) * 4, std::string(value.begin() + 2, value.end()), 16);
70 }
71
72 return emitError() << "expected either 'b' or 'x'";
73}
74
75BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) {
76 auto maybeValue = parseBitVectorString(nullptr, value);
77
78 assert(succeeded(maybeValue) && "string must have SMT-LIB format");
79 return Base::get(context, *maybeValue);
80}
81
82BitVectorAttr BitVectorAttr::getChecked(
83 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context, StringRef value
84) {
85 auto maybeValue = parseBitVectorString(emitError, value);
86 if (failed(maybeValue)) {
87 return {};
88 }
89
90 return Base::getChecked(emitError, context, *maybeValue);
91}
92
93BitVectorAttr BitVectorAttr::get(MLIRContext *context, uint64_t value, unsigned width) {
94 return Base::get(context, APInt(width, value));
95}
96
97BitVectorAttr BitVectorAttr::getChecked(
98 function_ref<InFlightDiagnostic()> emitError, MLIRContext *context, uint64_t value,
99 unsigned width
100) {
101 if (width < 64 && value >= (UINT64_C(1) << width)) {
102 emitError() << "value does not fit in a bit-vector of desired width";
103 return {};
104 }
105 return Base::getChecked(emitError, context, APInt(width, value));
106}
107
108Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) {
109 llvm::SMLoc loc = odsParser.getCurrentLocation();
110
111 APInt val;
112 if (odsParser.parseLess() || odsParser.parseInteger(val) || odsParser.parseGreater()) {
113 return {};
114 }
115
116 // Requires the use of `quantified(<attr>)` in operation assembly formats.
117 if (!odsType || !llvm::isa<BitVectorType>(odsType)) {
118 odsParser.emitError(loc) << "explicit bit-vector type required";
119 return {};
120 }
121
122 unsigned width = llvm::cast<BitVectorType>(odsType).getWidth();
123
124 if (width > val.getBitWidth()) {
125 // sext is always safe here, even for unsigned values, because the
126 // parseOptionalInteger method will return something with a zero in the
127 // top bits if it is a positive number.
128 val = val.sext(width);
129 } else if (width < val.getBitWidth()) {
130 // The parser can return an unnecessarily wide result.
131 // This isn't a problem, but truncating off bits is bad.
132 unsigned neededBits = val.isNegative() ? val.getSignificantBits() : val.getActiveBits();
133 if (width < neededBits) {
134 odsParser.emitError(loc) << "integer value out of range for given bit-vector type "
135 << odsType;
136 return {};
137 }
138 val = val.trunc(width);
139 }
140
141 return BitVectorAttr::get(odsParser.getContext(), val);
142}
143
144void BitVectorAttr::print(AsmPrinter &odsPrinter) const {
145 // This printer only works for the extended format where the MLIR
146 // infrastructure prints the type for us. This means, the attribute should
147 // never be used without `quantified` in an assembly format.
148 odsPrinter << "<" << getValue() << ">";
149}
150
151Type BitVectorAttr::getType() const {
152 return BitVectorType::get(getContext(), getValue().getBitWidth());
153}
154
155//===----------------------------------------------------------------------===//
156// ODS Boilerplate
157//===----------------------------------------------------------------------===//
158
159#define GET_ATTRDEF_CLASSES
161
163 addAttributes<
164#define GET_ATTRDEF_LIST
166 >();
167}
static BitVectorType get(::mlir::MLIRContext *context, int64_t width)