LLZK 0.1.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Struct op implementations ---------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// SPDX-License-Identifier: Apache-2.0
7//
8//===----------------------------------------------------------------------===//
9
16#include "llzk/Util/Constants.h"
19
20#include <mlir/IR/IRMapping.h>
21#include <mlir/IR/OpImplementation.h>
22
23#include <llvm/ADT/MapVector.h>
24#include <llvm/ADT/STLExtras.h>
25#include <llvm/ADT/StringRef.h>
26#include <llvm/ADT/StringSet.h>
27
28#include <optional>
29
30// TableGen'd implementation files
32
33// TableGen'd implementation files
34#define GET_OP_CLASSES
36
37using namespace mlir;
38using namespace llzk::felt;
39using namespace llzk::array;
40using namespace llzk::felt;
41using namespace llzk::function;
42
43namespace llzk::component {
44
45bool isInStruct(Operation *op) { return succeeded(getParentOfType<StructDefOp>(op)); }
46
47FailureOr<StructDefOp> verifyInStruct(Operation *op) {
48 FailureOr<StructDefOp> res = getParentOfType<StructDefOp>(op);
49 if (failed(res)) {
50 return op->emitOpError() << "only valid within a '" << StructDefOp::getOperationName()
51 << "' ancestor";
52 }
53 return res;
54}
55
56bool isInStructFunctionNamed(Operation *op, char const *funcName) {
57 FailureOr<FuncDefOp> parentFuncOpt = getParentOfType<FuncDefOp>(op);
58 if (succeeded(parentFuncOpt)) {
59 FuncDefOp parentFunc = parentFuncOpt.value();
60 if (isInStruct(parentFunc.getOperation())) {
61 if (parentFunc.getSymName().compare(funcName) == 0) {
62 return true;
63 }
64 }
65 }
66 return false;
67}
68
69// Again, only valid/implemented for StructDefOp
70template <> LogicalResult SetFuncAllowAttrs<StructDefOp>::verifyTrait(Operation *structOp) {
71 assert(llvm::isa<StructDefOp>(structOp));
72 Region &bodyRegion = llvm::cast<StructDefOp>(structOp).getBodyRegion();
73 if (!bodyRegion.empty()) {
74 bodyRegion.front().walk([](FuncDefOp funcDef) {
75 if (funcDef.nameIsConstrain()) {
76 funcDef.setAllowConstraintAttr();
77 funcDef.setAllowWitnessAttr(false);
78 } else if (funcDef.nameIsCompute()) {
79 funcDef.setAllowConstraintAttr(false);
80 funcDef.setAllowWitnessAttr();
81 } else if (funcDef.nameIsProduct()) {
82 funcDef.setAllowConstraintAttr();
83 funcDef.setAllowWitnessAttr();
84 }
85 });
86 }
87 return success();
88}
89
90InFlightDiagnostic genCompareErr(StructDefOp expected, Operation *origin, const char *aspect) {
91 std::string prefix = std::string();
92 if (SymbolOpInterface symbol = llvm::dyn_cast<SymbolOpInterface>(origin)) {
93 prefix += "\"@";
94 prefix += symbol.getName();
95 prefix += "\" ";
96 }
97 return origin->emitOpError().append(
98 prefix, "must use type of its ancestor '", StructDefOp::getOperationName(), "' \"",
99 expected.getHeaderString(), "\" as ", aspect, " type"
100 );
101}
102
103static inline InFlightDiagnostic structFuncDefError(Operation *origin) {
104 return origin->emitError() << '\'' << StructDefOp::getOperationName() << "' op "
105 << "must define either only a \"@" << FUNC_NAME_PRODUCT
106 << "\" function, or both \"@" << FUNC_NAME_COMPUTE << "\" and \"@"
107 << FUNC_NAME_CONSTRAIN << "\" functions; ";
108}
109
112LogicalResult checkSelfType(
113 SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin,
114 const char *aspect
115) {
116 if (StructType actualStructType = llvm::dyn_cast<StructType>(actualType)) {
117 auto actualStructOpt =
118 lookupTopLevelSymbol<StructDefOp>(tables, actualStructType.getNameRef(), origin);
119 if (failed(actualStructOpt)) {
120 return origin->emitError().append(
121 "could not find '", StructDefOp::getOperationName(), "' named \"",
122 actualStructType.getNameRef(), '"'
123 );
124 }
125 StructDefOp actualStruct = actualStructOpt.value().get();
126 if (actualStruct != expectedStruct) {
127 return genCompareErr(expectedStruct, origin, aspect)
128 .attachNote(actualStruct.getLoc())
129 .append("uses this type instead");
130 }
131 // Check for an EXACT match in the parameter list since it must reference the "self" type.
132 if (expectedStruct.getConstParamsAttr() != actualStructType.getParams()) {
133 // To make error messages more consistent and meaningful, if the parameters don't match
134 // because the actual type uses symbols that are not defined, generate an error about the
135 // undefined symbol(s).
136 if (ArrayAttr tyParams = actualStructType.getParams()) {
137 if (failed(verifyParamsOfType(tables, tyParams.getValue(), actualStructType, origin))) {
138 return failure();
139 }
140 }
141 // Otherwise, generate an error stating the parent struct type must be used.
142 return genCompareErr(expectedStruct, origin, aspect)
143 .attachNote(actualStruct.getLoc())
144 .append("should be type of this '", StructDefOp::getOperationName(), '\'');
145 }
146 } else {
147 return genCompareErr(expectedStruct, origin, aspect);
148 }
149 return success();
150}
151
152//===------------------------------------------------------------------===//
153// StructDefOp
154//===------------------------------------------------------------------===//
155
156StructType StructDefOp::getType(std::optional<ArrayAttr> constParams) {
157 auto pathRes = getPathFromRoot(*this);
158 assert(succeeded(pathRes)); // consistent with StructType::get() with invalid args
159 return StructType::get(pathRes.value(), constParams.value_or(getConstParamsAttr()));
160}
161
163 return buildStringViaCallback([this](llvm::raw_ostream &ss) {
164 FailureOr<SymbolRefAttr> pathToExpected = getPathFromRoot(*this);
165 if (succeeded(pathToExpected)) {
166 ss << pathToExpected.value();
167 } else {
168 // When there is a failure trying to get the resolved name of the struct,
169 // just print its symbol name directly.
170 ss << '@' << this->getSymName();
171 }
172 if (auto attr = this->getConstParamsAttr()) {
173 ss << '<' << attr << '>';
174 }
175 });
176}
177
178bool StructDefOp::hasParamNamed(StringAttr find) {
179 if (ArrayAttr params = this->getConstParamsAttr()) {
180 for (Attribute attr : params) {
181 assert(llvm::isa<FlatSymbolRefAttr>(attr)); // per ODS
182 if (llvm::cast<FlatSymbolRefAttr>(attr).getRootReference() == find) {
183 return true;
184 }
185 }
186 }
187 return false;
188}
189
191 auto res = getPathFromRoot(*this);
192 assert(succeeded(res));
193 return res.value();
194}
195
196LogicalResult StructDefOp::verifySymbolUses(SymbolTableCollection &tables) {
197 if (ArrayAttr params = this->getConstParamsAttr()) {
198 // Ensure struct parameter names are unique
199 llvm::StringSet<> uniqNames;
200 for (Attribute attr : params) {
201 assert(llvm::isa<FlatSymbolRefAttr>(attr)); // per ODS
202 StringRef name = llvm::cast<FlatSymbolRefAttr>(attr).getValue();
203 if (!uniqNames.insert(name).second) {
204 return this->emitOpError().append("has more than one parameter named \"@", name, '"');
205 }
206 }
207 // Ensure they do not conflict with existing symbols
208 for (Attribute attr : params) {
209 auto res = lookupTopLevelSymbol(tables, llvm::cast<FlatSymbolRefAttr>(attr), *this, false);
210 if (succeeded(res)) {
211 return this->emitOpError()
212 .append("parameter name \"@")
213 .append(llvm::cast<FlatSymbolRefAttr>(attr).getValue())
214 .append("\" conflicts with an existing symbol")
215 .attachNote(res->get()->getLoc())
216 .append("symbol already defined here");
217 }
218 }
219 }
220 return success();
221}
222
223namespace {
224
225inline LogicalResult
226checkMainFuncParamType(Type pType, FuncDefOp inFunc, std::optional<StructType> appendSelfType) {
227 if (llvm::isa<FeltType>(pType)) {
228 return success();
229 } else if (auto arrayParamTy = llvm::dyn_cast<ArrayType>(pType)) {
230 if (llvm::isa<FeltType>(arrayParamTy.getElementType())) {
231 return success();
232 }
233 }
234
235 std::string message = buildStringViaCallback([&inFunc, appendSelfType](llvm::raw_ostream &ss) {
236 ss << "main entry component \"@" << inFunc.getSymName()
237 << "\" function parameters must be one of: {";
238 if (appendSelfType.has_value()) {
239 ss << appendSelfType.value() << ", ";
240 }
241 ss << '!' << FeltType::name << ", ";
242 ss << '!' << ArrayType::name << "<.. x !" << FeltType::name << ">}";
243 });
244 return inFunc.emitError(message);
245}
246
247inline LogicalResult verifyStructComputeConstrain(
248 StructDefOp structDef, FuncDefOp computeFunc, FuncDefOp constrainFunc
249) {
250 // ASSERT: The `SetFuncAllowAttrs` trait on StructDefOp set the attributes correctly.
251 assert(constrainFunc.hasAllowConstraintAttr());
252 assert(!computeFunc.hasAllowConstraintAttr());
253 assert(!constrainFunc.hasAllowWitnessAttr());
254 assert(computeFunc.hasAllowWitnessAttr());
255
256 // Verify parameter types are valid. Skip the first parameter of the "constrain" function; it is
257 // already checked via verifyFuncTypeConstrain() in Function/IR/Ops.cpp.
258 ArrayRef<Type> computeParams = computeFunc.getFunctionType().getInputs();
259 ArrayRef<Type> constrainParams = constrainFunc.getFunctionType().getInputs().drop_front();
260 if (structDef.isMainComponent()) {
261 // Verify the input parameter types are legal. The error message is explicit about what types
262 // are allowed so there is no benefit to report multiple errors if more than one parameter in
263 // the referenced function has an illegal type.
264 for (Type t : computeParams) {
265 if (failed(checkMainFuncParamType(t, computeFunc, std::nullopt))) {
266 return failure(); // checkMainFuncParamType() already emits a sufficient error message
267 }
268 }
269 auto appendSelf = std::make_optional(structDef.getType());
270 for (Type t : constrainParams) {
271 if (failed(checkMainFuncParamType(t, constrainFunc, appendSelf))) {
272 return failure(); // checkMainFuncParamType() already emits a sufficient error message
273 }
274 }
275 }
276
277 if (!typeListsUnify(computeParams, constrainParams)) {
278 return constrainFunc.emitError()
279 .append(
280 "expected \"@", FUNC_NAME_CONSTRAIN,
281 "\" function argument types (sans the first one) to match \"@", FUNC_NAME_COMPUTE,
282 "\" function argument types"
283 )
284 .attachNote(computeFunc.getLoc())
285 .append("\"@", FUNC_NAME_COMPUTE, "\" function defined here");
286 }
287
288 return success();
289}
290
291inline LogicalResult verifyStructProduct(StructDefOp structDef, FuncDefOp productFunc) {
292 // ASSERT: The `SetFuncAllowAttrs` trait on StructDefOp set the attributes correctly
293 assert(productFunc.hasAllowConstraintAttr());
294 assert(productFunc.hasAllowWitnessAttr());
295
296 // Verify parameter types are valid
297 if (structDef.isMainComponent()) {
298 ArrayRef<Type> productParams = productFunc.getFunctionType().getInputs();
299 // Verify the input parameter types are legal. The error message is explicit about what types
300 // are allowed so there is no benefit to report multiple errors if more than one parameter in
301 // the referenced function has an illegal type.
302 for (Type t : productParams) {
303 if (failed(checkMainFuncParamType(t, productFunc, std::nullopt))) {
304 return failure(); // checkMainFuncParamType() already emits a sufficient error message
305 }
306 }
307 }
308
309 return success();
310}
311
312} // namespace
313
315 std::optional<FuncDefOp> foundCompute = std::nullopt;
316 std::optional<FuncDefOp> foundConstrain = std::nullopt;
317 std::optional<FuncDefOp> foundProduct = std::nullopt;
318 {
319 // Verify the following:
320 // 1. The only ops within the body are member and function definitions
321 // 2. The only functions defined in the struct are `@compute()` and `@constrain()`, or
322 // `@product()`
323 OwningEmitErrorFn emitError = getEmitOpErrFn(this);
324 Region &bodyRegion = getBodyRegion();
325 if (!bodyRegion.empty()) {
326 for (Operation &op : bodyRegion.front()) {
327 if (!llvm::isa<MemberDefOp>(op)) {
328 if (FuncDefOp funcDef = llvm::dyn_cast<FuncDefOp>(op)) {
329 if (funcDef.nameIsCompute()) {
330 if (foundProduct) {
331 return structFuncDefError(funcDef.getOperation())
332 << "found both \"@" << FUNC_NAME_COMPUTE << "\" and \"@" << FUNC_NAME_PRODUCT
333 << "\" functions";
334 }
335 if (foundCompute) {
336 return structFuncDefError(funcDef.getOperation())
337 << "found multiple \"@" << FUNC_NAME_COMPUTE << "\" functions";
338 }
339 foundCompute = std::make_optional(funcDef);
340 } else if (funcDef.nameIsConstrain()) {
341 if (foundProduct) {
342 return structFuncDefError(funcDef.getOperation())
343 << "found both \"@" << FUNC_NAME_CONSTRAIN << "\" and \"@"
344 << FUNC_NAME_PRODUCT << "\" functions";
345 }
346 if (foundConstrain) {
347 return structFuncDefError(funcDef.getOperation())
348 << "found multiple \"@" << FUNC_NAME_CONSTRAIN << "\" functions";
349 }
350 foundConstrain = std::make_optional(funcDef);
351 } else if (funcDef.nameIsProduct()) {
352 if (foundCompute) {
353 return structFuncDefError(funcDef.getOperation())
354 << "found both \"@" << FUNC_NAME_COMPUTE << "\" and \"@" << FUNC_NAME_PRODUCT
355 << "\" functions";
356 }
357 if (foundConstrain) {
358 return structFuncDefError(funcDef.getOperation())
359 << "found both \"@" << FUNC_NAME_CONSTRAIN << "\" and \"@"
360 << FUNC_NAME_PRODUCT << "\" functions";
361 }
362 if (foundProduct) {
363 return structFuncDefError(funcDef.getOperation())
364 << "found multiple \"@" << FUNC_NAME_PRODUCT << "\" functions";
365 }
366 foundProduct = std::make_optional(funcDef);
367 } else {
368 // Must do a little more than a simple call to '?.emitOpError()' to
369 // tag the error with correct location and correct op name.
370 return structFuncDefError(funcDef.getOperation())
371 << "found \"@" << funcDef.getSymName() << '"';
372 }
373 } else {
374 return op.emitOpError()
375 << "invalid operation in '" << StructDefOp::getOperationName() << "'; only '"
376 << MemberDefOp::getOperationName() << '\'' << " and '"
377 << FuncDefOp::getOperationName() << "' operations are permitted";
378 }
379 }
380 }
381 }
382
383 if (!foundCompute.has_value() && foundConstrain.has_value()) {
384 return structFuncDefError(getOperation()) << "found \"@" << FUNC_NAME_CONSTRAIN
385 << "\", missing \"@" << FUNC_NAME_COMPUTE << "\"";
386 }
387 if (!foundConstrain.has_value() && foundCompute.has_value()) {
388 return structFuncDefError(getOperation()) << "found \"@" << FUNC_NAME_COMPUTE
389 << "\", missing \"@" << FUNC_NAME_CONSTRAIN << "\"";
390 }
391 }
392
393 if (!foundCompute.has_value() && !foundConstrain.has_value() && !foundProduct.has_value()) {
394 return structFuncDefError(getOperation())
395 << "could not find \"@" << FUNC_NAME_PRODUCT << "\", \"@" << FUNC_NAME_COMPUTE
396 << "\", or \"@" << FUNC_NAME_CONSTRAIN << "\"";
397 }
398
399 if (foundCompute && foundConstrain) {
400 return verifyStructComputeConstrain(*this, *foundCompute, *foundConstrain);
401 }
402 return verifyStructProduct(*this, *foundProduct);
403}
404
406 for (Operation &op : *getBody()) {
407 if (MemberDefOp memberDef = llvm::dyn_cast_if_present<MemberDefOp>(op)) {
408 if (memberName.compare(memberDef.getSymNameAttr()) == 0) {
409 return memberDef;
410 }
411 }
412 }
413 return nullptr;
414}
415
416std::vector<MemberDefOp> StructDefOp::getMemberDefs() {
417 std::vector<MemberDefOp> res;
418 for (Operation &op : *getBody()) {
419 if (MemberDefOp memberDef = llvm::dyn_cast_if_present<MemberDefOp>(op)) {
420 res.push_back(memberDef);
421 }
422 }
423 return res;
424}
425
427 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_COMPUTE));
428}
429
431 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_CONSTRAIN));
432}
433
435 if (auto *computeFunc = lookupSymbol(FUNC_NAME_COMPUTE)) {
436 return llvm::dyn_cast<FuncDefOp>(computeFunc);
437 }
438 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_PRODUCT));
439}
440
442 if (auto *constrainFunc = lookupSymbol(FUNC_NAME_CONSTRAIN)) {
443 return llvm::dyn_cast<FuncDefOp>(constrainFunc);
444 }
445 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_PRODUCT));
446}
447
449 FailureOr<StructType> mainTypeOpt = getMainInstanceType(this->getOperation());
450 if (succeeded(mainTypeOpt)) {
451 if (StructType mainType = mainTypeOpt.value()) {
452 return structTypesUnify(mainType, this->getType());
453 }
454 }
455 return false;
456}
457
458//===------------------------------------------------------------------===//
459// MemberDefOp
460//===------------------------------------------------------------------===//
461
463 OpBuilder &odsBuilder, OperationState &odsState, StringAttr sym_name, TypeAttr type,
464 bool isSignal, bool isColumn
465) {
466 Properties &props = odsState.getOrAddProperties<Properties>();
467 props.setSymName(sym_name);
468 props.setType(type);
469 if (isColumn) {
470 props.column = odsBuilder.getUnitAttr();
471 }
472 if (isSignal) {
473 props.signal = odsBuilder.getUnitAttr();
474 }
475}
476
478 OpBuilder &odsBuilder, OperationState &odsState, StringRef sym_name, Type type, bool isSignal,
479 bool isColumn
480) {
481 build(
482 odsBuilder, odsState, odsBuilder.getStringAttr(sym_name), TypeAttr::get(type), isSignal,
483 isColumn
484 );
485}
486
488 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, ValueRange operands,
489 ArrayRef<NamedAttribute> attributes, bool isSignal, bool isColumn
490) {
491 assert(operands.size() == 0u && "mismatched number of parameters");
492 odsState.addOperands(operands);
493 odsState.addAttributes(attributes);
494 assert(resultTypes.size() == 0u && "mismatched number of return types");
495 odsState.addTypes(resultTypes);
496 if (isColumn) {
497 odsState.getOrAddProperties<Properties>().column = odsBuilder.getUnitAttr();
498 }
499 if (isSignal) {
500 odsState.getOrAddProperties<Properties>().signal = odsBuilder.getUnitAttr();
501 }
502}
503
504void MemberDefOp::setPublicAttr(bool newValue) {
505 if (newValue) {
506 getOperation()->setAttr(PublicAttr::name, UnitAttr::get(getContext()));
507 } else {
508 getOperation()->removeAttr(PublicAttr::name);
509 }
510}
511
512static LogicalResult
513verifyMemberDefTypeImpl(Type memberType, SymbolTableCollection &tables, Operation *origin) {
514 if (StructType memberStructType = llvm::dyn_cast<StructType>(memberType)) {
515 // Special case for StructType verifies that the member type can resolve and that it is NOT the
516 // parent struct (i.e., struct members cannot create circular references).
517 auto memberTypeRes = verifyStructTypeResolution(tables, memberStructType, origin);
518 if (failed(memberTypeRes)) {
519 return failure(); // above already emits a sufficient error message
520 }
521 FailureOr<StructDefOp> parentRes = getParentOfType<StructDefOp>(origin);
522 assert(succeeded(parentRes) && "MemberDefOp parent is always StructDefOp"); // per ODS def
523 if (memberTypeRes.value() == parentRes.value()) {
524 return origin->emitOpError()
525 .append("type is circular")
526 .attachNote(parentRes.value().getLoc())
527 .append("references parent component defined here");
528 }
529 return success();
530 } else {
531 return verifyTypeResolution(tables, origin, memberType);
532 }
533}
534
535LogicalResult MemberDefOp::verifySymbolUses(SymbolTableCollection &tables) {
536 Type memberType = this->getType();
537 if (failed(verifyMemberDefTypeImpl(memberType, tables, *this))) {
538 return failure();
539 }
540
541 if (!getColumn()) {
542 return success();
543 }
544 // If the member is marked as a column only a small subset of types are allowed.
545 if (!isValidColumnType(getType(), tables, *this)) {
546 return emitOpError() << "marked as column can only contain felts, arrays of column types, or "
547 "structs with columns, but has type "
548 << getType();
549 }
550 return success();
551}
552
553//===------------------------------------------------------------------===//
554// MemberRefOp implementations
555//===------------------------------------------------------------------===//
556namespace {
557
558FailureOr<SymbolLookupResult<MemberDefOp>>
559getMemberDefOpImpl(MemberRefOpInterface refOp, SymbolTableCollection &tables, StructType tyStruct) {
560 Operation *op = refOp.getOperation();
561 auto structDefRes = tyStruct.getDefinition(tables, op);
562 if (failed(structDefRes)) {
563 return failure(); // getDefinition() already emits a sufficient error message
564 }
565 // Copy namespace because we will need it later.
566 llvm::SmallVector<llvm::StringRef> structDefOpNs(structDefRes->getNamespace());
568 tables, SymbolRefAttr::get(refOp->getContext(), refOp.getMemberName()),
569 std::move(*structDefRes), op
570 );
571 if (failed(res)) {
572 return refOp->emitError() << "could not find '" << MemberDefOp::getOperationName()
573 << "' named \"@" << refOp.getMemberName() << "\" in \""
574 << tyStruct.getNameRef() << '"';
575 }
576 // Prepend the namespace of the struct lookup since the type of the member is meant to be resolved
577 // within that scope.
578 res->prependNamespace(structDefOpNs);
579 return std::move(res.value());
580}
581
582static FailureOr<SymbolLookupResult<MemberDefOp>>
583findMember(MemberRefOpInterface refOp, SymbolTableCollection &tables) {
584 // Ensure the base component/struct type reference can be resolved.
585 StructType tyStruct = refOp.getStructType();
586 if (failed(tyStruct.verifySymbolRef(tables, refOp.getOperation()))) {
587 return failure();
588 }
589 // Ensure the member name can be resolved in that struct.
590 return getMemberDefOpImpl(refOp, tables, tyStruct);
591}
592
593static LogicalResult verifySymbolUsesImpl(
594 MemberRefOpInterface refOp, SymbolTableCollection &tables,
595 SymbolLookupResult<MemberDefOp> &member
596) {
597 // Ensure the type of the referenced member declaration matches the type used in this op.
598 Type actualType = refOp.getVal().getType();
599 Type memberType = member.get().getType();
600 if (!typesUnify(actualType, memberType, member.getNamespace())) {
601 return refOp->emitOpError() << "has wrong type; expected " << memberType << ", got "
602 << actualType;
603 }
604 // Ensure any SymbolRef used in the type are valid
605 return verifyTypeResolution(tables, refOp.getOperation(), actualType);
606}
607
608LogicalResult verifySymbolUsesImpl(MemberRefOpInterface refOp, SymbolTableCollection &tables) {
609 // Ensure the member name can be resolved in that struct.
610 auto member = findMember(refOp, tables);
611 if (failed(member)) {
612 return member; // getMemberDefOp() already emits a sufficient error message
613 }
614 return verifySymbolUsesImpl(refOp, tables, *member);
615}
616
617} // namespace
618
619FailureOr<SymbolLookupResult<MemberDefOp>>
620MemberRefOpInterface::getMemberDefOp(SymbolTableCollection &tables) {
621 return getMemberDefOpImpl(*this, tables, getStructType());
622}
623
624LogicalResult MemberReadOp::verifySymbolUses(SymbolTableCollection &tables) {
625 auto member = findMember(*this, tables);
626 if (failed(member)) {
627 return failure();
628 }
629 if (failed(verifySymbolUsesImpl(*this, tables, *member))) {
630 return failure();
631 }
632 // If the member is not a column and an offset was specified then fail to validate
633 if (!member->get().getColumn() && getTableOffset().has_value()) {
634 return emitOpError("cannot read with table offset from a member that is not a column")
635 .attachNote(member->get().getLoc())
636 .append("member defined here");
637 }
638 // If the member is private and this read is outside the struct, then fail to validate.
639 // The current op may be inside a struct or a free function, but the
640 // member op (the member definition) is always inside a struct.
641 FailureOr<StructDefOp> parentRes = getParentOfType<StructDefOp>(*this);
642 FailureOr<StructDefOp> memberParentRes = verifyInStruct(member->get());
643 if (failed(memberParentRes)) {
644 return failure(); // verifyInStruct() already emits a sufficient error message
645 }
646 StructDefOp memberParentStruct = memberParentRes.value();
647 if (!member->get().hasPublicAttr() &&
648 (failed(parentRes) || parentRes.value() != memberParentStruct)) {
649 return emitOpError()
650 .append(
651 "cannot read from private member of struct \"", memberParentStruct.getHeaderString(),
652 "\""
653 )
654 .attachNote(member->get().getLoc())
655 .append("member defined here");
656 }
657 return success();
658}
659
660LogicalResult MemberWriteOp::verifySymbolUses(SymbolTableCollection &tables) {
661 // Ensure the write op only targets members in the current struct.
662 FailureOr<StructDefOp> getParentRes = verifyInStruct(*this);
663 if (failed(getParentRes)) {
664 return failure(); // verifyInStruct() already emits a sufficient error message
665 }
666 if (failed(checkSelfType(tables, *getParentRes, getComponent().getType(), *this, "base value"))) {
667 return failure(); // checkSelfType() already emits a sufficient error message
668 }
669 // Perform the standard member ref checks.
670 return verifySymbolUsesImpl(*this, tables);
671}
672
673//===------------------------------------------------------------------===//
674// MemberReadOp
675//===------------------------------------------------------------------===//
676
678 OpBuilder &builder, OperationState &state, Type resultType, Value component, StringAttr member
679) {
680 Properties &props = state.getOrAddProperties<Properties>();
681 props.setMemberName(FlatSymbolRefAttr::get(member));
682 state.addTypes(resultType);
683 state.addOperands(component);
685}
686
688 OpBuilder &builder, OperationState &state, Type resultType, Value component, StringAttr member,
689 Attribute dist, ValueRange mapOperands, std::optional<int32_t> numDims
690) {
691 // '!mapOperands.empty()' implies 'numDims.has_value()'
692 assert(mapOperands.empty() || numDims.has_value());
693 state.addOperands(component);
694 state.addTypes(resultType);
695 if (numDims.has_value()) {
697 builder, state, ArrayRef({mapOperands}), builder.getDenseI32ArrayAttr({*numDims})
698 );
699 } else {
701 }
702 Properties &props = state.getOrAddProperties<Properties>();
703 props.setMemberName(FlatSymbolRefAttr::get(member));
704 props.setTableOffset(dist);
705}
706
708 OpBuilder & /*odsBuilder*/, OperationState &odsState, TypeRange resultTypes,
709 ValueRange operands, ArrayRef<NamedAttribute> attrs
710) {
711 odsState.addTypes(resultTypes);
712 odsState.addOperands(operands);
713 odsState.addAttributes(attrs);
714}
715
716LogicalResult MemberReadOp::verify() {
717 SmallVector<AffineMapAttr, 1> mapAttrs;
718 if (AffineMapAttr map =
719 llvm::dyn_cast_if_present<AffineMapAttr>(getTableOffset().value_or(nullptr))) {
720 mapAttrs.push_back(map);
721 }
723 getMapOperands(), getNumDimsPerMap(), mapAttrs, *this
724 );
725}
726
727//===------------------------------------------------------------------===//
728// CreateStructOp
729//===------------------------------------------------------------------===//
730
731void CreateStructOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
732 setNameFn(getResult(), "self");
733}
734
735LogicalResult CreateStructOp::verifySymbolUses(SymbolTableCollection &tables) {
736 FailureOr<StructDefOp> getParentRes = verifyInStruct(*this);
737 if (failed(getParentRes)) {
738 return failure(); // verifyInStruct() already emits a sufficient error message
739 }
740 if (failed(checkSelfType(tables, *getParentRes, this->getType(), *this, "result"))) {
741 return failure();
742 }
743 return success();
744}
745
746} // namespace llzk::component
llvm::ArrayRef< llvm::StringRef > getNamespace() const
Return the stack of symbol names from either IncludeOp or ModuleOp that were traversed to load this r...
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:51
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:735
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
Definition Ops.cpp:731
::mlir::TypedValue<::llzk::component::StructType > getResult()
Definition Ops.h.inc:143
void setPublicAttr(bool newValue=true)
Definition Ops.cpp:504
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:353
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::StringAttr sym_name, ::mlir::TypeAttr type, bool isSignal=false, bool isColumn=false)
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:535
FoldAdaptor::Properties Properties
Definition Ops.h.inc:315
::std::optional<::mlir::Attribute > getTableOffset()
Definition Ops.cpp.inc:991
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:691
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType, ::mlir::Value component, ::mlir::StringAttr member)
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:624
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
Definition Ops.cpp.inc:996
::llvm::LogicalResult verify()
Definition Ops.cpp:716
FoldAdaptor::Properties Properties
Definition Ops.h.inc:638
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the MemberRefOp.
::mlir::FailureOr< SymbolLookupResult< MemberDefOp > > getMemberDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the member referenced in this op.
Definition Ops.cpp:620
::llvm::StringRef getMemberName()
Gets the member name attribute value from the MemberRefOp.
::llzk::component::StructType getStructType()
Gets the struct type of the target component.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Definition Ops.h.inc:951
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:660
static mlir::LogicalResult verifyTrait(mlir::Operation *op)
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
::mlir::Region & getBodyRegion()
Definition Ops.h.inc:1213
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:196
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:1189
::llzk::function::FuncDefOp getConstrainOrProductFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:441
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:1676
::mlir::SymbolRefAttr getFullyQualifiedName()
Return the full name for this struct from the root module, including any surrounding module scopes.
Definition Ops.cpp:190
::std::vector< MemberDefOp > getMemberDefs()
Get all MemberDefOp in this structure.
Definition Ops.cpp:416
MemberDefOp getMemberDef(::mlir::StringAttr memberName)
Gets the MemberDefOp that defines the member in this structure with the given name,...
Definition Ops.cpp:405
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:430
::llzk::function::FuncDefOp getComputeOrProductFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:434
bool hasParamNamed(::mlir::StringAttr find)
Return true iff this StructDefOp has a parameter with the given name.
::llvm::LogicalResult verifyRegions()
Definition Ops.cpp:314
::mlir::ArrayAttr getConstParamsAttr()
Definition Ops.h.inc:1231
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:426
bool isMainComponent()
Return true iff this StructDefOp is the main struct. See llzk::MAIN_ATTR_NAME.
Definition Ops.cpp:448
::std::string getHeaderString()
Generate header string, in the same format as the assemblyFormat.
Definition Ops.cpp:162
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:79
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
Definition Types.cpp:46
::mlir::LogicalResult verifySymbolRef(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op)
Definition Types.cpp:77
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:26
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
Definition Ops.cpp:208
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:952
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:778
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
Definition Ops.h.inc:729
bool nameIsProduct()
Return true iff the function name is FUNC_NAME_PRODUCT (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:786
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:947
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
Definition Ops.h.inc:782
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:583
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
Definition Ops.cpp:200
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
Definition Ops.h.inc:721
OpClass::Properties & buildInstantiationAttrsEmptyNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
void buildInstantiationAttrsNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
bool isInStruct(Operation *op)
Definition Ops.cpp:45
InFlightDiagnostic genCompareErr(StructDefOp expected, Operation *origin, const char *aspect)
Definition Ops.cpp:90
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
Definition Ops.cpp:112
FailureOr< StructDefOp > verifyInStruct(Operation *op)
Definition Ops.cpp:47
bool isInStructFunctionNamed(Operation *op, char const *funcName)
Definition Ops.cpp:56
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
Definition TypeHelper.h:223
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
FailureOr< StructType > getMainInstanceType(Operation *lookupFrom)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:18
FailureOr< StructDefOp > verifyStructTypeResolution(SymbolTableCollection &tables, StructType ty, Operation *origin)
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
LogicalResult verifyParamsOfType(SymbolTableCollection &tables, ArrayRef< Attribute > tyParams, Type parameterizedType, Operation *origin)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:45
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)
void setSymName(const ::mlir::StringAttr &propValue)
Definition Ops.h.inc:200
void setMemberName(const ::mlir::FlatSymbolRefAttr &propValue)
Definition Ops.h.inc:496