LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Global value operation implementations --------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
18
19// TableGen'd implementation files
21
22// TableGen'd implementation files
23#define GET_OP_CLASSES
25
26using namespace mlir;
27using namespace llzk::array;
28using namespace llzk::felt;
29using namespace llzk::string;
30
31namespace llzk::global {
32
33//===------------------------------------------------------------------===//
34// GlobalDefOp
35//===------------------------------------------------------------------===//
36
37ParseResult GlobalDefOp::parseGlobalInitialValue(
38 OpAsmParser &parser, Attribute &initialValue, TypeAttr typeAttr
39) {
40 if (parser.parseOptionalEqual()) {
41 // When there's no equal sign, there's no initial value to parse.
42 return success();
43 }
44 Type specifiedType = typeAttr.getValue();
45
46 // Special case for parsing LLZK FeltType to match format of FeltConstantOp.
47 // Not actually necessary but the default format is verbose. ex: "#felt<const 35>"
48 if (isa<FeltType>(specifiedType)) {
49 FeltConstAttr feltConstAttr;
50 if (parser.parseCustomAttributeWithFallback<FeltConstAttr>(feltConstAttr)) {
51 return failure();
52 }
53 initialValue = feltConstAttr;
54 return success();
55 }
56 // Fallback to default parser for all other types.
57 if (failed(parser.parseAttribute(initialValue, specifiedType))) {
58 return failure();
59 }
60 return success();
61}
62
63void GlobalDefOp::printGlobalInitialValue(
64 OpAsmPrinter &p, GlobalDefOp /*op*/, Attribute initialValue, TypeAttr /*typeAttr*/
65) {
66 if (initialValue) {
67 p << " = ";
68 // Special case for LLZK FeltType to match format of FeltConstantOp.
69 // Not actually necessary but the default format is verbose. ex: "#felt<const 35>"
70 if (FeltConstAttr feltConstAttr = llvm::dyn_cast<FeltConstAttr>(initialValue)) {
71 p.printStrippedAttrOrType<FeltConstAttr>(feltConstAttr);
72 } else {
73 p.printAttributeWithoutType(initialValue);
74 }
75 }
76}
77
78LogicalResult GlobalDefOp::verifySymbolUses(SymbolTableCollection &tables) {
79 // Ensure any SymbolRef used in the type are valid
80 return verifyTypeResolution(tables, *this, getType());
81}
82
83namespace {
84
85inline InFlightDiagnosticWrapper reportMismatch(
86 EmitErrorFn errFn, Type rootType, const Twine &aspect, const Twine &expected, const Twine &found
87) {
88 return errFn().append(
89 "with type ", rootType, " expected ", expected, " ", aspect, " but found ", found
90 );
91}
92
93inline InFlightDiagnosticWrapper reportMismatch(
94 EmitErrorFn errFn, Type rootType, const Twine &aspect, const Twine &expected, Attribute found
95) {
96 return reportMismatch(errFn, rootType, aspect, expected, found.getAbstractAttribute().getName());
97}
98
99LogicalResult ensureAttrTypeMatch(
100 Type type, Attribute valAttr, const OwningEmitErrorFn &errFn, Type rootType, const Twine &aspect
101) {
102 if (!isValidGlobalType(type)) {
103 // Same error message ODS-generated code would produce
104 return errFn().append(
105 "attribute 'type' failed to satisfy constraint: type attribute of "
106 "any LLZK type except non-constant types"
107 );
108 }
109 if (type.isSignlessInteger(1)) {
110 if (IntegerAttr ia = llvm::dyn_cast<IntegerAttr>(valAttr)) {
111 APInt val = ia.getValue();
112 if (!val.isZero() && !val.isOne()) {
113 return errFn().append("integer constant out of range for attribute");
114 }
115 } else if (!llvm::isa<BoolAttr>(valAttr)) {
116 return reportMismatch(errFn, rootType, aspect, "builtin.bool or builtin.integer", valAttr);
117 }
118 } else if (llvm::isa<IndexType>(type)) {
119 // The explicit check for BoolAttr is needed because the LLVM isa/cast functions treat
120 // BoolAttr as a subtype of IntegerAttr but this scenario should not allow BoolAttr.
121 bool isBool = llvm::isa<BoolAttr>(valAttr);
122 if (isBool || !llvm::isa<IntegerAttr>(valAttr)) {
123 return reportMismatch(
124 errFn, rootType, aspect, "builtin.index",
125 isBool ? "builtin.bool" : valAttr.getAbstractAttribute().getName()
126 );
127 }
128 } else if (llvm::isa<FeltType>(type)) {
129 if (!llvm::isa<FeltConstAttr, IntegerAttr>(valAttr)) {
130 return reportMismatch(errFn, rootType, aspect, "felt.type", valAttr);
131 }
132 } else if (llvm::isa<StringType>(type)) {
133 if (!llvm::isa<StringAttr>(valAttr)) {
134 return reportMismatch(errFn, rootType, aspect, "builtin.string", valAttr);
135 }
136 } else if (ArrayType arrTy = llvm::dyn_cast<ArrayType>(type)) {
137 if (ArrayAttr arrVal = llvm::dyn_cast<ArrayAttr>(valAttr)) {
138 // Ensure the number of elements is correct for the ArrayType
139 assert(arrTy.hasStaticShape() && "implied by earlier isValidGlobalType() check");
140 int64_t expectedCount = arrTy.getNumElements();
141 size_t actualCount = arrVal.size();
142 if (std::cmp_not_equal(actualCount, expectedCount)) {
143 return reportMismatch(
144 errFn, rootType, Twine(aspect) + " to contain " + Twine(expectedCount) + " elements",
145 "builtin.array", Twine(actualCount)
146 );
147 }
148 // Ensure the type of each element is correct for the ArrayType.
149 // Rather than immediately returning on failure, check all elements and aggregate to provide
150 // as many errors are possible in a single verifier run.
151 bool hasFailure = false;
152 Type expectedElemTy = arrTy.getElementType();
153 for (Attribute e : arrVal.getValue()) {
154 hasFailure |=
155 failed(ensureAttrTypeMatch(expectedElemTy, e, errFn, rootType, "array element"));
156 }
157 if (hasFailure) {
158 return failure();
159 }
160 } else {
161 return reportMismatch(errFn, rootType, aspect, "builtin.array", valAttr);
162 }
163 } else {
164 return errFn().append("expected a valid LLZK type but found ", type);
165 }
166 return success();
167}
168
169} // namespace
170
171LogicalResult GlobalDefOp::verify() {
172 if (Attribute initValAttr = getInitialValueAttr()) {
173 Type ty = getType();
174 OwningEmitErrorFn errFn = getEmitOpErrFn(this);
175 return ensureAttrTypeMatch(ty, initValAttr, errFn, ty, "attribute value");
176 }
177 // If there is no initial value, it cannot have "const".
178 if (isConstant()) {
179 return emitOpError("marked as 'const' must be assigned a value");
180 }
181 return success();
182}
183
184//===------------------------------------------------------------------===//
185// GlobalReadOp / GlobalWriteOp
186//===------------------------------------------------------------------===//
187
188FailureOr<SymbolLookupResult<GlobalDefOp>>
189GlobalRefOpInterface::getGlobalDefOp(SymbolTableCollection &tables) {
190 return lookupTopLevelSymbol<GlobalDefOp>(tables, getNameRef(), getOperation());
191}
192
193namespace {
194
195FailureOr<SymbolLookupResult<GlobalDefOp>>
196verifySymbolUsesImpl(GlobalRefOpInterface refOp, SymbolTableCollection &tables) {
197 // Ensure this op references a valid GlobalDefOp name
198 auto tgt = refOp.getGlobalDefOp(tables);
199 if (failed(tgt)) {
200 return failure();
201 }
202 // Ensure the SSA Value type matches the GlobalDefOp type
203 Type globalType = tgt->get().getType();
204 if (!typesUnify(refOp.getVal().getType(), globalType, tgt->getIncludeSymNames())) {
205 return refOp->emitOpError() << "has wrong type; expected " << globalType << ", got "
206 << refOp.getVal().getType();
207 }
208 return tgt;
209}
210
211} // namespace
212
213LogicalResult GlobalReadOp::verifySymbolUses(SymbolTableCollection &tables) {
214 if (failed(verifySymbolUsesImpl(*this, tables))) {
215 return failure();
216 }
217 // Ensure any SymbolRef used in the type are valid
218 return verifyTypeResolution(tables, *this, getType());
219}
220
221LogicalResult GlobalWriteOp::verifySymbolUses(SymbolTableCollection &tables) {
222 auto tgt = verifySymbolUsesImpl(*this, tables);
223 if (failed(tgt)) {
224 return failure();
225 }
226 if (tgt->get().isConstant()) {
227 return emitOpError().append(
228 "cannot target '", GlobalDefOp::getOperationName(), "' marked as 'const'"
229 );
230 }
231 return success();
232}
233
234} // namespace llzk::global
Wrapper around InFlightDiagnostic that can either be a regular InFlightDiagnostic or a special versio...
Definition ErrorHelper.h:26
InFlightDiagnosticWrapper & append(Args &&...args) &
Append arguments to the diagnostic.
Definition ErrorHelper.h:90
::mlir::Type getType()
Definition Ops.cpp.inc:401
::mlir::Attribute getInitialValueAttr()
Definition Ops.h.inc:267
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:78
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:219
::llvm::LogicalResult verify()
Definition Ops.cpp:171
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:213
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the GlobalRefOp.
::mlir::FailureOr< SymbolLookupResult< GlobalDefOp > > getGlobalDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the global referenced in this op.
Definition Ops.cpp:189
::mlir::SymbolRefAttr getNameRef()
Gets the global name attribute from the GlobalRefOp.
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:221
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
bool isValidGlobalType(Type type)
llvm::function_ref< InFlightDiagnosticWrapper()> EmitErrorFn
Callback to produce an error diagnostic.
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)