LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Operation implementations ---------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
11
20
21// Include TableGen'd declarations
23
24// TableGen'd implementation files
25#define GET_OP_CLASSES
27
28using namespace mlir;
29using namespace llzk::component;
30using namespace llzk::verif;
31
33
34bool isInTemplate(Operation *op) { return getParentOfType<TemplateOp>(op); }
35
36FailureOr<TemplateOp> verifyInTemplate(Operation *op) {
38 return res;
39 }
40 return op->emitOpError() << "only valid within a '" << TemplateOp::getOperationName()
41 << "' ancestor";
42}
43
44//===------------------------------------------------------------------===//
45// TemplateParamOp
46//===------------------------------------------------------------------===//
47
48namespace {
49
50LogicalResult checkForNameConflict(SymbolTableCollection &tables, SymbolOpInterface op) {
51 // Ensure parameter name does not conflict with an existing top-level symbol
52 // because that would cause an ambiguity in symbol resolution within structs.
53 auto res = lookupTopLevelSymbol(tables, FlatSymbolRefAttr::get(op.getNameAttr()), op, false);
54 if (succeeded(res)) {
55 return op.emitOpError()
56 .append("name conflicts with an existing symbol")
57 .attachNote(res->get()->getLoc())
58 .append("symbol already defined here");
59 }
60 return success();
61}
62
63} // namespace
64
65LogicalResult TemplateParamOp::verifySymbolUses(SymbolTableCollection &tables) {
66 return checkForNameConflict(tables, *this);
67}
68
69//===------------------------------------------------------------------===//
70// TemplateExprOp
71//===------------------------------------------------------------------===//
72
73LogicalResult TemplateExprOp::verifySymbolUses(SymbolTableCollection &tables) {
74 if (failed(checkForNameConflict(tables, *this))) {
75 return failure(); // checkForNameConflict() already emits a sufficient error message
76 }
77 // Ensure no symbol used within the initializer region is defined via a `TemplateExprOp`.
78 // This prevents cyclic definitions of `TemplateExprOp`. Searches all symbol uses within
79 // this op and also within any nested symbol tables.
80 Operation *thisOp = this->getOperation();
81 TemplateOp parentTemplate = getParentOfType<TemplateOp>(thisOp);
82 assert(parentTemplate && "per ODS");
83 LogicalResult errorState = success();
84 auto checkUses = [this, &parentTemplate, &errorState](Operation *symTableOp, bool) {
85 if (auto uses = llzk::getSymbolUses(symTableOp)) {
86 for (SymbolTable::SymbolUse use : uses.value()) {
87 // Only need to check flat refs since `TemplateExprOp` refs must be flat
88 auto usedSym = llvm::dyn_cast<FlatSymbolRefAttr>(use.getSymbolRef());
89 if (usedSym && parentTemplate.hasConstNamed<TemplateExprOp>(usedSym)) {
90 InFlightDiagnostic diag = this->emitOpError().append(
91 "initialization cannot use a symbol defined by another `",
92 TemplateExprOp::getOperationName(), "` within this template"
93 );
94 diag.attachNote(use.getUser()->getLoc()).append("symbol ", usedSym, " used here");
95 auto def = parentTemplate.getConstNamed<TemplateExprOp>(usedSym);
96 diag.attachNote(def.getLoc()).append("defined here");
97 errorState = diag; // transformation to LogicalResult reports the error
98 return;
99 }
100 }
101 }
102 };
103 checkUses(thisOp, true);
104 if (succeeded(errorState)) {
105 SymbolTable::walkSymbolTables(thisOp, /*allSymUsesVisible=*/true, checkUses);
106 }
107 return errorState;
108}
109
111 Region &region = getInitializerRegion();
112 if (!region.hasOneBlock()) {
113 return emitOpError("expected initializer region with a single block");
114 }
115 Block &block = region.back();
116 if (!llvm::isa<YieldOp>(block.getTerminator())) {
117 return emitOpError("expected initializer region to end with a '")
118 << YieldOp::getOperationName() << '\'';
119 }
120 // Check or ops with side-effects that are not allowed within `poly.expr`.
121 Operation *illegalOp = nullptr;
122 auto walkRes = block.walk([&illegalOp](Operation *p) {
123 // Note: If side-effect traits are added to ops in the future, this check should
124 // be updated to check for those traits instead of specific op types.
125 if (llvm::isa<global::GlobalRefOpInterface, function::CallOp>(p)) {
126 illegalOp = p;
127 return WalkResult::interrupt();
128 }
129 return WalkResult::advance();
130 });
131 if (walkRes.wasInterrupted()) {
132 assert(illegalOp); // was set in the walk above
133 return illegalOp->emitOpError().append(
134 "is not allowed within a `", TemplateExprOp::getOperationName(), "` initializer"
135 );
136 }
137 return success();
138}
139
141 Region &region = getInitializerRegion();
142 assert(region.hasOneBlock() && "per `verifyRegions()`");
143 YieldOp yieldOp = llvm::dyn_cast<YieldOp>(region.back().getTerminator());
144 assert(yieldOp && "per `verifyRegions()`");
145 return yieldOp.getVal().getType();
146}
147
148std::optional<Type> TemplateExprOp::getTypeOpt() { return getType(); }
149
150//===------------------------------------------------------------------===//
151// ConstReadOp
152//===------------------------------------------------------------------===//
153
154LogicalResult ConstReadOp::verifySymbolUses(SymbolTableCollection &tables) {
155 FailureOr<TemplateOp> getParentRes = getConstResolutionTemplate(tables, *this);
156 if (failed(getParentRes)) {
157 return failure(); // getConstResolutionTemplate() failure cases emit a sufficient error message
158 }
159 if (!*getParentRes) {
160 return this->emitOpError() << "only valid within a '" << TemplateOp::getOperationName()
161 << "' ancestor or '" << ContractOp::getOperationName()
162 << "' that targets an operation with a '"
163 << TemplateOp::getOperationName() << "' ancestor";
164 }
165 // Ensure the named constant is a parameter of the parent struct
166 FlatSymbolRefAttr name = this->getConstNameAttr();
167 auto constParam = getParentRes->getConstNamed<TemplateSymbolBindingOpInterface>(name);
168 if (!constParam) {
169 return this->emitOpError()
170 .append("references unknown symbol \"", name, '"')
171 .attachNote(getParentRes->getLoc())
172 .append("must reference a param or expr of this template");
173 }
174 // Ensure the type of the constant read matches the type of the referenced parameter (if any).
175 if (std::optional<Type> paramType = constParam.getTypeOpt()) {
176 if (this->getType() != *paramType) {
177 return this->emitOpError().append(
178 "type ", this->getType(), " does not match constant param type ", *paramType
179 );
180 }
181 }
182
183 // Ensure any SymbolRef used in the type are valid
184 return verifyTypeResolution(tables, *this, getType());
185}
186
187//===------------------------------------------------------------------===//
188// ApplyMapOp
189//===------------------------------------------------------------------===//
190
191LogicalResult ApplyMapOp::verify() {
192 // Check input and output dimensions match.
193 AffineMap map = getMap();
194
195 // Verify that the map only produces one result.
196 if (map.getNumResults() != 1) {
197 return emitOpError("must produce exactly one value");
198 }
199
200 // Verify that operand count matches affine map dimension and symbol count.
201 unsigned mapDims = map.getNumDims();
202 if (getNumOperands() != mapDims + map.getNumSymbols()) {
203 return emitOpError("operand count must equal affine map dimension+symbol count");
204 } else if (mapDims != getNumDimsAttr().getInt()) {
205 return emitOpError("dimension operand count must equal affine map dimension count");
206 }
207
208 return success();
209}
210
211//===------------------------------------------------------------------===//
212// UnifiableCastOp
213//===------------------------------------------------------------------===//
214
215LogicalResult UnifiableCastOp::verify() {
216 if (!typesUnify(getInput().getType(), getResult().getType())) {
217 return emitOpError() << "input type " << getInput().getType() << " and output type "
218 << getResult().getType() << " are not unifiable";
219 }
220
221 return success();
222}
223
224} // namespace llzk::polymorphic
Apache License January AND DISTRIBUTION Definitions License shall mean the terms and conditions for use
Definition LICENSE.txt:9
::mlir::IntegerAttr getNumDimsAttr()
Definition Ops.h.inc:238
::llvm::LogicalResult verify()
Definition Ops.cpp:191
::mlir::AffineMap getMap()
Definition Ops.cpp.inc:358
::mlir::FlatSymbolRefAttr getConstNameAttr()
Definition Ops.h.inc:464
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:154
::llvm::LogicalResult verifyRegions()
Definition Ops.cpp:110
::mlir::Region & getInitializerRegion()
Definition Ops.h.inc:660
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:73
::mlir::Type getType()
Returns the type of the poly.yield op in the initializer region.
Definition Ops.cpp:140
::std::optional<::mlir::Type > getTypeOpt()
Definition Ops.cpp:148
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:636
OpT getConstNamed(::mlir::StringRef find)
Return the op of type OpT with the given name within the body region if it exists,...
Definition Ops.h.inc:969
bool hasConstNamed(::mlir::StringRef find)
Return true if there is an op of type OpT with the given name within the body region.
Definition Ops.h.inc:949
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:848
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:65
::llvm::LogicalResult verify()
Definition Ops.cpp:215
::mlir::TypedValue<::mlir::Type > getInput()
Definition Ops.h.inc:1327
::mlir::TypedValue<::mlir::Type > getResult()
Definition Ops.h.inc:1346
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:1450
::mlir::TypedValue<::mlir::Type > getVal()
Definition Ops.h.inc:1464
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:374
bool isInTemplate(Operation *op)
Definition Ops.cpp:34
FailureOr< TemplateOp > verifyInTemplate(Operation *op)
Definition Ops.cpp:36
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
std::optional< mlir::SymbolTable::UseRange > getSymbolUses(mlir::Operation *from)
Get an iterator range for all of the uses, for any symbol, that are nested within the given operation...
FailureOr< TemplateOp > getConstResolutionTemplate(SymbolTableCollection &tables, Operation *origin)
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent/ancestor operation that is of type 'OpClass'.
Definition OpHelpers.h:51
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)