36#include <mlir/IR/BuiltinOps.h>
37#include <mlir/Transforms/InliningUtils.h>
38#include <mlir/Transforms/WalkPatternRewriteDriver.h>
40#include <llvm/ADT/PostOrderIterator.h>
41#include <llvm/ADT/SmallPtrSet.h>
42#include <llvm/ADT/SmallVector.h>
43#include <llvm/ADT/StringMap.h>
44#include <llvm/ADT/TypeSwitch.h>
45#include <llvm/Support/Debug.h>
52#define GEN_PASS_DECL_INLINESTRUCTSPASS
53#define GEN_PASS_DEF_INLINESTRUCTSPASS
63#define DEBUG_TYPE "llzk-inline-structs"
72using SrcStructMemberToCloneInDest = std::map<StringRef, DestCloneOfSrcStructMember>;
75using DestToSrcToClonedSrcInDest =
76 DenseMap<DestMemberWithSrcStructType, SrcStructMemberToCloneInDest>;
80static inline Value getSelfValue(
FuncDefOp f) {
86 llvm_unreachable(
"expected \"@compute\" or \"@constrain\" function");
100static FailureOr<MemberWriteOp>
101findOpThatStoresSubcmp(Value writtenValue, function_ref<InFlightDiagnostic()> emitError) {
103 for (Operation *user : writtenValue.getUsers()) {
104 if (
MemberWriteOp writeOp = llvm::dyn_cast<MemberWriteOp>(user)) {
106 if (writeOp.getVal() == writtenValue) {
109 auto diag = emitError().append(
"result should not be written to more than one member.");
110 diag.attachNote(foundWrite.getLoc()).append(
"written here");
111 diag.attachNote(writeOp.getLoc()).append(
"written here");
114 foundWrite = writeOp;
121 return emitError().append(
"result should be written to a member.");
129static bool combineHelper(
134 llvm::dbgs() <<
"[combineHelper] " << readOp <<
" => " << destMemberRefOp <<
'\n';
137 auto srcToClone = destToSrcToClone.find(getDef(tables, destMemberRefOp));
138 if (srcToClone == destToSrcToClone.end()) {
141 SrcStructMemberToCloneInDest oldToNewMembers = srcToClone->second;
142 auto resNewMember = oldToNewMembers.find(readOp.
getMemberName());
143 if (resNewMember == oldToNewMembers.end()) {
148 OpBuilder builder(readOp);
150 readOp.getLoc(), readOp.getType(), destMemberRefOp.
getComponent(),
151 resNewMember->second.getNameAttr()
153 readOp.replaceAllUsesWith(newRead.getOperation());
171static bool combineReadChain(
173 const DestToSrcToClonedSrcInDest &destToSrcToClone
175 LLVM_DEBUG({ llvm::dbgs() <<
"[combineReadChain] " << readOp <<
'\n'; });
178 llvm::dyn_cast_if_present<MemberReadOp>(readOp.
getComponent().getDefiningOp());
179 if (!readThatDefinesBaseComponent) {
182 return combineHelper(readOp, tables, destToSrcToClone, readThatDefinesBaseComponent);
201static LogicalResult combineNewThenReadChain(
203 const DestToSrcToClonedSrcInDest &destToSrcToClone
205 LLVM_DEBUG({ llvm::dbgs() <<
"[combineNewThenReadChain] " << readOp <<
'\n'; });
208 llvm::dyn_cast_if_present<CreateStructOp>(readOp.
getComponent().getDefiningOp());
209 if (!createThatDefinesBaseComponent) {
212 FailureOr<MemberWriteOp> foundWrite =
213 findOpThatStoresSubcmp(createThatDefinesBaseComponent, [&createThatDefinesBaseComponent]() {
214 return createThatDefinesBaseComponent.emitOpError();
216 if (failed(foundWrite)) {
219 return success(combineHelper(readOp, tables, destToSrcToClone, foundWrite.value()));
222static inline MemberReadOp getMemberReadThatDefinesSelfValuePassedToConstrain(
CallOp callOp) {
224 return llvm::dyn_cast_if_present<MemberReadOp>(selfArgFromCall.getDefiningOp());
229struct PendingErasure {
230 SmallPtrSet<Operation *, 8> memberReadOps;
231 SmallPtrSet<Operation *, 8> memberWriteOps;
232 SmallVector<CreateStructOp> newStructOps;
233 SmallVector<DestMemberWithSrcStructType> memberDefs;
238 SymbolTableCollection &tables;
239 PendingErasure &toDelete;
253 class MemberRefRewriter final :
public OpInterfaceRewritePattern<MemberRefOpInterface> {
261 const SrcStructMemberToCloneInDest &oldToNewMembers;
266 const SrcStructMemberToCloneInDest &oldToNewMemberDef
268 : OpInterfaceRewritePattern(originalFunc.getContext()), funcRef(originalFunc),
269 oldBaseVal(nullptr), newBaseVal(newRefBase), oldToNewMembers(oldToNewMemberDef) {}
277 op.getComponent() == oldBaseVal && oldToNewMembers.contains(op.getMemberName())
282 rewriter.modifyOpInPlace(op, [
this, &op]() {
283 DestCloneOfSrcStructMember newF = oldToNewMembers.at(op.getMemberName());
284 op.setMemberName(newF.getSymName());
285 op.getComponentMutable().set(this->newBaseVal);
291 static FuncDefOp cloneWithMemberRefUpdate(std::unique_ptr<MemberRefRewriter> thisPat) {
295 thisPat->funcRef = srcFuncClone;
296 thisPat->oldBaseVal = getSelfValue(srcFuncClone);
298 MLIRContext *ctx = thisPat->getContext();
299 RewritePatternSet patterns(ctx, std::move(thisPat));
300 walkAndApplyPatterns(srcFuncClone, std::move(patterns));
309 const StructInliner &data;
310 const DestToSrcToClonedSrcInDest &destToSrcToClone;
315 virtual void processCloneBeforeInlining(
FuncDefOp func) {}
316 virtual ~ImplBase() =
default;
319 ImplBase(
const StructInliner &inliner,
const DestToSrcToClonedSrcInDest &destToSrcToCloneRef)
320 : data(inliner), destToSrcToClone(destToSrcToCloneRef) {}
324 llvm::dbgs() <<
"[doInlining] SOURCE FUNCTION:\n";
326 llvm::dbgs() <<
"[doInlining] DESTINATION FUNCTION:\n";
330 InlinerInterface inliner(destFunc.getContext());
333 auto callHandler = [
this, &inliner, &srcFunc](
CallOp callOp) {
336 assert(succeeded(callOpTarget));
337 if (callOpTarget->get() != srcFunc) {
338 return WalkResult::advance();
344 if (!selfMemberRefOp) {
346 return WalkResult::interrupt();
352 FuncDefOp srcFuncClone = MemberRefRewriter::cloneWithMemberRefUpdate(
353 std::make_unique<MemberRefRewriter>(
355 this->destToSrcToClone.at(this->data.getDef(selfMemberRefOp))
358 this->processCloneBeforeInlining(srcFuncClone);
361 LogicalResult inlineCallRes =
362 inlineCall(inliner, callOp, srcFuncClone, &srcFuncClone.
getBody(),
false);
363 if (failed(inlineCallRes)) {
365 return WalkResult::interrupt();
367 srcFuncClone.erase();
369 return WalkResult::skip();
372 auto memberWriteHandler = [
this](MemberWriteOp writeOp) {
374 if (this->destToSrcToClone.contains(this->data.getDef(writeOp))) {
375 this->data.toDelete.memberWriteOps.insert(writeOp);
377 return WalkResult::advance();
382 auto memberReadHandler = [
this](MemberReadOp readOp) {
384 if (this->destToSrcToClone.contains(this->data.getDef(readOp))) {
385 this->data.toDelete.memberReadOps.insert(readOp);
388 return combineReadChain(readOp, this->data.tables, destToSrcToClone)
390 : WalkResult::advance();
393 WalkResult walkRes = destFunc.
getBody().walk<WalkOrder::PreOrder>([&](Operation *op) {
394 return TypeSwitch<Operation *, WalkResult>(op)
395 .Case<CallOp>(callHandler)
396 .Case<MemberWriteOp>(memberWriteHandler)
397 .Case<MemberReadOp>(memberReadHandler)
398 .Default([](Operation *) {
return WalkResult::advance(); });
401 return failure(walkRes.wasInterrupted());
405 class ConstrainImpl :
public ImplBase {
406 using ImplBase::ImplBase;
408 MemberRefOpInterface getSelfRefMember(CallOp callOp)
override {
409 LLVM_DEBUG({ llvm::dbgs() <<
"[ConstrainImpl::getSelfRefMember] " << callOp <<
'\n'; });
414 MemberRefOpInterface selfMemberRef =
415 getMemberReadThatDefinesSelfValuePassedToConstrain(callOp);
417 selfMemberRef.getComponent().getType() == this->data.destStruct.getType()) {
418 return selfMemberRef;
423 "\" to be passed a value read from a member in the current stuct."
430 class ComputeImpl :
public ImplBase {
431 using ImplBase::ImplBase;
433 MemberRefOpInterface getSelfRefMember(CallOp callOp)
override {
434 LLVM_DEBUG({ llvm::dbgs() <<
"[ComputeImpl::getSelfRefMember] " << callOp <<
'\n'; });
441 FailureOr<MemberWriteOp> foundWrite =
443 return callOp.emitOpError().append(
"\"@", FUNC_NAME_COMPUTE,
"\" ");
445 return static_cast<MemberRefOpInterface
>(foundWrite.value_or(
nullptr));
448 void processCloneBeforeInlining(FuncDefOp func)
override {
452 func.
getBody().walk([
this](CreateStructOp newStructOp) {
453 if (newStructOp.getType() == this->data.srcStruct.getType()) {
454 this->data.toDelete.newStructOps.push_back(newStructOp);
463 DestToSrcToClonedSrcInDest cloneMembers() {
464 DestToSrcToClonedSrcInDest destToSrcToClone;
466 SymbolTable &destStructSymTable = tables.getSymbolTable(destStruct);
467 StructType srcStructType = srcStruct.getType();
468 for (MemberDefOp destMember : destStruct.getMemberDefs()) {
469 if (StructType destMemberType = llvm::dyn_cast<StructType>(destMember.getType())) {
474 assert(unifications.empty());
476 toDelete.memberDefs.push_back(destMember);
479 SrcStructMemberToCloneInDest &srcToClone = destToSrcToClone[destMember];
480 std::vector<MemberDefOp> srcMembers = srcStruct.getMemberDefs();
481 if (srcMembers.empty()) {
484 OpBuilder builder(destMember);
485 std::string newNameBase =
487 for (MemberDefOp srcMember : srcMembers) {
488 DestCloneOfSrcStructMember newF = llvm::cast<MemberDefOp>(builder.clone(*srcMember));
489 newF.setName(builder.getStringAttr(newNameBase +
'+' + newF.getName()));
490 srcToClone[srcMember.getSymNameAttr()] = newF;
492 destStructSymTable.insert(newF);
496 return destToSrcToClone;
500 inline LogicalResult inlineConstrainCall(
const DestToSrcToClonedSrcInDest &destToSrcToClone) {
501 return ConstrainImpl(*
this, destToSrcToClone)
502 .doInlining(srcStruct.getConstrainFuncOp(), destStruct.getConstrainFuncOp());
506 inline LogicalResult inlineComputeCall(
const DestToSrcToClonedSrcInDest &destToSrcToClone) {
507 return ComputeImpl(*
this, destToSrcToClone)
508 .doInlining(srcStruct.getComputeFuncOp(), destStruct.getComputeFuncOp());
513 SymbolTableCollection &tbls, PendingErasure &opsToDelete, StructDefOp
from, StructDefOp into
515 : tables(tbls), toDelete(opsToDelete), srcStruct(
from), destStruct(into) {}
517 FailureOr<DestToSrcToClonedSrcInDest> doInline() {
519 llvm::dbgs() <<
"[StructInliner] merge " << srcStruct.getSymNameAttr() <<
" into "
520 << destStruct.getSymNameAttr() <<
'\n'
523 DestToSrcToClonedSrcInDest destToSrcToClone = cloneMembers();
524 if (failed(inlineConstrainCall(destToSrcToClone)) ||
525 failed(inlineComputeCall(destToSrcToClone))) {
528 return destToSrcToClone;
534 { t.contains(p) } -> std::convertible_to<bool>;
538template <
typename... PendingDeletionSets>
540class DanglingUseHandler {
541 SymbolTableCollection &tables;
542 const DestToSrcToClonedSrcInDest &destToSrcToClone;
543 std::tuple<
const PendingDeletionSets &...> otherRefsToBeDeleted;
547 SymbolTableCollection &symTables,
const DestToSrcToClonedSrcInDest &destToSrcToCloneRef,
548 const PendingDeletionSets &...otherRefsPendingDeletion
550 : tables(symTables), destToSrcToClone(destToSrcToCloneRef),
551 otherRefsToBeDeleted(otherRefsPendingDeletion...) {}
558 LogicalResult handle(Operation *op)
const {
559 if (op->use_empty()) {
564 llvm::dbgs() <<
"[DanglingUseHandler::handle] op: " << *op <<
'\n';
565 llvm::dbgs() <<
"[DanglingUseHandler::handle] in function: "
566 << op->getParentOfType<
FuncDefOp>() <<
'\n';
568 for (OpOperand &
use : llvm::make_early_inc_range(op->getUses())) {
569 if (
CallOp c = llvm::dyn_cast<CallOp>(
use.getOwner())) {
570 if (failed(handleUseInCallOp(
use, c, op))) {
574 Operation *user =
use.getOwner();
576 if (!opWillBeDeleted(user)) {
577 return op->emitOpError()
579 "with use in '", user->getName().getStringRef(),
580 "' is not (currently) supported by this pass."
582 .attachNote(user->getLoc())
583 .append(
"used by this operation");
588 if (!op->use_empty()) {
589 for (Operation *user : op->getUsers()) {
590 if (!opWillBeDeleted(user)) {
591 llvm::errs() <<
"Op has remaining use(s) that could not be removed: " << *op <<
'\n';
592 llvm_unreachable(
"Expected all uses to be removed");
600 inline LogicalResult handleUseInCallOp(OpOperand &
use,
CallOp inCall, Operation *origin)
const {
602 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] use in call: " << inCall <<
'\n'
604 unsigned argIdx =
use.getOperandNumber() - inCall.
getArgOperands().getBeginOperandIndex();
606 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] at index: " << argIdx <<
'\n'
610 if (failed(tgtFuncRes)) {
612 ->emitOpError(
"as argument to an unknown function is not supported by this pass.")
613 .attachNote(inCall.getLoc())
614 .append(
"used by this call");
618 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] call target: " << tgtFunc <<
'\n'
620 if (tgtFunc.isExternal()) {
624 ->emitOpError(
"as argument to a no-body free function is not supported by this pass.")
625 .attachNote(inCall.getLoc())
626 .append(
"used by this call");
630 TypeSwitch<Operation *, MemberRefOpInterface>(origin)
631 .template Case<MemberReadOp>([](
auto p) {
return p; })
632 .
template Case<CreateStructOp>([](
auto p) {
633 return findOpThatStoresSubcmp(p, [&p]() {
return p.emitOpError(); }).value_or(
nullptr);
634 }).Default([](Operation *p) {
635 llvm::errs() <<
"Encountered unexpected op: "
636 << (p ? p->getName().getStringRef() :
"<<null>>") <<
'\n';
637 llvm_unreachable(
"Unexpected op kind");
641 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] member ref op for param: "
644 if (!paramFromMember) {
647 const SrcStructMemberToCloneInDest &newMembers =
648 destToSrcToClone.at(getDef(tables, paramFromMember));
650 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] members to split: "
655 splitFunctionParam(tgtFunc, argIdx, newMembers);
657 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] UPDATED call target: " << tgtFunc
659 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] UPDATED call target type: "
665 OpBuilder builder(inCall);
666 SmallVector<Value> splitArgs;
670 for (
auto [origName, newMemberRef] : newMembers) {
672 inCall.getLoc(), newMemberRef.getType(), originalBaseVal, newMemberRef.getNameAttr()
678 newOpArgs.erase(newOpArgs.begin() + argIdx), splitArgs.begin(), splitArgs.end()
681 inCall.replaceAllUsesWith(builder.create<
CallOp>(
687 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] UPDATED function: "
688 << origin->getParentOfType<
FuncDefOp>() <<
'\n';
694 inline bool opWillBeDeleted(Operation *otherOp)
const {
695 return std::apply([&](
const auto &...sets) {
696 return ((sets.contains(otherOp)) || ...);
697 }, otherRefsToBeDeleted);
704 static void splitFunctionParam(
705 FuncDefOp func,
unsigned paramIdx,
const SrcStructMemberToCloneInDest &nameToNewMember
709 const SrcStructMemberToCloneInDest &newMembers;
712 Impl(
unsigned paramIdx,
const SrcStructMemberToCloneInDest &nameToNewMember)
713 : inputIdx(paramIdx), newMembers(nameToNewMember) {}
716 SmallVector<Type>
convertInputs(ArrayRef<Type> origTypes)
override {
717 SmallVector<Type> newTypes(origTypes);
718 auto *it = newTypes.erase(newTypes.begin() + inputIdx);
719 for (
auto [_, newMember] : newMembers) {
720 newTypes.insert(it, newMember.getType());
725 SmallVector<Type>
convertResults(ArrayRef<Type> origTypes)
override {
726 return SmallVector<Type>(origTypes);
731 SmallVector<Attribute> newAttrs(origAttrs.getValue());
732 newAttrs.insert(newAttrs.begin() + inputIdx, newMembers.size() - 1, origAttrs[inputIdx]);
733 return ArrayAttr::get(origAttrs.getContext(), newAttrs);
742 Value oldStructRef = entryBlock.getArgument(inputIdx);
746 llvm::StringMap<BlockArgument> memberNameToNewArg;
747 Location loc = oldStructRef.getLoc();
748 unsigned idx = inputIdx;
749 for (
auto [memberName, newMember] : newMembers) {
751 BlockArgument newArg = entryBlock.insertArgument(++idx, newMember.getType(), loc);
752 memberNameToNewArg[memberName] = newArg;
757 for (OpOperand &oldBlockArgUse : llvm::make_early_inc_range(oldStructRef.getUses())) {
758 if (MemberReadOp readOp = llvm::dyn_cast<MemberReadOp>(oldBlockArgUse.getOwner())) {
760 BlockArgument newArg = memberNameToNewArg.at(readOp.
getMemberName());
761 rewriter.replaceAllUsesWith(readOp, newArg);
762 rewriter.eraseOp(readOp);
767 llvm::errs() <<
"Unexpected use of " << oldBlockArgUse.get() <<
" in "
768 << *oldBlockArgUse.getOwner() <<
'\n';
769 llvm_unreachable(
"Not yet implemented");
773 entryBlock.eraseArgument(inputIdx);
776 IRRewriter rewriter(func.getContext());
777 Impl(paramIdx, nameToNewMember).convert(func, rewriter);
781static LogicalResult finalizeStruct(
782 SymbolTableCollection &tables,
StructDefOp caller, PendingErasure &&toDelete,
783 DestToSrcToClonedSrcInDest &&destToSrcToClone
786 llvm::dbgs() <<
"[finalizeStruct] dumping 'caller' struct before compressing chains:\n";
787 caller.
print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
788 llvm::dbgs() <<
'\n';
793 combineReadChain(readOp, tables, destToSrcToClone);
797 auto res = computeFn.walk([&tables, &destToSrcToClone, &computeSelfVal](
MemberReadOp readOp) {
798 combineReadChain(readOp, tables, destToSrcToClone);
802 return WalkResult::advance();
804 LogicalResult innerRes = combineNewThenReadChain(readOp, tables, destToSrcToClone);
805 return failed(innerRes) ? WalkResult::interrupt() : WalkResult::advance();
807 if (res.wasInterrupted()) {
812 llvm::dbgs() <<
"[finalizeStruct] dumping 'caller' struct before deleting ops:\n";
813 caller.
print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
814 llvm::dbgs() <<
'\n';
815 llvm::dbgs() <<
"[finalizeStruct] ops marked for deletion:\n";
816 for (Operation *op : toDelete.memberReadOps) {
817 llvm::dbgs().indent(2) << *op <<
'\n';
819 for (Operation *op : toDelete.memberWriteOps) {
820 llvm::dbgs().indent(2) << *op <<
'\n';
823 llvm::dbgs().indent(2) << op <<
'\n';
825 for (DestMemberWithSrcStructType op : toDelete.memberDefs) {
826 llvm::dbgs().indent(2) << op <<
'\n';
832 DanglingUseHandler<SmallPtrSet<Operation *, 8>, SmallPtrSet<Operation *, 8>> useHandler(
833 tables, destToSrcToClone, toDelete.memberWriteOps, toDelete.memberReadOps
836 if (failed(useHandler.handle(op))) {
842 for (Operation *op : toDelete.memberWriteOps) {
843 if (failed(useHandler.handle(op))) {
848 for (Operation *op : toDelete.memberReadOps) {
849 if (failed(useHandler.handle(op))) {
858 SymbolTable &callerSymTab = tables.getSymbolTable(caller);
859 for (DestMemberWithSrcStructType op : toDelete.memberDefs) {
860 assert(op.getParentOp() == caller);
861 callerSymTab.erase(op);
870 for (
auto &[caller, callees] : plan) {
873 PendingErasure toDelete;
875 DestToSrcToClonedSrcInDest aggregateReplacements;
878 FailureOr<DestToSrcToClonedSrcInDest> res =
879 StructInliner(tables, toDelete, toInline, caller).doInline();
884 for (
auto &[k, v] : res.value()) {
885 assert(!aggregateReplacements.contains(k) &&
"duplicate not possible");
886 aggregateReplacements[k] = std::move(v);
890 LogicalResult finalizeResult =
891 finalizeStruct(tables, caller, std::move(toDelete), std::move(aggregateReplacements));
892 if (failed(finalizeResult)) {
902 static uint64_t complexity(FuncDefOp f) {
903 uint64_t complexity = 0;
904 f.
getBody().walk([&complexity](Operation *op) {
905 if (llvm::isa<felt::MulFeltOp>(op)) {
907 }
else if (
auto ee = llvm::dyn_cast<constrain::EmitEqualityOp>(op)) {
909 }
else if (
auto ec = llvm::dyn_cast<constrain::EmitContainmentOp>(op)) {
918 static FailureOr<FuncDefOp>
919 getIfStructConstrain(
const SymbolUseGraphNode *node, SymbolTableCollection &tables) {
921 assert(succeeded(lookupRes) &&
"graph contains node with invalid path");
922 if (FuncDefOp f = llvm::dyn_cast<FuncDefOp>(lookupRes->get())) {
932 static inline StructDefOp getParentStruct(FuncDefOp func) {
935 assert(currentNodeParentStruct);
936 return currentNodeParentStruct;
940 inline bool exceedsMaxComplexity(uint64_t
check) {
941 return maxComplexity > 0 &&
check > maxComplexity;
946 static inline bool canInline(FuncDefOp currentFunc, FuncDefOp successorFunc) {
958 WalkResult res = currentFunc.walk([](CallOp c) {
959 return getMemberReadThatDefinesSelfValuePassedToConstrain(c)
960 ? WalkResult::interrupt()
961 : WalkResult::advance();
967 return res.wasInterrupted();
974 inline FailureOr<InliningPlan>
975 makePlan(
const SymbolUseGraph &useGraph, SymbolTableCollection &tables) {
977 llvm::dbgs() <<
"Running InlineStructsPass with max complexity ";
978 if (maxComplexity == 0) {
979 llvm::dbgs() <<
"unlimited";
981 llvm::dbgs() << maxComplexity;
983 llvm::dbgs() <<
'\n';
986 DenseMap<const SymbolUseGraphNode *, uint64_t> complexityMemo;
998 for (
const SymbolUseGraphNode *currentNode : llvm::post_order(&useGraph)) {
999 LLVM_DEBUG(llvm::dbgs() <<
"\ncurrentNode = " << currentNode->toString());
1000 if (!currentNode->isRealNode()) {
1003 if (currentNode->isTemplateSymbolBinding()) {
1005 Operation *lookupFrom = currentNode->getSymbolPathRoot().getOperation();
1009 Operation *reportLoc = succeeded(res) ? res->get() : lookupFrom;
1010 return reportLoc->emitError() <<
"Cannot inline struct within a template. Run "
1011 "`llzk-flatten` to instantiate templated structs.";
1013 FailureOr<FuncDefOp> currentFuncOpt = getIfStructConstrain(currentNode, tables);
1014 if (failed(currentFuncOpt)) {
1017 FuncDefOp currentFunc = currentFuncOpt.value();
1018 uint64_t currentComplexity = complexity(currentFunc);
1020 if (exceedsMaxComplexity(currentComplexity)) {
1021 complexityMemo[currentNode] = currentComplexity;
1026 SmallVector<StructDefOp> successorsToMerge;
1027 for (
const SymbolUseGraphNode *successor : currentNode->successorIter()) {
1028 LLVM_DEBUG(llvm::dbgs().indent(2) <<
"successor: " << successor->toString() <<
'\n');
1030 auto memoResult = complexityMemo.find(successor);
1031 if (memoResult == complexityMemo.end()) {
1034 uint64_t sComplexity = memoResult->second;
1036 sComplexity <= (std::numeric_limits<uint64_t>::max() - currentComplexity) &&
1037 "addition will overflow"
1039 uint64_t potentialComplexity = currentComplexity + sComplexity;
1040 if (!exceedsMaxComplexity(potentialComplexity)) {
1041 currentComplexity = potentialComplexity;
1042 FailureOr<FuncDefOp> successorFuncOpt = getIfStructConstrain(successor, tables);
1043 assert(succeeded(successorFuncOpt));
1044 FuncDefOp successorFunc = successorFuncOpt.value();
1045 if (canInline(currentFunc, successorFunc)) {
1046 successorsToMerge.push_back(getParentStruct(successorFunc));
1050 complexityMemo[currentNode] = currentComplexity;
1051 if (!successorsToMerge.empty()) {
1052 retVal.emplace_back(getParentStruct(currentFunc), std::move(successorsToMerge));
1056 llvm::dbgs() <<
"-----------------------------------------------------------------\n";
1057 llvm::dbgs() <<
"InlineStructsPass plan:\n";
1058 for (
auto &[caller, callees] : retVal) {
1059 llvm::dbgs().indent(2) <<
"inlining the following into \"" << caller.
getSymName() <<
"\"\n";
1060 for (StructDefOp c : callees) {
1061 llvm::dbgs().indent(4) <<
"\"" << c.getSymName() <<
"\"\n";
1064 llvm::dbgs() <<
"-----------------------------------------------------------------\n";
1070 void runOnOperation()
override {
1071 const SymbolUseGraph &useGraph = getAnalysis<SymbolUseGraph>();
1074 SymbolTableCollection tables;
1075 FailureOr<InliningPlan> plan = makePlan(useGraph, tables);
1077 signalPassFailure();
1082 signalPassFailure();
1091 return std::make_unique<InlineStructsPass>();
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
LogicalResult performInlining(SymbolTableCollection &tables, InliningPlan &plan)
mlir::SmallVector< std::pair< llzk::component::StructDefOp, mlir::SmallVector< llzk::component::StructDefOp > > > InliningPlan
Maps caller struct to callees that should be inlined.
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
virtual void processBlockArgs(mlir::Block &entryBlock, mlir::RewriterBase &rewriter)=0
virtual llvm::SmallVector< mlir::Type > convertResults(mlir::ArrayRef< mlir::Type > origTypes)=0
virtual mlir::ArrayAttr convertResultAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual mlir::ArrayAttr convertInputAttrs(mlir::ArrayAttr origAttrs, llvm::SmallVector< mlir::Type > newTypes)=0
virtual llvm::SmallVector< mlir::Type > convertInputs(mlir::ArrayRef< mlir::Type > origTypes)=0
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
::llvm::StringRef getMemberName()
::mlir::FailureOr< SymbolLookupResult< MemberDefOp > > getMemberDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the member referenced in this op.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Gets the SSA value with the target component from the MemberRefOp.
::llvm::StringRef getSymName()
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
void print(::mlir::OpAsmPrinter &_odsPrinter)
::mlir::Operation::operand_range getArgOperands()
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
::mlir::OperandRangeRange getMapOperands()
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
FuncDefOp clone(::mlir::IRMapping &mapper)
Create a deep copy of this function and all of its blocks, remapping any operands that use values out...
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
::mlir::FunctionType getFunctionType()
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
bool isStructConstrain()
Return true iff the function is within a StructDefOp and named FUNC_NAME_CONSTRAIN.
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
::mlir::Region & getBody()
std::string toStringOne(const T &value)
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
uint64_t computeEmitEqCardinality(Type type)
constexpr char FUNC_NAME_CONSTRAIN[]
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
std::unique_ptr< mlir::Pass > createInlineStructsPass()
bool hasCycle(const GraphT &G)