LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
FlatteningPass.cpp
Go to the documentation of this file.
1//===-- LLZKFlatteningPass.cpp - Implements -llzk-flatten pass --*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
28#include "llzk/Util/Concepts.h"
29#include "llzk/Util/Debug.h"
34
35#include <mlir/Dialect/Affine/IR/AffineOps.h>
36#include <mlir/Dialect/Affine/LoopUtils.h>
37#include <mlir/Dialect/Arith/IR/Arith.h>
38#include <mlir/Dialect/SCF/IR/SCF.h>
39#include <mlir/Dialect/SCF/Utils/Utils.h>
40#include <mlir/Dialect/Utils/StaticValueUtils.h>
41#include <mlir/IR/Attributes.h>
42#include <mlir/IR/BuiltinAttributes.h>
43#include <mlir/IR/BuiltinOps.h>
44#include <mlir/IR/BuiltinTypes.h>
45#include <mlir/Interfaces/InferTypeOpInterface.h>
46#include <mlir/Pass/PassManager.h>
47#include <mlir/Support/LLVM.h>
48#include <mlir/Support/LogicalResult.h>
49#include <mlir/Transforms/DialectConversion.h>
50#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
51#include <mlir/Transforms/WalkPatternRewriteDriver.h>
52
53#include <llvm/ADT/APInt.h>
54#include <llvm/ADT/DenseMap.h>
55#include <llvm/ADT/DepthFirstIterator.h>
56#include <llvm/ADT/STLExtras.h>
57#include <llvm/ADT/SmallVector.h>
58#include <llvm/ADT/TypeSwitch.h>
59#include <llvm/Support/Debug.h>
60
61// Include the generated base pass class definitions.
62namespace llzk::polymorphic {
63#define GEN_PASS_DECL_FLATTENINGPASS
64#define GEN_PASS_DEF_FLATTENINGPASS
66} // namespace llzk::polymorphic
67
68#include "SharedImpl.h"
69
70#define DEBUG_TYPE "llzk-flatten"
71
72using namespace mlir;
73using namespace llzk;
74using namespace llzk::array;
75using namespace llzk::component;
76using namespace llzk::constrain;
77using namespace llzk::felt;
78using namespace llzk::function;
79using namespace llzk::polymorphic;
80using namespace llzk::polymorphic::detail;
81
82namespace {
83
84static void reportDelayedDiagnostics(CallOp caller, SmallVector<Diagnostic> &&diagnostics) {
85 DiagnosticEngine &engine = caller.getContext()->getDiagEngine();
86 for (Diagnostic &diag : diagnostics) {
87 // Update any notes referencing an UnknownLoc to use the CallOp location.
88 for (Diagnostic &note : diag.getNotes()) {
89 assert(note.getNotes().empty() && "notes cannot have notes attached");
90 if (llvm::isa<UnknownLoc>(note.getLocation())) {
91 note = std::move(Diagnostic(caller.getLoc(), note.getSeverity()).append(note.str()));
92 }
93 }
94 // Report. Based on InFlightDiagnostic::report().
95 engine.emit(std::move(diag));
96 }
97}
98
99class ConversionTracker {
101 bool modified;
104 DenseMap<StructType, StructType> structInstantiations;
106 DenseMap<StructType, StructType> reverseInstantiations;
109 DenseMap<StructType, SmallVector<Diagnostic>> delayedDiagnostics;
110
111public:
112 bool isModified() const { return modified; }
113 void resetModifiedFlag() { modified = false; }
114 void updateModifiedFlag(bool currStepModified) { modified |= currStepModified; }
115
116 void recordInstantiation(StructType oldType, StructType newType) {
117 assert(!isNullOrEmpty(oldType.getParams()) && "cannot instantiate with no params");
118
119 auto forwardResult = structInstantiations.try_emplace(oldType, newType);
120 if (forwardResult.second) {
121 // Insertion was successful
122 // ASSERT: The reverse map does not contain this mapping either
123 assert(!reverseInstantiations.contains(newType));
124 reverseInstantiations[newType] = oldType;
125 // Set the modified flag
126 modified = true;
127 } else {
128 // ASSERT: If a mapping already existed for `oldType` it must be `newType`
129 assert(forwardResult.first->getSecond() == newType);
130 // ASSERT: The reverse mapping is already present as well
131 assert(reverseInstantiations.lookup(newType) == oldType);
132 }
133 assert(structInstantiations.size() == reverseInstantiations.size());
134 }
137 std::optional<StructType> getInstantiation(StructType oldType) const {
138 auto cachedResult = structInstantiations.find(oldType);
139 if (cachedResult != structInstantiations.end()) {
140 return cachedResult->second;
141 }
142 return std::nullopt;
143 }
144
146 DenseSet<SymbolRefAttr> getInstantiatedStructNames() const {
147 DenseSet<SymbolRefAttr> instantiatedNames;
148 for (const auto &[origRemoteTy, _] : structInstantiations) {
149 instantiatedNames.insert(origRemoteTy.getNameRef());
150 }
151 return instantiatedNames;
152 }
153
154 void reportDelayedDiagnostics(StructType newType, CallOp caller) {
155 auto res = delayedDiagnostics.find(newType);
156 if (res != delayedDiagnostics.end()) {
157 ::reportDelayedDiagnostics(caller, std::move(res->second));
158
159 // Emitting a Diagnostic consumes it (per DiagnosticEngine::emit) so remove them from the map.
160 // Unfortunately, this means if the key StructType is the result of instantiation at multiple
161 // `compute()` calls it will only be reported at one of those locations, not all.
162 delayedDiagnostics.erase(newType);
163 }
164 }
165
166 SmallVector<Diagnostic> &delayedDiagnosticSet(StructType newType) {
167 return delayedDiagnostics[newType];
168 }
169
172 bool isLegalConversion(Type oldType, Type newType, const char *patName) const {
173 std::function<bool(Type, Type)> checkInstantiations = [&](Type oTy, Type nTy) {
174 // Check if `oTy` is a struct with a known instantiation to `nTy`
175 if (StructType oldStructType = llvm::dyn_cast<StructType>(oTy)) {
176 // Note: The values in `structInstantiations` must be no-parameter struct types
177 // so there is no need for recursive check, simple equality is sufficient.
178 if (this->structInstantiations.lookup(oldStructType) == nTy) {
179 return true;
180 }
181 }
182 // Check if `nTy` is the result of a struct instantiation and if the pre-image of
183 // that instantiation (i.e., the parameterized version of the instantiated struct)
184 // is a more concrete unification of `oTy`.
185 if (StructType newStructType = llvm::dyn_cast<StructType>(nTy)) {
186 if (auto preImage = this->reverseInstantiations.lookup(newStructType)) {
187 if (isMoreConcreteUnification(oTy, preImage, checkInstantiations)) {
188 return true;
189 }
190 }
191 }
192 return false;
193 };
194
195 if (isMoreConcreteUnification(oldType, newType, checkInstantiations)) {
196 return true;
197 }
198 LLVM_DEBUG(
199 llvm::dbgs() << "[" << patName << "] Cannot replace old type " << oldType
200 << " with new type " << newType
201 << " because it does not define a compatible and more concrete type.\n";
202 );
203 return false;
204 }
205
206 template <typename T, typename U>
207 inline bool areLegalConversions(T oldTypes, U newTypes, const char *patName) const {
208 return llvm::all_of(
209 llvm::zip_equal(oldTypes, newTypes), [this, &patName](std::tuple<Type, Type> oldThenNew) {
210 return this->isLegalConversion(std::get<0>(oldThenNew), std::get<1>(oldThenNew), patName);
211 }
212 );
213 }
214};
215
216template <typename Impl, typename Op, typename... HandledAttrs>
217class SymbolUserHelper : public OpConversionPattern<Op> {
218private:
219 const DenseMap<Attribute, Attribute> &paramNameToValue;
220
221 SymbolUserHelper(
222 TypeConverter &converter, MLIRContext *ctx, unsigned patternBenefit,
223 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
224 )
225 : OpConversionPattern<Op>(converter, ctx, patternBenefit),
226 paramNameToValue(paramNameToInstantiatedValue) {}
227
228public:
229 using OpAdaptor = typename mlir::OpConversionPattern<Op>::OpAdaptor;
230
231 virtual Attribute getNameAttr(Op) const = 0;
232
233 virtual LogicalResult handleDefaultRewrite(
234 Attribute, Op op, OpAdaptor, ConversionPatternRewriter &, Attribute a
235 ) const {
236 return op->emitOpError().append("expected value with type ", op.getType(), " but found ", a);
237 }
238
239 LogicalResult
240 matchAndRewrite(Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override {
241 LLVM_DEBUG(llvm::dbgs() << "[SymbolUserHelper] op: " << op << '\n');
242 auto res = this->paramNameToValue.find(getNameAttr(op));
243 if (res == this->paramNameToValue.end()) {
244 LLVM_DEBUG(llvm::dbgs() << "[SymbolUserHelper] no instantiation for " << op << '\n');
245 return failure();
246 }
247 llvm::TypeSwitch<Attribute, LogicalResult> TS(res->second);
248 llvm::TypeSwitch<Attribute, LogicalResult> *ptr = &TS;
249
250 ((ptr = &(ptr->template Case<HandledAttrs>([&](HandledAttrs a) {
251 return static_cast<const Impl *>(this)->handleRewrite(res->first, op, adaptor, rewriter, a);
252 }))),
253 ...);
254
255 return TS.Default([&](Attribute a) {
256 return handleDefaultRewrite(res->first, op, adaptor, rewriter, a);
257 });
258 }
259 friend Impl;
260};
261
262class ClonedBodyConstReadOpPattern
263 : public SymbolUserHelper<
264 ClonedBodyConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr> {
265 SmallVector<Diagnostic> &diagnostics;
266
267 using super =
268 SymbolUserHelper<ClonedBodyConstReadOpPattern, ConstReadOp, IntegerAttr, FeltConstAttr>;
269
270public:
271 ClonedBodyConstReadOpPattern(
272 TypeConverter &converter, MLIRContext *ctx,
273 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue,
274 SmallVector<Diagnostic> &instantiationDiagnostics
275 )
276 // benefit>0 so this applies instead of GeneralTypeReplacePattern<ConstReadOp>
277 : super(converter, ctx, /*patternBenefit=*/1, paramNameToInstantiatedValue),
278 diagnostics(instantiationDiagnostics) {}
279
280 Attribute getNameAttr(ConstReadOp op) const override { return op.getConstNameAttr(); }
281
282 LogicalResult handleRewrite(
283 Attribute sym, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, IntegerAttr a
284 ) const {
285 APInt attrValue = a.getValue();
286 Type origResTy = op.getType();
287 if (FeltType ty = llvm::dyn_cast<FeltType>(origResTy)) {
289 rewriter, op, FeltConstAttr::get(getContext(), attrValue, ty)
290 );
291 return success();
292 }
293
294 if (llvm::isa<IndexType>(origResTy)) {
296 return success();
297 }
298
299 if (origResTy.isSignlessInteger(1)) {
300 // Treat 0 as false and any other value as true (but give a warning if it's not 1)
301 if (attrValue.isZero()) {
302 replaceOpWithNewOp<arith::ConstantIntOp>(rewriter, op, false, origResTy);
303 return success();
304 }
305 if (!attrValue.isOne()) {
306 Location opLoc = op.getLoc();
307 Diagnostic diag(opLoc, DiagnosticSeverity::Warning);
308 diag << "Interpreting non-zero value " << stringWithoutType(a) << " as true";
309 if (getContext()->shouldPrintOpOnDiagnostic()) {
310 diag.attachNote(opLoc) << "see current operation: " << *op;
311 }
312 diag.attachNote(UnknownLoc::get(getContext()))
313 << "when instantiating '" << StructDefOp::getOperationName() << "' parameter \"" << sym
314 << "\" for this call";
315 diagnostics.push_back(std::move(diag));
316 }
317 replaceOpWithNewOp<arith::ConstantIntOp>(rewriter, op, true, origResTy);
318 return success();
319 }
320 return op->emitOpError().append("unexpected result type ", origResTy);
321 }
322
323 LogicalResult handleRewrite(
324 Attribute, ConstReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, FeltConstAttr a
325 ) const {
326 replaceOpWithNewOp<FeltConstantOp>(rewriter, op, a);
327 return success();
328 }
329};
330
333struct MatchFailureListener : public RewriterBase::Listener {
334 bool hadFailure = false;
335
336 ~MatchFailureListener() override {}
337
338 void notifyMatchFailure(Location loc, function_ref<void(Diagnostic &)> reasonCallback) override {
339 hadFailure = true;
340
341 InFlightDiagnostic diag = emitError(loc);
342 reasonCallback(*diag.getUnderlyingDiagnostic());
343 diag.report();
344 }
345};
346
347static LogicalResult
348applyAndFoldGreedily(ModuleOp modOp, ConversionTracker &tracker, RewritePatternSet &&patterns) {
349 bool currStepModified = false;
350 MatchFailureListener failureListener;
351 LogicalResult result = applyPatternsGreedily(
352 modOp->getRegion(0), std::move(patterns),
353 GreedyRewriteConfig {.maxIterations = 20, .listener = &failureListener, .fold = true},
354 &currStepModified
355 );
356 tracker.updateModifiedFlag(currStepModified);
357 return failure(result.failed() || failureListener.hadFailure);
358}
359
360template <bool AllowStructParams = true> bool isConcreteAttr(Attribute a) {
361 if (TypeAttr tyAttr = dyn_cast<TypeAttr>(a)) {
362 return isConcreteType(tyAttr.getValue(), AllowStructParams);
363 }
364 if (IntegerAttr intAttr = dyn_cast<IntegerAttr>(a)) {
365 return !isDynamic(intAttr);
366 }
367 return false;
368}
369
374static std::optional<Attribute>
375evaluateExpr(TemplateExprOp exprOp, const DenseMap<Attribute, Attribute> &paramNameToConcrete) {
376 // Map from SSA value in the expr body to its concrete Attribute.
377 DenseMap<Value, Attribute> valueMap;
378 for (Operation &bodyOp : exprOp.getInitializerRegion().front()) {
379 if (auto yieldOp = llvm::dyn_cast<YieldOp>(bodyOp)) {
380 auto it = valueMap.find(yieldOp.getVal());
381 return it != valueMap.end() ? std::make_optional(it->second) : std::nullopt;
382 }
383
384 if (auto constReadOp = llvm::dyn_cast<ConstReadOp>(bodyOp)) {
385 auto it = paramNameToConcrete.find(constReadOp.getConstNameAttr());
386 if (it == paramNameToConcrete.end()) {
387 return std::nullopt; // a referenced param is not concrete
388 }
389 // If the attribute type is `FeltType` but it's stored as an IntegerAttr, promote to
390 // a `FeltConstAttr`.
391 Attribute val = it->second;
392 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(val)) {
393 if (auto feltTy = llvm::dyn_cast<FeltType>(constReadOp.getResult().getType())) {
394 val = FeltConstAttr::get(bodyOp.getContext(), intAttr.getValue(), feltTy);
395 }
396 }
397 valueMap[constReadOp.getResult()] = val;
398 continue;
399 }
400
401 // Gather constant attributes for all operands.
402 SmallVector<Attribute> operandAttrs;
403 operandAttrs.reserve(bodyOp.getNumOperands());
404 for (Value operand : bodyOp.getOperands()) {
405 auto it = valueMap.find(operand);
406 if (it == valueMap.end()) {
407 return std::nullopt; // operand not known as a constant
408 }
409 operandAttrs.push_back(it->second);
410 }
411
412 // Try constant folding.
413 SmallVector<OpFoldResult> foldResults;
414 if (succeeded(bodyOp.fold(operandAttrs, foldResults)) &&
415 foldResults.size() == bodyOp.getNumResults()) {
416 for (auto [result, fr] : llvm::zip_equal(bodyOp.getResults(), foldResults)) {
417 if (Attribute a = llvm::dyn_cast<Attribute>(fr)) {
418 valueMap[result] = a;
419 } else {
420 return std::nullopt;
421 }
422 }
423 }
424 }
425 return std::nullopt; // no YieldOp found (shouldn't happen in a valid expr)
426}
427
431static void
432evaluateTemplateExprs(TemplateOp templateOp, DenseMap<Attribute, Attribute> &paramNameToConcrete) {
433 LLVM_DEBUG(
434 llvm::dbgs() << "[evaluateTemplateExprs] before: " << debug::toStringList(paramNameToConcrete)
435 << '\n'
436 );
437 for (TemplateExprOp exprOp : templateOp.getConstOps<TemplateExprOp>()) {
438 std::optional<Attribute> result = evaluateExpr(exprOp, paramNameToConcrete);
439 if (result.has_value()) {
440 auto exprNameAttr = FlatSymbolRefAttr::get(exprOp.getSymNameAttr());
441 paramNameToConcrete.try_emplace(exprNameAttr, *result);
442 LLVM_DEBUG(
443 llvm::dbgs() << "[evaluateTemplateExprs] expr @" << exprOp.getSymName()
444 << " evaluated to " << *result << '\n'
445 );
446 }
447 }
448 LLVM_DEBUG(
449 llvm::dbgs() << "[evaluateTemplateExprs] after: " << debug::toStringList(paramNameToConcrete)
450 << '\n'
451 );
452}
453
455
456static inline bool tableOffsetIsntSymbol(MemberReadOp op) {
457 return !llvm::isa_and_present<SymbolRefAttr>(op.getTableOffset().value_or(nullptr));
458}
459
462class StructCloner {
463 ConversionTracker &tracker_;
464 ModuleOp rootMod;
465 SymbolTableCollection symTables;
466 bool reportMissing = true;
467
468 class MappedTypeConverter : public TypeConverter {
469 StructType origTy;
470 StructType newTy;
471 const DenseMap<Attribute, Attribute> &paramNameToValue;
472
473 inline Attribute convertIfPossible(Attribute a) const {
474 auto res = this->paramNameToValue.find(a);
475 return (res != this->paramNameToValue.end()) ? res->second : a;
476 }
477
478 public:
479 MappedTypeConverter(
480 StructType originalType, StructType newType,
482 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
483 )
484 : TypeConverter(), origTy(originalType), newTy(newType),
485 paramNameToValue(paramNameToInstantiatedValue) {
486
487 addConversion([](Type inputTy) { return inputTy; });
488
489 addConversion([this](StructType inputTy) {
490 LLVM_DEBUG(llvm::dbgs() << "[MappedTypeConverter] convert " << inputTy << '\n');
491
492 // Check for replacement of the full type
493 if (inputTy == this->origTy) {
494 return this->newTy;
495 }
496 // Check for replacement of parameter symbol names with concrete values
497 if (ArrayAttr inputTyParams = inputTy.getParams()) {
498 SmallVector<Attribute> updated;
499 for (Attribute a : inputTyParams) {
500 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
501 updated.push_back(TypeAttr::get(this->convertType(ta.getValue())));
502 } else {
503 updated.push_back(convertIfPossible(a));
504 }
505 }
506 return StructType::get(
507 inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
508 );
509 }
510 // Otherwise, return the type unchanged
511 return inputTy;
512 });
513
514 addConversion([this](ArrayType inputTy) {
515 // Check for replacement of parameter symbol names with concrete values
516 ArrayRef<Attribute> dimSizes = inputTy.getDimensionSizes();
517 if (!dimSizes.empty()) {
518 SmallVector<Attribute> updated;
519 for (Attribute a : dimSizes) {
520 updated.push_back(convertIfPossible(a));
521 }
522 return ArrayType::get(this->convertType(inputTy.getElementType()), updated);
523 }
524 // Otherwise, return the type unchanged
525 return inputTy;
526 });
527
528 addConversion([this](TypeVarType inputTy) -> Type {
529 // Check for replacement of parameter symbol name with a concrete type
530 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.getNameRef()))) {
531 Type convertedType = tyAttr.getValue();
532 // Use the new type unless it contains a TypeVarType because a TypeVarType from a
533 // different struct references a parameter name from that other struct, not from the
534 // current struct so the reference would be invalid.
535 if (isConcreteType(convertedType)) {
536 return convertedType;
537 }
538 }
539 return inputTy;
540 });
541 }
542 };
543
544 class ClonedStructMemberReadOpPattern
545 : public SymbolUserHelper<
546 ClonedStructMemberReadOpPattern, MemberReadOp, IntegerAttr, FeltConstAttr> {
547 using super =
548 SymbolUserHelper<ClonedStructMemberReadOpPattern, MemberReadOp, IntegerAttr, FeltConstAttr>;
549
550 public:
551 ClonedStructMemberReadOpPattern(
552 TypeConverter &converter, MLIRContext *ctx,
553 const DenseMap<Attribute, Attribute> &paramNameToInstantiatedValue
554 )
555 // benefit>0 so this applies instead of GeneralTypeReplacePattern<MemberReadOp>
556 : super(converter, ctx, /*patternBenefit=*/1, paramNameToInstantiatedValue) {}
557
558 Attribute getNameAttr(MemberReadOp op) const override {
559 return op.getTableOffset().value_or(nullptr);
560 }
561
562 template <typename Attr>
563 LogicalResult handleRewrite(
564 Attribute, MemberReadOp op, OpAdaptor, ConversionPatternRewriter &rewriter, Attr a
565 ) const {
566 rewriter.modifyOpInPlace(op, [&]() {
567 op.setTableOffsetAttr(rewriter.getIndexAttr(fromAPInt(a.getValue())));
568 });
569
570 return success();
571 }
572
573 LogicalResult matchAndRewrite(
574 MemberReadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
575 ) const override {
576 LLVM_DEBUG(
577 llvm::dbgs() << "[ClonedStructMemberReadOpPattern] MemberReadOp: " << op << '\n';
578 );
579 if (tableOffsetIsntSymbol(op)) {
580 return failure();
581 }
582
583 return super::matchAndRewrite(op, adaptor, rewriter);
584 }
585 };
586
587 FailureOr<StructType> genClone(StructType typeAtCaller, ArrayRef<Attribute> typeAtCallerParams) {
588 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] attempting clone of " << typeAtCaller << '\n');
589 // Find the StructDefOp for the original StructType
590 FailureOr<SymbolLookupResult<StructDefOp>> r =
591 typeAtCaller.getDefinition(symTables, rootMod, reportMissing);
592 if (failed(r)) {
593 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: cannot find StructDefOp \n");
594 return failure(); // getDefinition() already emits a sufficient error message
595 }
596 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] found definition\n";);
597
598 StructDefOp origStruct = r->get();
599 StructType typeAtDef = origStruct.getType();
600 MLIRContext *ctx = origStruct.getContext();
601
602 // Map of StructDefOp parameter name to concrete Attribute at the current instantiation site.
603 DenseMap<Attribute, Attribute> paramNameToConcrete;
604 // List of concrete Attributes from the struct instantiation with `nullptr` at any positions
605 // where the original attribute from the current instantiation site was not concrete. This is
606 // used for generating the new struct name. See `BuildShortTypeString::from()`.
607 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
608 // List of template const param names that must be preserved because they
609 // were not assigned concrete values at the current instantiation site.
610 SmallVector<Attribute> remainingNames;
611 // Reduced from `typeAtCallerParams` to contain only the non-concrete Attributes.
612 ArrayAttr reducedCallerParams = nullptr;
613 {
614 ArrayAttr paramNames = typeAtDef.getParams();
615
616 // pre-conditions
617 assert(!isNullOrEmpty(paramNames));
618 assert(paramNames.size() == typeAtCallerParams.size());
619
620 SmallVector<Attribute> nonConcreteParams;
621 for (size_t i = 0, e = paramNames.size(); i < e; ++i) {
622 Attribute next = typeAtCallerParams[i];
623 if (isConcreteAttr<false>(next)) {
624 paramNameToConcrete[paramNames[i]] = next;
625 attrsForInstantiatedNameSuffix.push_back(next);
626 } else {
627 remainingNames.push_back(paramNames[i]);
628 nonConcreteParams.push_back(next);
629 attrsForInstantiatedNameSuffix.push_back(nullptr);
630 }
631 }
632 // post-conditions
633 assert(remainingNames.size() == nonConcreteParams.size());
634 assert(attrsForInstantiatedNameSuffix.size() == paramNames.size());
635 assert(remainingNames.size() + paramNameToConcrete.size() == paramNames.size());
636
637 if (paramNameToConcrete.empty()) {
638 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: no concrete params \n");
639 return failure();
640 }
641 if (!remainingNames.empty()) {
642 reducedCallerParams = ArrayAttr::get(ctx, nonConcreteParams);
643 }
644 }
645
646 // This list will be used to build the new remote/external type.
647 SmallVector<FlatSymbolRefAttr> typeAtCallerSymPieces = getPieces(typeAtCaller.getNameRef());
648 typeAtCallerSymPieces.pop_back(); // drop struct name
649 // Name of template with instantiated parameter values.
650 std::string templateNameWithAttrs = BuildShortTypeString::from(
651 typeAtCallerSymPieces.back().getValue().str(), attrsForInstantiatedNameSuffix
652 );
653
654 // Get parent refs
655 TemplateOp parentTemplate = getParentOfType<TemplateOp>(origStruct);
656 assert(parentTemplate && "parameterized struct must be nested in a TemplateOp");
657 ModuleOp parentModule = getParentOfType<ModuleOp>(parentTemplate);
658 assert(parentModule && "TemplateOp must be nested in a ModuleOp");
659
660 // Evaluate any poly.expr symbols whose param dependencies are now concrete; add them to the
661 // map so ClonedBodyConstReadOpPattern can replace uses of those symbols too.
662 evaluateTemplateExprs(parentTemplate, paramNameToConcrete);
663
664 // Clone the original struct.
665 StructDefOp newStruct = origStruct.clone();
666 if (remainingNames.empty()) { // FULL INSTANTIATION CASE
667 // Set name of the new struct by prepending its name with instantiated template name.
668 newStruct.setSymName(
669 (templateNameWithAttrs + mlir::Twine('_') + newStruct.getSymName()).str()
670 );
671 // Insert 'newStruct' into the parent ModuleOp of the original TemplateOp. Use the
672 // `SymbolTable::insert()` function so that the name will be made unique if necessary.
673 symTables.getSymbolTable(parentModule).insert(newStruct, Block::iterator(parentTemplate));
674 // Drop the old template name from the list.
675 typeAtCallerSymPieces.pop_back();
676 } else { // PARTIAL INSTANTIATION CASE
677 // Clone the template and set instantiated name.
678 TemplateOp newTemplate = parentTemplate.cloneWithoutRegions();
679 newTemplate.setSymName(templateNameWithAttrs);
680 assert(newTemplate->getNumRegions() > 0 && "region exists"); // it just doesn't have a block
681 newTemplate.getBodyRegion().emplaceBlock();
682
683 // Clone preserved const param/expr ops.
684 for (Attribute name : remainingNames) {
685 FlatSymbolRefAttr nameSym = llvm::dyn_cast<FlatSymbolRefAttr>(name);
686 assert(nameSym && "expected FlatSymbolRefAttr");
687
688 Operation *symOp = symTables.getSymbolTable(parentTemplate).lookup(nameSym.getAttr());
689 assert(symOp && "symbol must exist");
690 newTemplate.insert(newTemplate.begin(), symOp->clone());
691 }
692
693 // Insert the struct into the template and the template into the module. Use the
694 // `SymbolTable::insert()` function so that the name will be made unique if necessary.
695 symTables.getSymbolTable(newTemplate).insert(newStruct);
696 symTables.getSymbolTable(parentModule).insert(newTemplate, Block::iterator(parentTemplate));
697
698 // Replace the old template name in the list with the new one (get template name after
699 // symbol table insertion since it may be modified to make it unique).
700 typeAtCallerSymPieces.back() = FlatSymbolRefAttr::get(newTemplate.getSymNameAttr());
701 }
702
703 // Retrieve the new type AFTER inserting since the struct name may be appended to make
704 // it unique and use the remaining non-concrete parameters from the original type.
705 StructType newLocalType = newStruct.getType(reducedCallerParams);
706 typeAtCallerSymPieces.push_back(
707 FlatSymbolRefAttr::get(newLocalType.getNameRef().getLeafReference())
708 );
709 StructType newRemoteType =
710 StructType::get(asSymbolRefAttr(typeAtCallerSymPieces), newLocalType.getParams());
711 LLVM_DEBUG({
712 llvm::dbgs() << "[StructCloner] original def type: " << typeAtDef << '\n';
713 llvm::dbgs() << "[StructCloner] cloned def type: " << newStruct.getType() << '\n';
714 llvm::dbgs() << "[StructCloner] original remote type: " << typeAtCaller << '\n';
715 llvm::dbgs() << "[StructCloner] cloned local type: " << newLocalType << '\n';
716 llvm::dbgs() << "[StructCloner] cloned remote type: " << newRemoteType << '\n';
717 });
718
719 // Within the new struct, replace all references to the original StructType (i.e., the
720 // locally-parameterized version) with the new locally-parameterized StructType,
721 // and replace all uses of the removed struct parameters with the concrete values.
722 MappedTypeConverter tyConv(typeAtDef, newStruct.getType(), paramNameToConcrete);
723 ConversionTarget target =
724 newConverterDefinedTarget<EmitEqualityOp>(tyConv, ctx, tableOffsetIsntSymbol);
725 target.addDynamicallyLegalOp<ConstReadOp>([&paramNameToConcrete](ConstReadOp op) {
726 // Legal if it's not in the map of concrete attribute instantiations
727 return !paramNameToConcrete.contains(op.getConstNameAttr());
728 });
729
730 RewritePatternSet patterns = newGeneralRewritePatternSet<EmitEqualityOp>(tyConv, ctx, target);
731 patterns.add<ClonedBodyConstReadOpPattern>(
732 tyConv, ctx, paramNameToConcrete, tracker_.delayedDiagnosticSet(newLocalType)
733 );
734 patterns.add<ClonedStructMemberReadOpPattern>(tyConv, ctx, paramNameToConcrete);
735 if (failed(applyFullConversion(newStruct, target, std::move(patterns)))) {
736 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] instantiating body of struct failed \n");
737 return failure();
738 }
739 return newRemoteType;
740 }
741
742public:
743 StructCloner(ConversionTracker &tracker, ModuleOp root)
744 : tracker_(tracker), rootMod(root), symTables() {}
745
746 FailureOr<StructType> createInstantiatedClone(StructType orig) {
747 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] orig: " << orig << '\n');
748 if (ArrayAttr params = orig.getParams()) {
749 return genClone(orig, params.getValue());
750 }
751 LLVM_DEBUG(llvm::dbgs() << "[StructCloner] skip: nullptr for params \n");
752 return failure();
753 }
754
755 void enableReportMissing() { reportMissing = true; }
756
757 void disableReportMissing() { reportMissing = false; }
758};
759
760class DisableReportMissing;
761
762class ParameterizedStructUseTypeConverter : public TypeConverter {
763 ConversionTracker &tracker_;
764 StructCloner cloner;
765
766 friend DisableReportMissing;
767
768public:
769 ParameterizedStructUseTypeConverter(ConversionTracker &tracker, ModuleOp root)
770 : TypeConverter(), tracker_(tracker), cloner(tracker, root) {
771
772 addConversion([](Type inputTy) { return inputTy; });
773
774 addConversion([this](StructType inputTy) -> StructType {
775 LLVM_DEBUG(
776 llvm::dbgs() << "[ParameterizedStructUseTypeConverter] attempting conversion of "
777 << inputTy << '\n';
778 );
779 // First check for a cached entry
780 if (auto opt = tracker_.getInstantiation(inputTy)) {
781 return opt.value();
782 }
783
784 // Otherwise, try to create a clone of the struct with instantiated params. If that can't be
785 // done, return the original type to indicate that it's still legal (for this step at least).
786 FailureOr<StructType> cloneRes = cloner.createInstantiatedClone(inputTy);
787 if (failed(cloneRes)) {
788 return inputTy;
789 }
790 StructType newTy = cloneRes.value();
791 LLVM_DEBUG(
792 llvm::dbgs() << "[ParameterizedStructUseTypeConverter] instantiating " << inputTy
793 << " as " << newTy << '\n'
794 );
795 tracker_.recordInstantiation(inputTy, newTy);
796 return newTy;
797 });
798
799 addConversion([this](ArrayType inputTy) {
800 return inputTy.cloneWith(convertType(inputTy.getElementType()));
801 });
802 }
803};
804
805class CallStructFuncPattern : public OpConversionPattern<CallOp> {
806 ConversionTracker &tracker_;
807
808public:
809 CallStructFuncPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &tracker)
810 // benefit>0 so this applies instead of CallOpClassReplacePattern
811 : OpConversionPattern<CallOp>(converter, ctx, /*benefit=*/1), tracker_(tracker) {}
812
813 LogicalResult matchAndRewrite(
814 CallOp op, OpAdaptor adapter, ConversionPatternRewriter &rewriter
815 ) const override {
816 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] CallOp: " << op << '\n');
817
818 // Convert the result types of the CallOp
819 SmallVector<Type> newResultTypes;
820 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
821 return op->emitError("Could not convert Op result types.");
822 }
823 LLVM_DEBUG({
824 llvm::dbgs() << "[CallStructFuncPattern] newResultTypes: "
825 << debug::toStringList(newResultTypes) << '\n';
826 });
827
828 // Update the callee to reflect the new struct target if necessary. These checks are based on
829 // `CallOp::calleeIsStructC*()` but the types must not come from the CallOp in this case.
830 // Instead they must come from the converted versions.
831 SymbolRefAttr calleeAttr = op.getCalleeAttr();
832 if (op.calleeIsStructCompute()) {
833 if (StructType newStTy = getIfSingleton<StructType>(newResultTypes)) {
834 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] newStTy: " << newStTy << '\n');
835 calleeAttr = appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
836 tracker_.reportDelayedDiagnostics(newStTy, op);
837 }
838 } else if (op.calleeIsStructConstrain()) {
839 if (StructType newStTy = getAtIndex<StructType>(adapter.getArgOperands().getTypes(), 0)) {
840 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] newStTy: " << newStTy << '\n');
841 calleeAttr = appendLeaf(newStTy.getNameRef(), calleeAttr.getLeafReference());
842 }
843 }
844
845 LLVM_DEBUG(llvm::dbgs() << "[CallStructFuncPattern] replaced " << op);
847 rewriter, op, newResultTypes, calleeAttr, adapter.getMapOperands(),
848 op.getNumDimsPerMapAttr(), adapter.getArgOperands()
849 );
850 (void)newOp; // tell compiler it's intentionally unused in release builds
851 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
852 return success();
853 }
854};
855
856// This one ensures MemberDefOp types are converted even if there are no reads/writes to them.
857class MemberDefOpPattern : public OpConversionPattern<MemberDefOp> {
858public:
859 MemberDefOpPattern(TypeConverter &converter, MLIRContext *ctx, ConversionTracker &)
860 // benefit>0 so this applies instead of GeneralTypeReplacePattern<MemberDefOp>
861 : OpConversionPattern<MemberDefOp>(converter, ctx, /*benefit=*/1) {}
862
863 LogicalResult matchAndRewrite(
864 MemberDefOp op, OpAdaptor /*adapter*/, ConversionPatternRewriter &rewriter
865 ) const override {
866 LLVM_DEBUG(llvm::dbgs() << "[MemberDefOpPattern] MemberDefOp: " << op << '\n');
867
868 Type oldMemberType = op.getType();
869 Type newMemberType = getTypeConverter()->convertType(oldMemberType);
870 if (oldMemberType == newMemberType) {
871 return failure(); // nothing changed
872 }
873 rewriter.modifyOpInPlace(op, [&op, &newMemberType]() { op.setType(newMemberType); });
874 return success();
875 }
876};
877
880class DisableReportMissing : public LegalityCheckCallback {
881 ParameterizedStructUseTypeConverter &tyConv;
882
883public:
884 explicit DisableReportMissing(ParameterizedStructUseTypeConverter &tc) : tyConv(tc) {}
885
886 void checkStarted() override { tyConv.cloner.disableReportMissing(); }
887
888 void checkEnded(bool) override { tyConv.cloner.enableReportMissing(); }
889};
890
891LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
892 MLIRContext *ctx = modOp.getContext();
893 ParameterizedStructUseTypeConverter tyConv(tracker, modOp);
894 DisableReportMissing drm(tyConv);
895 ConversionTarget target = newConverterDefinedTargetWithCallback<>(tyConv, ctx, drm);
896 RewritePatternSet patterns = newGeneralRewritePatternSet(tyConv, ctx, target);
897 patterns.add<CallStructFuncPattern, MemberDefOpPattern>(tyConv, ctx, tracker);
898 return applyPartialConversion(modOp, target, std::move(patterns));
899}
900
901} // namespace Step1A_InstantiateStructs
902
904
907class FuncInstTypeConverter : public TypeConverter {
908 DenseMap<Attribute, Attribute> paramNameToValue;
909
910 Attribute convertIfPossible(Attribute a) const {
911 auto res = paramNameToValue.find(a);
912 return (res != paramNameToValue.end()) ? res->second : a;
913 }
914
915public:
916 explicit FuncInstTypeConverter(DenseMap<Attribute, Attribute> paramNameToConcrete)
917 : TypeConverter(), paramNameToValue(std::move(paramNameToConcrete)) {
918 addConversion([](Type t) { return t; });
919
920 addConversion([this](TypeVarType inputTy) -> Type {
921 if (TypeAttr tyAttr = llvm::dyn_cast<TypeAttr>(convertIfPossible(inputTy.getNameRef()))) {
922 Type convertedType = tyAttr.getValue();
923 if (isConcreteType(convertedType)) {
924 return convertedType;
925 }
926 }
927 return inputTy;
928 });
929
930 addConversion([this](ArrayType inputTy) {
931 SmallVector<Attribute> updated;
932 bool changed = false;
933 for (Attribute a : inputTy.getDimensionSizes()) {
934 Attribute converted = convertIfPossible(a);
935 updated.push_back(converted);
936 if (converted != a) {
937 changed = true;
938 }
939 }
940 Type newElemTy = this->convertType(inputTy.getElementType());
941 if (!changed && newElemTy == inputTy.getElementType()) {
942 return inputTy;
943 }
944 return ArrayType::get(newElemTy, updated);
945 });
946
947 addConversion([this](StructType inputTy) -> StructType {
948 if (ArrayAttr params = inputTy.getParams()) {
949 SmallVector<Attribute> updated;
950 bool changed = false;
951 for (Attribute a : params) {
952 if (TypeAttr ta = dyn_cast<TypeAttr>(a)) {
953 Type newTy = this->convertType(ta.getValue());
954 if (newTy != ta.getValue()) {
955 updated.push_back(TypeAttr::get(newTy));
956 changed = true;
957 continue;
958 }
959 } else {
960 Attribute converted = convertIfPossible(a);
961 if (converted != a) {
962 updated.push_back(converted);
963 changed = true;
964 continue;
965 }
966 }
967 updated.push_back(a);
968 }
969 if (changed) {
970 return StructType::get(
971 inputTy.getNameRef(), ArrayAttr::get(inputTy.getContext(), updated)
972 );
973 }
974 }
975 return inputTy;
976 });
977 }
978
979 bool containsParam(Attribute nameAttr) const { return paramNameToValue.contains(nameAttr); }
980 const DenseMap<Attribute, Attribute> &getParamMap() const { return paramNameToValue; }
981};
982
983class InstantiateFuncAtCallOp final : public OpRewritePattern<CallOp> {
984 ConversionTracker &tracker_;
985
986public:
987 InstantiateFuncAtCallOp(MLIRContext *ctx, ConversionTracker &tracker)
988 : OpRewritePattern<CallOp>(ctx), tracker_(tracker) {}
989
990 LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override {
991 LLVM_DEBUG(llvm::dbgs() << "[InstantiateFuncAtCallOp] op: " << op << '\n');
992
993 // Lookup callee target function
994 SymbolTableCollection symTables;
995 FailureOr<SymbolLookupResult<FuncDefOp>> callTgtOpt = op.getCalleeTarget(symTables);
996 if (failed(callTgtOpt)) {
997 return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) {
998 diag << "could not find target function for call";
999 });
1000 }
1001 FuncDefOp callTgt = callTgtOpt->get();
1002
1003 // Check if callee is within a TemplateOp
1004 TemplateOp parentTemplate = llvm::dyn_cast<TemplateOp>(callTgt->getParentOp());
1005 if (!parentTemplate) {
1006 return failure(); // nothing to do if not parameterized
1007 }
1008 LLVM_DEBUG(
1009 llvm::dbgs() << "[InstantiateFuncAtCallOp] target function in template "
1010 << parentTemplate.getSymName() << '\n'
1011 );
1012
1013 // Perform type unification with tracking to infer the instantiated type(s). Even though
1014 // `CallOp` verification already checked that caller and callee types unify, the progress of
1015 // instantiation so far may have brought together a chain of calls across templates where each
1016 // individual unification check passed due to permissive type variables and/or symbols in the
1017 // middle but the overall chain does not unify. Hence, this unification may fail and should
1018 // produce a meaningful error message if it does.
1019 // See: `test/Transforms/Flattening/instantiate_funcs_fail.llzk`
1020 FailureOr<UnificationMap> unifyResult = op.unifyTypeSignature(callTgt.getFunctionType());
1021 if (failed(unifyResult)) {
1022 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1023 diag.append("target function type does not unify with call type ")
1024 .append(op.getTypeSignature())
1025 .attachNote(callTgt.getLoc())
1026 .append("target function declared here");
1027 });
1028 }
1029 LLVM_DEBUG(
1030 llvm::dbgs() << "[InstantiateFuncAtCallOp] unifications of types: "
1031 << debug::toStringList(unifyResult.value()) << '\n'
1032 );
1033
1034 // Maps template parameter symbols to the instantiation value at the call site.
1035 DenseMap<Attribute, Attribute> paramNameToConcrete;
1036 // If template instantiation list is given, must use that. Otherwise, infer.
1037 auto realParams = parentTemplate.getConstOps<TemplateParamOp>();
1038 ArrayAttr callParams = op.getTemplateParamsAttr();
1039 LLVM_DEBUG(
1040 llvm::dbgs() << "[InstantiateFuncAtCallOp] TemplateParamsAttr: " << callParams << '\n'
1041 );
1042 if (isNullOrEmpty(callParams)) {
1043 for (auto paramOp : realParams) {
1044 auto paramName = FlatSymbolRefAttr::get(paramOp.getSymNameAttr());
1045 auto it = unifyResult->find({paramName, Side::RHS});
1046 if (it == unifyResult->end()) {
1047 LLVM_DEBUG(
1048 llvm::dbgs() << "[InstantiateFuncAtCallOp] unification for param '" << paramName
1049 << "': not found\n"
1050 );
1051 continue;
1052 }
1053 Attribute inferredVal = it->second;
1054 if (!isConcreteAttr(inferredVal)) {
1055 LLVM_DEBUG(
1056 llvm::dbgs() << "[InstantiateFuncAtCallOp] unification for param '" << paramName
1057 << "': not concrete, " << inferredVal << '\n'
1058 );
1059 continue;
1060 }
1061 // Ensure it's a valid value for the optional type restriction on the TemplateParamOp
1062 if (failed(op.verifyTemplateParamCompatibility(inferredVal, paramOp))) {
1063 LLVM_DEBUG(
1064 llvm::dbgs() << "[InstantiateFuncAtCallOp] unification for param '" << paramName
1065 << "': incompatible with specified param type. MUST FAIL!\n"
1066 );
1067 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1068 diag.append("inferred value for parameter '")
1069 .append(paramName)
1070 .append("' is incompatible with specified param type")
1071 .attachNote(paramOp.getLoc())
1072 .append("template parameter declared here");
1073 });
1074 }
1075 paramNameToConcrete[paramName] = inferredVal;
1076 }
1077 } else {
1078 // As stated earlier, need to run the verification checks again to ensure the
1079 // instantiation is valid, except for the size check becuase that cannot change.
1080 assert((callParams.size() == llvm::range_size(realParams)) && "per CallOpVerifier");
1081 if (failed(op.verifyTemplateParamCompatibility(realParams))) {
1082 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1083 diag.append("incompatible with specified param type(s)");
1084 });
1085 }
1086 if (failed(op.verifyTemplateParamsMatchInferred(realParams, unifyResult.value()))) {
1087 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1088 diag.append("incompatible with inferred param value(s)");
1089 });
1090 }
1091 // Add the mappings
1092 for (auto [paramOp, attr] : llvm::zip_equal(realParams, callParams.getValue())) {
1093 auto paramName = FlatSymbolRefAttr::get(paramOp.getSymNameAttr());
1094 if (!isConcreteAttr(attr)) {
1095 LLVM_DEBUG(
1096 llvm::dbgs() << "[InstantiateFuncAtCallOp] unification for param '" << paramName
1097 << "': not concrete, " << attr << '\n'
1098 );
1099 continue;
1100 }
1101 paramNameToConcrete[paramName] = attr;
1102 }
1103 }
1104
1105 if (paramNameToConcrete.empty()) {
1106 LLVM_DEBUG(llvm::dbgs() << "[InstantiateFuncAtCallOp] skip: no concrete params\n");
1107 return failure();
1108 }
1109
1110 // Evaluate any poly.expr symbols whose param dependencies are now concrete; add them to the
1111 // map so ClonedFuncConstReadOpPattern can replace uses of those symbols too.
1112 evaluateTemplateExprs(parentTemplate, paramNameToConcrete);
1113
1114 // Classify each template parameter as concrete (to be inlined) or remaining (to be preserved).
1115 SmallVector<Attribute> remainingNames;
1116 SmallVector<Attribute> attrsForInstantiatedNameSuffix;
1117 for (Attribute paramName : parentTemplate.getConstNames<TemplateParamOp>()) {
1118 auto it = paramNameToConcrete.find(paramName);
1119 if (it != paramNameToConcrete.end()) {
1120 attrsForInstantiatedNameSuffix.push_back(it->second);
1121 } else {
1122 attrsForInstantiatedNameSuffix.push_back(nullptr); // placeholder for non-concrete param
1123 remainingNames.push_back(paramName);
1124 }
1125 }
1126
1127 MLIRContext *ctx = op.getContext();
1128 ModuleOp parentModule = getParentOfType<ModuleOp>(parentTemplate);
1129 assert(parentModule && "TemplateOp must be nested in a ModuleOp");
1130
1131 // Build the (partially-)instantiated template name, e.g., "TemplateName_8_\x1A" where \x1A
1132 // is a placeholder character at the position of a non-concrete parameter.
1133 std::string templateNameWithAttrs = BuildShortTypeString::from(
1134 parentTemplate.getSymName().str(), attrsForInstantiatedNameSuffix
1135 );
1136
1137 // Helper lambda to:
1138 // 1. build the FuncInstTypeConverter and apply it to a cloned function
1139 // 2. verify CallOp in the converted function are valid for their respective targets
1140 // and emit a more helpful error at this point rather than discovering it later
1141 // when verifying the entire module.
1142 auto applyBodyConversions = [&](FuncDefOp newFunc) -> LogicalResult {
1143 FuncInstTypeConverter tyConv(paramNameToConcrete);
1144 ConversionTarget target = newConverterDefinedTarget<>(tyConv, ctx);
1145 target.addDynamicallyLegalOp<ConstReadOp>([&tyConv](ConstReadOp p) {
1146 // Legal if it's not in the map of concrete attribute instantiations
1147 return !tyConv.containsParam(p.getConstNameAttr());
1148 });
1149 SmallVector<Diagnostic> delayedDiagnostics;
1150 RewritePatternSet bodyPatterns = newGeneralRewritePatternSet(tyConv, ctx, target);
1151 bodyPatterns.add<ClonedBodyConstReadOpPattern>(
1152 tyConv, ctx, tyConv.getParamMap(), delayedDiagnostics
1153 );
1154 if (failed(applyFullConversion(newFunc, target, std::move(bodyPatterns)))) {
1155 return failure();
1156 }
1157 LLVM_DEBUG(
1158 llvm::dbgs() << "[InstantiateFuncAtCallOp] instantiated clone: " << newFunc << '\n'
1159 );
1160 ::reportDelayedDiagnostics(op, std::move(delayedDiagnostics));
1161
1162 // Verify CallOp match targets
1163 SymbolTableCollection tables;
1164 WalkResult res = newFunc.walk([&tables](CallOp nestedCall) {
1165 return WalkResult(nestedCall.verifySymbolUses(tables));
1166 });
1167 return failure(res.wasInterrupted());
1168 };
1169
1170 SmallVector<FlatSymbolRefAttr> symPieces = getPieces(op.getCalleeAttr());
1171 assert(symPieces.size() >= 2 && "callee must include at least template and function names");
1172 if (remainingNames.empty()) {
1173 // FULL INSTANTIATION: place the cloned function directly in the parent module.
1174 // New function name encodes all parameter values, e.g., "TemplateName_8_12_funcName".
1175 std::string newFuncName =
1176 (mlir::Twine(templateNameWithAttrs) + "_" + callTgt.getSymName()).str();
1177 StringRef actualNewFuncName = newFuncName;
1178 if (!symTables.getSymbolTable(parentModule).lookup(newFuncName)) {
1179 FuncDefOp newFunc = callTgt.clone();
1180 newFunc.setSymName(newFuncName);
1181 // Insert before the TemplateOp; symbol table may adjust the name to ensure uniqueness.
1182 symTables.getSymbolTable(parentModule).insert(newFunc, Block::iterator(parentTemplate));
1183 actualNewFuncName = newFunc.getSymName();
1184 LLVM_DEBUG(
1185 llvm::dbgs() << "[InstantiateFuncAtCallOp] created full instantiation function: "
1186 << actualNewFuncName << '\n'
1187 );
1188 if (failed(applyBodyConversions(newFunc))) {
1189 LLVM_DEBUG(
1190 llvm::dbgs() << "[InstantiateFuncAtCallOp] body conversion failed for "
1191 << actualNewFuncName << '\n'
1192 );
1193 newFunc->erase();
1194 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1195 diag.append("failure while creating instantiated function '", actualNewFuncName, '\'');
1196 });
1197 }
1198 } else {
1199 LLVM_DEBUG(
1200 llvm::dbgs() << "[InstantiateFuncAtCallOp] reusing full instantiation function: "
1201 << actualNewFuncName << '\n'
1202 );
1203 }
1204 // Callee: drop template & original function names, add the new module-level function name.
1205 // Original: @[prefix...]::@TemplateName::@funcName
1206 // New: @[prefix...]::@newFuncName
1207 symPieces.pop_back(); // remove original function name
1208 symPieces.pop_back(); // remove template name
1209 symPieces.push_back(FlatSymbolRefAttr::get(StringAttr::get(ctx, actualNewFuncName)));
1210 } else {
1211 // PARTIAL INSTANTIATION: place the cloned function in a new partially-instantiated
1212 // TemplateOp that retains only the non-concrete parameters.
1213 // New template name encodes the concrete values and uses placeholder chars for the rest,
1214 // e.g., "TemplateName_8_\x1A" where \x1A marks the position of a non-concrete param.
1215 TemplateOp newTemplate;
1216 if (Operation *existing =
1217 symTables.getSymbolTable(parentModule).lookup(templateNameWithAttrs)) {
1218 newTemplate = llvm::dyn_cast<TemplateOp>(existing);
1219 }
1220 if (!newTemplate) {
1221 // Clone the TemplateOp structure without its body and set the new name.
1222 newTemplate = parentTemplate.cloneWithoutRegions();
1223 newTemplate.setSymName(templateNameWithAttrs);
1224 assert(newTemplate->getNumRegions() > 0 && "region exists");
1225 newTemplate.getBodyRegion().emplaceBlock();
1226
1227 // Clone the preserved (non-concrete) param/expr ops into the new template in order.
1228 Block &newTemplateBody = newTemplate.getBodyRegion().front();
1229 for (Attribute name : remainingNames) {
1230 FlatSymbolRefAttr nameSym = llvm::cast<FlatSymbolRefAttr>(name);
1231 Operation *paramOp = symTables.getSymbolTable(parentTemplate).lookup(nameSym.getAttr());
1232 assert(paramOp && "symbol must exist");
1233 newTemplateBody.push_back(paramOp->clone());
1234 }
1235
1236 // Clone and partially convert the function (concretize only the concrete params).
1237 FuncDefOp newFunc = callTgt.clone();
1238 if (failed(applyBodyConversions(newFunc))) {
1239 StringRef newFuncName = newFunc.getSymName();
1240 LLVM_DEBUG(
1241 llvm::dbgs() << "[InstantiateFuncAtCallOp] body conversion failed for "
1242 << newFuncName << '\n'
1243 );
1244 newTemplate->erase();
1245 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1246 diag.append("failure while creating instantiated function '", newFuncName, '\'');
1247 });
1248 }
1249
1250 // Insert the function into the new template, then the template into the module. Use the
1251 // `SymbolTable::insert()` function so that the name will be made unique if necessary.
1252 symTables.getSymbolTable(newTemplate).insert(newFunc);
1253 symTables.getSymbolTable(parentModule).insert(newTemplate, Block::iterator(parentTemplate));
1254 LLVM_DEBUG(
1255 llvm::dbgs() << "[InstantiateFuncAtCallOp] created partial instantiation template: "
1256 << newTemplate.getSymName() << '\n'
1257 );
1258 } else {
1259 LLVM_DEBUG(
1260 llvm::dbgs() << "[InstantiateFuncAtCallOp] reusing partial instantiation template: "
1261 << newTemplate.getSymName() << '\n'
1262 );
1263 }
1264 // Callee: replace old template name with new template name, keep the function name.
1265 // Original: @[prefix...]::@TemplateName::@funcName
1266 // New: @[prefix...]::@newTemplateName::@funcName
1267 symPieces.pop_back(); // remove original function name (will be re-appended)
1268 symPieces.pop_back(); // remove original template name
1269 symPieces.push_back(FlatSymbolRefAttr::get(newTemplate.getSymNameAttr()));
1270 symPieces.push_back(FlatSymbolRefAttr::get(callTgt.getSymNameAttr()));
1271 }
1272
1273 // Update the CallOp to point to the instantiated function and mark the module as modified.
1274 rewriter.modifyOpInPlace(op, [&op, &symPieces]() {
1275 // Update callee attribute.
1276 SymbolRefAttr newCalleeAttr = asSymbolRefAttr(symPieces);
1277 LLVM_DEBUG({
1278 llvm::dbgs() << "[InstantiateFuncAtCallOp] updating callee from " << op.getCalleeAttr()
1279 << " to " << newCalleeAttr << '\n';
1280 });
1281 op.setCalleeAttr(newCalleeAttr);
1282 // Also drop template param list. If it was present, it was fully used (no partial case).
1283 op.setTemplateParamsAttr(nullptr);
1284 });
1285 tracker_.updateModifiedFlag(true);
1286 return success();
1287 }
1288};
1289
1290LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1291 MLIRContext *ctx = modOp.getContext();
1292 RewritePatternSet patterns(ctx);
1293 patterns.add<InstantiateFuncAtCallOp>(ctx, tracker);
1294 MatchFailureListener failureListener;
1295 walkAndApplyPatterns(modOp, std::move(patterns), &failureListener);
1296 return failure(failureListener.hadFailure);
1297}
1298
1299} // namespace Step1B_InstantiateFunctions
1300
1301namespace Step2_Unroll {
1302
1303// TODO: not guaranteed to work with WhileOp, can try with our custom attributes though.
1304template <HasInterface<LoopLikeOpInterface> OpClass>
1305class LoopUnrollPattern : public OpRewritePattern<OpClass> {
1306public:
1307 using OpRewritePattern<OpClass>::OpRewritePattern;
1308
1309 LogicalResult matchAndRewrite(OpClass loopOp, PatternRewriter &rewriter) const override {
1310 if (auto maybeConstant = getConstantTripCount(loopOp)) {
1311 uint64_t tripCount = *maybeConstant;
1312 if (tripCount == 0) {
1313 rewriter.eraseOp(loopOp);
1314 return success();
1315 } else if (tripCount == 1) {
1316 return loopOp.promoteIfSingleIteration(rewriter);
1317 }
1318 return loopUnrollByFactor(loopOp, tripCount);
1319 }
1320 return failure();
1321 }
1322
1323private:
1326 static std::optional<int64_t> getConstantTripCount(LoopLikeOpInterface loopOp) {
1327 std::optional<OpFoldResult> lbVal = loopOp.getSingleLowerBound();
1328 std::optional<OpFoldResult> ubVal = loopOp.getSingleUpperBound();
1329 std::optional<OpFoldResult> stepVal = loopOp.getSingleStep();
1330 if (!lbVal.has_value() || !ubVal.has_value() || !stepVal.has_value()) {
1331 return std::nullopt;
1332 }
1333 return constantTripCount(lbVal.value(), ubVal.value(), stepVal.value());
1334 }
1335};
1336
1337LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1338 MLIRContext *ctx = modOp.getContext();
1339 RewritePatternSet patterns(ctx);
1340 patterns.add<LoopUnrollPattern<scf::ForOp>>(ctx);
1341 patterns.add<LoopUnrollPattern<affine::AffineForOp>>(ctx);
1342
1343 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1344}
1345} // namespace Step2_Unroll
1346
1348
1349// Adapted from `mlir::getConstantIntValues()` but that one failed in CI for an unknown reason. This
1350// version uses a basic loop instead of llvm::map_to_vector().
1351std::optional<SmallVector<int64_t>> getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
1352 SmallVector<int64_t> res;
1353 for (OpFoldResult ofr : ofrs) {
1354 std::optional<int64_t> cv = getConstantIntValue(ofr);
1355 if (!cv.has_value()) {
1356 return std::nullopt;
1357 }
1358 res.push_back(cv.value());
1359 }
1360 return res;
1361}
1362
1363struct AffineMapFolder {
1364 struct Input {
1365 OperandRangeRange mapOpGroups;
1366 DenseI32ArrayAttr dimsPerGroup;
1367 ArrayRef<Attribute> paramsOfStructTy;
1368 };
1369
1370 struct Output {
1371 SmallVector<SmallVector<Value>> mapOpGroups;
1372 SmallVector<int32_t> dimsPerGroup;
1373 SmallVector<Attribute> paramsOfStructTy;
1374 };
1375
1376 static inline SmallVector<ValueRange> getConvertedMapOpGroups(Output out) {
1377 return llvm::map_to_vector(out.mapOpGroups, [](const SmallVector<Value> &grp) {
1378 return ValueRange(grp);
1379 });
1380 }
1381
1382 static LogicalResult
1383 fold(PatternRewriter &rewriter, const Input &in, Output &out, Operation *op, const char *aspect) {
1384 if (in.mapOpGroups.empty()) {
1385 // No affine map operands so nothing to do
1386 return failure();
1387 }
1388
1389 assert(in.mapOpGroups.size() <= in.paramsOfStructTy.size());
1390 assert(std::cmp_equal(in.mapOpGroups.size(), in.dimsPerGroup.size()));
1391
1392 size_t idx = 0; // index in `mapOpGroups`, i.e., the number of AffineMapAttr encountered
1393 for (Attribute sizeAttr : in.paramsOfStructTy) {
1394 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(sizeAttr)) {
1395 ValueRange currMapOps = in.mapOpGroups[idx++];
1396 LLVM_DEBUG(
1397 llvm::dbgs() << "[AffineMapFolder] currMapOps: " << debug::toStringList(currMapOps)
1398 << '\n'
1399 );
1400 SmallVector<OpFoldResult> currMapOpsCast = getAsOpFoldResult(currMapOps);
1401 LLVM_DEBUG(
1402 llvm::dbgs() << "[AffineMapFolder] currMapOps as fold results: "
1403 << debug::toStringList(currMapOpsCast) << '\n'
1404 );
1405 if (auto constOps = Step3_InstantiateAffineMaps::getConstantIntValues(currMapOpsCast)) {
1406 SmallVector<Attribute> result;
1407 bool hasPoison = false; // indicates divide by 0 or mod by <1
1408 auto constAttrs = llvm::map_to_vector(*constOps, [&rewriter](int64_t v) -> Attribute {
1409 return rewriter.getIndexAttr(v);
1410 });
1411 LogicalResult foldResult = m.getAffineMap().constantFold(constAttrs, result, &hasPoison);
1412 if (hasPoison) {
1413 // Diagnostic remark: could be removed for release builds if too noisy
1414 op->emitRemark()
1415 .append(
1416 "Cannot fold affine_map for ", aspect, " ", out.paramsOfStructTy.size(),
1417 " due to divide by 0 or modulus with negative divisor"
1418 )
1419 .report();
1420 return failure();
1421 }
1422 if (failed(foldResult)) {
1423 // Diagnostic remark: could be removed for release builds if too noisy
1424 op->emitRemark()
1425 .append(
1426 "Folding affine_map for ", aspect, " ", out.paramsOfStructTy.size(), " failed"
1427 )
1428 .report();
1429 return failure();
1430 }
1431 if (result.size() != 1) {
1432 // Diagnostic remark: could be removed for release builds if too noisy
1433 op->emitRemark()
1434 .append(
1435 "Folding affine_map for ", aspect, " ", out.paramsOfStructTy.size(),
1436 " produced ", result.size(), " results but expected 1"
1437 )
1438 .report();
1439 return failure();
1440 }
1441 assert(!llvm::isa<AffineMapAttr>(result[0]) && "not converted");
1442 out.paramsOfStructTy.push_back(result[0]);
1443 continue;
1444 }
1445 // If affine but not foldable, preserve the map ops
1446 out.mapOpGroups.emplace_back(currMapOps);
1447 out.dimsPerGroup.push_back(in.dimsPerGroup[idx - 1]); // idx was already incremented
1448 }
1449 // If not affine and foldable, preserve the original
1450 out.paramsOfStructTy.push_back(sizeAttr);
1451 }
1452 assert(idx == in.mapOpGroups.size() && "all affine_map not processed");
1453 assert(
1454 in.paramsOfStructTy.size() == out.paramsOfStructTy.size() &&
1455 "produced wrong number of dimensions"
1456 );
1457
1458 return success();
1459 }
1460};
1461
1463class InstantiateAtCreateArrayOp final : public OpRewritePattern<CreateArrayOp> {
1464 [[maybe_unused]]
1465 ConversionTracker &tracker_;
1466
1467public:
1468 InstantiateAtCreateArrayOp(MLIRContext *ctx, ConversionTracker &tracker)
1469 : OpRewritePattern(ctx), tracker_(tracker) {}
1470
1471 LogicalResult matchAndRewrite(CreateArrayOp op, PatternRewriter &rewriter) const override {
1472 ArrayType oldResultType = op.getType();
1473
1474 AffineMapFolder::Output out;
1475 AffineMapFolder::Input in = {
1476 op.getMapOperands(),
1478 oldResultType.getDimensionSizes(),
1479 };
1480 if (failed(AffineMapFolder::fold(rewriter, in, out, op, "array dimension"))) {
1481 return failure();
1482 }
1483
1484 ArrayType newResultType = ArrayType::get(oldResultType.getElementType(), out.paramsOfStructTy);
1485 if (newResultType == oldResultType) {
1486 return failure(); // nothing changed
1487 }
1488 // ASSERT: folding only preserves the original Attribute or converts affine to integer
1489 assert(tracker_.isLegalConversion(oldResultType, newResultType, "InstantiateAtCreateArrayOp"));
1490 LLVM_DEBUG(
1491 llvm::dbgs() << "[InstantiateAtCreateArrayOp] instantiating " << oldResultType << " as "
1492 << newResultType << " in \"" << op << "\"\n"
1493 );
1495 rewriter, op, newResultType, AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup
1496 );
1497 return success();
1498 }
1499};
1500
1502class InstantiateAtCallOpCompute final : public OpRewritePattern<CallOp> {
1503 ConversionTracker &tracker_;
1504
1505public:
1506 InstantiateAtCallOpCompute(MLIRContext *ctx, ConversionTracker &tracker)
1507 : OpRewritePattern(ctx), tracker_(tracker) {}
1508
1509 LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override {
1510 if (!op.calleeIsStructCompute()) {
1511 // this pattern only applies when the callee is "compute()" within a struct
1512 return failure();
1513 }
1514 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] target: " << op.getCallee() << '\n');
1516 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] oldRetTy: " << oldRetTy << '\n');
1517 ArrayAttr params = oldRetTy.getParams();
1518 if (isNullOrEmpty(params)) {
1519 // nothing to do if the StructType is not parameterized
1520 return failure();
1521 }
1522
1523 AffineMapFolder::Output out;
1524 AffineMapFolder::Input in = {
1525 op.getMapOperands(),
1527 params.getValue(),
1528 };
1529 if (!in.mapOpGroups.empty()) {
1530 // If there are affine map operands, attempt to fold them to a constant.
1531 if (failed(AffineMapFolder::fold(rewriter, in, out, op, "struct parameter"))) {
1532 return failure();
1533 }
1534 LLVM_DEBUG({
1535 llvm::dbgs() << "[InstantiateAtCallOpCompute] folded affine_map in result type params\n";
1536 });
1537 } else {
1538 // If there are no affine map operands, attempt to refine the result type of the CallOp using
1539 // the function argument types and the type of the target function.
1540 auto callArgTypes = op.getArgOperands().getTypes();
1541 if (callArgTypes.empty()) {
1542 // no refinement possible if no function arguments
1543 return failure();
1544 }
1545 SymbolTableCollection tables;
1546 auto lookupRes = lookupTopLevelSymbol<FuncDefOp>(tables, op.getCalleeAttr(), op);
1547 if (failed(lookupRes)) {
1548 return failure();
1549 }
1550 if (failed(instantiateViaTargetType(in, out, callArgTypes, lookupRes->get()))) {
1551 return failure();
1552 }
1553 LLVM_DEBUG({
1554 llvm::dbgs() << "[InstantiateAtCallOpCompute] propagated instantiations via symrefs in "
1555 "result type params: "
1556 << debug::toStringList(out.paramsOfStructTy) << '\n';
1557 });
1558 }
1559
1560 StructType newRetTy = StructType::get(oldRetTy.getNameRef(), out.paramsOfStructTy);
1561 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] newRetTy: " << newRetTy << '\n');
1562 if (newRetTy == oldRetTy) {
1563 return failure(); // nothing changed
1564 }
1565 // The `newRetTy` is computed via instantiateViaTargetType() which can only preserve the
1566 // original Attribute or convert to a concrete attribute via the unification process. Thus, if
1567 // the conversion here is illegal it means there is a type conflict within the LLZK code that
1568 // prevents instantiation of the struct with the requested type.
1569 if (!tracker_.isLegalConversion(oldRetTy, newRetTy, "InstantiateAtCallOpCompute")) {
1570 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1571 diag.append(
1572 "result type mismatch: due to struct instantiation, expected type ", newRetTy,
1573 ", but found ", oldRetTy
1574 );
1575 });
1576 }
1577 LLVM_DEBUG(llvm::dbgs() << "[InstantiateAtCallOpCompute] replaced " << op);
1579 rewriter, op, TypeRange {newRetTy}, op.getCallee(),
1580 AffineMapFolder::getConvertedMapOpGroups(out), out.dimsPerGroup, op.getArgOperands()
1581 );
1582 (void)newOp; // tell compiler it's intentionally unused in release builds
1583 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
1584 return success();
1585 }
1586
1587private:
1590 inline LogicalResult instantiateViaTargetType(
1591 const AffineMapFolder::Input &in, AffineMapFolder::Output &out,
1592 OperandRange::type_range callArgTypes, FuncDefOp targetFunc
1593 ) const {
1594 assert(targetFunc.isStructCompute()); // since `op.calleeIsStructCompute()`
1595 ArrayAttr targetResTyParams = targetFunc.getSingleResultTypeOfCompute().getParams();
1596 assert(!isNullOrEmpty(targetResTyParams)); // same cardinality as `in.paramsOfStructTy`
1597 assert(in.paramsOfStructTy.size() == targetResTyParams.size()); // verifier ensures this
1598
1599 if (llvm::all_of(in.paramsOfStructTy, isConcreteAttr<>)) {
1600 // Nothing can change if everything is already concrete
1601 return failure();
1602 }
1603
1604 LLVM_DEBUG({
1605 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1606 << " call arg types: " << debug::toStringList(callArgTypes) << '\n';
1607 llvm::dbgs() << '[' << __FUNCTION__ << ']' << " target func arg types: "
1608 << debug::toStringList(targetFunc.getArgumentTypes()) << '\n';
1609 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1610 << " struct params @ call: " << debug::toStringList(in.paramsOfStructTy) << '\n';
1611 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1612 << " target struct params: " << debug::toStringList(targetResTyParams) << '\n';
1613 });
1614
1615 UnificationMap unifications;
1616 bool unifies = typeListsUnify(targetFunc.getArgumentTypes(), callArgTypes, {}, &unifications);
1617 (void)unifies; // tell compiler it's intentionally unused in builds without assertions
1618 assert(unifies && "should have been checked by verifiers");
1619
1620 LLVM_DEBUG({
1621 llvm::dbgs() << '[' << __FUNCTION__ << ']'
1622 << " unifications of arg types: " << debug::toStringList(unifications) << '\n';
1623 });
1624
1625 // Check for LHS SymRef (i.e., from the target function) that have RHS concrete Attributes (i.e.
1626 // from the call argument types) without any struct parameters (because the type with concrete
1627 // struct parameters will be used to instantiate the target struct rather than the fully
1628 // flattened struct type resulting in type mismatch of the callee to target) and perform those
1629 // replacements in the `targetFunc` return type to produce the new result type for the CallOp.
1630 SmallVector<Attribute> newReturnStructParams = llvm::map_to_vector(
1631 llvm::zip_equal(targetResTyParams.getValue(), in.paramsOfStructTy),
1632 [&unifications](std::tuple<Attribute, Attribute> p) {
1633 Attribute fromCall = std::get<1>(p);
1634 // Preserve attributes that are already concrete at the call site. Otherwise attempt to lookup
1635 // non-parameterized concrete unification for the target struct parameter symbol.
1636 if (!isConcreteAttr<>(fromCall)) {
1637 Attribute fromTgt = std::get<0>(p);
1638 LLVM_DEBUG({
1639 llvm::dbgs() << "[instantiateViaTargetType] fromCall = " << fromCall << '\n';
1640 llvm::dbgs() << "[instantiateViaTargetType] fromTgt = " << fromTgt << '\n';
1641 });
1642 assert(llvm::isa<SymbolRefAttr>(fromTgt));
1643 auto it = unifications.find(std::make_pair(llvm::cast<SymbolRefAttr>(fromTgt), Side::LHS));
1644 if (it != unifications.end()) {
1645 Attribute unifiedAttr = it->second;
1646 LLVM_DEBUG({
1647 llvm::dbgs() << "[instantiateViaTargetType] unifiedAttr = " << unifiedAttr << '\n';
1648 });
1649 if (unifiedAttr && isConcreteAttr<false>(unifiedAttr)) {
1650 return unifiedAttr;
1651 }
1652 }
1653 }
1654 return fromCall;
1655 }
1656 );
1657
1658 out.paramsOfStructTy = newReturnStructParams;
1659 assert(out.paramsOfStructTy.size() == in.paramsOfStructTy.size() && "post-condition");
1660 assert(out.mapOpGroups.empty() && "post-condition");
1661 assert(out.dimsPerGroup.empty() && "post-condition");
1662 return success();
1663 }
1664};
1665
1666LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
1667 MLIRContext *ctx = modOp.getContext();
1668 RewritePatternSet patterns(ctx);
1669 patterns.add<
1670 InstantiateAtCreateArrayOp, // CreateArrayOp
1671 InstantiateAtCallOpCompute // CallOp, targeting struct "compute()"
1672 >(ctx, tracker);
1673
1674 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
1675}
1676
1677} // namespace Step3_InstantiateAffineMaps
1678
1680
1682class UpdateNewArrayElemFromWrite final : public OpRewritePattern<CreateArrayOp> {
1683 ConversionTracker &tracker_;
1684
1685public:
1686 UpdateNewArrayElemFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1687 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1688
1689 LogicalResult matchAndRewrite(CreateArrayOp op, PatternRewriter &rewriter) const override {
1690 Value createResult = op.getResult();
1691 ArrayType createResultType = dyn_cast<ArrayType>(createResult.getType());
1692 assert(createResultType && "CreateArrayOp must produce ArrayType");
1693 Type oldResultElemType = createResultType.getElementType();
1694
1695 // Look for WriteArrayOp where the array reference is the result of the CreateArrayOp and the
1696 // element type is different.
1697 Type newResultElemType = nullptr;
1698 for (Operation *user : createResult.getUsers()) {
1699 if (WriteArrayOp writeOp = dyn_cast<WriteArrayOp>(user)) {
1700 if (writeOp.getArrRef() != createResult) {
1701 continue;
1702 }
1703 Type writeRValueType = writeOp.getRvalue().getType();
1704 if (writeRValueType == oldResultElemType) {
1705 continue;
1706 }
1707 if (newResultElemType && newResultElemType != writeRValueType) {
1708 LLVM_DEBUG(
1709 llvm::dbgs()
1710 << "[UpdateNewArrayElemFromWrite] multiple possible element types for CreateArrayOp "
1711 << newResultElemType << " vs " << writeRValueType << '\n'
1712 );
1713 return failure();
1714 }
1715 newResultElemType = writeRValueType;
1716 }
1717 }
1718 if (!newResultElemType) {
1719 // no replacement type found
1720 return failure();
1721 }
1722 if (!tracker_.isLegalConversion(
1723 oldResultElemType, newResultElemType, "UpdateNewArrayElemFromWrite"
1724 )) {
1725 return failure();
1726 }
1727 ArrayType newType = createResultType.cloneWith(newResultElemType);
1728 rewriter.modifyOpInPlace(op, [&createResult, &newType]() { createResult.setType(newType); });
1729 LLVM_DEBUG(
1730 llvm::dbgs() << "[UpdateNewArrayElemFromWrite] updated result type of " << op << '\n'
1731 );
1732 return success();
1733 }
1734};
1735
1736namespace {
1737
1738LogicalResult updateArrayElemFromArrAccessOp(
1739 ArrayAccessOpInterface op, Type scalarElemTy, ConversionTracker &tracker,
1740 PatternRewriter &rewriter
1741) {
1742 ArrayType oldArrType = op.getArrRefType();
1743 if (oldArrType.getElementType() == scalarElemTy) {
1744 return failure(); // no change needed
1745 }
1746 ArrayType newArrType = oldArrType.cloneWith(scalarElemTy);
1747 if (oldArrType == newArrType ||
1748 !tracker.isLegalConversion(oldArrType, newArrType, "updateArrayElemFromArrAccessOp")) {
1749 return failure();
1750 }
1751 rewriter.modifyOpInPlace(op, [&op, &newArrType]() { op.getArrRef().setType(newArrType); });
1752 LLVM_DEBUG(
1753 llvm::dbgs() << "[updateArrayElemFromArrAccessOp] updated base array type in " << op << '\n'
1754 );
1755 return success();
1756}
1757
1758} // namespace
1759
1760class UpdateArrayElemFromArrWrite final : public OpRewritePattern<WriteArrayOp> {
1761 ConversionTracker &tracker_;
1762
1763public:
1764 UpdateArrayElemFromArrWrite(MLIRContext *ctx, ConversionTracker &tracker)
1765 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1766
1767 LogicalResult matchAndRewrite(WriteArrayOp op, PatternRewriter &rewriter) const override {
1768 return updateArrayElemFromArrAccessOp(op, op.getRvalue().getType(), tracker_, rewriter);
1769 }
1770};
1771
1772class UpdateArrayElemFromArrRead final : public OpRewritePattern<ReadArrayOp> {
1773 ConversionTracker &tracker_;
1774
1775public:
1776 UpdateArrayElemFromArrRead(MLIRContext *ctx, ConversionTracker &tracker)
1777 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1778
1779 LogicalResult matchAndRewrite(ReadArrayOp op, PatternRewriter &rewriter) const override {
1780 return updateArrayElemFromArrAccessOp(op, op.getResult().getType(), tracker_, rewriter);
1781 }
1782};
1783
1785class UpdateMemberDefTypeFromWrite final : public OpRewritePattern<MemberDefOp> {
1786 ConversionTracker &tracker_;
1787
1788public:
1789 UpdateMemberDefTypeFromWrite(MLIRContext *ctx, ConversionTracker &tracker)
1790 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1791
1792 LogicalResult matchAndRewrite(MemberDefOp op, PatternRewriter &rewriter) const override {
1793 // Find all uses of the member symbol name within its parent struct.
1795 assert(parentRes && "MemberDefOp parent is always StructDefOp"); // per ODS def
1796
1797 // If the symbol is used by a MemberWriteOp with a different result type then change
1798 // the type of the MemberDefOp to match the MemberWriteOp result type.
1799 Type newType = nullptr;
1800 if (auto memberUsers = llzk::getSymbolUses(op, parentRes)) {
1801 std::optional<Location> newTypeLoc = std::nullopt;
1802 for (SymbolTable::SymbolUse symUse : memberUsers.value()) {
1803 if (MemberWriteOp writeOp = llvm::dyn_cast<MemberWriteOp>(symUse.getUser())) {
1804 Type writeToType = writeOp.getVal().getType();
1805 LLVM_DEBUG(llvm::dbgs() << "[UpdateMemberDefTypeFromWrite] checking " << writeOp << '\n');
1806 if (!newType) {
1807 // If a new type has not yet been discovered, store the new type.
1808 newType = writeToType;
1809 newTypeLoc = writeOp.getLoc();
1810 } else if (writeToType != newType) {
1811 // Typically, there will only be one write for each member of a struct but do not rely
1812 // on that assumption. If multiple writes with a different types A and B are found where
1813 // A->B is a legal conversion (i.e., more concrete unification), then it is safe to use
1814 // type B with the assumption that the write with type A will be updated by another
1815 // pattern to also use type B.
1816 if (!tracker_.isLegalConversion(writeToType, newType, "UpdateMemberDefTypeFromWrite")) {
1817 if (tracker_.isLegalConversion(
1818 newType, writeToType, "UpdateMemberDefTypeFromWrite"
1819 )) {
1820 // 'writeToType' is the more concrete type
1821 newType = writeToType;
1822 newTypeLoc = writeOp.getLoc();
1823 } else {
1824 // Give an error if the types are incompatible.
1825 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1826 diag.append(
1827 "Cannot update type of '", MemberDefOp::getOperationName(),
1828 "' because there are multiple '", MemberWriteOp::getOperationName(),
1829 "' with different value types"
1830 );
1831 if (newTypeLoc) {
1832 diag.attachNote(newTypeLoc).append("type written here is ", newType);
1833 }
1834 diag.attachNote(writeOp.getLoc()).append("type written here is ", writeToType);
1835 });
1836 }
1837 }
1838 }
1839 }
1840 }
1841 }
1842 if (!newType || newType == op.getType()) {
1843 return failure(); // nothing changed
1844 }
1845 if (!tracker_.isLegalConversion(op.getType(), newType, "UpdateMemberDefTypeFromWrite")) {
1846 return failure();
1847 }
1848 rewriter.modifyOpInPlace(op, [&op, &newType]() { op.setType(newType); });
1849 LLVM_DEBUG(llvm::dbgs() << "[UpdateMemberDefTypeFromWrite] updated type of " << op << '\n');
1850 return success();
1851 }
1852};
1853
1854namespace {
1855
1856SmallVector<std::unique_ptr<Region>> moveRegions(Operation *op) {
1857 SmallVector<std::unique_ptr<Region>> newRegions;
1858 for (Region &region : op->getRegions()) {
1859 auto newRegion = std::make_unique<Region>();
1860 newRegion->takeBody(region);
1861 newRegions.push_back(std::move(newRegion));
1862 }
1863 return newRegions;
1864}
1865
1866} // namespace
1867
1870class UpdateInferredResultTypes final : public OpTraitRewritePattern<OpTrait::InferTypeOpAdaptor> {
1871 ConversionTracker &tracker_;
1872
1873public:
1874 UpdateInferredResultTypes(MLIRContext *ctx, ConversionTracker &tracker)
1875 : OpTraitRewritePattern(ctx, 6), tracker_(tracker) {}
1876
1877 LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
1878 SmallVector<Type, 1> inferredResultTypes;
1879 InferTypeOpInterface retTypeFn = llvm::cast<InferTypeOpInterface>(op);
1880 LogicalResult result = retTypeFn.inferReturnTypes(
1881 op->getContext(), op->getLoc(), op->getOperands(), op->getRawDictionaryAttrs(),
1882 op->getPropertiesStorage(), op->getRegions(), inferredResultTypes
1883 );
1884 if (failed(result)) {
1885 return failure();
1886 }
1887 if (op->getResultTypes() == inferredResultTypes) {
1888 return failure(); // nothing changed
1889 }
1890 if (!tracker_.areLegalConversions(
1891 op->getResultTypes(), inferredResultTypes, "UpdateInferredResultTypes"
1892 )) {
1893 return failure();
1894 }
1895
1896 // Move nested region bodies and replace the original op with the updated types list.
1897 LLVM_DEBUG(llvm::dbgs() << "[UpdateInferredResultTypes] replaced " << *op);
1898 SmallVector<std::unique_ptr<Region>> newRegions = moveRegions(op);
1899 Operation *newOp = rewriter.create(
1900 op->getLoc(), op->getName().getIdentifier(), op->getOperands(), inferredResultTypes,
1901 op->getAttrs(), op->getSuccessors(), newRegions
1902 );
1903 rewriter.replaceOp(op, newOp);
1904 LLVM_DEBUG(llvm::dbgs() << " with " << *newOp << '\n');
1905 return success();
1906 }
1907};
1908
1910class UpdateFuncTypeFromReturn final : public OpRewritePattern<FuncDefOp> {
1911 ConversionTracker &tracker_;
1912
1913public:
1914 UpdateFuncTypeFromReturn(MLIRContext *ctx, ConversionTracker &tracker)
1915 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1916
1917 LogicalResult matchAndRewrite(FuncDefOp op, PatternRewriter &rewriter) const override {
1918 Region &body = op.getFunctionBody();
1919 if (body.empty()) {
1920 return failure();
1921 }
1922 ReturnOp retOp = llvm::dyn_cast<ReturnOp>(body.back().getTerminator());
1923 assert(retOp && "final op in body region must be return");
1924 OperandRange::type_range tyFromReturnOp = retOp.getOperands().getTypes();
1925
1926 FunctionType oldFuncTy = op.getFunctionType();
1927 if (oldFuncTy.getResults() == tyFromReturnOp) {
1928 return failure(); // nothing changed
1929 }
1930 if (!tracker_.areLegalConversions(
1931 oldFuncTy.getResults(), tyFromReturnOp, "UpdateFuncTypeFromReturn"
1932 )) {
1933 return failure();
1934 }
1935
1936 rewriter.modifyOpInPlace(op, [&]() {
1937 op.setFunctionType(rewriter.getFunctionType(oldFuncTy.getInputs(), tyFromReturnOp));
1938 });
1939 LLVM_DEBUG(
1940 llvm::dbgs() << "[UpdateFuncTypeFromReturn] changed " << op.getSymName() << " from "
1941 << oldFuncTy << " to " << op.getFunctionType() << '\n'
1942 );
1943 return success();
1944 }
1945};
1946
1951class UpdateFreeFuncCallOpTypes final : public OpRewritePattern<CallOp> {
1952 ConversionTracker &tracker_;
1953
1954public:
1955 UpdateFreeFuncCallOpTypes(MLIRContext *ctx, ConversionTracker &tracker)
1956 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
1957
1958 LogicalResult matchAndRewrite(CallOp op, PatternRewriter &rewriter) const override {
1959 SymbolTableCollection tables;
1960 auto lookupRes = lookupTopLevelSymbol<FuncDefOp>(tables, op.getCalleeAttr(), op);
1961 if (failed(lookupRes)) {
1962 return failure();
1963 }
1964 FuncDefOp targetFunc = lookupRes->get();
1965 if (targetFunc.isInStruct()) {
1966 // this pattern only applies when the callee is NOT in a struct
1967 return failure();
1968 }
1969 if (op.getResultTypes() == targetFunc.getFunctionType().getResults()) {
1970 return failure(); // nothing changed
1971 }
1972 if (!tracker_.areLegalConversions(
1973 op.getResultTypes(), targetFunc.getFunctionType().getResults(),
1974 "UpdateFreeFuncCallOpTypes"
1975 )) {
1976 return failure();
1977 }
1978
1979 LLVM_DEBUG(llvm::dbgs() << "[UpdateFreeFuncCallOpTypes] replaced " << op);
1980 CallOp newOp = replaceOpWithNewOp<CallOp>(rewriter, op, targetFunc, op.getArgOperands());
1981 (void)newOp; // tell compiler it's intentionally unused in release builds
1982 LLVM_DEBUG(llvm::dbgs() << " with " << newOp << '\n');
1983 return success();
1984 }
1985};
1986
1987namespace {
1988
1989LogicalResult updateMemberRefValFromMemberDef(
1990 MemberRefOpInterface op, ConversionTracker &tracker, PatternRewriter &rewriter
1991) {
1992 SymbolTableCollection tables;
1993 auto def = op.getMemberDefOp(tables);
1994 if (failed(def)) {
1995 return failure();
1996 }
1997 Type oldResultType = op.getVal().getType();
1998 Type newResultType = def->get().getType();
1999 if (oldResultType == newResultType ||
2000 !tracker.isLegalConversion(oldResultType, newResultType, "updateMemberRefValFromMemberDef")) {
2001 return failure();
2002 }
2003 rewriter.modifyOpInPlace(op, [&op, &newResultType]() { op.getVal().setType(newResultType); });
2004 LLVM_DEBUG(
2005 llvm::dbgs() << "[updateMemberRefValFromMemberDef] updated value type in " << op << '\n'
2006 );
2007 return success();
2008}
2009
2010} // namespace
2011
2013class UpdateMemberReadValFromDef final : public OpRewritePattern<MemberReadOp> {
2014 ConversionTracker &tracker_;
2015
2016public:
2017 UpdateMemberReadValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
2018 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
2019
2020 LogicalResult matchAndRewrite(MemberReadOp op, PatternRewriter &rewriter) const override {
2021 return updateMemberRefValFromMemberDef(op, tracker_, rewriter);
2022 }
2023};
2024
2026class UpdateMemberWriteValFromDef final : public OpRewritePattern<MemberWriteOp> {
2027 ConversionTracker &tracker_;
2028
2029public:
2030 UpdateMemberWriteValFromDef(MLIRContext *ctx, ConversionTracker &tracker)
2031 : OpRewritePattern(ctx, 3), tracker_(tracker) {}
2032
2033 LogicalResult matchAndRewrite(MemberWriteOp op, PatternRewriter &rewriter) const override {
2034 return updateMemberRefValFromMemberDef(op, tracker_, rewriter);
2035 }
2036};
2037
2038LogicalResult run(ModuleOp modOp, ConversionTracker &tracker) {
2039 MLIRContext *ctx = modOp.getContext();
2040 RewritePatternSet patterns(ctx);
2041 patterns.add<
2042 // Benefit of this one must be higher than rules that would propagate the type in the opposite
2043 // direction (ex: `UpdateArrayElemFromArrRead`) else the greedy conversion would not converge.
2044 // benefit = 6
2045 UpdateInferredResultTypes, // OpTrait::InferTypeOpAdaptor (ReadArrayOp, ExtractArrayOp)
2046 // benefit = 3
2047 UpdateFreeFuncCallOpTypes, // CallOp, targeting non-struct functions
2048 UpdateFuncTypeFromReturn, // FuncDefOp
2049 UpdateNewArrayElemFromWrite, // CreateArrayOp
2050 UpdateArrayElemFromArrRead, // ReadArrayOp
2051 UpdateArrayElemFromArrWrite, // WriteArrayOp
2052 UpdateMemberDefTypeFromWrite, // MemberDefOp
2053 UpdateMemberReadValFromDef, // MemberReadOp
2054 UpdateMemberWriteValFromDef // MemberWriteOp
2055 >(ctx, tracker);
2056
2057 return applyAndFoldGreedily(modOp, tracker, std::move(patterns));
2058}
2059} // namespace Step4_PropagateTypes
2060
2061namespace Step5_Cleanup {
2062
2063class CleanupBase {
2064public:
2065 SymbolTableCollection tables;
2066
2067 CleanupBase(ModuleOp root, const SymbolDefTree &symDefTree, const SymbolUseGraph &symUseGraph)
2068 : rootMod(root), defTree(symDefTree), useGraph(symUseGraph) {}
2069
2070protected:
2071 ModuleOp rootMod;
2072 const SymbolDefTree &defTree;
2073 const SymbolUseGraph &useGraph;
2074};
2075
2076struct FromKeepSet : public CleanupBase {
2077 using CleanupBase::CleanupBase;
2078
2082 LogicalResult eraseUnreachableFrom(ArrayRef<StructDefOp> keep) {
2083 // Initialize roots from the given StructDefOp instances
2084 SetVector<SymbolOpInterface> roots(keep.begin(), keep.end());
2085 // Add GlobalDefOp and "free functions" to the set of roots
2086 rootMod.walk([&roots](Operation *op) {
2087 if (global::GlobalDefOp gdef = llvm::dyn_cast<global::GlobalDefOp>(op)) {
2088 roots.insert(gdef);
2089 } else if (function::FuncDefOp fdef = llvm::dyn_cast<function::FuncDefOp>(op)) {
2090 if (!fdef.isInStruct()) {
2091 roots.insert(fdef);
2092 }
2093 }
2094 });
2095
2096 // Use a SymbolDefTree to find all Symbol defs reachable from one of the root nodes. Then
2097 // collect all Symbol uses reachable from those def nodes. These are the symbols that should
2098 // be preserved. All other symbol defs should be removed.
2099 llvm::df_iterator_default_set<const SymbolUseGraphNode *> symbolsToKeep;
2100 for (size_t i = 0; i < roots.size(); ++i) { // iterate for safe insertion
2101 SymbolOpInterface keepRoot = roots[i];
2102 LLVM_DEBUG({ llvm::dbgs() << "[EraseUnreachable] root: " << keepRoot << '\n'; });
2103 const SymbolDefTreeNode *keepRootNode = defTree.lookupNode(keepRoot);
2104 assert(keepRootNode && "every struct def must be in the def tree");
2105 for (const SymbolDefTreeNode *reachableDefNode : llvm::depth_first(keepRootNode)) {
2106 LLVM_DEBUG({
2107 llvm::dbgs() << "[EraseUnreachable] can reach: " << reachableDefNode->getOp() << '\n';
2108 });
2109 if (SymbolOpInterface reachableDef = reachableDefNode->getOp()) {
2110 // Use 'depth_first_ext()' to get all symbol uses reachable from the current Symbol def
2111 // node. There are no uses if the node is not in the graph. Within the loop that populates
2112 // 'depth_first_ext()', also check if the symbol is a StructDefOp and ensure it is in
2113 // 'roots' so the outer loop will ensure that all symbols reachable from it are preserved.
2114 if (const SymbolUseGraphNode *useGraphNodeForDef = useGraph.lookupNode(reachableDef)) {
2115 for (const SymbolUseGraphNode *usedSymbolNode :
2116 depth_first_ext(useGraphNodeForDef, symbolsToKeep)) {
2117 LLVM_DEBUG({
2118 llvm::dbgs() << "[EraseUnreachable] uses symbol: "
2119 << usedSymbolNode->getSymbolPath() << '\n';
2120 });
2121 // Ignore struct/template parameter symbols (before doing the lookup below because it
2122 // would fail anyway and then cause the "failed" case to be triggered unnecessarily).
2123 if (usedSymbolNode->isTemplateSymbolBinding()) {
2124 continue;
2125 }
2126 // If `usedSymbolNode` references a StructDefOp, ensure it's considered in the roots.
2127 auto lookupRes = usedSymbolNode->lookupSymbol(tables);
2128 if (failed(lookupRes)) {
2129 LLVM_DEBUG(useGraph.dumpToDotFile());
2130 return failure();
2131 }
2132 // If loaded via an IncludeOp it's not in the current AST anyway so ignore.
2133 if (lookupRes->viaInclude()) {
2134 continue;
2135 }
2136 if (StructDefOp asStruct = llvm::dyn_cast<StructDefOp>(lookupRes->get())) {
2137 bool insertRes = roots.insert(asStruct);
2138 (void)insertRes; // tell compiler it's intentionally unused in release builds
2139 LLVM_DEBUG({
2140 if (insertRes) {
2141 llvm::dbgs() << "[EraseUnreachable] found another root: " << asStruct << '\n';
2142 }
2143 });
2144 }
2145 }
2146 }
2147 }
2148 }
2149 }
2150
2151 rootMod.walk([this, &symbolsToKeep](StructDefOp op) {
2152 const SymbolUseGraphNode *n = this->useGraph.lookupNode(op);
2153 assert(n);
2154 if (!symbolsToKeep.contains(n)) {
2155 LLVM_DEBUG(llvm::dbgs() << "[EraseUnreachable] removing: " << op.getSymName() << '\n');
2156 op.erase();
2157 }
2158
2159 return WalkResult::skip(); // StructDefOp cannot be nested
2160 });
2161
2162 return success();
2163 }
2164};
2165
2166struct FromEraseSet : public CleanupBase {
2167
2169 FromEraseSet(
2170 ModuleOp root, const SymbolDefTree &symDefTree, const SymbolUseGraph &symUseGraph,
2171 DenseSet<SymbolRefAttr> &&tryToErasePaths
2172 )
2173 : CleanupBase(root, symDefTree, symUseGraph) {
2174 // Convert the set of paths targeted for erasure into a set of the StructDefOp
2175 for (SymbolRefAttr path : tryToErasePaths) {
2176 LLVM_DEBUG(llvm::dbgs() << "[FromEraseSet] path to erase: " << path << '\n';);
2177 Operation *lookupFrom = rootMod.getOperation();
2178 auto res = lookupSymbolIn<StructDefOp>(tables, path, lookupFrom, lookupFrom);
2179 assert(succeeded(res) && "inputs must be valid StructDefOp references");
2180 if (!res->viaInclude()) { // do not remove if it's from another source file
2181 auto op = res->get();
2182 LLVM_DEBUG(llvm::dbgs() << "[FromEraseSet] added op to the erase set: " << op << '\n';);
2183 tryToErase.insert(op);
2184 } else {
2185 LLVM_DEBUG(
2186 llvm::dbgs() << "[FromEraseSet] ignored op because it comes from an include: "
2187 << res->get() << '\n';
2188 );
2189 }
2190 }
2191 }
2192
2193 LogicalResult eraseUnusedStructs() {
2194 // Collect the subset of 'tryToErase' that has no remaining uses.
2195 for (StructDefOp sd : tryToErase) {
2196 collectSafeToErase(sd);
2197 }
2198 // The `visitedPlusSafetyResult` will contain FuncDefOp w/in the StructDefOp so just a single
2199 // loop to `dyn_cast` and `erase()` will cause `use-after-free` errors w/in the `dyn_cast`.
2200 // Instead, reduce the map to only those that should be erased and erase in a separate loop.
2201 for (auto it = visitedPlusSafetyResult.begin(); it != visitedPlusSafetyResult.end(); ++it) {
2202 if (!it->second || !llvm::isa<StructDefOp>(it->first.getOperation())) {
2203 visitedPlusSafetyResult.erase(it);
2204 }
2205 }
2206 for (auto &[sym, _] : visitedPlusSafetyResult) {
2207 LLVM_DEBUG(llvm::dbgs() << "[EraseIfUnused] removing: " << sym.getNameAttr() << '\n');
2208 sym.erase();
2209 }
2210 return success();
2211 }
2212
2213 const DenseSet<StructDefOp> &getTryToEraseSet() const { return tryToErase; }
2214
2215private:
2217 DenseSet<StructDefOp> tryToErase;
2221 DenseMap<SymbolOpInterface, bool> visitedPlusSafetyResult;
2223 DenseMap<const SymbolUseGraphNode *, SymbolOpInterface> lookupCache;
2224
2227 bool collectSafeToErase(SymbolOpInterface check) {
2228 assert(check); // pre-condition
2229
2230 // If previously visited, return the safety result.
2231 auto visited = visitedPlusSafetyResult.find(check);
2232 if (visited != visitedPlusSafetyResult.end()) {
2233 return visited->second;
2234 }
2235
2236 // If it's a StructDefOp that is not in `tryToErase` then it cannot be erased.
2237 if (StructDefOp sd = llvm::dyn_cast<StructDefOp>(check.getOperation())) {
2238 if (!tryToErase.contains(sd)) {
2239 visitedPlusSafetyResult[check] = false;
2240 return false;
2241 }
2242 }
2243
2244 // Otherwise, temporarily mark as safe b/c a node cannot keep itself live (and this prevents
2245 // the recursion from getting stuck in an infinite loop).
2246 visitedPlusSafetyResult[check] = true;
2247
2248 // Check if it's safe according to both the def tree and use graph.
2249 // Note: Every symbol must have a def node but ModuleOp and TemplateOp symbols may not have a
2250 // use node since they are not "terminal" symbols (i.e. they are not referred to directly).
2251 if (collectSafeToErase(defTree.lookupNode(check))) {
2252 const auto *useNode = useGraph.lookupNode(check);
2253 assert(useNode || (llvm::isa<ModuleOp, TemplateOp>(check.getOperation())));
2254 if (!useNode || collectSafeToErase(useNode)) {
2255 return true;
2256 }
2257 }
2258
2259 // Otherwise, revert the safety decision and return it.
2260 visitedPlusSafetyResult[check] = false;
2261 return false;
2262 }
2263
2265 bool collectSafeToErase(const SymbolDefTreeNode *check) {
2266 assert(check); // pre-condition
2267 if (const SymbolDefTreeNode *p = check->getParent()) {
2268 if (SymbolOpInterface checkOp = p->getOp()) { // safe if parent is root
2269 return collectSafeToErase(checkOp);
2270 }
2271 }
2272 return true;
2273 }
2274
2276 bool collectSafeToErase(const SymbolUseGraphNode *check) {
2277 assert(check); // pre-condition
2278 for (const SymbolUseGraphNode *p : check->predecessorIter()) {
2279 if (SymbolOpInterface checkOp = cachedLookup(p)) { // safe if via IncludeOp
2280 if (!collectSafeToErase(checkOp)) {
2281 return false;
2282 }
2283 }
2284 }
2285 return true;
2286 }
2287
2292 SymbolOpInterface cachedLookup(const SymbolUseGraphNode *node) {
2293 assert(node && "must provide a node"); // pre-condition
2294 // Check for cached result
2295 auto fromCache = lookupCache.find(node);
2296 if (fromCache != lookupCache.end()) {
2297 return fromCache->second;
2298 }
2299 // Otherwise, perform lookup and cache
2300 auto lookupRes = node->lookupSymbol(tables);
2301 assert(succeeded(lookupRes) && "graph contains node with invalid path");
2302 assert(lookupRes->get() != nullptr && "lookup must return an Operation");
2303 // If loaded via an IncludeOp it's not in the current AST anyway so ignore.
2304 // NOTE: The SymbolUseGraph does contain nodes for struct parameters which cannot cast to
2305 // SymbolOpInterface. However, those will always be leaf nodes in the SymbolUseGraph and
2306 // therefore will not be traversed by this analysis so directly casting is fine.
2307 SymbolOpInterface actualRes =
2308 lookupRes->viaInclude() ? nullptr : llvm::cast<SymbolOpInterface>(lookupRes->get());
2309 // Cache and return
2310 lookupCache[node] = actualRes;
2311 assert((!actualRes == lookupRes->viaInclude()) && "not found iff included"); // post-condition
2312 return actualRes;
2313 }
2314};
2315
2316} // namespace Step5_Cleanup
2317
2318class FlatteningPass : public llzk::polymorphic::impl::FlatteningPassBase<FlatteningPass> {
2319
2320 void runOnOperation() override {
2321 ModuleOp modOp = getOperation();
2322 if (failed(runOn(modOp))) {
2323 LLVM_DEBUG({
2324 // If the pass failed, dump the current IR.
2325 llvm::dbgs() << "=====================================================================\n";
2326 llvm::dbgs() << " Dumping module after failure of pass " << DEBUG_TYPE << '\n';
2327 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
2328 llvm::dbgs() << "=====================================================================\n";
2329 });
2330 signalPassFailure();
2331 }
2332 }
2333
2334 inline LogicalResult runOn(ModuleOp modOp) {
2335 // If the cleanup mode is set to remove anything not reachable from the main struct, do an
2336 // initial pass to remove things that are not reachable (as an optimization) because creating
2337 // an instantiated version of a struct will not cause something to become reachable that was
2338 // not already reachable in parameterized form.
2339 if (cleanupMode == StructCleanupMode::MainAsRoot) {
2340 if (failed(eraseUnreachableFromMainStruct(modOp))) {
2341 return failure();
2342 }
2343 }
2344
2345 // Pass Manager to run some standard cleanup passes that are always beneficial:
2346 // - Remove templates that contain no struct or function definitions
2347 // - Convert templates with no constant parameters or expressions into modules
2348 OpPassManager universalCleanup(ModuleOp::getOperationName());
2349 universalCleanup.addPass(createEmptyTemplateRemoval());
2350
2351 // Run universal cleanup as a preliminary step to satisfy the
2352 // `assert(!isNullOrEmpty(paramNames))` precondition in `genClone()`.
2353 if (failed(runPipeline(universalCleanup, modOp))) {
2354 return failure();
2355 }
2356
2357 ConversionTracker tracker;
2358 unsigned loopCount = 0;
2359 do {
2360 ++loopCount;
2361 if (loopCount > iterationLimit) {
2362 llvm::errs() << DEBUG_TYPE << " exceeded the limit of " << iterationLimit
2363 << " iterations!\n";
2364 return failure();
2365 }
2366 tracker.resetModifiedFlag();
2367
2368 LLVM_DEBUG({
2369 llvm::dbgs() << "[FlatteningPass(count=" << loopCount
2370 << ")] Running step 1: struct instantiation\n";
2371 });
2372 // Find calls to "compute()" that return a parameterized struct type and replace it to call an
2373 // instantiated version of the struct that has parameters replaced with the constant values.
2374 // Create the necessary instantiated/flattened struct in the same location as the original.
2375 if (failed(Step1A_InstantiateStructs::run(modOp, tracker))) {
2376 llvm::errs() << DEBUG_TYPE << " failed while instantiating structs in templates\n";
2377 return failure();
2378 }
2379 // Instantiate calls to templated functions.
2380 if (failed(Step1B_InstantiateFunctions::run(modOp, tracker))) {
2381 llvm::errs() << DEBUG_TYPE << " failed while instantiating functions in templates\n";
2382 return failure();
2383 }
2384
2385 LLVM_DEBUG({
2386 llvm::dbgs() << "[FlatteningPass(count=" << loopCount
2387 << ")] Running step 2: loop unrolling\n";
2388 });
2389 // Unroll loops with known iterations.
2390 if (failed(Step2_Unroll::run(modOp, tracker))) {
2391 llvm::errs() << DEBUG_TYPE << " failed while unrolling loops\n";
2392 return failure();
2393 }
2394
2395 LLVM_DEBUG({
2396 llvm::dbgs() << "[FlatteningPass(count=" << loopCount
2397 << ")] Running step 3: affine maps instantiation\n";
2398 });
2399 // Instantiate affine_map parameters of StructType and ArrayType.
2400 if (failed(Step3_InstantiateAffineMaps::run(modOp, tracker))) {
2401 llvm::errs() << DEBUG_TYPE << " failed while instantiating `affine_map` parameters\n";
2402 return failure();
2403 }
2404
2405 LLVM_DEBUG({
2406 llvm::dbgs() << "[FlatteningPass(count=" << loopCount
2407 << ")] Running step 4: type propagation\n";
2408 });
2409 // Propagate updated types using the semantics of various ops.
2410 if (failed(Step4_PropagateTypes::run(modOp, tracker))) {
2411 llvm::errs() << DEBUG_TYPE << " failed while propagating instantiated types\n";
2412 return failure();
2413 }
2414
2415 LLVM_DEBUG(if (tracker.isModified()) {
2416 llvm::dbgs() << "=====================================================================\n";
2417 llvm::dbgs() << " Dumping module between iterations of " << DEBUG_TYPE << '\n';
2418 modOp.print(llvm::dbgs(), OpPrintingFlags().assumeVerified());
2419 llvm::dbgs() << "=====================================================================\n";
2420 });
2421 } while (tracker.isModified());
2422
2423 // Run user-selected cleanup first.
2424 if (failed(cleanupSwitch(modOp, tracker))) {
2425 return failure();
2426 }
2427 // Run universal cleanup again since no-param or param-only structs may exist now.
2428 if (failed(runPipeline(universalCleanup, modOp))) {
2429 return failure();
2430 }
2431 return success();
2432 }
2433
2434 // Perform cleanup according to the 'cleanupMode' option.
2435 LogicalResult cleanupSwitch(ModuleOp modOp, const ConversionTracker &tracker) {
2436 LLVM_DEBUG({ llvm::dbgs() << "[FlatteningPass] Running step 5: cleanup "; });
2437 switch (cleanupMode) {
2438 case StructCleanupMode::MainAsRoot:
2439 LLVM_DEBUG(llvm::dbgs() << "(main as root mode)\n");
2440 return eraseUnreachableFromMainStruct(modOp, false);
2441 case StructCleanupMode::ConcreteAsRoot:
2442 LLVM_DEBUG(llvm::dbgs() << "(concrete structs mode)\n");
2443 return eraseUnreachableFromConcreteStructs(modOp);
2444 case StructCleanupMode::Preimage:
2445 LLVM_DEBUG(llvm::dbgs() << "(preimage mode)\n");
2446 return erasePreimageOfInstantiations(modOp, tracker);
2447 default:
2448 LLVM_DEBUG(llvm::dbgs() << "(disabled)\n");
2449 return success();
2450 }
2451 }
2452
2453 // Erase parameterized structs that were replaced with concrete instantiations.
2454 LogicalResult erasePreimageOfInstantiations(ModuleOp rootMod, const ConversionTracker &tracker) {
2455 // TODO: The names from getInstantiatedStructNames() are NOT guaranteed to be paths from the
2456 // "top root" and they also do not indicate a root module so there could be ambiguity. This is a
2457 // broader problem in the FlatteningPass itself so let's just assume, for now, that these are
2458 // paths from the "top root". See [LLZK-286].
2459 Step5_Cleanup::FromEraseSet cleaner(
2460 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>(),
2461 tracker.getInstantiatedStructNames()
2462 );
2463 LogicalResult res = cleaner.eraseUnusedStructs();
2464 if (succeeded(res)) {
2465 LLVM_DEBUG(llvm::dbgs() << "[Cleanup(preimage)] success\n";);
2466 // Warn about any structs that were instantiated but still have uses elsewhere.
2467 const SymbolUseGraph *useGraph = nullptr;
2468 rootMod->walk([this, &cleaner, &useGraph](StructDefOp op) {
2469 if (cleaner.getTryToEraseSet().contains(op)) {
2470 // If needed, rebuild use graph to reflect deletions.
2471 if (!useGraph) {
2472 useGraph = &getAnalysis<SymbolUseGraph>();
2473 }
2474 // If the op has any users, report the warning.
2475 if (useGraph->lookupNode(op)->hasPredecessor()) {
2476 op.emitWarning("Parameterized struct still has uses!").report();
2477 }
2478 }
2479 return WalkResult::skip(); // StructDefOp cannot be nested
2480 });
2481 } else {
2482 LLVM_DEBUG(llvm::dbgs() << "[Cleanup(preimage)] failed\n";);
2483 }
2484 return res;
2485 }
2486
2487 LogicalResult eraseUnreachableFromConcreteStructs(ModuleOp rootMod) {
2488 SmallVector<StructDefOp> roots;
2489 rootMod.walk([&roots](StructDefOp op) {
2490 if (!op.hasTemplateSymbolBindings()) {
2491 roots.push_back(op);
2492 }
2493 return WalkResult::skip(); // StructDefOp cannot be nested
2494 });
2495
2496 Step5_Cleanup::FromKeepSet cleaner(
2497 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
2498 );
2499 return cleaner.eraseUnreachableFrom(roots);
2500 }
2501
2502 LogicalResult eraseUnreachableFromMainStruct(ModuleOp rootMod, bool emitWarning = true) {
2503 Step5_Cleanup::FromKeepSet cleaner(
2504 rootMod, getAnalysis<SymbolDefTree>(), getAnalysis<SymbolUseGraph>()
2505 );
2506 FailureOr<SymbolLookupResult<StructDefOp>> mainOpt =
2507 getMainInstanceDef(cleaner.tables, rootMod.getOperation());
2508 if (failed(mainOpt)) {
2509 return failure();
2510 }
2511 SymbolLookupResult<StructDefOp> main = mainOpt.value();
2512 if (emitWarning && !main) {
2513 // Emit warning if there is no main specified because all structs may be removed (only
2514 // structs that are reachable from a global def or free function will be preserved since
2515 // those constructs are not candidate for removal in this pass).
2516 rootMod.emitWarning()
2517 .append(
2518 "using option '", cleanupMode.getArgStr(), '=',
2519 stringifyStructCleanupMode(StructCleanupMode::MainAsRoot), "' with no \"",
2520 MAIN_ATTR_NAME, "\" attribute on the top-level module may remove all structs!"
2521 )
2522 .report();
2523 }
2524 return cleaner.eraseUnreachableFrom(
2525 main ? ArrayRef<StructDefOp> {*main} : ArrayRef<StructDefOp> {}
2526 );
2527 }
2528};
2529
2530} // namespace
2531
2533 return std::make_unique<FlatteningPass>();
2534};
#define DEBUG_TYPE
#define DEBUG_TYPE
#define check(x)
Definition Ops.cpp:175
Common private implementation for poly dialect passes.
This file defines methods symbol lookup across LLZK operations and included files.
static std::string from(mlir::Type type)
Return a brief string representation of the given LLZK type.
Definition TypeHelper.h:55
Builds a tree structure representing the symbol table structure.
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbol(mlir::SymbolTableCollection &tables, bool reportMissing=true) const
Builds a graph structure representing the relationships between symbols and their uses.
::mlir::TypedValue<::llzk::array::ArrayType > getArrRef()
Gets the SSA Value for the referenced array.
inline ::llzk::array::ArrayType getArrRefType()
Gets the type of the referenced array.
ArrayType cloneWith(std::optional<::llvm::ArrayRef< int64_t > > shape, ::mlir::Type elementType) const
Clone this type with the given shape and element type.
::mlir::Type getElementType() const
static ArrayType get(::mlir::Type elementType, ::llvm::ArrayRef<::mlir::Attribute > dimensionSizes)
Definition Types.cpp.inc:83
::llvm::ArrayRef<::mlir::Attribute > getDimensionSizes() const
::mlir::TypedValue<::llzk::array::ArrayType > getResult()
Definition Ops.h.inc:408
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:421
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:392
::mlir::TypedValue<::mlir::Type > getResult()
Definition Ops.h.inc:923
::mlir::TypedValue<::mlir::Type > getRvalue()
Definition Ops.h.inc:1075
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:353
void setType(::mlir::Type attrValue)
Definition Ops.cpp.inc:556
::std::optional<::mlir::Attribute > getTableOffset()
Definition Ops.cpp.inc:979
void setTableOffsetAttr(::mlir::Attribute attr)
Definition Ops.h.inc:750
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the MemberRefOp.
::mlir::FailureOr< SymbolLookupResult< MemberDefOp > > getMemberDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the member referenced in this op.
Definition Ops.cpp:689
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:938
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:1165
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:1600
void setSymName(::llvm::StringRef attrValue)
Definition Ops.cpp.inc:1605
bool hasTemplateSymbolBindings()
Return true iff the struct.def appears within a poly.template that defines constant parameters and/or...
Definition Ops.cpp:193
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:79
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
Definition Types.cpp:26
::mlir::ArrayAttr getParams() const
bool calleeIsStructConstrain()
Return true iff the callee function name is FUNC_NAME_CONSTRAIN within a StructDefOp.
Definition Ops.cpp:1038
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the callee is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:1061
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.h.inc:292
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:960
bool calleeIsStructCompute()
Return true iff the callee function name is FUNC_NAME_COMPUTE within a StructDefOp.
Definition Ops.cpp:1032
::mlir::SymbolRefAttr getCallee()
Definition Ops.cpp.inc:470
::mlir::FunctionType getTypeSignature()
Return the FunctionType inferred from the arg operands and result types of this CallOp.
Definition Ops.cpp:1000
void setTemplateParamsAttr(::mlir::ArrayAttr attr)
Definition Ops.h.inc:316
::mlir::Operation::operand_range getArgOperands()
Definition Ops.h.inc:266
::mlir::ArrayAttr getTemplateParamsAttr()
Definition Ops.h.inc:297
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:270
void setCalleeAttr(::mlir::SymbolRefAttr attr)
Definition Ops.h.inc:312
::mlir::FailureOr< UnificationMap > unifyTypeSignature(::mlir::FunctionType other)
Attempt type unfication between the inferred FunctionType from this CallOp (as LHS) and the given Fun...
Definition Ops.cpp:1004
::mlir::LogicalResult verifyTemplateParamsMatchInferred(::llvm::iterator_range<::mlir::Region::op_iterator<::llzk::polymorphic::TemplateParamOp > > targetParamDefs, const UnificationMap &unifications)
Verify that each template parameter value provided in this CallOp is consistent with the value inferr...
Definition Ops.cpp:568
::mlir::LogicalResult verifyTemplateParamCompatibility(::mlir::Attribute paramFromCallOp, ::llzk::polymorphic::TemplateParamOp targetParam)
Check type compatibility of the given template parameter value from this CallOp against the declared ...
::mlir::FailureOr<::llzk::SymbolLookupResult<::llzk::function::FuncDefOp > > getCalleeTarget(::mlir::SymbolTableCollection &tables)
Resolve and return the target FuncDefOp for this CallOp.
Definition Ops.cpp:1054
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:302
FuncDefOp clone(::mlir::IRMapping &mapper)
Create a deep copy of this function and all of its blocks, remapping any operands that use values out...
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:984
::llvm::ArrayRef<::mlir::Type > getArgumentTypes()
Required by FunctionOpInterface.
Definition Ops.h.inc:838
::llzk::component::StructType getSingleResultTypeOfCompute()
Assuming the name is FUNC_NAME_COMPUTE, return the single StructType result.
Definition Ops.cpp:387
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:979
bool isStructCompute()
Return true iff the function is within a StructDefOp and named FUNC_NAME_COMPUTE.
Definition Ops.h.inc:867
bool isInStruct()
Return true iff the function is within a StructDefOp.
Definition Ops.h.inc:864
void setFunctionType(::mlir::FunctionType attrValue)
Definition Ops.cpp.inc:1003
void setSymName(::llvm::StringRef attrValue)
Definition Ops.cpp.inc:999
::mlir::StringAttr getSymNameAttr()
Definition Ops.h.inc:703
::mlir::Operation::operand_range getOperands()
Definition Ops.h.inc:979
::mlir::FlatSymbolRefAttr getConstNameAttr()
Definition Ops.h.inc:464
::mlir::StringAttr getSymNameAttr()
Definition Ops.h.inc:673
::mlir::Region & getInitializerRegion()
Definition Ops.h.inc:660
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:838
::mlir::Region & getBodyRegion()
Definition Ops.h.inc:872
::llvm::SmallVector<::mlir::Attribute > getConstNames()
Return the names of all ops of type OpT within the body region in the order they are defined.
Definition Ops.h.inc:941
::mlir::StringAttr getSymNameAttr()
Definition Ops.h.inc:885
void setSymName(::llvm::StringRef attrValue)
Definition Ops.cpp.inc:1064
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:1059
inline ::llvm::iterator_range<::mlir::Region::op_iterator< OpT > > getConstOps()
Return ops of type OpT within the body region.
Definition Ops.h.inc:921
::mlir::FlatSymbolRefAttr getNameRef() const
int main(int argc, char **argv)
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
Definition Debug.h:156
mlir::ConversionTarget newConverterDefinedTargetWithCallback(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, LegalityCheckCallback &cb, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:97
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:81
std::unique_ptr< mlir::Pass > createFlatteningPass()
::llvm::StringRef stringifyStructCleanupMode(StructCleanupMode val)
std::unique_ptr< mlir::Pass > createEmptyTemplateRemoval()
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter::replaceOpWithNewOp() that automatically copies discardable attributes (i...
llvm::SMTExprRef tripCount(mlir::scf::ForOp op, llvm::SMTSolver *solver)
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
Definition TypeHelper.h:240
bool isConcreteType(Type type, bool allowStructParams)
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
TypeClass getIfSingleton(mlir::TypeRange types)
Definition TypeHelper.h:270
mlir::DenseMap< std::pair< mlir::SymbolRefAttr, Side >, mlir::Attribute > UnificationMap
Optional result from type unifications.
Definition TypeHelper.h:186
std::string stringWithoutType(mlir::Attribute a)
bool isNullOrEmpty(mlir::ArrayAttr a)
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:69
TypeClass getAtIndex(mlir::TypeRange types, size_t index)
Definition TypeHelper.h:274
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet covering all LLZK op types that may contain a StructType.
bool isDynamic(IntegerAttr intAttr)
mlir::SymbolRefAttr asSymbolRefAttr(mlir::StringAttr root, mlir::SymbolRefAttr tail)
Build a SymbolRefAttr that prepends tail with root, i.e., root::tail.
int64_t fromAPInt(const llvm::APInt &i)
FailureOr< SymbolLookupResult< StructDefOp > > getMainInstanceDef(SymbolTableCollection &symbolTable, Operation *lookupFrom)
bool isMoreConcreteUnification(Type oldTy, Type newTy, llvm::function_ref< bool(Type oldTy, Type newTy)> knownOldToNew)
llvm::SmallVector< FlatSymbolRefAttr > getPieces(SymbolRefAttr ref)
constexpr char MAIN_ATTR_NAME[]
Name of the attribute on the top-level ModuleOp that specifies the type of the main struct.
Definition Constants.h:37