11#include "llvm/ADT/APSInt.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/OpImplementation.h"
24 mlir::MLIRContext *context, std::optional<mlir::Location> location, ::mlir::ValueRange operands,
25 ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties,
26 ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes
28 inferredReturnTypes.push_back(properties.as<
Properties *>()->getValue().getType());
33 SmallVector<char, 128> specialNameBuffer;
34 llvm::raw_svector_ostream specialName(specialNameBuffer);
35 specialName <<
"c" <<
getValue().getValue() <<
"_bv" <<
getValue().getValue().getBitWidth();
36 setNameFn(
getResult(), specialName.str());
40 assert(adaptor.
getOperands().empty() &&
"constant has no operands");
57 if (getBody()->getTerminator()->getOperands().getTypes() != getResultTypes()) {
58 return emitOpError() <<
"types of yielded values must match return values";
60 if (getBody()->getArgumentTypes() !=
getInputs().getTypes()) {
61 return emitOpError() <<
"block argument types must match the types of the 'inputs'";
72 if (
getSatRegion().front().getTerminator()->getOperands().getTypes() != getResultTypes()) {
73 return emitOpError() <<
"types of yielded values in 'sat' region must "
74 "match return values";
76 if (
getUnknownRegion().front().getTerminator()->getOperands().getTypes() != getResultTypes()) {
77 return emitOpError() <<
"types of yielded values in 'unknown' region must "
78 "match return values";
80 if (
getUnsatRegion().front().getTerminator()->getOperands().getTypes() != getResultTypes()) {
81 return emitOpError() <<
"types of yielded values in 'unsat' region must "
82 "match return values";
93parseSameOperandTypeVariadicToBoolOp(OpAsmParser &parser, OperationState &result) {
94 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputs;
95 SMLoc loc = parser.getCurrentLocation();
98 if (parser.parseOperandList(inputs) || parser.parseOptionalAttrDict(result.attributes) ||
99 parser.parseColon() || parser.parseType(type)) {
103 result.addTypes(BoolType::get(parser.getContext()));
104 if (parser.resolveOperands(
105 inputs, SmallVector<Type>(inputs.size(), type), loc, result.operands
113ParseResult
EqOp::parse(OpAsmParser &parser, OperationState &result) {
114 return parseSameOperandTypeVariadicToBoolOp(parser, result);
119 printer.printOptionalAttrDict(getOperation()->getAttrs());
120 printer <<
" : " <<
getInputs().front().getType();
125 return emitOpError() <<
"'inputs' must have at least size 2, but got " <<
getInputs().size();
136 return parseSameOperandTypeVariadicToBoolOp(parser, result);
141 printer.printOptionalAttrDict(getOperation()->getAttrs());
142 printer <<
" : " <<
getInputs().front().getType();
147 return emitOpError() <<
"'inputs' must have at least size 2, but got " <<
getInputs().size();
158 unsigned rangeWidth = getType().getWidth();
160 if (
getLowBit() + rangeWidth > inputWidth) {
162 "range to be extracted is too big, expected range "
165 <<
getLowBit() <<
" of length " << rangeWidth <<
" requires input width of at least "
166 << (
getLowBit() + rangeWidth) <<
", but the input width is only " << inputWidth;
176 MLIRContext *context, std::optional<Location> location, ValueRange operands,
177 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
178 SmallVectorImpl<Type> &inferredReturnTypes
180 inferredReturnTypes.push_back(
195 unsigned resultWidth = getType().getWidth();
196 if (resultWidth % inputWidth != 0) {
197 return emitOpError() <<
"result bit-vector width must be a multiple of the "
198 "input bit-vector width";
206 unsigned resultWidth = getType().getWidth();
207 return resultWidth / inputWidth;
210void RepeatOp::build(OpBuilder &builder, OperationState &state,
unsigned count, Value input) {
213 build(builder, state, resultTy, input);
217 OpAsmParser::UnresolvedOperand input;
219 llvm::SMLoc countLoc = parser.getCurrentLocation();
222 if (parser.parseInteger(count) || parser.parseKeyword(
"times")) {
226 if (count.isNonPositive()) {
227 return parser.emitError(countLoc) <<
"integer must be positive";
230 llvm::SMLoc inputLoc = parser.getCurrentLocation();
231 if (parser.parseOperand(input) || parser.parseOptionalAttrDict(result.attributes) ||
232 parser.parseColon() || parser.parseType(inputType)) {
236 if (parser.resolveOperand(input, inputType, result.operands)) {
240 auto bvInputTy = dyn_cast<BitVectorType>(inputType);
242 return parser.emitError(inputLoc) <<
"input must have bit-vector type";
247 const unsigned maxBw = 63;
248 if (count.getActiveBits() > maxBw) {
249 return parser.emitError(countLoc) <<
"integer must fit into " << maxBw <<
" bits";
255 APInt resultBw = bvInputTy.getWidth() * count.zext(2 * maxBw);
256 if (resultBw.getActiveBits() > maxBw) {
257 return parser.emitError(countLoc)
258 <<
"result bit-width (provided integer times bit-width of the input "
259 "type) must fit into "
264 result.addTypes(resultTy);
270 printer.printOptionalAttrDict((*this)->getAttrs());
271 printer <<
" : " <<
getInput().getType();
283 assert(adaptor.
getOperands().empty() &&
"constant has no operands");
292 SmallVector<char, 32> specialNameBuffer;
293 llvm::raw_svector_ostream specialName(specialNameBuffer);
295 setNameFn(
getResult(), specialName.str());
299 assert(adaptor.
getOperands().empty() &&
"constant has no operands");
305 p.printOptionalAttrDict((*this)->getAttrs(), {
"value"});
310 if (parser.parseInteger(value)) {
314 result.getOrAddProperties<
Properties>().setValue(
315 IntegerAttr::get(parser.getContext(), APSInt(value))
318 if (parser.parseOptionalAttrDict(result.attributes)) {
322 result.addTypes(smt::IntType::get(parser.getContext()));
330template <
typename QuantifierOp>
static LogicalResult verifyQuantifierRegions(QuantifierOp op) {
331 if (op.getBoundVarNames() && op.getBody().getNumArguments() != op.getBoundVarNames()->size()) {
332 return op.emitOpError(
"number of bound variable names must match number of block arguments");
335 return op.emitOpError() <<
"bound variables must by any non-function SMT value";
338 if (op.getBody().front().getTerminator()->getNumOperands() != 1) {
339 return op.emitOpError(
"must have exactly one yielded value");
341 if (!isa<BoolType>(op.getBody().front().getTerminator()->getOperand(0).getType())) {
342 return op.emitOpError(
"yielded value must be of '!smt.bool' type");
345 for (
auto regionWithIndex : llvm::enumerate(op.getPatterns())) {
346 unsigned i = regionWithIndex.index();
347 Region ®ion = regionWithIndex.value();
349 if (op.getBody().getArgumentTypes() != region.getArgumentTypes()) {
350 return op.emitOpError() <<
"block argument number and types of the 'body' "
351 "and 'patterns' region #"
352 << i <<
" must match";
354 if (region.front().getTerminator()->getNumOperands() < 1) {
355 return op.emitOpError() <<
"'patterns' region #" << i
356 <<
" must have at least one yielded value";
360 auto result = region.walk([&](Operation *childOp) {
361 if (!isa<SMTDialect>(childOp->getDialect())) {
362 auto diag = op.emitOpError()
363 <<
"the 'patterns' region #" << i <<
" may only contain SMT dialect operations";
364 diag.attachNote(childOp->getLoc()) <<
"first non-SMT operation here";
365 return WalkResult::interrupt();
370 if (isa<ForallOp, ExistsOp>(childOp)) {
371 auto diag = op.emitOpError() <<
"the 'patterns' region #" << i
372 <<
" must not contain "
373 "any variable binding operations";
374 diag.attachNote(childOp->getLoc()) <<
"first violating operation here";
375 return WalkResult::interrupt();
378 return WalkResult::advance();
380 if (result.wasInterrupted()) {
388template <
typename Properties>
389static void buildQuantifier(
390 OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
391 function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
392 std::optional<ArrayRef<StringRef>> boundVarNames,
393 function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder, uint32_t weight,
396 odsState.addTypes(BoolType::get(odsBuilder.getContext()));
398 odsState.getOrAddProperties<Properties>().weight =
399 odsBuilder.getIntegerAttr(odsBuilder.getIntegerType(32), weight);
402 odsState.getOrAddProperties<Properties>().noPattern = odsBuilder.getUnitAttr();
404 if (boundVarNames.has_value()) {
405 SmallVector<Attribute> boundVarNamesList;
406 for (StringRef str : *boundVarNames) {
407 boundVarNamesList.emplace_back(odsBuilder.getStringAttr(str));
409 odsState.getOrAddProperties<Properties>().boundVarNames =
410 odsBuilder.getArrayAttr(boundVarNamesList);
413 OpBuilder::InsertionGuard guard(odsBuilder);
414 Region *region = odsState.addRegion();
415 Block *block = odsBuilder.createBlock(region);
417 boundVarTypes, SmallVector<Location>(boundVarTypes.size(), odsState.location)
419 Value returnVal = bodyBuilder(odsBuilder, odsState.location, block->getArguments());
420 odsBuilder.create<llzk::smt::YieldOp>(odsState.location, returnVal);
422 if (patternBuilder) {
423 Region *region = odsState.addRegion();
424 OpBuilder::InsertionGuard guard(odsBuilder);
425 Block *block = odsBuilder.createBlock(region);
427 boundVarTypes, SmallVector<Location>(boundVarTypes.size(), odsState.location)
429 ValueRange returnVals = patternBuilder(odsBuilder, odsState.location, block->getArguments());
430 odsBuilder.create<llzk::smt::YieldOp>(odsState.location, returnVals);
436 return emitOpError() <<
"patterns and the no_pattern attribute must not be "
437 "specified at the same time";
446 OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
447 function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
448 std::optional<ArrayRef<StringRef>> boundVarNames,
449 function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder, uint32_t weight,
452 buildQuantifier<Properties>(
453 odsBuilder, odsState, boundVarTypes, bodyBuilder, boundVarNames, patternBuilder, weight,
464 return emitOpError() <<
"patterns and the no_pattern attribute must not be "
465 "specified at the same time";
474 OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes,
475 function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder,
476 std::optional<ArrayRef<StringRef>> boundVarNames,
477 function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder, uint32_t weight,
480 buildQuantifier<Properties>(
481 odsBuilder, odsState, boundVarTypes, bodyBuilder, boundVarNames, patternBuilder, weight,
486#define GET_OP_CLASSES
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
::llzk::smt::BitVectorAttr getValue()
::llzk::smt::BitVectorAttr getValueAttr()
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location > location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type > &inferredReturnTypes)
FoldAdaptor::Properties Properties
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
::mlir::TypedValue<::llzk::smt::BitVectorType > getResult()
static BitVectorType get(::mlir::MLIRContext *context, int64_t width)
::mlir::TypedValue<::llzk::smt::BoolType > getResult()
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
::mlir::BoolAttr getValueAttr()
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
::mlir::Region & getUnsatRegion()
::mlir::Region & getUnknownRegion()
::mlir::Region & getSatRegion()
::llvm::LogicalResult verifyRegions()
::llvm::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location > location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type > &inferredReturnTypes)
::mlir::TypedValue<::mlir::Type > getResult()
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
::std::optional< ::llvm::StringRef > getNamePrefix()
::mlir::Operation::operand_range getInputs()
void print(::mlir::OpAsmPrinter &p)
::llvm::LogicalResult verify()
::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result)
::mlir::Operation::operand_range getInputs()
void print(::mlir::OpAsmPrinter &p)
::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result)
::llvm::LogicalResult verify()
::llvm::LogicalResult verify()
::llvm::LogicalResult verifyRegions()
::mlir::MutableArrayRef<::mlir::Region > getPatterns()
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, mlir::TypeRange boundVarTypes, llvm::function_ref< mlir::Value(mlir::OpBuilder &, mlir::Location, mlir::ValueRange)> bodyBuilder, std::optional< llvm::ArrayRef< mlir::StringRef > > boundVarNames=std::nullopt, llvm::function_ref< mlir::ValueRange(mlir::OpBuilder &, mlir::Location, mlir::ValueRange)> patternBuilder={}, uint32_t weight=0, bool noPattern=false)
::llvm::LogicalResult verifyRegions()
::llvm::LogicalResult verify()
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, mlir::TypeRange boundVarTypes, llvm::function_ref< mlir::Value(mlir::OpBuilder &, mlir::Location, mlir::ValueRange)> bodyBuilder, std::optional< llvm::ArrayRef< mlir::StringRef > > boundVarNames=std::nullopt, llvm::function_ref< mlir::ValueRange(mlir::OpBuilder &, mlir::Location, mlir::ValueRange)> patternBuilder={}, uint32_t weight=0, bool noPattern=false)
::mlir::MutableArrayRef<::mlir::Region > getPatterns()
void print(::mlir::OpAsmPrinter &p)
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
::mlir::TypedValue<::llzk::smt::IntType > getResult()
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result)
FoldAdaptor::Properties Properties
::mlir::IntegerAttr getValueAttr()
::mlir::TypedValue<::llzk::smt::BitVectorType > getInput()
::llvm::LogicalResult verify()
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, unsigned count, mlir::Value input)
void print(::mlir::OpAsmPrinter &p)
unsigned getCount()
Get the number of times the input operand is repeated.
::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result)
::llvm::LogicalResult verifyRegions()
::mlir::Operation::operand_range getInputs()
bool isAnyNonFuncSMTValueType(mlir::Type type)
Returns whether the given type is an SMT value type (excluding functions).