35#include <mlir/IR/BuiltinOps.h>
36#include <mlir/Transforms/InliningUtils.h>
37#include <mlir/Transforms/WalkPatternRewriteDriver.h>
39#include <llvm/ADT/DenseMap.h>
40#include <llvm/ADT/SmallPtrSet.h>
41#include <llvm/ADT/SmallVector.h>
42#include <llvm/ADT/StringMap.h>
43#include <llvm/ADT/TypeSwitch.h>
44#include <llvm/Support/Debug.h>
51#define GEN_PASS_DEF_INLINESTRUCTSPASS
61#define DEBUG_TYPE "llzk-inline-structs"
70using SrcStructMemberToCloneInDest = std::map<StringRef, DestCloneOfSrcStructMember>;
73using DestToSrcToClonedSrcInDest =
74 DenseMap<DestMemberWithSrcStructType, SrcStructMemberToCloneInDest>;
78static inline Value getSelfValue(
FuncDefOp f) {
84 llvm_unreachable(
"expected \"@compute\" or \"@constrain\" function");
98static FailureOr<MemberWriteOp>
99findOpThatStoresSubcmp(Value writtenValue, function_ref<InFlightDiagnostic()> emitError) {
101 for (Operation *user : writtenValue.getUsers()) {
102 if (
MemberWriteOp writeOp = llvm::dyn_cast<MemberWriteOp>(user)) {
104 if (writeOp.getVal() == writtenValue) {
107 auto diag = emitError().append(
"result should not be written to more than one member.");
108 diag.attachNote(foundWrite.getLoc()).append(
"written here");
109 diag.attachNote(writeOp.getLoc()).append(
"written here");
112 foundWrite = writeOp;
119 return emitError().append(
"result should be written to a member.");
127static bool combineHelper(
132 llvm::dbgs() <<
"[combineHelper] " << readOp <<
" => " << destMemberRefOp <<
'\n';
135 auto srcToClone = destToSrcToClone.find(getDef(tables, destMemberRefOp));
136 if (srcToClone == destToSrcToClone.end()) {
139 SrcStructMemberToCloneInDest oldToNewMembers = srcToClone->second;
140 auto resNewMember = oldToNewMembers.find(readOp.
getMemberName());
141 if (resNewMember == oldToNewMembers.end()) {
146 OpBuilder builder(readOp);
148 readOp.getLoc(), readOp.getType(), destMemberRefOp.
getComponent(),
149 resNewMember->second.getNameAttr()
151 readOp.replaceAllUsesWith(newRead.getOperation());
169static bool combineReadChain(
171 const DestToSrcToClonedSrcInDest &destToSrcToClone
173 LLVM_DEBUG({ llvm::dbgs() <<
"[combineReadChain] " << readOp <<
'\n'; });
176 llvm::dyn_cast_if_present<MemberReadOp>(readOp.
getComponent().getDefiningOp());
177 if (!readThatDefinesBaseComponent) {
180 return combineHelper(readOp, tables, destToSrcToClone, readThatDefinesBaseComponent);
199static LogicalResult combineNewThenReadChain(
201 const DestToSrcToClonedSrcInDest &destToSrcToClone
203 LLVM_DEBUG({ llvm::dbgs() <<
"[combineNewThenReadChain] " << readOp <<
'\n'; });
206 llvm::dyn_cast_if_present<CreateStructOp>(readOp.
getComponent().getDefiningOp());
207 if (!createThatDefinesBaseComponent) {
210 FailureOr<MemberWriteOp> foundWrite =
211 findOpThatStoresSubcmp(createThatDefinesBaseComponent, [&createThatDefinesBaseComponent]() {
212 return createThatDefinesBaseComponent.emitOpError();
214 if (failed(foundWrite)) {
217 return success(combineHelper(readOp, tables, destToSrcToClone, foundWrite.value()));
220static inline MemberReadOp getMemberReadThatDefinesSelfValuePassedToConstrain(
CallOp callOp) {
222 return llvm::dyn_cast_if_present<MemberReadOp>(selfArgFromCall.getDefiningOp());
227struct PendingErasure {
228 SmallPtrSet<Operation *, 8> memberReadOps;
229 SmallPtrSet<Operation *, 8> memberWriteOps;
230 SmallVector<CreateStructOp> newStructOps;
231 SmallVector<DestMemberWithSrcStructType> memberDefs;
236 SymbolTableCollection &tables;
237 PendingErasure &toDelete;
251 class MemberRefRewriter final :
public OpInterfaceRewritePattern<MemberRefOpInterface> {
259 const SrcStructMemberToCloneInDest &oldToNewMembers;
263 FuncDefOp originalFunc, Value newRefBase,
264 const SrcStructMemberToCloneInDest &oldToNewMemberDef
266 : OpInterfaceRewritePattern(originalFunc.getContext()), funcRef(originalFunc),
267 oldBaseVal(
nullptr), newBaseVal(newRefBase), oldToNewMembers(oldToNewMemberDef) {}
280 rewriter.modifyOpInPlace(op, [
this, &op]() {
281 DestCloneOfSrcStructMember newF = oldToNewMembers.at(op.
getMemberName());
289 static FuncDefOp cloneWithMemberRefUpdate(std::unique_ptr<MemberRefRewriter> thisPat) {
293 thisPat->funcRef = srcFuncClone;
294 thisPat->oldBaseVal = getSelfValue(srcFuncClone);
296 MLIRContext *ctx = thisPat->getContext();
297 RewritePatternSet patterns(ctx, std::move(thisPat));
298 walkAndApplyPatterns(srcFuncClone, std::move(patterns));
307 const StructInliner &data;
308 const DestToSrcToClonedSrcInDest &destToSrcToClone;
313 virtual void processCloneBeforeInlining(
FuncDefOp func) {}
314 virtual ~ImplBase() =
default;
317 ImplBase(
const StructInliner &inliner,
const DestToSrcToClonedSrcInDest &destToSrcToCloneRef)
318 : data(inliner), destToSrcToClone(destToSrcToCloneRef) {}
322 llvm::dbgs() <<
"[doInlining] SOURCE FUNCTION:\n";
324 llvm::dbgs() <<
"[doInlining] DESTINATION FUNCTION:\n";
328 InlinerInterface inliner(destFunc.getContext());
331 auto callHandler = [
this, &inliner, &srcFunc](
CallOp callOp) {
334 assert(succeeded(callOpTarget));
335 if (callOpTarget->get() != srcFunc) {
336 return WalkResult::advance();
342 if (!selfMemberRefOp) {
344 return WalkResult::interrupt();
350 FuncDefOp srcFuncClone = MemberRefRewriter::cloneWithMemberRefUpdate(
351 std::make_unique<MemberRefRewriter>(
353 this->destToSrcToClone.at(this->data.getDef(selfMemberRefOp))
356 this->processCloneBeforeInlining(srcFuncClone);
359 LogicalResult inlineCallRes =
360 inlineCall(inliner, callOp, srcFuncClone, &srcFuncClone.
getBody(),
false);
361 if (failed(inlineCallRes)) {
363 return WalkResult::interrupt();
365 srcFuncClone.erase();
367 return WalkResult::skip();
372 if (this->destToSrcToClone.contains(this->data.getDef(writeOp))) {
373 this->data.toDelete.memberWriteOps.insert(writeOp);
375 return WalkResult::advance();
382 if (this->destToSrcToClone.contains(this->data.getDef(readOp))) {
383 this->data.toDelete.memberReadOps.insert(readOp);
386 return combineReadChain(readOp, this->data.tables, destToSrcToClone)
388 : WalkResult::advance();
391 WalkResult walkRes = destFunc.
getBody().walk<WalkOrder::PreOrder>([&](Operation *op) {
392 return TypeSwitch<Operation *, WalkResult>(op)
393 .Case<
CallOp>(callHandler)
394 .Case<MemberWriteOp>(memberWriteHandler)
396 .Default([](Operation *) {
return WalkResult::advance(); });
399 return failure(walkRes.wasInterrupted());
403 class ConstrainImpl :
public ImplBase {
404 using ImplBase::ImplBase;
407 LLVM_DEBUG({ llvm::dbgs() <<
"[ConstrainImpl::getSelfRefMember] " << callOp <<
'\n'; });
413 getMemberReadThatDefinesSelfValuePassedToConstrain(callOp);
415 selfMemberRef.getComponent().getType() == this->data.destStruct.getType()) {
416 return selfMemberRef;
421 "\" to be passed a value read from a member in the current stuct."
428 class ComputeImpl :
public ImplBase {
429 using ImplBase::ImplBase;
431 MemberRefOpInterface getSelfRefMember(CallOp callOp)
override {
432 LLVM_DEBUG({ llvm::dbgs() <<
"[ComputeImpl::getSelfRefMember] " << callOp <<
'\n'; });
439 FailureOr<MemberWriteOp> foundWrite =
441 return callOp.emitOpError().append(
"\"@", FUNC_NAME_COMPUTE,
"\" ");
443 return static_cast<MemberRefOpInterface
>(foundWrite.value_or(
nullptr));
446 void processCloneBeforeInlining(FuncDefOp func)
override {
450 func.
getBody().walk([
this](CreateStructOp newStructOp) {
451 if (newStructOp.getType() == this->data.srcStruct.getType()) {
452 this->data.toDelete.newStructOps.push_back(newStructOp);
461 DestToSrcToClonedSrcInDest cloneMembers() {
462 DestToSrcToClonedSrcInDest destToSrcToClone;
464 SymbolTable &destStructSymTable = tables.getSymbolTable(destStruct);
465 StructType srcStructType = srcStruct.getType();
466 for (MemberDefOp destMember : destStruct.getMemberDefs()) {
467 if (StructType destMemberType = llvm::dyn_cast<StructType>(destMember.getType())) {
472 assert(unifications.empty());
474 toDelete.memberDefs.push_back(destMember);
477 SrcStructMemberToCloneInDest &srcToClone = destToSrcToClone[destMember];
478 std::vector<MemberDefOp> srcMembers = srcStruct.getMemberDefs();
479 if (srcMembers.empty()) {
482 OpBuilder builder(destMember);
483 std::string newNameBase =
485 for (MemberDefOp srcMember : srcMembers) {
486 DestCloneOfSrcStructMember newF = llvm::cast<MemberDefOp>(builder.clone(*srcMember));
487 newF.setName(builder.getStringAttr(newNameBase +
'+' + newF.getName()));
488 srcToClone[srcMember.getSymNameAttr()] = newF;
490 destStructSymTable.insert(newF);
494 return destToSrcToClone;
498 inline LogicalResult inlineConstrainCall(
const DestToSrcToClonedSrcInDest &destToSrcToClone) {
499 return ConstrainImpl(*
this, destToSrcToClone)
500 .doInlining(srcStruct.getConstrainFuncOp(), destStruct.getConstrainFuncOp());
504 inline LogicalResult inlineComputeCall(
const DestToSrcToClonedSrcInDest &destToSrcToClone) {
505 return ComputeImpl(*
this, destToSrcToClone)
506 .doInlining(srcStruct.getComputeFuncOp(), destStruct.getComputeFuncOp());
511 SymbolTableCollection &tbls, PendingErasure &opsToDelete, StructDefOp
from, StructDefOp into
513 : tables(tbls), toDelete(opsToDelete), srcStruct(
from), destStruct(into) {}
515 FailureOr<DestToSrcToClonedSrcInDest> doInline() {
517 llvm::dbgs() <<
"[StructInliner] merge " << srcStruct.getSymNameAttr() <<
" into "
518 << destStruct.getSymNameAttr() <<
'\n'
521 DestToSrcToClonedSrcInDest destToSrcToClone = cloneMembers();
522 if (failed(inlineConstrainCall(destToSrcToClone)) ||
523 failed(inlineComputeCall(destToSrcToClone))) {
526 return destToSrcToClone;
532 { t.contains(p) } -> std::convertible_to<bool>;
536template <
typename... PendingDeletionSets>
538class DanglingUseHandler {
539 SymbolTableCollection &tables;
540 const DestToSrcToClonedSrcInDest &destToSrcToClone;
541 std::tuple<
const PendingDeletionSets &...> otherRefsToBeDeleted;
545 SymbolTableCollection &symTables,
const DestToSrcToClonedSrcInDest &destToSrcToCloneRef,
546 const PendingDeletionSets &...otherRefsPendingDeletion
548 : tables(symTables), destToSrcToClone(destToSrcToCloneRef),
549 otherRefsToBeDeleted(otherRefsPendingDeletion...) {}
556 LogicalResult handle(Operation *op)
const {
557 if (op->use_empty()) {
562 llvm::dbgs() <<
"[DanglingUseHandler::handle] op: " << *op <<
'\n';
563 llvm::dbgs() <<
"[DanglingUseHandler::handle] in function: "
564 << op->getParentOfType<
FuncDefOp>() <<
'\n';
566 for (OpOperand &
use : llvm::make_early_inc_range(op->getUses())) {
567 if (
CallOp c = llvm::dyn_cast<CallOp>(
use.getOwner())) {
568 if (failed(handleUseInCallOp(
use, c, op))) {
572 Operation *user =
use.getOwner();
574 if (!opWillBeDeleted(user)) {
575 return op->emitOpError()
577 "with use in '", user->getName().getStringRef(),
578 "' is not (currently) supported by this pass."
580 .attachNote(user->getLoc())
581 .append(
"used by this operation");
586 if (!op->use_empty()) {
587 for (Operation *user : op->getUsers()) {
588 if (!opWillBeDeleted(user)) {
589 llvm::errs() <<
"Op has remaining use(s) that could not be removed: " << *op <<
'\n';
590 llvm_unreachable(
"Expected all uses to be removed");
598 inline LogicalResult handleUseInCallOp(OpOperand &
use,
CallOp inCall, Operation *origin)
const {
600 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] use in call: " << inCall <<
'\n'
602 unsigned argIdx =
use.getOperandNumber() - inCall.
getArgOperands().getBeginOperandIndex();
604 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] at index: " << argIdx <<
'\n'
608 if (failed(tgtFuncRes)) {
610 ->emitOpError(
"as argument to an unknown function is not supported by this pass.")
611 .attachNote(inCall.getLoc())
612 .append(
"used by this call");
616 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] call target: " << tgtFunc <<
'\n'
618 if (tgtFunc.isExternal()) {
622 ->emitOpError(
"as argument to a no-body free function is not supported by this pass.")
623 .attachNote(inCall.getLoc())
624 .append(
"used by this call");
628 TypeSwitch<Operation *, MemberRefOpInterface>(origin)
629 .template Case<MemberReadOp>([](
auto p) {
return p; })
630 .
template Case<CreateStructOp>([](
auto p) {
631 return findOpThatStoresSubcmp(p, [&p]() {
return p.emitOpError(); }).value_or(
nullptr);
632 }).Default([](Operation *p) {
633 llvm::errs() <<
"Encountered unexpected op: "
634 << (p ? p->getName().getStringRef() :
"<<null>>") <<
'\n';
635 llvm_unreachable(
"Unexpected op kind");
639 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] member ref op for param: "
642 if (!paramFromMember) {
645 const SrcStructMemberToCloneInDest &newMembers =
646 destToSrcToClone.at(getDef(tables, paramFromMember));
648 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] members to split: "
653 splitFunctionParam(tgtFunc, argIdx, newMembers);
655 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] UPDATED call target: " << tgtFunc
657 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] UPDATED call target type: "
663 OpBuilder builder(inCall);
664 SmallVector<Value> splitArgs;
668 for (
auto [origName, newMemberRef] : newMembers) {
670 inCall.getLoc(), newMemberRef.getType(), originalBaseVal, newMemberRef.getNameAttr()
676 newOpArgs.erase(newOpArgs.begin() + argIdx), splitArgs.begin(), splitArgs.end()
679 inCall.replaceAllUsesWith(builder.create<
CallOp>(
685 llvm::dbgs() <<
"[DanglingUseHandler::handleUseInCallOp] UPDATED function: "
686 << origin->getParentOfType<
FuncDefOp>() <<
'\n';
692 inline bool opWillBeDeleted(Operation *otherOp)
const {
693 return std::apply([&](
const auto &...sets) {
694 return ((sets.contains(otherOp)) || ...);
695 }, otherRefsToBeDeleted);
702 static void splitFunctionParam(
703 FuncDefOp func,
unsigned paramIdx,
const SrcStructMemberToCloneInDest &nameToNewMember
707 const SrcStructMemberToCloneInDest &newMembers;
708 std::optional<std::string> originalArgName;
709 SmallVector<std::string> existingArgNames;
712 Impl(FuncDefOp func,
unsigned paramIdx,
const SrcStructMemberToCloneInDest &nameToNewMember)
713 : inputIdx(paramIdx), newMembers(nameToNewMember) {
714 for (
unsigned i = 0, e = func.getNumArguments(); i < e; ++i) {
715 if (std::optional<StringAttr> argName = func.getArgNameAttr(i)) {
716 existingArgNames.push_back(argName->getValue().str());
718 originalArgName = argName->getValue().str();
725 SmallVector<Type> convertInputs(ArrayRef<Type> origTypes)
override {
726 SmallVector<Type> newTypes(origTypes);
727 auto *it = newTypes.erase(newTypes.begin() + inputIdx);
728 for (
auto [_, newMember] : newMembers) {
729 newTypes.insert(it, newMember.getType());
734 SmallVector<Type> convertResults(ArrayRef<Type> origTypes)
override {
735 return SmallVector<Type>(origTypes);
737 ArrayAttr convertInputAttrs(ArrayAttr origAttrs, SmallVector<Type>)
override {
740 SmallVector<Attribute> newAttrs(origAttrs.getValue());
741 auto splitAttr = llvm::cast<DictionaryAttr>(origAttrs[inputIdx]);
742 SmallVector<Attribute> splitAttrs;
743 if (originalArgName) {
744 llvm::StringSet<> usedArgNames;
745 for (StringRef argName : existingArgNames) {
746 usedArgNames.insert(argName);
748 for (
auto [memberName, _] : newMembers) {
749 std::string desiredName = (*originalArgName +
'.' + memberName).str();
755 splitAttrs.append(newMembers.size(), splitAttr);
757 newAttrs[inputIdx] = splitAttrs.front();
759 newAttrs.begin() + inputIdx + 1, splitAttrs.begin() + 1, splitAttrs.end()
761 return ArrayAttr::get(origAttrs.getContext(), newAttrs);
765 ArrayAttr convertResultAttrs(ArrayAttr origAttrs, SmallVector<Type>)
override {
769 void processBlockArgs(Block &entryBlock, RewriterBase &rewriter)
override {
770 Value oldStructRef = entryBlock.getArgument(inputIdx);
774 llvm::StringMap<BlockArgument> memberNameToNewArg;
775 Location loc = oldStructRef.getLoc();
776 unsigned idx = inputIdx;
777 for (
auto [memberName, newMember] : newMembers) {
779 BlockArgument newArg = entryBlock.insertArgument(++idx, newMember.getType(), loc);
780 memberNameToNewArg[memberName] = newArg;
785 for (OpOperand &oldBlockArgUse : llvm::make_early_inc_range(oldStructRef.getUses())) {
786 if (
MemberReadOp readOp = llvm::dyn_cast<MemberReadOp>(oldBlockArgUse.getOwner())) {
788 BlockArgument newArg = memberNameToNewArg.at(readOp.
getMemberName());
789 rewriter.replaceAllUsesWith(readOp, newArg);
790 rewriter.eraseOp(readOp);
795 llvm::errs() <<
"Unexpected use of " << oldBlockArgUse.get() <<
" in "
796 << *oldBlockArgUse.getOwner() <<
'\n';
797 llvm_unreachable(
"Not yet implemented");
801 entryBlock.eraseArgument(inputIdx);
804 IRRewriter rewriter(func.getContext());
805 Impl(func, paramIdx, nameToNewMember).convert(func, rewriter);
809static LogicalResult finalizeStruct(
810 SymbolTableCollection &tables,
StructDefOp caller, PendingErasure &&toDelete,
811 DestToSrcToClonedSrcInDest &&destToSrcToClone
814 llvm::dbgs() <<
"[finalizeStruct] dumping 'caller' struct before compressing chains:\n";
815 caller.
print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
816 llvm::dbgs() <<
'\n';
821 combineReadChain(readOp, tables, destToSrcToClone);
825 auto res = computeFn.walk([&tables, &destToSrcToClone, &computeSelfVal](
MemberReadOp readOp) {
826 combineReadChain(readOp, tables, destToSrcToClone);
830 return WalkResult::advance();
832 LogicalResult innerRes = combineNewThenReadChain(readOp, tables, destToSrcToClone);
833 return failed(innerRes) ? WalkResult::interrupt() : WalkResult::advance();
835 if (res.wasInterrupted()) {
840 llvm::dbgs() <<
"[finalizeStruct] dumping 'caller' struct before deleting ops:\n";
841 caller.
print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
842 llvm::dbgs() <<
'\n';
843 llvm::dbgs() <<
"[finalizeStruct] ops marked for deletion:\n";
844 for (Operation *op : toDelete.memberReadOps) {
845 llvm::dbgs().indent(2) << *op <<
'\n';
847 for (Operation *op : toDelete.memberWriteOps) {
848 llvm::dbgs().indent(2) << *op <<
'\n';
851 llvm::dbgs().indent(2) << op <<
'\n';
853 for (DestMemberWithSrcStructType op : toDelete.memberDefs) {
854 llvm::dbgs().indent(2) << op <<
'\n';
860 DanglingUseHandler<SmallPtrSet<Operation *, 8>, SmallPtrSet<Operation *, 8>> useHandler(
861 tables, destToSrcToClone, toDelete.memberWriteOps, toDelete.memberReadOps
864 if (failed(useHandler.handle(op))) {
870 for (Operation *op : toDelete.memberWriteOps) {
871 if (failed(useHandler.handle(op))) {
876 for (Operation *op : toDelete.memberReadOps) {
877 if (failed(useHandler.handle(op))) {
886 SymbolTable &callerSymTab = tables.getSymbolTable(caller);
887 for (DestMemberWithSrcStructType op : toDelete.memberDefs) {
888 assert(op.getParentOp() == caller);
889 callerSymTab.erase(op);
898 for (
auto &[caller, callees] : plan) {
901 PendingErasure toDelete;
903 DestToSrcToClonedSrcInDest aggregateReplacements;
906 FailureOr<DestToSrcToClonedSrcInDest> res =
907 StructInliner(tables, toDelete, toInline, caller).doInline();
912 for (
auto &[k, v] : res.value()) {
913 assert(!aggregateReplacements.contains(k) &&
"duplicate not possible");
914 aggregateReplacements[k] = std::move(v);
918 LogicalResult finalizeResult =
919 finalizeStruct(tables, caller, std::move(toDelete), std::move(aggregateReplacements));
920 if (failed(finalizeResult)) {
930 using Base = InlineStructsPassBase<PassImpl>;
933 static uint64_t complexity(FuncDefOp f) {
934 uint64_t complexity = 0;
935 f.
getBody().walk([&complexity](Operation *op) {
936 if (llvm::isa<felt::MulFeltOp>(op)) {
938 }
else if (
auto ee = llvm::dyn_cast<constrain::EmitEqualityOp>(op)) {
940 }
else if (
auto ec = llvm::dyn_cast<constrain::EmitContainmentOp>(op)) {
955 getIfResolvableStructConstrain(
const SymbolUseGraphNode *node, SymbolTableCollection &tables) {
960 if (failed(lookupRes)) {
963 FuncDefOp func = llvm::dyn_cast<FuncDefOp>(lookupRes->get());
972 static inline StructDefOp getParentStruct(FuncDefOp func) {
975 assert(currentNodeParentStruct);
976 return currentNodeParentStruct;
980 inline bool exceedsMaxComplexity(uint64_t
check) {
981 return maxComplexity > 0 &&
check > maxComplexity;
986 static inline bool canInline(FuncDefOp currentFunc, FuncDefOp successorFunc) {
998 WalkResult res = currentFunc.walk([](CallOp c) {
999 return getMemberReadThatDefinesSelfValuePassedToConstrain(c)
1000 ? WalkResult::interrupt()
1001 : WalkResult::advance();
1007 return res.wasInterrupted();
1010 static LogicalResult
1011 verifyNoTemplateSymbolBindings(
const SymbolUseGraph &useGraph, SymbolTableCollection &tables) {
1012 for (
const SymbolUseGraphNode *node : useGraph.
nodesIter()) {
1022 Operation *reportLoc = succeeded(res) ? res->get() : lookupFrom;
1023 return reportLoc->emitError() <<
"Cannot inline struct within a template. Run "
1024 "`llzk-flatten` to instantiate templated structs.";
1029 static LogicalResult emitConstrainReachableCycleError(
1030 ArrayRef<const SymbolUseGraphNode *> dfsStack,
const SymbolUseGraphNode *cycleHead,
1031 SymbolTableCollection &tables
1033 SmallVector<const SymbolUseGraphNode *, 8> cycle;
1034 bool inCycle =
false;
1035 for (
const SymbolUseGraphNode *node : dfsStack) {
1036 if (node == cycleHead) {
1040 cycle.push_back(node);
1043 if (cycle.empty()) {
1044 cycle.push_back(cycleHead);
1048 for (
const SymbolUseGraphNode *node : cycle) {
1053 if (failed(lookupRes)) {
1056 Operation *op = lookupRes->get();
1058 if (llvm::isa<FuncDefOp>(op)) {
1063 InFlightDiagnostic diag = reportOp->emitError();
1064 diag <<
"Cannot inline structs when a symbol-use cycle is reachable from a struct "
1065 "\"@constrain\" function. Prover-side recursion is allowed only when "
1066 "\"@constrain\" cannot reach it.";
1068 for (
const SymbolUseGraphNode *node : cycle) {
1072 if (
auto lookupRes = node->
lookupSymbol(tables,
false);
1073 succeeded(lookupRes)) {
1074 diag.attachNote(lookupRes->get()->getLoc()) <<
"cycle contains " << node->
getSymbolPath();
1089 static LogicalResult computeConstrainReachablePostOrder(
1090 const SymbolUseGraph &useGraph, SymbolTableCollection &tables,
1091 SmallVectorImpl<const SymbolUseGraphNode *> &postOrder
1093 enum class VisitState : std::uint8_t { Active, Done };
1095 DenseMap<const SymbolUseGraphNode *, VisitState> state;
1096 SmallVector<const SymbolUseGraphNode *, 32> dfsStack;
1098 auto dfs = [&](
auto &&self,
const SymbolUseGraphNode *node) -> LogicalResult {
1099 auto seen = state.find(node);
1100 if (seen != state.end()) {
1101 if (seen->second == VisitState::Active) {
1102 return emitConstrainReachableCycleError(dfsStack, node, tables);
1107 state[node] = VisitState::Active;
1108 dfsStack.push_back(node);
1109 for (
const SymbolUseGraphNode *successor : node->
successorIter()) {
1110 if (failed(self(self, successor))) {
1114 dfsStack.pop_back();
1116 state[node] = VisitState::Done;
1117 postOrder.push_back(node);
1121 for (
const SymbolUseGraphNode *node : useGraph.
nodesIter()) {
1122 if (!getIfResolvableStructConstrain(node, tables)) {
1125 if (failed(dfs(dfs, node))) {
1137 inline FailureOr<InliningPlan>
1138 makePlan(
const SymbolUseGraph &useGraph, SymbolTableCollection &tables) {
1140 llvm::dbgs() <<
"Running InlineStructsPass with max complexity ";
1141 if (maxComplexity == 0) {
1142 llvm::dbgs() <<
"unlimited";
1144 llvm::dbgs() << maxComplexity;
1146 llvm::dbgs() <<
'\n';
1149 DenseMap<const SymbolUseGraphNode *, uint64_t> complexityMemo;
1151 if (failed(verifyNoTemplateSymbolBindings(useGraph, tables))) {
1155 SmallVector<const SymbolUseGraphNode *, 32> constrainPostOrder;
1156 if (failed(computeConstrainReachablePostOrder(useGraph, tables, constrainPostOrder))) {
1163 for (
const SymbolUseGraphNode *currentNode : constrainPostOrder) {
1164 LLVM_DEBUG(llvm::dbgs() <<
"\ncurrentNode = " << currentNode->toString());
1165 FuncDefOp currentFunc = getIfResolvableStructConstrain(currentNode, tables);
1169 uint64_t currentComplexity = complexity(currentFunc);
1171 if (exceedsMaxComplexity(currentComplexity)) {
1172 complexityMemo[currentNode] = currentComplexity;
1177 SmallVector<StructDefOp> successorsToMerge;
1178 for (
const SymbolUseGraphNode *successor : currentNode->successorIter()) {
1179 LLVM_DEBUG(llvm::dbgs().indent(2) <<
"successor: " << successor->toString() <<
'\n');
1181 auto memoResult = complexityMemo.find(successor);
1182 if (memoResult == complexityMemo.end()) {
1185 uint64_t sComplexity = memoResult->second;
1187 sComplexity <= (std::numeric_limits<uint64_t>::max() - currentComplexity) &&
1188 "addition will overflow"
1190 uint64_t potentialComplexity = currentComplexity + sComplexity;
1191 if (!exceedsMaxComplexity(potentialComplexity)) {
1192 currentComplexity = potentialComplexity;
1193 FuncDefOp successorFunc = getIfResolvableStructConstrain(successor, tables);
1194 if (!successorFunc) {
1197 if (canInline(currentFunc, successorFunc)) {
1198 successorsToMerge.push_back(getParentStruct(successorFunc));
1202 complexityMemo[currentNode] = currentComplexity;
1203 if (!successorsToMerge.empty()) {
1204 retVal.emplace_back(getParentStruct(currentFunc), std::move(successorsToMerge));
1208 llvm::dbgs() <<
"-----------------------------------------------------------------\n";
1209 llvm::dbgs() <<
"InlineStructsPass plan:\n";
1210 for (
auto &[caller, callees] : retVal) {
1211 llvm::dbgs().indent(2) <<
"inlining the following into \"" << caller.
getSymName() <<
"\"\n";
1212 for (StructDefOp c : callees) {
1213 llvm::dbgs().indent(4) <<
"\"" << c.getSymName() <<
"\"\n";
1216 llvm::dbgs() <<
"-----------------------------------------------------------------\n";
1222 void runOnOperation()
override {
1223 const SymbolUseGraph &useGraph = getAnalysis<SymbolUseGraph>();
1226 SymbolTableCollection tables;
1227 FailureOr<InliningPlan> plan = makePlan(useGraph, tables);
1229 signalPassFailure();
1234 signalPassFailure();
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
LogicalResult performInlining(SymbolTableCollection &tables, InliningPlan &plan)
mlir::SmallVector< std::pair< llzk::component::StructDefOp, mlir::SmallVector< llzk::component::StructDefOp > > > InliningPlan
Maps caller struct to callees that should be inlined.
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
bool isRealNode() const
Return 'false' iff this node is an artificial node created for the graph head/tail.
bool isTemplateSymbolBinding() const
Return true iff the symbol is a defined by a TemplateSymbolBindingOpInterface.
mlir::SymbolRefAttr getSymbolPath() const
The symbol path+name relative to the closest root ModuleOp.
mlir::ModuleOp getSymbolPathRoot() const
Return the root ModuleOp for the path.
llvm::iterator_range< iterator > successorIter() const
Range over successor nodes.
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
llvm::iterator_range< iterator > nodesIter() const
Range over all nodes in the graph.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
::llvm::StringRef getMemberName()
::mlir::FailureOr< SymbolLookupResult< MemberDefOp > > getMemberDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the member referenced in this op.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Gets the SSA value with the target component from the MemberRefOp.
void setMemberName(::llvm::StringRef attrValue)
Sets the member name attribute value in the MemberRefOp.
::llvm::StringRef getMemberName()
Gets the member name attribute value from the MemberRefOp.
::mlir::OpOperand & getComponentMutable()
Gets the SSA value with the target component from the MemberRefOp.
::llvm::StringRef getSymName()
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
void print(::mlir::OpAsmPrinter &_odsPrinter)
::mlir::Operation::operand_range getArgOperands()
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
::mlir::OperandRangeRange getMapOperands()
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
FuncDefOp clone(::mlir::IRMapping &mapper)
Create a deep copy of this function and all of its blocks, remapping any operands that use values out...
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
::mlir::FunctionType getFunctionType()
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
bool isStructConstrain()
Return true iff the function is within a StructDefOp and named FUNC_NAME_CONSTRAIN.
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
::mlir::Region & getBody()
std::string toStringOne(const T &value)
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
uint64_t computeEmitEqCardinality(Type type)
constexpr char FUNC_NAME_CONSTRAIN[]
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
mlir::DictionaryAttr withFunctionArgNameAttr(mlir::DictionaryAttr attrs, llvm::StringRef name)
Return a copy of the given argument attribute dictionary with function.arg_name set to name.
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent/ancestor operation that is of type 'OpClass'.
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
std::string reserveUniqueAttrName(llvm::StringSet<> &usedNames, llvm::StringRef desiredName)
Reserve and return a unique function argument/result name based on desiredName.