@@ -38,8 +38,8 @@ use rustc_middle::ty::error::{ExpectedFound, TypeError};
3838use rustc_middle:: ty:: relate:: { self , Relate , RelateResult , TypeRelation } ;
3939use rustc_middle:: ty:: subst:: SubstsRef ;
4040use rustc_middle:: ty:: {
41- self , FallibleTypeFolder , InferConst , Ty , TyCtxt , TypeFoldable , TypeSuperFoldable ,
42- TypeVisitable ,
41+ self , AliasKind , FallibleTypeFolder , InferConst , ToPredicate , Ty , TyCtxt , TypeFoldable ,
42+ TypeSuperFoldable , TypeVisitable ,
4343} ;
4444use rustc_middle:: ty:: { IntType , UintType } ;
4545use rustc_span:: { Span , DUMMY_SP } ;
@@ -74,7 +74,7 @@ impl<'tcx> InferCtxt<'tcx> {
7474 b : Ty < ' tcx > ,
7575 ) -> RelateResult < ' tcx , Ty < ' tcx > >
7676 where
77- R : TypeRelation < ' tcx > ,
77+ R : ObligationEmittingRelation < ' tcx > ,
7878 {
7979 let a_is_expected = relation. a_is_expected ( ) ;
8080
@@ -122,6 +122,15 @@ impl<'tcx> InferCtxt<'tcx> {
122122 Err ( TypeError :: Sorts ( ty:: relate:: expected_found ( relation, a, b) ) )
123123 }
124124
125+ ( ty:: Alias ( AliasKind :: Projection , _) , _) if self . tcx . trait_solver_next ( ) => {
126+ relation. register_type_equate_obligation ( a. into ( ) , b. into ( ) ) ;
127+ Ok ( b)
128+ }
129+ ( _, ty:: Alias ( AliasKind :: Projection , _) ) if self . tcx . trait_solver_next ( ) => {
130+ relation. register_type_equate_obligation ( b. into ( ) , a. into ( ) ) ;
131+ Ok ( a)
132+ }
133+
125134 _ => ty:: relate:: super_relate_tys ( relation, a, b) ,
126135 }
127136 }
@@ -133,7 +142,7 @@ impl<'tcx> InferCtxt<'tcx> {
133142 b : ty:: Const < ' tcx > ,
134143 ) -> RelateResult < ' tcx , ty:: Const < ' tcx > >
135144 where
136- R : ConstEquateRelation < ' tcx > ,
145+ R : ObligationEmittingRelation < ' tcx > ,
137146 {
138147 debug ! ( "{}.consts({:?}, {:?})" , relation. tag( ) , a, b) ;
139148 if a == b {
@@ -169,15 +178,15 @@ impl<'tcx> InferCtxt<'tcx> {
169178 // FIXME(#59490): Need to remove the leak check to accommodate
170179 // escaping bound variables here.
171180 if !a. has_escaping_bound_vars ( ) && !b. has_escaping_bound_vars ( ) {
172- relation. const_equate_obligation ( a, b) ;
181+ relation. register_const_equate_obligation ( a, b) ;
173182 }
174183 return Ok ( b) ;
175184 }
176185 ( _, ty:: ConstKind :: Unevaluated ( ..) ) if self . tcx . lazy_normalization ( ) => {
177186 // FIXME(#59490): Need to remove the leak check to accommodate
178187 // escaping bound variables here.
179188 if !a. has_escaping_bound_vars ( ) && !b. has_escaping_bound_vars ( ) {
180- relation. const_equate_obligation ( a, b) ;
189+ relation. register_const_equate_obligation ( a, b) ;
181190 }
182191 return Ok ( a) ;
183192 }
@@ -435,32 +444,21 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> {
435444 Ok ( Generalization { ty, needs_wf } )
436445 }
437446
438- pub fn add_const_equate_obligation (
447+ pub fn register_obligations ( & mut self , obligations : PredicateObligations < ' tcx > ) {
448+ self . obligations . extend ( obligations. into_iter ( ) ) ;
449+ }
450+
451+ pub fn register_predicates (
439452 & mut self ,
440- a_is_expected : bool ,
441- a : ty:: Const < ' tcx > ,
442- b : ty:: Const < ' tcx > ,
453+ obligations : impl IntoIterator < Item = impl ToPredicate < ' tcx > > ,
443454 ) {
444- let predicate = if a_is_expected {
445- ty:: PredicateKind :: ConstEquate ( a, b)
446- } else {
447- ty:: PredicateKind :: ConstEquate ( b, a)
448- } ;
449- self . obligations . push ( Obligation :: new (
450- self . tcx ( ) ,
451- self . trace . cause . clone ( ) ,
452- self . param_env ,
453- ty:: Binder :: dummy ( predicate) ,
454- ) ) ;
455+ self . obligations . extend ( obligations. into_iter ( ) . map ( |to_pred| {
456+ Obligation :: new ( self . infcx . tcx , self . trace . cause . clone ( ) , self . param_env , to_pred)
457+ } ) )
455458 }
456459
457460 pub fn mark_ambiguous ( & mut self ) {
458- self . obligations . push ( Obligation :: new (
459- self . tcx ( ) ,
460- self . trace . cause . clone ( ) ,
461- self . param_env ,
462- ty:: Binder :: dummy ( ty:: PredicateKind :: Ambiguous ) ,
463- ) ) ;
461+ self . register_predicates ( [ ty:: Binder :: dummy ( ty:: PredicateKind :: Ambiguous ) ] ) ;
464462 }
465463}
466464
@@ -779,11 +777,42 @@ impl<'tcx> TypeRelation<'tcx> for Generalizer<'_, 'tcx> {
779777 }
780778}
781779
782- pub trait ConstEquateRelation < ' tcx > : TypeRelation < ' tcx > {
780+ pub trait ObligationEmittingRelation < ' tcx > : TypeRelation < ' tcx > {
781+ /// Register obligations that must hold in order for this relation to hold
782+ fn register_obligations ( & mut self , obligations : PredicateObligations < ' tcx > ) ;
783+
784+ /// Register predicates that must hold in order for this relation to hold. Uses
785+ /// a default obligation cause, [`ObligationEmittingRelation::register_obligations`] should
786+ /// be used if control over the obligaton causes is required.
787+ fn register_predicates (
788+ & mut self ,
789+ obligations : impl IntoIterator < Item = impl ToPredicate < ' tcx > > ,
790+ ) ;
791+
783792 /// Register an obligation that both constants must be equal to each other.
784793 ///
785794 /// If they aren't equal then the relation doesn't hold.
786- fn const_equate_obligation ( & mut self , a : ty:: Const < ' tcx > , b : ty:: Const < ' tcx > ) ;
795+ fn register_const_equate_obligation ( & mut self , a : ty:: Const < ' tcx > , b : ty:: Const < ' tcx > ) {
796+ let ( a, b) = if self . a_is_expected ( ) { ( a, b) } else { ( b, a) } ;
797+
798+ self . register_predicates ( [ ty:: Binder :: dummy ( if self . tcx ( ) . trait_solver_next ( ) {
799+ ty:: PredicateKind :: AliasEq ( a. into ( ) , b. into ( ) )
800+ } else {
801+ ty:: PredicateKind :: ConstEquate ( a, b)
802+ } ) ] ) ;
803+ }
804+
805+ /// Register an obligation that both types must be equal to each other.
806+ ///
807+ /// If they aren't equal then the relation doesn't hold.
808+ fn register_type_equate_obligation ( & mut self , a : Ty < ' tcx > , b : Ty < ' tcx > ) {
809+ let ( a, b) = if self . a_is_expected ( ) { ( a, b) } else { ( b, a) } ;
810+
811+ self . register_predicates ( [ ty:: Binder :: dummy ( ty:: PredicateKind :: AliasEq (
812+ a. into ( ) ,
813+ b. into ( ) ,
814+ ) ) ] ) ;
815+ }
787816}
788817
789818fn int_unification_error < ' tcx > (
0 commit comments