LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
PodToScalarPass.cpp
Go to the documentation of this file.
1//===-- PodToScalarPass.cpp -------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2026 Project LLZK
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
19/// record and remembers how each original member was split.
20///
21/// 2. Run a dialect conversion that does the following:
22///
23/// - Replace `MemberReadOp` and `MemberWriteOp` targeting the members that were split in step 1
24/// so they instead perform scalar reads and writes from the new members. The transformation is
25/// local to the current op. Therefore, when replacing the `MemberReadOp` a new pod is
26/// created locally and all uses of the `MemberReadOp` are replaced with the new pod Value,
27/// then each scalar member read is followed by scalar write into the new pod. Similarly,
28/// when replacing a `MemberWriteOp`, each element in the pod operand needs a scalar read
29/// from the pod followed by a scalar write to the new member. Making only local changes
30/// keeps this step simple and later steps will optimize.
31///
32/// - Remove optional initialization from `NewPodOp` and instead insert a list of `WritePodOp`
33/// immediately following.
34///
35/// - Split pods to scalars in `FuncDefOp`, `CallOp`, and `ReturnOp` and insert the necessary
36/// create/read/write ops so the changes are as local as possible (just as described for
37/// `MemberReadOp` and `MemberWriteOp`)
38///
39/// 3. Promote pod reads and writes out of `scf.if`, `scf.for`, and `scf.while` regions when the
40/// access can be modeled as an SSA value flowing through the region boundary. This puts the
41/// pod accesses that mem2reg must eliminate into a parent block or loop-carried value.
42///
43/// 4. Run MLIR "sroa" pass to split each pod with `N` records into `N` pods with 1 record each
44/// (to prepare for the "mem2reg" pass because its API cannot split memory by itself).
45///
46/// 5. Run MLIR "mem2reg" pass to convert all single-record pod allocations and accesses into SSA
47/// values.
48///
49/// ** Steps 4 and 5 are rerun while nested POD types are still being exposed, until a fixpoint.
50///
51/// Note: This transformation imposes a "last write wins" semantics on pod records. If
52/// different/configurable semantics are added in the future, some additional transformation would
53/// be necessary before/during this pass so that multiple writes to the same record can be handled
54/// properly while they still exist.
55///
56/// Note: This transformation will introduce a `nondet` op when there exists a read from a pod
57/// record that was not earlier written to.
58///
59//===----------------------------------------------------------------------===//
60
81#include "llzk/Util/Concepts.h"
82#include "llzk/Util/Walk.h"
83
84#include <mlir/Dialect/SCF/IR/SCF.h>
85#include <mlir/Pass/PassManager.h>
86#include <mlir/Transforms/DialectConversion.h>
87#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
88#include <mlir/Transforms/Passes.h>
89
90#include <llvm/ADT/DenseMapInfo.h>
91#include <llvm/ADT/STLExtras.h>
92#include <llvm/Support/Debug.h>
93
94// Include the generated base pass class definitions.
95namespace llzk::pod {
96#define GEN_PASS_DEF_PODTOSCALARPASS
98} // namespace llzk::pod
99
100using namespace mlir;
101using namespace llzk;
102using namespace llzk::pod;
103using namespace llzk::function;
104using namespace llzk::component;
105
106#define DEBUG_TYPE "llzk-pod-to-scalar"
107
108namespace {
109
111struct RecordChain {
112 SmallVector<StringAttr> nameList;
113
114 RecordChain() = default;
115
116 explicit RecordChain(ArrayRef<StringAttr> names) : nameList(names.begin(), names.end()) {}
117
118 bool operator==(const RecordChain &other) const { return nameList == other.nameList; }
119};
120
121} // namespace
122
123namespace llvm {
124template <> struct DenseMapInfo<RecordChain> {
125 static RecordChain getEmptyKey() {
126 return RecordChain {{DenseMapInfo<StringAttr>::getEmptyKey()}};
127 }
128
129 static RecordChain getTombstoneKey() {
130 return RecordChain {{DenseMapInfo<StringAttr>::getTombstoneKey()}};
131 }
132
133 static unsigned getHashValue(const RecordChain &chain) {
134 return llvm::hash_combine_range(chain.nameList.begin(), chain.nameList.end());
135 }
136
137 static bool isEqual(const RecordChain &lhs, const RecordChain &rhs) { return lhs == rhs; }
138};
139} // namespace llvm
140
141namespace {
142
144inline static PodType splittablePod(PodType pt) { return pt; }
145
147inline static PodType splittablePod(Type t) {
148 if (PodType pt = dyn_cast<PodType>(t)) {
149 return splittablePod(pt);
150 } else {
151 return nullptr;
152 }
153}
154
157inline static bool containsSplittablePodType(ArrayRef<Type> types) {
158 for (Type t : types) {
159 if (splittablePod(t)) {
160 return true;
161 }
162 }
163 return false;
164}
165
168template <typename T> static bool containsSplittablePodType(ValueTypeRange<T> types) {
169 for (Type t : types) {
170 if (splittablePod(t)) {
171 return true;
172 }
173 }
174 return false;
175}
176
179size_t splitPodTypeTo(Type t, SmallVector<Type> &collect) {
180 if (PodType pt = splittablePod(t)) {
181 auto records = pt.getRecords();
182 for (RecordAttr record : records) {
183 collect.push_back(record.getType());
184 }
185 return records.size();
186 } else {
187 collect.push_back(t);
188 return 1;
189 }
190}
191
193template <typename TypeCollection>
194inline void splitPodTypeTo(
195 TypeCollection types, SmallVector<Type> &collect, SmallVector<size_t> *originalIdxToSize
196) {
197 for (Type t : types) {
198 size_t count = splitPodTypeTo(t, collect);
199 if (originalIdxToSize) {
200 originalIdxToSize->push_back(count);
201 }
202 }
203}
204
207template <typename TypeCollection>
208inline SmallVector<Type>
209splitPodType(TypeCollection types, SmallVector<size_t> *originalIdxToSize = nullptr) {
210 SmallVector<Type> collect;
211 splitPodTypeTo(types, collect, originalIdxToSize);
212 return collect;
213}
214
216inline static ReadPodOp
217genRead(Location loc, Value podRef, StringAttr recordName, OpBuilder &rewriter) {
218 Type resultType =
219 llvm::cast<PodType>(podRef.getType()).getRecordMap().lookup(recordName.getValue());
220 return rewriter.create<ReadPodOp>(loc, resultType, podRef, recordName);
221}
222
224inline static WritePodOp
225genWrite(Location loc, Value podRef, StringAttr recordName, Value value, OpBuilder &rewriter) {
226 return rewriter.create<WritePodOp>(loc, podRef, recordName, value);
227}
228
230static SmallVector<std::string> getSplitRecordNameSuffixes(Type type) {
231 SmallVector<std::string> suffixes;
232 if (PodType pt = splittablePod(type)) {
233 suffixes.reserve(pt.getRecords().size());
234 for (RecordAttr record : pt.getRecords()) {
235 StringRef name = record.getName().getValue();
236 std::string result;
237 result.reserve(name.size() + 1);
238 result.push_back('.');
239 result.append(name.data(), name.size());
240 suffixes.push_back(result);
241 }
242 }
243 return suffixes;
244}
245
246// If the operand has PodType, add reads from all pod records to the `newOperands` list otherwise
247// add the original operand to the list.
248static void processInputOperand(
249 Location loc, Value operand, SmallVector<Value> &newOperands,
250 ConversionPatternRewriter &rewriter
251) {
252 if (PodType pt = splittablePod(operand.getType())) {
253 for (RecordAttr record : pt.getRecords()) {
254 newOperands.push_back(genRead(loc, operand, record.getName(), rewriter));
255 }
256 } else {
257 newOperands.push_back(operand);
258 }
259}
260
263static void processInputOperands(
264 ValueRange operands, MutableOperandRange outputOpRef, Operation *op,
265 ConversionPatternRewriter &rewriter
266) {
267 SmallVector<Value> newOperands;
268 for (Value v : operands) {
269 processInputOperand(op->getLoc(), v, newOperands, rewriter);
270 }
271 rewriter.modifyOpInPlace(op, [&outputOpRef, &newOperands]() {
272 outputOpRef.assign(ValueRange(newOperands));
273 });
274}
275
277inline static void baseTargetSetup(ConversionTarget &target) {
278 target.addLegalDialect<
283 scf::SCFDialect>();
284 target.addLegalOp<ModuleOp>();
285}
286
289class NondetToNewPod : public OpConversionPattern<NonDetOp> {
290 using OpConversionPattern<NonDetOp>::OpConversionPattern;
291 LogicalResult matchAndRewrite(
292 NonDetOp nondetOp, OpAdaptor, ConversionPatternRewriter &rewriter
293 ) const override {
294 if (auto pt = dyn_cast<PodType>(nondetOp.getType())) {
295 rewriter.replaceOpWithNewOp<NewPodOp>(nondetOp, pt);
296 return success();
297 }
298 return failure();
299 }
300};
301
303static LogicalResult step0(ModuleOp modOp) {
304 MLIRContext *ctx = modOp.getContext();
305 RewritePatternSet patterns {ctx};
306 patterns.add<NondetToNewPod>(ctx);
307 ConversionTarget target {*ctx};
308
309 baseTargetSetup(target);
310 target.addDynamicallyLegalOp<NonDetOp>([](NonDetOp op) { return !isa<PodType>(op.getType()); });
311
312 return applyFullConversion(modOp, target, std::move(patterns));
313}
314
316using MemberInfo = std::pair<StringAttr, Type>;
318using LocalMemberReplacementMap = DenseMap<RecordChain, MemberInfo>;
320using MemberReplacementMap = DenseMap<StructDefOp, DenseMap<StringAttr, LocalMemberReplacementMap>>;
321
323static StringAttr
324getFlattenedMemberName(MLIRContext *ctx, StringAttr memberName, ArrayRef<StringAttr> recordChain) {
325 std::string flatName;
326 llvm::raw_string_ostream os(flatName);
327 os << memberName.getValue();
328 for (StringAttr recordName : recordChain) {
329 os << '_' << recordName.getValue();
330 }
331 return StringAttr::get(ctx, flatName);
332}
333
335static void flattenPodMemberIntoLeaves(
336 MemberDefOp originalMember, PodType podTy, SmallVectorImpl<StringAttr> &recordChain,
337 LocalMemberReplacementMap &localRepMapRef, SymbolTable &structSymbolTable,
338 ConversionPatternRewriter &rewriter
339) {
340 for (RecordAttr record : podTy.getRecords()) {
341 recordChain.push_back(record.getName());
342 if (PodType nestedPodTy = dyn_cast<PodType>(record.getType())) {
343 flattenPodMemberIntoLeaves(
344 originalMember, nestedPodTy, recordChain, localRepMapRef, structSymbolTable, rewriter
345 );
346 recordChain.pop_back();
347 continue;
348 }
349
350 StringAttr name = getFlattenedMemberName(
351 originalMember.getContext(), originalMember.getSymNameAttr(), recordChain
352 );
353 Type ty = record.getType();
354 MemberDefOp newMember = rewriter.create<MemberDefOp>(
355 originalMember.getLoc(), name, ty, originalMember.getSignal(), originalMember.getColumn()
356 );
357 newMember.setPublicAttr(originalMember.hasPublicAttr());
358 localRepMapRef[RecordChain(recordChain)] =
359 std::make_pair(structSymbolTable.insert(newMember), ty);
360 recordChain.pop_back();
361 }
362}
363
368class SplitPodInMemberDefOp : public OpConversionPattern<MemberDefOp> {
369 SymbolTableCollection &tables;
370 MemberReplacementMap &repMapRef;
371
372public:
373 SplitPodInMemberDefOp(
374 MLIRContext *ctx, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap
375 )
376 : OpConversionPattern<MemberDefOp>(ctx), tables(symTables), repMapRef(memberRepMap) {}
377
378 inline static bool legal(MemberDefOp op) { return !splittablePod(op.getType()); }
379
380 LogicalResult match(MemberDefOp op) const override { return failure(legal(op)); }
381
382 void
383 rewrite(MemberDefOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
384 StructDefOp inStruct = op->getParentOfType<StructDefOp>();
385 assert(inStruct);
386 LocalMemberReplacementMap &localRepMapRef = repMapRef[inStruct][op.getSymNameAttr()];
387
388 PodType podTy = llvm::cast<PodType>(adaptor.getType()); // safe per legal() check
389
390 SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct);
391 SmallVector<StringAttr> recordChain;
392 flattenPodMemberIntoLeaves(op, podTy, recordChain, localRepMapRef, structSymbolTable, rewriter);
393 rewriter.eraseOp(op);
394 }
395};
396
398static LogicalResult
399step1(ModuleOp modOp, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap) {
400 MLIRContext *ctx = modOp.getContext();
401
402 RewritePatternSet patterns(ctx);
403
404 patterns.add<SplitPodInMemberDefOp>(ctx, symTables, memberRepMap);
405
406 ConversionTarget target(*ctx);
407 baseTargetSetup(target);
408 target.addDynamicallyLegalOp<MemberDefOp>(SplitPodInMemberDefOp::legal);
409
410 LLVM_DEBUG(llvm::dbgs() << "Begin step 1: split pod-type members\n";);
411 return applyFullConversion(modOp, target, std::move(patterns));
412}
413
415class SplitInitFromNewPodOp : public OpConversionPattern<NewPodOp> {
416public:
417 using OpConversionPattern<NewPodOp>::OpConversionPattern;
418
419 static bool legal(NewPodOp op) { return op.getInitialValues().empty(); }
420
421 LogicalResult match(NewPodOp op) const override { return failure(legal(op)); }
422
423 void rewrite(NewPodOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
424 // Generate an individual write for each initialization
425 rewriter.setInsertionPointAfter(op);
426 Location loc = op.getLoc();
427 for (auto [name, init] :
428 llvm::zip_equal(adaptor.getInitializedRecords(), adaptor.getInitialValues())) {
429 // Create the write
430 rewriter.create<WritePodOp>(loc, op.getResult(), llvm::cast<StringAttr>(name), init);
431 }
432 // Remove initializations from `op`
433 rewriter.modifyOpInPlace(op, [&op]() {
434 op.getInitialValuesMutable().clear();
435 op.setInitializedRecordsAttr(ArrayAttr::get(op.getContext(), {})); // DefaultValuedAttr:{}
436 });
437 }
438};
439
447class SplitPodInFuncDefOp : public OpConversionPattern<FuncDefOp> {
448public:
449 using OpConversionPattern<FuncDefOp>::OpConversionPattern;
450
451 inline static bool legal(FuncDefOp op) {
452 return !containsSplittablePodType(op.getArgumentTypes()) &&
453 !containsSplittablePodType(op.getResultTypes());
454 }
455
456 LogicalResult match(FuncDefOp op) const override { return failure(legal(op)); }
457
458 void rewrite(FuncDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override {
459 // Update in/out types of the function to replace pods with scalars
460 class Impl : public FunctionTypeConverter {
461 SmallVector<size_t> originalInputIdxToSize, originalResultIdxToSize;
462 SplitFunctionNameInfo inputNameInfo;
463 SplitFunctionNameInfo resultNameInfo;
464
465 protected:
466 SmallVector<Type> convertInputs(ArrayRef<Type> origTypes) override {
467 return splitPodType(origTypes, &originalInputIdxToSize);
468 }
469 SmallVector<Type> convertResults(ArrayRef<Type> origTypes) override {
470 return splitPodType(origTypes, &originalResultIdxToSize);
471 }
472 ArrayAttr convertInputAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes) override {
474 origAttrs, originalInputIdxToSize, newTypes, ARG_NAME_ATTR_NAME,
475 inputNameInfo.originalNames, inputNameInfo.existingNames,
476 inputNameInfo.splitNameSuffixes
477 );
478 }
479 ArrayAttr convertResultAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes) override {
481 origAttrs, originalResultIdxToSize, newTypes, RES_NAME_ATTR_NAME,
482 resultNameInfo.originalNames, resultNameInfo.existingNames,
483 resultNameInfo.splitNameSuffixes
484 );
485 }
486
491 void processBlockArgs(Block &entryBlock, RewriterBase &rewriter) override {
492 OpBuilder::InsertionGuard guard(rewriter);
493 rewriter.setInsertionPointToStart(&entryBlock);
494
495 for (unsigned i = 0; i < entryBlock.getNumArguments();) {
496 Value oldV = entryBlock.getArgument(i);
497 if (PodType pt = splittablePod(oldV.getType())) {
498 Location loc = oldV.getLoc();
499 // Generate `NewPodOp` and replace uses of the argument with it.
500 auto newPod = rewriter.create<NewPodOp>(loc, pt);
501 rewriter.replaceAllUsesWith(oldV, newPod);
502 // Remove the argument from the block
503 entryBlock.eraseArgument(i);
504 // For all indices in the PodType (i.e., the element count), generate a new
505 // block argument and a write of that argument to the new pod.
506 for (RecordAttr record : pt.getRecords()) {
507 BlockArgument newArg = entryBlock.insertArgument(i, record.getType(), loc);
508 genWrite(loc, newPod, record.getName(), newArg, rewriter);
509 ++i;
510 }
511 } else {
512 ++i;
513 }
514 }
515 }
516
517 public:
518 Impl(FuncDefOp op) {
519 inputNameInfo = collectSplitFunctionNameInfo(op.getArgumentTypes(), [&op](unsigned i) {
520 return op.getArgNameAttr(i);
521 }, getSplitRecordNameSuffixes);
522 resultNameInfo = collectSplitFunctionNameInfo(
523 op.getResultTypes(), [resultAttrs = op.getAllResultAttrs()](unsigned i) {
524 return getAttrAtIndexWithName(resultAttrs, i, RES_NAME_ATTR_NAME);
525 }, getSplitRecordNameSuffixes
526 );
527 }
528 };
529 Impl(op).convert(op, rewriter);
530 }
531};
532
538class SplitPodInReturnOp : public OpConversionPattern<ReturnOp> {
539public:
540 using OpConversionPattern<ReturnOp>::OpConversionPattern;
541
542 inline static bool legal(ReturnOp op) {
543 return !containsSplittablePodType(op.getOperands().getTypes());
544 }
545
546 LogicalResult match(ReturnOp op) const override { return failure(legal(op)); }
547
548 void rewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
549 processInputOperands(adaptor.getOperands(), op.getOperandsMutable(), op, rewriter);
550 }
551};
552
554static CallOp newCallOpWithSplitResults(
555 CallOp oldCall, CallOp::Adaptor adaptor, ConversionPatternRewriter &rewriter
556) {
557 OpBuilder::InsertionGuard guard(rewriter);
558 rewriter.setInsertionPointAfter(oldCall);
559
560 Operation::result_range oldResults = oldCall.getResults();
562 oldCall.getLoc(), splitPodType(oldResults.getTypes()), oldCall, adaptor.getMapOperands(),
563 adaptor.getArgOperands(), rewriter
564 );
565
566 auto newResults = newCall.getResults().begin();
567 for (Value oldVal : oldResults) {
568 if (PodType pt = splittablePod(oldVal.getType())) {
569 Location loc = oldVal.getLoc();
570 // Generate `NewPodOp` and replace uses of the result with it.
571 auto newPod = rewriter.create<NewPodOp>(loc, pt);
572 rewriter.replaceAllUsesWith(oldVal, newPod);
573
574 // For each record in the PodType, write the next result from the new CallOp to the new pod.
575 for (RecordAttr record : pt.getRecords()) {
576 genWrite(loc, newPod, record.getName(), *newResults, rewriter);
577 newResults++;
578 }
579 } else {
580 rewriter.replaceAllUsesWith(oldVal, *newResults);
581 newResults++;
582 }
583 }
584 // erase the original CallOp
585 rewriter.eraseOp(oldCall);
586
587 return newCall;
588}
589
596class SplitPodInCallOp : public OpConversionPattern<CallOp> {
597public:
598 using OpConversionPattern<CallOp>::OpConversionPattern;
599
600 inline static bool legal(CallOp op) {
601 return !containsSplittablePodType(op.getArgOperands().getTypes()) &&
602 !containsSplittablePodType(op.getResultTypes());
603 }
604
605 LogicalResult match(CallOp op) const override { return failure(legal(op)); }
606
607 void rewrite(CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
608 // Create new CallOp with split results first so, then process its inputs to split types
609 CallOp newCall = newCallOpWithSplitResults(op, adaptor, rewriter);
610 processInputOperands(
611 newCall.getArgOperands(), newCall.getArgOperandsMutable(), newCall, rewriter
612 );
613 }
614};
615
617static Value
618genReadAlongPath(Location loc, Value podRef, RecordChain recordChain, OpBuilder &rewriter) {
619 Value value = podRef;
620 for (StringAttr attr : recordChain.nameList) {
621 value = genRead(loc, value, attr, rewriter);
622 }
623 return value;
624}
625
627struct RebuildPodReadState {
628 NewPodOp pod;
629 DenseMap<RecordChain, Value> leafValues;
630};
631
633static Value rebuildFlattenedPodRecord(
634 Location loc, Type recordType, SmallVectorImpl<StringAttr> &recordChain,
635 const DenseMap<RecordChain, Value> &leafValues, ConversionPatternRewriter &rewriter
636) {
637 if (PodType nestedPodTy = dyn_cast<PodType>(recordType)) {
638 NewPodOp nestedPod = rewriter.create<NewPodOp>(loc, nestedPodTy);
639 for (RecordAttr record : nestedPodTy.getRecords()) {
640 recordChain.push_back(record.getName());
641 Value recordValue =
642 rebuildFlattenedPodRecord(loc, record.getType(), recordChain, leafValues, rewriter);
643 genWrite(loc, nestedPod, record.getName(), recordValue, rewriter);
644 recordChain.pop_back();
645 }
646 return nestedPod;
647 }
648
649 auto it = leafValues.find(RecordChain(recordChain));
650 assert(it != leafValues.end() && "missing flattened POD leaf value");
651 return it->second;
652}
653
655class SplitPodInMemberWriteOp : public SplitAggregateInMemberRefOp<
656 SplitPodInMemberWriteOp, MemberWriteOp, void *, RecordChain> {
657public:
658 using SplitAggregateInMemberRefOp<
659 SplitPodInMemberWriteOp, MemberWriteOp, void *, RecordChain>::SplitAggregateInMemberRefOp;
660
661 static bool legal(MemberWriteOp op) { return !containsSplittablePodType(op.getVal().getType()); }
662
663 static void *genHeader(MemberWriteOp, ConversionPatternRewriter &) { return nullptr; }
664
665 static void forId(
666 Location loc, void *&, RecordChain id, MemberInfo newMember, OpAdaptor adaptor,
667 ConversionPatternRewriter &rewriter
668 ) {
669 Value scalarRead = genReadAlongPath(loc, adaptor.getVal(), id, rewriter);
670 rewriter.create<MemberWriteOp>(
671 loc, adaptor.getComponent(), FlatSymbolRefAttr::get(newMember.first), scalarRead
672 );
673 }
674};
675
677class SplitPodInMemberReadOp
679 SplitPodInMemberReadOp, MemberReadOp, RebuildPodReadState, RecordChain> {
680public:
681 using SplitAggregateInMemberRefOp<
682 SplitPodInMemberReadOp, MemberReadOp, RebuildPodReadState,
683 RecordChain>::SplitAggregateInMemberRefOp;
684
685 static bool legal(MemberReadOp op) {
686 return !containsSplittablePodType(op.getResult().getType());
687 }
688
689 static RebuildPodReadState genHeader(MemberReadOp op, ConversionPatternRewriter &rewriter) {
690 RebuildPodReadState state;
691 state.pod = rewriter.create<NewPodOp>(op.getLoc(), llvm::cast<PodType>(op.getType()));
692 rewriter.replaceAllUsesWith(op, state.pod);
693 return state;
694 }
695
696 static void forId(
697 Location loc, RebuildPodReadState &state, RecordChain id, MemberInfo newMember,
698 OpAdaptor adaptor, ConversionPatternRewriter &rewriter
699 ) {
700 Value scalarRead = rewriter.create<MemberReadOp>(
701 loc, newMember.second, adaptor.getComponent(), newMember.first
702 );
703 state.leafValues[id] = scalarRead;
704 }
705
706 static void finalize(
707 MemberReadOp op, RebuildPodReadState &state, OpAdaptor, ConversionPatternRewriter &rewriter
708 ) {
709 auto podTy = llvm::cast<PodType>(op.getType());
710 SmallVector<StringAttr> recordChain;
711 for (RecordAttr record : podTy.getRecords()) {
712 recordChain.push_back(record.getName());
713 Value recordValue = rebuildFlattenedPodRecord(
714 op.getLoc(), record.getType(), recordChain, state.leafValues, rewriter
715 );
716 genWrite(op.getLoc(), state.pod, record.getName(), recordValue, rewriter);
717 recordChain.pop_back();
718 }
719 }
720};
721
724static LogicalResult
725step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap) {
726 MLIRContext *ctx = modOp.getContext();
727
728 RewritePatternSet patterns(ctx);
729 patterns.add<
730 // clang-format off
731 SplitInitFromNewPodOp,
732 SplitPodInFuncDefOp,
733 SplitPodInReturnOp,
734 SplitPodInCallOp
735 // clang-format on
736 >(ctx);
737
738 patterns.add<
739 // clang-format off
740 SplitPodInMemberWriteOp,
741 SplitPodInMemberReadOp
742 // clang-format on
743 >(ctx, symTables, memberRepMap);
744
745 ConversionTarget target(*ctx);
746 baseTargetSetup(target);
747 target.addDynamicallyLegalOp<NewPodOp>(SplitInitFromNewPodOp::legal);
748 target.addDynamicallyLegalOp<FuncDefOp>(SplitPodInFuncDefOp::legal);
749 target.addDynamicallyLegalOp<ReturnOp>(SplitPodInReturnOp::legal);
750 target.addDynamicallyLegalOp<CallOp>(SplitPodInCallOp::legal);
751 target.addDynamicallyLegalOp<MemberWriteOp>(SplitPodInMemberWriteOp::legal);
752 target.addDynamicallyLegalOp<MemberReadOp>(SplitPodInMemberReadOp::legal);
753
754 LLVM_DEBUG(llvm::dbgs() << "Begin step 2: update/split other pod ops\n";);
755 return applyFullConversion(modOp, target, std::move(patterns));
756}
757
759inline static StringAttr getRecordNameAsStringAttr(ReadPodOp readOp) {
760 return readOp.getRecordNameAttr().getLeafReference();
761}
762
764inline static StringAttr getRecordNameAsStringAttr(WritePodOp writeOp) {
765 return writeOp.getRecordNameAttr().getLeafReference();
766}
767
769inline static bool isSamePodRecord(ReadPodOp readOp, Value podRef, StringAttr recordName) {
770 return readOp.getPodRef() == podRef && getRecordNameAsStringAttr(readOp) == recordName;
771}
772
774inline static bool isSamePodRecord(WritePodOp writeOp, Value podRef, StringAttr recordName) {
775 return writeOp.getPodRef() == podRef && getRecordNameAsStringAttr(writeOp) == recordName;
776}
777
779static bool hasNestedWriteToRecord(Operation &op, Value podRef, StringAttr recordName) {
780 return walkContainsMatch<WritePodOp>(op, [&](WritePodOp writeOp) {
781 return writeOp.getOperation() != &op && isSamePodRecord(writeOp, podRef, recordName);
782 });
783}
784
786static bool hasReadFromRecord(Operation &op, Value podRef, StringAttr recordName) {
787 return walkContainsMatch<ReadPodOp>(op, [&](ReadPodOp readOp) {
788 return isSamePodRecord(readOp, podRef, recordName);
789 });
790}
791
793static bool hasValueUse(Operation &op, Value value) {
794 return walkContainsMatch<Operation *>(op, [&value](Operation *nestedOp) {
795 return llvm::is_contained(nestedOp->getOperands(), value);
796 });
797}
798
800static bool hasEarlierWriteInBlock(ReadPodOp readOp) {
801 Value podRef = readOp.getPodRef();
802 StringAttr recordName = getRecordNameAsStringAttr(readOp);
803
804 for (Operation &op : *readOp->getBlock()) {
805 if (&op == readOp.getOperation()) {
806 return false;
807 }
808
809 if (auto writeOp = dyn_cast<WritePodOp>(&op)) {
810 if (isSamePodRecord(writeOp, podRef, recordName)) {
811 return true;
812 }
813 continue;
814 }
815
816 if (hasNestedWriteToRecord(op, podRef, recordName)) {
817 return true;
818 }
819 }
820 return false;
821}
822
827static bool isValueDefinedInside(Operation *ancestor, Value value) {
828 if (Operation *defOp = value.getDefiningOp()) {
829 return ancestor->isAncestor(defOp);
830 }
831
832 auto blockArg = llvm::dyn_cast<BlockArgument>(value);
833 Operation *parentOp = blockArg.getOwner()->getParentOp();
834 return parentOp && ancestor->isAncestor(parentOp);
835}
836
838static WritePodOp findPrecedingWriteForIfRead(ReadPodOp readOp) {
839 auto ifOp = readOp->getParentOfType<scf::IfOp>();
840 if (!ifOp || readOp->getBlock()->getParentOp() != ifOp.getOperation()) {
841 return nullptr;
842 }
843 if (hasEarlierWriteInBlock(readOp)) {
844 return nullptr;
845 }
846
847 Block *ifBlock = ifOp->getBlock();
848 if (!ifBlock) {
849 return nullptr;
850 }
851
852 Value podRef = readOp.getPodRef();
853 StringAttr recordName = getRecordNameAsStringAttr(readOp);
854 WritePodOp replacement = nullptr;
855 for (Operation &op : *ifBlock) {
856 if (&op == ifOp.getOperation()) {
857 break;
858 }
859
860 if (auto writeOp = dyn_cast<WritePodOp>(&op)) {
861 if (isSamePodRecord(writeOp, podRef, recordName)) {
862 replacement = writeOp;
863 }
864 continue;
865 }
866
867 if (hasNestedWriteToRecord(op, podRef, recordName)) {
868 replacement = nullptr;
869 }
870 }
871
872 return replacement;
873}
874
876class ReplaceIfReadPattern final : public OpRewritePattern<ReadPodOp> {
877public:
878 using OpRewritePattern<ReadPodOp>::OpRewritePattern;
879
880 LogicalResult matchAndRewrite(ReadPodOp readOp, PatternRewriter &rewriter) const override {
881 auto ifOp = readOp->getParentOfType<scf::IfOp>();
882 if (!ifOp || readOp->getBlock()->getParentOp() != ifOp.getOperation()) {
883 return failure();
884 }
885 if (isValueDefinedInside(ifOp, readOp.getPodRef()) || hasEarlierWriteInBlock(readOp)) {
886 return failure();
887 }
888
889 if (WritePodOp writeOp = findPrecedingWriteForIfRead(readOp)) {
890 rewriter.replaceOp(readOp, writeOp.getValue());
891 return success();
892 }
893
894 rewriter.setInsertionPoint(ifOp);
895 rewriter.replaceOp(
896 readOp,
897 genRead(readOp.getLoc(), readOp.getPodRef(), getRecordNameAsStringAttr(readOp), rewriter)
898 .getResult()
899 );
900 return success();
901 }
902};
903
917class FoldIfCarriedPodReadAfterWritePattern final : public OpRewritePattern<ReadPodOp> {
918public:
919 using OpRewritePattern<ReadPodOp>::OpRewritePattern;
920
921 LogicalResult matchAndRewrite(ReadPodOp readOp, PatternRewriter &rewriter) const override {
922 auto podRes = dyn_cast<OpResult>(readOp.getPodRef());
923 if (!podRes) {
924 return failure();
925 }
926
927 auto ifOp = dyn_cast<scf::IfOp>(podRes.getOwner());
928 if (!ifOp) {
929 return failure();
930 }
931
932 auto writeOp = dyn_cast_or_null<WritePodOp>(readOp->getPrevNode());
933 if (!writeOp || getRecordNameAsStringAttr(writeOp) != getRecordNameAsStringAttr(readOp)) {
934 return failure();
935 }
936
937 auto valueRes = dyn_cast<OpResult>(writeOp.getValue());
938 if (!valueRes || valueRes.getOwner() != ifOp.getOperation()) {
939 return failure();
940 }
941
942 Value carriedPod = writeOp.getPodRef();
943 unsigned podResultIndex = podRes.getResultNumber();
944
945 auto thenYield = dyn_cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
946 if (!thenYield || thenYield.getOperand(podResultIndex) != carriedPod) {
947 return failure();
948 }
949
950 Region &elseRegion = ifOp.getElseRegion();
951 if (Block *elseBlock = elseRegion.empty() ? nullptr : &elseRegion.front()) {
952 auto elseYield = dyn_cast<scf::YieldOp>(elseBlock->getTerminator());
953 if (!elseYield || elseYield.getOperand(podResultIndex) != carriedPod) {
954 return failure();
955 }
956 }
957
958 rewriter.replaceOp(readOp, valueRes);
959 return success();
960 }
961};
962
967struct IfWriteSlot {
968 Value podRef;
969 StringAttr recordName;
970 Type type;
971 WritePodOp thenWrite;
972 WritePodOp elseWrite;
973 Value incomingValue;
974};
975
977static IfWriteSlot *
978lookupSlot(SmallVectorImpl<IfWriteSlot> &slots, Value podRef, StringAttr recordName) {
979 for (IfWriteSlot &slot : slots) {
980 if (slot.podRef == podRef && slot.recordName == recordName) {
981 return &slot;
982 }
983 }
984 return nullptr;
985}
986
988static IfWriteSlot &getOrCreateSlot(
989 SmallVectorImpl<IfWriteSlot> &slots, Value podRef, StringAttr recordName, Type type
990) {
991 if (IfWriteSlot *slot = lookupSlot(slots, podRef, recordName)) {
992 return *slot;
993 }
994 slots.push_back(IfWriteSlot {podRef, recordName, type, nullptr, nullptr, Value()});
995 return slots.back();
996}
997
999static Block *getElseBlockOrNull(scf::IfOp ifOp) {
1000 return ifOp.getElseRegion().empty() ? nullptr : &ifOp.getElseRegion().front();
1001}
1002
1004static void
1005collectDirectWrites(Block *block, bool isThenBlock, SmallVectorImpl<IfWriteSlot> &slots) {
1006 if (!block) {
1007 return;
1008 }
1009
1010 for (Operation &op : *block) {
1011 if (op.hasTrait<OpTrait::IsTerminator>()) {
1012 break;
1013 }
1014
1015 auto writeOp = dyn_cast<WritePodOp>(&op);
1016 if (!writeOp) {
1017 continue;
1018 }
1019
1020 IfWriteSlot &slot = getOrCreateSlot(
1021 slots, writeOp.getPodRef(), getRecordNameAsStringAttr(writeOp), writeOp.getValue().getType()
1022 );
1023 if (isThenBlock) {
1024 slot.thenWrite = writeOp;
1025 } else {
1026 slot.elseWrite = writeOp;
1027 }
1028 }
1029}
1030
1035static bool branchSlotCanBeLifted(Block *block, Value podRef, StringAttr recordName) {
1036 if (!block) {
1037 return true;
1038 }
1039
1040 bool seenDirectWrite = false;
1041 for (Operation &op : *block) {
1042 if (op.hasTrait<OpTrait::IsTerminator>()) {
1043 return true;
1044 }
1045
1046 if (auto writeOp = dyn_cast<WritePodOp>(&op)) {
1047 if (isSamePodRecord(writeOp, podRef, recordName)) {
1048 seenDirectWrite = true;
1049 continue;
1050 }
1051 }
1052
1053 if (hasNestedWriteToRecord(op, podRef, recordName)) {
1054 return false;
1055 }
1056 if (seenDirectWrite && (hasReadFromRecord(op, podRef, recordName) || hasValueUse(op, podRef))) {
1057 return false;
1058 }
1059 }
1060 return true;
1061}
1062
1064static bool isLiftedWrite(Operation &op, ArrayRef<IfWriteSlot> slots) {
1065 auto writeOp = dyn_cast<WritePodOp>(&op);
1066 return writeOp && llvm::any_of(slots, [&writeOp](const IfWriteSlot &slot) {
1067 return isSamePodRecord(writeOp, slot.podRef, slot.recordName);
1068 });
1069}
1070
1072static scf::YieldOp getYieldOp(Block &block) {
1073 auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator());
1074 assert(yieldOp && "expected scf.if branch to terminate with scf.yield");
1075 return yieldOp;
1076}
1077
1079static void dropTerminatorIfPresent(Block &block) {
1080 if (!block.empty() && block.back().hasTrait<OpTrait::IsTerminator>()) {
1081 block.back().erase();
1082 }
1083}
1084
1086static void
1087moveBranchWithoutLiftedWrites(Block *srcBlock, Block &destBlock, ArrayRef<IfWriteSlot> slots) {
1088 if (srcBlock) {
1089 for (auto it = srcBlock->begin(), end = srcBlock->end(); it != end;) {
1090 Operation &op = *it++;
1091 if (op.hasTrait<OpTrait::IsTerminator>() || isLiftedWrite(op, slots)) {
1092 continue;
1093 }
1094 op.moveBefore(&destBlock, destBlock.end());
1095 }
1096 }
1097}
1098
1101static void appendYield(
1102 Location loc, Block &block, ValueRange priorYieldValues, ArrayRef<IfWriteSlot> slots,
1103 bool isThenBlock, OpBuilder &builder
1104) {
1105 SmallVector<Value> yieldValues = llvm::to_vector(priorYieldValues);
1106 llvm::append_range(yieldValues, llvm::map_range(slots, [isThenBlock](const IfWriteSlot &slot) {
1107 WritePodOp writeOp = isThenBlock ? slot.thenWrite : slot.elseWrite;
1108 return writeOp ? writeOp.getValue() : slot.incomingValue;
1109 }));
1110
1111 builder.setInsertionPointToEnd(&block);
1112 builder.create<scf::YieldOp>(loc, yieldValues);
1113}
1114
1120struct LoopPodSlot {
1121 Value podRef;
1122 StringAttr recordName;
1123 Type type;
1124
1126 bool matches(Value findPodRef, StringAttr findRecordName) const {
1127 return this->podRef == findPodRef && this->recordName == findRecordName;
1128 }
1129};
1130
1132static LoopPodSlot *
1133lookupLoopSlot(SmallVectorImpl<LoopPodSlot> &slots, Value podRef, StringAttr recordName) {
1134 auto it = llvm::find_if(slots, [&podRef, &recordName](const LoopPodSlot &slot) {
1135 return slot.matches(podRef, recordName);
1136 });
1137 return it == slots.end() ? nullptr : &*it;
1138}
1139
1141static bool hasLoopSlot(ArrayRef<LoopPodSlot> slots, Value podRef, StringAttr recordName) {
1142 auto it = llvm::find_if(slots, [&podRef, &recordName](const LoopPodSlot &slot) {
1143 return slot.matches(podRef, recordName);
1144 });
1145 return it != slots.end();
1146}
1147
1149static LoopPodSlot &getOrCreateLoopSlot(
1150 SmallVectorImpl<LoopPodSlot> &slots, Value podRef, StringAttr recordName, Type type
1151) {
1152 if (LoopPodSlot *slot = lookupLoopSlot(slots, podRef, recordName)) {
1153 return *slot;
1154 }
1155 slots.push_back(LoopPodSlot {podRef, recordName, type});
1156 return slots.back();
1157}
1158
1160static std::optional<size_t>
1161findLoopSlotIndex(ArrayRef<LoopPodSlot> slots, Value podRef, StringAttr recordName) {
1162 for (auto [idx, slot] : llvm::enumerate(slots)) {
1163 if (slot.podRef == podRef && slot.recordName == recordName) {
1164 return idx;
1165 }
1166 }
1167 return std::nullopt;
1168}
1169
1172static void
1173collectDirectLoopPodSlots(Block &block, Operation *ancestor, SmallVectorImpl<LoopPodSlot> &slots) {
1174 for (Operation &op : block) {
1175 if (auto readOp = dyn_cast<ReadPodOp>(&op)) {
1176 if (!isValueDefinedInside(ancestor, readOp.getPodRef())) {
1177 getOrCreateLoopSlot(
1178 slots, readOp.getPodRef(), getRecordNameAsStringAttr(readOp), readOp.getType()
1179 );
1180 }
1181 continue;
1182 }
1183
1184 if (auto writeOp = dyn_cast<WritePodOp>(&op)) {
1185 if (!isValueDefinedInside(ancestor, writeOp.getPodRef())) {
1186 getOrCreateLoopSlot(
1187 slots, writeOp.getPodRef(), getRecordNameAsStringAttr(writeOp),
1188 writeOp.getValue().getType()
1189 );
1190 }
1191 }
1192 }
1193}
1194
1196static bool opUsesTrackedPodRefDirectly(Operation &op, ArrayRef<LoopPodSlot> slots) {
1197 return llvm::any_of(op.getOperands(), [&slots](Value operand) {
1198 return llvm::any_of(slots, [&operand](const LoopPodSlot &slot) {
1199 return slot.podRef == operand;
1200 });
1201 });
1202}
1203
1205static bool hasNestedTrackedPodAccess(Operation &op, ArrayRef<LoopPodSlot> slots) {
1206 return op
1207 .walk([&op, &slots](Operation *nestedOp) {
1208 if (nestedOp == &op) {
1209 return WalkResult::advance();
1210 }
1211
1212 if (auto readOp = dyn_cast<ReadPodOp>(nestedOp)) {
1213 if (hasLoopSlot(slots, readOp.getPodRef(), getRecordNameAsStringAttr(readOp))) {
1214 return WalkResult::interrupt();
1215 }
1216 return WalkResult::advance();
1217 }
1218
1219 if (auto writeOp = dyn_cast<WritePodOp>(nestedOp)) {
1220 if (hasLoopSlot(slots, writeOp.getPodRef(), getRecordNameAsStringAttr(writeOp))) {
1221 return WalkResult::interrupt();
1222 }
1223 }
1224 return WalkResult::advance();
1225 }).wasInterrupted();
1226}
1227
1230static bool hasUnliftableLoopPodUses(Block &block, ArrayRef<LoopPodSlot> slots) {
1231 for (Operation &op : block) {
1232 if (isa<ReadPodOp, WritePodOp>(op)) {
1233 continue;
1234 }
1235 if (opUsesTrackedPodRefDirectly(op, slots) || hasNestedTrackedPodAccess(op, slots)) {
1236 return true;
1237 }
1238 }
1239 return false;
1240}
1241
1245class LiftPodWritesFromIfBlocksPattern final : public OpRewritePattern<scf::IfOp> {
1246public:
1247 using OpRewritePattern<scf::IfOp>::OpRewritePattern;
1248
1249 LogicalResult matchAndRewrite(scf::IfOp ifOp, PatternRewriter &rewriter) const override {
1250 SmallVector<IfWriteSlot> slots;
1251 Block &thenBlock = *ifOp.thenBlock();
1252 Block *elseBlock = getElseBlockOrNull(ifOp);
1253 collectDirectWrites(&thenBlock, true, slots);
1254 collectDirectWrites(elseBlock, false, slots);
1255 if (slots.empty()) {
1256 return failure();
1257 }
1258
1259 llvm::erase_if(slots, [&](const IfWriteSlot &slot) {
1260 return isValueDefinedInside(ifOp, slot.podRef) ||
1261 !branchSlotCanBeLifted(&thenBlock, slot.podRef, slot.recordName) ||
1262 !branchSlotCanBeLifted(elseBlock, slot.podRef, slot.recordName);
1263 });
1264 if (slots.empty()) {
1265 return failure();
1266 }
1267
1268 for (IfWriteSlot &slot : slots) {
1269 if (slot.thenWrite && slot.elseWrite) {
1270 continue;
1271 }
1272 rewriter.setInsertionPoint(ifOp);
1273 slot.incomingValue =
1274 genRead(ifOp.getLoc(), slot.podRef, slot.recordName, rewriter).getResult();
1275 }
1276
1277 SmallVector<Type> resultTypes = llvm::to_vector(ifOp.getResultTypes());
1278 llvm::append_range(resultTypes, llvm::map_range(slots, [](auto slot) { return slot.type; }));
1279
1280 SmallVector<Value> originalThenYields;
1281 if (!ifOp.getResults().empty()) {
1282 scf::YieldOp thenYieldOp = getYieldOp(thenBlock);
1283 originalThenYields.append(thenYieldOp.getOperands().begin(), thenYieldOp.getOperands().end());
1284 }
1285
1286 SmallVector<Value> originalElseYields;
1287 if (elseBlock && !ifOp.getResults().empty()) {
1288 scf::YieldOp elseYieldOp = getYieldOp(*elseBlock);
1289 originalElseYields.append(elseYieldOp.getOperands().begin(), elseYieldOp.getOperands().end());
1290 }
1291
1292 rewriter.setInsertionPoint(ifOp);
1293 auto newIf = rewriter.create<scf::IfOp>(ifOp.getLoc(), resultTypes, ifOp.getCondition(), true);
1294 Block &newThenBlock = *newIf.thenBlock();
1295 Block &newElseBlock = *newIf.elseBlock();
1296 dropTerminatorIfPresent(newThenBlock);
1297 dropTerminatorIfPresent(newElseBlock);
1298
1299 moveBranchWithoutLiftedWrites(&thenBlock, newThenBlock, slots);
1300 moveBranchWithoutLiftedWrites(elseBlock, newElseBlock, slots);
1301 appendYield(ifOp.getLoc(), newThenBlock, originalThenYields, slots, true, rewriter);
1302 appendYield(ifOp.getLoc(), newElseBlock, originalElseYields, slots, false, rewriter);
1303
1304 rewriter.setInsertionPointAfter(newIf);
1305 unsigned originalResultCount = ifOp.getNumResults();
1306 for (auto [idx, slot] : llvm::enumerate(slots)) {
1307 genWrite(
1308 ifOp.getLoc(), slot.podRef, slot.recordName, newIf.getResult(originalResultCount + idx),
1309 rewriter
1310 );
1311 }
1312
1313 rewriter.replaceOp(ifOp, newIf.getResults().take_front(originalResultCount));
1314 return success();
1315 }
1316};
1317
1320class LiftPodAccessesFromForLoopPattern final : public OpRewritePattern<scf::ForOp> {
1321public:
1322 using OpRewritePattern<scf::ForOp>::OpRewritePattern;
1323
1324 LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const override {
1325 Block &body = *forOp.getBody();
1326 SmallVector<LoopPodSlot> slots;
1327 collectDirectLoopPodSlots(body, forOp.getOperation(), slots);
1328 if (slots.empty() || hasUnliftableLoopPodUses(body, slots)) {
1329 return failure();
1330 }
1331
1332 Location loc = forOp.getLoc();
1333
1334 SmallVector<Value> newInitArgs = llvm::to_vector(forOp.getInitArgs());
1335 rewriter.setInsertionPoint(forOp);
1336 for (const LoopPodSlot &slot : slots) {
1337 newInitArgs.push_back(genRead(loc, slot.podRef, slot.recordName, rewriter).getResult());
1338 }
1339
1340 auto newFor = rewriter.create<scf::ForOp>(
1341 loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newInitArgs
1342 );
1343 newFor->setAttrs(forOp->getAttrs());
1344
1345 Block &newBody = *newFor.getBody();
1346 dropTerminatorIfPresent(newBody);
1347
1348 IRMapping mapping;
1349 mapping.map(forOp.getInductionVar(), newFor.getInductionVar());
1350 for (auto [idx, oldArg] : llvm::enumerate(forOp.getRegionIterArgs())) {
1351 mapping.map(oldArg, newFor.getRegionIterArg(idx));
1352 }
1353
1354 SmallVector<Value> slotValues = llvm::map_to_vector(
1355 llvm::seq<size_t>(0, slots.size()),
1356 [base = static_cast<size_t>(forOp.getNumRegionIterArgs()), &newFor](size_t idx) -> Value {
1357 return newFor.getRegionIterArg(llzk::checkedCast<unsigned>(base + idx));
1358 }
1359 );
1360
1361 rewriter.setInsertionPointToEnd(&newBody);
1362 for (Operation &op : body) {
1363 if (auto yieldOp = dyn_cast<scf::YieldOp>(&op)) {
1364 auto yieldValues = llvm::map_to_vector(yieldOp.getOperands(), [&mapping](Value operand) {
1365 return mapping.lookupOrDefault(operand);
1366 });
1367 llvm::append_range(yieldValues, slotValues);
1368 rewriter.create<scf::YieldOp>(yieldOp.getLoc(), yieldValues);
1369 continue;
1370 }
1371
1372 if (auto readOp = dyn_cast<ReadPodOp>(&op)) {
1373 if (std::optional<size_t> slotIdx =
1374 findLoopSlotIndex(slots, readOp.getPodRef(), getRecordNameAsStringAttr(readOp))) {
1375 mapping.map(readOp.getResult(), slotValues[*slotIdx]);
1376 continue;
1377 }
1378 }
1379
1380 if (auto writeOp = dyn_cast<WritePodOp>(&op)) {
1381 if (std::optional<size_t> slotIdx =
1382 findLoopSlotIndex(slots, writeOp.getPodRef(), getRecordNameAsStringAttr(writeOp))) {
1383 slotValues[*slotIdx] = mapping.lookupOrDefault(writeOp.getValue());
1384 continue;
1385 }
1386 }
1387
1388 rewriter.clone(op, mapping);
1389 }
1390
1391 rewriter.setInsertionPointAfter(newFor);
1392 for (auto [idx, slot] : llvm::enumerate(slots)) {
1393 genWrite(
1394 loc, slot.podRef, slot.recordName, newFor.getResult(forOp.getNumResults() + idx), rewriter
1395 );
1396 }
1397
1398 rewriter.replaceOp(forOp, newFor.getResults().take_front(forOp.getNumResults()));
1399 return success();
1400 }
1401};
1402
1405class LiftPodAccessesFromWhileLoopPattern final : public OpRewritePattern<scf::WhileOp> {
1406public:
1407 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
1408
1409 LogicalResult matchAndRewrite(scf::WhileOp whileOp, PatternRewriter &rewriter) const override {
1410 Block &beforeBody = *whileOp.getBeforeBody();
1411 Block &afterBody = *whileOp.getAfterBody();
1412
1413 SmallVector<LoopPodSlot> slots;
1414 collectDirectLoopPodSlots(beforeBody, whileOp.getOperation(), slots);
1415 collectDirectLoopPodSlots(afterBody, whileOp.getOperation(), slots);
1416 if (slots.empty() || hasUnliftableLoopPodUses(beforeBody, slots) ||
1417 hasUnliftableLoopPodUses(afterBody, slots)) {
1418 return failure();
1419 }
1420
1421 Location loc = whileOp.getLoc();
1422
1423 SmallVector<Value> newInits = llvm::to_vector(whileOp.getInits());
1424 SmallVector<Type> newResultTypes = llvm::to_vector(whileOp.getResultTypes());
1425 rewriter.setInsertionPoint(whileOp);
1426 for (const LoopPodSlot &slot : slots) {
1427 newInits.push_back(genRead(loc, slot.podRef, slot.recordName, rewriter).getResult());
1428 newResultTypes.push_back(slot.type);
1429 }
1430
1431 auto newWhile = rewriter.create<scf::WhileOp>(loc, newResultTypes, newInits, nullptr, nullptr);
1432 newWhile->setAttrs(whileOp->getAttrs());
1433
1434 Block &newBeforeBody = *newWhile.getBeforeBody();
1435 Block &newAfterBody = *newWhile.getAfterBody();
1436 dropTerminatorIfPresent(newBeforeBody);
1437 dropTerminatorIfPresent(newAfterBody);
1438
1439 IRMapping beforeMapping;
1440 for (auto [oldArg, newArg] : llvm::zip_equal(
1441 whileOp.getBeforeArguments(),
1442 newWhile.getBeforeArguments().take_front(whileOp.getBeforeArguments().size())
1443 )) {
1444 beforeMapping.map(oldArg, newArg);
1445 }
1446
1447 SmallVector<Value> beforeSlotValues = llvm::map_to_vector(
1448 llvm::seq<size_t>(0, slots.size()),
1449 [base = whileOp.getBeforeArguments().size(), &newWhile](size_t idx) -> Value {
1450 return newWhile.getBeforeArguments()[llzk::checkedCast<unsigned>(base + idx)];
1451 }
1452 );
1453
1454 rewriter.setInsertionPointToEnd(&newBeforeBody);
1455 for (Operation &op : beforeBody) {
1456 if (auto conditionOp = dyn_cast<scf::ConditionOp>(&op)) {
1457 SmallVector<Value> conditionArgs =
1458 llvm::map_to_vector(conditionOp.getArgs(), [&beforeMapping](Value a) {
1459 return beforeMapping.lookupOrDefault(a);
1460 });
1461 llvm::append_range(conditionArgs, beforeSlotValues);
1462 rewriter.create<scf::ConditionOp>(
1463 conditionOp.getLoc(), beforeMapping.lookupOrDefault(conditionOp.getCondition()),
1464 conditionArgs
1465 );
1466 continue;
1467 }
1468
1469 if (auto readOp = dyn_cast<ReadPodOp>(&op)) {
1470 if (std::optional<size_t> slotIdx =
1471 findLoopSlotIndex(slots, readOp.getPodRef(), getRecordNameAsStringAttr(readOp))) {
1472 beforeMapping.map(readOp.getResult(), beforeSlotValues[*slotIdx]);
1473 continue;
1474 }
1475 }
1476
1477 if (auto writeOp = dyn_cast<WritePodOp>(&op)) {
1478 if (std::optional<size_t> slotIdx =
1479 findLoopSlotIndex(slots, writeOp.getPodRef(), getRecordNameAsStringAttr(writeOp))) {
1480 beforeSlotValues[*slotIdx] = beforeMapping.lookupOrDefault(writeOp.getValue());
1481 continue;
1482 }
1483 }
1484
1485 rewriter.clone(op, beforeMapping);
1486 }
1487
1488 IRMapping afterMapping;
1489 for (auto [oldArg, newArg] : llvm::zip_equal(
1490 whileOp.getAfterArguments(),
1491 newWhile.getAfterArguments().take_front(whileOp.getAfterArguments().size())
1492 )) {
1493 afterMapping.map(oldArg, newArg);
1494 }
1495
1496 SmallVector<Value> afterSlotValues = llvm::map_to_vector(
1497 llvm::seq<size_t>(0, slots.size()),
1498 [base = whileOp.getAfterArguments().size(), &newWhile](size_t idx) -> Value {
1499 return newWhile.getAfterArguments()[llzk::checkedCast<unsigned>(base + idx)];
1500 }
1501 );
1502
1503 rewriter.setInsertionPointToEnd(&newAfterBody);
1504 for (Operation &op : afterBody) {
1505 if (auto yieldOp = dyn_cast<scf::YieldOp>(&op)) {
1506 SmallVector<Value> yieldValues =
1507 llvm::map_to_vector(yieldOp.getOperands(), [&afterMapping](Value v) {
1508 return afterMapping.lookupOrDefault(v);
1509 });
1510 llvm::append_range(yieldValues, afterSlotValues);
1511 rewriter.create<scf::YieldOp>(yieldOp.getLoc(), yieldValues);
1512 continue;
1513 }
1514
1515 if (auto readOp = dyn_cast<ReadPodOp>(&op)) {
1516 if (std::optional<size_t> slotIdx =
1517 findLoopSlotIndex(slots, readOp.getPodRef(), getRecordNameAsStringAttr(readOp))) {
1518 afterMapping.map(readOp.getResult(), afterSlotValues[*slotIdx]);
1519 continue;
1520 }
1521 }
1522
1523 if (auto writeOp = dyn_cast<WritePodOp>(&op)) {
1524 if (std::optional<size_t> slotIdx =
1525 findLoopSlotIndex(slots, writeOp.getPodRef(), getRecordNameAsStringAttr(writeOp))) {
1526 afterSlotValues[*slotIdx] = afterMapping.lookupOrDefault(writeOp.getValue());
1527 continue;
1528 }
1529 }
1530
1531 rewriter.clone(op, afterMapping);
1532 }
1533
1534 rewriter.setInsertionPointAfter(newWhile);
1535 for (auto [idx, slot] : llvm::enumerate(slots)) {
1536 genWrite(
1537 loc, slot.podRef, slot.recordName, newWhile.getResult(whileOp.getNumResults() + idx),
1538 rewriter
1539 );
1540 }
1541
1542 rewriter.replaceOp(whileOp, newWhile.getResults().take_front(whileOp.getNumResults()));
1543 return success();
1544 }
1545};
1546
1548static LogicalResult
1549applyGreedily(ModuleOp modOp, RewritePatternSet &&patterns, bool *changed = nullptr) {
1550 return applyPatternsGreedily(
1551 modOp->getRegion(0), std::move(patterns),
1552 GreedyRewriteConfig {.fold = false, .cseConstants = false}, changed
1553 );
1554}
1555
1558static LogicalResult step3(ModuleOp modOp) {
1559 RewritePatternSet patterns(modOp.getContext());
1560 patterns.add<
1561 ReplaceIfReadPattern, LiftPodWritesFromIfBlocksPattern, LiftPodAccessesFromForLoopPattern,
1562 LiftPodAccessesFromWhileLoopPattern, FoldIfCarriedPodReadAfterWritePattern>(
1563 patterns.getContext()
1564 );
1565
1566 LLVM_DEBUG(llvm::dbgs() << "Begin step 3: refactor pod ops within SCF regions\n";);
1567 return applyGreedily(modOp, std::move(patterns));
1568}
1569
1572static bool applyIfCarriedPodReadAfterWritePatterns(ModuleOp modOp) {
1573 RewritePatternSet patterns(modOp.getContext());
1574 patterns.add<FoldIfCarriedPodReadAfterWritePattern>(patterns.getContext());
1575
1576 bool changed = false;
1577 if (failed(applyGreedily(modOp, std::move(patterns), &changed))) {
1578 return false;
1579 }
1580 return changed;
1581}
1582
1585static size_t podTypeScalarizationWeight(Type type) {
1586 auto podTy = dyn_cast<PodType>(type);
1587 if (!podTy) {
1588 return 0;
1589 }
1590
1591 size_t weight = 1;
1592 for (RecordAttr record : podTy.getRecords()) {
1593 weight += podTypeScalarizationWeight(record.getType());
1594 }
1595 return weight;
1596}
1597
1601static size_t podAllocScalarizationWeight(ModuleOp modOp) {
1602 size_t weight = 0;
1603 modOp.walk([&weight](NewPodOp newPodOp) {
1604 weight += podTypeScalarizationWeight(newPodOp.getType());
1605 });
1606 return weight;
1607}
1608
1610class PassImpl : public llzk::pod::impl::PodToScalarPassBase<PassImpl> {
1611 using Base = PodToScalarPassBase<PassImpl>;
1612 using Base::Base;
1613
1614 void runOnOperation() override {
1615 ModuleOp module = getOperation();
1616
1617 if (failed(step0(module))) {
1618 return signalPassFailure();
1619 }
1620 LLVM_DEBUG({
1621 llvm::dbgs() << "After step 0:\n";
1622 module.dump();
1623 });
1624
1625 {
1626 // This is divided into 2 steps to simplify the implementation for member-related ops. The
1627 // issue is that the conversions for member read/write expect the mapping of record name to
1628 // member name+type to already be populated for the referenced member (although this could be
1629 // computed on demand if desired but it complicates the implementation a bit).
1630 SymbolTableCollection symTables;
1631 MemberReplacementMap memberRepMap;
1632 if (failed(step1(module, symTables, memberRepMap))) {
1633 return signalPassFailure();
1634 }
1635 LLVM_DEBUG({
1636 llvm::dbgs() << "After step 1:\n";
1637 module.dump();
1638 });
1639
1640 if (failed(step2(module, symTables, memberRepMap))) {
1641 return signalPassFailure();
1642 }
1643 LLVM_DEBUG({
1644 llvm::dbgs() << "After step 2:\n";
1645 module.dump();
1646 });
1647 }
1648
1649 if (failed(step3(module))) {
1650 return signalPassFailure();
1651 }
1652 LLVM_DEBUG({
1653 llvm::dbgs() << "After step 3:\n";
1654 module.dump();
1655 });
1656
1657 // 1. Use SROA (Destructurable* interfaces) to split each pod with `N` records into `N` pods
1658 // with 1 record each. This is necessary because the mem2reg pass cannot deal with splitting
1659 // up memory, i.e., it can only convert scalar memory access into SSA values.
1660 // 2. The mem2reg pass converts the size 1 pod allocations and accesses into SSA values.
1661 OpPassManager scalarizePM(ModuleOp::getOperationName());
1662 scalarizePM.addPass(createSpecializedSROAPass<NewPodOp>());
1663 scalarizePM.addPass(createSpecializedMem2RegPass<NewPodOp>());
1664
1665 // Cleanup SSA values made dead by the transformations
1666 OpPassManager cleanupPM(ModuleOp::getOperationName());
1667 cleanupPM.addPass(createRemoveDeadValuesPass());
1668
1669 size_t podAllocWeight = podAllocScalarizationWeight(module);
1670 while (podAllocWeight != 0) {
1671 if (failed(runPipeline(scalarizePM, module))) {
1672 signalPassFailure();
1673 return;
1674 }
1675
1676 // SROA+mem2reg can expose `scf.if`-carried POD values that become redundant after a
1677 // same-record write from another `scf.if` result. Fold those reads and clean up before
1678 // checking convergence.
1679 bool foldedIfCarriedRead = applyIfCarriedPodReadAfterWritePatterns(module);
1680 if (failed(runPipeline(cleanupPM, module))) {
1681 signalPassFailure();
1682 return;
1683 }
1684
1685 // Nested PODs can become visible only after an outer single-record POD has been promoted,
1686 // and SROA can transiently increase allocation count while splitting aggregates. Keep
1687 // iterating until the allocation-weight heuristic reaches a fixed point.
1688 size_t nextPodAllocWeight = podAllocScalarizationWeight(module);
1689 if (!foldedIfCarriedRead && nextPodAllocWeight == podAllocWeight) {
1690 break;
1691 }
1692 podAllocWeight = nextPodAllocWeight;
1693 }
1694 LLVM_DEBUG({
1695 llvm::dbgs() << "After SROA+Mem2Reg pipeline:\n";
1696 module.dump();
1697 });
1698 }
1699};
1700
1701} // namespace
within a display generated by the Derivative if and wherever such third party notices normally appear The contents of the NOTICE file are for informational purposes only and do not modify the License You may add Your own attribution notices within Derivative Works that You alongside or as an addendum to the NOTICE text from the provided that such additional attribution notices cannot be construed as modifying the License You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for or distribution of Your or for any such Derivative Works as a provided Your and distribution of the Work otherwise complies with the conditions stated in this License Submission of Contributions Unless You explicitly state any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this without any additional terms or conditions Notwithstanding the nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions Trademarks This License does not grant permission to use the trade names
Definition LICENSE.txt:139
Provides SpecializedSROA<AllocOpTy> and SpecializedMem2Reg<AllocOpTy>: pass templates that replicate ...
Common implementation for handling MemberWriteOp and MemberReadOp while destructuring an aggregate ty...
void setPublicAttr(bool newValue=true)
Definition Ops.cpp:568
::mlir::StringAttr getSymNameAttr()
Definition Ops.h.inc:386
::mlir::TypedValue<::mlir::Type > getVal()
Definition Ops.h.inc:956
::llvm::SmallVector< RangeT > getMapOperands()
Definition Ops.h.inc:175
::mlir::MutableOperandRange getArgOperandsMutable()
Definition Ops.cpp.inc:223
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:266
CallOpAdaptor Adaptor
Definition Ops.h.inc:205
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Required by FunctionOpInterface.
Definition Ops.h.inc:850
::llvm::ArrayRef<::mlir::Type > getResultTypes()
Required by FunctionOpInterface.
Definition Ops.h.inc:854
::mlir::MutableOperandRange getOperandsMutable()
Definition Ops.cpp.inc:1169
::mlir::Operation::operand_range getOperands()
Definition Ops.h.inc:991
::mlir::Operation::operand_range getInitialValues()
Definition Ops.h.inc:237
::mlir::TypedValue<::llzk::pod::PodType > getResult()
Definition Ops.h.inc:257
void setInitializedRecordsAttr(::mlir::ArrayAttr attr)
Definition Ops.h.inc:285
::mlir::MutableOperandRange getInitialValuesMutable()
Definition Ops.cpp.inc:185
::llvm::ArrayRef<::llzk::pod::RecordAttr > getRecords() const
::mlir::TypedValue<::mlir::Type > getResult()
Definition Ops.h.inc:499
::mlir::FlatSymbolRefAttr getRecordNameAttr()
Definition Ops.h.inc:512
::mlir::TypedValue<::llzk::pod::PodType > getPodRef()
Definition Ops.h.inc:480
::mlir::FlatSymbolRefAttr getRecordNameAttr()
Definition Ops.h.inc:746
::mlir::TypedValue<::mlir::Type > getValue()
Definition Ops.h.inc:713
::mlir::TypedValue<::llzk::pod::PodType > getPodRef()
Definition Ops.h.inc:709
constexpr char ARG_NAME_ATTR_NAME[]
Attribute name for source-level function argument names.
Definition Ops.h:34
constexpr char RES_NAME_ATTR_NAME[]
Attribute name for source-level function result names.
Definition Ops.h:37
mlir::ArrayAttr replicateFunctionNameAttrsAsNeeded(mlir::ArrayAttr origAttrs, const llvm::SmallVector< size_t > &originalIdxToSize, const llvm::SmallVector< mlir::Type > &newTypes, llvm::StringRef functionNameAttrName, llvm::ArrayRef< std::optional< llvm::StringRef > > origNames={}, llvm::ArrayRef< llvm::StringRef > existingNames={}, llvm::ArrayRef< llvm::SmallVector< std::string > > splitNameSuffixes={})
Expand function arg/result attribute arrays to match a split signature, rewriting name attrs with the...
std::unique_ptr< SpecializedMem2Reg< AllocOpTy > > createSpecializedMem2RegPass()
function::CallOp createCallPreservingInstantiationOperands(mlir::Location loc, mlir::TypeRange newResultTypes, function::CallOp oldCall, llvm::ArrayRef< mlir::ValueRange > mapOperands, mlir::ValueRange argOperands, mlir::ConversionPatternRewriter &rewriter)
Rebuild a function.call while preserving explicit instantiation state from oldCall.
SplitFunctionNameInfo collectSplitFunctionNameInfo(mlir::ArrayRef< mlir::Type > origTypes, GetNameAttrFn &&getNameAttr, GetSplitSuffixesFn &&getSplitSuffixes)
Collect function arg/result names and split suffixes from a list of original types.
std::unique_ptr< SpecializedSROA< AllocOpTy > > createSpecializedSROAPass()
static unsigned getHashValue(const RecordChain &chain)
static bool isEqual(const RecordChain &lhs, const RecordChain &rhs)
llvm::SmallVector< std::optional< llvm::StringRef > > originalNames
llvm::SmallVector< llvm::StringRef > existingNames
llvm::SmallVector< llvm::SmallVector< std::string > > splitNameSuffixes