LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Boolean operation implementations ----------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
15
16#include <mlir/IR/BuiltinAttributes.h>
17
18// TableGen'd implementation files
19#define GET_OP_CLASSES
21
22using namespace mlir;
23
24namespace llzk::boolean {
25
26//===------------------------------------------------------------------===//
27// AssertOp
28//===------------------------------------------------------------------===//
29
30// This side effect models "program termination". Based on
31// https://github.com/llvm/llvm-project/blob/f325e4b2d836d6e65a4d0cf3efc6b0996ccf3765/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp#L92-L97
33 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &effects
34) {
35 effects.emplace_back(MemoryEffects::Write::get());
36}
37
38//===------------------------------------------------------------------===//
39// Fold helpers
40//===------------------------------------------------------------------===//
41
42namespace {
43
46static FailureOr<bool> getBoolValue(Attribute attr) {
47 auto ia = llvm::dyn_cast_or_null<IntegerAttr>(attr);
48 if (!ia || !ia.getType().isInteger(1)) {
49 return failure();
50 }
51 return ia.getValue().getBoolValue();
52}
53
55static IntegerAttr makeBoolAttr(MLIRContext *ctx, bool val) {
56 auto i1Ty = IntegerType::get(ctx, 1);
57 return IntegerAttr::get(i1Ty, val ? 1 : 0);
58}
59
60} // namespace
61
62//===------------------------------------------------------------------===//
63// AndBoolOp
64//===------------------------------------------------------------------===//
65
66OpFoldResult AndBoolOp::fold(FoldAdaptor adaptor) {
67 auto lhs = getBoolValue(adaptor.getLhs());
68 auto rhs = getBoolValue(adaptor.getRhs());
69 if (failed(lhs) || failed(rhs)) {
70 return {};
71 }
72 return makeBoolAttr(getContext(), *lhs && *rhs);
73}
74
75//===------------------------------------------------------------------===//
76// OrBoolOp
77//===------------------------------------------------------------------===//
78
79OpFoldResult OrBoolOp::fold(FoldAdaptor adaptor) {
80 auto lhs = getBoolValue(adaptor.getLhs());
81 auto rhs = getBoolValue(adaptor.getRhs());
82 if (failed(lhs) || failed(rhs)) {
83 return {};
84 }
85 return makeBoolAttr(getContext(), *lhs || *rhs);
86}
87
88//===------------------------------------------------------------------===//
89// XorBoolOp
90//===------------------------------------------------------------------===//
91
92OpFoldResult XorBoolOp::fold(FoldAdaptor adaptor) {
93 auto lhs = getBoolValue(adaptor.getLhs());
94 auto rhs = getBoolValue(adaptor.getRhs());
95 if (failed(lhs) || failed(rhs)) {
96 return {};
97 }
98 return makeBoolAttr(getContext(), *lhs != *rhs);
99}
100
101//===------------------------------------------------------------------===//
102// NotBoolOp
103//===------------------------------------------------------------------===//
104
105OpFoldResult NotBoolOp::fold(FoldAdaptor adaptor) {
106 auto val = getBoolValue(adaptor.getOperand());
107 if (failed(val)) {
108 return {};
109 }
110 return makeBoolAttr(getContext(), !*val);
111}
112
113//===------------------------------------------------------------------===//
114// CmpOp
115//===------------------------------------------------------------------===//
116
117inline static bool eval(FeltCmpPredicate pred, const llvm::APInt &lval, const llvm::APInt &rval) {
118 switch (pred) {
120 return lval == rval;
122 return lval != rval;
124 return lval.ult(rval);
126 return lval.ule(rval);
128 return lval.ugt(rval);
130 return lval.uge(rval);
131 }
132 llvm_unreachable("invalid FeltCmpPredicate");
133}
134
135OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
136 auto lhsAttr = llvm::dyn_cast_or_null<felt::FeltConstAttr>(adaptor.getLhs());
137 auto rhsAttr = llvm::dyn_cast_or_null<felt::FeltConstAttr>(adaptor.getRhs());
138 if (!lhsAttr || !rhsAttr) {
139 return {};
140 }
141
142 // Normalize to a common bit width for unsigned comparison.
143 llvm::APInt lval = lhsAttr.getValue();
144 llvm::APInt rval = rhsAttr.getValue();
145 unsigned w = std::max(lval.getBitWidth(), rval.getBitWidth());
146 if (lval.getBitWidth() < w) {
147 lval = lval.zext(w);
148 }
149 if (rval.getBitWidth() < w) {
150 rval = rval.zext(w);
151 }
152 return makeBoolAttr(getContext(), eval(getPredicate(), lval, rval));
153}
154
155} // namespace llzk::boolean
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:66
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:127
void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect > > &effects)
Definition Ops.cpp:32
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:551
::llzk::boolean::FeltCmpPredicate getPredicate()
Definition Ops.cpp.inc:601
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:135
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:770
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:105
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:939
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:79
GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute > > FoldAdaptor
Definition Ops.h.inc:1117
::mlir::OpFoldResult fold(FoldAdaptor adaptor)
Definition Ops.cpp:92