11use std:: ops:: ControlFlow ;
22
3- use rustc_data_structures:: fx:: FxIndexMap ;
3+ use rustc_data_structures:: fx:: { FxIndexMap , FxIndexSet } ;
44use rustc_hir as hir;
55use rustc_hir:: def_id:: DefId ;
6+ use rustc_infer:: infer:: outlives:: env:: OutlivesEnvironment ;
67use rustc_infer:: infer:: TyCtxtInferExt ;
78use rustc_infer:: traits:: Obligation ;
8- use rustc_middle:: traits:: ObligationCause ;
9+ use rustc_middle:: traits:: { ObligationCause , Reveal } ;
910use rustc_middle:: ty:: {
1011 self , Ty , TyCtxt , TypeFolder , TypeSuperFoldable , TypeSuperVisitable , TypeVisitable , TypeVisitor ,
1112} ;
1213use rustc_span:: ErrorGuaranteed ;
1314use rustc_span:: { sym, Span } ;
14- use rustc_trait_selection:: traits:: ObligationCtxt ;
15+ use rustc_trait_selection:: traits:: outlives_bounds:: InferCtxtExt ;
16+ use rustc_trait_selection:: traits:: { normalize_param_env_or_error, ObligationCtxt } ;
1517use rustc_type_ir:: fold:: TypeFoldable ;
1618
1719/// Check that an implementation does not refine an RPITIT from a trait method signature.
@@ -30,24 +32,48 @@ pub(super) fn compare_impl_trait_in_trait_predicate_entailment<'tcx>(
3032 let hidden_tys = tcx. collect_return_position_impl_trait_in_trait_tys ( impl_m. def_id ) ?;
3133
3234 let impl_def_id = impl_m. container_id ( tcx) ;
33- // let trait_def_id = trait_m.container_id(tcx);
34- let trait_m_to_impl_m_substs = ty:: InternalSubsts :: identity_for_item ( tcx, impl_m. def_id )
35- . rebase_onto ( tcx, impl_def_id, impl_trait_ref. substs ) ;
35+ let trait_def_id = trait_m. container_id ( tcx) ;
36+ let trait_m_to_impl_m_args = ty:: GenericArgs :: identity_for_item ( tcx, impl_m. def_id )
37+ . rebase_onto ( tcx, impl_def_id, impl_trait_ref. args ) ;
3638
37- let bound_trait_m_sig = tcx. fn_sig ( trait_m. def_id ) . subst ( tcx, trait_m_to_impl_m_substs) ;
38- let trait_m_sig = tcx. liberate_late_bound_regions ( impl_m. def_id , bound_trait_m_sig) ;
39+ let infcx = tcx. infer_ctxt ( ) . build ( ) ;
40+ let ocx = ObligationCtxt :: new ( & infcx) ;
41+
42+ let mut hybrid_preds = tcx. predicates_of ( impl_def_id) . instantiate_identity ( tcx) . predicates ;
43+ hybrid_preds. extend (
44+ tcx. predicates_of ( trait_m. def_id )
45+ . instantiate_own ( tcx, trait_m_to_impl_m_args)
46+ . map ( |( pred, _) | pred) ,
47+ ) ;
48+ let normalize_cause =
49+ ObligationCause :: misc ( tcx. def_span ( impl_m. def_id ) , impl_m. def_id . expect_local ( ) ) ;
50+ let unnormalized_param_env = ty:: ParamEnv :: new (
51+ tcx. mk_clauses ( & hybrid_preds) ,
52+ Reveal :: HideReturnPositionImplTraitInTrait ,
53+ ) ;
54+ let param_env = normalize_param_env_or_error ( tcx, unnormalized_param_env, normalize_cause) ;
55+
56+ let bound_trait_m_sig = tcx. fn_sig ( trait_m. def_id ) . instantiate ( tcx, trait_m_to_impl_m_args) ;
57+ let unnormalized_trait_m_sig =
58+ tcx. liberate_late_bound_regions ( impl_m. def_id , bound_trait_m_sig) ;
59+ let trait_m_sig = ocx. normalize ( & ObligationCause :: dummy ( ) , param_env, unnormalized_trait_m_sig) ;
3960
4061 let mut visitor = ImplTraitInTraitCollector { tcx, types : FxIndexMap :: default ( ) } ;
4162 trait_m_sig. visit_with ( & mut visitor) ;
4263
4364 let mut reverse_mapping = FxIndexMap :: default ( ) ;
4465 let mut bounds_to_prove = vec ! [ ] ;
45- for ( rpitit_def_id, rpitit_substs) in visitor. types {
46- let hidden_ty = hidden_tys
47- . get ( & rpitit_def_id)
48- . expect ( "expected hidden type for RPITIT" )
49- . subst_identity ( ) ;
50- reverse_mapping. insert ( hidden_ty, tcx. mk_projection ( rpitit_def_id, rpitit_substs) ) ;
66+ for ( rpitit_def_id, rpitit_args) in visitor. types {
67+ let hidden_ty =
68+ hidden_tys. get ( & rpitit_def_id) . expect ( "expected hidden type for RPITIT" ) . instantiate (
69+ tcx,
70+ rpitit_args. rebase_onto (
71+ tcx,
72+ trait_def_id,
73+ ty:: GenericArgs :: identity_for_item ( tcx, impl_def_id) ,
74+ ) ,
75+ ) ;
76+ reverse_mapping. insert ( hidden_ty, Ty :: new_projection ( tcx, rpitit_def_id, rpitit_args) ) ;
5177
5278 let ty:: Alias ( ty:: Opaque , opaque_ty) = * hidden_ty. kind ( ) else {
5379 return Err ( report_mismatched_rpitit_signature (
@@ -82,11 +108,6 @@ pub(super) fn compare_impl_trait_in_trait_predicate_entailment<'tcx>(
82108 ) ;
83109 }
84110
85- let infcx = tcx. infer_ctxt ( ) . build ( ) ;
86- let ocx = ObligationCtxt :: new ( & infcx) ;
87- let param_env =
88- tcx. param_env ( impl_m. def_id ) . with_hidden_return_position_impl_trait_in_trait_tys ( ) ;
89-
90111 ocx. register_obligations (
91112 bounds_to_prove. fold_with ( & mut ReverseMapper { tcx, reverse_mapping } ) . into_iter ( ) . map (
92113 |( pred, span) | {
@@ -107,6 +128,24 @@ pub(super) fn compare_impl_trait_in_trait_predicate_entailment<'tcx>(
107128 ) ) ;
108129 }
109130
131+ let mut wf_tys = FxIndexSet :: default ( ) ;
132+ wf_tys. extend ( unnormalized_trait_m_sig. inputs_and_output ) ;
133+ wf_tys. extend ( trait_m_sig. inputs_and_output ) ;
134+ let outlives_env = OutlivesEnvironment :: with_bounds (
135+ param_env,
136+ ocx. infcx . implied_bounds_tys ( param_env, impl_m. def_id . expect_local ( ) , wf_tys. clone ( ) ) ,
137+ ) ;
138+ let errors = ocx. infcx . resolve_regions ( & outlives_env) ;
139+ if !errors. is_empty ( ) {
140+ return Err ( report_mismatched_rpitit_signature (
141+ tcx,
142+ trait_m_sig,
143+ trait_m. def_id ,
144+ impl_m. def_id ,
145+ None ,
146+ ) ) ;
147+ }
148+
110149 Ok ( ( ) )
111150}
112151
0 commit comments