LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Types.cpp
Go to the documentation of this file.
1//===-- Types.cpp - Array type implementations ------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
15
16#include <mlir/IR/BuiltinTypeInterfaces.h>
17
18using namespace mlir;
19
20namespace llzk::array {
21
23 MLIRContext *ctx, ArrayRef<int64_t> shape, SmallVector<Attribute> &dimensionSizes
24) {
25 Builder builder(ctx);
26 dimensionSizes = llvm::map_to_vector(shape, [&builder](int64_t v) -> Attribute {
27 return builder.getIndexAttr(v);
28 });
29 assert(dimensionSizes.size() == shape.size()); // fully computed by this function
30 return success();
31}
32
34 EmitErrorFn emitError, ArrayRef<Attribute> dimensionSizes, SmallVector<int64_t> &shape
35) {
36 assert(shape.empty()); // fully computed by this function
37
38 // Ensure all Attributes are valid Attribute classes for ArrayType.
39 if (failed(verifyArrayDimSizes(emitError, dimensionSizes))) {
40 return failure();
41 }
42
43 // Convert the Attributes to int64_t
44 for (Attribute a : dimensionSizes) {
45 if (auto p = llvm::dyn_cast_if_present<IntegerAttr>(a)) {
46 shape.push_back(fromAPInt(p.getValue()));
47 } else if (llvm::isa_and_present<SymbolRefAttr, AffineMapAttr>(a)) {
48 // The ShapedTypeInterface uses 'kDynamic' for dimensions with non-static size.
49 shape.push_back(ShapedType::kDynamic);
50 } else {
51 // For every Attribute class in ArrayDimensionTypes, there should be a case here.
52 llvm::report_fatal_error("computeShapeFromDims() is out of sync with ArrayDimensionTypes");
53 return failure();
54 }
55 }
56 assert(shape.size() == dimensionSizes.size()); // fully computed by this function
57 return success();
58}
59
61 AsmParser &parser, SmallVector<int64_t> &shape,
62 SmallVector<Attribute> dimensionSizes // NOLINT(performance-unnecessary-value-param)
63) {
64 // This is not actually parsing. It's computing the derived
65 // `shape` from the `dimensionSizes` attributes.
66 auto emitError = [&parser] {
67 return InFlightDiagnosticWrapper(parser.emitError(parser.getCurrentLocation()));
68 };
69 return computeShapeFromDims(emitError, dimensionSizes, shape);
70}
71void printDerivedShape(AsmPrinter &, ArrayRef<int64_t>, ArrayRef<Attribute>) {
72 // nothing to print, it's derived and therefore not represented in the output
73}
74
75LogicalResult ArrayType::verify(
76 function_ref<InFlightDiagnostic()> emitError, Type elementType,
77 ArrayRef<Attribute> dimensionSizes, ArrayRef<int64_t> /*shape*/
78) {
79 return verifyArrayType(wrapNonNullableInFlightDiagnostic(emitError), elementType, dimensionSizes);
80}
81
82ArrayType ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape, Type elementType) const {
83 return ArrayType::get(elementType, shape.has_value() ? shape.value() : getShape());
84}
85
87ArrayType::cloneWith(Type elementType, std::optional<ArrayRef<Attribute>> dimensions) const {
88 return ArrayType::get(
89 elementType, dimensions.has_value() ? dimensions.value() : getDimensionSizes()
90 );
91}
92
93namespace {
94
95inline ArrayType createArrayOfSizeOne(Type elemType) { return ArrayType::get(elemType, {1}); }
96
97} // namespace
98
99bool ArrayType::collectIndices(llvm::function_ref<void(ArrayAttr)> inserter) const {
100 if (!hasStaticShape()) {
101 return false;
102 }
103 MLIRContext *ctx = getContext();
104 ArrayIndexGen idxGen = ArrayIndexGen::from(*this);
105 for (int64_t e = getNumElements(), i = 0; i < e; ++i) {
106 auto delinearized = idxGen.delinearize(i, ctx);
107 assert(delinearized.has_value()); // cannot fail since loop is over array size
108 inserter(ArrayAttr::get(ctx, delinearized.value()));
109 }
110 return true;
111}
112
113std::optional<SmallVector<ArrayAttr>> ArrayType::getSubelementIndices() const {
114 SmallVector<ArrayAttr> ret;
115 bool success = collectIndices([&ret](ArrayAttr v) { ret.push_back(v); });
116 return success ? std::make_optional(ret) : std::nullopt;
117}
118
120std::optional<DenseMap<Attribute, Type>> ArrayType::getSubelementIndexMap() const {
121 DenseMap<Attribute, Type> ret;
122 Type destructAs = createArrayOfSizeOne(getElementType());
123 bool success = collectIndices([&](ArrayAttr v) { ret[v] = destructAs; });
124 return success ? std::make_optional(ret) : std::nullopt;
125}
126
128Type ArrayType::getTypeAtIndex(Attribute index) const {
129 if (!hasStaticShape()) {
130 return nullptr;
131 }
132 // Since indexing is multi-dimensional, `index` should be ArrayAttr
133 ArrayAttr indexAttr = llvm::dyn_cast<ArrayAttr>(index);
134 if (!indexAttr) {
135 return nullptr;
136 }
137 // Ensure the shape is valid and dimensions are valid for the shape by computing linear index.
138 if (!ArrayIndexGen::from(*this).linearize(indexAttr.getValue())) {
139 return nullptr;
140 }
141 // If that's successful, the destructured type is the size-1 array of the element type.
142 return createArrayOfSizeOne(getElementType());
143}
144
145ParseResult parseAttrVec(AsmParser &parser, SmallVector<Attribute> &value) {
146 SmallVector<Attribute> attrs;
147 auto parseElement = [&parser, &value]() -> ParseResult {
148 auto qResult = parser.parseOptionalQuestion();
149 if (succeeded(qResult)) {
150 auto &builder = parser.getBuilder();
151 value.push_back(builder.getIntegerAttr(builder.getIndexType(), ShapedType::kDynamic));
152 return qResult;
153 }
154 auto attrParseResult = FieldParser<Attribute>::parse(parser);
155 if (succeeded(attrParseResult)) {
156 auto emitError = [&parser] {
157 return InFlightDiagnosticWrapper(parser.emitError(parser.getCurrentLocation()));
158 };
159 FailureOr<Attribute> forced = forceIntAttrType(*attrParseResult, emitError);
160 if (failed(forced)) {
161 return failure();
162 }
163 value.push_back(*forced);
164 }
165 return ParseResult(attrParseResult);
166 };
167 if (failed(parser.parseCommaSeparatedList(AsmParser::Delimiter::None, parseElement))) {
168 return parser.emitError(parser.getCurrentLocation(), "failed to parse array dimensions");
169 }
170 return success();
171}
172
173void printAttrVec(AsmPrinter &printer, ArrayRef<Attribute> value) {
174 printAttrs(printer, value, ",");
175}
176
177} // namespace llzk::array
Wrapper around InFlightDiagnostic that can either be a regular InFlightDiagnostic or a special versio...
Definition ErrorHelper.h:26
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
::llvm::ArrayRef< int64_t > getShape() const
ArrayType cloneWith(std::optional<::llvm::ArrayRef< int64_t > > shape, ::mlir::Type elementType) const
Clone this type with the given shape and element type.
::llvm::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes, ::llvm::ArrayRef< int64_t > shape)
std::optional<::llvm::SmallVector<::mlir::ArrayAttr > > getSubelementIndices() const
Return a list of all valid indices for this ArrayType.
Definition Types.cpp:113
::mlir::Type getElementType() const
::std::optional<::llvm::DenseMap<::mlir::Attribute, ::mlir::Type > > getSubelementIndexMap() const
Required by DestructurableTypeInterface / SROA pass.
Definition Types.cpp:120
::mlir::Type getTypeAtIndex(::mlir::Attribute index) const
Required by DestructurableTypeInterface / SROA pass.
Definition Types.cpp:128
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
Definition Types.cpp.inc:83
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
LogicalResult computeShapeFromDims(EmitErrorFn emitError, ArrayRef< Attribute > dimensionSizes, SmallVector< int64_t > &shape)
Definition Types.cpp:33
ParseResult parseAttrVec(AsmParser &parser, SmallVector< Attribute > &value)
Definition Types.cpp:145
void printDerivedShape(AsmPrinter &, ArrayRef< int64_t >, ArrayRef< Attribute >)
Definition Types.cpp:71
LogicalResult computeDimsFromShape(MLIRContext *ctx, ArrayRef< int64_t > shape, SmallVector< Attribute > &dimensionSizes)
Definition Types.cpp:22
void printAttrVec(AsmPrinter &printer, ArrayRef< Attribute > value)
Definition Types.cpp:173
ParseResult parseDerivedShape(AsmParser &parser, SmallVector< int64_t > &shape, SmallVector< Attribute > dimensionSizes)
Definition Types.cpp:60
FailureOr< Attribute > forceIntAttrType(Attribute attr, EmitErrorFn emitError)
void printAttrs(AsmPrinter &printer, ArrayRef< Attribute > attrs, const StringRef &separator)
LogicalResult verifyArrayType(EmitErrorFn emitError, Type elementType, ArrayRef< Attribute > dimensionSizes)
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
OwningEmitErrorFn wrapNonNullableInFlightDiagnostic(llvm::function_ref< mlir::InFlightDiagnostic()> emitError)
int64_t fromAPInt(const llvm::APInt &i)
LogicalResult verifyArrayDimSizes(EmitErrorFn emitError, ArrayRef< Attribute > dimensionSizes)