LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
FlatteningPass.cpp
Go to the documentation of this file.
1//===-- LLZKFlatteningPass.cpp - Implements -llzk-flatten pass --*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
29#include "llzk/Util/Concepts.h"
30#include "llzk/Util/Debug.h"
35
36#include <mlir/Dialect/Affine/IR/AffineOps.h>
37#include <mlir/Dialect/Affine/LoopUtils.h>
38#include <mlir/Dialect/Arith/IR/Arith.h>
39#include <mlir/Dialect/SCF/IR/SCF.h>
40#include <mlir/Dialect/SCF/Utils/Utils.h>
41#include <mlir/Dialect/Utils/StaticValueUtils.h>
42#include <mlir/IR/Attributes.h>
43#include <mlir/IR/BuiltinAttributes.h>
44#include <mlir/IR/BuiltinOps.h>
45#include <mlir/IR/BuiltinTypes.h>
46#include <mlir/Interfaces/InferTypeOpInterface.h>
47#include <mlir/Pass/PassManager.h>
48#include <mlir/Support/LLVM.h>
49#include <mlir/Support/LogicalResult.h>
50#include <mlir/Transforms/DialectConversion.h>
51#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
52#include <mlir/Transforms/WalkPatternRewriteDriver.h>
53
54#include <llvm/ADT/APInt.h>
55#include <llvm/ADT/DenseMap.h>
56#include <llvm/ADT/DepthFirstIterator.h>
57#include <llvm/ADT/STLExtras.h>
58#include <llvm/ADT/SmallVector.h>
59#include <llvm/ADT/TypeSwitch.h>
60#include <llvm/Support/Debug.h>
61
62// Include the generated base pass class definitions.
63namespace llzk::polymorphic {
64#define GEN_PASS_DEF_FLATTENINGPASS
66} // namespace llzk::polymorphic
67
68#include "SharedImpl.h"
69
70#define DEBUG_TYPE "llzk-flatten"
71
72using namespace mlir;
73using namespace llzk;
74using namespace llzk::array;
75using namespace llzk::component;
76using namespace llzk::constrain;
77using namespace llzk::felt;
78using namespace llzk::function;
79using namespace llzk::polymorphic;
80using namespace llzk::polymorphic::detail;
81
82namespace {
83
84static void reportDelayedDiagnostics(CallOp caller, SmallVector<Diagnostic> &&diagnostics) {
85 DiagnosticEngine &engine = caller.getContext()->getDiagEngine();
86 for (Diagnostic &diag : diagnostics) {
87 // Update any notes referencing an UnknownLoc to use the CallOp location.
88 for (Diagnostic &note : diag.getNotes()) {
89 assert(note.getNotes().empty() && "notes cannot have notes attached");
90 if (llvm::isa<UnknownLoc>(note.getLocation())) {
91 note = std::move(Diagnostic(caller.getLoc(), note.getSeverity()).append(note.str()));
92 }
93 }
94 // Report. Based on InFlightDiagnostic::report().
95 engine.emit(std::move(diag));
96 }
97}
98
99class ConversionTracker {
101 bool modified;
104 DenseMap<StructType, StructType> structInstantiations;
106 DenseMap<StructType, StructType> reverseInstantiations;
108 DenseSet<SymbolRefAttr> funcInstantiations;
111 DenseMap<StructType, SmallVector<Diagnostic>> delayedDiagnostics;
112
113public:
114 bool isModified() const { return modified; }
115 void resetModifiedFlag() { modified = false; }
116 void updateModifiedFlag(bool currStepModified) { modified |= currStepModified; }
117
118 void recordInstantiation(StructType oldType, StructType newType) {
119 assert(!isNullOrEmpty(oldType.getParams()) && "cannot instantiate with no params");
120
121 auto forwardResult = structInstantiations.try_emplace(oldType, newType);
122 if (forwardResult.second) {
123 // Insertion was successful
124 // ASSERT: The reverse map does not contain this mapping either
125 assert(!reverseInstantiations.contains(newType));
126 reverseInstantiations[newType] = oldType;
127 // Set the modified flag
128 modified = true;
129 } else {
130 // ASSERT: If a mapping already existed for `oldType` it must be `newType`
131 assert(forwardResult.first->getSecond() == newType);
132 // ASSERT: The reverse mapping is already present as well
133 assert(reverseInstantiations.lookup(newType) == oldType);
135 assert(structInstantiations.size() == reverseInstantiations.size());
136 }
139 std::optional<StructType> getInstantiation(StructType oldType) const {
140 auto cachedResult = structInstantiations.find(oldType);
141 if (cachedResult != structInstantiations.end()) {
142 return cachedResult->second;
143 }
144 return std::nullopt;
146
148 void recordInstantiation(SymbolRefAttr funcName) {
149 funcInstantiations.insert(funcName);
150 modified = true;
151 }
152
153 /// Collect the fully-qualified names of all structs and free functions that were instantiated.
154 DenseSet<SymbolRefAttr> getInstantiatedDefinitionNames() const {
155 DenseSet<SymbolRefAttr> instantiatedNames = funcInstantiations;
156 for (const auto &[origRemoteTy, _] : structInstantiations) {
157 instantiatedNames.insert(origRemoteTy.getNameRef());
159 return instantiatedNames;
160 }
161
162 void reportDelayedDiagnostics(StructType newType, CallOp caller) {
163 auto res = delayedDiagnostics.find(newType);
164 if (res != delayedDiagnostics.end()) {
165 ::reportDelayedDiagnostics(caller, std::move(res->second));
166
167 // Emitting a Diagnostic consumes it (per DiagnosticEngine::emit) so remove them from the map.
168 // Unfortunately, this means if the key StructType is the result of instantiation at multiple
169 // `compute()` calls it will only be reported at one of those locations, not all.
170 delayedDiagnostics.erase(newType);
171 }
172 }
173
174 SmallVector<Diagnostic> &delayedDiagnosticSet(StructType newType) {
175 return delayedDiagnostics[newType];
176 }
177
179 /// than the old type with additional allowance for the results of struct flattening conversions.
180 bool isLegalConversion(Type oldType, Type newType, const char *patName) const {
181 std::function<bool(Type, Type)> checkInstantiations = [&](Type oTy, Type nTy) {
182 // Check if `oTy` is a struct with a known instantiation to `nTy`
183 if (StructType oldStructType = llvm::dyn_cast<StructType>(oTy)) {
184 // Note: The values in `structInstantiations` must be no-parameter struct types
185 // so there is no need for recursive check, simple equality is sufficient.
186 if (this->structInstantiations.lookup(oldStructType) == nTy) {
187 return true;
188 }
189 }
190 // Check if `nTy` is the result of a struct instantiation and if the pre-image of
191 // that instantiation (i.e., the parameterized version of the instantiated struct)
192 // is a more concrete unification of `oTy`.
193 if (StructType newStructType = llvm::dyn_cast<StructType>(nTy)) {
194 if (auto preImage = this->reverseInstantiations.lookup(newStructType)) {
195 if (isMoreConcreteUnification(oTy, preImage, checkInstantiations)) {
196 return true;
197 }
198 }
199 }
200 return false;
201 };
202
203 if (isMoreConcreteUnification(oldType, newType, checkInstantiations)) {
204 return true;
205 }
206 LLVM_DEBUG(
207 llvm::dbgs() << "[" << patName << "] Cannot replace old type " << oldType
208 << " with new type " << newType
209 << " because it does not define a compatible and more concrete type.\n";
210 );
211 return false;
212 }
213
214 template <typename T, typename U>
215 inline bool areLegalConversions(T oldTypes, U newTypes, const char *patName) const {
216 return llvm::all_of(
217 llvm::zip_equal(oldTypes, newTypes), [this, &patName](std::tuple<Type, Type> oldThenNew) {
218 return this->isLegalConversion(std::get<0>(oldThenNew), std::get<1>(oldThenNew), patName);
219 }
220 );
221 }
222};
223
224template <typename Impl, typename Op, typename... HandledAttrs>
225class SymbolUserHelper : public OpConversionPattern<Op> {
226private:
227 const DenseMap<Attribute, Attribute> &paramNameToValue;
228
229 SymbolUserHelper(
230 TypeConverter &converter, MLIRContext *ctx, unsigned patternBenefit,
231 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
232 )
233 : OpConversionPattern<Op>(converter, ctx, patternBenefit),
234 paramNameToValue(paramNameToInstantiatedValue) {}
235
236public:
237 using OpAdaptor = typename mlir::OpConversionPattern<Op>::OpAdaptor;
238
239 virtual Attribute getNameAttr(Op) const = 0;
240
241 virtual LogicalResult handleDefaultRewrite(
242 Attribute, Op op, OpAdaptor, ConversionPatternRewriter &, Attribute a
243 ) const {
244 return op->emitOpError().append("expected value with type ", op.getType(), " but found ", a);
245 }
246
247 LogicalResult
248 matchAndRewrite(Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
249 LLVM_DEBUG(llvm::dbgs() << "[SymbolUserHelper] op: " << op << '\n');
250 auto res = this->paramNameToValue.find(getNameAttr(op));
251 if (res == this->paramNameToValue.end()) {
252 LLVM_DEBUG(llvm::dbgs() << "[SymbolUserHelper] no instantiation for " << op << '\n');
253 return failure();
254 }
255 llvm::TypeSwitch<Attribute, LogicalResult> TS(res->second);
256 llvm::TypeSwitch<Attribute, LogicalResult> *ptr = &TS;
257
258 ((ptr = &(ptr->template Case<HandledAttrs>([&](HandledAttrs a) {
259 return static_cast<const Impl *>(this)->handleRewrite(res->first, op, adaptor, rewriter, a);
260 }))),
261 ...);
262
263 return TS.Default([&](Attribute a) {
264 return handleDefaultRewrite(res->first, op, adaptor, rewriter, a);
265 });
266 }
267 friend Impl;
268};
269
270class ClonedBodyConstReadOpPattern
271 : public SymbolUserHelper<
272 ClonedBodyConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr> {
273 SmallVector<Diagnostic> &diagnostics;
274
275 using super =
276 SymbolUserHelper<ClonedBodyConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr>;
277
278public:
279 ClonedBodyConstReadOpPattern(
280 TypeConverter &converter, MLIRContext *ctx,
281 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue,
282 SmallVector<Diagnostic> &instantiationDiagnostics
283 )
284 // benefit>0 so this applies instead of GeneralTypeReplacePattern<ConstReadOp>
285 : super(converter, ctx, /*patternBenefit=*/1, paramNameToInstantiatedValue),
286 diagnostics(instantiationDiagnostics) {}
287
288 Attribute getNameAttr(ConstReadOp op) const override { return op.getConstNameAttr(); }
289
290 LogicalResult handleRewrite(
291 Attribute sym, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, IntegerAttr a
292 ) const {
293 APInt attrValue = a.getValue();
294 Type origResTy = op.getType();
295 if (FeltType ty = llvm::dyn_cast<FeltType>(origResTy)) {
297 rewriter, op, FeltConstAttr::get(getContext(), attrValue, ty)
298 );
299 return success();
300 }
301
302 if (llvm::isa<IndexType>(origResTy)) {
304 return success();
305 }
306
307 if (origResTy.isSignlessInteger(1)) {
308 // Treat 0 as false and any other value as true (but give a warning if it's not 1)
309 if (attrValue.isZero()) {
310 replaceOpWithNewOp<arith::ConstantIntOp>(rewriter, op, false, origResTy);
311 return success();
312 }
313 if (!attrValue.isOne()) {
314 Location opLoc = op.getLoc();
315 Diagnostic diag(opLoc, DiagnosticSeverity::Warning);
316 diag << "Interpreting non-zero value " << stringWithoutType(a) << " as true";
317 if (getContext()->shouldPrintOpOnDiagnostic()) {
318 diag.attachNote(opLoc) << "see current operation: " << *op;
319 }
320 diag.attachNote(UnknownLoc::get(getContext()))
321 << "when instantiating '" << StructDefOp::getOperationName() << "' parameter \"" << sym
322 << "\" for this call";
323 diagnostics.push_back(std::move(diag));
324 }
325 replaceOpWithNewOp<arith::ConstantIntOp>(rewriter, op, true, origResTy);
326 return success();
327 }
328 return op->emitOpError().append("unexpected result type ", origResTy);
329 }
330
331 LogicalResult handleRewrite(
332 Attribute, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, FeltConstAttr a
333 ) const {
334 replaceOpWithNewOp<FeltConstantOp>(rewriter, op, a);
335 return success();
336 }
337};
338
341struct MatchFailureListener : public RewriterBase::Listener {
342 bool hadFailure = false;
343
344 ~MatchFailureListener() override {}
345
346 void notifyMatchFailure(Location loc, function_ref<void(Diagnostic &)> reasonCallback) override {
347 hadFailure = true;
348
349 InFlightDiagnostic diag = emitError(loc);
350 reasonCallback(*diag.getUnderlyingDiagnostic());
351 diag.report();
352 }
353};
354
355static LogicalResult
356applyAndFoldGreedily(ModuleOp modOp, ConversionTracker &tracker, RewritePatternSet &&patterns) {
357 bool currStepModified = false;
358 MatchFailureListener failureListener;
359 LogicalResult result = applyPatternsGreedily(
360 modOp->getRegion(0), std::move(patterns),
361 GreedyRewriteConfig {.maxIterations = 20, .listener = &failureListener, .fold = true},
362 &currStepModified
363 );
364 tracker.updateModifiedFlag(currStepModified);
365 return failure(result.failed() || failureListener.hadFailure);
366}
367
368template <bool AllowStructParams = true> bool isConcreteAttr(Attribute a) {
369 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(a)) {
370 return isConcreteType(tyAttr.getValue(), AllowStructParams);
371 }
372 if (IntegerAttr intAttr = dyn_cast<IntegerAttr>(a)) {
373 return !isDynamic(intAttr);
374 }
375 return false;
376}
377
378static SymbolRefAttr
379convertCalleeSymRefs(SymbolRefAttr callee, const DenseMap<Attribute, Attribute> &paramNameToValue) {
380 auto it = paramNameToValue.find(FlatSymbolRefAttr::get(callee.getRootReference()));
381 if (it == paramNameToValue.end()) {
382 return callee;
383 }
384
385 auto tyAttr = llvm::dyn_cast<TypeAttr>(it->second);
386 if (!tyAttr) {
387 return callee;
388 }
389
390 auto structTy = llvm::dyn_cast<StructType>(tyAttr.getValue());
391 if (!structTy) {
392 return callee;
393 }
394
395 SmallVector<FlatSymbolRefAttr> newPieces = getPieces(structTy.getNameRef());
396 llvm::append_range(newPieces, callee.getNestedReferences());
397 return asSymbolRefAttr(newPieces);
398}
399
400static void
401convertCalleesInPlace(Operation *op, const DenseMap<Attribute, Attribute> &paramNameToValue) {
402 op->walk([&paramNameToValue](CallOp callOp) {
403 callOp.setCalleeAttr(convertCalleeSymRefs(callOp.getCalleeAttr(), paramNameToValue));
404 });
405}
406
407static bool calleeReferencesTemplateParam(CallOp op) {
408 SymbolRefAttr callee = op.getCalleeAttr();
409 if (!callee || callee.getNestedReferences().size() != 1) {
410 return false;
411 }
412 TemplateOp parentTemplate = getParentOfType<TemplateOp>(op);
413 if (!parentTemplate) {
414 return false;
415 }
416 return parentTemplate.hasConstNamed<TemplateParamOp>(callee.getRootReference());
417}
418
423static std::optional<Attribute>
424evaluateExpr(TemplateExprOp exprOp, const DenseMap<Attribute, Attribute> &paramNameToConcrete) {
425 // Map from SSA value in the expr body to its concrete Attribute.
426 DenseMap<Value, Attribute> valueMap;
427 for (Operation &bodyOp : exprOp.getInitializerRegion().front()) {
428 if (auto yieldOp = llvm::dyn_cast<YieldOp>(bodyOp)) {
429 auto it = valueMap.find(yieldOp.getVal());
430 return it != valueMap.end() ? std::make_optional(it->second) : std::nullopt;
431 }
432
433 if (auto constReadOp = llvm::dyn_cast<ConstReadOp>(bodyOp)) {
434 auto it = paramNameToConcrete.find(constReadOp.getConstNameAttr());
435 if (it == paramNameToConcrete.end()) {
436 return std::nullopt; // a referenced param is not concrete
437 }
438 // If the attribute type is `FeltType` but it's stored as an IntegerAttr, promote to
439 // a `FeltConstAttr`.
440 Attribute val = it->second;
441 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(val)) {
442 if (auto feltTy = llvm::dyn_cast<FeltType>(constReadOp.getResult().getType())) {
443 val = FeltConstAttr::get(bodyOp.getContext(), intAttr.getValue(), feltTy);
444 }
445 }
446 valueMap[constReadOp.getResult()] = val;
447 continue;
448 }
449
450 // Gather constant attributes for all operands.
451 SmallVector<Attribute> operandAttrs;
452 operandAttrs.reserve(bodyOp.getNumOperands());
453 for (Value operand : bodyOp.getOperands()) {
454 auto it = valueMap.find(operand);
455 if (it == valueMap.end()) {
456 return std::nullopt; // operand not known as a constant
457 }
458 operandAttrs.push_back(it->second);
459 }
460
461 // Try constant folding.
462 SmallVector<OpFoldResult> foldResults;
463 if (succeeded(bodyOp.fold(operandAttrs, foldResults)) &&
464 foldResults.size() == bodyOp.getNumResults()) {
465 for (auto [result, fr] : llvm::zip_equal(bodyOp.getResults(), foldResults)) {
466 if (Attribute a = llvm::dyn_cast<Attribute>(fr)) {
467 valueMap[result] = a;
468 } else {
469 return std::nullopt;
470 }
471 }
472 }
473 }
474 return std::nullopt; // no YieldOp found (shouldn't happen in a valid expr)
475}
476
480static void
481evaluateTemplateExprs(TemplateOp templateOp, DenseMap<Attribute, Attribute> &paramNameToConcrete) {
482 LLVM_DEBUG(
483 llvm::dbgs() << "[evaluateTemplateExprs] before: " << debug::toStringList(paramNameToConcrete)
484 << '\n'
485 );
486 for (TemplateExprOp exprOp : templateOp.getConstOps<TemplateExprOp>()) {
487 std::optional<Attribute> result = evaluateExpr(exprOp, paramNameToConcrete);
488 if (result.has_value()) {
489 auto exprNameAttr = FlatSymbolRefAttr::get(exprOp.getSymNameAttr());
490 paramNameToConcrete.try_emplace(exprNameAttr, *result);
491 LLVM_DEBUG(
492 llvm::dbgs() << "[evaluateTemplateExprs] expr @" << exprOp.getSymName()
493 << " evaluated to " << *result << '\n'
494 );
495 }
496 }
497 LLVM_DEBUG(
498 llvm::dbgs() << "[evaluateTemplateExprs] after: " << debug::toStringList(paramNameToConcrete)
499 << '\n'
500 );
501}
502
504
505static inline bool tableOffsetIsntSymbol(MemberReadOp op) {
506 return !llvm::isa_and_present<SymbolRefAttr>(op.getTableOffset().value_or(nullptr));
507}
508
511class StructCloner {
512 ConversionTracker &tracker_;
513 ModuleOp rootMod;
514 SymbolTableCollection symTables;
515 bool reportMissing = true;
516
517 class MappedTypeConverter : public TypeConverter {
518 StructType origTy;
519 StructType newTy;
520 const DenseMap<Attribute, Attribute> &paramNameToValue;
521
522 inline Attribute convertIfPossible(Attribute a) const {
523 auto res = this->paramNameToValue.find(a);
524 return (res != this->paramNameToValue.end()) ? res->second : a;
525 }
526
527 public:
528 MappedTypeConverter(
529 StructType originalType, StructType newType,
531 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
532 )
533 : TypeConverter(), origTy(originalType), newTy(newType),
534 paramNameToValue(paramNameToInstantiatedValue) {
535
536 addConversion([](Type inputTy) { return inputTy; });
537
538 addConversion([this](StructType inputTy) {
539 LLVM_DEBUG(llvm::dbgs() << "[MappedTypeConverter] convert " << inputTy << '\n');
540
541 // Check for replacement of the full type
542 if (inputTy == this->origTy) {
543 return this->newTy;
544 }
545 // Check for replacement of parameter symbol names with concrete values
546 if (ArrayAttr inputTyParams = inputTy.getParams()) {
547 SmallVector<Attribute> updated;
548 for (Attribute a : inputTyParams) {
549 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
550 updated.push_back(TypeAttr::get(this->convertType(ta.getValue())));
551 } else {
552 updated.push_back(convertIfPossible(a));
553 }
554 }
555 return StructType::get(
556 inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
557 );
558 }
559 // Otherwise, return the type unchanged
560 return inputTy;
561 });
562
563 addConversion([this](ArrayType inputTy) {
564 // Check for replacement of parameter symbol names with concrete values
565 ArrayRef<Attribute> dimSizes = inputTy.getDimensionSizes();
566 if (!dimSizes.empty()) {
567 SmallVector<Attribute> updated;
568 for (Attribute a : dimSizes) {
569 updated.push_back(convertIfPossible(a));
570 }
571 return ArrayType::get(this->convertType(inputTy.getElementType()), updated);
572 }
573 // Otherwise, return the type unchanged
574 return inputTy;
575 });
576
577 addConversion([this](TypeVarType inputTy) -> Type {
578 // Check for replacement of parameter symbol name with a concrete type
579 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.getNameRef()))) {
580 Type convertedType = tyAttr.getValue();
581 // Use the new type unless it contains a TypeVarType because a TypeVarType from a
582 // different struct references a parameter name from that other struct, not from the
583 // current struct so the reference would be invalid.
584 if (isConcreteType(convertedType)) {
585 return convertedType;
586 }
587 }
588 return inputTy;
589 });
590 }
591 };
592
593 class ClonedStructMemberReadOpPattern
594 : public SymbolUserHelper<
595 ClonedStructMemberReadOpPattern, MemberReadOp, IntegerAttr, FeltConstAttr> {
596 using super =
597 SymbolUserHelper<ClonedStructMemberReadOpPattern, MemberReadOp, IntegerAttr, FeltConstAttr>;
598
599 public:
600 ClonedStructMemberReadOpPattern(
601 TypeConverter &converter, MLIRContext *ctx,
602 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
603 )
604 // benefit>0 so this applies instead of GeneralTypeReplacePattern<MemberReadOp>
605 : super(converter, ctx, /*patternBenefit=*/1, paramNameToInstantiatedValue) {}
606
607 Attribute getNameAttr(MemberReadOp op) const override {
608 return op.getTableOffset().value_or(nullptr);
609 }
610
611 template <typename Attr>
612 LogicalResult handleRewrite(
613 Attribute, MemberReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, Attr a
614 ) const {
615 rewriter.modifyOpInPlace(op, [&]() {
616 op.setTableOffsetAttr(rewriter.getIndexAttr(fromAPInt(a.getValue())));
617 });
618
619 return success();
620 }
621
622 LogicalResult matchAndRewrite(
623 MemberReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
624 ) const override {
625 LLVM_DEBUG(
626 llvm::dbgs() << "[ClonedStructMemberReadOpPattern] MemberReadOp: " << op << '\n';
627 );
628 if (tableOffsetIsntSymbol(op)) {
629 return failure();
630 }
631
632 return super::matchAndRewrite(op, adaptor, rewriter);
633 }
634 };
635
636 FailureOr<StructType> genClone(StructType typeAtCaller, ArrayRef<Attribute> typeAtCallerParams) {
637 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] attempting clone of " << typeAtCaller << '\n');
638 // Find the StructDefOp for the original StructType
639 FailureOr<SymbolLookupResult<StructDefOp>> r =
640 typeAtCaller.getDefinition(symTables, rootMod, reportMissing);
641 if (failed(r)) {
642 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: cannot find StructDefOp \n");
643 return failure(); // getDefinition() already emits a sufficient error message
644 }
645 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] found definition\n";);
646
647 StructDefOp origStruct = r->get();
648 StructType typeAtDef = origStruct.getType();
649 MLIRContext *ctx = origStruct.getContext();
650
651 // Map of StructDefOp parameter name to concrete Attribute at the current instantiation site.
652 DenseMap<Attribute, Attribute> paramNameToConcrete;
653 // List of concrete Attributes from the struct instantiation with `nullptr` at any positions
654 // where the original attribute from the current instantiation site was not concrete. This is
655 // used for generating the new struct name. See `BuildShortTypeString::from()`.
656 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
657 // List of template const param names that must be preserved because they
658 // were not assigned concrete values at the current instantiation site.
659 SmallVector<Attribute> remainingNames;
660 // Reduced from `typeAtCallerParams` to contain only the non-concrete Attributes.
661 ArrayAttr reducedCallerParams = nullptr;
662 {
663 ArrayAttr paramNames = typeAtDef.getParams();
664
665 // pre-conditions
666 assert(!isNullOrEmpty(paramNames));
667 assert(paramNames.size() == typeAtCallerParams.size());
668
669 SmallVector<Attribute> nonConcreteParams;
670 for (size_t i = 0, e = paramNames.size(); i < e; ++i) {
671 Attribute next = typeAtCallerParams[i];
672 if (isConcreteAttr<false>(next)) {
673 paramNameToConcrete[paramNames[i]] = next;
674 attrsForInstantiatedNameSuffix.push_back(next);
675 } else {
676 remainingNames.push_back(paramNames[i]);
677 nonConcreteParams.push_back(next);
678 attrsForInstantiatedNameSuffix.push_back(nullptr);
679 }
680 }
681 // post-conditions
682 assert(remainingNames.size() == nonConcreteParams.size());
683 assert(attrsForInstantiatedNameSuffix.size() == paramNames.size());
684 assert(remainingNames.size() + paramNameToConcrete.size() == paramNames.size());
685
686 if (paramNameToConcrete.empty()) {
687 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: no concrete params \n");
688 return failure();
689 }
690 if (!remainingNames.empty()) {
691 reducedCallerParams = ArrayAttr::get(ctx, nonConcreteParams);
692 }
693 }
694
695 // This list will be used to build the new remote/external type.
696 SmallVector<FlatSymbolRefAttr> typeAtCallerSymPieces = getPieces(typeAtCaller.getNameRef());
697 typeAtCallerSymPieces.pop_back(); // drop struct name
698 // Name of template with instantiated parameter values.
699 std::string templateNameWithAttrs = BuildShortTypeString::from(
700 typeAtCallerSymPieces.back().getValue().str(), attrsForInstantiatedNameSuffix
701 );
702
703 // Get parent refs
704 TemplateOp parentTemplate = getParentOfType<TemplateOp>(origStruct);
705 assert(parentTemplate && "parameterized struct must be nested in a TemplateOp");
706 ModuleOp parentModule = getParentOfType<ModuleOp>(parentTemplate);
707 assert(parentModule && "TemplateOp must be nested in a ModuleOp");
708
709 // Evaluate any poly.expr symbols whose param dependencies are now concrete; add them to the
710 // map so ClonedBodyConstReadOpPattern can replace uses of those symbols too.
711 evaluateTemplateExprs(parentTemplate, paramNameToConcrete);
712
713 // Clone the original struct.
714 StructDefOp newStruct = origStruct.clone();
715 convertCalleesInPlace(newStruct, paramNameToConcrete);
716 if (remainingNames.empty()) { // FULL INSTANTIATION CASE
717 // Set name of the new struct by prepending its name with instantiated template name.
718 newStruct.setSymName(
719 (templateNameWithAttrs + mlir::Twine('_') + newStruct.getSymName()).str()
720 );
721 // Insert 'newStruct' into the parent ModuleOp of the original TemplateOp. Use the
722 // `SymbolTable::insert()` function so that the name will be made unique if necessary.
723 symTables.getSymbolTable(parentModule).insert(newStruct, Block::iterator(parentTemplate));
724 // Drop the old template name from the list.
725 typeAtCallerSymPieces.pop_back();
726 } else { // PARTIAL INSTANTIATION CASE
727 // Clone the template and set instantiated name.
728 TemplateOp newTemplate = parentTemplate.cloneWithoutRegions();
729 newTemplate.setSymName(templateNameWithAttrs);
730 assert(newTemplate->getNumRegions() > 0 && "region exists"); // it just doesn't have a block
731 newTemplate.getBodyRegion().emplaceBlock();
732
733 // Clone preserved const param/expr ops.
734 for (Attribute name : remainingNames) {
735 FlatSymbolRefAttr nameSym = llvm::dyn_cast<FlatSymbolRefAttr>(name);
736 assert(nameSym && "expected FlatSymbolRefAttr");
737
738 Operation *symOp = symTables.getSymbolTable(parentTemplate).lookup(nameSym.getAttr());
739 assert(symOp && "symbol must exist");
740 newTemplate.insert(newTemplate.begin(), symOp->clone());
741 }
742
743 // Insert the struct into the template and the template into the module. Use the
744 // `SymbolTable::insert()` function so that the name will be made unique if necessary.
745 symTables.getSymbolTable(newTemplate).insert(newStruct);
746 symTables.getSymbolTable(parentModule).insert(newTemplate, Block::iterator(parentTemplate));
747
748 // Replace the old template name in the list with the new one (get template name after
749 // symbol table insertion since it may be modified to make it unique).
750 typeAtCallerSymPieces.back() = FlatSymbolRefAttr::get(newTemplate.getSymNameAttr());
751 }
752
753 // Retrieve the new type AFTER inserting since the struct name may be appended to make
754 // it unique and use the remaining non-concrete parameters from the original type.
755 StructType newLocalType = newStruct.getType(reducedCallerParams);
756 typeAtCallerSymPieces.push_back(
757 FlatSymbolRefAttr::get(newLocalType.getNameRef().getLeafReference())
758 );
759 StructType newRemoteType =
760 StructType::get(asSymbolRefAttr(typeAtCallerSymPieces), newLocalType.getParams());
761 LLVM_DEBUG({
762 llvm::dbgs() << "[StructCloner] original def type: " << typeAtDef << '\n';
763 llvm::dbgs() << "[StructCloner] cloned def type: " << newStruct.getType() << '\n';
764 llvm::dbgs() << "[StructCloner] original remote type: " << typeAtCaller << '\n';
765 llvm::dbgs() << "[StructCloner] cloned local type: " << newLocalType << '\n';
766 llvm::dbgs() << "[StructCloner] cloned remote type: " << newRemoteType << '\n';
767 });
768
769 // Within the new struct, replace all references to the original StructType (i.e., the
770 // locally-parameterized version) with the new locally-parameterized StructType,
771 // and replace all uses of the removed struct parameters with the concrete values.
772 MappedTypeConverter tyConv(typeAtDef, newStruct.getType(), paramNameToConcrete);
773 ConversionTarget target =
774 newConverterDefinedTarget<EmitEqualityOp>(tyConv, ctx, tableOffsetIsntSymbol);
775 target.addDynamicallyLegalOp<ConstReadOp>([&paramNameToConcrete](ConstReadOp op) {
776 // Legal if it's not in the map of concrete attribute instantiations
777 return !paramNameToConcrete.contains(op.getConstNameAttr());
778 });
779
780 RewritePatternSet patterns = newGeneralRewritePatternSet<EmitEqualityOp>(tyConv, ctx, target);
781 patterns.add<ClonedBodyConstReadOpPattern>(
782 tyConv, ctx, paramNameToConcrete, tracker_.delayedDiagnosticSet(newLocalType)
783 );
784 patterns.add<ClonedStructMemberReadOpPattern>(tyConv, ctx, paramNameToConcrete);
785 if (failed(applyFullConversion(newStruct, target, std::move(patterns)))) {
786 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] instantiating body of struct failed \n");
787 return failure();
788 }
789 return newRemoteType;
790 }
791
792public:
793 StructCloner(ConversionTracker &tracker, ModuleOp root)
794 : tracker_(tracker), rootMod(root), symTables() {}
795
796 FailureOr<StructType> createInstantiatedClone(StructType orig) {
797 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] orig: " << orig << '\n');
798 if (ArrayAttr params = orig.getParams()) {
799 return genClone(orig, params.getValue());
800 }
801 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: nullptr for params \n");
802 return failure();
803 }
804
805 void enableReportMissing() { reportMissing = true; }
806
807 void disableReportMissing() { reportMissing = false; }
808};
809
810class DisableReportMissing;
811
812class ParameterizedStructUseTypeConverter : public TypeConverter {
813 ConversionTracker &tracker_;
814 StructCloner cloner;
815
816 friend DisableReportMissing;
817
818public:
819 ParameterizedStructUseTypeConverter(ConversionTracker &tracker, ModuleOp root)
820 : TypeConverter(), tracker_(tracker), cloner(tracker, root) {
821
822 addConversion([](Type inputTy) { return inputTy; });
823
824 addConversion([this](StructType inputTy) -> StructType {
825 LLVM_DEBUG(
826 llvm::dbgs() << "[ParameterizedStructUseTypeConverter] attempting conversion of "
827 << inputTy << '\n';
828 );
829 // First check for a cached entry
830 if (auto opt = tracker_.getInstantiation(inputTy)) {
831 return opt.value();
832 }
833
834 // Otherwise, try to create a clone of the struct with instantiated params. If that can't be
835 // done, return the original type to indicate that it's still legal (for this step at least).
836 FailureOr<StructType> cloneRes = cloner.createInstantiatedClone(inputTy);
837 if (failed(cloneRes)) {
838 return inputTy;
839 }
840 StructType newTy = cloneRes.value();
841 LLVM_DEBUG(
842 llvm::dbgs() << "[ParameterizedStructUseTypeConverter] instantiating " << inputTy
843 << " as " << newTy << '\n'
844 );
845 tracker_.recordInstantiation(inputTy, newTy);
846 return newTy;
847 });
848
849 addConversion([this](ArrayType inputTy) {
850 return inputTy.cloneWith(convertType(inputTy.getElementType()));
851 });
852 }
853};
854
855class CallStructFuncPattern : public OpConversionPattern<CallOp> {
856 ConversionTracker &tracker_;
857
858public:
859 CallStructFuncPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &tracker)
860 // benefit>0 so this applies instead of CallOpClassReplacePattern
861 : OpConversionPattern<CallOp>(converter, ctx, /*benefit=*/1), tracker_(tracker) {}
862
863 LogicalResult matchAndRewrite(
864 CallOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
865 ) const override {
866 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] CallOp: " << op << '\n');
867
868 // Convert the result types of the CallOp
869 SmallVector<Type> newResultTypes;
870 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
871 return op->emitError("Could not convert Op result types.");
872 }
873 LLVM_DEBUG({
874 llvm::dbgs() << "[CallStructFuncPattern] newResultTypes: "
875 << debug::toStringList(newResultTypes) << '\n';
876 });
877
878 // Update the callee to reflect the new struct target if necessary. These checks are based on
879 // `CallOp::calleeIsStructC*()` but the types must not come from the CallOp in this case.
880 // Instead they must come from the converted versions.
881 SymbolRefAttr calleeAttr = op.getCalleeAttr();
882 if (op.calleeIsStructCompute()) {
883 if (StructType newStTy = getIfSingleton<StructType>(newResultTypes)) {
884 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] newStTy: " << newStTy << '\n');
885 calleeAttr = appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
886 tracker_.reportDelayedDiagnostics(newStTy, op);
887 }
888 } else if (op.calleeIsStructConstrain()) {
889 if (StructType newStTy = getAtIndex<StructType>(adapter.getArgOperands().getTypes(), 0)) {
890 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] newStTy: " << newStTy << '\n');
891 calleeAttr = appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
892 }
893 }
894
895 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] replaced " << op);
897 rewriter, op, newResultTypes, calleeAttr, adapter.getMapOperands(),
898 op.getNumDimsPerMapAttr(), adapter.getArgOperands()
899 );
900 (void)newOp; // tell compiler it's intentionally unused in release builds
901 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
902 return success();
903 }
904};
905
906// This one ensures MemberDefOp types are converted even if there are no reads/writes to them.
907class MemberDefOpPattern : public OpConversionPattern<MemberDefOp> {
908public:
909 MemberDefOpPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &)
910 // benefit>0 so this applies instead of GeneralTypeReplacePattern<MemberDefOp>
911 : OpConversionPattern<MemberDefOp>(converter, ctx, /*benefit=*/1) {}
912
913 LogicalResult matchAndRewrite(
914 MemberDefOp op, OpAdaptor /*adapter*/, ConversionPatternRewriter &rewriter
915 ) const override {
916 LLVM_DEBUG(llvm::dbgs() << "[MemberDefOpPattern] MemberDefOp: " << op << '\n');
917
918 Type oldMemberType = op.getType();
919 Type newMemberType = getTypeConverter()->convertType(oldMemberType);
920 if (oldMemberType == newMemberType) {
921 return failure(); // nothing changed
922 }
923 rewriter.modifyOpInPlace(op, [&op, &newMemberType]() { op.setType(newMemberType); });
924 return success();
925 }
926};
927
930class DisableReportMissing : public LegalityCheckCallback {
931 ParameterizedStructUseTypeConverter &tyConv;
932
933public:
934 explicit DisableReportMissing(ParameterizedStructUseTypeConverter &tc) : tyConv(tc) {}
935
936 void checkStarted() override { tyConv.cloner.disableReportMissing(); }
937
938 void checkEnded(bool) override { tyConv.cloner.enableReportMissing(); }
939};
940
941LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
942 MLIRContext *ctx = modOp.getContext();
943 ParameterizedStructUseTypeConverter tyConv(tracker, modOp);
944 DisableReportMissing drm(tyConv);
945 ConversionTarget target = newConverterDefinedTargetWithCallback<>(tyConv, ctx, drm);
946 RewritePatternSet patterns = newGeneralRewritePatternSet(tyConv, ctx, target);
947 patterns.add<CallStructFuncPattern, MemberDefOpPattern>(tyConv, ctx, tracker);
948 return applyPartialConversion(modOp, target, std::move(patterns));
949}
950
951} // namespace Step1A_InstantiateStructs
952
954
957class FuncInstTypeConverter : public TypeConverter {
958 DenseMap<Attribute, Attribute> paramNameToValue;
959
960 Attribute convertIfPossible(Attribute a) const {
961 auto res = paramNameToValue.find(a);
962 return (res != paramNameToValue.end()) ? res->second : a;
963 }
964
965public:
966 explicit FuncInstTypeConverter(DenseMap<Attribute, Attribute> paramNameToConcrete)
967 : TypeConverter(), paramNameToValue(std::move(paramNameToConcrete)) {
968 addConversion([](Type t) { return t; });
969
970 addConversion([this](TypeVarType inputTy) -> Type {
971 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.getNameRef()))) {
972 Type convertedType = tyAttr.getValue();
973 if (isConcreteType(convertedType)) {
974 return convertedType;
975 }
976 }
977 return inputTy;
978 });
979
980 addConversion([this](ArrayType inputTy) {
981 SmallVector<Attribute> updated;
982 bool changed = false;
983 for (Attribute a : inputTy.getDimensionSizes()) {
984 Attribute converted = convertIfPossible(a);
985 updated.push_back(converted);
986 if (converted != a) {
987 changed = true;
988 }
989 }
990 Type newElemTy = this->convertType(inputTy.getElementType());
991 if (!changed && newElemTy == inputTy.getElementType()) {
992 return inputTy;
993 }
994 return ArrayType::get(newElemTy, updated);
995 });
996
997 addConversion([this](StructType inputTy) -> StructType {
998 if (ArrayAttr params = inputTy.getParams()) {
999 SmallVector<Attribute> updated;
1000 bool changed = false;
1001 for (Attribute a : params) {
1002 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
1003 Type newTy = this->convertType(ta.getValue());
1004 if (newTy != ta.getValue()) {
1005 updated.push_back(TypeAttr::get(newTy));
1006 changed = true;
1007 continue;
1008 }
1009 } else {
1010 Attribute converted = convertIfPossible(a);
1011 if (converted != a) {
1012 updated.push_back(converted);
1013 changed = true;
1014 continue;
1015 }
1016 }
1017 updated.push_back(a);
1018 }
1019 if (changed) {
1020 return StructType::get(
1021 inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
1022 );
1023 }
1024 }
1025 return inputTy;
1026 });
1027 }
1028
1029 bool containsParam(Attribute nameAttr) const { return paramNameToValue.contains(nameAttr); }
1030 const DenseMap<Attribute, Attribute> &getParamMap() const { return paramNameToValue; }
1031};
1032
1033class InstantiateFuncAtCallOp final : public OpRewritePattern<CallOp> {
1034 ConversionTracker &tracker_;
1035
1036public:
1037 InstantiateFuncAtCallOp(MLIRContext *ctx, ConversionTracker &tracker)
1038 : OpRewritePattern<CallOp>(ctx), tracker_(tracker) {}
1039
1040 LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override {
1041 LLVM_DEBUG(llvm::dbgs() << "[InstantiateFuncAtCallOp] op: " << op << '\n');
1042
1043 if (calleeReferencesTemplateParam(op)) {
1044 return failure();
1045 }
1046
1047 // Lookup callee target function
1048 SymbolTableCollection symTables;
1049 FailureOr<SymbolLookupResult<FuncDefOp>> callTgtOpt = op.getCalleeTarget(symTables);
1050 if (failed(callTgtOpt)) {
1051 return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
1052 diag << "could not find target function for call";
1053 });
1054 }
1055 FuncDefOp callTgt = callTgtOpt->get();
1056
1057 // Check if callee is within a TemplateOp
1058 TemplateOp parentTemplate = llvm::dyn_cast<TemplateOp>(callTgt->getParentOp());
1059 if (!parentTemplate) {
1060 return failure(); // nothing to do if not parameterized
1061 }
1062 LLVM_DEBUG(
1063 llvm::dbgs() << "[InstantiateFuncAtCallOp] target function in template "
1064 << parentTemplate.getSymName() << '\n'
1065 );
1066
1067 // Perform type unification with tracking to infer the instantiated type(s). Even though
1068 // `CallOp` verification already checked that caller and callee types unify, the progress of
1069 // instantiation so far may have brought together a chain of calls across templates where each
1070 // individual unification check passed due to permissive type variables and/or symbols in the
1071 // middle but the overall chain does not unify. Hence, this unification may fail and should
1072 // produce a meaningful error message if it does.
1073 // See: `test/Transforms/Flattening/instantiate_funcs_fail.llzk`
1074 FailureOr<UnificationMap> unifyResult = op.unifyTypeSignature(callTgt.getFunctionType());
1075 if (failed(unifyResult)) {
1076 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1077 diag.append("target function type does not unify with call type ")
1078 .append(op.getTypeSignature())
1079 .attachNote(callTgt.getLoc())
1080 .append("target function declared here");
1081 });
1082 }
1083 LLVM_DEBUG(
1084 llvm::dbgs() << "[InstantiateFuncAtCallOp] unifications of types: "
1085 << debug::toStringList(unifyResult.value()) << '\n'
1086 );
1087
1088 // Maps template parameter symbols to the instantiation value at the call site.
1089 DenseMap<Attribute, Attribute> paramNameToConcrete;
1090 // If template instantiation list is given, must use that. Otherwise, infer.
1091 auto realParams = parentTemplate.getConstOps<TemplateParamOp>();
1092 ArrayAttr callParams = op.getTemplateParamsAttr();
1093 LLVM_DEBUG(
1094 llvm::dbgs() << "[InstantiateFuncAtCallOp] TemplateParamsAttr: " << callParams << '\n'
1095 );
1096 if (isNullOrEmpty(callParams)) {
1097 for (auto paramOp : realParams) {
1098 auto paramName = FlatSymbolRefAttr::get(paramOp.getSymNameAttr());
1099 auto it = unifyResult->find({paramName, Side::RHS});
1100 if (it == unifyResult->end()) {
1101 LLVM_DEBUG(
1102 llvm::dbgs() << "[InstantiateFuncAtCallOp] unification for param '" << paramName
1103 << "': not found\n"
1104 );
1105 continue;
1106 }
1107 Attribute inferredVal = it->second;
1108 if (!isConcreteAttr(inferredVal)) {
1109 LLVM_DEBUG(
1110 llvm::dbgs() << "[InstantiateFuncAtCallOp] unification for param '" << paramName
1111 << "': not concrete, " << inferredVal << '\n'
1112 );
1113 continue;
1114 }
1115 // Ensure it's a valid value for the optional type restriction on the TemplateParamOp
1116 if (failed(op.verifyTemplateParamCompatibility(inferredVal, paramOp))) {
1117 LLVM_DEBUG(
1118 llvm::dbgs() << "[InstantiateFuncAtCallOp] unification for param '" << paramName
1119 << "': incompatible with specified param type. MUST FAIL!\n"
1120 );
1121 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1122 diag.append("inferred value for parameter '")
1123 .append(paramName)
1124 .append("' is incompatible with specified param type")
1125 .attachNote(paramOp.getLoc())
1126 .append("template parameter declared here");
1127 });
1128 }
1129 paramNameToConcrete[paramName] = inferredVal;
1130 }
1131 } else {
1132 // As stated earlier, need to run the verification checks again to ensure the
1133 // instantiation is valid, except for the size check becuase that cannot change.
1134 assert((callParams.size() == llvm::range_size(realParams)) && "per CallOpVerifier");
1135 if (failed(op.verifyTemplateParamCompatibility(realParams))) {
1136 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1137 diag.append("incompatible with specified param type(s)");
1138 });
1139 }
1140 if (failed(op.verifyTemplateParamsMatchInferred(realParams, unifyResult.value()))) {
1141 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1142 diag.append("incompatible with inferred param value(s)");
1143 });
1144 }
1145 // Add the mappings
1146 for (auto [paramOp, attr] : llvm::zip_equal(realParams, callParams.getValue())) {
1147 auto paramName = FlatSymbolRefAttr::get(paramOp.getSymNameAttr());
1148 if (!isConcreteAttr(attr)) {
1149 LLVM_DEBUG(
1150 llvm::dbgs() << "[InstantiateFuncAtCallOp] unification for param '" << paramName
1151 << "': not concrete, " << attr << '\n'
1152 );
1153 continue;
1154 }
1155 paramNameToConcrete[paramName] = attr;
1156 }
1157 }
1158
1159 if (paramNameToConcrete.empty()) {
1160 LLVM_DEBUG(llvm::dbgs() << "[InstantiateFuncAtCallOp] skip: no concrete params\n");
1161 return failure();
1162 }
1163
1164 // Evaluate any poly.expr symbols whose param dependencies are now concrete; add them to the
1165 // map so ClonedFuncConstReadOpPattern can replace uses of those symbols too.
1166 evaluateTemplateExprs(parentTemplate, paramNameToConcrete);
1167
1168 // Classify each template parameter as concrete (to be inlined) or remaining (to be preserved).
1169 SmallVector<Attribute> remainingNames;
1170 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
1171 for (Attribute paramName : parentTemplate.getConstNames<TemplateParamOp>()) {
1172 auto it = paramNameToConcrete.find(paramName);
1173 if (it != paramNameToConcrete.end()) {
1174 attrsForInstantiatedNameSuffix.push_back(it->second);
1175 } else {
1176 attrsForInstantiatedNameSuffix.push_back(nullptr); // placeholder for non-concrete param
1177 remainingNames.push_back(paramName);
1178 }
1179 }
1180
1181 MLIRContext *ctx = op.getContext();
1182 ModuleOp parentModule = getParentOfType<ModuleOp>(parentTemplate);
1183 assert(parentModule && "TemplateOp must be nested in a ModuleOp");
1184
1185 // Build the (partially-)instantiated template name, e.g., "TemplateName_8_\x1A" where \x1A
1186 // is a placeholder character at the position of a non-concrete parameter.
1187 std::string templateNameWithAttrs = BuildShortTypeString::from(
1188 parentTemplate.getSymName().str(), attrsForInstantiatedNameSuffix
1189 );
1190
1191 // Helper lambda to:
1192 // 1. build the FuncInstTypeConverter and apply it to a cloned function
1193 // 2. verify CallOp in the converted function are valid for their respective targets
1194 // and emit a more helpful error at this point rather than discovering it later
1195 // when verifying the entire module.
1196 auto applyBodyConversions = [&](FuncDefOp newFunc) -> LogicalResult {
1197 FuncInstTypeConverter tyConv(paramNameToConcrete);
1198 ConversionTarget target = newConverterDefinedTarget<>(tyConv, ctx);
1199 target.addDynamicallyLegalOp<ConstReadOp>([&tyConv](ConstReadOp p) {
1200 // Legal if it's not in the map of concrete attribute instantiations
1201 return !tyConv.containsParam(p.getConstNameAttr());
1202 });
1203 SmallVector<Diagnostic> delayedDiagnostics;
1204 RewritePatternSet bodyPatterns = newGeneralRewritePatternSet(tyConv, ctx, target);
1205 bodyPatterns.add<ClonedBodyConstReadOpPattern>(
1206 tyConv, ctx, tyConv.getParamMap(), delayedDiagnostics
1207 );
1208 if (failed(applyFullConversion(newFunc, target, std::move(bodyPatterns)))) {
1209 return failure();
1210 }
1211 LLVM_DEBUG(
1212 llvm::dbgs() << "[InstantiateFuncAtCallOp] instantiated clone: " << newFunc << '\n'
1213 );
1214 ::reportDelayedDiagnostics(op, std::move(delayedDiagnostics));
1215
1216 // Verify CallOp match targets
1217 SymbolTableCollection tables;
1218 WalkResult res = newFunc.walk([&tables](CallOp nestedCall) {
1219 return WalkResult(nestedCall.verifySymbolUses(tables));
1220 });
1221 return failure(res.wasInterrupted());
1222 };
1223
1224 SmallVector<FlatSymbolRefAttr> symPieces = getPieces(op.getCalleeAttr());
1225 assert(symPieces.size() >= 2 && "callee must include at least template and function names");
1226 SymbolRefAttr originalCalleeAttr = asSymbolRefAttr(symPieces);
1227 if (remainingNames.empty()) {
1228 // FULL INSTANTIATION: place the cloned function directly in the parent module.
1229 // New function name encodes all parameter values, e.g., "TemplateName_8_12_funcName".
1230 std::string newFuncName =
1231 (mlir::Twine(templateNameWithAttrs) + "_" + callTgt.getSymName()).str();
1232 StringRef actualNewFuncName = newFuncName;
1233 if (!symTables.getSymbolTable(parentModule).lookup(newFuncName)) {
1234 FuncDefOp newFunc = callTgt.clone();
1235 newFunc.setSymName(newFuncName);
1236 convertCalleesInPlace(newFunc, paramNameToConcrete);
1237 // Insert before the TemplateOp; symbol table may adjust the name to ensure uniqueness.
1238 symTables.getSymbolTable(parentModule).insert(newFunc, Block::iterator(parentTemplate));
1239 actualNewFuncName = newFunc.getSymName();
1240 LLVM_DEBUG(
1241 llvm::dbgs() << "[InstantiateFuncAtCallOp] created full instantiation function: "
1242 << actualNewFuncName << '\n'
1243 );
1244 if (failed(applyBodyConversions(newFunc))) {
1245 LLVM_DEBUG(
1246 llvm::dbgs() << "[InstantiateFuncAtCallOp] body conversion failed for "
1247 << actualNewFuncName << '\n'
1248 );
1249 newFunc->erase();
1250 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1251 diag.append("failure while creating instantiated function '", actualNewFuncName, '\'');
1252 });
1253 }
1254 } else {
1255 LLVM_DEBUG(
1256 llvm::dbgs() << "[InstantiateFuncAtCallOp] reusing full instantiation function: "
1257 << actualNewFuncName << '\n'
1258 );
1259 }
1260 // Callee: drop template & original function names, add the new module-level function name.
1261 // Original: @[prefix...]::@TemplateName::@funcName
1262 // New: @[prefix...]::@newFuncName
1263 symPieces.pop_back(); // remove original function name
1264 symPieces.pop_back(); // remove template name
1265 symPieces.push_back(FlatSymbolRefAttr::get(StringAttr::get(ctx, actualNewFuncName)));
1266 } else {
1267 // PARTIAL INSTANTIATION: place the cloned function in a new partially-instantiated
1268 // TemplateOp that retains only the non-concrete parameters.
1269 // New template name encodes the concrete values and uses placeholder chars for the rest,
1270 // e.g., "TemplateName_8_\x1A" where \x1A marks the position of a non-concrete param.
1271 TemplateOp newTemplate;
1272 if (Operation *existing =
1273 symTables.getSymbolTable(parentModule).lookup(templateNameWithAttrs)) {
1274 newTemplate = llvm::dyn_cast<TemplateOp>(existing);
1275 }
1276 if (!newTemplate) {
1277 // Clone the TemplateOp structure without its body and set the new name.
1278 newTemplate = parentTemplate.cloneWithoutRegions();
1279 newTemplate.setSymName(templateNameWithAttrs);
1280 assert(newTemplate->getNumRegions() > 0 && "region exists");
1281 newTemplate.getBodyRegion().emplaceBlock();
1282
1283 // Clone the preserved (non-concrete) param/expr ops into the new template in order.
1284 Block &newTemplateBody = newTemplate.getBodyRegion().front();
1285 for (Attribute name : remainingNames) {
1286 FlatSymbolRefAttr nameSym = llvm::cast<FlatSymbolRefAttr>(name);
1287 Operation *paramOp = symTables.getSymbolTable(parentTemplate).lookup(nameSym.getAttr());
1288 assert(paramOp && "symbol must exist");
1289 newTemplateBody.push_back(paramOp->clone());
1290 }
1291
1292 // Clone and partially convert the function (concretize only the concrete params).
1293 FuncDefOp newFunc = callTgt.clone();
1294 convertCalleesInPlace(newFunc, paramNameToConcrete);
1295
1296 // Insert before body conversion so nested concrete callees verify from the root module. Use
1297 // the `SymbolTable::insert()` function so that the name will be made unique if necessary.
1298 symTables.getSymbolTable(newTemplate).insert(newFunc);
1299 symTables.getSymbolTable(parentModule).insert(newTemplate, Block::iterator(parentTemplate));
1300 if (failed(applyBodyConversions(newFunc))) {
1301 StringRef newFuncName = newFunc.getSymName();
1302 LLVM_DEBUG(
1303 llvm::dbgs() << "[InstantiateFuncAtCallOp] body conversion failed for "
1304 << newFuncName << '\n'
1305 );
1306 newTemplate->erase();
1307 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1308 diag.append("failure while creating instantiated function '", newFuncName, '\'');
1309 });
1310 }
1311
1312 LLVM_DEBUG(
1313 llvm::dbgs() << "[InstantiateFuncAtCallOp] created partial instantiation template: "
1314 << newTemplate.getSymName() << '\n'
1315 );
1316 } else {
1317 LLVM_DEBUG(
1318 llvm::dbgs() << "[InstantiateFuncAtCallOp] reusing partial instantiation template: "
1319 << newTemplate.getSymName() << '\n'
1320 );
1321 }
1322 // Callee: replace old template name with new template name, keep the function name.
1323 // Original: @[prefix...]::@TemplateName::@funcName
1324 // New: @[prefix...]::@newTemplateName::@funcName
1325 symPieces.pop_back(); // remove original function name (will be re-appended)
1326 symPieces.pop_back(); // remove original template name
1327 symPieces.push_back(FlatSymbolRefAttr::get(newTemplate.getSymNameAttr()));
1328 symPieces.push_back(FlatSymbolRefAttr::get(callTgt.getSymNameAttr()));
1329 }
1330
1331 tracker_.recordInstantiation(originalCalleeAttr);
1332
1333 // Update the CallOp to point to the instantiated function and mark the module as modified.
1334 rewriter.modifyOpInPlace(op, [&op, &symPieces]() {
1335 // Update callee attribute.
1336 SymbolRefAttr newCalleeAttr = asSymbolRefAttr(symPieces);
1337 LLVM_DEBUG({
1338 llvm::dbgs() << "[InstantiateFuncAtCallOp] updating callee from " << op.getCalleeAttr()
1339 << " to " << newCalleeAttr << '\n';
1340 });
1341 op.setCalleeAttr(newCalleeAttr);
1342 // Also drop template param list. If it was present, it was fully used (no partial case).
1343 op.setTemplateParamsAttr(nullptr);
1344 });
1345 tracker_.updateModifiedFlag(true);
1346 return success();
1347 }
1348};
1349
1350LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1351 MLIRContext *ctx = modOp.getContext();
1352 RewritePatternSet patterns(ctx);
1353 patterns.add<InstantiateFuncAtCallOp>(ctx, tracker);
1354 MatchFailureListener failureListener;
1355 walkAndApplyPatterns(modOp, std::move(patterns), &failureListener);
1356 return failure(failureListener.hadFailure);
1357}
1358
1359} // namespace Step1B_InstantiateFunctions
1360
1361namespace Step2_Unroll {
1362
1363// TODO: not guaranteed to work with WhileOp, can try with our custom attributes though.
1364template <HasInterface<LoopLikeOpInterface> OpClass>
1365class LoopUnrollPattern : public OpRewritePattern<OpClass> {
1366public:
1367 using OpRewritePattern<OpClass>::OpRewritePattern;
1368
1369 LogicalResult matchAndRewrite(OpClass loopOp, PatternRewriter &rewriter) const override {
1370 if (auto maybeConstant = getConstantTripCount(loopOp)) {
1371 uint64_t tripCount = *maybeConstant;
1372 if (tripCount == 0) {
1373 rewriter.eraseOp(loopOp);
1374 return success();
1375 } else if (tripCount == 1) {
1376 return loopOp.promoteIfSingleIteration(rewriter);
1377 }
1378 return loopUnrollByFactor(loopOp, tripCount);
1379 }
1380 return failure();
1381 }
1382
1383private:
1386 static std::optional<int64_t> getConstantTripCount(LoopLikeOpInterface loopOp) {
1387 std::optional<OpFoldResult> lbVal = loopOp.getSingleLowerBound();
1388 std::optional<OpFoldResult> ubVal = loopOp.getSingleUpperBound();
1389 std::optional<OpFoldResult> stepVal = loopOp.getSingleStep();
1390 if (!lbVal.has_value() || !ubVal.has_value() || !stepVal.has_value()) {
1391 return std::nullopt;
1392 }
1393 return constantTripCount(lbVal.value(), ubVal.value(), stepVal.value());
1394 }
1395};
1396
1397LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1398 MLIRContext *ctx = modOp.getContext();
1399 RewritePatternSet patterns(ctx);
1400 patterns.add<LoopUnrollPattern<scf::ForOp>>(ctx);
1401 patterns.add<LoopUnrollPattern<affine::AffineForOp>>(ctx);
1402
1403 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1404}
1405} // namespace Step2_Unroll
1406
1408
1409// Adapted from `mlir::getConstantIntValues()` but that one failed in CI for an unknown reason. This
1410// version uses a basic loop instead of llvm::map_to_vector().
1411std::optional<SmallVector<int64_t>> getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
1412 SmallVector<int64_t> res;
1413 for (OpFoldResult ofr : ofrs) {
1414 std::optional<int64_t> cv = getConstantIntValue(ofr);
1415 if (!cv.has_value()) {
1416 return std::nullopt;
1417 }
1418 res.push_back(cv.value());
1419 }
1420 return res;
1421}
1422
1423struct AffineMapFolder {
1424 struct Input {
1425 OperandRangeRange mapOpGroups;
1426 DenseI32ArrayAttr dimsPerGroup;
1427 ArrayRef<Attribute> paramsOfStructTy;
1428 };
1429
1430 struct Output {
1431 SmallVector<SmallVector<Value>> mapOpGroups;
1432 SmallVector<int32_t> dimsPerGroup;
1433 SmallVector<Attribute> paramsOfStructTy;
1434 };
1435
1436 static inline SmallVector<ValueRange> getConvertedMapOpGroups(Output out) {
1437 return llvm::map_to_vector(out.mapOpGroups, [](const SmallVector<Value> &grp) {
1438 return ValueRange(grp);
1439 });
1440 }
1441
1442 static LogicalResult
1443 fold(PatternRewriter &rewriter, const Input &in, Output &out, Operation *op, const char *aspect) {
1444 if (in.mapOpGroups.empty()) {
1445 // No affine map operands so nothing to do
1446 return failure();
1447 }
1448
1449 assert(in.mapOpGroups.size() <= in.paramsOfStructTy.size());
1450 assert(std::cmp_equal(in.mapOpGroups.size(), in.dimsPerGroup.size()));
1451
1452 size_t idx = 0; // index in `mapOpGroups`, i.e., the number of AffineMapAttr encountered
1453 for (Attribute sizeAttr : in.paramsOfStructTy) {
1454 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(sizeAttr)) {
1455 ValueRange currMapOps = in.mapOpGroups[idx++];
1456 LLVM_DEBUG(
1457 llvm::dbgs() << "[AffineMapFolder] currMapOps: " << debug::toStringList(currMapOps)
1458 << '\n'
1459 );
1460 SmallVector<OpFoldResult> currMapOpsCast = getAsOpFoldResult(currMapOps);
1461 LLVM_DEBUG(
1462 llvm::dbgs() << "[AffineMapFolder] currMapOps as fold results: "
1463 << debug::toStringList(currMapOpsCast) << '\n'
1464 );
1465 if (auto constOps = Step3_InstantiateAffineMaps::getConstantIntValues(currMapOpsCast)) {
1466 SmallVector<Attribute> result;
1467 bool hasPoison = false; // indicates divide by 0 or mod by <1
1468 auto constAttrs = llvm::map_to_vector(*constOps, [&rewriter](int64_t v) -> Attribute {
1469 return rewriter.getIndexAttr(v);
1470 });
1471 LogicalResult foldResult = m.getAffineMap().constantFold(constAttrs, result, &hasPoison);
1472 if (hasPoison) {
1473 // Diagnostic remark: could be removed for release builds if too noisy
1474 op->emitRemark()
1475 .append(
1476 "Cannot fold affine_map for ", aspect, ' ', out.paramsOfStructTy.size(),
1477 " due to divide by 0 or modulus with negative divisor"
1478 )
1479 .report();
1480 return failure();
1481 }
1482 if (failed(foldResult)) {
1483 // Diagnostic remark: could be removed for release builds if too noisy
1484 op->emitRemark()
1485 .append(
1486 "Folding affine_map for ", aspect, ' ', out.paramsOfStructTy.size(), " failed"
1487 )
1488 .report();
1489 return failure();
1490 }
1491 if (result.size() != 1) {
1492 // Diagnostic remark: could be removed for release builds if too noisy
1493 op->emitRemark()
1494 .append(
1495 "Folding affine_map for ", aspect, ' ', out.paramsOfStructTy.size(),
1496 " produced ", result.size(), " results but expected 1"
1497 )
1498 .report();
1499 return failure();
1500 }
1501 assert(!llvm::isa<AffineMapAttr>(result[0]) && "not converted");
1502 out.paramsOfStructTy.push_back(result[0]);
1503 continue;
1504 }
1505 // If affine but not foldable, preserve the map ops
1506 out.mapOpGroups.emplace_back(currMapOps);
1507 out.dimsPerGroup.push_back(in.dimsPerGroup[idx - 1]); // idx was already incremented
1508 }
1509 // If not affine and foldable, preserve the original
1510 out.paramsOfStructTy.push_back(sizeAttr);
1511 }
1512 assert(idx == in.mapOpGroups.size() && "all affine_map not processed");
1513 assert(
1514 in.paramsOfStructTy.size() == out.paramsOfStructTy.size() &&
1515 "produced wrong number of dimensions"
1516 );
1517
1518 return success();
1519 }
1520};
1521
1523class InstantiateAtCreateArrayOp final : public OpRewritePattern<CreateArrayOp> {
1524 [[maybe_unused]]
1525 ConversionTracker &tracker_;
1526
1527public:
1528 InstantiateAtCreateArrayOp(MLIRContext *ctx, ConversionTracker &tracker)
1529 : OpRewritePattern(ctx), tracker_(tracker) {}
1530
1531 LogicalResult matchAndRewrite(CreateArrayOp op, PatternRewriter &rewriter) const override {
1532 ArrayType oldResultType = op.getType();
1533
1534 AffineMapFolder::Output out;
1535 AffineMapFolder::Input in = {
1536 op.getMapOperands(),
1538 oldResultType.getDimensionSizes(),
1539 };
1540 if (failed(AffineMapFolder::fold(rewriter, in, out, op, "array dimension"))) {
1541 return failure();
1542 }
1543
1544 ArrayType newResultType = ArrayType::get(oldResultType.getElementType(), out.paramsOfStructTy);
1545 if (newResultType == oldResultType) {
1546 return failure(); // nothing changed
1547 }
1548 // ASSERT: folding only preserves the original Attribute or converts affine to integer
1549 assert(tracker_.isLegalConversion(oldResultType, newResultType, "InstantiateAtCreateArrayOp"));
1550 LLVM_DEBUG(
1551 llvm::dbgs() << "[InstantiateAtCreateArrayOp] instantiating " << oldResultType << " as "
1552 << newResultType << " in \"" << op << "\"\n"
1553 );
1555 rewriter, op, newResultType, AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup
1556 );
1557 return success();
1558 }
1559};
1560
1562class InstantiateAtCallOpCompute final : public OpRewritePattern<CallOp> {
1563 ConversionTracker &tracker_;
1564
1565public:
1566 InstantiateAtCallOpCompute(MLIRContext *ctx, ConversionTracker &tracker)
1567 : OpRewritePattern(ctx), tracker_(tracker) {}
1568
1569 LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override {
1570 if (!op.calleeIsStructCompute()) {
1571 // this pattern only applies when the callee is "compute()" within a struct
1572 return failure();
1573 }
1574 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] target: " << op.getCallee() << '\n');
1576 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] oldRetTy: " << oldRetTy << '\n');
1577 ArrayAttr params = oldRetTy.getParams();
1578 if (isNullOrEmpty(params)) {
1579 // nothing to do if the StructType is not parameterized
1580 return failure();
1581 }
1582
1583 AffineMapFolder::Output out;
1584 AffineMapFolder::Input in = {
1585 op.getMapOperands(),
1587 params.getValue(),
1588 };
1589 if (!in.mapOpGroups.empty()) {
1590 // If there are affine map operands, attempt to fold them to a constant.
1591 if (failed(AffineMapFolder::fold(rewriter, in, out, op, "struct parameter"))) {
1592 return failure();
1593 }
1594 LLVM_DEBUG({
1595 llvm::dbgs() << "[InstantiateAtCallOpCompute] folded affine_map in result type params\n";
1596 });
1597 } else {
1598 // If there are no affine map operands, attempt to refine the result type of the CallOp using
1599 // the function argument types and the type of the target function.
1600 auto callArgTypes = op.getArgOperands().getTypes();
1601 if (callArgTypes.empty()) {
1602 // no refinement possible if no function arguments
1603 return failure();
1604 }
1605 if (calleeReferencesTemplateParam(op)) {
1606 return failure();
1607 }
1608 SymbolTableCollection tables;
1609 auto lookupRes = lookupTopLevelSymbol<FuncDefOp>(tables, op.getCalleeAttr(), op);
1610 if (failed(lookupRes)) {
1611 return failure();
1612 }
1613 if (failed(instantiateViaTargetType(in, out, callArgTypes, lookupRes->get()))) {
1614 return failure();
1615 }
1616 LLVM_DEBUG({
1617 llvm::dbgs() << "[InstantiateAtCallOpCompute] propagated instantiations via symrefs in "
1618 "result type params: "
1619 << debug::toStringList(out.paramsOfStructTy) << '\n';
1620 });
1621 }
1622
1623 StructType newRetTy = StructType::get(oldRetTy.getNameRef(), out.paramsOfStructTy);
1624 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] newRetTy: " << newRetTy << '\n');
1625 if (newRetTy == oldRetTy) {
1626 return failure(); // nothing changed
1627 }
1628 // The `newRetTy` is computed via instantiateViaTargetType() which can only preserve the
1629 // original Attribute or convert to a concrete attribute via the unification process. Thus, if
1630 // the conversion here is illegal it means there is a type conflict within the LLZK code that
1631 // prevents instantiation of the struct with the requested type.
1632 if (!tracker_.isLegalConversion(oldRetTy, newRetTy, "InstantiateAtCallOpCompute")) {
1633 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1634 diag.append(
1635 "result type mismatch: due to struct instantiation, expected type ", newRetTy,
1636 ", but found ", oldRetTy
1637 );
1638 });
1639 }
1640 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] replaced " << op);
1642 rewriter, op, TypeRange {newRetTy}, op.getCallee(),
1643 AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup, op.getArgOperands()
1644 );
1645 (void)newOp; // tell compiler it's intentionally unused in release builds
1646 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
1647 return success();
1648 }
1649
1650private:
1653 inline LogicalResult instantiateViaTargetType(
1654 const AffineMapFolder::Input &in, AffineMapFolder::Output &out,
1655 OperandRange::type_range callArgTypes, FuncDefOp targetFunc
1656 ) const {
1657 assert(targetFunc.isStructCompute()); // since `op.calleeIsStructCompute()`
1658 ArrayAttr targetResTyParams = targetFunc.getSingleResultTypeOfCompute().getParams();
1659 assert(!isNullOrEmpty(targetResTyParams)); // same cardinality as `in.paramsOfStructTy`
1660 assert(in.paramsOfStructTy.size() == targetResTyParams.size()); // verifier ensures this
1661
1662 if (llvm::all_of(in.paramsOfStructTy, isConcreteAttr<>)) {
1663 // Nothing can change if everything is already concrete
1664 return failure();
1665 }
1666
1667 LLVM_DEBUG({
1668 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1669 << " call arg types: " << debug::toStringList(callArgTypes) << '\n';
1670 llvm::dbgs() << '[' << __FUNCTION__ << ']' << " target func arg types: "
1671 << debug::toStringList(targetFunc.getArgumentTypes()) << '\n';
1672 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1673 << " struct params @ call: " << debug::toStringList(in.paramsOfStructTy) << '\n';
1674 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1675 << " target struct params: " << debug::toStringList(targetResTyParams) << '\n';
1676 });
1677
1678 UnificationMap unifications;
1679 bool unifies = typeListsUnify(targetFunc.getArgumentTypes(), callArgTypes, {}, &unifications);
1680 (void)unifies; // tell compiler it's intentionally unused in builds without assertions
1681 assert(unifies && "should have been checked by verifiers");
1682
1683 LLVM_DEBUG({
1684 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1685 << " unifications of arg types: " << debug::toStringList(unifications) << '\n';
1686 });
1687
1688 // Check for LHS SymRef (i.e., from the target function) that have RHS concrete Attributes (i.e.
1689 // from the call argument types) without any struct parameters (because the type with concrete
1690 // struct parameters will be used to instantiate the target struct rather than the fully
1691 // flattened struct type resulting in type mismatch of the callee to target) and perform those
1692 // replacements in the `targetFunc` return type to produce the new result type for the CallOp.
1693 SmallVector<Attribute> newReturnStructParams = llvm::map_to_vector(
1694 llvm::zip_equal(targetResTyParams.getValue(), in.paramsOfStructTy),
1695 [&unifications](std::tuple<Attribute, Attribute> p) {
1696 Attribute fromCall = std::get<1>(p);
1697 // Preserve attributes that are already concrete at the call site. Otherwise attempt to lookup
1698 // non-parameterized concrete unification for the target struct parameter symbol.
1699 if (!isConcreteAttr<>(fromCall)) {
1700 Attribute fromTgt = std::get<0>(p);
1701 LLVM_DEBUG({
1702 llvm::dbgs() << "[instantiateViaTargetType] fromCall = " << fromCall << '\n';
1703 llvm::dbgs() << "[instantiateViaTargetType] fromTgt = " << fromTgt << '\n';
1704 });
1705 assert(llvm::isa<SymbolRefAttr>(fromTgt));
1706 auto it = unifications.find(std::make_pair(llvm::cast<SymbolRefAttr>(fromTgt), Side::LHS));
1707 if (it != unifications.end()) {
1708 Attribute unifiedAttr = it->second;
1709 LLVM_DEBUG({
1710 llvm::dbgs() << "[instantiateViaTargetType] unifiedAttr = " << unifiedAttr << '\n';
1711 });
1712 if (unifiedAttr && isConcreteAttr<false>(unifiedAttr)) {
1713 return unifiedAttr;
1714 }
1715 }
1716 }
1717 return fromCall;
1718 }
1719 );
1720
1721 out.paramsOfStructTy = newReturnStructParams;
1722 assert(out.paramsOfStructTy.size() == in.paramsOfStructTy.size() && "post-condition");
1723 assert(out.mapOpGroups.empty() && "post-condition");
1724 assert(out.dimsPerGroup.empty() && "post-condition");
1725 return success();
1726 }
1727};
1728
1729LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1730 MLIRContext *ctx = modOp.getContext();
1731 RewritePatternSet patterns(ctx);
1732 patterns.add<
1733 InstantiateAtCreateArrayOp, // CreateArrayOp
1734 InstantiateAtCallOpCompute // CallOp, targeting struct "compute()"
1735 >(ctx, tracker);
1736
1737 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1738}
1739
1740} // namespace Step3_InstantiateAffineMaps
1741
1743
1745class UpdateNewArrayElemFromWrite final : public OpRewritePattern<CreateArrayOp> {
1746 ConversionTracker &tracker_;
1747
1748public:
1749 UpdateNewArrayElemFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1750 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1751
1752 LogicalResult matchAndRewrite(CreateArrayOp op, PatternRewriter &rewriter) const override {
1753 Value createResult = op.getResult();
1754 ArrayType createResultType = dyn_cast<ArrayType>(createResult.getType());
1755 assert(createResultType && "CreateArrayOp must produce ArrayType");
1756 Type oldResultElemType = createResultType.getElementType();
1757
1758 // Look for WriteArrayOp where the array reference is the result of the CreateArrayOp and the
1759 // element type is different.
1760 Type newResultElemType = nullptr;
1761 for (Operation *user : createResult.getUsers()) {
1762 if (WriteArrayOp writeOp = dyn_cast<WriteArrayOp>(user)) {
1763 if (writeOp.getArrRef() != createResult) {
1764 continue;
1765 }
1766 Type writeRValueType = writeOp.getRvalue().getType();
1767 if (writeRValueType == oldResultElemType) {
1768 continue;
1769 }
1770 if (newResultElemType && newResultElemType != writeRValueType) {
1771 LLVM_DEBUG(
1772 llvm::dbgs()
1773 << "[UpdateNewArrayElemFromWrite] multiple possible element types for CreateArrayOp "
1774 << newResultElemType << " vs " << writeRValueType << '\n'
1775 );
1776 return failure();
1777 }
1778 newResultElemType = writeRValueType;
1779 }
1780 }
1781 if (!newResultElemType) {
1782 // no replacement type found
1783 return failure();
1784 }
1785 if (!tracker_.isLegalConversion(
1786 oldResultElemType, newResultElemType, "UpdateNewArrayElemFromWrite"
1787 )) {
1788 return failure();
1789 }
1790 ArrayType newType = createResultType.cloneWith(newResultElemType);
1791 rewriter.modifyOpInPlace(op, [&createResult, &newType]() { createResult.setType(newType); });
1792 LLVM_DEBUG(
1793 llvm::dbgs() << "[UpdateNewArrayElemFromWrite] updated result type of " << op << '\n'
1794 );
1795 return success();
1796 }
1797};
1798
1799namespace {
1800
1801LogicalResult updateArrayElemFromArrAccessOp(
1802 ArrayAccessOpInterface op, Type scalarElemTy, ConversionTracker &tracker,
1803 PatternRewriter &rewriter
1804) {
1805 ArrayType oldArrType = op.getArrRefType();
1806 if (oldArrType.getElementType() == scalarElemTy) {
1807 return failure(); // no change needed
1808 }
1809 ArrayType newArrType = oldArrType.cloneWith(scalarElemTy);
1810 if (oldArrType == newArrType ||
1811 !tracker.isLegalConversion(oldArrType, newArrType, "updateArrayElemFromArrAccessOp")) {
1812 return failure();
1813 }
1814 rewriter.modifyOpInPlace(op, [&op, &newArrType]() { op.getArrRef().setType(newArrType); });
1815 LLVM_DEBUG(
1816 llvm::dbgs() << "[updateArrayElemFromArrAccessOp] updated base array type in " << op << '\n'
1817 );
1818 return success();
1819}
1820
1821} // namespace
1822
1823class UpdateArrayElemFromArrWrite final : public OpRewritePattern<WriteArrayOp> {
1824 ConversionTracker &tracker_;
1825
1826public:
1827 UpdateArrayElemFromArrWrite(MLIRContext *ctx, ConversionTracker &tracker)
1828 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1829
1830 LogicalResult matchAndRewrite(WriteArrayOp op, PatternRewriter &rewriter) const override {
1831 return updateArrayElemFromArrAccessOp(op, op.getRvalue().getType(), tracker_, rewriter);
1832 }
1833};
1834
1835class UpdateArrayElemFromArrRead final : public OpRewritePattern<ReadArrayOp> {
1836 ConversionTracker &tracker_;
1837
1838public:
1839 UpdateArrayElemFromArrRead(MLIRContext *ctx, ConversionTracker &tracker)
1840 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1841
1842 LogicalResult matchAndRewrite(ReadArrayOp op, PatternRewriter &rewriter) const override {
1843 return updateArrayElemFromArrAccessOp(op, op.getResult().getType(), tracker_, rewriter);
1844 }
1845};
1846
1848class UpdateMemberDefTypeFromWrite final : public OpRewritePattern<MemberDefOp> {
1849 ConversionTracker &tracker_;
1850
1851public:
1852 UpdateMemberDefTypeFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1853 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1854
1855 LogicalResult matchAndRewrite(MemberDefOp op, PatternRewriter &rewriter) const override {
1856 // Find all uses of the member symbol name within its parent struct.
1858 assert(parentRes && "MemberDefOp parent is always StructDefOp"); // per ODS def
1859
1860 // If the symbol is used by a MemberWriteOp with a different result type then change
1861 // the type of the MemberDefOp to match the MemberWriteOp result type.
1862 Type newType = nullptr;
1863 if (auto memberUsers = llzk::getSymbolUses(op, parentRes)) {
1864 std::optional<Location> newTypeLoc = std::nullopt;
1865 for (SymbolTable::SymbolUse symUse : memberUsers.value()) {
1866 if (MemberWriteOp writeOp = llvm::dyn_cast<MemberWriteOp>(symUse.getUser())) {
1867 Type writeToType = writeOp.getVal().getType();
1868 LLVM_DEBUG(llvm::dbgs() << "[UpdateMemberDefTypeFromWrite] checking " << writeOp << '\n');
1869 if (!newType) {
1870 // If a new type has not yet been discovered, store the new type.
1871 newType = writeToType;
1872 newTypeLoc = writeOp.getLoc();
1873 } else if (writeToType != newType) {
1874 // Typically, there will only be one write for each member of a struct but do not rely
1875 // on that assumption. If multiple writes with a different types A and B are found where
1876 // A->B is a legal conversion (i.e., more concrete unification), then it is safe to use
1877 // type B with the assumption that the write with type A will be updated by another
1878 // pattern to also use type B.
1879 if (!tracker_.isLegalConversion(writeToType, newType, "UpdateMemberDefTypeFromWrite")) {
1880 if (tracker_.isLegalConversion(
1881 newType, writeToType, "UpdateMemberDefTypeFromWrite"
1882 )) {
1883 // 'writeToType' is the more concrete type
1884 newType = writeToType;
1885 newTypeLoc = writeOp.getLoc();
1886 } else {
1887 // Give an error if the types are incompatible.
1888 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1889 diag.append(
1890 "Cannot update type of '", MemberDefOp::getOperationName(),
1891 "' because there are multiple '", MemberWriteOp::getOperationName(),
1892 "' with different value types"
1893 );
1894 if (newTypeLoc) {
1895 diag.attachNote(newTypeLoc).append("type written here is ", newType);
1896 }
1897 diag.attachNote(writeOp.getLoc()).append("type written here is ", writeToType);
1898 });
1899 }
1900 }
1901 }
1902 }
1903 }
1904 }
1905 if (!newType || newType == op.getType()) {
1906 return failure(); // nothing changed
1907 }
1908 if (!tracker_.isLegalConversion(op.getType(), newType, "UpdateMemberDefTypeFromWrite")) {
1909 return failure();
1910 }
1911 rewriter.modifyOpInPlace(op, [&op, &newType]() { op.setType(newType); });
1912 LLVM_DEBUG(llvm::dbgs() << "[UpdateMemberDefTypeFromWrite] updated type of " << op << '\n');
1913 return success();
1914 }
1915};
1916
1917namespace {
1918
1919SmallVector<std::unique_ptr<Region>> moveRegions(Operation *op) {
1920 SmallVector<std::unique_ptr<Region>> newRegions;
1921 for (Region &region : op->getRegions()) {
1922 auto newRegion = std::make_unique<Region>();
1923 newRegion->takeBody(region);
1924 newRegions.push_back(std::move(newRegion));
1925 }
1926 return newRegions;
1927}
1928
1929} // namespace
1930
1933class UpdateInferredResultTypes final : public OpTraitRewritePattern<OpTrait::InferTypeOpAdaptor> {
1934 ConversionTracker &tracker_;
1935
1936public:
1937 UpdateInferredResultTypes(MLIRContext *ctx, ConversionTracker &tracker)
1938 : OpTraitRewritePattern(ctx, 6), tracker_(tracker) {}
1939
1940 LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
1941 SmallVector<Type, 1> inferredResultTypes;
1942 InferTypeOpInterface retTypeFn = llvm::cast<InferTypeOpInterface>(op);
1943 LogicalResult result = retTypeFn.inferReturnTypes(
1944 op->getContext(), op->getLoc(), op->getOperands(), op->getRawDictionaryAttrs(),
1945 op->getPropertiesStorage(), op->getRegions(), inferredResultTypes
1946 );
1947 if (failed(result)) {
1948 return failure();
1949 }
1950 if (op->getResultTypes() == inferredResultTypes) {
1951 return failure(); // nothing changed
1952 }
1953 if (!tracker_.areLegalConversions(
1954 op->getResultTypes(), inferredResultTypes, "UpdateInferredResultTypes"
1955 )) {
1956 return failure();
1957 }
1958
1959 // Move nested region bodies and replace the original op with the updated types list.
1960 LLVM_DEBUG(llvm::dbgs() << "[UpdateInferredResultTypes] replaced " << *op);
1961 SmallVector<std::unique_ptr<Region>> newRegions = moveRegions(op);
1962 Operation *newOp = rewriter.create(
1963 op->getLoc(), op->getName().getIdentifier(), op->getOperands(), inferredResultTypes,
1964 op->getAttrs(), op->getSuccessors(), newRegions
1965 );
1966 rewriter.replaceOp(op, newOp);
1967 LLVM_DEBUG(llvm::dbgs() << " with " << *newOp << '\n');
1968 return success();
1969 }
1970};
1971
1973class UpdateFuncTypeFromReturn final : public OpRewritePattern<FuncDefOp> {
1974 ConversionTracker &tracker_;
1975
1976public:
1977 UpdateFuncTypeFromReturn(MLIRContext *ctx, ConversionTracker &tracker)
1978 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1979
1980 LogicalResult matchAndRewrite(FuncDefOp op, PatternRewriter &rewriter) const override {
1981 Region &body = op.getFunctionBody();
1982 if (body.empty()) {
1983 return failure();
1984 }
1985 ReturnOp retOp = llvm::dyn_cast<ReturnOp>(body.back().getTerminator());
1986 assert(retOp && "final op in body region must be return");
1987 OperandRange::type_range tyFromReturnOp = retOp.getOperands().getTypes();
1988
1989 FunctionType oldFuncTy = op.getFunctionType();
1990 if (oldFuncTy.getResults() == tyFromReturnOp) {
1991 return failure(); // nothing changed
1992 }
1993 if (!tracker_.areLegalConversions(
1994 oldFuncTy.getResults(), tyFromReturnOp, "UpdateFuncTypeFromReturn"
1995 )) {
1996 return failure();
1997 }
1998
1999 rewriter.modifyOpInPlace(op, [&]() {
2000 op.setFunctionType(rewriter.getFunctionType(oldFuncTy.getInputs(), tyFromReturnOp));
2001 });
2002 LLVM_DEBUG(
2003 llvm::dbgs() << "[UpdateFuncTypeFromReturn] changed " << op.getSymName() << " from "
2004 << oldFuncTy << " to " << op.getFunctionType() << '\n'
2005 );
2006 return success();
2007 }
2008};
2009
2014class UpdateFreeFuncCallOpTypes final : public OpRewritePattern<CallOp> {
2015 ConversionTracker &tracker_;
2016
2017public:
2018 UpdateFreeFuncCallOpTypes(MLIRContext *ctx, ConversionTracker &tracker)
2019 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
2020
2021 LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override {
2022 if (calleeReferencesTemplateParam(op)) {
2023 return failure();
2024 }
2025 SymbolTableCollection tables;
2026 auto lookupRes = lookupTopLevelSymbol<FuncDefOp>(tables, op.getCalleeAttr(), op);
2027 if (failed(lookupRes)) {
2028 return failure();
2029 }
2030 FuncDefOp targetFunc = lookupRes->get();
2031 if (targetFunc.isInStruct()) {
2032 // this pattern only applies when the callee is NOT in a struct
2033 return failure();
2034 }
2035 if (op.getResultTypes() == targetFunc.getFunctionType().getResults()) {
2036 return failure(); // nothing changed
2037 }
2038 if (!tracker_.areLegalConversions(
2039 op.getResultTypes(), targetFunc.getFunctionType().getResults(),
2040 "UpdateFreeFuncCallOpTypes"
2041 )) {
2042 return failure();
2043 }
2044
2045 LLVM_DEBUG(llvm::dbgs() << "[UpdateFreeFuncCallOpTypes] replaced " << op);
2046 CallOp newOp = replaceOpWithNewOp<CallOp>(rewriter, op, targetFunc, op.getArgOperands());
2047 (void)newOp; // tell compiler it's intentionally unused in release builds
2048 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
2049 return success();
2050 }
2051};
2052
2053namespace {
2054
2055LogicalResult updateMemberRefValFromMemberDef(
2056 MemberRefOpInterface op, ConversionTracker &tracker, PatternRewriter &rewriter
2057) {
2058 SymbolTableCollection tables;
2059 auto def = op.getMemberDefOp(tables);
2060 if (failed(def)) {
2061 return failure();
2062 }
2063 Type oldResultType = op.getVal().getType();
2064 Type newResultType = def->get().getType();
2065 if (oldResultType == newResultType ||
2066 !tracker.isLegalConversion(oldResultType, newResultType, "updateMemberRefValFromMemberDef")) {
2067 return failure();
2068 }
2069 rewriter.modifyOpInPlace(op, [&op, &newResultType]() { op.getVal().setType(newResultType); });
2070 LLVM_DEBUG(
2071 llvm::dbgs() << "[updateMemberRefValFromMemberDef] updated value type in " << op << '\n'
2072 );
2073 return success();
2074}
2075
2076} // namespace
2077
2079class UpdateMemberReadValFromDef final : public OpRewritePattern<MemberReadOp> {
2080 ConversionTracker &tracker_;
2081
2082public:
2083 UpdateMemberReadValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
2084 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
2085
2086 LogicalResult matchAndRewrite(MemberReadOp op, PatternRewriter &rewriter) const override {
2087 return updateMemberRefValFromMemberDef(op, tracker_, rewriter);
2088 }
2089};
2090
2092class UpdateMemberWriteValFromDef final : public OpRewritePattern<MemberWriteOp> {
2093 ConversionTracker &tracker_;
2094
2095public:
2096 UpdateMemberWriteValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
2097 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
2098
2099 LogicalResult matchAndRewrite(MemberWriteOp op, PatternRewriter &rewriter) const override {
2100 return updateMemberRefValFromMemberDef(op, tracker_, rewriter);
2101 }
2102};
2103
2104LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
2105 MLIRContext *ctx = modOp.getContext();
2106 RewritePatternSet patterns(ctx);
2107 patterns.add<
2108 // Benefit of this one must be higher than rules that would propagate the type in the opposite
2109 // direction (ex: `UpdateArrayElemFromArrRead`) else the greedy conversion would not converge.
2110 // benefit = 6
2111 UpdateInferredResultTypes, // OpTrait::InferTypeOpAdaptor (ReadArrayOp, ExtractArrayOp)
2112 // benefit = 3
2113 UpdateFreeFuncCallOpTypes, // CallOp, targeting non-struct functions
2114 UpdateFuncTypeFromReturn, // FuncDefOp
2115 UpdateNewArrayElemFromWrite, // CreateArrayOp
2116 UpdateArrayElemFromArrRead, // ReadArrayOp
2117 UpdateArrayElemFromArrWrite, // WriteArrayOp
2118 UpdateMemberDefTypeFromWrite, // MemberDefOp
2119 UpdateMemberReadValFromDef, // MemberReadOp
2120 UpdateMemberWriteValFromDef // MemberWriteOp
2121 >(ctx, tracker);
2122
2123 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
2124}
2125} // namespace Step4_PropagateTypes
2126
2127namespace Step5_Cleanup {
2128
2129class CleanupBase {
2130public:
2131 SymbolTableCollection tables;
2132
2133 CleanupBase(ModuleOp root, const SymbolDefTree &symDefTree, const SymbolUseGraph &symUseGraph)
2134 : rootMod(root), defTree(symDefTree), useGraph(symUseGraph) {}
2135
2136protected:
2137 ModuleOp rootMod;
2138 const SymbolDefTree &defTree;
2139 const SymbolUseGraph &useGraph;
2140};
2141
2142struct FromKeepSet : public CleanupBase {
2143 using CleanupBase::CleanupBase;
2144
2147 static bool hasTemplateSymbolBindings(Operation *op) {
2148 if (StructDefOp sdef = llvm::dyn_cast<StructDefOp>(op)) {
2149 return sdef.hasTemplateSymbolBindings();
2150 }
2151 if (llvm::isa<function::FuncDefOp>(op)) {
2152 if (TemplateOp parent = getParentOfType<TemplateOp>(op)) {
2153 return parent.hasConstOps<TemplateSymbolBindingOpInterface>();
2154 }
2155 }
2156 return false;
2157 }
2158
2160 static bool isErasableDefinition(Operation *op) {
2161 if (llvm::isa<StructDefOp>(op)) {
2162 return true;
2163 }
2164 if (function::FuncDefOp fdef = llvm::dyn_cast<function::FuncDefOp>(op)) {
2165 return !fdef.isInStruct();
2166 }
2167 return false;
2168 }
2169
2173 LogicalResult eraseUnreachableFrom(ArrayRef<SymbolOpInterface> keep) {
2174 // Initialize roots from the given symbol definitions.
2175 SetVector<SymbolOpInterface> roots(keep.begin(), keep.end());
2176 // Add GlobalDefOp to the set of roots.
2177 rootMod.walk([&roots](global::GlobalDefOp gdef) { roots.insert(gdef); });
2178
2179 // Use a SymbolDefTree to find all Symbol defs reachable from one of the root nodes. Then
2180 // collect all Symbol uses reachable from those def nodes. These are the symbols that should
2181 // be preserved. All other symbol defs should be removed.
2182 DenseSet<Operation *> defsToKeep;
2183 llvm::df_iterator_default_set<const SymbolUseGraphNode *> symbolsToKeep;
2184 for (size_t i = 0; i < roots.size(); ++i) { // iterate for safe insertion
2185 SymbolOpInterface keepRoot = roots[i];
2186 LLVM_DEBUG({ llvm::dbgs() << "[EraseUnreachable] root: " << keepRoot << '\n'; });
2187 const SymbolDefTreeNode *keepRootNode = defTree.lookupNode(keepRoot);
2188 assert(keepRootNode && "every symbol def must be in the def tree");
2189 for (const SymbolDefTreeNode *reachableDefNode : llvm::depth_first(keepRootNode)) {
2190 LLVM_DEBUG({
2191 llvm::dbgs() << "[EraseUnreachable] can reach: " << reachableDefNode->getOp() << '\n';
2192 });
2193 if (SymbolOpInterface reachableDef = reachableDefNode->getOp()) {
2194 if (isErasableDefinition(reachableDef.getOperation())) {
2195 defsToKeep.insert(reachableDef.getOperation());
2196 }
2197 // Use 'depth_first_ext()' to get all symbol uses reachable from the current Symbol def
2198 // node. There are no uses if the node is not in the graph. Within the loop that populates
2199 // 'depth_first_ext()', also check if the symbol is an erasable definition and ensure it
2200 // is in 'roots' so the outer loop preserves all symbols reachable from it.
2201 if (const SymbolUseGraphNode *useGraphNodeForDef = useGraph.lookupNode(reachableDef)) {
2202 for (const SymbolUseGraphNode *usedSymbolNode :
2203 depth_first_ext(useGraphNodeForDef, symbolsToKeep)) {
2204 LLVM_DEBUG({
2205 llvm::dbgs() << "[EraseUnreachable] uses symbol: "
2206 << usedSymbolNode->getSymbolPath() << '\n';
2207 });
2208 // Ignore struct/template parameter symbols (before doing the lookup below because it
2209 // would fail anyway and then cause the "failed" case to be triggered unnecessarily).
2210 if (usedSymbolNode->isTemplateSymbolBinding()) {
2211 continue;
2212 }
2213 // If `usedSymbolNode` references an erasable definition, ensure it's considered in
2214 // the roots so symbols reachable from its body are preserved too.
2215 auto lookupRes = usedSymbolNode->lookupSymbol(tables);
2216 if (failed(lookupRes)) {
2217 LLVM_DEBUG(useGraph.dumpToDotFile());
2218 return failure();
2219 }
2220 // If loaded via an IncludeOp it's not in the current AST anyway so ignore.
2221 if (lookupRes->viaInclude()) {
2222 continue;
2223 }
2224 Operation *usedOp = lookupRes->get();
2225 if (isErasableDefinition(usedOp)) {
2226 SymbolOpInterface asSymbol = llvm::cast<SymbolOpInterface>(usedOp);
2227 bool insertRes = roots.insert(asSymbol);
2228 (void)insertRes; // tell compiler it's intentionally unused in release builds
2229 LLVM_DEBUG({
2230 if (insertRes) {
2231 llvm::dbgs() << "[EraseUnreachable] found another root: " << asSymbol << '\n';
2232 }
2233 });
2234 }
2235 }
2236 }
2237 }
2238 }
2239 }
2240
2241 SmallVector<SymbolOpInterface> toErase;
2242 rootMod.walk([this, &defsToKeep, &symbolsToKeep, &toErase](Operation *op) {
2243 if (!isErasableDefinition(op) || defsToKeep.contains(op)) {
2244 return;
2245 }
2246 SymbolOpInterface symOp = llvm::cast<SymbolOpInterface>(op);
2247 const SymbolUseGraphNode *n = this->useGraph.lookupNode(symOp);
2248 if (!n || !symbolsToKeep.contains(n)) {
2249 LLVM_DEBUG(llvm::dbgs() << "[EraseUnreachable] removing: " << symOp.getNameAttr() << '\n');
2250 toErase.push_back(symOp);
2251 }
2252 });
2253 for (SymbolOpInterface symOp : toErase) {
2254 symOp.erase();
2255 }
2256
2257 return success();
2258 }
2259};
2260
2261struct FromEraseSet : public CleanupBase {
2262
2264 FromEraseSet(
2265 ModuleOp root, const SymbolDefTree &symDefTree, const SymbolUseGraph &symUseGraph,
2266 DenseSet<SymbolRefAttr> &&tryToErasePaths
2267 )
2268 : CleanupBase(root, symDefTree, symUseGraph) {
2269 // Convert the set of paths targeted for erasure into a set of cleanup-candidate definitions.
2270 for (SymbolRefAttr path : tryToErasePaths) {
2271 LLVM_DEBUG(llvm::dbgs() << "[FromEraseSet] path to erase: " << path << '\n';);
2272 Operation *lookupFrom = rootMod.getOperation();
2273 auto res = lookupSymbolIn(tables, path, Within(), lookupFrom);
2274 assert(succeeded(res) && "inputs must be valid symbol references");
2275 assert(FromKeepSet::isErasableDefinition(res->get()) && "inputs must be cleanup candidates");
2276 if (!res->viaInclude()) { // do not remove if it's from another source file
2277 SymbolOpInterface op = llvm::cast<SymbolOpInterface>(res->get());
2278 LLVM_DEBUG(llvm::dbgs() << "[FromEraseSet] added op to the erase set: " << op << '\n';);
2279 tryToErase.insert(op);
2280 } else {
2281 LLVM_DEBUG(
2282 llvm::dbgs() << "[FromEraseSet] ignored op because it comes from an include: "
2283 << res->get() << '\n';
2284 );
2285 }
2286 }
2287 }
2288
2289 LogicalResult eraseUnusedDefinitions() {
2290 // Collect the subset of 'tryToErase' that has no remaining uses.
2291 for (SymbolOpInterface sym : tryToErase) {
2292 collectSafeToErase(sym);
2293 }
2294 // The `visitedPlusSafetyResult` may contain child FuncDefOp within an erased StructDefOp, so
2295 // reduce the map to only top-level erase targets before erasing in a separate loop.
2296 for (auto &it : llvm::make_early_inc_range(visitedPlusSafetyResult)) {
2297 if (!it.second || !tryToErase.contains(it.first)) {
2298 visitedPlusSafetyResult.erase(it.first);
2299 }
2300 }
2301 for (auto &[sym, _] : visitedPlusSafetyResult) {
2302 LLVM_DEBUG(llvm::dbgs() << "[EraseIfUnused] removing: " << sym.getNameAttr() << '\n');
2303 sym.erase();
2304 }
2305 return success();
2306 }
2307
2308 const DenseSet<SymbolOpInterface> &getTryToEraseSet() const { return tryToErase; }
2309
2310private:
2312 DenseSet<SymbolOpInterface> tryToErase;
2316 DenseMap<SymbolOpInterface, bool> visitedPlusSafetyResult;
2318 DenseMap<const SymbolUseGraphNode *, SymbolOpInterface> lookupCache;
2319
2322 bool collectSafeToErase(SymbolOpInterface check) {
2323 assert(check); // pre-condition
2324
2325 // If previously visited, return the safety result.
2326 auto visited = visitedPlusSafetyResult.find(check);
2327 if (visited != visitedPlusSafetyResult.end()) {
2328 return visited->second;
2329 }
2330
2331 // If it's an erasable definition that is not in `tryToErase` then it cannot be erased.
2332 if (FromKeepSet::isErasableDefinition(check.getOperation()) && !tryToErase.contains(check)) {
2333 visitedPlusSafetyResult[check] = false;
2334 return false;
2335 }
2336
2337 // Otherwise, temporarily mark as safe b/c a node cannot keep itself live (and this prevents
2338 // the recursion from getting stuck in an infinite loop).
2339 visitedPlusSafetyResult[check] = true;
2340
2341 // Check if it's safe according to both the def tree and use graph.
2342 // Note: Every symbol must have a def node but ModuleOp and TemplateOp symbols may not have a
2343 // use node since they are not "terminal" symbols (i.e. they are not referred to directly).
2344 if (collectSafeToErase(defTree.lookupNode(check))) {
2345 const auto *useNode = useGraph.lookupNode(check);
2346 assert(useNode || (llvm::isa<ModuleOp, TemplateOp>(check.getOperation())));
2347 if (!useNode || collectSafeToErase(useNode)) {
2348 return true;
2349 }
2350 }
2351
2352 // Otherwise, revert the safety decision and return it.
2353 visitedPlusSafetyResult[check] = false;
2354 return false;
2355 }
2356
2358 bool collectSafeToErase(const SymbolDefTreeNode *check) {
2359 assert(check); // pre-condition
2360 if (const SymbolDefTreeNode *p = check->getParent()) {
2361 if (SymbolOpInterface checkOp = p->getOp()) { // safe if parent is root
2362 return collectSafeToErase(checkOp);
2363 }
2364 }
2365 return true;
2366 }
2367
2369 bool collectSafeToErase(const SymbolUseGraphNode *check) {
2370 assert(check); // pre-condition
2371 for (const SymbolUseGraphNode *p : check->predecessorIter()) {
2372 if (SymbolOpInterface checkOp = cachedLookup(p)) { // safe if via IncludeOp
2373 if (!collectSafeToErase(checkOp)) {
2374 return false;
2375 }
2376 }
2377 }
2378 return true;
2379 }
2380
2385 SymbolOpInterface cachedLookup(const SymbolUseGraphNode *node) {
2386 assert(node && "must provide a node"); // pre-condition
2387 // Check for cached result
2388 auto fromCache = lookupCache.find(node);
2389 if (fromCache != lookupCache.end()) {
2390 return fromCache->second;
2391 }
2392 // Otherwise, perform lookup and cache
2393 auto lookupRes = node->lookupSymbol(tables);
2394 assert(succeeded(lookupRes) && "graph contains node with invalid path");
2395 assert(lookupRes->get() != nullptr && "lookup must return an Operation");
2396 // If loaded via an IncludeOp it's not in the current AST anyway so ignore.
2397 // NOTE: The SymbolUseGraph does contain nodes for struct parameters which cannot cast to
2398 // SymbolOpInterface. However, those will always be leaf nodes in the SymbolUseGraph and
2399 // therefore will not be traversed by this analysis so directly casting is fine.
2400 SymbolOpInterface actualRes =
2401 lookupRes->viaInclude() ? nullptr : llvm::cast<SymbolOpInterface>(lookupRes->get());
2402 // Cache and return
2403 lookupCache[node] = actualRes;
2404 assert((!actualRes == lookupRes->viaInclude()) && "not found iff included"); // post-condition
2405 return actualRes;
2406 }
2407};
2408
2409} // namespace Step5_Cleanup
2410
2411class PassImpl : public llzk::polymorphic::impl::FlatteningPassBase<PassImpl> {
2412 using Base = FlatteningPassBase<PassImpl>;
2413 using Base::Base;
2414
2415 void runOnOperation() override {
2416 ModuleOp modOp = getOperation();
2417 if (failed(runOn(modOp))) {
2418 LLVM_DEBUG({
2419 // If the pass failed, dump the current IR.
2420 llvm::dbgs() << "=====================================================================\n";
2421 llvm::dbgs() << " Dumping module after failure of pass " << DEBUG_TYPE << '\n';
2422 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
2423 llvm::dbgs() << "=====================================================================\n";
2424 });
2425 signalPassFailure();
2426 }
2427 }
2428
2429 inline LogicalResult runOn(ModuleOp modOp) {
2430 // If the cleanup mode is set to remove anything not reachable from the main struct, do an
2431 // initial pass to remove things that are not reachable (as an optimization) because creating
2432 // an instantiated version of a struct will not cause something to become reachable that was
2433 // not already reachable in parameterized form.
2434 if (cleanupMode == FlatteningCleanupMode::MainAsRoot) {
2435 if (failed(eraseUnreachableFromMainStruct(modOp))) {
2436 return failure();
2437 }
2438 }
2439
2440 // Pass Manager to run some standard cleanup passes that are always beneficial:
2441 // - Remove templates that contain no struct or function definitions
2442 // - Convert templates with no constant parameters or expressions into modules
2443 OpPassManager universalCleanup(ModuleOp::getOperationName());
2444 universalCleanup.addPass(createEmptyTemplateRemovalPass());
2445
2446 // Run universal cleanup as a preliminary step to satisfy the
2447 // `assert(!isNullOrEmpty(paramNames))` precondition in `genClone()`.
2448 if (failed(runPipeline(universalCleanup, modOp))) {
2449 return failure();
2450 }
2451
2452 ConversionTracker tracker;
2453 unsigned loopCount = 0;
2454 do {
2455 ++loopCount;
2456 if (loopCount > iterationLimit) {
2457 llvm::errs() << DEBUG_TYPE << " exceeded the limit of " << iterationLimit
2458 << " iterations!\n";
2459 return failure();
2460 }
2461 tracker.resetModifiedFlag();
2462
2463 LLVM_DEBUG({
2464 llvm::dbgs() << "[FlatteningPass(count=" << loopCount
2465 << ")] Running step 1: struct instantiation\n";
2466 });
2467 // Find calls to "compute()" that return a parameterized struct type and replace it to call an
2468 // instantiated version of the struct that has parameters replaced with the constant values.
2469 // Create the necessary instantiated/flattened struct in the same location as the original.
2470 if (failed(Step1A_InstantiateStructs::run(modOp, tracker))) {
2471 llvm::errs() << DEBUG_TYPE << " failed while instantiating structs in templates\n";
2472 return failure();
2473 }
2474 // Instantiate calls to templated functions.
2475 if (failed(Step1B_InstantiateFunctions::run(modOp, tracker))) {
2476 llvm::errs() << DEBUG_TYPE << " failed while instantiating functions in templates\n";
2477 return failure();
2478 }
2479
2480 LLVM_DEBUG({
2481 llvm::dbgs() << "[FlatteningPass(count=" << loopCount
2482 << ")] Running step 2: loop unrolling\n";
2483 });
2484 // Unroll loops with known iterations.
2485 if (failed(Step2_Unroll::run(modOp, tracker))) {
2486 llvm::errs() << DEBUG_TYPE << " failed while unrolling loops\n";
2487 return failure();
2488 }
2489
2490 LLVM_DEBUG({
2491 llvm::dbgs() << "[FlatteningPass(count=" << loopCount
2492 << ")] Running step 3: affine maps instantiation\n";
2493 });
2494 // Instantiate affine_map parameters of StructType and ArrayType.
2495 if (failed(Step3_InstantiateAffineMaps::run(modOp, tracker))) {
2496 llvm::errs() << DEBUG_TYPE << " failed while instantiating `affine_map` parameters\n";
2497 return failure();
2498 }
2499
2500 LLVM_DEBUG({
2501 llvm::dbgs() << "[FlatteningPass(count=" << loopCount
2502 << ")] Running step 4: type propagation\n";
2503 });
2504 // Propagate updated types using the semantics of various ops.
2505 if (failed(Step4_PropagateTypes::run(modOp, tracker))) {
2506 llvm::errs() << DEBUG_TYPE << " failed while propagating instantiated types\n";
2507 return failure();
2508 }
2509
2510 LLVM_DEBUG(if (tracker.isModified()) {
2511 llvm::dbgs() << "=====================================================================\n";
2512 llvm::dbgs() << " Dumping module between iterations of " << DEBUG_TYPE << '\n';
2513 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
2514 llvm::dbgs() << "=====================================================================\n";
2515 });
2516 } while (tracker.isModified());
2517
2518 // Run user-selected cleanup first.
2519 if (failed(cleanupSwitch(modOp, tracker))) {
2520 return failure();
2521 }
2522 // Run universal cleanup again since no-param or param-only structs may exist now.
2523 if (failed(runPipeline(universalCleanup, modOp))) {
2524 return failure();
2525 }
2526
2527 OpPassManager allocationCleanup(ModuleOp::getOperationName());
2528 allocationCleanup.addPass(createRemoveUnusedDiscardableAllocationsPass(
2529 RemoveUnusedDiscardableAllocationsPassOptions {
2530 .allocatorOpName = CreateArrayOp::getOperationName().str()
2531 }
2532 ));
2533 return runPipeline(allocationCleanup, modOp);
2534 }
2535
2536 // Perform cleanup according to the 'cleanupMode' option.
2537 LogicalResult cleanupSwitch(ModuleOp modOp, const ConversionTracker &tracker) {
2538 LLVM_DEBUG({ llvm::dbgs() << "[FlatteningPass] Running step 5: cleanup "; });
2539 switch (cleanupMode) {
2540 case FlatteningCleanupMode::MainAsRoot:
2541 LLVM_DEBUG(llvm::dbgs() << "(main as root mode)\n");
2542 return eraseUnreachableFromMainStruct(modOp, false);
2543 case FlatteningCleanupMode::ConcreteAsRoot:
2544 LLVM_DEBUG(llvm::dbgs() << "(concrete definitions mode)\n");
2545 return eraseUnreachableFromConcreteDefinitions(modOp);
2546 case FlatteningCleanupMode::Preimage:
2547 LLVM_DEBUG(llvm::dbgs() << "(preimage mode)\n");
2548 return erasePreimageOfInstantiations(modOp, tracker);
2549 default:
2550 LLVM_DEBUG(llvm::dbgs() << "(disabled)\n");
2551 return success();
2552 }
2553 }
2554
2555 // Erase parameterized definitions that were replaced with concrete instantiations.
2556 LogicalResult erasePreimageOfInstantiations(ModuleOp rootMod, const ConversionTracker &tracker) {
2557 // TODO: The names from getInstantiatedDefinitionNames() are NOT guaranteed to be paths from the
2558 // "top root" and they also do not indicate a root module so there could be ambiguity. This is a
2559 // broader problem in the FlatteningPass itself so let's just assume, for now, that these are
2560 // paths from the "top root". See [LLZK-286].
2561 Step5_Cleanup::FromEraseSet cleaner(
2562 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>(),
2563 tracker.getInstantiatedDefinitionNames()
2564 );
2565 LogicalResult res = cleaner.eraseUnusedDefinitions();
2566 if (succeeded(res)) {
2567 LLVM_DEBUG(llvm::dbgs() << "[Cleanup(preimage)] success\n";);
2568 // Warn about any definitions that were instantiated but still have uses elsewhere.
2569 const SymbolUseGraph *useGraph = nullptr;
2570 rootMod->walk([this, &cleaner, &useGraph](Operation *walkedOp) {
2571 SymbolOpInterface op = llvm::dyn_cast<SymbolOpInterface>(walkedOp);
2572 if (!op || !cleaner.getTryToEraseSet().contains(op)) {
2573 return;
2574 }
2575 // If needed, rebuild use graph to reflect deletions.
2576 if (!useGraph) {
2577 useGraph = &getAnalysis<SymbolUseGraph>();
2578 }
2579 // If the op has any users, report the warning.
2580 if (useGraph->lookupNode(op)->hasPredecessor()) {
2581 op.emitWarning("Parameterized definition still has uses!").report();
2582 }
2583 });
2584 } else {
2585 LLVM_DEBUG(llvm::dbgs() << "[Cleanup(preimage)] failed\n";);
2586 }
2587 return res;
2588 }
2589
2590 LogicalResult eraseUnreachableFromConcreteDefinitions(ModuleOp rootMod) {
2591 SmallVector<SymbolOpInterface> roots;
2592 rootMod.walk([&roots](Operation *op) {
2593 if (Step5_Cleanup::FromKeepSet::isErasableDefinition(op) &&
2594 !Step5_Cleanup::FromKeepSet::hasTemplateSymbolBindings(op)) {
2595 roots.push_back(llvm::cast<SymbolOpInterface>(op));
2596 }
2597 });
2598
2599 Step5_Cleanup::FromKeepSet cleaner(
2600 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
2601 );
2602 return cleaner.eraseUnreachableFrom(roots);
2603 }
2604
2605 LogicalResult eraseUnreachableFromMainStruct(ModuleOp rootMod, bool emitWarning = true) {
2606 Step5_Cleanup::FromKeepSet cleaner(
2607 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
2608 );
2609 FailureOr<SymbolLookupResult<StructDefOp>> mainOpt =
2610 getMainInstanceDef(cleaner.tables, rootMod.getOperation());
2611 if (failed(mainOpt)) {
2612 return failure();
2613 }
2614 SymbolLookupResult<StructDefOp> main = mainOpt.value();
2615 if (emitWarning && !main) {
2616 // Emit warning if there is no main specified because all cleanup-candidate definitions not
2617 // reachable from global defs may be removed.
2618 rootMod.emitWarning()
2619 .append(
2620 "using option '", cleanupMode.getArgStr(), '=',
2621 stringifyFlatteningCleanupMode(FlatteningCleanupMode::MainAsRoot), "' with no \"",
2623 "\" attribute on the top-level module may remove all cleanup-candidate definitions!"
2624 )
2625 .report();
2626 }
2627 SmallVector<SymbolOpInterface> roots;
2628 if (main) {
2629 roots.push_back(*main);
2630 }
2631 return cleaner.eraseUnreachableFrom(roots);
2632 }
2633};
2634
2635} // namespace
#define DEBUG_TYPE
#define check(x)
Definition Ops.cpp:285
Common private implementation for poly dialect passes.
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:55
Builds a tree structure representing the symbol table structure.
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
Builds a graph structure representing the relationships between symbols and their uses.
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
Gets the SSA Value for the referenced array.
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced array.
ArrayType cloneWith(std::optional<::llvm::ArrayRef< int64_t > > shape, ::mlir::Type elementType) const
Clone this type with the given shape and element type.
::mlir::Type getElementType() const
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
Definition Types.cpp.inc:83
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.h.inc:408
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:421
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:377
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:392
::mlir::TypedValue<::mlir::Type > getResult()
Definition Ops.h.inc:923
::mlir::TypedValue<::mlir::Type > getRvalue()
Definition Ops.h.inc:1078
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:353
void setType(::mlir::Type attrValue)
Definition Ops.cpp.inc:556
::std::optional<::mlir::Attribute > getTableOffset()
Definition Ops.cpp.inc:979
void setTableOffsetAttr(::mlir::Attribute attr)
Definition Ops.h.inc:750
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the MemberRefOp.
::mlir::FailureOr< SymbolLookupResult< MemberDefOp > > getMemberDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the member referenced in this op.
Definition Ops.cpp:691
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:938
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:1165
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:1600
void setSymName(::llvm::StringRef attrValue)
Definition Ops.cpp.inc:1605
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:79
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
Definition Types.cpp:26
::mlir::ArrayAttr getParams() const
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:1145
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the callee is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:1168
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.h.inc:292
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:1067
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:1139
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:470
::mlir::FunctionType getTypeSignature()
Return the FunctionType inferred from the arg operands and result types of this CallOp.
Definition Ops.cpp:1107
void setTemplateParamsAttr(::mlir::ArrayAttr attr)
Definition Ops.h.inc:316
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:266
::mlir::ArrayAttr getTemplateParamsAttr()
Definition Ops.h.inc:297
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:270
void setCalleeAttr(::mlir::SymbolRefAttr attr)
Definition Ops.h.inc:312
::mlir::FailureOr< UnificationMap > unifyTypeSignature(::mlir::FunctionType other)
Attempt type unfication between the inferred FunctionType from this CallOp (as LHS) and the given Fun...
Definition Ops.cpp:1111
::mlir::LogicalResult verifyTemplateParamsMatchInferred(::llvm::iterator_range<::mlir::Region::op_iterator<::llzk::polymorphic::TemplateParamOp > > targetParamDefs, const UnificationMap &unifications)
Verify that each template parameter value provided in this CallOp is consistent with the value inferr...
Definition Ops.cpp:669
::mlir::LogicalResult verifyTemplateParamCompatibility(::mlir::Attribute paramFromCallOp, ::llzk::polymorphic::TemplateParamOp targetParam)
Check type compatibility of the given template parameter value from this CallOp against the declared ...
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
Definition Ops.cpp:1161
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:302
FuncDefOp clone(::mlir::IRMapping &mapper)
Create a deep copy of this function and all of its blocks, remapping any operands that use values out...
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:984
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Required by FunctionOpInterface.
Definition Ops.h.inc:850
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the name is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:481
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:979
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
Definition Ops.h.inc:879
bool isInStruct()
Return true iff the function is within a StructDefOp.
Definition Ops.h.inc:876
void setFunctionType(::mlir::FunctionType attrValue)
Definition Ops.cpp.inc:1003
void setSymName(::llvm::StringRef attrValue)
Definition Ops.cpp.inc:999
::mlir::StringAttr getSymNameAttr()
Definition Ops.h.inc:703
::mlir::Operation::operand_range getOperands()
Definition Ops.h.inc:991
::mlir::FlatSymbolRefAttr getConstNameAttr()
Definition Ops.h.inc:464
::mlir::StringAttr getSymNameAttr()
Definition Ops.h.inc:673
::mlir::Region & getInitializerRegion()
Definition Ops.h.inc:660
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:838
::mlir::Region & getBodyRegion()
Definition Ops.h.inc:872
::llvm::SmallVector<::mlir::Attribute > getConstNames()
Return the names of all ops of type OpT within the body region in the order they are defined.
Definition Ops.h.inc:941
bool hasConstNamed(::mlir::StringRef find)
Return true if there is an op of type OpT with the given name within the body region.
Definition Ops.h.inc:949
::mlir::StringAttr getSymNameAttr()
Definition Ops.h.inc:885
void setSymName(::llvm::StringRef attrValue)
Definition Ops.cpp.inc:1064
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:1059
inline ::llvm::iterator_range<::mlir::Region::op_iterator< OpT > > getConstOps()
Return ops of type OpT within the body region.
Definition Ops.h.inc:921
::mlir::FlatSymbolRefAttr getNameRef() const
int main(int argc, char **argv)
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
Definition Debug.h:156
mlir::ConversionTarget newConverterDefinedTargetWithCallback(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, LegalityCheckCallback &cb, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:97
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:81
std::unique_ptr<::mlir::Pass > createEmptyTemplateRemovalPass()
::llvm::StringRef stringifyFlatteningCleanupMode(FlatteningCleanupMode val)
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter::replaceOpWithNewOp() that automatically copies discardable attributes (i...
std::unique_ptr<::mlir::Pass > createRemoveUnusedDiscardableAllocationsPass()
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
Definition TypeHelper.h:240
bool isConcreteType(Type type, bool allowStructParams)
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
TypeClass getIfSingleton(mlir::TypeRange types)
Definition TypeHelper.h:270
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:186
std::string stringWithoutType(mlir::Attribute a)
bool isNullOrEmpty(mlir::ArrayAttr a)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent/ancestor operation that is of type 'OpClass'.
Definition OpHelpers.h:51
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
Definition TypeHelper.h:274
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet covering all LLZK op types that may contain a StructType.
bool isDynamic(IntegerAttr intAttr)
mlir::SymbolRefAttr asSymbolRefAttr(mlir::StringAttr root, mlir::SymbolRefAttr tail)
Build a SymbolRefAttr that prepends tail with root, i.e., root::tail.
int64_t fromAPInt(const llvm::APInt &i)
FailureOr< SymbolLookupResult< StructDefOp > > getMainInstanceDef(SymbolTableCollection &symbolTable, Operation *lookupFrom)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
llvm::SmallVector< FlatSymbolRefAttr > getPieces(SymbolRefAttr ref)
constexpr char MAIN_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that specifies the type of the main struct.
Definition Constants.h:37