16#include <mlir/IR/Builders.h>
17#include <mlir/IR/BuiltinAttributes.h>
18#include <mlir/IR/Diagnostics.h>
19#include <mlir/IR/OpImplementation.h>
20#include <mlir/IR/OperationSupport.h>
21#include <mlir/Support/LLVM.h>
23#include <llvm/ADT/STLExtras.h>
24#include <llvm/ADT/SmallString.h>
25#include <llvm/ADT/SmallVectorExtras.h>
26#include <llvm/ADT/StringSet.h>
27#include <llvm/ADT/TypeSwitch.h>
28#include <llvm/Support/Debug.h>
45static void buildCommon(
46 OpBuilder &builder, OperationState &state, PodType result,
InitializedRecords initialValues
48 SmallVector<Value, 4> values;
49 SmallVector<StringRef, 4>
names;
51 for (
const auto &record : initialValues) {
52 names.push_back(record.name);
53 values.push_back(record.value);
57 state.addTypes(result);
58 state.addOperands(values);
64 OpBuilder &builder, OperationState &state, PodType result, ArrayRef<ValueRange> mapOperands,
67 buildCommon(builder, state, result, initialValues);
72 OpBuilder &builder, OperationState &state, PodType result,
InitializedRecords initialValues
74 buildCommon(builder, state, result, initialValues);
75 assert(std::cmp_less_equal(initialValues.size(), std::numeric_limits<int32_t>::max()));
77 builder, state,
static_cast<int32_t
>(initialValues.size())
87static void collectMapAttrs(Type type, SmallVector<AffineMapAttr> &mapAttrs) {
88 llvm::TypeSwitch<Type, void>(type)
91 collectMapAttrs(record.getType(), mapAttrs);
94 .Case([&mapAttrs](array::ArrayType t) {
95 for (
auto a : t.getDimensionSizes()) {
96 if (
auto m = llvm::dyn_cast<AffineMapAttr>(a)) {
97 mapAttrs.push_back(m);
101 .Case([&mapAttrs](component::StructType t) {
102 for (
auto param : t.getParams()) {
103 if (
auto m = llvm::dyn_cast<AffineMapAttr>(param)) {
104 mapAttrs.push_back(m);
107 }).Default([](Type) {});
117static LogicalResult verifyInitialValues(
118 ValueRange values, ArrayRef<Attribute>
names,
PodType retTy,
119 llvm::function_ref<InFlightDiagnostic()> emitError
122 if (
names.size() != values.size()) {
123 emitError() <<
"number of initialized records and initial values does not match ("
124 <<
names.size() <<
" != " << values.size() <<
")";
128 llvm::StringMap<Type> records = retTy.getRecordMap();
129 llvm::StringSet<> seenNames;
130 for (
auto [nameAttr, value] : llvm::zip_equal(
names, values)) {
131 auto name = llvm::cast<StringAttr>(nameAttr).getValue();
132 if (seenNames.contains(name)) {
133 emitError() <<
"found duplicated record name '" << name <<
'\'';
136 seenNames.insert(name);
138 if (!records.contains(name)) {
139 emitError() <<
"record '" << name <<
"' is not part of the struct";
144 auto valueTy = value.getType();
145 auto recordTy = records.at(name);
146 if (valueTy != recordTy) {
147 auto err = emitError();
148 err <<
"record '" << name <<
"' expected type " << recordTy <<
" but got " << valueTy;
151 <<
"types " << valueTy <<
" and " << recordTy
152 <<
" can be unified. Perhaps you can add a 'poly.unifiable_cast' operation?";
158 return failure(failed);
161static LogicalResult verifyAffineMapOperands(
NewPodOp *op, Type retTy) {
162 SmallVector<AffineMapAttr> mapAttrs;
163 collectMapAttrs(retTy, mapAttrs);
165 op->getMapOperands(), op->getNumDimsPerMap(), mapAttrs, *op
173 failed = failed || mlir::failed(x); \
177 auto retTy = llvm::dyn_cast<PodType>(
getResult().getType());
183 return this->emitError();
186 check(verifyAffineMapOperands(
this, retTy));
188 return failure(failed);
197 if (failed(parser.parseSymbolName(name))) {
201 if (parser.parseEqual()) {
204 return parser.parseOperand(operand);
217 SmallVector<Attribute> initializedRecords;
220 llvm::StringMap<UnresolvedOp> initialValuesOperands;
221 auto parseElementFn = [&parser, &initializedRecords, &initialValuesOperands] {
227 initializedRecords.push_back(name);
228 initialValuesOperands.insert({name.getValue(), operand});
231 auto initialValuesLoc = parser.getCurrentLocation();
232 if (parser.parseCommaSeparatedList(AsmParser::Delimiter::OptionalBraces, parseElementFn)) {
235 SmallVector<int32_t> mapOperandsGroupSizes;
236 SmallVector<UnresolvedOp> allMapOperands;
237 Type indexTy = parser.getBuilder().getIndexType();
238 bool colonAlreadyParsed =
true;
239 auto mapOperandsLoc = parser.getCurrentLocation();
242 if (failed(parser.parseOptionalColon())) {
243 colonAlreadyParsed =
false;
244 SmallVector<SmallVector<UnresolvedOp>> mapOperands {};
249 mapOperandsGroupSizes.reserve(mapOperands.size());
250 for (
const auto &subRange : mapOperands) {
251 allMapOperands.append(subRange.begin(), subRange.end());
252 assert(std::cmp_less_equal(subRange.size(), std::numeric_limits<int32_t>::max()));
253 mapOperandsGroupSizes.push_back(
static_cast<int32_t
>(subRange.size()));
257 if (!colonAlreadyParsed && parser.parseColon()) {
262 if (parser.parseCustomTypeWithFallback(resultType)) {
267 for (
auto attr : initializedRecords) {
268 auto name = llvm::cast<StringAttr>(attr);
269 auto lookup = resultType.
getRecord(name.getValue(), [&parser, initialValuesLoc] {
270 return parser.emitError(initialValuesLoc);
272 if (failed(lookup)) {
275 const auto &operand = initialValuesOperands.at(name.getValue());
276 if (failed(parser.resolveOperands({operand}, *lookup, initialValuesLoc, result.operands))) {
280 assert(std::cmp_less_equal(initializedRecords.size(), std::numeric_limits<int32_t>::max()));
281 assert(std::cmp_less_equal(allMapOperands.size(), std::numeric_limits<int32_t>::max()));
282 props.operandSegmentSizes = {
283 static_cast<int32_t
>(initializedRecords.size()),
static_cast<int32_t
>(allMapOperands.size())
285 props.mapOpGroupSizes = parser.getBuilder().getDenseI32ArrayAttr(mapOperandsGroupSizes);
286 props.initializedRecords = parser.getBuilder().getArrayAttr(initializedRecords);
287 result.addTypes({resultType});
289 if (failed(parser.resolveOperands(allMapOperands, indexTy, mapOperandsLoc, result.operands))) {
293 auto loc = parser.getCurrentLocation();
294 if (parser.parseOptionalAttrDict(result.attributes)) {
298 return parser.emitError(loc) <<
'\'' << result.name.getStringRef() <<
"' op ";
308 auto &os = printer.getStream();
310 if (!initializedRecords.empty()) {
312 llvm::interleaveComma(initializedRecords, os, [&os, &printer](
auto record) {
313 printer.printSymbolName(record.name);
315 printer.printOperand(record.value);
324 if (
auto validType = llvm::dyn_cast<PodType>(type)) {
325 printer.printStrippedAttrOrType(validType);
327 printer.printType(type);
330 printer.printOptionalAttrDict(
332 {
"initializedRecords",
"mapOpGroupSizes",
"numDimsPerMap",
"operandSegmentSizes"}
337 return llvm::map_to_vector(
339 auto [value, name] = pair;
340 return RecordValue {.name = llvm::cast<StringAttr>(name).getValue(), .value = value};
350 auto podTy = llvm::dyn_cast<PodType>(
getPodRef().getType());
352 return emitError() <<
"reference operand expected a plain-old-data struct but got "
356 auto lookup = podTy.getRecord(
getRecordName(), [
this]() {
return this->emitError(); });
357 if (failed(lookup)) {
362 return emitError() <<
"operation result type and type of record do not match ("
363 <<
getResult().getType() <<
" != " << *lookup <<
")";
374 auto podTy = llvm::dyn_cast<PodType>(
getPodRef().getType());
376 return emitError() <<
"reference operand expected a plain-old-data struct but got "
380 auto lookup = podTy.getRecord(
getRecordName(), [
this]() {
return this->emitError(); });
381 if (failed(lookup)) {
385 if (
getValue().getType() != *lookup) {
386 return emitError() <<
"type of source value and type of record do not match ("
387 <<
getValue().getType() <<
" != " << *lookup <<
")";
398 return parser.parseCustomAttributeWithFallback(name);
402 printer.printSymbolName(name.getValue());
within a display generated by the Derivative if and wherever such third party notices normally appear The contents of the NOTICE file are for informational purposes only and do not modify the License You may add Your own attribution notices within Derivative Works that You alongside or as an addendum to the NOTICE text from the provided that such additional attribution notices cannot be construed as modifying the License You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for or distribution of Your or for any such Derivative Works as a provided Your and distribution of the Work otherwise complies with the conditions stated in this License Submission of Contributions Unless You explicitly state any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this without any additional terms or conditions Notwithstanding the nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions Trademarks This License does not grant permission to use the trade names
void print(::mlir::OpAsmPrinter &p)
::mlir::Operation::operand_range getInitialValues()
::mlir::OperandRangeRange getMapOperands()
::mlir::SmallVector<::llzk::pod::RecordValue > getInitializedRecordValues()
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llzk::pod::InitializedRecords initialValues={})
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
::mlir::TypedValue<::llzk::pod::PodType > getResult()
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
FoldAdaptor::Properties Properties
::llvm::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError)
::llvm::LogicalResult verify()
::mlir::ArrayAttr getInitializedRecords()
::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result)
::llvm::FailureOr<::mlir::Type > getRecord(::llvm::StringRef name, ::llvm::function_ref<::mlir::InFlightDiagnostic()>) const
Searches a record by name.
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
::llvm::StringRef getRecordName()
::mlir::TypedValue<::mlir::Type > getResult()
::llvm::LogicalResult verify()
::mlir::TypedValue<::llzk::pod::PodType > getPodRef()
::llvm::LogicalResult verify()
::mlir::TypedValue<::mlir::Type > getValue()
::mlir::TypedValue<::llzk::pod::PodType > getPodRef()
::llvm::StringRef getRecordName()
OpClass::Properties & buildInstantiationAttrs(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
OpClass::Properties & buildInstantiationAttrsEmpty(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
mlir::ArrayRef< RecordValue > InitializedRecords
OpAsmParser::UnresolvedOperand UnresolvedOp
ParseResult parseRecordName(AsmParser &parser, FlatSymbolRefAttr &name)
ParseResult parseRecordInitialization(OpAsmParser &parser, StringAttr &name, UnresolvedOp &operand)
void printRecordName(AsmPrinter &printer, Operation *, FlatSymbolRefAttr name)
void printMultiDimAndSymbolList(mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::OperandRangeRange multiMapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
mlir::ParseResult parseMultiDimAndSymbolList(mlir::OpAsmParser &parser, mlir::SmallVector< mlir::SmallVector< mlir::OpAsmParser::UnresolvedOperand > > &multiMapOperands, mlir::DenseI32ArrayAttr &numDimsPerMap)
void setInitializedRecords(const ::mlir::ArrayAttr &propValue)