27#include <mlir/IR/IRMapping.h>
28#include <mlir/IR/OpImplementation.h>
30#include <llvm/ADT/MapVector.h>
31#include <llvm/ADT/STLExtras.h>
32#include <llvm/ADT/StringRef.h>
33#include <llvm/ADT/StringSet.h>
34#include <llvm/ADT/TypeSwitch.h>
68 if (parentFunc.getSymName().compare(funcName) == 0) {
78 assert(llvm::isa<StructDefOp>(structOp));
79 Region &bodyRegion = llvm::cast<StructDefOp>(structOp).getBodyRegion();
80 if (!bodyRegion.empty()) {
81 bodyRegion.front().walk([](
FuncDefOp funcDef) {
98 std::string prefix = std::string();
99 if (SymbolOpInterface symbol = llvm::dyn_cast<SymbolOpInterface>(origin)) {
101 prefix += symbol.getName();
104 return origin->emitOpError().append(
110static inline InFlightDiagnostic structFuncDefError(Operation *origin) {
120 SymbolTableCollection &tables,
StructDefOp expectedStruct, Type actualType, Operation *origin,
123 if (
StructType actualStructType = llvm::dyn_cast<StructType>(actualType)) {
124 auto actualStructOpt =
126 if (failed(actualStructOpt)) {
127 return origin->emitError().append(
129 actualStructType.getNameRef(),
'"'
132 StructDefOp actualStruct = actualStructOpt.value().get();
133 if (actualStruct != expectedStruct) {
135 .attachNote(actualStruct.getLoc())
136 .append(
"uses this type instead");
139 ArrayAttr actualTypeParamsAttr = actualStructType.getParams();
140 ArrayRef<Attribute> actualTypeParams =
141 actualTypeParamsAttr ? actualTypeParamsAttr.getValue() : ArrayRef<Attribute> {};
151 .attachNote(actualStruct.getLoc())
166 assert(succeeded(pathRes));
168 if (constParams.has_value()) {
182 if (succeeded(pathToExpected)) {
183 ss << pathToExpected.value();
204 return SmallVector<Attribute>();
212 return SmallVector<Attribute>();
218 assert(succeeded(res));
225checkMainFuncParamType(Type pType,
FuncDefOp inFunc, std::optional<StructType> appendSelfType) {
231 ss <<
"main entry component \"@" << inFunc.
getSymName()
232 <<
"\" function parameters must be one of: {";
233 if (appendSelfType.has_value()) {
234 ss << appendSelfType.value() <<
", ";
239 return inFunc.emitError(message);
242inline LogicalResult checkMainFuncOutputSignalType(Type pType,
StructDefOp structOp) {
248 ss <<
"main entry component output signals must be one of: {";
252 return structOp.emitError(message);
255inline LogicalResult verifyStructComputeConstrain(
256 StructDefOp structDef, FuncDefOp computeFunc, FuncDefOp constrainFunc
266 ArrayRef<Type> computeParams = computeFunc.
getFunctionType().getInputs();
267 ArrayRef<Type> constrainParams = constrainFunc.
getFunctionType().getInputs().drop_front();
272 for (Type t : computeParams) {
273 if (failed(checkMainFuncParamType(t, computeFunc, std::nullopt))) {
277 auto appendSelf = std::make_optional(structDef.
getType());
278 for (Type t : constrainParams) {
279 if (failed(checkMainFuncParamType(t, constrainFunc, appendSelf))) {
286 return constrainFunc.emitError()
289 "\" function argument types (sans the first one) to match \"@",
FUNC_NAME_COMPUTE,
290 "\" function argument types"
292 .attachNote(computeFunc.getLoc())
299inline LogicalResult verifyStructProduct(
StructDefOp structDef, FuncDefOp productFunc) {
306 ArrayRef<Type> productParams = productFunc.
getFunctionType().getInputs();
310 for (Type t : productParams) {
311 if (failed(checkMainFuncParamType(t, productFunc, std::nullopt))) {
323 std::optional<FuncDefOp> foundCompute = std::nullopt;
324 std::optional<FuncDefOp> foundConstrain = std::nullopt;
325 std::optional<FuncDefOp> foundProduct = std::nullopt;
333 if (!bodyRegion.empty()) {
334 for (Operation &op : bodyRegion.front()) {
335 auto member = llvm::dyn_cast<MemberDefOp>(op);
337 if (
FuncDefOp funcDef = llvm::dyn_cast<FuncDefOp>(op)) {
338 if (funcDef.nameIsCompute()) {
340 return structFuncDefError(funcDef.getOperation())
343 foundCompute = std::make_optional(funcDef);
344 }
else if (funcDef.nameIsConstrain()) {
345 if (foundConstrain) {
346 return structFuncDefError(funcDef.getOperation())
349 foundConstrain = std::make_optional(funcDef);
350 }
else if (funcDef.nameIsProduct()) {
352 return structFuncDefError(funcDef.getOperation())
355 foundProduct = std::make_optional(funcDef);
359 return structFuncDefError(funcDef.getOperation())
360 <<
"found \"@" << funcDef.getSymName() <<
'"';
363 return op.emitOpError()
371 failed(checkMainFuncOutputSignalType(member.getType(), *
this))) {
378 if (!foundCompute.has_value() && foundConstrain.has_value()) {
382 if (!foundConstrain.has_value() && foundCompute.has_value()) {
388 if (!foundCompute.has_value() && !foundConstrain.has_value() && !foundProduct.has_value()) {
389 return structFuncDefError(getOperation())
395 auto nonderived = [](std::optional<FuncDefOp> op) ->
bool {
399 auto attachDerivedNotes = [&foundCompute, &foundConstrain,
400 &foundProduct](InFlightDiagnostic &&error) {
402 error.attachNote(foundProduct->getLoc()) <<
"derived \"@" <<
FUNC_NAME_PRODUCT <<
"\" here";
405 error.attachNote(foundCompute->getLoc()) <<
"derived \"@" <<
FUNC_NAME_COMPUTE <<
"\" here";
408 error.attachNote(foundConstrain->getLoc())
418 if (!nonderived(foundCompute) && !nonderived(foundConstrain) && !nonderived(foundProduct)) {
419 return attachDerivedNotes(
420 structFuncDefError(getOperation())
427 if (nonderived(foundCompute) ^ nonderived(foundConstrain)) {
428 return attachDerivedNotes(
429 structFuncDefError(getOperation())
431 <<
"\" must both be either derived or non-derived"
437 if (nonderived(foundCompute) && nonderived(foundConstrain) && !nonderived(foundProduct)) {
438 return verifyStructComputeConstrain(*
this, *foundCompute, *foundConstrain);
441 assert(!nonderived(foundCompute) && !nonderived(foundConstrain) && nonderived(foundProduct));
442 return verifyStructProduct(*
this, *foundProduct);
446 for (Operation &op : *getBody()) {
447 if (
MemberDefOp memberDef = llvm::dyn_cast_if_present<MemberDefOp>(op)) {
448 if (memberName.compare(memberDef.getSymNameAttr()) == 0) {
457 std::vector<MemberDefOp> res;
458 for (Operation &op : *getBody()) {
459 if (
MemberDefOp memberDef = llvm::dyn_cast_if_present<MemberDefOp>(op)) {
460 res.push_back(memberDef);
480 if (succeeded(mainTypeOpt)) {
481 if (
StructType mainType = mainTypeOpt.value()) {
491 auto &prop = state.getOrAddProperties<
Properties>();
494 if (succeeded(versionOpt)) {
496 if (ver.majorVersion < 2) {
499 ArrayAttr constParams;
500 if (failed(reader.readOptionalAttribute(constParams))) {
504 state.addAttribute(llzk::kV1ConstParamsAttr, constParams);
506 return reader.readAttribute(prop.sym_name);
511 return reader.readAttribute(prop.sym_name);
516 auto &prop = getProperties();
517 writer.writeAttribute(prop.sym_name);
525 OpBuilder &odsBuilder, OperationState &odsState, StringAttr sym_name, TypeAttr type,
526 bool isSignal,
bool isColumn
532 props.column = odsBuilder.getUnitAttr();
535 props.signal = odsBuilder.getUnitAttr();
540 OpBuilder &odsBuilder, OperationState &odsState, StringRef sym_name, Type type,
bool isSignal,
544 odsBuilder, odsState, odsBuilder.getStringAttr(sym_name), TypeAttr::get(type), isSignal,
550 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, ValueRange operands,
551 ArrayRef<NamedAttribute> attributes,
bool isSignal,
bool isColumn
553 assert(operands.size() == 0u &&
"mismatched number of parameters");
554 odsState.addOperands(operands);
555 odsState.addAttributes(attributes);
556 assert(resultTypes.size() == 0u &&
"mismatched number of return types");
557 odsState.addTypes(resultTypes);
559 odsState.getOrAddProperties<
Properties>().column = odsBuilder.getUnitAttr();
562 odsState.getOrAddProperties<
Properties>().signal = odsBuilder.getUnitAttr();
568 getOperation()->setAttr(PublicAttr::name, UnitAttr::get(getContext()));
570 getOperation()->removeAttr(PublicAttr::name);
575verifyMemberDefTypeImpl(Type memberType, SymbolTableCollection &tables, Operation *origin) {
576 if (
StructType memberStructType = llvm::dyn_cast<StructType>(memberType)) {
580 if (failed(memberTypeRes)) {
584 assert(parentRes &&
"MemberDefOp parent is always StructDefOp");
585 if (memberTypeRes.value() == parentRes) {
586 return origin->emitOpError()
587 .append(
"type is circular")
588 .attachNote(parentRes.getLoc())
589 .append(
"references parent component defined here");
598 Type memberType = this->
getType();
599 if (failed(verifyMemberDefTypeImpl(memberType, tables, *
this))) {
608 return emitOpError() <<
"marked as column can only contain felts, arrays of column types, or "
609 "structs with columns, but has type "
617 return emitOpError() <<
"with type " <<
getType() <<
" cannot have the signal attribute";
627FailureOr<SymbolLookupResult<MemberDefOp>>
629 Operation *op = refOp.getOperation();
631 if (failed(structDefRes)) {
635 llvm::SmallVector<llvm::StringRef> structDefOpNs(structDefRes->getNamespace());
637 tables, SymbolRefAttr::get(refOp->getContext(), refOp.
getMemberName()),
638 std::move(*structDefRes), op
647 res->prependNamespace(structDefOpNs);
648 return std::move(res.value());
651static FailureOr<SymbolLookupResult<MemberDefOp>>
659 return getMemberDefOpImpl(refOp, tables, tyStruct);
662static LogicalResult verifySymbolUsesImpl(
664 SymbolLookupResult<MemberDefOp> &member
667 Type actualType = refOp.
getVal().getType();
668 Type memberType = member.
get().getType();
670 return refOp->emitOpError() <<
"has wrong type; expected " << memberType <<
", got "
679 auto member = findMember(refOp, tables);
680 if (failed(member)) {
683 return verifySymbolUsesImpl(refOp, tables, *member);
688FailureOr<SymbolLookupResult<MemberDefOp>>
694 auto member = findMember(*
this, tables);
695 if (failed(member)) {
698 if (failed(verifySymbolUsesImpl(*
this, tables, *member))) {
703 return emitOpError(
"cannot read with table offset from a member that is not a column")
704 .attachNote(member->
get().getLoc())
705 .append(
"member defined here");
711 if (failed(memberParentRes)) {
715 StructDefOp memberParentStruct = memberParentRes.value();
716 if (!member->
get().hasPublicAttr() && (!thisParent || thisParent != memberParentStruct)) {
719 "cannot read from private member of struct \"", memberParentStruct.
getHeaderString(),
722 .attachNote(member->
get().getLoc())
723 .append(
"member defined here");
731 if (failed(getParentRes)) {
738 return verifySymbolUsesImpl(*
this, tables);
746 OpBuilder &builder, OperationState &state, Type resultType, Value
component, StringAttr member
750 state.addTypes(resultType);
756 OpBuilder &builder, OperationState &state, Type resultType, Value
component, StringAttr member,
757 Attribute dist, ValueRange mapOperands, std::optional<int32_t> numDims
760 assert(mapOperands.empty() || numDims.has_value());
762 state.addTypes(resultType);
763 if (numDims.has_value()) {
765 builder, state, ArrayRef({mapOperands}), builder.getDenseI32ArrayAttr({*numDims})
771 props.setMemberName(FlatSymbolRefAttr::get(member));
772 props.setTableOffset(dist);
776 OpBuilder & , OperationState &odsState, TypeRange resultTypes,
777 ValueRange operands, ArrayRef<NamedAttribute> attrs
779 odsState.addTypes(resultTypes);
780 odsState.addOperands(operands);
781 odsState.addAttributes(attrs);
785 SmallVector<AffineMapAttr, 1> mapAttrs;
786 if (AffineMapAttr map =
787 llvm::dyn_cast_if_present<AffineMapAttr>(
getTableOffset().value_or(
nullptr))) {
788 mapAttrs.push_back(map);
805 if (failed(getParentRes)) {
808 if (failed(
checkSelfType(tables, *getParentRes, this->getType(), *
this,
"result"))) {
llvm::ArrayRef< llvm::StringRef > getNamespace() const
Return the stack of symbol names from either IncludeOp or ModuleOp that were traversed to load this r...
static constexpr ::llvm::StringLiteral name
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
::mlir::TypedValue<::llzk::component::StructType > getResult()
void setPublicAttr(bool newValue=true)
static constexpr ::llvm::StringLiteral getOperationName()
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::StringAttr sym_name, ::mlir::TypeAttr type, bool isSignal=false, bool isColumn=false)
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
::llvm::LogicalResult verify()
FoldAdaptor::Properties Properties
::std::optional<::mlir::Attribute > getTableOffset()
::mlir::OperandRangeRange getMapOperands()
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType, ::mlir::Value component, ::mlir::StringAttr member)
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
::llvm::LogicalResult verify()
FoldAdaptor::Properties Properties
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the MemberRefOp.
::mlir::FailureOr< SymbolLookupResult< MemberDefOp > > getMemberDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the member referenced in this op.
::llvm::StringRef getMemberName()
Gets the member name attribute value from the MemberRefOp.
::llzk::component::StructType getStructType()
Gets the struct type of the target component.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
static mlir::LogicalResult verifyTrait(mlir::Operation *op)
::llvm::SmallVector<::mlir::Attribute > getTemplateParamOpNames()
If this struct.def is within a poly.template, return names of all poly.param within the poly....
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
::mlir::Region & getBodyRegion()
static constexpr ::llvm::StringLiteral getOperationName()
::llvm::LogicalResult readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state)
::llvm::SmallVector<::mlir::Attribute > getTemplateExprOpNames()
If this struct.def is within a poly.template, return names of all poly.expr within the poly....
::llvm::StringRef getSymName()
::mlir::SymbolRefAttr getFullyQualifiedName()
Return the full name for this struct from the root module, including any surrounding module scopes.
::std::vector< MemberDefOp > getMemberDefs()
Get all MemberDefOp in this structure.
FoldAdaptor::Properties Properties
::llzk::function::FuncDefOp getProductFuncOp()
Gets the FuncDefOp that defines the product function in this structure, if present,...
MemberDefOp getMemberDef(::mlir::StringAttr memberName)
Gets the MemberDefOp that defines the member in this structure with the given name,...
void writeProperties(::mlir::DialectBytecodeWriter &writer)
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
bool hasTemplateSymbolBindings()
Return true iff the struct.def appears within a poly.template that defines constant parameters and/or...
::llvm::LogicalResult verifyRegions()
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
bool isMainComponent()
Return true iff this struct.def is the main struct. See llzk::MAIN_ATTR_NAME.
::std::string getHeaderString()
Generate header string, in the same format as the assemblyFormat.
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
::mlir::LogicalResult verifySymbolRef(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op)
static constexpr ::llvm::StringLiteral name
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
::mlir::FunctionType getFunctionType()
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
bool nameIsProduct()
Return true iff the function name is FUNC_NAME_PRODUCT (if needed, a check that this FuncDefOp is loc...
::llvm::StringRef getSymName()
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
static constexpr ::llvm::StringLiteral getOperationName()
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
OpClass::Properties & buildInstantiationAttrsEmptyNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
void buildInstantiationAttrsNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
bool isInStruct(Operation *op)
InFlightDiagnostic genCompareErr(StructDefOp expected, Operation *origin, const char *aspect)
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
FailureOr< StructDefOp > verifyInStruct(Operation *op)
bool isInStructFunctionNamed(Operation *op, char const *funcName)
std::string toStringList(InputIt begin, InputIt end)
Generate a comma-separated string representation by traversing elements from begin to end where the e...
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
FailureOr< StructType > getMainInstanceType(Operation *lookupFrom)
constexpr char FUNC_NAME_CONSTRAIN[]
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isFeltOrSimpleFeltAggregate(Type ty)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
bool isValidMainSignalType(Type pType)
OpClass getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
constexpr char FUNC_NAME_PRODUCT[]
constexpr char DERIVED_ATTR_NAME[]
Name of the attribute on a @product func that has been automatically aligned from @compute + @constra...
FailureOr< StructDefOp > verifyStructTypeResolution(SymbolTableCollection &tables, StructType ty, Operation *origin)
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
LogicalResult verifyParamsOfType(SymbolTableCollection &tables, ArrayRef< Attribute > tyParams, Type parameterizedType, Operation *origin)
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)
void setSymName(const ::mlir::StringAttr &propValue)
void setMemberName(const ::mlir::FlatSymbolRefAttr &propValue)