LLZK 2.0.0
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Struct op implementations ---------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// Copyright 2026 Project LLZK
7// SPDX-License-Identifier: Apache-2.0
8//
9//===----------------------------------------------------------------------===//
10
12
22#include "llzk/Util/Constants.h"
23#include "llzk/Util/Debug.h"
26
27#include <mlir/IR/IRMapping.h>
28#include <mlir/IR/OpImplementation.h>
29
30#include <llvm/ADT/MapVector.h>
31#include <llvm/ADT/STLExtras.h>
32#include <llvm/ADT/StringRef.h>
33#include <llvm/ADT/StringSet.h>
34#include <llvm/ADT/TypeSwitch.h>
35
36#include <optional>
37
38// TableGen'd implementation files
40
41// TableGen'd implementation files
42#define GET_OP_CLASSES
44
45using namespace mlir;
46using namespace llzk::felt;
47using namespace llzk::array;
48using namespace llzk::felt;
49using namespace llzk::function;
50using namespace llzk::pod;
51using namespace llzk::polymorphic;
52
53namespace llzk::component {
54
55bool isInStruct(Operation *op) { return getParentOfType<StructDefOp>(op); }
56
57FailureOr<StructDefOp> verifyInStruct(Operation *op) {
59 return res;
60 }
61 return op->emitOpError() << "only valid within a '" << StructDefOp::getOperationName()
62 << "' ancestor";
63}
64
65bool isInStructFunctionNamed(Operation *op, char const *funcName) {
66 if (FuncDefOp parentFunc = getParentOfType<FuncDefOp>(op)) {
67 if (isInStruct(parentFunc.getOperation())) {
68 if (parentFunc.getSymName().compare(funcName) == 0) {
69 return true;
70 }
71 }
72 }
73 return false;
74}
75
76// Again, only valid/implemented for StructDefOp
77template <> LogicalResult SetFuncAllowAttrs<StructDefOp>::verifyTrait(Operation *structOp) {
78 assert(llvm::isa<StructDefOp>(structOp));
79 Region &bodyRegion = llvm::cast<StructDefOp>(structOp).getBodyRegion();
80 if (!bodyRegion.empty()) {
81 bodyRegion.front().walk([](FuncDefOp funcDef) {
82 if (funcDef.nameIsConstrain()) {
83 funcDef.setAllowConstraintAttr();
84 funcDef.setAllowWitnessAttr(false);
85 } else if (funcDef.nameIsCompute()) {
86 funcDef.setAllowConstraintAttr(false);
87 funcDef.setAllowWitnessAttr();
88 } else if (funcDef.nameIsProduct()) {
89 funcDef.setAllowConstraintAttr();
90 funcDef.setAllowWitnessAttr();
91 }
92 });
93 }
94 return success();
95}
96
97InFlightDiagnostic genCompareErr(StructDefOp expected, Operation *origin, const char *aspect) {
98 std::string prefix = std::string();
99 if (SymbolOpInterface symbol = llvm::dyn_cast<SymbolOpInterface>(origin)) {
100 prefix += "\"@";
101 prefix += symbol.getName();
102 prefix += "\" ";
103 }
104 return origin->emitOpError().append(
105 prefix, "must use type of its ancestor '", StructDefOp::getOperationName(), "' \"",
106 expected.getHeaderString(), "\" as ", aspect, " type"
107 );
108}
109
110static inline InFlightDiagnostic structFuncDefError(Operation *origin) {
111 return origin->emitError() << '\'' << StructDefOp::getOperationName() << "' op "
112 << "must define either only a non-derived \"@" << FUNC_NAME_PRODUCT
113 << "\" function, or both non-derived \"@" << FUNC_NAME_COMPUTE
114 << "\" and \"@" << FUNC_NAME_CONSTRAIN << "\" functions; ";
115}
116
119LogicalResult checkSelfType(
120 SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin,
121 const char *aspect
122) {
123 if (StructType actualStructType = llvm::dyn_cast<StructType>(actualType)) {
124 auto actualStructOpt =
125 lookupTopLevelSymbol<StructDefOp>(tables, actualStructType.getNameRef(), origin);
126 if (failed(actualStructOpt)) {
127 return origin->emitError().append(
128 "could not find '", StructDefOp::getOperationName(), "' named \"",
129 actualStructType.getNameRef(), '"'
130 );
131 }
132 StructDefOp actualStruct = actualStructOpt.value().get();
133 if (actualStruct != expectedStruct) {
134 return genCompareErr(expectedStruct, origin, aspect)
135 .attachNote(actualStruct.getLoc())
136 .append("uses this type instead");
137 }
138 // Check for an EXACT match in the parameter list since it must reference the "self" type.
139 ArrayAttr actualTypeParamsAttr = actualStructType.getParams(); // may be nullptr
140 ArrayRef<Attribute> actualTypeParams =
141 actualTypeParamsAttr ? actualTypeParamsAttr.getValue() : ArrayRef<Attribute> {};
142 if (ArrayRef(expectedStruct.getTemplateParamOpNames()) != actualTypeParams) {
143 // To make error messages more consistent and meaningful, if the parameters don't match
144 // because the actual type uses symbols that are not defined, generate an error about the
145 // undefined symbol(s).
146 if (failed(verifyParamsOfType(tables, actualTypeParams, actualStructType, origin))) {
147 return failure();
148 }
149 // Otherwise, generate an error stating the parent struct type must be used.
150 return genCompareErr(expectedStruct, origin, aspect)
151 .attachNote(actualStruct.getLoc())
152 .append("should be type of this '", StructDefOp::getOperationName(), '\'');
153 }
154 } else {
155 return genCompareErr(expectedStruct, origin, aspect);
156 }
157 return success();
158}
159
160//===------------------------------------------------------------------===//
161// StructDefOp
162//===------------------------------------------------------------------===//
163
164StructType StructDefOp::getType(std::optional<ArrayAttr> constParams) {
165 auto pathRes = getPathFromRoot(*this);
166 assert(succeeded(pathRes)); // consistent with StructType::get() with invalid args
167 // Use the specified parameters if provided.
168 if (constParams.has_value()) {
169 return StructType::get(pathRes.value(), constParams.value());
170 }
171 // Check if there is an enclosing `TemplateOp` defining parameters, else there are none.
172 if (TemplateOp parent = getParentOfType<TemplateOp>(*this)) {
173 return StructType::get(pathRes.value(), parent.getConstNames<TemplateParamOp>());
174 } else {
175 return StructType::get(pathRes.value());
176 }
177}
178
180 return buildStringViaCallback([this](llvm::raw_ostream &ss) {
181 FailureOr<SymbolRefAttr> pathToExpected = getPathFromRoot(*this);
182 if (succeeded(pathToExpected)) {
183 ss << pathToExpected.value();
184 } else {
185 // When there is a failure trying to get the resolved name of the struct,
186 // just print its symbol name directly.
187 ss << '@' << this->getSymName();
188 }
189 ss << '<' << debug::toStringList(this->getTemplateParamOpNames()) << '>';
190 });
191}
192
194 if (TemplateOp parent = getParentOfType<TemplateOp>(*this)) {
195 return parent.hasConstOps<TemplateSymbolBindingOpInterface>();
196 }
197 return false;
198}
199
200SmallVector<Attribute> StructDefOp::getTemplateParamOpNames() {
201 if (TemplateOp parent = getParentOfType<TemplateOp>(*this)) {
202 return parent.getConstNames<TemplateParamOp>();
203 } else {
204 return SmallVector<Attribute>();
205 }
206}
207
208SmallVector<Attribute> StructDefOp::getTemplateExprOpNames() {
209 if (TemplateOp parent = getParentOfType<TemplateOp>(*this)) {
210 return parent.getConstNames<TemplateExprOp>();
211 } else {
212 return SmallVector<Attribute>();
213 }
214}
215
217 auto res = getPathFromRoot(*this);
218 assert(succeeded(res));
219 return res.value();
220}
221
222namespace {
223
224inline LogicalResult
225checkMainFuncParamType(Type pType, FuncDefOp inFunc, std::optional<StructType> appendSelfType) {
226 if (isValidMainSignalType(pType)) {
227 return success();
228 }
229
230 std::string message = buildStringViaCallback([&inFunc, appendSelfType](llvm::raw_ostream &ss) {
231 ss << "main entry component \"@" << inFunc.getSymName()
232 << "\" function parameters must be one of: {";
233 if (appendSelfType.has_value()) {
234 ss << appendSelfType.value() << ", ";
235 }
236 ss << '!' << FeltType::name << ", ";
237 ss << '!' << ArrayType::name << "<.. x !" << FeltType::name << ">}";
238 });
239 return inFunc.emitError(message);
240}
241
242inline LogicalResult checkMainFuncOutputSignalType(Type pType, StructDefOp structOp) {
243 if (isValidMainSignalType(pType)) {
244 return success();
245 }
246
247 std::string message = buildStringViaCallback([](llvm::raw_ostream &ss) {
248 ss << "main entry component output signals must be one of: {";
249 ss << '!' << FeltType::name << ", ";
250 ss << '!' << ArrayType::name << "<.. x !" << FeltType::name << ">}";
251 });
252 return structOp.emitError(message);
253}
254
255inline LogicalResult verifyStructComputeConstrain(
256 StructDefOp structDef, FuncDefOp computeFunc, FuncDefOp constrainFunc
257) {
258 // ASSERT: The `SetFuncAllowAttrs` trait on StructDefOp set the attributes correctly.
259 assert(constrainFunc.hasAllowConstraintAttr());
260 assert(!computeFunc.hasAllowConstraintAttr());
261 assert(!constrainFunc.hasAllowWitnessAttr());
262 assert(computeFunc.hasAllowWitnessAttr());
263
264 // Verify parameter types are valid. Skip the first parameter of the "constrain" function; it is
265 // already checked via verifyFuncTypeConstrain() in Function/IR/Ops.cpp.
266 ArrayRef<Type> computeParams = computeFunc.getFunctionType().getInputs();
267 ArrayRef<Type> constrainParams = constrainFunc.getFunctionType().getInputs().drop_front();
268 if (structDef.isMainComponent()) {
269 // Verify the input parameter types are legal. The error message is explicit about what types
270 // are allowed so there is no benefit to report multiple errors if more than one parameter in
271 // the referenced function has an illegal type.
272 for (Type t : computeParams) {
273 if (failed(checkMainFuncParamType(t, computeFunc, std::nullopt))) {
274 return failure(); // checkMainFuncParamType() already emits a sufficient error message
275 }
276 }
277 auto appendSelf = std::make_optional(structDef.getType());
278 for (Type t : constrainParams) {
279 if (failed(checkMainFuncParamType(t, constrainFunc, appendSelf))) {
280 return failure(); // checkMainFuncParamType() already emits a sufficient error message
281 }
282 }
283 }
284
285 if (!typeListsUnify(computeParams, constrainParams)) {
286 return constrainFunc.emitError()
287 .append(
288 "expected \"@", FUNC_NAME_CONSTRAIN,
289 "\" function argument types (sans the first one) to match \"@", FUNC_NAME_COMPUTE,
290 "\" function argument types"
291 )
292 .attachNote(computeFunc.getLoc())
293 .append("\"@", FUNC_NAME_COMPUTE, "\" function defined here");
294 }
295
296 return success();
297}
298
299inline LogicalResult verifyStructProduct(StructDefOp structDef, FuncDefOp productFunc) {
300 // ASSERT: The `SetFuncAllowAttrs` trait on StructDefOp set the attributes correctly
301 assert(productFunc.hasAllowConstraintAttr());
302 assert(productFunc.hasAllowWitnessAttr());
303
304 // Verify parameter types are valid
305 if (structDef.isMainComponent()) {
306 ArrayRef<Type> productParams = productFunc.getFunctionType().getInputs();
307 // Verify the input parameter types are legal. The error message is explicit about what types
308 // are allowed so there is no benefit to report multiple errors if more than one parameter in
309 // the referenced function has an illegal type.
310 for (Type t : productParams) {
311 if (failed(checkMainFuncParamType(t, productFunc, std::nullopt))) {
312 return failure(); // checkMainFuncParamType() already emits a sufficient error message
313 }
314 }
315 }
316
317 return success();
318}
319
320} // namespace
321
323 std::optional<FuncDefOp> foundCompute = std::nullopt;
324 std::optional<FuncDefOp> foundConstrain = std::nullopt;
325 std::optional<FuncDefOp> foundProduct = std::nullopt;
326 {
327 // Verify the following:
328 // 1. The only ops within the body are member and function definitions
329 // 2. The only functions defined in the struct are `@compute()` and `@constrain()`, or
330 // `@product()`
331 OwningEmitErrorFn emitError = getEmitOpErrFn(this);
332 Region &bodyRegion = getBodyRegion();
333 if (!bodyRegion.empty()) {
334 for (Operation &op : bodyRegion.front()) {
335 auto member = llvm::dyn_cast<MemberDefOp>(op);
336 if (!member) {
337 if (FuncDefOp funcDef = llvm::dyn_cast<FuncDefOp>(op)) {
338 if (funcDef.nameIsCompute()) {
339 if (foundCompute) {
340 return structFuncDefError(funcDef.getOperation())
341 << "found multiple \"@" << FUNC_NAME_COMPUTE << "\" functions";
342 }
343 foundCompute = std::make_optional(funcDef);
344 } else if (funcDef.nameIsConstrain()) {
345 if (foundConstrain) {
346 return structFuncDefError(funcDef.getOperation())
347 << "found multiple \"@" << FUNC_NAME_CONSTRAIN << "\" functions";
348 }
349 foundConstrain = std::make_optional(funcDef);
350 } else if (funcDef.nameIsProduct()) {
351 if (foundProduct) {
352 return structFuncDefError(funcDef.getOperation())
353 << "found multiple \"@" << FUNC_NAME_PRODUCT << "\" functions";
354 }
355 foundProduct = std::make_optional(funcDef);
356 } else {
357 // Must do a little more than a simple call to '?.emitOpError()' to
358 // tag the error with correct location and correct op name.
359 return structFuncDefError(funcDef.getOperation())
360 << "found \"@" << funcDef.getSymName() << '"';
361 }
362 } else {
363 return op.emitOpError()
364 << "invalid operation in '" << StructDefOp::getOperationName() << "'; only '"
365 << MemberDefOp::getOperationName() << '\'' << " and '"
366 << FuncDefOp::getOperationName() << "' operations are permitted";
367 }
368 }
369 // Also check if the member complies with output signal restrictions
370 else if (isMainComponent() && member.hasPublicAttr() &&
371 failed(checkMainFuncOutputSignalType(member.getType(), *this))) {
372 // checkMainFuncOutputSignalType already emits a sufficient error message
373 return failure();
374 }
375 }
376 }
377
378 if (!foundCompute.has_value() && foundConstrain.has_value()) {
379 return structFuncDefError(getOperation()) << "found \"@" << FUNC_NAME_CONSTRAIN
380 << "\", missing \"@" << FUNC_NAME_COMPUTE << "\"";
381 }
382 if (!foundConstrain.has_value() && foundCompute.has_value()) {
383 return structFuncDefError(getOperation()) << "found \"@" << FUNC_NAME_COMPUTE
384 << "\", missing \"@" << FUNC_NAME_CONSTRAIN << "\"";
385 }
386 }
387
388 if (!foundCompute.has_value() && !foundConstrain.has_value() && !foundProduct.has_value()) {
389 return structFuncDefError(getOperation())
390 << "could not find \"@" << FUNC_NAME_PRODUCT << "\", \"@" << FUNC_NAME_COMPUTE
391 << "\", or \"@" << FUNC_NAME_CONSTRAIN << "\"";
392 }
393
394 // Check which funcs are present and not marked with {llzk.derived}
395 auto nonderived = [](std::optional<FuncDefOp> op) -> bool {
396 return op && !(*op)->hasAttr(DERIVED_ATTR_NAME);
397 };
398
399 auto attachDerivedNotes = [&foundCompute, &foundConstrain,
400 &foundProduct](InFlightDiagnostic &&error) {
401 if (foundProduct && (*foundProduct)->hasAttr(DERIVED_ATTR_NAME)) {
402 error.attachNote(foundProduct->getLoc()) << "derived \"@" << FUNC_NAME_PRODUCT << "\" here";
403 }
404 if (foundCompute && (*foundCompute)->hasAttr(DERIVED_ATTR_NAME)) {
405 error.attachNote(foundCompute->getLoc()) << "derived \"@" << FUNC_NAME_COMPUTE << "\" here";
406 }
407 if (foundConstrain && (*foundConstrain)->hasAttr(DERIVED_ATTR_NAME)) {
408 error.attachNote(foundConstrain->getLoc())
409 << "derived \"@" << FUNC_NAME_CONSTRAIN << "\" here";
410 }
411 return error;
412 };
413
414 // We know that (@compute+@constrain) is present or @product is present, or both
415
416 // Error cases:
417 // Everything is derived
418 if (!nonderived(foundCompute) && !nonderived(foundConstrain) && !nonderived(foundProduct)) {
419 return attachDerivedNotes(
420 structFuncDefError(getOperation())
421 << "could not find non-derived \"@" << FUNC_NAME_PRODUCT << "\", \"@" << FUNC_NAME_COMPUTE
422 << "\", or \"@" << FUNC_NAME_CONSTRAIN << "\""
423 );
424 }
425
426 // Only one of @compute/@constrain is non-derived
427 if (nonderived(foundCompute) ^ nonderived(foundConstrain)) {
428 return attachDerivedNotes(
429 structFuncDefError(getOperation())
430 << "\"@" << FUNC_NAME_COMPUTE << "\" and \"@" << FUNC_NAME_CONSTRAIN
431 << "\" must both be either derived or non-derived"
432 );
433 }
434
435 // Here, at least one thing is non-derived, and @compute/@constrain are derived or non-derived
436 // together so everything is fine
437 if (nonderived(foundCompute) && nonderived(foundConstrain) && !nonderived(foundProduct)) {
438 return verifyStructComputeConstrain(*this, *foundCompute, *foundConstrain);
439 }
440
441 assert(!nonderived(foundCompute) && !nonderived(foundConstrain) && nonderived(foundProduct));
442 return verifyStructProduct(*this, *foundProduct);
443}
444
446 for (Operation &op : *getBody()) {
447 if (MemberDefOp memberDef = llvm::dyn_cast_if_present<MemberDefOp>(op)) {
448 if (memberName.compare(memberDef.getSymNameAttr()) == 0) {
449 return memberDef;
450 }
451 }
452 }
453 return nullptr;
454}
455
456std::vector<MemberDefOp> StructDefOp::getMemberDefs() {
457 std::vector<MemberDefOp> res;
458 for (Operation &op : *getBody()) {
459 if (MemberDefOp memberDef = llvm::dyn_cast_if_present<MemberDefOp>(op)) {
460 res.push_back(memberDef);
461 }
462 }
463 return res;
464}
465
467 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_COMPUTE));
468}
469
471 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_CONSTRAIN));
472}
473
475 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_PRODUCT));
476}
477
479 FailureOr<StructType> mainTypeOpt = getMainInstanceType(this->getOperation());
480 if (succeeded(mainTypeOpt)) {
481 if (StructType mainType = mainTypeOpt.value()) {
482 return structTypesUnify(mainType, this->getType());
483 }
484 }
485 return false;
486}
487
488// Custom implementation to deserialize bytecode produced prior to version 2 when `StructDefOp` had
489// an optional `const_params` attribute serialized before `sym_name`.
490LogicalResult StructDefOp::readProperties(DialectBytecodeReader &reader, OperationState &state) {
491 auto &prop = state.getOrAddProperties<Properties>();
492
493 auto versionOpt = reader.getDialectVersion<StructDialect>();
494 if (succeeded(versionOpt)) {
495 const auto &ver = static_cast<const LLZKDialectVersion &>(**versionOpt);
496 if (ver.majorVersion < 2) {
497 // Read and stash the old `const_params` as a temporary attribute so `upgradeFromVersion()`
498 // can wrap this `StructDefOp` in a `TemplateOp` with the corresponding `TemplateParamOps`.
499 ArrayAttr constParams;
500 if (failed(reader.readOptionalAttribute(constParams))) {
501 return failure();
502 }
503 if (constParams) {
504 state.addAttribute(llzk::kV1ConstParamsAttr, constParams);
505 }
506 return reader.readAttribute(prop.sym_name);
507 }
508 }
509
510 // Same as tablegen would generate to deserialize current-version IR.
511 return reader.readAttribute(prop.sym_name);
512}
513
514// Same as tablegen would generate to serialize version 2 IR.
515void StructDefOp::writeProperties(DialectBytecodeWriter &writer) {
516 auto &prop = getProperties();
517 writer.writeAttribute(prop.sym_name);
518}
519
520//===------------------------------------------------------------------===//
521// MemberDefOp
522//===------------------------------------------------------------------===//
523
525 OpBuilder &odsBuilder, OperationState &odsState, StringAttr sym_name, TypeAttr type,
526 bool isSignal, bool isColumn
527) {
528 Properties &props = odsState.getOrAddProperties<Properties>();
529 props.setSymName(sym_name);
530 props.setType(type);
531 if (isColumn) {
532 props.column = odsBuilder.getUnitAttr();
533 }
534 if (isSignal) {
535 props.signal = odsBuilder.getUnitAttr();
536 }
537}
538
540 OpBuilder &odsBuilder, OperationState &odsState, StringRef sym_name, Type type, bool isSignal,
541 bool isColumn
542) {
543 build(
544 odsBuilder, odsState, odsBuilder.getStringAttr(sym_name), TypeAttr::get(type), isSignal,
545 isColumn
546 );
547}
548
550 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, ValueRange operands,
551 ArrayRef<NamedAttribute> attributes, bool isSignal, bool isColumn
552) {
553 assert(operands.size() == 0u && "mismatched number of parameters");
554 odsState.addOperands(operands);
555 odsState.addAttributes(attributes);
556 assert(resultTypes.size() == 0u && "mismatched number of return types");
557 odsState.addTypes(resultTypes);
558 if (isColumn) {
559 odsState.getOrAddProperties<Properties>().column = odsBuilder.getUnitAttr();
560 }
561 if (isSignal) {
562 odsState.getOrAddProperties<Properties>().signal = odsBuilder.getUnitAttr();
563 }
564}
565
566void MemberDefOp::setPublicAttr(bool newValue) {
567 if (newValue) {
568 getOperation()->setAttr(PublicAttr::name, UnitAttr::get(getContext()));
569 } else {
570 getOperation()->removeAttr(PublicAttr::name);
571 }
572}
573
574static LogicalResult
575verifyMemberDefTypeImpl(Type memberType, SymbolTableCollection &tables, Operation *origin) {
576 if (StructType memberStructType = llvm::dyn_cast<StructType>(memberType)) {
577 // Special case for StructType verifies that the member type can resolve and that it is NOT the
578 // parent struct (i.e., struct members cannot create circular references).
579 auto memberTypeRes = verifyStructTypeResolution(tables, memberStructType, origin);
580 if (failed(memberTypeRes)) {
581 return failure(); // above already emits a sufficient error message
582 }
583 StructDefOp parentRes = getParentOfType<StructDefOp>(origin);
584 assert(parentRes && "MemberDefOp parent is always StructDefOp"); // per ODS def
585 if (memberTypeRes.value() == parentRes) {
586 return origin->emitOpError()
587 .append("type is circular")
588 .attachNote(parentRes.getLoc())
589 .append("references parent component defined here");
590 }
591 return success();
592 } else {
593 return verifyTypeResolution(tables, origin, memberType);
594 }
595}
596
597LogicalResult MemberDefOp::verifySymbolUses(SymbolTableCollection &tables) {
598 Type memberType = this->getType();
599 if (failed(verifyMemberDefTypeImpl(memberType, tables, *this))) {
600 return failure();
601 }
602
603 if (!getColumn()) {
604 return success();
605 }
606 // If the member is marked as a column only a small subset of types are allowed.
607 if (!isValidColumnType(getType(), tables, *this)) {
608 return emitOpError() << "marked as column can only contain felts, arrays of column types, or "
609 "structs with columns, but has type "
610 << getType();
611 }
612 return success();
613}
614
615LogicalResult MemberDefOp::verify() {
617 return emitOpError() << "with type " << getType() << " cannot have the signal attribute";
618 }
619 return success();
620}
621
622//===------------------------------------------------------------------===//
623// MemberRefOp implementations
624//===------------------------------------------------------------------===//
625namespace {
626
627FailureOr<SymbolLookupResult<MemberDefOp>>
628getMemberDefOpImpl(MemberRefOpInterface refOp, SymbolTableCollection &tables, StructType tyStruct) {
629 Operation *op = refOp.getOperation();
630 auto structDefRes = tyStruct.getDefinition(tables, op);
631 if (failed(structDefRes)) {
632 return failure(); // getDefinition() already emits a sufficient error message
633 }
634 // Copy namespace because we will need it later.
635 llvm::SmallVector<llvm::StringRef> structDefOpNs(structDefRes->getNamespace());
637 tables, SymbolRefAttr::get(refOp->getContext(), refOp.getMemberName()),
638 std::move(*structDefRes), op
639 );
640 if (failed(res)) {
641 return refOp->emitError() << "could not find '" << MemberDefOp::getOperationName()
642 << "' named \"@" << refOp.getMemberName() << "\" in \""
643 << tyStruct.getNameRef() << '"';
644 }
645 // Prepend the namespace of the struct lookup since the type of the member is meant to be resolved
646 // within that scope.
647 res->prependNamespace(structDefOpNs);
648 return std::move(res.value());
649}
650
651static FailureOr<SymbolLookupResult<MemberDefOp>>
652findMember(MemberRefOpInterface refOp, SymbolTableCollection &tables) {
653 // Ensure the base component/struct type reference can be resolved.
654 StructType tyStruct = refOp.getStructType();
655 if (failed(tyStruct.verifySymbolRef(tables, refOp.getOperation()))) {
656 return failure();
657 }
658 // Ensure the member name can be resolved in that struct.
659 return getMemberDefOpImpl(refOp, tables, tyStruct);
660}
661
662static LogicalResult verifySymbolUsesImpl(
663 MemberRefOpInterface refOp, SymbolTableCollection &tables,
664 SymbolLookupResult<MemberDefOp> &member
665) {
666 // Ensure the type of the referenced member declaration matches the type used in this op.
667 Type actualType = refOp.getVal().getType();
668 Type memberType = member.get().getType();
669 if (!typesUnify(actualType, memberType, member.getNamespace())) {
670 return refOp->emitOpError() << "has wrong type; expected " << memberType << ", got "
671 << actualType;
672 }
673 // Ensure any SymbolRef used in the type are valid
674 return verifyTypeResolution(tables, refOp.getOperation(), actualType);
675}
676
677LogicalResult verifySymbolUsesImpl(MemberRefOpInterface refOp, SymbolTableCollection &tables) {
678 // Ensure the member name can be resolved in that struct.
679 auto member = findMember(refOp, tables);
680 if (failed(member)) {
681 return member; // getMemberDefOp() already emits a sufficient error message
682 }
683 return verifySymbolUsesImpl(refOp, tables, *member);
684}
685
686} // namespace
687
688FailureOr<SymbolLookupResult<MemberDefOp>>
689MemberRefOpInterface::getMemberDefOp(SymbolTableCollection &tables) {
690 return getMemberDefOpImpl(*this, tables, getStructType());
691}
692
693LogicalResult MemberReadOp::verifySymbolUses(SymbolTableCollection &tables) {
694 auto member = findMember(*this, tables);
695 if (failed(member)) {
696 return failure();
697 }
698 if (failed(verifySymbolUsesImpl(*this, tables, *member))) {
699 return failure();
700 }
701 // If the member is not a column and an offset was specified then fail to validate
702 if (!member->get().getColumn() && getTableOffset().has_value()) {
703 return emitOpError("cannot read with table offset from a member that is not a column")
704 .attachNote(member->get().getLoc())
705 .append("member defined here");
706 }
707 // If the member is private and this read is outside the struct, then fail to validate.
708 // The current op may be inside a struct or a free function, but the
709 // member op (the member definition) is always inside a struct.
710 FailureOr<StructDefOp> memberParentRes = verifyInStruct(member->get());
711 if (failed(memberParentRes)) {
712 return failure(); // verifyInStruct() already emits a sufficient error message
713 }
714 StructDefOp thisParent = getParentOfType<StructDefOp>(*this);
715 StructDefOp memberParentStruct = memberParentRes.value();
716 if (!member->get().hasPublicAttr() && (!thisParent || thisParent != memberParentStruct)) {
717 return emitOpError()
718 .append(
719 "cannot read from private member of struct \"", memberParentStruct.getHeaderString(),
720 "\""
721 )
722 .attachNote(member->get().getLoc())
723 .append("member defined here");
724 }
725 return success();
726}
727
728LogicalResult MemberWriteOp::verifySymbolUses(SymbolTableCollection &tables) {
729 // Ensure the write op only targets members in the current struct.
730 FailureOr<StructDefOp> getParentRes = verifyInStruct(*this);
731 if (failed(getParentRes)) {
732 return failure(); // verifyInStruct() already emits a sufficient error message
733 }
734 if (failed(checkSelfType(tables, *getParentRes, getComponent().getType(), *this, "base value"))) {
735 return failure(); // checkSelfType() already emits a sufficient error message
736 }
737 // Perform the standard member ref checks.
738 return verifySymbolUsesImpl(*this, tables);
739}
740
741//===------------------------------------------------------------------===//
742// MemberReadOp
743//===------------------------------------------------------------------===//
744
746 OpBuilder &builder, OperationState &state, Type resultType, Value component, StringAttr member
747) {
748 Properties &props = state.getOrAddProperties<Properties>();
749 props.setMemberName(FlatSymbolRefAttr::get(member));
750 state.addTypes(resultType);
751 state.addOperands(component);
753}
754
756 OpBuilder &builder, OperationState &state, Type resultType, Value component, StringAttr member,
757 Attribute dist, ValueRange mapOperands, std::optional<int32_t> numDims
758) {
759 // '!mapOperands.empty()' implies 'numDims.has_value()'
760 assert(mapOperands.empty() || numDims.has_value());
761 state.addOperands(component);
762 state.addTypes(resultType);
763 if (numDims.has_value()) {
765 builder, state, ArrayRef({mapOperands}), builder.getDenseI32ArrayAttr({*numDims})
766 );
767 } else {
769 }
770 Properties &props = state.getOrAddProperties<Properties>();
771 props.setMemberName(FlatSymbolRefAttr::get(member));
772 props.setTableOffset(dist);
773}
774
776 OpBuilder & /*odsBuilder*/, OperationState &odsState, TypeRange resultTypes,
777 ValueRange operands, ArrayRef<NamedAttribute> attrs
778) {
779 odsState.addTypes(resultTypes);
780 odsState.addOperands(operands);
781 odsState.addAttributes(attrs);
782}
783
784LogicalResult MemberReadOp::verify() {
785 SmallVector<AffineMapAttr, 1> mapAttrs;
786 if (AffineMapAttr map =
787 llvm::dyn_cast_if_present<AffineMapAttr>(getTableOffset().value_or(nullptr))) {
788 mapAttrs.push_back(map);
789 }
791 getMapOperands(), getNumDimsPerMap(), mapAttrs, *this
792 );
793}
794
795//===------------------------------------------------------------------===//
796// CreateStructOp
797//===------------------------------------------------------------------===//
798
799void CreateStructOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
800 setNameFn(getResult(), "self");
801}
802
803LogicalResult CreateStructOp::verifySymbolUses(SymbolTableCollection &tables) {
804 FailureOr<StructDefOp> getParentRes = verifyInStruct(*this);
805 if (failed(getParentRes)) {
806 return failure(); // verifyInStruct() already emits a sufficient error message
807 }
808 if (failed(checkSelfType(tables, *getParentRes, this->getType(), *this, "result"))) {
809 return failure();
810 }
811 return success();
812}
813
814} // namespace llzk::component
llvm::ArrayRef< llvm::StringRef > getNamespace() const
Return the stack of symbol names from either IncludeOp or ModuleOp that were traversed to load this r...
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:51
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:803
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
Definition Ops.cpp:799
::mlir::TypedValue<::llzk::component::StructType > getResult()
Definition Ops.h.inc:143
void setPublicAttr(bool newValue=true)
Definition Ops.cpp:566
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:353
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::StringAttr sym_name, ::mlir::TypeAttr type, bool isSignal=false, bool isColumn=false)
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:597
::llvm::LogicalResult verify()
Definition Ops.cpp:615
FoldAdaptor::Properties Properties
Definition Ops.h.inc:315
::std::optional<::mlir::Attribute > getTableOffset()
Definition Ops.cpp.inc:979
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:692
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType, ::mlir::Value component, ::mlir::StringAttr member)
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:693
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
Definition Ops.cpp.inc:984
::llvm::LogicalResult verify()
Definition Ops.cpp:784
FoldAdaptor::Properties Properties
Definition Ops.h.inc:639
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the MemberRefOp.
::mlir::FailureOr< SymbolLookupResult< MemberDefOp > > getMemberDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the member referenced in this op.
Definition Ops.cpp:689
::llvm::StringRef getMemberName()
Gets the member name attribute value from the MemberRefOp.
::llzk::component::StructType getStructType()
Gets the struct type of the target component.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Definition Ops.h.inc:952
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:728
static mlir::LogicalResult verifyTrait(mlir::Operation *op)
::llvm::SmallVector<::mlir::Attribute > getTemplateParamOpNames()
If this struct.def is within a poly.template, return names of all poly.param within the poly....
Definition Ops.cpp:200
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
::mlir::Region & getBodyRegion()
Definition Ops.h.inc:1189
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:1165
::llvm::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state)
Definition Ops.cpp:490
::llvm::SmallVector<::mlir::Attribute > getTemplateExprOpNames()
If this struct.def is within a poly.template, return names of all poly.expr within the poly....
Definition Ops.cpp:208
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:1600
::mlir::SymbolRefAttr getFullyQualifiedName()
Return the full name for this struct from the root module, including any surrounding module scopes.
Definition Ops.cpp:216
::std::vector< MemberDefOp > getMemberDefs()
Get all MemberDefOp in this structure.
Definition Ops.cpp:456
FoldAdaptor::Properties Properties
Definition Ops.h.inc:1151
::llzk::function::FuncDefOp getProductFuncOp()
Gets the FuncDefOp that defines the product function in this structure, if present,...
Definition Ops.cpp:474
MemberDefOp getMemberDef(::mlir::StringAttr memberName)
Gets the MemberDefOp that defines the member in this structure with the given name,...
Definition Ops.cpp:445
void writeProperties(::mlir::DialectBytecodeWriter &writer)
Definition Ops.cpp:515
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:470
bool hasTemplateSymbolBindings()
Return true iff the struct.def appears within a poly.template that defines constant parameters and/or...
Definition Ops.cpp:193
::llvm::LogicalResult verifyRegions()
Definition Ops.cpp:322
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:466
bool isMainComponent()
Return true iff this struct.def is the main struct. See llzk::MAIN_ATTR_NAME.
Definition Ops.cpp:478
::std::string getHeaderString()
Generate header string, in the same format as the assemblyFormat.
Definition Ops.cpp:179
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:79
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
Definition Types.cpp:26
::mlir::LogicalResult verifySymbolRef(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op)
Definition Types.cpp:60
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:32
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
Definition Ops.cpp:218
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:984
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:853
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
Definition Ops.h.inc:812
bool nameIsProduct()
Return true iff the function name is FUNC_NAME_PRODUCT (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:861
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:979
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
Definition Ops.h.inc:857
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:666
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
Definition Ops.cpp:210
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
Definition Ops.h.inc:804
OpClass::Properties & buildInstantiationAttrsEmptyNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
void buildInstantiationAttrsNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
bool isInStruct(Operation *op)
Definition Ops.cpp:55
InFlightDiagnostic genCompareErr(StructDefOp expected, Operation *origin, const char *aspect)
Definition Ops.cpp:97
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
Definition Ops.cpp:119
FailureOr< StructDefOp > verifyInStruct(Operation *op)
Definition Ops.cpp:57
bool isInStructFunctionNamed(Operation *op, char const *funcName)
Definition Ops.cpp:65
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
Definition Debug.h:156
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
Definition TypeHelper.h:240
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
FailureOr< StructType > getMainInstanceType(Operation *lookupFrom)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isFeltOrSimpleFeltAggregate(Type ty)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
bool isValidMainSignalType(Type pType)
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
Definition OpHelpers.h:69
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:18
constexpr char DERIVED_ATTR_NAME[]
Name of the attribute on a @product func that has been automatically aligned from @compute + @constra...
Definition Constants.h:28
FailureOr< StructDefOp > verifyStructTypeResolution(SymbolTableCollection &tables, StructType ty, Operation *origin)
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
LogicalResult verifyParamsOfType(SymbolTableCollection &tables, ArrayRef< Attribute > tyParams, Type parameterizedType, Operation *origin)
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)
void setSymName(const ::mlir::StringAttr &propValue)
Definition Ops.h.inc:200
void setMemberName(const ::mlir::FlatSymbolRefAttr &propValue)
Definition Ops.h.inc:497