LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
TypeHelper.cpp
Go to the documentation of this file.
1//===-- TypeHelper.cpp ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
18#include "llzk/Util/Debug.h"
22
23#include <llvm/ADT/STLExtras.h>
24#include <llvm/ADT/SmallVector.h>
25#include <llvm/ADT/TypeSwitch.h>
26#include <llvm/Support/Debug.h>
27
28#include <cstdint>
29#include <numeric>
30
31#define DEBUG_TYPE "llzk-type-helpers"
32
33using namespace mlir;
34
35namespace llzk {
36
37using namespace array;
38using namespace component;
39using namespace felt;
40using namespace polymorphic;
41using namespace string;
42using namespace pod;
43
46template <typename Derived, typename ResultType> struct LLZKTypeSwitch {
47 inline ResultType match(Type type) {
48 return llvm::TypeSwitch<Type, ResultType>(type)
49 .template Case<IndexType>([this](auto t) {
50 return static_cast<Derived *>(this)->caseIndex(t);
51 })
52 .template Case<FeltType>([this](auto t) {
53 return static_cast<Derived *>(this)->caseFelt(t);
54 })
55 .template Case<StringType>([this](auto t) {
56 return static_cast<Derived *>(this)->caseString(t);
57 })
58 .template Case<TypeVarType>([this](auto t) {
59 return static_cast<Derived *>(this)->caseTypeVar(t);
60 })
61 .template Case<ArrayType>([this](auto t) {
62 return static_cast<Derived *>(this)->caseArray(t);
63 })
64 .template Case<PodType>([this](auto t) { return static_cast<Derived *>(this)->casePod(t); })
65 .template Case<StructType>([this](auto t) {
66 return static_cast<Derived *>(this)->caseStruct(t);
67 }).Default([this](Type t) {
68 if (t.isSignlessInteger(1)) {
69 return static_cast<Derived *>(this)->caseBool(cast<IntegerType>(t));
70 } else {
71 return static_cast<Derived *>(this)->caseInvalid(t);
72 }
73 });
74 }
75};
76
77void BuildShortTypeString::appendSymName(StringRef str) {
78 if (str.empty()) {
79 ss << '?';
80 } else {
81 ss << '@' << str;
82 }
83}
84
85void BuildShortTypeString::appendSymRef(SymbolRefAttr sa) {
86 appendSymName(sa.getRootReference().getValue());
87 for (FlatSymbolRefAttr nestedRef : sa.getNestedReferences()) {
88 ss << "::";
89 appendSymName(nestedRef.getValue());
90 }
91}
92
93BuildShortTypeString &BuildShortTypeString::append(Type type) {
94 size_t position = ret.size();
95 (void)position; // tell compiler it's intentionally unused in builds without assertions
96
97 struct Impl : LLZKTypeSwitch<Impl, void> {
98 BuildShortTypeString &outer;
99 Impl(BuildShortTypeString &outerRef) : outer(outerRef) {}
100
101 void caseInvalid(Type) { outer.ss << "!INVALID"; }
102 void caseBool(IntegerType) { outer.ss << 'b'; }
103 void caseIndex(IndexType) { outer.ss << 'i'; }
104 void caseFelt(FeltType) { outer.ss << 'f'; }
105 void caseString(StringType) { outer.ss << 's'; }
106 void caseTypeVar(TypeVarType t) {
107 outer.ss << "!t<";
108 outer.appendSymName(llvm::cast<TypeVarType>(t).getRefName());
109 outer.ss << '>';
110 }
111 void caseArray(ArrayType t) {
112 outer.ss << "!a<";
113 outer.append(t.getElementType());
114 outer.ss << ':';
115 outer.append(t.getDimensionSizes());
116 outer.ss << '>';
117 }
118 void casePod(PodType t) {
119 outer.ss << "!r<";
120 for (auto record : t.getRecords()) {
121 outer.appendSymRef(record.getNameSym());
122 }
123 outer.ss << '>';
124 }
125 void caseStruct(StructType t) {
126 outer.ss << "!s<";
127 outer.appendSymRef(t.getNameRef());
128 if (ArrayAttr params = t.getParams()) {
129 outer.ss << '_';
130 outer.append(params.getValue());
131 }
132 outer.ss << '>';
133 }
134 };
135 Impl(*this).match(type);
136
137 assert(
138 ret.find(PLACEHOLDER, position) == std::string::npos &&
139 "formatting a Type should not produce the 'PLACEHOLDER' char"
140 );
141 return *this;
142}
143
144BuildShortTypeString &BuildShortTypeString::append(Attribute a) {
145 // Special case for inserting the `PLACEHOLDER`
146 if (a == nullptr) {
147 ss << PLACEHOLDER;
148 return *this;
149 }
150
151 size_t position = ret.size();
152 (void)position; // tell compiler it's intentionally unused in builds without assertions
153
154 // Adapted from AsmPrinter::Impl::printAttributeImpl()
155 if (auto ia = llvm::dyn_cast<IntegerAttr>(a)) {
156 Type ty = ia.getType();
157 bool isUnsigned = ty.isUnsignedInteger() || ty.isSignlessInteger(1);
158 ia.getValue().print(ss, !isUnsigned);
159 } else if (auto sra = llvm::dyn_cast<SymbolRefAttr>(a)) {
160 appendSymRef(sra);
161 } else if (auto ta = llvm::dyn_cast<TypeAttr>(a)) {
162 append(ta.getValue());
163 } else if (auto ama = llvm::dyn_cast<AffineMapAttr>(a)) {
164 ss << "!m<";
165 // Filter to remove spaces from the affine_map representation
166 filtered_raw_ostream fs(ss, [](char c) { return c == ' '; });
167 ama.getValue().print(fs);
168 fs.flush();
169 ss << '>';
170 } else if (auto aa = llvm::dyn_cast<ArrayAttr>(a)) {
171 append(aa.getValue());
172 } else {
173 // All valid/legal cases must be covered above
175 }
176 assert(
177 ret.find(PLACEHOLDER, position) == std::string::npos &&
178 "formatting a non-null Attribute should not produce the 'PLACEHOLDER' char"
179 );
180 return *this;
181}
182
183BuildShortTypeString &BuildShortTypeString::append(ArrayRef<Attribute> attrs) {
184 llvm::interleave(attrs, ss, [this](Attribute a) { append(a); }, "_");
185 return *this;
186}
187
188std::string BuildShortTypeString::from(const std::string &base, ArrayRef<Attribute> attrs) {
189 BuildShortTypeString bldr;
190
191 bldr.ret.reserve(base.size() + attrs.size()); // reserve minimum space required
192
193 // First handle replacements of PLACEHOLDER
194 auto END = attrs.end();
195 auto IT = attrs.begin();
196 {
197 size_t start = 0;
198 for (size_t pos; (pos = base.find(PLACEHOLDER, start)) != std::string::npos; start = pos + 1) {
199 // Append original up to the PLACEHOLDER
200 bldr.ret.append(base, start, pos - start);
201 // Append the formatted Attribute
202 assert(IT != END && "must have an Attribute for every 'PLACEHOLDER' char");
203 bldr.append(*IT++);
204 }
205 // Append remaining suffix of the original
206 bldr.ret.append(base, start, base.size() - start);
207 }
208
209 // Append any remaining Attributes
210 if (IT != END) {
211 bldr.ss << '_';
212 bldr.append(ArrayRef(IT, END));
213 }
214
215 return bldr.ret;
216}
217
218namespace {
219
220template <typename... Types> class TypeList {
221
223 template <typename StreamType> struct Appender {
224
225 // single
226 template <typename Ty> static inline void append(StreamType &stream) {
227 stream << '\'' << Ty::name << '\'';
228 }
229
230 // multiple
231 template <typename First, typename Second, typename... Rest>
232 static void append(StreamType &stream) {
233 append<First>(stream);
234 stream << ", ";
235 append<Second, Rest...>(stream);
236 }
237
238 // full list with wrapping brackets
239 static inline void append(StreamType &stream) {
240 stream << '[';
241 append<Types...>(stream);
242 stream << ']';
243 }
244 };
245
246public:
247 // Checks if the provided value is an instance of any of `Types`
248 template <typename T> static inline bool matches(const T &value) {
249 return llvm::isa_and_present<Types...>(value);
250 }
251
252 static void reportInvalid(EmitErrorFn emitError, const Twine &foundName, const char *aspect) {
253 InFlightDiagnosticWrapper diag = emitError().append(aspect, " must be one of ");
254 Appender<InFlightDiagnosticWrapper>::append(diag);
255 diag.append(" but found '", foundName, '\'').report();
256 }
257
258 static inline void reportInvalid(EmitErrorFn emitError, Attribute found, const char *aspect) {
259 if (emitError) {
260 reportInvalid(emitError, found ? found.getAbstractAttribute().getName() : "nullptr", aspect);
261 }
262 }
263
264 // Returns a comma-separated list formatted string of the names of `Types`
265 static inline std::string getNames() {
266 return buildStringViaCallback(Appender<llvm::raw_string_ostream>::append);
267 }
268};
269
272template <class... Ts> struct make_unique {
273 using type = TypeList<Ts...>;
274};
275
276template <class... Ts> struct make_unique<TypeList<>, Ts...> : make_unique<Ts...> {};
277
278template <class U, class... Us, class... Ts>
279struct make_unique<TypeList<U, Us...>, Ts...>
280 : std::conditional_t<
281 (std::is_same_v<U, Us> || ...) || (std::is_same_v<U, Ts> || ...),
282 make_unique<TypeList<Us...>, Ts...>, make_unique<TypeList<Us...>, Ts..., U>> {};
283
284template <class... Ts> using TypeListUnion = typename make_unique<Ts...>::type;
285
286// Dimensions in the ArrayType must be one of the following:
287// - Integer constants
288// - SymbolRef (flat ref for struct params, non-flat for global constants from another module)
289// - AffineMap (for array created within a loop where size depends on loop variable)
290using ArrayDimensionTypes = TypeList<IntegerAttr, SymbolRefAttr, AffineMapAttr>;
291
292// Parameters in the StructType must be one of the following:
293// - Integer constants
294// - SymbolRef (flat ref for struct params, non-flat for global constants from another module)
295// - Type
296// - AffineMap (for array of non-homogeneous structs)
297using StructParamTypes = TypeList<IntegerAttr, SymbolRefAttr, TypeAttr, AffineMapAttr>;
298
299class AllowedTypes {
300 struct ColumnCheckData {
301 SymbolTableCollection *symbolTable = nullptr;
302 Operation *op = nullptr;
303 };
304
305 bool no_felt : 1 = false;
306 bool no_string : 1 = false;
307 bool no_struct : 1 = false;
308 bool no_array : 1 = false;
309 bool no_pod : 1 = false;
310 bool no_var : 1 = false;
311 bool no_int : 1 = false;
312 bool no_struct_params : 1 = false;
313 bool must_be_column : 1 = false;
314
315 ColumnCheckData columnCheck;
316
320 bool validColumns(StructType s) {
321 if (!must_be_column) {
322 return true;
323 }
324 assert(columnCheck.symbolTable);
325 assert(columnCheck.op);
326 return succeeded(s.hasColumns(*columnCheck.symbolTable, columnCheck.op));
327 }
328
329public:
330 constexpr AllowedTypes &noFelt() {
331 no_felt = true;
332 return *this;
333 }
334
335 constexpr AllowedTypes &noString() {
336 no_string = true;
337 return *this;
338 }
339
340 constexpr AllowedTypes &noStruct() {
341 no_struct = true;
342 return *this;
343 }
344
345 constexpr AllowedTypes &noArray() {
346 no_array = true;
347 return *this;
348 }
349
350 constexpr AllowedTypes &noPod() {
351 no_pod = true;
352 return *this;
353 }
354
355 constexpr AllowedTypes &noVar() {
356 no_var = true;
357 return *this;
358 }
359
360 constexpr AllowedTypes &noInt() {
361 no_int = true;
362 return *this;
363 }
364
365 constexpr AllowedTypes &noStructParams(bool noStructParams = true) {
366 no_struct_params = noStructParams;
367 return *this;
368 }
369
370 constexpr AllowedTypes &onlyInt() {
371 no_int = false;
372 return noFelt().noString().noStruct().noArray().noPod().noVar();
373 }
374
375 constexpr AllowedTypes &mustBeColumn(SymbolTableCollection &symbolTable, Operation *op) {
376 must_be_column = true;
377 columnCheck.symbolTable = &symbolTable;
378 columnCheck.op = op;
379 return *this;
380 }
381
382 // This is the main check for allowed types.
383 bool isValidTypeImpl(Type type);
384
385 bool areValidArrayDimSizes(ArrayRef<Attribute> dimensionSizes, EmitErrorFn emitError = nullptr) {
386 // In LLZK, the number of array dimensions must always be known, i.e., `hasRank()==true`
387 if (dimensionSizes.empty()) {
388 if (emitError) {
389 emitError().append("array must have at least one dimension").report();
390 }
391 return false;
392 }
393 // Rather than immediately returning on failure, we check all dimensions and aggregate to
394 // provide as many errors are possible in a single verifier run.
395 bool success = true;
396 for (Attribute a : dimensionSizes) {
397 if (!ArrayDimensionTypes::matches(a)) {
398 ArrayDimensionTypes::reportInvalid(emitError, a, "Array dimension");
399 success = false;
400 } else if (no_var && !llvm::isa_and_present<IntegerAttr>(a)) {
401 TypeList<IntegerAttr>::reportInvalid(emitError, a, "Concrete array dimension");
402 success = false;
403 } else if (failed(verifyAffineMapAttrType(emitError, a))) {
404 success = false;
405 } else if (failed(verifyIntAttrType(emitError, a))) {
406 success = false;
407 }
408 }
409 return success;
410 }
411
412 bool isValidArrayElemTypeImpl(Type type) {
413 // ArrayType element can be any valid type sans ArrayType itself.
414 return !llvm::isa<ArrayType>(type) && isValidTypeImpl(type);
415 }
416
417 bool isValidArrayTypeImpl(
418 Type elementType, ArrayRef<Attribute> dimensionSizes, EmitErrorFn emitError = nullptr
419 ) {
420 if (!areValidArrayDimSizes(dimensionSizes, emitError)) {
421 return false;
422 }
423
424 // Ensure array element type is valid
425 if (!isValidArrayElemTypeImpl(elementType)) {
426 if (emitError) {
427 // Print proper message if `elementType` is not a valid LLZK type or
428 // if it's simply not the right kind of type for an array element.
429 if (succeeded(checkValidType(emitError, elementType))) {
430 emitError()
431 .append(
432 '\'', ArrayType::name, "' element type cannot be '",
433 elementType.getAbstractType().getName(), '\''
434 )
435 .report();
436 }
437 }
438 return false;
439 }
440 return true;
441 }
442
443 bool isValidArrayTypeImpl(Type type) {
444 if (ArrayType arrTy = llvm::dyn_cast<ArrayType>(type)) {
445 return isValidArrayTypeImpl(arrTy.getElementType(), arrTy.getDimensionSizes());
446 }
447 return false;
448 }
449
450 // Note: The `no*` flags here refer to Types nested within a TypeAttr parameter (if any) except
451 // for the `no_struct_params` flag which requires that `params` is null or empty.
452 bool areValidStructTypeParams(ArrayAttr params, EmitErrorFn emitError = nullptr) {
453 if (isNullOrEmpty(params)) {
454 return true;
455 }
456 if (no_struct_params) {
457 return false;
458 }
459 bool success = true;
460 for (Attribute p : params) {
461 if (!StructParamTypes::matches(p)) {
462 StructParamTypes::reportInvalid(emitError, p, "Struct parameter");
463 success = false;
464 } else if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(p)) {
465 if (!isValidTypeImpl(tyAttr.getValue())) {
466 if (emitError) {
467 emitError().append("expected a valid LLZK type but found ", tyAttr.getValue()).report();
468 }
469 success = false;
470 }
471 } else if (no_var && !llvm::isa<IntegerAttr>(p)) {
472 TypeList<IntegerAttr>::reportInvalid(emitError, p, "Concrete struct parameter");
473 success = false;
474 } else if (failed(verifyAffineMapAttrType(emitError, p))) {
475 success = false;
476 } else if (failed(verifyIntAttrType(emitError, p))) {
477 success = false;
478 }
479 }
480
481 return success;
482 }
483
484 bool areValidPodRecords(ArrayRef<RecordAttr> records) {
485 return llvm::all_of(records, [this](auto record) { return isValidTypeImpl(record.getType()); });
486 }
487};
488
489bool AllowedTypes::isValidTypeImpl(Type type) {
490 assert(
491 !(no_int && no_felt && no_string && no_var && no_struct && no_array && no_pod) &&
492 "All types have been deactivated"
493 );
494 struct Impl : LLZKTypeSwitch<Impl, bool> {
495 AllowedTypes &outer;
496 Impl(AllowedTypes &outerRef) : outer(outerRef) {}
497
498 bool caseBool(IntegerType t) { return !outer.no_int && t.isSignlessInteger(1); }
499 bool caseIndex(IndexType) { return !outer.no_int; }
500 bool caseFelt(FeltType) { return !outer.no_felt; }
501 bool caseString(StringType) { return !outer.no_string; }
502 bool caseTypeVar(TypeVarType) { return !outer.no_var; }
503 bool caseArray(ArrayType t) {
504 return !outer.no_array &&
505 outer.isValidArrayTypeImpl(t.getElementType(), t.getDimensionSizes());
506 }
507 bool casePod(PodType t) { return !outer.no_pod && outer.areValidPodRecords(t.getRecords()); }
508 bool caseStruct(StructType t) {
509 // Note: The `no*` flags here refer to Types nested within a TypeAttr parameter.
510 if (outer.no_struct || !outer.validColumns(t)) {
511 return false;
512 }
513 return !outer.no_struct && outer.areValidStructTypeParams(t.getParams());
514 }
515 bool caseInvalid(Type _) { return false; }
516 };
517 return Impl(*this).match(type);
518}
519
520} // namespace
521
522bool isValidType(Type type) { return AllowedTypes().isValidTypeImpl(type); }
523
524bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op) {
525 return AllowedTypes().noString().noInt().mustBeColumn(symbolTable, op).isValidTypeImpl(type);
526}
527
528bool isValidGlobalType(Type type) { return AllowedTypes().noVar().isValidTypeImpl(type); }
529
530bool isValidEmitEqType(Type type) {
531 return AllowedTypes().noString().noStruct().isValidTypeImpl(type);
532}
533
534// Allowed types must align with StructParamTypes (defined below)
535bool isValidConstReadType(Type type) {
536 return AllowedTypes().noString().noStruct().noArray().isValidTypeImpl(type);
537}
538
539bool isValidArrayElemType(Type type) { return AllowedTypes().isValidArrayElemTypeImpl(type); }
540
541bool isValidArrayType(Type type) { return AllowedTypes().isValidArrayTypeImpl(type); }
542
543bool isConcreteType(Type type, bool allowStructParams) {
544 return AllowedTypes().noVar().noStructParams(!allowStructParams).isValidTypeImpl(type);
545}
546
547bool hasAffineMapAttr(Type type) {
548 bool encountered = false;
549 type.walk([&](AffineMapAttr a) {
550 encountered = true;
551 return WalkResult::interrupt();
552 });
553 return encountered;
554}
555
556bool isDynamic(IntegerAttr intAttr) { return ShapedType::isDynamic(fromAPInt(intAttr.getValue())); }
557
558uint64_t computeEmitEqCardinality(Type type) {
559 struct Impl : LLZKTypeSwitch<Impl, uint64_t> {
560 uint64_t caseBool(IntegerType) { return 1; }
561 uint64_t caseIndex(IndexType) { return 1; }
562 uint64_t caseFelt(FeltType) { return 1; }
563 uint64_t caseArray(ArrayType t) {
564 int64_t n = t.getNumElements();
565 assert(n >= 0);
566 return static_cast<uint64_t>(n) * computeEmitEqCardinality(t.getElementType());
567 }
568 uint64_t caseStruct(StructType t) { llvm_unreachable("not a valid EmitEq type"); }
569 uint64_t casePod(PodType t) {
570 return std::accumulate(
571 t.getRecords().begin(), t.getRecords().end(), 0,
572 [](const uint64_t &acc, const RecordAttr &record) {
573 return computeEmitEqCardinality(record.getType()) + acc;
574 }
575 );
576 }
577 uint64_t caseString(StringType) { llvm_unreachable("not a valid EmitEq type"); }
578 uint64_t caseTypeVar(TypeVarType) { llvm_unreachable("tvar has unknown cardinality"); }
579 uint64_t caseInvalid(Type) { llvm_unreachable("not a valid LLZK type"); }
580 };
581 return Impl().match(type);
582}
583
584namespace {
585
594using AffineInstantiations = DenseMap<std::pair<AffineMapAttr, Side>, IntegerAttr>;
595
596struct UnifierImpl {
597 ArrayRef<StringRef> rhsRevPrefix;
598 UnificationMap *unifications;
599 AffineInstantiations *affineToIntTracker;
600 // This optional function can be used to provide an exception to the standard unification
601 // rules and return a true/success result when it otherwise may not.
602 llvm::function_ref<bool(Type oldTy, Type newTy)> overrideSuccess;
603
604 UnifierImpl(UnificationMap *unificationMap, ArrayRef<StringRef> rhsReversePrefix = {})
605 : rhsRevPrefix(rhsReversePrefix), unifications(unificationMap), affineToIntTracker(nullptr),
606 overrideSuccess(nullptr) {}
607
608 bool typeParamsUnify(
609 const ArrayRef<Attribute> &lhsParams, const ArrayRef<Attribute> &rhsParams,
610 bool unifyDynamicSize = false
611 ) {
612 auto pred = [this, unifyDynamicSize](auto lhsAttr, auto rhsAttr) {
613 return paramAttrUnify(lhsAttr, rhsAttr, unifyDynamicSize);
614 };
615 return (lhsParams.size() == rhsParams.size()) &&
616 std::equal(lhsParams.begin(), lhsParams.end(), rhsParams.begin(), pred);
617 }
618
619 UnifierImpl &trackAffineToInt(AffineInstantiations *tracker) {
620 this->affineToIntTracker = tracker;
621 return *this;
622 }
623
624 UnifierImpl &withOverrides(llvm::function_ref<bool(Type oldTy, Type newTy)> overrides) {
625 this->overrideSuccess = overrides;
626 return *this;
627 }
628
633 bool typeParamsUnify(
634 const ArrayAttr &lhsParams, const ArrayAttr &rhsParams, bool unifyDynamicSize = false
635 ) {
636 ArrayRef<Attribute> emptyParams;
637 return typeParamsUnify(
638 lhsParams ? lhsParams.getValue() : emptyParams,
639 rhsParams ? rhsParams.getValue() : emptyParams, unifyDynamicSize
640 );
641 }
642
643 bool arrayTypesUnify(ArrayType lhs, ArrayType rhs) {
644 // Check if the element types of the two arrays can unify
645 if (!typesUnify(lhs.getElementType(), rhs.getElementType())) {
646 return false;
647 }
648 // Check if the dimension size attributes unify between the LHS and RHS
649 return typeParamsUnify(
650 lhs.getDimensionSizes(), rhs.getDimensionSizes(), /*unifyDynamicSize=*/true
651 );
652 }
653
654 bool structTypesUnify(StructType lhs, StructType rhs) {
655 LLVM_DEBUG({
656 llvm::dbgs() << "[structTypesUnify] lhs = " << lhs << ", rhs = " << rhs << '\n';
657 });
658 // Check if it references the same StructDefOp, considering the additional RHS path prefix.
659 SmallVector<StringRef> rhsNames = getNames(rhs.getNameRef());
660 rhsNames.insert(rhsNames.begin(), rhsRevPrefix.rbegin(), rhsRevPrefix.rend());
661 auto lhsNames = getNames(lhs.getNameRef());
662 if (rhsNames != lhsNames) {
663 LLVM_DEBUG({
664 llvm::interleaveComma(
665 lhsNames, llvm::dbgs() << "[structTypesUnify] names do not match\n"
666 << " lhsNames = ["
667 );
668 llvm::interleaveComma(
669 rhsNames, llvm::dbgs() << "]\n"
670 << " rhsNames = ["
671 );
672 llvm::dbgs() << "]\n";
673 });
674 return false;
675 }
676 LLVM_DEBUG({ llvm::dbgs() << "[structTypesUnify] checking unification of parameters\n"; });
677 // Check if the parameters unify between the LHS and RHS
678 return typeParamsUnify(lhs.getParams(), rhs.getParams(), /*unifyDynamicSize=*/false);
679 }
680
681 bool podTypesUnify(PodType lhs, PodType rhs) {
682 // Same number of records, with the same names in the same order and record types unify.
683 auto lhsRecords = lhs.getRecords();
684 auto rhsRecords = rhs.getRecords();
685
686 return lhsRecords.size() == rhsRecords.size() &&
687 llvm::all_of(llvm::zip_equal(lhsRecords, rhsRecords), [this](auto &&records) {
688 auto &&[lhsRecord, rhsRecord] = records;
689 return lhsRecord.getName() == rhsRecord.getName() &&
690 typesUnify(lhsRecord.getType(), rhsRecord.getType());
691 });
692 }
693
694 bool typesUnify(Type lhs, Type rhs) {
695 if (lhs == rhs) {
696 return true;
697 }
698 if (overrideSuccess && overrideSuccess(lhs, rhs)) {
699 return true;
700 }
701 // A type variable can be any type, thus it unifies with anything.
702 if (TypeVarType lhsTvar = llvm::dyn_cast<TypeVarType>(lhs)) {
703 track(Side::LHS, lhsTvar.getNameRef(), rhs);
704 return true;
705 }
706 if (TypeVarType rhsTvar = llvm::dyn_cast<TypeVarType>(rhs)) {
707 track(Side::RHS, rhsTvar.getNameRef(), lhs);
708 return true;
709 }
710 if (llvm::isa<StructType>(lhs) && llvm::isa<StructType>(rhs)) {
711 return structTypesUnify(llvm::cast<StructType>(lhs), llvm::cast<StructType>(rhs));
712 }
713 if (llvm::isa<ArrayType>(lhs) && llvm::isa<ArrayType>(rhs)) {
714 return arrayTypesUnify(llvm::cast<ArrayType>(lhs), llvm::cast<ArrayType>(rhs));
715 }
716 if (llvm::isa<PodType>(lhs) && llvm::isa<PodType>(rhs)) {
717 return podTypesUnify(llvm::cast<PodType>(lhs), llvm::cast<PodType>(rhs));
718 }
719 return false;
720 }
721
722private:
723 template <typename Tracker, typename Key, typename Val>
724 inline void track(Tracker &tracker, Side side, Key keyHead, Val val) {
725 auto key = std::make_pair(keyHead, side);
726 auto it = tracker.find(key);
727 if (it == tracker.end()) {
728 tracker.try_emplace(key, val);
729 } else if (it->getSecond() != val) {
730 it->second = nullptr;
731 }
732 }
733
734 void track(Side side, SymbolRefAttr symRef, Type ty) {
735 if (unifications) {
736 Attribute attr;
737 if (TypeVarType tvar = dyn_cast<TypeVarType>(ty)) {
738 // If 'ty' is TypeVarType<@S>, just map to @S directly.
739 attr = tvar.getNameRef();
740 } else {
741 // Otherwise wrap as a TypeAttr.
742 attr = TypeAttr::get(ty);
743 }
744 assert(symRef);
745 assert(attr);
746 track(*unifications, side, symRef, attr);
747 }
748 }
749
750 void track(Side side, SymbolRefAttr symRef, Attribute attr) {
751 if (unifications) {
752 // If 'attr' is TypeAttr<TypeVarType<@S>>, just map to @S directly.
753 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(attr)) {
754 if (TypeVarType tvar = dyn_cast<TypeVarType>(tyAttr.getValue())) {
755 attr = tvar.getNameRef();
756 }
757 }
758 assert(symRef);
759 assert(attr);
760 // If 'attr' is a SymbolRefAttr, map in both directions for the correctness of
761 // `isMoreConcreteUnification()` which relies on RHS check while other external
762 // checks on the UnificationMap may do LHS checks, and in the case of both being
763 // SymbolRefAttr, unification in either direction is possible.
764 if (SymbolRefAttr otherSymAttr = dyn_cast<SymbolRefAttr>(attr)) {
765 track(*unifications, reverse(side), otherSymAttr, symRef);
766 }
767 track(*unifications, side, symRef, attr);
768 }
769 }
770
771 void track(Side side, AffineMapAttr affineAttr, IntegerAttr intAttr) {
772 if (affineToIntTracker) {
773 assert(affineAttr);
774 assert(intAttr);
775 assert(!isDynamic(intAttr));
776 track(*affineToIntTracker, side, affineAttr, intAttr);
777 }
778 }
779
780 bool paramAttrUnify(Attribute lhsAttr, Attribute rhsAttr, bool unifyDynamicSize = false) {
783 // Straightforward equality check.
784 if (lhsAttr == rhsAttr) {
785 return true;
786 }
787 // AffineMapAttr can unify with IntegerAttr (other than kDynamic) because struct parameter
788 // instantiation will result in conversion of AffineMapAttr to IntegerAttr.
789 if (AffineMapAttr lhsAffine = llvm::dyn_cast<AffineMapAttr>(lhsAttr)) {
790 if (IntegerAttr rhsInt = llvm::dyn_cast<IntegerAttr>(rhsAttr)) {
791 if (!isDynamic(rhsInt)) {
792 track(Side::LHS, lhsAffine, rhsInt);
793 return true;
794 }
795 }
796 }
797 if (AffineMapAttr rhsAffine = llvm::dyn_cast<AffineMapAttr>(rhsAttr)) {
798 if (IntegerAttr lhsInt = llvm::dyn_cast<IntegerAttr>(lhsAttr)) {
799 if (!isDynamic(lhsInt)) {
800 track(Side::RHS, rhsAffine, lhsInt);
801 return true;
802 }
803 }
804 }
805 // If either side is a SymbolRefAttr, assume they unify because either flattening or a pass with
806 // a more involved value analysis is required to check if they are actually the same value.
807 if (SymbolRefAttr lhsSymRef = llvm::dyn_cast<SymbolRefAttr>(lhsAttr)) {
808 track(Side::LHS, lhsSymRef, rhsAttr);
809 return true;
810 }
811 if (SymbolRefAttr rhsSymRef = llvm::dyn_cast<SymbolRefAttr>(rhsAttr)) {
812 track(Side::RHS, rhsSymRef, lhsAttr);
813 return true;
814 }
815 // If either side is ShapedType::kDynamic then, similarly to Symbols, assume they unify.
816 // NOTE: Dynamic array dimensions (i.e. '?') are allowed in LLZK but should generally be
817 // restricted to scenarios where it can be replaced with a concrete value during the flattening
818 // pass, such as a `unifiable_cast` where the other side of the cast has concrete dimensions or
819 // extern functions with varargs.
820 if (unifyDynamicSize) {
821 auto dyn_cast_if_dynamic = [](Attribute attr) -> IntegerAttr {
822 if (IntegerAttr intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
823 if (isDynamic(intAttr)) {
824 return intAttr;
825 }
826 }
827 return nullptr;
828 };
829 auto is_const_like = [](Attribute attr) {
830 return llvm::isa_and_present<IntegerAttr, SymbolRefAttr, AffineMapAttr>(attr);
831 };
832 if (IntegerAttr lhsIntAttr = dyn_cast_if_dynamic(lhsAttr)) {
833 if (is_const_like(rhsAttr)) {
834 return true;
835 }
836 }
837 if (IntegerAttr rhsIntAttr = dyn_cast_if_dynamic(rhsAttr)) {
838 if (is_const_like(lhsAttr)) {
839 return true;
840 }
841 }
842 }
843 // If both are type refs, check for unification of the types.
844 if (TypeAttr lhsTy = llvm::dyn_cast<TypeAttr>(lhsAttr)) {
845 if (TypeAttr rhsTy = llvm::dyn_cast<TypeAttr>(rhsAttr)) {
846 return typesUnify(lhsTy.getValue(), rhsTy.getValue());
847 }
848 }
849 // Otherwise, they do not unify.
850 return false;
851 }
852};
853
854} // namespace
855
857 const ArrayRef<Attribute> &lhsParams, const ArrayRef<Attribute> &rhsParams,
858 UnificationMap *unifications
859) {
860 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
861}
862
866 const ArrayAttr &lhsParams, const ArrayAttr &rhsParams, UnificationMap *unifications
867) {
868 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
869}
870
872 ArrayType lhs, ArrayType rhs, ArrayRef<StringRef> rhsReversePrefix, UnificationMap *unifications
873) {
874 return UnifierImpl(unifications, rhsReversePrefix).arrayTypesUnify(lhs, rhs);
875}
876
878 StructType lhs, StructType rhs, ArrayRef<StringRef> rhsReversePrefix,
879 UnificationMap *unifications
880) {
881 return UnifierImpl(unifications, rhsReversePrefix).structTypesUnify(lhs, rhs);
882}
883
885 Type lhs, Type rhs, ArrayRef<StringRef> rhsReversePrefix, UnificationMap *unifications
886) {
887 return UnifierImpl(unifications, rhsReversePrefix).typesUnify(lhs, rhs);
888}
889
891 Type oldTy, Type newTy, llvm::function_ref<bool(Type oldTy, Type newTy)> knownOldToNew
892) {
893 UnificationMap unifications;
894 AffineInstantiations affineInstantiations;
895 // Run type unification with the addition that affine map can become integer in the new type.
896 if (!UnifierImpl(&unifications)
897 .trackAffineToInt(&affineInstantiations)
898 .withOverrides(knownOldToNew)
899 .typesUnify(oldTy, newTy)) {
900 return false;
901 }
902
903 // If either map contains RHS-keyed mappings then the old type is "more concrete" than the new.
904 // In the UnificationMap, a RHS key would indicate that the new type contains a SymbolRef (i.e.
905 // the "least concrete" attribute kind) where the old type contained any other attribute. In the
906 // AffineInstantiations map, a RHS key would indicate that the new type contains an AffineMapAttr
907 // where the old type contains an IntegerAttr.
908 auto entryIsRHS = [](const auto &entry) { return entry.first.second == Side::RHS; };
909 return !llvm::any_of(unifications, entryIsRHS) && !llvm::any_of(affineInstantiations, entryIsRHS);
910}
911
912FailureOr<IntegerAttr> forceIntType(IntegerAttr attr, EmitErrorFn emitError) {
913 if (llvm::isa<IndexType>(attr.getType())) {
914 return attr;
915 }
916 // Ensure the APInt is the right bitwidth for IndexType or else
917 // IntegerAttr::verify(..) will report an error.
918 APInt value = attr.getValue();
919 auto compare = value.getBitWidth() <=> IndexType::kInternalStorageBitWidth;
920 if (compare < 0) {
921 value = value.zext(IndexType::kInternalStorageBitWidth);
922 } else if (compare > 0) {
923 return emitError().append("value is too large for `index` type: ", debug::toStringOne(value));
924 }
925 return IntegerAttr::get(IndexType::get(attr.getContext()), value);
926}
927
928FailureOr<Attribute> forceIntAttrType(Attribute attr, EmitErrorFn emitError) {
929 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(attr)) {
930 return forceIntType(intAttr, emitError);
931 }
932 return attr;
933}
934
935FailureOr<SmallVector<Attribute>>
936forceIntAttrTypes(ArrayRef<Attribute> attrList, EmitErrorFn emitError) {
937 SmallVector<Attribute> result;
938 for (Attribute attr : attrList) {
939 FailureOr<Attribute> forced = forceIntAttrType(attr, emitError);
940 if (failed(forced)) {
941 return failure();
942 }
943 result.push_back(*forced);
944 }
945 return result;
946}
947
948LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in) {
949 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(in)) {
950 Type attrTy = intAttr.getType();
951 if (!AllowedTypes().onlyInt().isValidTypeImpl(attrTy)) {
952 if (emitError) {
953 emitError()
954 .append("IntegerAttr must have type 'index' or 'i1' but found '", attrTy, '\'')
955 .report();
956 }
957 return failure();
958 }
959 }
960 return success();
961}
962
963LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in) {
964 if (AffineMapAttr affineAttr = llvm::dyn_cast_if_present<AffineMapAttr>(in)) {
965 AffineMap map = affineAttr.getValue();
966 if (map.getNumResults() != 1) {
967 if (emitError) {
968 emitError()
969 .append(
970 "AffineMapAttr must yield a single result, but found ", map.getNumResults(),
971 " results"
972 )
973 .report();
974 }
975 return failure();
976 }
977 }
978 return success();
979}
980
981LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params) {
982 return success(AllowedTypes().areValidStructTypeParams(params, emitError));
983}
984
985LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef<Attribute> dimensionSizes) {
986 return success(AllowedTypes().areValidArrayDimSizes(dimensionSizes, emitError));
987}
988
989LogicalResult
990verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef<Attribute> dimensionSizes) {
991 return success(AllowedTypes().isValidArrayTypeImpl(elementType, dimensionSizes, emitError));
992}
993
994void assertValidAttrForParamOfType(Attribute attr) {
995 // Must be the union of valid attribute types within ArrayType, StructType, and TypeVarType.
996 using TypeVarAttrs = TypeList<SymbolRefAttr>; // per ODS spec of TypeVarType
997 if (!TypeListUnion<ArrayDimensionTypes, StructParamTypes, TypeVarAttrs>::matches(attr)) {
998 llvm::report_fatal_error(
999 "Legal type parameters are inconsistent. Encountered " +
1000 attr.getAbstractAttribute().getName()
1001 );
1002 }
1003}
1004
1005LogicalResult
1006verifySubArrayType(EmitErrorFn emitError, ArrayType arrayType, ArrayType subArrayType) {
1007 ArrayRef<Attribute> dimsFromArr = arrayType.getDimensionSizes();
1008 size_t numArrDims = dimsFromArr.size();
1009 ArrayRef<Attribute> dimsFromSubArr = subArrayType.getDimensionSizes();
1010 size_t numSubArrDims = dimsFromSubArr.size();
1011
1012 if (numArrDims < numSubArrDims) {
1013 return emitError().append(
1014 "subarray type ", subArrayType, " has more dimensions than array type ", arrayType
1015 );
1016 }
1017
1018 size_t toDrop = numArrDims - numSubArrDims;
1019 ArrayRef<Attribute> dimsFromArrReduced = dimsFromArr.drop_front(toDrop);
1020
1021 // Ensure dimension sizes are compatible (ignoring the indexed dimensions)
1022 if (!typeParamsUnify(dimsFromArrReduced, dimsFromSubArr)) {
1023 std::string message;
1024 llvm::raw_string_ostream ss(message);
1025 auto appendOne = [&ss](Attribute a) { appendWithoutType(ss, a); };
1026 ss << "cannot unify array dimensions [";
1027 llvm::interleaveComma(dimsFromArrReduced, ss, appendOne);
1028 ss << "] with [";
1029 llvm::interleaveComma(dimsFromSubArr, ss, appendOne);
1030 ss << "]";
1031 return emitError().append(message);
1032 }
1033
1034 // Ensure element types of the arrays are compatible
1035 if (!typesUnify(arrayType.getElementType(), subArrayType.getElementType())) {
1036 return emitError().append(
1037 "incorrect array element type; expected: ", arrayType.getElementType(),
1038 ", found: ", subArrayType.getElementType()
1039 );
1040 }
1041
1042 return success();
1043}
1044
1045LogicalResult
1046verifySubArrayOrElementType(EmitErrorFn emitError, ArrayType arrayType, Type subArrayOrElemType) {
1047 if (auto subArrayType = llvm::dyn_cast<ArrayType>(subArrayOrElemType)) {
1048 return verifySubArrayType(emitError, arrayType, subArrayType);
1049 }
1050 if (!typesUnify(arrayType.getElementType(), subArrayOrElemType)) {
1051 return emitError().append(
1052 "incorrect array element type; expected: ", arrayType.getElementType(),
1053 ", found: ", subArrayOrElemType
1054 );
1055 }
1056
1057 return success();
1058}
1059
1060} // namespace llzk
Note: If any symbol refs in an input Type/Attribute use any of the special characters that this class...
Definition TypeHelper.h:36
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:52
::mlir::Type getElementType() const
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:51
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
std::string toStringOne(const T &value)
Definition Debug.h:175
LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in)
void assertValidAttrForParamOfType(Attribute attr)
LogicalResult verifySubArrayType(EmitErrorFn emitError, ArrayType arrayType, ArrayType subArrayType)
Determine if the subArrayType is a valid subarray of arrayType.
FailureOr< Attribute > forceIntAttrType(Attribute attr, EmitErrorFn emitError)
uint64_t computeEmitEqCardinality(Type type)
bool isValidArrayType(Type type)
LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in)
bool isConcreteType(Type type, bool allowStructParams)
bool isValidArrayElemType(Type type)
llvm::SmallVector< StringRef > getNames(SymbolRefAttr ref)
bool isValidGlobalType(Type type)
FailureOr< IntegerAttr > forceIntType(IntegerAttr attr, EmitErrorFn emitError)
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
LogicalResult verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef< Attribute > dimensionSizes)
LogicalResult verifySubArrayOrElementType(EmitErrorFn emitError, ArrayType arrayType, Type subArrayOrElemType)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:183
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
FailureOr< SmallVector< Attribute > > forceIntAttrTypes(ArrayRef< Attribute > attrList, EmitErrorFn emitError)
bool isNullOrEmpty(mlir::ArrayAttr a)
bool isValidEmitEqType(Type type)
bool isValidType(Type type)
bool arrayTypesUnify(ArrayType lhs, ArrayType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isDynamic(IntegerAttr intAttr)
Side reverse(Side in)
Definition TypeHelper.h:141
int64_t fromAPInt(const llvm::APInt &i)
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool typeParamsUnify(const ArrayRef< Attribute > &lhsParams, const ArrayRef< Attribute > &rhsParams, UnificationMap *unifications)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params)
void appendWithoutType(mlir::raw_ostream &os, mlir::Attribute a)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
bool hasAffineMapAttr(Type type)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
Definition TypeHelper.h:111
bool isValidConstReadType(Type type)
LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef< Attribute > dimensionSizes)
Template pattern for performing some operation by cases based on a given LLZK type.
ResultType match(Type type)