20#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
21#include <mlir/Analysis/DataFlow/DenseAnalysis.h>
22#include <mlir/IR/Value.h>
24#include <llvm/Support/Debug.h>
27#include <unordered_set>
29#define DEBUG_TYPE "llzk-cdg"
36using namespace component;
37using namespace constrain;
38using namespace function;
43 return solver.lookupState<
Lattice>(val);
47 if (
const auto *state =
getLattice(solver, val)) {
48 return state->getValue();
53mlir::FailureOr<SourceRefLatticeValue>
55 llvm::SmallDenseMap<Value, SourceRefLatticeValue, 4> operandVals;
56 for (Value operand : op->getOperands()) {
60 SymbolTableCollection tables;
61 if (
auto memberRefOp = llvm::dyn_cast<MemberRefOpInterface>(op)) {
62 if (!memberRefOp.isRead()) {
63 auto memberOpRes = memberRefOp.getMemberDefOp(tables);
64 ensure(succeeded(memberOpRes),
"could not find member write");
65 auto componentIt = operandVals.find(memberRefOp.getComponent());
66 ensure(componentIt != operandVals.end(),
"missing component lattice for member write");
67 auto memberValsRes = componentIt->second.referenceMember(memberOpRes.value());
68 ensure(succeeded(memberValsRes),
"could not create SourceRef child for member write");
69 return memberValsRes->first;
73 if (
auto arrayAccessOp = llvm::dyn_cast<ArrayAccessOpInterface>(op)) {
74 if (llvm::isa<WriteArrayOp, InsertArrayOp>(arrayAccessOp)) {
75 auto array = arrayAccessOp.getArrRef();
76 auto it = operandVals.find(
array);
77 ensure(it != operandVals.end(),
"improperly constructed operandVals map");
78 const auto &currVals = it->second;
80 std::vector<SourceRefIndex> indices;
81 for (
size_t i = 0; i < arrayAccessOp.getIndices().size(); ++i) {
82 auto idxOperand = arrayAccessOp.getIndices()[i];
83 auto idxIt = operandVals.find(idxOperand);
84 ensure(idxIt != operandVals.end(),
"improperly constructed operandVals map");
85 const auto &idxVals = idxIt->second;
87 if (idxVals.isSingleValue() && idxVals.getSingleValue().isConstant()) {
88 indices.emplace_back(*idxVals.getSingleValue().getConstantValue());
90 auto arrayType = llvm::dyn_cast<ArrayType>(
array.getType());
91 auto lower = APInt::getZero(64);
92 assert(i <= std::numeric_limits<unsigned>::max() &&
"index too large");
93 APInt upper(64, arrayType.getDimSize(
static_cast<unsigned>(i)));
94 indices.emplace_back(lower, upper);
98 auto newValsRes = currVals.extract(indices);
99 ensure(succeeded(newValsRes),
"could not create SourceRef child for array access");
100 auto [newVals, _] = *newValsRes;
101 if (llvm::isa<WriteArrayOp>(arrayAccessOp)) {
102 ensure(newVals.isScalar(),
"array write must produce a scalar value");
108 return mlir::failure();
112 if (
auto value = llvm::dyn_cast_if_present<Value>(lattice->getAnchor())) {
118 Operation *op, ArrayRef<const Lattice *> operands, ArrayRef<Lattice *> results
120 LLVM_DEBUG(llvm::dbgs() <<
"SourceRefAnalysis::visitOperation: " << *op <<
'\n');
122 DenseMap<Value, const Lattice *> operandVals;
123 for (
auto [operand, lattice] : llvm::zip(op->getOperands(), operands)) {
124 operandVals[operand] = lattice;
127 if (
auto memberRefOp = llvm::dyn_cast<MemberRefOpInterface>(op)) {
128 auto memberOpRes = memberRefOp.getMemberDefOp(tables);
129 ensure(succeeded(memberOpRes),
"could not find member read");
131 operandVals.at(memberRefOp.getComponent())->getValue().referenceMember(memberOpRes.value());
132 ensure(succeeded(memberValsRes),
"could not create SourceRef child for member reference");
133 if (memberRefOp.isRead()) {
134 auto [memberVals, _] = *memberValsRes;
135 propagateIfChanged(results.front(), results.front()->setValue(memberVals));
140 if (
auto arrayAccessOp = llvm::dyn_cast<ArrayAccessOpInterface>(op)) {
141 if (!results.empty()) {
143 propagateIfChanged(results.front(), results.front()->setValue(newVals));
148 if (
auto createArray = llvm::dyn_cast<CreateArrayOp>(op)) {
149 auto createArrayRes = createArray.getResult();
150 const auto &elements = createArray.getElements();
151 if (elements.empty()) {
154 results.front()->setValue(
SourceRef(llvm::cast<OpResult>(createArrayRes)))
160 for (
size_t i = 0; i < elements.size(); i++) {
163 propagateIfChanged(results.front(), results.front()->setValue(newArrayVal));
167 if (
auto structNewOp = llvm::dyn_cast<CreateStructOp>(op)) {
169 propagateIfChanged(results.front(), results.front()->setValue(newStructValue));
174 for (
Lattice *result : results) {
175 propagateIfChanged(result, updated);
181 CallOpInterface call, ArrayRef<const Lattice *> operandLattices,
182 ArrayRef<Lattice *> resultLattices
184 auto callable = dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
185 if (!callable || !callable.getCallableRegion()) {
187 for (
auto [result, lattice] : llvm::zip(call->getResults(), resultLattices)) {
189 ensure(succeeded(resultRef),
"could not create external call SourceRef");
190 propagateIfChanged(lattice, lattice->setValue(*resultRef));
194 if (resultLattices.empty()) {
203 ensure(succeeded(funcOpRes),
"could not lookup called function");
204 auto funcOp = funcOpRes->get();
206 const auto *predecessors = getOrCreateFor<mlir::dataflow::PredecessorState>(
207 getProgramPointAfter(call), getProgramPointAfter(call)
211 if (!predecessors->allPredecessorsKnown()) {
215 const auto returnSites = predecessors->getKnownPredecessors();
217 std::unordered_map<SourceRef, SourceRefLatticeValue, SourceRef::Hash> translation;
218 for (
unsigned i = 0; i < funcOp.getNumArguments(); i++) {
219 translation[
SourceRef(funcOp.getArgument(i))] =
220 static_cast<const Lattice *
>(operandLattices[i])->getValue();
223 for (
auto [result, resultLattice] : llvm::zip(call->getResults(), resultLattices)) {
226 unsigned resultNum = llvm::cast<OpResult>(result).getResultNumber();
227 for (Operation *returnSite : returnSites) {
229 getProgramPointAfter(call.getOperation()),
230 returnSite->getOperand(resultNum)
233 auto [translatedVal, _] = retVal.translate(translation);
234 (void)combined.
update(translatedVal);
236 propagateIfChanged(resultLattice,
static_cast<Lattice *
>(resultLattice)->setValue(combined));
241 Operation *op,
const OperandValues &operandVals, ArrayRef<Lattice *> results
243 auto updated = ChangeResult::NoChange;
244 for (
auto [res, lattice] : llvm::zip(op->getResults(), results)) {
246 for (
const auto &[_, opVal] : operandVals) {
247 (void)cur.update(opVal->getValue());
249 updated |= lattice->setValue(cur);
258 auto it = operandVals.find(
array);
259 ensure(it != operandVals.end(),
"improperly constructed operandVals map");
260 const auto &currVals = it->second->getValue();
262 std::vector<SourceRefIndex> indices;
263 for (
size_t i = 0; i < arrayAccessOp.
getIndices().size(); ++i) {
264 auto idxOperand = arrayAccessOp.
getIndices()[i];
265 auto idxIt = operandVals.find(idxOperand);
266 ensure(idxIt != operandVals.end(),
"improperly constructed operandVals map");
267 const auto &idxVals = idxIt->second->getValue();
269 if (idxVals.isSingleValue() && idxVals.getSingleValue().isConstant()) {
270 indices.emplace_back(*idxVals.getSingleValue().getConstantValue());
272 auto arrayType = llvm::dyn_cast<ArrayType>(
array.getType());
273 auto lower = APInt::getZero(64);
274 assert(i <= std::numeric_limits<unsigned>::max() &&
"index too large");
275 APInt upper(64, arrayType.getDimSize(
static_cast<unsigned>(i)));
276 indices.emplace_back(lower, upper);
280 auto newValsRes = currVals.extract(indices);
281 ensure(succeeded(newValsRes),
"could not create SourceRef child for array access");
282 auto [newVals, _] = *newValsRes;
283 if (llvm::isa<ReadArrayOp, WriteArrayOp>(arrayAccessOp)) {
284 ensure(newVals.isScalar(),
"array read/write must produce a scalar value");
292 ModuleOp m,
StructDefOp s, DataFlowSolver &solver, AnalysisManager &am,
296 if (cdg.computeConstraints(solver, am).failed()) {
297 return mlir::failure();
309 std::set<std::set<SourceRef>> sortedSets;
310 for (
auto it = signalSets.begin(); it != signalSets.end(); it++) {
311 if (!it->isLeader()) {
315 std::set<SourceRef> sortedMembers;
316 for (
auto mit = signalSets.member_begin(it); mit != signalSets.member_end(); mit++) {
317 sortedMembers.insert(*mit);
322 if (sortedMembers.size() > 1) {
323 sortedSets.insert(sortedMembers);
327 for (
const auto &[ref, constSet] : constantSets) {
328 if (constSet.empty()) {
331 std::set<SourceRef> sortedMembers(constSet.begin(), constSet.end());
332 sortedMembers.insert(ref);
333 sortedSets.insert(sortedMembers);
336 os <<
"ConstraintDependencyGraph { ";
338 for (
auto it = sortedSets.begin(); it != sortedSets.end();) {
340 for (
auto mit = it->begin(); mit != it->end();) {
343 if (mit != it->end()) {
349 if (it == sortedSets.end()) {
359mlir::LogicalResult ConstraintDependencyGraph::computeConstraints(
360 mlir::DataFlowSolver &solver, mlir::AnalysisManager &am
366 "malformed struct " + mlir::Twine(structDef.getName()) +
" must define a constrain function"
377 constrainFnOp.walk([
this, &solver](Operation *op) {
382 for (Value operand : op->getOperands()) {
384 for (
const SourceRef &ref : operandRefs) {
385 ref2Val[ref].insert(operand);
388 for (Value result : op->getResults()) {
390 for (
const SourceRef &ref : resultRefs) {
391 ref2Val[ref].insert(result);
395 if (succeeded(writeTargetState)) {
396 for (
const SourceRef &ref : writeTargetState->foldToScalar()) {
397 ref2Val[ref].insert(op);
400 if (isa<EmitEqualityOp, EmitContainmentOp>(op)) {
401 this->walkConstrainOp(solver, op);
412 auto fnCallWalker = [
this, &solver, &am](CallOp fnCall)
mutable {
417 ensure(mlir::succeeded(res),
"could not resolve constrain call");
419 auto fn = res->get();
420 if (!fn.isStructConstrain()) {
424 auto calledStruct = fn.getOperation()->getParentOfType<StructDefOp>();
428 for (
unsigned i = 0; i < fn.getNumArguments(); i++) {
429 SourceRef prefix(fn.getArgument(i));
430 Value operand = fnCall.getOperand(i);
432 translations.push_back({prefix, val});
434 auto &childAnalysis =
435 am.getChildAnalysis<ConstraintDependencyGraphStructAnalysis>(calledStruct);
436 if (!childAnalysis.constructed(ctx)) {
438 mlir::succeeded(childAnalysis.runAnalysis(solver, am, {.runIntraprocedural = false})),
439 "could not construct CDG for child struct"
442 auto translatedCDG = childAnalysis.getResult(ctx).translate(translations);
444 const auto &translatedRef2Val = translatedCDG.getRef2Val();
445 ref2Val.insert(translatedRef2Val.begin(), translatedRef2Val.end());
449 auto &tSets = translatedCDG.signalSets;
450 for (
auto lit = tSets.begin(); lit != tSets.end(); lit++) {
451 if (!lit->isLeader()) {
454 auto leader = lit->getData();
455 for (
auto mit = tSets.member_begin(lit); mit != tSets.member_end(); mit++) {
456 signalSets.unionSets(leader, *mit);
460 for (
auto &[ref, constSet] : translatedCDG.constantSets) {
461 constantSets[ref].insert(constSet.begin(), constSet.end());
464 if (!ctx.runIntraproceduralAnalysis()) {
465 constrainFnOp.walk(fnCallWalker);
468 return mlir::success();
471void ConstraintDependencyGraph::walkConstrainOp(
472 mlir::DataFlowSolver &solver, mlir::Operation *emitOp
474 std::vector<SourceRef> signalUsages, constUsages;
476 for (
auto operand : emitOp->getOperands()) {
478 for (
const auto &ref : latticeVal.foldToScalar()) {
479 if (ref.isConstant()) {
480 constUsages.push_back(ref);
482 signalUsages.push_back(ref);
488 if (!signalUsages.empty()) {
489 auto it = signalUsages.begin();
490 auto leader = signalSets.getOrInsertLeaderValue(*it);
491 for (it++; it != signalUsages.end(); it++) {
492 signalSets.unionSets(leader, *it);
496 for (
auto &sig : signalUsages) {
497 constantSets[sig].insert(constUsages.begin(), constUsages.end());
505 [&translation](
const SourceRef &elem) -> mlir::FailureOr<std::vector<SourceRef>> {
506 std::vector<SourceRef> refs;
507 for (
auto &[prefix, vals] : translation) {
508 if (!elem.isValidPrefix(prefix)) {
512 if (vals.isArray()) {
514 auto suffix = elem.getSuffix(prefix);
516 mlir::succeeded(suffix),
"failure is nonsensical, we already checked for valid prefix"
519 auto resolvedValsRes = vals.extract(suffix.value());
520 ensure(succeeded(resolvedValsRes),
"could not create SourceRef child while resolving refs");
521 auto [resolvedVals, _] = *resolvedValsRes;
522 auto folded = resolvedVals.foldToScalar();
523 refs.insert(refs.end(), folded.begin(), folded.end());
525 for (
const auto &replacement : vals.getScalarValue()) {
526 auto translated = elem.translate(prefix, replacement);
527 if (mlir::succeeded(translated)) {
528 refs.push_back(translated.value());
534 return mlir::failure();
539 for (
auto leaderIt = signalSets.begin(); leaderIt != signalSets.end(); leaderIt++) {
540 if (!leaderIt->isLeader()) {
544 std::vector<SourceRef> translatedSignals, translatedConsts;
545 for (
auto mit = signalSets.member_begin(leaderIt); mit != signalSets.member_end(); mit++) {
547 if (mlir::failed(member)) {
550 for (
const auto &ref : *member) {
551 if (ref.isConstant()) {
552 translatedConsts.push_back(ref);
554 translatedSignals.push_back(ref);
558 if (
auto it = constantSets.find(*mit); it != constantSets.end()) {
559 const auto &origConstSet = it->second;
560 translatedConsts.insert(translatedConsts.end(), origConstSet.begin(), origConstSet.end());
564 if (translatedSignals.empty()) {
569 auto it = translatedSignals.begin();
571 res.signalSets.insert(leader);
572 for (it++; it != translatedSignals.end(); it++) {
573 res.signalSets.insert(*it);
574 res.signalSets.unionSets(leader, *it);
578 for (
auto &ref : translatedSignals) {
579 res.constantSets[ref].insert(translatedConsts.begin(), translatedConsts.end());
584 for (
const auto &[ref, vals] : ref2Val) {
586 if (succeeded(translationRes)) {
587 for (
const auto &translatedRef : *translationRes) {
588 res.ref2Val[translatedRef].insert(vals.begin(), vals.end());
598 auto currRef = mlir::FailureOr<SourceRef>(ref);
599 while (mlir::succeeded(currRef)) {
601 for (
auto it = signalSets.findLeader(*currRef); it != signalSets.member_end(); it++) {
602 if (currRef.value() != *it) {
607 auto constIt = constantSets.find(*currRef);
608 if (constIt != constantSets.end()) {
609 res.insert(constIt->second.begin(), constIt->second.end());
612 currRef = currRef->getParentPrefix();
620 mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager,
626 if (mlir::failed(result)) {
627 return mlir::failure();
630 return mlir::success();
mlir::LogicalResult runAnalysis(mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager, const CDGAnalysisContext &ctx) override
Construct a CDG, using the module's analysis manager to query ConstraintDependencyGraph objects for n...
A dependency graph of constraints enforced by an LLZK struct.
void print(mlir::raw_ostream &os) const
Print the CDG to the specified output stream.
ConstraintDependencyGraph(const ConstraintDependencyGraph &other)
static mlir::FailureOr< ConstraintDependencyGraph > compute(mlir::ModuleOp mod, component::StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, const CDGAnalysisContext &ctx)
Compute a ConstraintDependencyGraph (CDG)
SourceRefSet getConstrainingValues(const SourceRef &ref) const
Get the values that are connected to the given ref via emitted constraints.
void dump() const
Dumps the CDG to stderr.
ConstraintDependencyGraph translate(SourceRefRemappings translation) const
Translate the SourceRefs in this CDG to that of a different context.
static mlir::ChangeResult fallbackOpUpdate(mlir::Operation *op, const OperandValues &operandVals, mlir::ArrayRef< Lattice * > results)
void visitExternalCall(mlir::CallOpInterface call, mlir::ArrayRef< const Lattice * > argumentLattices, mlir::ArrayRef< Lattice * > resultLattices) override
Visit a call operation to an externally defined function given the lattices of its arguments.
static mlir::FailureOr< SourceRefLatticeValue > getWriteTargetState(mlir::DataFlowSolver &solver, mlir::Operation *op)
static SourceRefLatticeValue arraySubdivisionOpUpdate(array::ArrayAccessOpInterface op, const OperandValues &operandVals)
static SourceRefLatticeValue getValueState(mlir::DataFlowSolver &solver, mlir::Value val)
void setToEntryState(Lattice *lattice) override
Set the given lattice element(s) at control flow entry point(s).
mlir::LogicalResult visitOperation(mlir::Operation *op, mlir::ArrayRef< const Lattice * > operands, mlir::ArrayRef< Lattice * > results) override
Propagate SourceRef lattice values from operands to results.
static const Lattice * getLattice(mlir::DataFlowSolver &solver, mlir::Value val)
mlir::DenseMap< mlir::Value, const Lattice * > OperandValues
A value at a given point of the SourceRefLattice.
mlir::ChangeResult setValue(const LatticeValue &newValue)
static SourceRefLatticeValue getDefaultValue(ValueTy v)
static mlir::FailureOr< SourceRef > getSourceRef(mlir::Value val)
If val is the source of other values (i.e., a block argument, an allocation-like op result,...
A reference to a "source", which is the base value from which other SSA values are derived.
component::StructDefOp getStruct() const
void setResult(const CDGAnalysisContext &ctx, ConstraintDependencyGraph &&r)
mlir::ModuleOp getModule() const
::mlir::Operation::operand_range getIndices()
Gets the operand range containing the index for each dimension.
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
Gets the SSA Value for the referenced array.
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
ScalarTy foldToScalar() const
If this is an array value, combine all elements into a single scalar value and return it.
mlir::ChangeResult setValue(const AbstractLatticeValue &rhs)
Sets this value to be equal to rhs.
mlir::ChangeResult update(const Derived &rhs)
Union this value with that of rhs.
const Derived & getElemFlatIdx(size_t i) const
Directly index into the flattened array using a single index.
const SourceRefLattice * getLatticeElementFor(mlir::ProgramPoint *point, mlir::Value value)
void setAllToEntryStates(mlir::ArrayRef< SourceRefLattice * > lattices)
bool isOperationLive(DataFlowSolver &solver, Operation *op)
std::vector< std::pair< SourceRef, SourceRefLatticeValue > > SourceRefRemappings
void ensure(bool condition, const llvm::Twine &errMsg)
mlir::FailureOr< SymbolLookupResult< T > > resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers.
Parameters and shared objects to pass to child analyses.