LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
AbstractLatticeValue.h
Go to the documentation of this file.
1//===-- AbstractLatticeValue.h ----------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
10#pragma once
11
12#include "llzk/Util/Debug.h"
14
15#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
16#include <mlir/Support/LLVM.h>
17
18#include <llvm/Support/Debug.h>
19
20#include <concepts>
21#include <type_traits>
22#include <variant>
23
24#define DEBUG_TYPE "llzk-abstract-lattice-value"
25
26namespace llzk::dataflow {
27
28template <typename Val>
30 // Require default constructable
31 std::default_initializable<Val> && requires(Val lhs, Val rhs, mlir::raw_ostream &os) {
32 // Require a form of print function
33 { os << lhs } -> std::same_as<mlir::raw_ostream &>;
34 // Require comparability
35 { lhs == rhs } -> std::same_as<bool>;
36 // Require the ability to combine two scalar values
37 { lhs.join(rhs) } -> std::same_as<Val &>;
38 };
39
40template <typename Derived, ScalarLatticeValue ScalarTy> class AbstractLatticeValue {
41 friend Derived;
42
50 using ArrayTy = std::vector<std::unique_ptr<Derived>>;
51
54 static ArrayTy constructArrayTy(const mlir::ArrayRef<int64_t> &shape) {
55 size_t totalElem = 1;
56 for (auto dim : shape) {
57 ensure(!mlir::ShapedType::isDynamic(dim), "Cannot pre-allocate dynamically-sized array");
58 totalElem *= dim;
59 }
60 ArrayTy arr(totalElem);
61 for (auto it = arr.begin(); it != arr.end(); it++) {
62 *it = std::make_unique<Derived>();
63 }
64 return arr;
65 }
66
67 static inline bool isDynamicArray(const mlir::ArrayRef<int64_t> &shape) {
68 return mlir::ShapedType::isDynamicShape(shape);
69 }
70
71 explicit AbstractLatticeValue(ScalarTy s)
72 : value(s), arrayShape(std::nullopt), isDynamic(false) {}
73 AbstractLatticeValue() : AbstractLatticeValue(ScalarTy()) {}
74 explicit AbstractLatticeValue(const mlir::ArrayRef<int64_t> shape)
75 : arrayShape(shape), isDynamic(isDynamicArray(shape)) {
76 if (isDynamic) {
77 value = ScalarTy();
78 } else {
79 value = constructArrayTy(shape);
80 }
81 }
82
83 AbstractLatticeValue(const AbstractLatticeValue &rhs) { *this = rhs; }
84 AbstractLatticeValue(AbstractLatticeValue &&rhs) = default;
85
86 // Enable copying by duplicating unique_ptrs and copying the contained values.
87 AbstractLatticeValue &operator=(const AbstractLatticeValue &rhs) {
88 copyArrayShape(rhs);
89 if (rhs.isScalar() || rhs.isDynamicArray()) {
90 getValue() = rhs.getScalarValue();
91 } else {
92 // create an empty array of the same size
93 getValue() = constructArrayTy(rhs.getArrayShape());
94 auto &lhsArr = getArrayValue();
95 auto &rhsArr = rhs.getArrayValue();
96 for (unsigned i = 0; i < lhsArr.size(); i++) {
97 // Recursive copy assignment of lattice values
98 *lhsArr[i] = *rhsArr[i];
99 }
100 }
101 return *this;
102 }
103 AbstractLatticeValue &operator=(AbstractLatticeValue &&rhs) = default;
104
105public:
106 bool isScalar() const { return std::holds_alternative<ScalarTy>(value); }
107 bool isSingleValue() const { return isScalar() && getScalarValue().size() == 1; }
108 bool isArray() const { return std::holds_alternative<ArrayTy>(value); }
109 bool isDynamicArray() const { return isDynamic; }
110
111 const ScalarTy &getScalarValue() const {
112 ensure(isScalar(), "not a scalar value");
113 return std::get<ScalarTy>(value);
114 }
115
116 ScalarTy &getScalarValue() {
117 ensure(isScalar(), "not a scalar value");
118 return std::get<ScalarTy>(value);
119 }
120
121 const ArrayTy &getArrayValue() const {
122 ensure(isArray() && !isDynamicArray(), "not a static array value");
123 return std::get<ArrayTy>(value);
124 }
125
126 ArrayTy &getArrayValue() {
127 ensure(isArray() && !isDynamicArray(), "not a static array value");
128 return std::get<ArrayTy>(value);
129 }
130
132 const Derived &getElemFlatIdx(size_t i) const {
133 ensure(isArray() && !isDynamicArray(), "not a static array value");
134 auto &arr = getArrayValue();
135 ensure(i < arr.size(), "index out of range");
136 return *arr.at(i);
137 }
138
139 Derived &getElemFlatIdx(size_t i) {
140 ensure(isArray() && !isDynamicArray(), "not a static array value");
141 auto &arr = getArrayValue();
142 ensure(i < arr.size(), "index out of range");
143 return *arr.at(i);
144 }
145
146 size_t getArraySize() const { return getArrayValue().size(); }
147
148 size_t getNumArrayDims() const { return getArrayShape().size(); }
149
150 void print(mlir::raw_ostream &os) const {
151 if (isScalar() || isDynamicArray()) {
152 os << getScalarValue();
153 } else {
154 os << "[ ";
155 const auto &arr = getArrayValue();
156 for (auto it = arr.begin(); it != arr.end();) {
157 (*it)->print(os);
158 it++;
159 if (it != arr.end()) {
160 os << ", ";
161 } else {
162 os << ' ';
163 }
164 }
165 os << ']';
166 }
167 }
168
171 ScalarTy foldToScalar() const {
172 if (isScalar()) {
173 return getScalarValue();
174 }
175
176 ScalarTy res;
177 for (auto &val : getArrayValue()) {
178 auto rhs = val->foldToScalar();
179 res.join(rhs);
180 }
181 return res;
182 }
183
186 mlir::ChangeResult setValue(const AbstractLatticeValue &rhs) {
187 if (*this == rhs) {
188 return mlir::ChangeResult::NoChange;
189 }
190 *this = rhs;
191 return mlir::ChangeResult::Change;
192 }
193
195 mlir::ChangeResult update(const Derived &rhs) {
196 if (isScalar() && rhs.isScalar()) {
197 return updateScalar(rhs.getScalarValue());
198 } else if (isArray() && rhs.isArray() && getArraySize() == rhs.getArraySize()) {
199 return updateArray(rhs.getArrayValue());
200 } else {
201 return foldAndUpdate(rhs);
202 }
203 }
204
205 bool operator==(const AbstractLatticeValue &rhs) const {
206 if (isScalar() && rhs.isScalar()) {
207 return getScalarValue() == rhs.getScalarValue();
208 } else if (isArray() && rhs.isArray() && getArraySize() == rhs.getArraySize()) {
209 for (size_t i = 0; i < getArraySize(); i++) {
210 if (getElemFlatIdx(i) != rhs.getElemFlatIdx(i)) {
211 return false;
212 }
213 }
214 return true;
215 }
216 return false;
217 }
218
219protected:
220 std::variant<ScalarTy, ArrayTy> &getValue() { return value; }
221
222 const std::vector<int64_t> &getArrayShape() const {
223 ensure(arrayShape != std::nullopt, "not an array value");
224 return arrayShape.value();
225 }
226
227 int64_t getArrayDim(unsigned i) const {
228 const auto &arrShape = getArrayShape();
229 ensure(i < arrShape.size(), "dimension index out of bounds");
230 return arrShape.at(i);
231 }
232
233 void copyArrayShape(const AbstractLatticeValue &rhs) {
234 arrayShape = rhs.arrayShape;
235 isDynamic = rhs.isDynamic;
236 }
237
239 mlir::ChangeResult updateScalar(const ScalarTy &rhs) {
240 auto lhs = getScalarValue();
241 lhs.join(rhs);
242 if (getScalarValue() == lhs) {
243 return mlir::ChangeResult::NoChange;
244 }
245 getScalarValue() = lhs;
246 return mlir::ChangeResult::Change;
247 }
248
250 mlir::ChangeResult updateArray(const ArrayTy &rhs) {
251 mlir::ChangeResult res = mlir::ChangeResult::NoChange;
252 auto &lhs = getArrayValue();
253 for (size_t i = 0; i < getArraySize(); i++) {
254 res |= lhs[i]->update(*rhs.at(i));
255 }
256 return res;
257 }
258
261 mlir::ChangeResult foldAndUpdate(const Derived &rhs) {
262 auto folded = foldToScalar();
263 auto rhsScalar = rhs.foldToScalar();
264 folded.join(rhsScalar);
265 if (isScalar() && getScalarValue() == folded) {
266 return mlir::ChangeResult::NoChange;
267 }
268 getValue() = folded;
269 return mlir::ChangeResult::Change;
270 }
271
272private:
273 std::variant<ScalarTy, ArrayTy> value;
274 std::optional<std::vector<int64_t>> arrayShape;
275 bool isDynamic;
276};
277
278template <typename Derived, ScalarLatticeValue ScalarTy>
279mlir::raw_ostream &
280operator<<(mlir::raw_ostream &os, const AbstractLatticeValue<Derived, ScalarTy> &v) {
281 v.print(os);
282 return os;
283}
284
285} // namespace llzk::dataflow
286
287#undef DEBUG_TYPE
bool operator==(const AbstractLatticeValue &rhs) const
mlir::ChangeResult updateArray(const ArrayTy &rhs)
Union this value with the given array.
ScalarTy foldToScalar() const
If this is an array value, combine all elements into a single scalar value and return it.
mlir::ChangeResult updateScalar(const ScalarTy &rhs)
Union this value with the given scalar.
const std::vector< int64_t > & getArrayShape() const
std::variant< ScalarTy, ArrayTy > & getValue()
mlir::ChangeResult setValue(const AbstractLatticeValue &rhs)
Sets this value to be equal to rhs.
mlir::ChangeResult update(const Derived &rhs)
Union this value with that of rhs.
mlir::ChangeResult foldAndUpdate(const Derived &rhs)
Folds the current value into a scalar and folds rhs to a scalar and updates the current value to the ...
void copyArrayShape(const AbstractLatticeValue &rhs)
const Derived & getElemFlatIdx(size_t i) const
Directly index into the flattened array using a single index.
void print(mlir::raw_ostream &os) const
mlir::raw_ostream & operator<<(mlir::raw_ostream &os, const AbstractLatticeValue< Derived, ScalarTy > &v)
void ensure(bool condition, const llvm::Twine &errMsg)
bool isDynamic(IntegerAttr intAttr)