68 tables, origin, ArrayRef<ArrayRef<Type>> {funcType.getInputs(), funcType.getResults()}
78 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs
84 Location location, StringRef name, FunctionType type, Operation::dialect_attr_range attrs
86 SmallVector<NamedAttribute, 8> attrRef(attrs);
87 return create(location, name, type, llvm::ArrayRef(attrRef));
91 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs,
92 ArrayRef<DictionaryAttr> argAttrs
94 FuncDefOp func =
create(location, name, type, attrs);
95 func.setAllArgAttrs(argAttrs);
100 OpBuilder &builder, OperationState &state, StringRef name, FunctionType type,
101 ArrayRef<NamedAttribute> attrs, ArrayRef<DictionaryAttr> argAttrs
103 state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name));
105 state.attributes.append(attrs.begin(), attrs.end());
108 if (argAttrs.empty()) {
111 assert(type.getNumInputs() == argAttrs.size());
112 function_interface_impl::addArgAndResultAttrs(
119 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
120 function_interface_impl::VariadicFlag,
121 std::string &) {
return builder.getFunctionType(argTypes, results); };
123 return function_interface_impl::parseFunctionOp(
130 function_interface_impl::printFunctionOp(
140 llvm::MapVector<StringAttr, Attribute> newAttrMap;
141 for (
const auto &attr : dest->getAttrs()) {
142 newAttrMap.insert({attr.getName(), attr.getValue()});
144 for (
const auto &attr : (*this)->getAttrs()) {
145 newAttrMap.insert({attr.getName(), attr.getValue()});
149 llvm::to_vector(llvm::map_range(newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
150 return NamedAttribute(attrPair.first, attrPair.second);
152 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
165 FuncDefOp newFunc = llvm::cast<FuncDefOp>(getOperation()->cloneWithoutRegions());
173 unsigned oldNumArgs = oldType.getNumInputs();
174 SmallVector<Type, 4> newInputs;
175 newInputs.reserve(oldNumArgs);
176 for (
unsigned i = 0; i != oldNumArgs; ++i) {
177 if (!mapper.contains(getArgument(i))) {
178 newInputs.push_back(oldType.getInput(i));
184 if (newInputs.size() != oldNumArgs) {
185 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, oldType.getResults()));
187 if (ArrayAttr argAttrs = getAllArgAttrs()) {
188 SmallVector<Attribute> newArgAttrs;
189 newArgAttrs.reserve(newInputs.size());
190 for (
unsigned i = 0; i != oldNumArgs; ++i) {
191 if (!mapper.contains(getArgument(i))) {
192 newArgAttrs.push_back(argAttrs[i]);
195 newFunc.setAllArgAttrs(newArgAttrs);
207 return clone(mapper);
212 getOperation()->setAttr(AllowConstraintAttr::name, UnitAttr::get(getContext()));
214 getOperation()->removeAttr(AllowConstraintAttr::name);
220 getOperation()->setAttr(AllowWitnessAttr::name, UnitAttr::get(getContext()));
222 getOperation()->removeAttr(AllowWitnessAttr::name);
228 getOperation()->setAttr(AllowNonNativeFieldOpsAttr::name, UnitAttr::get(getContext()));
230 getOperation()->removeAttr(AllowNonNativeFieldOpsAttr::name);
235 if (index < this->getNumArguments()) {
236 DictionaryAttr res = function_interface_impl::getArgAttrDict(*
this, index);
237 return res ? res.contains(PublicAttr::name) :
false;
251 for (Type t : type.getInputs()) {
256 return emitErrorFunc().append(
257 "\"@", getName(),
"\" parameters cannot contain affine map attributes but found ", t
261 for (Type t : type.getResults()) {
269 WalkResult res = this->walk<WalkOrder::PreOrder>([
this](ModuleOp nestedMod) {
271 "cannot be nested within '", getOperation()->getName(),
"' operations"
273 return WalkResult::interrupt();
275 if (res.wasInterrupted()) {
287 llvm::ArrayRef<Type> resTypes = funcType.getResults();
289 if (resTypes.size() != 1) {
290 return origin.emitOpError().append(
294 if (failed(
checkSelfType(tables, parent, resTypes.front(), origin,
"return"))) {
305verifyFuncTypeProduct(
FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
307 return verifyFuncTypeCompute(origin, tables, parent);
311verifyFuncTypeConstrain(
FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
314 if (funcType.getResults().size() != 0) {
315 return origin.emitOpError() <<
"\"@" <<
FUNC_NAME_CONSTRAIN <<
"\" must have no return type";
319 llvm::ArrayRef<Type> inputTypes = funcType.getInputs();
320 if (inputTypes.size() < 1) {
322 <<
"\" must have at least one input type";
324 if (failed(
checkSelfType(tables, parent, inputTypes.front(), origin,
"first input"))) {
342 return verifyFuncTypeCompute(*
this, tables, parentStructOpt);
344 return verifyFuncTypeConstrain(*
this, tables, parentStructOpt);
346 return verifyFuncTypeProduct(*
this, tables, parentStructOpt);
355 if (!requireParent && getOperation()->getParentOp() ==
nullptr) {
356 return SymbolRefAttr::get(getOperation());
359 assert(succeeded(res));
367 assert(!body.empty() &&
"compute() function body is empty");
368 Block &block = body.back();
371 Operation *terminator = block.getTerminator();
372 assert(terminator &&
"compute() function has no terminator");
373 auto retOp = llvm::dyn_cast<ReturnOp>(terminator);
376 << terminator->getName() <<
"'\n";
377 llvm_unreachable(
"compute() function must end with ReturnOp");
379 return retOp.getOperands().front();
384 return getArguments().front();
397 auto function = getParentOp<FuncDefOp>();
400 const auto results =
function.getFunctionType().getResults();
401 if (getNumOperands() != results.size()) {
402 return emitOpError(
"has ") << getNumOperands() <<
" operands, but enclosing function (@"
403 <<
function.getName() <<
") returns " << results.size();
406 for (
unsigned i = 0, e = results.size(); i != e; ++i) {
407 if (!
typesUnify(getOperand(i).getType(), results[i])) {
408 return emitError() <<
"type of return operand " << i <<
" (" << getOperand(i).getType()
409 <<
") doesn't match function result type (" << results[i] <<
")"
410 <<
" in function @" <<
function.getName();
424 auto &prop = state.getOrAddProperties<
Properties>();
425 if (failed(reader.readAttribute(prop.callee)) ||
426 failed(reader.readAttribute(prop.mapOpGroupSizes)) ||
427 failed(reader.readOptionalAttribute(prop.numDimsPerMap))) {
431 if (reader.getBytecodeVersion() < 6) {
432 auto &propStorage = prop.operandSegmentSizes;
433 DenseI32ArrayAttr attr;
434 if (failed(reader.readAttribute(attr))) {
437 if (attr.size() >
static_cast<int64_t
>(
sizeof(propStorage) /
sizeof(int32_t))) {
438 reader.emitError(
"size mismatch for operand/result_segment_size");
441 llvm::copy(ArrayRef<int32_t>(attr), propStorage.begin());
446 if (succeeded(versionOpt)) {
448 if (ver.majorVersion >= 2) {
449 if (failed(reader.readOptionalAttribute(prop.templateParams))) {
455 if (reader.getBytecodeVersion() >= 6) {
456 return reader.readSparseArray(MutableArrayRef(prop.operandSegmentSizes));
463 auto &prop = getProperties();
464 writer.writeAttribute(prop.callee);
465 writer.writeAttribute(prop.mapOpGroupSizes);
466 writer.writeOptionalAttribute(prop.numDimsPerMap);
468 if (writer.getBytecodeVersion() < 6) {
469 auto &propStorage = prop.operandSegmentSizes;
470 writer.writeAttribute(DenseI32ArrayAttr::get(this->getContext(), propStorage));
473 writer.writeOptionalAttribute(prop.templateParams);
475 auto &propStorage = prop.operandSegmentSizes;
476 if (writer.getBytecodeVersion() >= 6) {
477 writer.writeSparseArray(ArrayRef(propStorage));
481static void addTemplateParams(
484 if (!templateParams.empty()) {
490 ArrayRef<Attribute> converted = succeeded(r) ? r.value() : templateParams;
496 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
497 ValueRange argOperands, ArrayRef<Attribute> templateParams
499 odsState.addTypes(resultTypes);
500 odsState.addOperands(argOperands);
504 props.setCallee(callee);
505 addTemplateParams(odsBuilder, props, templateParams);
509 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
510 ArrayRef<ValueRange> mapOperands, DenseI32ArrayAttr numDimsPerMap, ValueRange argOperands,
511 ArrayRef<Attribute> templateParams
513 odsState.addTypes(resultTypes);
514 odsState.addOperands(argOperands);
516 odsBuilder, odsState, mapOperands, numDimsPerMap,
519 props.setCallee(callee);
520 addTemplateParams(odsBuilder, props, templateParams);
525 if (std::optional<Type> declaredType = targetParam.
getTypeOpt()) {
527 bool compatible =
false;
528 if (llvm::isa<TypeVarType>(*declaredType)) {
529 compatible = llvm::isa<TypeAttr>(paramFromCallOp);
530 }
else if (llvm::isa<FeltType>(*declaredType)) {
531 compatible = llvm::isa<FeltConstAttr, IntegerAttr>(paramFromCallOp) &&
533 }
else if (llvm::isa<IndexType, IntegerType>(*declaredType)) {
537 compatible = llvm::isa<IntegerAttr>(paramFromCallOp) &&
540 llvm_unreachable(
"inconsistent with `isValidConstReadType()`");
544 return this->emitOpError().append(
545 "instantiation value '", paramFromCallOp,
"' is not compatible with parameter \"@",
546 targetParam.getName(),
"\" type restriction ", *declaredType
554 llvm::iterator_range<Region::op_iterator<TemplateParamOp>> targetParamDefs
558 assert((callParams.size() == llvm::range_size(targetParamDefs)) &&
"pre-condition");
560 for (
auto [paramOp, attr] : llvm::zip_equal(targetParamDefs, callParams.getValue())) {
569 llvm::iterator_range<Region::op_iterator<TemplateParamOp>> targetParamDefs,
574 assert((callParams.size() == llvm::range_size(targetParamDefs)) &&
"pre-condition");
576 for (
auto [paramOp, attr] : llvm::zip_equal(targetParamDefs, callParams.getValue())) {
577 auto it = unifications.find({FlatSymbolRefAttr::get(paramOp.getNameAttr()),
Side::RHS});
578 if (it != unifications.end() && !
typeParamsUnify({attr}, {it->second})) {
580 return this->emitOpError().append(
581 "template instantiation value '", attr,
"' for parameter \"@", paramOp.getName(),
582 "\" conflicts with value '", it->second,
"' inferred from function type signature"
591struct CallOpVerifier {
592 CallOpVerifier(
CallOp *c,
FunctionKind tgtFuncKind) : callOp(c), tgtKind(tgtFuncKind) {}
593 CallOpVerifier(
CallOp *c, StringRef tgtName) : CallOpVerifier(c,
fnNameToKind(tgtName)) {}
594 virtual ~CallOpVerifier() =
default;
596 LogicalResult verify() {
599 LogicalResult aggregateResult = success();
600 if (failed(verifyTargetAttributes())) {
601 aggregateResult = failure();
603 if (failed(verifyInputs())) {
604 aggregateResult = failure();
606 if (failed(verifyOutputs())) {
607 aggregateResult = failure();
609 if (failed(verifyTemplateParams())) {
610 aggregateResult = failure();
612 if (failed(verifyAffineMapParams())) {
613 aggregateResult = failure();
615 return aggregateResult;
622 virtual LogicalResult verifyTargetAttributes() = 0;
623 virtual LogicalResult verifyInputs() = 0;
624 virtual LogicalResult verifyOutputs() = 0;
625 virtual LogicalResult verifyTemplateParams() = 0;
626 virtual LogicalResult verifyAffineMapParams() = 0;
629 LogicalResult verifyTargetAttributesMatch(FuncDefOp target) {
630 LogicalResult aggregateRes = success();
631 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
632 auto emitAttrErr = [&](StringLiteral attrName) {
633 aggregateRes = callOp->emitOpError()
634 <<
"target '@" << target.getName() <<
"' has '" << attrName
635 <<
"' attribute, which is not specified by the caller '@" << caller.getName()
640 emitAttrErr(AllowConstraintAttr::name);
643 emitAttrErr(AllowWitnessAttr::name);
646 emitAttrErr(AllowNonNativeFieldOpsAttr::name);
652 LogicalResult verifyNoTemplateInstantiations() {
655 return callOp->emitOpError().append(
656 "can only have template instantiations when targeting a templated free function"
662 LogicalResult verifyNoAffineMapInstantiations() {
665 return callOp->emitOpError().append(
666 "can only have affine map instantiations when targeting a \"@",
FUNC_NAME_COMPUTE,
672 assert(callOp->getMapOperands().empty());
677struct KnownTargetVerifier :
public CallOpVerifier {
678 KnownTargetVerifier(CallOp *c, SymbolLookupResult<FuncDefOp> &&tgtRes)
679 : CallOpVerifier(c, tgtRes.get().getSymName()), tgt(*tgtRes), tgtType(tgt.getFunctionType()),
680 includeSymNames(tgtRes.getNamespace()) {}
682 LogicalResult verifyTargetAttributes()
override {
683 return CallOpVerifier::verifyTargetAttributesMatch(tgt);
686 LogicalResult verifyInputs()
override {
687 return verifyTypesMatch(callOp->
getArgOperands().getTypes(), tgtType.getInputs(),
"operand");
690 LogicalResult verifyOutputs()
override {
691 return verifyTypesMatch(callOp->getResultTypes(), tgtType.getResults(),
"result");
694 LogicalResult verifyTemplateParams()
override {
695 auto tgtOp = tgt.getOperation();
698 return verifyNoTemplateInstantiations();
708 auto realParams = tgtOpParent.getConstOps<TemplateParamOp>();
713 llvm::SmallDenseSet<SymbolRefAttr> referencedInSignature;
717 bool allParamsReferenced = llvm::all_of(realParams, [&](TemplateParamOp p) {
718 return referencedInSignature.contains(FlatSymbolRefAttr::get(p.getNameAttr()));
720 if (allParamsReferenced) {
724 return callOp->emitOpError().append(
725 "must provide template instantiation parameters when calling \"@", tgt.getSymName(),
726 "\" because not all template parameters of \"@", tgtOpParent.getSymName(),
727 "\" appear in the function type signature"
733 return llzk::InFlightDiagnosticWrapper(this->callOp->emitOpError());
739 size_t numTemplateParams = llvm::range_size(realParams);
740 if (callParams.size() != numTemplateParams) {
742 return callOp->emitOpError().append(
743 "template instantiation has ", callParams.size(),
" parameter(s) but \"@",
744 tgtOpParent.getSymName(),
"\" expects ", numTemplateParams,
" template parameter(s)"
756 assert(succeeded(unifyResult) &&
"already checked by `verifyInputs()` and `verifyOutputs()`");
760 return verifyNoTemplateInstantiations();
764 LogicalResult verifyAffineMapParams()
override {
772 if (ArrayAttr params = retTy.getParams()) {
774 SmallVector<AffineMapAttr> mapAttrs;
775 for (Attribute a : params) {
776 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(a)) {
777 mapAttrs.push_back(m);
788 return verifyNoAffineMapInstantiations();
793 template <
typename T>
795 verifyTypesMatch(ValueTypeRange<T> callOpTypes, ArrayRef<Type> tgtTypes,
const char *aspect) {
796 if (tgtTypes.size() != callOpTypes.size()) {
797 return callOp->emitOpError()
798 .append(
"incorrect number of ", aspect,
"s for callee, expected ", tgtTypes.size())
799 .attachNote(tgt.getLoc())
800 .append(
"callee defined here");
802 for (
unsigned i = 0, e = tgtTypes.size(); i != e; ++i) {
803 if (!
typesUnify(callOpTypes[i], tgtTypes[i], includeSymNames)) {
804 return callOp->emitOpError().append(
805 aspect,
" type mismatch: expected type ", tgtTypes[i],
", but found ", callOpTypes[i],
806 " for ", aspect,
" number ", i
814 FunctionType tgtType;
815 std::vector<llvm::StringRef> includeSymNames;
820LogicalResult checkSelfTypeUnknownTarget(
821 StringAttr expectedParamName, Type actualType,
CallOp *origin,
const char *aspect
823 if (!llvm::isa<TypeVarType>(actualType) ||
824 llvm::cast<TypeVarType>(actualType).getRefName() != expectedParamName) {
830 return origin->emitOpError().append(
831 "target \"@", origin->
getCallee().getLeafReference().getValue(),
"\" expected ", aspect,
832 " type '!",
TypeVarType::name,
"<@", expectedParamName.getValue(),
">' but found ",
848struct UnknownTargetVerifier :
public CallOpVerifier {
849 UnknownTargetVerifier(CallOp *c,
FunctionKind tgtFuncKind, SymbolRefAttr callee)
850 : CallOpVerifier(c, tgtFuncKind), calleeAttr(callee) {
857 LogicalResult verifyTargetAttributes()
override {
860 LogicalResult aggregateRes = success();
861 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
862 auto emitAttrErr = [&](StringLiteral attrName) {
863 aggregateRes = callOp->emitOpError()
864 <<
"target '" << calleeAttr <<
"' has '" << attrName
865 <<
"' attribute, which is not specified by the caller '@" << caller.getName()
871 if (!caller.hasAllowConstraintAttr()) {
872 emitAttrErr(AllowConstraintAttr::name);
876 if (!caller.hasAllowWitnessAttr()) {
877 emitAttrErr(AllowWitnessAttr::name);
881 if (!caller.hasAllowWitnessAttr()) {
882 emitAttrErr(AllowWitnessAttr::name);
884 if (!caller.hasAllowConstraintAttr()) {
885 emitAttrErr(AllowConstraintAttr::name);
895 LogicalResult verifyInputs()
override {
901 Operation::operand_type_range inputTypes = callOp->
getArgOperands().getTypes();
902 if (inputTypes.size() < 1) {
904 return callOp->emitOpError()
907 return checkSelfTypeUnknownTarget(
908 calleeAttr.getRootReference(), inputTypes.front(), callOp,
"first input"
914 LogicalResult verifyOutputs()
override {
918 Operation::result_type_range resTypes = callOp->getResultTypes();
919 if (resTypes.size() != 1) {
921 return callOp->emitOpError().append(
925 return checkSelfTypeUnknownTarget(
926 calleeAttr.getRootReference(), resTypes.front(), callOp,
"return"
930 if (callOp->getNumResults() != 0) {
932 return callOp->emitOpError()
939 LogicalResult verifyTemplateParams()
override {
941 return verifyNoTemplateInstantiations();
944 LogicalResult verifyAffineMapParams()
override {
949 return verifyNoAffineMapInstantiations();
955 SymbolRefAttr calleeAttr;
969 return emitOpError(
"requires a 'callee' symbol reference attribute");
974 if (calleeAttr.getNestedReferences().size() == 1) {
976 if (parent.hasConstNamed<
TemplateParamOp>(calleeAttr.getRootReference())) {
979 return UnknownTargetVerifier(
this, tgtKind, calleeAttr).verify();
981 return this->emitError(
"expected parameterized callee to target a struct function")
993 if (failed(tgtOpt)) {
995 << calleeAttr <<
'"';
997 return KnownTargetVerifier(
this, std::move(*tgtOpt)).verify();
1001 return FunctionType::get(getContext(),
getArgOperands().getTypes(), getResultTypes());
1007 return unifications;
1015bool calleeIsStructFunctionImpl(
1016 const char *funcName, SymbolRefAttr callee, llvm::function_ref<
StructType()> getType
1018 if (callee.getLeafReference() == funcName) {
1046 return getResults().front();
1055 Operation *thisOp = this->getOperation();
1057 assert(succeeded(root));
1080 llvm::SmallVector<ValueRange, 4> output;
1081 output.reserve(input.size());
1082 for (OperandRange r : input) {
1083 output.push_back(r);
1089 FailureOr<SymbolLookupResult<FuncDefOp>> res =
1091 if (failed(res) || res->isManaged()) {
1099 SymbolTableCollection tables;