@@ -56,12 +56,12 @@ impl<'hir> LoweringContext<'_, 'hir> {
5656 return ex;
5757 }
5858 // Desugar `ExprForLoop`
59- // from: `[opt_ident]: for <pat> in <head > <body>`
59+ // from: `[opt_ident]: for await? <pat> in <iter > <body>`
6060 //
6161 // This also needs special handling because the HirId of the returned `hir::Expr` will not
6262 // correspond to the `e.id`, so `lower_expr_for` handles attribute lowering itself.
63- ExprKind :: ForLoop ( pat, head , body, opt_label ) => {
64- return self . lower_expr_for ( e, pat, head , body, * opt_label ) ;
63+ ExprKind :: ForLoop { pat, iter , body, label , kind } => {
64+ return self . lower_expr_for ( e, pat, iter , body, * label , * kind ) ;
6565 }
6666 _ => ( ) ,
6767 }
@@ -337,7 +337,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
337337 ) ,
338338 ExprKind :: Try ( sub_expr) => self . lower_expr_try ( e. span , sub_expr) ,
339339
340- ExprKind :: Paren ( _) | ExprKind :: ForLoop ( .. ) => {
340+ ExprKind :: Paren ( _) | ExprKind :: ForLoop { .. } => {
341341 unreachable ! ( "already handled" )
342342 }
343343
@@ -874,6 +874,17 @@ impl<'hir> LoweringContext<'_, 'hir> {
874874 /// }
875875 /// ```
876876 fn lower_expr_await ( & mut self , await_kw_span : Span , expr : & Expr ) -> hir:: ExprKind < ' hir > {
877+ let expr = self . arena . alloc ( self . lower_expr_mut ( expr) ) ;
878+ self . make_lowered_await ( await_kw_span, expr, FutureKind :: Future )
879+ }
880+
881+ /// Takes an expr that has already been lowered and generates a desugared await loop around it
882+ fn make_lowered_await (
883+ & mut self ,
884+ await_kw_span : Span ,
885+ expr : & ' hir hir:: Expr < ' hir > ,
886+ await_kind : FutureKind ,
887+ ) -> hir:: ExprKind < ' hir > {
877888 let full_span = expr. span . to ( await_kw_span) ;
878889
879890 let is_async_gen = match self . coroutine_kind {
@@ -887,13 +898,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
887898 }
888899 } ;
889900
890- let span = self . mark_span_with_reason ( DesugaringKind :: Await , await_kw_span, None ) ;
901+ let features = match await_kind {
902+ FutureKind :: Future => None ,
903+ FutureKind :: AsyncIterator => Some ( self . allow_for_await . clone ( ) ) ,
904+ } ;
905+ let span = self . mark_span_with_reason ( DesugaringKind :: Await , await_kw_span, features) ;
891906 let gen_future_span = self . mark_span_with_reason (
892907 DesugaringKind :: Await ,
893908 full_span,
894909 Some ( self . allow_gen_future . clone ( ) ) ,
895910 ) ;
896- let expr = self . lower_expr_mut ( expr) ;
897911 let expr_hir_id = expr. hir_id ;
898912
899913 // Note that the name of this binding must not be changed to something else because
@@ -934,11 +948,18 @@ impl<'hir> LoweringContext<'_, 'hir> {
934948 hir:: LangItem :: GetContext ,
935949 arena_vec ! [ self ; task_context] ,
936950 ) ;
937- let call = self . expr_call_lang_item_fn (
938- span,
939- hir:: LangItem :: FuturePoll ,
940- arena_vec ! [ self ; new_unchecked, get_context] ,
941- ) ;
951+ let call = match await_kind {
952+ FutureKind :: Future => self . expr_call_lang_item_fn (
953+ span,
954+ hir:: LangItem :: FuturePoll ,
955+ arena_vec ! [ self ; new_unchecked, get_context] ,
956+ ) ,
957+ FutureKind :: AsyncIterator => self . expr_call_lang_item_fn (
958+ span,
959+ hir:: LangItem :: AsyncIteratorPollNext ,
960+ arena_vec ! [ self ; new_unchecked, get_context] ,
961+ ) ,
962+ } ;
942963 self . arena . alloc ( self . expr_unsafe ( call) )
943964 } ;
944965
@@ -1020,11 +1041,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
10201041 let awaitee_arm = self . arm ( awaitee_pat, loop_expr) ;
10211042
10221043 // `match ::std::future::IntoFuture::into_future(<expr>) { ... }`
1023- let into_future_expr = self . expr_call_lang_item_fn (
1024- span,
1025- hir:: LangItem :: IntoFutureIntoFuture ,
1026- arena_vec ! [ self ; expr] ,
1027- ) ;
1044+ let into_future_expr = match await_kind {
1045+ FutureKind :: Future => self . expr_call_lang_item_fn (
1046+ span,
1047+ hir:: LangItem :: IntoFutureIntoFuture ,
1048+ arena_vec ! [ self ; * expr] ,
1049+ ) ,
1050+ // Not needed for `for await` because we expect to have already called
1051+ // `IntoAsyncIterator::into_async_iter` on it.
1052+ FutureKind :: AsyncIterator => expr,
1053+ } ;
10281054
10291055 // match <into_future_expr> {
10301056 // mut __awaitee => loop { .. }
@@ -1685,6 +1711,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
16851711 head : & Expr ,
16861712 body : & Block ,
16871713 opt_label : Option < Label > ,
1714+ loop_kind : ForLoopKind ,
16881715 ) -> hir:: Expr < ' hir > {
16891716 let head = self . lower_expr_mut ( head) ;
16901717 let pat = self . lower_pat ( pat) ;
@@ -1713,17 +1740,41 @@ impl<'hir> LoweringContext<'_, 'hir> {
17131740 let ( iter_pat, iter_pat_nid) =
17141741 self . pat_ident_binding_mode ( head_span, iter, hir:: BindingAnnotation :: MUT ) ;
17151742
1716- // `match Iterator::next(&mut iter) { ... }`
17171743 let match_expr = {
17181744 let iter = self . expr_ident ( head_span, iter, iter_pat_nid) ;
1719- let ref_mut_iter = self . expr_mut_addr_of ( head_span, iter) ;
1720- let next_expr = self . expr_call_lang_item_fn (
1721- head_span,
1722- hir:: LangItem :: IteratorNext ,
1723- arena_vec ! [ self ; ref_mut_iter] ,
1724- ) ;
1745+ let next_expr = match loop_kind {
1746+ ForLoopKind :: For => {
1747+ // `Iterator::next(&mut iter)`
1748+ let ref_mut_iter = self . expr_mut_addr_of ( head_span, iter) ;
1749+ self . expr_call_lang_item_fn (
1750+ head_span,
1751+ hir:: LangItem :: IteratorNext ,
1752+ arena_vec ! [ self ; ref_mut_iter] ,
1753+ )
1754+ }
1755+ ForLoopKind :: ForAwait => {
1756+ // we'll generate `unsafe { Pin::new_unchecked(&mut iter) })` and then pass this
1757+ // to make_lowered_await with `FutureKind::AsyncIterator` which will generator
1758+ // calls to `poll_next`. In user code, this would probably be a call to
1759+ // `Pin::as_mut` but here it's easy enough to do `new_unchecked`.
1760+
1761+ // `&mut iter`
1762+ let iter = self . expr_mut_addr_of ( head_span, iter) ;
1763+ // `Pin::new_unchecked(...)`
1764+ let iter = self . arena . alloc ( self . expr_call_lang_item_fn_mut (
1765+ head_span,
1766+ hir:: LangItem :: PinNewUnchecked ,
1767+ arena_vec ! [ self ; iter] ,
1768+ ) ) ;
1769+ // `unsafe { ... }`
1770+ let iter = self . arena . alloc ( self . expr_unsafe ( iter) ) ;
1771+ let kind = self . make_lowered_await ( head_span, iter, FutureKind :: AsyncIterator ) ;
1772+ self . arena . alloc ( hir:: Expr { hir_id : self . next_id ( ) , kind, span : head_span } )
1773+ }
1774+ } ;
17251775 let arms = arena_vec ! [ self ; none_arm, some_arm] ;
17261776
1777+ // `match $next_expr { ... }`
17271778 self . expr_match ( head_span, next_expr, arms, hir:: MatchSource :: ForLoopDesugar )
17281779 } ;
17291780 let match_stmt = self . stmt_expr ( for_span, match_expr) ;
@@ -1743,13 +1794,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
17431794 // `mut iter => { ... }`
17441795 let iter_arm = self . arm ( iter_pat, loop_expr) ;
17451796
1746- // `match ::std::iter::IntoIterator::into_iter(<head>) { ... }`
1747- let into_iter_expr = {
1748- self . expr_call_lang_item_fn (
1749- head_span,
1750- hir:: LangItem :: IntoIterIntoIter ,
1751- arena_vec ! [ self ; head] ,
1752- )
1797+ let into_iter_expr = match loop_kind {
1798+ ForLoopKind :: For => {
1799+ // `::std::iter::IntoIterator::into_iter(<head>)`
1800+ self . expr_call_lang_item_fn (
1801+ head_span,
1802+ hir:: LangItem :: IntoIterIntoIter ,
1803+ arena_vec ! [ self ; head] ,
1804+ )
1805+ }
1806+ ForLoopKind :: ForAwait => self . arena . alloc ( head) ,
17531807 } ;
17541808
17551809 let match_expr = self . arena . alloc ( self . expr_match (
@@ -2152,3 +2206,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
21522206 }
21532207 }
21542208}
2209+
2210+ /// Used by [`LoweringContext::make_lowered_await`] to customize the desugaring based on what kind
2211+ /// of future we are awaiting.
2212+ #[ derive( Copy , Clone , Debug , PartialEq , Eq ) ]
2213+ enum FutureKind {
2214+ /// We are awaiting a normal future
2215+ Future ,
2216+ /// We are awaiting something that's known to be an AsyncIterator (i.e. we are in the header of
2217+ /// a `for await` loop)
2218+ AsyncIterator ,
2219+ }
0 commit comments