34#include <mlir/IR/BuiltinOps.h>
35#include <mlir/Transforms/InliningUtils.h>
36#include <mlir/Transforms/WalkPatternRewriteDriver.h>
38#include <llvm/ADT/PostOrderIterator.h>
39#include <llvm/ADT/SmallPtrSet.h>
40#include <llvm/ADT/SmallVector.h>
41#include <llvm/ADT/StringMap.h>
42#include <llvm/ADT/TypeSwitch.h>
43#include <llvm/Support/Debug.h>
50#define GEN_PASS_DECL_INLINESTRUCTSPASS
51#define GEN_PASS_DEF_INLINESTRUCTSPASS
60#define DEBUG_TYPE "llzk-inline-structs"
69using SrcStructMemberToCloneInDest = std::map<StringRef, DestCloneOfSrcStructMember>;
72using DestToSrcToClonedSrcInDest =
73 DenseMap<DestMemberWithSrcStructType, SrcStructMemberToCloneInDest>;
77static inline Value getSelfValue(
FuncDefOp f) {
83 llvm_unreachable(
"expected \"@compute\" or \"@constrain\" function");
97static FailureOr<MemberWriteOp>
98findOpThatStoresSubcmp(Value writtenValue, function_ref<InFlightDiagnostic()> emitError) {
100 for (Operation *user : writtenValue.getUsers()) {
101 if (
MemberWriteOp writeOp = llvm::dyn_cast<MemberWriteOp>(user)) {
103 if (writeOp.getVal() == writtenValue) {
106 auto diag = emitError().append(
"result should not be written to more than one member.");
107 diag.attachNote(foundWrite.getLoc()).append(
"written here");
108 diag.attachNote(writeOp.getLoc()).append(
"written here");
111 foundWrite = writeOp;
118 return emitError().append(
"result should be written to a member.");
126static bool combineHelper(
131 llvm::dbgs() <<
"[combineHelper] " << readOp <<
" => " << destMemberRefOp <<
'\n';
134 auto srcToClone = destToSrcToClone.find(getDef(tables, destMemberRefOp));
135 if (srcToClone == destToSrcToClone.end()) {
138 SrcStructMemberToCloneInDest oldToNewMembers = srcToClone->second;
139 auto resNewMember = oldToNewMembers.find(readOp.
getMemberName());
140 if (resNewMember == oldToNewMembers.end()) {
145 OpBuilder builder(readOp);
147 readOp.getLoc(), readOp.getType(), destMemberRefOp.
getComponent(),
148 resNewMember->second.getNameAttr()
150 readOp.replaceAllUsesWith(newRead.getOperation());
168static bool combineReadChain(
170 const DestToSrcToClonedSrcInDest &destToSrcToClone
172 LLVM_DEBUG({ llvm::dbgs() <<
"[combineReadChain] " << readOp <<
'\n'; });
175 llvm::dyn_cast_if_present<MemberReadOp>(readOp.
getComponent().getDefiningOp());
176 if (!readThatDefinesBaseComponent) {
179 return combineHelper(readOp, tables, destToSrcToClone, readThatDefinesBaseComponent);
198static LogicalResult combineNewThenReadChain(
200 const DestToSrcToClonedSrcInDest &destToSrcToClone
202 LLVM_DEBUG({ llvm::dbgs() <<
"[combineNewThenReadChain] " << readOp <<
'\n'; });
205 llvm::dyn_cast_if_present<CreateStructOp>(readOp.
getComponent().getDefiningOp());
206 if (!createThatDefinesBaseComponent) {
209 FailureOr<MemberWriteOp> foundWrite =
210 findOpThatStoresSubcmp(createThatDefinesBaseComponent, [&createThatDefinesBaseComponent]() {
211 return createThatDefinesBaseComponent.emitOpError();
213 if (failed(foundWrite)) {
216 return success(combineHelper(readOp, tables, destToSrcToClone, foundWrite.value()));
219static inline MemberReadOp getMemberReadThatDefinesSelfValuePassedToConstrain(
CallOp callOp) {
221 return llvm::dyn_cast_if_present<MemberReadOp>(selfArgFromCall.getDefiningOp());
226struct PendingErasure {
227 SmallPtrSet<Operation *, 8> memberReadOps;
228 SmallPtrSet<Operation *, 8> memberWriteOps;
229 SmallVector<CreateStructOp> newStructOps;
230 SmallVector<DestMemberWithSrcStructType> memberDefs;
235 SymbolTableCollection &tables;
236 PendingErasure &toDelete;
250 class MemberRefRewriter final :
public OpInterfaceRewritePattern<MemberRefOpInterface> {
258 const SrcStructMemberToCloneInDest &oldToNewMembers;
262 FuncDefOp originalFunc, Value newRefBase,
263 const SrcStructMemberToCloneInDest &oldToNewMemberDef
265 : OpInterfaceRewritePattern(originalFunc.getContext()), funcRef(originalFunc),
266 oldBaseVal(
nullptr), newBaseVal(newRefBase), oldToNewMembers(oldToNewMemberDef) {}
279 rewriter.modifyOpInPlace(op, [
this, &op]() {
280 DestCloneOfSrcStructMember newF = oldToNewMembers.at(op.
getMemberName());
288 static FuncDefOp cloneWithMemberRefUpdate(std::unique_ptr<MemberRefRewriter> thisPat) {
292 thisPat->funcRef = srcFuncClone;
293 thisPat->oldBaseVal = getSelfValue(srcFuncClone);
295 MLIRContext *ctx = thisPat->getContext();
296 RewritePatternSet patterns(ctx, std::move(thisPat));
297 walkAndApplyPatterns(srcFuncClone, std::move(patterns));
306 const StructInliner &data;
307 const DestToSrcToClonedSrcInDest &destToSrcToClone;
312 virtual void processCloneBeforeInlining(
FuncDefOp func) {}
313 virtual ~ImplBase() =
default;
316 ImplBase(
const StructInliner &inliner,
const DestToSrcToClonedSrcInDest &destToSrcToCloneRef)
317 : data(inliner), destToSrcToClone(destToSrcToCloneRef) {}
321 llvm::dbgs() <<
"[doInlining] SOURCE FUNCTION:\n";
323 llvm::dbgs() <<
"[doInlining] DESTINATION FUNCTION:\n";
327 InlinerInterface inliner(destFunc.getContext());
330 auto callHandler = [
this, &inliner, &srcFunc](CallOp callOp) {
332 auto callOpTarget = callOp.getCalleeTarget(this->data.tables);
333 assert(succeeded(callOpTarget));
334 if (callOpTarget->get() != srcFunc) {
335 return WalkResult::advance();
340 MemberRefOpInterface selfMemberRefOp = this->getSelfRefMember(callOp);
341 if (!selfMemberRefOp) {
343 return WalkResult::interrupt();
349 FuncDefOp srcFuncClone = MemberRefRewriter::cloneWithMemberRefUpdate(
350 std::make_unique<MemberRefRewriter>(
352 this->destToSrcToClone.at(this->data.getDef(selfMemberRefOp))
355 this->processCloneBeforeInlining(srcFuncClone);
358 LogicalResult inlineCallRes =
359 inlineCall(inliner, callOp, srcFuncClone, &srcFuncClone.
getBody(),
false);
360 if (failed(inlineCallRes)) {
362 return WalkResult::interrupt();
364 srcFuncClone.erase();
366 return WalkResult::skip();
369 auto memberWriteHandler = [
this](MemberWriteOp writeOp) {
371 if (this->destToSrcToClone.contains(this->data.getDef(writeOp))) {
372 this->data.toDelete.memberWriteOps.insert(writeOp);
374 return WalkResult::advance();
379 auto memberReadHandler = [
this](MemberReadOp readOp) {
381 if (this->destToSrcToClone.contains(this->data.getDef(readOp))) {
382 this->data.toDelete.memberReadOps.insert(readOp);
385 return combineReadChain(readOp, this->data.tables, destToSrcToClone)
387 : WalkResult::advance();
390 WalkResult walkRes = destFunc.
getBody().walk<WalkOrder::PreOrder>([&](Operation *op) {
391 return TypeSwitch<Operation *, WalkResult>(op)
392 .Case<CallOp>(callHandler)
393 .Case<MemberWriteOp>(memberWriteHandler)
394 .Case<MemberReadOp>(memberReadHandler)
395 .Default([](Operation *) {
return WalkResult::advance(); });
398 return failure(walkRes.wasInterrupted());
402 class ConstrainImpl :
public ImplBase {
403 using ImplBase::ImplBase;
405 MemberRefOpInterface getSelfRefMember(CallOp callOp)
override {
406 LLVM_DEBUG({ llvm::dbgs() <<
"[ConstrainImpl::getSelfRefMember] " << callOp <<
'\n'; });
411 MemberRefOpInterface selfMemberRef =
412 getMemberReadThatDefinesSelfValuePassedToConstrain(callOp);
414 selfMemberRef.getComponent().getType() == this->data.destStruct.getType()) {
415 return selfMemberRef;
420 "\" to be passed a value read from a member in the current stuct."
427 class ComputeImpl :
public ImplBase {
428 using ImplBase::ImplBase;
430 MemberRefOpInterface getSelfRefMember(CallOp callOp)
override {
431 LLVM_DEBUG({ llvm::dbgs() <<
"[ComputeImpl::getSelfRefMember] " << callOp <<
'\n'; });
438 FailureOr<MemberWriteOp> foundWrite =
440 return callOp.emitOpError().append(
"\"@", FUNC_NAME_COMPUTE,
"\" ");
442 return static_cast<MemberRefOpInterface
>(foundWrite.value_or(
nullptr));
445 void processCloneBeforeInlining(FuncDefOp func)
override {
449 func.
getBody().walk([
this](CreateStructOp newStructOp) {
450 if (newStructOp.getType() == this->data.srcStruct.getType()) {
451 this->data.toDelete.newStructOps.push_back(newStructOp);
460 DestToSrcToClonedSrcInDest cloneMembers() {
461 DestToSrcToClonedSrcInDest destToSrcToClone;
463 SymbolTable &destStructSymTable = tables.getSymbolTable(destStruct);
464 StructType srcStructType = srcStruct.getType();
465 for (MemberDefOp destMember : destStruct.getMemberDefs()) {
466 if (StructType destMemberType = llvm::dyn_cast<StructType>(destMember.getType())) {
471 assert(unifications.empty());
473 toDelete.memberDefs.push_back(destMember);
476 SrcStructMemberToCloneInDest &srcToClone = destToSrcToClone[destMember];
477 std::vector<MemberDefOp> srcMembers = srcStruct.getMemberDefs();
478 if (srcMembers.empty()) {
481 OpBuilder builder(destMember);
482 std::string newNameBase =
484 for (MemberDefOp srcMember : srcMembers) {
485 DestCloneOfSrcStructMember newF = llvm::cast<MemberDefOp>(builder.clone(*srcMember));
486 newF.setName(builder.getStringAttr(newNameBase +
'+' + newF.getName()));
487 srcToClone[srcMember.getSymNameAttr()] = newF;
489 destStructSymTable.insert(newF);
493 return destToSrcToClone;
497 inline LogicalResult inlineConstrainCall(
const DestToSrcToClonedSrcInDest &destToSrcToClone) {
498 return ConstrainImpl(*
this, destToSrcToClone)
499 .doInlining(srcStruct.getConstrainFuncOp(), destStruct.getConstrainFuncOp());
503 inline LogicalResult inlineComputeCall(
const DestToSrcToClonedSrcInDest &destToSrcToClone) {
504 return ComputeImpl(*
this, destToSrcToClone)
505 .doInlining(srcStruct.getComputeFuncOp(), destStruct.getComputeFuncOp());
510 SymbolTableCollection &tbls, PendingErasure &opsToDelete, StructDefOp
from, StructDefOp into
512 : tables(tbls), toDelete(opsToDelete), srcStruct(
from), destStruct(into) {}
514 FailureOr<DestToSrcToClonedSrcInDest> doInline() {
516 llvm::dbgs() <<
"[StructInliner] merge " << srcStruct.getSymNameAttr() <<
" into "
517 << destStruct.getSymNameAttr() <<
'\n'
520 DestToSrcToClonedSrcInDest destToSrcToClone = cloneMembers();
521 if (failed(inlineConstrainCall(destToSrcToClone)) ||
522 failed(inlineComputeCall(destToSrcToClone))) {
525 return destToSrcToClone;
531 { t.contains(p) } -> std::convertible_to<bool>;
535template <
typename... PendingDeletionSets>
537class DanglingUseHandler {
538 SymbolTableCollection &tables;
539 const DestToSrcToClonedSrcInDest &destToSrcToClone;
540 std::tuple<
const PendingDeletionSets &...> otherRefsToBeDeleted;
544 SymbolTableCollection &symTables,
const DestToSrcToClonedSrcInDest &destToSrcToCloneRef,
545 const PendingDeletionSets &...otherRefsPendingDeletion
547 : tables(symTables), destToSrcToClone(destToSrcToCloneRef),
548 otherRefsToBeDeleted(otherRefsPendingDeletion...) {}
555 LogicalResult handle(Operation *op)
const {
556 if (op->use_empty()) {
561 llvm::dbgs() <<
"[DanglingUseHandler::handle] op: " << *op <<
'\n';
562 llvm::dbgs() <<
"[DanglingUseHandler::handle] in function: "
563 << op->getParentOfType<
FuncDefOp>() <<
'\n';
565 for (OpOperand &
use : llvm::make_early_inc_range(op->getUses())) {
566 if (
CallOp c = llvm::dyn_cast<CallOp>(
use.getOwner())) {
567 if (failed(handleUseInCallOp(
use, c, op))) {
571 Operation *user =
use.getOwner();
573 if (!opWillBeDeleted(user)) {
574 return op->emitOpError()
576 "with use in '", user->getName().getStringRef(),
577 "' is not (currently) supported by this pass."
579 .attachNote(user->getLoc())
580 .append(
"used by this operation");
585 if (!op->use_empty()) {
586 for (Operation *user : op->getUsers()) {
587 if (!opWillBeDeleted(user)) {
588 llvm::errs() <<
"Op has remaining use(s) that could not be removed: " << *op <<
'\n';
589 llvm_unreachable(
"Expected all uses to be removed");
597 inline LogicalResult handleUseInCallOp(OpOperand &
use,
CallOp inCall, Operation *origin)
const {
599 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] use in call: " << inCall <<
'\n'
601 unsigned argIdx =
use.getOperandNumber() - inCall.
getArgOperands().getBeginOperandIndex();
603 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] at index: " << argIdx <<
'\n'
607 if (failed(tgtFuncRes)) {
609 ->emitOpError(
"as argument to an unknown function is not supported by this pass.")
610 .attachNote(inCall.getLoc())
611 .append(
"used by this call");
615 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] call target: " << tgtFunc <<
'\n'
617 if (tgtFunc.isExternal()) {
621 ->emitOpError(
"as argument to a no-body free function is not supported by this pass.")
622 .attachNote(inCall.getLoc())
623 .append(
"used by this call");
627 TypeSwitch<Operation *, MemberRefOpInterface>(origin)
628 .template Case<MemberReadOp>([](
auto p) {
return p; })
629 .
template Case<CreateStructOp>([](
auto p) {
630 return findOpThatStoresSubcmp(p, [&p]() {
return p.emitOpError(); }).value_or(
nullptr);
631 }).Default([](Operation *p) {
632 llvm::errs() <<
"Encountered unexpected op: "
633 << (p ? p->getName().getStringRef() :
"<<null>>") <<
'\n';
634 llvm_unreachable(
"Unexpected op kind");
638 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] member ref op for param: "
641 if (!paramFromMember) {
644 const SrcStructMemberToCloneInDest &newMembers =
645 destToSrcToClone.at(getDef(tables, paramFromMember));
647 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] members to split: "
652 splitFunctionParam(tgtFunc, argIdx, newMembers);
654 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] UPDATED call target: " << tgtFunc
656 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] UPDATED call target type: "
662 OpBuilder builder(inCall);
663 SmallVector<Value> splitArgs;
667 for (
auto [origName, newMemberRef] : newMembers) {
669 inCall.getLoc(), newMemberRef.getType(), originalBaseVal, newMemberRef.getNameAttr()
675 newOpArgs.erase(newOpArgs.begin() + argIdx), splitArgs.begin(), splitArgs.end()
678 inCall.replaceAllUsesWith(builder.create<
CallOp>(
684 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] UPDATED function: "
685 << origin->getParentOfType<
FuncDefOp>() <<
'\n';
691 inline bool opWillBeDeleted(Operation *otherOp)
const {
692 return std::apply([&](
const auto &...sets) {
693 return ((sets.contains(otherOp)) || ...);
694 }, otherRefsToBeDeleted);
701 static void splitFunctionParam(
702 FuncDefOp func,
unsigned paramIdx,
const SrcStructMemberToCloneInDest &nameToNewMember
706 const SrcStructMemberToCloneInDest &newMembers;
709 Impl(
unsigned paramIdx,
const SrcStructMemberToCloneInDest &nameToNewMember)
710 : inputIdx(paramIdx), newMembers(nameToNewMember) {}
713 SmallVector<Type>
convertInputs(ArrayRef<Type> origTypes)
override {
714 SmallVector<Type> newTypes(origTypes);
715 auto it = newTypes.erase(newTypes.begin() + inputIdx);
716 for (
auto [_, newMember] : newMembers) {
717 newTypes.insert(it, newMember.getType());
722 SmallVector<Type>
convertResults(ArrayRef<Type> origTypes)
override {
723 return SmallVector<Type>(origTypes);
728 SmallVector<Attribute> newAttrs(origAttrs.getValue());
729 newAttrs.insert(newAttrs.begin() + inputIdx, newMembers.size() - 1, origAttrs[inputIdx]);
730 return ArrayAttr::get(origAttrs.getContext(), newAttrs);
739 Value oldStructRef = entryBlock.getArgument(inputIdx);
743 llvm::StringMap<BlockArgument> memberNameToNewArg;
744 Location loc = oldStructRef.getLoc();
745 unsigned idx = inputIdx;
746 for (
auto [memberName, newMember] : newMembers) {
748 BlockArgument newArg = entryBlock.insertArgument(++idx, newMember.getType(), loc);
749 memberNameToNewArg[memberName] = newArg;
754 for (OpOperand &oldBlockArgUse : llvm::make_early_inc_range(oldStructRef.getUses())) {
755 if (MemberReadOp readOp = llvm::dyn_cast<MemberReadOp>(oldBlockArgUse.getOwner())) {
757 BlockArgument newArg = memberNameToNewArg.at(readOp.
getMemberName());
758 rewriter.replaceAllUsesWith(readOp, newArg);
759 rewriter.eraseOp(readOp);
764 llvm::errs() <<
"Unexpected use of " << oldBlockArgUse.get() <<
" in "
765 << *oldBlockArgUse.getOwner() <<
'\n';
766 llvm_unreachable(
"Not yet implemented");
770 entryBlock.eraseArgument(inputIdx);
773 IRRewriter rewriter(func.getContext());
774 Impl(paramIdx, nameToNewMember).convert(func, rewriter);
778static LogicalResult finalizeStruct(
779 SymbolTableCollection &tables,
StructDefOp caller, PendingErasure &&toDelete,
780 DestToSrcToClonedSrcInDest &&destToSrcToClone
783 llvm::dbgs() <<
"[finalizeStruct] dumping 'caller' struct before compressing chains:\n";
784 caller.
print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
785 llvm::dbgs() <<
'\n';
790 combineReadChain(readOp, tables, destToSrcToClone);
794 auto res = computeFn.walk([&tables, &destToSrcToClone, &computeSelfVal](
MemberReadOp readOp) {
795 combineReadChain(readOp, tables, destToSrcToClone);
799 return WalkResult::advance();
801 LogicalResult innerRes = combineNewThenReadChain(readOp, tables, destToSrcToClone);
802 return failed(innerRes) ? WalkResult::interrupt() : WalkResult::advance();
804 if (res.wasInterrupted()) {
809 llvm::dbgs() <<
"[finalizeStruct] dumping 'caller' struct before deleting ops:\n";
810 caller.
print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
811 llvm::dbgs() <<
'\n';
812 llvm::dbgs() <<
"[finalizeStruct] ops marked for deletion:\n";
813 for (Operation *op : toDelete.memberReadOps) {
814 llvm::dbgs().indent(2) << *op <<
'\n';
816 for (Operation *op : toDelete.memberWriteOps) {
817 llvm::dbgs().indent(2) << *op <<
'\n';
820 llvm::dbgs().indent(2) << op <<
'\n';
822 for (DestMemberWithSrcStructType op : toDelete.memberDefs) {
823 llvm::dbgs().indent(2) << op <<
'\n';
829 DanglingUseHandler<SmallPtrSet<Operation *, 8>, SmallPtrSet<Operation *, 8>> useHandler(
830 tables, destToSrcToClone, toDelete.memberWriteOps, toDelete.memberReadOps
833 if (failed(useHandler.handle(op))) {
839 for (Operation *op : toDelete.memberWriteOps) {
840 if (failed(useHandler.handle(op))) {
845 for (Operation *op : toDelete.memberReadOps) {
846 if (failed(useHandler.handle(op))) {
855 SymbolTable &callerSymTab = tables.getSymbolTable(caller);
856 for (DestMemberWithSrcStructType op : toDelete.memberDefs) {
857 assert(op.getParentOp() == caller);
858 callerSymTab.erase(op);
867 for (
auto &[caller, callees] : plan) {
870 PendingErasure toDelete;
872 DestToSrcToClonedSrcInDest aggregateReplacements;
875 FailureOr<DestToSrcToClonedSrcInDest> res =
876 StructInliner(tables, toDelete, toInline, caller).doInline();
881 for (
auto &[k, v] : res.value()) {
882 assert(!aggregateReplacements.contains(k) &&
"duplicate not possible");
883 aggregateReplacements[k] = std::move(v);
887 LogicalResult finalizeResult =
888 finalizeStruct(tables, caller, std::move(toDelete), std::move(aggregateReplacements));
889 if (failed(finalizeResult)) {
899 static uint64_t complexity(FuncDefOp f) {
900 uint64_t complexity = 0;
901 f.
getBody().walk([&complexity](Operation *op) {
902 if (llvm::isa<felt::MulFeltOp>(op)) {
904 }
else if (
auto ee = llvm::dyn_cast<constrain::EmitEqualityOp>(op)) {
906 }
else if (
auto ec = llvm::dyn_cast<constrain::EmitContainmentOp>(op)) {
915 static FailureOr<FuncDefOp>
916 getIfStructConstrain(
const SymbolUseGraphNode *node, SymbolTableCollection &tables) {
918 assert(succeeded(lookupRes) &&
"graph contains node with invalid path");
919 if (FuncDefOp f = llvm::dyn_cast<FuncDefOp>(lookupRes->get())) {
929 static inline StructDefOp getParentStruct(FuncDefOp func) {
932 assert(succeeded(currentNodeParentStruct));
933 return currentNodeParentStruct.value();
937 inline bool exceedsMaxComplexity(uint64_t
check) {
938 return maxComplexity > 0 &&
check > maxComplexity;
943 static inline bool canInline(FuncDefOp currentFunc, FuncDefOp successorFunc) {
955 WalkResult res = currentFunc.walk([](CallOp c) {
956 return getMemberReadThatDefinesSelfValuePassedToConstrain(c)
957 ? WalkResult::interrupt()
958 : WalkResult::advance();
964 return res.wasInterrupted();
971 inline FailureOr<InliningPlan>
972 makePlan(
const SymbolUseGraph &useGraph, SymbolTableCollection &tables) {
974 llvm::dbgs() <<
"Running InlineStructsPass with max complexity ";
975 if (maxComplexity == 0) {
976 llvm::dbgs() <<
"unlimited";
978 llvm::dbgs() << maxComplexity;
980 llvm::dbgs() <<
'\n';
983 DenseMap<const SymbolUseGraphNode *, uint64_t> complexityMemo;
995 for (
const SymbolUseGraphNode *currentNode : llvm::post_order(&useGraph)) {
996 LLVM_DEBUG(llvm::dbgs() <<
"\ncurrentNode = " << currentNode->toString());
997 if (!currentNode->isRealNode()) {
1000 if (currentNode->isStructParam()) {
1002 Operation *lookupFrom = currentNode->getSymbolPathRoot().getOperation();
1006 Operation *reportLoc = succeeded(res) ? res->get() : lookupFrom;
1007 return reportLoc->emitError(
"Cannot inline structs with parameters.");
1009 FailureOr<FuncDefOp> currentFuncOpt = getIfStructConstrain(currentNode, tables);
1010 if (failed(currentFuncOpt)) {
1013 FuncDefOp currentFunc = currentFuncOpt.value();
1014 uint64_t currentComplexity = complexity(currentFunc);
1016 if (exceedsMaxComplexity(currentComplexity)) {
1017 complexityMemo[currentNode] = currentComplexity;
1022 SmallVector<StructDefOp> successorsToMerge;
1023 for (
const SymbolUseGraphNode *successor : currentNode->successorIter()) {
1024 LLVM_DEBUG(llvm::dbgs().indent(2) <<
"successor: " << successor->toString() <<
'\n');
1026 auto memoResult = complexityMemo.find(successor);
1027 if (memoResult == complexityMemo.end()) {
1030 uint64_t sComplexity = memoResult->second;
1032 sComplexity <= (std::numeric_limits<uint64_t>::max() - currentComplexity) &&
1033 "addition will overflow"
1035 uint64_t potentialComplexity = currentComplexity + sComplexity;
1036 if (!exceedsMaxComplexity(potentialComplexity)) {
1037 currentComplexity = potentialComplexity;
1038 FailureOr<FuncDefOp> successorFuncOpt = getIfStructConstrain(successor, tables);
1039 assert(succeeded(successorFuncOpt));
1040 FuncDefOp successorFunc = successorFuncOpt.value();
1041 if (canInline(currentFunc, successorFunc)) {
1042 successorsToMerge.push_back(getParentStruct(successorFunc));
1046 complexityMemo[currentNode] = currentComplexity;
1047 if (!successorsToMerge.empty()) {
1048 retVal.emplace_back(getParentStruct(currentFunc), std::move(successorsToMerge));
1052 llvm::dbgs() <<
"-----------------------------------------------------------------\n";
1053 llvm::dbgs() <<
"InlineStructsPass plan:\n";
1054 for (
auto &[caller, callees] : retVal) {
1055 llvm::dbgs().indent(2) <<
"inlining the following into \"" << caller.
getSymName() <<
"\"\n";
1056 for (StructDefOp c : callees) {
1057 llvm::dbgs().indent(4) <<
"\"" << c.getSymName() <<
"\"\n";
1060 llvm::dbgs() <<
"-----------------------------------------------------------------\n";
1066 void runOnOperation()
override {
1067 const SymbolUseGraph &useGraph = getAnalysis<SymbolUseGraph>();
1070 SymbolTableCollection tables;
1071 FailureOr<InliningPlan> plan = makePlan(useGraph, tables);
1073 signalPassFailure();
1078 signalPassFailure();
1087 return std::make_unique<InlineStructsPass>();
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
LogicalResult performInlining(SymbolTableCollection &tables, InliningPlan &plan)
mlir::SmallVector< std::pair< llzk::component::StructDefOp, mlir::SmallVector< llzk::component::StructDefOp > > > InliningPlan
Maps caller struct to callees that should be inlined.
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
virtual void processBlockArgs(mlir::Block &entryBlock, mlir::RewriterBase &rewriter)=0
virtual llvm::SmallVector< mlir::Type > convertResults(mlir::ArrayRef< mlir::Type > origTypes)=0
virtual mlir::ArrayAttr convertResultAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual mlir::ArrayAttr convertInputAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual llvm::SmallVector< mlir::Type > convertInputs(mlir::ArrayRef< mlir::Type > origTypes)=0
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
::llvm::StringRef getMemberName()
::mlir::FailureOr< SymbolLookupResult< MemberDefOp > > getMemberDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the member referenced in this op.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Gets the SSA value with the target component from the MemberRefOp.
void setMemberName(::llvm::StringRef attrValue)
Sets the member name attribute value in the MemberRefOp.
::llvm::StringRef getMemberName()
Gets the member name attribute value from the MemberRefOp.
::mlir::OpOperand & getComponentMutable()
Gets the SSA value with the target component from the MemberRefOp.
::llvm::StringRef getSymName()
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
void print(::mlir::OpAsmPrinter &_odsPrinter)
::mlir::Operation::operand_range getArgOperands()
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
::mlir::OperandRangeRange getMapOperands()
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
FuncDefOp clone(::mlir::IRMapping &mapper)
Create a deep copy of this function and all of its blocks, remapping any operands that use values out...
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
::mlir::FunctionType getFunctionType()
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
bool isStructConstrain()
Return true iff the function is within a StructDefOp and named FUNC_NAME_CONSTRAIN.
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
::mlir::Region & getBody()
std::string toStringOne(const T &value)
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
uint64_t computeEmitEqCardinality(Type type)
constexpr char FUNC_NAME_CONSTRAIN[]
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
std::unique_ptr< mlir::Pass > createInlineStructsPass()
bool hasCycle(const GraphT &G)