LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
FlatteningPass.cpp
Go to the documentation of this file.
1//===-- LLZKFlatteningPass.cpp - Implements -llzk-flatten pass --*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
28#include "llzk/Util/Concepts.h"
29#include "llzk/Util/Debug.h"
34
35#include <mlir/Dialect/Affine/IR/AffineOps.h>
36#include <mlir/Dialect/Affine/LoopUtils.h>
37#include <mlir/Dialect/Arith/IR/Arith.h>
38#include <mlir/Dialect/SCF/IR/SCF.h>
39#include <mlir/Dialect/SCF/Utils/Utils.h>
40#include <mlir/Dialect/Utils/StaticValueUtils.h>
41#include <mlir/IR/Attributes.h>
42#include <mlir/IR/BuiltinAttributes.h>
43#include <mlir/IR/BuiltinOps.h>
44#include <mlir/IR/BuiltinTypes.h>
45#include <mlir/Interfaces/InferTypeOpInterface.h>
46#include <mlir/Pass/PassManager.h>
47#include <mlir/Support/LLVM.h>
48#include <mlir/Support/LogicalResult.h>
49#include <mlir/Transforms/DialectConversion.h>
50#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
51
52#include <llvm/ADT/APInt.h>
53#include <llvm/ADT/DenseMap.h>
54#include <llvm/ADT/DepthFirstIterator.h>
55#include <llvm/ADT/STLExtras.h>
56#include <llvm/ADT/SmallVector.h>
57#include <llvm/ADT/TypeSwitch.h>
58#include <llvm/Support/Debug.h>
59
60// Include the generated base pass class definitions.
61namespace llzk::polymorphic {
62#define GEN_PASS_DECL_FLATTENINGPASS
63#define GEN_PASS_DEF_FLATTENINGPASS
65} // namespace llzk::polymorphic
66
67#include "SharedImpl.h"
68
69#define DEBUG_TYPE "llzk-flatten"
70
71using namespace mlir;
72using namespace llzk;
73using namespace llzk::array;
74using namespace llzk::component;
75using namespace llzk::constrain;
76using namespace llzk::felt;
77using namespace llzk::function;
78using namespace llzk::polymorphic;
79using namespace llzk::polymorphic::detail;
80
81namespace {
82
83class ConversionTracker {
85 bool modified;
88 DenseMap<StructType, StructType> structInstantiations;
89 /// Contains the reverse of mappings in `structInstantiations` for use in legal conversion check.
90 DenseMap<StructType, StructType> reverseInstantiations;
93 DenseMap<StructType, SmallVector<Diagnostic>> delayedDiagnostics;
94
95public:
96 bool isModified() const { return modified; }
97 void resetModifiedFlag() { modified = false; }
98 void updateModifiedFlag(bool currStepModified) { modified |= currStepModified; }
100 void recordInstantiation(StructType oldType, StructType newType) {
101 assert(!isNullOrEmpty(oldType.getParams()) && "cannot instantiate with no params");
102
103 auto forwardResult = structInstantiations.try_emplace(oldType, newType);
104 if (forwardResult.second) {
105 // Insertion was successful
106 // ASSERT: The reverse map does not contain this mapping either
107 assert(!reverseInstantiations.contains(newType));
108 reverseInstantiations[newType] = oldType;
109 // Set the modified flag
110 modified = true;
111 } else {
112 // ASSERT: If a mapping already existed for `oldType` it must be `newType`
113 assert(forwardResult.first->getSecond() == newType);
114 // ASSERT: The reverse mapping is already present as well
115 assert(reverseInstantiations.lookup(newType) == oldType);
117 assert(structInstantiations.size() == reverseInstantiations.size());
118 }
121 std::optional<StructType> getInstantiation(StructType oldType) const {
122 auto cachedResult = structInstantiations.find(oldType);
123 if (cachedResult != structInstantiations.end()) {
124 return cachedResult->second;
126 return std::nullopt;
127 }
128
130 DenseSet<SymbolRefAttr> getInstantiatedStructNames() const {
131 DenseSet<SymbolRefAttr> instantiatedNames;
132 for (const auto &[origRemoteTy, _] : structInstantiations) {
133 instantiatedNames.insert(origRemoteTy.getNameRef());
134 }
135 return instantiatedNames;
136 }
137
138 void reportDelayedDiagnostics(StructType newType, CallOp caller) {
139 auto res = delayedDiagnostics.find(newType);
140 if (res == delayedDiagnostics.end()) {
141 return;
142 }
143
144 DiagnosticEngine &engine = caller.getContext()->getDiagEngine();
145 for (Diagnostic &diag : res->second) {
146 // Update any notes referencing an UnknownLoc to use the CallOp location.
147 for (Diagnostic &note : diag.getNotes()) {
148 assert(note.getNotes().empty() && "notes cannot have notes attached");
149 if (llvm::isa<UnknownLoc>(note.getLocation())) {
150 note = std::move(Diagnostic(caller.getLoc(), note.getSeverity()).append(note.str()));
151 }
152 }
153 // Report. Based on InFlightDiagnostic::report().
154 engine.emit(std::move(diag));
155 }
156 // Emitting a Diagnostic consumes it (per DiagnosticEngine::emit) so remove them from the map.
157 // Unfortunately, this means if the key StructType is the result of instantiation at multiple
158 // `compute()` calls it will only be reported at one of those locations, not all.
159 delayedDiagnostics.erase(newType);
160 }
162 SmallVector<Diagnostic> &delayedDiagnosticSet(StructType newType) {
163 return delayedDiagnostics[newType];
164 }
165
168 bool isLegalConversion(Type oldType, Type newType, const char *patName) const {
169 std::function<bool(Type, Type)> checkInstantiations = [&](Type oTy, Type nTy) {
170 // Check if `oTy` is a struct with a known instantiation to `nTy`
171 if (StructType oldStructType = llvm::dyn_cast<StructType>(oTy)) {
172 // Note: The values in `structInstantiations` must be no-parameter struct types
173 // so there is no need for recursive check, simple equality is sufficient.
174 if (this->structInstantiations.lookup(oldStructType) == nTy) {
175 return true;
176 }
177 }
178 // Check if `nTy` is the result of a struct instantiation and if the pre-image of
179 // that instantiation (i.e., the parameterized version of the instantiated struct)
180 // is a more concrete unification of `oTy`.
181 if (StructType newStructType = llvm::dyn_cast<StructType>(nTy)) {
182 if (auto preImage = this->reverseInstantiations.lookup(newStructType)) {
183 if (isMoreConcreteUnification(oTy, preImage, checkInstantiations)) {
184 return true;
185 }
186 }
187 }
188 return false;
189 };
190
191 if (isMoreConcreteUnification(oldType, newType, checkInstantiations)) {
192 return true;
193 }
194 LLVM_DEBUG(
195 llvm::dbgs() << "[" << patName << "] Cannot replace old type " << oldType
196 << " with new type " << newType
197 << " because it does not define a compatible and more concrete type.\n";
198 );
199 return false;
200 }
201
202 template <typename T, typename U>
203 inline bool areLegalConversions(T oldTypes, U newTypes, const char *patName) const {
204 return llvm::all_of(
205 llvm::zip_equal(oldTypes, newTypes), [this, &patName](std::tuple<Type, Type> oldThenNew) {
206 return this->isLegalConversion(std::get<0>(oldThenNew), std::get<1>(oldThenNew), patName);
207 }
208 );
209 }
210};
211
214struct MatchFailureListener : public RewriterBase::Listener {
215 bool hadFailure = false;
216
217 ~MatchFailureListener() override {}
218
219 void notifyMatchFailure(Location loc, function_ref<void(Diagnostic &)> reasonCallback) override {
220 hadFailure = true;
221
222 InFlightDiagnostic diag = emitError(loc);
223 reasonCallback(*diag.getUnderlyingDiagnostic());
224 diag.report();
225 }
226};
227
228static LogicalResult
229applyAndFoldGreedily(ModuleOp modOp, ConversionTracker &tracker, RewritePatternSet &&patterns) {
230 bool currStepModified = false;
231 MatchFailureListener failureListener;
232 LogicalResult result = applyPatternsGreedily(
233 modOp->getRegion(0), std::move(patterns),
234 GreedyRewriteConfig {.maxIterations = 20, .listener = &failureListener, .fold = true},
235 &currStepModified
236 );
237 tracker.updateModifiedFlag(currStepModified);
238 return failure(result.failed() || failureListener.hadFailure);
239}
240
241template <bool AllowStructParams = true> bool isConcreteAttr(Attribute a) {
242 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(a)) {
243 return isConcreteType(tyAttr.getValue(), AllowStructParams);
244 }
245 if (IntegerAttr intAttr = dyn_cast<IntegerAttr>(a)) {
246 return !isDynamic(intAttr);
247 }
248 return false;
249}
250
252
253static inline bool tableOffsetIsntSymbol(MemberReadOp op) {
254 return !llvm::isa_and_present<SymbolRefAttr>(op.getTableOffset().value_or(nullptr));
255}
256
259class StructCloner {
260 ConversionTracker &tracker_;
261 ModuleOp rootMod;
262 SymbolTableCollection symTables;
263 bool reportMissing = true;
264
265 class MappedTypeConverter : public TypeConverter {
266 StructType origTy;
267 StructType newTy;
268 const DenseMap<Attribute, Attribute> &paramNameToValue;
269
270 inline Attribute convertIfPossible(Attribute a) const {
271 auto res = this->paramNameToValue.find(a);
272 return (res != this->paramNameToValue.end()) ? res->second : a;
273 }
274
275 public:
276 MappedTypeConverter(
277 StructType originalType, StructType newType,
279 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
280 )
281 : TypeConverter(), origTy(originalType), newTy(newType),
282 paramNameToValue(paramNameToInstantiatedValue) {
283
284 addConversion([](Type inputTy) { return inputTy; });
285
286 addConversion([this](StructType inputTy) {
287 LLVM_DEBUG(llvm::dbgs() << "[MappedTypeConverter] convert " << inputTy << '\n');
288
289 // Check for replacement of the full type
290 if (inputTy == this->origTy) {
291 return this->newTy;
292 }
293 // Check for replacement of parameter symbol names with concrete values
294 if (ArrayAttr inputTyParams = inputTy.getParams()) {
295 SmallVector<Attribute> updated;
296 for (Attribute a : inputTyParams) {
297 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
298 updated.push_back(TypeAttr::get(this->convertType(ta.getValue())));
299 } else {
300 updated.push_back(convertIfPossible(a));
301 }
302 }
303 return StructType::get(
304 inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
305 );
306 }
307 // Otherwise, return the type unchanged
308 return inputTy;
309 });
310
311 addConversion([this](ArrayType inputTy) {
312 // Check for replacement of parameter symbol names with concrete values
313 ArrayRef<Attribute> dimSizes = inputTy.getDimensionSizes();
314 if (!dimSizes.empty()) {
315 SmallVector<Attribute> updated;
316 for (Attribute a : dimSizes) {
317 updated.push_back(convertIfPossible(a));
318 }
319 return ArrayType::get(this->convertType(inputTy.getElementType()), updated);
320 }
321 // Otherwise, return the type unchanged
322 return inputTy;
323 });
324
325 addConversion([this](TypeVarType inputTy) -> Type {
326 // Check for replacement of parameter symbol name with a concrete type
327 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.getNameRef()))) {
328 Type convertedType = tyAttr.getValue();
329 // Use the new type unless it contains a TypeVarType because a TypeVarType from a
330 // different struct references a parameter name from that other struct, not from the
331 // current struct so the reference would be invalid.
332 if (isConcreteType(convertedType)) {
333 return convertedType;
334 }
335 }
336 return inputTy;
337 });
338 }
339 };
340
341 template <typename Impl, typename Op, typename... HandledAttrs>
342 class SymbolUserHelper : public OpConversionPattern<Op> {
343 private:
344 const DenseMap<Attribute, Attribute> &paramNameToValue;
345
346 SymbolUserHelper(
347 TypeConverter &converter, MLIRContext *ctx, unsigned Benefit,
348 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
349 )
350 : OpConversionPattern<Op>(converter, ctx, Benefit),
351 paramNameToValue(paramNameToInstantiatedValue) {}
352
353 public:
354 using OpAdaptor = typename mlir::OpConversionPattern<Op>::OpAdaptor;
355
356 virtual Attribute getNameAttr(Op) const = 0;
357
358 virtual LogicalResult handleDefaultRewrite(
359 Attribute, Op op, OpAdaptor, ConversionPatternRewriter &, Attribute a
360 ) const {
361 return op->emitOpError().append("expected value with type ", op.getType(), " but found ", a);
362 }
363
364 LogicalResult
365 matchAndRewrite(Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
366 LLVM_DEBUG(llvm::dbgs() << "[SymbolUserHelper] op: " << op << '\n');
367 auto res = this->paramNameToValue.find(getNameAttr(op));
368 if (res == this->paramNameToValue.end()) {
369 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] no instantiation for " << op << '\n');
370 return failure();
371 }
372 llvm::TypeSwitch<Attribute, LogicalResult> TS(res->second);
373 llvm::TypeSwitch<Attribute, LogicalResult> *ptr = &TS;
374
375 ((ptr = &(ptr->template Case<HandledAttrs>([&](HandledAttrs a) {
376 return static_cast<const Impl *>(this)->handleRewrite(res->first, op, adaptor, rewriter, a);
377 }))),
378 ...);
379
380 return TS.Default([&](Attribute a) {
381 return handleDefaultRewrite(res->first, op, adaptor, rewriter, a);
382 });
383 }
384 friend Impl;
385 };
386
387 class ClonedStructConstReadOpPattern
388 : public SymbolUserHelper<
389 ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr> {
390 SmallVector<Diagnostic> &diagnostics;
391
392 using super =
393 SymbolUserHelper<ClonedStructConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr>;
394
395 public:
396 ClonedStructConstReadOpPattern(
397 TypeConverter &converter, MLIRContext *ctx,
398 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue,
399 SmallVector<Diagnostic> &instantiationDiagnostics
400 )
401 // Must use higher benefit than GeneralTypeReplacePattern so this pattern will be applied
402 // instead of the GeneralTypeReplacePattern<ConstReadOp> from newGeneralRewritePatternSet().
403 : super(converter, ctx, /*benefit=*/2, paramNameToInstantiatedValue),
404 diagnostics(instantiationDiagnostics) {}
405
406 Attribute getNameAttr(ConstReadOp op) const override { return op.getConstNameAttr(); }
407
408 LogicalResult handleRewrite(
409 Attribute sym, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, IntegerAttr a
410 ) const {
411 APInt attrValue = a.getValue();
412 Type origResTy = op.getType();
413 if (llvm::isa<FeltType>(origResTy)) {
415 rewriter, op, FeltConstAttr::get(getContext(), attrValue)
416 );
417 return success();
418 }
419
420 if (llvm::isa<IndexType>(origResTy)) {
422 return success();
423 }
424
425 if (origResTy.isSignlessInteger(1)) {
426 // Treat 0 as false and any other value as true (but give a warning if it's not 1)
427 if (attrValue.isZero()) {
428 replaceOpWithNewOp<arith::ConstantIntOp>(rewriter, op, false, origResTy);
429 return success();
430 }
431 if (!attrValue.isOne()) {
432 Location opLoc = op.getLoc();
433 Diagnostic diag(opLoc, DiagnosticSeverity::Warning);
434 diag << "Interpreting non-zero value " << stringWithoutType(a) << " as true";
435 if (getContext()->shouldPrintOpOnDiagnostic()) {
436 diag.attachNote(opLoc) << "see current operation: " << *op;
437 }
438 diag.attachNote(UnknownLoc::get(getContext()))
439 << "when instantiating '" << StructDefOp::getOperationName() << "' parameter \""
440 << sym << "\" for this call";
441 diagnostics.push_back(std::move(diag));
442 }
443 replaceOpWithNewOp<arith::ConstantIntOp>(rewriter, op, true, origResTy);
444 return success();
445 }
446 return op->emitOpError().append("unexpected result type ", origResTy);
447 }
448
449 LogicalResult handleRewrite(
450 Attribute, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, FeltConstAttr a
451 ) const {
452 replaceOpWithNewOp<FeltConstantOp>(rewriter, op, a);
453 return success();
454 }
455 };
456
457 class ClonedStructMemberReadOpPattern
458 : public SymbolUserHelper<
459 ClonedStructMemberReadOpPattern, MemberReadOp, IntegerAttr, FeltConstAttr> {
460 using super =
461 SymbolUserHelper<ClonedStructMemberReadOpPattern, MemberReadOp, IntegerAttr, FeltConstAttr>;
462
463 public:
464 ClonedStructMemberReadOpPattern(
465 TypeConverter &converter, MLIRContext *ctx,
466 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
467 )
468 // Must use higher benefit than GeneralTypeReplacePattern so this pattern will be applied
469 // instead of the GeneralTypeReplacePattern<MemberReadOp> from
470 // newGeneralRewritePatternSet().
471 : super(converter, ctx, /*benefit=*/2, paramNameToInstantiatedValue) {}
472
473 Attribute getNameAttr(MemberReadOp op) const override {
474 return op.getTableOffset().value_or(nullptr);
475 }
476
477 template <typename Attr>
478 LogicalResult handleRewrite(
479 Attribute, MemberReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, Attr a
480 ) const {
481 rewriter.modifyOpInPlace(op, [&]() {
482 op.setTableOffsetAttr(rewriter.getIndexAttr(fromAPInt(a.getValue())));
483 });
484
485 return success();
486 }
487
488 LogicalResult matchAndRewrite(
489 MemberReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
490 ) const override {
491 LLVM_DEBUG(
492 llvm::dbgs() << "[ClonedStructMemberReadOpPattern] MemberReadOp: " << op << '\n';
493 );
494 if (tableOffsetIsntSymbol(op)) {
495 return failure();
496 }
497
498 return super::matchAndRewrite(op, adaptor, rewriter);
499 }
500 };
501
502 FailureOr<StructType> genClone(StructType typeAtCaller, ArrayRef<Attribute> typeAtCallerParams) {
503 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] attempting clone of " << typeAtCaller << '\n');
504 // Find the StructDefOp for the original StructType
505 FailureOr<SymbolLookupResult<StructDefOp>> r =
506 typeAtCaller.getDefinition(symTables, rootMod, reportMissing);
507 if (failed(r)) {
508 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: cannot find StructDefOp \n");
509 return failure(); // getDefinition() already emits a sufficient error message
510 }
511 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] found definition\n";);
512
513 StructDefOp origStruct = r->get();
514 StructType typeAtDef = origStruct.getType();
515 MLIRContext *ctx = origStruct.getContext();
516
517 // Map of StructDefOp parameter name to concrete Attribute at the current instantiation site.
518 DenseMap<Attribute, Attribute> paramNameToConcrete;
519 // List of concrete Attributes from the struct instantiation with `nullptr` at any positions
520 // where the original attribute from the current instantiation site was not concrete. This is
521 // used for generating the new struct name. See `BuildShortTypeString::from()`.
522 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
523 // Parameter list for the new StructDefOp containing the names that must be preserved because
524 // they were not assigned concrete values at the current instantiation site.
525 ArrayAttr reducedParamNameList = nullptr;
526 // Reduced from `typeAtCallerParams` to contain only the non-concrete Attributes.
527 ArrayAttr reducedCallerParams = nullptr;
528 {
529 ArrayAttr paramNames = typeAtDef.getParams();
530
531 // pre-conditions
532 assert(!isNullOrEmpty(paramNames));
533 assert(paramNames.size() == typeAtCallerParams.size());
534
535 SmallVector<Attribute> remainingNames;
536 SmallVector<Attribute> nonConcreteParams;
537 for (size_t i = 0, e = paramNames.size(); i < e; ++i) {
538 Attribute next = typeAtCallerParams[i];
539 if (isConcreteAttr<false>(next)) {
540 paramNameToConcrete[paramNames[i]] = next;
541 attrsForInstantiatedNameSuffix.push_back(next);
542 } else {
543 remainingNames.push_back(paramNames[i]);
544 nonConcreteParams.push_back(next);
545 attrsForInstantiatedNameSuffix.push_back(nullptr);
546 }
547 }
548 // post-conditions
549 assert(remainingNames.size() == nonConcreteParams.size());
550 assert(attrsForInstantiatedNameSuffix.size() == paramNames.size());
551 assert(remainingNames.size() + paramNameToConcrete.size() == paramNames.size());
552
553 if (paramNameToConcrete.empty()) {
554 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: no concrete params \n");
555 return failure();
556 }
557 if (!remainingNames.empty()) {
558 reducedParamNameList = ArrayAttr::get(ctx, remainingNames);
559 reducedCallerParams = ArrayAttr::get(ctx, nonConcreteParams);
560 }
561 }
562
563 // Clone the original struct, apply the new name, and set the parameter list of the new struct
564 // to contain only those that did not have concrete instantiated values.
565 StructDefOp newStruct = origStruct.clone();
566 newStruct.setConstParamsAttr(reducedParamNameList);
567 newStruct.setSymName(
569 typeAtCaller.getNameRef().getLeafReference().str(), attrsForInstantiatedNameSuffix
570 )
571 );
572
573 // Insert 'newStruct' into the parent ModuleOp of the original StructDefOp. Use the
574 // `SymbolTable::insert()` function directly so that the name will be made unique.
575 ModuleOp parentModule = origStruct.getParentOp<ModuleOp>(); // parent is ModuleOp per ODS
576 symTables.getSymbolTable(parentModule).insert(newStruct, Block::iterator(origStruct));
577 // Retrieve the new type AFTER inserting since the name may be appended to make it unique and
578 // use the remaining non-concrete parameters from the original type.
579 StructType newLocalType = newStruct.getType(reducedCallerParams);
580 auto typeAtCallerSym = typeAtCaller.getNameRef();
581 // Copy the leafs of the type at the caller.
582 SmallVector<FlatSymbolRefAttr> newLeafs(typeAtCallerSym.getNestedReferences());
583 auto rootSym = typeAtCallerSym.getRootReference();
584 if (!newLeafs.empty()) {
585 // Replace the last one with the new name.
586 newLeafs.back() = FlatSymbolRefAttr::get(newLocalType.getNameRef().getLeafReference());
587 } else {
588 // If there's only one symbol then write the new name on the root.
589 rootSym = newLocalType.getNameRef().getLeafReference();
590 }
591 StructType newRemoteType =
592 StructType::get(SymbolRefAttr::get(rootSym, newLeafs), newLocalType.getParams());
593 LLVM_DEBUG({
594 llvm::dbgs() << "[StructCloner] original def type: " << typeAtDef << '\n';
595 llvm::dbgs() << "[StructCloner] cloned def type: " << newStruct.getType() << '\n';
596 llvm::dbgs() << "[StructCloner] original remote type: " << typeAtCaller << '\n';
597 llvm::dbgs() << "[StructCloner] cloned local type: " << newLocalType << '\n';
598 llvm::dbgs() << "[StructCloner] cloned remote type: " << newRemoteType << '\n';
599 });
600
601 // Within the new struct, replace all references to the original StructType (i.e., the
602 // locally-parameterized version) with the new locally-parameterized StructType,
603 // and replace all uses of the removed struct parameters with the concrete values.
604 MappedTypeConverter tyConv(typeAtDef, newStruct.getType(), paramNameToConcrete);
605 ConversionTarget target =
606 newConverterDefinedTarget<EmitEqualityOp>(tyConv, ctx, tableOffsetIsntSymbol);
607 target.addDynamicallyLegalOp<ConstReadOp>([&paramNameToConcrete](ConstReadOp op) {
608 // Legal if it's not in the map of concrete attribute instantiations
609 return paramNameToConcrete.find(op.getConstNameAttr()) == paramNameToConcrete.end();
610 });
611
612 RewritePatternSet patterns = newGeneralRewritePatternSet<EmitEqualityOp>(tyConv, ctx, target);
613 patterns.add<ClonedStructConstReadOpPattern>(
614 tyConv, ctx, paramNameToConcrete, tracker_.delayedDiagnosticSet(newLocalType)
615 );
616 patterns.add<ClonedStructMemberReadOpPattern>(tyConv, ctx, paramNameToConcrete);
617 if (failed(applyFullConversion(newStruct, target, std::move(patterns)))) {
618 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] instantiating body of struct failed \n");
619 return failure();
620 }
621 return newRemoteType;
622 }
623
624public:
625 StructCloner(ConversionTracker &tracker, ModuleOp root)
626 : tracker_(tracker), rootMod(root), symTables() {}
627
628 FailureOr<StructType> createInstantiatedClone(StructType orig) {
629 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] orig: " << orig << '\n');
630 if (ArrayAttr params = orig.getParams()) {
631 return genClone(orig, params.getValue());
632 }
633 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: nullptr for params \n");
634 return failure();
635 }
636
637 void enableReportMissing() { reportMissing = true; }
638
639 void disableReportMissing() { reportMissing = false; }
640};
641
642class DisableReportMissing;
643
644class ParameterizedStructUseTypeConverter : public TypeConverter {
645 ConversionTracker &tracker_;
646 StructCloner cloner;
647
648 friend DisableReportMissing;
649
650public:
651 ParameterizedStructUseTypeConverter(ConversionTracker &tracker, ModuleOp root)
652 : TypeConverter(), tracker_(tracker), cloner(tracker, root) {
653
654 addConversion([](Type inputTy) { return inputTy; });
655
656 addConversion([this](StructType inputTy) -> StructType {
657 LLVM_DEBUG(
658 llvm::dbgs() << "[ParameterizedStructUseTypeConverter] attempting conversion of "
659 << inputTy << '\n';
660 );
661 // First check for a cached entry
662 if (auto opt = tracker_.getInstantiation(inputTy)) {
663 return opt.value();
664 }
665
666 // Otherwise, try to create a clone of the struct with instantiated params. If that can't be
667 // done, return the original type to indicate that it's still legal (for this step at least).
668 FailureOr<StructType> cloneRes = cloner.createInstantiatedClone(inputTy);
669 if (failed(cloneRes)) {
670 return inputTy;
671 }
672 StructType newTy = cloneRes.value();
673 LLVM_DEBUG(
674 llvm::dbgs() << "[ParameterizedStructUseTypeConverter] instantiating " << inputTy
675 << " as " << newTy << '\n'
676 );
677 tracker_.recordInstantiation(inputTy, newTy);
678 return newTy;
679 });
680
681 addConversion([this](ArrayType inputTy) {
682 return inputTy.cloneWith(convertType(inputTy.getElementType()));
683 });
684 }
685};
686
687class CallStructFuncPattern : public OpConversionPattern<CallOp> {
688 ConversionTracker &tracker_;
689
690public:
691 CallStructFuncPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &tracker)
692 // Must use higher benefit than CallOpClassReplacePattern so this pattern will be applied
693 // instead of the CallOpClassReplacePattern from newGeneralRewritePatternSet().
694 : OpConversionPattern<CallOp>(converter, ctx, /*benefit=*/2), tracker_(tracker) {}
695
696 LogicalResult matchAndRewrite(
697 CallOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
698 ) const override {
699 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] CallOp: " << op << '\n');
700
701 // Convert the result types of the CallOp
702 SmallVector<Type> newResultTypes;
703 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
704 return op->emitError("Could not convert Op result types.");
705 }
706 LLVM_DEBUG({
707 llvm::dbgs() << "[CallStructFuncPattern] newResultTypes: "
708 << debug::toStringList(newResultTypes) << '\n';
709 });
710
711 // Update the callee to reflect the new struct target if necessary. These checks are based on
712 // `CallOp::calleeIsStructC*()` but the types must not come from the CallOp in this case.
713 // Instead they must come from the converted versions.
714 SymbolRefAttr calleeAttr = op.getCalleeAttr();
715 if (op.calleeIsStructCompute()) {
716 if (StructType newStTy = getIfSingleton<StructType>(newResultTypes)) {
717 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] newStTy: " << newStTy << '\n');
718 calleeAttr = appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
719 tracker_.reportDelayedDiagnostics(newStTy, op);
720 }
721 } else if (op.calleeIsStructConstrain()) {
722 if (StructType newStTy = getAtIndex<StructType>(adapter.getArgOperands().getTypes(), 0)) {
723 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] newStTy: " << newStTy << '\n');
724 calleeAttr = appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
725 }
726 }
727
728 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] replaced " << op);
730 rewriter, op, newResultTypes, calleeAttr, adapter.getMapOperands(),
731 op.getNumDimsPerMapAttr(), adapter.getArgOperands()
732 );
733 (void)newOp; // tell compiler it's intentionally unused in release builds
734 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
735 return success();
736 }
737};
738
739// This one ensures MemberDefOp types are converted even if there are no reads/writes to them.
740class MemberDefOpPattern : public OpConversionPattern<MemberDefOp> {
741public:
742 MemberDefOpPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &)
743 // Must use higher benefit than GeneralTypeReplacePattern so this pattern will be applied
744 // instead of the GeneralTypeReplacePattern<MemberDefOp> from newGeneralRewritePatternSet().
745 : OpConversionPattern<MemberDefOp>(converter, ctx, /*benefit=*/2) {}
746
747 LogicalResult matchAndRewrite(
748 MemberDefOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
749 ) const override {
750 LLVM_DEBUG(llvm::dbgs() << "[MemberDefOpPattern] MemberDefOp: " << op << '\n');
751
752 Type oldMemberType = op.getType();
753 Type newMemberType = getTypeConverter()->convertType(oldMemberType);
754 if (oldMemberType == newMemberType) {
755 // nothing changed
756 return failure();
757 }
758 rewriter.modifyOpInPlace(op, [&op, &newMemberType]() { op.setType(newMemberType); });
759 return success();
760 }
761};
762
765class DisableReportMissing : public LegalityCheckCallback {
766 ParameterizedStructUseTypeConverter &tyConv;
767
768public:
769 explicit DisableReportMissing(ParameterizedStructUseTypeConverter &tc) : tyConv(tc) {}
770
771 void checkStarted() override { tyConv.cloner.disableReportMissing(); }
772
773 void checkEnded(bool) override { tyConv.cloner.enableReportMissing(); }
774};
775
776LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
777 MLIRContext *ctx = modOp.getContext();
778 ParameterizedStructUseTypeConverter tyConv(tracker, modOp);
779 DisableReportMissing drm(tyConv);
780 ConversionTarget target = newConverterDefinedTargetWithCallback<>(tyConv, ctx, drm);
781 RewritePatternSet patterns = newGeneralRewritePatternSet(tyConv, ctx, target);
782 patterns.add<CallStructFuncPattern, MemberDefOpPattern>(tyConv, ctx, tracker);
783 return applyPartialConversion(modOp, target, std::move(patterns));
784}
785
786} // namespace Step1_InstantiateStructs
787
788namespace Step2_Unroll {
789
790// TODO: not guaranteed to work with WhileOp, can try with our custom attributes though.
791template <HasInterface<LoopLikeOpInterface> OpClass>
792class LoopUnrollPattern : public OpRewritePattern<OpClass> {
793public:
794 using OpRewritePattern<OpClass>::OpRewritePattern;
795
796 LogicalResult matchAndRewrite(OpClass loopOp, PatternRewriter &rewriter) const override {
797 if (auto maybeConstant = getConstantTripCount(loopOp)) {
798 uint64_t tripCount = *maybeConstant;
799 if (tripCount == 0) {
800 rewriter.eraseOp(loopOp);
801 return success();
802 } else if (tripCount == 1) {
803 return loopOp.promoteIfSingleIteration(rewriter);
804 }
805 return loopUnrollByFactor(loopOp, tripCount);
806 }
807 return failure();
808 }
809
810private:
813 static std::optional<int64_t> getConstantTripCount(LoopLikeOpInterface loopOp) {
814 std::optional<OpFoldResult> lbVal = loopOp.getSingleLowerBound();
815 std::optional<OpFoldResult> ubVal = loopOp.getSingleUpperBound();
816 std::optional<OpFoldResult> stepVal = loopOp.getSingleStep();
817 if (!lbVal.has_value() || !ubVal.has_value() || !stepVal.has_value()) {
818 return std::nullopt;
819 }
820 return constantTripCount(lbVal.value(), ubVal.value(), stepVal.value());
821 }
822};
823
824LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
825 MLIRContext *ctx = modOp.getContext();
826 RewritePatternSet patterns(ctx);
827 patterns.add<LoopUnrollPattern<scf::ForOp>>(ctx);
828 patterns.add<LoopUnrollPattern<affine::AffineForOp>>(ctx);
829
830 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
831}
832} // namespace Step2_Unroll
833
835
836// Adapted from `mlir::getConstantIntValues()` but that one failed in CI for an unknown reason. This
837// version uses a basic loop instead of llvm::map_to_vector().
838std::optional<SmallVector<int64_t>> getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
839 SmallVector<int64_t> res;
840 for (OpFoldResult ofr : ofrs) {
841 std::optional<int64_t> cv = getConstantIntValue(ofr);
842 if (!cv.has_value()) {
843 return std::nullopt;
844 }
845 res.push_back(cv.value());
846 }
847 return res;
848}
849
850struct AffineMapFolder {
851 struct Input {
852 OperandRangeRange mapOpGroups;
853 DenseI32ArrayAttr dimsPerGroup;
854 ArrayRef<Attribute> paramsOfStructTy;
855 };
856
857 struct Output {
858 SmallVector<SmallVector<Value>> mapOpGroups;
859 SmallVector<int32_t> dimsPerGroup;
860 SmallVector<Attribute> paramsOfStructTy;
861 };
862
863 static inline SmallVector<ValueRange> getConvertedMapOpGroups(Output out) {
864 return llvm::map_to_vector(out.mapOpGroups, [](const SmallVector<Value> &grp) {
865 return ValueRange(grp);
866 });
867 }
868
869 static LogicalResult
870 fold(PatternRewriter &rewriter, const Input &in, Output &out, Operation *op, const char *aspect) {
871 if (in.mapOpGroups.empty()) {
872 // No affine map operands so nothing to do
873 return failure();
874 }
875
876 assert(in.mapOpGroups.size() <= in.paramsOfStructTy.size());
877 assert(std::cmp_equal(in.mapOpGroups.size(), in.dimsPerGroup.size()));
878
879 size_t idx = 0; // index in `mapOpGroups`, i.e., the number of AffineMapAttr encountered
880 for (Attribute sizeAttr : in.paramsOfStructTy) {
881 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(sizeAttr)) {
882 ValueRange currMapOps = in.mapOpGroups[idx++];
883 LLVM_DEBUG(
884 llvm::dbgs() << "[AffineMapFolder] currMapOps: " << debug::toStringList(currMapOps)
885 << '\n'
886 );
887 SmallVector<OpFoldResult> currMapOpsCast = getAsOpFoldResult(currMapOps);
888 LLVM_DEBUG(
889 llvm::dbgs() << "[AffineMapFolder] currMapOps as fold results: "
890 << debug::toStringList(currMapOpsCast) << '\n'
891 );
892 if (auto constOps = Step3_InstantiateAffineMaps::getConstantIntValues(currMapOpsCast)) {
893 SmallVector<Attribute> result;
894 bool hasPoison = false; // indicates divide by 0 or mod by <1
895 auto constAttrs = llvm::map_to_vector(*constOps, [&rewriter](int64_t v) -> Attribute {
896 return rewriter.getIndexAttr(v);
897 });
898 LogicalResult foldResult = m.getAffineMap().constantFold(constAttrs, result, &hasPoison);
899 if (hasPoison) {
900 // Diagnostic remark: could be removed for release builds if too noisy
901 op->emitRemark()
902 .append(
903 "Cannot fold affine_map for ", aspect, " ", out.paramsOfStructTy.size(),
904 " due to divide by 0 or modulus with negative divisor"
905 )
906 .report();
907 return failure();
908 }
909 if (failed(foldResult)) {
910 // Diagnostic remark: could be removed for release builds if too noisy
911 op->emitRemark()
912 .append(
913 "Folding affine_map for ", aspect, " ", out.paramsOfStructTy.size(), " failed"
914 )
915 .report();
916 return failure();
917 }
918 if (result.size() != 1) {
919 // Diagnostic remark: could be removed for release builds if too noisy
920 op->emitRemark()
921 .append(
922 "Folding affine_map for ", aspect, " ", out.paramsOfStructTy.size(),
923 " produced ", result.size(), " results but expected 1"
924 )
925 .report();
926 return failure();
927 }
928 assert(!llvm::isa<AffineMapAttr>(result[0]) && "not converted");
929 out.paramsOfStructTy.push_back(result[0]);
930 continue;
931 }
932 // If affine but not foldable, preserve the map ops
933 out.mapOpGroups.emplace_back(currMapOps);
934 out.dimsPerGroup.push_back(in.dimsPerGroup[idx - 1]); // idx was already incremented
935 }
936 // If not affine and foldable, preserve the original
937 out.paramsOfStructTy.push_back(sizeAttr);
938 }
939 assert(idx == in.mapOpGroups.size() && "all affine_map not processed");
940 assert(
941 in.paramsOfStructTy.size() == out.paramsOfStructTy.size() &&
942 "produced wrong number of dimensions"
943 );
944
945 return success();
946 }
947};
948
950class InstantiateAtCreateArrayOp final : public OpRewritePattern<CreateArrayOp> {
951 [[maybe_unused]]
952 ConversionTracker &tracker_;
953
954public:
955 InstantiateAtCreateArrayOp(MLIRContext *ctx, ConversionTracker &tracker)
956 : OpRewritePattern(ctx), tracker_(tracker) {}
957
958 LogicalResult matchAndRewrite(CreateArrayOp op, PatternRewriter &rewriter) const override {
959 ArrayType oldResultType = op.getType();
960
961 AffineMapFolder::Output out;
962 AffineMapFolder::Input in = {
963 op.getMapOperands(),
965 oldResultType.getDimensionSizes(),
966 };
967 if (failed(AffineMapFolder::fold(rewriter, in, out, op, "array dimension"))) {
968 return failure();
969 }
970
971 ArrayType newResultType = ArrayType::get(oldResultType.getElementType(), out.paramsOfStructTy);
972 if (newResultType == oldResultType) {
973 // nothing changed
974 return failure();
975 }
976 // ASSERT: folding only preserves the original Attribute or converts affine to integer
977 assert(tracker_.isLegalConversion(oldResultType, newResultType, "InstantiateAtCreateArrayOp"));
978 LLVM_DEBUG(
979 llvm::dbgs() << "[InstantiateAtCreateArrayOp] instantiating " << oldResultType << " as "
980 << newResultType << " in \"" << op << "\"\n"
981 );
983 rewriter, op, newResultType, AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup
984 );
985 return success();
986 }
987};
988
990class InstantiateAtCallOpCompute final : public OpRewritePattern<CallOp> {
991 ConversionTracker &tracker_;
992
993public:
994 InstantiateAtCallOpCompute(MLIRContext *ctx, ConversionTracker &tracker)
995 : OpRewritePattern(ctx), tracker_(tracker) {}
996
997 LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override {
998 if (!op.calleeIsStructCompute()) {
999 // this pattern only applies when the callee is "compute()" within a struct
1000 return failure();
1001 }
1002 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] target: " << op.getCallee() << '\n');
1004 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] oldRetTy: " << oldRetTy << '\n');
1005 ArrayAttr params = oldRetTy.getParams();
1006 if (isNullOrEmpty(params)) {
1007 // nothing to do if the StructType is not parameterized
1008 return failure();
1009 }
1010
1011 AffineMapFolder::Output out;
1012 AffineMapFolder::Input in = {
1013 op.getMapOperands(),
1015 params.getValue(),
1016 };
1017 if (!in.mapOpGroups.empty()) {
1018 // If there are affine map operands, attempt to fold them to a constant.
1019 if (failed(AffineMapFolder::fold(rewriter, in, out, op, "struct parameter"))) {
1020 return failure();
1021 }
1022 LLVM_DEBUG({
1023 llvm::dbgs() << "[InstantiateAtCallOpCompute] folded affine_map in result type params\n";
1024 });
1025 } else {
1026 // If there are no affine map operands, attempt to refine the result type of the CallOp using
1027 // the function argument types and the type of the target function.
1028 auto callArgTypes = op.getArgOperands().getTypes();
1029 if (callArgTypes.empty()) {
1030 // no refinement possible if no function arguments
1031 return failure();
1032 }
1033 SymbolTableCollection tables;
1034 auto lookupRes = lookupTopLevelSymbol<FuncDefOp>(tables, op.getCalleeAttr(), op);
1035 if (failed(lookupRes)) {
1036 return failure();
1037 }
1038 if (failed(instantiateViaTargetType(in, out, callArgTypes, lookupRes->get()))) {
1039 return failure();
1040 }
1041 LLVM_DEBUG({
1042 llvm::dbgs() << "[InstantiateAtCallOpCompute] propagated instantiations via symrefs in "
1043 "result type params: "
1044 << debug::toStringList(out.paramsOfStructTy) << '\n';
1045 });
1046 }
1047
1048 StructType newRetTy = StructType::get(oldRetTy.getNameRef(), out.paramsOfStructTy);
1049 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] newRetTy: " << newRetTy << '\n');
1050 if (newRetTy == oldRetTy) {
1051 // nothing changed
1052 return failure();
1053 }
1054 // The `newRetTy` is computed via instantiateViaTargetType() which can only preserve the
1055 // original Attribute or convert to a concrete attribute via the unification process. Thus, if
1056 // the conversion here is illegal it means there is a type conflict within the LLZK code that
1057 // prevents instantiation of the struct with the requested type.
1058 if (!tracker_.isLegalConversion(oldRetTy, newRetTy, "InstantiateAtCallOpCompute")) {
1059 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1060 diag.append(
1061 "result type mismatch: due to struct instantiation, expected type ", newRetTy,
1062 ", but found ", oldRetTy
1063 );
1064 });
1065 }
1066 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] replaced " << op);
1068 rewriter, op, TypeRange {newRetTy}, op.getCallee(),
1069 AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup, op.getArgOperands()
1070 );
1071 (void)newOp; // tell compiler it's intentionally unused in release builds
1072 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
1073 return success();
1074 }
1075
1076private:
1079 inline LogicalResult instantiateViaTargetType(
1080 const AffineMapFolder::Input &in, AffineMapFolder::Output &out,
1081 OperandRange::type_range callArgTypes, FuncDefOp targetFunc
1082 ) const {
1083 assert(targetFunc.isStructCompute()); // since `op.calleeIsStructCompute()`
1084 ArrayAttr targetResTyParams = targetFunc.getSingleResultTypeOfCompute().getParams();
1085 assert(!isNullOrEmpty(targetResTyParams)); // same cardinality as `in.paramsOfStructTy`
1086 assert(in.paramsOfStructTy.size() == targetResTyParams.size()); // verifier ensures this
1087
1088 if (llvm::all_of(in.paramsOfStructTy, isConcreteAttr<>)) {
1089 // Nothing can change if everything is already concrete
1090 return failure();
1091 }
1092
1093 LLVM_DEBUG({
1094 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1095 << " call arg types: " << debug::toStringList(callArgTypes) << '\n';
1096 llvm::dbgs() << '[' << __FUNCTION__ << ']' << " target func arg types: "
1097 << debug::toStringList(targetFunc.getArgumentTypes()) << '\n';
1098 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1099 << " struct params @ call: " << debug::toStringList(in.paramsOfStructTy) << '\n';
1100 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1101 << " target struct params: " << debug::toStringList(targetResTyParams) << '\n';
1102 });
1103
1104 UnificationMap unifications;
1105 bool unifies = typeListsUnify(targetFunc.getArgumentTypes(), callArgTypes, {}, &unifications);
1106 (void)unifies; // tell compiler it's intentionally unused in builds without assertions
1107 assert(unifies && "should have been checked by verifiers");
1108
1109 LLVM_DEBUG({
1110 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1111 << " unifications of arg types: " << debug::toStringList(unifications) << '\n';
1112 });
1113
1114 // Check for LHS SymRef (i.e., from the target function) that have RHS concrete Attributes (i.e.
1115 // from the call argument types) without any struct parameters (because the type with concrete
1116 // struct parameters will be used to instantiate the target struct rather than the fully
1117 // flattened struct type resulting in type mismatch of the callee to target) and perform those
1118 // replacements in the `targetFunc` return type to produce the new result type for the CallOp.
1119 SmallVector<Attribute> newReturnStructParams = llvm::map_to_vector(
1120 llvm::zip_equal(targetResTyParams.getValue(), in.paramsOfStructTy),
1121 [&unifications](std::tuple<Attribute, Attribute> p) {
1122 Attribute fromCall = std::get<1>(p);
1123 // Preserve attributes that are already concrete at the call site. Otherwise attempt to lookup
1124 // non-parameterized concrete unification for the target struct parameter symbol.
1125 if (!isConcreteAttr<>(fromCall)) {
1126 Attribute fromTgt = std::get<0>(p);
1127 LLVM_DEBUG({
1128 llvm::dbgs() << "[instantiateViaTargetType] fromCall = " << fromCall << '\n';
1129 llvm::dbgs() << "[instantiateViaTargetType] fromTgt = " << fromTgt << '\n';
1130 });
1131 assert(llvm::isa<SymbolRefAttr>(fromTgt));
1132 auto it = unifications.find(std::make_pair(llvm::cast<SymbolRefAttr>(fromTgt), Side::LHS));
1133 if (it != unifications.end()) {
1134 Attribute unifiedAttr = it->second;
1135 LLVM_DEBUG({
1136 llvm::dbgs() << "[instantiateViaTargetType] unifiedAttr = " << unifiedAttr << '\n';
1137 });
1138 if (unifiedAttr && isConcreteAttr<false>(unifiedAttr)) {
1139 return unifiedAttr;
1140 }
1141 }
1142 }
1143 return fromCall;
1144 }
1145 );
1146
1147 out.paramsOfStructTy = newReturnStructParams;
1148 assert(out.paramsOfStructTy.size() == in.paramsOfStructTy.size() && "post-condition");
1149 assert(out.mapOpGroups.empty() && "post-condition");
1150 assert(out.dimsPerGroup.empty() && "post-condition");
1151 return success();
1152 }
1153};
1154
1155LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1156 MLIRContext *ctx = modOp.getContext();
1157 RewritePatternSet patterns(ctx);
1158 patterns.add<
1159 InstantiateAtCreateArrayOp, // CreateArrayOp
1160 InstantiateAtCallOpCompute // CallOp, targeting struct "compute()"
1161 >(ctx, tracker);
1162
1163 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1164}
1165
1166} // namespace Step3_InstantiateAffineMaps
1167
1169
1171class UpdateNewArrayElemFromWrite final : public OpRewritePattern<CreateArrayOp> {
1172 ConversionTracker &tracker_;
1173
1174public:
1175 UpdateNewArrayElemFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1176 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1177
1178 LogicalResult matchAndRewrite(CreateArrayOp op, PatternRewriter &rewriter) const override {
1179 Value createResult = op.getResult();
1180 ArrayType createResultType = dyn_cast<ArrayType>(createResult.getType());
1181 assert(createResultType && "CreateArrayOp must produce ArrayType");
1182 Type oldResultElemType = createResultType.getElementType();
1183
1184 // Look for WriteArrayOp where the array reference is the result of the CreateArrayOp and the
1185 // element type is different.
1186 Type newResultElemType = nullptr;
1187 for (Operation *user : createResult.getUsers()) {
1188 if (WriteArrayOp writeOp = dyn_cast<WriteArrayOp>(user)) {
1189 if (writeOp.getArrRef() != createResult) {
1190 continue;
1191 }
1192 Type writeRValueType = writeOp.getRvalue().getType();
1193 if (writeRValueType == oldResultElemType) {
1194 continue;
1195 }
1196 if (newResultElemType && newResultElemType != writeRValueType) {
1197 LLVM_DEBUG(
1198 llvm::dbgs()
1199 << "[UpdateNewArrayElemFromWrite] multiple possible element types for CreateArrayOp "
1200 << newResultElemType << " vs " << writeRValueType << '\n'
1201 );
1202 return failure();
1203 }
1204 newResultElemType = writeRValueType;
1205 }
1206 }
1207 if (!newResultElemType) {
1208 // no replacement type found
1209 return failure();
1210 }
1211 if (!tracker_.isLegalConversion(
1212 oldResultElemType, newResultElemType, "UpdateNewArrayElemFromWrite"
1213 )) {
1214 return failure();
1215 }
1216 ArrayType newType = createResultType.cloneWith(newResultElemType);
1217 rewriter.modifyOpInPlace(op, [&createResult, &newType]() { createResult.setType(newType); });
1218 LLVM_DEBUG(
1219 llvm::dbgs() << "[UpdateNewArrayElemFromWrite] updated result type of " << op << '\n'
1220 );
1221 return success();
1222 }
1223};
1224
1225namespace {
1226
1227LogicalResult updateArrayElemFromArrAccessOp(
1228 ArrayAccessOpInterface op, Type scalarElemTy, ConversionTracker &tracker,
1229 PatternRewriter &rewriter
1230) {
1231 ArrayType oldArrType = op.getArrRefType();
1232 if (oldArrType.getElementType() == scalarElemTy) {
1233 return failure(); // no change needed
1234 }
1235 ArrayType newArrType = oldArrType.cloneWith(scalarElemTy);
1236 if (oldArrType == newArrType ||
1237 !tracker.isLegalConversion(oldArrType, newArrType, "updateArrayElemFromArrAccessOp")) {
1238 return failure();
1239 }
1240 rewriter.modifyOpInPlace(op, [&op, &newArrType]() { op.getArrRef().setType(newArrType); });
1241 LLVM_DEBUG(
1242 llvm::dbgs() << "[updateArrayElemFromArrAccessOp] updated base array type in " << op << '\n'
1243 );
1244 return success();
1245}
1246
1247} // namespace
1248
1249class UpdateArrayElemFromArrWrite final : public OpRewritePattern<WriteArrayOp> {
1250 ConversionTracker &tracker_;
1251
1252public:
1253 UpdateArrayElemFromArrWrite(MLIRContext *ctx, ConversionTracker &tracker)
1254 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1255
1256 LogicalResult matchAndRewrite(WriteArrayOp op, PatternRewriter &rewriter) const override {
1257 return updateArrayElemFromArrAccessOp(op, op.getRvalue().getType(), tracker_, rewriter);
1258 }
1259};
1260
1261class UpdateArrayElemFromArrRead final : public OpRewritePattern<ReadArrayOp> {
1262 ConversionTracker &tracker_;
1263
1264public:
1265 UpdateArrayElemFromArrRead(MLIRContext *ctx, ConversionTracker &tracker)
1266 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1267
1268 LogicalResult matchAndRewrite(ReadArrayOp op, PatternRewriter &rewriter) const override {
1269 return updateArrayElemFromArrAccessOp(op, op.getResult().getType(), tracker_, rewriter);
1270 }
1271};
1272
1274class UpdateMemberDefTypeFromWrite final : public OpRewritePattern<MemberDefOp> {
1275 ConversionTracker &tracker_;
1276
1277public:
1278 UpdateMemberDefTypeFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1279 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1280
1281 LogicalResult matchAndRewrite(MemberDefOp op, PatternRewriter &rewriter) const override {
1282 // Find all uses of the member symbol name within its parent struct.
1283 FailureOr<StructDefOp> parentRes = getParentOfType<StructDefOp>(op);
1284 assert(succeeded(parentRes) && "MemberDefOp parent is always StructDefOp"); // per ODS def
1285
1286 // If the symbol is used by a MemberWriteOp with a different result type then change
1287 // the type of the MemberDefOp to match the MemberWriteOp result type.
1288 Type newType = nullptr;
1289 if (auto memberUsers = llzk::getSymbolUses(op, parentRes.value())) {
1290 std::optional<Location> newTypeLoc = std::nullopt;
1291 for (SymbolTable::SymbolUse symUse : memberUsers.value()) {
1292 if (MemberWriteOp writeOp = llvm::dyn_cast<MemberWriteOp>(symUse.getUser())) {
1293 Type writeToType = writeOp.getVal().getType();
1294 LLVM_DEBUG(llvm::dbgs() << "[UpdateMemberDefTypeFromWrite] checking " << writeOp << '\n');
1295 if (!newType) {
1296 // If a new type has not yet been discovered, store the new type.
1297 newType = writeToType;
1298 newTypeLoc = writeOp.getLoc();
1299 } else if (writeToType != newType) {
1300 // Typically, there will only be one write for each member of a struct but do not rely
1301 // on that assumption. If multiple writes with a different types A and B are found where
1302 // A->B is a legal conversion (i.e., more concrete unification), then it is safe to use
1303 // type B with the assumption that the write with type A will be updated by another
1304 // pattern to also use type B.
1305 if (!tracker_.isLegalConversion(writeToType, newType, "UpdateMemberDefTypeFromWrite")) {
1306 if (tracker_.isLegalConversion(
1307 newType, writeToType, "UpdateMemberDefTypeFromWrite"
1308 )) {
1309 // 'writeToType' is the more concrete type
1310 newType = writeToType;
1311 newTypeLoc = writeOp.getLoc();
1312 } else {
1313 // Give an error if the types are incompatible.
1314 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1315 diag.append(
1316 "Cannot update type of '", MemberDefOp::getOperationName(),
1317 "' because there are multiple '", MemberWriteOp::getOperationName(),
1318 "' with different value types"
1319 );
1320 if (newTypeLoc) {
1321 diag.attachNote(*newTypeLoc).append("type written here is ", newType);
1322 }
1323 diag.attachNote(writeOp.getLoc()).append("type written here is ", writeToType);
1324 });
1325 }
1326 }
1327 }
1328 }
1329 }
1330 }
1331 if (!newType || newType == op.getType()) {
1332 // nothing changed
1333 return failure();
1334 }
1335 if (!tracker_.isLegalConversion(op.getType(), newType, "UpdateMemberDefTypeFromWrite")) {
1336 return failure();
1337 }
1338 rewriter.modifyOpInPlace(op, [&op, &newType]() { op.setType(newType); });
1339 LLVM_DEBUG(llvm::dbgs() << "[UpdateMemberDefTypeFromWrite] updated type of " << op << '\n');
1340 return success();
1341 }
1342};
1343
1344namespace {
1345
1346SmallVector<std::unique_ptr<Region>> moveRegions(Operation *op) {
1347 SmallVector<std::unique_ptr<Region>> newRegions;
1348 for (Region &region : op->getRegions()) {
1349 auto newRegion = std::make_unique<Region>();
1350 newRegion->takeBody(region);
1351 newRegions.push_back(std::move(newRegion));
1352 }
1353 return newRegions;
1354}
1355
1356} // namespace
1357
1360class UpdateInferredResultTypes final : public OpTraitRewritePattern<OpTrait::InferTypeOpAdaptor> {
1361 ConversionTracker &tracker_;
1362
1363public:
1364 UpdateInferredResultTypes(MLIRContext *ctx, ConversionTracker &tracker)
1365 : OpTraitRewritePattern(ctx, 6), tracker_(tracker) {}
1366
1367 LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
1368 SmallVector<Type, 1> inferredResultTypes;
1369 InferTypeOpInterface retTypeFn = llvm::cast<InferTypeOpInterface>(op);
1370 LogicalResult result = retTypeFn.inferReturnTypes(
1371 op->getContext(), op->getLoc(), op->getOperands(), op->getRawDictionaryAttrs(),
1372 op->getPropertiesStorage(), op->getRegions(), inferredResultTypes
1373 );
1374 if (failed(result)) {
1375 return failure();
1376 }
1377 if (op->getResultTypes() == inferredResultTypes) {
1378 // nothing changed
1379 return failure();
1380 }
1381 if (!tracker_.areLegalConversions(
1382 op->getResultTypes(), inferredResultTypes, "UpdateInferredResultTypes"
1383 )) {
1384 return failure();
1385 }
1386
1387 // Move nested region bodies and replace the original op with the updated types list.
1388 LLVM_DEBUG(llvm::dbgs() << "[UpdateInferredResultTypes] replaced " << *op);
1389 SmallVector<std::unique_ptr<Region>> newRegions = moveRegions(op);
1390 Operation *newOp = rewriter.create(
1391 op->getLoc(), op->getName().getIdentifier(), op->getOperands(), inferredResultTypes,
1392 op->getAttrs(), op->getSuccessors(), newRegions
1393 );
1394 rewriter.replaceOp(op, newOp);
1395 LLVM_DEBUG(llvm::dbgs() << " with " << *newOp << '\n');
1396 return success();
1397 }
1398};
1399
1401class UpdateFuncTypeFromReturn final : public OpRewritePattern<FuncDefOp> {
1402 ConversionTracker &tracker_;
1403
1404public:
1405 UpdateFuncTypeFromReturn(MLIRContext *ctx, ConversionTracker &tracker)
1406 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1407
1408 LogicalResult matchAndRewrite(FuncDefOp op, PatternRewriter &rewriter) const override {
1409 Region &body = op.getFunctionBody();
1410 if (body.empty()) {
1411 return failure();
1412 }
1413 ReturnOp retOp = llvm::dyn_cast<ReturnOp>(body.back().getTerminator());
1414 assert(retOp && "final op in body region must be return");
1415 OperandRange::type_range tyFromReturnOp = retOp.getOperands().getTypes();
1416
1417 FunctionType oldFuncTy = op.getFunctionType();
1418 if (oldFuncTy.getResults() == tyFromReturnOp) {
1419 // nothing changed
1420 return failure();
1421 }
1422 if (!tracker_.areLegalConversions(
1423 oldFuncTy.getResults(), tyFromReturnOp, "UpdateFuncTypeFromReturn"
1424 )) {
1425 return failure();
1426 }
1427
1428 rewriter.modifyOpInPlace(op, [&]() {
1429 op.setFunctionType(rewriter.getFunctionType(oldFuncTy.getInputs(), tyFromReturnOp));
1430 });
1431 LLVM_DEBUG(
1432 llvm::dbgs() << "[UpdateFuncTypeFromReturn] changed " << op.getSymName() << " from "
1433 << oldFuncTy << " to " << op.getFunctionType() << '\n'
1434 );
1435 return success();
1436 }
1437};
1438
1443class UpdateGlobalCallOpTypes final : public OpRewritePattern<CallOp> {
1444 ConversionTracker &tracker_;
1445
1446public:
1447 UpdateGlobalCallOpTypes(MLIRContext *ctx, ConversionTracker &tracker)
1448 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1449
1450 LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override {
1451 SymbolTableCollection tables;
1452 auto lookupRes = lookupTopLevelSymbol<FuncDefOp>(tables, op.getCalleeAttr(), op);
1453 if (failed(lookupRes)) {
1454 return failure();
1455 }
1456 FuncDefOp targetFunc = lookupRes->get();
1457 if (targetFunc.isInStruct()) {
1458 // this pattern only applies when the callee is NOT in a struct
1459 return failure();
1460 }
1461 if (op.getResultTypes() == targetFunc.getFunctionType().getResults()) {
1462 // nothing changed
1463 return failure();
1464 }
1465 if (!tracker_.areLegalConversions(
1466 op.getResultTypes(), targetFunc.getFunctionType().getResults(),
1467 "UpdateGlobalCallOpTypes"
1468 )) {
1469 return failure();
1470 }
1471
1472 LLVM_DEBUG(llvm::dbgs() << "[UpdateGlobalCallOpTypes] replaced " << op);
1473 CallOp newOp = replaceOpWithNewOp<CallOp>(rewriter, op, targetFunc, op.getArgOperands());
1474 (void)newOp; // tell compiler it's intentionally unused in release builds
1475 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
1476 return success();
1477 }
1478};
1479
1480namespace {
1481
1482LogicalResult updateMemberRefValFromMemberDef(
1483 MemberRefOpInterface op, ConversionTracker &tracker, PatternRewriter &rewriter
1484) {
1485 SymbolTableCollection tables;
1486 auto def = op.getMemberDefOp(tables);
1487 if (failed(def)) {
1488 return failure();
1489 }
1490 Type oldResultType = op.getVal().getType();
1491 Type newResultType = def->get().getType();
1492 if (oldResultType == newResultType ||
1493 !tracker.isLegalConversion(oldResultType, newResultType, "updateMemberRefValFromMemberDef")) {
1494 return failure();
1495 }
1496 rewriter.modifyOpInPlace(op, [&op, &newResultType]() { op.getVal().setType(newResultType); });
1497 LLVM_DEBUG(
1498 llvm::dbgs() << "[updateMemberRefValFromMemberDef] updated value type in " << op << '\n'
1499 );
1500 return success();
1501}
1502
1503} // namespace
1504
1506class UpdateMemberReadValFromDef final : public OpRewritePattern<MemberReadOp> {
1507 ConversionTracker &tracker_;
1508
1509public:
1510 UpdateMemberReadValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1511 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1512
1513 LogicalResult matchAndRewrite(MemberReadOp op, PatternRewriter &rewriter) const override {
1514 return updateMemberRefValFromMemberDef(op, tracker_, rewriter);
1515 }
1516};
1517
1519class UpdateMemberWriteValFromDef final : public OpRewritePattern<MemberWriteOp> {
1520 ConversionTracker &tracker_;
1521
1522public:
1523 UpdateMemberWriteValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
1524 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1525
1526 LogicalResult matchAndRewrite(MemberWriteOp op, PatternRewriter &rewriter) const override {
1527 return updateMemberRefValFromMemberDef(op, tracker_, rewriter);
1528 }
1529};
1530
1531LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1532 MLIRContext *ctx = modOp.getContext();
1533 RewritePatternSet patterns(ctx);
1534 patterns.add<
1535 // Benefit of this one must be higher than rules that would propagate the type in the opposite
1536 // direction (ex: `UpdateArrayElemFromArrRead`) else the greedy conversion would not converge.
1537 // benefit = 6
1538 UpdateInferredResultTypes, // OpTrait::InferTypeOpAdaptor (ReadArrayOp, ExtractArrayOp)
1539 // benefit = 3
1540 UpdateGlobalCallOpTypes, // CallOp, targeting non-struct functions
1541 UpdateFuncTypeFromReturn, // FuncDefOp
1542 UpdateNewArrayElemFromWrite, // CreateArrayOp
1543 UpdateArrayElemFromArrRead, // ReadArrayOp
1544 UpdateArrayElemFromArrWrite, // WriteArrayOp
1545 UpdateMemberDefTypeFromWrite, // MemberDefOp
1546 UpdateMemberReadValFromDef, // MemberReadOp
1547 UpdateMemberWriteValFromDef // MemberWriteOp
1548 >(ctx, tracker);
1549
1550 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1551}
1552} // namespace Step4_PropagateTypes
1553
1554namespace Step5_Cleanup {
1555
1556class CleanupBase {
1557public:
1558 SymbolTableCollection tables;
1559
1560 CleanupBase(ModuleOp root, const SymbolDefTree &symDefTree, const SymbolUseGraph &symUseGraph)
1561 : rootMod(root), defTree(symDefTree), useGraph(symUseGraph) {}
1562
1563protected:
1564 ModuleOp rootMod;
1565 const SymbolDefTree &defTree;
1566 const SymbolUseGraph &useGraph;
1567};
1568
1569struct FromKeepSet : public CleanupBase {
1570 using CleanupBase::CleanupBase;
1571
1575 LogicalResult eraseUnreachableFrom(ArrayRef<StructDefOp> keep) {
1576 // Initialize roots from the given StructDefOp instances
1577 SetVector<SymbolOpInterface> roots(keep.begin(), keep.end());
1578 // Add GlobalDefOp and "free functions" to the set of roots
1579 rootMod.walk([&roots](Operation *op) {
1580 if (global::GlobalDefOp gdef = llvm::dyn_cast<global::GlobalDefOp>(op)) {
1581 roots.insert(gdef);
1582 } else if (function::FuncDefOp fdef = llvm::dyn_cast<function::FuncDefOp>(op)) {
1583 if (!fdef.isInStruct()) {
1584 roots.insert(fdef);
1585 }
1586 }
1587 });
1588
1589 // Use a SymbolDefTree to find all Symbol defs reachable from one of the root nodes. Then
1590 // collect all Symbol uses reachable from those def nodes. These are the symbols that should
1591 // be preserved. All other symbol defs should be removed.
1592 llvm::df_iterator_default_set<const SymbolUseGraphNode *> symbolsToKeep;
1593 for (size_t i = 0; i < roots.size(); ++i) { // iterate for safe insertion
1594 SymbolOpInterface keepRoot = roots[i];
1595 LLVM_DEBUG({ llvm::dbgs() << "[EraseUnreachable] root: " << keepRoot << '\n'; });
1596 const SymbolDefTreeNode *keepRootNode = defTree.lookupNode(keepRoot);
1597 assert(keepRootNode && "every struct def must be in the def tree");
1598 for (const SymbolDefTreeNode *reachableDefNode : llvm::depth_first(keepRootNode)) {
1599 LLVM_DEBUG({
1600 llvm::dbgs() << "[EraseUnreachable] can reach: " << reachableDefNode->getOp() << '\n';
1601 });
1602 if (SymbolOpInterface reachableDef = reachableDefNode->getOp()) {
1603 // Use 'depth_first_ext()' to get all symbol uses reachable from the current Symbol def
1604 // node. There are no uses if the node is not in the graph. Within the loop that populates
1605 // 'depth_first_ext()', also check if the symbol is a StructDefOp and ensure it is in
1606 // 'roots' so the outer loop will ensure that all symbols reachable from it are preserved.
1607 if (const SymbolUseGraphNode *useGraphNodeForDef = useGraph.lookupNode(reachableDef)) {
1608 for (const SymbolUseGraphNode *usedSymbolNode :
1609 depth_first_ext(useGraphNodeForDef, symbolsToKeep)) {
1610 LLVM_DEBUG({
1611 llvm::dbgs() << "[EraseUnreachable] uses symbol: "
1612 << usedSymbolNode->getSymbolPath() << '\n';
1613 });
1614 // Ignore struct/template parameter symbols (before doing the lookup below because it
1615 // would fail anyway and then cause the "failed" case to be triggered unnecessarily).
1616 if (usedSymbolNode->isStructParam()) {
1617 continue;
1618 }
1619 // If `usedSymbolNode` references a StructDefOp, ensure it's considered in the roots.
1620 auto lookupRes = usedSymbolNode->lookupSymbol(tables);
1621 if (failed(lookupRes)) {
1622 LLVM_DEBUG(useGraph.dumpToDotFile());
1623 return failure();
1624 }
1625 // If loaded via an IncludeOp it's not in the current AST anyway so ignore.
1626 if (lookupRes->viaInclude()) {
1627 continue;
1628 }
1629 if (StructDefOp asStruct = llvm::dyn_cast<StructDefOp>(lookupRes->get())) {
1630 bool insertRes = roots.insert(asStruct);
1631 (void)insertRes; // tell compiler it's intentionally unused in release builds
1632 LLVM_DEBUG({
1633 if (insertRes) {
1634 llvm::dbgs() << "[EraseUnreachable] found another root: " << asStruct << '\n';
1635 }
1636 });
1637 }
1638 }
1639 }
1640 }
1641 }
1642 }
1643
1644 rootMod.walk([this, &symbolsToKeep](StructDefOp op) {
1645 const SymbolUseGraphNode *n = this->useGraph.lookupNode(op);
1646 assert(n);
1647 if (!symbolsToKeep.contains(n)) {
1648 LLVM_DEBUG(llvm::dbgs() << "[EraseUnreachable] removing: " << op.getSymName() << '\n');
1649 op.erase();
1650 }
1651
1652 return WalkResult::skip(); // StructDefOp cannot be nested
1653 });
1654
1655 return success();
1656 }
1657};
1658
1659struct FromEraseSet : public CleanupBase {
1660
1662 FromEraseSet(
1663 ModuleOp root, const SymbolDefTree &symDefTree, const SymbolUseGraph &symUseGraph,
1664 DenseSet<SymbolRefAttr> &&tryToErasePaths
1665 )
1666 : CleanupBase(root, symDefTree, symUseGraph) {
1667 // Convert the set of paths targeted for erasure into a set of the StructDefOp
1668 for (SymbolRefAttr path : tryToErasePaths) {
1669 LLVM_DEBUG(llvm::dbgs() << "[FromEraseSet] path to erase: " << path << '\n';);
1670 Operation *lookupFrom = rootMod.getOperation();
1671 auto res = lookupSymbolIn<StructDefOp>(tables, path, lookupFrom, lookupFrom);
1672 assert(succeeded(res) && "inputs must be valid StructDefOp references");
1673 if (!res->viaInclude()) { // do not remove if it's from another source file
1674 auto op = res->get();
1675 LLVM_DEBUG(llvm::dbgs() << "[FromEraseSet] added op to the erase set: " << op << '\n';);
1676 tryToErase.insert(op);
1677 } else {
1678 LLVM_DEBUG(
1679 llvm::dbgs() << "[FromEraseSet] ignored op because it comes from an include: "
1680 << res->get() << '\n';
1681 );
1682 }
1683 }
1684 }
1685
1686 LogicalResult eraseUnusedStructs() {
1687 // Collect the subset of 'tryToErase' that has no remaining uses.
1688 for (StructDefOp sd : tryToErase) {
1689 collectSafeToErase(sd);
1690 }
1691 // The `visitedPlusSafetyResult` will contain FuncDefOp w/in the StructDefOp so just a single
1692 // loop to `dyn_cast` and `erase()` will cause `use-after-free` errors w/in the `dyn_cast`.
1693 // Instead, reduce the map to only those that should be erased and erase in a separate loop.
1694 for (auto it = visitedPlusSafetyResult.begin(); it != visitedPlusSafetyResult.end(); ++it) {
1695 if (!it->second || !llvm::isa<StructDefOp>(it->first.getOperation())) {
1696 visitedPlusSafetyResult.erase(it);
1697 }
1698 }
1699 for (auto &[sym, _] : visitedPlusSafetyResult) {
1700 LLVM_DEBUG(llvm::dbgs() << "[EraseIfUnused] removing: " << sym.getNameAttr() << '\n');
1701 sym.erase();
1702 }
1703 return success();
1704 }
1705
1706 const DenseSet<StructDefOp> &getTryToEraseSet() const { return tryToErase; }
1707
1708private:
1710 DenseSet<StructDefOp> tryToErase;
1714 DenseMap<SymbolOpInterface, bool> visitedPlusSafetyResult;
1716 DenseMap<const SymbolUseGraphNode *, SymbolOpInterface> lookupCache;
1717
1720 bool collectSafeToErase(SymbolOpInterface check) {
1721 assert(check); // pre-condition
1722
1723 // If previously visited, return the safety result.
1724 auto visited = visitedPlusSafetyResult.find(check);
1725 if (visited != visitedPlusSafetyResult.end()) {
1726 return visited->second;
1727 }
1728
1729 // If it's a StructDefOp that is not in `tryToErase` then it cannot be erased.
1730 if (StructDefOp sd = llvm::dyn_cast<StructDefOp>(check.getOperation())) {
1731 if (!tryToErase.contains(sd)) {
1732 visitedPlusSafetyResult[check] = false;
1733 return false;
1734 }
1735 }
1736
1737 // Otherwise, temporarily mark as safe b/c a node cannot keep itself live (and this prevents
1738 // the recursion from getting stuck in an infinite loop).
1739 visitedPlusSafetyResult[check] = true;
1740
1741 // Check if it's safe according to both the def tree and use graph.
1742 // Note: every symbol must have a def node but module symbols may not have a use node.
1743 if (collectSafeToErase(defTree.lookupNode(check))) {
1744 auto useNode = useGraph.lookupNode(check);
1745 assert(useNode || llvm::isa<ModuleOp>(check.getOperation()));
1746 if (!useNode || collectSafeToErase(useNode)) {
1747 return true;
1748 }
1749 }
1750
1751 // Otherwise, revert the safety decision and return it.
1752 visitedPlusSafetyResult[check] = false;
1753 return false;
1754 }
1755
1757 bool collectSafeToErase(const SymbolDefTreeNode *check) {
1758 assert(check); // pre-condition
1759 if (const SymbolDefTreeNode *p = check->getParent()) {
1760 if (SymbolOpInterface checkOp = p->getOp()) { // safe if parent is root
1761 return collectSafeToErase(checkOp);
1762 }
1763 }
1764 return true;
1765 }
1766
1768 bool collectSafeToErase(const SymbolUseGraphNode *check) {
1769 assert(check); // pre-condition
1770 for (const SymbolUseGraphNode *p : check->predecessorIter()) {
1771 if (SymbolOpInterface checkOp = cachedLookup(p)) { // safe if via IncludeOp
1772 if (!collectSafeToErase(checkOp)) {
1773 return false;
1774 }
1775 }
1776 }
1777 return true;
1778 }
1779
1784 SymbolOpInterface cachedLookup(const SymbolUseGraphNode *node) {
1785 assert(node && "must provide a node"); // pre-condition
1786 // Check for cached result
1787 auto fromCache = lookupCache.find(node);
1788 if (fromCache != lookupCache.end()) {
1789 return fromCache->second;
1790 }
1791 // Otherwise, perform lookup and cache
1792 auto lookupRes = node->lookupSymbol(tables);
1793 assert(succeeded(lookupRes) && "graph contains node with invalid path");
1794 assert(lookupRes->get() != nullptr && "lookup must return an Operation");
1795 // If loaded via an IncludeOp it's not in the current AST anyway so ignore.
1796 // NOTE: The SymbolUseGraph does contain nodes for struct parameters which cannot cast to
1797 // SymbolOpInterface. However, those will always be leaf nodes in the SymbolUseGraph and
1798 // therefore will not be traversed by this analysis so directly casting is fine.
1799 SymbolOpInterface actualRes =
1800 lookupRes->viaInclude() ? nullptr : llvm::cast<SymbolOpInterface>(lookupRes->get());
1801 // Cache and return
1802 lookupCache[node] = actualRes;
1803 assert((!actualRes == lookupRes->viaInclude()) && "not found iff included"); // post-condition
1804 return actualRes;
1805 }
1806};
1807
1808} // namespace Step5_Cleanup
1809
1810class FlatteningPass : public llzk::polymorphic::impl::FlatteningPassBase<FlatteningPass> {
1811
1812 void runOnOperation() override {
1813 ModuleOp modOp = getOperation();
1814 if (failed(runOn(modOp))) {
1815 LLVM_DEBUG({
1816 // If the pass failed, dump the current IR.
1817 llvm::dbgs() << "=====================================================================\n";
1818 llvm::dbgs() << " Dumping module after failure of pass " << DEBUG_TYPE << '\n';
1819 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
1820 llvm::dbgs() << "=====================================================================\n";
1821 });
1822 signalPassFailure();
1823 }
1824 }
1825
1826 inline LogicalResult runOn(ModuleOp modOp) {
1827 // If the cleanup mode is set to remove anything not reachable from the main struct, do an
1828 // initial pass to remove things that are not reachable (as an optimization) because creating
1829 // an instantiated version of a struct will not cause something to become reachable that was
1830 // not already reachable in parameterized form.
1831 if (cleanupMode == StructCleanupMode::MainAsRoot) {
1832 if (failed(eraseUnreachableFromMainStruct(modOp))) {
1833 return failure();
1834 }
1835 }
1836
1837 {
1838 // Preliminary step: remove empty parameter lists from structs
1839 OpPassManager nestedPM(ModuleOp::getOperationName());
1840 nestedPM.addPass(createEmptyParamListRemoval());
1841 if (failed(runPipeline(nestedPM, modOp))) {
1842 return failure();
1843 }
1844 }
1845
1846 ConversionTracker tracker;
1847 unsigned loopCount = 0;
1848 do {
1849 ++loopCount;
1850 if (loopCount > iterationLimit) {
1851 llvm::errs() << DEBUG_TYPE << " exceeded the limit of " << iterationLimit
1852 << " iterations!\n";
1853 return failure();
1854 }
1855 tracker.resetModifiedFlag();
1856
1857 LLVM_DEBUG({
1858 llvm::dbgs() << "[FlatteningPass(count=" << loopCount
1859 << ")] Running step 1: struct instantiation\n";
1860 });
1861 // Find calls to "compute()" that return a parameterized struct and replace it to call a
1862 // flattened version of the struct that has parameters replaced with the constant values.
1863 // Create the necessary instantiated/flattened struct in the same location as the original.
1864 if (failed(Step1_InstantiateStructs::run(modOp, tracker))) {
1865 llvm::errs() << DEBUG_TYPE << " failed while replacing concrete-parameter struct types\n";
1866 return failure();
1867 }
1868
1869 LLVM_DEBUG({
1870 llvm::dbgs() << "[FlatteningPass(count=" << loopCount
1871 << ")] Running step 2: loop unrolling\n";
1872 });
1873 // Unroll loops with known iterations.
1874 if (failed(Step2_Unroll::run(modOp, tracker))) {
1875 llvm::errs() << DEBUG_TYPE << " failed while unrolling loops\n";
1876 return failure();
1877 }
1878
1879 LLVM_DEBUG({
1880 llvm::dbgs() << "[FlatteningPass(count=" << loopCount
1881 << ")] Running step 3: affine maps instantiation\n";
1882 });
1883 // Instantiate affine_map parameters of StructType and ArrayType.
1884 if (failed(Step3_InstantiateAffineMaps::run(modOp, tracker))) {
1885 llvm::errs() << DEBUG_TYPE << " failed while instantiating `affine_map` parameters\n";
1886 return failure();
1887 }
1888
1889 LLVM_DEBUG({
1890 llvm::dbgs() << "[FlatteningPass(count=" << loopCount
1891 << ")] Running step 4: type propagation\n";
1892 });
1893 // Propagate updated types using the semantics of various ops.
1894 if (failed(Step4_PropagateTypes::run(modOp, tracker))) {
1895 llvm::errs() << DEBUG_TYPE << " failed while propagating instantiated types\n";
1896 return failure();
1897 }
1898
1899 LLVM_DEBUG(if (tracker.isModified()) {
1900 llvm::dbgs() << "=====================================================================\n";
1901 llvm::dbgs() << " Dumping module between iterations of " << DEBUG_TYPE << '\n';
1902 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
1903 llvm::dbgs() << "=====================================================================\n";
1904 });
1905 } while (tracker.isModified());
1906
1907 LLVM_DEBUG({ llvm::dbgs() << "[FlatteningPass] Running step 5: cleanup "; });
1908 // Perform cleanup according to the 'cleanupMode' option.
1909 switch (cleanupMode) {
1910 case StructCleanupMode::MainAsRoot:
1911 LLVM_DEBUG(llvm::dbgs() << "(main as root mode)\n");
1912 return eraseUnreachableFromMainStruct(modOp, false);
1913 case StructCleanupMode::ConcreteAsRoot:
1914 LLVM_DEBUG(llvm::dbgs() << "(concrete structs mode)\n");
1915 return eraseUnreachableFromConcreteStructs(modOp);
1916 case StructCleanupMode::Preimage:
1917 LLVM_DEBUG(llvm::dbgs() << "(preimage mode)\n");
1918 return erasePreimageOfInstantiations(modOp, tracker);
1919 case StructCleanupMode::Disabled:
1920 LLVM_DEBUG(llvm::dbgs() << "(disabled)\n");
1921 return success();
1922 }
1923 llvm_unreachable("switch cases cover all options");
1924 }
1925
1926 // Erase parameterized structs that were replaced with concrete instantiations.
1927 LogicalResult erasePreimageOfInstantiations(ModuleOp rootMod, const ConversionTracker &tracker) {
1928 // TODO: The names from getInstantiatedStructNames() are NOT guaranteed to be paths from the
1929 // "top root" and they also do not indicate a root module so there could be ambiguity. This is a
1930 // broader problem in the FlatteningPass itself so let's just assume, for now, that these are
1931 // paths from the "top root". See [LLZK-286].
1932 Step5_Cleanup::FromEraseSet cleaner(
1933 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>(),
1934 tracker.getInstantiatedStructNames()
1935 );
1936 LogicalResult res = cleaner.eraseUnusedStructs();
1937 if (succeeded(res)) {
1938 LLVM_DEBUG(llvm::dbgs() << "[Cleanup(preimage)] success\n";);
1939 // Warn about any structs that were instantiated but still have uses elsewhere.
1940 const SymbolUseGraph *useGraph = nullptr;
1941 rootMod->walk([this, &cleaner, &useGraph](StructDefOp op) {
1942 if (cleaner.getTryToEraseSet().contains(op)) {
1943 // If needed, rebuild use graph to reflect deletions.
1944 if (!useGraph) {
1945 useGraph = &getAnalysis<SymbolUseGraph>();
1946 }
1947 // If the op has any users, report the warning.
1948 if (useGraph->lookupNode(op)->hasPredecessor()) {
1949 op.emitWarning("Parameterized struct still has uses!").report();
1950 }
1951 }
1952 return WalkResult::skip(); // StructDefOp cannot be nested
1953 });
1954 } else {
1955 LLVM_DEBUG(llvm::dbgs() << "[Cleanup(preimage)] failed\n";);
1956 }
1957 return res;
1958 }
1959
1960 LogicalResult eraseUnreachableFromConcreteStructs(ModuleOp rootMod) {
1961 SmallVector<StructDefOp> roots;
1962 rootMod.walk([&roots](StructDefOp op) {
1963 // Note: no need to check if the ConstParamsAttr is empty since `EmptyParamRemovalPass`
1964 // ran earlier.
1965 if (!op.hasConstParamsAttr()) {
1966 roots.push_back(op);
1967 }
1968 return WalkResult::skip(); // StructDefOp cannot be nested
1969 });
1970
1971 Step5_Cleanup::FromKeepSet cleaner(
1972 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
1973 );
1974 return cleaner.eraseUnreachableFrom(roots);
1975 }
1976
1977 LogicalResult eraseUnreachableFromMainStruct(ModuleOp rootMod, bool emitWarning = true) {
1978 Step5_Cleanup::FromKeepSet cleaner(
1979 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
1980 );
1981 FailureOr<SymbolLookupResult<StructDefOp>> mainOpt =
1982 getMainInstanceDef(cleaner.tables, rootMod.getOperation());
1983 if (failed(mainOpt)) {
1984 return failure();
1985 }
1986 SymbolLookupResult<StructDefOp> main = mainOpt.value();
1987 if (emitWarning && !main) {
1988 // Emit warning if there is no main specified because all structs may be removed (only
1989 // structs that are reachable from a global def or free function will be preserved since
1990 // those constructs are not candidate for removal in this pass).
1991 rootMod.emitWarning()
1992 .append(
1993 "using option '", cleanupMode.getArgStr(), '=',
1994 stringifyStructCleanupMode(StructCleanupMode::MainAsRoot), "' with no \"",
1995 MAIN_ATTR_NAME, "\" attribute on the top-level module may remove all structs!"
1996 )
1997 .report();
1998 }
1999 return cleaner.eraseUnreachableFrom(
2000 main ? ArrayRef<StructDefOp> {*main} : ArrayRef<StructDefOp> {}
2001 );
2002 }
2003};
2004
2005} // namespace
2006
2008 return std::make_unique<FlatteningPass>();
2009};
#define DEBUG_TYPE
#define DEBUG_TYPE
#define check(x)
Definition Ops.cpp:171
Common private implementation for poly dialect passes.
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:52
Builds a tree structure representing the symbol table structure.
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
Builds a graph structure representing the relationships between symbols and their uses.
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
Gets the SSA Value for the referenced array.
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced array.
ArrayType cloneWith(std::optional<::llvm::ArrayRef< int64_t > > shape, ::mlir::Type elementType) const
Clone this type with the given shape and element type.
::mlir::Type getElementType() const
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
Definition Types.cpp.inc:83
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.h.inc:408
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:421
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:392
::mlir::TypedValue<::mlir::Type > getResult()
Definition Ops.h.inc:923
::mlir::TypedValue<::mlir::Type > getRvalue()
Definition Ops.h.inc:1075
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:353
void setType(::mlir::Type attrValue)
Definition Ops.cpp.inc:570
::std::optional<::mlir::Attribute > getTableOffset()
Definition Ops.cpp.inc:991
void setTableOffsetAttr(::mlir::Attribute attr)
Definition Ops.h.inc:749
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the MemberRefOp.
::mlir::FailureOr< SymbolLookupResult< MemberDefOp > > getMemberDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the member referenced in this op.
Definition Ops.cpp:620
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:937
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:1189
void setSymName(::llvm::StringRef attrValue)
Definition Ops.cpp.inc:1686
bool hasConstParamsAttr()
Return false iff getConstParamsAttr() returns nullptr
Definition Ops.h.inc:1324
void setConstParamsAttr(::mlir::ArrayAttr attr)
Definition Ops.h.inc:1241
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:79
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
Definition Types.cpp:46
::mlir::ArrayAttr getParams() const
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:772
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the callee is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:795
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.h.inc:267
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:766
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:467
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:241
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:245
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:272
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:952
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Returns the argument types of this function.
Definition Ops.h.inc:757
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the name is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:378
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:947
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
Definition Ops.h.inc:792
bool isInStruct()
Return true iff the function is within a StructDefOp.
Definition Ops.h.inc:789
void setFunctionType(::mlir::FunctionType attrValue)
Definition Ops.cpp.inc:971
::mlir::Operation::operand_range getOperands()
Definition Ops.h.inc:904
::mlir::FlatSymbolRefAttr getConstNameAttr()
Definition Ops.h.inc:443
int main(int argc, char **argv)
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
Definition Debug.h:149
mlir::ConversionTarget newConverterDefinedTargetWithCallback(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, LegalityCheckCallback &cb, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:301
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet that includes a GeneralTypeReplacePattern for all of OpClassesWithStru...
Definition SharedImpl.h:243
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:285
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter.replaceOpWithNewOp() that automatically copies discardable attributes (i....
Definition SharedImpl.h:129
std::unique_ptr< mlir::Pass > createFlatteningPass()
std::unique_ptr< mlir::Pass > createEmptyParamListRemoval()
::llvm::StringRef stringifyStructCleanupMode(StructCleanupMode val)
llvm::SMTExprRef tripCount(mlir::scf::ForOp op, llvm::SMTSolver *solver)
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
Definition TypeHelper.h:223
bool isConcreteType(Type type, bool allowStructParams)
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
TypeClass getIfSingleton(mlir::TypeRange types)
Definition TypeHelper.h:253
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:183
std::string stringWithoutType(mlir::Attribute a)
bool isNullOrEmpty(mlir::ArrayAttr a)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
Definition TypeHelper.h:257
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
bool isDynamic(IntegerAttr intAttr)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:45
int64_t fromAPInt(const llvm::APInt &i)
FailureOr< SymbolLookupResult< StructDefOp > > getMainInstanceDef(SymbolTableCollection &symbolTable, Operation *lookupFrom)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
constexpr char MAIN_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that specifies the type of the main struct.
Definition Constants.h:28