58 tables, origin, ArrayRef<ArrayRef<Type>> {funcType.getInputs(), funcType.getResults()}
68 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs
74 Location location, StringRef name, FunctionType type, Operation::dialect_attr_range attrs
76 SmallVector<NamedAttribute, 8> attrRef(attrs);
77 return create(location, name, type, llvm::ArrayRef(attrRef));
81 Location location, StringRef name, FunctionType type, ArrayRef<NamedAttribute> attrs,
82 ArrayRef<DictionaryAttr> argAttrs
84 FuncDefOp func =
create(location, name, type, attrs);
85 func.setAllArgAttrs(argAttrs);
90 OpBuilder &builder, OperationState &state, StringRef name, FunctionType type,
91 ArrayRef<NamedAttribute> attrs, ArrayRef<DictionaryAttr> argAttrs
93 state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name));
95 state.attributes.append(attrs.begin(), attrs.end());
98 if (argAttrs.empty()) {
101 assert(type.getNumInputs() == argAttrs.size());
102 function_interface_impl::addArgAndResultAttrs(
109 auto buildFuncType = [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
110 function_interface_impl::VariadicFlag,
111 std::string &) {
return builder.getFunctionType(argTypes, results); };
113 return function_interface_impl::parseFunctionOp(
120 function_interface_impl::printFunctionOp(
130 llvm::MapVector<StringAttr, Attribute> newAttrMap;
131 for (
const auto &attr : dest->getAttrs()) {
132 newAttrMap.insert({attr.getName(), attr.getValue()});
134 for (
const auto &attr : (*this)->getAttrs()) {
135 newAttrMap.insert({attr.getName(), attr.getValue()});
139 llvm::to_vector(llvm::map_range(newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
140 return NamedAttribute(attrPair.first, attrPair.second);
142 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
155 FuncDefOp newFunc = llvm::cast<FuncDefOp>(getOperation()->cloneWithoutRegions());
163 unsigned oldNumArgs = oldType.getNumInputs();
164 SmallVector<Type, 4> newInputs;
165 newInputs.reserve(oldNumArgs);
166 for (
unsigned i = 0; i != oldNumArgs; ++i) {
167 if (!mapper.contains(getArgument(i))) {
168 newInputs.push_back(oldType.getInput(i));
174 if (newInputs.size() != oldNumArgs) {
175 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, oldType.getResults()));
177 if (ArrayAttr argAttrs = getAllArgAttrs()) {
178 SmallVector<Attribute> newArgAttrs;
179 newArgAttrs.reserve(newInputs.size());
180 for (
unsigned i = 0; i != oldNumArgs; ++i) {
181 if (!mapper.contains(getArgument(i))) {
182 newArgAttrs.push_back(argAttrs[i]);
185 newFunc.setAllArgAttrs(newArgAttrs);
197 return clone(mapper);
202 getOperation()->setAttr(AllowConstraintAttr::name, UnitAttr::get(getContext()));
204 getOperation()->removeAttr(AllowConstraintAttr::name);
210 getOperation()->setAttr(AllowWitnessAttr::name, UnitAttr::get(getContext()));
212 getOperation()->removeAttr(AllowWitnessAttr::name);
218 getOperation()->setAttr(AllowNonNativeFieldOpsAttr::name, UnitAttr::get(getContext()));
220 getOperation()->removeAttr(AllowNonNativeFieldOpsAttr::name);
225 if (index < this->getNumArguments()) {
226 DictionaryAttr res = function_interface_impl::getArgAttrDict(*
this, index);
227 return res ? res.contains(PublicAttr::name) :
false;
241 for (Type t : type.getInputs()) {
246 return emitErrorFunc().append(
247 "\"@", getName(),
"\" parameters cannot contain affine map attributes but found ", t
251 for (Type t : type.getResults()) {
259 WalkResult res = this->walk<WalkOrder::PreOrder>([
this](ModuleOp nestedMod) {
261 "cannot be nested within '", getOperation()->getName(),
"' operations"
263 return WalkResult::interrupt();
265 if (res.wasInterrupted()) {
277 llvm::ArrayRef<Type> resTypes = funcType.getResults();
279 if (resTypes.size() != 1) {
280 return origin.emitOpError().append(
284 if (failed(
checkSelfType(tables, parent, resTypes.front(), origin,
"return"))) {
295verifyFuncTypeProduct(
FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
297 return verifyFuncTypeCompute(origin, tables, parent);
301verifyFuncTypeConstrain(
FuncDefOp &origin, SymbolTableCollection &tables, StructDefOp &parent) {
304 if (funcType.getResults().size() != 0) {
305 return origin.emitOpError() <<
"\"@" <<
FUNC_NAME_CONSTRAIN <<
"\" must have no return type";
309 llvm::ArrayRef<Type> inputTypes = funcType.getInputs();
310 if (inputTypes.size() < 1) {
312 <<
"\" must have at least one input type";
314 if (failed(
checkSelfType(tables, parent, inputTypes.front(), origin,
"first input"))) {
330 if (succeeded(parentStructOpt)) {
333 return verifyFuncTypeCompute(*
this, tables, parentStructOpt.value());
335 return verifyFuncTypeConstrain(*
this, tables, parentStructOpt.value());
337 return verifyFuncTypeProduct(*
this, tables, parentStructOpt.value());
346 if (!requireParent && getOperation()->getParentOp() ==
nullptr) {
347 return SymbolRefAttr::get(getOperation());
350 assert(succeeded(res));
358 assert(!body.empty() &&
"compute() function body is empty");
359 Block &block = body.back();
362 Operation *terminator = block.getTerminator();
363 assert(terminator &&
"compute() function has no terminator");
364 auto retOp = llvm::dyn_cast<ReturnOp>(terminator);
367 << terminator->getName() <<
"'\n";
368 llvm_unreachable(
"compute() function must end with ReturnOp");
370 return retOp.getOperands().front();
375 return getArguments().front();
388 auto function = getParentOp<FuncDefOp>();
391 const auto results =
function.getFunctionType().getResults();
392 if (getNumOperands() != results.size()) {
393 return emitOpError(
"has ") << getNumOperands() <<
" operands, but enclosing function (@"
394 <<
function.getName() <<
") returns " << results.size();
397 for (
unsigned i = 0, e = results.size(); i != e; ++i) {
398 if (!
typesUnify(getOperand(i).getType(), results[i])) {
399 return emitError() <<
"type of return operand " << i <<
" (" << getOperand(i).getType()
400 <<
") doesn't match function result type (" << results[i] <<
")"
401 <<
" in function @" <<
function.getName();
413 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
414 ValueRange argOperands
416 odsState.addTypes(resultTypes);
417 odsState.addOperands(argOperands);
419 odsBuilder, odsState,
static_cast<int32_t
>(argOperands.size())
421 props.setCallee(callee);
425 OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, SymbolRefAttr callee,
426 ArrayRef<ValueRange> mapOperands, DenseI32ArrayAttr numDimsPerMap, ValueRange argOperands
428 odsState.addTypes(resultTypes);
429 odsState.addOperands(argOperands);
431 odsBuilder, odsState, mapOperands, numDimsPerMap, argOperands.size()
433 props.setCallee(callee);
438struct CallOpVerifier {
439 CallOpVerifier(
CallOp *c, StringRef tgtName) : callOp(c), tgtKind(
fnNameToKind(tgtName)) {}
440 virtual ~CallOpVerifier() =
default;
442 LogicalResult verify() {
445 LogicalResult aggregateResult = success();
446 if (failed(verifyTargetAttributes())) {
447 aggregateResult = failure();
449 if (failed(verifyInputs())) {
450 aggregateResult = failure();
452 if (failed(verifyOutputs())) {
453 aggregateResult = failure();
455 if (failed(verifyAffineMapParams())) {
456 aggregateResult = failure();
458 return aggregateResult;
465 virtual LogicalResult verifyTargetAttributes() = 0;
466 virtual LogicalResult verifyInputs() = 0;
467 virtual LogicalResult verifyOutputs() = 0;
468 virtual LogicalResult verifyAffineMapParams() = 0;
471 LogicalResult verifyTargetAttributesMatch(FuncDefOp target) {
472 LogicalResult aggregateRes = success();
473 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
474 auto emitAttrErr = [&](StringLiteral attrName) {
475 aggregateRes = callOp->emitOpError()
476 <<
"target '@" << target.getName() <<
"' has '" << attrName
477 <<
"' attribute, which is not specified by the caller '@" << caller.getName()
482 emitAttrErr(AllowConstraintAttr::name);
485 emitAttrErr(AllowWitnessAttr::name);
488 emitAttrErr(AllowNonNativeFieldOpsAttr::name);
494 LogicalResult verifyNoAffineMapInstantiations() {
497 return callOp->emitOpError().append(
498 "can only have affine map instantiations when targeting a \"@",
FUNC_NAME_COMPUTE,
504 assert(callOp->getMapOperands().empty());
509struct KnownTargetVerifier :
public CallOpVerifier {
510 KnownTargetVerifier(CallOp *c, SymbolLookupResult<FuncDefOp> &&tgtRes)
511 : CallOpVerifier(c, tgtRes.get().getSymName()), tgt(*tgtRes), tgtType(tgt.getFunctionType()),
512 includeSymNames(tgtRes.getNamespace()) {}
514 LogicalResult verifyTargetAttributes()
override {
515 return CallOpVerifier::verifyTargetAttributesMatch(tgt);
518 LogicalResult verifyInputs()
override {
519 return verifyTypesMatch(callOp->
getArgOperands().getTypes(), tgtType.getInputs(),
"operand");
522 LogicalResult verifyOutputs()
override {
523 return verifyTypesMatch(callOp->getResultTypes(), tgtType.getResults(),
"result");
526 LogicalResult verifyAffineMapParams()
override {
534 if (ArrayAttr params = retTy.getParams()) {
536 SmallVector<AffineMapAttr> mapAttrs;
537 for (Attribute a : params) {
538 if (AffineMapAttr m = dyn_cast<AffineMapAttr>(a)) {
539 mapAttrs.push_back(m);
550 return verifyNoAffineMapInstantiations();
555 template <
typename T>
557 verifyTypesMatch(ValueTypeRange<T> callOpTypes, ArrayRef<Type> tgtTypes,
const char *aspect) {
558 if (tgtTypes.size() != callOpTypes.size()) {
559 return callOp->emitOpError()
560 .append(
"incorrect number of ", aspect,
"s for callee, expected ", tgtTypes.size())
561 .attachNote(tgt.getLoc())
562 .append(
"callee defined here");
564 for (
unsigned i = 0, e = tgtTypes.size(); i != e; ++i) {
565 if (!
typesUnify(callOpTypes[i], tgtTypes[i], includeSymNames)) {
566 return callOp->emitOpError().append(
567 aspect,
" type mismatch: expected type ", tgtTypes[i],
", but found ", callOpTypes[i],
568 " for ", aspect,
" number ", i
576 FunctionType tgtType;
577 std::vector<llvm::StringRef> includeSymNames;
582LogicalResult checkSelfTypeUnknownTarget(
583 StringAttr expectedParamName, Type actualType,
CallOp *origin,
const char *aspect
585 if (!llvm::isa<TypeVarType>(actualType) ||
586 llvm::cast<TypeVarType>(actualType).getRefName() != expectedParamName) {
592 return origin->emitOpError().append(
593 "target \"@", origin->
getCallee().getLeafReference().getValue(),
"\" expected ", aspect,
594 " type '!",
TypeVarType::name,
"<@", expectedParamName.getValue(),
">' but found ",
610struct UnknownTargetVerifier :
public CallOpVerifier {
611 UnknownTargetVerifier(CallOp *c, SymbolRefAttr callee)
612 : CallOpVerifier(c, callee.getLeafReference().getValue()), calleeAttr(callee) {}
614 LogicalResult verifyTargetAttributes()
override {
617 LogicalResult aggregateRes = success();
618 if (FuncDefOp caller = (*callOp)->getParentOfType<FuncDefOp>()) {
619 auto emitAttrErr = [&](StringLiteral attrName) {
620 aggregateRes = callOp->emitOpError()
621 <<
"target '" << calleeAttr <<
"' has '" << attrName
622 <<
"' attribute, which is not specified by the caller '@" << caller.getName()
628 if (!caller.hasAllowConstraintAttr()) {
629 emitAttrErr(AllowConstraintAttr::name);
633 if (!caller.hasAllowWitnessAttr()) {
634 emitAttrErr(AllowWitnessAttr::name);
638 if (!caller.hasAllowWitnessAttr()) {
639 emitAttrErr(AllowWitnessAttr::name);
641 if (!caller.hasAllowConstraintAttr()) {
642 emitAttrErr(AllowConstraintAttr::name);
652 LogicalResult verifyInputs()
override {
658 Operation::operand_type_range inputTypes = callOp->
getArgOperands().getTypes();
659 if (inputTypes.size() < 1) {
661 return callOp->emitOpError()
664 return checkSelfTypeUnknownTarget(
665 calleeAttr.getRootReference(), inputTypes.front(), callOp,
"first input"
671 LogicalResult verifyOutputs()
override {
675 Operation::result_type_range resTypes = callOp->getResultTypes();
676 if (resTypes.size() != 1) {
678 return callOp->emitOpError().append(
682 return checkSelfTypeUnknownTarget(
683 calleeAttr.getRootReference(), resTypes.front(), callOp,
"return"
687 if (callOp->getNumResults() != 0) {
689 return callOp->emitOpError()
696 LogicalResult verifyAffineMapParams()
override {
701 return verifyNoAffineMapInstantiations();
707 SymbolRefAttr calleeAttr;
721 return emitOpError(
"requires a 'callee' symbol reference attribute");
726 if (calleeAttr.getNestedReferences().size() == 1) {
728 if (succeeded(parent) && parent->hasParamNamed(calleeAttr.getRootReference())) {
729 return UnknownTargetVerifier(
this, calleeAttr).verify();
736 if (failed(tgtOpt)) {
738 << calleeAttr <<
'"';
740 return KnownTargetVerifier(
this, std::move(*tgtOpt)).verify();
744 return FunctionType::get(getContext(),
getArgOperands().getTypes(), getResultTypes());
749bool calleeIsStructFunctionImpl(
750 const char *funcName, SymbolRefAttr callee, llvm::function_ref<
StructType()> getType
752 if (callee.getLeafReference() == funcName) {
780 return getResults().front();
789 Operation *thisOp = this->getOperation();
791 assert(succeeded(root));
814 llvm::SmallVector<ValueRange, 4> output;
815 output.reserve(input.size());
816 for (OperandRange r : input) {