1//===-- Ops.td ---------------------------------------------*- tablegen -*-===//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
8//===----------------------------------------------------------------------===//
10#ifndef LLZK_POLYMORPHIC_OPS
11#define LLZK_POLYMORPHIC_OPS
13include "llzk/Dialect/Polymorphic/IR/Dialect.td"
14include "llzk/Dialect/Polymorphic/IR/Types.td"
15include "llzk/Dialect/Polymorphic/IR/OpInterfaces.td"
16include "llzk/Dialect/Shared/OpTraits.td"
18include "mlir/IR/OpBase.td"
19include "mlir/IR/RegionKindInterface.td"
20include "mlir/IR/SymbolInterfaces.td"
21include "mlir/Interfaces/ControlFlowInterfaces.td"
22include "mlir/Interfaces/SideEffectInterfaces.td"
24class PolymorphicDialectOp<string mnemonic, list<Trait> traits = []>
25 : Op<PolymorphicDialect, mnemonic, traits>;
28 : PolymorphicDialectOp<"template", [HasParent<"::mlir::ModuleOp">, Symbol,
29 LLZKSymbolTable, IsolatedFromAbove,
30 NoRegionArguments, NoTerminator,
32 let summary = "defines polymorphic functions or structs";
34 The `poly.template` allows defining polymorphic templated functions and structs.
35 The body contains the definitions of the template's parameters along with the
36 function and/or struct definitions that utilize those parameters in their bodies.
40 poly.template @TemplateName {
44 function.def @f(%inp: !array.type<8,5,@N x !felt.type>) -> !array.type<@N x !felt.type> {
45 // function header and body can use parameters @N and @T
48 struct.def @StructName {
49 function.def @compute() -> !struct.type<@TemplateName::@StructName> {
52 function.def @constrain(%self: !struct.type<@TemplateName::@StructName>) {
59 The order of `poly.param` definitions in the template body determines the order that
60 template parameters must be listed in the parameter list of a `struct.type` refering
61 to a struct nested within the template. In the example above, the type of `@StructName`
62 is `!struct.type<@TemplateName::@StructName<[@N, @T]>>`.
65 let arguments = (ins SymbolNameAttr:$sym_name);
67 let regions = (region SizedRegion<1>:$bodyRegion);
69 let assemblyFormat = [{ $sym_name $bodyRegion attr-dict }];
71 let extraClassDeclaration = [{
72 /// Return ops of type `OpT` within the body region.
73 /// Ops are returned in the order they are defined in the IR.
74 template <TemplateSymbolBindingOp OpT>
75 inline ::llvm::iterator_range<::mlir::Region::op_iterator<OpT>> getConstOps() {
76 return getBodyRegion().getOps<OpT>();
79 /// Return `true` if there are ops of type `OpT` within the body region.
80 template <TemplateSymbolBindingOp OpT>
81 inline bool hasConstOps() {
82 return !getConstOps<OpT>().empty();
85 /// Return the number of ops of type `OpT` within the body region.
86 template <TemplateSymbolBindingOp OpT>
87 inline size_t numConstOps() {
88 return llvm::range_size(getConstOps<OpT>());
91 /// Return the names of all ops of type `OpT` within the body region in the order they
92 /// are defined. The names are returned as `FlatSymbolRefAttr` but the more general
93 /// `Attribute` type is used in the return type since that's usually what's needed.
94 template <TemplateSymbolBindingOp OpT>
95 ::llvm::SmallVector<::mlir::Attribute> getConstNames() {
96 return ::llvm::to_vector(::llvm::map_range(getConstOps<OpT>(), [](auto p) -> ::mlir::Attribute {
97 return ::mlir::FlatSymbolRefAttr::get(p.getNameAttr());
101 /// Return `true` if there is an op of type `OpT` with the given name within the body region.
102 template <TemplateSymbolBindingOp OpT>
103 inline bool hasConstNamed(::mlir::StringRef find) {
104 return ::llvm::any_of(getConstOps<OpT>(), [&](OpT op) {
105 return op.getName() == find;
109 /// Return `true` if there is an op of type `OpT` with the given name within the body region.
110 template <TemplateSymbolBindingOp OpT>
111 inline bool hasConstNamed(::mlir::StringAttr find) {
112 return hasConstNamed<OpT>(find.strref());
115 /// Return `true` if there is an op of type `OpT` with the given name within the body region.
116 template <TemplateSymbolBindingOp OpT>
117 inline bool hasConstNamed(::mlir::FlatSymbolRefAttr find) {
118 return hasConstNamed<OpT>(find.getRootReference());
121 /// Return the op of type `OpT` with the given name within the body region if it exists, else `nullptr`.
122 template <TemplateSymbolBindingOp OpT>
123 inline OpT getConstNamed(::mlir::StringRef find) {
124 auto range = getConstOps<OpT>();
125 auto it = ::llvm::find_if(range, [&find](OpT op) { return op.getName() == find; });
126 return it != range.end() ? *it : OpT{};
129 /// Return the op of type `OpT` with the given name within the body region if it exists, else `nullptr`.
130 template <TemplateSymbolBindingOp OpT>
131 inline OpT getConstNamed(::mlir::StringAttr find) {
132 return getConstNamed<OpT>(find.strref());
135 /// Return the op of type `OpT` with the given name within the body region if it exists, else `nullptr`.
136 template <TemplateSymbolBindingOp OpT>
137 inline OpT getConstNamed(::mlir::FlatSymbolRefAttr find) {
138 return getConstNamed<OpT>(find.getRootReference());
143def LLZK_TemplateParamOp
144 : PolymorphicDialectOp<
145 "param", [HasParent<"::llzk::polymorphic::TemplateOp">,
146 DeclareOpInterfaceMethods<TemplateSymbolBindingOpInterface>,
147 DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
148 let summary = "declares a parameter of a polymorphic template";
150 Declares a parameter of a `poly.template` that can be used by the function and/or struct
151 definitions within the template. Each parameter can have an optional type restriction.
155 poly.template @TemplateName {
157 poly.param @F : !felt.type
158 // To restrict a parameter to accept only a type, use `!poly.tvar` with the parameter's own name.
159 poly.param @T : !poly.tvar<@T>
164 let arguments = (ins SymbolNameAttr:$sym_name,
165 OptionalAttr<TypeAttrOf<ConstReadType>>:$type_opt);
167 let assemblyFormat = [{ $sym_name (`:` $type_opt^)? attr-dict }];
170def LLZK_TemplateExprOp
171 : PolymorphicDialectOp<
172 "expr", [HasParent<"::llzk::polymorphic::TemplateOp">,
173 DeclareOpInterfaceMethods<TemplateSymbolBindingOpInterface>,
174 DeclareOpInterfaceMethods<SymbolUserOpInterface>,
175 IsolatedFromAbove, NoRegionArguments, SingleBlock]> {
176 let summary = "declares a named expression in a polymorphic template";
178 Declares an expression over parameters of a `poly.template` that can be used just like
179 a parameter within the function and/or struct definitions within the template.
181 The body of a `poly.expr` cannot contain any symbols defined via `poly.expr` to prevent
182 cyclic initialization. The body must also have no side effects to ensure it can be safely
183 duplicated if needed. This means operations such as read/write to globals or function
184 calls are not allowed in the body of a `poly.expr`.
188 poly.template @TemplateName {
189 poly.expr @ExprName {
190 %0 = some_op %param1, %param2 : (!felt.type, !felt.type) -> !felt.type
191 poly.yield %0 : !felt.type
197 let arguments = (ins SymbolNameAttr:$sym_name);
198 let regions = (region SizedRegion<1>:$initializerRegion);
200 let assemblyFormat = [{ $sym_name $initializerRegion attr-dict }];
202 let hasRegionVerifier = 1;
204 let extraClassDeclaration = [{
205 /// Returns the type of the `poly.yield` op in the initializer region.
206 ::mlir::Type getType();
211 : PolymorphicDialectOp<
212 "yield", [HasParent<"::llzk::polymorphic::TemplateExprOp">,
213 ReturnLike, Terminator]> {
214 let summary = "expr initialization yield and termination operation";
216 This operation yields an SSA value from a `poly.expr` initialization
217 region and terminates the region.
220 let arguments = (ins ConstReadType:$val);
222 let assemblyFormat = [{ $val `:` type($val) attr-dict }];
226 : PolymorphicDialectOp<"read_const", [Pure, DeclareOpInterfaceMethods<
227 SymbolUserOpInterface>]> {
228 let summary = "read value of a struct parameter";
230 This operation reads the value from the named constant parameter of
231 the struct/component in which this op appears. The op itself puts
232 some restriction on the type of this value, but leaves it to a later
233 type-checking pass to ensure the struct parameters are instantiated
234 with types matching the uses of the parameter within the struct.
238 // Read a `!felt.type` value from struct parameter "@A"
239 %0 = poly.read_const @A : !felt.type
240 // Read a value from struct parameter "@B" where its type is
241 // specified by struct parameter "@T"
242 %1 = poly.read_const @B : !poly.tvar<@T>
246 let arguments = (ins FlatSymbolRefAttr:$const_name);
247 let results = (outs ConstReadType:$val);
249 let assemblyFormat = [{ $const_name `:` type($val) attr-dict }];
252def LLZK_UnifiableCastOp : PolymorphicDialectOp<"unifiable_cast", [Pure]> {
253 let summary = "cast between two unifiable types";
255 This operation reinterprets a value as a different type with the restriction
256 that the input and output types of the cast are unifiable.
258 Most ops that accept LLZK types accept unifiable types as input and thus there
259 is no need for casting between types. This op is meant to be used in situations where
260 is not possible to modify the given or the target type and they are different but unifiable.
261 For example, inside a conversion pattern the driver may introduce `unrealized_conversion_cast`
262 operations if the types are not equal. This will happen regardless of whether the two types unify.
263 This cast can be introduced instead of the default cast operation to satisfy MLIR's assumptions
268 %0 = some_other_op : !array.type<@N x !felt.type>
269 %1 = unifiable_cast %0 : (!array.type<@N x @felt.type>) -> !array.type<affine_map<()[s0, s1] -> (s0 + s1)> x !felt.type>
273 let arguments = (ins AnyLLZKType:$input);
274 let results = (outs AnyLLZKType:$result);
275 let assemblyFormat = [{
276 $input `:` functional-type($input, results) attr-dict
282def LLZK_ApplyMapOp : PolymorphicDialectOp<"applymap", [Pure]> {
283 let summary = "apply an AffineMap";
285 This operation applies an AffineMap to a list of SSA values, yielding a single
286 SSA value. The number of dimension and symbol arguments must be equal to the
287 respective number of dimensional and symbolic inputs to the AffineMap; the
288 AffineMap has to be one-dimensional, and so this operation always returns one
289 value. The input operands and result all have `index` type.
293 #map10 = affine_map<(d0, d1) -> (d0 floordiv 8 + d1 floordiv 128)>
295 %1 = poly.applymap(%s, %t) #map10
300 %2 = poly.applymap(%42)[%n] affine_map<(i)[s0] -> (i+s0)>
304 let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$mapOperands,
306 let results = (outs Index);
308 // Define builders manually so inference of `numDims` attribute is not
310 let skipDefaultBuilders = 1;
311 let builders = [OpBuilder<(ins "::mlir::AffineMapAttr":$map,
312 CArg<"::mlir::ValueRange", "{}">:$mapOperands),
314 $_state.addOperands(mapOperands);
315 Properties &props = $_state.getOrAddProperties<Properties>();
317 props.setNumDims($_builder.getIntegerAttr($_builder.getIndexType(),
318 map.getAffineMap().getNumDims()));
319 $_state.addTypes($_builder.getIndexType());
321 OpBuilder<(ins "::mlir::AffineMap":$map,
322 CArg<"::mlir::ValueRange", "{}">:$mapOperands),
324 build($_builder, $_state, ::mlir::AffineMapAttr::get(map), mapOperands);
326 OpBuilder<(ins "::mlir::AffineExpr":$expr,
327 CArg<"::mlir::ValueRange", "{}">:$mapOperands),
329 auto map = ::mlir::AffineMap::inferFromExprList({expr}, $_builder.getContext()).front();
330 build($_builder, $_state, map, mapOperands);
333 let assemblyFormat = [{
334 custom<DimAndSymbolList>($mapOperands, $numDims) $map attr-dict
339 let extraClassDeclaration = [{
340 /// Returns the affine map to be applied by this operation.
341 ::mlir::AffineMap inline getAffineMap() { return getMap(); }
343 /// Returns the affine value map computed from this operation.
344 ::mlir::affine::AffineValueMap getAffineValueMap() {
345 return ::mlir::affine::AffineValueMap(getAffineMap(), getOperands(), getResult());
348 /// Returns all dimension operands.
349 ::mlir::ValueRange getDimOperands() {
350 return ::mlir::OperandRange{
351 getOperands().begin(),
352 getOperands().begin() + getMap().getNumDims()};
355 /// Returns all symbol operands.
356 ::mlir::ValueRange getSymbolOperands() {
357 return ::mlir::OperandRange{
358 getOperands().begin() + getMap().getNumDims(),
359 getOperands().end()};
364#endif // LLZK_POLYMORPHIC_OPS