28#include <mlir/IR/IRMapping.h>
29#include <mlir/IR/OpImplementation.h>
31#include <llvm/ADT/MapVector.h>
32#include <llvm/ADT/STLExtras.h>
33#include <llvm/ADT/StringRef.h>
34#include <llvm/ADT/StringSet.h>
35#include <llvm/ADT/TypeSwitch.h>
70 if (parentFunc.getSymName().compare(funcName) == 0) {
80 assert(llvm::isa<StructDefOp>(structOp));
81 Region &bodyRegion = llvm::cast<StructDefOp>(structOp).getBodyRegion();
82 if (!bodyRegion.empty()) {
83 bodyRegion.front().walk([](
FuncDefOp funcDef) {
100 std::string prefix = std::string();
101 if (SymbolOpInterface symbol = llvm::dyn_cast<SymbolOpInterface>(origin)) {
103 prefix += symbol.getName();
106 return origin->emitOpError().append(
112static inline InFlightDiagnostic structFuncDefError(Operation *origin) {
122 SymbolTableCollection &tables,
StructDefOp expectedStruct, Type actualType, Operation *origin,
125 if (
StructType actualStructType = llvm::dyn_cast<StructType>(actualType)) {
126 auto actualStructOpt =
128 if (failed(actualStructOpt)) {
129 return origin->emitError().append(
131 actualStructType.getNameRef(),
'"'
134 StructDefOp actualStruct = actualStructOpt.value().get();
135 if (actualStruct != expectedStruct) {
137 .attachNote(actualStruct.getLoc())
138 .append(
"uses this type instead");
141 ArrayAttr actualTypeParamsAttr = actualStructType.getParams();
142 ArrayRef<Attribute> actualTypeParams =
143 actualTypeParamsAttr ? actualTypeParamsAttr.getValue() : ArrayRef<Attribute> {};
153 .attachNote(actualStruct.getLoc())
168 assert(succeeded(pathRes));
170 if (constParams.has_value()) {
184 if (succeeded(pathToExpected)) {
185 ss << pathToExpected.value();
206 return SmallVector<Attribute>();
214 return SmallVector<Attribute>();
220 assert(succeeded(res));
227checkMainFuncParamType(Type pType,
FuncDefOp inFunc, std::optional<StructType> appendSelfType) {
233 ss <<
"main entry component \"@" << inFunc.
getSymName()
234 <<
"\" function parameters must be one of: {";
235 if (appendSelfType.has_value()) {
236 ss << appendSelfType.value() <<
", ";
241 return inFunc.emitError(message);
244inline LogicalResult checkMainFuncOutputSignalType(Type pType,
StructDefOp structOp) {
250 ss <<
"main entry component output signals must be one of: {";
254 return structOp.emitError(message);
257inline LogicalResult verifyStructComputeConstrain(
258 StructDefOp structDef, FuncDefOp computeFunc, FuncDefOp constrainFunc
268 ArrayRef<Type> computeParams = computeFunc.
getFunctionType().getInputs();
269 ArrayRef<Type> constrainParams = constrainFunc.
getFunctionType().getInputs().drop_front();
274 for (Type t : computeParams) {
275 if (failed(checkMainFuncParamType(t, computeFunc, std::nullopt))) {
279 auto appendSelf = std::make_optional(structDef.
getType());
280 for (Type t : constrainParams) {
281 if (failed(checkMainFuncParamType(t, constrainFunc, appendSelf))) {
288 return constrainFunc.emitError()
291 "\" function argument types (sans the first one) to match \"@",
FUNC_NAME_COMPUTE,
292 "\" function argument types"
294 .attachNote(computeFunc.getLoc())
301inline LogicalResult verifyStructProduct(
StructDefOp structDef, FuncDefOp productFunc) {
308 ArrayRef<Type> productParams = productFunc.
getFunctionType().getInputs();
312 for (Type t : productParams) {
313 if (failed(checkMainFuncParamType(t, productFunc, std::nullopt))) {
325 std::optional<FuncDefOp> foundCompute = std::nullopt;
326 std::optional<FuncDefOp> foundConstrain = std::nullopt;
327 std::optional<FuncDefOp> foundProduct = std::nullopt;
335 if (!bodyRegion.empty()) {
336 for (Operation &op : bodyRegion.front()) {
337 auto member = llvm::dyn_cast<MemberDefOp>(op);
339 if (
FuncDefOp funcDef = llvm::dyn_cast<FuncDefOp>(op)) {
340 if (funcDef.nameIsCompute()) {
342 return structFuncDefError(funcDef.getOperation())
345 foundCompute = std::make_optional(funcDef);
346 }
else if (funcDef.nameIsConstrain()) {
347 if (foundConstrain) {
348 return structFuncDefError(funcDef.getOperation())
351 foundConstrain = std::make_optional(funcDef);
352 }
else if (funcDef.nameIsProduct()) {
354 return structFuncDefError(funcDef.getOperation())
357 foundProduct = std::make_optional(funcDef);
361 return structFuncDefError(funcDef.getOperation())
362 <<
"found \"@" << funcDef.getSymName() <<
'"';
365 return op.emitOpError()
373 failed(checkMainFuncOutputSignalType(member.getType(), *
this))) {
380 if (!foundCompute.has_value() && foundConstrain.has_value()) {
384 if (!foundConstrain.has_value() && foundCompute.has_value()) {
390 if (!foundCompute.has_value() && !foundConstrain.has_value() && !foundProduct.has_value()) {
391 return structFuncDefError(getOperation())
397 auto nonderived = [](std::optional<FuncDefOp> op) ->
bool {
401 auto attachDerivedNotes = [&foundCompute, &foundConstrain,
402 &foundProduct](InFlightDiagnostic &&error) {
404 error.attachNote(foundProduct->getLoc()) <<
"derived \"@" <<
FUNC_NAME_PRODUCT <<
"\" here";
407 error.attachNote(foundCompute->getLoc()) <<
"derived \"@" <<
FUNC_NAME_COMPUTE <<
"\" here";
410 error.attachNote(foundConstrain->getLoc())
420 if (!nonderived(foundCompute) && !nonderived(foundConstrain) && !nonderived(foundProduct)) {
421 return attachDerivedNotes(
422 structFuncDefError(getOperation())
429 if (nonderived(foundCompute) ^ nonderived(foundConstrain)) {
430 return attachDerivedNotes(
431 structFuncDefError(getOperation())
433 <<
"\" must both be either derived or non-derived"
439 if (nonderived(foundCompute) && nonderived(foundConstrain) && !nonderived(foundProduct)) {
440 return verifyStructComputeConstrain(*
this, *foundCompute, *foundConstrain);
443 assert(!nonderived(foundCompute) && !nonderived(foundConstrain) && nonderived(foundProduct));
444 return verifyStructProduct(*
this, *foundProduct);
448 for (Operation &op : *getBody()) {
449 if (
MemberDefOp memberDef = llvm::dyn_cast_if_present<MemberDefOp>(op)) {
450 if (memberName.compare(memberDef.getSymNameAttr()) == 0) {
459 std::vector<MemberDefOp> res;
460 for (Operation &op : *getBody()) {
461 if (
MemberDefOp memberDef = llvm::dyn_cast_if_present<MemberDefOp>(op)) {
462 res.push_back(memberDef);
482 if (succeeded(mainTypeOpt)) {
483 if (
StructType mainType = mainTypeOpt.value()) {
493 auto &prop = state.getOrAddProperties<
Properties>();
496 if (succeeded(versionOpt)) {
498 if (ver.majorVersion < 2) {
501 ArrayAttr constParams;
502 if (failed(reader.readOptionalAttribute(constParams))) {
506 state.addAttribute(llzk::kV1ConstParamsAttr, constParams);
508 return reader.readAttribute(prop.sym_name);
513 return reader.readAttribute(prop.sym_name);
518 auto &prop = getProperties();
519 writer.writeAttribute(prop.sym_name);
527 OpBuilder &odsBuilder, OperationState &odsState, StringAttr sym_name, TypeAttr type,
528 bool isSignal,
bool isColumn
534 props.column = odsBuilder.getUnitAttr();
537 props.signal = odsBuilder.getUnitAttr();
542 OpBuilder &odsBuilder, OperationState &odsState, StringRef sym_name, Type type,
bool isSignal,
546 odsBuilder, odsState, odsBuilder.getStringAttr(sym_name), TypeAttr::get(type), isSignal,
552 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, ValueRange operands,
553 ArrayRef<NamedAttribute> attributes,
bool isSignal,
bool isColumn
555 assert(operands.size() == 0u &&
"mismatched number of parameters");
556 odsState.addOperands(operands);
557 odsState.addAttributes(attributes);
558 assert(resultTypes.size() == 0u &&
"mismatched number of return types");
559 odsState.addTypes(resultTypes);
561 odsState.getOrAddProperties<
Properties>().column = odsBuilder.getUnitAttr();
564 odsState.getOrAddProperties<
Properties>().signal = odsBuilder.getUnitAttr();
570 getOperation()->setAttr(PublicAttr::name, UnitAttr::get(getContext()));
572 getOperation()->removeAttr(PublicAttr::name);
577verifyMemberDefTypeImpl(Type memberType, SymbolTableCollection &tables, Operation *origin) {
578 if (
StructType memberStructType = llvm::dyn_cast<StructType>(memberType)) {
582 if (failed(memberTypeRes)) {
586 assert(parentRes &&
"MemberDefOp parent is always StructDefOp");
587 if (memberTypeRes.value() == parentRes) {
588 return origin->emitOpError()
589 .append(
"type is circular")
590 .attachNote(parentRes.getLoc())
591 .append(
"references parent component defined here");
600 Type memberType = this->
getType();
601 if (failed(verifyMemberDefTypeImpl(memberType, tables, *
this))) {
610 return emitOpError() <<
"marked as column can only contain felts, arrays of column types, or "
611 "structs with columns, but has type "
619 return emitOpError() <<
"with type " <<
getType() <<
" cannot have the signal attribute";
629FailureOr<SymbolLookupResult<MemberDefOp>>
631 Operation *op = refOp.getOperation();
633 if (failed(structDefRes)) {
637 llvm::SmallVector<llvm::StringRef> structDefOpNs(structDefRes->getNamespace());
639 tables, SymbolRefAttr::get(refOp->getContext(), refOp.
getMemberName()),
640 std::move(*structDefRes), op
649 res->prependNamespace(structDefOpNs);
650 return std::move(res.value());
653static FailureOr<SymbolLookupResult<MemberDefOp>>
661 return getMemberDefOpImpl(refOp, tables, tyStruct);
664static LogicalResult verifySymbolUsesImpl(
666 SymbolLookupResult<MemberDefOp> &member
669 Type actualType = refOp.
getVal().getType();
670 Type memberType = member.
get().getType();
672 return refOp->emitOpError() <<
"has wrong type; expected " << memberType <<
", got "
681 auto member = findMember(refOp, tables);
682 if (failed(member)) {
685 return verifySymbolUsesImpl(refOp, tables, *member);
690FailureOr<SymbolLookupResult<MemberDefOp>>
696 auto member = findMember(*
this, tables);
697 if (failed(member)) {
700 if (failed(verifySymbolUsesImpl(*
this, tables, *member))) {
705 return emitOpError(
"cannot read with table offset from a member that is not a column")
706 .attachNote(member->
get().getLoc())
707 .append(
"member defined here");
713 if (failed(memberParentRes)) {
720 FailureOr<SymbolLookupResult<StructDefOp>> contractTarget;
722 contractTarget = contractParent.getStructTarget(tables);
724 StructDefOp memberParentStruct = memberParentRes.value();
725 bool correctContractTarget =
726 succeeded(contractTarget) && memberParentStruct == contractTarget->get();
727 bool inMemberParent = thisParent && (thisParent == memberParentStruct);
728 bool validParent = inMemberParent || correctContractTarget;
729 if (!member->
get().hasPublicAttr() && !validParent) {
732 "cannot read from private member of struct \"", memberParentStruct.
getHeaderString(),
735 .attachNote(member->
get().getLoc())
736 .append(
"member defined here");
744 if (failed(getParentRes)) {
751 return verifySymbolUsesImpl(*
this, tables);
759 OpBuilder &builder, OperationState &state, Type resultType, Value
component, StringAttr member
763 state.addTypes(resultType);
769 OpBuilder &builder, OperationState &state, Type resultType, Value
component, StringAttr member,
770 Attribute dist, ValueRange mapOperands, std::optional<int32_t> numDims
773 assert(mapOperands.empty() || numDims.has_value());
775 state.addTypes(resultType);
776 if (numDims.has_value()) {
778 builder, state, ArrayRef({mapOperands}), builder.getDenseI32ArrayAttr({*numDims})
784 props.setMemberName(FlatSymbolRefAttr::get(member));
785 props.setTableOffset(dist);
789 OpBuilder & , OperationState &odsState, TypeRange resultTypes,
790 ValueRange operands, ArrayRef<NamedAttribute> attrs
792 odsState.addTypes(resultTypes);
793 odsState.addOperands(operands);
794 odsState.addAttributes(attrs);
798 SmallVector<AffineMapAttr, 1> mapAttrs;
799 if (AffineMapAttr map =
800 llvm::dyn_cast_if_present<AffineMapAttr>(
getTableOffset().value_or(
nullptr))) {
801 mapAttrs.push_back(map);
818 if (failed(getParentRes)) {
821 if (failed(
checkSelfType(tables, *getParentRes, this->getType(), *
this,
"result"))) {
llvm::ArrayRef< llvm::StringRef > getNamespace() const
Return the stack of symbol names from either IncludeOp or ModuleOp that were traversed to load this r...
static constexpr ::llvm::StringLiteral name
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
::mlir::TypedValue<::llzk::component::StructType > getResult()
void setPublicAttr(bool newValue=true)
static constexpr ::llvm::StringLiteral getOperationName()
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::StringAttr sym_name, ::mlir::TypeAttr type, bool isSignal=false, bool isColumn=false)
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
::llvm::LogicalResult verify()
FoldAdaptor::Properties Properties
::std::optional<::mlir::Attribute > getTableOffset()
::mlir::OperandRangeRange getMapOperands()
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType, ::mlir::Value component, ::mlir::StringAttr member)
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
::llvm::LogicalResult verify()
FoldAdaptor::Properties Properties
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the MemberRefOp.
::mlir::FailureOr< SymbolLookupResult< MemberDefOp > > getMemberDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the member referenced in this op.
::llvm::StringRef getMemberName()
Gets the member name attribute value from the MemberRefOp.
::llzk::component::StructType getStructType()
Gets the struct type of the target component.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
static mlir::LogicalResult verifyTrait(mlir::Operation *op)
::llvm::SmallVector<::mlir::Attribute > getTemplateParamOpNames()
If this struct.def is within a poly.template, return names of all poly.param within the poly....
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
::mlir::Region & getBodyRegion()
static constexpr ::llvm::StringLiteral getOperationName()
::llvm::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state)
::llvm::SmallVector<::mlir::Attribute > getTemplateExprOpNames()
If this struct.def is within a poly.template, return names of all poly.expr within the poly....
::llvm::StringRef getSymName()
::mlir::SymbolRefAttr getFullyQualifiedName()
Return the full name for this struct from the root module, including any surrounding module scopes.
::std::vector< MemberDefOp > getMemberDefs()
Get all MemberDefOp in this structure.
FoldAdaptor::Properties Properties
::llzk::function::FuncDefOp getProductFuncOp()
Gets the FuncDefOp that defines the product function in this structure, if present,...
MemberDefOp getMemberDef(::mlir::StringAttr memberName)
Gets the MemberDefOp that defines the member in this structure with the given name,...
void writeProperties(::mlir::DialectBytecodeWriter &writer)
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
bool hasTemplateSymbolBindings()
Return true iff the struct.def appears within a poly.template that defines constant parameters and/or...
::llvm::LogicalResult verifyRegions()
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
bool isMainComponent()
Return true iff this struct.def is the main struct. See llzk::MAIN_ATTR_NAME.
::std::string getHeaderString()
Generate header string, in the same format as the assemblyFormat.
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
::mlir::LogicalResult verifySymbolRef(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op)
static constexpr ::llvm::StringLiteral name
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
::mlir::FunctionType getFunctionType()
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
bool nameIsProduct()
Return true iff the function name is FUNC_NAME_PRODUCT (if needed, a check that this FuncDefOp is loc...
::llvm::StringRef getSymName()
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
static constexpr ::llvm::StringLiteral getOperationName()
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
OpClass::Properties & buildInstantiationAttrsEmptyNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
void buildInstantiationAttrsNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
bool isInStruct(Operation *op)
InFlightDiagnostic genCompareErr(StructDefOp expected, Operation *origin, const char *aspect)
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
FailureOr< StructDefOp > verifyInStruct(Operation *op)
bool isInStructFunctionNamed(Operation *op, char const *funcName)
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
FailureOr< StructType > getMainInstanceType(Operation *lookupFrom)
constexpr char FUNC_NAME_CONSTRAIN[]
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isFeltOrSimpleFeltAggregate(Type ty)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
bool isValidMainSignalType(Type pType)
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent/ancestor operation that is of type 'OpClass'.
constexpr char FUNC_NAME_PRODUCT[]
constexpr char DERIVED_ATTR_NAME[]
Name of the attribute on a @product func that has been automatically aligned from @compute + @constra...
FailureOr< StructDefOp > verifyStructTypeResolution(SymbolTableCollection &tables, StructType ty, Operation *origin)
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
LogicalResult verifyParamsOfType(SymbolTableCollection &tables, ArrayRef< Attribute > tyParams, Type parameterizedType, Operation *origin)
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)
void setSymName(const ::mlir::StringAttr &propValue)
void setMemberName(const ::mlir::FlatSymbolRefAttr &propValue)