Skip to content

Commit 1459eca

Browse files
authored
Merge pull request #81214 from j-hui/base-6.2/swift-function-as-template-arg
2 parents 322a376 + 3f8e9cd commit 1459eca

File tree

6 files changed

+301
-37
lines changed

6 files changed

+301
-37
lines changed

lib/AST/ASTContext.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -6536,13 +6536,16 @@ const clang::Type *
65366536
ASTContext::getClangFunctionType(ArrayRef<AnyFunctionType::Param> params,
65376537
Type resultTy,
65386538
FunctionTypeRepresentation trueRep) {
6539-
return getClangTypeConverter().getFunctionType(params, resultTy, trueRep);
6539+
return getClangTypeConverter().getFunctionType(params, resultTy, trueRep,
6540+
/*templateArgument=*/false);
65406541
}
65416542

65426543
const clang::Type *ASTContext::getCanonicalClangFunctionType(
65436544
ArrayRef<SILParameterInfo> params, std::optional<SILResultInfo> result,
65446545
SILFunctionType::Representation trueRep) {
6545-
auto *ty = getClangTypeConverter().getFunctionType(params, result, trueRep);
6546+
auto *ty =
6547+
getClangTypeConverter().getFunctionType(params, result, trueRep,
6548+
/*templateArgument=*/false);
65466549
return ty ? ty->getCanonicalTypeInternal().getTypePtr() : nullptr;
65476550
}
65486551

lib/AST/ClangTypeConverter.cpp

+49-21
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "clang/Basic/TargetInfo.h"
4141
#include "clang/Sema/Sema.h"
4242

43+
#include "llvm/ADT/STLExtras.h"
4344
#include "llvm/ADT/StringSwitch.h"
4445
#include "llvm/Support/Compiler.h"
4546

