23#include <llvm/ADT/STLExtras.h>
24#include <llvm/ADT/SmallVector.h>
25#include <llvm/ADT/TypeSwitch.h>
26#include <llvm/Support/Debug.h>
31#define DEBUG_TYPE "llzk-type-helpers"
38using namespace component;
40using namespace polymorphic;
41using namespace string;
47 inline ResultType
match(Type type) {
48 return llvm::TypeSwitch<Type, ResultType>(type)
49 .template Case<IndexType>([
this](
auto t) {
50 return static_cast<Derived *
>(
this)->caseIndex(t);
52 .
template Case<FeltType>([
this](
auto t) {
53 return static_cast<Derived *
>(
this)->caseFelt(t);
55 .
template Case<StringType>([
this](
auto t) {
56 return static_cast<Derived *
>(
this)->caseString(t);
58 .
template Case<TypeVarType>([
this](
auto t) {
59 return static_cast<Derived *
>(
this)->caseTypeVar(t);
61 .
template Case<ArrayType>([
this](
auto t) {
62 return static_cast<Derived *
>(
this)->caseArray(t);
64 .
template Case<PodType>([
this](
auto t) {
return static_cast<Derived *
>(
this)->casePod(t); })
65 .
template Case<StructType>([
this](
auto t) {
66 return static_cast<Derived *
>(
this)->caseStruct(t);
67 }).Default([
this](Type t) {
68 if (t.isSignlessInteger(1)) {
71 return static_cast<Derived *
>(
this)->caseInvalid(t);
77void BuildShortTypeString::appendSymName(StringRef str) {
85void BuildShortTypeString::appendSymRef(SymbolRefAttr sa) {
86 appendSymName(sa.getRootReference().getValue());
87 for (FlatSymbolRefAttr nestedRef : sa.getNestedReferences()) {
89 appendSymName(nestedRef.getValue());
94 size_t position = ret.size();
97 struct Impl : LLZKTypeSwitch<Impl, void> {
98 BuildShortTypeString &outer;
99 Impl(BuildShortTypeString &outerRef) : outer(outerRef) {}
101 void caseInvalid(Type) { outer.ss <<
"!INVALID"; }
102 void caseBool(IntegerType) { outer.ss <<
'b'; }
103 void caseIndex(IndexType) { outer.ss <<
'i'; }
104 void caseFelt(FeltType) { outer.ss <<
'f'; }
105 void caseString(StringType) { outer.ss <<
's'; }
106 void caseTypeVar(TypeVarType t) {
108 outer.appendSymName(llvm::cast<TypeVarType>(t).getRefName());
111 void caseArray(ArrayType t) {
113 outer.append(t.getElementType());
115 outer.append(t.getDimensionSizes());
118 void casePod(PodType t) {
120 for (
auto record : t.getRecords()) {
121 outer.appendSymRef(record.getNameSym());
125 void caseStruct(StructType t) {
127 outer.appendSymRef(t.getNameRef());
128 if (ArrayAttr params = t.getParams()) {
130 outer.append(params.getValue());
135 Impl(*this).match(type);
138 ret.find(PLACEHOLDER, position) == std::string::npos &&
139 "formatting a Type should not produce the 'PLACEHOLDER' char"
151 size_t position = ret.size();
155 if (
auto ia = llvm::dyn_cast<IntegerAttr>(a)) {
156 Type ty = ia.getType();
157 bool isUnsigned = ty.isUnsignedInteger() || ty.isSignlessInteger(1);
158 ia.getValue().print(ss, !isUnsigned);
159 }
else if (
auto sra = llvm::dyn_cast<SymbolRefAttr>(a)) {
161 }
else if (
auto ta = llvm::dyn_cast<TypeAttr>(a)) {
162 append(ta.getValue());
163 }
else if (
auto ama = llvm::dyn_cast<AffineMapAttr>(a)) {
166 filtered_raw_ostream fs(ss, [](
char c) {
return c ==
' '; });
167 ama.getValue().print(fs);
170 }
else if (
auto aa = llvm::dyn_cast<ArrayAttr>(a)) {
171 append(aa.getValue());
177 ret.find(PLACEHOLDER, position) == std::string::npos &&
178 "formatting a non-null Attribute should not produce the 'PLACEHOLDER' char"
184 llvm::interleave(attrs, ss, [
this](Attribute a) { append(a); },
"_");
189 BuildShortTypeString bldr;
191 bldr.ret.reserve(base.size() + attrs.size());
194 auto END = attrs.end();
195 auto IT = attrs.begin();
198 for (
size_t pos; (pos = base.find(PLACEHOLDER, start)) != std::string::npos; start = pos + 1) {
200 bldr.ret.append(base, start, pos - start);
202 assert(IT != END &&
"must have an Attribute for every 'PLACEHOLDER' char");
206 bldr.ret.append(base, start, base.size() - start);
212 bldr.append(ArrayRef(IT, END));
220template <
typename... Types>
class TypeList {
223 template <
typename StreamType>
struct Appender {
226 template <
typename Ty>
static inline void append(StreamType &stream) {
227 stream <<
'\'' << Ty::name <<
'\'';
231 template <
typename First,
typename Second,
typename... Rest>
232 static void append(StreamType &stream) {
233 append<First>(stream);
235 append<Second, Rest...>(stream);
239 static inline void append(StreamType &stream) {
241 append<Types...>(stream);
248 template <
typename T>
static inline bool matches(
const T &value) {
249 return llvm::isa_and_present<Types...>(value);
252 static void reportInvalid(
EmitErrorFn emitError,
const Twine &foundName,
const char *aspect) {
253 InFlightDiagnosticWrapper diag = emitError().append(aspect,
" must be one of ");
254 Appender<InFlightDiagnosticWrapper>::append(diag);
255 diag.append(
" but found '", foundName,
'\'').report();
258 static inline void reportInvalid(
EmitErrorFn emitError, Attribute found,
const char *aspect) {
260 reportInvalid(emitError, found ? found.getAbstractAttribute().getName() :
"nullptr", aspect);
265 static inline std::string
getNames() {
272template <
class... Ts>
struct make_unique {
273 using type = TypeList<Ts...>;
276template <
class... Ts>
struct make_unique<TypeList<>, Ts...> : make_unique<Ts...> {};
278template <
class U,
class... Us,
class... Ts>
279struct make_unique<TypeList<U, Us...>, Ts...>
280 : std::conditional_t<
281 (std::is_same_v<U, Us> || ...) || (std::is_same_v<U, Ts> || ...),
282 make_unique<TypeList<Us...>, Ts...>, make_unique<TypeList<Us...>, Ts..., U>> {};
284template <
class... Ts>
using TypeListUnion =
typename make_unique<Ts...>::type;
290using ArrayDimensionTypes = TypeList<IntegerAttr, SymbolRefAttr, AffineMapAttr>;
297using StructParamTypes = TypeList<IntegerAttr, SymbolRefAttr, TypeAttr, AffineMapAttr>;
300 struct ColumnCheckData {
301 SymbolTableCollection *symbolTable =
nullptr;
302 Operation *op =
nullptr;
305 bool no_felt : 1 =
false;
306 bool no_string : 1 =
false;
307 bool no_struct : 1 =
false;
308 bool no_array : 1 =
false;
309 bool no_pod : 1 =
false;
310 bool no_var : 1 =
false;
311 bool no_int : 1 =
false;
312 bool no_struct_params : 1 =
false;
313 bool must_be_column : 1 =
false;
315 ColumnCheckData columnCheck;
320 bool validColumns(StructType s) {
321 if (!must_be_column) {
324 assert(columnCheck.symbolTable);
325 assert(columnCheck.op);
326 return succeeded(s.hasColumns(*columnCheck.symbolTable, columnCheck.op));
330 constexpr AllowedTypes &noFelt() {
335 constexpr AllowedTypes &noString() {
340 constexpr AllowedTypes &noStruct() {
345 constexpr AllowedTypes &noArray() {
350 constexpr AllowedTypes &noPod() {
355 constexpr AllowedTypes &noVar() {
360 constexpr AllowedTypes &noInt() {
365 constexpr AllowedTypes &noStructParams(
bool noStructParams =
true) {
366 no_struct_params = noStructParams;
370 constexpr AllowedTypes &onlyInt() {
372 return noFelt().noString().noStruct().noArray().noPod().noVar();
375 constexpr AllowedTypes &mustBeColumn(SymbolTableCollection &symbolTable, Operation *op) {
376 must_be_column =
true;
377 columnCheck.symbolTable = &symbolTable;
383 bool isValidTypeImpl(Type type);
385 bool areValidArrayDimSizes(ArrayRef<Attribute> dimensionSizes,
EmitErrorFn emitError =
nullptr) {
387 if (dimensionSizes.empty()) {
389 emitError().append(
"array must have at least one dimension").report();
396 for (Attribute a : dimensionSizes) {
397 if (!ArrayDimensionTypes::matches(a)) {
398 ArrayDimensionTypes::reportInvalid(emitError, a,
"Array dimension");
400 }
else if (no_var && !llvm::isa_and_present<IntegerAttr>(a)) {
401 TypeList<IntegerAttr>::reportInvalid(emitError, a,
"Concrete array dimension");
412 bool isValidArrayElemTypeImpl(Type type) {
414 return !llvm::isa<ArrayType>(type) && isValidTypeImpl(type);
417 bool isValidArrayTypeImpl(
418 Type elementType, ArrayRef<Attribute> dimensionSizes,
EmitErrorFn emitError =
nullptr
420 if (!areValidArrayDimSizes(dimensionSizes, emitError)) {
425 if (!isValidArrayElemTypeImpl(elementType)) {
433 elementType.getAbstractType().getName(),
'\''
443 bool isValidArrayTypeImpl(Type type) {
444 if (ArrayType arrTy = llvm::dyn_cast<ArrayType>(type)) {
445 return isValidArrayTypeImpl(arrTy.getElementType(), arrTy.getDimensionSizes());
452 bool areValidStructTypeParams(ArrayAttr params,
EmitErrorFn emitError =
nullptr) {
456 if (no_struct_params) {
460 for (Attribute p : params) {
461 if (!StructParamTypes::matches(p)) {
462 StructParamTypes::reportInvalid(emitError, p,
"Struct parameter");
464 }
else if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(p)) {
465 if (!isValidTypeImpl(tyAttr.getValue())) {
467 emitError().append(
"expected a valid LLZK type but found ", tyAttr.getValue()).report();
471 }
else if (no_var && !llvm::isa<IntegerAttr>(p)) {
472 TypeList<IntegerAttr>::reportInvalid(emitError, p,
"Concrete struct parameter");
484 bool areValidPodRecords(ArrayRef<RecordAttr> records) {
485 return llvm::all_of(records, [
this](
auto record) {
return isValidTypeImpl(record.getType()); });
489bool AllowedTypes::isValidTypeImpl(Type type) {
491 !(no_int && no_felt && no_string && no_var && no_struct && no_array && no_pod) &&
492 "All types have been deactivated"
494 struct Impl : LLZKTypeSwitch<Impl, bool> {
496 Impl(AllowedTypes &outerRef) : outer(outerRef) {}
498 bool caseBool(IntegerType t) {
return !outer.no_int && t.isSignlessInteger(1); }
499 bool caseIndex(IndexType) {
return !outer.no_int; }
500 bool caseFelt(FeltType) {
return !outer.no_felt; }
501 bool caseString(StringType) {
return !outer.no_string; }
502 bool caseTypeVar(TypeVarType) {
return !outer.no_var; }
503 bool caseArray(ArrayType t) {
504 return !outer.no_array &&
505 outer.isValidArrayTypeImpl(t.getElementType(), t.getDimensionSizes());
507 bool casePod(PodType t) {
return !outer.no_pod && outer.areValidPodRecords(t.getRecords()); }
508 bool caseStruct(StructType t) {
510 if (outer.no_struct || !outer.validColumns(t)) {
513 return !outer.no_struct && outer.areValidStructTypeParams(t.getParams());
515 bool caseInvalid(Type _) {
return false; }
517 return Impl(*this).match(type);
522bool isValidType(Type type) {
return AllowedTypes().isValidTypeImpl(type); }
525 return AllowedTypes().noString().noInt().mustBeColumn(symbolTable, op).isValidTypeImpl(type);
531 return AllowedTypes().noString().noStruct().isValidTypeImpl(type);
536 return AllowedTypes().noString().noStruct().noArray().isValidTypeImpl(type);
544 return AllowedTypes().noVar().noStructParams(!allowStructParams).isValidTypeImpl(type);
548 bool encountered =
false;
549 type.walk([&](AffineMapAttr a) {
551 return WalkResult::interrupt();
560 uint64_t caseBool(IntegerType) {
return 1; }
561 uint64_t caseIndex(IndexType) {
return 1; }
562 uint64_t caseFelt(
FeltType) {
return 1; }
564 int64_t n = t.getNumElements();
568 uint64_t caseStruct(
StructType t) { llvm_unreachable(
"not a valid EmitEq type"); }
570 return std::accumulate(
572 [](
const uint64_t &acc,
const RecordAttr &record) {
573 return computeEmitEqCardinality(record.getType()) + acc;
577 uint64_t caseString(
StringType) { llvm_unreachable(
"not a valid EmitEq type"); }
578 uint64_t caseTypeVar(
TypeVarType) { llvm_unreachable(
"tvar has unknown cardinality"); }
579 uint64_t caseInvalid(Type) { llvm_unreachable(
"not a valid LLZK type"); }
581 return Impl().match(type);
594using AffineInstantiations = DenseMap<std::pair<AffineMapAttr, Side>, IntegerAttr>;
597 ArrayRef<StringRef> rhsRevPrefix;
598 UnificationMap *unifications;
599 AffineInstantiations *affineToIntTracker;
602 llvm::function_ref<bool(Type oldTy, Type newTy)> overrideSuccess;
604 UnifierImpl(UnificationMap *unificationMap, ArrayRef<StringRef> rhsReversePrefix = {})
605 : rhsRevPrefix(rhsReversePrefix), unifications(unificationMap), affineToIntTracker(nullptr),
606 overrideSuccess(nullptr) {}
609 const ArrayRef<Attribute> &lhsParams,
const ArrayRef<Attribute> &rhsParams,
610 bool unifyDynamicSize =
false
612 auto pred = [
this, unifyDynamicSize](
auto lhsAttr,
auto rhsAttr) {
613 return paramAttrUnify(lhsAttr, rhsAttr, unifyDynamicSize);
615 return (lhsParams.size() == rhsParams.size()) &&
616 std::equal(lhsParams.begin(), lhsParams.end(), rhsParams.begin(), pred);
619 UnifierImpl &trackAffineToInt(AffineInstantiations *tracker) {
620 this->affineToIntTracker = tracker;
624 UnifierImpl &withOverrides(llvm::function_ref<
bool(Type oldTy, Type newTy)> overrides) {
625 this->overrideSuccess = overrides;
634 const ArrayAttr &lhsParams,
const ArrayAttr &rhsParams,
bool unifyDynamicSize =
false
636 ArrayRef<Attribute> emptyParams;
638 lhsParams ? lhsParams.getValue() : emptyParams,
639 rhsParams ? rhsParams.getValue() : emptyParams, unifyDynamicSize
645 if (!
typesUnify(lhs.getElementType(), rhs.getElementType())) {
650 lhs.getDimensionSizes(), rhs.getDimensionSizes(),
true
656 llvm::dbgs() <<
"[structTypesUnify] lhs = " << lhs <<
", rhs = " << rhs <<
'\n';
659 SmallVector<StringRef> rhsNames =
getNames(rhs.getNameRef());
660 rhsNames.insert(rhsNames.begin(), rhsRevPrefix.rbegin(), rhsRevPrefix.rend());
661 auto lhsNames =
getNames(lhs.getNameRef());
662 if (rhsNames != lhsNames) {
664 llvm::interleaveComma(
665 lhsNames, llvm::dbgs() <<
"[structTypesUnify] names do not match\n"
668 llvm::interleaveComma(
669 rhsNames, llvm::dbgs() <<
"]\n"
672 llvm::dbgs() <<
"]\n";
676 LLVM_DEBUG({ llvm::dbgs() <<
"[structTypesUnify] checking unification of parameters\n"; });
681 bool podTypesUnify(PodType lhs, PodType rhs) {
683 auto lhsRecords = lhs.getRecords();
684 auto rhsRecords = rhs.getRecords();
686 return lhsRecords.size() == rhsRecords.size() &&
687 llvm::all_of(llvm::zip_equal(lhsRecords, rhsRecords), [
this](
auto &&records) {
688 auto &&[lhsRecord, rhsRecord] = records;
689 return lhsRecord.getName() == rhsRecord.getName() &&
690 typesUnify(lhsRecord.getType(), rhsRecord.getType());
698 if (overrideSuccess && overrideSuccess(lhs, rhs)) {
702 if (TypeVarType lhsTvar = llvm::dyn_cast<TypeVarType>(lhs)) {
703 track(Side::LHS, lhsTvar.getNameRef(), rhs);
706 if (TypeVarType rhsTvar = llvm::dyn_cast<TypeVarType>(rhs)) {
707 track(Side::RHS, rhsTvar.getNameRef(), lhs);
710 if (llvm::isa<StructType>(lhs) && llvm::isa<StructType>(rhs)) {
711 return structTypesUnify(llvm::cast<StructType>(lhs), llvm::cast<StructType>(rhs));
713 if (llvm::isa<ArrayType>(lhs) && llvm::isa<ArrayType>(rhs)) {
714 return arrayTypesUnify(llvm::cast<ArrayType>(lhs), llvm::cast<ArrayType>(rhs));
716 if (llvm::isa<PodType>(lhs) && llvm::isa<PodType>(rhs)) {
717 return podTypesUnify(llvm::cast<PodType>(lhs), llvm::cast<PodType>(rhs));
723 template <
typename Tracker,
typename Key,
typename Val>
724 inline void track(Tracker &tracker, Side side, Key keyHead, Val val) {
725 auto key = std::make_pair(keyHead, side);
726 auto it = tracker.find(key);
727 if (it == tracker.end()) {
728 tracker.try_emplace(key, val);
729 }
else if (it->getSecond() != val) {
730 it->second =
nullptr;
734 void track(Side side, SymbolRefAttr symRef, Type ty) {
737 if (TypeVarType tvar = dyn_cast<TypeVarType>(ty)) {
739 attr = tvar.getNameRef();
742 attr = TypeAttr::get(ty);
746 track(*unifications, side, symRef, attr);
750 void track(Side side, SymbolRefAttr symRef, Attribute attr) {
753 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(attr)) {
754 if (TypeVarType tvar = dyn_cast<TypeVarType>(tyAttr.getValue())) {
755 attr = tvar.getNameRef();
764 if (SymbolRefAttr otherSymAttr = dyn_cast<SymbolRefAttr>(attr)) {
765 track(*unifications,
reverse(side), otherSymAttr, symRef);
767 track(*unifications, side, symRef, attr);
771 void track(Side side, AffineMapAttr affineAttr, IntegerAttr intAttr) {
772 if (affineToIntTracker) {
776 track(*affineToIntTracker, side, affineAttr, intAttr);
780 bool paramAttrUnify(Attribute lhsAttr, Attribute rhsAttr,
bool unifyDynamicSize =
false) {
784 if (lhsAttr == rhsAttr) {
789 if (AffineMapAttr lhsAffine = llvm::dyn_cast<AffineMapAttr>(lhsAttr)) {
790 if (IntegerAttr rhsInt = llvm::dyn_cast<IntegerAttr>(rhsAttr)) {
792 track(Side::LHS, lhsAffine, rhsInt);
797 if (AffineMapAttr rhsAffine = llvm::dyn_cast<AffineMapAttr>(rhsAttr)) {
798 if (IntegerAttr lhsInt = llvm::dyn_cast<IntegerAttr>(lhsAttr)) {
800 track(Side::RHS, rhsAffine, lhsInt);
807 if (SymbolRefAttr lhsSymRef = llvm::dyn_cast<SymbolRefAttr>(lhsAttr)) {
808 track(Side::LHS, lhsSymRef, rhsAttr);
811 if (SymbolRefAttr rhsSymRef = llvm::dyn_cast<SymbolRefAttr>(rhsAttr)) {
812 track(Side::RHS, rhsSymRef, lhsAttr);
820 if (unifyDynamicSize) {
821 auto dyn_cast_if_dynamic = [](Attribute attr) -> IntegerAttr {
822 if (IntegerAttr intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
829 auto is_const_like = [](Attribute attr) {
830 return llvm::isa_and_present<IntegerAttr, SymbolRefAttr, AffineMapAttr>(attr);
832 if (IntegerAttr lhsIntAttr = dyn_cast_if_dynamic(lhsAttr)) {
833 if (is_const_like(rhsAttr)) {
837 if (IntegerAttr rhsIntAttr = dyn_cast_if_dynamic(rhsAttr)) {
838 if (is_const_like(lhsAttr)) {
844 if (TypeAttr lhsTy = llvm::dyn_cast<TypeAttr>(lhsAttr)) {
845 if (TypeAttr rhsTy = llvm::dyn_cast<TypeAttr>(rhsAttr)) {
846 return typesUnify(lhsTy.getValue(), rhsTy.getValue());
857 const ArrayRef<Attribute> &lhsParams,
const ArrayRef<Attribute> &rhsParams,
860 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
866 const ArrayAttr &lhsParams,
const ArrayAttr &rhsParams,
UnificationMap *unifications
868 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
874 return UnifierImpl(unifications, rhsReversePrefix).arrayTypesUnify(lhs, rhs);
881 return UnifierImpl(unifications, rhsReversePrefix).structTypesUnify(lhs, rhs);
885 Type lhs, Type rhs, ArrayRef<StringRef> rhsReversePrefix,
UnificationMap *unifications
887 return UnifierImpl(unifications, rhsReversePrefix).typesUnify(lhs, rhs);
891 Type oldTy, Type newTy, llvm::function_ref<
bool(Type oldTy, Type newTy)> knownOldToNew
894 AffineInstantiations affineInstantiations;
896 if (!UnifierImpl(&unifications)
897 .trackAffineToInt(&affineInstantiations)
898 .withOverrides(knownOldToNew)
908 auto entryIsRHS = [](
const auto &entry) {
return entry.first.second ==
Side::RHS; };
909 return !llvm::any_of(unifications, entryIsRHS) && !llvm::any_of(affineInstantiations, entryIsRHS);
913 if (llvm::isa<IndexType>(attr.getType())) {
918 APInt value = attr.getValue();
919 auto compare = value.getBitWidth() <=> IndexType::kInternalStorageBitWidth;
921 value = value.zext(IndexType::kInternalStorageBitWidth);
922 }
else if (compare > 0) {
923 return emitError().append(
"value is too large for `index` type: ",
debug::toStringOne(value));
925 return IntegerAttr::get(IndexType::get(attr.getContext()), value);
929 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(attr)) {
935FailureOr<SmallVector<Attribute>>
937 SmallVector<Attribute> result;
938 for (Attribute attr : attrList) {
940 if (failed(forced)) {
943 result.push_back(*forced);
949 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(in)) {
950 Type attrTy = intAttr.getType();
951 if (!AllowedTypes().onlyInt().isValidTypeImpl(attrTy)) {
954 .append(
"IntegerAttr must have type 'index' or 'i1' but found '", attrTy,
'\'')
964 if (AffineMapAttr affineAttr = llvm::dyn_cast_if_present<AffineMapAttr>(in)) {
965 AffineMap map = affineAttr.getValue();
966 if (map.getNumResults() != 1) {
970 "AffineMapAttr must yield a single result, but found ", map.getNumResults(),
982 return success(AllowedTypes().areValidStructTypeParams(params, emitError));
986 return success(AllowedTypes().areValidArrayDimSizes(dimensionSizes, emitError));
991 return success(AllowedTypes().isValidArrayTypeImpl(elementType, dimensionSizes, emitError));
996 using TypeVarAttrs = TypeList<SymbolRefAttr>;
997 if (!TypeListUnion<ArrayDimensionTypes, StructParamTypes, TypeVarAttrs>::matches(attr)) {
998 llvm::report_fatal_error(
999 "Legal type parameters are inconsistent. Encountered " +
1000 attr.getAbstractAttribute().getName()
1008 size_t numArrDims = dimsFromArr.size();
1010 size_t numSubArrDims = dimsFromSubArr.size();
1012 if (numArrDims < numSubArrDims) {
1013 return emitError().append(
1014 "subarray type ", subArrayType,
" has more dimensions than array type ", arrayType
1018 size_t toDrop = numArrDims - numSubArrDims;
1019 ArrayRef<Attribute> dimsFromArrReduced = dimsFromArr.drop_front(toDrop);
1023 std::string message;
1024 llvm::raw_string_ostream ss(message);
1026 ss <<
"cannot unify array dimensions [";
1027 llvm::interleaveComma(dimsFromArrReduced, ss, appendOne);
1029 llvm::interleaveComma(dimsFromSubArr, ss, appendOne);
1031 return emitError().append(message);
1036 return emitError().append(
1037 "incorrect array element type; expected: ", arrayType.
getElementType(),
1047 if (
auto subArrayType = llvm::dyn_cast<ArrayType>(subArrayOrElemType)) {
1051 return emitError().append(
1052 "incorrect array element type; expected: ", arrayType.
getElementType(),
1053 ", found: ", subArrayOrElemType
Note: If any symbol refs in an input Type/Attribute use any of the special characters that this class...
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
::mlir::Type getElementType() const
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
static constexpr ::llvm::StringLiteral name
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
std::string toStringOne(const T &value)
LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in)
void assertValidAttrForParamOfType(Attribute attr)
LogicalResult verifySubArrayType(EmitErrorFn emitError, ArrayType arrayType, ArrayType subArrayType)
Determine if the subArrayType is a valid subarray of arrayType.
FailureOr< Attribute > forceIntAttrType(Attribute attr, EmitErrorFn emitError)
uint64_t computeEmitEqCardinality(Type type)
bool isValidArrayType(Type type)
LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in)
bool isConcreteType(Type type, bool allowStructParams)
bool isValidArrayElemType(Type type)
llvm::SmallVector< StringRef > getNames(SymbolRefAttr ref)
bool isValidGlobalType(Type type)
FailureOr< IntegerAttr > forceIntType(IntegerAttr attr, EmitErrorFn emitError)
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
LogicalResult verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef< Attribute > dimensionSizes)
LogicalResult verifySubArrayOrElementType(EmitErrorFn emitError, ArrayType arrayType, Type subArrayOrElemType)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
FailureOr< SmallVector< Attribute > > forceIntAttrTypes(ArrayRef< Attribute > attrList, EmitErrorFn emitError)
bool isNullOrEmpty(mlir::ArrayAttr a)
bool isValidEmitEqType(Type type)
bool isValidType(Type type)
bool arrayTypesUnify(ArrayType lhs, ArrayType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isDynamic(IntegerAttr intAttr)
int64_t fromAPInt(const llvm::APInt &i)
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool typeParamsUnify(const ArrayRef< Attribute > &lhsParams, const ArrayRef< Attribute > &rhsParams, UnificationMap *unifications)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params)
void appendWithoutType(mlir::raw_ostream &os, mlir::Attribute a)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
bool hasAffineMapAttr(Type type)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
bool isValidConstReadType(Type type)
LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef< Attribute > dimensionSizes)
Template pattern for performing some operation by cases based on a given LLZK type.
ResultType match(Type type)