Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions include/swift/Sema/CSBindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -592,16 +592,27 @@ class BindingSet {
void dump(llvm::raw_ostream &out, unsigned indent) const;

private:
/// Add a new binding to the set.
/// Introduce a new binding to the set. The binding might not
/// actually be added due to subtyping or other rule like
/// CGFloat/Double implicit conversion. This method should be
/// be preferred over \c addBinding when adding new bindings.
///
/// \param binding The binding to add.
/// \param isTransitive Indicates whether this binding has been
/// acquired through transitive inference and requires validity
/// checking.
void addBinding(PotentialBinding binding, bool isTransitive);
void introduceBinding(PotentialBinding binding, bool isTransitive);

void addLiteralRequirement(Constraint *literal);

/// Insert the given binding into \c Bindings.
///
/// This method is going to compute referenced variables before
/// forwarding to the other overload.
void addBinding(const PotentialBinding &&binding);
void addBinding(const PotentialBinding &&binding,
llvm::SmallPtrSetImpl<TypeVariableType *> &referencedVars);

void addDefault(Constraint *constraint);

StringRef getLiteralBindingKind(LiteralBindingKind K) const {
Expand Down
85 changes: 74 additions & 11 deletions lib/Sema/CSBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ BindingSet::BindingSet(ConstraintSystem &CS, TypeVariableType *TypeVar,
: CS(CS), TypeVar(TypeVar), Info(info) {

for (const auto &binding : info.Bindings)
addBinding(binding, /*isTransitive=*/false);
introduceBinding(binding, /*isTransitive=*/false);

for (auto *constraint : info.Constraints) {
switch (constraint->getKind()) {
Expand Down Expand Up @@ -596,7 +596,7 @@ void BindingSet::inferTransitiveKeyPathBindings() {

// Copy the bindings over to the root.
for (const auto &binding : bindings.Bindings)
addBinding(binding, /*isTransitive=*/true);
introduceBinding(binding, /*isTransitive=*/true);

// Make a note that the key path root is transitively adjacent
// to contextual root type variable and all of its variables.
Expand All @@ -606,7 +606,7 @@ void BindingSet::inferTransitiveKeyPathBindings() {
bindings.AdjacentVars.end());
}
} else {
addBinding(
introduceBinding(
binding.withSameSource(inferredRootTy, AllowedBindingKind::Exact),
/*isTransitive=*/true);
}
Expand Down Expand Up @@ -679,7 +679,7 @@ void BindingSet::inferTransitiveSupertypeBindings() {
if (ConstraintSystem::typeVarOccursInType(TypeVar, type))
continue;

addBinding(binding.withSameSource(type, AllowedBindingKind::Supertypes),
introduceBinding(binding.withSameSource(type, AllowedBindingKind::Supertypes),
/*isTransitive=*/true);
}
}
Expand Down Expand Up @@ -713,7 +713,7 @@ void BindingSet::inferTransitiveUnresolvedMemberRefBindings() {
continue;
}

addBinding({protocolTy, AllowedBindingKind::Exact, constraint},
introduceBinding({protocolTy, AllowedBindingKind::Exact, constraint},
/*isTransitive=*/false);
}
}
Expand Down Expand Up @@ -889,7 +889,22 @@ void BindingSet::finalizeUnresolvedMemberChainResult() {
}
}

void BindingSet::addBinding(PotentialBinding binding, bool isTransitive) {
void BindingSet::addBinding(const PotentialBinding &&binding) {
SmallPtrSet<TypeVariableType *, 4> referencedVars;
binding.BindingType->getTypeVariables(referencedVars);

addBinding(std::move(binding), referencedVars);
}

void BindingSet::addBinding(const PotentialBinding &&binding,
SmallPtrSetImpl<TypeVariableType *> &referencedVars) {
for (auto *adjacentVar : referencedVars)
AdjacentVars.insert(adjacentVar);

(void)Bindings.insert(binding);
}

void BindingSet::introduceBinding(PotentialBinding binding, bool isTransitive) {
if (Bindings.count(binding))
return;

Expand Down Expand Up @@ -944,6 +959,57 @@ void BindingSet::addBinding(PotentialBinding binding, bool isTransitive) {
}
}

// If the type variable prefers subtypes, diasambiguate a situation
// when this type variable is simultaneously a supertype of `@Sendable`
// function type and a subtype of a non-Sendable one by using a supertype
// binding because it constitutes a "subtype" in this case.
//
// For example:
//
// @Sendable () -> Void conv $T
// $T argument conv () -> Void
//
// Either of the types could also be wrapped in a number of optionals. Even if
// there is an optionality mismatch, let's still prefer a supertype binding
// because that would be easier to diagnose.
//
// In particular, this is helpful with ternary operators where the context is
// non-Sendable, but one or both sides are.
if (TypeVar->getImpl().prefersSubtypeBinding()) {
if (auto *funcType = binding.BindingType->lookThroughAllOptionalTypes()
->getAs<FunctionType>()) {
if (binding.Kind == AllowedBindingKind::Supertypes &&
funcType->isSendable()) {
// Note that we are removing the bindings but leaving AdjacentVars
// intact to make sure that this doesn't affect assessment of the
// binding set i.e. \c involvesTypeVariables.
Bindings.remove_if([](const PotentialBinding &existing) {
if (existing.Kind != AllowedBindingKind::Subtypes)
return false;

auto *existingFn = existing.BindingType->lookThroughAllOptionalTypes()
->getAs<FunctionType>();
return existingFn && !existingFn->isSendable();
});
}

// If there are existing `@Sendable` supertype bindings, we can skip this
// one.
if (binding.Kind == AllowedBindingKind::Subtypes &&
!funcType->isSendable()) {
if (llvm::any_of(Bindings, [](const PotentialBinding &existing) {
if (existing.Kind != AllowedBindingKind::Supertypes)
return false;
auto *existingFn =
existing.BindingType->lookThroughAllOptionalTypes()
->getAs<FunctionType>();
return existingFn && existingFn->isSendable();
}))
return;
}
}
}

// If this is a non-defaulted supertype binding,
// check whether we can combine it with another
// supertype binding by computing the 'join' of the types.
Expand Down Expand Up @@ -976,18 +1042,15 @@ void BindingSet::addBinding(PotentialBinding binding, bool isTransitive) {
}

for (const auto &binding : joined)
(void)Bindings.insert(binding);
addBinding(std::move(binding));

// If new binding has been joined with at least one of existing
// bindings, there is no reason to include it into the set.
if (!joined.empty())
return;
}

for (auto *adjacentVar : referencedTypeVars)
AdjacentVars.insert(adjacentVar);

(void)Bindings.insert(std::move(binding));
addBinding(std::move(binding), referencedTypeVars);
}

void BindingSet::determineLiteralCoverage() {
Expand Down
11 changes: 7 additions & 4 deletions test/Concurrency/sendable_keypaths.swift
Original file line number Diff line number Diff line change
Expand Up @@ -247,16 +247,19 @@ do {
static func otherFn() {}
}

// TODO(rdar://125948508): This shouldn't be ambiguous (@Sendable version should be preferred)
func fnRet(cond: Bool) -> () -> Void {
cond ? Test.fn : Test.otherFn // expected-error {{failed to produce diagnostic for expression}}
cond ? Test.fn : Test.otherFn // Ok
}

func forward<T>(_: T) -> T {
}

// TODO(rdar://125948508): This shouldn't be ambiguous (@Sendable version should be preferred)
let _: () -> Void = forward(Test.fn) // expected-error {{conflicting arguments to generic parameter 'T' ('@Sendable () -> ()' vs. '() -> Void')}}
let _: () -> Void = forward(Test.fn) // Ok

func test(fn1: (@Sendable () -> Void)?, fn2: @escaping () -> Void) {
let _: () -> Void = true ? fn1 : fn2
// expected-error@-1 {{cannot convert value of type '(@Sendable () -> Void)?' to specified type '() -> Void'}}
}
}

// https://github.com/swiftlang/swift/issues/77105
Expand Down
19 changes: 18 additions & 1 deletion test/Concurrency/sendable_methods.swift
Original file line number Diff line number Diff line change
Expand Up @@ -331,4 +331,21 @@ do {

static func ff() {}
}
}
}

// Ambiguity between `@Sendable` method and non-Sendable context injected into an Optional.
do {
struct Test {
func action() -> Void {}

func onAction(_: (() -> Void)?) {}

func test() {
onAction(true ? action : nil) // Ok
}

func test(fn1: (@Sendable () -> Void)?, fn2: @escaping () -> Void) {
let _: () -> Void = fn1 ?? fn2 // Ok
}
}
}