From 01ca51903bc5bb46f28ce674e440e2865c525b31 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Sat, 13 Dec 2025 10:14:57 -0500 Subject: [PATCH 01/10] Sema: Sink Protocols down from BindingSet into PotentialBindings --- include/swift/Sema/CSBindings.h | 14 ++--- include/swift/Sema/CSTrail.def | 1 + lib/Sema/CSBindings.cpp | 74 +++++++++++++----------- lib/Sema/CSTrail.cpp | 5 ++ unittests/Sema/BindingInferenceTests.cpp | 8 ++- 5 files changed, 59 insertions(+), 43 deletions(-) diff --git a/include/swift/Sema/CSBindings.h b/include/swift/Sema/CSBindings.h index 6fada3ed98053..95fd17d2d1d51 100644 --- a/include/swift/Sema/CSBindings.h +++ b/include/swift/Sema/CSBindings.h @@ -241,6 +241,9 @@ struct PotentialBindings { llvm::SmallVector, 4> SupertypeOf; llvm::SmallVector, 4> EquivalentTo; + /// The set of protocol conformance requirements imposed on this type variable. + llvm::SmallVector Protocols; + ASTNode AssociatedCodeCompletionToken = ASTNode(); /// Add a potential binding to the list of bindings, @@ -256,6 +259,10 @@ struct PotentialBindings { }); } + ArrayRef getConformanceRequirements() const { + return Protocols; + } + private: /// Attempt to infer a new binding and other useful information /// (i.e. whether bindings should be delayed) from the given @@ -365,9 +372,6 @@ class BindingSet { public: swift::SmallSetVector Bindings; - /// The set of protocol conformance requirements placed on this type variable. - llvm::SmallVector Protocols; - /// The set of unique literal protocol requirements placed on this /// type variable or inferred transitively through subtype chains. /// @@ -494,10 +498,6 @@ class BindingSet { return hasViableBindings() || isDirectHole(); } - ArrayRef getConformanceRequirements() const { - return Protocols; - } - unsigned getNumViableLiteralBindings() const; unsigned getNumViableDefaultableBindings() const { diff --git a/include/swift/Sema/CSTrail.def b/include/swift/Sema/CSTrail.def index fa16ad1baf5e5..fa8bb1698f98c 100644 --- a/include/swift/Sema/CSTrail.def +++ b/include/swift/Sema/CSTrail.def @@ -77,6 +77,7 @@ GRAPH_NODE_CHANGE(RemovedConstraint) GRAPH_NODE_CHANGE(InferredBindings) GRAPH_NODE_CHANGE(RetractedBindings) GRAPH_NODE_CHANGE(RetractedDelayedBy) +GRAPH_NODE_CHANGE(RetractedProtocol) BINDING_RELATION_CHANGE(RetractedAdjacentVar) BINDING_RELATION_CHANGE(RetractedSubtypeOf) diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index 8e900fbb25ccf..7b32070d5ac4b 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -47,12 +47,6 @@ BindingSet::BindingSet(ConstraintSystem &CS, TypeVariableType *TypeVar, for (auto *constraint : info.Constraints) { switch (constraint->getKind()) { - case ConstraintKind::NonisolatedConformsTo: - case ConstraintKind::ConformsTo: - if (constraint->getSecondType()->is()) - Protocols.push_back(constraint); - break; - case ConstraintKind::LiteralConformsTo: addLiteralRequirement(constraint); break; @@ -435,6 +429,8 @@ void BindingSet::inferTransitiveProtocolRequirements() { } auto &bindings = node.getBindingSet(); + auto conformanceReqs = + node.getPotentialBindings().getConformanceRequirements(); // If current variable already has transitive protocol // conformances inferred, there is no need to look deeper @@ -443,8 +439,8 @@ void BindingSet::inferTransitiveProtocolRequirements() { TypeVariableType *parent = nullptr; std::tie(parent, currentVar) = workList.pop_back_val(); assert(parent); - propagateProtocolsTo(parent, bindings.getConformanceRequirements(), - *bindings.TransitiveProtocols); + propagateProtocolsTo(parent, conformanceReqs, + *bindings.TransitiveProtocols); continue; } @@ -485,14 +481,16 @@ void BindingSet::inferTransitiveProtocolRequirements() { if (!node.hasBindingSet()) continue; - const auto &bindings = node.getBindingSet(); + auto conformanceReqs = + node.getPotentialBindings().getConformanceRequirements(); llvm::SmallPtrSet placeholder; // Add any direct protocols from members of the // equivalence class, so they could be propagated // to all of the members. - propagateProtocolsTo(currentVar, bindings.getConformanceRequirements(), - placeholder); + propagateProtocolsTo(currentVar, conformanceReqs, placeholder); + + const auto &bindings = node.getBindingSet(); // Since type variables are equal, current type variable // becomes a subtype to any supertype found in the current @@ -512,8 +510,7 @@ void BindingSet::inferTransitiveProtocolRequirements() { // are transitive to its parent, propagate them down the subtype/equivalence // chain. if (parent) { - propagateProtocolsTo(parent, bindings.getConformanceRequirements(), - protocols[currentVar]); + propagateProtocolsTo(parent, conformanceReqs, protocols[currentVar]); } auto &inferredProtocols = protocols[currentVar]; @@ -526,9 +523,8 @@ void BindingSet::inferTransitiveProtocolRequirements() { // - all of the transitive protocols inferred through // the members of the equivalence class. { - auto directRequirements = bindings.getConformanceRequirements(); - protocolsForEquivalence.insert(directRequirements.begin(), - directRequirements.end()); + protocolsForEquivalence.insert(conformanceReqs.begin(), + conformanceReqs.end()); protocolsForEquivalence.insert(inferredProtocols.begin(), inferredProtocols.end()); @@ -2063,6 +2059,12 @@ void PotentialBindings::infer(ConstraintSystem &CS, break; } + case ConstraintKind::NonisolatedConformsTo: + case ConstraintKind::ConformsTo: + if (constraint->getSecondType()->is()) + Protocols.push_back(constraint); + break; + case ConstraintKind::BridgingConversion: case ConstraintKind::CheckedCast: case ConstraintKind::EscapableFunctionOf: @@ -2076,8 +2078,6 @@ void PotentialBindings::infer(ConstraintSystem &CS, case ConstraintKind::PackElementOf: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: - case ConstraintKind::NonisolatedConformsTo: - case ConstraintKind::ConformsTo: case ConstraintKind::LiteralConformsTo: case ConstraintKind::Defaultable: case ConstraintKind::FallbackType: @@ -2206,21 +2206,27 @@ void PotentialBindings::retract(ConstraintSystem &CS, }), Bindings.end()); +#define CALLBACK(ChangeKind) \ + [&](Constraint *other) { \ + if (other == constraint) { \ + if (recordingChanges) { \ + CS.recordChange(SolverTrail::Change::ChangeKind( \ + TypeVar, constraint)); \ + } \ + return true; \ + } \ + return false; \ + } + DelayedBy.erase( - llvm::remove_if(DelayedBy, - [&](Constraint *existing) { - if (existing == constraint) { - if (recordingChanges) { - CS.recordChange(SolverTrail::Change::RetractedDelayedBy( - TypeVar, constraint)); - } - return true; - } - return false; - }), + llvm::remove_if(DelayedBy, CALLBACK(RetractedDelayedBy)), DelayedBy.end()); -#define CALLBACK(ChangeKind) \ + Protocols.erase( + llvm::remove_if(Protocols, CALLBACK(RetractedProtocol)), + Protocols.end()); + +#define PAIR_CALLBACK(ChangeKind) \ [&](std::pair pair) { \ if (pair.second == constraint) { \ if (recordingChanges) { \ @@ -2233,19 +2239,19 @@ void PotentialBindings::retract(ConstraintSystem &CS, } AdjacentVars.erase( - llvm::remove_if(AdjacentVars, CALLBACK(RetractedAdjacentVar)), + llvm::remove_if(AdjacentVars, PAIR_CALLBACK(RetractedAdjacentVar)), AdjacentVars.end()); SubtypeOf.erase( - llvm::remove_if(SubtypeOf, CALLBACK(RetractedSubtypeOf)), + llvm::remove_if(SubtypeOf, PAIR_CALLBACK(RetractedSubtypeOf)), SubtypeOf.end()); SupertypeOf.erase( - llvm::remove_if(SupertypeOf, CALLBACK(RetractedSupertypeOf)), + llvm::remove_if(SupertypeOf, PAIR_CALLBACK(RetractedSupertypeOf)), SupertypeOf.end()); EquivalentTo.erase( - llvm::remove_if(EquivalentTo, CALLBACK(RetractedEquivalentTo)), + llvm::remove_if(EquivalentTo, PAIR_CALLBACK(RetractedEquivalentTo)), EquivalentTo.end()); #undef CALLBACK diff --git a/lib/Sema/CSTrail.cpp b/lib/Sema/CSTrail.cpp index 61b1a77e1e0f8..688afdb8eb303 100644 --- a/lib/Sema/CSTrail.cpp +++ b/lib/Sema/CSTrail.cpp @@ -538,6 +538,11 @@ void SolverTrail::Change::undo(ConstraintSystem &cs) const { .DelayedBy.push_back(TheConstraint.Constraint); break; + case ChangeKind::RetractedProtocol: + cg[TheConstraint.TypeVar].getPotentialBindings() + .Protocols.push_back(TheConstraint.Constraint); + break; + case ChangeKind::RetractedAdjacentVar: cg[BindingRelation.TypeVar].getPotentialBindings() .AdjacentVars.emplace_back(BindingRelation.OtherTypeVar, diff --git a/unittests/Sema/BindingInferenceTests.cpp b/unittests/Sema/BindingInferenceTests.cpp index 3791eef8bccd2..2dcbb710fd7d0 100644 --- a/unittests/Sema/BindingInferenceTests.cpp +++ b/unittests/Sema/BindingInferenceTests.cpp @@ -197,7 +197,9 @@ TEST_F(SemaTest, TestTransitiveProtocolInference) { CTP_Initialization))); auto &bindings = inferBindings(cs, typeVar); - ASSERT_TRUE(bindings.getConformanceRequirements().empty()); + ASSERT_TRUE(cs.getConstraintGraph()[typeVar] + .getPotentialBindings().getConformanceRequirements().empty()); + ASSERT_TRUE(bool(bindings.TransitiveProtocols)); verifyProtocolInferenceResults(*bindings.TransitiveProtocols, {protocolTy1}); @@ -218,8 +220,10 @@ TEST_F(SemaTest, TestTransitiveProtocolInference) { cs.addConstraint(ConstraintKind::Conversion, typeVar, GPT1, cs.getConstraintLocator({})); + ASSERT_TRUE(cs.getConstraintGraph()[typeVar] + .getPotentialBindings().getConformanceRequirements().empty()); + auto &bindings = inferBindings(cs, typeVar); - ASSERT_TRUE(bindings.getConformanceRequirements().empty()); ASSERT_TRUE(bool(bindings.TransitiveProtocols)); verifyProtocolInferenceResults(*bindings.TransitiveProtocols, {protocolTy1, protocolTy2}); From 6384c27e8e2d61835a7e0dbeb4d80a463de95cfa Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Sat, 13 Dec 2025 11:37:06 -0500 Subject: [PATCH 02/10] Sema: BindingSet::Defaults can just be a vector and not a map The only place where we did a lookup, we also iterated over it anyway, and all remaining usages are simplified by downgrading it to a vector. --- include/swift/Sema/CSBindings.h | 16 +++++++++------- lib/Sema/CSBindings.cpp | 31 ++++++++++++++++++------------- lib/Sema/ConstraintSystem.cpp | 6 ++---- 3 files changed, 29 insertions(+), 24 deletions(-) diff --git a/include/swift/Sema/CSBindings.h b/include/swift/Sema/CSBindings.h index 95fd17d2d1d51..254b3f1029ac5 100644 --- a/include/swift/Sema/CSBindings.h +++ b/include/swift/Sema/CSBindings.h @@ -380,7 +380,7 @@ class BindingSet { /// before transitive ones. llvm::SmallMapVector Literals; - llvm::SmallDenseMap Defaults; + llvm::SmallVector Defaults; /// The set of transitive protocol requirements inferred through /// subtype/conversion/equivalence relations with other type variables. @@ -505,8 +505,8 @@ class BindingSet { return 1; auto numDefaultable = llvm::count_if( - Defaults, [](const std::pair &entry) { - return entry.second->getKind() == ConstraintKind::Defaultable; + Defaults, [](Constraint *constraint) { + return constraint->getKind() == ConstraintKind::Defaultable; }); // Short-circuit unviable checks if there are no defaultable bindings. @@ -518,10 +518,12 @@ class BindingSet { auto unviable = llvm::count_if(Bindings, [&](const PotentialBinding &binding) { auto type = binding.BindingType->getCanonicalType(); - auto def = Defaults.find(type); - return def != Defaults.end() - ? def->second->getKind() == ConstraintKind::Defaultable - : false; + for (auto *constraint : Defaults) { + if (constraint->getSecondType()->isEqual(type)) { + return constraint->getKind() == ConstraintKind::Defaultable; + } + } + return false; }); assert(numDefaultable >= unviable); diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index 7b32070d5ac4b..cc7933a91550d 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -647,11 +647,11 @@ void BindingSet::inferTransitiveSupertypeBindings() { addLiteralRequirement(literal.second.getSource()); // Infer transitive defaults. - for (const auto &def : bindings.Defaults) { - if (def.getSecond()->getKind() == ConstraintKind::FallbackType) + for (auto *def : bindings.Defaults) { + if (def->getKind() == ConstraintKind::FallbackType) continue; - addDefault(def.second); + addDefault(def); } // TODO: We shouldn't need this in the future. @@ -838,11 +838,13 @@ bool BindingSet::finalizeKeyPathBindings() { (keyPath->getParsedRoot() || (fixedRootTy && !fixedRootTy->isTypeVariableOrMember()))) { auto fallback = llvm::find_if(Defaults, [](const auto &entry) { - return entry.second->getKind() == ConstraintKind::FallbackType; + return entry->getKind() == ConstraintKind::FallbackType; }); assert(fallback != Defaults.end()); updatedBindings.insert( - {fallback->first, AllowedBindingKind::Exact, fallback->second}); + {(*fallback)->getSecondType(), + AllowedBindingKind::Exact, + *fallback}); } else { updatedBindings.insert(PotentialBinding::forHole( TypeVar, CS.getConstraintLocator( @@ -1106,10 +1108,10 @@ bool BindingSet::operator==(const BindingSet &other) { if (Defaults.size() != other.Defaults.size()) return false; - for (auto pair : Defaults) { - auto found = other.Defaults.find(pair.first); - if (found == other.Defaults.end() || - pair.second != found->second) + for (auto i : indices(Defaults)) { + auto *x = Defaults[i]; + auto *y = other.Defaults[i]; + if (x != y) return false; } @@ -1271,8 +1273,12 @@ findInferableTypeVars(Type type, } void BindingSet::addDefault(Constraint *constraint) { - auto defaultTy = constraint->getSecondType(); - Defaults.insert({defaultTy->getCanonicalType(), constraint}); + if (CONDITIONAL_ASSERT_enabled()) { + for (auto *other : Defaults) { + ASSERT(other != constraint); + } + } + Defaults.push_back(constraint); } bool LiteralRequirement::isCoveredBy(Type type, ConstraintSystem &CS) const { @@ -2507,8 +2513,7 @@ void BindingSet::dump(llvm::raw_ostream &out, unsigned indent) const { out << " [defaults: "; interleave( Defaults, - [&](const auto &entry) { - auto *constraint = entry.second; + [&](Constraint *constraint) { auto defaultBinding = PrintableBinding::exact(constraint->getSecondType()); defaultBinding.print(out, PO); diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index b0d0659f09127..63c743e478e80 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -5384,8 +5384,7 @@ TypeVarBindingProducer::TypeVarBindingProducer(const BindingSet &bindings) if (viableBindings.size() == 1) { addBinding(viableBindings.front()); } else { - for (const auto &entry : bindings.Defaults) { - auto *constraint = entry.second; + for (auto *constraint : bindings.Defaults) { Bindings.push_back(getDefaultBinding(constraint)); } } @@ -5423,8 +5422,7 @@ TypeVarBindingProducer::TypeVarBindingProducer(const BindingSet &bindings) { bool noBindings = Bindings.empty(); - for (const auto &entry : bindings.Defaults) { - auto *constraint = entry.second; + for (auto *constraint : bindings.Defaults) { if (noBindings) { // If there are no direct or transitive bindings to attempt // let's add defaults to the list right away. From 36c7263d261725f54c2aa736db6cd543ac82f304 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Sat, 13 Dec 2025 11:46:16 -0500 Subject: [PATCH 03/10] Sema: Factor out isDirectRequirement() from two places that check this --- lib/Sema/CSBindings.cpp | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index cc7933a91550d..39e6814472fcc 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -38,6 +38,16 @@ void ConstraintGraphNode::initBindingSet() { Set.emplace(CG.getConstraintSystem(), TypeVar, Potential); } +static bool isDirectRequirement(ConstraintSystem &cs, + TypeVariableType *typeVar, + Constraint *constraint) { + if (auto *other = constraint->getFirstType()->getAs()) { + return typeVar == cs.getRepresentative(other); + } + + return false; +} + BindingSet::BindingSet(ConstraintSystem &CS, TypeVariableType *TypeVar, const PotentialBindings &info) : CS(CS), TypeVar(TypeVar), Info(info) { @@ -54,10 +64,8 @@ BindingSet::BindingSet(ConstraintSystem &CS, TypeVariableType *TypeVar, case ConstraintKind::Defaultable: case ConstraintKind::FallbackType: // Do these in a separate pass. - if (CS.getFixedTypeRecursive(constraint->getFirstType(), true) - ->getAs() == TypeVar) { + if (isDirectRequirement(CS, TypeVar, constraint)) addDefault(constraint); - } break; default: @@ -1049,16 +1057,7 @@ void BindingSet::addLiteralRequirement(Constraint *constraint) { if (Literals.count(protocol) > 0) return; - auto isDirectRequirement = [&](Constraint *constraint) -> bool { - if (auto *typeVar = constraint->getFirstType()->getAs()) { - auto *repr = CS.getRepresentative(typeVar); - return repr == TypeVar; - } - - return false; - }; - - bool isDirect = isDirectRequirement(constraint); + bool isDirect = isDirectRequirement(CS, TypeVar, constraint); Type defaultType; // `ExpressibleByNilLiteral` doesn't have a default type. From 88d04b5baa1a78f14c92b6aa70c65780f8c69654 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Sat, 13 Dec 2025 13:08:25 -0500 Subject: [PATCH 04/10] Sema: Extract coalesceIntegerAndFloatLiteralRequirements() from BindingSet::BindingSet, and do it at the end --- include/swift/Sema/CSBindings.h | 5 +++++ lib/Sema/CSBindings.cpp | 22 +++++++++++++++------- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/include/swift/Sema/CSBindings.h b/include/swift/Sema/CSBindings.h index 254b3f1029ac5..03797b994585d 100644 --- a/include/swift/Sema/CSBindings.h +++ b/include/swift/Sema/CSBindings.h @@ -574,6 +574,11 @@ class BindingSet { /// requirements down the subtype or equivalence chain. void inferTransitiveProtocolRequirements(); + /// Try to coalesce integer and floating point literal protocols + /// if they appear together because the only possible default type that + /// could satisfy both requirements is `Double`. + void coalesceIntegerAndFloatLiteralRequirements(); + /// Check whether the given binding set covers any of the literal protocols /// associated with this type variable. The idea is that if a type variable /// has a binding like Int and also it has a conformance requirement to diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index 39e6814472fcc..bd8abb3bb9796 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -1031,19 +1031,18 @@ void BindingSet::determineLiteralCoverage() { } } -void BindingSet::addLiteralRequirement(Constraint *constraint) { - auto *protocol = constraint->getProtocol(); +void BindingSet::coalesceIntegerAndFloatLiteralRequirements() { + for (const auto &pair : Literals) { + auto *protocol = pair.first; - // Let's try to coalesce integer and floating point literal protocols - // if they appear together because the only possible default type that - // could satisfy both requirements is `Double`. - { if (protocol->isSpecificProtocol( KnownProtocolKind::ExpressibleByIntegerLiteral)) { auto *floatLiteral = CS.getASTContext().getProtocol( KnownProtocolKind::ExpressibleByFloatLiteral); - if (Literals.count(floatLiteral)) + if (Literals.count(floatLiteral)) { + Literals.erase(protocol); return; + } } if (protocol->isSpecificProtocol( @@ -1051,8 +1050,13 @@ void BindingSet::addLiteralRequirement(Constraint *constraint) { auto *intLiteral = CS.getASTContext().getProtocol( KnownProtocolKind::ExpressibleByIntegerLiteral); Literals.erase(intLiteral); + return; } } +} + +void BindingSet::addLiteralRequirement(Constraint *constraint) { + auto *protocol = constraint->getProtocol(); if (Literals.count(protocol) > 0) return; @@ -1235,6 +1239,9 @@ const BindingSet *ConstraintSystem::determineBestBindings( bestBindings = &bindings; } + if (bestBindings) + bestBindings->coalesceIntegerAndFloatLiteralRequirements(); + return bestBindings; } @@ -1672,6 +1679,7 @@ BindingSet ConstraintSystem::getBindingsFor(TypeVariableType *typeVar) { (void) bindings.finalizeKeyPathBindings(); bindings.finalizeUnresolvedMemberChainResult(); bindings.determineLiteralCoverage(); + bindings.coalesceIntegerAndFloatLiteralRequirements(); return bindings; } From 6c1593a01506f5a725ee983295f75e434f602ef9 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Sat, 13 Dec 2025 13:15:01 -0500 Subject: [PATCH 05/10] Sema: Stash the ProtocolDecl inside the LiteralRequirement --- include/swift/Sema/CSBindings.h | 10 +++++++--- lib/Sema/CSBindings.cpp | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/include/swift/Sema/CSBindings.h b/include/swift/Sema/CSBindings.h index 03797b994585d..3a38993f18558 100644 --- a/include/swift/Sema/CSBindings.h +++ b/include/swift/Sema/CSBindings.h @@ -152,6 +152,8 @@ struct PotentialBinding { }; struct LiteralRequirement { + /// The literal protocol. + ProtocolDecl *Protocol; /// The source of the literal requirement. Constraint *Source; /// The default type associated with this literal (if any). @@ -164,12 +166,14 @@ struct LiteralRequirement { /// this points to the source of the binding. mutable Constraint *CoveredBy = nullptr; - LiteralRequirement(Constraint *source, Type defaultTy, bool isDirect) - : Source(source), DefaultType(defaultTy), IsDirectRequirement(isDirect) {} + LiteralRequirement(ProtocolDecl *protocol, Constraint *source, + Type defaultTy, bool isDirect) + : Protocol(protocol), Source(source), DefaultType(defaultTy), + IsDirectRequirement(isDirect) {} Constraint *getSource() const { return Source; } - ProtocolDecl *getProtocol() const { return Source->getProtocol(); } + ProtocolDecl *getProtocol() const { return Protocol; } bool isCovered() const { return bool(CoveredBy); } diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index bd8abb3bb9796..ff69e631d74db 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -1070,7 +1070,7 @@ void BindingSet::addLiteralRequirement(Constraint *constraint) { defaultType = TypeChecker::getDefaultType(protocol, CS.DC); } - LiteralRequirement literal(constraint, defaultType, isDirect); + LiteralRequirement literal(protocol, constraint, defaultType, isDirect); Literals.insert({protocol, std::move(literal)}); } From 74fa1e7f5e1c02517e74bdb4b09585ec3348aa7d Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Sat, 13 Dec 2025 13:27:12 -0500 Subject: [PATCH 06/10] Sema: BindingSet::Literals can just be a vector and not a map --- include/swift/Sema/CSBindings.h | 4 +- lib/Sema/CSBindings.cpp | 75 +++++++++++------------- lib/Sema/CSOptimizer.cpp | 4 +- lib/Sema/CSSolver.cpp | 4 +- lib/Sema/ConstraintSystem.cpp | 4 +- unittests/Sema/BindingInferenceTests.cpp | 8 +-- 6 files changed, 46 insertions(+), 53 deletions(-) diff --git a/include/swift/Sema/CSBindings.h b/include/swift/Sema/CSBindings.h index 3a38993f18558..9f31b2c979352 100644 --- a/include/swift/Sema/CSBindings.h +++ b/include/swift/Sema/CSBindings.h @@ -382,7 +382,7 @@ class BindingSet { /// Note that ordering is important when it comes to bindings, we'd /// like to add any "direct" default types first to attempt them /// before transitive ones. - llvm::SmallMapVector Literals; + llvm::SmallVector Literals; llvm::SmallVector Defaults; @@ -465,7 +465,7 @@ class BindingSet { // Literal requirements always result in a subtype/supertype // relationship to a concrete type. if (llvm::any_of(Literals, [](const auto &literal) { - return literal.second.viableAsBinding(); + return literal.viableAsBinding(); })) return false; diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index ff69e631d74db..1ebc1305dcae4 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -86,9 +86,12 @@ bool BindingSet::forGenericParameter() const { } bool BindingSet::canBeNil() const { - auto &ctx = CS.getASTContext(); - return Literals.count( - ctx.getProtocol(KnownProtocolKind::ExpressibleByNilLiteral)); + for (const auto &literal : Literals) { + if (literal.getProtocol()->isSpecificProtocol( + KnownProtocolKind::ExpressibleByNilLiteral)) + return true; + } + return false; } bool BindingSet::isDirectHole() const { @@ -652,7 +655,7 @@ void BindingSet::inferTransitiveSupertypeBindings() { // `ExpressibleByStringLiteral` conformance, we'd end up picking // `T` with only one type `Any?` which is incorrect. for (const auto &literal : bindings.Literals) - addLiteralRequirement(literal.second.getSource()); + addLiteralRequirement(literal.getSource()); // Infer transitive defaults. for (auto *def : bindings.Defaults) { @@ -1002,9 +1005,7 @@ void BindingSet::determineLiteralCoverage() { bool allowsNil = canBeNil(); - for (auto &entry : Literals) { - auto &literal = entry.second; - + for (auto &literal : Literals) { if (!literal.viableAsBinding()) continue; @@ -1032,34 +1033,36 @@ void BindingSet::determineLiteralCoverage() { } void BindingSet::coalesceIntegerAndFloatLiteralRequirements() { - for (const auto &pair : Literals) { - auto *protocol = pair.first; + decltype(Literals)::iterator intLiteral = Literals.end(); + decltype(Literals)::iterator floatLiteral = Literals.end(); + + for (auto iter = Literals.begin(); iter != Literals.end(); ++iter) { + auto *protocol = iter->getProtocol(); if (protocol->isSpecificProtocol( KnownProtocolKind::ExpressibleByIntegerLiteral)) { - auto *floatLiteral = CS.getASTContext().getProtocol( - KnownProtocolKind::ExpressibleByFloatLiteral); - if (Literals.count(floatLiteral)) { - Literals.erase(protocol); - return; - } + intLiteral = iter; } if (protocol->isSpecificProtocol( KnownProtocolKind::ExpressibleByFloatLiteral)) { - auto *intLiteral = CS.getASTContext().getProtocol( - KnownProtocolKind::ExpressibleByIntegerLiteral); - Literals.erase(intLiteral); - return; + floatLiteral = iter; } } + + if (intLiteral != Literals.end() && + floatLiteral != Literals.end()) { + Literals.erase(intLiteral); + } } void BindingSet::addLiteralRequirement(Constraint *constraint) { auto *protocol = constraint->getProtocol(); - if (Literals.count(protocol) > 0) - return; + for (const auto &literal : Literals) { + if (literal.getProtocol() == protocol) + return; + } bool isDirect = isDirectRequirement(CS, TypeVar, constraint); @@ -1070,8 +1073,7 @@ void BindingSet::addLiteralRequirement(Constraint *constraint) { defaultType = TypeChecker::getDefaultType(protocol, CS.DC); } - LiteralRequirement literal(protocol, constraint, defaultType, isDirect); - Literals.insert({protocol, std::move(literal)}); + Literals.emplace_back(protocol, constraint, defaultType, isDirect); } bool BindingSet::operator==(const BindingSet &other) { @@ -1081,7 +1083,7 @@ bool BindingSet::operator==(const BindingSet &other) { if (Bindings.size() != other.Bindings.size()) return false; - for (auto i : indices(Bindings)) { + for (unsigned i : indices(Bindings)) { const auto &x = Bindings[i]; const auto &y = other.Bindings[i]; @@ -1093,13 +1095,9 @@ bool BindingSet::operator==(const BindingSet &other) { if (Literals.size() != other.Literals.size()) return false; - for (auto pair : Literals) { - auto found = other.Literals.find(pair.first); - if (found == other.Literals.end()) - return false; - - const auto &x = pair.second; - const auto &y = found->second; + for (unsigned i : indices(Literals)) { + auto &x = Literals[i]; + auto &y = other.Literals[i]; if (x.Source != y.Source || x.DefaultType.getPointer() != y.DefaultType.getPointer() || @@ -2324,15 +2322,12 @@ void PotentialBindings::dump(ConstraintSystem &cs, TypeVariableType *typeVar, void BindingSet::forEachLiteralRequirement( llvm::function_ref callback) const { - for (const auto &literal : Literals) { - auto *protocol = literal.first; - const auto &info = literal.second; - + for (const auto &info : Literals) { // Only uncovered defaultable literal protocols participate. if (!info.viableAsBinding()) continue; - if (auto protocolKind = protocol->getKnownProtocolKind()) + if (auto protocolKind = info.getProtocol()->getKnownProtocolKind()) callback(*protocolKind); } } @@ -2363,7 +2358,7 @@ LiteralBindingKind BindingSet::getLiteralForScore() const { unsigned BindingSet::getNumViableLiteralBindings() const { return llvm::count_if(Literals, [&](const auto &literal) { - return literal.second.viableAsBinding(); + return literal.viableAsBinding(); }); } @@ -2501,10 +2496,10 @@ void BindingSet::dump(llvm::raw_ostream &out, unsigned indent) const { } for (const auto &literal : Literals) { potentialBindings.push_back(PrintableBinding::literalDefaultType( - literal.second.hasDefaultType() - ? literal.second.getDefaultType() + literal.hasDefaultType() + ? literal.getDefaultType() : Type(), - literal.second.viableAsBinding())); + literal.viableAsBinding())); } if (potentialBindings.empty()) { out << ""; diff --git a/lib/Sema/CSOptimizer.cpp b/lib/Sema/CSOptimizer.cpp index c822f51f5457a..ef1c81c0753b3 100644 --- a/lib/Sema/CSOptimizer.cpp +++ b/lib/Sema/CSOptimizer.cpp @@ -1048,9 +1048,9 @@ static void determineBestChoicesInContext( } for (const auto &literal : bindingSet.Literals) { - if (literal.second.hasDefaultType()) { + if (literal.hasDefaultType()) { // Add primary default type - auto type = restoreOptionality(literal.second.getDefaultType(), + auto type = restoreOptionality(literal.getDefaultType(), optionals.size()); types.push_back({type, /*fromLiteral=*/true}); diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp index 7815762a82a8b..a9921e8c83124 100644 --- a/lib/Sema/CSSolver.cpp +++ b/lib/Sema/CSSolver.cpp @@ -2566,8 +2566,8 @@ void DisjunctionChoice::propagateConversionInfo(ConstraintSystem &cs) const { conversionType = bindings.Bindings[0].BindingType; } else { for (const auto &literal : bindings.Literals) { - if (literal.second.viableAsBinding()) { - conversionType = literal.second.getDefaultType(); + if (literal.viableAsBinding()) { + conversionType = literal.getDefaultType(); break; } } diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index 63c743e478e80..0bc30bbc7d64c 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -5397,9 +5397,7 @@ TypeVarBindingProducer::TypeVarBindingProducer(const BindingSet &bindings) } // Infer defaults based on "uncovered" literal protocol requirements. - for (const auto &info : bindings.Literals) { - const auto &literal = info.second; - + for (const auto &literal : bindings.Literals) { if (!literal.viableAsBinding()) continue; diff --git a/unittests/Sema/BindingInferenceTests.cpp b/unittests/Sema/BindingInferenceTests.cpp index 2dcbb710fd7d0..faf73b68c4133 100644 --- a/unittests/Sema/BindingInferenceTests.cpp +++ b/unittests/Sema/BindingInferenceTests.cpp @@ -42,7 +42,7 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) { ASSERT_EQ(bindings.Literals.size(), (unsigned)1); - const auto &literal = bindings.Literals.front().second; + const auto &literal = bindings.Literals.front(); ASSERT_TRUE(literal.hasDefaultType()); ASSERT_TRUE(literal.getDefaultType()->isEqual(intTy)); @@ -65,7 +65,7 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) { ASSERT_TRUE(bindings.Bindings[0].BindingType->isEqual(intTy)); - const auto &literal = bindings.Literals.front().second; + const auto &literal = bindings.Literals.front(); ASSERT_TRUE(literal.isCovered()); ASSERT_TRUE(literal.isDirectRequirement()); ASSERT_TRUE(literal.getDefaultType()->isEqual(intTy)); @@ -99,7 +99,7 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) { ASSERT_TRUE(bindings.Bindings[0].BindingType->isEqual(floatTy)); - const auto &literal = bindings.Literals.front().second; + const auto &literal = bindings.Literals.front(); ASSERT_TRUE(literal.isCovered()); ASSERT_TRUE(literal.isDirectRequirement()); ASSERT_FALSE(literal.getDefaultType()->isEqual(floatTy)); @@ -140,7 +140,7 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) { // Inferred literal requirement through `$T_float` as well. ASSERT_EQ(bindings.Literals.size(), (unsigned)1); - const auto &literal = bindings.Literals.front().second; + const auto &literal = bindings.Literals.front(); ASSERT_TRUE(literal.isCovered()); ASSERT_FALSE(literal.isDirectRequirement()); From 88f347a000d536ec8383efa00b1099401192cb41 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Sat, 13 Dec 2025 13:41:01 -0500 Subject: [PATCH 07/10] Sema: Collect defaults in PotentialBindings::infer() --- include/swift/Sema/CSBindings.h | 3 +++ include/swift/Sema/CSTrail.def | 1 + lib/Sema/CSBindings.cpp | 22 ++++++++++++++-------- lib/Sema/CSTrail.cpp | 5 +++++ 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/include/swift/Sema/CSBindings.h b/include/swift/Sema/CSBindings.h index 9f31b2c979352..1637dad2b1877 100644 --- a/include/swift/Sema/CSBindings.h +++ b/include/swift/Sema/CSBindings.h @@ -248,6 +248,9 @@ struct PotentialBindings { /// The set of protocol conformance requirements imposed on this type variable. llvm::SmallVector Protocols; + /// The set of fallback constraints imposed on this type variable. + llvm::SmallVector Defaults; + ASTNode AssociatedCodeCompletionToken = ASTNode(); /// Add a potential binding to the list of bindings, diff --git a/include/swift/Sema/CSTrail.def b/include/swift/Sema/CSTrail.def index fa8bb1698f98c..b28606bdfa3f6 100644 --- a/include/swift/Sema/CSTrail.def +++ b/include/swift/Sema/CSTrail.def @@ -78,6 +78,7 @@ GRAPH_NODE_CHANGE(InferredBindings) GRAPH_NODE_CHANGE(RetractedBindings) GRAPH_NODE_CHANGE(RetractedDelayedBy) GRAPH_NODE_CHANGE(RetractedProtocol) +GRAPH_NODE_CHANGE(RetractedDefault) BINDING_RELATION_CHANGE(RetractedAdjacentVar) BINDING_RELATION_CHANGE(RetractedSubtypeOf) diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index 1ebc1305dcae4..28d43ef362c13 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -61,18 +61,17 @@ BindingSet::BindingSet(ConstraintSystem &CS, TypeVariableType *TypeVar, addLiteralRequirement(constraint); break; - case ConstraintKind::Defaultable: - case ConstraintKind::FallbackType: - // Do these in a separate pass. - if (isDirectRequirement(CS, TypeVar, constraint)) - addDefault(constraint); - break; - default: break; } } + for (auto *constraint : info.Defaults) { + // Do these in a separate pass. + if (isDirectRequirement(CS, TypeVar, constraint)) + addDefault(constraint); + } + for (auto &entry : info.AdjacentVars) AdjacentVars.insert(entry.first); } @@ -2090,9 +2089,12 @@ void PotentialBindings::infer(ConstraintSystem &CS, case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: case ConstraintKind::LiteralConformsTo: + // Constraints from which we can't do anything. + break; + case ConstraintKind::Defaultable: case ConstraintKind::FallbackType: - // Constraints from which we can't do anything. + Defaults.push_back(constraint); break; // For now let's avoid inferring protocol requirements from @@ -2237,6 +2239,10 @@ void PotentialBindings::retract(ConstraintSystem &CS, llvm::remove_if(Protocols, CALLBACK(RetractedProtocol)), Protocols.end()); + Defaults.erase( + llvm::remove_if(Defaults, CALLBACK(RetractedDefault)), + Defaults.end()); + #define PAIR_CALLBACK(ChangeKind) \ [&](std::pair pair) { \ if (pair.second == constraint) { \ diff --git a/lib/Sema/CSTrail.cpp b/lib/Sema/CSTrail.cpp index 688afdb8eb303..38883b7215ea8 100644 --- a/lib/Sema/CSTrail.cpp +++ b/lib/Sema/CSTrail.cpp @@ -543,6 +543,11 @@ void SolverTrail::Change::undo(ConstraintSystem &cs) const { .Protocols.push_back(TheConstraint.Constraint); break; + case ChangeKind::RetractedDefault: + cg[TheConstraint.TypeVar].getPotentialBindings() + .Defaults.push_back(TheConstraint.Constraint); + break; + case ChangeKind::RetractedAdjacentVar: cg[BindingRelation.TypeVar].getPotentialBindings() .AdjacentVars.emplace_back(BindingRelation.OtherTypeVar, From 77ee3dab7d764a04fc9c0f464446fdd869d7128e Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Sat, 13 Dec 2025 13:53:30 -0500 Subject: [PATCH 08/10] Sema: Collect LiteralRequirements in PotentialBindings::infer() --- include/swift/Sema/CSBindings.h | 16 ++++++++--- include/swift/Sema/CSTrail.def | 1 + lib/Sema/CSBindings.cpp | 50 ++++++++++++++++++++++++--------- lib/Sema/CSTrail.cpp | 6 ++++ 4 files changed, 55 insertions(+), 18 deletions(-) diff --git a/include/swift/Sema/CSBindings.h b/include/swift/Sema/CSBindings.h index 1637dad2b1877..c8237b5889b0a 100644 --- a/include/swift/Sema/CSBindings.h +++ b/include/swift/Sema/CSBindings.h @@ -179,6 +179,10 @@ struct LiteralRequirement { bool isDirectRequirement() const { return IsDirectRequirement; } + void setDirectRequirement(bool isDirectRequirement) { + IsDirectRequirement = isDirectRequirement; + } + bool hasDefaultType() const { return bool(DefaultType); } Type getDefaultType() const { @@ -248,6 +252,10 @@ struct PotentialBindings { /// The set of protocol conformance requirements imposed on this type variable. llvm::SmallVector Protocols; + /// The set of unique literal protocol requirements placed on this + /// type variable. + llvm::SmallVector Literals; + /// The set of fallback constraints imposed on this type variable. llvm::SmallVector Defaults; @@ -270,7 +278,10 @@ struct PotentialBindings { return Protocols; } -private: + void inferFromLiteral(ConstraintSystem &CS, + TypeVariableType *TypeVar, + Constraint *literal); + /// Attempt to infer a new binding and other useful information /// (i.e. whether bindings should be delayed) from the given /// relational constraint. @@ -279,7 +290,6 @@ struct PotentialBindings { TypeVariableType *TypeVar, Constraint *constraint); -public: void infer(ConstraintSystem &CS, TypeVariableType *TypeVar, Constraint *constraint); @@ -621,8 +631,6 @@ class BindingSet { /// checking. void addBinding(PotentialBinding binding, bool isTransitive); - void addLiteralRequirement(Constraint *literal); - void addDefault(Constraint *constraint); StringRef getLiteralBindingKind(LiteralBindingKind K) const { diff --git a/include/swift/Sema/CSTrail.def b/include/swift/Sema/CSTrail.def index b28606bdfa3f6..a501e7b5fca84 100644 --- a/include/swift/Sema/CSTrail.def +++ b/include/swift/Sema/CSTrail.def @@ -76,6 +76,7 @@ GRAPH_NODE_CHANGE(AddedConstraint) GRAPH_NODE_CHANGE(RemovedConstraint) GRAPH_NODE_CHANGE(InferredBindings) GRAPH_NODE_CHANGE(RetractedBindings) +GRAPH_NODE_CHANGE(RetractedLiteral) GRAPH_NODE_CHANGE(RetractedDelayedBy) GRAPH_NODE_CHANGE(RetractedProtocol) GRAPH_NODE_CHANGE(RetractedDefault) diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index 28d43ef362c13..b3a5660d9001b 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -55,16 +55,8 @@ BindingSet::BindingSet(ConstraintSystem &CS, TypeVariableType *TypeVar, for (const auto &binding : info.Bindings) addBinding(binding, /*isTransitive=*/false); - for (auto *constraint : info.Constraints) { - switch (constraint->getKind()) { - case ConstraintKind::LiteralConformsTo: - addLiteralRequirement(constraint); - break; - - default: - break; - } - } + for (const auto &literal : info.Literals) + Literals.push_back(literal); for (auto *constraint : info.Defaults) { // Do these in a separate pass. @@ -653,8 +645,19 @@ void BindingSet::inferTransitiveSupertypeBindings() { // If one of the literal arguments doesn't propagate its // `ExpressibleByStringLiteral` conformance, we'd end up picking // `T` with only one type `Any?` which is incorrect. - for (const auto &literal : bindings.Literals) - addLiteralRequirement(literal.getSource()); + for (auto literal : bindings.Literals) { + auto *protocol = literal.getProtocol(); + + bool found = llvm::any_of(Literals, + [&](const auto &literal) -> bool { + return literal.getProtocol() == protocol; + }); + if (found) + continue; + + literal.setDirectRequirement(false); + Literals.push_back(literal); + } // Infer transitive defaults. for (auto *def : bindings.Defaults) { @@ -1055,7 +1058,9 @@ void BindingSet::coalesceIntegerAndFloatLiteralRequirements() { } } -void BindingSet::addLiteralRequirement(Constraint *constraint) { +void PotentialBindings::inferFromLiteral(ConstraintSystem &CS, + TypeVariableType *TypeVar, + Constraint *constraint) { auto *protocol = constraint->getProtocol(); for (const auto &literal : Literals) { @@ -2088,10 +2093,13 @@ void PotentialBindings::infer(ConstraintSystem &CS, case ConstraintKind::PackElementOf: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: - case ConstraintKind::LiteralConformsTo: // Constraints from which we can't do anything. break; + case ConstraintKind::LiteralConformsTo: + inferFromLiteral(CS, TypeVar, constraint); + break; + case ConstraintKind::Defaultable: case ConstraintKind::FallbackType: Defaults.push_back(constraint); @@ -2219,6 +2227,20 @@ void PotentialBindings::retract(ConstraintSystem &CS, }), Bindings.end()); + Literals.erase( + llvm::remove_if(Literals, + [&](const LiteralRequirement &literal) { + if (literal.getSource() == constraint) { + if (recordingChanges) { + CS.recordChange(SolverTrail::Change::RetractedLiteral( + TypeVar, constraint)); + } + return true; + } + return false; + }), + Literals.end()); + #define CALLBACK(ChangeKind) \ [&](Constraint *other) { \ if (other == constraint) { \ diff --git a/lib/Sema/CSTrail.cpp b/lib/Sema/CSTrail.cpp index 38883b7215ea8..28913ed48fbb3 100644 --- a/lib/Sema/CSTrail.cpp +++ b/lib/Sema/CSTrail.cpp @@ -548,6 +548,12 @@ void SolverTrail::Change::undo(ConstraintSystem &cs) const { .Defaults.push_back(TheConstraint.Constraint); break; + case ChangeKind::RetractedLiteral: + cg[TheConstraint.TypeVar].getPotentialBindings() + .inferFromLiteral(cs, TheConstraint.TypeVar, + TheConstraint.Constraint); + break; + case ChangeKind::RetractedAdjacentVar: cg[BindingRelation.TypeVar].getPotentialBindings() .AdjacentVars.emplace_back(BindingRelation.OtherTypeVar, From d795c185b523585806aed38e280afa43f44c928c Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Sat, 13 Dec 2025 14:49:30 -0500 Subject: [PATCH 09/10] Sema: Remove BindingSet::getConstraintSystem() We shouldn't store a pointer to the ConstraintSystem inside every BindingSet, but there are some annoying things to untangle before we can do that. As a starting point toward that, remove the getConstraintSystem() getter so that at least we can't reach up to the ConstraintSystem from the outside. --- include/swift/Sema/CSBindings.h | 2 -- include/swift/Sema/ConstraintSystem.h | 4 +++- lib/Sema/CSStep.cpp | 3 ++- lib/Sema/CSStep.h | 8 +++++--- lib/Sema/ConstraintSystem.cpp | 10 ++++++---- unittests/Sema/BindingInferenceTests.cpp | 4 ++-- 6 files changed, 18 insertions(+), 13 deletions(-) diff --git a/include/swift/Sema/CSBindings.h b/include/swift/Sema/CSBindings.h index c8237b5889b0a..d10ecdada17e8 100644 --- a/include/swift/Sema/CSBindings.h +++ b/include/swift/Sema/CSBindings.h @@ -410,8 +410,6 @@ class BindingSet { BindingSet(const BindingSet &other) = delete; - ConstraintSystem &getConstraintSystem() const { return CS; } - TypeVariableType *getTypeVariable() const { return TypeVar; } /// Check whether this binding set belongs to a type variable diff --git a/include/swift/Sema/ConstraintSystem.h b/include/swift/Sema/ConstraintSystem.h index a3a8c3b70aff5..d2dc0bb82618a 100644 --- a/include/swift/Sema/ConstraintSystem.h +++ b/include/swift/Sema/ConstraintSystem.h @@ -6196,7 +6196,9 @@ class TypeVarBindingProducer : public BindingProducer { public: using Element = TypeVariableBinding; - TypeVarBindingProducer(const BindingSet &bindings); + TypeVarBindingProducer(ConstraintSystem &cs, + TypeVariableType *typeVar, + const BindingSet &bindings); /// Retrieve a set of bindings available in the current state. ArrayRef getCurrentBindings() const { return Bindings; } diff --git a/lib/Sema/CSStep.cpp b/lib/Sema/CSStep.cpp index e4f8e5e128062..6fdd801b6df7e 100644 --- a/lib/Sema/CSStep.cpp +++ b/lib/Sema/CSStep.cpp @@ -335,7 +335,8 @@ StepResult ComponentStep::take(bool prevFailed) { switch (*step) { case StepKind::Binding: return suspend( - std::make_unique(*bestBindings, Solutions)); + std::make_unique(CS, bestBindings->getTypeVariable(), + *bestBindings, Solutions)); case StepKind::Disjunction: { CS.retireConstraint(disjunction->first); return suspend( diff --git a/lib/Sema/CSStep.h b/lib/Sema/CSStep.h index 7c40af249933b..f09d3827a2f87 100644 --- a/lib/Sema/CSStep.h +++ b/lib/Sema/CSStep.h @@ -540,10 +540,12 @@ class TypeVariableStep final : public BindingStep { bool SawFirstLiteralConstraint = false; public: - TypeVariableStep(const BindingContainer &bindings, + TypeVariableStep(ConstraintSystem &cs, + TypeVariableType *typeVar, + const BindingContainer &bindings, SmallVectorImpl &solutions) - : BindingStep(bindings.getConstraintSystem(), {bindings}, solutions), - TypeVar(bindings.getTypeVariable()) {} + : BindingStep(cs, {cs, typeVar, bindings}, solutions), + TypeVar(typeVar) {} void setup() override; diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index 0bc30bbc7d64c..cbddf562b79ea 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -5323,10 +5323,12 @@ ConstraintSystem::inferKeyPathLiteralCapability(KeyPathExpr *keyPath) { return success(mutability, isSendable); } -TypeVarBindingProducer::TypeVarBindingProducer(const BindingSet &bindings) - : BindingProducer(bindings.getConstraintSystem(), - bindings.getTypeVariable()->getImpl().getLocator()), - TypeVar(bindings.getTypeVariable()), CanBeNil(bindings.canBeNil()) { +TypeVarBindingProducer::TypeVarBindingProducer( + ConstraintSystem &cs, + TypeVariableType *typeVar, + const BindingSet &bindings) + : BindingProducer(cs, typeVar->getImpl().getLocator()), + TypeVar(typeVar), CanBeNil(bindings.canBeNil()) { if (bindings.isDirectHole()) { auto *locator = getLocator(); // If this type variable is associated with a code completion token diff --git a/unittests/Sema/BindingInferenceTests.cpp b/unittests/Sema/BindingInferenceTests.cpp index faf73b68c4133..244eb4d7cf840 100644 --- a/unittests/Sema/BindingInferenceTests.cpp +++ b/unittests/Sema/BindingInferenceTests.cpp @@ -352,7 +352,7 @@ TEST_F(SemaTest, TestNoDoubleVoidClosureResultInference) { auto verifyInference = [&](TypeVariableType *typeVar, unsigned numExpected) { auto bindings = cs.getBindingsFor(typeVar); - TypeVarBindingProducer producer(bindings); + TypeVarBindingProducer producer(cs, typeVar, bindings); llvm::SmallPtrSet inferredTypes; @@ -425,7 +425,7 @@ TEST_F(SemaTest, TestSupertypeInferenceWithDefaults) { cs.getConstraintLocator({})); auto bindings = cs.getBindingsFor(genericArg); - TypeVarBindingProducer producer(bindings); + TypeVarBindingProducer producer(cs, genericArg, bindings); llvm::SmallVector inferredTypes; while (auto binding = producer()) { From 9a68d4aabe7bc57cd8d095c831bbba160d85b2ef Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Sun, 14 Dec 2025 11:44:42 -0500 Subject: [PATCH 10/10] Sema: De-duplicate defaults in inferTransitiveSupertypeBindings() This is necessary now that BindingSet::Defaults is a vector and not a set. --- lib/Sema/CSBindings.cpp | 21 ++++++++++++++++----- test/Constraints/closures.swift | 14 ++++++++++++++ 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index b3a5660d9001b..2946181cf41f1 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -616,6 +616,18 @@ void BindingSet::inferTransitiveKeyPathBindings() { } void BindingSet::inferTransitiveSupertypeBindings() { + llvm::SmallDenseSet seenLiterals; + for (const auto &literal : Literals) { + bool inserted = seenLiterals.insert(literal.getProtocol()).second; + ASSERT(inserted); + } + + llvm::SmallDenseSet seenDefaults; + for (auto *constraint : Defaults) { + bool inserted = seenDefaults.insert(constraint).second; + ASSERT(inserted); + } + for (const auto &entry : Info.SupertypeOf) { auto &node = CS.getConstraintGraph()[entry.first]; if (!node.hasBindingSet()) @@ -648,11 +660,7 @@ void BindingSet::inferTransitiveSupertypeBindings() { for (auto literal : bindings.Literals) { auto *protocol = literal.getProtocol(); - bool found = llvm::any_of(Literals, - [&](const auto &literal) -> bool { - return literal.getProtocol() == protocol; - }); - if (found) + if (!seenLiterals.insert(protocol).second) continue; literal.setDirectRequirement(false); @@ -664,6 +672,9 @@ void BindingSet::inferTransitiveSupertypeBindings() { if (def->getKind() == ConstraintKind::FallbackType) continue; + if (!seenDefaults.insert(def).second) + continue; + addDefault(def); } diff --git a/test/Constraints/closures.swift b/test/Constraints/closures.swift index 7f2e6020001cf..f9a8362fe741a 100644 --- a/test/Constraints/closures.swift +++ b/test/Constraints/closures.swift @@ -1415,3 +1415,17 @@ func test_implicit_result_conversions() { return // Ok } } + +// Random example reduced from swift-build which tripped an assert not +// previously covered by our test suite +do { + struct S { + var x: [Int: [String]] = [:] + } + + let s = [S]() + + let _: [Int: Set] = s.map(\.x) + .reduce([:], { x, y in x.merging(y, uniquingKeysWith: +) }) + .mapValues { Set($0) } +}