@@ -124,17 +125,18 @@ const clang::ASTContext &clangCtx) {
124125

125126
const clang::Type *ClangTypeConverter::getFunctionType(
126127
ArrayRef<AnyFunctionType::Param> params, Type resultTy,
127-
AnyFunctionType::Representation repr) {
128-
129-
auto resultClangTy = convert(resultTy);
128+
AnyFunctionType::Representation repr, bool templateArgument) {
129+
auto resultClangTy =
130+
templateArgument ? convertTemplateArgument(resultTy) : convert(resultTy);
130131
if (resultClangTy.isNull())
131132
return nullptr;
132133

133134
SmallVector<clang::FunctionProtoType::ExtParameterInfo, 4> extParamInfos;
134135
SmallVector<clang::QualType, 4> paramsClangTy;
135136
bool someParamIsConsumed = false;
136137
for (auto p : params) {
137-
auto pc = convert(p.getPlainType());
138+
auto pc = templateArgument ? convertTemplateArgument(p.getPlainType())
139+
: convert(p.getPlainType());
138140
if (pc.isNull())
139141
return nullptr;
140142
clang::FunctionProtoType::ExtParameterInfo extParamInfo;
@@ -165,16 +167,19 @@ const clang::Type *ClangTypeConverter::getFunctionType(
165167
llvm_unreachable("invalid representation");
166168
}
167169

168-
const clang::Type *
169-
ClangTypeConverter::getFunctionType(ArrayRef<SILParameterInfo> params,
170-
std::optional<SILResultInfo> result,
171-
SILFunctionType::Representation repr) {
172-
173-
// Using the interface type is sufficient as type parameters get mapped to
174-
// `id`, since ObjC lightweight generics use type erasure. (See also: SE-0057)
175-
auto resultClangTy = result.has_value()
176-
? convert(result.value().getInterfaceType())
177-
: ClangASTContext.VoidTy;
170+
const clang::Type *ClangTypeConverter::getFunctionType(
171+
ArrayRef<SILParameterInfo> params, std::optional<SILResultInfo> result,
172+
SILFunctionType::Representation repr, bool templateArgument) {
173+
clang::QualType resultClangTy = ClangASTContext.VoidTy;
174+
if (result) {
175+
// Using the interface type is sufficient as type parameters get mapped to
176+
// `id`, since ObjC lightweight generics use type erasure.
177+
//
178+
// (See also: SE-0057)
179+
auto interfaceType = result->getInterfaceType();
180+
resultClangTy = templateArgument ? convertTemplateArgument(interfaceType)
181+
: convert(interfaceType);
182+
}
178183

179184
if (resultClangTy.isNull())
180185
return nullptr;
@@ -183,7 +188,8 @@ ClangTypeConverter::getFunctionType(ArrayRef<SILParameterInfo> params,
183188
SmallVector<clang::QualType, 4> paramsClangTy;
184189
bool someParamIsConsumed = false;
185190
for (auto &p : params) {
186-
auto pc = convert(p.getInterfaceType());
191+
auto pc = templateArgument ? convertTemplateArgument(p.getInterfaceType())
192+
: convert(p.getInterfaceType());
187193
if (pc.isNull())
188194
return nullptr;
189195
clang::FunctionProtoType::ExtParameterInfo extParamInfo;
@@ -651,7 +657,8 @@ clang::QualType ClangTypeConverter::visitEnumType(EnumType *type) {
651657
return convert(type->getDecl()->getRawType());
652658
}
653659

654-
clang::QualType ClangTypeConverter::visitFunctionType(FunctionType *type) {
660+
clang::QualType ClangTypeConverter::visitFunctionType(FunctionType *type,
661+
bool templateArgument) {
655662
const clang::Type *clangTy = nullptr;
656663
auto repr = type->getRepresentation();
657664
bool useClangTypes = type->getASTContext().LangOpts.UseClangFunctionTypes;
@@ -665,12 +672,15 @@ clang::QualType ClangTypeConverter::visitFunctionType(FunctionType *type) {
665672
auto newRepr = (repr == FunctionTypeRepresentation::Swift
666673
? FunctionTypeRepresentation::Block
667674
: repr);
668-
clangTy = getFunctionType(type->getParams(), type->getResult(), newRepr);
675+
clangTy = getFunctionType(type->getParams(), type->getResult(), newRepr,
676+
templateArgument);
669677
}
670678
return clang::QualType(clangTy, 0);
671679
}
672680

673-
clang::QualType ClangTypeConverter::visitSILFunctionType(SILFunctionType *type) {
681+
clang::QualType
682+
ClangTypeConverter::visitSILFunctionType(SILFunctionType *type,
683+
bool templateArgument) {
674684
const clang::Type *clangTy = nullptr;
675685
auto repr = type->getRepresentation();
676686
bool useClangTypes = type->getASTContext().LangOpts.UseClangFunctionTypes;
@@ -688,7 +698,8 @@ clang::QualType ClangTypeConverter::visitSILFunctionType(SILFunctionType *type)
688698
auto optionalResult = results.empty()
689699
? std::nullopt
690700
: std::optional<SILResultInfo>(results[0]);
691-
clangTy = getFunctionType(type->getParameters(), optionalResult, newRepr);
701+
clangTy = getFunctionType(type->getParameters(), optionalResult, newRepr,
702+
templateArgument);
692703
}
693704
return clang::QualType(clangTy, 0);
694705
}
@@ -933,6 +944,13 @@ clang::QualType ClangTypeConverter::convertTemplateArgument(Type type) {
933944
if (auto floatType = type->getAs<BuiltinFloatType>())
934945
return withCache([&]() { return visitBuiltinFloatType(floatType); });
935946

947+
if (auto tupleType = type->getAs<TupleType>()) {
948+
// We do not call visitTupleType() because we cannot yet handle tuples with
949+
// a non-zero number of elements.
950+
if (tupleType->getNumElements() == 0)
951+
return ClangASTContext.VoidTy;
952+
}
953+
936954
if (auto structType = type->getAs<StructType>()) {
937955
// Swift structs are not supported in general, but some foreign types are
938956
// imported as Swift structs. We reverse that mapping here.
@@ -953,8 +971,6 @@ clang::QualType ClangTypeConverter::convertTemplateArgument(Type type) {
953971
return withCache([&]() { return reverseBuiltinTypeMapping(structType); });
954972
}
955973

956-
// TODO: function pointers are not yet supported, but they should be.
957-
958974
if (auto boundGenericType = type->getAs<BoundGenericType>()) {
959975
if (boundGenericType->getGenericArgs().size() != 1)
960976
// Must've got something other than a T?, *Pointer<T>, or SIMD*<T>
@@ -991,6 +1007,18 @@ clang::QualType ClangTypeConverter::convertTemplateArgument(Type type) {
9911007
return clang::QualType();
9921008
}
9931009

1010+
if (auto functionType = type->getAs<FunctionType>()) {
1011+
return withCache([&]() {
1012+
return visitFunctionType(functionType, /*templateArgument=*/true);
1013+
});
1014+
}
1015+
1016+
if (auto functionType = type->getAs<SILFunctionType>()) {
1017+
return withCache([&]() {
1018+
return visitSILFunctionType(functionType, /*templateArgument=*/true);
1019+
});
1020+
}
1021+
9941022
// Most types cannot be used to instantiate C++ function templates; give up.
9951023
return clang::QualType();
9961024
}

lib/AST/ClangTypeConverter.h

+10-6
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,16 @@ class ClangTypeConverter :
7070
/// \returns The appropriate clang type on success, nullptr on failure.
7171
///
7272
/// Precondition: The representation argument must be C-compatible.
73-
const clang::Type *getFunctionType(
74-
ArrayRef<AnyFunctionType::Param> params, Type resultTy,
75-
AnyFunctionType::Representation repr);
73+
const clang::Type *getFunctionType(ArrayRef<AnyFunctionType::Param> params,
74+
Type resultTy,
75+
AnyFunctionType::Representation repr,
76+
bool templateArgument);
7677

7778
/// Compute the C function type for a SIL function type.
7879
const clang::Type *getFunctionType(ArrayRef<SILParameterInfo> params,
7980
std::optional<SILResultInfo> result,
80-
SILFunctionType::Representation repr);
81+
SILFunctionType::Representation repr,
82+
bool templateArgument);
8183

8284
/// Check whether the given Clang declaration is an export of a Swift
8385
/// declaration introduced by this converter, and if so, return the original
@@ -148,15 +150,17 @@ class ClangTypeConverter :
148150
clang::QualType visitBoundGenericClassType(BoundGenericClassType *type);
149151
clang::QualType visitBoundGenericType(BoundGenericType *type);
150152
clang::QualType visitEnumType(EnumType *type);
151-
clang::QualType visitFunctionType(FunctionType *type);
153+
clang::QualType visitFunctionType(FunctionType *type,
154+
bool templateArgument = false);
152155
clang::QualType visitProtocolCompositionType(ProtocolCompositionType *type);
153156
clang::QualType visitExistentialType(ExistentialType *type);
154157
clang::QualType visitBuiltinRawPointerType(BuiltinRawPointerType *type);
155158
clang::QualType visitBuiltinIntegerType(BuiltinIntegerType *type);
156159
clang::QualType visitBuiltinFloatType(BuiltinFloatType *type);
157160
clang::QualType visitArchetypeType(ArchetypeType *type);
158161
clang::QualType visitDependentMemberType(DependentMemberType *type);
159-
clang::QualType visitSILFunctionType(SILFunctionType *type);
162+
clang::QualType visitSILFunctionType(SILFunctionType *type,
163+
bool templateArgument = false);
160164
clang::QualType visitGenericTypeParamType(GenericTypeParamType *type);
161165
clang::QualType visitDynamicSelfType(DynamicSelfType *type);
162166
clang::QualType visitSILBlockStorageType(SILBlockStorageType *type);

test/Interop/Cxx/templates/Inputs/function-templates.h

+36-1
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,27 @@ template <class T> void expectsConstCharPtr(T str) { takesString(str); }
1515
template <long x> void hasNonTypeTemplateParameter() {}
1616
template <long x = 0> void hasDefaultedNonTypeTemplateParameter() {}
1717

18+
// NOTE: these will cause multi-def linker errors if used in more than one compilation unit
1819
int *intPtr;
19-
int (*functionPtr)(void);
20+
21+
int get42(void) { return 42; }
22+
int (*functionPtrGet42)(void) = &get42;
23+
int (*_Nonnull nonNullFunctionPtrGet42)(void) = &get42;
24+
25+
int tripleInt(int x) { return x * 3; }
26+
int (*functionPtrTripleInt)(int) = &tripleInt;
27+
int (*_Nonnull nonNullFunctionPtrTripleInt)(int) = &tripleInt;
28+
29+
int (^blockReturns111)(void) = ^{ return 111; };
30+
int (^_Nonnull nonNullBlockReturns222)(void) = ^{ return 222; };
31+
32+
int (^blockTripleInt)(int) = ^(int x) { return x * 3; };
33+
int (^_Nonnull nonNullBlockTripleInt)(int) = ^(int x) { return x * 3; };
34+
35+
// These functions construct block literals that capture a local variable, and
36+
// then feed those blocks back to Swift via the given Swift closure (cb).
37+
void getConstantIntBlock(int returnValue, void (^_Nonnull cb)(int (^_Nonnull)(void))) { cb(^{ return returnValue; }); }
38+
int getMultiplyIntBlock(int multiplier, int (^_Nonnull cb)(int (^_Nonnull)(int))) { return cb(^(int x) { return x * multiplier; }); }
2039

2140
// We cannot yet use this in Swift but, make sure we don't crash when parsing
2241
// it.
@@ -59,6 +78,7 @@ struct PlainStruct {
5978
struct CxxClass {
6079
int x;
6180
void method() {}
81+
int getX() const { return x; }
6282
};
6383

6484
struct __attribute__((swift_attr("import_reference")))
@@ -102,6 +122,21 @@ template <class T> void forwardingReference(T &&) {}
102122

103123
template <class T> void PointerTemplateParameter(T*){}
104124

125+
template <typename F> void callFunction(F f) { f(); }
126+
template <typename F, typename T> void callFunctionWithParam(F f, T t) { f(t); }
127+
template <typename F, typename T> T callFunctionWithReturn(F f) { return f(); }
128+
template <typename F, typename T> T callFunctionWithPassthrough(F f, T t) { return f(t); }
129+
130+
static inline void callBlock(void (^_Nonnull callback)(void)) { callback(); }
131+
template <typename F> void indirectlyCallFunction(F f) { callBlock(f); }
132+
template <typename F> void indirectlyCallFunctionTemplate(F f) { callFunction(f); }
133+
134+
static inline void callBlockWith42(void (^_Nonnull callback)(int)) { callback(42); }
135+
template <typename F> void indirectlyCallFunctionWith42(F f) { callBlockWith42(f); }
136+
137+
static inline void callBlockWithCxxClass24(void (^_Nonnull cb)(CxxClass)) { CxxClass c = {24}; cb(c); }
138+
template <typename F> void indirectlyCallFunctionWithCxxClass24(F f) { callBlockWithCxxClass24(f); }
139+
105140
namespace Orbiters {
106141

107142
template<class T>

0 commit comments

Comments
 (0)