Skip to content

Commit 7fc815e

Browse files
committed
[cxx-interop] Allow initializing std::function from Swift capturing closures
This introduces support for converting a Swift closure that captures variables from its surrounding context into an instance of `std::function`, which is useful for working with C++ APIs that use callbacks. Each instantiation of `std::function` gets a synthesized Swift constructor that takes a Swift closure. Unlike the previous implementation, the closure is _not_ marked as `@convention(c)`. The body of the constructor is created lazily. Under the hood, the closure is bitcast to a pair of a function pointer and a context pointer, which are then wrapped in a C++ object, `__SwiftFunctionWrapper`, that manages the lifetime of the context object via calls to `swift_retain`/`swift_release` from the copy constructor and the destructor. The `__SwiftFunctionWrapper` class is templated, and is instantiated by ClangImporter. rdar://133777029
1 parent 11d7c1e commit 7fc815e

23 files changed

+912
-103
lines changed

include/swift/AST/ClangModuleLoader.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,24 @@ class ClangModuleLoader : public ModuleLoader {
298298

299299
virtual FuncDecl *getDefaultArgGenerator(const clang::ParmVarDecl *param) = 0;
300300

301+
/// Determine whether this is a functional C++ type, e.g. std::function, for
302+
/// which Swift provides a synthesized constructor that takes a Swift closure
303+
/// as the single parameter.
304+
virtual bool
305+
needsClosureConstructor(const clang::CXXRecordDecl *recordDecl) const = 0;
306+
307+
/// Determine whether this is an instantiation of the __SwiftFunctionWrapper
308+
/// type, which wraps around a Swift closure along with its context.
309+
virtual bool isSwiftFunctionWrapper(const clang::RecordDecl *decl) const = 0;
310+
virtual bool isDeconstructedSwiftClosure(const clang::Type* type) const = 0;
311+
312+
/// Given a functional C++ type, e.g. std::function, determine the
313+
/// corresponding C++ closure type.
314+
///
315+
/// \see needsClosureConstructor
316+
virtual const clang::FunctionType *extractCXXFunctionType(
317+
const clang::CXXRecordDecl *functionalTypeDecl) const = 0;
318+
301319
virtual FuncDecl *
302320
getAvailabilityDomainPredicate(const clang::VarDecl *var) = 0;
303321

include/swift/AST/Types.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ namespace llvm {
5454
struct fltSemantics;
5555
} // namespace llvm
5656

57+
namespace clang {
58+
class FunctionType;
59+
} // namespace clang
60+
5761
namespace swift {
5862

5963
enum class AllocationArena;
@@ -4101,6 +4105,10 @@ class FunctionType final
41014105
return getLifetimeDependenceFor(getNumParams());
41024106
}
41034107

4108+
uint16_t
4109+
getPointerAuthDiscriminator(ModuleDecl &m,
4110+
const clang::FunctionType *clangType = nullptr);
4111+
41044112
void Profile(llvm::FoldingSetNodeID &ID) {
41054113
std::optional<ExtInfo> info = std::nullopt;
41064114
if (hasExtInfo())

include/swift/ClangImporter/ClangImporter.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,15 @@ class ClangImporter final : public ClangModuleLoader {
620620

621621
FuncDecl *getDefaultArgGenerator(const clang::ParmVarDecl *param) override;
622622

623+
bool needsClosureConstructor(
624+
const clang::CXXRecordDecl *recordDecl) const override;
625+
626+
bool isSwiftFunctionWrapper(const clang::RecordDecl *decl) const override;
627+
bool isDeconstructedSwiftClosure(const clang::Type *type) const override;
628+
629+
const clang::FunctionType *extractCXXFunctionType(
630+
const clang::CXXRecordDecl *functionalTypeDecl) const override;
631+
623632
FuncDecl *getAvailabilityDomainPredicate(const clang::VarDecl *var) override;
624633

625634
bool isAnnotatedWith(const clang::CXXMethodDecl *method, StringRef attr);

include/swift/SIL/AbstractionPattern.h

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace llvm {
2828

2929
namespace clang {
3030
class CXXMethodDecl;
31+
class CXXRecordDecl;
3132
class ObjCMethodDecl;
3233
class Type;
3334
class ValueDecl;
@@ -198,6 +199,10 @@ class AbstractionPattern {
198199
/// non-static member function. OrigType is valid and is a function type.
199200
/// CXXMethod is valid.
200201
PartialCurriedCXXMethodType,
202+
/// The type of a constructor that initializes a C++ functional type, e.g.
203+
/// std::function, with a Swift closure. This constructor is synthesized by
204+
/// ClangImporter. ClangType is valid.
205+
CXXFunctionalConstructorType,
201206
/// A Swift function whose parameters and results are opaque. This is
202207
/// like `AP::Type<T>((T) -> T)`, except that the number of parameters is
203208
/// unspecified.
@@ -462,6 +467,7 @@ class AbstractionPattern {
462467
case Kind::CFunctionAsMethodType:
463468
case Kind::CurriedCFunctionAsMethodType:
464469
case Kind::PartialCurriedCFunctionAsMethodType:
470+
case Kind::CXXFunctionalConstructorType:
465471
case Kind::ObjCCompletionHandlerArgumentsType:
466472
return true;
467473

@@ -632,6 +638,7 @@ class AbstractionPattern {
632638
case Kind::CXXMethodType:
633639
case Kind::CurriedCXXMethodType:
634640
case Kind::PartialCurriedCXXMethodType:
641+
case Kind::CXXFunctionalConstructorType:
635642
case Kind::ObjCCompletionHandlerArgumentsType:
636643
return true;
637644
case Kind::Invalid:
@@ -766,6 +773,13 @@ class AbstractionPattern {
766773
return pattern;
767774
}
768775

776+
/// Return an abstraction pattern for a constructor of a functional C++ type,
777+
/// e.g. std::function, which takes a Swift closure as a single parameter.
778+
/// This constructor was synthesized by ClangImporter.
779+
static AbstractionPattern
780+
getCXXFunctionalConstructor(CanType origType,
781+
const clang::CXXRecordDecl *functionalTypeDecl);
782+
769783
/// For a C-function-as-method pattern,
770784
/// get the index of the C function parameter that was imported as the
771785
/// `self` parameter of the imported method, or None if this is a static
@@ -1048,6 +1062,7 @@ class AbstractionPattern {
10481062
case Kind::CXXMethodType:
10491063
case Kind::CurriedCXXMethodType:
10501064
case Kind::PartialCurriedCXXMethodType:
1065+
case Kind::CXXFunctionalConstructorType:
10511066
case Kind::Type:
10521067
case Kind::Discard:
10531068
return OrigType;
@@ -1084,6 +1099,7 @@ class AbstractionPattern {
10841099
case Kind::CXXMethodType:
10851100
case Kind::CurriedCXXMethodType:
10861101
case Kind::PartialCurriedCXXMethodType:
1102+
case Kind::CXXFunctionalConstructorType:
10871103
case Kind::Type:
10881104
case Kind::Discard:
10891105
case Kind::ObjCCompletionHandlerArgumentsType:
@@ -1131,6 +1147,7 @@ class AbstractionPattern {
11311147
case Kind::CFunctionAsMethodType:
11321148
case Kind::CurriedCFunctionAsMethodType:
11331149
case Kind::PartialCurriedCFunctionAsMethodType:
1150+
case Kind::CXXFunctionalConstructorType:
11341151
case Kind::CXXMethodType:
11351152
case Kind::CurriedCXXMethodType:
11361153
case Kind::PartialCurriedCXXMethodType:
@@ -1148,7 +1165,8 @@ class AbstractionPattern {
11481165
/// Return whether this abstraction pattern represents a Clang type.
11491166
/// If so, it is legal to return getClangType().
11501167
bool isClangType() const {
1151-
return (getKind() == Kind::ClangType);
1168+
return getKind() == Kind::ClangType ||
1169+
getKind() == Kind::CXXFunctionalConstructorType;
11521170
}
11531171

11541172
const clang::Type *getClangType() const {
@@ -1211,6 +1229,7 @@ class AbstractionPattern {
12111229
case Kind::CXXMethodType:
12121230
case Kind::CurriedCXXMethodType:
12131231
case Kind::PartialCurriedCXXMethodType:
1232+
case Kind::CXXFunctionalConstructorType:
12141233
case Kind::OpaqueFunction:
12151234
case Kind::OpaqueDerivativeFunction:
12161235
case Kind::ObjCCompletionHandlerArgumentsType:
@@ -1243,6 +1262,7 @@ class AbstractionPattern {
12431262
case Kind::CFunctionAsMethodType:
12441263
case Kind::CurriedCFunctionAsMethodType:
12451264
case Kind::PartialCurriedCFunctionAsMethodType:
1265+
case Kind::CXXFunctionalConstructorType:
12461266
case Kind::CXXMethodType:
12471267
case Kind::CurriedCXXMethodType:
12481268
case Kind::PartialCurriedCXXMethodType:
@@ -1275,6 +1295,7 @@ class AbstractionPattern {
12751295
case Kind::CXXMethodType:
12761296
case Kind::CurriedCXXMethodType:
12771297
case Kind::PartialCurriedCXXMethodType:
1298+
case Kind::CXXFunctionalConstructorType:
12781299
case Kind::OpaqueFunction:
12791300
case Kind::OpaqueDerivativeFunction:
12801301
case Kind::ObjCCompletionHandlerArgumentsType:
@@ -1306,6 +1327,7 @@ class AbstractionPattern {
13061327
case Kind::CXXMethodType:
13071328
case Kind::CurriedCXXMethodType:
13081329
case Kind::PartialCurriedCXXMethodType:
1330+
case Kind::CXXFunctionalConstructorType:
13091331
case Kind::OpaqueFunction:
13101332
case Kind::OpaqueDerivativeFunction:
13111333
return false;
@@ -1334,6 +1356,7 @@ class AbstractionPattern {
13341356
case Kind::CXXMethodType:
13351357
case Kind::CurriedCXXMethodType:
13361358
case Kind::PartialCurriedCXXMethodType:
1359+
case Kind::CXXFunctionalConstructorType:
13371360
case Kind::OpaqueFunction:
13381361
case Kind::OpaqueDerivativeFunction:
13391362
llvm_unreachable("pattern is not a tuple");
@@ -1414,6 +1437,7 @@ class AbstractionPattern {
14141437
case Kind::CXXMethodType:
14151438
case Kind::CurriedCXXMethodType:
14161439
case Kind::PartialCurriedCXXMethodType:
1440+
case Kind::CXXFunctionalConstructorType:
14171441
case Kind::OpaqueFunction:
14181442
case Kind::OpaqueDerivativeFunction:
14191443
case Kind::ObjCCompletionHandlerArgumentsType:
@@ -1441,6 +1465,7 @@ class AbstractionPattern {
14411465
case Kind::CXXMethodType:
14421466
case Kind::CurriedCXXMethodType:
14431467
case Kind::PartialCurriedCXXMethodType:
1468+
case Kind::CXXFunctionalConstructorType:
14441469
case Kind::OpaqueFunction:
14451470
case Kind::OpaqueDerivativeFunction:
14461471
case Kind::ObjCCompletionHandlerArgumentsType:

include/swift/Strings.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ constexpr static const StringLiteral SWIFT_STRING_PROCESSING_NAME = "_StringProc
3737
constexpr static const StringLiteral SWIFT_SHIMS_NAME = "SwiftShims";
3838
/// The name of the CxxShim module, which contains a cxx casting utility.
3939
constexpr static const StringLiteral CXX_SHIM_NAME = "CxxShim";
40+
/// The name of the CxxStdlibShim module, which contains utilities for the C++ stdlib overlay.
41+
constexpr static const StringLiteral CXX_STDLIB_SHIM_NAME = "CxxStdlibShim";
4042
/// The name of the Cxx module, which contains C++ interop helper protocols.
4143
constexpr static const StringLiteral CXX_MODULE_NAME = "Cxx";
4244
/// The name of the Builtin module, which contains Builtin functions.

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 0 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,94 +1180,6 @@ void swift::conformToCxxVectorIfNeeded(ClangImporter::Implementation &impl,
11801180
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxVector});
11811181
}
11821182

1183-
void swift::conformToCxxFunctionIfNeeded(
1184-
ClangImporter::Implementation &impl, NominalTypeDecl *decl,
1185-
const clang::CXXRecordDecl *clangDecl) {
1186-
PrettyStackTraceDecl trace("conforming to CxxFunction", decl);
1187-
1188-
assert(decl);
1189-
assert(clangDecl);
1190-
ASTContext &ctx = decl->getASTContext();
1191-
clang::ASTContext &clangCtx = impl.getClangASTContext();
1192-
clang::Sema &clangSema = impl.getClangSema();
1193-
1194-
// Only auto-conform types from the C++ standard library. Custom user types
1195-
// might have a similar interface but different semantics.
1196-
if (!isStdDecl(clangDecl, {"function"}))
1197-
return;
1198-
1199-
// There is no typealias for the argument types on the C++ side, so to
1200-
// retrieve the argument types we look at the overload of `operator()` that
1201-
// got imported into Swift.
1202-
1203-
auto callAsFunctionDecl = lookupDirectSingleWithoutExtensions<FuncDecl>(
1204-
decl, ctx.getIdentifier("callAsFunction"));
1205-
if (!callAsFunctionDecl)
1206-
return;
1207-
1208-
auto operatorCallDecl = dyn_cast_or_null<clang::CXXMethodDecl>(
1209-
callAsFunctionDecl->getClangDecl());
1210-
if (!operatorCallDecl)
1211-
return;
1212-
1213-
std::vector<clang::QualType> operatorCallParamTypes;
1214-
llvm::transform(
1215-
operatorCallDecl->parameters(),
1216-
std::back_inserter(operatorCallParamTypes),
1217-
[](const clang::ParmVarDecl *paramDecl) { return paramDecl->getType(); });
1218-
1219-
auto funcPointerType = clangCtx.getPointerType(clangCtx.getFunctionType(
1220-
operatorCallDecl->getReturnType(), operatorCallParamTypes,
1221-
clang::FunctionProtoType::ExtProtoInfo())).withConst();
1222-
1223-
// Create a fake variable with a function type that matches the type of
1224-
// `operator()`.
1225-
auto fakeFuncPointerVarDecl = clang::VarDecl::Create(
1226-
clangCtx, /*DC*/ clangCtx.getTranslationUnitDecl(),
1227-
clang::SourceLocation(), clang::SourceLocation(), /*Id*/ nullptr,
1228-
funcPointerType, clangCtx.getTrivialTypeSourceInfo(funcPointerType),
1229-
clang::StorageClass::SC_None);
1230-
auto fakeFuncPointerRefExpr = new (clangCtx) clang::DeclRefExpr(
1231-
clangCtx, fakeFuncPointerVarDecl, false, funcPointerType,
1232-
clang::ExprValueKind::VK_LValue, clang::SourceLocation());
1233-
1234-
auto clangDeclTyInfo = clangCtx.getTrivialTypeSourceInfo(
1235-
clang::QualType(clangDecl->getTypeForDecl(), 0));
1236-
SmallVector<clang::Expr *, 1> constructExprArgs = {fakeFuncPointerRefExpr};
1237-
1238-
// Instantiate the templated constructor that would accept this fake variable.
1239-
auto constructExprResult = clangSema.BuildCXXTypeConstructExpr(
1240-
clangDeclTyInfo, clangDecl->getLocation(), constructExprArgs,
1241-
clangDecl->getLocation(), /*ListInitialization*/ false);
1242-
if (!constructExprResult.isUsable())
1243-
return;
1244-
1245-
auto castExpr = dyn_cast_or_null<clang::CastExpr>(constructExprResult.get());
1246-
if (!castExpr)
1247-
return;
1248-
1249-
auto bindTempExpr =
1250-
dyn_cast_or_null<clang::CXXBindTemporaryExpr>(castExpr->getSubExpr());
1251-
if (!bindTempExpr)
1252-
return;
1253-
1254-
auto constructExpr =
1255-
dyn_cast_or_null<clang::CXXConstructExpr>(bindTempExpr->getSubExpr());
1256-
if (!constructExpr)
1257-
return;
1258-
1259-
auto constructorDecl = constructExpr->getConstructor();
1260-
1261-
auto importedConstructor =
1262-
impl.importDecl(constructorDecl, impl.CurrentVersion);
1263-
if (!importedConstructor)
1264-
return;
1265-
decl->addMember(importedConstructor);
1266-
1267-
// TODO: actually conform to some form of CxxFunction protocol
1268-
1269-
}
1270-
12711183
void swift::conformToCxxSpanIfNeeded(ClangImporter::Implementation &impl,
12721184
NominalTypeDecl *decl,
12731185
const clang::CXXRecordDecl *clangDecl) {

lib/ClangImporter/ClangDerivedConformances.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,6 @@ void conformToCxxVectorIfNeeded(ClangImporter::Implementation &impl,
7171
NominalTypeDecl *decl,
7272
const clang::CXXRecordDecl *clangDecl);
7373

74-
/// If the decl is an instantiation of C++ `std::function`, synthesize a
75-
/// conformance to CxxFunction, which is defined in the Cxx module.
76-
void conformToCxxFunctionIfNeeded(ClangImporter::Implementation &impl,
77-
NominalTypeDecl *decl,
78-
const clang::CXXRecordDecl *clangDecl);
79-
8074
/// If the decl is an instantiation of C++ `std::span`, synthesize a
8175
/// conformance to CxxSpan, which is defined in the Cxx module.
8276
void conformToCxxSpanIfNeeded(ClangImporter::Implementation &impl,

lib/ClangImporter/ClangImporter.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7929,6 +7929,65 @@ ClangImporter::getDefaultArgGenerator(const clang::ParmVarDecl *param) {
79297929
return nullptr;
79307930
}
79317931

7932+
bool ClangImporter::needsClosureConstructor(
7933+
const clang::CXXRecordDecl *recordDecl) const {
7934+
return Impl.needsClosureConstructor(recordDecl);
7935+
}
7936+
7937+
bool ClangImporter::Implementation::needsClosureConstructor(
7938+
const clang::CXXRecordDecl *recordDecl) const {
7939+
// In the future, this should probably be configurable via a Clang attribute.
7940+
return recordDecl->isInStdNamespace() && recordDecl->getIdentifier() &&
7941+
recordDecl->getName() == "function";
7942+
}
7943+
7944+
bool ClangImporter::isSwiftFunctionWrapper(
7945+
const clang::RecordDecl *decl) const {
7946+
return Impl.isSwiftFunctionWrapper(decl);
7947+
}
7948+
7949+
bool ClangImporter::Implementation::isSwiftFunctionWrapper(
7950+
const clang::RecordDecl *decl) const {
7951+
return decl->getIdentifier() && decl->getName() == "__SwiftFunctionWrapper";
7952+
}
7953+
7954+
bool ClangImporter::isDeconstructedSwiftClosure(const clang::Type *type) const {
7955+
auto recordDecl = type->getAsCXXRecordDecl();
7956+
return recordDecl && recordDecl->getIdentifier() &&
7957+
recordDecl->getName() == "__swift_interop_closure";
7958+
}
7959+
7960+
const clang::FunctionType *ClangImporter::extractCXXFunctionType(
7961+
const clang::CXXRecordDecl *functionalTypeDecl) const {
7962+
auto &clangCtx = functionalTypeDecl->getASTContext();
7963+
auto operatorCallName = clangCtx.DeclarationNames.getCXXOperatorName(
7964+
clang::OverloadedOperatorKind::OO_Call);
7965+
auto lookupResult = functionalTypeDecl->lookup(operatorCallName);
7966+
const clang::CXXMethodDecl *operatorCall = nullptr;
7967+
// If an overload of operator() was found, assert that it's a single one.
7968+
if (!lookupResult.empty()) {
7969+
ASSERT(lookupResult.isSingleResult() &&
7970+
"expected single operator() in a functional type");
7971+
operatorCall = cast<clang::CXXMethodDecl>(lookupResult.front());
7972+
} else {
7973+
// If no overload if operator() was found, there could be a viable overload
7974+
// in one of the base types. For instance, std::function in Microsoft STL
7975+
// exposes an operator() from a base type.
7976+
functionalTypeDecl->forallBases([&](const clang::CXXRecordDecl *base) {
7977+
auto baseResult = base->lookup(operatorCallName);
7978+
if (!baseResult.empty()) {
7979+
ASSERT(!operatorCall && "expected single operator() across base types");
7980+
ASSERT(baseResult.isSingleResult() &&
7981+
"expected single operator() in a base type");
7982+
operatorCall = cast<clang::CXXMethodDecl>(baseResult.front());
7983+
}
7984+
return true;
7985+
});
7986+
ASSERT(operatorCall && "expected operator() in one of the bases");
7987+
}
7988+
return operatorCall->getType()->getAs<clang::FunctionType>();
7989+
}
7990+
79327991
FuncDecl *
79337992
ClangImporter::getAvailabilityDomainPredicate(const clang::VarDecl *var) {
79347993
auto it = Impl.availabilityDomainPredicates.find(var);

0 commit comments

Comments
 (0)