20#include <mlir/Analysis/DataFlow/DeadCodeAnalysis.h>
21#include <mlir/Analysis/DataFlow/DenseAnalysis.h>
22#include <mlir/IR/Value.h>
24#include <llvm/Support/Debug.h>
27#include <unordered_set>
29#define DEBUG_TYPE "llzk-cdg"
36using namespace component;
37using namespace constrain;
38using namespace function;
40static bool isOperationLive(DataFlowSolver &solver, Operation *op) {
41 if (!op->getBlock()) {
44 if (
const auto *exec = solver.lookupState<mlir::dataflow::Executable>(
45 solver.getProgramPointBefore(op->getBlock())
47 return exec->isLive();
55 return solver.lookupState<
Lattice>(val);
59 if (
const auto *state =
getLattice(solver, val)) {
60 return state->getValue();
65mlir::FailureOr<SourceRefLatticeValue>
67 llvm::SmallDenseMap<Value, SourceRefLatticeValue, 4> operandVals;
68 for (Value operand : op->getOperands()) {
72 SymbolTableCollection tables;
73 if (
auto memberRefOp = llvm::dyn_cast<MemberRefOpInterface>(op)) {
74 if (!memberRefOp.isRead()) {
75 auto memberOpRes = memberRefOp.getMemberDefOp(tables);
76 ensure(succeeded(memberOpRes),
"could not find member write");
77 auto componentIt = operandVals.find(memberRefOp.getComponent());
78 ensure(componentIt != operandVals.end(),
"missing component lattice for member write");
79 auto memberValsRes = componentIt->second.referenceMember(memberOpRes.value());
80 ensure(succeeded(memberValsRes),
"could not create SourceRef child for member write");
81 return memberValsRes->first;
85 if (
auto arrayAccessOp = llvm::dyn_cast<ArrayAccessOpInterface>(op)) {
86 if (llvm::isa<WriteArrayOp, InsertArrayOp>(arrayAccessOp)) {
87 auto array = arrayAccessOp.getArrRef();
88 auto it = operandVals.find(
array);
89 ensure(it != operandVals.end(),
"improperly constructed operandVals map");
90 const auto &currVals = it->second;
92 std::vector<SourceRefIndex> indices;
93 for (
size_t i = 0; i < arrayAccessOp.getIndices().size(); ++i) {
94 auto idxOperand = arrayAccessOp.getIndices()[i];
95 auto idxIt = operandVals.find(idxOperand);
96 ensure(idxIt != operandVals.end(),
"improperly constructed operandVals map");
97 const auto &idxVals = idxIt->second;
99 if (idxVals.isSingleValue() && idxVals.getSingleValue().isConstant()) {
100 indices.emplace_back(*idxVals.getSingleValue().getConstantValue());
102 auto arrayType = llvm::dyn_cast<ArrayType>(
array.getType());
103 auto lower = APInt::getZero(64);
104 assert(i <= std::numeric_limits<unsigned>::max() &&
"index too large");
105 APInt upper(64, arrayType.getDimSize(
static_cast<unsigned>(i)));
106 indices.emplace_back(lower, upper);
110 auto newValsRes = currVals.extract(indices);
111 ensure(succeeded(newValsRes),
"could not create SourceRef child for array access");
112 auto [newVals, _] = *newValsRes;
113 if (llvm::isa<WriteArrayOp>(arrayAccessOp)) {
114 ensure(newVals.isScalar(),
"array write must produce a scalar value");
120 return mlir::failure();
124 if (
auto value = llvm::dyn_cast_if_present<Value>(lattice->getAnchor())) {
130 Operation *op, ArrayRef<const Lattice *> operands, ArrayRef<Lattice *> results
132 LLVM_DEBUG(llvm::dbgs() <<
"SourceRefAnalysis::visitOperation: " << *op <<
'\n');
134 DenseMap<Value, const Lattice *> operandVals;
135 for (
auto [operand, lattice] : llvm::zip(op->getOperands(), operands)) {
136 operandVals[operand] = lattice;
139 if (
auto memberRefOp = llvm::dyn_cast<MemberRefOpInterface>(op)) {
140 auto memberOpRes = memberRefOp.getMemberDefOp(tables);
141 ensure(succeeded(memberOpRes),
"could not find member read");
143 operandVals.at(memberRefOp.getComponent())->getValue().referenceMember(memberOpRes.value());
144 ensure(succeeded(memberValsRes),
"could not create SourceRef child for member reference");
145 if (memberRefOp.isRead()) {
146 auto [memberVals, _] = *memberValsRes;
147 propagateIfChanged(results.front(), results.front()->setValue(memberVals));
152 if (
auto arrayAccessOp = llvm::dyn_cast<ArrayAccessOpInterface>(op)) {
153 if (!results.empty()) {
155 propagateIfChanged(results.front(), results.front()->setValue(newVals));
160 if (
auto createArray = llvm::dyn_cast<CreateArrayOp>(op)) {
161 auto createArrayRes = createArray.getResult();
162 const auto &elements = createArray.getElements();
163 if (elements.empty()) {
166 results.front()->setValue(
SourceRef(llvm::cast<OpResult>(createArrayRes)))
172 for (
size_t i = 0; i < elements.size(); i++) {
175 propagateIfChanged(results.front(), results.front()->setValue(newArrayVal));
179 if (
auto structNewOp = llvm::dyn_cast<CreateStructOp>(op)) {
181 propagateIfChanged(results.front(), results.front()->setValue(newStructValue));
186 for (
Lattice *result : results) {
187 propagateIfChanged(result, updated);
193 CallOpInterface call, ArrayRef<const Lattice *> operandLattices,
194 ArrayRef<Lattice *> resultLattices
196 auto callable = dyn_cast_if_present<CallableOpInterface>(call.resolveCallable());
197 if (!callable || !callable.getCallableRegion()) {
199 for (
auto [result, lattice] : llvm::zip(call->getResults(), resultLattices)) {
201 ensure(succeeded(resultRef),
"could not create external call SourceRef");
202 propagateIfChanged(lattice, lattice->setValue(*resultRef));
209 ensure(succeeded(funcOpRes),
"could not lookup called function");
210 auto funcOp = funcOpRes->get();
212 const auto *predecessors = getOrCreateFor<mlir::dataflow::PredecessorState>(
213 getProgramPointAfter(call), getProgramPointAfter(call)
217 if (!predecessors->allPredecessorsKnown()) {
221 const auto returnSites = predecessors->getKnownPredecessors();
223 std::unordered_map<SourceRef, SourceRefLatticeValue, SourceRef::Hash> translation;
224 for (
unsigned i = 0; i < funcOp.getNumArguments(); i++) {
225 translation[
SourceRef(funcOp.getArgument(i))] =
226 static_cast<const Lattice *
>(operandLattices[i])->getValue();
229 for (
auto [result, resultLattice] : llvm::zip(call->getResults(), resultLattices)) {
232 unsigned resultNum = llvm::cast<OpResult>(result).getResultNumber();
233 for (Operation *returnSite : returnSites) {
235 getProgramPointAfter(call.getOperation()),
236 returnSite->getOperand(resultNum)
239 auto [translatedVal, _] = retVal.translate(translation);
240 (void)combined.
update(translatedVal);
242 propagateIfChanged(resultLattice,
static_cast<Lattice *
>(resultLattice)->setValue(combined));
247 Operation *op,
const OperandValues &operandVals, ArrayRef<Lattice *> results
249 auto updated = ChangeResult::NoChange;
250 for (
auto [res, lattice] : llvm::zip(op->getResults(), results)) {
252 for (
const auto &[_, opVal] : operandVals) {
253 (void)cur.update(opVal->getValue());
255 updated |= lattice->setValue(cur);
264 auto it = operandVals.find(
array);
265 ensure(it != operandVals.end(),
"improperly constructed operandVals map");
266 const auto &currVals = it->second->getValue();
268 std::vector<SourceRefIndex> indices;
269 for (
size_t i = 0; i < arrayAccessOp.
getIndices().size(); ++i) {
270 auto idxOperand = arrayAccessOp.
getIndices()[i];
271 auto idxIt = operandVals.find(idxOperand);
272 ensure(idxIt != operandVals.end(),
"improperly constructed operandVals map");
273 const auto &idxVals = idxIt->second->getValue();
275 if (idxVals.isSingleValue() && idxVals.getSingleValue().isConstant()) {
276 indices.emplace_back(*idxVals.getSingleValue().getConstantValue());
278 auto arrayType = llvm::dyn_cast<ArrayType>(
array.getType());
279 auto lower = APInt::getZero(64);
280 assert(i <= std::numeric_limits<unsigned>::max() &&
"index too large");
281 APInt upper(64, arrayType.getDimSize(
static_cast<unsigned>(i)));
282 indices.emplace_back(lower, upper);
286 auto newValsRes = currVals.extract(indices);
287 ensure(succeeded(newValsRes),
"could not create SourceRef child for array access");
288 auto [newVals, _] = *newValsRes;
289 if (llvm::isa<ReadArrayOp, WriteArrayOp>(arrayAccessOp)) {
290 ensure(newVals.isScalar(),
"array read/write must produce a scalar value");
298 ModuleOp m,
StructDefOp s, DataFlowSolver &solver, AnalysisManager &am,
302 if (cdg.computeConstraints(solver, am).failed()) {
303 return mlir::failure();
315 std::set<std::set<SourceRef>> sortedSets;
316 for (
auto it = signalSets.begin(); it != signalSets.end(); it++) {
317 if (!it->isLeader()) {
321 std::set<SourceRef> sortedMembers;
322 for (
auto mit = signalSets.member_begin(it); mit != signalSets.member_end(); mit++) {
323 sortedMembers.insert(*mit);
328 if (sortedMembers.size() > 1) {
329 sortedSets.insert(sortedMembers);
333 for (
const auto &[ref, constSet] : constantSets) {
334 if (constSet.empty()) {
337 std::set<SourceRef> sortedMembers(constSet.begin(), constSet.end());
338 sortedMembers.insert(ref);
339 sortedSets.insert(sortedMembers);
342 os <<
"ConstraintDependencyGraph { ";
344 for (
auto it = sortedSets.begin(); it != sortedSets.end();) {
346 for (
auto mit = it->begin(); mit != it->end();) {
349 if (mit != it->end()) {
355 if (it == sortedSets.end()) {
365mlir::LogicalResult ConstraintDependencyGraph::computeConstraints(
366 mlir::DataFlowSolver &solver, mlir::AnalysisManager &am
372 "malformed struct " + mlir::Twine(structDef.getName()) +
" must define a constrain function"
383 constrainFnOp.walk([
this, &solver](Operation *op) {
384 if (!isOperationLive(solver, op)) {
388 for (Value operand : op->getOperands()) {
390 for (
const SourceRef &ref : operandRefs) {
391 ref2Val[ref].insert(operand);
394 for (Value result : op->getResults()) {
396 for (
const SourceRef &ref : resultRefs) {
397 ref2Val[ref].insert(result);
401 if (succeeded(writeTargetState)) {
402 for (
const SourceRef &ref : writeTargetState->foldToScalar()) {
403 ref2Val[ref].insert(op);
406 if (isa<EmitEqualityOp, EmitContainmentOp>(op)) {
407 this->walkConstrainOp(solver, op);
418 auto fnCallWalker = [
this, &solver, &am](CallOp fnCall)
mutable {
419 if (!isOperationLive(solver, fnCall.getOperation())) {
423 ensure(mlir::succeeded(res),
"could not resolve constrain call");
425 auto fn = res->get();
426 if (!fn.isStructConstrain()) {
430 auto calledStruct = fn.getOperation()->getParentOfType<StructDefOp>();
434 for (
unsigned i = 0; i < fn.getNumArguments(); i++) {
435 SourceRef prefix(fn.getArgument(i));
436 Value operand = fnCall.getOperand(i);
438 translations.push_back({prefix, val});
440 auto &childAnalysis =
441 am.getChildAnalysis<ConstraintDependencyGraphStructAnalysis>(calledStruct);
442 if (!childAnalysis.constructed(ctx)) {
444 mlir::succeeded(childAnalysis.runAnalysis(solver, am, {.runIntraprocedural = false})),
445 "could not construct CDG for child struct"
448 auto translatedCDG = childAnalysis.getResult(ctx).translate(translations);
450 const auto &translatedRef2Val = translatedCDG.getRef2Val();
451 ref2Val.insert(translatedRef2Val.begin(), translatedRef2Val.end());
455 auto &tSets = translatedCDG.signalSets;
456 for (
auto lit = tSets.begin(); lit != tSets.end(); lit++) {
457 if (!lit->isLeader()) {
460 auto leader = lit->getData();
461 for (
auto mit = tSets.member_begin(lit); mit != tSets.member_end(); mit++) {
462 signalSets.unionSets(leader, *mit);
466 for (
auto &[ref, constSet] : translatedCDG.constantSets) {
467 constantSets[ref].insert(constSet.begin(), constSet.end());
470 if (!ctx.runIntraproceduralAnalysis()) {
471 constrainFnOp.walk(fnCallWalker);
474 return mlir::success();
477void ConstraintDependencyGraph::walkConstrainOp(
478 mlir::DataFlowSolver &solver, mlir::Operation *emitOp
480 std::vector<SourceRef> signalUsages, constUsages;
482 for (
auto operand : emitOp->getOperands()) {
484 for (
const auto &ref : latticeVal.foldToScalar()) {
485 if (ref.isConstant()) {
486 constUsages.push_back(ref);
488 signalUsages.push_back(ref);
494 if (!signalUsages.empty()) {
495 auto it = signalUsages.begin();
496 auto leader = signalSets.getOrInsertLeaderValue(*it);
497 for (it++; it != signalUsages.end(); it++) {
498 signalSets.unionSets(leader, *it);
502 for (
auto &sig : signalUsages) {
503 constantSets[sig].insert(constUsages.begin(), constUsages.end());
511 [&translation](
const SourceRef &elem) -> mlir::FailureOr<std::vector<SourceRef>> {
512 std::vector<SourceRef> refs;
513 for (
auto &[prefix, vals] : translation) {
514 if (!elem.isValidPrefix(prefix)) {
518 if (vals.isArray()) {
520 auto suffix = elem.getSuffix(prefix);
522 mlir::succeeded(suffix),
"failure is nonsensical, we already checked for valid prefix"
525 auto resolvedValsRes = vals.extract(suffix.value());
526 ensure(succeeded(resolvedValsRes),
"could not create SourceRef child while resolving refs");
527 auto [resolvedVals, _] = *resolvedValsRes;
528 auto folded = resolvedVals.foldToScalar();
529 refs.insert(refs.end(), folded.begin(), folded.end());
531 for (
const auto &replacement : vals.getScalarValue()) {
532 auto translated = elem.translate(prefix, replacement);
533 if (mlir::succeeded(translated)) {
534 refs.push_back(translated.value());
540 return mlir::failure();
545 for (
auto leaderIt = signalSets.begin(); leaderIt != signalSets.end(); leaderIt++) {
546 if (!leaderIt->isLeader()) {
550 std::vector<SourceRef> translatedSignals, translatedConsts;
551 for (
auto mit = signalSets.member_begin(leaderIt); mit != signalSets.member_end(); mit++) {
553 if (mlir::failed(member)) {
556 for (
const auto &ref : *member) {
557 if (ref.isConstant()) {
558 translatedConsts.push_back(ref);
560 translatedSignals.push_back(ref);
564 if (
auto it = constantSets.find(*mit); it != constantSets.end()) {
565 const auto &origConstSet = it->second;
566 translatedConsts.insert(translatedConsts.end(), origConstSet.begin(), origConstSet.end());
570 if (translatedSignals.empty()) {
575 auto it = translatedSignals.begin();
577 res.signalSets.insert(leader);
578 for (it++; it != translatedSignals.end(); it++) {
579 res.signalSets.insert(*it);
580 res.signalSets.unionSets(leader, *it);
584 for (
auto &ref : translatedSignals) {
585 res.constantSets[ref].insert(translatedConsts.begin(), translatedConsts.end());
590 for (
const auto &[ref, vals] : ref2Val) {
592 if (succeeded(translationRes)) {
593 for (
const auto &translatedRef : *translationRes) {
594 res.ref2Val[translatedRef].insert(vals.begin(), vals.end());
604 auto currRef = mlir::FailureOr<SourceRef>(ref);
605 while (mlir::succeeded(currRef)) {
607 for (
auto it = signalSets.findLeader(*currRef); it != signalSets.member_end(); it++) {
608 if (currRef.value() != *it) {
613 auto constIt = constantSets.find(*currRef);
614 if (constIt != constantSets.end()) {
615 res.insert(constIt->second.begin(), constIt->second.end());
618 currRef = currRef->getParentPrefix();
626 mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager,
632 if (mlir::failed(result)) {
633 return mlir::failure();
636 return mlir::success();
mlir::LogicalResult runAnalysis(mlir::DataFlowSolver &solver, mlir::AnalysisManager &moduleAnalysisManager, const CDGAnalysisContext &ctx) override
Construct a CDG, using the module's analysis manager to query ConstraintDependencyGraph objects for n...
A dependency graph of constraints enforced by an LLZK struct.
void print(mlir::raw_ostream &os) const
Print the CDG to the specified output stream.
ConstraintDependencyGraph(const ConstraintDependencyGraph &other)
static mlir::FailureOr< ConstraintDependencyGraph > compute(mlir::ModuleOp mod, component::StructDefOp s, mlir::DataFlowSolver &solver, mlir::AnalysisManager &am, const CDGAnalysisContext &ctx)
Compute a ConstraintDependencyGraph (CDG)
SourceRefSet getConstrainingValues(const SourceRef &ref) const
Get the values that are connected to the given ref via emitted constraints.
void dump() const
Dumps the CDG to stderr.
ConstraintDependencyGraph translate(SourceRefRemappings translation) const
Translate the SourceRefs in this CDG to that of a different context.
static mlir::ChangeResult fallbackOpUpdate(mlir::Operation *op, const OperandValues &operandVals, mlir::ArrayRef< Lattice * > results)
void visitExternalCall(mlir::CallOpInterface call, mlir::ArrayRef< const Lattice * > argumentLattices, mlir::ArrayRef< Lattice * > resultLattices) override
Visit a call operation to an externally defined function given the lattices of its arguments.
static mlir::FailureOr< SourceRefLatticeValue > getWriteTargetState(mlir::DataFlowSolver &solver, mlir::Operation *op)
static SourceRefLatticeValue arraySubdivisionOpUpdate(array::ArrayAccessOpInterface op, const OperandValues &operandVals)
static SourceRefLatticeValue getValueState(mlir::DataFlowSolver &solver, mlir::Value val)
void setToEntryState(Lattice *lattice) override
Set the given lattice element(s) at control flow entry point(s).
mlir::LogicalResult visitOperation(mlir::Operation *op, mlir::ArrayRef< const Lattice * > operands, mlir::ArrayRef< Lattice * > results) override
Propagate SourceRef lattice values from operands to results.
static const Lattice * getLattice(mlir::DataFlowSolver &solver, mlir::Value val)
mlir::DenseMap< mlir::Value, const Lattice * > OperandValues
A value at a given point of the SourceRefLattice.
mlir::ChangeResult setValue(const LatticeValue &newValue)
static SourceRefLatticeValue getDefaultValue(ValueTy v)
static mlir::FailureOr< SourceRef > getSourceRef(mlir::Value val)
If val is the source of other values (i.e., a block argument, an allocation-like op result,...
A reference to a "source", which is the base value from which other SSA values are derived.
component::StructDefOp getStruct() const
void setResult(const CDGAnalysisContext &ctx, ConstraintDependencyGraph &&r)
mlir::ModuleOp getModule() const
::mlir::Operation::operand_range getIndices()
Gets the operand range containing the index for each dimension.
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
Gets the SSA Value for the referenced array.
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
ScalarTy foldToScalar() const
If this is an array value, combine all elements into a single scalar value and return it.
mlir::ChangeResult setValue(const AbstractLatticeValue &rhs)
Sets this value to be equal to rhs.
mlir::ChangeResult update(const Derived &rhs)
Union this value with that of rhs.
const Derived & getElemFlatIdx(size_t i) const
Directly index into the flattened array using a single index.
const SourceRefLattice * getLatticeElementFor(mlir::ProgramPoint *point, mlir::Value value)
void setAllToEntryStates(mlir::ArrayRef< SourceRefLattice * > lattices)
std::vector< std::pair< SourceRef, SourceRefLatticeValue > > SourceRefRemappings
void ensure(bool condition, const llvm::Twine &errMsg)
mlir::FailureOr< SymbolLookupResult< T > > resolveCallable(mlir::SymbolTableCollection &symbolTable, mlir::CallOpInterface call)
Based on mlir::CallOpInterface::resolveCallable, but using LLZK lookup helpers.
Parameters and shared objects to pass to child analyses.