@@ -190,7 +190,12 @@ pub trait InferCtxtExt<'tcx> {
190190 trait_ref : & ty:: PolyTraitRef < ' tcx > ,
191191 ) ;
192192
193- fn suggest_derive ( & self , err : & mut Diagnostic , trait_pred : ty:: PolyTraitPredicate < ' tcx > ) ;
193+ fn suggest_derive (
194+ & self ,
195+ obligation : & PredicateObligation < ' tcx > ,
196+ err : & mut Diagnostic ,
197+ trait_pred : ty:: PolyTraitPredicate < ' tcx > ,
198+ ) ;
194199}
195200
196201fn predicate_constraint ( generics : & hir:: Generics < ' _ > , pred : String ) -> ( Span , String ) {
@@ -2592,33 +2597,60 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
25922597 }
25932598 }
25942599
2595- fn suggest_derive ( & self , err : & mut Diagnostic , trait_pred : ty:: PolyTraitPredicate < ' tcx > ) {
2600+ fn suggest_derive (
2601+ & self ,
2602+ obligation : & PredicateObligation < ' tcx > ,
2603+ err : & mut Diagnostic ,
2604+ trait_pred : ty:: PolyTraitPredicate < ' tcx > ,
2605+ ) {
25962606 let Some ( diagnostic_name) = self . tcx . get_diagnostic_name ( trait_pred. def_id ( ) ) else {
25972607 return ;
25982608 } ;
2599- let Some ( self_ty) = trait_pred. self_ty ( ) . no_bound_vars ( ) else {
2600- return ;
2601- } ;
2602-
2603- let adt = match self_ty. ty_adt_def ( ) {
2604- Some ( adt) if adt. did ( ) . is_local ( ) => adt,
2609+ let ( adt, substs) = match trait_pred. skip_binder ( ) . self_ty ( ) . kind ( ) {
2610+ ty:: Adt ( adt, substs) if adt. did ( ) . is_local ( ) => ( adt, substs) ,
26052611 _ => return ,
26062612 } ;
2607- let can_derive = match diagnostic_name {
2608- sym:: Default => !adt. is_enum ( ) ,
2609- sym:: PartialEq | sym:: PartialOrd => {
2610- let rhs_ty = trait_pred. skip_binder ( ) . trait_ref . substs . type_at ( 1 ) ;
2611- self_ty == rhs_ty
2612- }
2613- sym:: Eq | sym:: Ord | sym:: Clone | sym:: Copy | sym:: Hash | sym:: Debug => true ,
2614- _ => false ,
2613+ let can_derive = {
2614+ let is_derivable_trait = match diagnostic_name {
2615+ sym:: Default => !adt. is_enum ( ) ,
2616+ sym:: PartialEq | sym:: PartialOrd => {
2617+ let rhs_ty = trait_pred. skip_binder ( ) . trait_ref . substs . type_at ( 1 ) ;
2618+ trait_pred. skip_binder ( ) . self_ty ( ) == rhs_ty
2619+ }
2620+ sym:: Eq | sym:: Ord | sym:: Clone | sym:: Copy | sym:: Hash | sym:: Debug => true ,
2621+ _ => false ,
2622+ } ;
2623+ is_derivable_trait &&
2624+ // Ensure all fields impl the trait.
2625+ adt. all_fields ( ) . all ( |field| {
2626+ let field_ty = field. ty ( self . tcx , substs) ;
2627+ let trait_substs = match diagnostic_name {
2628+ sym:: PartialEq | sym:: PartialOrd => {
2629+ self . tcx . mk_substs_trait ( field_ty, & [ field_ty. into ( ) ] )
2630+ }
2631+ _ => self . tcx . mk_substs_trait ( field_ty, & [ ] ) ,
2632+ } ;
2633+ let trait_pred = trait_pred. map_bound_ref ( |tr| ty:: TraitPredicate {
2634+ trait_ref : ty:: TraitRef {
2635+ substs : trait_substs,
2636+ ..trait_pred. skip_binder ( ) . trait_ref
2637+ } ,
2638+ ..* tr
2639+ } ) ;
2640+ let field_obl = Obligation :: new (
2641+ obligation. cause . clone ( ) ,
2642+ obligation. param_env ,
2643+ trait_pred. to_predicate ( self . tcx ) ,
2644+ ) ;
2645+ self . predicate_must_hold_modulo_regions ( & field_obl)
2646+ } )
26152647 } ;
26162648 if can_derive {
26172649 err. span_suggestion_verbose (
26182650 self . tcx . def_span ( adt. did ( ) ) . shrink_to_lo ( ) ,
26192651 & format ! (
26202652 "consider annotating `{}` with `#[derive({})]`" ,
2621- trait_pred. skip_binder( ) . self_ty( ) . to_string ( ) ,
2653+ trait_pred. skip_binder( ) . self_ty( ) ,
26222654 diagnostic_name. to_string( ) ,
26232655 ) ,
26242656 format ! ( "#[derive({})]\n " , diagnostic_name. to_string( ) ) ,
0 commit comments