LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
TypeHelper.cpp
Go to the documentation of this file.
1//===-- TypeHelper.cpp ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
21#include "llzk/Util/Compare.h"
22#include "llzk/Util/Debug.h"
25
26#include <llvm/ADT/STLExtras.h>
27#include <llvm/ADT/SmallVector.h>
28#include <llvm/ADT/TypeSwitch.h>
29#include <llvm/Support/Debug.h>
30
31#include <cstdint>
32#include <numeric>
33
34#define DEBUG_TYPE "llzk-type-helpers"
35
36using namespace mlir;
37
38namespace llzk {
39
40using namespace array;
41using namespace component;
42using namespace felt;
43using namespace polymorphic;
44using namespace string;
45using namespace pod;
46
49template <typename Derived, typename ResultType> struct LLZKTypeSwitch {
50 inline ResultType match(Type type) {
51 return llvm::TypeSwitch<Type, ResultType>(type)
52 .template Case<IndexType>([this](auto t) {
53 return static_cast<Derived *>(this)->caseIndex(t);
54 })
55 .template Case<FeltType>([this](auto t) {
56 return static_cast<Derived *>(this)->caseFelt(t);
57 })
58 .template Case<StringType>([this](auto t) {
59 return static_cast<Derived *>(this)->caseString(t);
60 })
61 .template Case<TypeVarType>([this](auto t) {
62 return static_cast<Derived *>(this)->caseTypeVar(t);
63 })
64 .template Case<ArrayType>([this](auto t) {
65 return static_cast<Derived *>(this)->caseArray(t);
66 })
67 .template Case<PodType>([this](auto t) { return static_cast<Derived *>(this)->casePod(t); })
68 .template Case<StructType>([this](auto t) {
69 return static_cast<Derived *>(this)->caseStruct(t);
70 }).Default([this](Type t) {
71 if (t.isSignlessInteger(1)) {
72 return static_cast<Derived *>(this)->caseBool(cast<IntegerType>(t));
73 } else {
74 return static_cast<Derived *>(this)->caseInvalid(t);
75 }
76 });
77 }
78
79private:
80 friend Derived;
81 LLZKTypeSwitch() = default;
82};
83
84void BuildShortTypeString::appendSymName(StringRef str) {
85 if (str.empty()) {
86 ss << '?';
87 } else {
88 ss << '@' << str;
89 }
90}
91
92void BuildShortTypeString::appendSymRef(SymbolRefAttr sa) {
93 appendSymName(sa.getRootReference().getValue());
94 for (FlatSymbolRefAttr nestedRef : sa.getNestedReferences()) {
95 ss << "::";
96 appendSymName(nestedRef.getValue());
97 }
98}
99
100BuildShortTypeString &BuildShortTypeString::append(Type type) {
101 size_t position = ret.size();
102 (void)position; // tell compiler it's intentionally unused in builds without assertions
103
104 struct Impl : LLZKTypeSwitch<Impl, void> {
105 BuildShortTypeString &outer;
106 Impl(BuildShortTypeString &outerRef) : outer(outerRef) {}
107
108 void caseInvalid(Type) { outer.ss << "!INVALID"; }
109 void caseBool(IntegerType) { outer.ss << 'b'; }
110 void caseIndex(IndexType) { outer.ss << 'i'; }
111 void caseFelt(FeltType) { outer.ss << 'f'; }
112 void caseString(StringType) { outer.ss << 's'; }
113 void caseTypeVar(TypeVarType t) {
114 outer.ss << "!t<";
115 outer.appendSymName(llvm::cast<TypeVarType>(t).getRefName());
116 outer.ss << '>';
117 }
118 void caseArray(ArrayType t) {
119 outer.ss << "!a<";
120 outer.append(t.getElementType());
121 outer.ss << ':';
122 outer.append(t.getDimensionSizes());
123 outer.ss << '>';
124 }
125 void casePod(PodType t) {
126 outer.ss << "!r<";
127 for (auto record : t.getRecords()) {
128 outer.appendSymRef(record.getNameSym());
129 }
130 outer.ss << '>';
131 }
132 void caseStruct(StructType t) {
133 outer.ss << "!s<";
134 outer.appendSymRef(t.getNameRef());
135 if (ArrayAttr params = t.getParams()) {
136 outer.ss << '_';
137 outer.append(params.getValue());
138 }
139 outer.ss << '>';
140 }
141 };
142 Impl(*this).match(type);
143
144 assert(
145 ret.find(PLACEHOLDER, position) == std::string::npos &&
146 "formatting a Type should not produce the 'PLACEHOLDER' char"
147 );
148 return *this;
149}
150
151BuildShortTypeString &BuildShortTypeString::append(Attribute a) {
152 // Special case for inserting the `PLACEHOLDER`
153 if (a == nullptr) {
154 ss << PLACEHOLDER;
155 return *this;
156 }
157
158 size_t position = ret.size();
159 (void)position; // tell compiler it's intentionally unused in builds without assertions
160
161 // Adapted from AsmPrinter::Impl::printAttributeImpl()
162 if (auto ia = llvm::dyn_cast<IntegerAttr>(a)) {
163 Type ty = ia.getType();
164 bool isUnsigned = ty.isUnsignedInteger() || ty.isSignlessInteger(1);
165 ia.getValue().print(ss, !isUnsigned);
166 } else if (auto sra = llvm::dyn_cast<SymbolRefAttr>(a)) {
167 appendSymRef(sra);
168 } else if (auto ta = llvm::dyn_cast<TypeAttr>(a)) {
169 append(ta.getValue());
170 } else if (auto ama = llvm::dyn_cast<AffineMapAttr>(a)) {
171 ss << "!m<";
172 // Filter to remove spaces from the affine_map representation
173 filtered_raw_ostream fs(ss, [](char c) { return c == ' '; });
174 ama.getValue().print(fs);
175 fs.flush();
176 ss << '>';
177 } else if (auto aa = llvm::dyn_cast<ArrayAttr>(a)) {
178 append(aa.getValue());
179 } else {
180 // All valid/legal cases must be covered above
182 }
183 assert(
184 ret.find(PLACEHOLDER, position) == std::string::npos &&
185 "formatting a non-null Attribute should not produce the 'PLACEHOLDER' char"
186 );
187 return *this;
188}
189
190BuildShortTypeString &BuildShortTypeString::append(ArrayRef<Attribute> attrs) {
191 llvm::interleave(attrs, ss, [this](Attribute a) { append(a); }, "_");
192 return *this;
193}
194
195std::string BuildShortTypeString::from(const std::string &base, ArrayRef<Attribute> attrs) {
196 BuildShortTypeString bldr;
197
198 bldr.ret.reserve(base.size() + attrs.size()); // reserve minimum space required
199
200 // First handle replacements of PLACEHOLDER
201 const auto *END = attrs.end();
202 const auto *IT = attrs.begin();
203 {
204 size_t start = 0;
205 for (size_t pos; (pos = base.find(PLACEHOLDER, start)) != std::string::npos; start = pos + 1) {
206 // Append original up to the PLACEHOLDER
207 bldr.ret.append(base, start, pos - start);
208 // Append the formatted Attribute
209 assert(IT != END && "must have an Attribute for every 'PLACEHOLDER' char");
210 bldr.append(*IT++);
211 }
212 // Append remaining suffix of the original
213 bldr.ret.append(base, start, base.size() - start);
214 }
215
216 // Append any remaining Attributes
217 if (IT != END) {
218 bldr.ss << '_';
219 bldr.append(ArrayRef(IT, END));
220 }
221
222 return bldr.ret;
223}
224
225namespace {
226
227template <typename... Types> class TypeList {
228
230 template <typename StreamType> struct Appender {
231
232 // single
233 template <typename Ty> static inline void append(StreamType &stream) {
234 stream << '\'' << Ty::name << '\'';
235 }
236
237 // multiple
238 template <typename First, typename Second, typename... Rest>
239 static void append(StreamType &stream) {
240 append<First>(stream);
241 stream << ", ";
242 append<Second, Rest...>(stream);
243 }
244
245 // full list with wrapping brackets
246 static inline void append(StreamType &stream) {
247 stream << '[';
248 append<Types...>(stream);
249 stream << ']';
250 }
251 };
252
253public:
254 // Checks if the provided value is an instance of any of `Types`
255 template <typename T> static inline bool matches(const T &value) {
256 return llvm::isa_and_present<Types...>(value);
257 }
258
259 static void reportInvalid(EmitErrorFn emitError, const Twine &foundName, const char *aspect) {
260 InFlightDiagnosticWrapper diag = emitError().append(aspect, " must be one of ");
261 Appender<InFlightDiagnosticWrapper>::append(diag);
262 diag.append(" but found '", foundName, '\'').report();
263 }
264
265 static inline void reportInvalid(EmitErrorFn emitError, Attribute found, const char *aspect) {
266 if (emitError) {
267 reportInvalid(emitError, found ? found.getAbstractAttribute().getName() : "nullptr", aspect);
268 }
269 }
270
271 // Returns a comma-separated list formatted string of the names of `Types`
272 static inline std::string getNames() {
273 return buildStringViaCallback(Appender<llvm::raw_string_ostream>::append);
274 }
275};
276
279template <class... Ts> struct make_unique {
280 using type = TypeList<Ts...>;
281};
282
283template <class... Ts> struct make_unique<TypeList<>, Ts...> : make_unique<Ts...> {};
284
285template <class U, class... Us, class... Ts>
286struct make_unique<TypeList<U, Us...>, Ts...>
287 : std::conditional_t<
288 (std::is_same_v<U, Us> || ...) || (std::is_same_v<U, Ts> || ...),
289 make_unique<TypeList<Us...>, Ts...>, make_unique<TypeList<Us...>, Ts..., U>> {};
290
291template <class... Ts> using TypeListUnion = typename make_unique<Ts...>::type;
292
293// Dimensions in the ArrayType must be one of the following:
294// - Integer constants
295// - SymbolRef (flat ref for struct params, non-flat for global constants from another module)
296// - AffineMap (for array created within a loop where size depends on loop variable)
297using ArrayDimensionTypes = TypeList<IntegerAttr, SymbolRefAttr, AffineMapAttr>;
298
299// Parameters in the StructType must be one of the following:
300// - Integer constants
301// - Field element constants
302// - SymbolRef (flat ref for struct params, non-flat for global constants from another module)
303// - Type
304// - AffineMap (for array of non-homogeneous structs)
305using StructParamTypes =
306 TypeList<IntegerAttr, FeltConstAttr, SymbolRefAttr, TypeAttr, AffineMapAttr>;
307
308class AllowedTypes {
309 struct ColumnCheckData {
310 SymbolTableCollection *symbolTable = nullptr;
311 Operation *op = nullptr;
312 };
313
314 bool no_felt : 1 = false;
315 bool no_string : 1 = false;
316 bool no_struct : 1 = false;
317 bool no_array : 1 = false;
318 bool no_pod : 1 = false;
319 bool no_var : 1 = false;
320 bool no_int : 1 = false;
321 bool no_struct_params : 1 = false;
322 bool must_be_column : 1 = false;
323
324 ColumnCheckData columnCheck;
325
329 bool validColumns(StructType s) {
330 if (!must_be_column) {
331 return true;
332 }
333 assert(columnCheck.symbolTable);
334 assert(columnCheck.op);
335 return succeeded(s.hasColumns(*columnCheck.symbolTable, columnCheck.op));
336 }
337
338public:
339 constexpr AllowedTypes &noFelt() {
340 no_felt = true;
341 return *this;
342 }
343
344 constexpr AllowedTypes &noString() {
345 no_string = true;
346 return *this;
347 }
348
349 constexpr AllowedTypes &noStruct() {
350 no_struct = true;
351 return *this;
352 }
353
354 constexpr AllowedTypes &noArray() {
355 no_array = true;
356 return *this;
357 }
358
359 constexpr AllowedTypes &noPod() {
360 no_pod = true;
361 return *this;
362 }
363
364 constexpr AllowedTypes &noVar() {
365 no_var = true;
366 return *this;
367 }
368
369 constexpr AllowedTypes &noInt() {
370 no_int = true;
371 return *this;
372 }
373
374 constexpr AllowedTypes &noStructParams(bool noStructParams = true) {
375 no_struct_params = noStructParams;
376 return *this;
377 }
378
379 constexpr AllowedTypes &onlyInt() {
380 no_int = false;
381 return noFelt().noString().noStruct().noArray().noPod().noVar();
382 }
383
384 constexpr AllowedTypes &mustBeColumn(SymbolTableCollection &symbolTable, Operation *op) {
385 must_be_column = true;
386 columnCheck.symbolTable = &symbolTable;
387 columnCheck.op = op;
388 return *this;
389 }
390
391 // This is the main check for allowed types.
392 bool isValidTypeImpl(Type type);
393
394 bool areValidArrayDimSizes(ArrayRef<Attribute> dimensionSizes, EmitErrorFn emitError = nullptr) {
395 // In LLZK, the number of array dimensions must always be known, i.e., `hasRank()==true`
396 if (dimensionSizes.empty()) {
397 if (emitError) {
398 emitError().append("array must have at least one dimension").report();
399 }
400 return false;
401 }
402 // Rather than immediately returning on failure, we check all dimensions and aggregate to
403 // provide as many errors are possible in a single verifier run.
404 bool success = true;
405 for (Attribute a : dimensionSizes) {
406 if (!ArrayDimensionTypes::matches(a)) {
407 ArrayDimensionTypes::reportInvalid(emitError, a, "Array dimension");
408 success = false;
409 } else if (no_var && !llvm::isa_and_present<IntegerAttr>(a)) {
410 TypeList<IntegerAttr>::reportInvalid(emitError, a, "Concrete array dimension");
411 success = false;
412 } else if (failed(verifyAffineMapAttrType(emitError, a))) {
413 success = false;
414 } else if (failed(verifyIntAttrType(emitError, a))) {
415 success = false;
416 }
417 }
418 return success;
419 }
420
421 bool isValidArrayElemTypeImpl(Type type) {
422 // ArrayType element can be any valid type sans ArrayType itself.
423 return !llvm::isa<ArrayType>(type) && isValidTypeImpl(type);
424 }
425
426 bool isValidArrayTypeImpl(
427 Type elementType, ArrayRef<Attribute> dimensionSizes, EmitErrorFn emitError = nullptr
428 ) {
429 if (!areValidArrayDimSizes(dimensionSizes, emitError)) {
430 return false;
431 }
432
433 // Ensure array element type is valid
434 if (!isValidArrayElemTypeImpl(elementType)) {
435 if (emitError) {
436 // Print proper message if `elementType` is not a valid LLZK type or
437 // if it's simply not the right kind of type for an array element.
438 if (succeeded(checkValidType(emitError, elementType))) {
439 emitError()
440 .append(
441 '\'', ArrayType::name, "' element type cannot be '",
442 elementType.getAbstractType().getName(), '\''
443 )
444 .report();
445 }
446 }
447 return false;
448 }
449 return true;
450 }
451
452 bool isValidArrayTypeImpl(Type type) {
453 if (ArrayType arrTy = llvm::dyn_cast<ArrayType>(type)) {
454 return isValidArrayTypeImpl(arrTy.getElementType(), arrTy.getDimensionSizes());
455 }
456 return false;
457 }
458
459 // Note: The `no*` flags here refer to Types nested within a TypeAttr parameter (if any) except
460 // for the `no_struct_params` flag which requires that `params` is null or empty.
461 bool areValidStructTypeParams(ArrayAttr params, EmitErrorFn emitError = nullptr) {
462 if (isNullOrEmpty(params)) {
463 return true;
464 }
465 if (no_struct_params) {
466 return false;
467 }
468 bool success = true;
469 for (Attribute p : params) {
470 if (!StructParamTypes::matches(p)) {
471 StructParamTypes::reportInvalid(emitError, p, "Struct parameter");
472 success = false;
473 } else if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(p)) {
474 if (!isValidTypeImpl(tyAttr.getValue())) {
475 if (emitError) {
476 emitError().append("expected a valid LLZK type but found ", tyAttr.getValue()).report();
477 }
478 success = false;
479 }
480 } else if (no_var && !llvm::isa<IntegerAttr>(p)) {
481 TypeList<IntegerAttr>::reportInvalid(emitError, p, "Concrete struct parameter");
482 success = false;
483 } else if (failed(verifyAffineMapAttrType(emitError, p))) {
484 success = false;
485 } else if (failed(verifyIntAttrType(emitError, p))) {
486 success = false;
487 }
488 }
489
490 return success;
491 }
492
493 bool areValidPodRecords(ArrayRef<RecordAttr> records) {
494 return llvm::all_of(records, [this](auto record) { return isValidTypeImpl(record.getType()); });
495 }
496};
497
498bool AllowedTypes::isValidTypeImpl(Type type) {
499 assert(
500 !(no_int && no_felt && no_string && no_var && no_struct && no_array && no_pod) &&
501 "All types have been deactivated"
502 );
503 struct Impl : LLZKTypeSwitch<Impl, bool> {
504 AllowedTypes &outer;
505 Impl(AllowedTypes &outerRef) : outer(outerRef) {}
506
507 bool caseBool(IntegerType t) { return !outer.no_int && t.isSignlessInteger(1); }
508 bool caseIndex(IndexType) { return !outer.no_int; }
509 bool caseFelt(FeltType) { return !outer.no_felt; }
510 bool caseString(StringType) { return !outer.no_string; }
511 bool caseTypeVar(TypeVarType) { return !outer.no_var; }
512 bool caseArray(ArrayType t) {
513 return !outer.no_array &&
514 outer.isValidArrayTypeImpl(t.getElementType(), t.getDimensionSizes());
515 }
516 bool casePod(PodType t) { return !outer.no_pod && outer.areValidPodRecords(t.getRecords()); }
517 bool caseStruct(StructType t) {
518 // Note: The `no*` flags here refer to Types nested within a TypeAttr parameter.
519 if (outer.no_struct || !outer.validColumns(t)) {
520 return false;
521 }
522 return !outer.no_struct && outer.areValidStructTypeParams(t.getParams());
523 }
524 bool caseInvalid(Type) { return false; }
525 };
526 return Impl(*this).match(type);
527}
528
529} // namespace
530
531bool isValidType(Type type) { return AllowedTypes().isValidTypeImpl(type); }
532
533bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op) {
534 return AllowedTypes().noString().noInt().mustBeColumn(symbolTable, op).isValidTypeImpl(type);
535}
536
537bool isValidGlobalType(Type type) { return AllowedTypes().noVar().isValidTypeImpl(type); }
538
539bool isValidEmitEqType(Type type) {
540 return AllowedTypes().noString().noStruct().isValidTypeImpl(type);
541}
542
543// Allowed types must be a subset of StructParamTypes (defined below)
544bool isValidConstReadType(Type type) {
545 return AllowedTypes().noString().noStruct().noArray().noPod().isValidTypeImpl(type);
546}
547
548bool isValidArrayElemType(Type type) { return AllowedTypes().isValidArrayElemTypeImpl(type); }
549
550bool isValidArrayType(Type type) { return AllowedTypes().isValidArrayTypeImpl(type); }
551
552bool isConcreteType(Type type, bool allowStructParams) {
553 return AllowedTypes().noVar().noStructParams(!allowStructParams).isValidTypeImpl(type);
554}
555
556bool hasAffineMapAttr(Type type) {
557 bool encountered = false;
558 type.walk([&](AffineMapAttr) {
559 encountered = true;
560 return WalkResult::interrupt();
561 });
562 return encountered;
563}
564
565bool isDynamic(IntegerAttr intAttr) { return ShapedType::isDynamic(fromAPInt(intAttr.getValue())); }
566
567uint64_t computeEmitEqCardinality(Type type) {
568 struct Impl : LLZKTypeSwitch<Impl, uint64_t> {
569 uint64_t caseBool(IntegerType) { return 1; }
570 uint64_t caseIndex(IndexType) { return 1; }
571 uint64_t caseFelt(FeltType) { return 1; }
572 uint64_t caseArray(ArrayType t) {
573 int64_t n = t.getNumElements();
575 }
576 uint64_t caseStruct(StructType) { llvm_unreachable("not a valid EmitEq type"); }
577 uint64_t casePod(PodType t) {
578 return std::accumulate(
579 t.getRecords().begin(), t.getRecords().end(), 0,
580 [](const uint64_t &acc, const RecordAttr &record) {
581 return computeEmitEqCardinality(record.getType()) + acc;
582 }
583 );
584 }
585 uint64_t caseString(StringType) { llvm_unreachable("not a valid EmitEq type"); }
586 uint64_t caseTypeVar(TypeVarType) { llvm_unreachable("tvar has unknown cardinality"); }
587 uint64_t caseInvalid(Type) { llvm_unreachable("not a valid LLZK type"); }
588 };
589 return Impl().match(type);
590}
591
592namespace {
593
602using AffineInstantiations = DenseMap<std::pair<AffineMapAttr, Side>, IntegerAttr>;
603
604struct UnifierImpl {
605 ArrayRef<StringRef> rhsRevPrefix;
606 UnificationMap *unifications;
607 AffineInstantiations *affineToIntTracker;
608 // This optional function can be used to provide an exception to the standard unification
609 // rules and return a true/success result when it otherwise may not.
610 llvm::function_ref<bool(Type oldTy, Type newTy)> overrideSuccess;
611
612 UnifierImpl(UnificationMap *unificationMap, ArrayRef<StringRef> rhsReversePrefix = {})
613 : rhsRevPrefix(rhsReversePrefix), unifications(unificationMap), affineToIntTracker(nullptr),
614 overrideSuccess(nullptr) {}
615
616 UnifierImpl &trackAffineToInt(AffineInstantiations *tracker) {
617 this->affineToIntTracker = tracker;
618 return *this;
619 }
620
621 UnifierImpl &withOverrides(llvm::function_ref<bool(Type oldTy, Type newTy)> overrides) {
622 this->overrideSuccess = overrides;
623 return *this;
624 }
625
628 template <typename Iter1, typename Iter2> bool typeListsUnify(Iter1 lhs, Iter2 rhs) {
629 return (lhs.size() == rhs.size()) &&
630 std::equal(lhs.begin(), lhs.end(), rhs.begin(), [this](Type a, Type b) {
631 return this->typesUnify(a, b);
632 });
633 }
634
637 bool typeParamsUnify(
638 const ArrayRef<Attribute> &lhsParams, const ArrayRef<Attribute> &rhsParams,
639 bool unifyDynamicSize = false
640 ) {
641 auto pred = [this, unifyDynamicSize](auto lhsAttr, auto rhsAttr) {
642 return paramAttrUnify(lhsAttr, rhsAttr, unifyDynamicSize);
643 };
644 return (lhsParams.size() == rhsParams.size()) &&
645 std::equal(lhsParams.begin(), lhsParams.end(), rhsParams.begin(), pred);
646 }
647
652 bool typeParamsUnify(
653 const ArrayAttr &lhsParams, const ArrayAttr &rhsParams, bool unifyDynamicSize = false
654 ) {
655 ArrayRef<Attribute> emptyParams;
656 return typeParamsUnify(
657 lhsParams ? lhsParams.getValue() : emptyParams,
658 rhsParams ? rhsParams.getValue() : emptyParams, unifyDynamicSize
659 );
660 }
661
662 bool arrayTypesUnify(ArrayType lhs, ArrayType rhs) {
663 // Check if the element types of the two arrays can unify
664 if (!typesUnify(lhs.getElementType(), rhs.getElementType())) {
665 return false;
666 }
667 // Check if the dimension size attributes unify between the LHS and RHS
668 return typeParamsUnify(
669 lhs.getDimensionSizes(), rhs.getDimensionSizes(), /*unifyDynamicSize=*/true
670 );
671 }
672
673 bool structTypesUnify(StructType lhs, StructType rhs) {
674 LLVM_DEBUG({
675 llvm::dbgs() << "[structTypesUnify] lhs = " << lhs << ", rhs = " << rhs << '\n';
676 });
677 // Check if it references the same StructDefOp, considering the additional RHS path prefix.
678 SmallVector<StringRef> rhsNames = getNames(rhs.getNameRef());
679 rhsNames.insert(rhsNames.begin(), rhsRevPrefix.rbegin(), rhsRevPrefix.rend());
680 auto lhsNames = getNames(lhs.getNameRef());
681 if (rhsNames != lhsNames) {
682 LLVM_DEBUG({
683 llvm::interleaveComma(
684 lhsNames, llvm::dbgs() << "[structTypesUnify] names do not match\n"
685 << " lhsNames = ["
686 );
687 llvm::interleaveComma(
688 rhsNames, llvm::dbgs() << "]\n"
689 << " rhsNames = ["
690 );
691 llvm::dbgs() << "]\n";
692 });
693 return false;
694 }
695 LLVM_DEBUG({ llvm::dbgs() << "[structTypesUnify] checking unification of parameters\n"; });
696 // Check if the parameters unify between the LHS and RHS
697 return typeParamsUnify(lhs.getParams(), rhs.getParams(), /*unifyDynamicSize=*/false);
698 }
699
700 bool podTypesUnify(PodType lhs, PodType rhs) {
701 // Same number of records, with the same names in the same order and record types unify.
702 auto lhsRecords = lhs.getRecords();
703 auto rhsRecords = rhs.getRecords();
704
705 return lhsRecords.size() == rhsRecords.size() &&
706 llvm::all_of(llvm::zip_equal(lhsRecords, rhsRecords), [this](auto &&records) {
707 auto &&[lhsRecord, rhsRecord] = records;
708 return lhsRecord.getName() == rhsRecord.getName() &&
709 typesUnify(lhsRecord.getType(), rhsRecord.getType());
710 });
711 }
712
713 bool functionTypesUnify(FunctionType lhs, FunctionType rhs) {
714 return typeListsUnify(lhs.getInputs(), rhs.getInputs()) &&
715 typeListsUnify(lhs.getResults(), rhs.getResults());
716 }
717
718 bool typesUnify(Type lhs, Type rhs) {
719 if (lhs == rhs) {
720 return true;
721 }
722 if (overrideSuccess && overrideSuccess(lhs, rhs)) {
723 return true;
724 }
725 // A type variable can be any type, thus it unifies with anything.
726 if (TypeVarType lhsTvar = llvm::dyn_cast<TypeVarType>(lhs)) {
727 track(Side::LHS, lhsTvar.getNameRef(), rhs);
728 return true;
729 }
730 if (TypeVarType rhsTvar = llvm::dyn_cast<TypeVarType>(rhs)) {
731 track(Side::RHS, rhsTvar.getNameRef(), lhs);
732 return true;
733 }
734 if (llvm::isa<StructType>(lhs) && llvm::isa<StructType>(rhs)) {
735 return structTypesUnify(llvm::cast<StructType>(lhs), llvm::cast<StructType>(rhs));
736 }
737 if (llvm::isa<ArrayType>(lhs) && llvm::isa<ArrayType>(rhs)) {
738 return arrayTypesUnify(llvm::cast<ArrayType>(lhs), llvm::cast<ArrayType>(rhs));
739 }
740 if (llvm::isa<PodType>(lhs) && llvm::isa<PodType>(rhs)) {
741 return podTypesUnify(llvm::cast<PodType>(lhs), llvm::cast<PodType>(rhs));
742 }
743 if (llvm::isa<FunctionType>(lhs) && llvm::isa<FunctionType>(rhs)) {
744 return functionTypesUnify(llvm::cast<FunctionType>(lhs), llvm::cast<FunctionType>(rhs));
745 }
746 return false;
747 }
748
749private:
750 template <typename Tracker, typename Key, typename Val>
751 inline void track(Tracker &tracker, Side side, Key keyHead, Val val) {
752 auto key = std::make_pair(keyHead, side);
753 auto it = tracker.find(key);
754 if (it == tracker.end()) {
755 tracker.try_emplace(key, val);
756 } else if (it->getSecond() != val) {
757 it->second = nullptr;
758 }
759 }
760
761 void track(Side side, SymbolRefAttr symRef, Type ty) {
762 if (unifications) {
763 Attribute attr;
764 if (TypeVarType tvar = dyn_cast<TypeVarType>(ty)) {
765 // If 'ty' is TypeVarType<@S>, just map to @S directly.
766 attr = tvar.getNameRef();
767 } else {
768 // Otherwise wrap as a TypeAttr.
769 attr = TypeAttr::get(ty);
770 }
771 assert(symRef);
772 assert(attr);
773 track(*unifications, side, symRef, attr);
774 }
775 }
776
777 void track(Side side, SymbolRefAttr symRef, Attribute attr) {
778 if (unifications) {
779 // If 'attr' is TypeAttr<TypeVarType<@S>>, just map to @S directly.
780 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(attr)) {
781 if (TypeVarType tvar = dyn_cast<TypeVarType>(tyAttr.getValue())) {
782 attr = tvar.getNameRef();
783 }
784 }
785 assert(symRef);
786 assert(attr);
787 // If 'attr' is a SymbolRefAttr, map in both directions for the correctness of
788 // `isMoreConcreteUnification()` which relies on RHS check while other external
789 // checks on the UnificationMap may do LHS checks, and in the case of both being
790 // SymbolRefAttr, unification in either direction is possible.
791 if (SymbolRefAttr otherSymAttr = dyn_cast<SymbolRefAttr>(attr)) {
792 track(*unifications, reverse(side), otherSymAttr, symRef);
793 }
794 track(*unifications, side, symRef, attr);
795 }
796 }
797
798 void track(Side side, AffineMapAttr affineAttr, IntegerAttr intAttr) {
799 if (affineToIntTracker) {
800 assert(affineAttr);
801 assert(intAttr);
802 assert(!isDynamic(intAttr));
803 track(*affineToIntTracker, side, affineAttr, intAttr);
804 }
805 }
806
807 bool paramAttrUnify(Attribute lhsAttr, Attribute rhsAttr, bool unifyDynamicSize = false) {
810 // Straightforward equality check.
811 if (lhsAttr == rhsAttr) {
812 return true;
813 }
814 // AffineMapAttr can unify with IntegerAttr (other than kDynamic) because struct parameter
815 // instantiation will result in conversion of AffineMapAttr to IntegerAttr.
816 if (AffineMapAttr lhsAffine = llvm::dyn_cast<AffineMapAttr>(lhsAttr)) {
817 if (IntegerAttr rhsInt = llvm::dyn_cast<IntegerAttr>(rhsAttr)) {
818 if (!isDynamic(rhsInt)) {
819 track(Side::LHS, lhsAffine, rhsInt);
820 return true;
821 }
822 }
823 }
824 if (AffineMapAttr rhsAffine = llvm::dyn_cast<AffineMapAttr>(rhsAttr)) {
825 if (IntegerAttr lhsInt = llvm::dyn_cast<IntegerAttr>(lhsAttr)) {
826 if (!isDynamic(lhsInt)) {
827 track(Side::RHS, rhsAffine, lhsInt);
828 return true;
829 }
830 }
831 }
832 // If either side is a SymbolRefAttr, assume they unify because either flattening or a pass with
833 // a more involved value analysis is required to check if they are actually the same value.
834 if (SymbolRefAttr lhsSymRef = llvm::dyn_cast<SymbolRefAttr>(lhsAttr)) {
835 track(Side::LHS, lhsSymRef, rhsAttr);
836 return true;
837 }
838 if (SymbolRefAttr rhsSymRef = llvm::dyn_cast<SymbolRefAttr>(rhsAttr)) {
839 track(Side::RHS, rhsSymRef, lhsAttr);
840 return true;
841 }
842 // If either side is ShapedType::kDynamic then, similarly to Symbols, assume they unify.
843 // NOTE: Dynamic array dimensions (i.e. '?') are allowed in LLZK but should generally be
844 // restricted to scenarios where it can be replaced with a concrete value during the flattening
845 // pass, such as a `unifiable_cast` where the other side of the cast has concrete dimensions or
846 // extern functions with varargs.
847 if (unifyDynamicSize) {
848 auto dyn_cast_if_dynamic = [](Attribute attr) -> IntegerAttr {
849 if (IntegerAttr intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
850 if (isDynamic(intAttr)) {
851 return intAttr;
852 }
853 }
854 return nullptr;
855 };
856 auto is_const_like = [](Attribute attr) {
857 return llvm::isa_and_present<IntegerAttr, SymbolRefAttr, AffineMapAttr>(attr);
858 };
859 if (IntegerAttr lhsIntAttr = dyn_cast_if_dynamic(lhsAttr)) {
860 if (is_const_like(rhsAttr)) {
861 return true;
862 }
863 }
864 if (IntegerAttr rhsIntAttr = dyn_cast_if_dynamic(rhsAttr)) {
865 if (is_const_like(lhsAttr)) {
866 return true;
867 }
868 }
869 }
870 // If both are type refs, check for unification of the types.
871 if (TypeAttr lhsTy = llvm::dyn_cast<TypeAttr>(lhsAttr)) {
872 if (TypeAttr rhsTy = llvm::dyn_cast<TypeAttr>(rhsAttr)) {
873 return typesUnify(lhsTy.getValue(), rhsTy.getValue());
874 }
875 }
876 // Otherwise, they do not unify.
877 return false;
878 }
879};
880
881} // namespace
882
884 const ArrayRef<Attribute> &lhsParams, const ArrayRef<Attribute> &rhsParams,
885 UnificationMap *unifications
886) {
887 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
888}
889
893 const ArrayAttr &lhsParams, const ArrayAttr &rhsParams, UnificationMap *unifications
894) {
895 return UnifierImpl(unifications).typeParamsUnify(lhsParams, rhsParams);
896}
897
899 ArrayType lhs, ArrayType rhs, ArrayRef<StringRef> rhsReversePrefix, UnificationMap *unifications
900) {
901 return UnifierImpl(unifications, rhsReversePrefix).arrayTypesUnify(lhs, rhs);
902}
903
905 StructType lhs, StructType rhs, ArrayRef<StringRef> rhsReversePrefix,
906 UnificationMap *unifications
907) {
908 return UnifierImpl(unifications, rhsReversePrefix).structTypesUnify(lhs, rhs);
909}
910
912 PodType lhs, PodType rhs, ArrayRef<StringRef> rhsReversePrefix, UnificationMap *unifications
913) {
914 return UnifierImpl(unifications, rhsReversePrefix).podTypesUnify(lhs, rhs);
915}
916
918 FunctionType lhs, FunctionType rhs, ArrayRef<StringRef> rhsReversePrefix,
919 UnificationMap *unifications
920) {
921 return UnifierImpl(unifications, rhsReversePrefix).functionTypesUnify(lhs, rhs);
922}
923
925 Type lhs, Type rhs, ArrayRef<StringRef> rhsReversePrefix, UnificationMap *unifications
926) {
927 return UnifierImpl(unifications, rhsReversePrefix).typesUnify(lhs, rhs);
928}
929
931 Type oldTy, Type newTy, llvm::function_ref<bool(Type oldTy, Type newTy)> knownOldToNew
932) {
933 UnificationMap unifications;
934 AffineInstantiations affineInstantiations;
935 // Run type unification with the addition that affine map can become integer in the new type.
936 if (!UnifierImpl(&unifications)
937 .trackAffineToInt(&affineInstantiations)
938 .withOverrides(knownOldToNew)
939 .typesUnify(oldTy, newTy)) {
940 return false;
941 }
942
943 // If either map contains RHS-keyed mappings then the old type is "more concrete" than the new.
944 // In the UnificationMap, a RHS key would indicate that the new type contains a SymbolRef (i.e.
945 // the "least concrete" attribute kind) where the old type contained any other attribute. In the
946 // AffineInstantiations map, a RHS key would indicate that the new type contains an AffineMapAttr
947 // where the old type contains an IntegerAttr.
948 auto entryIsRHS = [](const auto &entry) { return entry.first.second == Side::RHS; };
949 return !llvm::any_of(unifications, entryIsRHS) && !llvm::any_of(affineInstantiations, entryIsRHS);
950}
951
952FailureOr<IntegerAttr> forceIntType(IntegerAttr attr, EmitErrorFn emitError) {
953 if (llvm::isa<IndexType>(attr.getType())) {
954 return attr;
955 }
956 // Ensure the APInt is the right bitwidth for IndexType or else
957 // IntegerAttr::verify(..) will report an error.
958 APInt value = attr.getValue();
959 auto compare = value.getBitWidth() <=> IndexType::kInternalStorageBitWidth;
960 if (compare < 0) {
961 value = value.zext(IndexType::kInternalStorageBitWidth);
962 } else if (compare > 0) {
963 return emitError().append("value is too large for `index` type: ", debug::toStringOne(value));
964 }
965 return IntegerAttr::get(IndexType::get(attr.getContext()), value);
966}
967
968FailureOr<Attribute> forceIntAttrType(Attribute attr, EmitErrorFn emitError) {
969 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(attr)) {
970 return forceIntType(intAttr, emitError);
971 }
972 return attr;
973}
974
975FailureOr<SmallVector<Attribute>>
976forceIntAttrTypes(ArrayRef<Attribute> attrList, EmitErrorFn emitError) {
977 SmallVector<Attribute> result;
978 for (Attribute attr : attrList) {
979 FailureOr<Attribute> forced = forceIntAttrType(attr, emitError);
980 if (failed(forced)) {
981 return failure();
982 }
983 result.push_back(*forced);
984 }
985 return result;
986}
987
988LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in) {
989 if (IntegerAttr intAttr = llvm::dyn_cast_if_present<IntegerAttr>(in)) {
990 Type attrTy = intAttr.getType();
991 if (!AllowedTypes().onlyInt().isValidTypeImpl(attrTy)) {
992 if (emitError) {
993 emitError()
994 .append("IntegerAttr must have type 'index' or 'i1' but found '", attrTy, '\'')
995 .report();
996 }
997 return failure();
998 }
999 }
1000 return success();
1001}
1002
1003LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in) {
1004 if (AffineMapAttr affineAttr = llvm::dyn_cast_if_present<AffineMapAttr>(in)) {
1005 AffineMap map = affineAttr.getValue();
1006 if (map.getNumResults() != 1) {
1007 if (emitError) {
1008 emitError()
1009 .append(
1010 "AffineMapAttr must yield a single result, but found ", map.getNumResults(),
1011 " results"
1012 )
1013 .report();
1014 }
1015 return failure();
1016 }
1017 }
1018 return success();
1019}
1020
1021LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params) {
1022 return success(AllowedTypes().areValidStructTypeParams(params, emitError));
1023}
1024
1025LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef<Attribute> dimensionSizes) {
1026 return success(AllowedTypes().areValidArrayDimSizes(dimensionSizes, emitError));
1027}
1028
1029LogicalResult
1030verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef<Attribute> dimensionSizes) {
1031 return success(AllowedTypes().isValidArrayTypeImpl(elementType, dimensionSizes, emitError));
1032}
1033
1034void assertValidAttrForParamOfType(Attribute attr) {
1035 // Must be the union of valid attribute types within ArrayType, StructType, and TypeVarType.
1036 using TypeVarAttrs = TypeList<SymbolRefAttr>; // per ODS spec of TypeVarType
1037 if (!TypeListUnion<ArrayDimensionTypes, StructParamTypes, TypeVarAttrs>::matches(attr)) {
1038 llvm::report_fatal_error(
1039 "Legal type parameters are inconsistent. Encountered " +
1040 attr.getAbstractAttribute().getName()
1041 );
1042 }
1043}
1044
1045LogicalResult
1046verifySubArrayType(EmitErrorFn emitError, ArrayType arrayType, ArrayType subArrayType) {
1047 ArrayRef<Attribute> dimsFromArr = arrayType.getDimensionSizes();
1048 size_t numArrDims = dimsFromArr.size();
1049 ArrayRef<Attribute> dimsFromSubArr = subArrayType.getDimensionSizes();
1050 size_t numSubArrDims = dimsFromSubArr.size();
1051
1052 if (numArrDims < numSubArrDims) {
1053 return emitError().append(
1054 "subarray type ", subArrayType, " has more dimensions than array type ", arrayType
1055 );
1056 }
1057
1058 size_t toDrop = numArrDims - numSubArrDims;
1059 ArrayRef<Attribute> dimsFromArrReduced = dimsFromArr.drop_front(toDrop);
1060
1061 // Ensure dimension sizes are compatible (ignoring the indexed dimensions)
1062 if (!typeParamsUnify(dimsFromArrReduced, dimsFromSubArr)) {
1063 std::string message;
1064 llvm::raw_string_ostream ss(message);
1065 auto appendOne = [&ss](Attribute a) { appendWithoutType(ss, a); };
1066 ss << "cannot unify array dimensions [";
1067 llvm::interleaveComma(dimsFromArrReduced, ss, appendOne);
1068 ss << "] with [";
1069 llvm::interleaveComma(dimsFromSubArr, ss, appendOne);
1070 ss << "]";
1071 return emitError().append(message);
1072 }
1073
1074 // Ensure element types of the arrays are compatible
1075 if (!typesUnify(arrayType.getElementType(), subArrayType.getElementType())) {
1076 return emitError().append(
1077 "incorrect array element type; expected: ", arrayType.getElementType(),
1078 ", found: ", subArrayType.getElementType()
1079 );
1080 }
1081
1082 return success();
1083}
1084
1085LogicalResult
1086verifySubArrayOrElementType(EmitErrorFn emitError, ArrayType arrayType, Type subArrayOrElemType) {
1087 if (auto subArrayType = llvm::dyn_cast<ArrayType>(subArrayOrElemType)) {
1088 return verifySubArrayType(emitError, arrayType, subArrayType);
1089 }
1090 if (!typesUnify(arrayType.getElementType(), subArrayOrElemType)) {
1091 return emitError().append(
1092 "incorrect array element type; expected: ", arrayType.getElementType(),
1093 ", found: ", subArrayOrElemType
1094 );
1095 }
1096
1097 return success();
1098}
1099
1101 return TypeSwitch<Type, bool>(ty)
1102 .Case<FeltType>([](auto) { return true; })
1103 .Case<ArrayType>([](auto arrTy) {
1104 return isFeltOrSimpleFeltAggregate(arrTy.getElementType());
1105 })
1106 .Case<PodType>([](auto podTy) {
1107 for (auto record : podTy.getRecords()) {
1108 if (!isFeltOrSimpleFeltAggregate(record.getType())) {
1109 return false;
1110 }
1111 }
1112 return true;
1113 }).Default([](auto) { return false; });
1114}
1115
1116bool isValidMainSignalType(Type pType) {
1117 if (auto arrayParamTy = llvm::dyn_cast<ArrayType>(pType)) {
1118 return llvm::isa<FeltType>(arrayParamTy.getElementType());
1119 }
1120 return llvm::isa<FeltType>(pType);
1121}
1122
1123} // namespace llzk
Note: If any symbol refs in an input Type/Attribute use any of the special characters that this class...
Definition TypeHelper.h:39
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:55
::mlir::Type getElementType() const
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:51
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
std::string toStringOne(const T &value)
Definition Debug.h:182
LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in)
void assertValidAttrForParamOfType(Attribute attr)
LogicalResult verifySubArrayType(EmitErrorFn emitError, ArrayType arrayType, ArrayType subArrayType)
Determine if the subArrayType is a valid subarray of arrayType.
FailureOr< Attribute > forceIntAttrType(Attribute attr, EmitErrorFn emitError)
uint64_t computeEmitEqCardinality(Type type)
bool isValidArrayType(Type type)
LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in)
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
Definition TypeHelper.h:240
bool isConcreteType(Type type, bool allowStructParams)
bool isValidArrayElemType(Type type)
llvm::SmallVector< StringRef > getNames(SymbolRefAttr ref)
bool isValidGlobalType(Type type)
FailureOr< IntegerAttr > forceIntType(IntegerAttr attr, EmitErrorFn emitError)
Convert an IntegerAttr with a type other than IndexType to use IndexType.
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
LogicalResult verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef< Attribute > dimensionSizes)
bool isFeltOrSimpleFeltAggregate(Type ty)
LogicalResult verifySubArrayOrElementType(EmitErrorFn emitError, ArrayType arrayType, Type subArrayOrElemType)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
bool isValidMainSignalType(Type pType)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:186
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
FailureOr< SmallVector< Attribute > > forceIntAttrTypes(ArrayRef< Attribute > attrList, EmitErrorFn emitError)
bool isNullOrEmpty(mlir::ArrayAttr a)
bool podTypesUnify(PodType lhs, PodType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
constexpr T checkedCast(U u) noexcept
Definition Compare.h:81
bool isValidEmitEqType(Type type)
bool isValidType(Type type)
bool arrayTypesUnify(ArrayType lhs, ArrayType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isDynamic(IntegerAttr intAttr)
Side reverse(Side in)
Definition TypeHelper.h:144
int64_t fromAPInt(const llvm::APInt &i)
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool typeParamsUnify(const ArrayRef< Attribute > &lhsParams, const ArrayRef< Attribute > &rhsParams, UnificationMap *unifications)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
bool functionTypesUnify(FunctionType lhs, FunctionType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params)
void appendWithoutType(mlir::raw_ostream &os, mlir::Attribute a)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
bool hasAffineMapAttr(Type type)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
Definition TypeHelper.h:114
bool isValidConstReadType(Type type)
LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef< Attribute > dimensionSizes)
Template pattern for performing some operation by cases based on a given LLZK type.
ResultType match(Type type)