LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
LLZKInlineStructsPass.cpp
Go to the documentation of this file.
1//===-- LLZKInlineStructsPass.cpp -------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
19//===----------------------------------------------------------------------===//
20
22
31#include "llzk/Util/Debug.h"
34
35#include <mlir/IR/BuiltinOps.h>
36#include <mlir/Transforms/InliningUtils.h>
37#include <mlir/Transforms/WalkPatternRewriteDriver.h>
38
39#include <llvm/ADT/DenseMap.h>
40#include <llvm/ADT/SmallPtrSet.h>
41#include <llvm/ADT/SmallVector.h>
42#include <llvm/ADT/StringMap.h>
43#include <llvm/ADT/TypeSwitch.h>
44#include <llvm/Support/Debug.h>
45
46#include <concepts>
47#include <optional>
48
49// Include the generated base pass class definitions.
50namespace llzk {
51#define GEN_PASS_DEF_INLINESTRUCTSPASS
53} // namespace llzk
54
55using namespace mlir;
56using namespace llzk;
57using namespace llzk::component;
58using namespace llzk::function;
59using namespace llzk::polymorphic;
60
61#define DEBUG_TYPE "llzk-inline-structs"
62
63namespace {
64
65using DestMemberWithSrcStructType = MemberDefOp;
66using DestCloneOfSrcStructMember = MemberDefOp;
70using SrcStructMemberToCloneInDest = std::map<StringRef, DestCloneOfSrcStructMember>;
73using DestToSrcToClonedSrcInDest =
74 DenseMap<DestMemberWithSrcStructType, SrcStructMemberToCloneInDest>;
75
78static inline Value getSelfValue(FuncDefOp f) {
79 if (f.nameIsCompute()) {
80 return f.getSelfValueFromCompute();
81 } else if (f.nameIsConstrain()) {
83 } else {
84 llvm_unreachable("expected \"@compute\" or \"@constrain\" function");
85 }
86}
87
90static inline MemberDefOp getDef(SymbolTableCollection &tables, MemberRefOpInterface fRef) {
91 auto r = fRef.getMemberDefOp(tables);
92 assert(succeeded(r));
93 return r->get();
94}
95
98static FailureOr<MemberWriteOp>
99findOpThatStoresSubcmp(Value writtenValue, function_ref<InFlightDiagnostic()> emitError) {
100 MemberWriteOp foundWrite = nullptr;
101 for (Operation *user : writtenValue.getUsers()) {
102 if (MemberWriteOp writeOp = llvm::dyn_cast<MemberWriteOp>(user)) {
103 // Find the write op that stores the created value
104 if (writeOp.getVal() == writtenValue) {
105 if (foundWrite) {
106 // Note: There is no reason for a subcomponent to be stored to more than one member.
107 auto diag = emitError().append("result should not be written to more than one member.");
108 diag.attachNote(foundWrite.getLoc()).append("written here");
109 diag.attachNote(writeOp.getLoc()).append("written here");
110 return diag;
111 } else {
112 foundWrite = writeOp;
113 }
114 }
115 }
116 }
117 if (!foundWrite) {
118 // Note: There is no reason to construct a subcomponent and not store it to a member.
119 return emitError().append("result should be written to a member.");
120 }
121 return foundWrite;
122}
123
127static bool combineHelper(
128 MemberReadOp readOp, SymbolTableCollection &tables,
129 const DestToSrcToClonedSrcInDest &destToSrcToClone, MemberRefOpInterface destMemberRefOp
130) {
131 LLVM_DEBUG({
132 llvm::dbgs() << "[combineHelper] " << readOp << " => " << destMemberRefOp << '\n';
133 });
134
135 auto srcToClone = destToSrcToClone.find(getDef(tables, destMemberRefOp));
136 if (srcToClone == destToSrcToClone.end()) {
137 return false;
138 }
139 SrcStructMemberToCloneInDest oldToNewMembers = srcToClone->second;
140 auto resNewMember = oldToNewMembers.find(readOp.getMemberName());
141 if (resNewMember == oldToNewMembers.end()) {
142 return false;
143 }
144
145 // Replace this MemberReadOp with a new one that targets the cloned member.
146 OpBuilder builder(readOp);
147 MemberReadOp newRead = builder.create<MemberReadOp>(
148 readOp.getLoc(), readOp.getType(), destMemberRefOp.getComponent(),
149 resNewMember->second.getNameAttr()
150 );
151 readOp.replaceAllUsesWith(newRead.getOperation());
152 readOp.erase(); // delete the original MemberReadOp
153 return true;
154}
155
169static bool combineReadChain(
170 MemberReadOp readOp, SymbolTableCollection &tables,
171 const DestToSrcToClonedSrcInDest &destToSrcToClone
172) {
173 LLVM_DEBUG({ llvm::dbgs() << "[combineReadChain] " << readOp << '\n'; });
174
175 MemberReadOp readThatDefinesBaseComponent =
176 llvm::dyn_cast_if_present<MemberReadOp>(readOp.getComponent().getDefiningOp());
177 if (!readThatDefinesBaseComponent) {
178 return false;
179 }
180 return combineHelper(readOp, tables, destToSrcToClone, readThatDefinesBaseComponent);
181}
182
199static LogicalResult combineNewThenReadChain(
200 MemberReadOp readOp, SymbolTableCollection &tables,
201 const DestToSrcToClonedSrcInDest &destToSrcToClone
202) {
203 LLVM_DEBUG({ llvm::dbgs() << "[combineNewThenReadChain] " << readOp << '\n'; });
204
205 CreateStructOp createThatDefinesBaseComponent =
206 llvm::dyn_cast_if_present<CreateStructOp>(readOp.getComponent().getDefiningOp());
207 if (!createThatDefinesBaseComponent) {
208 return success(); // No error. The pattern simply doesn't match.
209 }
210 FailureOr<MemberWriteOp> foundWrite =
211 findOpThatStoresSubcmp(createThatDefinesBaseComponent, [&createThatDefinesBaseComponent]() {
212 return createThatDefinesBaseComponent.emitOpError();
213 });
214 if (failed(foundWrite)) {
215 return failure(); // error already printed within findOpThatStoresSubcmp()
216 }
217 return success(combineHelper(readOp, tables, destToSrcToClone, foundWrite.value()));
218}
219
220static inline MemberReadOp getMemberReadThatDefinesSelfValuePassedToConstrain(CallOp callOp) {
221 Value selfArgFromCall = callOp.getSelfValueFromConstrain();
222 return llvm::dyn_cast_if_present<MemberReadOp>(selfArgFromCall.getDefiningOp());
223}
224
227struct PendingErasure {
228 SmallPtrSet<Operation *, 8> memberReadOps;
229 SmallPtrSet<Operation *, 8> memberWriteOps;
230 SmallVector<CreateStructOp> newStructOps;
231 SmallVector<DestMemberWithSrcStructType> memberDefs;
232};
233
235class StructInliner {
236 SymbolTableCollection &tables;
237 PendingErasure &toDelete;
239 StructDefOp srcStruct;
241 StructDefOp destStruct;
242
243 inline MemberDefOp getDef(MemberRefOpInterface fRef) const { return ::getDef(tables, fRef); }
244
245 // Update member read/write ops that target the "self" value of the FuncDefOp plus some key in
246 // `oldToNewMemberDef` to instead target the new base Value provided to the constructor plus the
247 // mapped Value from `oldToNewMemberDef`.
248 // Example:
249 // old: %1 = struct.readm %0[@f1] : <@Component1A>, !felt.type
250 // new: %1 = struct.readm %self[@"f2:!s<@Component1A>+f1"] : <@Component1B>, !felt.type
251 class MemberRefRewriter final : public OpInterfaceRewritePattern<MemberRefOpInterface> {
254 FuncDefOp funcRef;
256 Value oldBaseVal;
258 Value newBaseVal;
259 const SrcStructMemberToCloneInDest &oldToNewMembers;
260
261 public:
262 MemberRefRewriter(
263 FuncDefOp originalFunc, Value newRefBase,
264 const SrcStructMemberToCloneInDest &oldToNewMemberDef
265 )
266 : OpInterfaceRewritePattern(originalFunc.getContext()), funcRef(originalFunc),
267 oldBaseVal(nullptr), newBaseVal(newRefBase), oldToNewMembers(oldToNewMemberDef) {}
268
269 LogicalResult match(MemberRefOpInterface op) const final {
270 assert(oldBaseVal); // ensure it's used via `cloneWithMemberRefUpdate()` only
271 // Check if the MemberRef accesses a member of "self" within the `oldToNewMembers` map.
272 // Per `cloneWithMemberRefUpdate()`, `oldBaseVal` is the "self" value of `funcRef` so
273 // check for a match there and then check that the referenced member name is in the map.
274 return success(
275 op.getComponent() == oldBaseVal && oldToNewMembers.contains(op.getMemberName())
276 );
277 }
278
279 void rewrite(MemberRefOpInterface op, PatternRewriter &rewriter) const final {
280 rewriter.modifyOpInPlace(op, [this, &op]() {
281 DestCloneOfSrcStructMember newF = oldToNewMembers.at(op.getMemberName());
282 op.setMemberName(newF.getSymName());
283 op.getComponentMutable().set(this->newBaseVal);
284 });
285 }
286
289 static FuncDefOp cloneWithMemberRefUpdate(std::unique_ptr<MemberRefRewriter> thisPat) {
290 IRMapping mapper;
291 FuncDefOp srcFuncClone = thisPat->funcRef.clone(mapper);
292 // Update some data in the `MemberRefRewriter` instance before moving it.
293 thisPat->funcRef = srcFuncClone;
294 thisPat->oldBaseVal = getSelfValue(srcFuncClone);
295 // Run the rewriter to replace read/write ops
296 MLIRContext *ctx = thisPat->getContext();
297 RewritePatternSet patterns(ctx, std::move(thisPat));
298 walkAndApplyPatterns(srcFuncClone, std::move(patterns));
299
300 return srcFuncClone;
301 }
302 };
303
305 class ImplBase {
306 protected:
307 const StructInliner &data;
308 const DestToSrcToClonedSrcInDest &destToSrcToClone;
309
312 virtual MemberRefOpInterface getSelfRefMember(CallOp callOp) = 0;
313 virtual void processCloneBeforeInlining(FuncDefOp func) {}
314 virtual ~ImplBase() = default;
315
316 public:
317 ImplBase(const StructInliner &inliner, const DestToSrcToClonedSrcInDest &destToSrcToCloneRef)
318 : data(inliner), destToSrcToClone(destToSrcToCloneRef) {}
319
320 LogicalResult doInlining(FuncDefOp srcFunc, FuncDefOp destFunc) {
321 LLVM_DEBUG({
322 llvm::dbgs() << "[doInlining] SOURCE FUNCTION:\n";
323 srcFunc.dump();
324 llvm::dbgs() << "[doInlining] DESTINATION FUNCTION:\n";
325 destFunc.dump();
326 });
328 InlinerInterface inliner(destFunc.getContext());
331 auto callHandler = [this, &inliner, &srcFunc](CallOp callOp) {
332 // Ensure the CallOp targets `srcFunc`
333 auto callOpTarget = callOp.getCalleeTarget(this->data.tables);
334 assert(succeeded(callOpTarget));
335 if (callOpTarget->get() != srcFunc) {
336 return WalkResult::advance();
337 }
339 // Get the "self" struct parameter from the CallOp and determine which member that struct
340 // was stored in within the caller (i.e. `destFunc`).
341 MemberRefOpInterface selfMemberRefOp = this->getSelfRefMember(callOp);
342 if (!selfMemberRefOp) {
343 // Note: error message was already printed within `getSelfRefMember()`
344 return WalkResult::interrupt(); // use interrupt to signal failure
345 }
347 // Create a clone of the source function (must do the whole function not just the body
348 // region because `inlineCall()` expects the Region to have a parent op) and update member
349 // references to the old struct members to instead use the new struct members.
350 FuncDefOp srcFuncClone = MemberRefRewriter::cloneWithMemberRefUpdate(
351 std::make_unique<MemberRefRewriter>(
352 srcFunc, selfMemberRefOp.getComponent(),
353 this->destToSrcToClone.at(this->data.getDef(selfMemberRefOp))
355 );
356 this->processCloneBeforeInlining(srcFuncClone);
357
358 // Inline the cloned function in place of `callOp`
359 LogicalResult inlineCallRes =
360 inlineCall(inliner, callOp, srcFuncClone, &srcFuncClone.getBody(), false);
361 if (failed(inlineCallRes)) {
362 callOp.emitError().append("Failed to inline ", srcFunc.getFullyQualifiedName()).report();
363 return WalkResult::interrupt(); // use interrupt to signal failure
364 }
365 srcFuncClone.erase(); // delete what's left after transferring the body elsewhere
366 callOp.erase(); // delete the original CallOp
367 return WalkResult::skip(); // Must skip because the CallOp was erased.
368 };
369
370 auto memberWriteHandler = [this](MemberWriteOp writeOp) {
371 // Check if the member ref op should be deleted in the end
372 if (this->destToSrcToClone.contains(this->data.getDef(writeOp))) {
373 this->data.toDelete.memberWriteOps.insert(writeOp);
374 }
375 return WalkResult::advance();
376 };
377
380 auto memberReadHandler = [this](MemberReadOp readOp) {
381 // Check if the member ref op should be deleted in the end
382 if (this->destToSrcToClone.contains(this->data.getDef(readOp))) {
383 this->data.toDelete.memberReadOps.insert(readOp);
385 // If the MemberReadOp was replaced/erased, must skip.
386 return combineReadChain(readOp, this->data.tables, destToSrcToClone)
387 ? WalkResult::skip()
388 : WalkResult::advance();
389 };
390
391 WalkResult walkRes = destFunc.getBody().walk<WalkOrder::PreOrder>([&](Operation *op) {
392 return TypeSwitch<Operation *, WalkResult>(op)
393 .Case<CallOp>(callHandler)
394 .Case<MemberWriteOp>(memberWriteHandler)
395 .Case<MemberReadOp>(memberReadHandler)
396 .Default([](Operation *) { return WalkResult::advance(); });
397 });
398
399 return failure(walkRes.wasInterrupted());
400 }
401 };
402
403 class ConstrainImpl : public ImplBase {
404 using ImplBase::ImplBase;
405
406 MemberRefOpInterface getSelfRefMember(CallOp callOp) override {
407 LLVM_DEBUG({ llvm::dbgs() << "[ConstrainImpl::getSelfRefMember] " << callOp << '\n'; });
408
409 // The typical pattern is to read a struct instance from a member and then call "constrain()"
410 // on it. Get the Value passed as the "self" struct to the CallOp and determine which member
411 // it was read from in the current struct (i.e., `destStruct`).
412 MemberRefOpInterface selfMemberRef =
413 getMemberReadThatDefinesSelfValuePassedToConstrain(callOp);
414 if (selfMemberRef &&
415 selfMemberRef.getComponent().getType() == this->data.destStruct.getType()) {
416 return selfMemberRef;
417 }
418 callOp.emitError()
419 .append(
420 "expected \"self\" parameter to \"@", FUNC_NAME_CONSTRAIN,
421 "\" to be passed a value read from a member in the current stuct."
422 )
423 .report();
424 return nullptr;
425 }
426 };
427
428 class ComputeImpl : public ImplBase {
429 using ImplBase::ImplBase;
430
431 MemberRefOpInterface getSelfRefMember(CallOp callOp) override {
432 LLVM_DEBUG({ llvm::dbgs() << "[ComputeImpl::getSelfRefMember] " << callOp << '\n'; });
433
434 // The typical pattern is to write the return value of "compute()" to a member in
435 // the current struct (i.e., `destStruct`).
436 // It doesn't really make sense (although there is no semantic restriction against it) to just
437 // pass the "compute()" result into another function and never write it to a member since that
438 // leaves no way for the "constrain()" function to call "constrain()" on that result struct.
439 FailureOr<MemberWriteOp> foundWrite =
440 findOpThatStoresSubcmp(callOp.getSelfValueFromCompute(), [&callOp]() {
441 return callOp.emitOpError().append("\"@", FUNC_NAME_COMPUTE, "\" ");
442 });
443 return static_cast<MemberRefOpInterface>(foundWrite.value_or(nullptr));
444 }
445
446 void processCloneBeforeInlining(FuncDefOp func) override {
447 // Within the compute function, find `CreateStructOp` with `srcStruct` type and mark them
448 // for later deletion. The deletion must occur later because these values may still have
449 // uses until ALL callees of a function have been inlined.
450 func.getBody().walk([this](CreateStructOp newStructOp) {
451 if (newStructOp.getType() == this->data.srcStruct.getType()) {
452 this->data.toDelete.newStructOps.push_back(newStructOp);
453 }
454 });
455 }
456 };
457
458 // Find any member(s) in `destStruct` whose type matches `srcStruct` (allowing any parameters, if
459 // applicable). For each such member, clone all members from `srcStruct` into `destStruct` and
460 // cache the mapping of `destStruct` to `srcStruct` to cloned members in the return value.
461 DestToSrcToClonedSrcInDest cloneMembers() {
462 DestToSrcToClonedSrcInDest destToSrcToClone;
463
464 SymbolTable &destStructSymTable = tables.getSymbolTable(destStruct);
465 StructType srcStructType = srcStruct.getType();
466 for (MemberDefOp destMember : destStruct.getMemberDefs()) {
467 if (StructType destMemberType = llvm::dyn_cast<StructType>(destMember.getType())) {
468 UnificationMap unifications;
469 if (!structTypesUnify(srcStructType, destMemberType, {}, &unifications)) {
470 continue;
471 }
472 assert(unifications.empty()); // `makePlan()` reports failure earlier
473 // Mark the original `destMember` for deletion
474 toDelete.memberDefs.push_back(destMember);
475 // Clone each member from 'srcStruct' into 'destStruct'. Add an entry to `destToSrcToClone`
476 // even if there are no members in `srcStruct` so its presence can be used as a marker.
477 SrcStructMemberToCloneInDest &srcToClone = destToSrcToClone[destMember];
478 std::vector<MemberDefOp> srcMembers = srcStruct.getMemberDefs();
479 if (srcMembers.empty()) {
480 continue;
481 }
482 OpBuilder builder(destMember);
483 std::string newNameBase =
484 destMember.getName().str() + ':' + BuildShortTypeString::from(destMemberType);
485 for (MemberDefOp srcMember : srcMembers) {
486 DestCloneOfSrcStructMember newF = llvm::cast<MemberDefOp>(builder.clone(*srcMember));
487 newF.setName(builder.getStringAttr(newNameBase + '+' + newF.getName()));
488 srcToClone[srcMember.getSymNameAttr()] = newF;
489 // Also update the cached SymbolTable
490 destStructSymTable.insert(newF);
491 }
492 }
493 }
494 return destToSrcToClone;
495 }
496
498 inline LogicalResult inlineConstrainCall(const DestToSrcToClonedSrcInDest &destToSrcToClone) {
499 return ConstrainImpl(*this, destToSrcToClone)
500 .doInlining(srcStruct.getConstrainFuncOp(), destStruct.getConstrainFuncOp());
501 }
502
504 inline LogicalResult inlineComputeCall(const DestToSrcToClonedSrcInDest &destToSrcToClone) {
505 return ComputeImpl(*this, destToSrcToClone)
506 .doInlining(srcStruct.getComputeFuncOp(), destStruct.getComputeFuncOp());
507 }
508
509public:
510 StructInliner(
511 SymbolTableCollection &tbls, PendingErasure &opsToDelete, StructDefOp from, StructDefOp into
512 )
513 : tables(tbls), toDelete(opsToDelete), srcStruct(from), destStruct(into) {}
514
515 FailureOr<DestToSrcToClonedSrcInDest> doInline() {
516 LLVM_DEBUG(
517 llvm::dbgs() << "[StructInliner] merge " << srcStruct.getSymNameAttr() << " into "
518 << destStruct.getSymNameAttr() << '\n'
519 );
520
521 DestToSrcToClonedSrcInDest destToSrcToClone = cloneMembers();
522 if (failed(inlineConstrainCall(destToSrcToClone)) ||
523 failed(inlineComputeCall(destToSrcToClone))) {
524 return failure(); // error already printed within doInlining()
525 }
526 return destToSrcToClone;
527 }
528};
529
530template <typename T>
531concept HasContainsOp = requires(const T &t, Operation *p) {
532 { t.contains(p) } -> std::convertible_to<bool>;
533};
534
536template <typename... PendingDeletionSets>
538class DanglingUseHandler {
539 SymbolTableCollection &tables;
540 const DestToSrcToClonedSrcInDest &destToSrcToClone;
541 std::tuple<const PendingDeletionSets &...> otherRefsToBeDeleted;
542
543public:
544 DanglingUseHandler(
545 SymbolTableCollection &symTables, const DestToSrcToClonedSrcInDest &destToSrcToCloneRef,
546 const PendingDeletionSets &...otherRefsPendingDeletion
547 )
548 : tables(symTables), destToSrcToClone(destToSrcToCloneRef),
549 otherRefsToBeDeleted(otherRefsPendingDeletion...) {}
550
556 LogicalResult handle(Operation *op) const {
557 if (op->use_empty()) {
558 return success(); // safe to erase
559 }
560
561 LLVM_DEBUG({
562 llvm::dbgs() << "[DanglingUseHandler::handle] op: " << *op << '\n';
563 llvm::dbgs() << "[DanglingUseHandler::handle] in function: "
564 << op->getParentOfType<FuncDefOp>() << '\n';
565 });
566 for (OpOperand &use : llvm::make_early_inc_range(op->getUses())) {
567 if (CallOp c = llvm::dyn_cast<CallOp>(use.getOwner())) {
568 if (failed(handleUseInCallOp(use, c, op))) {
569 return failure();
570 }
571 } else {
572 Operation *user = use.getOwner();
573 // Report an error for any user other than some member ref that will be deleted anyway.
574 if (!opWillBeDeleted(user)) {
575 return op->emitOpError()
576 .append(
577 "with use in '", user->getName().getStringRef(),
578 "' is not (currently) supported by this pass."
579 )
580 .attachNote(user->getLoc())
581 .append("used by this operation");
582 }
583 }
584 }
585 // Ensure that all users of the 'op' were deleted above, or will be per 'otherRefsToBeDeleted'.
586 if (!op->use_empty()) {
587 for (Operation *user : op->getUsers()) {
588 if (!opWillBeDeleted(user)) {
589 llvm::errs() << "Op has remaining use(s) that could not be removed: " << *op << '\n';
590 llvm_unreachable("Expected all uses to be removed");
591 }
592 }
593 }
594 return success();
595 }
596
597private:
598 inline LogicalResult handleUseInCallOp(OpOperand &use, CallOp inCall, Operation *origin) const {
599 LLVM_DEBUG(
600 llvm::dbgs() << "[DanglingUseHandler::handleUseInCallOp] use in call: " << inCall << '\n'
601 );
602 unsigned argIdx = use.getOperandNumber() - inCall.getArgOperands().getBeginOperandIndex();
603 LLVM_DEBUG(
604 llvm::dbgs() << "[DanglingUseHandler::handleUseInCallOp] at index: " << argIdx << '\n'
605 );
606
607 auto tgtFuncRes = inCall.getCalleeTarget(tables);
608 if (failed(tgtFuncRes)) {
609 return origin
610 ->emitOpError("as argument to an unknown function is not supported by this pass.")
611 .attachNote(inCall.getLoc())
612 .append("used by this call");
613 }
614 FuncDefOp tgtFunc = tgtFuncRes->get();
615 LLVM_DEBUG(
616 llvm::dbgs() << "[DanglingUseHandler::handleUseInCallOp] call target: " << tgtFunc << '\n'
617 );
618 if (tgtFunc.isExternal()) {
619 // Those without a body (i.e. external implementation) present a problem because LLZK does
620 // not define a memory layout for the external implementation to interpret the struct.
621 return origin
622 ->emitOpError("as argument to a no-body free function is not supported by this pass.")
623 .attachNote(inCall.getLoc())
624 .append("used by this call");
625 }
626
627 MemberRefOpInterface paramFromMember =
628 TypeSwitch<Operation *, MemberRefOpInterface>(origin)
629 .template Case<MemberReadOp>([](auto p) { return p; })
630 .template Case<CreateStructOp>([](auto p) {
631 return findOpThatStoresSubcmp(p, [&p]() { return p.emitOpError(); }).value_or(nullptr);
632 }).Default([](Operation *p) {
633 llvm::errs() << "Encountered unexpected op: "
634 << (p ? p->getName().getStringRef() : "<<null>>") << '\n';
635 llvm_unreachable("Unexpected op kind");
636 return nullptr;
637 });
638 LLVM_DEBUG({
639 llvm::dbgs() << "[DanglingUseHandler::handleUseInCallOp] member ref op for param: "
640 << (paramFromMember ? debug::toStringOne(paramFromMember) : "<<null>>") << '\n';
641 });
642 if (!paramFromMember) {
643 return failure(); // error already printed within findOpThatStoresSubcmp()
644 }
645 const SrcStructMemberToCloneInDest &newMembers =
646 destToSrcToClone.at(getDef(tables, paramFromMember));
647 LLVM_DEBUG({
648 llvm::dbgs() << "[DanglingUseHandler::handleUseInCallOp] members to split: "
649 << debug::toStringList(newMembers) << '\n';
650 });
651
652 // Convert the FuncDefOp side first (to use the easier builder for the new CallOp).
653 splitFunctionParam(tgtFunc, argIdx, newMembers);
654 LLVM_DEBUG({
655 llvm::dbgs() << "[DanglingUseHandler::handleUseInCallOp] UPDATED call target: " << tgtFunc
656 << '\n';
657 llvm::dbgs() << "[DanglingUseHandler::handleUseInCallOp] UPDATED call target type: "
658 << tgtFunc.getFunctionType() << '\n';
659 });
660
661 // Convert the CallOp side. Add a MemberReadOp for each value from the struct and pass them
662 // individually in place of the struct parameter.
663 OpBuilder builder(inCall);
664 SmallVector<Value> splitArgs;
665 // Before the CallOp, insert a read from every new member. These Values will replace the
666 // original argument in the CallOp.
667 Value originalBaseVal = paramFromMember.getComponent();
668 for (auto [origName, newMemberRef] : newMembers) {
669 splitArgs.push_back(builder.create<MemberReadOp>(
670 inCall.getLoc(), newMemberRef.getType(), originalBaseVal, newMemberRef.getNameAttr()
671 ));
672 }
673 // Generate the new argument list from the original but replace 'argIdx'
674 SmallVector<Value> newOpArgs(inCall.getArgOperands());
675 newOpArgs.insert(
676 newOpArgs.erase(newOpArgs.begin() + argIdx), splitArgs.begin(), splitArgs.end()
677 );
678 // Create the new CallOp, replace uses of the old with the new, delete the old
679 inCall.replaceAllUsesWith(builder.create<CallOp>(
680 inCall.getLoc(), tgtFunc, CallOp::toVectorOfValueRange(inCall.getMapOperands()),
681 inCall.getNumDimsPerMapAttr(), newOpArgs
682 ));
683 inCall.erase();
684 LLVM_DEBUG({
685 llvm::dbgs() << "[DanglingUseHandler::handleUseInCallOp] UPDATED function: "
686 << origin->getParentOfType<FuncDefOp>() << '\n';
687 });
688 return success();
689 }
690
692 inline bool opWillBeDeleted(Operation *otherOp) const {
693 return std::apply([&](const auto &...sets) {
694 return ((sets.contains(otherOp)) || ...);
695 }, otherRefsToBeDeleted);
696 }
697
702 static void splitFunctionParam(
703 FuncDefOp func, unsigned paramIdx, const SrcStructMemberToCloneInDest &nameToNewMember
704 ) {
705 class Impl : public FunctionTypeConverter {
706 unsigned inputIdx;
707 const SrcStructMemberToCloneInDest &newMembers;
708 std::optional<std::string> originalArgName;
709 SmallVector<std::string> existingArgNames;
710
711 public:
712 Impl(FuncDefOp func, unsigned paramIdx, const SrcStructMemberToCloneInDest &nameToNewMember)
713 : inputIdx(paramIdx), newMembers(nameToNewMember) {
714 for (unsigned i = 0, e = func.getNumArguments(); i < e; ++i) {
715 if (std::optional<StringAttr> argName = func.getArgNameAttr(i)) {
716 existingArgNames.push_back(argName->getValue().str());
717 if (i == inputIdx) {
718 originalArgName = argName->getValue().str();
719 }
720 }
721 }
722 }
723
724 protected:
725 SmallVector<Type> convertInputs(ArrayRef<Type> origTypes) override {
726 SmallVector<Type> newTypes(origTypes);
727 auto *it = newTypes.erase(newTypes.begin() + inputIdx);
728 for (auto [_, newMember] : newMembers) {
729 newTypes.insert(it, newMember.getType());
730 ++it;
731 }
732 return newTypes;
733 }
734 SmallVector<Type> convertResults(ArrayRef<Type> origTypes) override {
735 return SmallVector<Type>(origTypes);
736 }
737 ArrayAttr convertInputAttrs(ArrayAttr origAttrs, SmallVector<Type>) override {
738 if (origAttrs) {
739 // Replicate the value at `origAttrs[inputIdx]` to have `newMembers.size()`
740 SmallVector<Attribute> newAttrs(origAttrs.getValue());
741 auto splitAttr = llvm::cast<DictionaryAttr>(origAttrs[inputIdx]);
742 SmallVector<Attribute> splitAttrs;
743 if (originalArgName) {
744 llvm::StringSet<> usedArgNames;
745 for (StringRef argName : existingArgNames) {
746 usedArgNames.insert(argName);
747 }
748 for (auto [memberName, _] : newMembers) {
749 std::string desiredName = (*originalArgName + '.' + memberName).str();
750 splitAttrs.push_back(withFunctionArgNameAttr(
751 splitAttr, reserveUniqueAttrName(usedArgNames, desiredName)
752 ));
753 }
754 } else {
755 splitAttrs.append(newMembers.size(), splitAttr);
756 }
757 newAttrs[inputIdx] = splitAttrs.front();
758 newAttrs.insert(
759 newAttrs.begin() + inputIdx + 1, splitAttrs.begin() + 1, splitAttrs.end()
760 );
761 return ArrayAttr::get(origAttrs.getContext(), newAttrs);
762 }
763 return nullptr;
764 }
765 ArrayAttr convertResultAttrs(ArrayAttr origAttrs, SmallVector<Type>) override {
766 return origAttrs;
767 }
768
769 void processBlockArgs(Block &entryBlock, RewriterBase &rewriter) override {
770 Value oldStructRef = entryBlock.getArgument(inputIdx);
771
772 // Insert new Block arguments, one per member, following the original one. Keep a map
773 // of member name to the associated block argument for replacing MemberReadOp.
774 llvm::StringMap<BlockArgument> memberNameToNewArg;
775 Location loc = oldStructRef.getLoc();
776 unsigned idx = inputIdx;
777 for (auto [memberName, newMember] : newMembers) {
778 // note: pre-increment so the original to be erased is still at `inputIdx`
779 BlockArgument newArg = entryBlock.insertArgument(++idx, newMember.getType(), loc);
780 memberNameToNewArg[memberName] = newArg;
781 }
782
783 // Find all member reads from the original Block argument and replace uses of those
784 // reads with the appropriate new Block argument.
785 for (OpOperand &oldBlockArgUse : llvm::make_early_inc_range(oldStructRef.getUses())) {
786 if (MemberReadOp readOp = llvm::dyn_cast<MemberReadOp>(oldBlockArgUse.getOwner())) {
787 if (readOp.getComponent() == oldStructRef) {
788 BlockArgument newArg = memberNameToNewArg.at(readOp.getMemberName());
789 rewriter.replaceAllUsesWith(readOp, newArg);
790 rewriter.eraseOp(readOp);
791 continue;
792 }
793 }
794 // Currently, there's no other way in which a StructType parameter can be used.
795 llvm::errs() << "Unexpected use of " << oldBlockArgUse.get() << " in "
796 << *oldBlockArgUse.getOwner() << '\n';
797 llvm_unreachable("Not yet implemented");
798 }
799
800 // Delete the original Block argument
801 entryBlock.eraseArgument(inputIdx);
802 }
803 };
804 IRRewriter rewriter(func.getContext());
805 Impl(func, paramIdx, nameToNewMember).convert(func, rewriter);
806 }
807};
808
809static LogicalResult finalizeStruct(
810 SymbolTableCollection &tables, StructDefOp caller, PendingErasure &&toDelete,
811 DestToSrcToClonedSrcInDest &&destToSrcToClone
812) {
813 LLVM_DEBUG({
814 llvm::dbgs() << "[finalizeStruct] dumping 'caller' struct before compressing chains:\n";
815 caller.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
816 llvm::dbgs() << '\n';
817 });
818
819 // Compress chains of reads that result after inlining multiple callees.
820 caller.getConstrainFuncOp().walk([&tables, &destToSrcToClone](MemberReadOp readOp) {
821 combineReadChain(readOp, tables, destToSrcToClone);
822 });
823 FuncDefOp computeFn = caller.getComputeFuncOp();
824 Value computeSelfVal = computeFn.getSelfValueFromCompute();
825 auto res = computeFn.walk([&tables, &destToSrcToClone, &computeSelfVal](MemberReadOp readOp) {
826 combineReadChain(readOp, tables, destToSrcToClone);
827 // Reads targeting the "self" value from "compute()" are not eligible for the compression
828 // provided in `combineNewThenReadChain()` and will actually cause an error within.
829 if (readOp.getComponent() == computeSelfVal) {
830 return WalkResult::advance();
831 }
832 LogicalResult innerRes = combineNewThenReadChain(readOp, tables, destToSrcToClone);
833 return failed(innerRes) ? WalkResult::interrupt() : WalkResult::advance();
834 });
835 if (res.wasInterrupted()) {
836 return failure(); // error already printed within combineNewThenReadChain()
837 }
838
839 LLVM_DEBUG({
840 llvm::dbgs() << "[finalizeStruct] dumping 'caller' struct before deleting ops:\n";
841 caller.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
842 llvm::dbgs() << '\n';
843 llvm::dbgs() << "[finalizeStruct] ops marked for deletion:\n";
844 for (Operation *op : toDelete.memberReadOps) {
845 llvm::dbgs().indent(2) << *op << '\n';
846 }
847 for (Operation *op : toDelete.memberWriteOps) {
848 llvm::dbgs().indent(2) << *op << '\n';
849 }
850 for (CreateStructOp op : toDelete.newStructOps) {
851 llvm::dbgs().indent(2) << op << '\n';
852 }
853 for (DestMemberWithSrcStructType op : toDelete.memberDefs) {
854 llvm::dbgs().indent(2) << op << '\n';
855 }
856 });
857
858 // Handle remaining uses of CreateStructOp before deleting anything because this process
859 // needs to be able to find the MemberWriteOp instances that store the result of these ops.
860 DanglingUseHandler<SmallPtrSet<Operation *, 8>, SmallPtrSet<Operation *, 8>> useHandler(
861 tables, destToSrcToClone, toDelete.memberWriteOps, toDelete.memberReadOps
862 );
863 for (CreateStructOp op : toDelete.newStructOps) {
864 if (failed(useHandler.handle(op))) {
865 return failure(); // error already printed within handle()
866 }
867 }
868 // Next, to avoid "still has uses" errors, must erase MemberWriteOp first, then MemberReadOp,
869 // before erasing the CreateStructOp or MemberDefOp.
870 for (Operation *op : toDelete.memberWriteOps) {
871 if (failed(useHandler.handle(op))) {
872 return failure(); // error already printed within handle()
873 }
874 op->erase();
875 }
876 for (Operation *op : toDelete.memberReadOps) {
877 if (failed(useHandler.handle(op))) {
878 return failure(); // error already printed within handle()
879 }
880 op->erase();
881 }
882 for (CreateStructOp op : toDelete.newStructOps) {
883 op.erase();
884 }
885 // Finally, erase MemberDefOp via SymbolTable so table itself is updated too.
886 SymbolTable &callerSymTab = tables.getSymbolTable(caller);
887 for (DestMemberWithSrcStructType op : toDelete.memberDefs) {
888 assert(op.getParentOp() == caller); // using correct SymbolTable
889 callerSymTab.erase(op);
890 }
891
892 return success();
893}
894
895} // namespace
896
897LogicalResult performInlining(SymbolTableCollection &tables, InliningPlan &plan) {
898 for (auto &[caller, callees] : plan) {
899 // Cache operations that should be deleted but must wait until all callees are processed
900 // to ensure that all uses of the values defined by these operations are replaced.
901 PendingErasure toDelete;
902 // Cache old-to-new member mappings across all callees inlined for the current struct.
903 DestToSrcToClonedSrcInDest aggregateReplacements;
904 // Inline callees/subcomponents of the current struct
905 for (StructDefOp toInline : callees) {
906 FailureOr<DestToSrcToClonedSrcInDest> res =
907 StructInliner(tables, toDelete, toInline, caller).doInline();
908 if (failed(res)) {
909 return failure();
910 }
911 // Add current member replacements to the aggregate
912 for (auto &[k, v] : res.value()) {
913 assert(!aggregateReplacements.contains(k) && "duplicate not possible");
914 aggregateReplacements[k] = std::move(v);
915 }
916 }
917 // Complete steps to finalize/cleanup the caller
918 LogicalResult finalizeResult =
919 finalizeStruct(tables, caller, std::move(toDelete), std::move(aggregateReplacements));
920 if (failed(finalizeResult)) {
921 return failure();
922 }
923 }
924 return success();
925}
926
927namespace {
928
929class PassImpl : public llzk::impl::InlineStructsPassBase<PassImpl> {
930 using Base = InlineStructsPassBase<PassImpl>;
931 using Base::Base;
932
933 static uint64_t complexity(FuncDefOp f) {
934 uint64_t complexity = 0;
935 f.getBody().walk([&complexity](Operation *op) {
936 if (llvm::isa<felt::MulFeltOp>(op)) {
937 ++complexity;
938 } else if (auto ee = llvm::dyn_cast<constrain::EmitEqualityOp>(op)) {
939 complexity += computeEmitEqCardinality(ee.getLhs().getType());
940 } else if (auto ec = llvm::dyn_cast<constrain::EmitContainmentOp>(op)) {
941 // TODO: increment based on dimension sizes in the operands
942 // Pending update to implementation/semantics of EmitContainmentOp.
943 ++complexity;
944 }
945 });
946 return complexity;
947 }
948
954 static FuncDefOp
955 getIfResolvableStructConstrain(const SymbolUseGraphNode *node, SymbolTableCollection &tables) {
956 if (!node || !node->isRealNode() || node->isTemplateSymbolBinding()) {
957 return nullptr;
958 }
959 auto lookupRes = node->lookupSymbol(tables, /*reportMissing=*/false);
960 if (failed(lookupRes)) {
961 return nullptr;
962 }
963 FuncDefOp func = llvm::dyn_cast<FuncDefOp>(lookupRes->get());
964 if (!func || !func.isStructConstrain()) {
965 return nullptr;
966 }
967 return func;
968 }
969
972 static inline StructDefOp getParentStruct(FuncDefOp func) {
973 assert(func.isStructConstrain()); // pre-condition
974 StructDefOp currentNodeParentStruct = getParentOfType<StructDefOp>(func);
975 assert(currentNodeParentStruct); // follows from ODS definition
976 return currentNodeParentStruct;
977 }
978
980 inline bool exceedsMaxComplexity(uint64_t check) {
981 return maxComplexity > 0 && check > maxComplexity;
982 }
983
986 static inline bool canInline(FuncDefOp currentFunc, FuncDefOp successorFunc) {
987 // Find CallOp for `successorFunc` within `currentFunc` and check the condition used by
988 // `ConstrainImpl::getSelfRefMember()`.
989 //
990 // Implementation Note: There is a possibility that the "self" value is not from a member read.
991 // It could be a parameter to the current/destination function or a global read. Inlining a
992 // struct stored to a global would probably require splitting up the global into multiple, one
993 // for each member in the successor/source struct. That may not be a good idea. The parameter
994 // case could be handled but it will not have a mapping in `destToSrcToClone` in
995 // `getSelfRefMember()` and new members will still need to be added. They can be prefixed with
996 // parameter index since there is no current member name to use as the unique prefix. Handling
997 // that would require refactoring the inlining process a bit.
998 WalkResult res = currentFunc.walk([](CallOp c) {
999 return getMemberReadThatDefinesSelfValuePassedToConstrain(c)
1000 ? WalkResult::interrupt() // use interrupt to indicate success
1001 : WalkResult::advance();
1002 });
1003 LLVM_DEBUG({
1004 llvm::dbgs() << "[canInline] " << successorFunc.getFullyQualifiedName() << " into "
1005 << currentFunc.getFullyQualifiedName() << "? " << res.wasInterrupted() << '\n';
1006 });
1007 return res.wasInterrupted();
1008 }
1009
1010 static LogicalResult
1011 verifyNoTemplateSymbolBindings(const SymbolUseGraph &useGraph, SymbolTableCollection &tables) {
1012 for (const SymbolUseGraphNode *node : useGraph.nodesIter()) {
1013 if (!node->isTemplateSymbolBinding()) {
1014 continue;
1015 }
1016
1017 // Try to get the location of the TemplateOp to report an error.
1018 Operation *lookupFrom = node->getSymbolPathRoot().getOperation();
1019 SymbolRefAttr prefix = getPrefixAsSymbolRefAttr(node->getSymbolPath());
1020 auto res = lookupSymbolIn<TemplateOp>(tables, prefix, lookupFrom, lookupFrom, false);
1021 // If that lookup did not work for some reason, report at the path root location.
1022 Operation *reportLoc = succeeded(res) ? res->get() : lookupFrom;
1023 return reportLoc->emitError() << "Cannot inline struct within a template. Run "
1024 "`llzk-flatten` to instantiate templated structs.";
1025 }
1026 return success();
1027 }
1028
1029 static LogicalResult emitConstrainReachableCycleError(
1030 ArrayRef<const SymbolUseGraphNode *> dfsStack, const SymbolUseGraphNode *cycleHead,
1031 SymbolTableCollection &tables
1032 ) {
1033 SmallVector<const SymbolUseGraphNode *, 8> cycle;
1034 bool inCycle = false;
1035 for (const SymbolUseGraphNode *node : dfsStack) {
1036 if (node == cycleHead) {
1037 inCycle = true;
1038 }
1039 if (inCycle) {
1040 cycle.push_back(node);
1041 }
1042 }
1043 if (cycle.empty()) {
1044 cycle.push_back(cycleHead);
1045 }
1046
1047 Operation *reportOp = cycleHead->getSymbolPathRoot().getOperation();
1048 for (const SymbolUseGraphNode *node : cycle) {
1049 if (!node->isRealNode()) {
1050 continue;
1051 }
1052 auto lookupRes = node->lookupSymbol(tables, /*reportMissing=*/false);
1053 if (failed(lookupRes)) {
1054 continue;
1055 }
1056 Operation *op = lookupRes->get();
1057 reportOp = op;
1058 if (llvm::isa<FuncDefOp>(op)) {
1059 break;
1060 }
1061 }
1062
1063 InFlightDiagnostic diag = reportOp->emitError();
1064 diag << "Cannot inline structs when a symbol-use cycle is reachable from a struct "
1065 "\"@constrain\" function. Prover-side recursion is allowed only when "
1066 "\"@constrain\" cannot reach it.";
1067
1068 for (const SymbolUseGraphNode *node : cycle) {
1069 if (!node->isRealNode()) {
1070 continue;
1071 }
1072 if (auto lookupRes = node->lookupSymbol(tables, /*reportMissing=*/false);
1073 succeeded(lookupRes)) {
1074 diag.attachNote(lookupRes->get()->getLoc()) << "cycle contains " << node->getSymbolPath();
1075 } else {
1076 diag.attachNote(node->getSymbolPathRoot().getLoc())
1077 << "cycle contains " << node->getSymbolPath();
1078 }
1079 }
1080
1081 return failure();
1082 }
1083
1089 static LogicalResult computeConstrainReachablePostOrder(
1090 const SymbolUseGraph &useGraph, SymbolTableCollection &tables,
1091 SmallVectorImpl<const SymbolUseGraphNode *> &postOrder
1092 ) {
1093 enum class VisitState : std::uint8_t { Active, Done };
1094
1095 DenseMap<const SymbolUseGraphNode *, VisitState> state;
1096 SmallVector<const SymbolUseGraphNode *, 32> dfsStack;
1097
1098 auto dfs = [&](auto &&self, const SymbolUseGraphNode *node) -> LogicalResult {
1099 auto seen = state.find(node);
1100 if (seen != state.end()) {
1101 if (seen->second == VisitState::Active) {
1102 return emitConstrainReachableCycleError(dfsStack, node, tables);
1103 }
1104 return success();
1105 }
1106
1107 state[node] = VisitState::Active;
1108 dfsStack.push_back(node);
1109 for (const SymbolUseGraphNode *successor : node->successorIter()) {
1110 if (failed(self(self, successor))) {
1111 return failure();
1112 }
1113 }
1114 dfsStack.pop_back();
1115
1116 state[node] = VisitState::Done;
1117 postOrder.push_back(node);
1118 return success();
1119 };
1120
1121 for (const SymbolUseGraphNode *node : useGraph.nodesIter()) {
1122 if (!getIfResolvableStructConstrain(node, tables)) {
1123 continue;
1124 }
1125 if (failed(dfs(dfs, node))) {
1126 return failure();
1127 }
1128 }
1129
1130 return success();
1131 }
1132
1137 inline FailureOr<InliningPlan>
1138 makePlan(const SymbolUseGraph &useGraph, SymbolTableCollection &tables) {
1139 LLVM_DEBUG({
1140 llvm::dbgs() << "Running InlineStructsPass with max complexity ";
1141 if (maxComplexity == 0) {
1142 llvm::dbgs() << "unlimited";
1143 } else {
1144 llvm::dbgs() << maxComplexity;
1145 }
1146 llvm::dbgs() << '\n';
1147 });
1148 InliningPlan retVal;
1149 DenseMap<const SymbolUseGraphNode *, uint64_t> complexityMemo;
1150
1151 if (failed(verifyNoTemplateSymbolBindings(useGraph, tables))) {
1152 return failure();
1153 }
1154
1155 SmallVector<const SymbolUseGraphNode *, 32> constrainPostOrder;
1156 if (failed(computeConstrainReachablePostOrder(useGraph, tables, constrainPostOrder))) {
1157 return failure();
1158 }
1159
1160 // Traverse "constrain" function nodes to compute their complexity and an inlining plan. Use
1161 // post-order traversal so the complexity of all successor nodes is computed before computing
1162 // the current node's complexity.
1163 for (const SymbolUseGraphNode *currentNode : constrainPostOrder) {
1164 LLVM_DEBUG(llvm::dbgs() << "\ncurrentNode = " << currentNode->toString());
1165 FuncDefOp currentFunc = getIfResolvableStructConstrain(currentNode, tables);
1166 if (!currentFunc) {
1167 continue;
1168 }
1169 uint64_t currentComplexity = complexity(currentFunc);
1170 // If the current complexity is already too high, store it and continue.
1171 if (exceedsMaxComplexity(currentComplexity)) {
1172 complexityMemo[currentNode] = currentComplexity;
1173 continue;
1174 }
1175 // Otherwise, make a plan that adds successor "constrain" functions unless the
1176 // complexity becomes too high by adding that successor.
1177 SmallVector<StructDefOp> successorsToMerge;
1178 for (const SymbolUseGraphNode *successor : currentNode->successorIter()) {
1179 LLVM_DEBUG(llvm::dbgs().indent(2) << "successor: " << successor->toString() << '\n');
1180 // Note: all "constrain" function nodes will have a value, and all other nodes will not.
1181 auto memoResult = complexityMemo.find(successor);
1182 if (memoResult == complexityMemo.end()) {
1183 continue; // inner loop
1184 }
1185 uint64_t sComplexity = memoResult->second;
1186 assert(
1187 sComplexity <= (std::numeric_limits<uint64_t>::max() - currentComplexity) &&
1188 "addition will overflow"
1189 );
1190 uint64_t potentialComplexity = currentComplexity + sComplexity;
1191 if (!exceedsMaxComplexity(potentialComplexity)) {
1192 currentComplexity = potentialComplexity;
1193 FuncDefOp successorFunc = getIfResolvableStructConstrain(successor, tables);
1194 if (!successorFunc) {
1195 continue;
1196 }
1197 if (canInline(currentFunc, successorFunc)) {
1198 successorsToMerge.push_back(getParentStruct(successorFunc));
1199 }
1200 }
1201 }
1202 complexityMemo[currentNode] = currentComplexity;
1203 if (!successorsToMerge.empty()) {
1204 retVal.emplace_back(getParentStruct(currentFunc), std::move(successorsToMerge));
1205 }
1206 }
1207 LLVM_DEBUG({
1208 llvm::dbgs() << "-----------------------------------------------------------------\n";
1209 llvm::dbgs() << "InlineStructsPass plan:\n";
1210 for (auto &[caller, callees] : retVal) {
1211 llvm::dbgs().indent(2) << "inlining the following into \"" << caller.getSymName() << "\"\n";
1212 for (StructDefOp c : callees) {
1213 llvm::dbgs().indent(4) << "\"" << c.getSymName() << "\"\n";
1214 }
1215 }
1216 llvm::dbgs() << "-----------------------------------------------------------------\n";
1217 });
1218 return retVal;
1219 }
1220
1221public:
1222 void runOnOperation() override {
1223 const SymbolUseGraph &useGraph = getAnalysis<SymbolUseGraph>();
1224 LLVM_DEBUG(useGraph.dumpToDotFile());
1225
1226 SymbolTableCollection tables;
1227 FailureOr<InliningPlan> plan = makePlan(useGraph, tables);
1228 if (failed(plan)) {
1229 signalPassFailure(); // error already printed w/in makePlan()
1230 return;
1231 }
1232
1233 if (failed(performInlining(tables, plan.value()))) {
1234 signalPassFailure();
1235 return;
1236 };
1237 }
1238};
1239
1240} // namespace
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Definition LICENSE.txt:9
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for and distribution as defined by Sections through of this document Licensor shall mean the copyright owner or entity authorized by the copyright owner that is granting the License Legal Entity shall mean the union of the acting entity and all other entities that control are controlled by or are under common control with that entity For the purposes of this definition control direct or to cause the direction or management of such whether by contract or including but not limited to software source documentation and configuration files Object form shall mean any form resulting from mechanical transformation or translation of a Source including but not limited to compiled object generated and conversions to other media types Work shall mean the work of whether in Source or Object made available under the as indicated by a copyright notice that is included in or attached to the whether in Source or Object that is based or other modifications as a an original work of authorship For the purposes of this Derivative Works shall not include works that remain separable from
Definition LICENSE.txt:45
LogicalResult performInlining(SymbolTableCollection &tables, InliningPlan &plan)
mlir::SmallVector< std::pair< llzk::component::StructDefOp, mlir::SmallVector< llzk::component::StructDefOp > > > InliningPlan
Maps caller struct to callees that should be inlined.
#define check(x)
Definition Ops.cpp:285
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:55
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
bool isRealNode() const
Return 'false' iff this node is an artificial node created for the graph head/tail.
bool isTemplateSymbolBinding() const
Return true iff the symbol is a defined by a TemplateSymbolBindingOpInterface.
mlir::SymbolRefAttr getSymbolPath() const
The symbol path+name relative to the closest root ModuleOp.
mlir::ModuleOp getSymbolPathRoot() const
Return the root ModuleOp for the path.
llvm::iterator_range< iterator > successorIter() const
Range over successor nodes.
void dumpToDotFile(std::string filename="") const
Dump the graph to file in dot graph format.
llvm::iterator_range< iterator > nodesIter() const
Range over all nodes in the graph.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Definition Ops.h.inc:688
::llvm::StringRef getMemberName()
Definition Ops.cpp.inc:974
::mlir::FailureOr< SymbolLookupResult< MemberDefOp > > getMemberDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the member referenced in this op.
Definition Ops.cpp:691
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Gets the SSA value with the target component from the MemberRefOp.
void setMemberName(::llvm::StringRef attrValue)
Sets the member name attribute value in the MemberRefOp.
::llvm::StringRef getMemberName()
Gets the member name attribute value from the MemberRefOp.
::mlir::OpOperand & getComponentMutable()
Gets the SSA value with the target component from the MemberRefOp.
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:1600
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:472
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:468
void print(::mlir::OpAsmPrinter &_odsPrinter)
Definition Ops.cpp.inc:1698
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:266
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:1156
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:270
static ::llvm::SmallVector<::mlir::ValueRange > toVectorOfValueRange(::mlir::OperandRangeRange)
Allocate consecutive storage of the ValueRange instances in the parameter so it can be passed to the ...
Definition Ops.cpp:1186
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:1151
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
Definition Ops.cpp:1161
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:302
FuncDefOp clone(::mlir::IRMapping &mapper)
Create a deep copy of this function and all of its blocks, remapping any operands that use values out...
::mlir::Value getSelfValueFromCompute()
Return the "self" value (i.e.
Definition Ops.cpp:457
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:984
::mlir::Value getSelfValueFromConstrain()
Return the "self" value (i.e.
Definition Ops.cpp:476
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:865
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
Definition Ops.h.inc:869
bool isStructConstrain()
Return true iff the function is within a StructDefOp and named FUNC_NAME_CONSTRAIN.
Definition Ops.h.inc:882
::mlir::SymbolRefAttr getFullyQualifiedName(bool requireParent=true)
Return the full name for this function from the root module, including all surrounding symbol table n...
Definition Ops.cpp:447
::mlir::Region & getBody()
Definition Ops.h.inc:690
std::string toStringOne(const T &value)
Definition Debug.h:182
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
Definition Debug.h:156
mlir::SymbolRefAttr getPrefixAsSymbolRefAttr(mlir::SymbolRefAttr symbol)
Return SymbolRefAttr like the one given but with the leaf/final element removed.
uint64_t computeEmitEqCardinality(Type type)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:186
mlir::DictionaryAttr withFunctionArgNameAttr(mlir::DictionaryAttr attrs, llvm::StringRef name)
Return a copy of the given argument attribute dictionary with function.arg_name set to name.
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent/ancestor operation that is of type 'OpClass'.
Definition OpHelpers.h:51
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
std::string reserveUniqueAttrName(llvm::StringSet<> &usedNames, llvm::StringRef desiredName)
Reserve and return a unique function argument/result name based on desiredName.