20#include <mlir/IR/IRMapping.h>
21#include <mlir/IR/OpImplementation.h>
23#include <llvm/ADT/MapVector.h>
24#include <llvm/ADT/STLExtras.h>
25#include <llvm/ADT/StringRef.h>
26#include <llvm/ADT/StringSet.h>
58 if (succeeded(parentFuncOpt)) {
59 FuncDefOp parentFunc = parentFuncOpt.value();
61 if (parentFunc.
getSymName().compare(funcName) == 0) {
71 assert(llvm::isa<StructDefOp>(structOp));
72 Region &bodyRegion = llvm::cast<StructDefOp>(structOp).getBodyRegion();
73 if (!bodyRegion.empty()) {
74 bodyRegion.front().walk([](
FuncDefOp funcDef) {
91 std::string prefix = std::string();
92 if (SymbolOpInterface symbol = llvm::dyn_cast<SymbolOpInterface>(origin)) {
94 prefix += symbol.getName();
97 return origin->emitOpError().append(
103static inline InFlightDiagnostic structFuncDefError(Operation *origin) {
113 SymbolTableCollection &tables,
StructDefOp expectedStruct, Type actualType, Operation *origin,
116 if (
StructType actualStructType = llvm::dyn_cast<StructType>(actualType)) {
117 auto actualStructOpt =
119 if (failed(actualStructOpt)) {
120 return origin->emitError().append(
122 actualStructType.getNameRef(),
'"'
125 StructDefOp actualStruct = actualStructOpt.value().get();
126 if (actualStruct != expectedStruct) {
128 .attachNote(actualStruct.getLoc())
129 .append(
"uses this type instead");
136 if (ArrayAttr tyParams = actualStructType.getParams()) {
137 if (failed(
verifyParamsOfType(tables, tyParams.getValue(), actualStructType, origin))) {
143 .attachNote(actualStruct.getLoc())
158 assert(succeeded(pathRes));
165 if (succeeded(pathToExpected)) {
166 ss << pathToExpected.value();
173 ss <<
'<' << attr <<
'>';
180 for (Attribute attr : params) {
181 assert(llvm::isa<FlatSymbolRefAttr>(attr));
182 if (llvm::cast<FlatSymbolRefAttr>(attr).getRootReference() == find) {
192 assert(succeeded(res));
199 llvm::StringSet<> uniqNames;
200 for (Attribute attr : params) {
201 assert(llvm::isa<FlatSymbolRefAttr>(attr));
202 StringRef name = llvm::cast<FlatSymbolRefAttr>(attr).getValue();
203 if (!uniqNames.insert(name).second) {
204 return this->emitOpError().append(
"has more than one parameter named \"@", name,
'"');
208 for (Attribute attr : params) {
210 if (succeeded(res)) {
211 return this->emitOpError()
212 .append(
"parameter name \"@")
213 .append(llvm::cast<FlatSymbolRefAttr>(attr).getValue())
214 .append(
"\" conflicts with an existing symbol")
215 .attachNote(res->get()->getLoc())
216 .append(
"symbol already defined here");
226checkMainFuncParamType(Type pType,
FuncDefOp inFunc, std::optional<StructType> appendSelfType) {
227 if (llvm::isa<FeltType>(pType)) {
229 }
else if (
auto arrayParamTy = llvm::dyn_cast<ArrayType>(pType)) {
230 if (llvm::isa<FeltType>(arrayParamTy.getElementType())) {
236 ss <<
"main entry component \"@" << inFunc.
getSymName()
237 <<
"\" function parameters must be one of: {";
238 if (appendSelfType.has_value()) {
239 ss << appendSelfType.value() <<
", ";
244 return inFunc.emitError(message);
247inline LogicalResult verifyStructComputeConstrain(
248 StructDefOp structDef, FuncDefOp computeFunc, FuncDefOp constrainFunc
258 ArrayRef<Type> computeParams = computeFunc.
getFunctionType().getInputs();
259 ArrayRef<Type> constrainParams = constrainFunc.
getFunctionType().getInputs().drop_front();
264 for (Type t : computeParams) {
265 if (failed(checkMainFuncParamType(t, computeFunc, std::nullopt))) {
269 auto appendSelf = std::make_optional(structDef.
getType());
270 for (Type t : constrainParams) {
271 if (failed(checkMainFuncParamType(t, constrainFunc, appendSelf))) {
278 return constrainFunc.emitError()
281 "\" function argument types (sans the first one) to match \"@",
FUNC_NAME_COMPUTE,
282 "\" function argument types"
284 .attachNote(computeFunc.getLoc())
291inline LogicalResult verifyStructProduct(
StructDefOp structDef, FuncDefOp productFunc) {
298 ArrayRef<Type> productParams = productFunc.
getFunctionType().getInputs();
302 for (Type t : productParams) {
303 if (failed(checkMainFuncParamType(t, productFunc, std::nullopt))) {
315 std::optional<FuncDefOp> foundCompute = std::nullopt;
316 std::optional<FuncDefOp> foundConstrain = std::nullopt;
317 std::optional<FuncDefOp> foundProduct = std::nullopt;
325 if (!bodyRegion.empty()) {
326 for (Operation &op : bodyRegion.front()) {
327 if (!llvm::isa<MemberDefOp>(op)) {
328 if (
FuncDefOp funcDef = llvm::dyn_cast<FuncDefOp>(op)) {
329 if (funcDef.nameIsCompute()) {
331 return structFuncDefError(funcDef.getOperation())
336 return structFuncDefError(funcDef.getOperation())
339 foundCompute = std::make_optional(funcDef);
340 }
else if (funcDef.nameIsConstrain()) {
342 return structFuncDefError(funcDef.getOperation())
346 if (foundConstrain) {
347 return structFuncDefError(funcDef.getOperation())
350 foundConstrain = std::make_optional(funcDef);
351 }
else if (funcDef.nameIsProduct()) {
353 return structFuncDefError(funcDef.getOperation())
357 if (foundConstrain) {
358 return structFuncDefError(funcDef.getOperation())
363 return structFuncDefError(funcDef.getOperation())
366 foundProduct = std::make_optional(funcDef);
370 return structFuncDefError(funcDef.getOperation())
371 <<
"found \"@" << funcDef.getSymName() <<
'"';
374 return op.emitOpError()
383 if (!foundCompute.has_value() && foundConstrain.has_value()) {
387 if (!foundConstrain.has_value() && foundCompute.has_value()) {
393 if (!foundCompute.has_value() && !foundConstrain.has_value() && !foundProduct.has_value()) {
394 return structFuncDefError(getOperation())
399 if (foundCompute && foundConstrain) {
400 return verifyStructComputeConstrain(*
this, *foundCompute, *foundConstrain);
402 return verifyStructProduct(*
this, *foundProduct);
406 for (Operation &op : *getBody()) {
407 if (
MemberDefOp memberDef = llvm::dyn_cast_if_present<MemberDefOp>(op)) {
408 if (memberName.compare(memberDef.getSymNameAttr()) == 0) {
417 std::vector<MemberDefOp> res;
418 for (Operation &op : *getBody()) {
419 if (
MemberDefOp memberDef = llvm::dyn_cast_if_present<MemberDefOp>(op)) {
420 res.push_back(memberDef);
436 return llvm::dyn_cast<FuncDefOp>(computeFunc);
443 return llvm::dyn_cast<FuncDefOp>(constrainFunc);
450 if (succeeded(mainTypeOpt)) {
451 if (
StructType mainType = mainTypeOpt.value()) {
463 OpBuilder &odsBuilder, OperationState &odsState, StringAttr sym_name, TypeAttr type,
464 bool isSignal,
bool isColumn
470 props.column = odsBuilder.getUnitAttr();
473 props.signal = odsBuilder.getUnitAttr();
478 OpBuilder &odsBuilder, OperationState &odsState, StringRef sym_name, Type type,
bool isSignal,
482 odsBuilder, odsState, odsBuilder.getStringAttr(sym_name), TypeAttr::get(type), isSignal,
488 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, ValueRange operands,
489 ArrayRef<NamedAttribute> attributes,
bool isSignal,
bool isColumn
491 assert(operands.size() == 0u &&
"mismatched number of parameters");
492 odsState.addOperands(operands);
493 odsState.addAttributes(attributes);
494 assert(resultTypes.size() == 0u &&
"mismatched number of return types");
495 odsState.addTypes(resultTypes);
497 odsState.getOrAddProperties<
Properties>().column = odsBuilder.getUnitAttr();
500 odsState.getOrAddProperties<
Properties>().signal = odsBuilder.getUnitAttr();
506 getOperation()->setAttr(PublicAttr::name, UnitAttr::get(getContext()));
508 getOperation()->removeAttr(PublicAttr::name);
513verifyMemberDefTypeImpl(Type memberType, SymbolTableCollection &tables, Operation *origin) {
514 if (
StructType memberStructType = llvm::dyn_cast<StructType>(memberType)) {
518 if (failed(memberTypeRes)) {
522 assert(succeeded(parentRes) &&
"MemberDefOp parent is always StructDefOp");
523 if (memberTypeRes.value() == parentRes.value()) {
524 return origin->emitOpError()
525 .append(
"type is circular")
526 .attachNote(parentRes.value().getLoc())
527 .append(
"references parent component defined here");
536 Type memberType = this->
getType();
537 if (failed(verifyMemberDefTypeImpl(memberType, tables, *
this))) {
546 return emitOpError() <<
"marked as column can only contain felts, arrays of column types, or "
547 "structs with columns, but has type "
558FailureOr<SymbolLookupResult<MemberDefOp>>
560 Operation *op = refOp.getOperation();
562 if (failed(structDefRes)) {
566 llvm::SmallVector<llvm::StringRef> structDefOpNs(structDefRes->getNamespace());
568 tables, SymbolRefAttr::get(refOp->getContext(), refOp.
getMemberName()),
569 std::move(*structDefRes), op
578 res->prependNamespace(structDefOpNs);
579 return std::move(res.value());
582static FailureOr<SymbolLookupResult<MemberDefOp>>
590 return getMemberDefOpImpl(refOp, tables, tyStruct);
593static LogicalResult verifySymbolUsesImpl(
595 SymbolLookupResult<MemberDefOp> &member
598 Type actualType = refOp.
getVal().getType();
599 Type memberType = member.
get().getType();
601 return refOp->emitOpError() <<
"has wrong type; expected " << memberType <<
", got "
610 auto member = findMember(refOp, tables);
611 if (failed(member)) {
614 return verifySymbolUsesImpl(refOp, tables, *member);
619FailureOr<SymbolLookupResult<MemberDefOp>>
625 auto member = findMember(*
this, tables);
626 if (failed(member)) {
629 if (failed(verifySymbolUsesImpl(*
this, tables, *member))) {
634 return emitOpError(
"cannot read with table offset from a member that is not a column")
635 .attachNote(member->
get().getLoc())
636 .append(
"member defined here");
643 if (failed(memberParentRes)) {
646 StructDefOp memberParentStruct = memberParentRes.value();
647 if (!member->
get().hasPublicAttr() &&
648 (failed(parentRes) || parentRes.value() != memberParentStruct)) {
651 "cannot read from private member of struct \"", memberParentStruct.
getHeaderString(),
654 .attachNote(member->
get().getLoc())
655 .append(
"member defined here");
663 if (failed(getParentRes)) {
670 return verifySymbolUsesImpl(*
this, tables);
678 OpBuilder &builder, OperationState &state, Type resultType, Value
component, StringAttr member
682 state.addTypes(resultType);
688 OpBuilder &builder, OperationState &state, Type resultType, Value
component, StringAttr member,
689 Attribute dist, ValueRange mapOperands, std::optional<int32_t> numDims
692 assert(mapOperands.empty() || numDims.has_value());
694 state.addTypes(resultType);
695 if (numDims.has_value()) {
697 builder, state, ArrayRef({mapOperands}), builder.getDenseI32ArrayAttr({*numDims})
703 props.setMemberName(FlatSymbolRefAttr::get(member));
704 props.setTableOffset(dist);
708 OpBuilder & , OperationState &odsState, TypeRange resultTypes,
709 ValueRange operands, ArrayRef<NamedAttribute> attrs
711 odsState.addTypes(resultTypes);
712 odsState.addOperands(operands);
713 odsState.addAttributes(attrs);
717 SmallVector<AffineMapAttr, 1> mapAttrs;
718 if (AffineMapAttr map =
719 llvm::dyn_cast_if_present<AffineMapAttr>(
getTableOffset().value_or(
nullptr))) {
720 mapAttrs.push_back(map);
737 if (failed(getParentRes)) {
740 if (failed(
checkSelfType(tables, *getParentRes, this->getType(), *
this,
"result"))) {
llvm::ArrayRef< llvm::StringRef > getNamespace() const
Return the stack of symbol names from either IncludeOp or ModuleOp that were traversed to load this r...
static constexpr ::llvm::StringLiteral name
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn)
::mlir::TypedValue<::llzk::component::StructType > getResult()
void setPublicAttr(bool newValue=true)
static constexpr ::llvm::StringLiteral getOperationName()
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::StringAttr sym_name, ::mlir::TypeAttr type, bool isSignal=false, bool isColumn=false)
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
FoldAdaptor::Properties Properties
::std::optional<::mlir::Attribute > getTableOffset()
::mlir::OperandRangeRange getMapOperands()
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType, ::mlir::Value component, ::mlir::StringAttr member)
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
::llvm::ArrayRef< int32_t > getNumDimsPerMap()
::llvm::LogicalResult verify()
FoldAdaptor::Properties Properties
::mlir::Value getVal()
Gets the SSA Value that holds the read/write data for the MemberRefOp.
::mlir::FailureOr< SymbolLookupResult< MemberDefOp > > getMemberDefOp(::mlir::SymbolTableCollection &tables)
Gets the definition for the member referenced in this op.
::llvm::StringRef getMemberName()
Gets the member name attribute value from the MemberRefOp.
::llzk::component::StructType getStructType()
Gets the struct type of the target component.
::mlir::TypedValue<::llzk::component::StructType > getComponent()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
static mlir::LogicalResult verifyTrait(mlir::Operation *op)
StructType getType(::std::optional<::mlir::ArrayAttr > constParams={})
Gets the StructType representing this struct.
::mlir::Region & getBodyRegion()
::llvm::LogicalResult verifySymbolUses(::mlir::SymbolTableCollection &symbolTable)
static constexpr ::llvm::StringLiteral getOperationName()
::llzk::function::FuncDefOp getConstrainOrProductFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::llvm::StringRef getSymName()
::mlir::SymbolRefAttr getFullyQualifiedName()
Return the full name for this struct from the root module, including any surrounding module scopes.
::std::vector< MemberDefOp > getMemberDefs()
Get all MemberDefOp in this structure.
MemberDefOp getMemberDef(::mlir::StringAttr memberName)
Gets the MemberDefOp that defines the member in this structure with the given name,...
::llzk::function::FuncDefOp getConstrainFuncOp()
Gets the FuncDefOp that defines the constrain function in this structure, if present,...
::llzk::function::FuncDefOp getComputeOrProductFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
bool hasParamNamed(::mlir::StringAttr find)
Return true iff this StructDefOp has a parameter with the given name.
::llvm::LogicalResult verifyRegions()
::mlir::ArrayAttr getConstParamsAttr()
::llzk::function::FuncDefOp getComputeFuncOp()
Gets the FuncDefOp that defines the compute function in this structure, if present,...
bool isMainComponent()
Return true iff this StructDefOp is the main struct. See llzk::MAIN_ATTR_NAME.
::std::string getHeaderString()
Generate header string, in the same format as the assemblyFormat.
::mlir::SymbolRefAttr getNameRef() const
static StructType get(::mlir::SymbolRefAttr structName)
::mlir::FailureOr< SymbolLookupResult< StructDefOp > > getDefinition(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op, bool reportMissing=true) const
Gets the struct op that defines this struct.
::mlir::LogicalResult verifySymbolRef(::mlir::SymbolTableCollection &symbolTable, ::mlir::Operation *op)
static constexpr ::llvm::StringLiteral name
void setAllowWitnessAttr(bool newValue=true)
Add (resp. remove) the allow_witness attribute to (resp. from) the function def.
::mlir::FunctionType getFunctionType()
bool nameIsCompute()
Return true iff the function name is FUNC_NAME_COMPUTE (if needed, a check that this FuncDefOp is loc...
bool hasAllowWitnessAttr()
Return true iff the function def has the allow_witness attribute.
bool nameIsProduct()
Return true iff the function name is FUNC_NAME_PRODUCT (if needed, a check that this FuncDefOp is loc...
::llvm::StringRef getSymName()
bool nameIsConstrain()
Return true iff the function name is FUNC_NAME_CONSTRAIN (if needed, a check that this FuncDefOp is l...
static constexpr ::llvm::StringLiteral getOperationName()
void setAllowConstraintAttr(bool newValue=true)
Add (resp. remove) the allow_constraint attribute to (resp. from) the function def.
bool hasAllowConstraintAttr()
Return true iff the function def has the allow_constraint attribute.
OpClass::Properties & buildInstantiationAttrsEmptyNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
void buildInstantiationAttrsNoSegments(mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, mlir::ArrayRef< mlir::ValueRange > mapOperands, mlir::DenseI32ArrayAttr numDimsPerMap)
Utility for build() functions that initializes the mapOpGroupSizes, and numDimsPerMap attributes for ...
LogicalResult verifyAffineMapInstantiations(OperandRangeRange mapOps, ArrayRef< int32_t > numDimsPerMap, ArrayRef< AffineMapAttr > mapAttrs, Operation *origin)
bool isInStruct(Operation *op)
InFlightDiagnostic genCompareErr(StructDefOp expected, Operation *origin, const char *aspect)
LogicalResult checkSelfType(SymbolTableCollection &tables, StructDefOp expectedStruct, Type actualType, Operation *origin, const char *aspect)
Verifies that the given actualType matches the StructDefOp given (i.e., for the "self" type parameter...
FailureOr< StructDefOp > verifyInStruct(Operation *op)
bool isInStructFunctionNamed(Operation *op, char const *funcName)
constexpr char FUNC_NAME_COMPUTE[]
Symbol name for the witness generation (and resp.
bool typeListsUnify(Iter1 lhs, Iter2 rhs, mlir::ArrayRef< llvm::StringRef > rhsReversePrefix={}, UnificationMap *unifications=nullptr)
Return true iff the two lists of Type instances are equivalent or could be equivalent after full inst...
mlir::FailureOr< SymbolLookupResultUntyped > lookupTopLevelSymbol(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, mlir::Operation *origin, bool reportMissing=true)
FailureOr< StructType > getMainInstanceType(Operation *lookupFrom)
constexpr char FUNC_NAME_CONSTRAIN[]
bool structTypesUnify(StructType lhs, StructType rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
bool isValidColumnType(Type type, SymbolTableCollection &symbolTable, Operation *op)
constexpr char FUNC_NAME_PRODUCT[]
FailureOr< StructDefOp > verifyStructTypeResolution(SymbolTableCollection &tables, StructType ty, Operation *origin)
LogicalResult verifyTypeResolution(SymbolTableCollection &tables, Operation *origin, Type ty)
OwningEmitErrorFn getEmitOpErrFn(mlir::Operation *op)
mlir::FailureOr< SymbolLookupResultUntyped > lookupSymbolIn(mlir::SymbolTableCollection &tables, mlir::SymbolRefAttr symbol, Within &&lookupWithin, mlir::Operation *origin, bool reportMissing=true)
LogicalResult verifyParamsOfType(SymbolTableCollection &tables, ArrayRef< Attribute > tyParams, Type parameterizedType, Operation *origin)
mlir::FailureOr< OpClass > getParentOfType(mlir::Operation *op)
Return the closest surrounding parent operation that is of type 'OpClass'.
std::function< InFlightDiagnosticWrapper()> OwningEmitErrorFn
This type is required in cases like the functions below to take ownership of the lambda so it is not ...
bool typesUnify(Type lhs, Type rhs, ArrayRef< StringRef > rhsReversePrefix, UnificationMap *unifications)
std::string buildStringViaCallback(Func &&appendFn, Args &&...args)
Generate a string by calling the given appendFn with an llvm::raw_ostream & as the first argument fol...
FailureOr< SymbolRefAttr > getPathFromRoot(SymbolOpInterface to, ModuleOp *foundRoot)
void setSymName(const ::mlir::StringAttr &propValue)
void setMemberName(const ::mlir::FlatSymbolRefAttr &propValue)