69 tables, origin, ArrayRef<ArrayRef<Type>> {funcType.getInputs(), funcType.getResults()}
77static LogicalResult verifyArgOrResNameAttrs(
78 ArrayAttr attrs, StringRef ownAttrName, StringRef crossAttrName, StringRef ownLabel,
84 llvm::DenseSet<StringAttr> seenNames;
85 for (
auto [i, attr] : llvm::enumerate(attrs)) {
86 auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr);
90 if (dictAttr.contains(crossAttrName)) {
91 return emitFn().append(
92 '\'', crossAttrName,
"' is only valid on function ", crossLabel,
"s but found on ",
96 Attribute nameAttr = dictAttr.get(ownAttrName);
100 auto name = llvm::dyn_cast<StringAttr>(nameAttr);
102 return emitFn().append(
103 '\'', ownAttrName,
"' on ", ownLabel,
' ', i,
" must be a string attribute"
106 if (!llvm::isa<NoneType>(name.getType())) {
107 return emitFn().append(
108 '\'', ownAttrName,
"' on ", ownLabel,
' ', i,
" must not have an explicit type"
111 if (name.getValue().empty()) {
112 return emitFn().append(
'\'', ownAttrName,
"' on ", ownLabel,
' ', i,
" must not be empty");
114 if (!seenNames.insert(name).second) {
115 return emitFn().append(
116 "duplicate '", ownAttrName,
"' value \"", name.getValue(),
"\" on ", ownLabel,
' ', i
129 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs
135 Location location, StringRef name, FunctionType type, Operation::dialect_attr_range attrs
137 SmallVector<NamedAttribute, 8> attrRef(attrs);
138 return create(location, name, type, llvm::ArrayRef(attrRef));
142 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs,
143 ArrayRef<DictionaryAttr> argAttrs
145 FuncDefOp func =
create(location, name, type, attrs);
146 func.setAllArgAttrs(argAttrs);
151 OpBuilder &builder, OperationState &state, StringRef name, FunctionType type,
152 ArrayRef<NamedAttribute> attrs, ArrayRef<DictionaryAttr> argAttrs
154 state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name));
156 state.attributes.append(attrs.begin(), attrs.end());
159 if (argAttrs.empty()) {
162 assert(type.getNumInputs() == argAttrs.size());
163 function_interface_impl::addArgAndResultAttrs(
170 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
171 function_interface_impl::VariadicFlag,
172 std::string &) {
return builder.getFunctionType(argTypes, results); };
174 return function_interface_impl::parseFunctionOp(
181 function_interface_impl::printFunctionOp(
191 llvm::MapVector<StringAttr, Attribute> newAttrMap;
192 for (
const auto &attr : dest->getAttrs()) {
193 newAttrMap.insert({attr.getName(), attr.getValue()});
195 for (
const auto &attr : (*this)->getAttrs()) {
196 newAttrMap.insert({attr.getName(), attr.getValue()});
200 llvm::to_vector(llvm::map_range(newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
201 return NamedAttribute(attrPair.first, attrPair.second);
203 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
216 FuncDefOp newFunc = llvm::cast<FuncDefOp>(getOperation()->cloneWithoutRegions());
224 unsigned oldNumArgs = oldType.getNumInputs();
225 SmallVector<Type, 4> newInputs;
226 newInputs.reserve(oldNumArgs);
227 for (
unsigned i = 0; i != oldNumArgs; ++i) {
228 if (!mapper.contains(getArgument(i))) {
229 newInputs.push_back(oldType.getInput(i));
235 if (newInputs.size() != oldNumArgs) {
236 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, oldType.getResults()));
238 if (ArrayAttr argAttrs = getAllArgAttrs()) {
239 SmallVector<Attribute> newArgAttrs;
240 newArgAttrs.reserve(newInputs.size());
241 for (
unsigned i = 0; i != oldNumArgs; ++i) {
242 if (!mapper.contains(getArgument(i))) {
243 newArgAttrs.push_back(argAttrs[i]);
246 newFunc.setAllArgAttrs(newArgAttrs);
258 return clone(mapper);
263 getOperation()->setAttr(AllowConstraintAttr::name, UnitAttr::get(getContext()));
265 getOperation()->removeAttr(AllowConstraintAttr::name);
271 getOperation()->setAttr(AllowWitnessAttr::name, UnitAttr::get(getContext()));
273 getOperation()->removeAttr(AllowWitnessAttr::name);
279 getOperation()->setAttr(AllowNonNativeFieldOpsAttr::name, UnitAttr::get(getContext()));
281 getOperation()->removeAttr(AllowNonNativeFieldOpsAttr::name);
286 if (index < this->getNumArguments()) {
287 DictionaryAttr res = function_interface_impl::getArgAttrDict(*
this, index);
288 return res ? res.contains(PublicAttr::name) :
false;
298 if (index >= getNumArguments()) {
308 assert(index < getNumArguments() &&
"argument index out of range");
320 return emitErrorFunc() <<
'\'' <<
ARG_NAME_ATTR_NAME <<
"' is only valid on function arguments";
323 return emitErrorFunc() <<
'\'' <<
RES_NAME_ATTR_NAME <<
"' is only valid on function results";
326 if (failed(verifyArgOrResNameAttrs(
333 if (failed(verifyArgOrResNameAttrs(
345 for (Type t : type.getInputs()) {
350 return emitErrorFunc().append(
351 "\"@", getName(),
"\" parameters cannot contain affine map attributes but found ", t
355 for (Type t : type.getResults()) {
363 WalkResult res = this->walk<WalkOrder::PreOrder>([
this](ModuleOp nestedMod) {
365 "cannot be nested within '", getOperation()->getName(),
"' operations"
367 return WalkResult::interrupt();
369 if (res.wasInterrupted()) {
381 llvm::ArrayRef<Type> resTypes = funcType.getResults();
383 if (resTypes.size() != 1) {
384 return origin.emitOpError().append(
388 if (failed(
checkSelfType(tables, parent, resTypes.front(), origin,
"return"))) {
399verifyFuncTypeProduct(
FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
401 return verifyFuncTypeCompute(origin, tables, parent);
405verifyFuncTypeConstrain(
FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
408 if (funcType.getResults().size() != 0) {
409 return origin.emitOpError() <<
"\"@" <<
FUNC_NAME_CONSTRAIN <<
"\" must have no return type";
413 llvm::ArrayRef<Type> inputTypes = funcType.getInputs();
414 if (inputTypes.size() < 1) {
416 <<
"\" must have at least one input type";
418 if (failed(
checkSelfType(tables, parent, inputTypes.front(), origin,
"first input"))) {
436 return verifyFuncTypeCompute(*
this, tables, parentStructOpt);
438 return verifyFuncTypeConstrain(*
this, tables, parentStructOpt);
440 return verifyFuncTypeProduct(*
this, tables, parentStructOpt);
449 if (!requireParent && getOperation()->getParentOp() ==
nullptr) {
450 return SymbolRefAttr::get(getOperation());
453 assert(succeeded(res));
461 assert(!body.empty() &&
"compute() function body is empty");
462 Block &block = body.back();
465 Operation *terminator = block.getTerminator();
466 assert(terminator &&
"compute() function has no terminator");
467 auto retOp = llvm::dyn_cast<ReturnOp>(terminator);
470 << terminator->getName() <<
"'\n";
471 llvm_unreachable(
"compute() function must end with ReturnOp");
473 return retOp.getOperands().front();
478 return getArguments().front();
491 auto function = getParentOp<FuncDefOp>();
494 const auto results =
function.getFunctionType().getResults();
495 if (getNumOperands() != results.size()) {
496 return emitOpError(
"has ") << getNumOperands() <<
" operands, but enclosing function (@"
497 <<
function.getName() <<
") returns " << results.size();
500 for (
unsigned i = 0, e = results.size(); i != e; ++i) {
501 if (!
typesUnify(getOperand(i).getType(), results[i])) {
502 return emitError() <<
"type of return operand " << i <<
" (" << getOperand(i).getType()
503 <<
") doesn't match function result type (" << results[i] <<
")"
504 <<
" in function @" <<
function.getName();
518 auto &prop = state.getOrAddProperties<
Properties>();
519 if (failed(reader.readAttribute(prop.callee)) ||
520 failed(reader.readAttribute(prop.mapOpGroupSizes)) ||
521 failed(reader.readOptionalAttribute(prop.numDimsPerMap))) {
525 if (reader.getBytecodeVersion() < 6) {
526 auto &propStorage = prop.operandSegmentSizes;
527 DenseI32ArrayAttr attr;
528 if (failed(reader.readAttribute(attr))) {
531 if (attr.size() >
static_cast<int64_t
>(
sizeof(propStorage) /
sizeof(int32_t))) {
532 reader.emitError(
"size mismatch for operand/result_segment_size");
535 llvm::copy(ArrayRef<int32_t>(attr), propStorage.begin());
540 if (succeeded(versionOpt)) {
542 if (ver.majorVersion >= 2) {
543 if (failed(reader.readOptionalAttribute(prop.templateParams))) {
549 if (reader.getBytecodeVersion() >= 6) {
550 return reader.readSparseArray(MutableArrayRef(prop.operandSegmentSizes));
557 auto &prop = getProperties();
558 writer.writeAttribute(prop.callee);
559 writer.writeAttribute(prop.mapOpGroupSizes);
560 writer.writeOptionalAttribute(prop.numDimsPerMap);
562 if (writer.getBytecodeVersion() < 6) {
563 auto &propStorage = prop.operandSegmentSizes;
564 writer.writeAttribute(DenseI32ArrayAttr::get(this->getContext(), propStorage));
567 writer.writeOptionalAttribute(prop.templateParams);
569 auto &propStorage = prop.operandSegmentSizes;
570 if (writer.getBytecodeVersion() >= 6) {
571 writer.writeSparseArray(ArrayRef(propStorage));
576 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
577 ValueRange argOperands, ArrayRef<Attribute> templateParams
579 odsState.addTypes(resultTypes);
580 odsState.addOperands(argOperands);
584 props.setCallee(callee);
589 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
590 ArrayRef<ValueRange> mapOperands, DenseI32ArrayAttr numDimsPerMap, ValueRange argOperands,
591 ArrayRef<Attribute> templateParams
593 odsState.addTypes(resultTypes);
594 odsState.addOperands(argOperands);
596 odsBuilder, odsState, mapOperands, numDimsPerMap,
599 props.setCallee(callee);
607 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(paramFromCallOp)) {
609 std::optional<Type> declaredType = targetParam.
getTypeOpt();
610 if (!declaredType || !llvm::isa<TypeVarType>(*declaredType)) {
611 auto diag = this->emitOpError().append(
612 "wildcard `?` can only be used for template parameters with `!poly.tvar` "
613 "type restriction, but parameter \"@",
614 targetParam.getName(),
"\" has "
617 diag.append(
"type restriction ", *declaredType);
619 diag.append(
"no type restriction");
626 if (std::optional<Type> declaredType = targetParam.
getTypeOpt()) {
628 bool compatible =
false;
629 if (llvm::isa<TypeVarType>(*declaredType)) {
630 compatible = llvm::isa<TypeAttr>(paramFromCallOp);
631 }
else if (llvm::isa<FeltType>(*declaredType)) {
632 compatible = llvm::isa<FeltConstAttr, IntegerAttr>(paramFromCallOp) &&
634 }
else if (llvm::isa<IndexType, IntegerType>(*declaredType)) {
638 compatible = llvm::isa<IntegerAttr>(paramFromCallOp) &&
641 llvm_unreachable(
"inconsistent with `isValidConstReadType()`");
645 return this->emitOpError().append(
646 "instantiation value '", paramFromCallOp,
"' is not compatible with parameter \"@",
647 targetParam.getName(),
"\" type restriction ", *declaredType
655 llvm::iterator_range<Region::op_iterator<TemplateParamOp>> targetParamDefs
659 assert((callParams.size() == llvm::range_size(targetParamDefs)) &&
"pre-condition");
661 for (
auto [paramOp, attr] : llvm::zip_equal(targetParamDefs, callParams.getValue())) {
670 llvm::iterator_range<Region::op_iterator<TemplateParamOp>> targetParamDefs,
675 assert((callParams.size() == llvm::range_size(targetParamDefs)) &&
"pre-condition");
677 for (
auto [paramOp, attr] : llvm::zip_equal(targetParamDefs, callParams.getValue())) {
679 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
684 auto it = unifications.find({FlatSymbolRefAttr::get(paramOp.getNameAttr()),
Side::RHS});
685 if (it != unifications.end() && !
typeParamsUnify({attr}, {it->second})) {
687 return this->emitOpError().append(
688 "template instantiation value '", attr,
"' for parameter \"@", paramOp.getName(),
689 "\" conflicts with value '", it->second,
"' inferred from function type signature"
698struct CallOpVerifier {
699 CallOpVerifier(
CallOp *c,
FunctionKind tgtFuncKind) : callOp(c), tgtKind(tgtFuncKind) {}
700 CallOpVerifier(
CallOp *c, StringRef tgtName) : CallOpVerifier(c,
fnNameToKind(tgtName)) {}
701 virtual ~CallOpVerifier() =
default;
703 LogicalResult verify() {
706 LogicalResult aggregateResult = success();
707 if (failed(verifyTargetAttributes())) {
708 aggregateResult = failure();
710 if (failed(verifyInputs())) {
711 aggregateResult = failure();
713 if (failed(verifyOutputs())) {
714 aggregateResult = failure();
716 if (failed(verifyTemplateParams())) {
717 aggregateResult = failure();
719 if (failed(verifyAffineMapParams())) {
720 aggregateResult = failure();
722 return aggregateResult;
729 virtual LogicalResult verifyTargetAttributes() = 0;
730 virtual LogicalResult verifyInputs() = 0;
731 virtual LogicalResult verifyOutputs() = 0;
732 virtual LogicalResult verifyTemplateParams() = 0;
733 virtual LogicalResult verifyAffineMapParams() = 0;
736 LogicalResult verifyTargetAttributesMatch(FuncDefOp target) {
737 LogicalResult aggregateRes = success();
738 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
739 auto emitAttrErr = [&](StringLiteral attrName) {
740 aggregateRes = callOp->emitOpError()
741 <<
"target '@" << target.getName() <<
"' has '" << attrName
742 <<
"' attribute, which is not specified by the caller '@" << caller.getName()
747 emitAttrErr(AllowConstraintAttr::name);
750 emitAttrErr(AllowWitnessAttr::name);
753 emitAttrErr(AllowNonNativeFieldOpsAttr::name);
759 LogicalResult verifyNoTemplateInstantiations() {
762 return callOp->emitOpError().append(
763 "can only have template instantiations when targeting a templated free function"
769 LogicalResult verifyNoAffineMapInstantiations() {
772 return callOp->emitOpError().append(
773 "can only have affine map instantiations when targeting a \"@",
FUNC_NAME_COMPUTE,
779 assert(callOp->getMapOperands().empty());
784struct KnownTargetVerifier :
public CallOpVerifier {
785 KnownTargetVerifier(CallOp *c, SymbolLookupResult<FuncDefOp> &&tgtRes)
786 : CallOpVerifier(c, tgtRes.get().getSymName()), tgt(*tgtRes), tgtType(tgt.getFunctionType()),
787 includeSymNames(tgtRes.getNamespace()) {}
789 LogicalResult verifyTargetAttributes()
override {
790 return CallOpVerifier::verifyTargetAttributesMatch(tgt);
793 LogicalResult verifyInputs()
override {
794 return verifyTypesMatch(callOp->
getArgOperands().getTypes(), tgtType.getInputs(),
"operand");
797 LogicalResult verifyOutputs()
override {
798 return verifyTypesMatch(callOp->getResultTypes(), tgtType.getResults(),
"result");
801 LogicalResult verifyTemplateParams()
override {
802 Operation *tgtOp = tgt.getOperation();
805 return verifyNoTemplateInstantiations();
815 auto realParams = tgtOpParent.getConstOps<TemplateParamOp>();
820 llvm::SmallDenseSet<SymbolRefAttr> referencedInSignature;
824 bool allParamsReferenced = llvm::all_of(realParams, [&](TemplateParamOp p) {
825 return referencedInSignature.contains(FlatSymbolRefAttr::get(p.getNameAttr()));
827 if (allParamsReferenced) {
831 return callOp->emitOpError().append(
832 "must provide template instantiation parameters when calling \"@", tgt.getSymName(),
833 "\" because not all template parameters of \"@", tgtOpParent.getSymName(),
834 "\" appear in the function type signature"
840 return llzk::InFlightDiagnosticWrapper(this->callOp->emitOpError());
846 size_t numTemplateParams = llvm::range_size(realParams);
847 if (callParams.size() != numTemplateParams) {
849 return callOp->emitOpError().append(
850 "template instantiation has ", callParams.size(),
" parameter(s) but \"@",
851 tgtOpParent.getSymName(),
"\" expects ", numTemplateParams,
" template parameter(s)"
863 assert(succeeded(unifyResult) &&
"already checked by `verifyInputs()` and `verifyOutputs()`");
867 return verifyNoTemplateInstantiations();
871 LogicalResult verifyAffineMapParams()
override {
879 if (ArrayAttr params = retTy.getParams()) {
881 SmallVector<AffineMapAttr> mapAttrs;
882 for (Attribute a : params) {
883 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(a)) {
884 mapAttrs.push_back(m);
895 return verifyNoAffineMapInstantiations();
900 template <
typename T>
902 verifyTypesMatch(ValueTypeRange<T> callOpTypes, ArrayRef<Type> tgtTypes,
const char *aspect) {
903 if (tgtTypes.size() != callOpTypes.size()) {
904 return callOp->emitOpError()
905 .append(
"incorrect number of ", aspect,
"s for callee, expected ", tgtTypes.size())
906 .attachNote(tgt.getLoc())
907 .append(
"callee defined here");
909 for (
unsigned i = 0, e = tgtTypes.size(); i != e; ++i) {
910 if (!
typesUnify(callOpTypes[i], tgtTypes[i], includeSymNames)) {
911 return callOp->emitOpError().append(
912 aspect,
" type mismatch: expected type ", tgtTypes[i],
", but found ", callOpTypes[i],
913 " for ", aspect,
" number ", i
921 FunctionType tgtType;
922 std::vector<llvm::StringRef> includeSymNames;
927LogicalResult checkSelfTypeUnknownTarget(
928 StringAttr expectedParamName, Type actualType,
CallOp *origin,
const char *aspect
930 if (!llvm::isa<TypeVarType>(actualType) ||
931 llvm::cast<TypeVarType>(actualType).getRefName() != expectedParamName) {
937 return origin->emitOpError().append(
938 "target \"@", origin->
getCallee().getLeafReference().getValue(),
"\" expected ", aspect,
939 " type '!",
TypeVarType::name,
"<@", expectedParamName.getValue(),
">' but found ",
955struct UnknownTargetVerifier :
public CallOpVerifier {
956 UnknownTargetVerifier(CallOp *c,
FunctionKind tgtFuncKind, SymbolRefAttr callee)
957 : CallOpVerifier(c, tgtFuncKind), calleeAttr(callee) {
964 LogicalResult verifyTargetAttributes()
override {
967 LogicalResult aggregateRes = success();
968 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
969 auto emitAttrErr = [&](StringLiteral attrName) {
970 aggregateRes = callOp->emitOpError()
971 <<
"target '" << calleeAttr <<
"' has '" << attrName
972 <<
"' attribute, which is not specified by the caller '@" << caller.getName()
978 if (!caller.hasAllowConstraintAttr()) {
979 emitAttrErr(AllowConstraintAttr::name);
983 if (!caller.hasAllowWitnessAttr()) {
984 emitAttrErr(AllowWitnessAttr::name);
988 if (!caller.hasAllowWitnessAttr()) {
989 emitAttrErr(AllowWitnessAttr::name);
991 if (!caller.hasAllowConstraintAttr()) {
992 emitAttrErr(AllowConstraintAttr::name);
1002 LogicalResult verifyInputs()
override {
1008 Operation::operand_type_range inputTypes = callOp->
getArgOperands().getTypes();
1009 if (inputTypes.size() < 1) {
1011 return callOp->emitOpError()
1014 return checkSelfTypeUnknownTarget(
1015 calleeAttr.getRootReference(), inputTypes.front(), callOp,
"first input"
1021 LogicalResult verifyOutputs()
override {
1025 Operation::result_type_range resTypes = callOp->getResultTypes();
1026 if (resTypes.size() != 1) {
1028 return callOp->emitOpError().append(
1032 return checkSelfTypeUnknownTarget(
1033 calleeAttr.getRootReference(), resTypes.front(), callOp,
"return"
1037 if (callOp->getNumResults() != 0) {
1039 return callOp->emitOpError()
1046 LogicalResult verifyTemplateParams()
override {
1048 return verifyNoTemplateInstantiations();
1051 LogicalResult verifyAffineMapParams()
override {
1056 return verifyNoAffineMapInstantiations();
1062 SymbolRefAttr calleeAttr;
1076 return emitOpError(
"requires a 'callee' symbol reference attribute");
1081 if (calleeAttr.getNestedReferences().size() == 1) {
1083 if (parent.hasConstNamed<
TemplateParamOp>(calleeAttr.getRootReference())) {
1086 return UnknownTargetVerifier(
this, tgtKind, calleeAttr).verify();
1088 return this->emitError(
"expected parameterized callee to target a struct function")
1100 if (failed(tgtOpt)) {
1102 << calleeAttr <<
'"';
1104 return KnownTargetVerifier(
this, std::move(*tgtOpt)).verify();
1108 return FunctionType::get(getContext(),
getArgOperands().getTypes(), getResultTypes());
1114 return unifications;
1122bool calleeIsStructFunctionImpl(
1123 const char *funcName, SymbolRefAttr callee, llvm::function_ref<
StructType()> getType
1125 if (callee.getLeafReference() == funcName) {
1153 return getResults().front();
1162 Operation *thisOp = this->getOperation();
1164 assert(succeeded(root));
1187 llvm::SmallVector<ValueRange, 4> output;
1188 output.reserve(input.size());
1189 for (OperandRange r : input) {
1190 output.push_back(r);
1196 FailureOr<SymbolLookupResult<FuncDefOp>> res =
1198 if (failed(res) || res->isManaged()) {
1206 SymbolTableCollection tables;