LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
ArrayToScalarPass.cpp
Go to the documentation of this file.
1//===-- ArrayToScalarPass.cpp -----------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
17/// 2. Run a dialect conversion that does the following:
18///
19/// - Replace `MemberReadOp` and `MemberWriteOp` targeting the members that were split in step 1
20/// so they instead perform scalar reads and writes from the new members. The transformation is
21/// local to the current op. Therefore, when replacing the `MemberReadOp` a new array is
22/// created locally and all uses of the `MemberReadOp` are replaced with the new array Value,
23/// then each scalar member read is followed by scalar write into the new array. Similarly,
24/// when replacing a `MemberWriteOp`, each element in the array operand needs a scalar read
25/// from the array followed by a scalar write to the new member. Making only local changes
26/// keeps this step simple and later steps will optimize.
27///
28/// - Replace `ArrayLengthOp` with the constant size of the selected dimension.
29///
30/// - Remove element initialization from `CreateArrayOp` and instead insert a list of
31/// `WriteArrayOp` immediately following.
32///
33/// - Desugar `InsertArrayOp` and `ExtractArrayOp` into their element-wise scalar reads/writes.
34///
35/// - Split arrays to scalars in `FuncDefOp`, `CallOp`, and `ReturnOp` and insert the necessary
36/// create/read/write ops so the changes are as local as possible (just as described for
37/// `MemberReadOp` and `MemberWriteOp`)
38///
39/// 3. Run MLIR "sroa" pass to split each array with linear size `N` into `N` arrays of size 1 (to
40/// prepare for "mem2reg" pass because it's API does not allow for indexing to split aggregates).
41///
42/// 4. Run MLIR "mem2reg" pass to convert all of the size 1 array allocation and access into SSA
43/// values. This pass also runs several standard optimizations so the final result is condensed.
44///
45/// Note: This transformation imposes a "last write wins" semantics on array elements. If
46/// different/configurable semantics are added in the future, some additional transformation would
47/// be necessary before/during this pass so that multiple writes to the same index can be handled
48/// properly while they still exist.
49///
50/// Note: This transformation will introduce a `nondet` op when there exists a read from an array
51/// index that was not earlier written to.
52///
53//===----------------------------------------------------------------------===//
54
63#include "llzk/Util/Compare.h"
64#include "llzk/Util/Concepts.h"
65
66#include <mlir/IR/BuiltinOps.h>
67#include <mlir/Pass/PassManager.h>
68#include <mlir/Transforms/DialectConversion.h>
69#include <mlir/Transforms/Passes.h>
70
71#include <llvm/Support/Debug.h>
72
73// Include the generated base pass class definitions.
74namespace llzk::array {
75#define GEN_PASS_DEF_ARRAYTOSCALARPASS
77} // namespace llzk::array
78
79using namespace mlir;
80using namespace llzk;
81using namespace llzk::array;
82using namespace llzk::component;
83using namespace llzk::function;
84
85#define DEBUG_TYPE "llzk-array-to-scalar"
86
87namespace {
88
90inline ArrayType splittableArray(ArrayType at) { return at.hasStaticShape() ? at : nullptr; }
91
93inline ArrayType splittableArray(Type t) {
94 if (ArrayType at = dyn_cast<ArrayType>(t)) {
95 return splittableArray(at);
96 } else {
97 return nullptr;
98 }
99}
100
102inline bool containsSplittableArrayType(Type t) {
103 return t
104 .walk([](ArrayType a) {
105 return splittableArray(a) ? WalkResult::interrupt() : WalkResult::skip();
106 }).wasInterrupted();
107}
108
110template <typename T> bool containsSplittableArrayType(ValueTypeRange<T> types) {
111 for (Type t : types) {
112 if (containsSplittableArrayType(t)) {
113 return true;
114 }
115 }
116 return false;
117}
118
121size_t splitArrayTypeTo(Type t, SmallVector<Type> &collect) {
122 if (ArrayType at = splittableArray(t)) {
123 size_t size = llzk::checkedCast<size_t>(at.getNumElements());
124 collect.append(size, at.getElementType());
125 return size;
126 } else {
127 collect.push_back(t);
128 return 1;
129 }
130}
131
133template <typename TypeCollection>
134inline void splitArrayTypeTo(
135 TypeCollection types, SmallVector<Type> &collect, SmallVector<size_t> *originalIdxToSize
136) {
137 for (Type t : types) {
138 size_t count = splitArrayTypeTo(t, collect);
139 if (originalIdxToSize) {
140 originalIdxToSize->push_back(count);
141 }
142 }
143}
144
147template <typename TypeCollection>
148inline SmallVector<Type>
149splitArrayType(TypeCollection types, SmallVector<size_t> *originalIdxToSize = nullptr) {
150 SmallVector<Type> collect;
151 splitArrayTypeTo(types, collect, originalIdxToSize);
152 return collect;
153}
154
157SmallVector<Value> genIndexConstants(ArrayAttr index, Location loc, RewriterBase &rewriter) {
158 SmallVector<Value> operands;
159 for (Attribute a : index) {
160 // ASSERT: Attributes are index constants, created by ArrayType::getSubelementIndices().
161 IntegerAttr ia = llvm::dyn_cast<IntegerAttr>(a);
162 assert(ia && ia.getType().isIndex());
163 operands.push_back(rewriter.create<arith::ConstantOp>(loc, ia));
164 }
165 return operands;
166}
167
168inline WriteArrayOp
169genWrite(Location loc, Value baseArrayOp, ArrayAttr index, Value init, RewriterBase &rewriter) {
170 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
171 return rewriter.create<WriteArrayOp>(loc, baseArrayOp, ValueRange(readOperands), init);
172}
173
177CallOp newCallOpWithSplitResults(
178 CallOp oldCall, CallOp::Adaptor adaptor, ConversionPatternRewriter &rewriter
179) {
180 OpBuilder::InsertionGuard guard(rewriter);
181 rewriter.setInsertionPointAfter(oldCall);
182
183 Operation::result_range oldResults = oldCall.getResults();
184 CallOp newCall = rewriter.create<CallOp>(
185 oldCall.getLoc(), splitArrayType(oldResults.getTypes()), oldCall.getCallee(),
186 adaptor.getArgOperands()
187 );
188
189 auto newResults = newCall.getResults().begin();
190 for (Value oldVal : oldResults) {
191 if (ArrayType at = splittableArray(oldVal.getType())) {
192 Location loc = oldVal.getLoc();
193 // Generate `CreateArrayOp` and replace uses of the result with it.
194 auto newArray = rewriter.create<CreateArrayOp>(loc, at);
195 rewriter.replaceAllUsesWith(oldVal, newArray);
196
197 // For all indices in the ArrayType (i.e., the element count), write the next
198 // result from the new CallOp to the new array.
199 std::optional<SmallVector<ArrayAttr>> allIndices = at.getSubelementIndices();
200 assert(allIndices); // follows from legal() check
201 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
202 for (ArrayAttr subIdx : allIndices.value()) {
203 genWrite(loc, newArray, subIdx, *newResults, rewriter);
204 newResults++;
205 }
206 } else {
207 newResults++;
208 }
209 }
210 // erase the original CallOp
211 rewriter.eraseOp(oldCall);
212
213 return newCall;
214}
215
216inline ReadArrayOp
217genRead(Location loc, Value baseArrayOp, ArrayAttr index, ConversionPatternRewriter &rewriter) {
218 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
219 return rewriter.create<ReadArrayOp>(loc, baseArrayOp, ValueRange(readOperands));
220}
221
222// If the operand has ArrayType, add N reads from the array to the `newOperands` list otherwise add
223// the original operand to the list.
224void processInputOperand(
225 Location loc, Value operand, SmallVector<Value> &newOperands,
226 ConversionPatternRewriter &rewriter
227) {
228 if (ArrayType at = splittableArray(operand.getType())) {
229 std::optional<SmallVector<ArrayAttr>> indices = at.getSubelementIndices();
230 assert(indices.has_value() && "passed earlier hasStaticShape() check");
231 for (ArrayAttr index : indices.value()) {
232 newOperands.push_back(genRead(loc, operand, index, rewriter));
233 }
234 } else {
235 newOperands.push_back(operand);
236 }
237}
238
239// For each operand with ArrayType, add N reads from the array and use those N values instead.
240void processInputOperands(
241 ValueRange operands, MutableOperandRange outputOpRef, Operation *op,
242 ConversionPatternRewriter &rewriter
243) {
244 SmallVector<Value> newOperands;
245 for (Value v : operands) {
246 processInputOperand(op->getLoc(), v, newOperands, rewriter);
247 }
248 rewriter.modifyOpInPlace(op, [&outputOpRef, &newOperands]() {
249 outputOpRef.assign(ValueRange(newOperands));
250 });
251}
252
253namespace {
254
255enum Direction : std::uint8_t {
257 SMALL_TO_LARGE,
259 LARGE_TO_SMALL,
260};
261
264template <Direction dir>
265inline void rewriteImpl(
266 ArrayAccessOpInterface op, ArrayType smallType, Value smallArr, Value largeArr,
267 ConversionPatternRewriter &rewriter
268) {
269 assert(smallType); // follows from legal() check
270 Location loc = op.getLoc();
271 MLIRContext *ctx = op.getContext();
272
273 ArrayAttr indexAsAttr = op.indexOperandsToAttributeArray();
274 assert(indexAsAttr); // follows from legal() check
275
276 // For all indices in the ArrayType (i.e., the element count), read from one array into the other
277 // (depending on direction flag).
278 std::optional<SmallVector<ArrayAttr>> subIndices = smallType.getSubelementIndices();
279 assert(subIndices); // follows from legal() check
280 assert(std::cmp_equal(subIndices->size(), smallType.getNumElements()));
281 for (ArrayAttr indexingTail : subIndices.value()) {
282 SmallVector<Attribute> joined;
283 joined.append(indexAsAttr.begin(), indexAsAttr.end());
284 joined.append(indexingTail.begin(), indexingTail.end());
285 ArrayAttr fullIndex = ArrayAttr::get(ctx, joined);
286
287 if constexpr (dir == Direction::SMALL_TO_LARGE) {
288 auto init = genRead(loc, smallArr, indexingTail, rewriter);
289 genWrite(loc, largeArr, fullIndex, init, rewriter);
290 } else if constexpr (dir == Direction::LARGE_TO_SMALL) {
291 auto init = genRead(loc, largeArr, fullIndex, rewriter);
292 genWrite(loc, smallArr, indexingTail, init, rewriter);
293 }
294 }
295}
296
297} // namespace
298
299class SplitInsertArrayOp : public OpConversionPattern<InsertArrayOp> {
300public:
301 using OpConversionPattern<InsertArrayOp>::OpConversionPattern;
302
303 static bool legal(InsertArrayOp op) {
304 return !containsSplittableArrayType(op.getRvalue().getType());
305 }
306
307 LogicalResult match(InsertArrayOp op) const override { return failure(legal(op)); }
308
309 void
310 rewrite(InsertArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
311 ArrayType at = splittableArray(op.getRvalue().getType());
312 rewriteImpl<SMALL_TO_LARGE>(
313 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, adaptor.getRvalue(),
314 adaptor.getArrRef(), rewriter
315 );
316 rewriter.eraseOp(op);
317 }
318};
319
320class SplitExtractArrayOp : public OpConversionPattern<ExtractArrayOp> {
321public:
322 using OpConversionPattern<ExtractArrayOp>::OpConversionPattern;
323
324 static bool legal(ExtractArrayOp op) {
325 return !containsSplittableArrayType(op.getResult().getType());
326 }
327
328 LogicalResult match(ExtractArrayOp op) const override { return failure(legal(op)); }
329
330 void rewrite(
331 ExtractArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
332 ) const override {
333 ArrayType at = splittableArray(op.getResult().getType());
334 // Generate `CreateArrayOp` in place of the current op.
335 auto newArray = rewriter.replaceOpWithNewOp<CreateArrayOp>(op, at);
336 rewriteImpl<LARGE_TO_SMALL>(
337 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, newArray, adaptor.getArrRef(),
338 rewriter
339 );
340 }
341};
342
343class SplitInitFromCreateArrayOp : public OpConversionPattern<CreateArrayOp> {
344public:
345 using OpConversionPattern<CreateArrayOp>::OpConversionPattern;
346
347 static bool legal(CreateArrayOp op) { return op.getElements().empty(); }
348
349 LogicalResult match(CreateArrayOp op) const override { return failure(legal(op)); }
350
351 void
352 rewrite(CreateArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
353 // Remove elements from `op`
354 rewriter.modifyOpInPlace(op, [&op]() { op.getElementsMutable().clear(); });
355 // Generate an individual write for each initialization element
356 rewriter.setInsertionPointAfter(op);
357 Location loc = op.getLoc();
358 ArrayIndexGen idxGen = ArrayIndexGen::from(op.getType());
359 for (auto [i, init] : llvm::enumerate(adaptor.getElements())) {
360 // Convert the linear index 'i' into a multi-dim index
361 std::optional<SmallVector<Value>> multiDimIdxVals =
362 idxGen.delinearize(llzk::checkedCast<int64_t>(i), loc, rewriter);
363 // ASSERT: CreateArrayOp verifier ensures the number of elements provided matches the full
364 // linear array size so delinearization of `i` will not fail.
365 assert(multiDimIdxVals.has_value());
366 // Create the write
367 rewriter.create<WriteArrayOp>(loc, op.getResult(), ValueRange(*multiDimIdxVals), init);
368 }
369 }
370};
371
372class SplitArrayInFuncDefOp : public OpConversionPattern<FuncDefOp> {
373public:
374 using OpConversionPattern<FuncDefOp>::OpConversionPattern;
375
376 inline static bool legal(FuncDefOp op) {
377 return !containsSplittableArrayType(op.getFunctionType());
378 }
379
380 // Create a new ArrayAttr like the one given but with repetitions of the elements according to the
381 // mapping defined by `originalIdxToSize`. In other words, if `originalIdxToSize[i] = n`, then `n`
382 // copies of `origAttrs[i]` are appended in its place.
383 static ArrayAttr replicateAttributesAsNeeded(
384 ArrayAttr origAttrs, const SmallVector<size_t> &originalIdxToSize,
385 const SmallVector<Type> &newTypes
386 ) {
387 if (origAttrs) {
388 assert(originalIdxToSize.size() == origAttrs.size());
389 if (originalIdxToSize.size() != newTypes.size()) {
390 SmallVector<Attribute> newArgAttrs;
391 for (auto [i, s] : llvm::enumerate(originalIdxToSize)) {
392 newArgAttrs.append(s, origAttrs[i]);
393 }
394 return ArrayAttr::get(origAttrs.getContext(), newArgAttrs);
395 }
396 }
397 return nullptr;
398 }
399
400 LogicalResult match(FuncDefOp op) const override { return failure(legal(op)); }
401
402 void rewrite(FuncDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override {
403 // Update in/out types of the function to replace arrays with scalars
404 class Impl : public FunctionTypeConverter {
405 SmallVector<size_t> originalInputIdxToSize, originalResultIdxToSize;
406
407 protected:
408 SmallVector<Type> convertInputs(ArrayRef<Type> origTypes) override {
409 return splitArrayType(origTypes, &originalInputIdxToSize);
410 }
411 SmallVector<Type> convertResults(ArrayRef<Type> origTypes) override {
412 return splitArrayType(origTypes, &originalResultIdxToSize);
413 }
414 ArrayAttr convertInputAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes) override {
415 return replicateAttributesAsNeeded(origAttrs, originalInputIdxToSize, newTypes);
416 }
417 ArrayAttr convertResultAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes) override {
418 return replicateAttributesAsNeeded(origAttrs, originalResultIdxToSize, newTypes);
419 }
420
425 void processBlockArgs(Block &entryBlock, RewriterBase &rewriter) override {
426 OpBuilder::InsertionGuard guard(rewriter);
427 rewriter.setInsertionPointToStart(&entryBlock);
428
429 for (unsigned i = 0; i < entryBlock.getNumArguments();) {
430 Value oldV = entryBlock.getArgument(i);
431 if (ArrayType at = splittableArray(oldV.getType())) {
432 Location loc = oldV.getLoc();
433 // Generate `CreateArrayOp` and replace uses of the argument with it.
434 auto newArray = rewriter.create<CreateArrayOp>(loc, at);
435 rewriter.replaceAllUsesWith(oldV, newArray);
436 // Remove the argument from the block
437 entryBlock.eraseArgument(i);
438 // For all indices in the ArrayType (i.e., the element count), generate a new block
439 // argument and a write of that argument to the new array.
440 std::optional<SmallVector<ArrayAttr>> allIndices = at.getSubelementIndices();
441 assert(allIndices); // follows from legal() check
442 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
443 for (ArrayAttr subIdx : allIndices.value()) {
444 BlockArgument newArg = entryBlock.insertArgument(i, at.getElementType(), loc);
445 genWrite(loc, newArray, subIdx, newArg, rewriter);
446 ++i;
447 }
448 } else {
449 ++i;
450 }
451 }
452 }
453 };
454 Impl().convert(op, rewriter);
455 }
456};
457
458class SplitArrayInReturnOp : public OpConversionPattern<ReturnOp> {
459public:
460 using OpConversionPattern<ReturnOp>::OpConversionPattern;
461
462 inline static bool legal(ReturnOp op) {
463 return !containsSplittableArrayType(op.getOperands().getTypes());
464 }
465
466 LogicalResult match(ReturnOp op) const override { return failure(legal(op)); }
467
468 void rewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
469 processInputOperands(adaptor.getOperands(), op.getOperandsMutable(), op, rewriter);
470 }
471};
472
473class SplitArrayInCallOp : public OpConversionPattern<CallOp> {
474public:
475 using OpConversionPattern<CallOp>::OpConversionPattern;
476
477 inline static bool legal(CallOp op) {
478 return !containsSplittableArrayType(op.getArgOperands().getTypes()) &&
479 !containsSplittableArrayType(op.getResultTypes());
480 }
481
482 LogicalResult match(CallOp op) const override { return failure(legal(op)); }
483
484 void rewrite(CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
485 assert(isNullOrEmpty(op.getMapOpGroupSizesAttr()) && "structs must be previously flattened");
486
487 // Create new CallOp with split results first so, then process its inputs to split types
488 CallOp newCall = newCallOpWithSplitResults(op, adaptor, rewriter);
489 processInputOperands(
490 newCall.getArgOperands(), newCall.getArgOperandsMutable(), newCall, rewriter
491 );
492 }
493};
494
495class ReplaceKnownArrayLengthOp : public OpConversionPattern<ArrayLengthOp> {
496public:
497 using OpConversionPattern<ArrayLengthOp>::OpConversionPattern;
498
500 static std::optional<llvm::APInt> getDimSizeIfKnown(Value dimIdx, ArrayType baseArrType) {
501 if (splittableArray(baseArrType)) {
502 llvm::APInt idxAP;
503 if (mlir::matchPattern(dimIdx, mlir::m_ConstantInt(&idxAP))) {
504 size_t idx = llzk::checkedCast<size_t>(idxAP.getZExtValue());
505 Attribute dimSizeAttr = baseArrType.getDimensionSizes()[idx];
506 if (mlir::matchPattern(dimSizeAttr, mlir::m_ConstantInt(&idxAP))) {
507 return idxAP;
508 }
509 }
510 }
511 return std::nullopt;
512 }
513
514 inline static bool legal(ArrayLengthOp op) {
515 // rewrite() can only work with constant dim size, i.e., must consider it legal otherwise
516 return !getDimSizeIfKnown(op.getDim(), op.getArrRefType()).has_value();
517 }
518
519 LogicalResult match(ArrayLengthOp op) const override { return failure(legal(op)); }
520
521 void
522 rewrite(ArrayLengthOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
523 ArrayType arrTy = dyn_cast<ArrayType>(adaptor.getArrRef().getType());
524 assert(arrTy); // must have array type per ODS spec of ArrayLengthOp
525 std::optional<llvm::APInt> len = getDimSizeIfKnown(adaptor.getDim(), arrTy);
526 assert(len.has_value()); // follows from legal() check
527 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, llzk::fromAPInt(len.value()));
528 }
529};
530
532using MemberInfo = std::pair<StringAttr, Type>;
534using LocalMemberReplacementMap = DenseMap<ArrayAttr, MemberInfo>;
536using MemberReplacementMap = DenseMap<StructDefOp, DenseMap<StringAttr, LocalMemberReplacementMap>>;
537
538class SplitArrayInMemberDefOp : public OpConversionPattern<MemberDefOp> {
539 SymbolTableCollection &tables;
540 MemberReplacementMap &repMapRef;
541
542public:
543 SplitArrayInMemberDefOp(
544 MLIRContext *ctx, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap
545 )
546 : OpConversionPattern<MemberDefOp>(ctx), tables(symTables), repMapRef(memberRepMap) {}
547
548 inline static bool legal(MemberDefOp op) { return !containsSplittableArrayType(op.getType()); }
549
550 LogicalResult match(MemberDefOp op) const override { return failure(legal(op)); }
551
552 void rewrite(MemberDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override {
553 StructDefOp inStruct = op->getParentOfType<StructDefOp>();
554 assert(inStruct);
555 LocalMemberReplacementMap &localRepMapRef = repMapRef[inStruct][op.getSymNameAttr()];
556
557 ArrayType arrTy = dyn_cast<ArrayType>(op.getType());
558 assert(arrTy); // follows from legal() check
559 auto subIdxs = arrTy.getSubelementIndices();
560 assert(subIdxs.has_value());
561 Type elemTy = arrTy.getElementType();
562
563 SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct);
564 for (ArrayAttr idx : subIdxs.value()) {
565 // Create scalar version of the member
566 MemberDefOp newMember = rewriter.create<MemberDefOp>(
567 op.getLoc(), op.getSymNameAttr(), elemTy, op.getColumn(), op.getSignal()
568 );
569 newMember.setPublicAttr(op.hasPublicAttr());
570 // Use SymbolTable to give it a unique name and store to the replacement map
571 localRepMapRef[idx] = std::make_pair(structSymbolTable.insert(newMember), elemTy);
572 }
573 rewriter.eraseOp(op);
574 }
575};
576
582template <
583 typename ImplClass, HasInterface<MemberRefOpInterface> MemberRefOpClass, typename GenHeaderType>
584class SplitArrayInMemberRefOp : public OpConversionPattern<MemberRefOpClass> {
585 SymbolTableCollection &tables;
586 const MemberReplacementMap &repMapRef;
587
588 // static check to ensure the functions are implemented in all subclasses
589 inline static void ensureImplementedAtCompile() {
590 static_assert(
591 sizeof(MemberRefOpClass) == 0, "SplitArrayInMemberRefOp not implemented for requested type."
592 );
593 }
594
595protected:
596 using OpAdaptor = typename MemberRefOpClass::Adaptor;
597
600 static GenHeaderType genHeader(MemberRefOpClass, ConversionPatternRewriter &) {
601 ensureImplementedAtCompile();
602 llvm_unreachable("must have concrete instantiation");
603 }
604
607 static void
608 forIndex(Location, GenHeaderType, ArrayAttr, MemberInfo, OpAdaptor, ConversionPatternRewriter &) {
609 ensureImplementedAtCompile();
610 llvm_unreachable("must have concrete instantiation");
611 }
612
613public:
614 // Suppress false positive from `clang-tidy`
615 // NOLINTNEXTLINE(bugprone-crtp-constructor-accessibility)
616 SplitArrayInMemberRefOp(
617 MLIRContext *ctx, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap
618 )
619 : OpConversionPattern<MemberRefOpClass>(ctx), tables(symTables), repMapRef(memberRepMap) {}
620
621 static bool legal(MemberRefOpClass) {
622 ensureImplementedAtCompile();
623 llvm_unreachable("must have concrete instantiation");
624 return false;
625 }
626
627 LogicalResult match(MemberRefOpClass op) const override { return failure(ImplClass::legal(op)); }
628
629 void rewrite(
630 MemberRefOpClass op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
631 ) const override {
632 StructType tgtStructTy = llvm::cast<MemberRefOpInterface>(op.getOperation()).getStructType();
633 assert(tgtStructTy);
634 auto tgtStructDef = tgtStructTy.getDefinition(tables, op);
635 assert(succeeded(tgtStructDef));
636
637 GenHeaderType prefixResult = ImplClass::genHeader(op, rewriter);
638
639 const LocalMemberReplacementMap &idxToName =
640 repMapRef.at(tgtStructDef->get()).at(op.getMemberNameAttr().getAttr());
641 // Split the array member write into a series of read array + write scalar member
642 for (auto [idx, newMember] : idxToName) {
643 ImplClass::forIndex(op.getLoc(), prefixResult, idx, newMember, adaptor, rewriter);
644 }
645 rewriter.eraseOp(op);
646 }
647};
648
649class SplitArrayInMemberWriteOp
650 : public SplitArrayInMemberRefOp<SplitArrayInMemberWriteOp, MemberWriteOp, void *> {
651public:
652 using SplitArrayInMemberRefOp<
653 SplitArrayInMemberWriteOp, MemberWriteOp, void *>::SplitArrayInMemberRefOp;
654
655 static bool legal(MemberWriteOp op) {
656 return !containsSplittableArrayType(op.getVal().getType());
657 }
658
659 static void *genHeader(MemberWriteOp, ConversionPatternRewriter &) { return nullptr; }
660
661 static void forIndex(
662 Location loc, void *, ArrayAttr idx, MemberInfo newMember, OpAdaptor adaptor,
663 ConversionPatternRewriter &rewriter
664 ) {
665 ReadArrayOp scalarRead = genRead(loc, adaptor.getVal(), idx, rewriter);
666 rewriter.create<MemberWriteOp>(
667 loc, adaptor.getComponent(), FlatSymbolRefAttr::get(newMember.first), scalarRead
668 );
669 }
670};
671
672class SplitArrayInMemberReadOp
673 : public SplitArrayInMemberRefOp<SplitArrayInMemberReadOp, MemberReadOp, CreateArrayOp> {
674public:
675 using SplitArrayInMemberRefOp<
676 SplitArrayInMemberReadOp, MemberReadOp, CreateArrayOp>::SplitArrayInMemberRefOp;
677
678 static bool legal(MemberReadOp op) {
679 return !containsSplittableArrayType(op.getResult().getType());
680 }
681
682 static CreateArrayOp genHeader(MemberReadOp op, ConversionPatternRewriter &rewriter) {
683 CreateArrayOp newArray =
684 rewriter.create<CreateArrayOp>(op.getLoc(), llvm::cast<ArrayType>(op.getType()));
685 rewriter.replaceAllUsesWith(op, newArray);
686 return newArray;
687 }
688
689 static void forIndex(
690 Location loc, CreateArrayOp newArray, ArrayAttr idx, MemberInfo newMember, OpAdaptor adaptor,
691 ConversionPatternRewriter &rewriter
692 ) {
693 MemberReadOp scalarRead = rewriter.create<MemberReadOp>(
694 loc, newMember.second, adaptor.getComponent(), newMember.first
695 );
696 genWrite(loc, newArray, idx, scalarRead, rewriter);
697 }
698};
699
700static void baseTargetSetup(ConversionTarget &target) {
701 target.addLegalDialect<
705 scf::SCFDialect>();
706 target.addLegalOp<ModuleOp>();
707}
708
709static LogicalResult
710step1(ModuleOp modOp, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap) {
711 MLIRContext *ctx = modOp.getContext();
712
713 RewritePatternSet patterns(ctx);
714
715 patterns.add<SplitArrayInMemberDefOp>(ctx, symTables, memberRepMap);
716
717 ConversionTarget target(*ctx);
718 baseTargetSetup(target);
719 target.addDynamicallyLegalOp<MemberDefOp>(SplitArrayInMemberDefOp::legal);
720
721 LLVM_DEBUG(llvm::dbgs() << "Begin step 1: split array members\n";);
722 return applyFullConversion(modOp, target, std::move(patterns));
723}
724
725static LogicalResult
726step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap) {
727 MLIRContext *ctx = modOp.getContext();
728
729 RewritePatternSet patterns(ctx);
730 patterns.add<
731 // clang-format off
732 SplitInitFromCreateArrayOp,
733 SplitInsertArrayOp,
734 SplitExtractArrayOp,
735 SplitArrayInFuncDefOp,
736 SplitArrayInReturnOp,
737 SplitArrayInCallOp,
738 ReplaceKnownArrayLengthOp
739 // clang-format on
740 >(ctx);
741
742 patterns.add<
743 // clang-format off
744 SplitArrayInMemberWriteOp,
745 SplitArrayInMemberReadOp
746 // clang-format on
747 >(ctx, symTables, memberRepMap);
748
749 ConversionTarget target(*ctx);
750 baseTargetSetup(target);
751 target.addDynamicallyLegalOp<CreateArrayOp>(SplitInitFromCreateArrayOp::legal);
752 target.addDynamicallyLegalOp<InsertArrayOp>(SplitInsertArrayOp::legal);
753 target.addDynamicallyLegalOp<ExtractArrayOp>(SplitExtractArrayOp::legal);
754 target.addDynamicallyLegalOp<FuncDefOp>(SplitArrayInFuncDefOp::legal);
755 target.addDynamicallyLegalOp<ReturnOp>(SplitArrayInReturnOp::legal);
756 target.addDynamicallyLegalOp<CallOp>(SplitArrayInCallOp::legal);
757 target.addDynamicallyLegalOp<ArrayLengthOp>(ReplaceKnownArrayLengthOp::legal);
758 target.addDynamicallyLegalOp<MemberWriteOp>(SplitArrayInMemberWriteOp::legal);
759 target.addDynamicallyLegalOp<MemberReadOp>(SplitArrayInMemberReadOp::legal);
760
761 LLVM_DEBUG(llvm::dbgs() << "Begin step 2: update/split other array ops\n";);
762 return applyFullConversion(modOp, target, std::move(patterns));
763}
764
765LogicalResult splitArrayCreateInit(ModuleOp modOp) {
766 SymbolTableCollection symTables;
767 MemberReplacementMap memberRepMap;
768
769 // This is divided into 2 steps to simplify the implementation for member-related ops. The issue
770 // is that the conversions for member read/write expect the mapping of array index to member
771 // name+type to already be populated for the referenced member (although this could be computed on
772 // demand if desired but it complicates the implementation a bit).
773 if (failed(step1(modOp, symTables, memberRepMap))) {
774 return failure();
775 }
776 LLVM_DEBUG({
777 llvm::dbgs() << "After step 1:\n";
778 modOp.dump();
779 });
780 if (failed(step2(modOp, symTables, memberRepMap))) {
781 return failure();
782 }
783 LLVM_DEBUG({
784 llvm::dbgs() << "After step 2:\n";
785 modOp.dump();
786 });
787 return success();
788}
789
790class ArrayToScalarPass : public llzk::array::impl::ArrayToScalarPassBase<ArrayToScalarPass> {
791 void runOnOperation() override {
792 ModuleOp module = getOperation();
793 // Separate array initialization from creation by removing the initialization list from
794 // CreateArrayOp and inserting the corresponding WriteArrayOp following it.
795 if (failed(splitArrayCreateInit(module))) {
796 signalPassFailure();
797 return;
798 }
799 OpPassManager nestedPM(ModuleOp::getOperationName());
800 // Use SROA (Destructurable* interfaces) to split each array with linear size N into N arrays of
801 // size 1. This is necessary because the mem2reg pass cannot deal with indexing and splitting up
802 // memory, i.e., it can only convert scalar memory access into SSA values.
803 nestedPM.addPass(createSROA());
804 // The mem2reg pass converts all of the size 1 array allocation and access into SSA values.
805 nestedPM.addPass(createMem2Reg());
806 // Cleanup SSA values made dead by the transformations
807 nestedPM.addPass(createRemoveDeadValuesPass());
808 if (failed(runPipeline(nestedPM, module))) {
809 signalPassFailure();
810 return;
811 }
812 LLVM_DEBUG({
813 llvm::dbgs() << "After SROA+Mem2Reg pipeline:\n";
814 module.dump();
815 });
816 }
817};
818
819} // namespace
820
822 return std::make_unique<ArrayToScalarPass>();
823};
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
::mlir::ArrayAttr indexOperandsToAttributeArray()
Returns the multi-dimensional indices of the array access as an Attribute array or a null pointer if ...
Definition Ops.cpp:213
Helper for converting between linear and multi-dimensional indexing with checks to ensure indices are...
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced base array.
Definition Ops.h.inc:192
::mlir::TypedValue<::mlir::IndexType > getDim()
Definition Ops.h.inc:150
std::optional<::llvm::SmallVector<::mlir::ArrayAttr > > getSubelementIndices() const
Return a list of all valid indices for this ArrayType.
Definition Types.cpp:113
::mlir::Type getElementType() const
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.h.inc:408
::mlir::MutableOperandRange getElementsMutable()
Definition Ops.cpp.inc:334
::mlir::Operation::operand_range getElements()
Definition Ops.h.inc:388
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.h.inc:613
::mlir::TypedValue<::llzk::array::ArrayType > getRvalue()
Definition Ops.h.inc:757
void setPublicAttr(bool newValue=true)
Definition Ops.cpp:566
::mlir::StringAttr getSymNameAttr()
Definition Ops.h.inc:386
::mlir::TypedValue<::mlir::Type > getVal()
Definition Ops.h.inc:956
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
Definition Types.cpp:26
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:470
::mlir::MutableOperandRange getArgOperandsMutable()
Definition Ops.cpp.inc:223
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:266
CallOpAdaptor Adaptor
Definition Ops.h.inc:205
::mlir::DenseI32ArrayAttr getMapOpGroupSizesAttr()
Definition Ops.h.inc:307
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:984
::mlir::MutableOperandRange getOperandsMutable()
Definition Ops.cpp.inc:1169
::mlir::Operation::operand_range getOperands()
Definition Ops.h.inc:979
Restricts a template parameter to Op classes that implement the given OpInterface.
Definition Concepts.h:20
std::unique_ptr< mlir::Pass > createArrayToScalarPass()
bool isNullOrEmpty(mlir::ArrayAttr a)
constexpr T checkedCast(U u) noexcept
Definition Compare.h:81
int64_t fromAPInt(const llvm::APInt &i)