@@ -48,13 +48,14 @@ where
4848 let this = self . project ( ) ;
4949
5050 match this. rx_frame . poll_recv ( cx) {
51- Poll :: Ready ( frame) => return Poll :: Ready ( frame. map ( Ok ) ) ,
52- Poll :: Pending => { }
51+ Poll :: Ready ( frame @ Some ( _ ) ) => return Poll :: Ready ( frame. map ( Ok ) ) ,
52+ Poll :: Ready ( None ) | Poll :: Pending => { }
5353 }
5454
5555 use core:: future:: Future ;
5656 match this. rx_error . poll ( cx) {
57- Poll :: Ready ( err) => return Poll :: Ready ( err. ok ( ) . map ( Err ) ) ,
57+ Poll :: Ready ( Ok ( error) ) => return Poll :: Ready ( Some ( Err ( error) ) ) ,
58+ Poll :: Ready ( Err ( _) ) => return Poll :: Ready ( None ) ,
5859 Poll :: Pending => { }
5960 }
6061
@@ -131,13 +132,54 @@ mod tests {
131132 use super :: * ;
132133
133134 #[ tokio:: test]
134- async fn works ( ) {
135+ async fn empty ( ) {
136+ let ( tx, body) = Channel :: < Bytes > :: new ( 1024 ) ;
137+ drop ( tx) ;
138+
139+ let collected = body. collect ( ) . await . unwrap ( ) ;
140+ assert ! ( collected. trailers( ) . is_none( ) ) ;
141+ assert ! ( collected. to_bytes( ) . is_empty( ) ) ;
142+ }
143+
144+ #[ tokio:: test]
145+ async fn can_send_data ( ) {
135146 let ( mut tx, body) = Channel :: < Bytes > :: new ( 1024 ) ;
136147
137148 tokio:: spawn ( async move {
138149 tx. send_data ( Bytes :: from ( "Hel" ) ) . await . unwrap ( ) ;
139150 tx. send_data ( Bytes :: from ( "lo!" ) ) . await . unwrap ( ) ;
151+ } ) ;
152+
153+ let collected = body. collect ( ) . await . unwrap ( ) ;
154+ assert ! ( collected. trailers( ) . is_none( ) ) ;
155+ assert_eq ! ( collected. to_bytes( ) , "Hello!" ) ;
156+ }
157+
158+ #[ tokio:: test]
159+ async fn can_send_trailers ( ) {
160+ let ( mut tx, body) = Channel :: < Bytes > :: new ( 1024 ) ;
161+
162+ tokio:: spawn ( async move {
163+ let mut trailers = HeaderMap :: new ( ) ;
164+ trailers. insert (
165+ HeaderName :: from_static ( "foo" ) ,
166+ HeaderValue :: from_static ( "bar" ) ,
167+ ) ;
168+ tx. send_trailers ( trailers) . await . unwrap ( ) ;
169+ } ) ;
170+
171+ let collected = body. collect ( ) . await . unwrap ( ) ;
172+ assert_eq ! ( collected. trailers( ) . unwrap( ) [ "foo" ] , "bar" ) ;
173+ assert ! ( collected. to_bytes( ) . is_empty( ) ) ;
174+ }
175+
176+ #[ tokio:: test]
177+ async fn can_send_both_data_and_trailers ( ) {
178+ let ( mut tx, body) = Channel :: < Bytes > :: new ( 1024 ) ;
140179
180+ tokio:: spawn ( async move {
181+ tx. send_data ( Bytes :: from ( "Hel" ) ) . await . unwrap ( ) ;
182+ tx. send_data ( Bytes :: from ( "lo!" ) ) . await . unwrap ( ) ;
141183 let mut trailers = HeaderMap :: new ( ) ;
142184 trailers. insert (
143185 HeaderName :: from_static ( "foo" ) ,
@@ -150,4 +192,43 @@ mod tests {
150192 assert_eq ! ( collected. trailers( ) . unwrap( ) [ "foo" ] , "bar" ) ;
151193 assert_eq ! ( collected. to_bytes( ) , "Hello!" ) ;
152194 }
195+
196+ /// A stand-in for an error type, for unit tests.
197+ type Error = & ' static str ;
198+ /// An example error message.
199+ const MSG : Error = "oh no" ;
200+
201+ #[ tokio:: test]
202+ async fn aborts_before_trailers ( ) {
203+ let ( mut tx, body) = Channel :: < Bytes , Error > :: new ( 1024 ) ;
204+
205+ tokio:: spawn ( async move {
206+ tx. send_data ( Bytes :: from ( "Hel" ) ) . await . unwrap ( ) ;
207+ tx. send_data ( Bytes :: from ( "lo!" ) ) . await . unwrap ( ) ;
208+ tx. abort ( MSG ) ;
209+ } ) ;
210+
211+ let err = body. collect ( ) . await . unwrap_err ( ) ;
212+ assert_eq ! ( err, MSG ) ;
213+ }
214+
215+ #[ tokio:: test]
216+ async fn aborts_after_trailers ( ) {
217+ let ( mut tx, body) = Channel :: < Bytes , Error > :: new ( 1024 ) ;
218+
219+ tokio:: spawn ( async move {
220+ tx. send_data ( Bytes :: from ( "Hel" ) ) . await . unwrap ( ) ;
221+ tx. send_data ( Bytes :: from ( "lo!" ) ) . await . unwrap ( ) ;
222+ let mut trailers = HeaderMap :: new ( ) ;
223+ trailers. insert (
224+ HeaderName :: from_static ( "foo" ) ,
225+ HeaderValue :: from_static ( "bar" ) ,
226+ ) ;
227+ tx. send_trailers ( trailers) . await . unwrap ( ) ;
228+ tx. abort ( MSG ) ;
229+ } ) ;
230+
231+ let err = body. collect ( ) . await . unwrap_err ( ) ;
232+ assert_eq ! ( err, MSG ) ;
233+ }
153234}
0 commit comments