18#include <mlir/IR/Builders.h>
19#include <mlir/IR/BuiltinAttributes.h>
20#include <mlir/IR/Diagnostics.h>
21#include <mlir/IR/OpImplementation.h>
22#include <mlir/IR/OperationSupport.h>
23#include <mlir/Support/LLVM.h>
25#include <llvm/ADT/STLExtras.h>
26#include <llvm/ADT/SmallString.h>
27#include <llvm/ADT/SmallVectorExtras.h>
28#include <llvm/ADT/StringSet.h>
29#include <llvm/ADT/TypeSwitch.h>
30#include <llvm/Support/Debug.h>
51static void buildCommon(
52 OpBuilder &builder, OperationState &state, PodType result,
InitializedRecords initialValues
54 SmallVector<Value, 4> values;
55 SmallVector<StringRef, 4>
names;
57 for (
const auto &record : initialValues) {
58 names.push_back(record.name);
59 values.push_back(record.value);
63 state.addTypes(result);
64 state.addOperands(values);
70 OpBuilder &builder, OperationState &state, PodType result, ArrayRef<ValueRange> mapOperands,
73 buildCommon(builder, state, result, initialValues);
80 OpBuilder &builder, OperationState &state, PodType result,
InitializedRecords initialValues
82 buildCommon(builder, state, result, initialValues);
99 return {DestructurableMemorySlot {{
getResult(), podType}, std::move(*destructured)}};
106 const DestructurableMemorySlot &slot,
const SmallPtrSetImpl<Attribute> &usedIndices,
107 OpBuilder &builder, SmallVectorImpl<DestructurableAllocationOpInterface> &newAllocators
110 assert(slot.elemType == getType());
112 builder.setInsertionPointAfter(*
this);
115 DenseMap<Attribute, MemorySlot> slotMap;
116 for (Attribute index : usedIndices) {
117 auto recordName = llvm::dyn_cast<StringAttr>(index);
118 assert(recordName &&
"expected StringAttr");
120 Type destructAs = getType().getTypeAtIndex(recordName);
121 assert(destructAs == slot.subelementTypes.lookup(recordName));
123 auto destructAsPodTy = llvm::dyn_cast<PodType>(destructAs);
124 assert(destructAsPodTy &&
"expected PodType");
126 SmallVector<RecordValue, 1> initialValue;
128 if (record.name == recordName.getValue()) {
129 initialValue.push_back(record);
134 auto subNew = builder.create<
NewPodOp>(getLoc(), destructAsPodTy, initialValue);
135 newAllocators.push_back(subNew);
136 slotMap.try_emplace<MemorySlot>(index, {subNew.getResult(), destructAs});
144 const DestructurableMemorySlot &slot, OpBuilder &
153 ArrayRef<RecordAttr> records = getType().getRecords();
154 if (records.size() != 1) {
157 return {MemorySlot {
getResult(), records.front().getType()}};
163 ArrayRef<RecordAttr> records = getType().getRecords();
164 assert(records.size() == 1 &&
"only single-record pods are promotable");
165 assert(records.front().getType() == slot.elemType);
167 StringRef recordName = records.front().getName().getValue();
169 if (record.name == recordName) {
181 const MemorySlot &slot, Value defaultValue, OpBuilder &
184 if (defaultValue && defaultValue.use_empty()) {
185 if (Operation *defOp = defaultValue.getDefiningOp()) {
186 if (llvm::isa<llzk::NonDetOp>(defOp)) {
197static void collectMapAttrs(Type type, SmallVector<AffineMapAttr> &mapAttrs) {
199 llvm::TypeSwitch<Type, void>(type)
202 collectMapAttrs(record.getType(), mapAttrs);
205 .Case([&mapAttrs](array::ArrayType t) {
206 for (
auto a : t.getDimensionSizes()) {
207 if (
auto m = llvm::dyn_cast<AffineMapAttr>(a)) {
208 mapAttrs.push_back(m);
212 .Case([&mapAttrs](component::StructType t) {
213 if (ArrayAttr params = t.getParams()) {
214 for (
auto param : params) {
215 if (
auto m = llvm::dyn_cast<AffineMapAttr>(param)) {
216 mapAttrs.push_back(m);
220 }).Default([](Type) {});
231static LogicalResult verifyInitialValues(
232 ValueRange values, ArrayRef<Attribute>
names,
PodType retTy,
233 llvm::function_ref<InFlightDiagnostic()> emitError
236 if (
names.size() != values.size()) {
237 emitError() <<
"number of initialized records and initial values does not match ("
238 <<
names.size() <<
" != " << values.size() <<
")";
242 llvm::StringMap<Type> records = retTy.getRecordMap();
243 llvm::StringSet<> seenNames;
244 for (
auto [nameAttr, value] : llvm::zip_equal(
names, values)) {
245 auto name = llvm::cast<StringAttr>(nameAttr).getValue();
246 if (seenNames.contains(name)) {
247 emitError() <<
"found duplicated record name '" << name <<
'\'';
250 seenNames.insert(name);
252 if (!records.contains(name)) {
253 emitError() <<
"record '" << name <<
"' is not part of the struct";
258 auto valueTy = value.getType();
259 auto recordTy = records.at(name);
260 if (valueTy != recordTy) {
261 auto err = emitError();
262 err <<
"record '" << name <<
"' expected type " << recordTy <<
" but got " << valueTy;
265 <<
"types " << valueTy <<
" and " << recordTy
266 <<
" can be unified. Perhaps you can add a 'poly.unifiable_cast' operation?";
272 return failure(failed);
275static LogicalResult verifyAffineMapOperands(
NewPodOp *op, Type retTy) {
276 SmallVector<AffineMapAttr> mapAttrs;
277 collectMapAttrs(retTy, mapAttrs);
279 op->getMapOperands(), op->getNumDimsPerMap(), mapAttrs, *op
287 failed = failed || mlir::failed(x); \
291 auto retTy = llvm::dyn_cast<PodType>(
getResult().getType());
297 return this->emitError();
300 check(verifyAffineMapOperands(
this, retTy));
302 return failure(failed);
311 if (failed(parser.parseSymbolName(name))) {
315 if (parser.parseEqual()) {
318 return parser.parseOperand(operand);
331 SmallVector<Attribute> initializedRecords;
334 llvm::StringMap<UnresolvedOp> initialValuesOperands;
335 auto parseElementFn = [&parser, &initializedRecords, &initialValuesOperands] {
341 initializedRecords.push_back(name);
342 initialValuesOperands.insert({name.getValue(), operand});
345 auto initialValuesLoc = parser.getCurrentLocation();
346 if (parser.parseCommaSeparatedList(AsmParser::Delimiter::OptionalBraces, parseElementFn)) {
349 SmallVector<int32_t> mapOperandsGroupSizes;
350 SmallVector<UnresolvedOp> allMapOperands;
351 Type indexTy = parser.getBuilder().getIndexType();
352 bool colonAlreadyParsed =
true;
353 auto mapOperandsLoc = parser.getCurrentLocation();
356 if (failed(parser.parseOptionalColon())) {
357 colonAlreadyParsed =
false;
358 SmallVector<SmallVector<UnresolvedOp>> mapOperands {};
363 mapOperandsGroupSizes.reserve(mapOperands.size());
364 for (
const auto &subRange : mapOperands) {
365 allMapOperands.append(subRange.begin(), subRange.end());
370 if (!colonAlreadyParsed && parser.parseColon()) {
375 if (parser.parseCustomTypeWithFallback(resultType)) {
380 for (
auto attr : initializedRecords) {
381 auto name = llvm::cast<StringAttr>(attr);
382 auto lookup = resultType.
getRecord(name.getValue(), [&parser, initialValuesLoc] {
383 return parser.emitError(initialValuesLoc);
385 if (failed(lookup)) {
388 const auto &operand = initialValuesOperands.at(name.getValue());
389 if (failed(parser.resolveOperands({operand}, *lookup, initialValuesLoc, result.operands))) {
393 props.operandSegmentSizes = {
397 props.mapOpGroupSizes = parser.getBuilder().getDenseI32ArrayAttr(mapOperandsGroupSizes);
398 props.initializedRecords = parser.getBuilder().getArrayAttr(initializedRecords);
399 result.addTypes({resultType});
401 if (failed(parser.resolveOperands(allMapOperands, indexTy, mapOperandsLoc, result.operands))) {
405 auto loc = parser.getCurrentLocation();
406 if (parser.parseOptionalAttrDict(result.attributes)) {
410 return parser.emitError(loc) <<
'\'' << result.name.getStringRef() <<
"' op ";
420 auto &os = printer.getStream();
422 if (!initializedRecords.empty()) {
424 llvm::interleaveComma(initializedRecords, os, [&os, &printer](
auto record) {
425 printer.printSymbolName(record.name);
427 printer.printOperand(record.value);
436 if (
auto validType = llvm::dyn_cast<PodType>(type)) {
437 printer.printStrippedAttrOrType(validType);
439 printer.printType(type);
442 printer.printOptionalAttrDict(
444 {
"initializedRecords",
"mapOpGroupSizes",
"numDimsPerMap",
"operandSegmentSizes"}
448SmallVector<RecordValue>
450 return llvm::map_to_vector(llvm::zip_equal(initialValues, initializedRecords), [](
auto pair) {
451 auto [value, name] = pair;
452 return RecordValue {.name = llvm::cast<StringAttr>(name).getValue(), .value = value};
466 const DestructurableMemorySlot &slot, SmallPtrSetImpl<Attribute> &usedIndices,
467 SmallVectorImpl<MemorySlot> & ,
const DataLayout &
474 if (!slot.subelementTypes.contains(recordName)) {
478 usedIndices.insert(recordName);
484 const DestructurableMemorySlot &slot, DenseMap<Attribute, MemorySlot> &subslots,
485 OpBuilder & ,
const DataLayout &
491 const MemorySlot &memorySlot = subslots.at(recordName);
494 return DeletionKind::Keep;
502 auto podTy = llvm::dyn_cast<PodType>(
getPodRef().getType());
504 return emitError() <<
"reference operand expected a plain-old-data struct but got "
508 auto lookup = podTy.getRecord(
getRecordName(), [
this]() {
return this->emitError(); });
509 if (failed(lookup)) {
514 return emitError() <<
"operation result type and type of record do not match ("
515 <<
getResult().getType() <<
" != " << *lookup <<
")";
526 auto podTy = llvm::dyn_cast<PodType>(
getPodRef().getType());
528 return emitError() <<
"reference operand expected a plain-old-data struct but got "
532 auto lookup = podTy.getRecord(
getRecordName(), [
this]() {
return this->emitError(); });
533 if (failed(lookup)) {
537 if (
getValue().getType() != *lookup) {
538 return emitError() <<
"type of source value and type of record do not match ("
539 <<
getValue().getType() <<
" != " << *lookup <<
")";
550 return parser.parseCustomAttributeWithFallback(name);
554 printer.printSymbolName(name.getValue());
within a display generated by the Derivative if and wherever such third party notices normally appear The contents of the NOTICE file are for informational purposes only and do not modify the License You may add Your own attribution notices within Derivative Works that You alongside or as an addendum to the NOTICE text from the provided that such additional attribution notices cannot be construed as modifying the License You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for or distribution of Your or for any such Derivative Works as a provided Your and distribution of the Work otherwise complies with the conditions stated in this License Submission of Contributions Unless You explicitly state any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this without any additional terms or conditions Notwithstanding the nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions Trademarks This License does not grant permission to use the trade names
void print(::mlir::OpAsmPrinter &p)
::mlir::Operation::operand_range getInitialValues()
::mlir::OperandRangeRange getMapOperands()
::mlir::SmallVector<::llzk::pod::RecordValue > getInitializedRecordValues()
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::llzk::pod::InitializedRecords initialValues={})
::std::optional<::mlir::PromotableAllocationOpInterface > handlePromotionComplete(const ::mlir::MemorySlot &slot, ::mlir::Value defaultValue, ::mlir::OpBuilder &builder)
Required by PromotableAllocationOpInterface / mem2reg pass.
::llvm::SmallVector<::mlir::DestructurableMemorySlot > getDestructurableSlots()
Required by DestructurableAllocationOpInterface / SROA pass.
::llvm::SmallVector<::mlir::MemorySlot > getPromotableSlots()
Required by PromotableAllocationOpInterface / mem2reg pass.
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
::mlir::TypedValue<::llzk::pod::PodType > getResult()
::mlir::Value getDefaultValue(const ::mlir::MemorySlot &slot, ::mlir::OpBuilder &builder)
Required by PromotableAllocationOpInterface / mem2reg pass.
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
FoldAdaptor::Properties Properties
::llvm::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> emitError)
::llvm::LogicalResult verify()
::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot > destructure(const ::mlir::DestructurableMemorySlot &slot, const ::llvm::SmallPtrSetImpl<::mlir::Attribute > &usedIndices, ::mlir::OpBuilder &builder, ::mlir::SmallVectorImpl<::mlir::DestructurableAllocationOpInterface > &newAllocators)
Required by DestructurableAllocationOpInterface / SROA pass.
::std::optional<::mlir::DestructurableAllocationOpInterface > handleDestructuringComplete(const ::mlir::DestructurableMemorySlot &slot, ::mlir::OpBuilder &builder)
Required by DestructurableAllocationOpInterface / SROA pass.
::mlir::ArrayAttr getInitializedRecords()
::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result)
void handleBlockArgument(const ::mlir::MemorySlot &slot, ::mlir::BlockArgument argument, ::mlir::OpBuilder &builder)
Required by PromotableAllocationOpInterface / mem2reg pass.
::mlir::OpOperand & getPodRefMutable()
Gets the mutable operand slot holding the SSA Value for the referenced pod.
::mlir::TypedValue<::llzk::pod::PodType > getPodRef()
Gets the SSA Value for the referenced pod.
inline ::llzk::pod::PodType getPodRefType()
Gets the type of the referenced pod.
::mlir::DeletionKind rewire(const ::mlir::DestructurableMemorySlot &slot, ::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot > &subslots, ::mlir::OpBuilder &builder, const ::mlir::DataLayout &dataLayout)
Required by companion interface DestructurableAccessorOpInterface / SROA pass.
inline ::mlir::StringAttr getRecordNameAsStringAttr()
Gets the record name as an attribute suitable for destructuring indices.
bool canRewire(const ::mlir::DestructurableMemorySlot &slot, ::llvm::SmallPtrSetImpl<::mlir::Attribute > &usedIndices, ::mlir::SmallVectorImpl<::mlir::MemorySlot > &mustBeSafelyUsed, const ::mlir::DataLayout &dataLayout)
Required by companion interface DestructurableAccessorOpInterface / SROA pass.
::llvm::FailureOr<::mlir::Type > getRecord(::llvm::StringRef name, ::llvm::function_ref<::mlir::InFlightDiagnostic()>) const
Searches a record by name.
::std::optional<::llvm::DenseMap<::mlir::Attribute, ::mlir::Type > > getSubelementIndexMap() const
Required by DestructurableTypeInterface / SROA pass.
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
::llvm::StringRef getRecordName()
::mlir::TypedValue<::mlir::Type > getResult()
::llvm::LogicalResult verify()
::mlir::TypedValue<::llzk::pod::PodType > getPodRef()
::llvm::LogicalResult verify()
::mlir::TypedValue<::mlir::Type > getValue()
::mlir::TypedValue<::llzk::pod::PodType > getPodRef()
::llvm::StringRef getRecordName()
OpClass::Properties & buildInstantiationAttrs(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
OpClass::Properties & buildInstantiationAttrsEmpty(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, int32_t firstSegmentSize=0)
Utility for build() functions that initializes the operandSegmentSizes, mapOpGroupSizes,...
mlir::ArrayRef< RecordValue > InitializedRecords
OpAsmParser::UnresolvedOperand UnresolvedOp
ParseResult parseRecordName(AsmParser &parser, FlatSymbolRefAttr &name)
ParseResult parseRecordInitialization(OpAsmParser &parser, StringAttr &name, UnresolvedOp &operand)
SmallVector< RecordValue > getInitializedRecordValues(ValueRange initialValues, ArrayAttr initializedRecords)
void printRecordName(AsmPrinter &printer, Operation *, FlatSymbolRefAttr name)
constexpr T checkedCast(U u) noexcept
void printMultiDimAndSymbolList(mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::OperandRangeRange multiMapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
mlir::ParseResult parseMultiDimAndSymbolList(mlir::OpAsmParser &parser, mlir::SmallVector< mlir::SmallVector< mlir::OpAsmParser::UnresolvedOperand > > &multiMapOperands, mlir::DenseI32ArrayAttr &numDimsPerMap)
void setInitializedRecords(const ::mlir::ArrayAttr &propValue)