@@ -2,17 +2,21 @@ use std::error::Error as StdError;
22#[ cfg( feature = "runtime" ) ]
33use std:: time:: Duration ;
44
5+ use bytes:: Bytes ;
56use futures_channel:: { mpsc, oneshot} ;
67use futures_util:: future:: { self , Either , FutureExt as _, TryFutureExt as _} ;
78use futures_util:: stream:: StreamExt as _;
89use h2:: client:: { Builder , SendRequest } ;
10+ use http:: { Method , StatusCode } ;
911use tokio:: io:: { AsyncRead , AsyncWrite } ;
1012
11- use super :: { decode_content_length , ping , PipeToSendStream , SendBuf } ;
13+ use super :: { ping , H2Upgraded , PipeToSendStream , SendBuf } ;
1214use crate :: body:: HttpBody ;
1315use crate :: common:: { exec:: Exec , task, Future , Never , Pin , Poll } ;
1416use crate :: headers;
17+ use crate :: proto:: h2:: UpgradedSendStream ;
1518use crate :: proto:: Dispatched ;
19+ use crate :: upgrade:: Upgraded ;
1620use crate :: { Body , Request , Response } ;
1721
1822type ClientRx < B > = crate :: client:: dispatch:: Receiver < Request < B > , Response < Body > > ;
@@ -233,8 +237,25 @@ where
233237 headers:: set_content_length_if_missing ( req. headers_mut ( ) , len) ;
234238 }
235239 }
240+
241+ let is_connect = req. method ( ) == Method :: CONNECT ;
236242 let eos = body. is_end_stream ( ) ;
237- let ( fut, body_tx) = match self . h2_tx . send_request ( req, eos) {
243+ let ping = self . ping . clone ( ) ;
244+
245+ if is_connect {
246+ if headers:: content_length_parse_all ( req. headers ( ) )
247+ . map_or ( false , |len| len != 0 )
248+ {
249+ warn ! ( "h2 connect request with non-zero body not supported" ) ;
250+ cb. send ( Err ( (
251+ crate :: Error :: new_h2 ( h2:: Reason :: INTERNAL_ERROR . into ( ) ) ,
252+ None ,
253+ ) ) ) ;
254+ continue ;
255+ }
256+ }
257+
258+ let ( fut, body_tx) = match self . h2_tx . send_request ( req, !is_connect && eos) {
238259 Ok ( ok) => ok,
239260 Err ( err) => {
240261 debug ! ( "client send request error: {}" , err) ;
@@ -243,45 +264,81 @@ where
243264 }
244265 } ;
245266
246- let ping = self . ping . clone ( ) ;
247- if !eos {
248- let mut pipe = Box :: pin ( PipeToSendStream :: new ( body, body_tx) ) . map ( |res| {
249- if let Err ( e) = res {
250- debug ! ( "client request body error: {}" , e) ;
251- }
252- } ) ;
253-
254- // eagerly see if the body pipe is ready and
255- // can thus skip allocating in the executor
256- match Pin :: new ( & mut pipe) . poll ( cx) {
257- Poll :: Ready ( _) => ( ) ,
258- Poll :: Pending => {
259- let conn_drop_ref = self . conn_drop_ref . clone ( ) ;
260- // keep the ping recorder's knowledge of an
261- // "open stream" alive while this body is
262- // still sending...
263- let ping = ping. clone ( ) ;
264- let pipe = pipe. map ( move |x| {
265- drop ( conn_drop_ref) ;
266- drop ( ping) ;
267- x
267+ let send_stream = if !is_connect {
268+ if !eos {
269+ let mut pipe =
270+ Box :: pin ( PipeToSendStream :: new ( body, body_tx) ) . map ( |res| {
271+ if let Err ( e) = res {
272+ debug ! ( "client request body error: {}" , e) ;
273+ }
268274 } ) ;
269- self . executor . execute ( pipe) ;
275+
276+ // eagerly see if the body pipe is ready and
277+ // can thus skip allocating in the executor
278+ match Pin :: new ( & mut pipe) . poll ( cx) {
279+ Poll :: Ready ( _) => ( ) ,
280+ Poll :: Pending => {
281+ let conn_drop_ref = self . conn_drop_ref . clone ( ) ;
282+ // keep the ping recorder's knowledge of an
283+ // "open stream" alive while this body is
284+ // still sending...
285+ let ping = ping. clone ( ) ;
286+ let pipe = pipe. map ( move |x| {
287+ drop ( conn_drop_ref) ;
288+ drop ( ping) ;
289+ x
290+ } ) ;
291+ self . executor . execute ( pipe) ;
292+ }
270293 }
271294 }
272- }
295+
296+ None
297+ } else {
298+ Some ( body_tx)
299+ } ;
273300
274301 let fut = fut. map ( move |result| match result {
275302 Ok ( res) => {
276303 // record that we got the response headers
277304 ping. record_non_data ( ) ;
278305
279- let content_length = decode_content_length ( res. headers ( ) ) ;
280- let res = res. map ( |stream| {
281- let ping = ping. for_stream ( & stream) ;
282- crate :: Body :: h2 ( stream, content_length, ping)
283- } ) ;
284- Ok ( res)
306+ let content_length = headers:: content_length_parse_all ( res. headers ( ) ) ;
307+ if let ( Some ( mut send_stream) , StatusCode :: OK ) =
308+ ( send_stream, res. status ( ) )
309+ {
310+ if content_length. map_or ( false , |len| len != 0 ) {
311+ warn ! ( "h2 connect response with non-zero body not supported" ) ;
312+
313+ send_stream. send_reset ( h2:: Reason :: INTERNAL_ERROR ) ;
314+ return Err ( (
315+ crate :: Error :: new_h2 ( h2:: Reason :: INTERNAL_ERROR . into ( ) ) ,
316+ None ,
317+ ) ) ;
318+ }
319+ let ( parts, recv_stream) = res. into_parts ( ) ;
320+ let mut res = Response :: from_parts ( parts, Body :: empty ( ) ) ;
321+
322+ let ( pending, on_upgrade) = crate :: upgrade:: pending ( ) ;
323+ let io = H2Upgraded {
324+ ping,
325+ send_stream : unsafe { UpgradedSendStream :: new ( send_stream) } ,
326+ recv_stream,
327+ buf : Bytes :: new ( ) ,
328+ } ;
329+ let upgraded = Upgraded :: new ( io, Bytes :: new ( ) ) ;
330+
331+ pending. fulfill ( upgraded) ;
332+ res. extensions_mut ( ) . insert ( on_upgrade) ;
333+
334+ Ok ( res)
335+ } else {
336+ let res = res. map ( |stream| {
337+ let ping = ping. for_stream ( & stream) ;
338+ crate :: Body :: h2 ( stream, content_length. into ( ) , ping)
339+ } ) ;
340+ Ok ( res)
341+ }
285342 }
286343 Err ( err) => {
287344 ping. ensure_not_timed_out ( ) . map_err ( |e| ( e, None ) ) ?;
0 commit comments