diff --git a/include/swift/Sema/CSBindings.h b/include/swift/Sema/CSBindings.h index 6fada3ed98053..d10ecdada17e8 100644 --- a/include/swift/Sema/CSBindings.h +++ b/include/swift/Sema/CSBindings.h @@ -152,6 +152,8 @@ struct PotentialBinding { }; struct LiteralRequirement { + /// The literal protocol. + ProtocolDecl *Protocol; /// The source of the literal requirement. Constraint *Source; /// The default type associated with this literal (if any). @@ -164,17 +166,23 @@ struct LiteralRequirement { /// this points to the source of the binding. mutable Constraint *CoveredBy = nullptr; - LiteralRequirement(Constraint *source, Type defaultTy, bool isDirect) - : Source(source), DefaultType(defaultTy), IsDirectRequirement(isDirect) {} + LiteralRequirement(ProtocolDecl *protocol, Constraint *source, + Type defaultTy, bool isDirect) + : Protocol(protocol), Source(source), DefaultType(defaultTy), + IsDirectRequirement(isDirect) {} Constraint *getSource() const { return Source; } - ProtocolDecl *getProtocol() const { return Source->getProtocol(); } + ProtocolDecl *getProtocol() const { return Protocol; } bool isCovered() const { return bool(CoveredBy); } bool isDirectRequirement() const { return IsDirectRequirement; } + void setDirectRequirement(bool isDirectRequirement) { + IsDirectRequirement = isDirectRequirement; + } + bool hasDefaultType() const { return bool(DefaultType); } Type getDefaultType() const { @@ -241,6 +249,16 @@ struct PotentialBindings { llvm::SmallVector, 4> SupertypeOf; llvm::SmallVector, 4> EquivalentTo; + /// The set of protocol conformance requirements imposed on this type variable. + llvm::SmallVector Protocols; + + /// The set of unique literal protocol requirements placed on this + /// type variable. + llvm::SmallVector Literals; + + /// The set of fallback constraints imposed on this type variable. + llvm::SmallVector Defaults; + ASTNode AssociatedCodeCompletionToken = ASTNode(); /// Add a potential binding to the list of bindings, @@ -256,7 +274,14 @@ struct PotentialBindings { }); } -private: + ArrayRef getConformanceRequirements() const { + return Protocols; + } + + void inferFromLiteral(ConstraintSystem &CS, + TypeVariableType *TypeVar, + Constraint *literal); + /// Attempt to infer a new binding and other useful information /// (i.e. whether bindings should be delayed) from the given /// relational constraint. @@ -265,7 +290,6 @@ struct PotentialBindings { TypeVariableType *TypeVar, Constraint *constraint); -public: void infer(ConstraintSystem &CS, TypeVariableType *TypeVar, Constraint *constraint); @@ -365,18 +389,15 @@ class BindingSet { public: swift::SmallSetVector Bindings; - /// The set of protocol conformance requirements placed on this type variable. - llvm::SmallVector Protocols; - /// The set of unique literal protocol requirements placed on this /// type variable or inferred transitively through subtype chains. /// /// Note that ordering is important when it comes to bindings, we'd /// like to add any "direct" default types first to attempt them /// before transitive ones. - llvm::SmallMapVector Literals; + llvm::SmallVector Literals; - llvm::SmallDenseMap Defaults; + llvm::SmallVector Defaults; /// The set of transitive protocol requirements inferred through /// subtype/conversion/equivalence relations with other type variables. @@ -389,8 +410,6 @@ class BindingSet { BindingSet(const BindingSet &other) = delete; - ConstraintSystem &getConstraintSystem() const { return CS; } - TypeVariableType *getTypeVariable() const { return TypeVar; } /// Check whether this binding set belongs to a type variable @@ -457,7 +476,7 @@ class BindingSet { // Literal requirements always result in a subtype/supertype // relationship to a concrete type. if (llvm::any_of(Literals, [](const auto &literal) { - return literal.second.viableAsBinding(); + return literal.viableAsBinding(); })) return false; @@ -494,10 +513,6 @@ class BindingSet { return hasViableBindings() || isDirectHole(); } - ArrayRef getConformanceRequirements() const { - return Protocols; - } - unsigned getNumViableLiteralBindings() const; unsigned getNumViableDefaultableBindings() const { @@ -505,8 +520,8 @@ class BindingSet { return 1; auto numDefaultable = llvm::count_if( - Defaults, [](const std::pair &entry) { - return entry.second->getKind() == ConstraintKind::Defaultable; + Defaults, [](Constraint *constraint) { + return constraint->getKind() == ConstraintKind::Defaultable; }); // Short-circuit unviable checks if there are no defaultable bindings. @@ -518,10 +533,12 @@ class BindingSet { auto unviable = llvm::count_if(Bindings, [&](const PotentialBinding &binding) { auto type = binding.BindingType->getCanonicalType(); - auto def = Defaults.find(type); - return def != Defaults.end() - ? def->second->getKind() == ConstraintKind::Defaultable - : false; + for (auto *constraint : Defaults) { + if (constraint->getSecondType()->isEqual(type)) { + return constraint->getKind() == ConstraintKind::Defaultable; + } + } + return false; }); assert(numDefaultable >= unviable); @@ -572,6 +589,11 @@ class BindingSet { /// requirements down the subtype or equivalence chain. void inferTransitiveProtocolRequirements(); + /// Try to coalesce integer and floating point literal protocols + /// if they appear together because the only possible default type that + /// could satisfy both requirements is `Double`. + void coalesceIntegerAndFloatLiteralRequirements(); + /// Check whether the given binding set covers any of the literal protocols /// associated with this type variable. The idea is that if a type variable /// has a binding like Int and also it has a conformance requirement to @@ -607,8 +629,6 @@ class BindingSet { /// checking. void addBinding(PotentialBinding binding, bool isTransitive); - void addLiteralRequirement(Constraint *literal); - void addDefault(Constraint *constraint); StringRef getLiteralBindingKind(LiteralBindingKind K) const { diff --git a/include/swift/Sema/CSTrail.def b/include/swift/Sema/CSTrail.def index fa16ad1baf5e5..a501e7b5fca84 100644 --- a/include/swift/Sema/CSTrail.def +++ b/include/swift/Sema/CSTrail.def @@ -76,7 +76,10 @@ GRAPH_NODE_CHANGE(AddedConstraint) GRAPH_NODE_CHANGE(RemovedConstraint) GRAPH_NODE_CHANGE(InferredBindings) GRAPH_NODE_CHANGE(RetractedBindings) +GRAPH_NODE_CHANGE(RetractedLiteral) GRAPH_NODE_CHANGE(RetractedDelayedBy) +GRAPH_NODE_CHANGE(RetractedProtocol) +GRAPH_NODE_CHANGE(RetractedDefault) BINDING_RELATION_CHANGE(RetractedAdjacentVar) BINDING_RELATION_CHANGE(RetractedSubtypeOf) diff --git a/include/swift/Sema/ConstraintSystem.h b/include/swift/Sema/ConstraintSystem.h index a3a8c3b70aff5..d2dc0bb82618a 100644 --- a/include/swift/Sema/ConstraintSystem.h +++ b/include/swift/Sema/ConstraintSystem.h @@ -6196,7 +6196,9 @@ class TypeVarBindingProducer : public BindingProducer { public: using Element = TypeVariableBinding; - TypeVarBindingProducer(const BindingSet &bindings); + TypeVarBindingProducer(ConstraintSystem &cs, + TypeVariableType *typeVar, + const BindingSet &bindings); /// Retrieve a set of bindings available in the current state. ArrayRef getCurrentBindings() const { return Bindings; } diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index 8e900fbb25ccf..2946181cf41f1 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -38,6 +38,16 @@ void ConstraintGraphNode::initBindingSet() { Set.emplace(CG.getConstraintSystem(), TypeVar, Potential); } +static bool isDirectRequirement(ConstraintSystem &cs, + TypeVariableType *typeVar, + Constraint *constraint) { + if (auto *other = constraint->getFirstType()->getAs()) { + return typeVar == cs.getRepresentative(other); + } + + return false; +} + BindingSet::BindingSet(ConstraintSystem &CS, TypeVariableType *TypeVar, const PotentialBindings &info) : CS(CS), TypeVar(TypeVar), Info(info) { @@ -45,30 +55,13 @@ BindingSet::BindingSet(ConstraintSystem &CS, TypeVariableType *TypeVar, for (const auto &binding : info.Bindings) addBinding(binding, /*isTransitive=*/false); - for (auto *constraint : info.Constraints) { - switch (constraint->getKind()) { - case ConstraintKind::NonisolatedConformsTo: - case ConstraintKind::ConformsTo: - if (constraint->getSecondType()->is()) - Protocols.push_back(constraint); - break; + for (const auto &literal : info.Literals) + Literals.push_back(literal); - case ConstraintKind::LiteralConformsTo: - addLiteralRequirement(constraint); - break; - - case ConstraintKind::Defaultable: - case ConstraintKind::FallbackType: - // Do these in a separate pass. - if (CS.getFixedTypeRecursive(constraint->getFirstType(), true) - ->getAs() == TypeVar) { - addDefault(constraint); - } - break; - - default: - break; - } + for (auto *constraint : info.Defaults) { + // Do these in a separate pass. + if (isDirectRequirement(CS, TypeVar, constraint)) + addDefault(constraint); } for (auto &entry : info.AdjacentVars) @@ -84,9 +77,12 @@ bool BindingSet::forGenericParameter() const { } bool BindingSet::canBeNil() const { - auto &ctx = CS.getASTContext(); - return Literals.count( - ctx.getProtocol(KnownProtocolKind::ExpressibleByNilLiteral)); + for (const auto &literal : Literals) { + if (literal.getProtocol()->isSpecificProtocol( + KnownProtocolKind::ExpressibleByNilLiteral)) + return true; + } + return false; } bool BindingSet::isDirectHole() const { @@ -435,6 +431,8 @@ void BindingSet::inferTransitiveProtocolRequirements() { } auto &bindings = node.getBindingSet(); + auto conformanceReqs = + node.getPotentialBindings().getConformanceRequirements(); // If current variable already has transitive protocol // conformances inferred, there is no need to look deeper @@ -443,8 +441,8 @@ void BindingSet::inferTransitiveProtocolRequirements() { TypeVariableType *parent = nullptr; std::tie(parent, currentVar) = workList.pop_back_val(); assert(parent); - propagateProtocolsTo(parent, bindings.getConformanceRequirements(), - *bindings.TransitiveProtocols); + propagateProtocolsTo(parent, conformanceReqs, + *bindings.TransitiveProtocols); continue; } @@ -485,14 +483,16 @@ void BindingSet::inferTransitiveProtocolRequirements() { if (!node.hasBindingSet()) continue; - const auto &bindings = node.getBindingSet(); + auto conformanceReqs = + node.getPotentialBindings().getConformanceRequirements(); llvm::SmallPtrSet placeholder; // Add any direct protocols from members of the // equivalence class, so they could be propagated // to all of the members. - propagateProtocolsTo(currentVar, bindings.getConformanceRequirements(), - placeholder); + propagateProtocolsTo(currentVar, conformanceReqs, placeholder); + + const auto &bindings = node.getBindingSet(); // Since type variables are equal, current type variable // becomes a subtype to any supertype found in the current @@ -512,8 +512,7 @@ void BindingSet::inferTransitiveProtocolRequirements() { // are transitive to its parent, propagate them down the subtype/equivalence // chain. if (parent) { - propagateProtocolsTo(parent, bindings.getConformanceRequirements(), - protocols[currentVar]); + propagateProtocolsTo(parent, conformanceReqs, protocols[currentVar]); } auto &inferredProtocols = protocols[currentVar]; @@ -526,9 +525,8 @@ void BindingSet::inferTransitiveProtocolRequirements() { // - all of the transitive protocols inferred through // the members of the equivalence class. { - auto directRequirements = bindings.getConformanceRequirements(); - protocolsForEquivalence.insert(directRequirements.begin(), - directRequirements.end()); + protocolsForEquivalence.insert(conformanceReqs.begin(), + conformanceReqs.end()); protocolsForEquivalence.insert(inferredProtocols.begin(), inferredProtocols.end()); @@ -618,6 +616,18 @@ void BindingSet::inferTransitiveKeyPathBindings() { } void BindingSet::inferTransitiveSupertypeBindings() { + llvm::SmallDenseSet seenLiterals; + for (const auto &literal : Literals) { + bool inserted = seenLiterals.insert(literal.getProtocol()).second; + ASSERT(inserted); + } + + llvm::SmallDenseSet seenDefaults; + for (auto *constraint : Defaults) { + bool inserted = seenDefaults.insert(constraint).second; + ASSERT(inserted); + } + for (const auto &entry : Info.SupertypeOf) { auto &node = CS.getConstraintGraph()[entry.first]; if (!node.hasBindingSet()) @@ -647,15 +657,25 @@ void BindingSet::inferTransitiveSupertypeBindings() { // If one of the literal arguments doesn't propagate its // `ExpressibleByStringLiteral` conformance, we'd end up picking // `T` with only one type `Any?` which is incorrect. - for (const auto &literal : bindings.Literals) - addLiteralRequirement(literal.second.getSource()); + for (auto literal : bindings.Literals) { + auto *protocol = literal.getProtocol(); + + if (!seenLiterals.insert(protocol).second) + continue; + + literal.setDirectRequirement(false); + Literals.push_back(literal); + } // Infer transitive defaults. - for (const auto &def : bindings.Defaults) { - if (def.getSecond()->getKind() == ConstraintKind::FallbackType) + for (auto *def : bindings.Defaults) { + if (def->getKind() == ConstraintKind::FallbackType) continue; - addDefault(def.second); + if (!seenDefaults.insert(def).second) + continue; + + addDefault(def); } // TODO: We shouldn't need this in the future. @@ -842,11 +862,13 @@ bool BindingSet::finalizeKeyPathBindings() { (keyPath->getParsedRoot() || (fixedRootTy && !fixedRootTy->isTypeVariableOrMember()))) { auto fallback = llvm::find_if(Defaults, [](const auto &entry) { - return entry.second->getKind() == ConstraintKind::FallbackType; + return entry->getKind() == ConstraintKind::FallbackType; }); assert(fallback != Defaults.end()); updatedBindings.insert( - {fallback->first, AllowedBindingKind::Exact, fallback->second}); + {(*fallback)->getSecondType(), + AllowedBindingKind::Exact, + *fallback}); } else { updatedBindings.insert(PotentialBinding::forHole( TypeVar, CS.getConstraintLocator( @@ -996,9 +1018,7 @@ void BindingSet::determineLiteralCoverage() { bool allowsNil = canBeNil(); - for (auto &entry : Literals) { - auto &literal = entry.second; - + for (auto &literal : Literals) { if (!literal.viableAsBinding()) continue; @@ -1025,42 +1045,41 @@ void BindingSet::determineLiteralCoverage() { } } -void BindingSet::addLiteralRequirement(Constraint *constraint) { - auto *protocol = constraint->getProtocol(); +void BindingSet::coalesceIntegerAndFloatLiteralRequirements() { + decltype(Literals)::iterator intLiteral = Literals.end(); + decltype(Literals)::iterator floatLiteral = Literals.end(); + + for (auto iter = Literals.begin(); iter != Literals.end(); ++iter) { + auto *protocol = iter->getProtocol(); - // Let's try to coalesce integer and floating point literal protocols - // if they appear together because the only possible default type that - // could satisfy both requirements is `Double`. - { if (protocol->isSpecificProtocol( KnownProtocolKind::ExpressibleByIntegerLiteral)) { - auto *floatLiteral = CS.getASTContext().getProtocol( - KnownProtocolKind::ExpressibleByFloatLiteral); - if (Literals.count(floatLiteral)) - return; + intLiteral = iter; } if (protocol->isSpecificProtocol( KnownProtocolKind::ExpressibleByFloatLiteral)) { - auto *intLiteral = CS.getASTContext().getProtocol( - KnownProtocolKind::ExpressibleByIntegerLiteral); - Literals.erase(intLiteral); + floatLiteral = iter; } } - if (Literals.count(protocol) > 0) - return; + if (intLiteral != Literals.end() && + floatLiteral != Literals.end()) { + Literals.erase(intLiteral); + } +} - auto isDirectRequirement = [&](Constraint *constraint) -> bool { - if (auto *typeVar = constraint->getFirstType()->getAs()) { - auto *repr = CS.getRepresentative(typeVar); - return repr == TypeVar; - } +void PotentialBindings::inferFromLiteral(ConstraintSystem &CS, + TypeVariableType *TypeVar, + Constraint *constraint) { + auto *protocol = constraint->getProtocol(); - return false; - }; + for (const auto &literal : Literals) { + if (literal.getProtocol() == protocol) + return; + } - bool isDirect = isDirectRequirement(constraint); + bool isDirect = isDirectRequirement(CS, TypeVar, constraint); Type defaultType; // `ExpressibleByNilLiteral` doesn't have a default type. @@ -1069,8 +1088,7 @@ void BindingSet::addLiteralRequirement(Constraint *constraint) { defaultType = TypeChecker::getDefaultType(protocol, CS.DC); } - LiteralRequirement literal(constraint, defaultType, isDirect); - Literals.insert({protocol, std::move(literal)}); + Literals.emplace_back(protocol, constraint, defaultType, isDirect); } bool BindingSet::operator==(const BindingSet &other) { @@ -1080,7 +1098,7 @@ bool BindingSet::operator==(const BindingSet &other) { if (Bindings.size() != other.Bindings.size()) return false; - for (auto i : indices(Bindings)) { + for (unsigned i : indices(Bindings)) { const auto &x = Bindings[i]; const auto &y = other.Bindings[i]; @@ -1092,13 +1110,9 @@ bool BindingSet::operator==(const BindingSet &other) { if (Literals.size() != other.Literals.size()) return false; - for (auto pair : Literals) { - auto found = other.Literals.find(pair.first); - if (found == other.Literals.end()) - return false; - - const auto &x = pair.second; - const auto &y = found->second; + for (unsigned i : indices(Literals)) { + auto &x = Literals[i]; + auto &y = other.Literals[i]; if (x.Source != y.Source || x.DefaultType.getPointer() != y.DefaultType.getPointer() || @@ -1110,10 +1124,10 @@ bool BindingSet::operator==(const BindingSet &other) { if (Defaults.size() != other.Defaults.size()) return false; - for (auto pair : Defaults) { - auto found = other.Defaults.find(pair.first); - if (found == other.Defaults.end() || - pair.second != found->second) + for (auto i : indices(Defaults)) { + auto *x = Defaults[i]; + auto *y = other.Defaults[i]; + if (x != y) return false; } @@ -1238,6 +1252,9 @@ const BindingSet *ConstraintSystem::determineBestBindings( bestBindings = &bindings; } + if (bestBindings) + bestBindings->coalesceIntegerAndFloatLiteralRequirements(); + return bestBindings; } @@ -1275,8 +1292,12 @@ findInferableTypeVars(Type type, } void BindingSet::addDefault(Constraint *constraint) { - auto defaultTy = constraint->getSecondType(); - Defaults.insert({defaultTy->getCanonicalType(), constraint}); + if (CONDITIONAL_ASSERT_enabled()) { + for (auto *other : Defaults) { + ASSERT(other != constraint); + } + } + Defaults.push_back(constraint); } bool LiteralRequirement::isCoveredBy(Type type, ConstraintSystem &CS) const { @@ -1671,6 +1692,7 @@ BindingSet ConstraintSystem::getBindingsFor(TypeVariableType *typeVar) { (void) bindings.finalizeKeyPathBindings(); bindings.finalizeUnresolvedMemberChainResult(); bindings.determineLiteralCoverage(); + bindings.coalesceIntegerAndFloatLiteralRequirements(); return bindings; } @@ -2063,6 +2085,12 @@ void PotentialBindings::infer(ConstraintSystem &CS, break; } + case ConstraintKind::NonisolatedConformsTo: + case ConstraintKind::ConformsTo: + if (constraint->getSecondType()->is()) + Protocols.push_back(constraint); + break; + case ConstraintKind::BridgingConversion: case ConstraintKind::CheckedCast: case ConstraintKind::EscapableFunctionOf: @@ -2076,12 +2104,16 @@ void PotentialBindings::infer(ConstraintSystem &CS, case ConstraintKind::PackElementOf: case ConstraintKind::SameShape: case ConstraintKind::MaterializePackExpansion: - case ConstraintKind::NonisolatedConformsTo: - case ConstraintKind::ConformsTo: + // Constraints from which we can't do anything. + break; + case ConstraintKind::LiteralConformsTo: + inferFromLiteral(CS, TypeVar, constraint); + break; + case ConstraintKind::Defaultable: case ConstraintKind::FallbackType: - // Constraints from which we can't do anything. + Defaults.push_back(constraint); break; // For now let's avoid inferring protocol requirements from @@ -2206,21 +2238,45 @@ void PotentialBindings::retract(ConstraintSystem &CS, }), Bindings.end()); - DelayedBy.erase( - llvm::remove_if(DelayedBy, - [&](Constraint *existing) { - if (existing == constraint) { + Literals.erase( + llvm::remove_if(Literals, + [&](const LiteralRequirement &literal) { + if (literal.getSource() == constraint) { if (recordingChanges) { - CS.recordChange(SolverTrail::Change::RetractedDelayedBy( + CS.recordChange(SolverTrail::Change::RetractedLiteral( TypeVar, constraint)); } return true; } return false; }), - DelayedBy.end()); + Literals.end()); #define CALLBACK(ChangeKind) \ + [&](Constraint *other) { \ + if (other == constraint) { \ + if (recordingChanges) { \ + CS.recordChange(SolverTrail::Change::ChangeKind( \ + TypeVar, constraint)); \ + } \ + return true; \ + } \ + return false; \ + } + + DelayedBy.erase( + llvm::remove_if(DelayedBy, CALLBACK(RetractedDelayedBy)), + DelayedBy.end()); + + Protocols.erase( + llvm::remove_if(Protocols, CALLBACK(RetractedProtocol)), + Protocols.end()); + + Defaults.erase( + llvm::remove_if(Defaults, CALLBACK(RetractedDefault)), + Defaults.end()); + +#define PAIR_CALLBACK(ChangeKind) \ [&](std::pair pair) { \ if (pair.second == constraint) { \ if (recordingChanges) { \ @@ -2233,19 +2289,19 @@ void PotentialBindings::retract(ConstraintSystem &CS, } AdjacentVars.erase( - llvm::remove_if(AdjacentVars, CALLBACK(RetractedAdjacentVar)), + llvm::remove_if(AdjacentVars, PAIR_CALLBACK(RetractedAdjacentVar)), AdjacentVars.end()); SubtypeOf.erase( - llvm::remove_if(SubtypeOf, CALLBACK(RetractedSubtypeOf)), + llvm::remove_if(SubtypeOf, PAIR_CALLBACK(RetractedSubtypeOf)), SubtypeOf.end()); SupertypeOf.erase( - llvm::remove_if(SupertypeOf, CALLBACK(RetractedSupertypeOf)), + llvm::remove_if(SupertypeOf, PAIR_CALLBACK(RetractedSupertypeOf)), SupertypeOf.end()); EquivalentTo.erase( - llvm::remove_if(EquivalentTo, CALLBACK(RetractedEquivalentTo)), + llvm::remove_if(EquivalentTo, PAIR_CALLBACK(RetractedEquivalentTo)), EquivalentTo.end()); #undef CALLBACK @@ -2305,15 +2361,12 @@ void PotentialBindings::dump(ConstraintSystem &cs, TypeVariableType *typeVar, void BindingSet::forEachLiteralRequirement( llvm::function_ref callback) const { - for (const auto &literal : Literals) { - auto *protocol = literal.first; - const auto &info = literal.second; - + for (const auto &info : Literals) { // Only uncovered defaultable literal protocols participate. if (!info.viableAsBinding()) continue; - if (auto protocolKind = protocol->getKnownProtocolKind()) + if (auto protocolKind = info.getProtocol()->getKnownProtocolKind()) callback(*protocolKind); } } @@ -2344,7 +2397,7 @@ LiteralBindingKind BindingSet::getLiteralForScore() const { unsigned BindingSet::getNumViableLiteralBindings() const { return llvm::count_if(Literals, [&](const auto &literal) { - return literal.second.viableAsBinding(); + return literal.viableAsBinding(); }); } @@ -2482,10 +2535,10 @@ void BindingSet::dump(llvm::raw_ostream &out, unsigned indent) const { } for (const auto &literal : Literals) { potentialBindings.push_back(PrintableBinding::literalDefaultType( - literal.second.hasDefaultType() - ? literal.second.getDefaultType() + literal.hasDefaultType() + ? literal.getDefaultType() : Type(), - literal.second.viableAsBinding())); + literal.viableAsBinding())); } if (potentialBindings.empty()) { out << ""; @@ -2501,8 +2554,7 @@ void BindingSet::dump(llvm::raw_ostream &out, unsigned indent) const { out << " [defaults: "; interleave( Defaults, - [&](const auto &entry) { - auto *constraint = entry.second; + [&](Constraint *constraint) { auto defaultBinding = PrintableBinding::exact(constraint->getSecondType()); defaultBinding.print(out, PO); diff --git a/lib/Sema/CSOptimizer.cpp b/lib/Sema/CSOptimizer.cpp index c822f51f5457a..ef1c81c0753b3 100644 --- a/lib/Sema/CSOptimizer.cpp +++ b/lib/Sema/CSOptimizer.cpp @@ -1048,9 +1048,9 @@ static void determineBestChoicesInContext( } for (const auto &literal : bindingSet.Literals) { - if (literal.second.hasDefaultType()) { + if (literal.hasDefaultType()) { // Add primary default type - auto type = restoreOptionality(literal.second.getDefaultType(), + auto type = restoreOptionality(literal.getDefaultType(), optionals.size()); types.push_back({type, /*fromLiteral=*/true}); diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp index 7815762a82a8b..a9921e8c83124 100644 --- a/lib/Sema/CSSolver.cpp +++ b/lib/Sema/CSSolver.cpp @@ -2566,8 +2566,8 @@ void DisjunctionChoice::propagateConversionInfo(ConstraintSystem &cs) const { conversionType = bindings.Bindings[0].BindingType; } else { for (const auto &literal : bindings.Literals) { - if (literal.second.viableAsBinding()) { - conversionType = literal.second.getDefaultType(); + if (literal.viableAsBinding()) { + conversionType = literal.getDefaultType(); break; } } diff --git a/lib/Sema/CSStep.cpp b/lib/Sema/CSStep.cpp index e4f8e5e128062..6fdd801b6df7e 100644 --- a/lib/Sema/CSStep.cpp +++ b/lib/Sema/CSStep.cpp @@ -335,7 +335,8 @@ StepResult ComponentStep::take(bool prevFailed) { switch (*step) { case StepKind::Binding: return suspend( - std::make_unique(*bestBindings, Solutions)); + std::make_unique(CS, bestBindings->getTypeVariable(), + *bestBindings, Solutions)); case StepKind::Disjunction: { CS.retireConstraint(disjunction->first); return suspend( diff --git a/lib/Sema/CSStep.h b/lib/Sema/CSStep.h index 7c40af249933b..f09d3827a2f87 100644 --- a/lib/Sema/CSStep.h +++ b/lib/Sema/CSStep.h @@ -540,10 +540,12 @@ class TypeVariableStep final : public BindingStep { bool SawFirstLiteralConstraint = false; public: - TypeVariableStep(const BindingContainer &bindings, + TypeVariableStep(ConstraintSystem &cs, + TypeVariableType *typeVar, + const BindingContainer &bindings, SmallVectorImpl &solutions) - : BindingStep(bindings.getConstraintSystem(), {bindings}, solutions), - TypeVar(bindings.getTypeVariable()) {} + : BindingStep(cs, {cs, typeVar, bindings}, solutions), + TypeVar(typeVar) {} void setup() override; diff --git a/lib/Sema/CSTrail.cpp b/lib/Sema/CSTrail.cpp index 61b1a77e1e0f8..28913ed48fbb3 100644 --- a/lib/Sema/CSTrail.cpp +++ b/lib/Sema/CSTrail.cpp @@ -538,6 +538,22 @@ void SolverTrail::Change::undo(ConstraintSystem &cs) const { .DelayedBy.push_back(TheConstraint.Constraint); break; + case ChangeKind::RetractedProtocol: + cg[TheConstraint.TypeVar].getPotentialBindings() + .Protocols.push_back(TheConstraint.Constraint); + break; + + case ChangeKind::RetractedDefault: + cg[TheConstraint.TypeVar].getPotentialBindings() + .Defaults.push_back(TheConstraint.Constraint); + break; + + case ChangeKind::RetractedLiteral: + cg[TheConstraint.TypeVar].getPotentialBindings() + .inferFromLiteral(cs, TheConstraint.TypeVar, + TheConstraint.Constraint); + break; + case ChangeKind::RetractedAdjacentVar: cg[BindingRelation.TypeVar].getPotentialBindings() .AdjacentVars.emplace_back(BindingRelation.OtherTypeVar, diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index b0d0659f09127..cbddf562b79ea 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -5323,10 +5323,12 @@ ConstraintSystem::inferKeyPathLiteralCapability(KeyPathExpr *keyPath) { return success(mutability, isSendable); } -TypeVarBindingProducer::TypeVarBindingProducer(const BindingSet &bindings) - : BindingProducer(bindings.getConstraintSystem(), - bindings.getTypeVariable()->getImpl().getLocator()), - TypeVar(bindings.getTypeVariable()), CanBeNil(bindings.canBeNil()) { +TypeVarBindingProducer::TypeVarBindingProducer( + ConstraintSystem &cs, + TypeVariableType *typeVar, + const BindingSet &bindings) + : BindingProducer(cs, typeVar->getImpl().getLocator()), + TypeVar(typeVar), CanBeNil(bindings.canBeNil()) { if (bindings.isDirectHole()) { auto *locator = getLocator(); // If this type variable is associated with a code completion token @@ -5384,8 +5386,7 @@ TypeVarBindingProducer::TypeVarBindingProducer(const BindingSet &bindings) if (viableBindings.size() == 1) { addBinding(viableBindings.front()); } else { - for (const auto &entry : bindings.Defaults) { - auto *constraint = entry.second; + for (auto *constraint : bindings.Defaults) { Bindings.push_back(getDefaultBinding(constraint)); } } @@ -5398,9 +5399,7 @@ TypeVarBindingProducer::TypeVarBindingProducer(const BindingSet &bindings) } // Infer defaults based on "uncovered" literal protocol requirements. - for (const auto &info : bindings.Literals) { - const auto &literal = info.second; - + for (const auto &literal : bindings.Literals) { if (!literal.viableAsBinding()) continue; @@ -5423,8 +5422,7 @@ TypeVarBindingProducer::TypeVarBindingProducer(const BindingSet &bindings) { bool noBindings = Bindings.empty(); - for (const auto &entry : bindings.Defaults) { - auto *constraint = entry.second; + for (auto *constraint : bindings.Defaults) { if (noBindings) { // If there are no direct or transitive bindings to attempt // let's add defaults to the list right away. diff --git a/test/Constraints/closures.swift b/test/Constraints/closures.swift index 7f2e6020001cf..f9a8362fe741a 100644 --- a/test/Constraints/closures.swift +++ b/test/Constraints/closures.swift @@ -1415,3 +1415,17 @@ func test_implicit_result_conversions() { return // Ok } } + +// Random example reduced from swift-build which tripped an assert not +// previously covered by our test suite +do { + struct S { + var x: [Int: [String]] = [:] + } + + let s = [S]() + + let _: [Int: Set] = s.map(\.x) + .reduce([:], { x, y in x.merging(y, uniquingKeysWith: +) }) + .mapValues { Set($0) } +} diff --git a/unittests/Sema/BindingInferenceTests.cpp b/unittests/Sema/BindingInferenceTests.cpp index 3791eef8bccd2..244eb4d7cf840 100644 --- a/unittests/Sema/BindingInferenceTests.cpp +++ b/unittests/Sema/BindingInferenceTests.cpp @@ -42,7 +42,7 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) { ASSERT_EQ(bindings.Literals.size(), (unsigned)1); - const auto &literal = bindings.Literals.front().second; + const auto &literal = bindings.Literals.front(); ASSERT_TRUE(literal.hasDefaultType()); ASSERT_TRUE(literal.getDefaultType()->isEqual(intTy)); @@ -65,7 +65,7 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) { ASSERT_TRUE(bindings.Bindings[0].BindingType->isEqual(intTy)); - const auto &literal = bindings.Literals.front().second; + const auto &literal = bindings.Literals.front(); ASSERT_TRUE(literal.isCovered()); ASSERT_TRUE(literal.isDirectRequirement()); ASSERT_TRUE(literal.getDefaultType()->isEqual(intTy)); @@ -99,7 +99,7 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) { ASSERT_TRUE(bindings.Bindings[0].BindingType->isEqual(floatTy)); - const auto &literal = bindings.Literals.front().second; + const auto &literal = bindings.Literals.front(); ASSERT_TRUE(literal.isCovered()); ASSERT_TRUE(literal.isDirectRequirement()); ASSERT_FALSE(literal.getDefaultType()->isEqual(floatTy)); @@ -140,7 +140,7 @@ TEST_F(SemaTest, TestIntLiteralBindingInference) { // Inferred literal requirement through `$T_float` as well. ASSERT_EQ(bindings.Literals.size(), (unsigned)1); - const auto &literal = bindings.Literals.front().second; + const auto &literal = bindings.Literals.front(); ASSERT_TRUE(literal.isCovered()); ASSERT_FALSE(literal.isDirectRequirement()); @@ -197,7 +197,9 @@ TEST_F(SemaTest, TestTransitiveProtocolInference) { CTP_Initialization))); auto &bindings = inferBindings(cs, typeVar); - ASSERT_TRUE(bindings.getConformanceRequirements().empty()); + ASSERT_TRUE(cs.getConstraintGraph()[typeVar] + .getPotentialBindings().getConformanceRequirements().empty()); + ASSERT_TRUE(bool(bindings.TransitiveProtocols)); verifyProtocolInferenceResults(*bindings.TransitiveProtocols, {protocolTy1}); @@ -218,8 +220,10 @@ TEST_F(SemaTest, TestTransitiveProtocolInference) { cs.addConstraint(ConstraintKind::Conversion, typeVar, GPT1, cs.getConstraintLocator({})); + ASSERT_TRUE(cs.getConstraintGraph()[typeVar] + .getPotentialBindings().getConformanceRequirements().empty()); + auto &bindings = inferBindings(cs, typeVar); - ASSERT_TRUE(bindings.getConformanceRequirements().empty()); ASSERT_TRUE(bool(bindings.TransitiveProtocols)); verifyProtocolInferenceResults(*bindings.TransitiveProtocols, {protocolTy1, protocolTy2}); @@ -348,7 +352,7 @@ TEST_F(SemaTest, TestNoDoubleVoidClosureResultInference) { auto verifyInference = [&](TypeVariableType *typeVar, unsigned numExpected) { auto bindings = cs.getBindingsFor(typeVar); - TypeVarBindingProducer producer(bindings); + TypeVarBindingProducer producer(cs, typeVar, bindings); llvm::SmallPtrSet inferredTypes; @@ -421,7 +425,7 @@ TEST_F(SemaTest, TestSupertypeInferenceWithDefaults) { cs.getConstraintLocator({})); auto bindings = cs.getBindingsFor(genericArg); - TypeVarBindingProducer producer(bindings); + TypeVarBindingProducer producer(cs, genericArg, bindings); llvm::SmallVector inferredTypes; while (auto binding = producer()) {