LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
ArrayToScalarPass.cpp
Go to the documentation of this file.
1//===-- ArrayToScalarPass.cpp -----------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
17/// 2. Run a dialect conversion that does the following:
18///
19/// - Replace `MemberReadOp` and `MemberWriteOp` targeting the members that were split in step 1
20/// so they instead perform scalar reads and writes from the new members. The transformation is
21/// local to the current op. Therefore, when replacing the `MemberReadOp` a new array is
22/// created locally and all uses of the `MemberReadOp` are replaced with the new array Value,
23/// then each scalar member read is followed by scalar write into the new array. Similarly,
24/// when replacing a `MemberWriteOp`, each element in the array operand needs a scalar read
25/// from the array followed by a scalar write to the new member. Making only local changes
26/// keeps this step simple and later steps will optimize.
27///
28/// - Replace `ArrayLengthOp` with the constant size of the selected dimension.
29///
30/// - Remove element initialization from `CreateArrayOp` and instead insert a list of
31/// `WriteArrayOp` immediately following.
32///
33/// - Desugar `InsertArrayOp` and `ExtractArrayOp` into their element-wise scalar reads/writes.
34///
35/// - Split arrays to scalars in `FuncDefOp`, `CallOp`, and `ReturnOp` and insert the necessary
36/// create/read/write ops so the changes are as local as possible (just as described for
37/// `MemberReadOp` and `MemberWriteOp`)
38///
39/// 3. Run MLIR "sroa" pass to split each array with linear size `N` into `N` arrays of size 1 (to
40/// prepare for "mem2reg" pass because it's API does not allow for indexing to split aggregates).
41///
42/// 4. Run MLIR "mem2reg" pass to convert all of the size 1 array allocation and access into SSA
43/// values. This pass also runs several standard optimizations so the final result is condensed.
44///
45/// Note: This transformation imposes a "last write wins" semantics on array elements. If
46/// different/configurable semantics are added in the future, some additional transformation would
47/// be necessary before/during this pass so that multiple writes to the same index can be handled
48/// properly while they still exist.
49///
50/// Note: This transformation will introduce a `nondet` op when there exists a read from an array
51/// index that was not earlier written to.
52///
53//===----------------------------------------------------------------------===//
54
63#include "llzk/Util/Concepts.h"
64
65#include <mlir/IR/BuiltinOps.h>
66#include <mlir/Pass/PassManager.h>
67#include <mlir/Transforms/DialectConversion.h>
68#include <mlir/Transforms/Passes.h>
69
70#include <llvm/Support/Debug.h>
71
72// Include the generated base pass class definitions.
73namespace llzk::array {
74#define GEN_PASS_DEF_ARRAYTOSCALARPASS
76} // namespace llzk::array
77
78using namespace mlir;
79using namespace llzk;
80using namespace llzk::array;
81using namespace llzk::component;
82using namespace llzk::function;
83
84#define DEBUG_TYPE "llzk-array-to-scalar"
85
86namespace {
87
89inline ArrayType splittableArray(ArrayType at) { return at.hasStaticShape() ? at : nullptr; }
90
92inline ArrayType splittableArray(Type t) {
93 if (ArrayType at = dyn_cast<ArrayType>(t)) {
94 return splittableArray(at);
95 } else {
96 return nullptr;
97 }
98}
99
101inline bool containsSplittableArrayType(Type t) {
102 return t
103 .walk([](ArrayType a) {
104 return splittableArray(a) ? WalkResult::interrupt() : WalkResult::skip();
105 }).wasInterrupted();
106}
107
109template <typename T> bool containsSplittableArrayType(ValueTypeRange<T> types) {
110 for (Type t : types) {
111 if (containsSplittableArrayType(t)) {
112 return true;
113 }
114 }
115 return false;
116}
117
120size_t splitArrayTypeTo(Type t, SmallVector<Type> &collect) {
121 if (ArrayType at = splittableArray(t)) {
122 int64_t n = at.getNumElements();
123 assert(n >= 0);
124 assert(std::cmp_less_equal(n, std::numeric_limits<size_t>::max()));
125 size_t size = n;
126 collect.append(size, at.getElementType());
127 return size;
128 } else {
129 collect.push_back(t);
130 return 1;
131 }
132}
133
135template <typename TypeCollection>
136inline void splitArrayTypeTo(
137 TypeCollection types, SmallVector<Type> &collect, SmallVector<size_t> *originalIdxToSize
138) {
139 for (Type t : types) {
140 size_t count = splitArrayTypeTo(t, collect);
141 if (originalIdxToSize) {
142 originalIdxToSize->push_back(count);
143 }
144 }
145}
146
149template <typename TypeCollection>
150inline SmallVector<Type>
151splitArrayType(TypeCollection types, SmallVector<size_t> *originalIdxToSize = nullptr) {
152 SmallVector<Type> collect;
153 splitArrayTypeTo(types, collect, originalIdxToSize);
154 return collect;
155}
156
159SmallVector<Value> genIndexConstants(ArrayAttr index, Location loc, RewriterBase &rewriter) {
160 SmallVector<Value> operands;
161 for (Attribute a : index) {
162 // ASSERT: Attributes are index constants, created by ArrayType::getSubelementIndices().
163 IntegerAttr ia = llvm::dyn_cast<IntegerAttr>(a);
164 assert(ia && ia.getType().isIndex());
165 operands.push_back(rewriter.create<arith::ConstantOp>(loc, ia));
166 }
167 return operands;
168}
169
170inline WriteArrayOp
171genWrite(Location loc, Value baseArrayOp, ArrayAttr index, Value init, RewriterBase &rewriter) {
172 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
173 return rewriter.create<WriteArrayOp>(loc, baseArrayOp, ValueRange(readOperands), init);
174}
175
179CallOp newCallOpWithSplitResults(
180 CallOp oldCall, CallOp::Adaptor adaptor, ConversionPatternRewriter &rewriter
181) {
182 OpBuilder::InsertionGuard guard(rewriter);
183 rewriter.setInsertionPointAfter(oldCall);
184
185 Operation::result_range oldResults = oldCall.getResults();
186 CallOp newCall = rewriter.create<CallOp>(
187 oldCall.getLoc(), splitArrayType(oldResults.getTypes()), oldCall.getCallee(),
188 adaptor.getArgOperands()
189 );
190
191 auto newResults = newCall.getResults().begin();
192 for (Value oldVal : oldResults) {
193 if (ArrayType at = splittableArray(oldVal.getType())) {
194 Location loc = oldVal.getLoc();
195 // Generate `CreateArrayOp` and replace uses of the result with it.
196 auto newArray = rewriter.create<CreateArrayOp>(loc, at);
197 rewriter.replaceAllUsesWith(oldVal, newArray);
198
199 // For all indices in the ArrayType (i.e., the element count), write the next
200 // result from the new CallOp to the new array.
201 std::optional<SmallVector<ArrayAttr>> allIndices = at.getSubelementIndices();
202 assert(allIndices); // follows from legal() check
203 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
204 for (ArrayAttr subIdx : allIndices.value()) {
205 genWrite(loc, newArray, subIdx, *newResults, rewriter);
206 newResults++;
207 }
208 } else {
209 newResults++;
210 }
211 }
212 // erase the original CallOp
213 rewriter.eraseOp(oldCall);
214
215 return newCall;
216}
217
218inline ReadArrayOp
219genRead(Location loc, Value baseArrayOp, ArrayAttr index, ConversionPatternRewriter &rewriter) {
220 SmallVector<Value> readOperands = genIndexConstants(index, loc, rewriter);
221 return rewriter.create<ReadArrayOp>(loc, baseArrayOp, ValueRange(readOperands));
222}
223
224// If the operand has ArrayType, add N reads from the array to the `newOperands` list otherwise add
225// the original operand to the list.
226void processInputOperand(
227 Location loc, Value operand, SmallVector<Value> &newOperands,
228 ConversionPatternRewriter &rewriter
229) {
230 if (ArrayType at = splittableArray(operand.getType())) {
231 std::optional<SmallVector<ArrayAttr>> indices = at.getSubelementIndices();
232 assert(indices.has_value() && "passed earlier hasStaticShape() check");
233 for (ArrayAttr index : indices.value()) {
234 newOperands.push_back(genRead(loc, operand, index, rewriter));
235 }
236 } else {
237 newOperands.push_back(operand);
238 }
239}
240
241// For each operand with ArrayType, add N reads from the array and use those N values instead.
242void processInputOperands(
243 ValueRange operands, MutableOperandRange outputOpRef, Operation *op,
244 ConversionPatternRewriter &rewriter
245) {
246 SmallVector<Value> newOperands;
247 for (Value v : operands) {
248 processInputOperand(op->getLoc(), v, newOperands, rewriter);
249 }
250 rewriter.modifyOpInPlace(op, [&outputOpRef, &newOperands]() {
251 outputOpRef.assign(ValueRange(newOperands));
252 });
253}
254
255namespace {
256
257enum Direction : std::uint8_t {
259 SMALL_TO_LARGE,
261 LARGE_TO_SMALL,
262};
263
266template <Direction dir>
267inline void rewriteImpl(
268 ArrayAccessOpInterface op, ArrayType smallType, Value smallArr, Value largeArr,
269 ConversionPatternRewriter &rewriter
270) {
271 assert(smallType); // follows from legal() check
272 Location loc = op.getLoc();
273 MLIRContext *ctx = op.getContext();
274
275 ArrayAttr indexAsAttr = op.indexOperandsToAttributeArray();
276 assert(indexAsAttr); // follows from legal() check
277
278 // For all indices in the ArrayType (i.e., the element count), read from one array into the other
279 // (depending on direction flag).
280 std::optional<SmallVector<ArrayAttr>> subIndices = smallType.getSubelementIndices();
281 assert(subIndices); // follows from legal() check
282 assert(std::cmp_equal(subIndices->size(), smallType.getNumElements()));
283 for (ArrayAttr indexingTail : subIndices.value()) {
284 SmallVector<Attribute> joined;
285 joined.append(indexAsAttr.begin(), indexAsAttr.end());
286 joined.append(indexingTail.begin(), indexingTail.end());
287 ArrayAttr fullIndex = ArrayAttr::get(ctx, joined);
288
289 if constexpr (dir == Direction::SMALL_TO_LARGE) {
290 auto init = genRead(loc, smallArr, indexingTail, rewriter);
291 genWrite(loc, largeArr, fullIndex, init, rewriter);
292 } else if constexpr (dir == Direction::LARGE_TO_SMALL) {
293 auto init = genRead(loc, largeArr, fullIndex, rewriter);
294 genWrite(loc, smallArr, indexingTail, init, rewriter);
295 }
296 }
297}
298
299} // namespace
300
301class SplitInsertArrayOp : public OpConversionPattern<InsertArrayOp> {
302public:
303 using OpConversionPattern<InsertArrayOp>::OpConversionPattern;
304
305 static bool legal(InsertArrayOp op) {
306 return !containsSplittableArrayType(op.getRvalue().getType());
307 }
308
309 LogicalResult match(InsertArrayOp op) const override { return failure(legal(op)); }
310
311 void
312 rewrite(InsertArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
313 ArrayType at = splittableArray(op.getRvalue().getType());
314 rewriteImpl<SMALL_TO_LARGE>(
315 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, adaptor.getRvalue(),
316 adaptor.getArrRef(), rewriter
317 );
318 rewriter.eraseOp(op);
319 }
320};
321
322class SplitExtractArrayOp : public OpConversionPattern<ExtractArrayOp> {
323public:
324 using OpConversionPattern<ExtractArrayOp>::OpConversionPattern;
325
326 static bool legal(ExtractArrayOp op) {
327 return !containsSplittableArrayType(op.getResult().getType());
328 }
329
330 LogicalResult match(ExtractArrayOp op) const override { return failure(legal(op)); }
331
332 void rewrite(
333 ExtractArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
334 ) const override {
335 ArrayType at = splittableArray(op.getResult().getType());
336 // Generate `CreateArrayOp` in place of the current op.
337 auto newArray = rewriter.replaceOpWithNewOp<CreateArrayOp>(op, at);
338 rewriteImpl<LARGE_TO_SMALL>(
339 llvm::cast<ArrayAccessOpInterface>(op.getOperation()), at, newArray, adaptor.getArrRef(),
340 rewriter
341 );
342 }
343};
344
345class SplitInitFromCreateArrayOp : public OpConversionPattern<CreateArrayOp> {
346public:
347 using OpConversionPattern<CreateArrayOp>::OpConversionPattern;
348
349 static bool legal(CreateArrayOp op) { return op.getElements().empty(); }
350
351 LogicalResult match(CreateArrayOp op) const override { return failure(legal(op)); }
352
353 void
354 rewrite(CreateArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
355 // Remove elements from `op`
356 rewriter.modifyOpInPlace(op, [&op]() { op.getElementsMutable().clear(); });
357 // Generate an individual write for each initialization element
358 rewriter.setInsertionPointAfter(op);
359 Location loc = op.getLoc();
360 ArrayIndexGen idxGen = ArrayIndexGen::from(op.getType());
361 for (auto [i, init] : llvm::enumerate(adaptor.getElements())) {
362 // Convert the linear index 'i' into a multi-dim index
363 assert(std::cmp_less_equal(i, std::numeric_limits<int64_t>::max()));
364 std::optional<SmallVector<Value>> multiDimIdxVals =
365 idxGen.delinearize(static_cast<int64_t>(i), loc, rewriter);
366 // ASSERT: CreateArrayOp verifier ensures the number of elements provided matches the full
367 // linear array size so delinearization of `i` will not fail.
368 assert(multiDimIdxVals.has_value());
369 // Create the write
370 rewriter.create<WriteArrayOp>(loc, op.getResult(), ValueRange(*multiDimIdxVals), init);
371 }
372 }
373};
374
375class SplitArrayInFuncDefOp : public OpConversionPattern<FuncDefOp> {
376public:
377 using OpConversionPattern<FuncDefOp>::OpConversionPattern;
378
379 inline static bool legal(FuncDefOp op) {
380 return !containsSplittableArrayType(op.getFunctionType());
381 }
382
383 // Create a new ArrayAttr like the one given but with repetitions of the elements according to the
384 // mapping defined by `originalIdxToSize`. In other words, if `originalIdxToSize[i] = n`, then `n`
385 // copies of `origAttrs[i]` are appended in its place.
386 static ArrayAttr replicateAttributesAsNeeded(
387 ArrayAttr origAttrs, const SmallVector<size_t> &originalIdxToSize,
388 const SmallVector<Type> &newTypes
389 ) {
390 if (origAttrs) {
391 assert(originalIdxToSize.size() == origAttrs.size());
392 if (originalIdxToSize.size() != newTypes.size()) {
393 SmallVector<Attribute> newArgAttrs;
394 for (auto [i, s] : llvm::enumerate(originalIdxToSize)) {
395 newArgAttrs.append(s, origAttrs[i]);
396 }
397 return ArrayAttr::get(origAttrs.getContext(), newArgAttrs);
398 }
399 }
400 return nullptr;
401 }
402
403 LogicalResult match(FuncDefOp op) const override { return failure(legal(op)); }
404
405 void rewrite(FuncDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override {
406 // Update in/out types of the function to replace arrays with scalars
407 class Impl : public FunctionTypeConverter {
408 SmallVector<size_t> originalInputIdxToSize, originalResultIdxToSize;
409
410 protected:
411 SmallVector<Type> convertInputs(ArrayRef<Type> origTypes) override {
412 return splitArrayType(origTypes, &originalInputIdxToSize);
413 }
414 SmallVector<Type> convertResults(ArrayRef<Type> origTypes) override {
415 return splitArrayType(origTypes, &originalResultIdxToSize);
416 }
417 ArrayAttr convertInputAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes) override {
418 return replicateAttributesAsNeeded(origAttrs, originalInputIdxToSize, newTypes);
419 }
420 ArrayAttr convertResultAttrs(ArrayAttr origAttrs, SmallVector<Type> newTypes) override {
421 return replicateAttributesAsNeeded(origAttrs, originalResultIdxToSize, newTypes);
422 }
423
428 void processBlockArgs(Block &entryBlock, RewriterBase &rewriter) override {
429 OpBuilder::InsertionGuard guard(rewriter);
430 rewriter.setInsertionPointToStart(&entryBlock);
431
432 for (unsigned i = 0; i < entryBlock.getNumArguments();) {
433 Value oldV = entryBlock.getArgument(i);
434 if (ArrayType at = splittableArray(oldV.getType())) {
435 Location loc = oldV.getLoc();
436 // Generate `CreateArrayOp` and replace uses of the argument with it.
437 auto newArray = rewriter.create<CreateArrayOp>(loc, at);
438 rewriter.replaceAllUsesWith(oldV, newArray);
439 // Remove the argument from the block
440 entryBlock.eraseArgument(i);
441 // For all indices in the ArrayType (i.e., the element count), generate a new block
442 // argument and a write of that argument to the new array.
443 std::optional<SmallVector<ArrayAttr>> allIndices = at.getSubelementIndices();
444 assert(allIndices); // follows from legal() check
445 assert(std::cmp_equal(allIndices->size(), at.getNumElements()));
446 for (ArrayAttr subIdx : allIndices.value()) {
447 BlockArgument newArg = entryBlock.insertArgument(i, at.getElementType(), loc);
448 genWrite(loc, newArray, subIdx, newArg, rewriter);
449 ++i;
450 }
451 } else {
452 ++i;
453 }
454 }
455 }
456 };
457 Impl().convert(op, rewriter);
458 }
459};
460
461class SplitArrayInReturnOp : public OpConversionPattern<ReturnOp> {
462public:
463 using OpConversionPattern<ReturnOp>::OpConversionPattern;
464
465 inline static bool legal(ReturnOp op) {
466 return !containsSplittableArrayType(op.getOperands().getTypes());
467 }
468
469 LogicalResult match(ReturnOp op) const override { return failure(legal(op)); }
470
471 void rewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
472 processInputOperands(adaptor.getOperands(), op.getOperandsMutable(), op, rewriter);
473 }
474};
475
476class SplitArrayInCallOp : public OpConversionPattern<CallOp> {
477public:
478 using OpConversionPattern<CallOp>::OpConversionPattern;
479
480 inline static bool legal(CallOp op) {
481 return !containsSplittableArrayType(op.getArgOperands().getTypes()) &&
482 !containsSplittableArrayType(op.getResultTypes());
483 }
484
485 LogicalResult match(CallOp op) const override { return failure(legal(op)); }
486
487 void rewrite(CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
488 assert(isNullOrEmpty(op.getMapOpGroupSizesAttr()) && "structs must be previously flattened");
489
490 // Create new CallOp with split results first so, then process its inputs to split types
491 CallOp newCall = newCallOpWithSplitResults(op, adaptor, rewriter);
492 processInputOperands(
493 newCall.getArgOperands(), newCall.getArgOperandsMutable(), newCall, rewriter
494 );
495 }
496};
497
498class ReplaceKnownArrayLengthOp : public OpConversionPattern<ArrayLengthOp> {
499public:
500 using OpConversionPattern<ArrayLengthOp>::OpConversionPattern;
501
503 static std::optional<llvm::APInt> getDimSizeIfKnown(Value dimIdx, ArrayType baseArrType) {
504 if (splittableArray(baseArrType)) {
505 llvm::APInt idxAP;
506 if (mlir::matchPattern(dimIdx, mlir::m_ConstantInt(&idxAP))) {
507 uint64_t idx64 = idxAP.getZExtValue();
508 assert(std::cmp_less_equal(idx64, std::numeric_limits<size_t>::max()));
509 Attribute dimSizeAttr = baseArrType.getDimensionSizes()[static_cast<size_t>(idx64)];
510 if (mlir::matchPattern(dimSizeAttr, mlir::m_ConstantInt(&idxAP))) {
511 return idxAP;
512 }
513 }
514 }
515 return std::nullopt;
516 }
517
518 inline static bool legal(ArrayLengthOp op) {
519 // rewrite() can only work with constant dim size, i.e., must consider it legal otherwise
520 return !getDimSizeIfKnown(op.getDim(), op.getArrRefType()).has_value();
521 }
522
523 LogicalResult match(ArrayLengthOp op) const override { return failure(legal(op)); }
524
525 void
526 rewrite(ArrayLengthOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
527 ArrayType arrTy = dyn_cast<ArrayType>(adaptor.getArrRef().getType());
528 assert(arrTy); // must have array type per ODS spec of ArrayLengthOp
529 std::optional<llvm::APInt> len = getDimSizeIfKnown(adaptor.getDim(), arrTy);
530 assert(len.has_value()); // follows from legal() check
531 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, llzk::fromAPInt(len.value()));
532 }
533};
534
536using MemberInfo = std::pair<StringAttr, Type>;
538using LocalMemberReplacementMap = DenseMap<ArrayAttr, MemberInfo>;
540using MemberReplacementMap = DenseMap<StructDefOp, DenseMap<StringAttr, LocalMemberReplacementMap>>;
541
542class SplitArrayInMemberDefOp : public OpConversionPattern<MemberDefOp> {
543 SymbolTableCollection &tables;
544 MemberReplacementMap &repMapRef;
545
546public:
547 SplitArrayInMemberDefOp(
548 MLIRContext *ctx, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap
549 )
550 : OpConversionPattern<MemberDefOp>(ctx), tables(symTables), repMapRef(memberRepMap) {}
551
552 inline static bool legal(MemberDefOp op) { return !containsSplittableArrayType(op.getType()); }
553
554 LogicalResult match(MemberDefOp op) const override { return failure(legal(op)); }
555
556 void rewrite(MemberDefOp op, OpAdaptor, ConversionPatternRewriter &rewriter) const override {
557 StructDefOp inStruct = op->getParentOfType<StructDefOp>();
558 assert(inStruct);
559 LocalMemberReplacementMap &localRepMapRef = repMapRef[inStruct][op.getSymNameAttr()];
560
561 ArrayType arrTy = dyn_cast<ArrayType>(op.getType());
562 assert(arrTy); // follows from legal() check
563 auto subIdxs = arrTy.getSubelementIndices();
564 assert(subIdxs.has_value());
565 Type elemTy = arrTy.getElementType();
566
567 SymbolTable &structSymbolTable = tables.getSymbolTable(inStruct);
568 for (ArrayAttr idx : subIdxs.value()) {
569 // Create scalar version of the member
570 MemberDefOp newMember = rewriter.create<MemberDefOp>(
571 op.getLoc(), op.getSymNameAttr(), elemTy, op.getColumn(), op.getSignal()
572 );
573 newMember.setPublicAttr(op.hasPublicAttr());
574 // Use SymbolTable to give it a unique name and store to the replacement map
575 localRepMapRef[idx] = std::make_pair(structSymbolTable.insert(newMember), elemTy);
576 }
577 rewriter.eraseOp(op);
578 }
579};
580
586template <
587 typename ImplClass, HasInterface<MemberRefOpInterface> MemberRefOpClass, typename GenHeaderType>
588class SplitArrayInMemberRefOp : public OpConversionPattern<MemberRefOpClass> {
589 SymbolTableCollection &tables;
590 const MemberReplacementMap &repMapRef;
591
592 // static check to ensure the functions are implemented in all subclasses
593 inline static void ensureImplementedAtCompile() {
594 static_assert(
595 sizeof(MemberRefOpClass) == 0, "SplitArrayInMemberRefOp not implemented for requested type."
596 );
597 }
598
599protected:
600 using OpAdaptor = typename MemberRefOpClass::Adaptor;
601
604 static GenHeaderType genHeader(MemberRefOpClass, ConversionPatternRewriter &) {
605 ensureImplementedAtCompile();
606 llvm_unreachable("must have concrete instantiation");
607 }
608
611 static void
612 forIndex(Location, GenHeaderType, ArrayAttr, MemberInfo, OpAdaptor, ConversionPatternRewriter &) {
613 ensureImplementedAtCompile();
614 llvm_unreachable("must have concrete instantiation");
615 }
616
617public:
618 SplitArrayInMemberRefOp(
619 MLIRContext *ctx, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap
620 )
621 : OpConversionPattern<MemberRefOpClass>(ctx), tables(symTables), repMapRef(memberRepMap) {}
622
623 static bool legal(MemberRefOpClass) {
624 ensureImplementedAtCompile();
625 llvm_unreachable("must have concrete instantiation");
626 return false;
627 }
628
629 LogicalResult match(MemberRefOpClass op) const override { return failure(ImplClass::legal(op)); }
630
631 void rewrite(
632 MemberRefOpClass op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
633 ) const override {
634 StructType tgtStructTy = llvm::cast<MemberRefOpInterface>(op.getOperation()).getStructType();
635 assert(tgtStructTy);
636 auto tgtStructDef = tgtStructTy.getDefinition(tables, op);
637 assert(succeeded(tgtStructDef));
638
639 GenHeaderType prefixResult = ImplClass::genHeader(op, rewriter);
640
641 const LocalMemberReplacementMap &idxToName =
642 repMapRef.at(tgtStructDef->get()).at(op.getMemberNameAttr().getAttr());
643 // Split the array member write into a series of read array + write scalar member
644 for (auto [idx, newMember] : idxToName) {
645 ImplClass::forIndex(op.getLoc(), prefixResult, idx, newMember, adaptor, rewriter);
646 }
647 rewriter.eraseOp(op);
648 }
649};
650
651class SplitArrayInMemberWriteOp
652 : public SplitArrayInMemberRefOp<SplitArrayInMemberWriteOp, MemberWriteOp, void *> {
653public:
654 using SplitArrayInMemberRefOp<
655 SplitArrayInMemberWriteOp, MemberWriteOp, void *>::SplitArrayInMemberRefOp;
656
657 static bool legal(MemberWriteOp op) {
658 return !containsSplittableArrayType(op.getVal().getType());
659 }
660
661 static void *genHeader(MemberWriteOp, ConversionPatternRewriter &) { return nullptr; }
662
663 static void forIndex(
664 Location loc, void *, ArrayAttr idx, MemberInfo newMember, OpAdaptor adaptor,
665 ConversionPatternRewriter &rewriter
666 ) {
667 ReadArrayOp scalarRead = genRead(loc, adaptor.getVal(), idx, rewriter);
668 rewriter.create<MemberWriteOp>(
669 loc, adaptor.getComponent(), FlatSymbolRefAttr::get(newMember.first), scalarRead
670 );
671 }
672};
673
674class SplitArrayInMemberReadOp
675 : public SplitArrayInMemberRefOp<SplitArrayInMemberReadOp, MemberReadOp, CreateArrayOp> {
676public:
677 using SplitArrayInMemberRefOp<
678 SplitArrayInMemberReadOp, MemberReadOp, CreateArrayOp>::SplitArrayInMemberRefOp;
679
680 static bool legal(MemberReadOp op) {
681 return !containsSplittableArrayType(op.getResult().getType());
682 }
683
684 static CreateArrayOp genHeader(MemberReadOp op, ConversionPatternRewriter &rewriter) {
685 CreateArrayOp newArray =
686 rewriter.create<CreateArrayOp>(op.getLoc(), llvm::cast<ArrayType>(op.getType()));
687 rewriter.replaceAllUsesWith(op, newArray);
688 return newArray;
689 }
690
691 static void forIndex(
692 Location loc, CreateArrayOp newArray, ArrayAttr idx, MemberInfo newMember, OpAdaptor adaptor,
693 ConversionPatternRewriter &rewriter
694 ) {
695 MemberReadOp scalarRead = rewriter.create<MemberReadOp>(
696 loc, newMember.second, adaptor.getComponent(), newMember.first
697 );
698 genWrite(loc, newArray, idx, scalarRead, rewriter);
699 }
700};
701
702LogicalResult
703step1(ModuleOp modOp, SymbolTableCollection &symTables, MemberReplacementMap &memberRepMap) {
704 MLIRContext *ctx = modOp.getContext();
705
706 RewritePatternSet patterns(ctx);
707
708 patterns.add<SplitArrayInMemberDefOp>(ctx, symTables, memberRepMap);
709
710 ConversionTarget target(*ctx);
711 target.addLegalDialect<
715 scf::SCFDialect>();
716 target.addLegalOp<ModuleOp>();
717 target.addDynamicallyLegalOp<MemberDefOp>(SplitArrayInMemberDefOp::legal);
718
719 LLVM_DEBUG(llvm::dbgs() << "Begin step 1: split array members\n";);
720 return applyFullConversion(modOp, target, std::move(patterns));
721}
722
723LogicalResult
724step2(ModuleOp modOp, SymbolTableCollection &symTables, const MemberReplacementMap &memberRepMap) {
725 MLIRContext *ctx = modOp.getContext();
726
727 RewritePatternSet patterns(ctx);
728 patterns.add<
729 // clang-format off
730 SplitInitFromCreateArrayOp,
731 SplitInsertArrayOp,
732 SplitExtractArrayOp,
733 SplitArrayInFuncDefOp,
734 SplitArrayInReturnOp,
735 SplitArrayInCallOp,
736 ReplaceKnownArrayLengthOp
737 // clang-format on
738 >(ctx);
739
740 patterns.add<
741 // clang-format off
742 SplitArrayInMemberWriteOp,
743 SplitArrayInMemberReadOp
744 // clang-format on
745 >(ctx, symTables, memberRepMap);
746
747 ConversionTarget target(*ctx);
748 target.addLegalDialect<
751 global::GlobalDialect, include::IncludeDialect, arith::ArithDialect, scf::SCFDialect>();
752 target.addLegalOp<ModuleOp>();
753 target.addDynamicallyLegalOp<CreateArrayOp>(SplitInitFromCreateArrayOp::legal);
754 target.addDynamicallyLegalOp<InsertArrayOp>(SplitInsertArrayOp::legal);
755 target.addDynamicallyLegalOp<ExtractArrayOp>(SplitExtractArrayOp::legal);
756 target.addDynamicallyLegalOp<FuncDefOp>(SplitArrayInFuncDefOp::legal);
757 target.addDynamicallyLegalOp<ReturnOp>(SplitArrayInReturnOp::legal);
758 target.addDynamicallyLegalOp<CallOp>(SplitArrayInCallOp::legal);
759 target.addDynamicallyLegalOp<ArrayLengthOp>(ReplaceKnownArrayLengthOp::legal);
760 target.addDynamicallyLegalOp<MemberWriteOp>(SplitArrayInMemberWriteOp::legal);
761 target.addDynamicallyLegalOp<MemberReadOp>(SplitArrayInMemberReadOp::legal);
762
763 LLVM_DEBUG(llvm::dbgs() << "Begin step 2: update/split other array ops\n";);
764 return applyFullConversion(modOp, target, std::move(patterns));
765}
766
767LogicalResult splitArrayCreateInit(ModuleOp modOp) {
768 SymbolTableCollection symTables;
769 MemberReplacementMap memberRepMap;
770
771 // This is divided into 2 steps to simplify the implementation for member-related ops. The issue
772 // is that the conversions for member read/write expect the mapping of array index to member
773 // name+type to already be populated for the referenced member (although this could be computed on
774 // demand if desired but it complicates the implementation a bit).
775 if (failed(step1(modOp, symTables, memberRepMap))) {
776 return failure();
777 }
778 LLVM_DEBUG({
779 llvm::dbgs() << "After step 1:\n";
780 modOp.dump();
781 });
782 if (failed(step2(modOp, symTables, memberRepMap))) {
783 return failure();
784 }
785 LLVM_DEBUG({
786 llvm::dbgs() << "After step 2:\n";
787 modOp.dump();
788 });
789 return success();
790}
791
792class ArrayToScalarPass : public llzk::array::impl::ArrayToScalarPassBase<ArrayToScalarPass> {
793 void runOnOperation() override {
794 ModuleOp module = getOperation();
795 // Separate array initialization from creation by removing the initialization list from
796 // CreateArrayOp and inserting the corresponding WriteArrayOp following it.
797 if (failed(splitArrayCreateInit(module))) {
798 signalPassFailure();
799 return;
800 }
801 OpPassManager nestedPM(ModuleOp::getOperationName());
802 // Use SROA (Destructurable* interfaces) to split each array with linear size N into N arrays of
803 // size 1. This is necessary because the mem2reg pass cannot deal with indexing and splitting up
804 // memory, i.e., it can only convert scalar memory access into SSA values.
805 nestedPM.addPass(createSROA());
806 // The mem2reg pass converts all of the size 1 array allocation and access into SSA values.
807 nestedPM.addPass(createMem2Reg());
808 // Cleanup SSA values made dead by the transformations
809 nestedPM.addPass(createRemoveDeadValuesPass());
810 if (failed(runPipeline(nestedPM, module))) {
811 signalPassFailure();
812 return;
813 }
814 LLVM_DEBUG({
815 llvm::dbgs() << "After SROA+Mem2Reg pipeline:\n";
816 module.dump();
817 });
818 }
819};
820
821} // namespace
822
824 return std::make_unique<ArrayToScalarPass>();
825};
General helper for converting a FuncDefOp by changing its input and/or result types and the associate...
::mlir::ArrayAttr indexOperandsToAttributeArray()
Returns the multi-dimensional indices of the array access as an Attribute array or a null pointer if ...
Definition Ops.cpp:211
Helper for converting between linear and multi-dimensional indexing with checks to ensure indices are...
static ArrayIndexGen from(ArrayType)
Construct new ArrayIndexGen. Will assert if hasStaticShape() is false.
std::optional< llvm::SmallVector< mlir::Value > > delinearize(int64_t, mlir::Location, mlir::OpBuilder &) const
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced base array.
Definition Ops.h.inc:192
::mlir::TypedValue<::mlir::IndexType > getDim()
Definition Ops.h.inc:150
std::optional<::llvm::SmallVector<::mlir::ArrayAttr > > getSubelementIndices() const
Return a list of all valid indices for this ArrayType.
Definition Types.cpp:111
::mlir::Type getElementType() const
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.h.inc:408
::mlir::MutableOperandRange getElementsMutable()
Definition Ops.cpp.inc:334
::mlir::Operation::operand_range getElements()
Definition Ops.h.inc:388
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.h.inc:613
::mlir::TypedValue<::llzk::array::ArrayType > getRvalue()
Definition Ops.h.inc:757
void setPublicAttr(bool newValue=true)
Definition Ops.cpp:504
::mlir::StringAttr getSymNameAttr()
Definition Ops.h.inc:386
::mlir::TypedValue<::mlir::Type > getVal()
Definition Ops.h.inc:955
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
Definition Types.cpp:46
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:467
::mlir::MutableOperandRange getArgOperandsMutable()
Definition Ops.cpp.inc:200
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:241
CallOpAdaptor Adaptor
Definition Ops.h.inc:188
::mlir::DenseI32ArrayAttr getMapOpGroupSizesAttr()
Definition Ops.h.inc:277
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:952
::mlir::MutableOperandRange getOperandsMutable()
Definition Ops.cpp.inc:1137
::mlir::Operation::operand_range getOperands()
Definition Ops.h.inc:904
Restricts a template parameter to Op classes that implement the given OpInterface.
Definition Concepts.h:20
std::unique_ptr< mlir::Pass > createArrayToScalarPass()
bool isNullOrEmpty(mlir::ArrayAttr a)
int64_t fromAPInt(const llvm::APInt &i)