diff --git a/dune-project b/dune-project index d03fd93a..39f2389d 100644 --- a/dune-project +++ b/dune-project @@ -22,6 +22,7 @@ base-threads result hmap + base-unix (iostream (>= 0.2)) (ocaml (>= 4.08)) (odoc :with-doc) diff --git a/examples/echo_ws.ml b/examples/echo_ws.ml index f24cf283..37e2635d 100644 --- a/examples/echo_ws.ml +++ b/examples/echo_ws.ml @@ -28,7 +28,7 @@ let handle_ws _client_addr ic oc = let buf = Bytes.create 32 in let continue = ref true in while !continue do - let n = IO.Input.input ic buf 0 (Bytes.length buf) in + let n = IO.Input_with_timeout.input ic buf 0 (Bytes.length buf) in Log.debug (fun k -> k "echo %d bytes from websocket: %S" n (Bytes.sub_string buf 0 n)); diff --git a/src/Tiny_httpd.ml b/src/Tiny_httpd.ml index cbffe69b..036d4f97 100644 --- a/src/Tiny_httpd.ml +++ b/src/Tiny_httpd.ml @@ -10,6 +10,7 @@ module Meth = Tiny_httpd_core.Meth module Pool = Tiny_httpd_core.Pool module Log = Tiny_httpd_core.Log module Server = Tiny_httpd_core.Server +module Time = Time module Util = Tiny_httpd_core.Util include Server module Dir = Tiny_httpd_unix.Dir diff --git a/src/Tiny_httpd.mli b/src/Tiny_httpd.mli index 2490646e..b0664373 100644 --- a/src/Tiny_httpd.mli +++ b/src/Tiny_httpd.mli @@ -85,6 +85,10 @@ module Buf = Buf module IO = Tiny_httpd_core.IO +(** {2 Time} *) + +module Time = Time + (** {2 Logging *) module Log = Tiny_httpd_core.Log diff --git a/src/camlzip/Tiny_httpd_camlzip.ml b/src/camlzip/Tiny_httpd_camlzip.ml index 8be7da00..89d4b73f 100644 --- a/src/camlzip/Tiny_httpd_camlzip.ml +++ b/src/camlzip/Tiny_httpd_camlzip.ml @@ -1,8 +1,9 @@ module W = IO.Writer -let decode_deflate_stream_ ~buf_size (ic : IO.Input.t) : IO.Input.t = +let decode_deflate_stream_ ~buf_size (ic : #IO.Input_with_timeout.t) : + IO.Input_with_timeout.t = Log.debug (fun k -> k "wrap stream with deflate.decode"); - Iostream_camlzip.decompress_in_buf ~buf_size ic + Iostream_camlzip.decompress_in_buf_with_timeout ~now_s:Time.now_s ~buf_size ic let encode_deflate_writer_ ~buf_size (w : W.t) : W.t = Log.debug (fun k -> k "wrap writer with deflate.encode"); @@ -27,8 +28,8 @@ let has_deflate s = try Scanf.sscanf s "deflate, %s" (fun _ -> true) with _ -> false (* decompress [req]'s body if needed *) -let decompress_req_stream_ ~buf_size (req : IO.Input.t Request.t) : _ Request.t - = +let decompress_req_stream_ ~buf_size (req : #IO.Input_with_timeout.t Request.t) + : _ Request.t = match Request.get_header ~f:String.trim req "Transfer-Encoding" with (* TODO | Some "gzip" -> diff --git a/src/core/IO.ml b/src/core/IO.ml index 249da955..8f864012 100644 --- a/src/core/IO.ml +++ b/src/core/IO.ml @@ -11,6 +11,7 @@ open Common_ module Buf = Buf module Slice = Iostream.Slice +module A = Atomic_ (** Output channel (byte sink) *) module Output = struct @@ -44,13 +45,11 @@ module Output = struct done method private close_underlying () = - if not !closed then ( - closed := true; + if not (A.exchange closed true) then if close_noerr then ( try Unix.close fd with _ -> () ) else Unix.close fd - ) end let output_buf (self : t) (buf : Buf.t) : unit = @@ -108,38 +107,28 @@ module Input = struct let of_unix_fd ?(close_noerr = false) ~closed ~(buf : Slice.t) (fd : Unix.file_descr) : t = let eof = ref false in + let input buf i len : int = + let n = ref 0 in + if not !eof then ( + n := Unix.read fd buf i len; + if !n = 0 then eof := true + ); + !n + in + object inherit Iostream.In_buf.t_from_refill ~bytes:buf.bytes () method private refill (slice : Slice.t) = if not !eof then ( slice.off <- 0; - let continue = ref true in - while !continue do - match Unix.read fd slice.bytes 0 (Bytes.length slice.bytes) with - | n -> - slice.len <- n; - continue := false - | exception - Unix.Unix_error - ( ( Unix.EBADF | Unix.ENOTCONN | Unix.ESHUTDOWN - | Unix.ECONNRESET | Unix.EPIPE ), - _, - _ ) -> - eof := true; - continue := false - | exception - Unix.Unix_error - ((Unix.EWOULDBLOCK | Unix.EAGAIN | Unix.EINTR), _, _) -> - ignore (Unix.select [ fd ] [] [] 1.) - done; + slice.len <- input slice.bytes 0 (Bytes.length slice.bytes); (* Printf.eprintf "read returned %d B\n%!" !n; *) if slice.len = 0 then eof := true ) method close () = - if not !closed then ( - closed := true; + if not (A.exchange closed true) then ( eof := true; if close_noerr then ( try Unix.close fd with _ -> () @@ -148,6 +137,8 @@ module Input = struct ) end + let[@inline] of_string s : t = (of_string s :> t) + let of_slice (slice : Slice.t) : t = object inherit Iostream.In_buf.t_from_refill ~bytes:slice.bytes () @@ -168,7 +159,7 @@ module Input = struct (** Read exactly [len] bytes. @raise End_of_file if the input did not contain enough data. *) - let really_input (self : t) buf i len : unit = + let really_input (self : #t) buf i len : unit = let i = ref i in let len = ref len in while !len > 0 do @@ -178,31 +169,6 @@ module Input = struct len := !len - n done - let append (i1 : #t) (i2 : #t) : t = - let use_i1 = ref true in - let rec input_rec (slice : Slice.t) = - if !use_i1 then ( - slice.len <- input i1 slice.bytes 0 (Bytes.length slice.bytes); - if slice.len = 0 then ( - use_i1 := false; - input_rec slice - ) - ) else - slice.len <- input i1 slice.bytes 0 (Bytes.length slice.bytes) - in - - object - inherit Iostream.In_buf.t_from_refill () - - method private refill (slice : Slice.t) = - slice.off <- 0; - input_rec slice - - method close () = - close i1; - close i2 - end - let iter_slice (f : Slice.t -> unit) (self : #t) : unit = let continue = ref true in while !continue do @@ -231,11 +197,131 @@ module Input = struct Iostream.Out.output oc slice.bytes slice.off slice.len) self - let read_all_using ~buf (self : #t) : string = + (** Output a stream using chunked encoding *) + let output_chunked' ?buf (oc : #Iostream.Out_buf.t) (self : #t) : unit = + let oc' = Output.chunk_encoding ?buf oc ~close_rec:false in + match to_chan' oc' self with + | () -> Output.close oc' + | exception e -> + let bt = Printexc.get_raw_backtrace () in + Output.close oc'; + Printexc.raise_with_backtrace e bt + + (** print a stream as a series of chunks *) + let output_chunked ?buf (oc : out_channel) (self : #t) : unit = + output_chunked' ?buf (Output.of_out_channel oc) self +end + +(** Input channel (byte source) with read-with-timeout *) +module Input_with_timeout = struct + include Iostream.In_buf + + class type t = Iostream.In_buf.t_with_timeout + + exception Timeout = Iostream.Timeout + (** Exception for timeouts *) + + exception Timeout_partial_read of int + (** Exception for timeouts with a partial read *) + + (** fill buffer, but stop at the deadline *) + let fill_buf_with_deadline (self : #t) ~(deadline : float) : Slice.t = + let timeout = deadline -. Time.now_s () in + if timeout <= 0. then raise Timeout; + fill_buf_with_timeout self timeout + + (** fill buffer, but stop at the deadline if provided *) + let fill_buf_with_deadline_opt (self : #t) ~(deadline : float option) : + Slice.t = + match deadline with + | None -> fill_buf self + | Some d -> fill_buf_with_deadline self ~deadline:d + + let of_unix_fd ?(close_noerr = false) ~closed ~(buf : Slice.t) + (fd : Unix.file_descr) : t = + let eof = ref false in + + let input_with_timeout t buf i len : int = + let deadline = Time.now_s () +. t in + let n = ref 0 in + while + (not (Atomic.get closed)) + && (not !eof) + && + try + n := Unix.read fd buf i len; + false + with + | Unix.Unix_error ((Unix.EAGAIN | Unix.EWOULDBLOCK), _, _) -> + (* sleep *) + true + | Unix.Unix_error ((Unix.ECONNRESET | Unix.ESHUTDOWN | Unix.EPIPE), _, _) + -> + (* exit *) + false + do + let now = Time.now_s () in + if now >= deadline then raise Timeout; + ignore (Unix.select [ fd ] [] [] (deadline -. now) : _ * _ * _) + done; + !n + in + + object + inherit Iostream.In_buf.t_with_timeout_from_refill ~bytes:buf.bytes () + + method private refill_with_timeout t (slice : Slice.t) = + if not !eof then ( + slice.off <- 0; + slice.len <- + input_with_timeout t slice.bytes 0 (Bytes.length slice.bytes); + (* Printf.eprintf "read returned %d B\n%!" !n; *) + if slice.len = 0 then eof := true + ) + + method close () = + if not (A.exchange closed true) then ( + eof := true; + if close_noerr then ( + try Unix.close fd with _ -> () + ) else + Unix.close fd + ) + end + + let of_slice (slice : Slice.t) : t = + object + inherit Iostream.In_buf.t_with_timeout_from_refill ~bytes:slice.bytes () + + method private refill_with_timeout _t (slice : Slice.t) = + slice.off <- 0; + slice.len <- 0 + + method close () = () + end + + (** Read into the given slice. + @return the number of bytes read, [0] means end of input. *) + let[@inline] input (self : t) buf i len = self#input buf i len + + (** Close the channel. *) + let[@inline] close self : unit = self#close () + + let iter_slice = Input.iter_slice + let iter = Input.iter + let to_chan = Input.to_chan + let to_chan' = Input.to_chan' + + (** Read the whole body + @param deadline a deadline before which the operation must complete + @raise Timeout if deadline expires (leftovers are in [buf] *) + let read_all_using ~buf ~(deadline : float) (self : #t) : string = Buf.clear buf; let continue = ref true in while !continue do - let slice = fill_buf self in + let timeout = deadline -. Time.now_s () in + if timeout <= 0. then raise Timeout; + let slice = fill_buf_with_timeout self timeout in if slice.len = 0 then continue := false else ( @@ -246,12 +332,17 @@ module Input = struct done; Buf.contents_and_clear buf - (** Read [n] bytes from the input into [bytes]. *) - let read_exactly_ ~too_short (self : #t) (bytes : bytes) (n : int) : unit = - assert (Bytes.length bytes >= n); - let offset = ref 0 in + (** Read [n] bytes from the input into [bytes]. + @raise Timeout_partial_read if timeout occurs before it's done *) + let read_exactly_ ?(off = 0) ~too_short ~(deadline : float) (self : #t) + (bytes : bytes) (n : int) : unit = + assert (Bytes.length bytes >= off + n); + let offset = ref off in while !offset < n do - let slice = self#fill_buf () in + let slice = + try fill_buf_with_deadline self ~deadline + with Timeout -> raise (Timeout_partial_read (!offset - off)) + in let n_read = min slice.len (n - !offset) in Bytes.blit slice.bytes slice.off bytes !offset n_read; offset := !offset + n_read; @@ -259,12 +350,16 @@ module Input = struct if n_read = 0 then too_short () done + let[@inline] really_input (self : #t) ~deadline buf i len = + read_exactly_ ~off:i ~deadline self buf len ~too_short:(fun () -> + raise End_of_file) + (** read a line into the buffer, after clearing it. *) - let read_line_into (self : t) ~buf : unit = + let read_line_into (self : #t) ~(deadline : float) ~buf : unit = Buf.clear buf; let continue = ref true in while !continue do - let slice = self#fill_buf () in + let slice = fill_buf_with_deadline self ~deadline in if slice.len = 0 then ( continue := false; if Buf.size buf = 0 then raise End_of_file @@ -286,32 +381,32 @@ module Input = struct ) done - let read_line_using ~buf (self : #t) : string = - read_line_into self ~buf; + let read_line_using ~buf ~deadline (self : #t) : string = + read_line_into self ~deadline ~buf; Buf.contents_and_clear buf - let read_line_using_opt ~buf (self : #t) : string option = - match read_line_into self ~buf with + let read_line_using_opt ~buf ~deadline (self : #t) : string option = + match read_line_into self ~buf ~deadline with | () -> Some (Buf.contents_and_clear buf) | exception End_of_file -> None (* helper for making a new input stream that either contains at most [size] bytes, or contains exactly [size] bytes. *) - let reading_exactly_ ~skip_on_close ~close_rec ~size ~bytes (arg : t) : t = + let reading_exactly_ ~skip_on_close ~close_rec ~size ~bytes (arg : #t) : t = let remaining_size = ref size in object - inherit t_from_refill ~bytes () + inherit t_with_timeout_from_refill ~bytes () method close () = if !remaining_size > 0 && skip_on_close then skip arg !remaining_size; if close_rec then close arg - method private refill (slice : Slice.t) = + method private refill_with_timeout t (slice : Slice.t) = slice.off <- 0; slice.len <- 0; if !remaining_size > 0 then ( - let sub = fill_buf arg in + let sub = fill_buf_with_timeout arg t in let n = min !remaining_size (min sub.len (Bytes.length slice.bytes)) in @@ -324,7 +419,7 @@ module Input = struct (** new stream with maximum size [max_size]. @param close_rec if true, closing this will also close the input stream *) - let limit_size_to ~close_rec ~max_size ~bytes (arg : t) : t = + let limit_size_to ~close_rec ~max_size ~bytes (arg : #t) : t = reading_exactly_ ~size:max_size ~skip_on_close:false ~bytes ~close_rec arg (** New stream that consumes exactly [size] bytes from the input. @@ -339,15 +434,15 @@ module Input = struct (* small buffer to read the chunk sizes *) let line_buf = Buf.create ~size:32 () in - let read_next_chunk_len () : int = + let read_next_chunk_len ~deadline () : int = if !first then first := false else ( - let line = read_line_using ~buf:line_buf ic in + let line = read_line_using ~buf:line_buf ~deadline ic in if String.trim line <> "" then raise (fail "expected crlf between chunks") ); - let line = read_line_using ~buf:line_buf ic in + let line = read_line_using ~buf:line_buf ~deadline ic in (* parse chunk length, ignore extensions *) let chunk_size = if String.trim line = "" then @@ -367,11 +462,12 @@ module Input = struct let chunk_size = ref 0 in object - inherit t_from_refill ~bytes () + inherit t_with_timeout_from_refill ~bytes () - method private refill (slice : Slice.t) : unit = + method private refill_with_timeout t (slice : Slice.t) : unit = + let deadline = Time.now_s () +. t in if !chunk_size = 0 && not !eof then ( - chunk_size := read_next_chunk_len (); + chunk_size := read_next_chunk_len ~deadline (); if !chunk_size = 0 then eof := true (* stream is finished *) ); slice.off <- 0; @@ -379,7 +475,7 @@ module Input = struct if !chunk_size > 0 then ( (* read the whole chunk, or [Bytes.length bytes] of it *) let to_read = min !chunk_size (Bytes.length slice.bytes) in - read_exactly_ + read_exactly_ ~deadline ~too_short:(fun () -> raise (fail "chunk is too short")) ic slice.bytes to_read; slice.len <- to_read; @@ -389,19 +485,8 @@ module Input = struct method close () = eof := true (* do not close underlying stream *) end - (** Output a stream using chunked encoding *) - let output_chunked' ?buf (oc : #Iostream.Out_buf.t) (self : #t) : unit = - let oc' = Output.chunk_encoding ?buf oc ~close_rec:false in - match to_chan' oc' self with - | () -> Output.close oc' - | exception e -> - let bt = Printexc.get_raw_backtrace () in - Output.close oc'; - Printexc.raise_with_backtrace e bt - - (** print a stream as a series of chunks *) - let output_chunked ?buf (oc : out_channel) (self : #t) : unit = - output_chunked' ?buf (Output.of_out_channel oc) self + let output_chunked = Input.output_chunked + let output_chunked' = Input.output_chunked' end (** A writer abstraction. *) @@ -441,7 +526,8 @@ end (** A TCP server abstraction. *) module TCP_server = struct type conn_handler = { - handle: client_addr:Unix.sockaddr -> Input.t -> Output.t -> unit; + handle: + client_addr:Unix.sockaddr -> Input_with_timeout.t -> Output.t -> unit; (** Handle client connection *) } diff --git a/src/core/dune b/src/core/dune index a04707ef..a8a7fb0c 100644 --- a/src/core/dune +++ b/src/core/dune @@ -4,6 +4,9 @@ (public_name tiny_httpd.core) (private_modules parse_ common_) (libraries threads seq hmap iostream + (select time.ml from + (mtime mtime.clock.os -> time.mtime.ml) + (unix -> time.default.ml)) (select log.ml from (logs -> log.logs.ml) (-> log.default.ml)))) diff --git a/src/core/headers.ml b/src/core/headers.ml index 89c4d8d2..3b137516 100644 --- a/src/core/headers.ml +++ b/src/core/headers.ml @@ -46,9 +46,9 @@ let for_all pred s = true with Exit -> false -let parse_ ~(buf : Buf.t) (bs : IO.Input.t) : t = +let parse_ ~(buf : Buf.t) ~deadline (bs : #IO.Input_with_timeout.t) : t = let rec loop acc = - match IO.Input.read_line_using_opt ~buf bs with + match IO.Input_with_timeout.read_line_using_opt ~buf ~deadline bs with | None -> raise End_of_file | Some "\r" -> acc | Some line -> diff --git a/src/core/headers.mli b/src/core/headers.mli index b46b5d54..62add34e 100644 --- a/src/core/headers.mli +++ b/src/core/headers.mli @@ -32,4 +32,4 @@ val contains : string -> t -> bool val pp : Format.formatter -> t -> unit (** Pretty print the headers. *) -val parse_ : buf:Buf.t -> IO.Input.t -> t +val parse_ : buf:Buf.t -> deadline:float -> #IO.Input_with_timeout.t -> t diff --git a/src/core/request.ml b/src/core/request.ml index 1a3275df..8bbbef58 100644 --- a/src/core/request.ml +++ b/src/core/request.ml @@ -88,29 +88,33 @@ let pp out self : unit = pp_with ~pp_body () out self (* decode a "chunked" stream into a normal stream *) -let read_stream_chunked_ ~bytes (bs : #IO.Input.t) : IO.Input.t = +let read_stream_chunked_ ~bytes (bs : #IO.Input_with_timeout.t) : + IO.Input_with_timeout.t = Log.debug (fun k -> k "body: start reading chunked stream..."); - IO.Input.read_chunked ~bytes ~fail:(fun s -> Bad_req (400, s)) bs + IO.Input_with_timeout.read_chunked ~bytes ~fail:(fun s -> Bad_req (400, s)) bs -let limit_body_size_ ~max_size ~bytes (bs : #IO.Input.t) : IO.Input.t = +let limit_body_size_ ~max_size ~bytes (bs : #IO.Input_with_timeout.t) : + IO.Input_with_timeout.t = Log.debug (fun k -> k "limit size of body to max-size=%d" max_size); - IO.Input.limit_size_to ~max_size ~close_rec:false ~bytes bs + IO.Input_with_timeout.limit_size_to ~max_size ~close_rec:false ~bytes bs -let limit_body_size ~max_size ~bytes (req : IO.Input.t t) : IO.Input.t t = +let limit_body_size ~max_size ~bytes (req : #IO.Input_with_timeout.t t) : + IO.Input_with_timeout.t t = { req with body = limit_body_size_ ~max_size ~bytes req.body } (** read exactly [size] bytes from the stream *) -let read_exactly ~size ~bytes (bs : #IO.Input.t) : IO.Input.t = +let read_exactly ~size ~bytes (bs : #IO.Input_with_timeout.t) : + IO.Input_with_timeout.t = Log.debug (fun k -> k "body: must read exactly %d bytes" size); - IO.Input.reading_exactly bs ~close_rec:false ~bytes ~size + IO.Input_with_timeout.reading_exactly bs ~close_rec:false ~bytes ~size (* parse request, but not body (yet) *) -let parse_req_start ~client_addr ~get_time_s ~buf (bs : IO.Input.t) : - unit t option resp_result = +let parse_req_start ~client_addr ~(deadline : float) ~buf + (bs : #IO.Input_with_timeout.t) : unit t option resp_result = try - let line = IO.Input.read_line_using ~buf bs in + let line = IO.Input_with_timeout.read_line_using ~buf ~deadline bs in Log.debug (fun k -> k "parse request line: %s" line); - let start_time = get_time_s () in + let start_time = Time.now_s () in let meth, path, version = try let off = ref 0 in @@ -134,7 +138,7 @@ let parse_req_start ~client_addr ~get_time_s ~buf (bs : IO.Input.t) : in let meth = Meth.of_string meth in Log.debug (fun k -> k "got meth: %s, path %S" (Meth.to_string meth) path); - let headers = Headers.parse_ ~buf bs in + let headers = Headers.parse_ ~buf ~deadline bs in let host = match Headers.get "Host" headers with | None -> bad_reqf 400 "No 'Host' header in request" @@ -170,8 +174,8 @@ let parse_req_start ~client_addr ~get_time_s ~buf (bs : IO.Input.t) : (* parse body, given the headers. @param tr_stream a transformation of the input stream. *) -let parse_body_ ~tr_stream ~bytes (req : IO.Input.t t) : - IO.Input.t t resp_result = +let parse_body_ ~tr_stream ~bytes (req : #IO.Input_with_timeout.t t) : + IO.Input_with_timeout.t t resp_result = try let size, has_size = match Headers.get_exn "Content-Length" req.headers |> int_of_string with @@ -186,7 +190,7 @@ let parse_body_ ~tr_stream ~bytes (req : IO.Input.t t) : bad_reqf 400 "specifying both transfer-encoding and content-length" | Some "chunked" -> (* body sent by chunks *) - let bs : IO.Input.t = + let bs : IO.Input_with_timeout.t = read_stream_chunked_ ~bytes @@ tr_stream req.body in if size > 0 then ( @@ -203,14 +207,15 @@ let parse_body_ ~tr_stream ~bytes (req : IO.Input.t t) : | Bad_req (c, s) -> Error (c, s) | e -> Error (400, Printexc.to_string e) -let read_body_full ?bytes ?buf_size (self : IO.Input.t t) : string t = +let read_body_full ?bytes ?buf_size ~deadline + (self : #IO.Input_with_timeout.t t) : string t = try let buf = match bytes with | Some b -> Buf.of_bytes b | None -> Buf.create ?size:buf_size () in - let body = IO.Input.read_all_using ~buf self.body in + let body = IO.Input_with_timeout.read_all_using ~buf ~deadline self.body in { self with body } with | Bad_req _ as e -> raise e @@ -220,11 +225,13 @@ module Private_ = struct let close_after_req = close_after_req let parse_req_start = parse_req_start - let parse_req_start_exn ?(buf = Buf.create ()) ~client_addr ~get_time_s bs = - parse_req_start ~client_addr ~get_time_s ~buf bs |> unwrap_resp_result + let parse_req_start_exn ?(buf = Buf.create ()) ~client_addr ~deadline bs = + parse_req_start ~client_addr ~deadline ~buf bs |> unwrap_resp_result let parse_body ?(bytes = Bytes.create 4096) req bs : _ t = - parse_body_ ~tr_stream:(fun s -> s) ~bytes { req with body = bs } + parse_body_ + ~tr_stream:(fun s -> (s :> IO.Input_with_timeout.t)) + ~bytes { req with body = bs } |> unwrap_resp_result let[@inline] set_body body self = { self with body } diff --git a/src/core/request.mli b/src/core/request.mli index e4242bcf..931e8531 100644 --- a/src/core/request.mli +++ b/src/core/request.mli @@ -129,17 +129,26 @@ val start_time : _ t -> float @since 0.11 *) val limit_body_size : - max_size:int -> bytes:bytes -> IO.Input.t t -> IO.Input.t t + max_size:int -> + bytes:bytes -> + #IO.Input_with_timeout.t t -> + IO.Input_with_timeout.t t (** Limit the body size to [max_size] bytes, or return a [413] error. @since 0.3 *) -val read_body_full : ?bytes:bytes -> ?buf_size:int -> IO.Input.t t -> string t +val read_body_full : + ?bytes:bytes -> + ?buf_size:int -> + deadline:float -> + #IO.Input_with_timeout.t t -> + string t (** Read the whole body into a string. Potentially blocking. @param buf_size initial size of underlying buffer (since 0.11) @param bytes the initial buffer (since 0.14) + @param deadline time after which this should fail with [Timeout] (since NEXT_RELEASE) *) (**/**) @@ -148,20 +157,26 @@ val read_body_full : ?bytes:bytes -> ?buf_size:int -> IO.Input.t t -> string t module Private_ : sig val parse_req_start : client_addr:Unix.sockaddr -> - get_time_s:(unit -> float) -> + deadline:float -> buf:Buf.t -> - IO.Input.t -> + IO.Input_with_timeout.t -> unit t option resp_result val parse_req_start_exn : ?buf:Buf.t -> client_addr:Unix.sockaddr -> - get_time_s:(unit -> float) -> - IO.Input.t -> + deadline:float -> + #IO.Input_with_timeout.t -> unit t option val close_after_req : _ t -> bool - val parse_body : ?bytes:bytes -> unit t -> IO.Input.t -> IO.Input.t t + + val parse_body : + ?bytes:bytes -> + unit t -> + #IO.Input_with_timeout.t -> + IO.Input_with_timeout.t t + val set_body : 'a -> _ t -> 'a t end diff --git a/src/core/server.ml b/src/core/server.ml index 1eb1715b..83d4810b 100644 --- a/src/core/server.ml +++ b/src/core/server.ml @@ -3,7 +3,9 @@ open Common_ type resp_error = Response_code.t * string module Middleware = struct - type handler = IO.Input.t Request.t -> resp:(Response.t -> unit) -> unit + type handler = + IO.Input_with_timeout.t Request.t -> resp:(Response.t -> unit) -> unit + type t = handler -> handler (** Apply a list of middlewares to [h] *) @@ -40,7 +42,11 @@ module type UPGRADE_HANDLER = sig code is [101] alongside these headers. *) val handle_connection : - Unix.sockaddr -> handshake_state -> IO.Input.t -> IO.Output.t -> unit + Unix.sockaddr -> + handshake_state -> + IO.Input_with_timeout.t -> + IO.Output.t -> + unit (** Take control of the connection and take it from there *) end @@ -52,9 +58,6 @@ module type IO_BACKEND = sig val init_addr : unit -> string val init_port : unit -> int - val get_time_s : unit -> float - (** obtain the current timestamp in seconds. *) - val tcp_server : unit -> IO.TCP_server.builder (** Server that can listen on a port and handle clients. *) end @@ -72,13 +75,14 @@ let unwrap_handler_result req = function type t = { backend: (module IO_BACKEND); mutable tcp_server: IO.TCP_server.t option; - mutable handler: IO.Input.t Request.t -> Response.t; + mutable handler: IO.Input_with_timeout.t Request.t -> Response.t; (** toplevel handler, if any *) mutable middlewares: (int * Middleware.t) list; (** Global middlewares *) mutable middlewares_sorted: (int * Middleware.t) list lazy_t; (** sorted version of {!middlewares} *) mutable path_handlers: (unit Request.t -> handler_result option) list; (** path handlers *) + request_timeout_s: float; (** Timeout for parsing requests *) bytes_pool: bytes Pool.t; } @@ -169,7 +173,8 @@ let add_route_handler (type a) ?accept ?middlewares ?meth self let tr_req _oc req ~resp f = let req = Pool.with_resource self.bytes_pool @@ fun bytes -> - Request.read_body_full ~bytes req + let deadline = Time.now_s () +. self.request_timeout_s in + Request.read_body_full ~bytes ~deadline req in resp (f req) in @@ -190,7 +195,8 @@ let add_route_server_sent_handler ?accept self route f = let tr_req (oc : IO.Output.t) req ~resp f = let req = Pool.with_resource self.bytes_pool @@ fun bytes -> - Request.read_body_full ~bytes req + let deadline = Time.now_s () +. self.request_timeout_s in + Request.read_body_full ~bytes ~deadline req in let headers = ref Headers.(empty |> set "content-type" "text/event-stream") @@ -257,7 +263,11 @@ let add_upgrade_handler ?(accept = fun _ -> Ok ()) (self : t) route f : unit = let clear_bytes_ bs = Bytes.fill bs 0 (Bytes.length bs) '\x00' -let create_from ?(buf_size = 16 * 1_024) ?(middlewares = []) ~backend () : t = +(* client has at most 10s to send the request, unless it's a streaming request *) +let default_req_timeout_s_ = 30. + +let create_from ?(buf_size = 16 * 1_024) ?(middlewares = []) + ?(request_timeout_s = default_req_timeout_s_) ~backend () : t = let handler _req = Response.fail ~code:404 "no top handler" in let self = { @@ -267,6 +277,7 @@ let create_from ?(buf_size = 16 * 1_024) ?(middlewares = []) ~backend () : t = path_handlers = []; middlewares = []; middlewares_sorted = lazy []; + request_timeout_s; bytes_pool = Pool.create ~clear:clear_bytes_ ~mk_item:(fun () -> Bytes.create buf_size) @@ -304,13 +315,11 @@ let string_as_list_contains_ (s : string) (sub : string) : bool = let client_handle_for (self : t) ~client_addr ic oc : unit = Pool.with_resource self.bytes_pool @@ fun bytes_req -> Pool.with_resource self.bytes_pool @@ fun bytes_res -> - let (module B) = self.backend in - (* how to log the response to this query *) let log_response (req : _ Request.t) (resp : Response.t) = if not Log.dummy then ( let msgf k = - let elapsed = B.get_time_s () -. req.start_time in + let elapsed = Time.now_s () -. req.start_time in k ("response to=%s code=%d time=%.3fs meth=%s path=%S" : _ format4) (Util.show_sockaddr client_addr) @@ -387,10 +396,10 @@ let client_handle_for (self : t) ~client_addr ic oc : unit = let continue = ref true in let handle_one_req () = + let deadline = Time.now_s () +. self.request_timeout_s in match let buf = Buf.of_bytes bytes_req in - Request.Private_.parse_req_start ~client_addr ~get_time_s:B.get_time_s - ~buf ic + Request.Private_.parse_req_start ~client_addr ~deadline ~buf ic with | Ok None -> continue := false (* client is done *) | Error (c, s) -> diff --git a/src/core/server.mli b/src/core/server.mli index e856c7e4..87be36ac 100644 --- a/src/core/server.mli +++ b/src/core/server.mli @@ -17,7 +17,8 @@ *) module Middleware : sig - type handler = IO.Input.t Request.t -> resp:(Response.t -> unit) -> unit + type handler = + IO.Input_with_timeout.t Request.t -> resp:(Response.t -> unit) -> unit (** Handlers are functions returning a response to a request. The response can be delayed, hence the use of a continuation as the [resp] parameter. *) @@ -52,9 +53,6 @@ module type IO_BACKEND = sig val init_port : unit -> int (** Initial port *) - val get_time_s : unit -> float - (** Obtain the current timestamp in seconds. *) - val tcp_server : unit -> IO.TCP_server.builder (** TCP server builder, to create servers that can listen on a port and handle clients. *) @@ -63,6 +61,7 @@ end val create_from : ?buf_size:int -> ?middlewares:([ `Encoding | `Stage of int ] * Middleware.t) list -> + ?request_timeout_s:float -> backend:(module IO_BACKEND) -> unit -> t @@ -74,6 +73,7 @@ val create_from : @param buf_size size for buffers (since 0.11) @param middlewares see {!add_middleware} for more details. + @param request_timeout_s default timeout for requests (headers+body) (since NEXT_RELEASE) @since 0.14 *) @@ -95,7 +95,8 @@ val active_connections : t -> int val add_decode_request_cb : t -> - (unit Request.t -> (unit Request.t * (IO.Input.t -> IO.Input.t)) option) -> + (unit Request.t -> + (unit Request.t * (IO.Input_with_timeout.t -> IO.Input_with_timeout.t)) option) -> unit [@@deprecated "use add_middleware"] (** Add a callback for every request. @@ -130,7 +131,8 @@ val add_middleware : (** {2 Request handlers} *) -val set_top_handler : t -> (IO.Input.t Request.t -> Response.t) -> unit +val set_top_handler : + t -> (IO.Input_with_timeout.t Request.t -> Response.t) -> unit (** Setup a handler called by default. This handler is called with any request not accepted by any handler @@ -174,7 +176,7 @@ val add_route_handler_stream : ?middlewares:Middleware.t list -> ?meth:Meth.t -> t -> - ('a, IO.Input.t Request.t -> Response.t) Route.t -> + ('a, IO.Input_with_timeout.t Request.t -> Response.t) Route.t -> 'a -> unit (** Similar to {!add_route_handler}, but where the body of the request @@ -257,7 +259,11 @@ module type UPGRADE_HANDLER = sig The connection is closed without further ado. *) val handle_connection : - Unix.sockaddr -> handshake_state -> IO.Input.t -> IO.Output.t -> unit + Unix.sockaddr -> + handshake_state -> + IO.Input_with_timeout.t -> + IO.Output.t -> + unit (** Take control of the connection and take it from ther.e *) end diff --git a/src/prometheus/time_.default.ml b/src/core/time.default.ml similarity index 71% rename from src/prometheus/time_.default.ml rename to src/core/time.default.ml index 86dd302c..7c0959fd 100644 --- a/src/prometheus/time_.default.ml +++ b/src/core/time.default.ml @@ -1,3 +1,5 @@ +let now_s = Unix.gettimeofday + let[@inline] now_us () = let t = Unix.gettimeofday () in t *. 1e6 |> ceil diff --git a/src/core/time.mli b/src/core/time.mli new file mode 100644 index 00000000..18558346 --- /dev/null +++ b/src/core/time.mli @@ -0,0 +1,10 @@ +(** Basic time measurement. + + This provides a basic clock, monotonic if [mtime] is installed, + or based on [Unix.gettimeofday] otherwise *) + +val now_us : unit -> float +(** Current time in microseconds. The precision should be at least below the millisecond. *) + +val now_s : unit -> float +(** Current time in seconds. The precision should be at least below the millisecond. *) diff --git a/src/core/time.mtime.ml b/src/core/time.mtime.ml new file mode 100644 index 00000000..2c3202da --- /dev/null +++ b/src/core/time.mtime.ml @@ -0,0 +1,7 @@ +let[@inline] now_s () = + let t = Mtime_clock.now_ns () in + Int64.(div t 1_000_000_000L |> to_float) + +let[@inline] now_us () = + let t = Mtime_clock.now_ns () in + Int64.(div t 1000L |> to_float) diff --git a/src/prometheus/common_p_.ml b/src/prometheus/common_p_.ml index 812670ab..e610f671 100644 --- a/src/prometheus/common_p_.ml +++ b/src/prometheus/common_p_.ml @@ -1,3 +1,4 @@ module A = Tiny_httpd_core.Atomic_ +module Time = Tiny_httpd_core.Time let spf = Printf.sprintf diff --git a/src/prometheus/dune b/src/prometheus/dune index 3439a474..b415ed3e 100644 --- a/src/prometheus/dune +++ b/src/prometheus/dune @@ -4,10 +4,7 @@ (name tiny_httpd_prometheus) (public_name tiny_httpd.prometheus) (synopsis "Metrics using prometheus") - (private_modules common_p_ time_) + (private_modules common_p_) (flags :standard -open Tiny_httpd_core) (libraries - tiny_httpd.core unix - (select time_.ml from - (mtime mtime.clock.os -> time_.mtime.ml) - (-> time_.default.ml)))) + tiny_httpd.core unix)) diff --git a/src/prometheus/time_.mli b/src/prometheus/time_.mli deleted file mode 100644 index d6824fba..00000000 --- a/src/prometheus/time_.mli +++ /dev/null @@ -1 +0,0 @@ -val now_us : unit -> float diff --git a/src/prometheus/time_.mtime.ml b/src/prometheus/time_.mtime.ml deleted file mode 100644 index 65e2ec73..00000000 --- a/src/prometheus/time_.mtime.ml +++ /dev/null @@ -1,3 +0,0 @@ -let[@inline] now_us () = - let t = Mtime_clock.now_ns () in - Int64.(div t 1000L |> to_float) diff --git a/src/prometheus/tiny_httpd_prometheus.ml b/src/prometheus/tiny_httpd_prometheus.ml index b3ec4e39..b3529320 100644 --- a/src/prometheus/tiny_httpd_prometheus.ml +++ b/src/prometheus/tiny_httpd_prometheus.ml @@ -189,12 +189,12 @@ let http_middleware (reg : Registry.t) : Server.Middleware.t = fun h : Server.Middleware.handler -> fun req ~resp : unit -> - let start = Time_.now_us () in + let start = Time.now_us () in Counter.incr c_req; h req ~resp:(fun (response : Response.t) -> let code = response.code in - let elapsed_us = Time_.now_us () -. start in + let elapsed_us = Time.now_us () -. start in let elapsed_s = elapsed_us /. 1e6 in Histogram.add h_latency elapsed_s; diff --git a/src/unix/dir.ml b/src/unix/dir.ml index 0035849c..ea693057 100644 --- a/src/unix/dir.ml +++ b/src/unix/dir.ml @@ -93,12 +93,12 @@ let vfs_of_dir (top : string) : vfs = let contains f = Sys.file_exists (top // f) let list_dir f = Sys.readdir (top // f) - let read_file_content f = + let read_file_content f : IO.Input.t = let fpath = top // f in match Unix.stat fpath with | { st_kind = Unix.S_REG; _ } -> let ic = Unix.(openfile fpath [ O_RDONLY ] 0) in - let closed = ref false in + let closed = Atomic_.make false in let buf = IO.Slice.create 4096 in IO.Input.of_unix_fd ~buf ~close_noerr:true ~closed ic | _ -> failwith (Printf.sprintf "not a regular file: %S" f) diff --git a/src/unix/tiny_httpd_unix.ml b/src/unix/tiny_httpd_unix.ml index f1de3936..f8d564e6 100644 --- a/src/unix/tiny_httpd_unix.ml +++ b/src/unix/tiny_httpd_unix.ml @@ -92,15 +92,15 @@ module Unix_tcp_server_ = struct Pool.with_resource self.slice_pool @@ fun ic_buf -> Pool.with_resource self.slice_pool @@ fun oc_buf -> - let closed = ref false in + let closed = Atomic_.make false in let oc = new IO.Output.of_unix_fd ~close_noerr:true ~closed ~buf:oc_buf client_sock in let ic = - IO.Input.of_unix_fd ~close_noerr:true ~closed ~buf:ic_buf - client_sock + IO.Input_with_timeout.of_unix_fd ~close_noerr:true ~closed + ~buf:ic_buf client_sock in handle.handle ~client_addr ic oc in diff --git a/src/ws/tiny_httpd_ws.ml b/src/ws/tiny_httpd_ws.ml index bf12ef2f..88784f59 100644 --- a/src/ws/tiny_httpd_ws.ml +++ b/src/ws/tiny_httpd_ws.ml @@ -1,6 +1,6 @@ open Common_ws_ -type handler = Unix.sockaddr -> IO.Input.t -> IO.Output.t -> unit +type handler = Unix.sockaddr -> IO.Input_with_timeout.t -> IO.Output.t -> unit module Frame_type = struct type t = int @@ -196,7 +196,7 @@ module Reader = struct | Close type t = { - ic: IO.Input.t; + ic: IO.Input_with_timeout.t; writer: Writer.t; (** Writer, to send "pong" *) header_buf: bytes; (** small buffer to read frame headers *) small_buf: bytes; (** Used for control frames *) @@ -220,52 +220,65 @@ module Reader = struct let max_fragment_size = 1 lsl 30 (** Read next frame header into [self.header] *) - let read_frame_header (self : t) : unit = - (* read header *) - IO.Input.really_input self.ic self.header_buf 0 2; - - let b0 = Bytes.unsafe_get self.header_buf 0 |> Char.code in - let b1 = Bytes.unsafe_get self.header_buf 1 |> Char.code in + let read_frame_header (self : t) ~deadline : unit = + try + (* read header *) + IO.Input_with_timeout.really_input self.ic ~deadline self.header_buf 0 2; - self.header.fin <- b0 land 1 == 1; - let ext = (b0 lsr 4) land 0b0111 in - if ext <> 0 then ( - Log.error (fun k -> k "websocket: unknown extension %d, closing" ext); - raise Close_connection - ); + let b0 = Bytes.unsafe_get self.header_buf 0 |> Char.code in + let b1 = Bytes.unsafe_get self.header_buf 1 |> Char.code in - self.header.ty <- b0 land 0b0000_1111; - self.header.mask <- b1 land 0b1000_0000 != 0; - - let payload_len : int = - let len = b1 land 0b0111_1111 in - if len = 126 then ( - IO.Input.really_input self.ic self.header_buf 0 2; - Bytes.get_int16_be self.header_buf 0 - ) else if len = 127 then ( - IO.Input.really_input self.ic self.header_buf 0 8; - let len64 = Bytes.get_int64_be self.header_buf 0 in - if compare len64 (Int64.of_int max_fragment_size) > 0 then ( - Log.error (fun k -> - k "websocket: maximum frame fragment exceeded (%Ld > %d)" len64 - max_fragment_size); - raise Close_connection - ); + self.header.fin <- b0 land 1 == 1; + let ext = (b0 lsr 4) land 0b0111 in + if ext <> 0 then ( + Log.error (fun k -> k "websocket: unknown extension %d, closing" ext); + raise Close_connection + ); - Int64.to_int len64 - ) else - len - in - self.header.payload_len <- payload_len; + self.header.ty <- b0 land 0b0000_1111; + self.header.mask <- b1 land 0b1000_0000 != 0; + + let payload_len : int = + let len = b1 land 0b0111_1111 in + if len = 126 then ( + IO.Input_with_timeout.really_input self.ic ~deadline self.header_buf 0 + 2; + Bytes.get_int16_be self.header_buf 0 + ) else if len = 127 then ( + IO.Input_with_timeout.really_input self.ic ~deadline self.header_buf 0 + 8; + let len64 = Bytes.get_int64_be self.header_buf 0 in + if compare len64 (Int64.of_int max_fragment_size) > 0 then ( + Log.error (fun k -> + k "websocket: maximum frame fragment exceeded (%Ld > %d)" len64 + max_fragment_size); + raise Close_connection + ); + + Int64.to_int len64 + ) else + len + in + self.header.payload_len <- payload_len; - if self.header.mask then - IO.Input.really_input self.ic self.header.mask_key 0 4; + if self.header.mask then + IO.Input_with_timeout.really_input self.ic ~deadline + self.header.mask_key 0 4; - (*Log.debug (fun k -> - k "websocket: read frame header type=%s payload_len=%d mask=%b" - (Frame_type.show self.header.ty) - self.header.payload_len self.header.mask);*) - () + (*Log.debug (fun k -> + k "websocket: read frame header type=%s payload_len=%d mask=%b" + (Frame_type.show self.header.ty) + self.header.payload_len self.header.mask);*) + () + with + | IO.Input_with_timeout.Timeout_partial_read _ + | IO.Input_with_timeout.Timeout + -> + (* NOTE: this is not optimal but it's the easiest solution, for now, + to the problem of a partially read frame header with + a timeout in the middle (we would have to save *) + Log.error (fun k -> k "websocket: timeout while reading frame header"); + raise Close_connection external apply_masking_ : bytes -> bytes -> int -> int -> unit = "tiny_httpd_ws_apply_masking" @@ -276,30 +289,45 @@ module Reader = struct assert (off >= 0 && off + len <= Bytes.length buf); apply_masking_ mask_key buf off len - let read_body_to_string (self : t) : string = + let read_body_to_string (self : t) ~deadline : string = let len = self.header.payload_len in let buf = Bytes.create len in - IO.Input.really_input self.ic buf 0 len; + (try IO.Input_with_timeout.really_input self.ic ~deadline buf 0 len + with + | IO.Input_with_timeout.Timeout_partial_read _ + | IO.Input_with_timeout.Timeout + -> + raise Close_connection); if self.header.mask then apply_masking ~mask_key:self.header.mask_key buf 0 len; Bytes.unsafe_to_string buf (** Skip bytes of the body *) - let skip_body (self : t) : unit = + let skip_body (self : t) ~deadline : unit = let len = ref self.header.payload_len in while !len > 0 do let n = min !len (Bytes.length self.small_buf) in - IO.Input.really_input self.ic self.small_buf 0 n; + (try + IO.Input_with_timeout.really_input self.ic ~deadline self.small_buf 0 n + with + | IO.Input_with_timeout.Timeout_partial_read _ + | IO.Input_with_timeout.Timeout + -> + raise Close_connection); len := !len - n done (** State machine that reads [len] bytes into [buf] *) - let rec read_rec (self : t) buf i len : int = + let rec read_rec (self : t) ~deadline buf i len : int = match self.state with | Close -> 0 | Reading_frame r -> let len = min len r.remaining_bytes in - let n = IO.Input.input self.ic buf i len in + let timeout = Time.now_s () -. deadline in + if timeout <= 0. then raise IO.Input_with_timeout.Timeout; + let n = + IO.Input_with_timeout.input_with_timeout self.ic timeout buf i len + in (* update state *) r.remaining_bytes <- r.remaining_bytes - n; @@ -313,7 +341,7 @@ module Reader = struct ); n | Begin -> - read_frame_header self; + read_frame_header self ~deadline; (*Log.debug (fun k -> k "websocket: read frame of type=%s payload_len=%d" (Frame_type.show self.header.ty) @@ -330,19 +358,19 @@ module Reader = struct (Frame_type.show self.last_ty)); raise Close_connection ); - read_rec self buf i len + read_rec self ~deadline buf i len | 1 -> self.state <- Reading_frame { remaining_bytes = self.header.payload_len }; - read_rec self buf i len + read_rec self ~deadline buf i len | 2 -> self.state <- Reading_frame { remaining_bytes = self.header.payload_len }; - read_rec self buf i len + read_rec self ~deadline buf i len | 8 -> (* close frame *) self.state <- Close; - let body = read_body_to_string self in + let body = read_body_to_string self ~deadline in if String.length body >= 2 then ( let errcode = Bytes.get_int16_be (Bytes.unsafe_of_string body) 0 in Log.info (fun k -> @@ -352,19 +380,19 @@ module Reader = struct 0 | 9 -> (* pong, just ignore *) - skip_body self; + skip_body self ~deadline; Writer.send_pong self.writer; - read_rec self buf i len + read_rec self ~deadline buf i len | 10 -> (* pong, just ignore *) - skip_body self; - read_rec self buf i len + skip_body self ~deadline; + read_rec self ~deadline buf i len | ty -> Log.error (fun k -> k "unknown frame type: %xd" ty); raise Close_connection) - let read self buf i len = - try read_rec self buf i len + let read self ~deadline buf i len = + try read_rec self ~deadline buf i len with Close_connection -> self.state <- Close; 0 @@ -376,16 +404,26 @@ module Reader = struct ) end -let upgrade ic oc : _ * _ = +(* 30 min *) +let default_timeout_s = 60. *. 30. + +let upgrade ?(timeout_s = default_timeout_s) ic oc : _ * _ = let writer = Writer.create ~oc () in let reader = Reader.create ~ic ~writer () in - let ws_ic : IO.Input.t = - object - inherit IO.Input.t_from_refill ~bytes:(Bytes.create 4_096) () - - method private refill (slice : IO.Slice.t) = + let ws_ic : IO.Input_with_timeout.t = + object (self) + inherit + IO.Input_with_timeout.t_with_timeout_from_refill + ~bytes:(Bytes.create 4_096) () as super + + method private refill_with_timeout t (slice : IO.Slice.t) = + let deadline = Time.now_s () +. t in slice.off <- 0; - slice.len <- Reader.read reader slice.bytes 0 (Bytes.length slice.bytes) + slice.len <- + Reader.read reader ~deadline slice.bytes 0 (Bytes.length slice.bytes) + + method! fill_buf () = + IO.Input_with_timeout.fill_buf_with_timeout self timeout_s method close () = Reader.close reader end @@ -404,6 +442,7 @@ let upgrade ic oc : _ * _ = module Make_upgrade_handler (X : sig val accept_ws_protocol : string -> bool val handler : handler + val timeout_s : float end) : Server.UPGRADE_HANDLER = struct type handshake_state = unit @@ -446,7 +485,7 @@ end) : Server.UPGRADE_HANDLER = struct try Ok (handshake_ req) with Bad_req s -> Error s let handle_connection addr () ic oc = - let ws_ic, ws_oc = upgrade ic oc in + let ws_ic, ws_oc = upgrade ~timeout_s:X.timeout_s ic oc in try X.handler addr ws_ic ws_oc with Close_connection -> Log.debug (fun k -> k "websocket: requested to close the connection"); @@ -454,10 +493,12 @@ end) : Server.UPGRADE_HANDLER = struct end let add_route_handler ?accept ?(accept_ws_protocol = fun _ -> true) - (server : Server.t) route (f : handler) : unit = + ?(timeout_s = default_timeout_s) (server : Server.t) route (f : handler) : + unit = let module M = Make_upgrade_handler (struct let handler = f let accept_ws_protocol = accept_ws_protocol + let timeout_s = timeout_s end) in let up : Server.upgrade_handler = (module M) in Server.add_upgrade_handler ?accept server route up diff --git a/src/ws/tiny_httpd_ws.mli b/src/ws/tiny_httpd_ws.mli index 2bd30f70..f9d84d0a 100644 --- a/src/ws/tiny_httpd_ws.mli +++ b/src/ws/tiny_httpd_ws.mli @@ -4,15 +4,20 @@ for a websocket server. It has no additional dependencies. *) -type handler = Unix.sockaddr -> IO.Input.t -> IO.Output.t -> unit +type handler = Unix.sockaddr -> IO.Input_with_timeout.t -> IO.Output.t -> unit (** Websocket handler *) -val upgrade : IO.Input.t -> IO.Output.t -> IO.Input.t * IO.Output.t +val upgrade : + ?timeout_s:float -> + IO.Input_with_timeout.t -> + IO.Output.t -> + IO.Input_with_timeout.t * IO.Output.t (** Upgrade a byte stream to the websocket framing protocol. *) val add_route_handler : ?accept:(unit Request.t -> (unit, int * string) result) -> ?accept_ws_protocol:(string -> bool) -> + ?timeout_s:float -> Server.t -> (Server.upgrade_handler, Server.upgrade_handler) Route.t -> handler -> diff --git a/tests/unit/t_server.ml b/tests/unit/t_server.ml index 01b82eac..d25829e4 100644 --- a/tests/unit/t_server.ml +++ b/tests/unit/t_server.ml @@ -9,12 +9,13 @@ let () = \r\n\ salutationsSOMEJUNK" in - let str = IO.Input.of_string q in + let str = IO.Input_with_timeout.of_string q in let client_addr = Unix.(ADDR_INET (inet_addr_loopback, 1024)) in + + let deadline = Time.now_s () +. 10. in let r = Request.Private_.parse_req_start_exn ~client_addr ~buf:(Buf.create ()) - ~get_time_s:(fun _ -> 0.) - str + ~deadline str in match r with | None -> failwith "should parse" @@ -23,6 +24,8 @@ let () = assert_eq (Some "coucou") (Headers.get "host" req.headers); assert_eq (Some "11") (Headers.get "content-length" req.headers); assert_eq "hello" req.path; - let req = Request.Private_.parse_body req str |> Request.read_body_full in + let req = + Request.Private_.parse_body req str |> Request.read_body_full ~deadline + in assert_eq ~to_string:(fun s -> s) "salutations" req.body; () diff --git a/tiny_httpd.opam b/tiny_httpd.opam index c144b511..980a938d 100644 --- a/tiny_httpd.opam +++ b/tiny_httpd.opam @@ -16,6 +16,7 @@ depends: [ "base-threads" "result" "hmap" + "base-unix" "iostream" {>= "0.2"} "ocaml" {>= "4.08"} "odoc" {with-doc}