LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
TypeHelper.h
Go to the documentation of this file.
1//===-- TypeHelper.h --------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
13
14#include <mlir/IR/OpImplementation.h>
15#include <mlir/IR/Operation.h>
16#include <mlir/IR/SymbolTable.h>
17
18#include <llvm/ADT/ArrayRef.h>
19#include <llvm/ADT/DenseMap.h>
20#include <llvm/ADT/StringRef.h>
21
22namespace llzk {
23
24// Forward declarations
25namespace component {
26class StructType;
27} // namespace component
28namespace array {
29class ArrayType;
30} // namespace array
31namespace pod {
32class PodType;
33} // namespace pod
34
39class BuildShortTypeString {
40 static constexpr char PLACEHOLDER = '\x1A';
41
42 std::string ret;
43 llvm::raw_string_ostream ss;
44
45 BuildShortTypeString() : ret(), ss(ret) {}
46 BuildShortTypeString &append(mlir::Type);
47 BuildShortTypeString &append(mlir::ArrayRef<mlir::Attribute>);
48 BuildShortTypeString &append(mlir::Attribute);
49
50 void appendSymRef(mlir::SymbolRefAttr);
51 void appendSymName(mlir::StringRef);
52
53public:
55 static inline std::string from(mlir::Type type) {
56 return BuildShortTypeString().append(type).ret;
57 }
58
61 static inline std::string from(mlir::ArrayRef<mlir::Attribute> attrs) {
62 return BuildShortTypeString().append(attrs).ret;
63 }
64
72 static std::string from(const std::string &base, mlir::ArrayRef<mlir::Attribute> attrs);
73};
74
75// This function asserts that the given Attribute kind is legal within the LLZK types that can
76// contain Attribute parameters (i.e., ArrayType, StructType, and TypeVarType). This should be used
77// in any function that examines the attribute parameters within parameterized LLZK types to ensure
78// that the function handles all possible cases properly, especially if more legal attributes are
79// added in the future. Throw a fatal error if anything illegal is found, indicating that the caller
80// of this function should be updated.
81void assertValidAttrForParamOfType(mlir::Attribute attr);
82
84bool isValidType(mlir::Type type);
85
89 mlir::Type type, mlir::SymbolTableCollection &symbolTable, mlir::Operation *op
90);
91
93bool isValidGlobalType(mlir::Type type);
94
96bool isValidEmitEqType(mlir::Type type);
97
99bool isValidConstReadType(mlir::Type type);
100
102bool isValidArrayElemType(mlir::Type type);
103
105bool isValidArrayType(mlir::Type type);
106
112bool isConcreteType(mlir::Type type, bool allowStructParams = true);
113
114inline mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type) {
115 if (!isValidType(type)) {
116 return emitError() << "expected a valid LLZK type but found " << type;
117 } else {
118 return mlir::success();
119 }
120}
121
123bool hasAffineMapAttr(mlir::Type type);
124
125enum class Side : std::uint8_t { EMPTY = 0, LHS, RHS, TOMB };
126static inline mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const Side &val) {
127 switch (val) {
128 case Side::EMPTY:
129 os << "EMPTY";
130 break;
131 case Side::TOMB:
132 os << "TOMB";
133 break;
134 case Side::LHS:
135 os << "LHS";
136 break;
137 case Side::RHS:
138 os << "RHS";
139 break;
140 }
141 return os;
142}
143
144inline Side reverse(Side in) {
145 switch (in) {
146 case Side::LHS:
147 return Side::RHS;
148 case Side::RHS:
149 return Side::LHS;
150 default:
151 return in;
152 }
153}
154
155} // namespace llzk
156
157namespace llvm {
158template <> struct DenseMapInfo<llzk::Side> {
159 using T = llzk::Side;
160 static inline T getEmptyKey() { return T::EMPTY; }
161 static inline T getTombstoneKey() { return T::TOMB; }
162 static unsigned getHashValue(const T &val) {
163 using UT = std::underlying_type_t<T>;
164 return llvm::DenseMapInfo<UT>::getHashValue(static_cast<UT>(val));
165 }
166 static bool isEqual(const T &lhs, const T &rhs) { return lhs == rhs; }
167};
168} // namespace llvm
169
170namespace llzk {
171
172bool isDynamic(mlir::IntegerAttr intAttr);
173
176uint64_t computeEmitEqCardinality(mlir::Type type);
177
186using UnificationMap = mlir::DenseMap<std::pair<mlir::SymbolRefAttr, Side>, mlir::Attribute>;
187
191 const mlir::ArrayRef<mlir::Attribute> &lhsParams,
192 const mlir::ArrayRef<mlir::Attribute> &rhsParams, UnificationMap *unifications = nullptr
193);
194
198 const mlir::ArrayAttr &lhsParams, const mlir::ArrayAttr &rhsParams,
199 UnificationMap *unifications = nullptr
200);
201
206 mlir::ArrayRef<llvm::StringRef> rhsReversePrefix = {}, UnificationMap *unifications = nullptr
207);
208
213 mlir::ArrayRef<llvm::StringRef> rhsReversePrefix = {}, UnificationMap *unifications = nullptr
214);
215
219 pod::PodType lhs, pod::PodType rhs, mlir::ArrayRef<llvm::StringRef> rhsReversePrefix = {},
220 UnificationMap *unifications = nullptr
221);
222
226 mlir::FunctionType lhs, mlir::FunctionType rhs,
227 mlir::ArrayRef<llvm::StringRef> rhsReversePrefix = {}, UnificationMap *unifications = nullptr
228);
229
233 mlir::Type lhs, mlir::Type rhs, mlir::ArrayRef<llvm::StringRef> rhsReversePrefix = {},
234 UnificationMap *unifications = nullptr
235);
236
239template <typename Iter1, typename Iter2>
240inline bool typeListsUnify(
241 Iter1 lhs, Iter2 rhs, mlir::ArrayRef<llvm::StringRef> rhsReversePrefix = {},
242 UnificationMap *unifications = nullptr
243) {
244 return (lhs.size() == rhs.size()) &&
245 std::equal(lhs.begin(), lhs.end(), rhs.begin(), [&](mlir::Type a, mlir::Type b) {
246 return typesUnify(a, b, rhsReversePrefix, unifications);
247 });
248}
249
250template <typename Iter1, typename Iter2>
252 Iter1 lhs, Iter2 rhs, mlir::ArrayRef<llvm::StringRef> rhsReversePrefix = {},
253 UnificationMap *unifications = nullptr
254) {
255 return lhs.size() == 1 && rhs.size() == 1 &&
256 typesUnify(lhs.front(), rhs.front(), rhsReversePrefix, unifications);
257}
258
266 mlir::Type oldTy, mlir::Type newTy,
267 llvm::function_ref<bool(mlir::Type oldTy, mlir::Type newTy)> knownOldToNew = nullptr
268);
269
270template <typename TypeClass> inline TypeClass getIfSingleton(mlir::TypeRange types) {
271 return (types.size() == 1) ? llvm::dyn_cast<TypeClass>(types.front()) : nullptr;
272}
273
274template <typename TypeClass> inline TypeClass getAtIndex(mlir::TypeRange types, size_t index) {
275 return (types.size() > index) ? llvm::dyn_cast<TypeClass>(types[index]) : nullptr;
276}
277
279mlir::FailureOr<mlir::IntegerAttr> forceIntType(mlir::IntegerAttr attr, EmitErrorFn emitError);
280
282mlir::FailureOr<mlir::Attribute> forceIntAttrType(mlir::Attribute attr, EmitErrorFn emitError);
283
285mlir::FailureOr<llvm::SmallVector<mlir::Attribute>>
286forceIntAttrTypes(llvm::ArrayRef<mlir::Attribute> attrList, EmitErrorFn emitError);
287
289mlir::LogicalResult verifyIntAttrType(EmitErrorFn emitError, mlir::Attribute in);
290
292mlir::LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, mlir::Attribute in);
293
295mlir::LogicalResult verifyStructTypeParams(EmitErrorFn emitError, mlir::ArrayAttr params);
296
298mlir::LogicalResult
299verifyArrayDimSizes(EmitErrorFn emitError, mlir::ArrayRef<mlir::Attribute> dimensionSizes);
300
302mlir::LogicalResult verifyArrayType(
303 EmitErrorFn emitError, mlir::Type elementType, mlir::ArrayRef<mlir::Attribute> dimensionSizes
304);
305
311mlir::LogicalResult verifySubArrayType(
312 EmitErrorFn emitError, array::ArrayType arrayType, array::ArrayType subArrayType
313);
314
318mlir::LogicalResult verifySubArrayOrElementType(
319 EmitErrorFn emitError, array::ArrayType arrayType, mlir::Type subArrayOrElemType
320);
321
324bool isFeltOrSimpleFeltAggregate(mlir::Type ty);
325
329bool isValidMainSignalType(mlir::Type pType);
330
331} // namespace llzk
static std::string from(mlir::ArrayRef< mlir::Attribute > attrs)
Return a brief string representation of the attribute list from a parameterized type.
Definition TypeHelper.h:61
static std::string from(const std::string &base, mlir::ArrayRef< mlir::Attribute > attrs)
Take an existing name prefix/base that contains N>=0 PLACEHOLDER character(s) and the Attribute list ...
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:55
LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in)
void assertValidAttrForParamOfType(Attribute attr)
LogicalResult verifySubArrayType(EmitErrorFn emitError, ArrayType arrayType, ArrayType subArrayType)
Determine if the subArrayType is a valid subarray of arrayType.
FailureOr< Attribute > forceIntAttrType(Attribute attr, EmitErrorFn emitError)
uint64_t computeEmitEqCardinality(Type type)
bool isValidArrayType(Type type)
LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in)
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
Definition TypeHelper.h:240
bool isConcreteType(Type type, bool allowStructParams)
bool isValidArrayElemType(Type type)
TypeClass getIfSingleton(mlir::TypeRange types)
Definition TypeHelper.h:270
bool isValidGlobalType(Type type)
FailureOr< IntegerAttr > forceIntType(IntegerAttr attr, EmitErrorFn emitError)
Convert an IntegerAttr with a type other than IndexType to use IndexType.
bool singletonTypeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Definition TypeHelper.h:251
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
LogicalResult verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef< Attribute > dimensionSizes)
bool isFeltOrSimpleFeltAggregate(Type ty)
LogicalResult verifySubArrayOrElementType(EmitErrorFn emitError, ArrayType arrayType, Type subArrayOrElemType)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
bool isValidMainSignalType(Type pType)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:186
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
FailureOr< SmallVector< Attribute > > forceIntAttrTypes(ArrayRef< Attribute > attrList, EmitErrorFn emitError)
Interval operator<<(const Interval &lhs, const Interval &rhs)
bool podTypesUnify(PodType lhs, PodType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isValidEmitEqType(Type type)
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
Definition TypeHelper.h:274
bool isValidType(Type type)
bool arrayTypesUnify(ArrayType lhs, ArrayType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isDynamic(IntegerAttr intAttr)
Side reverse(Side in)
Definition TypeHelper.h:144
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool typeParamsUnify(const ArrayRef< Attribute > &lhsParams, const ArrayRef< Attribute > &rhsParams, UnificationMap *unifications)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
bool functionTypesUnify(FunctionType lhs, FunctionType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params)
bool hasAffineMapAttr(Type type)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
Definition TypeHelper.h:114
bool isValidConstReadType(Type type)
LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef< Attribute > dimensionSizes)
static bool isEqual(const T &lhs, const T &rhs)
Definition TypeHelper.h:166
static unsigned getHashValue(const T &val)
Definition TypeHelper.h:162