LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
SharedImpl.h
Go to the documentation of this file.
1//===-- SharedImpl.h --------------------------------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
13//===----------------------------------------------------------------------===//
14
15#pragma once
16
24
25#include <mlir/Dialect/Arith/IR/Arith.h>
26#include <mlir/Dialect/SCF/IR/SCF.h>
27#include <mlir/Dialect/SCF/Transforms/Patterns.h>
28#include <mlir/IR/Attributes.h>
29#include <mlir/IR/BuiltinAttributes.h>
30#include <mlir/IR/MLIRContext.h>
31#include <mlir/IR/Operation.h>
32#include <mlir/IR/PatternMatch.h>
33#include <mlir/Transforms/DialectConversion.h>
34
35#include <llvm/ADT/STLExtras.h>
36#include <llvm/ADT/SmallVector.h>
37#include <llvm/Support/Debug.h>
38
39#include <tuple>
40
41#define DEBUG_TYPE "poly-dialect-shared"
42
44
45namespace {
46
48static struct OpClassesWithStructTypes {
49
52 const std::tuple<
53 // clang-format off
72 // clang-format on
73 >
74 WithGeneralBuilder {};
75
80 const std::tuple<llzk::function::CallOp, llzk::array::CreateArrayOp> NoGeneralBuilder {};
81
82} OpClassesWithStructTypes;
83
84template <typename I, typename NextOpClass, typename... OtherOpClasses>
85inline void applyToMoreTypes(I inserter) {
86 std::apply(inserter, std::tuple<NextOpClass, OtherOpClasses...> {});
87}
88template <typename I> inline void applyToMoreTypes(I inserter) {}
89
90inline bool defaultLegalityCheck(const mlir::TypeConverter &tyConv, mlir::Operation *op) {
91 // Check operand types and result types
92 if (!tyConv.isLegal(op)) {
93 return false;
94 }
95 // Check type attributes
96 // Extend lifetime of temporary to suppress warnings.
97 mlir::DictionaryAttr dictAttr = op->getAttrDictionary();
98 for (mlir::NamedAttribute n : dictAttr.getValue()) {
99 if (mlir::TypeAttr tyAttr = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
100 mlir::Type t = tyAttr.getValue();
101 if (mlir::FunctionType funcTy = llvm::dyn_cast<mlir::FunctionType>(t)) {
102 if (!tyConv.isSignatureLegal(funcTy)) {
103 return false;
104 }
105 } else {
106 if (!tyConv.isLegal(t)) {
107 return false;
108 }
109 }
110 }
111 }
112 return true;
113}
114
115// Default to true if the check is not for that particular operation type.
116template <typename Check> inline bool runCheck(mlir::Operation *op, Check check) {
117 if (auto specificOp =
118 llvm::dyn_cast_if_present<typename llvm::function_traits<Check>::template arg_t<0>>(op)) {
119 return check(specificOp);
120 }
121 return true;
122}
123
124} // namespace
125
128template <typename OpClass, typename Rewriter, typename... Args>
129inline OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args) {
130 mlir::DictionaryAttr attrs = op->getDiscardableAttrDictionary();
131 OpClass newOp = rewriter.template replaceOpWithNewOp<OpClass>(op, std::forward<Args>(args)...);
132 newOp->setDiscardableAttrs(attrs);
133 return newOp;
134}
135
136// NOTE: This pattern will produce a compile error if `OpClass` does not define the general
137// `build(OpBuilder&, OperationState&, TypeRange, ValueRange, ArrayRef<NamedAttribute>)` function
138// because that function is required by the `replaceOpWithNewOp()` call.
139template <typename OpClass>
140class GeneralTypeReplacePattern : public mlir::OpConversionPattern<OpClass> {
141public:
142 using mlir::OpConversionPattern<OpClass>::OpConversionPattern;
143
144 mlir::LogicalResult matchAndRewrite(
145 OpClass op, OpClass::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter
146 ) const override {
147 const mlir::TypeConverter *converter = mlir::OpConversionPattern<OpClass>::getTypeConverter();
148 assert(converter);
149 // Convert result types
150 mlir::SmallVector<mlir::Type> newResultTypes;
151 if (mlir::failed(converter->convertTypes(op->getResultTypes(), newResultTypes))) {
152 return op->emitError("Could not convert Op result types.");
153 }
154 // ASSERT: 'adaptor.getAttributes()' is empty or subset of 'op->getAttrDictionary()' so the
155 // former can be ignored without losing anything.
156 assert(
157 adaptor.getAttributes().empty() ||
158 llvm::all_of(
159 adaptor.getAttributes(), [d = op->getAttrDictionary()](mlir::NamedAttribute a) {
160 return d.contains(a.getName());
161 }
162 )
163 );
164 // Convert any TypeAttr in the attribute list.
165 mlir::SmallVector<mlir::NamedAttribute> newAttrs(op->getAttrDictionary().getValue());
166 for (mlir::NamedAttribute &n : newAttrs) {
167 if (mlir::TypeAttr t = llvm::dyn_cast<mlir::TypeAttr>(n.getValue())) {
168 if (mlir::Type newType = converter->convertType(t.getValue())) {
169 n.setValue(mlir::TypeAttr::get(newType));
170 } else {
171 return op->emitError().append("Could not convert type in attribute: ", t);
172 }
173 }
174 }
175 // Build a new Op in place of the current one
177 rewriter, op, mlir::TypeRange(newResultTypes), adaptor.getOperands(),
178 mlir::ArrayRef(newAttrs)
179 );
180 return mlir::success();
181 }
182};
183
185 : public mlir::OpConversionPattern<llzk::array::CreateArrayOp> {
186public:
187 using mlir::OpConversionPattern<llzk::array::CreateArrayOp>::OpConversionPattern;
188
189 mlir::LogicalResult match(llzk::array::CreateArrayOp op) const override {
190 if (mlir::Type newType = getTypeConverter()->convertType(op.getType())) {
191 return mlir::success();
192 } else {
193 return op->emitError("Could not convert Op result type.");
194 }
195 }
196
198 llzk::array::CreateArrayOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter
199 ) const override {
200 mlir::Type newType = getTypeConverter()->convertType(op.getType());
201 assert(
202 llvm::isa<llzk::array::ArrayType>(newType) && "CreateArrayOp must produce ArrayType result"
203 );
204 mlir::DenseI32ArrayAttr numDimsPerMap = op.getNumDimsPerMapAttr();
205 if (isNullOrEmpty(numDimsPerMap)) {
207 rewriter, op, llvm::cast<llzk::array::ArrayType>(newType), adapter.getElements()
208 );
209 } else {
211 rewriter, op, llvm::cast<llzk::array::ArrayType>(newType), adapter.getMapOperands(),
212 numDimsPerMap
213 );
214 }
215 }
216};
217
218class CallOpClassReplacePattern : public mlir::OpConversionPattern<llzk::function::CallOp> {
219public:
220 using mlir::OpConversionPattern<llzk::function::CallOp>::OpConversionPattern;
221
222 mlir::LogicalResult matchAndRewrite(
223 llzk::function::CallOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter
224 ) const override {
225 // Convert the result types of the CallOp
226 mlir::SmallVector<mlir::Type> newResultTypes;
227 if (mlir::failed(getTypeConverter()->convertTypes(op.getResultTypes(), newResultTypes))) {
228 return op->emitError("Could not convert Op result types.");
229 }
231 rewriter, op, newResultTypes, op.getCalleeAttr(), adapter.getMapOperands(),
232 op.getNumDimsPerMapAttr(), adapter.getArgOperands()
233 );
234 return mlir::success();
235 }
236};
237
242template <typename... AdditionalOpClasses>
243mlir::RewritePatternSet newGeneralRewritePatternSet(
244 mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target
245) {
246 mlir::RewritePatternSet patterns(ctx);
247 auto inserter = [&](auto... opClasses) {
248 patterns.add<GeneralTypeReplacePattern<decltype(opClasses)>...>(tyConv, ctx);
249 };
250 std::apply(inserter, OpClassesWithStructTypes.WithGeneralBuilder);
251 applyToMoreTypes<decltype(inserter), AdditionalOpClasses...>(inserter);
252 // Special cases for ops where GeneralTypeReplacePattern doesn't work
254 // Add builtin FunctionType converter
255 mlir::populateFunctionOpInterfaceTypeConversionPattern<llzk::function::FuncDefOp>(
256 patterns, tyConv
257 );
258 mlir::scf::populateSCFStructuralTypeConversionsAndLegality(tyConv, patterns, target);
259 return patterns;
260}
261
263mlir::ConversionTarget newBaseTarget(mlir::MLIRContext *ctx);
264
266public:
267 virtual ~LegalityCheckCallback() = default;
268 virtual void checkStarted() = 0;
269 virtual void checkEnded(bool outcome) = 0;
270};
271
273public:
274 void checkStarted() override {}
275 void checkEnded(bool) override {}
276};
277
284template <typename... AdditionalOpClasses, typename... AdditionalChecks>
285mlir::ConversionTarget newConverterDefinedTarget(
286 mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks
287) {
288 static EmptyLegalityCheckCallback empty;
289 return newConverterDefinedTargetWithCallback<AdditionalOpClasses...>(
290 tyConv, ctx, empty, (std::forward<AdditionalChecks>(checks))...
291 );
292}
293
300template <typename... AdditionalOpClasses, typename... AdditionalChecks>
302 mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, LegalityCheckCallback &cb,
303 AdditionalChecks &&...checks
304) {
305 mlir::ConversionTarget target = newBaseTarget(ctx);
306 auto inserter = [&](auto... opClasses) {
307 target.addDynamicallyLegalOp<decltype(opClasses)...>([&cb, &tyConv,
308 &checks...](mlir::Operation *op) {
309 LLVM_DEBUG(if (op) {
310 llvm::dbgs() << "[newConverterDefinedTarget] checking legality of ";
311 op->dump();
312 });
313 cb.checkStarted();
314 auto legality =
315 defaultLegalityCheck(tyConv, op) && (runCheck<AdditionalChecks>(op, checks) && ...);
316
317 cb.checkEnded(legality);
318 LLVM_DEBUG(if (legality) { llvm::dbgs() << "[newConverterDefinedTarget] is legal\n"; } else {
319 llvm::dbgs() << "[newConverterDefinedTarget] is not legal\n";
320 });
321 return legality;
322 });
323 };
324 std::apply(inserter, OpClassesWithStructTypes.NoGeneralBuilder);
325 std::apply(inserter, OpClassesWithStructTypes.WithGeneralBuilder);
326 applyToMoreTypes<decltype(inserter), AdditionalOpClasses...>(inserter);
327 return target;
328}
329
330} // namespace llzk::polymorphic::detail
331
332#undef DEBUG_TYPE
#define check(x)
Definition Ops.cpp:171
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:421
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:392
::mlir::Operation::operand_range getElements()
Definition Ops.h.inc:388
::mlir::SymbolRefAttr getCalleeAttr()
Definition Ops.h.inc:267
::mlir::DenseI32ArrayAttr getNumDimsPerMapAttr()
Definition Ops.h.inc:272
mlir::LogicalResult matchAndRewrite(llzk::function::CallOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter) const override
Definition SharedImpl.h:222
mlir::LogicalResult match(llzk::array::CreateArrayOp op) const override
Definition SharedImpl.h:189
void rewrite(llzk::array::CreateArrayOp op, OpAdaptor adapter, mlir::ConversionPatternRewriter &rewriter) const override
Definition SharedImpl.h:197
mlir::LogicalResult matchAndRewrite(OpClass op, OpClass::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override
Definition SharedImpl.h:144
mlir::ConversionTarget newConverterDefinedTargetWithCallback(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, LegalityCheckCallback &cb, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:301
mlir::RewritePatternSet newGeneralRewritePatternSet(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, mlir::ConversionTarget &target)
Return a new RewritePatternSet that includes a GeneralTypeReplacePattern for all of OpClassesWithStru...
Definition SharedImpl.h:243
mlir::ConversionTarget newConverterDefinedTarget(mlir::TypeConverter &tyConv, mlir::MLIRContext *ctx, AdditionalChecks &&...checks)
Return a new ConversionTarget allowing all LLZK-required dialects and defining Op legality based on t...
Definition SharedImpl.h:285
OpClass replaceOpWithNewOp(Rewriter &rewriter, mlir::Operation *op, Args &&...args)
Wrapper for PatternRewriter.replaceOpWithNewOp() that automatically copies discardable attributes (i....
Definition SharedImpl.h:129
mlir::ConversionTarget newBaseTarget(mlir::MLIRContext *ctx)
Return a new ConversionTarget allowing all LLZK-required dialects.
bool isNullOrEmpty(mlir::ArrayAttr a)