22#include <mlir/IR/AsmState.h>
38compareDynamicAPInt(
const llvm::DynamicAPInt &lhs,
const llvm::DynamicAPInt &rhs) {
40 return std::strong_ordering::less;
43 return std::strong_ordering::greater;
45 return std::strong_ordering::equal;
48std::strong_ordering compareStringRef(llvm::StringRef lhs, llvm::StringRef rhs) {
49 int cmp = lhs.compare(rhs);
51 return std::strong_ordering::less;
54 return std::strong_ordering::greater;
56 return std::strong_ordering::equal;
60compareSourceRefPaths(llvm::ArrayRef<SourceRefIndex> lhs, llvm::ArrayRef<SourceRefIndex> rhs) {
61 for (
size_t i = 0; i < lhs.size() && i < rhs.size(); i++) {
62 if (
auto cmp = lhs[i] <=> rhs[i];
cmp != std::strong_ordering::equal) {
66 return lhs.size() <=> rhs.size();
80 if (ShapedType::isDynamic(int64_t(high))) {
83 os << low <<
':' << high;
91 return std::strong_ordering::less;
94 return std::strong_ordering::greater;
96 return std::strong_ordering::equal;
104 if (
auto cmp = compareDynamicAPInt(ll, rl);
cmp != std::strong_ordering::equal) {
107 return compareDynamicAPInt(lu, ru);
111 return std::strong_ordering::less;
114 return std::strong_ordering::greater;
117 return std::strong_ordering::less;
119 return std::strong_ordering::greater;
129 unsigned requiredBits = idx.getSignificantBits();
130 auto hash = llvm::hash_value(idx.trunc(requiredBits));
134 return llvm::hash_value(std::get<0>(r)) ^ llvm::hash_value(std::get<1>(r));
142SourceRef::SortCategory SourceRef::getSortCategory()
const {
143 if (isBlockArgument()) {
144 return SortCategory::BlockArgument;
146 if (isCreateStructOp()) {
147 return SortCategory::CreateStruct;
150 return SortCategory::NonDet;
153 return SortCategory::RootResult;
155 if (isTemplateConstant()) {
156 return SortCategory::TemplateConstant;
158 if (isConstantIndex()) {
159 return SortCategory::ConstantIndex;
161 if (isConstantFelt()) {
162 return SortCategory::ConstantFelt;
165 llvm::errs() << *
this <<
'\n';
166 llvm_unreachable(
"unhandled SourceRef sort category");
169StringRef SourceRef::getTemplateConstantName()
const {
170 auto constantVal = getConstant();
171 ensure(succeeded(constantVal),
"template constant must be constant");
172 auto constRead = llvm::dyn_cast<ConstReadOp>(constantVal->getDefiningOp());
173 ensure(constRead,
"template constant must be backed by const.read");
174 return constRead.getConstName();
178SourceRef::compareWithinCategory(
const SourceRef &rhs, SortCategory category)
const {
180 case SortCategory::BlockArgument: {
181 if (
auto cmp = *getInputNum() <=> *rhs.getInputNum();
cmp != std::strong_ordering::equal) {
184 if (
auto cmp = getAsOpaquePointer() <=> rhs.getAsOpaquePointer();
185 cmp != std::strong_ordering::equal) {
188 return compareSourceRefPaths(getPath(), rhs.getPath());
190 case SortCategory::CreateStruct:
191 case SortCategory::NonDet:
192 case SortCategory::RootResult: {
193 if (
auto cmp = getAsOpaquePointer() <=> rhs.getAsOpaquePointer();
194 cmp != std::strong_ordering::equal) {
197 return compareSourceRefPaths(getPath(), rhs.getPath());
199 case SortCategory::TemplateConstant: {
200 if (
auto cmp = compareStringRef(getTemplateConstantName(), rhs.getTemplateConstantName());
201 cmp != std::strong_ordering::equal) {
204 return getAsOpaquePointer() <=> rhs.getAsOpaquePointer();
206 case SortCategory::ConstantIndex:
207 return compareDynamicAPInt(*getConstantIndexValue(), *rhs.getConstantIndexValue());
208 case SortCategory::ConstantFelt:
209 return compareDynamicAPInt(*getConstantFeltValue(), *rhs.getConstantFeltValue());
212 llvm_unreachable(
"unhandled SourceRef category compare");
222SymbolLookupResult<StructDefOp>
230 return std::move(*sDef);
233std::vector<SourceRef>
235 std::vector<SourceRef> res = {root};
236 for (
const SourceRef &child : root.getAllChildren(tables,
mod)) {
237 auto recursiveChildren = getAllSourceRefs(tables,
mod, child);
238 res.insert(res.end(), recursiveChildren.begin(), recursiveChildren.end());
244 std::vector<SourceRef> res;
247 structDef == fnOp->getParentOfType<
StructDefOp>(),
"function must be within the given struct"
251 ensure(succeeded(modOp),
"could not lookup module from struct " + Twine(structDef.getName()));
253 SymbolTableCollection tables;
254 for (
auto a : fnOp.getArguments()) {
256 res.insert(res.end(), argRes.begin(), argRes.end());
263 auto createOp = dyn_cast_if_present<CreateStructOp>(selfVal.getDefiningOp());
264 ensure(createOp,
"self value should originate from struct.new operation");
265 auto selfRes =
getAllSourceRefs(tables, modOp.value(), SourceRef(createOp));
266 res.insert(res.end(), selfRes.begin(), selfRes.end());
273 std::vector<SourceRef> res;
276 memberDef->getParentOfType<
StructDefOp>() == structDef,
277 "Member " + Twine(memberDef.getName()) +
" is not a member of struct " +
278 Twine(structDef.getName())
281 ensure(succeeded(modOp),
"could not lookup module from struct " + Twine(structDef.getName()));
284 BlockArgument self = constrainFnOp.getArguments().front();
285 SourceRef memberRef = SourceRef(self, {
SourceRefIndex(memberDef)});
287 SymbolTableCollection tables;
293 size_t arrayDerefs = 0;
294 size_t idx = pathRef.size();
300 Type currTy = idx > 0 ? pathRef[idx - 1].getMember().getType() : value.getType();
301 if (arrayDerefs > 0) {
302 auto arrTy = dyn_cast<ArrayType>(currTy);
303 ensure(
static_cast<bool>(arrTy),
"SourceRef array indices require an array-typed base");
305 arrayDerefs <= arrTy.getDimensionSizes().size(),
306 "SourceRef indexes more array dimensions than exist in the base type"
309 if (arrayDerefs == arrTy.getDimensionSizes().size()) {
310 currTy = arrTy.getElementType();
313 ArrayType::get(arrTy.getElementType(), arrTy.getDimensionSizes().drop_front(arrayDerefs));
325 auto prefixPath = prefix.
getPath();
326 if (value != prefix.value || pathRef.size() < prefixPath.size()) {
329 for (
size_t i = 0; i < prefixPath.size(); i++) {
330 if (pathRef[i] != prefixPath[i]) {
343 auto prefixPath = prefix.
getPath();
344 suffix.reserve(pathRef.size() - prefixPath.size());
345 for (
size_t i = prefixPath.size(); i < pathRef.size(); i++) {
346 suffix.push_back(pathRef[i]);
356 if (failed(suffix)) {
360 SourceRef newSignalUsage = other;
363 pathRef.insert(pathRef.end(), suffix->begin(), suffix->end());
366 return newSignalUsage;
372 std::vector<SourceRef> res;
374 for (int64_t i = 0; i < arrayTy.getDimSize(0); i++) {
376 ensure(succeeded(childRef),
"array children require a rooted SourceRef");
377 res.push_back(*childRef);
387 std::vector<SourceRef> res;
395 auto structDefCopy = structDefRes;
397 tables, SymbolRefAttr::get(f.getContext(), f.getSymNameAttr()), std::move(structDefCopy),
400 ensure(succeeded(memberLookup),
"could not get SymbolLookupResult of existing MemberDefOp");
402 ensure(succeeded(childRef),
"struct children require a rooted SourceRef");
405 res.push_back(*childRef);
410std::vector<SourceRef>
413 if (
auto structTy = dyn_cast<StructType>(ty)) {
415 }
else if (
auto arrayType = dyn_cast<ArrayType>(ty)) {
428 auto constRead = getDefiningOp<ConstReadOp>();
429 ensure(succeeded(constRead),
"template constant should be backed by a const.read op");
430 auto structDefOp = (*constRead)->getParentOfType<
StructDefOp>();
431 ensure(structDefOp,
"struct template should have a struct parent");
432 os <<
'@' << structDefOp.getName() <<
"<[@" << constRead->getConstName() <<
"]>";
442 os <<
"<call " << callOp.getCallee();
444 Operation *printScope = callOp.getOperation();
445 if (
auto funcOp = callOp->getParentOfType<
FuncDefOp>()) {
446 printScope = funcOp.getOperation();
450 AsmState state(printScope);
451 value.printAsOperand(os, state);
455 OpPrintingFlags flags;
456 value.printAsOperand(os, flags);
459 for (
const auto &f :
getPath()) {
460 os <<
"[" << f <<
"]";
471 return constant == rhs.constant && value == rhs.value && llvm::equal(
getPath(), rhs.
getPath());
476 auto lhsCategory = getSortCategory();
477 auto rhsCategory = rhs.getSortCategory();
478 if (
auto cmp = lhsCategory <=> rhsCategory;
cmp != std::strong_ordering::equal) {
481 return compareWithinCategory(rhs, lhsCategory);
488 return llvm::hash_value(val.getAsOpaquePointer());
492 "unhandled SourceRef hash case"
495 size_t hash = llvm::hash_value(val.getAsOpaquePointer());
496 for (
const auto &f : val.
getPath()) {
497 hash = llvm::hash_combine(hash, f.getHash());
511 insert(rhs.begin(), rhs.end());
517 std::vector<SourceRef> sortedRefs(rhs.begin(), rhs.end());
518 std::sort(sortedRefs.begin(), sortedRefs.end());
519 for (
auto it = sortedRefs.begin(); it != sortedRefs.end();) {
522 if (it != sortedRefs.end()) {
This file implements helper methods for constructing DynamicAPInts.
This file defines methods symbol lookup across LLZK operations and included files.
std::strong_ordering operator<=>(const SourceRefIndex &rhs) const
bool isIndexRange() const
llvm::DynamicAPInt getIndex() const
void print(mlir::raw_ostream &os) const
IndexRange getIndexRange() const
component::MemberDefOp getMember() const
SourceRefIndex(component::MemberDefOp f)
SourceRefSet & join(const SourceRefSet &rhs)
A reference to a "source", which is the base value from which other SSA values are derived.
bool isBlockArgument() const
mlir::FailureOr< SourceRef > createChild(const SourceRefIndex &r) const
std::vector< SourceRef > getAllChildren(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod) const
Get all direct children of this SourceRef, assuming this ref is not a scalar.
mlir::FailureOr< std::vector< SourceRefIndex > > getSuffix(const SourceRef &prefix) const
If prefix is a valid prefix of this reference, return the suffix that remains after removing the pref...
mlir::FailureOr< function::CallOp > getCallOp() const
void print(mlir::raw_ostream &os) const
bool isCallResult() const
bool operator==(const SourceRef &rhs) const
bool isConstantFelt() const
llvm::ArrayRef< SourceRefIndex > getPath() const
bool isValidPrefix(const SourceRef &prefix) const
Returns true iff prefix is a valid prefix of this reference.
std::strong_ordering operator<=>(const SourceRef &rhs) const
mlir::FailureOr< llvm::DynamicAPInt > getConstantFeltValue() const
bool isConstantIndex() const
std::vector< SourceRefIndex > Path
static std::vector< SourceRef > getAllSourceRefs(mlir::SymbolTableCollection &tables, mlir::ModuleOp mod, const SourceRef &root)
Produce all possible SourceRefs that are present starting from the given root.
mlir::FailureOr< llvm::DynamicAPInt > getConstantValue() const
mlir::FailureOr< unsigned > getInputNum() const
mlir::FailureOr< NonDetOp > getNonDetOp() const
mlir::FailureOr< llvm::DynamicAPInt > getConstantIndexValue() const
mlir::FailureOr< SourceRef > translate(const SourceRef &prefix, const SourceRef &other) const
Create a new reference with prefix replaced with other iff prefix is a valid prefix for this referenc...
bool isTemplateConstant() const
bool isConstantInt() const
bool isCreateStructOp() const
mlir::Type getType() const
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
static constexpr ::llvm::StringLiteral getOperationName()
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
ExpressionValue mod(const llvm::SMTSolverRef &solver, const ExpressionValue &lhs, const ExpressionValue &rhs)
void ensure(bool condition, const llvm::Twine &errMsg)
FailureOr< ModuleOp > getRootModule(Operation *from)
ExpressionValue cmp(const llvm::SMTSolverRef &solver, CmpOp op, const ExpressionValue &lhs, const ExpressionValue &rhs)
Interval operator<<(const Interval &lhs, const Interval &rhs)
std::vector< SourceRef > getAllChildren(SymbolTableCollection &, ModuleOp, ArrayType arrayTy, const SourceRef &root)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
APSInt toAPSInt(const DynamicAPInt &i)
SymbolLookupResult< StructDefOp > getStructDef(SymbolTableCollection &tables, ModuleOp mod, StructType ty)
Lookup a StructDefOp from a given StructType.
size_t operator()(const SourceRefIndex &c) const
size_t operator()(const SourceRef &val) const