LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
ArrayTypeHelper.cpp
Go to the documentation of this file.
1//===-- ArrayTypeHelper.cpp -------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
16
18
19#include <mlir/Dialect/Arith/IR/Arith.h>
20#include <mlir/Dialect/Utils/IndexingUtils.h>
21#include <mlir/IR/Matchers.h>
22
23#include <llvm/ADT/APInt.h>
24#include <llvm/ADT/STLExtras.h>
25#include <llvm/ADT/STLFunctionalExtras.h>
26
27using namespace mlir;
28using namespace llzk;
29using namespace llzk::array;
30
31ArrayIndexGen::ArrayIndexGen(ArrayType t)
32 : shape(t.getShape()), linearSize(t.getNumElements()), strides(mlir::computeStrides(shape)) {}
33
35 assert(t.hasStaticShape());
36 return ArrayIndexGen(t);
37}
38
39namespace {
40
41inline bool isInRange(int64_t idx, int64_t dimSize) { return 0 <= idx && idx < dimSize; }
42
43// This can support Value, Attribute, and Operation* per matchPattern() implementations.
44template <typename TypeOfIndex> inline std::optional<int64_t> toI64(TypeOfIndex index) {
45 llvm::APInt idxAP;
46 if (!mlir::matchPattern(index, mlir::m_ConstantInt(&idxAP))) {
47 return std::nullopt;
48 }
49 return llzk::fromAPInt(idxAP);
50}
51
52template <typename OutType> struct CheckAndConvert {
53 template <typename InType>
54 static std::optional<OutType> from(InType /*index*/, int64_t /*dimSize*/) {
55 static_assert(sizeof(OutType) == 0, "CheckAndConvert not implemented for requested type.");
56 assert(false);
57 }
58};
59
60// Specialization to produce `int64_t`
61template <> struct CheckAndConvert<int64_t> {
62 template <typename InType> static std::optional<int64_t> from(InType index, int64_t dimSize) {
63 if (auto idxVal = toI64<InType>(index)) {
64 if (isInRange(*idxVal, dimSize)) {
65 return idxVal;
66 }
67 }
68 return std::nullopt;
69 }
70};
71
72// Specialization to produce `Attribute`
73template <> struct CheckAndConvert<Attribute> {
74 template <typename InType> static std::optional<Attribute> from(InType index, int64_t dimSize) {
75 if (auto c = CheckAndConvert<int64_t>::from(index, dimSize)) {
76 return IntegerAttr::get(IndexType::get(index.getContext()), *c);
77 }
78 return std::nullopt;
79 }
80};
81
82template <typename OutType, typename InListType>
83inline std::optional<SmallVector<OutType>>
84checkAndConvertMulti(InListType multiDimIndex, ArrayRef<int64_t> shape, bool mustBeEqual) {
85 if (mustBeEqual) {
86 assert(
87 llvm::all_equal({llvm::range_size(multiDimIndex), llvm::range_size(shape)}) &&
88 "Iteratees do not have equal length"
89 );
90 }
91 SmallVector<OutType> ret;
92 for (auto [idx, dimSize] : llvm::zip_first(multiDimIndex, shape)) {
93 std::optional<OutType> next = CheckAndConvert<OutType>::from(idx, dimSize);
94 if (!next.has_value()) {
95 return std::nullopt;
96 }
97 ret.push_back(next.value());
98 }
99 return ret;
100}
101
102inline std::optional<int64_t> linearizeImpl(
103 ArrayRef<int64_t> multiDimIndex, const ArrayRef<int64_t> &shape,
104 const SmallVector<int64_t> &strides
105) {
106 // Ensure the index for each dimension is in range. Then the linearized index will be as well.
107 for (auto [idx, dimSize] : llvm::zip_equal(multiDimIndex, shape)) {
108 if (!isInRange(idx, dimSize)) {
109 return std::nullopt;
110 }
111 }
112 return mlir::linearize(multiDimIndex, strides);
113}
114
115template <typename TypeOfIndex>
116inline std::optional<int64_t> linearizeImpl(
117 ArrayRef<TypeOfIndex> multiDimIndex, const ArrayRef<int64_t> &shape,
118 const SmallVector<int64_t> &strides
119) {
120 std::optional<SmallVector<int64_t>> conv =
121 checkAndConvertMulti<int64_t>(multiDimIndex, shape, true /*TODO: I think*/);
122 if (!conv.has_value()) {
123 return std::nullopt;
124 }
125 return mlir::linearize(conv.value(), strides);
126}
127
128template <typename ResultElemType>
129inline std::optional<SmallVector<ResultElemType>> delinearizeImpl(
130 int64_t linearIndex, int64_t linearSize, const SmallVector<int64_t> &strides, MLIRContext *ctx,
131 llvm::function_ref<ResultElemType(IntegerAttr)> convert
132) {
133 if (!isInRange(linearIndex, linearSize)) {
134 return std::nullopt;
135 }
136 SmallVector<ResultElemType> ret;
137 for (int64_t idx : mlir::delinearize(linearIndex, strides)) {
138 ret.push_back(convert(IntegerAttr::get(IndexType::get(ctx), idx)));
139 }
140 return ret;
141}
142
143} // namespace
144
145std::optional<SmallVector<Value>>
146ArrayIndexGen::delinearize(int64_t linearIndex, Location loc, OpBuilder &bldr) const {
147 return delinearizeImpl<Value>(
148 linearIndex, linearSize, strides, bldr.getContext(),
149 [&](IntegerAttr a) { return bldr.create<arith::ConstantOp>(loc, a); }
150 );
151}
152
153std::optional<SmallVector<Attribute>>
154ArrayIndexGen::delinearize(int64_t linearIndex, MLIRContext *ctx) const {
155 return delinearizeImpl<Attribute>(linearIndex, linearSize, strides, ctx, [](IntegerAttr a) {
156 return a;
157 });
158}
159
160template <typename InListType> std::optional<int64_t> ArrayIndexGen::linearize(InListType) const {
161 static_assert(sizeof(InListType) == 0, "linearize() not implemented for requested type.");
162 llvm_unreachable("must have concrete instantiation");
163 return std::nullopt;
164}
165
166template <> std::optional<int64_t> ArrayIndexGen::linearize(ArrayRef<int64_t> multiDimIndex) const {
167 return linearizeImpl(multiDimIndex, shape, strides);
168}
169
170template <>
171std::optional<int64_t> ArrayIndexGen::linearize(ArrayRef<Attribute> multiDimIndex) const {
172 return linearizeImpl(multiDimIndex, shape, strides);
173}
174
175template <>
176std::optional<int64_t> ArrayIndexGen::linearize(ArrayRef<Operation *> multiDimIndex) const {
177 return linearizeImpl(multiDimIndex, shape, strides);
178}
179
180template <> std::optional<int64_t> ArrayIndexGen::linearize(ArrayRef<Value> multiDimIndex) const {
181 return linearizeImpl(multiDimIndex, shape, strides);
182}
183
184template <typename InListType>
185std::optional<SmallVector<Attribute>> ArrayIndexGen::checkAndConvert(InListType) {
186 static_assert(sizeof(InListType) == 0, "checkAndConvert() not implemented for requested type.");
187 llvm_unreachable("must have concrete instantiation");
188 return std::nullopt;
189}
190
191template <>
192std::optional<SmallVector<Attribute>> ArrayIndexGen::checkAndConvert(OperandRange multiDimIndex) {
193 return checkAndConvertMulti<Attribute>(multiDimIndex, shape, false);
194}
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
Definition LICENSE.txt:45
std::optional< llvm::SmallVector< mlir::Attribute > > checkAndConvert(InListType multiDimIndex)
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< int64_t > linearize(InListType multiDimIndex) const
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
int64_t fromAPInt(const llvm::APInt &i)