LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Dialect.cpp
Go to the documentation of this file.
1//===-- Dialect.cpp - Dialect method implementations ------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
22
23#include <mlir/IR/Builders.h>
24#include <mlir/IR/BuiltinTypes.h>
25#include <mlir/IR/DialectImplementation.h>
26#include <mlir/Transforms/DialectConversion.h>
27
28#include <llvm/ADT/DenseMap.h>
29#include <llvm/ADT/TypeSwitch.h>
30
31// TableGen'd implementation files
33
34#define GET_TYPEDEF_CLASSES
36
37using namespace mlir;
38using namespace llzk;
39using namespace llzk::array;
40using namespace llzk::component;
41using namespace llzk::constrain;
42using namespace llzk::function;
43using namespace llzk::global;
44using namespace llzk::polymorphic;
45
46namespace {
47
49class V1StructNameTypeConverter : public TypeConverter {
50 const DenseMap<SymbolRefAttr, SymbolRefAttr> &fqnMap;
51
52public:
53 explicit V1StructNameTypeConverter(const DenseMap<SymbolRefAttr, SymbolRefAttr> &renamingMap)
54 : fqnMap(renamingMap) {
55
56 addConversion([](Type t) { return t; });
57
58 addConversion([this](StructType t) {
59 auto it = fqnMap.find(t.getNameRef());
60 SymbolRefAttr newRef = (it != fqnMap.end()) ? it->second : t.getNameRef();
61 bool changed = (newRef != t.getNameRef());
62 ArrayAttr params = t.getParams();
63 if (params) {
64 SmallVector<Attribute> updated;
65 bool paramsChanged = false;
66 for (Attribute a : params) {
67 if (auto ta = dyn_cast<TypeAttr>(a)) {
68 Type inner = convertType(ta.getValue());
69 updated.push_back(TypeAttr::get(inner));
70 paramsChanged |= (inner != ta.getValue());
71 } else {
72 updated.push_back(a);
73 }
74 }
75 if (paramsChanged) {
76 params = ArrayAttr::get(t.getContext(), updated);
77 changed = true;
78 }
79 }
80 return changed ? StructType::get(newRef, params) : t;
81 });
82 }
83};
84
87class V1CallOpPattern : public OpConversionPattern<CallOp> {
88 const DenseMap<SymbolRefAttr, SymbolRefAttr> &fqnMap;
89
90public:
91 V1CallOpPattern(
92 TypeConverter &converter, MLIRContext *ctx,
93 const DenseMap<SymbolRefAttr, SymbolRefAttr> &renamingMap
94 )
95 : OpConversionPattern<CallOp>(converter, ctx), fqnMap(renamingMap) {}
96
97 LogicalResult matchAndRewrite(
98 CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter
99 ) const override {
100 SmallVector<Type> newResultTypes;
101 if (failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
102 return op->emitError("Could not convert Op result types.");
103 }
104 // Remap callee if its path prefix matches a struct FQN that was wrapped.
105 SymbolRefAttr calleeAttr = op.getCalleeAttr();
106 SmallVector<FlatSymbolRefAttr> calleePieces = getPieces(calleeAttr);
107 for (auto &[oldFQN, newFQN] : fqnMap) {
108 SmallVector<FlatSymbolRefAttr> oldPieces = getPieces(oldFQN);
109 if (calleePieces.size() > oldPieces.size() &&
110 std::equal(
111 calleePieces.begin(), calleePieces.begin() + (ptrdiff_t)oldPieces.size(),
112 oldPieces.begin()
113 )) {
114 SmallVector<FlatSymbolRefAttr> newPieces = getPieces(newFQN);
115 newPieces.append(calleePieces.begin() + (ptrdiff_t)oldPieces.size(), calleePieces.end());
116 calleeAttr = asSymbolRefAttr(newPieces);
117 break;
118 }
119 }
121 rewriter, op, newResultTypes, calleeAttr, adaptor.getMapOperands(),
122 op.getNumDimsPerMapAttr(), adaptor.getArgOperands()
123 );
124 return success();
125 }
126};
127
128// Prior to version 2, `StructDefOp` had a `const_params` attribute containing the list of struct
129// parameters. In version 2, those parameters are represented explicitly as `poly.param` ops inside
130// a `poly.template` that wraps the `StructDefOp`.
131//
132// If the `StructDefOp::readProperties()` function encounters the old `const_params` attribute, it
133// stores them in a temporary `llzk::kV1ConstParamsAttr` attribute. This migration function creates
134// a `poly.param` op for each such parameter and creates a `poly.template` to wrap these followed by
135// the `StructDefOp`.
136LogicalResult migrateToV2(Operation *rootOp) {
137 // Ensure the Polymorphic dialect is loaded so we can create `poly.template` ops.
138 rootOp->getContext()->loadDialect<polymorphic::PolymorphicDialect>();
139
140 // Collect mappings from old to new FQN (fully-qualified name) for each updated struct.
141 llvm::DenseMap<SymbolRefAttr, SymbolRefAttr> oldToNewFQN;
142
143 // Visit all StructDefOp and perform the necessary transformation.
144 rootOp->walk<WalkOrder::PreOrder>([&oldToNewFQN](StructDefOp structOp) -> WalkResult {
145 Attribute constParamsAttr = structOp->getAttr(llzk::kV1ConstParamsAttr);
146 if (!constParamsAttr) {
147 return WalkResult::advance();
148 }
149 structOp->removeAttr(llzk::kV1ConstParamsAttr);
150
151 // Create the TemplateOp at the position of the StructDefOp, using the
152 // struct's own name (it becomes the outer template name).
153 OpBuilder builder(structOp);
154 auto templateOp =
155 builder.create<polymorphic::TemplateOp>(structOp.getLoc(), structOp.getSymName());
156
157 // Populate TemplateParamOps (in order) before the struct inside the template.
158 Block &templateBody = templateOp.getBodyRegion().emplaceBlock();
159 OpBuilder templateBuilder = OpBuilder::atBlockBegin(&templateBody);
160 auto constParams = llvm::cast<ArrayAttr>(constParamsAttr);
161 for (auto paramRef : constParams.getAsRange<FlatSymbolRefAttr>()) {
162 templateBuilder.create<polymorphic::TemplateParamOp>(
163 structOp.getLoc(), paramRef.getValue(),
164 /*type_opt=*/TypeAttr {}
165 );
166 }
167
168 // Compute the old FQN before the struct is moved into the template.
169 SymbolRefAttr oldFQN = structOp.getFullyQualifiedName();
170
171 // Move the StructDefOp into the template body (after the params).
172 structOp->moveBefore(&templateBody, templateBody.end());
173
174 // Record the FQN change: since the template name equals the old struct name, the
175 // new FQN is the old FQN with the struct name appended as one more nesting level.
176 oldToNewFQN[oldFQN] = appendLeaf(oldFQN, structOp.getSymNameAttr());
177
178 // Skip descending into the now-moved StructDefOp.
179 return WalkResult::skip();
180 });
181
182 // Done if no structs were updated.
183 if (oldToNewFQN.empty()) {
184 return success();
185 }
186
187 // Update all references to the old struct FQNs using the dialect conversion framework
188 // so that every StructType and CallOp callee is updated if necessary.
189 MLIRContext *ctx = rootOp->getContext();
190 V1StructNameTypeConverter tyConv(oldToNewFQN);
191 ConversionTarget target(*ctx);
192 target.markUnknownOpDynamicallyLegal([&tyConv](Operation *op) {
193 return defaultLegalityCheck(tyConv, op);
194 });
195
196 // Build pattern set for all LLZK op types. V1CallOpPattern (benefit 1) overrides
197 // the default CallOpClassReplacePattern (benefit 0) for CallOp.
198 RewritePatternSet patterns = newGeneralRewritePatternSet(tyConv, ctx, target);
199 patterns.add<V1CallOpPattern>(tyConv, ctx, oldToNewFQN);
200 return applyPartialConversion(rootOp, target, std::move(patterns));
201}
202
203} // namespace
204
205//===------------------------------------------------------------------===//
206// StructDialect
207//===------------------------------------------------------------------===//
208
210
214
215 LogicalResult upgradeFromVersion(
216 mlir::Operation *rootOp, const LLZKDialectVersion &current,
217 const LLZKDialectVersion &requested
218 ) const override {
219 assert(requested < current && "pre-condition");
220 if (requested.majorVersion < 2) {
221 if (failed(migrateToV2(rootOp))) {
222 return failure();
223 }
224 }
225 // Future migrations can be added here if necessary.
226 return success();
227 }
228};
229
230} // namespace llzk::component
231
232auto llzk::component::StructDialect::initialize() -> void {
233 // clang-format off
234 addOperations<
235 #define GET_OP_LIST
237 >();
238
239 addTypes<
240 #define GET_TYPEDEF_LIST
242 >();
243 // clang-format on
244 addInterfaces<StructDialectBytecodeInterface>();
245}
Reusable MLIR dialect conversion functions for LLZK StructType replacement.
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:79
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.h.inc:292
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:270
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:302
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter::replaceOpWithNewOp() that automatically copies discardable attributes (i...
SymbolRefAttr appendLeaf(SymbolRefAttr orig, FlatSymbolRefAttr newLeaf)
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet covering all LLZK op types that may contain a StructType.
mlir::SymbolRefAttr asSymbolRefAttr(mlir::StringAttr root, mlir::SymbolRefAttr tail)
Build a SymbolRefAttr that prepends tail with root, i.e., root::tail.
llvm::SmallVector< FlatSymbolRefAttr > getPieces(SymbolRefAttr ref)
bool defaultLegalityCheck(const mlir::TypeConverter &tyConv, mlir::Operation *op)
Check whether an op is legal with respect to the given type converter, including TypeAttr attributes ...
LLZKDialectBytecodeInterface(mlir::Dialect *dia)
Definition Versioning.h:42
Implement version upgrade for StructDialect.
Definition Dialect.cpp:212
LogicalResult upgradeFromVersion(mlir::Operation *rootOp, const LLZKDialectVersion &current, const LLZKDialectVersion &requested) const override
Definition Dialect.cpp:215