LLZK 2.1.1
An open-source IR for Zero Knowledge (ZK) circuits
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===-- Ops.cpp - Struct op implementations ---------------------*- C++ -*-===//
2//
3// Part of the LLZK Project, under the Apache License v2.0.
4// See LICENSE.txt for license information.
5// Copyright 2025 Veridise Inc.
6// Copyright 2026 Project LLZK
7// SPDX-License-Identifier: Apache-2.0
8//
9//===----------------------------------------------------------------------===//
10
12
23#include "llzk/Util/Constants.h"
24#include "llzk/Util/Debug.h"
27
28#include <mlir/IR/IRMapping.h>
29#include <mlir/IR/OpImplementation.h>
30
31#include <llvm/ADT/MapVector.h>
32#include <llvm/ADT/STLExtras.h>
33#include <llvm/ADT/StringRef.h>
34#include <llvm/ADT/StringSet.h>
35#include <llvm/ADT/TypeSwitch.h>
36
37#include <optional>
38
39// TableGen'd implementation files
41
42// TableGen'd implementation files
43#define GET_OP_CLASSES
45
46using namespace mlir;
47using namespace llzk::felt;
48using namespace llzk::array;
49using namespace llzk::felt;
50using namespace llzk::function;
51using namespace llzk::pod;
52using namespace llzk::polymorphic;
53using namespace llzk::verif;
54
55namespace llzk::component {
56
57bool isInStruct(Operation *op) { return getParentOfType<StructDefOp>(op); }
58
59FailureOr<StructDefOp> verifyInStruct(Operation *op) {
61 return res;
62 }
63 return op->emitOpError() << "only valid within a '" << StructDefOp::getOperationName()
64 << "' ancestor";
65}
66
67bool isInStructFunctionNamed(Operation *op, char const *funcName) {
68 if (FuncDefOp parentFunc = getParentOfType<FuncDefOp>(op)) {
69 if (isInStruct(parentFunc.getOperation())) {
70 if (parentFunc.getSymName().compare(funcName) == 0) {
71 return true;
72 }
73 }
74 }
75 return false;
76}
77
78// Again, only valid/implemented for StructDefOp
79template <> LogicalResult SetFuncAllowAttrs<StructDefOp>::verifyTrait(Operation *structOp) {
80 assert(llvm::isa<StructDefOp>(structOp));
81 Region &bodyRegion = llvm::cast<StructDefOp>(structOp).getBodyRegion();
82 if (!bodyRegion.empty()) {
83 bodyRegion.front().walk([](FuncDefOp funcDef) {
84 if (funcDef.nameIsConstrain()) {
85 funcDef.setAllowConstraintAttr();
86 funcDef.setAllowWitnessAttr(false);
87 } else if (funcDef.nameIsCompute()) {
88 funcDef.setAllowConstraintAttr(false);
89 funcDef.setAllowWitnessAttr();
90 } else if (funcDef.nameIsProduct()) {
91 funcDef.setAllowConstraintAttr();
92 funcDef.setAllowWitnessAttr();
93 }
94 });
95 }
96 return success();
97}
98
99InFlightDiagnostic genCompareErr(StructDefOp expected, Operation *origin, const char *aspect) {
100 std::string prefix = std::string();
101 if (SymbolOpInterface symbol = llvm::dyn_cast<SymbolOpInterface>(origin)) {
102 prefix += "\"@";
103 prefix += symbol.getName();
104 prefix += "\" ";
105 }
106 return origin->emitOpError().append(
107 prefix, "must use type of its ancestor '", StructDefOp::getOperationName(), "' \"",
108 expected.getHeaderString(), "\" as ", aspect, " type"
109 );
110}
111
112static inline InFlightDiagnostic structFuncDefError(Operation *origin) {
113 return origin->emitError() << '\'' << StructDefOp::getOperationName() << "' op "
114 << "must define either only a non-derived \"@" << FUNC_NAME_PRODUCT
115 << "\" function, or both non-derived \"@" << FUNC_NAME_COMPUTE
116 << "\" and \"@" << FUNC_NAME_CONSTRAIN << "\" functions; ";
117}
118
121LogicalResult checkSelfType(
122 SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin,
123 const char *aspect
124) {
125 if (StructType actualStructType = llvm::dyn_cast<StructType>(actualType)) {
126 auto actualStructOpt =
127 lookupTopLevelSymbol<StructDefOp>(tables, actualStructType.getNameRef(), origin);
128 if (failed(actualStructOpt)) {
129 return origin->emitError().append(
130 "could not find '", StructDefOp::getOperationName(), "' named \"",
131 actualStructType.getNameRef(), '"'
132 );
133 }
134 StructDefOp actualStruct = actualStructOpt.value().get();
135 if (actualStruct != expectedStruct) {
136 return genCompareErr(expectedStruct, origin, aspect)
137 .attachNote(actualStruct.getLoc())
138 .append("uses this type instead");
139 }
140 // Check for an EXACT match in the parameter list since it must reference the "self" type.
141 ArrayAttr actualTypeParamsAttr = actualStructType.getParams(); // may be nullptr
142 ArrayRef<Attribute> actualTypeParams =
143 actualTypeParamsAttr ? actualTypeParamsAttr.getValue() : ArrayRef<Attribute> {};
144 if (ArrayRef(expectedStruct.getTemplateParamOpNames()) != actualTypeParams) {
145 // To make error messages more consistent and meaningful, if the parameters don't match
146 // because the actual type uses symbols that are not defined, generate an error about the
147 // undefined symbol(s).
148 if (failed(verifyParamsOfType(tables, actualTypeParams, actualStructType, origin))) {
149 return failure();
150 }
151 // Otherwise, generate an error stating the parent struct type must be used.
152 return genCompareErr(expectedStruct, origin, aspect)
153 .attachNote(actualStruct.getLoc())
154 .append("should be type of this '", StructDefOp::getOperationName(), '\'');
155 }
156 } else {
157 return genCompareErr(expectedStruct, origin, aspect);
158 }
159 return success();
160}
161
162//===------------------------------------------------------------------===//
163// StructDefOp
164//===------------------------------------------------------------------===//
165
166StructType StructDefOp::getType(std::optional<ArrayAttr> constParams) {
167 auto pathRes = getPathFromRoot(*this);
168 assert(succeeded(pathRes)); // consistent with StructType::get() with invalid args
169 // Use the specified parameters if provided.
170 if (constParams.has_value()) {
171 return StructType::get(pathRes.value(), constParams.value());
172 }
173 // Check if there is an enclosing `TemplateOp` defining parameters, else there are none.
174 if (TemplateOp parent = getParentOfType<TemplateOp>(*this)) {
175 return StructType::get(pathRes.value(), parent.getConstNames<TemplateParamOp>());
176 } else {
177 return StructType::get(pathRes.value());
178 }
179}
180
182 return buildStringViaCallback([this](llvm::raw_ostream &ss) {
183 FailureOr<SymbolRefAttr> pathToExpected = getPathFromRoot(*this);
184 if (succeeded(pathToExpected)) {
185 ss << pathToExpected.value();
186 } else {
187 // When there is a failure trying to get the resolved name of the struct,
188 // just print its symbol name directly.
189 ss << '@' << this->getSymName();
190 }
191 ss << '<' << debug::toStringList(this->getTemplateParamOpNames()) << '>';
192 });
193}
194
196 if (TemplateOp parent = getParentOfType<TemplateOp>(*this)) {
197 return parent.hasConstOps<TemplateSymbolBindingOpInterface>();
198 }
199 return false;
200}
201
202SmallVector<Attribute> StructDefOp::getTemplateParamOpNames() {
203 if (TemplateOp parent = getParentOfType<TemplateOp>(*this)) {
204 return parent.getConstNames<TemplateParamOp>();
205 } else {
206 return SmallVector<Attribute>();
207 }
208}
209
210SmallVector<Attribute> StructDefOp::getTemplateExprOpNames() {
211 if (TemplateOp parent = getParentOfType<TemplateOp>(*this)) {
212 return parent.getConstNames<TemplateExprOp>();
213 } else {
214 return SmallVector<Attribute>();
215 }
216}
217
219 auto res = getPathFromRoot(*this);
220 assert(succeeded(res));
221 return res.value();
222}
223
224namespace {
225
226inline LogicalResult
227checkMainFuncParamType(Type pType, FuncDefOp inFunc, std::optional<StructType> appendSelfType) {
228 if (isValidMainSignalType(pType)) {
229 return success();
230 }
231
232 std::string message = buildStringViaCallback([&inFunc, appendSelfType](llvm::raw_ostream &ss) {
233 ss << "main entry component \"@" << inFunc.getSymName()
234 << "\" function parameters must be one of: {";
235 if (appendSelfType.has_value()) {
236 ss << appendSelfType.value() << ", ";
237 }
238 ss << '!' << FeltType::name << ", ";
239 ss << '!' << ArrayType::name << "<.. x !" << FeltType::name << ">}";
240 });
241 return inFunc.emitError(message);
242}
243
244inline LogicalResult checkMainFuncOutputSignalType(Type pType, StructDefOp structOp) {
245 if (isValidMainSignalType(pType)) {
246 return success();
247 }
248
249 std::string message = buildStringViaCallback([](llvm::raw_ostream &ss) {
250 ss << "main entry component output signals must be one of: {";
251 ss << '!' << FeltType::name << ", ";
252 ss << '!' << ArrayType::name << "<.. x !" << FeltType::name << ">}";
253 });
254 return structOp.emitError(message);
255}
256
257inline LogicalResult verifyStructComputeConstrain(
258 StructDefOp structDef, FuncDefOp computeFunc, FuncDefOp constrainFunc
259) {
260 // ASSERT: The `SetFuncAllowAttrs` trait on StructDefOp set the attributes correctly.
261 assert(constrainFunc.hasAllowConstraintAttr());
262 assert(!computeFunc.hasAllowConstraintAttr());
263 assert(!constrainFunc.hasAllowWitnessAttr());
264 assert(computeFunc.hasAllowWitnessAttr());
265
266 // Verify parameter types are valid. Skip the first parameter of the "constrain" function; it is
267 // already checked via verifyFuncTypeConstrain() in Function/IR/Ops.cpp.
268 ArrayRef<Type> computeParams = computeFunc.getFunctionType().getInputs();
269 ArrayRef<Type> constrainParams = constrainFunc.getFunctionType().getInputs().drop_front();
270 if (structDef.isMainComponent()) {
271 // Verify the input parameter types are legal. The error message is explicit about what types
272 // are allowed so there is no benefit to report multiple errors if more than one parameter in
273 // the referenced function has an illegal type.
274 for (Type t : computeParams) {
275 if (failed(checkMainFuncParamType(t, computeFunc, std::nullopt))) {
276 return failure(); // checkMainFuncParamType() already emits a sufficient error message
277 }
278 }
279 auto appendSelf = std::make_optional(structDef.getType());
280 for (Type t : constrainParams) {
281 if (failed(checkMainFuncParamType(t, constrainFunc, appendSelf))) {
282 return failure(); // checkMainFuncParamType() already emits a sufficient error message
283 }
284 }
285 }
286
287 if (!typeListsUnify(computeParams, constrainParams)) {
288 return constrainFunc.emitError()
289 .append(
290 "expected \"@", FUNC_NAME_CONSTRAIN,
291 "\" function argument types (sans the first one) to match \"@", FUNC_NAME_COMPUTE,
292 "\" function argument types"
293 )
294 .attachNote(computeFunc.getLoc())
295 .append("\"@", FUNC_NAME_COMPUTE, "\" function defined here");
296 }
297
298 return success();
299}
300
301inline LogicalResult verifyStructProduct(StructDefOp structDef, FuncDefOp productFunc) {
302 // ASSERT: The `SetFuncAllowAttrs` trait on StructDefOp set the attributes correctly
303 assert(productFunc.hasAllowConstraintAttr());
304 assert(productFunc.hasAllowWitnessAttr());
305
306 // Verify parameter types are valid
307 if (structDef.isMainComponent()) {
308 ArrayRef<Type> productParams = productFunc.getFunctionType().getInputs();
309 // Verify the input parameter types are legal. The error message is explicit about what types
310 // are allowed so there is no benefit to report multiple errors if more than one parameter in
311 // the referenced function has an illegal type.
312 for (Type t : productParams) {
313 if (failed(checkMainFuncParamType(t, productFunc, std::nullopt))) {
314 return failure(); // checkMainFuncParamType() already emits a sufficient error message
315 }
316 }
317 }
318
319 return success();
320}
321
322} // namespace
323
325 std::optional<FuncDefOp> foundCompute = std::nullopt;
326 std::optional<FuncDefOp> foundConstrain = std::nullopt;
327 std::optional<FuncDefOp> foundProduct = std::nullopt;
328 {
329 // Verify the following:
330 // 1. The only ops within the body are member and function definitions
331 // 2. The only functions defined in the struct are `@compute()` and `@constrain()`, or
332 // `@product()`
333 OwningEmitErrorFn emitError = getEmitOpErrFn(this);
334 Region &bodyRegion = getBodyRegion();
335 if (!bodyRegion.empty()) {
336 for (Operation &op : bodyRegion.front()) {
337 auto member = llvm::dyn_cast<MemberDefOp>(op);
338 if (!member) {
339 if (FuncDefOp funcDef = llvm::dyn_cast<FuncDefOp>(op)) {
340 if (funcDef.nameIsCompute()) {
341 if (foundCompute) {
342 return structFuncDefError(funcDef.getOperation())
343 << "found multiple \"@" << FUNC_NAME_COMPUTE << "\" functions";
344 }
345 foundCompute = std::make_optional(funcDef);
346 } else if (funcDef.nameIsConstrain()) {
347 if (foundConstrain) {
348 return structFuncDefError(funcDef.getOperation())
349 << "found multiple \"@" << FUNC_NAME_CONSTRAIN << "\" functions";
350 }
351 foundConstrain = std::make_optional(funcDef);
352 } else if (funcDef.nameIsProduct()) {
353 if (foundProduct) {
354 return structFuncDefError(funcDef.getOperation())
355 << "found multiple \"@" << FUNC_NAME_PRODUCT << "\" functions";
356 }
357 foundProduct = std::make_optional(funcDef);
358 } else {
359 // Must do a little more than a simple call to '?.emitOpError()' to
360 // tag the error with correct location and correct op name.
361 return structFuncDefError(funcDef.getOperation())
362 << "found \"@" << funcDef.getSymName() << '"';
363 }
364 } else {
365 return op.emitOpError()
366 << "invalid operation in '" << StructDefOp::getOperationName() << "'; only '"
367 << MemberDefOp::getOperationName() << '\'' << " and '"
368 << FuncDefOp::getOperationName() << "' operations are permitted";
369 }
370 }
371 // Also check if the member complies with output signal restrictions
372 else if (isMainComponent() && member.hasPublicAttr() &&
373 failed(checkMainFuncOutputSignalType(member.getType(), *this))) {
374 // checkMainFuncOutputSignalType already emits a sufficient error message
375 return failure();
376 }
377 }
378 }
379
380 if (!foundCompute.has_value() && foundConstrain.has_value()) {
381 return structFuncDefError(getOperation()) << "found \"@" << FUNC_NAME_CONSTRAIN
382 << "\", missing \"@" << FUNC_NAME_COMPUTE << "\"";
383 }
384 if (!foundConstrain.has_value() && foundCompute.has_value()) {
385 return structFuncDefError(getOperation()) << "found \"@" << FUNC_NAME_COMPUTE
386 << "\", missing \"@" << FUNC_NAME_CONSTRAIN << "\"";
387 }
388 }
389
390 if (!foundCompute.has_value() && !foundConstrain.has_value() && !foundProduct.has_value()) {
391 return structFuncDefError(getOperation())
392 << "could not find \"@" << FUNC_NAME_PRODUCT << "\", \"@" << FUNC_NAME_COMPUTE
393 << "\", or \"@" << FUNC_NAME_CONSTRAIN << "\"";
394 }
395
396 // Check which funcs are present and not marked with {llzk.derived}
397 auto nonderived = [](std::optional<FuncDefOp> op) -> bool {
398 return op && !(*op)->hasAttr(DERIVED_ATTR_NAME);
399 };
400
401 auto attachDerivedNotes = [&foundCompute, &foundConstrain,
402 &foundProduct](InFlightDiagnostic &&error) {
403 if (foundProduct && (*foundProduct)->hasAttr(DERIVED_ATTR_NAME)) {
404 error.attachNote(foundProduct->getLoc()) << "derived \"@" << FUNC_NAME_PRODUCT << "\" here";
405 }
406 if (foundCompute && (*foundCompute)->hasAttr(DERIVED_ATTR_NAME)) {
407 error.attachNote(foundCompute->getLoc()) << "derived \"@" << FUNC_NAME_COMPUTE << "\" here";
408 }
409 if (foundConstrain && (*foundConstrain)->hasAttr(DERIVED_ATTR_NAME)) {
410 error.attachNote(foundConstrain->getLoc())
411 << "derived \"@" << FUNC_NAME_CONSTRAIN << "\" here";
412 }
413 return error;
414 };
415
416 // We know that (@compute+@constrain) is present or @product is present, or both
417
418 // Error cases:
419 // Everything is derived
420 if (!nonderived(foundCompute) && !nonderived(foundConstrain) && !nonderived(foundProduct)) {
421 return attachDerivedNotes(
422 structFuncDefError(getOperation())
423 << "could not find non-derived \"@" << FUNC_NAME_PRODUCT << "\", \"@" << FUNC_NAME_COMPUTE
424 << "\", or \"@" << FUNC_NAME_CONSTRAIN << "\""
425 );
426 }
427
428 // Only one of @compute/@constrain is non-derived
429 if (nonderived(foundCompute) ^ nonderived(foundConstrain)) {
430 return attachDerivedNotes(
431 structFuncDefError(getOperation())
432 << "\"@" << FUNC_NAME_COMPUTE << "\" and \"@" << FUNC_NAME_CONSTRAIN
433 << "\" must both be either derived or non-derived"
434 );
435 }
436
437 // Here, at least one thing is non-derived, and @compute/@constrain are derived or non-derived
438 // together so everything is fine
439 if (nonderived(foundCompute) && nonderived(foundConstrain) && !nonderived(foundProduct)) {
440 return verifyStructComputeConstrain(*this, *foundCompute, *foundConstrain);
441 }
442
443 assert(!nonderived(foundCompute) && !nonderived(foundConstrain) && nonderived(foundProduct));
444 return verifyStructProduct(*this, *foundProduct);
445}
446
448 for (Operation &op : *getBody()) {
449 if (MemberDefOp memberDef = llvm::dyn_cast_if_present<MemberDefOp>(op)) {
450 if (memberName.compare(memberDef.getSymNameAttr()) == 0) {
451 return memberDef;
452 }
453 }
454 }
455 return nullptr;
456}
457
458std::vector<MemberDefOp> StructDefOp::getMemberDefs() {
459 std::vector<MemberDefOp> res;
460 for (Operation &op : *getBody()) {
461 if (MemberDefOp memberDef = llvm::dyn_cast_if_present<MemberDefOp>(op)) {
462 res.push_back(memberDef);
463 }
464 }
465 return res;
466}
467
469 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_COMPUTE));
470}
471
473 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_CONSTRAIN));
474}
475
477 return llvm::dyn_cast_if_present<FuncDefOp>(lookupSymbol(FUNC_NAME_PRODUCT));
478}
479
481 FailureOr<StructType> mainTypeOpt = getMainInstanceType(this->getOperation());
482 if (succeeded(mainTypeOpt)) {
483 if (StructType mainType = mainTypeOpt.value()) {
484 return structTypesUnify(mainType, this->getType());
485 }
486 }
487 return false;
488}
489
490// Custom implementation to deserialize bytecode produced prior to version 2 when `StructDefOp` had
491// an optional `const_params` attribute serialized before `sym_name`.
492LogicalResult StructDefOp::readProperties(DialectBytecodeReader &reader, OperationState &state) {
493 auto &prop = state.getOrAddProperties<Properties>();
494
495 auto versionOpt = reader.getDialectVersion<StructDialect>();
496 if (succeeded(versionOpt)) {
497 const auto &ver = static_cast<const LLZKDialectVersion &>(**versionOpt);
498 if (ver.majorVersion < 2) {
499 // Read and stash the old `const_params` as a temporary attribute so `upgradeFromVersion()`
500 // can wrap this `StructDefOp` in a `TemplateOp` with the corresponding `TemplateParamOps`.
501 ArrayAttr constParams;
502 if (failed(reader.readOptionalAttribute(constParams))) {
503 return failure();
504 }
505 if (constParams) {
506 state.addAttribute(llzk::kV1ConstParamsAttr, constParams);
507 }
508 return reader.readAttribute(prop.sym_name);
509 }
510 }
511
512 // Same as tablegen would generate to deserialize current-version IR.
513 return reader.readAttribute(prop.sym_name);
514}
515
516// Same as tablegen would generate to serialize version 2 IR.
517void StructDefOp::writeProperties(DialectBytecodeWriter &writer) {
518 auto &prop = getProperties();
519 writer.writeAttribute(prop.sym_name);
520}
521
522//===------------------------------------------------------------------===//
523// MemberDefOp
524//===------------------------------------------------------------------===//
525
527 OpBuilder &odsBuilder, OperationState &odsState, StringAttr sym_name, TypeAttr type,
528 bool isSignal, bool isColumn
529) {
530 Properties &props = odsState.getOrAddProperties<Properties>();
531 props.setSymName(sym_name);
532 props.setType(type);
533 if (isColumn) {
534 props.column = odsBuilder.getUnitAttr();
535 }
536 if (isSignal) {
537 props.signal = odsBuilder.getUnitAttr();
538 }
539}
540
542 OpBuilder &odsBuilder, OperationState &odsState, StringRef sym_name, Type type, bool isSignal,
543 bool isColumn
544) {
545 build(
546 odsBuilder, odsState, odsBuilder.getStringAttr(sym_name), TypeAttr::get(type), isSignal,
547 isColumn
548 );
549}
550
552 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, ValueRange operands,
553 ArrayRef<NamedAttribute> attributes, bool isSignal, bool isColumn
554) {
555 assert(operands.size() == 0u && "mismatched number of parameters");
556 odsState.addOperands(operands);
557 odsState.addAttributes(attributes);
558 assert(resultTypes.size() == 0u && "mismatched number of return types");
559 odsState.addTypes(resultTypes);
560 if (isColumn) {
561 odsState.getOrAddProperties<Properties>().column = odsBuilder.getUnitAttr();
562 }
563 if (isSignal) {
564 odsState.getOrAddProperties<Properties>().signal = odsBuilder.getUnitAttr();
565 }
566}
567
568void MemberDefOp::setPublicAttr(bool newValue) {
569 if (newValue) {
570 getOperation()->setAttr(PublicAttr::name, UnitAttr::get(getContext()));
571 } else {
572 getOperation()->removeAttr(PublicAttr::name);
573 }
574}
575
576static LogicalResult
577verifyMemberDefTypeImpl(Type memberType, SymbolTableCollection &tables, Operation *origin) {
578 if (StructType memberStructType = llvm::dyn_cast<StructType>(memberType)) {
579 // Special case for StructType verifies that the member type can resolve and that it is NOT the
580 // parent struct (i.e., struct members cannot create circular references).
581 auto memberTypeRes = verifyStructTypeResolution(tables, memberStructType, origin);
582 if (failed(memberTypeRes)) {
583 return failure(); // above already emits a sufficient error message
584 }
585 StructDefOp parentRes = getParentOfType<StructDefOp>(origin);
586 assert(parentRes && "MemberDefOp parent is always StructDefOp"); // per ODS def
587 if (memberTypeRes.value() == parentRes) {
588 return origin->emitOpError()
589 .append("type is circular")
590 .attachNote(parentRes.getLoc())
591 .append("references parent component defined here");
592 }
593 return success();
594 } else {
595 return verifyTypeResolution(tables, origin, memberType);
596 }
597}
598
599LogicalResult MemberDefOp::verifySymbolUses(SymbolTableCollection &tables) {
600 Type memberType = this->getType();
601 if (failed(verifyMemberDefTypeImpl(memberType, tables, *this))) {
602 return failure();
603 }
604
605 if (!getColumn()) {
606 return success();
607 }
608 // If the member is marked as a column only a small subset of types are allowed.
609 if (!isValidColumnType(getType(), tables, *this)) {
610 return emitOpError() << "marked as column can only contain felts, arrays of column types, or "
611 "structs with columns, but has type "
612 << getType();
613 }
614 return success();
615}
616
617LogicalResult MemberDefOp::verify() {
619 return emitOpError() << "with type " << getType() << " cannot have the signal attribute";
620 }
621 return success();
622}
623
624//===------------------------------------------------------------------===//
625// MemberRefOp implementations
626//===------------------------------------------------------------------===//
627namespace {
628
629FailureOr<SymbolLookupResult<MemberDefOp>>
630getMemberDefOpImpl(MemberRefOpInterface refOp, SymbolTableCollection &tables, StructType tyStruct) {
631 Operation *op = refOp.getOperation();
632 auto structDefRes = tyStruct.getDefinition(tables, op);
633 if (failed(structDefRes)) {
634 return failure(); // getDefinition() already emits a sufficient error message
635 }
636 // Copy namespace because we will need it later.
637 llvm::SmallVector<llvm::StringRef> structDefOpNs(structDefRes->getNamespace());
639 tables, SymbolRefAttr::get(refOp->getContext(), refOp.getMemberName()),
640 std::move(*structDefRes), op
641 );
642 if (failed(res)) {
643 return refOp->emitError() << "could not find '" << MemberDefOp::getOperationName()
644 << "' named \"@" << refOp.getMemberName() << "\" in \""
645 << tyStruct.getNameRef() << '"';
646 }
647 // Prepend the namespace of the struct lookup since the type of the member is meant to be resolved
648 // within that scope.
649 res->prependNamespace(structDefOpNs);
650 return std::move(res.value());
651}
652
653static FailureOr<SymbolLookupResult<MemberDefOp>>
654findMember(MemberRefOpInterface refOp, SymbolTableCollection &tables) {
655 // Ensure the base component/struct type reference can be resolved.
656 StructType tyStruct = refOp.getStructType();
657 if (failed(tyStruct.verifySymbolRef(tables, refOp.getOperation()))) {
658 return failure();
659 }
660 // Ensure the member name can be resolved in that struct.
661 return getMemberDefOpImpl(refOp, tables, tyStruct);
662}
663
664static LogicalResult verifySymbolUsesImpl(
665 MemberRefOpInterface refOp, SymbolTableCollection &tables,
666 SymbolLookupResult<MemberDefOp> &member
667) {
668 // Ensure the type of the referenced member declaration matches the type used in this op.
669 Type actualType = refOp.getVal().getType();
670 Type memberType = member.get().getType();
671 if (!typesUnify(actualType, memberType, member.getNamespace())) {
672 return refOp->emitOpError() << "has wrong type; expected " << memberType << ", got "
673 << actualType;
674 }
675 // Ensure any SymbolRef used in the type are valid
676 return verifyTypeResolution(tables, refOp.getOperation(), actualType);
677}
678
679LogicalResult verifySymbolUsesImpl(MemberRefOpInterface refOp, SymbolTableCollection &tables) {
680 // Ensure the member name can be resolved in that struct.
681 auto member = findMember(refOp, tables);
682 if (failed(member)) {
683 return member; // getMemberDefOp() already emits a sufficient error message
684 }
685 return verifySymbolUsesImpl(refOp, tables, *member);
686}
687
688} // namespace
689
690FailureOr<SymbolLookupResult<MemberDefOp>>
691MemberRefOpInterface::getMemberDefOp(SymbolTableCollection &tables) {
692 return getMemberDefOpImpl(*this, tables, getStructType());
693}
694
695LogicalResult MemberReadOp::verifySymbolUses(SymbolTableCollection &tables) {
696 auto member = findMember(*this, tables);
697 if (failed(member)) {
698 return failure();
699 }
700 if (failed(verifySymbolUsesImpl(*this, tables, *member))) {
701 return failure();
702 }
703 // If the member is not a column and an offset was specified then fail to validate
704 if (!member->get().getColumn() && getTableOffset().has_value()) {
705 return emitOpError("cannot read with table offset from a member that is not a column")
706 .attachNote(member->get().getLoc())
707 .append("member defined here");
708 }
709 // If the member is private and this read is outside the struct, then fail to validate.
710 // The current op may be inside a struct or a free function, but the
711 // member op (the member definition) is always inside a struct.
712 FailureOr<StructDefOp> memberParentRes = verifyInStruct(member->get());
713 if (failed(memberParentRes)) {
714 return failure(); // verifyInStruct() already emits a sufficient error message
715 }
716 // Can only read private members within the defining struct or from a verif
717 // contract targeting the struct.
718 StructDefOp thisParent = getParentOfType<StructDefOp>(*this);
719 // Defaults to failure
720 FailureOr<SymbolLookupResult<StructDefOp>> contractTarget;
721 if (auto contractParent = getParentOfType<ContractOp>(*this)) {
722 contractTarget = contractParent.getStructTarget(tables);
723 }
724 StructDefOp memberParentStruct = memberParentRes.value();
725 bool correctContractTarget =
726 succeeded(contractTarget) && memberParentStruct == contractTarget->get();
727 bool inMemberParent = thisParent && (thisParent == memberParentStruct);
728 bool validParent = inMemberParent || correctContractTarget;
729 if (!member->get().hasPublicAttr() && !validParent) {
730 return emitOpError()
731 .append(
732 "cannot read from private member of struct \"", memberParentStruct.getHeaderString(),
733 "\""
734 )
735 .attachNote(member->get().getLoc())
736 .append("member defined here");
737 }
738 return success();
739}
740
741LogicalResult MemberWriteOp::verifySymbolUses(SymbolTableCollection &tables) {
742 // Ensure the write op only targets members in the current struct.
743 FailureOr<StructDefOp> getParentRes = verifyInStruct(*this);
744 if (failed(getParentRes)) {
745 return failure(); // verifyInStruct() already emits a sufficient error message
746 }
747 if (failed(checkSelfType(tables, *getParentRes, getComponent().getType(), *this, "base value"))) {
748 return failure(); // checkSelfType() already emits a sufficient error message
749 }
750 // Perform the standard member ref checks.
751 return verifySymbolUsesImpl(*this, tables);
752}
753
754//===------------------------------------------------------------------===//
755// MemberReadOp
756//===------------------------------------------------------------------===//
757
759 OpBuilder &builder, OperationState &state, Type resultType, Value component, StringAttr member
760) {
761 Properties &props = state.getOrAddProperties<Properties>();
762 props.setMemberName(FlatSymbolRefAttr::get(member));
763 state.addTypes(resultType);
764 state.addOperands(component);
766}
767
769 OpBuilder &builder, OperationState &state, Type resultType, Value component, StringAttr member,
770 Attribute dist, ValueRange mapOperands, std::optional<int32_t> numDims
771) {
772 // '!mapOperands.empty()' implies 'numDims.has_value()'
773 assert(mapOperands.empty() || numDims.has_value());
774 state.addOperands(component);
775 state.addTypes(resultType);
776 if (numDims.has_value()) {
778 builder, state, ArrayRef({mapOperands}), builder.getDenseI32ArrayAttr({*numDims})
779 );
780 } else {
782 }
783 Properties &props = state.getOrAddProperties<Properties>();
784 props.setMemberName(FlatSymbolRefAttr::get(member));
785 props.setTableOffset(dist);
786}
787
789 OpBuilder & /*odsBuilder*/, OperationState &odsState, TypeRange resultTypes,
790 ValueRange operands, ArrayRef<NamedAttribute> attrs
791) {
792 odsState.addTypes(resultTypes);
793 odsState.addOperands(operands);
794 odsState.addAttributes(attrs);
795}
796
797LogicalResult MemberReadOp::verify() {
798 SmallVector<AffineMapAttr, 1> mapAttrs;
799 if (AffineMapAttr map =
800 llvm::dyn_cast_if_present<AffineMapAttr>(getTableOffset().value_or(nullptr))) {
801 mapAttrs.push_back(map);
802 }
804 getMapOperands(), getNumDimsPerMap(), mapAttrs, *this
805 );
806}
807
808//===------------------------------------------------------------------===//
809// CreateStructOp
810//===------------------------------------------------------------------===//
811
812void CreateStructOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
813 setNameFn(getResult(), "self");
814}
815
816LogicalResult CreateStructOp::verifySymbolUses(SymbolTableCollection &tables) {
817 FailureOr<StructDefOp> getParentRes = verifyInStruct(*this);
818 if (failed(getParentRes)) {
819 return failure(); // verifyInStruct() already emits a sufficient error message
820 }
821 if (failed(checkSelfType(tables, *getParentRes, this->getType(), *this, "result"))) {
822 return failure();
823 }
824 return success();
825}
826
827} // namespace llzk::component
llvm::ArrayRef< llvm::StringRef > getNamespace() const
Return the stack of symbol names from either IncludeOp or ModuleOp that were traversed to load this r...
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:51
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:816
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
Definition Ops.cpp:812
::mlir::TypedValue<::llzk::component::StructType > getResult()
Definition Ops.h.inc:143
void setPublicAttr(bool newValue=true)
Definition Ops.cpp:568
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:353
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::StringAttr sym_name, ::mlir::TypeAttr type, bool isSignal=false, bool isColumn=false)
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:599
::llvm::LogicalResult verify()
Definition Ops.cpp:617
FoldAdaptor::Properties Properties
Definition Ops.h.inc:315
::std::optional<::mlir::Attribute > getTableOffset()
Definition Ops.cpp.inc:979
::mlir::OperandRangeRange getMapOperands()
Definition Ops.h.inc:692
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType, ::mlir::Value component, ::mlir::StringAttr member)
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:695
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
Definition Ops.cpp.inc:984
::llvm::LogicalResult verify()
Definition Ops.cpp:797
FoldAdaptor::Properties Properties
Definition Ops.h.inc:639
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the MemberRefOp.
::mlir::FailureOr< SymbolLookupResult< MemberDefOp > > getMemberDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the member referenced in this op.
Definition Ops.cpp:691
::llvm::StringRef getMemberName()
Gets the member name attribute value from the MemberRefOp.
::llzk::component::StructType getStructType()
Gets the struct type of the target component.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
Definition Ops.h.inc:952
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
Definition Ops.cpp:741
static mlir::LogicalResult verifyTrait(mlir::Operation *op)
::llvm::SmallVector<::mlir::Attribute > getTemplateParamOpNames()
If this struct.def is within a poly.template, return names of all poly.param within the poly....
Definition Ops.cpp:202
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
::mlir::Region & getBodyRegion()
Definition Ops.h.inc:1189
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:1165
::llvm::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state)
Definition Ops.cpp:492
::llvm::SmallVector<::mlir::Attribute > getTemplateExprOpNames()
If this struct.def is within a poly.template, return names of all poly.expr within the poly....
Definition Ops.cpp:210
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:1600
::mlir::SymbolRefAttr getFullyQualifiedName()
Return the full name for this struct from the root module, including any surrounding module scopes.
Definition Ops.cpp:218
::std::vector< MemberDefOp > getMemberDefs()
Get all MemberDefOp in this structure.
Definition Ops.cpp:458
FoldAdaptor::Properties Properties
Definition Ops.h.inc:1151
::llzk::function::FuncDefOp getProductFuncOp()
Gets the FuncDefOp that defines the product function in this structure, if present,...
Definition Ops.cpp:476
MemberDefOp getMemberDef(::mlir::StringAttr memberName)
Gets the MemberDefOp that defines the member in this structure with the given name,...
Definition Ops.cpp:447
void writeProperties(::mlir::DialectBytecodeWriter &writer)
Definition Ops.cpp:517
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
Definition Ops.cpp:472
bool hasTemplateSymbolBindings()
Return true iff the struct.def appears within a poly.template that defines constant parameters and/or...
Definition Ops.cpp:195
::llvm::LogicalResult verifyRegions()
Definition Ops.cpp:324
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
Definition Ops.cpp:468
bool isMainComponent()
Return true iff this struct.def is the main struct. See llzk::MAIN_ATTR_NAME.
Definition Ops.cpp:480
::std::string getHeaderString()
Generate header string, in the same format as the assemblyFormat.
Definition Ops.cpp:181
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
Definition Types.cpp.inc:79
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
Definition Types.cpp:26
::mlir::LogicalResult verifySymbolRef(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op)
Definition Types.cpp:60
static constexpr ::llvm::StringLiteral name
Definition Types.h.inc:32
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
Definition Ops.cpp:269
::mlir::FunctionType getFunctionType()
Definition Ops.cpp.inc:984
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:865
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
Definition Ops.h.inc:812
bool nameIsProduct()
Return true iff the function name is FUNC_NAME_PRODUCT (if needed, a check that this FuncDefOp is loc...
Definition Ops.h.inc:873
::llvm::StringRef getSymName()
Definition Ops.cpp.inc:979
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
Definition Ops.h.inc:869
static constexpr ::llvm::StringLiteral getOperationName()
Definition Ops.h.inc:666
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
Definition Ops.cpp:261
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
Definition Ops.h.inc:804
OpClass::Properties & buildInstantiationAttrsEmptyNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
void buildInstantiationAttrsNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
bool isInStruct(Operation *op)
Definition Ops.cpp:57
InFlightDiagnostic genCompareErr(StructDefOp expected, Operation *origin, const char *aspect)
Definition Ops.cpp:99
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
Definition Ops.cpp:121
FailureOr< StructDefOp > verifyInStruct(Operation *op)
Definition Ops.cpp:59
bool isInStructFunctionNamed(Operation *op, char const *funcName)
Definition Ops.cpp:67
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
Definition Debug.h:156
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
Definition Constants.h:16
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
Definition TypeHelper.h:240
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
FailureOr< StructType > getMainInstanceType(Operation *lookupFrom)
constexpr char FUNC_NAME_CONSTRAIN[]
Definition Constants.h:17
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isFeltOrSimpleFeltAggregate(Type ty)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
bool isValidMainSignalType(Type pType)
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent/ancestor operation that is of type 'OpClass'.
Definition OpHelpers.h:51
constexpr char FUNC_NAME_PRODUCT[]
Definition Constants.h:18
constexpr char DERIVED_ATTR_NAME[]
Name of the attribute on a @product func that has been automatically aligned from @compute + @constra...
Definition Constants.h:28
FailureOr< StructDefOp > verifyStructTypeResolution(SymbolTableCollection &tables, StructType ty, Operation *origin)
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
LogicalResult verifyParamsOfType(SymbolTableCollection &tables, ArrayRef< Attribute > tyParams, Type parameterizedType, Operation *origin)
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)
void setSymName(const ::mlir::StringAttr &propValue)
Definition Ops.h.inc:200
void setMemberName(const ::mlir::FlatSymbolRefAttr &propValue)
Definition Ops.h.inc:497