LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
ArrayToScalarPass.cpp
Go to the documentation of this file.
1//===-- ArrayToScalarPass.cpp -----------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
19///
20/// 2. Run a dialect conversion that does the following:
21///
22/// - Replace `MemberReadOp` and `MemberWriteOp` targeting the members that were split in step 1
23/// so they instead perform scalar reads and writes from the new members. The transformation is
24/// local to the current op. Therefore, when replacing the `MemberReadOp` a new array is
25/// created locally and all uses of the `MemberReadOp` are replaced with the new array Value,
26/// then each scalar member read is followed by scalar write into the new array. Similarly,
27/// when replacing a `MemberWriteOp`, each element in the array operand needs a scalar read
28/// from the array followed by a scalar write to the new member. Making only local changes
29/// keeps this step simple and later steps will optimize.
30///
31/// - Replace `ArrayLengthOp` with the constant size of the selected dimension.
32///
33/// - Remove element initialization from `CreateArrayOp` and instead insert a list of
34/// `WriteArrayOp` immediately following.
35///
36/// - Desugar `InsertArrayOp` and `ExtractArrayOp` into their element-wise scalar reads/writes.
37///
38/// - Split arrays to scalars in `FuncDefOp`, `CallOp`, and `ReturnOp` and insert the necessary
39/// create/read/write ops so the changes are as local as possible (just as described for
40/// `MemberReadOp` and `MemberWriteOp`)
41///
42/// 3. Replace branch-local reads (in `scf.if`) with the value written by a same-index write op that
43/// dominates the parent `scf.if` (because the passes below cannot handle that case).
44///
45/// 4. Run MLIR "sroa" pass to split each array with linear size `N` into `N` arrays of size 1
46/// (to prepare for "mem2reg" pass because its API cannot deal with splitting up memory).
47///
48/// 5. Run MLIR "mem2reg" pass to convert all of the size 1 array allocation and access into SSA
49/// values. This pass also runs several standard optimizations so the final result is condensed.
50///
51/// 6. Remove array allocations that become unread after memory promotion, then remove SSA values
52/// made dead by that cleanup.
53///
54/// Note: This transformation imposes a "last write wins" semantics on array elements. If
55/// different/configurable semantics are added in the future, some additional transformation would
56/// be necessary before/during this pass so that multiple writes to the same index can be handled
57/// properly while they still exist.
58///
59/// Note: This transformation will introduce a `nondet` op when there exists a read from an array
60/// index that was not earlier written to.
88
89#include <mlir/Dialect/SCF/IR/SCF.h>
90#include <mlir/IR/BuiltinOps.h>
91#include <mlir/Pass/PassManager.h>
92#include <mlir/Transforms/DialectConversion.h>
93#include <mlir/Transforms/Passes.h>
94
95#include <llvm/Support/Debug.h>
96
97#include <optional>
98
99// Include the generated base pass class definitions.
100namespace llzk::array {
101#define GEN_PASS_DEF_ARRAYTOSCALARPASS
103} // namespace llzk::array
104
105using namespace mlir;
106using namespace llzk;
107using namespace llzk::array;
108using namespace llzk::component;
109using namespace llzk::function;
110
111#define DEBUG_TYPE "llzk-array-to-scalar"
112
113namespace {
114
116inline ArrayType splittableArray(ArrayType at) { return at.hasStaticShape() ? at : nullptr; }
117
119inline ArrayType splittableArray(Type t) {
120 if (ArrayType at = dyn_cast<ArrayType>(t)) {
121 return splittableArray(at);
122 } else {
123 return nullptr;
124 }
125}
126
129inline bool containsSplittableArrayType(ArrayRef<Type> types) {
130 for (Type t : types) {
131 if (splittableArray(t)) {
132 return true;
133 }
134 }
135 return false;
136}
137
139template <typename T> bool containsSplittableArrayType(ValueTypeRange<T> types) {
140 for (Type t : types) {
141 if (splittableArray(t)) {
142 return true;
143 }
144 }
145 return false;
146}
147
150size_t splitArrayTypeTo(Type t, SmallVector<Type> &collect) {
151 if (ArrayType at = splittableArray(t)) {
152 size_t size = llzk::checkedCast<size_t>(at.getNumElements());
153 collect.append(size, at.getElementType());
154 return size;
155 } else {
156 collect.push_back(t);
157 return 1;
158 }
159}
160
162template <typename TypeCollection>
163inline void splitArrayTypeTo(
164 TypeCollection types, SmallVector<Type> &collect, SmallVector<size_t> *originalIdxToSize
165) {
166 for (Type t : types) {
167 size_t count = splitArrayTypeTo(t, collect);
168 if (originalIdxToSize) {
169 originalIdxToSize->push_back(count);
170 }
171 }
172}
173
176template <typename TypeCollection>
177inline SmallVector<Type>
178splitArrayType(TypeCollection types, SmallVector<size_t> *originalIdxToSize = nullptr) {
179 SmallVector<Type> collect;
180 splitArrayTypeTo(types, collect, originalIdxToSize);
181 return collect;
182}
183
186SmallVector<Value> genIndexConstants(ArrayAttr index, Location loc, RewriterBase &rewriter) {
187 SmallVector<Value> operands;
188 for (Attribute a : index) {
189 // ASSERT: Attributes are index constants, created by ArrayType::getSubelementIndices().
190 IntegerAttr ia = llvm::dyn_cast<IntegerAttr>(a);
191 assert(ia && ia.getType().isIndex());
192 operands.push_back(rewriter.create<arith::ConstantOp>(loc, ia));
193 }
194 return operands;
195}
196
198inline WriteArrayOp
199genWrite(Location loc, Value baseArrayOp, ArrayAttr index, Value init, RewriterBase &rewriter) {
200 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
201 return rewriter.create<WriteArrayOp>(loc, baseArrayOp, ValueRange(readOperands), init);
202}
203
205static std::string formatSplitArrayIndexSuffix(ArrayAttr index) {
206 std::string suffix;
207 llvm::raw_string_ostream os(suffix);
208 for (Attribute attr : index) {
209 os << '[';
210 attr.print(os, true);
211 os << ']';
212 }
213 return suffix;
214}
215
217static SmallVector<std::string> getSplitArrayIndexSuffixes(Type type) {
218 SmallVector<std::string> suffixes;
219 if (ArrayType at = splittableArray(type)) {
220 std::optional<SmallVector<ArrayAttr>> indices = at.getSubelementIndices();
221 assert(indices.has_value() && "static-shape arrays must provide subelement indices");
222 suffixes.reserve(indices->size());
223 for (ArrayAttr index : *indices) {
224 suffixes.push_back(formatSplitArrayIndexSuffix(index));
225 }
226 }
227 return suffixes;
228}
229
231CallOp newCallOpWithSplitResults(
232 CallOp oldCall, CallOp::Adaptor adaptor, ConversionPatternRewriter &rewriter
233) {
234 OpBuilder::InsertionGuard guard(rewriter);
235 rewriter.setInsertionPointAfter(oldCall);
236
237 Operation::result_range oldResults = oldCall.getResults();
239 oldCall.getLoc(), splitArrayType(oldResults.getTypes()), oldCall, adaptor.getMapOperands(),
240 adaptor.getArgOperands(), rewriter
241 );
242
243 auto newResults = newCall.getResults().begin();
244 for (Value oldVal : oldResults) {
245 if (ArrayType at = splittableArray(oldVal.getType())) {
246 Location loc = oldVal.getLoc();
247 // Generate `CreateArrayOp` and replace uses of the result with it.
248 auto newArray = rewriter.create<CreateArrayOp>(loc, at);
249 rewriter.replaceAllUsesWith(oldVal, newArray);
250
251 // For all indices in the ArrayType (i.e., the element count), write the next
252 // result from the new CallOp to the new array.
253 std::optional<SmallVector<ArrayAttr>> allIndices = at.getSubelementIndices();
254 assert(allIndices); // follows from legal() check
255 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
256 for (ArrayAttr subIdx : allIndices.value()) {
257 genWrite(loc, newArray, subIdx, *newResults, rewriter);
258 newResults++;
259 }
260 } else {
261 rewriter.replaceAllUsesWith(oldVal, *newResults);
262 newResults++;
263 }
264 }
265 // erase the original CallOp
266 rewriter.eraseOp(oldCall);
267
268 return newCall;
269}
270
272inline ReadArrayOp
273genRead(Location loc, Value baseArrayOp, ArrayAttr index, ConversionPatternRewriter &rewriter) {
274 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
275 return rewriter.create<ReadArrayOp>(loc, baseArrayOp, ValueRange(readOperands));
276}
277
280void processInputOperand(
281 Location loc, Value operand, SmallVector<Value> &newOperands,
282 ConversionPatternRewriter &rewriter
283) {
284 if (ArrayType at = splittableArray(operand.getType())) {
285 std::optional<SmallVector<ArrayAttr>> indices = at.getSubelementIndices();
286 assert(indices.has_value() && "passed earlier hasStaticShape() check");
287 for (ArrayAttr index : indices.value()) {
288 newOperands.push_back(genRead(loc, operand, index, rewriter));
289 }
290 } else {
291 newOperands.push_back(operand);
292 }
293}
294
296void processInputOperands(
297 ValueRange operands, MutableOperandRange outputOpRef, Operation *op,
298 ConversionPatternRewriter &rewriter
299) {
300 SmallVector<Value> newOperands;
301 for (Value v : operands) {
302 processInputOperand(op->getLoc(), v, newOperands, rewriter);
303 }
304 rewriter.modifyOpInPlace(op, [&outputOpRef, &newOperands]() {
305 outputOpRef.assign(ValueRange(newOperands));
306 });
307}
308
309namespace {
310
311enum Direction : std::uint8_t {
313 SMALL_TO_LARGE,
315 LARGE_TO_SMALL,
316};
317
320template <Direction dir>
321inline void rewriteImpl(
322 ArrayAccessOpInterface op, ArrayType smallType, Value smallArr, Value largeArr,
323 ConversionPatternRewriter &rewriter
324) {
325 assert(smallType); // follows from legal() check
326 Location loc = op.getLoc();
327 MLIRContext *ctx = op.getContext();
328
329 ArrayAttr indexAsAttr = op.indexOperandsToAttributeArray();
330 assert(indexAsAttr); // follows from legal() check
331
332 // For all indices in the ArrayType (i.e., the element count), read from one array into the other
333 // (depending on direction flag).
334 std::optional<SmallVector<ArrayAttr>> subIndices = smallType.getSubelementIndices();
335 assert(subIndices); // follows from legal() check
336 assert(std::cmp_equal(subIndices->size(), smallType.getNumElements()));
337 for (ArrayAttr indexingTail : subIndices.value()) {
338 SmallVector<Attribute> joined;
339 joined.append(indexAsAttr.begin(), indexAsAttr.end());
340 joined.append(indexingTail.begin(), indexingTail.end());
341 ArrayAttr fullIndex = ArrayAttr::get(ctx, joined);
342
343 if constexpr (dir == Direction::SMALL_TO_LARGE) {
344 auto init = genRead(loc, smallArr, indexingTail, rewriter);
345 genWrite(loc, largeArr, fullIndex, init, rewriter);
346 } else if constexpr (dir == Direction::LARGE_TO_SMALL) {
347 auto init = genRead(loc, largeArr, fullIndex, rewriter);
348 genWrite(loc, smallArr, indexingTail, init, rewriter);
349 }
350 }
351}
352
353} // namespace
354
356class SplitInsertArrayOp : public OpConversionPattern<InsertArrayOp> {
357public:
358 using OpConversionPattern<InsertArrayOp>::OpConversionPattern;
359
360 static bool legal(InsertArrayOp op) {
361 return !containsSplittableArrayType(op.getRvalue().getType());
362 }
363
364 LogicalResult match(InsertArrayOp op) const override { return failure(legal(op)); }
365
366 void
367 rewrite(InsertArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
368 ArrayType at = splittableArray(op.getRvalue().getType());
369 rewriteImpl<SMALL_TO_LARGE>(
370 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, adaptor.getRvalue(),
371 adaptor.getArrRef(), rewriter
372 );
373 rewriter.eraseOp(op);
374 }
375};
376
378class SplitExtractArrayOp : public OpConversionPattern<ExtractArrayOp> {
379public:
380 using OpConversionPattern<ExtractArrayOp>::OpConversionPattern;
381
382 static bool legal(ExtractArrayOp op) {
383 return !containsSplittableArrayType(op.getResult().getType());
384 }
385
386 LogicalResult match(ExtractArrayOp op) const override { return failure(legal(op)); }
387
388 void rewrite(
389 ExtractArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
390 ) const override {
391 ArrayType at = splittableArray(op.getResult().getType());
392 // Generate `CreateArrayOp` in place of the current op.
393 auto newArray = rewriter.replaceOpWithNewOp<CreateArrayOp>(op, at);
394 rewriteImpl<LARGE_TO_SMALL>(
395 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, newArray, adaptor.getArrRef(),
396 rewriter
397 );
398 }
399};
400
402class SplitInitFromCreateArrayOp : public OpConversionPattern<CreateArrayOp> {
403public:
404 using OpConversionPattern<CreateArrayOp>::OpConversionPattern;
405
406 static bool legal(CreateArrayOp op) { return op.getElements().empty(); }
407
408 LogicalResult match(CreateArrayOp op) const override { return failure(legal(op)); }
409
410 void
411 rewrite(CreateArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
412 // Remove elements from `op`
413 rewriter.modifyOpInPlace(op, [&op]() { op.getElementsMutable().clear(); });
414 // Generate an individual write for each initialization element
415 rewriter.setInsertionPointAfter(op);
416 Location loc = op.getLoc();
417 ArrayIndexGen idxGen = ArrayIndexGen::from(op.getType());
418 for (auto [i, init] : llvm::enumerate(adaptor.getElements())) {
419 // Convert the linear index 'i' into a multi-dim index
420 std::optional<SmallVector<Value>> multiDimIdxVals =
421 idxGen.delinearize(llzk::checkedCast<int64_t>(i), loc, rewriter);
422 // ASSERT: CreateArrayOp verifier ensures the number of elements provided matches the full
423 // linear array size so delinearization of `i` will not fail.
424 assert(multiDimIdxVals.has_value());
425 // Create the write
426 rewriter.create<WriteArrayOp>(loc, op.getResult(), ValueRange(*multiDimIdxVals), init);
427 }
428 }
429};
430
432class SplitArrayInFuncDefOp : public OpConversionPattern<FuncDefOp> {
433public:
434 using OpConversionPattern<FuncDefOp>::OpConversionPattern;
435
436 inline static bool legal(FuncDefOp op) {
437 return !containsSplittableArrayType(op.getArgumentTypes()) &&
438 !containsSplittableArrayType(op.getResultTypes());
439 }
440
441 LogicalResult match(FuncDefOp op) const override { return failure(legal(op)); }
442
443 void rewrite(FuncDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override {
444 // Update in/out types of the function to replace arrays with scalars
445 class Impl : public FunctionTypeConverter {
446 SmallVector<size_t> originalInputIdxToSize, originalResultIdxToSize;
447 SplitFunctionNameInfo inputNameInfo;
448 SplitFunctionNameInfo resultNameInfo;
449
450 protected:
451 SmallVector<Type> convertInputs(ArrayRef<Type> origTypes) override {
452 return splitArrayType(origTypes, &originalInputIdxToSize);
453 }
454 SmallVector<Type> convertResults(ArrayRef<Type> origTypes) override {
455 return splitArrayType(origTypes, &originalResultIdxToSize);
456 }
457 ArrayAttr convertInputAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes) override {
459 origAttrs, originalInputIdxToSize, newTypes, ARG_NAME_ATTR_NAME,
460 inputNameInfo.originalNames, inputNameInfo.existingNames,
461 inputNameInfo.splitNameSuffixes
462 );
463 }
464 ArrayAttr convertResultAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes) override {
466 origAttrs, originalResultIdxToSize, newTypes, RES_NAME_ATTR_NAME,
467 resultNameInfo.originalNames, resultNameInfo.existingNames,
468 resultNameInfo.splitNameSuffixes
469 );
470 }
471
476 void processBlockArgs(Block &entryBlock, RewriterBase &rewriter) override {
477 OpBuilder::InsertionGuard guard(rewriter);
478 rewriter.setInsertionPointToStart(&entryBlock);
479
480 for (unsigned i = 0; i < entryBlock.getNumArguments();) {
481 Value oldV = entryBlock.getArgument(i);
482 if (ArrayType at = splittableArray(oldV.getType())) {
483 Location loc = oldV.getLoc();
484 // Generate `CreateArrayOp` and replace uses of the argument with it.
485 auto newArray = rewriter.create<CreateArrayOp>(loc, at);
486 rewriter.replaceAllUsesWith(oldV, newArray);
487 // Remove the argument from the block
488 entryBlock.eraseArgument(i);
489 // For all indices in the ArrayType (i.e., the element count), generate a new block
490 // argument and a write of that argument to the new array.
491 std::optional<SmallVector<ArrayAttr>> allIndices = at.getSubelementIndices();
492 assert(allIndices); // follows from legal() check
493 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
494 for (ArrayAttr subIdx : allIndices.value()) {
495 BlockArgument newArg = entryBlock.insertArgument(i, at.getElementType(), loc);
496 genWrite(loc, newArray, subIdx, newArg, rewriter);
497 ++i;
498 }
499 } else {
500 ++i;
501 }
502 }
503 }
504
505 public:
506 Impl(FuncDefOp op) {
507 ArrayAttr resultAttrs = op.getAllResultAttrs();
508 inputNameInfo = collectSplitFunctionNameInfo(op.getArgumentTypes(), [&](unsigned i) {
509 return op.getArgNameAttr(i);
510 }, getSplitArrayIndexSuffixes);
511 resultNameInfo = collectSplitFunctionNameInfo(op.getResultTypes(), [&](unsigned i) {
512 return getAttrAtIndexWithName(resultAttrs, i, RES_NAME_ATTR_NAME);
513 }, getSplitArrayIndexSuffixes);
514 }
515 };
516 Impl(op).convert(op, rewriter);
517 }
518};
519
521class SplitArrayInReturnOp : public OpConversionPattern<ReturnOp> {
522public:
523 using OpConversionPattern<ReturnOp>::OpConversionPattern;
524
525 inline static bool legal(ReturnOp op) {
526 return !containsSplittableArrayType(op.getOperands().getTypes());
527 }
528
529 LogicalResult match(ReturnOp op) const override { return failure(legal(op)); }
530
531 void rewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
532 processInputOperands(adaptor.getOperands(), op.getOperandsMutable(), op, rewriter);
533 }
534};
535
537class SplitArrayInCallOp : public OpConversionPattern<CallOp> {
538public:
539 using OpConversionPattern<CallOp>::OpConversionPattern;
540
541 inline static bool legal(CallOp op) {
542 return !containsSplittableArrayType(op.getArgOperands().getTypes()) &&
543 !containsSplittableArrayType(op.getResultTypes());
544 }
545
546 LogicalResult match(CallOp op) const override { return failure(legal(op)); }
547
548 void rewrite(CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
549 // Create new CallOp with split results first so, then process its inputs to split types
550 CallOp newCall = newCallOpWithSplitResults(op, adaptor, rewriter);
551 processInputOperands(
552 newCall.getArgOperands(), newCall.getArgOperandsMutable(), newCall, rewriter
553 );
554 }
555};
556
558class ReplaceKnownArrayLengthOp : public OpConversionPattern<ArrayLengthOp> {
559public:
560 using OpConversionPattern<ArrayLengthOp>::OpConversionPattern;
561
563 static std::optional<llvm::APInt> getDimSizeIfKnown(Value dimIdx, ArrayType baseArrType) {
564 if (splittableArray(baseArrType)) {
565 llvm::APInt idxAP;
566 if (mlir::matchPattern(dimIdx, mlir::m_ConstantInt(&idxAP))) {
567 std::optional<int64_t> signedIdx = idxAP.trySExtValue();
568 if (!signedIdx || *signedIdx < 0) {
569 return std::nullopt;
570 }
571 size_t idx = llzk::checkedCast<size_t>(*signedIdx);
572 ArrayRef<Attribute> dimSizes = baseArrType.getDimensionSizes();
573 if (idx >= dimSizes.size()) {
574 return std::nullopt;
575 }
576 Attribute dimSizeAttr = dimSizes[idx];
577 if (mlir::matchPattern(dimSizeAttr, mlir::m_ConstantInt(&idxAP))) {
578 return idxAP;
579 }
580 }
581 }
582 return std::nullopt;
583 }
584
585 inline static bool legal(ArrayLengthOp op) {
586 // rewrite() can only work with constant dim size, i.e., must consider it legal otherwise
587 return !getDimSizeIfKnown(op.getDim(), op.getArrRefType()).has_value();
588 }
589
590 LogicalResult match(ArrayLengthOp op) const override { return failure(legal(op)); }
591
592 void
593 rewrite(ArrayLengthOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
594 ArrayType arrTy = dyn_cast<ArrayType>(adaptor.getArrRef().getType());
595 assert(arrTy); // must have array type per ODS spec of ArrayLengthOp
596 std::optional<llvm::APInt> len = getDimSizeIfKnown(adaptor.getDim(), arrTy);
597 assert(len.has_value()); // follows from legal() check
598 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, llzk::fromAPInt(len.value()));
599 }
600};
601
603using MemberInfo = std::pair<StringAttr, Type>;
605using LocalMemberReplacementMap = DenseMap<ArrayAttr, MemberInfo>;
607using MemberReplacementMap = DenseMap<StructDefOp, DenseMap<StringAttr, LocalMemberReplacementMap>>;
608
610class SplitArrayInMemberDefOp : public OpConversionPattern<MemberDefOp> {
611 SymbolTableCollection &tables;
612 MemberReplacementMap &repMapRef;
613
614public:
615 SplitArrayInMemberDefOp(
616 MLIRContext *ctx, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap
617 )
618 : OpConversionPattern<MemberDefOp>(ctx), tables(symTables), repMapRef(memberRepMap) {}
619
620 inline static bool legal(MemberDefOp op) { return !containsSplittableArrayType(op.getType()); }
621
622 LogicalResult match(MemberDefOp op) const override { return failure(legal(op)); }
623
624 void rewrite(MemberDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override {
625 StructDefOp inStruct = op->getParentOfType<StructDefOp>();
626 assert(inStruct);
627 LocalMemberReplacementMap &localRepMapRef = repMapRef[inStruct][op.getSymNameAttr()];
628
629 ArrayType arrTy = dyn_cast<ArrayType>(op.getType());
630 assert(arrTy); // follows from legal() check
631 auto subIdxs = arrTy.getSubelementIndices();
632 assert(subIdxs.has_value());
633 Type elemTy = arrTy.getElementType();
634
635 SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct);
636 for (ArrayAttr idx : subIdxs.value()) {
637 // Create scalar version of the member
638 MemberDefOp newMember = rewriter.create<MemberDefOp>(
639 op.getLoc(), op.getSymNameAttr(), elemTy, op.getSignal(), op.getColumn()
640 );
641 newMember.setPublicAttr(op.hasPublicAttr());
642 // Use SymbolTable to give it a unique name and store to the replacement map
643 localRepMapRef[idx] = std::make_pair(structSymbolTable.insert(newMember), elemTy);
644 }
645 rewriter.eraseOp(op);
646 }
647};
648
650class SplitArrayInMemberWriteOp : public SplitAggregateInMemberRefOp<
651 SplitArrayInMemberWriteOp, MemberWriteOp, void *, ArrayAttr> {
652public:
654 SplitArrayInMemberWriteOp, MemberWriteOp, void *, ArrayAttr>::SplitAggregateInMemberRefOp;
655
656 static bool legal(MemberWriteOp op) {
657 return !containsSplittableArrayType(op.getVal().getType());
658 }
659
660 static void *genHeader(MemberWriteOp, ConversionPatternRewriter &) { return nullptr; }
661
662 static void forId(
663 Location loc, void *, ArrayAttr idx, MemberInfo newMember, OpAdaptor adaptor,
664 ConversionPatternRewriter &rewriter
665 ) {
666 ReadArrayOp scalarRead = genRead(loc, adaptor.getVal(), idx, rewriter);
667 rewriter.create<MemberWriteOp>(
668 loc, adaptor.getComponent(), FlatSymbolRefAttr::get(newMember.first), scalarRead
669 );
670 }
671};
672
675class SplitArrayInMemberReadOp
677 SplitArrayInMemberReadOp, MemberReadOp, CreateArrayOp, ArrayAttr> {
678public:
680 SplitArrayInMemberReadOp, MemberReadOp, CreateArrayOp,
682
683 static bool legal(MemberReadOp op) {
684 return !containsSplittableArrayType(op.getResult().getType());
685 }
686
687 static CreateArrayOp genHeader(MemberReadOp op, ConversionPatternRewriter &rewriter) {
688 CreateArrayOp newArray =
689 rewriter.create<CreateArrayOp>(op.getLoc(), llvm::cast<ArrayType>(op.getType()));
690 rewriter.replaceAllUsesWith(op, newArray);
691 return newArray;
692 }
693
694 static void forId(
695 Location loc, CreateArrayOp newArray, ArrayAttr idx, MemberInfo newMember, OpAdaptor adaptor,
696 ConversionPatternRewriter &rewriter
697 ) {
698 MemberReadOp scalarRead = rewriter.create<MemberReadOp>(
699 loc, newMember.second, adaptor.getComponent(), newMember.first
700 );
701 genWrite(loc, newArray, idx, scalarRead, rewriter);
702 }
703};
704
706static void baseTargetSetup(ConversionTarget &target) {
707 target.addLegalDialect<
712 scf::SCFDialect>();
713 target.addLegalOp<ModuleOp>();
714}
715
717class NondetToNewArray : public OpConversionPattern<NonDetOp> {
718 using OpConversionPattern<NonDetOp>::OpConversionPattern;
719 LogicalResult matchAndRewrite(
720 NonDetOp nondetOp, OpAdaptor, ConversionPatternRewriter &rewriter
721 ) const override {
722 if (auto at = dyn_cast<ArrayType>(nondetOp.getType())) {
723 rewriter.replaceOpWithNewOp<CreateArrayOp>(nondetOp, at);
724 return success();
725 }
726 return failure();
727 }
728};
729
731static LogicalResult step0(ModuleOp modOp) {
732 MLIRContext *ctx = modOp.getContext();
733 RewritePatternSet patterns {ctx};
734 patterns.add<NondetToNewArray>(ctx);
735 ConversionTarget target {*ctx};
736
737 baseTargetSetup(target);
738 target.addDynamicallyLegalOp<NonDetOp>([](NonDetOp op) { return !isa<ArrayType>(op.getType()); });
739
740 return applyFullConversion(modOp, target, std::move(patterns));
741}
742
744static LogicalResult
745step1(ModuleOp modOp, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap) {
746 MLIRContext *ctx = modOp.getContext();
747
748 RewritePatternSet patterns(ctx);
749
750 patterns.add<SplitArrayInMemberDefOp>(ctx, symTables, memberRepMap);
751
752 ConversionTarget target(*ctx);
753 baseTargetSetup(target);
754 target.addDynamicallyLegalOp<MemberDefOp>(SplitArrayInMemberDefOp::legal);
755
756 LLVM_DEBUG(llvm::dbgs() << "Begin step 1: split array-type members\n";);
757 return applyFullConversion(modOp, target, std::move(patterns));
758}
759
762static LogicalResult
763step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap) {
764 MLIRContext *ctx = modOp.getContext();
765
766 RewritePatternSet patterns(ctx);
767 patterns.add<
768 // clang-format off
769 SplitInitFromCreateArrayOp,
770 SplitInsertArrayOp,
771 SplitExtractArrayOp,
772 SplitArrayInFuncDefOp,
773 SplitArrayInReturnOp,
774 SplitArrayInCallOp,
775 ReplaceKnownArrayLengthOp
776 // clang-format on
777 >(ctx);
778
779 patterns.add<
780 // clang-format off
781 SplitArrayInMemberWriteOp,
782 SplitArrayInMemberReadOp
783 // clang-format on
784 >(ctx, symTables, memberRepMap);
785
786 ConversionTarget target(*ctx);
787 baseTargetSetup(target);
788 target.addDynamicallyLegalOp<CreateArrayOp>(SplitInitFromCreateArrayOp::legal);
789 target.addDynamicallyLegalOp<InsertArrayOp>(SplitInsertArrayOp::legal);
790 target.addDynamicallyLegalOp<ExtractArrayOp>(SplitExtractArrayOp::legal);
791 target.addDynamicallyLegalOp<FuncDefOp>(SplitArrayInFuncDefOp::legal);
792 target.addDynamicallyLegalOp<ReturnOp>(SplitArrayInReturnOp::legal);
793 target.addDynamicallyLegalOp<CallOp>(SplitArrayInCallOp::legal);
794 target.addDynamicallyLegalOp<ArrayLengthOp>(ReplaceKnownArrayLengthOp::legal);
795 target.addDynamicallyLegalOp<MemberWriteOp>(SplitArrayInMemberWriteOp::legal);
796 target.addDynamicallyLegalOp<MemberReadOp>(SplitArrayInMemberReadOp::legal);
797
798 LLVM_DEBUG(llvm::dbgs() << "Begin step 2: update/split other array ops\n";);
799 return applyFullConversion(modOp, target, std::move(patterns));
800}
801
803inline static ArrayAttr getIndexAsAttr(ArrayAccessOpInterface op) {
805}
806
808static bool mayWriteToIndex(WriteArrayOp writeOp, ArrayAttr index) {
809 ArrayAttr writeIndex = getIndexAsAttr(writeOp);
810 return !writeIndex || writeIndex == index;
811}
812
814static bool hasEarlierWriteInBlock(ReadArrayOp readOp, ArrayAttr readIndex) {
815 Value arrRef = readOp.getArrRef();
816 for (Operation &op : *readOp->getBlock()) {
817 if (&op == readOp.getOperation()) {
818 return false;
819 }
820
821 if (auto writeOp = dyn_cast<WriteArrayOp>(&op)) {
822 if (writeOp.getArrRef() == arrRef && mayWriteToIndex(writeOp, readIndex)) {
823 return true;
824 }
825 continue;
826 }
827
828 // Writes nested inside earlier operations may conditionally clobber the read's value.
829 if (op.walk([arrRef, readIndex](WriteArrayOp writeOp) {
830 if (writeOp.getArrRef() == arrRef && mayWriteToIndex(writeOp, readIndex)) {
831 return WalkResult::interrupt();
832 }
833 return WalkResult::advance();
834 }).wasInterrupted()) {
835 return true;
836 }
837 }
838 return false;
839}
840
842static std::optional<WriteArrayOp> findPrecedingWriteForIfRead(ReadArrayOp readOp) {
843 ArrayAttr readIndex = getIndexAsAttr(readOp);
844 if (!readIndex) {
845 return std::nullopt;
846 }
847
848 // Only handle reads that are direct children of an `scf.if` branch.
849 auto ifOp = readOp->getParentOfType<scf::IfOp>();
850 if (!ifOp || readOp->getBlock()->getParentOp() != ifOp.getOperation()) {
851 return std::nullopt;
852 }
853 if (hasEarlierWriteInBlock(readOp, readIndex)) {
854 return std::nullopt;
855 }
856
857 Block *ifBlock = ifOp->getBlock();
858 if (!ifBlock) {
859 return std::nullopt;
860 }
861
862 Value arrRef = readOp.getArrRef();
863 WriteArrayOp replacement;
864 for (Operation &op : *ifBlock) {
865 if (&op == ifOp.getOperation()) {
866 break;
867 }
868
869 if (auto writeOp = dyn_cast<WriteArrayOp>(&op)) {
870 if (writeOp.getArrRef() != arrRef) {
871 continue;
872 }
873
874 if (mayWriteToIndex(writeOp, readIndex)) {
875 ArrayAttr writeIndex = getIndexAsAttr(writeOp);
876 replacement = writeIndex == readIndex ? writeOp : WriteArrayOp();
877 }
878 continue;
879 }
880
881 // A nested write before the `scf.if` may overwrite the current candidate.
882 if (op.walk([arrRef, readIndex, &replacement](WriteArrayOp writeOp) {
883 if (writeOp.getArrRef() == arrRef && mayWriteToIndex(writeOp, readIndex)) {
884 replacement = WriteArrayOp();
885 return WalkResult::interrupt();
886 }
887 return WalkResult::advance();
888 }).wasInterrupted()) {
889 continue;
890 }
891 }
892
893 return replacement ? std::make_optional(replacement) : std::nullopt;
894}
895
898static void step3(ModuleOp modOp) {
899 SmallVector<std::pair<ReadArrayOp, Value>> replacements;
900 modOp.walk([&replacements](ReadArrayOp readOp) {
901 if (std::optional<WriteArrayOp> writeOp = findPrecedingWriteForIfRead(readOp)) {
902 replacements.emplace_back(readOp, writeOp->getRvalue());
903 }
904 });
905
906 for (auto [readOp, value] : replacements) {
907 readOp.getResult().replaceAllUsesWith(value);
908 readOp.erase();
909 }
910}
911
913class PassImpl : public llzk::array::impl::ArrayToScalarPassBase<PassImpl> {
914 using Base = ArrayToScalarPassBase<PassImpl>;
915 using Base::Base;
916
917 void runOnOperation() override {
918 ModuleOp module = getOperation();
919
920 if (failed(step0(module))) {
921 return signalPassFailure();
922 }
923 LLVM_DEBUG({
924 llvm::dbgs() << "After step 0:\n";
925 module.dump();
926 });
927
928 {
929 // This is divided into 2 steps to simplify the implementation for member-related ops. The
930 // issue is that the conversions for member read/write expect the mapping of array index to
931 // member name+type to already be populated for the referenced member (although this could be
932 // computed on demand if desired but it complicates the implementation a bit).
933 SymbolTableCollection symTables;
934 MemberReplacementMap memberRepMap;
935 if (failed(step1(module, symTables, memberRepMap))) {
936 return signalPassFailure();
937 }
938 LLVM_DEBUG({
939 llvm::dbgs() << "After step 1:\n";
940 module.dump();
941 });
942
943 if (failed(step2(module, symTables, memberRepMap))) {
944 return signalPassFailure();
945 }
946 LLVM_DEBUG({
947 llvm::dbgs() << "After step 2:\n";
948 module.dump();
949 });
950 }
951
952 step3(module);
953 LLVM_DEBUG({
954 llvm::dbgs() << "After step 3:\n";
955 module.dump();
956 });
957
958 OpPassManager nestedPM(ModuleOp::getOperationName());
959 // Use SROA (Destructurable* interfaces) to split each array with linear size `N` into `N`
960 // arrays of size 1. This is necessary because the mem2reg pass cannot deal with indexing
961 // and splitting up memory, i.e., it can only convert scalar memory access into SSA values.
963 // The mem2reg pass converts all of the size-1 array allocation and access into SSA values.
965 // Cleanup allocations made dead by memory promotion.
967 RemoveUnusedDiscardableAllocationsPassOptions {
968 .allocatorOpName = CreateArrayOp::getOperationName().str()
969 }
970 ));
971 // Cleanup SSA values made dead by removing allocations and writes.
972 nestedPM.addPass(createRemoveDeadValuesPass());
973 if (failed(runPipeline(nestedPM, module))) {
974 signalPassFailure();
975 return;
976 }
977 LLVM_DEBUG({
978 llvm::dbgs() << "After SROA+Mem2Reg pipeline:\n";
979 module.dump();
980 });
981 }
982};
983
984} // namespace
Provides SpecializedSROA<AllocOpTy> and SpecializedMem2Reg<AllocOpTy>: pass templates that replicate ...
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
Common implementation for handling MemberWriteOp and MemberReadOp while destructuring an aggregate ty...
::mlir::ArrayAttr indexOperandsToAttributeArray()
Returns the multi-dimensional indices of the array access as an Attribute array or a null pointer if ...
Definition Ops.cpp:216
Helper for converting between linear and multi-dimensional indexing with checks to ensure indices are...
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced base array.
Definition Ops.h.inc:192
::mlir::TypedValue<::mlir::IndexType > getDim()
Definition Ops.h.inc:150
std::optional<::llvm::SmallVector<::mlir::ArrayAttr > > getSubelementIndices() const
Return a list of all valid indices for this ArrayType.
Definition Types.cpp:113
::mlir::Type getElementType() const
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.h.inc:408
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:377
::mlir::MutableOperandRange getElementsMutable()
Definition Ops.cpp.inc:334
::mlir::Operation::operand_range getElements()
Definition Ops.h.inc:388
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.h.inc:613
::mlir::TypedValue<::llzk::array::ArrayType > getRvalue()
Definition Ops.h.inc:757
::mlir::TypedValue<::mlir::Type > getResult()
Definition Ops.h.inc:923
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
Definition Ops.h.inc:899
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
Definition Ops.h.inc:1070
void setPublicAttr(bool newValue=true)
Definition Ops.cpp:568
::mlir::StringAttr getSymNameAttr()
Definition Ops.h.inc:386
::mlir::TypedValue<::mlir::Type > getVal()
Definition Ops.h.inc:956
::llvm::SmallVector< RangeT > getMapOperands()
Definition Ops.h.inc:175
::mlir::MutableOperandRange getArgOperandsMutable()
Definition Ops.cpp.inc:223
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:266
CallOpAdaptor Adaptor
Definition Ops.h.inc:205
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Required by FunctionOpInterface.
Definition Ops.h.inc:850
::llvm::ArrayRef<::mlir::Type > getResultTypes()
Required by FunctionOpInterface.
Definition Ops.h.inc:854
::mlir::MutableOperandRange getOperandsMutable()
Definition Ops.cpp.inc:1169
::mlir::Operation::operand_range getOperands()
Definition Ops.h.inc:991
constexpr char ARG_NAME_ATTR_NAME[]
Attribute name for source-level function argument names.
Definition Ops.h:34
constexpr char RES_NAME_ATTR_NAME[]
Attribute name for source-level function result names.
Definition Ops.h:37
std::unique_ptr<::mlir::Pass > createRemoveUnusedDiscardableAllocationsPass()
mlir::ArrayAttr replicateFunctionNameAttrsAsNeeded(mlir::ArrayAttr origAttrs, const llvm::SmallVector< size_t > &originalIdxToSize, const llvm::SmallVector< mlir::Type > &newTypes, llvm::StringRef functionNameAttrName, llvm::ArrayRef< std::optional< llvm::StringRef > > origNames={}, llvm::ArrayRef< llvm::StringRef > existingNames={}, llvm::ArrayRef< llvm::SmallVector< std::string > > splitNameSuffixes={})
Expand function arg/result attribute arrays to match a split signature, rewriting name attrs with the...
constexpr T checkedCast(U u) noexcept
Definition Compare.h:81
std::unique_ptr< SpecializedMem2Reg< AllocOpTy > > createSpecializedMem2RegPass()
int64_t fromAPInt(const llvm::APInt &i)
function::CallOp createCallPreservingInstantiationOperands(mlir::Location loc, mlir::TypeRange newResultTypes, function::CallOp oldCall, llvm::ArrayRef< mlir::ValueRange > mapOperands, mlir::ValueRange argOperands, mlir::ConversionPatternRewriter &rewriter)
Rebuild a function.call while preserving explicit instantiation state from oldCall.
SplitFunctionNameInfo collectSplitFunctionNameInfo(mlir::ArrayRef< mlir::Type > origTypes, GetNameAttrFn &&getNameAttr, GetSplitSuffixesFn &&getSplitSuffixes)
Collect function arg/result names and split suffixes from a list of original types.
std::unique_ptr< SpecializedSROA< AllocOpTy > > createSpecializedSROAPass()
Cached function arg/result names and split suffixes used while rewriting a function signature.
llvm::SmallVector< std::optional< llvm::StringRef > > originalNames
llvm::SmallVector< llvm::StringRef > existingNames
llvm::SmallVector< llvm::SmallVector< std::string > > splitNameSuffixes