LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
TypeHelper.h
Go to the documentation of this file.
1//===-- TypeHelper.h --------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
13
14#include <mlir/IR/OpImplementation.h>
15#include <mlir/IR/Operation.h>
16#include <mlir/IR/SymbolTable.h>
17
18#include <llvm/ADT/ArrayRef.h>
19#include <llvm/ADT/DenseMap.h>
20#include <llvm/ADT/StringRef.h>
21
22namespace llzk {
23
24// Forward declarations
25namespace component {
26class StructType;
27} // namespace component
28namespace array {
29class ArrayType;
30} // namespace array
31
36class BuildShortTypeString {
37 static constexpr char PLACEHOLDER = '\x1A';
38
39 std::string ret;
40 llvm::raw_string_ostream ss;
41
42 BuildShortTypeString() : ret(), ss(ret) {}
43 BuildShortTypeString &append(mlir::Type);
44 BuildShortTypeString &append(mlir::ArrayRef<mlir::Attribute>);
45 BuildShortTypeString &append(mlir::Attribute);
46
47 void appendSymRef(mlir::SymbolRefAttr);
48 void appendSymName(mlir::StringRef);
49
50public:
52 static inline std::string from(mlir::Type type) {
53 return BuildShortTypeString().append(type).ret;
54 }
55
58 static inline std::string from(mlir::ArrayRef<mlir::Attribute> attrs) {
59 return BuildShortTypeString().append(attrs).ret;
60 }
61
69 static std::string from(const std::string &base, mlir::ArrayRef<mlir::Attribute> attrs);
70};
71
72// This function asserts that the given Attribute kind is legal within the LLZK types that can
73// contain Attribute parameters (i.e., ArrayType, StructType, and TypeVarType). This should be used
74// in any function that examines the attribute parameters within parameterized LLZK types to ensure
75// that the function handles all possible cases properly, especially if more legal attributes are
76// added in the future. Throw a fatal error if anything illegal is found, indicating that the caller
77// of this function should be updated.
78void assertValidAttrForParamOfType(mlir::Attribute attr);
79
81bool isValidType(mlir::Type type);
82
86 mlir::Type type, mlir::SymbolTableCollection &symbolTable, mlir::Operation *op
87);
88
90bool isValidGlobalType(mlir::Type type);
91
93bool isValidEmitEqType(mlir::Type type);
94
96bool isValidConstReadType(mlir::Type type);
97
99bool isValidArrayElemType(mlir::Type type);
100
102bool isValidArrayType(mlir::Type type);
103
109bool isConcreteType(mlir::Type type, bool allowStructParams = true);
110
111inline mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type) {
112 if (!isValidType(type)) {
113 return emitError() << "expected a valid LLZK type but found " << type;
114 } else {
115 return mlir::success();
116 }
117}
118
120bool hasAffineMapAttr(mlir::Type type);
121
122enum class Side : std::uint8_t { EMPTY = 0, LHS, RHS, TOMB };
123static inline mlir::raw_ostream &operator<<(mlir::raw_ostream &os, const Side &val) {
124 switch (val) {
125 case Side::EMPTY:
126 os << "EMPTY";
127 break;
128 case Side::TOMB:
129 os << "TOMB";
130 break;
131 case Side::LHS:
132 os << "LHS";
133 break;
134 case Side::RHS:
135 os << "RHS";
136 break;
137 }
138 return os;
139}
140
141inline Side reverse(Side in) {
142 switch (in) {
143 case Side::LHS:
144 return Side::RHS;
145 case Side::RHS:
146 return Side::LHS;
147 default:
148 return in;
149 }
150}
151
152} // namespace llzk
153
154namespace llvm {
155template <> struct DenseMapInfo<llzk::Side> {
156 using T = llzk::Side;
157 static inline T getEmptyKey() { return T::EMPTY; }
158 static inline T getTombstoneKey() { return T::TOMB; }
159 static unsigned getHashValue(const T &val) {
160 using UT = std::underlying_type_t<T>;
161 return llvm::DenseMapInfo<UT>::getHashValue(static_cast<UT>(val));
162 }
163 static bool isEqual(const T &lhs, const T &rhs) { return lhs == rhs; }
164};
165} // namespace llvm
166
167namespace llzk {
168
169bool isDynamic(mlir::IntegerAttr intAttr);
170
173uint64_t computeEmitEqCardinality(mlir::Type type);
174
183using UnificationMap = mlir::DenseMap<std::pair<mlir::SymbolRefAttr, Side>, mlir::Attribute>;
184
188 const mlir::ArrayRef<mlir::Attribute> &lhsParams,
189 const mlir::ArrayRef<mlir::Attribute> &rhsParams, UnificationMap *unifications = nullptr
190);
191
195 const mlir::ArrayAttr &lhsParams, const mlir::ArrayAttr &rhsParams,
196 UnificationMap *unifications = nullptr
197);
198
203 mlir::ArrayRef<llvm::StringRef> rhsReversePrefix = {}, UnificationMap *unifications = nullptr
204);
205
210 mlir::ArrayRef<llvm::StringRef> rhsReversePrefix = {}, UnificationMap *unifications = nullptr
211);
212
216 mlir::Type lhs, mlir::Type rhs, mlir::ArrayRef<llvm::StringRef> rhsReversePrefix = {},
217 UnificationMap *unifications = nullptr
218);
219
222template <typename Iter1, typename Iter2>
223inline bool typeListsUnify(
224 Iter1 lhs, Iter2 rhs, mlir::ArrayRef<llvm::StringRef> rhsReversePrefix = {},
225 UnificationMap *unifications = nullptr
226) {
227 return (lhs.size() == rhs.size()) &&
228 std::equal(lhs.begin(), lhs.end(), rhs.begin(), [&](mlir::Type a, mlir::Type b) {
229 return typesUnify(a, b, rhsReversePrefix, unifications);
230 });
231}
232
233template <typename Iter1, typename Iter2>
235 Iter1 lhs, Iter2 rhs, mlir::ArrayRef<llvm::StringRef> rhsReversePrefix = {},
236 UnificationMap *unifications = nullptr
237) {
238 return lhs.size() == 1 && rhs.size() == 1 &&
239 typesUnify(lhs.front(), rhs.front(), rhsReversePrefix, unifications);
240}
241
249 mlir::Type oldTy, mlir::Type newTy,
250 llvm::function_ref<bool(mlir::Type oldTy, mlir::Type newTy)> knownOldToNew = nullptr
251);
252
253template <typename TypeClass> inline TypeClass getIfSingleton(mlir::TypeRange types) {
254 return (types.size() == 1) ? llvm::dyn_cast<TypeClass>(types.front()) : nullptr;
255}
256
257template <typename TypeClass> inline TypeClass getAtIndex(mlir::TypeRange types, size_t index) {
258 return (types.size() > index) ? llvm::dyn_cast<TypeClass>(types[index]) : nullptr;
259}
260
262mlir::FailureOr<mlir::IntegerAttr> forceIntType(mlir::IntegerAttr attr, EmitErrorFn emitError);
263
265mlir::FailureOr<mlir::Attribute> forceIntAttrType(mlir::Attribute attr, EmitErrorFn emitError);
266
268mlir::FailureOr<llvm::SmallVector<mlir::Attribute>>
269forceIntAttrTypes(llvm::ArrayRef<mlir::Attribute> attrList, EmitErrorFn emitError);
270
272mlir::LogicalResult verifyIntAttrType(EmitErrorFn emitError, mlir::Attribute in);
273
275mlir::LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, mlir::Attribute in);
276
278mlir::LogicalResult verifyStructTypeParams(EmitErrorFn emitError, mlir::ArrayAttr params);
279
281mlir::LogicalResult
282verifyArrayDimSizes(EmitErrorFn emitError, mlir::ArrayRef<mlir::Attribute> dimensionSizes);
283
285mlir::LogicalResult verifyArrayType(
286 EmitErrorFn emitError, mlir::Type elementType, mlir::ArrayRef<mlir::Attribute> dimensionSizes
287);
288
294mlir::LogicalResult verifySubArrayType(
295 EmitErrorFn emitError, array::ArrayType arrayType, array::ArrayType subArrayType
296);
297
301mlir::LogicalResult verifySubArrayOrElementType(
302 EmitErrorFn emitError, array::ArrayType arrayType, mlir::Type subArrayOrElemType
303);
304
305} // namespace llzk
static std::string from(mlir::ArrayRef< mlir::Attribute > attrs)
Return a brief string representation of the attribute list from a parameterized type.
Definition TypeHelper.h:58
static std::string from(const std::string &base, mlir::ArrayRef< mlir::Attribute > attrs)
Take an existing name prefix/base that contains N>=0 PLACEHOLDER character(s) and the Attribute list ...
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:52
LogicalResult verifyAffineMapAttrType(EmitErrorFn emitError, Attribute in)
void assertValidAttrForParamOfType(Attribute attr)
LogicalResult verifySubArrayType(EmitErrorFn emitError, ArrayType arrayType, ArrayType subArrayType)
Determine if the subArrayType is a valid subarray of arrayType.
FailureOr< Attribute > forceIntAttrType(Attribute attr, EmitErrorFn emitError)
uint64_t computeEmitEqCardinality(Type type)
bool isValidArrayType(Type type)
LogicalResult verifyIntAttrType(EmitErrorFn emitError, Attribute in)
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
Definition TypeHelper.h:223
bool isConcreteType(Type type, bool allowStructParams)
bool isValidArrayElemType(Type type)
TypeClass getIfSingleton(mlir::TypeRange types)
Definition TypeHelper.h:253
bool isValidGlobalType(Type type)
FailureOr< IntegerAttr > forceIntType(IntegerAttr attr, EmitErrorFn emitError)
bool singletonTypeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Definition TypeHelper.h:234
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
LogicalResult verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef< Attribute > dimensionSizes)
LogicalResult verifySubArrayOrElementType(EmitErrorFn emitError, ArrayType arrayType, Type subArrayOrElemType)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:183
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
FailureOr< SmallVector< Attribute > > forceIntAttrTypes(ArrayRef< Attribute > attrList, EmitErrorFn emitError)
Interval operator<<(const Interval &lhs, const Interval &rhs)
bool isValidEmitEqType(Type type)
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
Definition TypeHelper.h:257
bool isValidType(Type type)
bool arrayTypesUnify(ArrayType lhs, ArrayType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isDynamic(IntegerAttr intAttr)
Side reverse(Side in)
Definition TypeHelper.h:141
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool typeParamsUnify(const ArrayRef< Attribute > &lhsParams, const ArrayRef< Attribute > &rhsParams, UnificationMap *unifications)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
LogicalResult verifyStructTypeParams(EmitErrorFn emitError, ArrayAttr params)
bool hasAffineMapAttr(Type type)
mlir::LogicalResult checkValidType(EmitErrorFn emitError, mlir::Type type)
Definition TypeHelper.h:111
bool isValidConstReadType(Type type)
LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef< Attribute > dimensionSizes)
static bool isEqual(const T &lhs, const T &rhs)
Definition TypeHelper.h:163
static unsigned getHashValue(const T &val)
Definition TypeHelper.h:159