26#include <llvm/ADT/STLExtras.h>
27#include <llvm/ADT/SmallVector.h>
28#include <llvm/ADT/TypeSwitch.h>
29#include <llvm/Support/Debug.h>
34#define DEBUG_TYPE "llzk-type-helpers"
41using namespace component;
43using namespace polymorphic;
44using namespace string;
49template <
typename Derived,
typename ResultType>
struct LLZKTypeSwitch {
50 inline ResultType
match(Type type) {
51 return llvm::TypeSwitch<Type, ResultType>(type)
52 .template Case<IndexType>([
this](
auto t) {
53 return static_cast<Derived *
>(
this)->caseIndex(t);
55 .
template Case<FeltType>([
this](
auto t) {
56 return static_cast<Derived *
>(
this)->caseFelt(t);
58 .
template Case<StringType>([
this](
auto t) {
59 return static_cast<Derived *
>(
this)->caseString(t);
61 .
template Case<TypeVarType>([
this](
auto t) {
62 return static_cast<Derived *
>(
this)->caseTypeVar(t);
64 .
template Case<ArrayType>([
this](
auto t) {
65 return static_cast<Derived *
>(
this)->caseArray(t);
67 .
template Case<PodType>([
this](
auto t) {
return static_cast<Derived *
>(
this)->casePod(t); })
68 .
template Case<StructType>([
this](
auto t) {
69 return static_cast<Derived *
>(
this)->caseStruct(t);
70 }).Default([
this](Type t) {
71 if (t.isSignlessInteger(1)) {
74 return static_cast<Derived *
>(
this)->caseInvalid(t);
84void BuildShortTypeString::appendSymName(StringRef str) {
92void BuildShortTypeString::appendSymRef(SymbolRefAttr sa) {
93 appendSymName(sa.getRootReference().getValue());
94 for (FlatSymbolRefAttr nestedRef : sa.getNestedReferences()) {
96 appendSymName(nestedRef.getValue());
101 size_t position = ret.size();
104 struct Impl : LLZKTypeSwitch<Impl, void> {
105 BuildShortTypeString &outer;
106 Impl(BuildShortTypeString &outerRef) : outer(outerRef) {}
108 void caseInvalid(Type) { outer.ss <<
"!INVALID"; }
109 void caseBool(IntegerType) { outer.ss <<
'b'; }
110 void caseIndex(IndexType) { outer.ss <<
'i'; }
111 void caseFelt(FeltType) { outer.ss <<
'f'; }
112 void caseString(StringType) { outer.ss <<
's'; }
113 void caseTypeVar(TypeVarType t) {
115 outer.appendSymName(llvm::cast<TypeVarType>(t).getRefName());
118 void caseArray(ArrayType t) {
120 outer.append(t.getElementType());
122 outer.append(t.getDimensionSizes());
125 void casePod(PodType t) {
127 for (
auto record : t.getRecords()) {
128 outer.appendSymRef(record.getNameSym());
132 void caseStruct(StructType t) {
134 outer.appendSymRef(t.getNameRef());
135 if (ArrayAttr params = t.getParams()) {
137 outer.append(params.getValue());
142 Impl(*this).match(type);
145 ret.find(PLACEHOLDER, position) == std::string::npos &&
146 "formatting a Type should not produce the 'PLACEHOLDER' char"
158 size_t position = ret.size();
162 if (
auto ia = llvm::dyn_cast<IntegerAttr>(a)) {
163 Type ty = ia.getType();
164 bool isUnsigned = ty.isUnsignedInteger() || ty.isSignlessInteger(1);
165 ia.getValue().print(ss, !isUnsigned);
166 }
else if (
auto sra = llvm::dyn_cast<SymbolRefAttr>(a)) {
168 }
else if (
auto ta = llvm::dyn_cast<TypeAttr>(a)) {
169 append(ta.getValue());
170 }
else if (
auto ama = llvm::dyn_cast<AffineMapAttr>(a)) {
173 filtered_raw_ostream fs(ss, [](
char c) {
return c ==
' '; });
174 ama.getValue().print(fs);
177 }
else if (
auto aa = llvm::dyn_cast<ArrayAttr>(a)) {
178 append(aa.getValue());
184 ret.find(PLACEHOLDER, position) == std::string::npos &&
185 "formatting a non-null Attribute should not produce the 'PLACEHOLDER' char"
191 llvm::interleave(attrs, ss, [
this](Attribute a) { append(a); },
"_");
196 BuildShortTypeString bldr;
198 bldr.ret.reserve(base.size() + attrs.size());
201 const auto *END = attrs.end();
202 const auto *IT = attrs.begin();
205 for (
size_t pos; (pos = base.find(PLACEHOLDER, start)) != std::string::npos; start = pos + 1) {
207 bldr.ret.append(base, start, pos - start);
209 assert(IT != END &&
"must have an Attribute for every 'PLACEHOLDER' char");
213 bldr.ret.append(base, start, base.size() - start);
219 bldr.append(ArrayRef(IT, END));
227template <
typename... Types>
class TypeList {
230 template <
typename StreamType>
struct Appender {
233 template <
typename Ty>
static inline void append(StreamType &stream) {
234 stream <<
'\'' << Ty::name <<
'\'';
238 template <
typename First,
typename Second,
typename... Rest>
239 static void append(StreamType &stream) {
240 append<First>(stream);
242 append<Second, Rest...>(stream);
246 static inline void append(StreamType &stream) {
248 append<Types...>(stream);
255 template <
typename T>
static inline bool matches(
const T &value) {
256 return llvm::isa_and_present<Types...>(value);
259 static void reportInvalid(
EmitErrorFn emitError,
const Twine &foundName,
const char *aspect) {
260 InFlightDiagnosticWrapper diag = emitError().append(aspect,
" must be one of ");
261 Appender<InFlightDiagnosticWrapper>::append(diag);
262 diag.append(
" but found '", foundName,
'\'').report();
265 static inline void reportInvalid(
EmitErrorFn emitError, Attribute found,
const char *aspect) {
267 reportInvalid(emitError, found ? found.getAbstractAttribute().getName() :
"nullptr", aspect);
272 static inline std::string
getNames() {
279template <
class... Ts>
struct make_unique {
280 using type = TypeList<Ts...>;
283template <
class... Ts>
struct make_unique<TypeList<>, Ts...> : make_unique<Ts...> {};
285template <
class U,
class... Us,
class... Ts>
286struct make_unique<TypeList<U, Us...>, Ts...>
287 : std::conditional_t<
288 (std::is_same_v<U, Us> || ...) || (std::is_same_v<U, Ts> || ...),
289 make_unique<TypeList<Us...>, Ts...>, make_unique<TypeList<Us...>, Ts..., U>> {};
291template <
class... Ts>
using TypeListUnion =
typename make_unique<Ts...>::type;
297using ArrayDimensionTypes = TypeList<IntegerAttr, SymbolRefAttr, AffineMapAttr>;
305using StructParamTypes =
306 TypeList<IntegerAttr, FeltConstAttr, SymbolRefAttr, TypeAttr, AffineMapAttr>;
309 struct ColumnCheckData {
310 SymbolTableCollection *symbolTable =
nullptr;
311 Operation *op =
nullptr;
314 bool no_felt : 1 =
false;
315 bool no_string : 1 =
false;
316 bool no_struct : 1 =
false;
317 bool no_array : 1 =
false;
318 bool no_pod : 1 =
false;
319 bool no_var : 1 =
false;
320 bool no_int : 1 =
false;
321 bool no_struct_params : 1 =
false;
322 bool must_be_column : 1 =
false;
324 ColumnCheckData columnCheck;
329 bool validColumns(StructType s) {
330 if (!must_be_column) {
333 assert(columnCheck.symbolTable);
334 assert(columnCheck.op);
335 return succeeded(s.hasColumns(*columnCheck.symbolTable, columnCheck.op));
339 constexpr AllowedTypes &noFelt() {
344 constexpr AllowedTypes &noString() {
349 constexpr AllowedTypes &noStruct() {
354 constexpr AllowedTypes &noArray() {
359 constexpr AllowedTypes &noPod() {
364 constexpr AllowedTypes &noVar() {
369 constexpr AllowedTypes &noInt() {
374 constexpr AllowedTypes &noStructParams(
bool noStructParams =
true) {
375 no_struct_params = noStructParams;
379 constexpr AllowedTypes &onlyInt() {
381 return noFelt().noString().noStruct().noArray().noPod().noVar();
384 constexpr AllowedTypes &mustBeColumn(SymbolTableCollection &symbolTable, Operation *op) {
385 must_be_column =
true;
386 columnCheck.symbolTable = &symbolTable;
392 bool isValidTypeImpl(Type type);
394 bool areValidArrayDimSizes(ArrayRef<Attribute> dimensionSizes,
EmitErrorFn emitError =
nullptr) {
396 if (dimensionSizes.empty()) {
398 emitError().append(
"array must have at least one dimension").report();
405 for (Attribute a : dimensionSizes) {
406 if (!ArrayDimensionTypes::matches(a)) {
407 ArrayDimensionTypes::reportInvalid(emitError, a,
"Array dimension");
409 }
else if (no_var && !llvm::isa_and_present<IntegerAttr>(a)) {
410 TypeList<IntegerAttr>::reportInvalid(emitError, a,
"Concrete array dimension");
421 bool isValidArrayElemTypeImpl(Type type) {
423 return !llvm::isa<ArrayType>(type) && isValidTypeImpl(type);
426 bool isValidArrayTypeImpl(
427 Type elementType, ArrayRef<Attribute> dimensionSizes,
EmitErrorFn emitError =
nullptr
429 if (!areValidArrayDimSizes(dimensionSizes, emitError)) {
434 if (!isValidArrayElemTypeImpl(elementType)) {
442 elementType.getAbstractType().getName(),
'\''
452 bool isValidArrayTypeImpl(Type type) {
453 if (ArrayType arrTy = llvm::dyn_cast<ArrayType>(type)) {
454 return isValidArrayTypeImpl(arrTy.getElementType(), arrTy.getDimensionSizes());
461 bool areValidStructTypeParams(ArrayAttr params,
EmitErrorFn emitError =
nullptr) {
465 if (no_struct_params) {
469 for (Attribute p : params) {
470 if (!StructParamTypes::matches(p)) {
471 StructParamTypes::reportInvalid(emitError, p,
"Struct parameter");
473 }
else if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(p)) {
474 if (!isValidTypeImpl(tyAttr.getValue())) {
476 emitError().append(
"expected a valid LLZK type but found ", tyAttr.getValue()).report();
480 }
else if (no_var && !llvm::isa<IntegerAttr>(p)) {
481 TypeList<IntegerAttr>::reportInvalid(emitError, p,
"Concrete struct parameter");
493 bool areValidPodRecords(ArrayRef<RecordAttr> records) {
494 return llvm::all_of(records, [
this](
auto record) {
return isValidTypeImpl(record.getType()); });
498bool AllowedTypes::isValidTypeImpl(Type type) {
500 !(no_int && no_felt && no_string && no_var && no_struct && no_array && no_pod) &&
501 "All types have been deactivated"
503 struct Impl : LLZKTypeSwitch<Impl, bool> {
505 Impl(AllowedTypes &outerRef) : outer(outerRef) {}
507 bool caseBool(IntegerType t) {
return !outer.no_int && t.isSignlessInteger(1); }
508 bool caseIndex(IndexType) {
return !outer.no_int; }
509 bool caseFelt(FeltType) {
return !outer.no_felt; }
510 bool caseString(StringType) {
return !outer.no_string; }
511 bool caseTypeVar(TypeVarType) {
return !outer.no_var; }
512 bool caseArray(ArrayType t) {
513 return !outer.no_array &&
514 outer.isValidArrayTypeImpl(t.getElementType(), t.getDimensionSizes());
516 bool casePod(PodType t) {
return !outer.no_pod && outer.areValidPodRecords(t.getRecords()); }
517 bool caseStruct(StructType t) {
519 if (outer.no_struct || !outer.validColumns(t)) {
522 return !outer.no_struct && outer.areValidStructTypeParams(t.getParams());
524 bool caseInvalid(Type) {
return false; }
526 return Impl(*this).match(type);
531bool isValidType(Type type) {
return AllowedTypes().isValidTypeImpl(type); }
534 return AllowedTypes().noString().noInt().mustBeColumn(symbolTable, op).isValidTypeImpl(type);
540 return AllowedTypes().noString().noStruct().isValidTypeImpl(type);
545 return AllowedTypes().noString().noStruct().noArray().noPod().isValidTypeImpl(type);
553 return AllowedTypes().noVar().noStructParams(!allowStructParams).isValidTypeImpl(type);
557 bool encountered =
false;
558 type.walk([&](AffineMapAttr) {
560 return WalkResult::interrupt();
569 uint64_t caseBool(IntegerType) {
return 1; }
570 uint64_t caseIndex(IndexType) {
return 1; }
571 uint64_t caseFelt(
FeltType) {
return 1; }
573 int64_t n = t.getNumElements();
576 uint64_t caseStruct(
StructType) { llvm_unreachable(
"not a valid EmitEq type"); }
578 return std::accumulate(
580 [](
const uint64_t &acc,
const RecordAttr &record) {
581 return computeEmitEqCardinality(record.getType()) + acc;
585 uint64_t caseString(
StringType) { llvm_unreachable(
"not a valid EmitEq type"); }
586 uint64_t caseTypeVar(
TypeVarType) { llvm_unreachable(
"tvar has unknown cardinality"); }
587 uint64_t caseInvalid(Type) { llvm_unreachable(
"not a valid LLZK type"); }
589 return Impl().match(type);
602using AffineInstantiations = DenseMap<std::pair<AffineMapAttr, Side>, IntegerAttr>;
605 ArrayRef<StringRef> rhsRevPrefix;
606 UnificationMap *unifications;
607 AffineInstantiations *affineToIntTracker;
610 llvm::function_ref<bool(Type oldTy, Type newTy)> overrideSuccess;
612 UnifierImpl(UnificationMap *unificationMap, ArrayRef<StringRef> rhsReversePrefix = {})
613 : rhsRevPrefix(rhsReversePrefix), unifications(unificationMap), affineToIntTracker(nullptr),
614 overrideSuccess(nullptr) {}
616 UnifierImpl &trackAffineToInt(AffineInstantiations *tracker) {
617 this->affineToIntTracker = tracker;
621 UnifierImpl &withOverrides(llvm::function_ref<
bool(Type oldTy, Type newTy)> overrides) {
622 this->overrideSuccess = overrides;
628 template <
typename Iter1,
typename Iter2>
bool typeListsUnify(Iter1 lhs, Iter2 rhs) {
629 return (lhs.size() == rhs.size()) &&
630 std::equal(lhs.begin(), lhs.end(), rhs.begin(), [
this](Type a, Type b) {
631 return this->typesUnify(a, b);
638 const ArrayRef<Attribute> &lhsParams,
const ArrayRef<Attribute> &rhsParams,
639 bool unifyDynamicSize =
false
641 auto pred = [
this, unifyDynamicSize](
auto lhsAttr,
auto rhsAttr) {
642 return paramAttrUnify(lhsAttr, rhsAttr, unifyDynamicSize);
644 return (lhsParams.size() == rhsParams.size()) &&
645 std::equal(lhsParams.begin(), lhsParams.end(), rhsParams.begin(), pred);
653 const ArrayAttr &lhsParams,
const ArrayAttr &rhsParams,
bool unifyDynamicSize =
false
655 ArrayRef<Attribute> emptyParams;
657 lhsParams ? lhsParams.getValue() : emptyParams,
658 rhsParams ? rhsParams.getValue() : emptyParams, unifyDynamicSize
664 if (!
typesUnify(lhs.getElementType(), rhs.getElementType())) {
669 lhs.getDimensionSizes(), rhs.getDimensionSizes(),
true
675 llvm::dbgs() <<
"[structTypesUnify] lhs = " << lhs <<
", rhs = " << rhs <<
'\n';
678 SmallVector<StringRef> rhsNames =
getNames(rhs.getNameRef());
679 rhsNames.insert(rhsNames.begin(), rhsRevPrefix.rbegin(), rhsRevPrefix.rend());
680 auto lhsNames =
getNames(lhs.getNameRef());
681 if (rhsNames != lhsNames) {
683 llvm::interleaveComma(
684 lhsNames, llvm::dbgs() <<
"[structTypesUnify] names do not match\n"
687 llvm::interleaveComma(
688 rhsNames, llvm::dbgs() <<
"]\n"
691 llvm::dbgs() <<
"]\n";
695 LLVM_DEBUG({ llvm::dbgs() <<
"[structTypesUnify] checking unification of parameters\n"; });
702 auto lhsRecords = lhs.getRecords();
703 auto rhsRecords = rhs.getRecords();
705 return lhsRecords.size() == rhsRecords.size() &&
706 llvm::all_of(llvm::zip_equal(lhsRecords, rhsRecords), [
this](
auto &&records) {
707 auto &&[lhsRecord, rhsRecord] = records;
708 return lhsRecord.getName() == rhsRecord.getName() &&
709 typesUnify(lhsRecord.getType(), rhsRecord.getType());
722 if (overrideSuccess && overrideSuccess(lhs, rhs)) {
726 if (TypeVarType lhsTvar = llvm::dyn_cast<TypeVarType>(lhs)) {
727 track(Side::LHS, lhsTvar.getNameRef(), rhs);
730 if (TypeVarType rhsTvar = llvm::dyn_cast<TypeVarType>(rhs)) {
731 track(Side::RHS, rhsTvar.getNameRef(), lhs);
734 if (llvm::isa<StructType>(lhs) && llvm::isa<StructType>(rhs)) {
735 return structTypesUnify(llvm::cast<StructType>(lhs), llvm::cast<StructType>(rhs));
737 if (llvm::isa<ArrayType>(lhs) && llvm::isa<ArrayType>(rhs)) {
738 return arrayTypesUnify(llvm::cast<ArrayType>(lhs), llvm::cast<ArrayType>(rhs));
740 if (llvm::isa<PodType>(lhs) && llvm::isa<PodType>(rhs)) {
741 return podTypesUnify(llvm::cast<PodType>(lhs), llvm::cast<PodType>(rhs));
743 if (llvm::isa<FunctionType>(lhs) && llvm::isa<FunctionType>(rhs)) {
744 return functionTypesUnify(llvm::cast<FunctionType>(lhs), llvm::cast<FunctionType>(rhs));
750 template <
typename Tracker,
typename Key,
typename Val>
751 inline void track(Tracker &tracker, Side side, Key keyHead, Val val) {
752 auto key = std::make_pair(keyHead, side);
753 auto it = tracker.find(key);
754 if (it == tracker.end()) {
755 tracker.try_emplace(key, val);
756 }
else if (it->getSecond() != val) {
757 it->second =
nullptr;
761 void track(Side side, SymbolRefAttr symRef, Type ty) {
764 if (TypeVarType tvar = dyn_cast<TypeVarType>(ty)) {
766 attr = tvar.getNameRef();
769 attr = TypeAttr::get(ty);
773 track(*unifications, side, symRef, attr);
777 void track(Side side, SymbolRefAttr symRef, Attribute attr) {
780 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(attr)) {
781 if (TypeVarType tvar = dyn_cast<TypeVarType>(tyAttr.getValue())) {
782 attr = tvar.getNameRef();
791 if (SymbolRefAttr otherSymAttr = dyn_cast<SymbolRefAttr>(attr)) {
792 track(*unifications,
reverse(side), otherSymAttr, symRef);
794 track(*unifications, side, symRef, attr);
798 void track(Side side, AffineMapAttr affineAttr, IntegerAttr intAttr) {
799 if (affineToIntTracker) {
803 track(*affineToIntTracker, side, affineAttr, intAttr);
807 bool paramAttrUnify(Attribute lhsAttr, Attribute rhsAttr,
bool unifyDynamicSize =
false) {
811 if (lhsAttr == rhsAttr) {
816 if (AffineMapAttr lhsAffine = llvm::dyn_cast<AffineMapAttr>(lhsAttr)) {
817 if (IntegerAttr rhsInt = llvm::dyn_cast<IntegerAttr>(rhsAttr)) {
819 track(Side::LHS, lhsAffine, rhsInt);
824 if (AffineMapAttr rhsAffine = llvm::dyn_cast<AffineMapAttr>(rhsAttr)) {
825 if (IntegerAttr lhsInt = llvm::dyn_cast<IntegerAttr>(lhsAttr)) {
827 track(Side::RHS, rhsAffine, lhsInt);
834 if (SymbolRefAttr lhsSymRef = llvm::dyn_cast<SymbolRefAttr>(lhsAttr)) {
835 track(Side::LHS, lhsSymRef, rhsAttr);
838 if (SymbolRefAttr rhsSymRef = llvm::dyn_cast<SymbolRefAttr>(rhsAttr)) {
839 track(Side::RHS, rhsSymRef, lhsAttr);
847 if (unifyDynamicSize) {
848 auto dyn_cast_if_dynamic = [](Attribute attr) -> IntegerAttr {
849 if (IntegerAttr intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
856 auto is_const_like = [](Attribute attr) {
857 return llvm::isa_and_present<IntegerAttr, SymbolRefAttr, AffineMapAttr>(attr);
859 if (IntegerAttr lhsIntAttr = dyn_cast_if_dynamic(lhsAttr)) {
860 if (is_const_like(rhsAttr)) {
864 if (IntegerAttr rhsIntAttr = dyn_cast_if_dynamic(rhsAttr)) {
865 if (is_const_like(lhsAttr)) {
871 if (TypeAttr lhsTy = llvm::dyn_cast<TypeAttr>(lhsAttr)) {
872 if (TypeAttr rhsTy = llvm::dyn_cast<TypeAttr>(rhsAttr)) {
873 return typesUnify(lhsTy.getValue(), rhsTy.getValue());
884 const ArrayRef<Attribute> &lhsParams,
const ArrayRef<Attribute> &rhsParams,
887 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
893 const ArrayAttr &lhsParams,
const ArrayAttr &rhsParams,
UnificationMap *unifications
895 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
901 return UnifierImpl(unifications, rhsReversePrefix).arrayTypesUnify(lhs, rhs);
908 return UnifierImpl(unifications, rhsReversePrefix).structTypesUnify(lhs, rhs);
914 return UnifierImpl(unifications, rhsReversePrefix).podTypesUnify(lhs, rhs);
918 FunctionType lhs, FunctionType rhs, ArrayRef<StringRef> rhsReversePrefix,
921 return UnifierImpl(unifications, rhsReversePrefix).functionTypesUnify(lhs, rhs);
925 Type lhs, Type rhs, ArrayRef<StringRef> rhsReversePrefix,
UnificationMap *unifications
927 return UnifierImpl(unifications, rhsReversePrefix).typesUnify(lhs, rhs);
931 Type oldTy, Type newTy, llvm::function_ref<
bool(Type oldTy, Type newTy)> knownOldToNew
934 AffineInstantiations affineInstantiations;
936 if (!UnifierImpl(&unifications)
937 .trackAffineToInt(&affineInstantiations)
938 .withOverrides(knownOldToNew)
948 auto entryIsRHS = [](
const auto &entry) {
return entry.first.second ==
Side::RHS; };
949 return !llvm::any_of(unifications, entryIsRHS) && !llvm::any_of(affineInstantiations, entryIsRHS);
953 if (llvm::isa<IndexType>(attr.getType())) {
958 APInt value = attr.getValue();
959 auto compare = value.getBitWidth() <=> IndexType::kInternalStorageBitWidth;
961 value = value.zext(IndexType::kInternalStorageBitWidth);
962 }
else if (compare > 0) {
963 return emitError().append(
"value is too large for `index` type: ",
debug::toStringOne(value));
965 return IntegerAttr::get(IndexType::get(attr.getContext()), value);
969 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(attr)) {
975FailureOr<SmallVector<Attribute>>
977 SmallVector<Attribute> result;
978 for (Attribute attr : attrList) {
980 if (failed(forced)) {
983 result.push_back(*forced);
989 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(in)) {
990 Type attrTy = intAttr.getType();
991 if (!AllowedTypes().onlyInt().isValidTypeImpl(attrTy)) {
994 .append(
"IntegerAttr must have type 'index' or 'i1' but found '", attrTy,
'\'')
1004 if (AffineMapAttr affineAttr = llvm::dyn_cast_if_present<AffineMapAttr>(in)) {
1005 AffineMap map = affineAttr.getValue();
1006 if (map.getNumResults() != 1) {
1010 "AffineMapAttr must yield a single result, but found ", map.getNumResults(),
1022 return success(AllowedTypes().areValidStructTypeParams(params, emitError));
1026 return success(AllowedTypes().areValidArrayDimSizes(dimensionSizes, emitError));
1031 return success(AllowedTypes().isValidArrayTypeImpl(elementType, dimensionSizes, emitError));
1036 using TypeVarAttrs = TypeList<SymbolRefAttr>;
1037 if (!TypeListUnion<ArrayDimensionTypes, StructParamTypes, TypeVarAttrs>::matches(attr)) {
1038 llvm::report_fatal_error(
1039 "Legal type parameters are inconsistent. Encountered " +
1040 attr.getAbstractAttribute().getName()
1048 size_t numArrDims = dimsFromArr.size();
1050 size_t numSubArrDims = dimsFromSubArr.size();
1052 if (numArrDims < numSubArrDims) {
1053 return emitError().append(
1054 "subarray type ", subArrayType,
" has more dimensions than array type ", arrayType
1058 size_t toDrop = numArrDims - numSubArrDims;
1059 ArrayRef<Attribute> dimsFromArrReduced = dimsFromArr.drop_front(toDrop);
1063 std::string message;
1064 llvm::raw_string_ostream ss(message);
1066 ss <<
"cannot unify array dimensions [";
1067 llvm::interleaveComma(dimsFromArrReduced, ss, appendOne);
1069 llvm::interleaveComma(dimsFromSubArr, ss, appendOne);
1071 return emitError().append(message);
1076 return emitError().append(
1077 "incorrect array element type; expected: ", arrayType.
getElementType(),
1087 if (
auto subArrayType = llvm::dyn_cast<ArrayType>(subArrayOrElemType)) {
1091 return emitError().append(
1092 "incorrect array element type; expected: ", arrayType.
getElementType(),
1093 ", found: ", subArrayOrElemType
1101 return TypeSwitch<Type, bool>(ty)
1102 .Case<
FeltType>([](
auto) {
return true; })
1103 .Case<ArrayType>([](
auto arrTy) {
1106 .Case<PodType>([](
auto podTy) {
1107 for (
auto record : podTy.getRecords()) {
1113 }).Default([](
auto) {
return false; });
1117 if (
auto arrayParamTy = llvm::dyn_cast<ArrayType>(pType)) {
1118 return llvm::isa<FeltType>(arrayParamTy.getElementType());
1120 return llvm::isa<FeltType>(pType);
Note: If any symbol refs in an input Type/Attribute use any of the special characters that this class...
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
::mlir::Type getElementType() const
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
static constexpr ::llvm::StringLiteral name
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
std::string toStringOne(const T &value)
LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in)
void assertValidAttrForParamOfType(Attribute attr)
LogicalResult verifySubArrayType(EmitErrorFn emitError, ArrayType arrayType, ArrayType subArrayType)
Determine if the subArrayType is a valid subarray of arrayType.
FailureOr< Attribute > forceIntAttrType(Attribute attr, EmitErrorFn emitError)
uint64_t computeEmitEqCardinality(Type type)
bool isValidArrayType(Type type)
LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in)
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
bool isConcreteType(Type type, bool allowStructParams)
bool isValidArrayElemType(Type type)
llvm::SmallVector< StringRef > getNames(SymbolRefAttr ref)
bool isValidGlobalType(Type type)
FailureOr< IntegerAttr > forceIntType(IntegerAttr attr, EmitErrorFn emitError)
Convert an IntegerAttr with a type other than IndexType to use IndexType.
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
LogicalResult verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef< Attribute > dimensionSizes)
bool isFeltOrSimpleFeltAggregate(Type ty)
LogicalResult verifySubArrayOrElementType(EmitErrorFn emitError, ArrayType arrayType, Type subArrayOrElemType)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
bool isValidMainSignalType(Type pType)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
FailureOr< SmallVector< Attribute > > forceIntAttrTypes(ArrayRef< Attribute > attrList, EmitErrorFn emitError)
bool isNullOrEmpty(mlir::ArrayAttr a)
bool podTypesUnify(PodType lhs, PodType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
constexpr T checkedCast(U u) noexcept
bool isValidEmitEqType(Type type)
bool isValidType(Type type)
bool arrayTypesUnify(ArrayType lhs, ArrayType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isDynamic(IntegerAttr intAttr)
int64_t fromAPInt(const llvm::APInt &i)
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool typeParamsUnify(const ArrayRef< Attribute > &lhsParams, const ArrayRef< Attribute > &rhsParams, UnificationMap *unifications)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
bool functionTypesUnify(FunctionType lhs, FunctionType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params)
void appendWithoutType(mlir::raw_ostream &os, mlir::Attribute a)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
bool hasAffineMapAttr(Type type)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
bool isValidConstReadType(Type type)
LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef< Attribute > dimensionSizes)
Template pattern for performing some operation by cases based on a given LLZK type.
ResultType match(Type type)