19#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
20#include <mlir/IR/Value.h>
22#include <llvm/Support/Debug.h>
25#include <unordered_set>
27#define DEBUG_TYPE "llzk-constrain-ref-lattice"
34using namespace component;
36using namespace polymorphic;
49std::pair<SourceRefLatticeValue, mlir::ChangeResult>
52 auto res = mlir::ChangeResult::NoChange;
53 if (newVal.isScalar()) {
54 res = newVal.translateScalar(translation);
56 for (
auto &elem : newVal.getArrayValue()) {
57 auto [newElem, elemRes] = elem->translate(translation);
65mlir::FailureOr<std::pair<SourceRefLatticeValue, mlir::ChangeResult>>
68 auto transform = [&idx](
const SourceRef &r) -> mlir::FailureOr<SourceRef> {
69 return r.createChild(idx);
74mlir::FailureOr<std::pair<SourceRefLatticeValue, mlir::ChangeResult>>
80 std::vector<size_t> currIdxs {0};
81 for (
unsigned i = 0; i < indices.size(); i++) {
82 const auto &idx = indices[i];
85 std::vector<size_t> newIdxs;
86 ensure(idx.isIndex() || idx.isIndexRange(),
"wrong type of index for array");
88 int64_t idxVal(idx.getIndex());
90 currIdxs.begin(), currIdxs.end(), std::back_inserter(newIdxs),
91 [&currDim, &idxVal](
size_t j) { return j * currDim + idxVal; }
94 auto [low, high] = idx.getIndexRange();
95 int64_t lowInt(low), highInt(high);
96 for (int64_t idxVal = lowInt; idxVal < highInt; idxVal++) {
98 currIdxs.begin(), currIdxs.end(), std::back_inserter(newIdxs),
99 [&currDim, &idxVal](
size_t j) { return j * currDim + idxVal; }
106 std::vector<int64_t> newArrayDims;
110 newArrayDims.push_back(dim);
113 if (newArrayDims.empty()) {
116 for (
auto idx : currIdxs) {
119 return std::make_pair(extractedVal, mlir::ChangeResult::Change);
123 for (
auto chunkStart : currIdxs) {
124 for (
size_t i = 0; i < chunkSz; i++) {
128 return std::make_pair(extractedVal, mlir::ChangeResult::Change);
131 auto currVal = *
this;
132 auto res = mlir::ChangeResult::NoChange;
133 for (
const auto &idx : indices) {
134 auto transform = [&idx](
const SourceRef &r) -> mlir::FailureOr<SourceRef> {
135 return r.createChild(idx);
137 auto transformedVal = currVal.elementwiseTransform(transform);
138 if (failed(transformedVal)) {
139 return mlir::failure();
141 auto [newVal, transformRes] = *transformedVal;
142 currVal = std::move(newVal);
145 return std::make_pair(currVal, res);
150 auto res = mlir::ChangeResult::NoChange;
158 for (
const SourceRef &currRef : currVal) {
159 for (
const auto &[prefix, replacementVal] : translation) {
160 if (currRef.isValidPrefix(prefix)) {
161 for (
const SourceRef &replacementPrefix : replacementVal.foldToScalar()) {
162 auto translatedRefRes = currRef.translate(prefix, replacementPrefix);
163 if (succeeded(translatedRefRes)) {
164 res |=
insert(*translatedRefRes);
173mlir::FailureOr<std::pair<SourceRefLatticeValue, mlir::ChangeResult>>
175 llvm::function_ref<mlir::FailureOr<SourceRef>(
const SourceRef &)> transform
178 auto res = mlir::ChangeResult::NoChange;
179 if (newVal.isScalar()) {
181 for (
const auto &ref : newVal.getScalarValue()) {
182 auto transformedRef = transform(ref);
183 if (failed(transformedRef)) {
184 return mlir::failure();
186 auto [_, inserted] = indexed.insert(*transformedRef);
188 res |= mlir::ChangeResult::Change;
191 newVal.getScalarValue() = indexed;
193 for (
auto &elem : newVal.getArrayValue()) {
194 auto transformedElem = elem->elementwiseTransform(transform);
195 if (failed(transformedElem)) {
196 return mlir::failure();
198 auto [newElem, elemRes] = *transformedElem;
199 (*elem) = std::move(newElem);
203 return std::make_pair(newVal, res);
214 if (
auto blockArg = llvm::dyn_cast<mlir::BlockArgument>(val)) {
216 }
else if (
auto *defOp = val.getDefiningOp()) {
217 if (
auto feltConst = llvm::dyn_cast<FeltConstantOp>(defOp)) {
219 }
else if (
auto constIdx = llvm::dyn_cast<mlir::arith::ConstantIndexOp>(defOp)) {
221 }
else if (
auto readConst = llvm::dyn_cast<ConstReadOp>(defOp)) {
223 }
else if (
auto structNew = llvm::dyn_cast<CreateStructOp>(defOp)) {
225 }
else if (
auto nonDet = llvm::dyn_cast<NonDetOp>(defOp)) {
227 }
else if (
auto createArray = llvm::dyn_cast<CreateArrayOp>(defOp)) {
228 return SourceRef(createArray->getResult(0));
229 }
else if (llvm::isa<function::CallOp>(defOp)) {
230 auto callResult = llvm::dyn_cast<mlir::OpResult>(val);
231 ensure(callResult !=
nullptr,
"function.call value should be an OpResult");
235 return mlir::failure();
239 if (
auto asVal = llvm::dyn_cast_if_present<Value>(v)) {
241 if (mlir::succeeded(sourceRef)) {
253 llvm::report_fatal_error(
"meet operation is not supported for SourceRefLattice");
254 return ChangeResult::NoChange;
258 os <<
"SourceRefLattice { " << value <<
" }";
262 return value.setValue(newValue);
273raw_ostream &
operator<<(raw_ostream &os, llvm::PointerUnion<mlir::Value, mlir::Operation *> ptr) {
274 if (
auto asVal = llvm::dyn_cast_if_present<Value>(ptr)) {
276 }
else if (
auto *asOp = llvm::dyn_cast_if_present<Operation *>(ptr)) {
279 os <<
"<<null PointerUnion>>";
Defines an index into an LLZK object.
A value at a given point of the SourceRefLattice.
mlir::FailureOr< std::pair< SourceRefLatticeValue, mlir::ChangeResult > > referenceMember(SymbolLookupResult< component::MemberDefOp > memberRef) const
Add the given memberRef to the SourceRefs contained within this value.
virtual mlir::FailureOr< std::pair< SourceRefLatticeValue, mlir::ChangeResult > > elementwiseTransform(llvm::function_ref< mlir::FailureOr< SourceRef >(const SourceRef &)> transform) const
Perform a recursive transformation over all elements of this value and return a new value with the mo...
mlir::ChangeResult insert(const SourceRef &rhs)
Directly insert the ref into this value.
std::pair< SourceRefLatticeValue, mlir::ChangeResult > translate(const TranslationMap &translation) const
For the refs contained in this value, translate them given the translation map and return the transfo...
mlir::FailureOr< std::pair< SourceRefLatticeValue, mlir::ChangeResult > > extract(const std::vector< SourceRefIndex > &indices) const
Perform an array.extract or array.read operation, depending on how many indices are provided.
SourceRefLatticeValue(ScalarTy s)
mlir::ChangeResult translateScalar(const TranslationMap &translation)
Translate this value using the translation map, assuming this value is a scalar.
Sparse SSA-value lattice for SourceRef propagation.
mlir::ChangeResult join(const AbstractSparseLattice &rhs) override
mlir::ChangeResult setValue(const LatticeValue &newValue)
mlir::ChangeResult meet(const AbstractSparseLattice &rhs) override
static SourceRefLatticeValue getDefaultValue(ValueTy v)
void print(mlir::raw_ostream &os) const override
static mlir::FailureOr< SourceRef > getSourceRef(mlir::Value val)
If val is the source of other values (i.e., a block argument, an allocation-like op result,...
llvm::PointerUnion< mlir::Value, mlir::Operation * > ValueTy
SourceRefLatticeValue LatticeValue
A reference to a "source", which is the base value from which other SSA values are derived.
size_t getNumArrayDims() const
mlir::ChangeResult updateScalar(const ScalarTy &rhs)
int64_t getArrayDim(unsigned i) const
std::variant< ScalarTy, ArrayTy > & getValue()
mlir::ChangeResult update(const Derived &rhs)
Union this value with that of rhs.
mlir::ChangeResult foldAndUpdate(const SourceRefLatticeValue &rhs)
const ScalarTy & getScalarValue() const
const SourceRefLatticeValue & getElemFlatIdx(size_t i) const
void print(mlir::raw_ostream &os) const
raw_ostream & operator<<(raw_ostream &os, llvm::PointerUnion< mlir::Value, mlir::Operation * > ptr)
void ensure(bool condition, const llvm::Twine &errMsg)
Interval operator<<(const Interval &lhs, const Interval &rhs)
std::unordered_map< SourceRef, SourceRefLatticeValue, SourceRef::Hash > TranslationMap