LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Operation implementations ---------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
19
20// Include TableGen'd declarations
22
23// TableGen'd implementation files
24#define GET_OP_CLASSES
26
27using namespace mlir;
28using namespace llzk::component;
29
31
32bool isInTemplate(Operation *op) { return getParentOfType<TemplateOp>(op); }
33
34FailureOr<TemplateOp> verifyInTemplate(Operation *op) {
36 return res;
37 }
38 return op->emitOpError() << "only valid within a '" << TemplateOp::getOperationName()
39 << "' ancestor";
40}
41
42//===------------------------------------------------------------------===//
43// TemplateParamOp
44//===------------------------------------------------------------------===//
45
46namespace {
47
48LogicalResult checkForNameConflict(SymbolTableCollection &tables, SymbolOpInterface op) {
49 // Ensure parameter name does not conflict with an existing top-level symbol
50 // because that would cause an ambiguity in symbol resolution within structs.
51 auto res = lookupTopLevelSymbol(tables, FlatSymbolRefAttr::get(op.getNameAttr()), op, false);
52 if (succeeded(res)) {
53 return op.emitOpError()
54 .append("name conflicts with an existing symbol")
55 .attachNote(res->get()->getLoc())
56 .append("symbol already defined here");
57 }
58 return success();
59}
60
61} // namespace
62
63LogicalResult TemplateParamOp::verifySymbolUses(SymbolTableCollection &tables) {
64 return checkForNameConflict(tables, *this);
65}
66
67//===------------------------------------------------------------------===//
68// TemplateExprOp
69//===------------------------------------------------------------------===//
70
71LogicalResult TemplateExprOp::verifySymbolUses(SymbolTableCollection &tables) {
72 if (failed(checkForNameConflict(tables, *this))) {
73 return failure(); // checkForNameConflict() already emits a sufficient error message
74 }
75 // Ensure no symbol used within the initializer region is defined via a `TemplateExprOp`.
76 // This prevents cyclic definitions of `TemplateExprOp`. Searches all symbol uses within
77 // this op and also within any nested symbol tables.
78 Operation *thisOp = this->getOperation();
79 TemplateOp parentTemplate = getParentOfType<TemplateOp>(thisOp);
80 assert(parentTemplate && "per ODS");
81 LogicalResult errorState = success();
82 auto checkUses = [this, &parentTemplate, &errorState](Operation *symTableOp, bool) {
83 if (auto uses = llzk::getSymbolUses(symTableOp)) {
84 for (SymbolTable::SymbolUse use : uses.value()) {
85 // Only need to check flat refs since `TemplateExprOp` refs must be flat
86 auto usedSym = llvm::dyn_cast<FlatSymbolRefAttr>(use.getSymbolRef());
87 if (usedSym && parentTemplate.hasConstNamed<TemplateExprOp>(usedSym)) {
88 InFlightDiagnostic diag = this->emitOpError().append(
89 "initialization cannot use a symbol defined by another `",
90 TemplateExprOp::getOperationName(), "` within this template"
91 );
92 diag.attachNote(use.getUser()->getLoc()).append("symbol ", usedSym, " used here");
93 auto def = parentTemplate.getConstNamed<TemplateExprOp>(usedSym);
94 diag.attachNote(def.getLoc()).append("defined here");
95 errorState = diag; // transformation to LogicalResult reports the error
96 return;
97 }
98 }
99 }
100 };
101 checkUses(thisOp, true);
102 if (succeeded(errorState)) {
103 SymbolTable::walkSymbolTables(thisOp, /*allSymUsesVisible=*/true, checkUses);
104 }
105 return errorState;
106}
107
109 Region &region = getInitializerRegion();
110 if (!region.hasOneBlock()) {
111 return emitOpError("expected initializer region with a single block");
112 }
113 Block &block = region.back();
114 if (!llvm::isa<YieldOp>(block.getTerminator())) {
115 return emitOpError("expected initializer region to end with a '")
116 << YieldOp::getOperationName() << '\'';
117 }
118 // Check or ops with side-effects that are not allowed within `poly.expr`.
119 Operation *illegalOp = nullptr;
120 auto walkRes = block.walk([&illegalOp](Operation *p) {
121 // Note: If side-effect traits are added to ops in the future, this check should
122 // be updated to check for those traits instead of specific op types.
123 if (llvm::isa<global::GlobalRefOpInterface, function::CallOp>(p)) {
124 illegalOp = p;
125 return WalkResult::interrupt();
126 }
127 return WalkResult::advance();
128 });
129 if (walkRes.wasInterrupted()) {
130 assert(illegalOp); // was set in the walk above
131 return illegalOp->emitOpError().append(
132 "is not allowed within a `", TemplateExprOp::getOperationName(), "` initializer"
133 );
134 }
135 return success();
136}
137
139 Region &region = getInitializerRegion();
140 assert(region.hasOneBlock() && "per `verifyRegions()`");
141 YieldOp yieldOp = llvm::dyn_cast<YieldOp>(region.back().getTerminator());
142 assert(yieldOp && "per `verifyRegions()`");
143 return yieldOp.getVal().getType();
144}
145
146std::optional<Type> TemplateExprOp::getTypeOpt() { return getType(); }
147
148//===------------------------------------------------------------------===//
149// ConstReadOp
150//===------------------------------------------------------------------===//
151
152LogicalResult ConstReadOp::verifySymbolUses(SymbolTableCollection &tables) {
153 FailureOr<TemplateOp> getParentRes = verifyInTemplate(*this);
154 if (failed(getParentRes)) {
155 return failure(); // verifyInTemplate() already emits a sufficient error message
156 }
157 // Ensure the named constant is a parameter of the parent struct
158 FlatSymbolRefAttr name = this->getConstNameAttr();
159 auto constParam = getParentRes->getConstNamed<TemplateSymbolBindingOpInterface>(name);
160 if (!constParam) {
161 return this->emitOpError()
162 .append("references unknown symbol \"", name, '"')
163 .attachNote(getParentRes->getLoc())
164 .append("must reference a param or expr of this template");
165 }
166 // Ensure the type of the constant read matches the type of the referenced parameter (if any).
167 if (std::optional<Type> paramType = constParam.getTypeOpt()) {
168 if (this->getType() != *paramType) {
169 return this->emitOpError().append(
170 "type ", this->getType(), " does not match constant param type ", *paramType
171 );
172 }
173 }
174
175 // Ensure any SymbolRef used in the type are valid
176 return verifyTypeResolution(tables, *this, getType());
177}
178
179//===------------------------------------------------------------------===//
180// ApplyMapOp
181//===------------------------------------------------------------------===//
182
183LogicalResult ApplyMapOp::verify() {
184 // Check input and output dimensions match.
185 AffineMap map = getMap();
186
187 // Verify that the map only produces one result.
188 if (map.getNumResults() != 1) {
189 return emitOpError("must produce exactly one value");
190 }
191
192 // Verify that operand count matches affine map dimension and symbol count.
193 unsigned mapDims = map.getNumDims();
194 if (getNumOperands() != mapDims + map.getNumSymbols()) {
195 return emitOpError("operand count must equal affine map dimension+symbol count");
196 } else if (mapDims != getNumDimsAttr().getInt()) {
197 return emitOpError("dimension operand count must equal affine map dimension count");
198 }
199
200 return success();
201}
202
203//===------------------------------------------------------------------===//
204// UnifiableCastOp
205//===------------------------------------------------------------------===//
206
207LogicalResult UnifiableCastOp::verify() {
208 if (!typesUnify(getInput().getType(), getResult().getType())) {
209 return emitOpError() << "input type " << getInput().getType() << " and output type "
210 << getResult().getType() << " are not unifiable";
211 }
212
213 return success();
214}
215
216} // namespace llzk::polymorphic
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Definition LICENSE.txt:9
::mlir::IntegerAttr getNumDimsAttr()
Definition Ops.h.inc:238
::llvm::LogicalResult verify()
Definition Ops.cpp:183
::mlir::AffineMap getMap()
Definition Ops.cpp.inc:358
::mlir::FlatSymbolRefAttr getConstNameAttr()
Definition Ops.h.inc:464
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:152
::llvm::LogicalResult verifyRegions()
Definition Ops.cpp:108
::mlir::Region & getInitializerRegion()
Definition Ops.h.inc:660
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:71
::mlir::Type getType()
Returns the type of the poly.yield op in the initializer region.
Definition Ops.cpp:138
::std::optional<::mlir::Type > getTypeOpt()
Definition Ops.cpp:146
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:636
OpT getConstNamed(::mlir::StringRef find)
Return the op of type OpT with the given name within the body region if it exists,...
Definition Ops.h.inc:969
bool hasConstNamed(::mlir::StringRef find)
Return true if there is an op of type OpT with the given name within the body region.
Definition Ops.h.inc:949
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:848
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:63
::llvm::LogicalResult verify()
Definition Ops.cpp:207
::mlir::TypedValue<::mlir::Type > getInput()
Definition Ops.h.inc:1327
::mlir::TypedValue<::mlir::Type > getResult()
Definition Ops.h.inc:1346
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:1450
::mlir::TypedValue<::mlir::Type > getVal()
Definition Ops.h.inc:1464
bool isInTemplate(Operation *op)
Definition Ops.cpp:32
FailureOr< TemplateOp > verifyInTemplate(Operation *op)
Definition Ops.cpp:34
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:69
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)