diff --git a/.gitignore b/.gitignore index 9ac1073d7..d2c412189 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,5 @@ /test/websockets/reports/* .idea/* .vscode +!test/fixtures/http2.key +!test/fixtures/http2.crt diff --git a/Project.toml b/Project.toml index 6d1b5d973..70f876cda 100644 --- a/Project.toml +++ b/Project.toml @@ -8,29 +8,27 @@ Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -LibAwsCommon = "c6e421ba-b5f8-4792-a1c4-42948de3ed9d" -LibAwsHTTPFork = "d3f1d20b-921e-4930-8491-471e0be3121a" -LibAwsIO = "a5388770-19df-4151-b103-3d71de896ddf" +AwsHTTP = "d4eb1443-154a-48c0-b55a-2f1d1087a5c5" +Reseau = "802f3686-a58f-41ce-bb0c-3c43c75bba36" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Mmap = "a63ad114-7e13-5084-954f-fe012c677804" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" [compat] CodecZlib = "0.7" JSON = "0.21.4, 1" -LibAwsCommon = "1.3" -LibAwsHTTPFork = "1.0.2" -LibAwsIO = "1.2.0" +AwsHTTP = "0.1" +Reseau = "1.1" PrecompileTools = "1.2.1" URIs = "1" julia = "1.10" [extras] JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["JSON", "Test"] +test = ["JSON", "Sockets", "Test"] diff --git a/docs/src/manual/client.md b/docs/src/manual/client.md index a29b0404f..62a05dd0a 100644 --- a/docs/src/manual/client.md +++ b/docs/src/manual/client.md @@ -35,10 +35,14 @@ The following keyword arguments (which correspond to the non-`scheme`/`host`/`po - **chunkedbody**: To send a request body in chunks, an iterable must be provided where each element is one of the valid types of request bodies mentioned above. - **modifier**: A function of the form `f(request, body) -> newbody`, i.e. that takes the HTTP request object and proposed request body, and can optionally return a new request body. If the modifer only modifies the request object, it should return `nothing`, which will ensure the original request body is sent unmodified. -- Response options: - - **response_body**: By default, response bodies are returned as `Vector{UInt8}`. Alternatively, a preallocated `AbstractVector{UInt8}` or any `IO` object can be provided for the response body to be written into. + - **response_body**: By default, response bodies are returned as `Vector{UInt8}`. Alternatively, a preallocated `AbstractVector{UInt8}` or any `IO` object can be provided for the response body to be written into. `response_stream` is a compatible alias. - **decompress**: If `true`, the response body will be decompressed if it is compressed. By default, response bodies with the `Content-Encoding: gzip` header are decompressed. - **status_exception**: Default `true`. If `true`, an exception will be thrown if the response status code is not in the 200-299 range. - **readtimeout**: The maximum time in seconds to wait for a response from the server. Only valid for HTTP/1.1 connections. +-- Retry options (per request): + - **retry_non_idempotent**: Default `false`. If `true`, non-idempotent requests may be retried. + - **retry_check**: Optional function to override retry decisions. It is called as `retry_check(delay, err, req, resp, resp_body)` when a retry is being considered. + - **retry_delays**: Custom retry delays. Provide a number (seconds) or any iterator of delays; defaults to exponential backoff. -- Redirect options: - **redirect**: Default `true`. If `true`, the client will follow redirects. - **redirect_limit**: The maximum number of redirects to follow. Default is 3. @@ -77,8 +81,11 @@ The following keyword arguments (which correspond to the non-`scheme`/`host`/`po - **proxy_ssl_cacert**: The path to the CA certificate file for the proxy. - **proxy_ssl_insecure**: Default `false`. If `true`, SSL certificate verification will be disabled for the proxy. - **proxy_ssl_alpn_list**: A list of ALPN protocols to use for the proxy connection. Default is `"h2;http/1.1"`. + - **proxy_auth**: Optional. Set to `:basic` to enable basic auth on proxy requests. + - **proxy_username**: Username for proxy basic auth (requires explicit `proxy_host`/`proxy_port`). + - **proxy_password**: Password for proxy basic auth (requires explicit `proxy_host`/`proxy_port`). -- Retry options: - - **max_retries**: The maximum number of times to retry a request. Default is 10. + - **max_retries**: The maximum number of times to retry a request. Default is 4. - **retry_partition**: Requests utilizing the same retry partition (an arbitrary string) will coordinate retries against each other to not overwhelm a temporarily unresponsive server. - **backoff_scale_factor_ms**: The factor by which to scale the backoff time between retries. Default is 25. - **max_backoff_secs**: The maximum time in seconds to wait between retries. Default is 20. @@ -88,6 +95,25 @@ The following keyword arguments (which correspond to the non-`scheme`/`host`/`po -- Connection pool options: - **max_connections**: The maximum number of connections to keep open in the connection pool. Default is 512. - **max_connection_idle_in_milliseconds**: The maximum time in milliseconds to keep a connection open in the pool. Default is 60000. + - **connection_acquisition_timeout_ms**: The maximum time in milliseconds to wait for a connection from the pool. Default is 0 (no timeout). + - **max_pending_connection_acquisitions**: Maximum number of pending connection acquisitions. Default is 0 (no limit). + - **enable_read_back_pressure**: If `true`, enable back pressure on reads to limit buffered data. Default `false`. + - **response_first_byte_timeout_ms**: Maximum time in milliseconds to wait for the first response byte. Default is 0 (disabled). For per-request control, use `readtimeout`. + -- Monitoring options: + - **monitoring_minimum_throughput_bytes_per_second**: Minimum throughput to consider the connection healthy. Default is 0 (disabled). + - **monitoring_allowable_throughput_failure_interval_seconds**: Seconds of below-minimum throughput before the connection is closed. Default is 0 (disabled). + - **monitoring_statistics_observer**: Optional callback `(connection_nonce, stats) -> nothing` invoked with connection stats samples. + -- HTTP/2 options: + - **http2_prior_knowledge**: Default `false`. If `true`, assume HTTP/2 without ALPN negotiation. + - **http2_stream_manager**: Default `false`. If `true`, enable the HTTP/2 stream manager for multiplexed requests. + - **http2_close_connection_on_server_error**: Default `false`. If `true`, close HTTP/2 connections when a 5xx response is received. + - **http2_connection_manual_window_management**: Default `false`. If `true`, use manual connection-level flow control (call `HTTP.http2_update_window`). + - **http2_connection_ping_period_ms**: Default `0` (disabled). Period in milliseconds for sending HTTP/2 PING frames. + - **http2_connection_ping_timeout_ms**: Default `0` (AWS default). Timeout in milliseconds for PING responses. + - **http2_ideal_concurrent_streams_per_connection**: Default `0` (AWS default). Target streams per connection before opening new connections. + - **http2_max_concurrent_streams_per_connection**: Default `0` (no explicit limit). Upper bound for streams per connection. + - **http2_max_closed_streams**: Default `0` (AWS default). Max closed streams to remember before ignoring late frames. + - **http2_initial_window_size**: Default `65535`. Initial flow-control window size for HTTP/2 (must be <= 2^31-1). -- AWS runtime options: - **allocator**: The allocator to use for AWS-allocated memory during the request. - **bootstrap**: The AWS client bootstrap to use for the request. @@ -158,6 +184,67 @@ response = HTTP.get("https://api.example.com/data"; client = custom_client) println(String(response.body)) \`\`\` +## HTTP/2 Features (Advanced) + +### Stream Manager + +For high-concurrency HTTP/2 workloads, you can enable the HTTP/2 stream manager. This allows multiple in-flight +streams to share pooled HTTP/2 connections. + +\`\`\`julia +using HTTP + +client = HTTP.Client("https", "example.com", 443; http2_stream_manager=true) +resp = HTTP.get("https://example.com/resource"; client=client) +\`\`\` + +### Connection Control Helpers + +When a connection negotiates HTTP/2, you can use the following helpers: + +- `HTTP.http2_ping(client; data=nothing)` -> returns round-trip time in nanoseconds. +- `HTTP.http2_change_settings(client, settings)` where `settings` is a vector of pairs or `aws_http2_setting`. +- `HTTP.http2_local_settings(client)` / `HTTP.http2_remote_settings(client)` -> returns current settings. +- `HTTP.http2_send_goaway(client, error_code; allow_more_streams=true, debug_data=nothing)` +- `HTTP.http2_get_sent_goaway(client)` / `HTTP.http2_get_received_goaway(client)` -> returns `nothing` if no GOAWAY. +- `HTTP.http2_update_window(client_or_stream, increment)` -> increases the HTTP/2 connection flow-control window. +- `HTTP.update_window(stream, increment)` -> increases the stream flow-control window (useful with `enable_read_back_pressure=true`). + +HTTP/2-specific helpers require an HTTP/2 connection and will throw an `ArgumentError` if the connection is HTTP/1.1. + +## Trailing Headers + +Trailing headers are available on responses as `resp.trailers` after the response completes. For streaming requests, +you can attach trailers before closing the write side of the stream. + +\`\`\`julia +using HTTP + +resp = HTTP.open("POST", "https://example.com/upload") do stream + write(stream, "chunk-1") + write(stream, "chunk-2") + HTTP.addtrailer(stream, "x-checksum" => "abc123") +end + +resp.trailers === nothing || println(HTTP.header(resp.trailers, "x-server-checksum")) +\`\`\` + +## Metrics and Observability + +Each response includes a `metrics` field: + +- `response.metrics.request_body_length` +- `response.metrics.response_body_length` +- `response.metrics.nretries` +- `response.metrics.stream_metrics` (AWS CRT `aws_http_stream_metrics`) + +For connection-level metrics, use `HTTP.manager_metrics(client)`, which returns `aws_http_manager_metrics` with +`available_concurrency`, `pending_concurrency_acquires`, and `leased_concurrency`. + +To enable periodic connection statistics callbacks, pass `monitoring_statistics_observer` in `ClientSettings`. The +callback receives a `connection_nonce` and a vector of stats entries. Each entry has a `category` field +(`:http1_channel` or `:http2_channel`) and category-specific fields like `pending_outgoing_stream_ms`. + ## Under the Hood (Advanced) When you call `HTTP.request`, the following advanced steps occur: @@ -179,4 +266,3 @@ When you call `HTTP.request`, the following advanced steps occur: 6. **Response Processing:** The response is parsed, and if errors occur (as dictated by your settings), an exception is raised. - diff --git a/docs/src/manual/migrate.md b/docs/src/manual/migrate.md index f2bcfc824..acffd9bdf 100644 --- a/docs/src/manual/migrate.md +++ b/docs/src/manual/migrate.md @@ -37,12 +37,14 @@ r = HTTP.request("GET", "http://example.com") # Accessing fields status = r.status body_text = String(r.body) # Similar to v1.x -header_value = HTTP.header(r.headers, "Content-Type") +header_value = HTTP.header(r, "Content-Type") +# or HTTP.header(r.headers, "Content-Type") ``` Key differences: -- Headers access has changed slightly to operate on the `headers` field -- The `.body` field can now be any type, not just `Vector{UInt8}` +- Headers access works directly on `Request`/`Response`, or on the `headers` field if you already have it +- Request/response *body arguments* accept strings, byte vectors, Dict/NamedTuple (form-encoded), `HTTP.Form`, or IO +- `Response.body` remains a `Vector{UInt8}` for buffered responses (or `nothing` when streaming into `response_body`) - Context dictionary access is now through `.context` rather than request-specific fields ### Making Requests @@ -51,8 +53,10 @@ While the basic request syntax remains similar, there are some changes to keywor #### Changes to Keyword Arguments -- The `response_stream` keyword argument is still supported, but HTTP.jl no longer automatically closes this stream when done - you need to handle this yourself +- The preferred keyword for streaming responses is `response_body`. `response_stream` is still supported for compatibility +- Response streams are not automatically closed; you need to handle this yourself - `retry` behavior has been overhauled with more consistent rules for what is retryable +- The default `max_retries` is now 4 (was 10 in v1.x) - Some connection-related options have new defaults (e.g., TLS is now OpenSSL-based by default rather than MbedTLS) Example: @@ -84,7 +88,7 @@ end # After (v2.0) HTTP.open("GET", "http://example.com") do http - # Start reading must be explicitly called + # Optional: call startread to access headers before reading the body startread(http) while !eof(http) data = readavailable(http) @@ -133,7 +137,7 @@ end ``` Key differences: -- Most server functionality has been standardized around `serve`/`serve!` rather than `listen`/`listen!` +- `serve`/`serve!` are the primary entry points; `listen`/`listen!` remain available for stream handlers and WebSockets - The handler typically works with `Request`/`Response` objects rather than `Stream` objects - The lifecycle management for servers has improved with clearer semantics for `isopen`, `close`, and `wait` @@ -166,6 +170,7 @@ end ``` Note the addition of the `stream=true` keyword argument to indicate you want to work with a stream handler. +You can also use `HTTP.listen`/`HTTP.listen!` as shorthand for `stream=true`. ### Router and Middleware @@ -270,7 +275,7 @@ close(server) using HTTP.WebSockets # Non-blocking server -server = WebSockets.serve!("127.0.0.1", 8081) do ws +server = WebSockets.listen!("127.0.0.1", 8081) do ws for msg in ws # Echo back any received message send(ws, msg) @@ -281,7 +286,8 @@ end close(server) ``` -Note the change from `listen!` to `serve!` to maintain consistency with the HTTP server API. +`serve`/`serve!` are the primary request/response handlers. `listen`/`listen!` are stream-handler shorthands +equivalent to `serve(...; stream=true)`. ## Error Handling @@ -302,7 +308,7 @@ catch e if e isa HTTP.ConnectError println("Connection failed: $(e.error)") elseif e isa HTTP.TimeoutError - println("Request timed out after $(e.timeout) seconds") + println("Request timed out after $(e.readtimeout) seconds") elseif e isa HTTP.StatusError println("Server returned error status: $(e.status)") elseif e isa HTTP.RequestError @@ -325,16 +331,21 @@ jar = HTTP.CookieJar() response = HTTP.get("https://example.com", cookiejar=jar) # Checking cookies -cookies = HTTP.getcookies(jar, "example.com") +cookies = HTTP.Cookies.getcookies!(jar, "https", "example.com", "/") ``` ## Other Notable Changes - **URI Handling**: URIs are now handled by the separate URIs.jl package (this change actually occurred in v1.0) - **Default Headers**: Headers like `Accept: */*` are now included by default in requests -- **TLS Implementation**: OpenSSL is now the default TLS provider instead of MbedTLS +- **TLS Implementation**: TLS is handled by AWS CRT (s2n-tls) instead of MbedTLS - **Multithreading**: Improved thread safety throughout the codebase - **Performance**: Significant performance improvements, especially for high-throughput servers +- **Parser APIs**: Low-level parser APIs from v1.x have been removed in v2.0 +- **Trailing Headers**: Trailing headers are now captured on `Request.trailers` and `Response.trailers`, and can be sent with `HTTP.addtrailer` when streaming. +- **HTTP/2 Controls**: HTTP/2 helpers (ping, settings, GOAWAY) are available for advanced connection management. +- **Metrics**: Responses include `response.metrics` and clients expose `HTTP.manager_metrics` for connection manager stats. +- **Monitoring**: Optional connection monitoring callbacks are available via `monitoring_statistics_observer`. ## Transitioning Tips @@ -350,4 +361,4 @@ cookies = HTTP.getcookies(jar, "example.com") - Custom client-side layers from v1.x are not compatible with v2.0 and will need to be reimplemented - WebSocket handling code may need adjustments even though the API is similar -For more detailed information on specific topics, consult the full HTTP.jl v2.0.0 documentation. \ No newline at end of file +For more detailed information on specific topics, consult the full HTTP.jl v2.0.0 documentation. diff --git a/docs/src/manual/server.md b/docs/src/manual/server.md index 9dfc75b1e..b98850916 100644 --- a/docs/src/manual/server.md +++ b/docs/src/manual/server.md @@ -117,6 +117,53 @@ server = HTTP.serve!(handle_request, "127.0.0.1", 8080; println("Server started on http://127.0.0.1:8080") ``` +## Stream Handlers (Advanced) + +If you need direct access to the request and response streams, pass `stream=true` to `serve!` (or `serve`) and use +an `HTTP.Stream` handler. Reads will automatically start when you read from the stream, and writes will automatically +start when you write. + +```julia +using HTTP + +server = HTTP.serve!("127.0.0.1", 8080; stream=true) do stream::HTTP.Stream + req = HTTP.startread(stream) + data = read(stream) + + HTTP.setstatus(stream, 200) + HTTP.setheader(stream, "Content-Type" => "text/plain") + write(stream, "received $(length(data)) bytes") + HTTP.addtrailer(stream, "x-request-id" => HTTP.header(req, "x-request-id", "unknown")) +end +``` + +Trailing headers sent by the client are available after the request completes via `stream.request.trailers`. + +## HTTP/2 Server Push (Advanced) + +When handling an HTTP/2 request, you can send a push promise to the client and stream a pushed response. +The `HTTP.push_promise` function returns a new `HTTP.Stream` to write the pushed response. + +```julia +using HTTP + +HTTP.serve!("127.0.0.1", 8443; stream=true, ssl_cert="server.crt", ssl_key="server.key", ssl_alpn_list="h2") do stream + req = HTTP.startread(stream) + if stream.http2 + push = HTTP.push_promise(stream, "GET", "/assets/app.js"; scheme="https", authority="127.0.0.1:8443") + HTTP.setstatus(push, 200) + HTTP.setheader(push, "Content-Type" => "application/javascript") + write(push, "console.log(\"pushed\")") + HTTP.closewrite(push) + end + HTTP.setstatus(stream, 200) + write(stream, "ok") +end +``` + +The push request must include `:scheme` and `:authority`. If these are not present on the original request, +pass them explicitly via the keyword arguments. Clients that do not accept server push will reject the promise. + ## Handlers and Middleware ### Handler Functions diff --git a/docs/src/manual/websockets.md b/docs/src/manual/websockets.md index 259650e0c..33c08cdc7 100644 --- a/docs/src/manual/websockets.md +++ b/docs/src/manual/websockets.md @@ -82,6 +82,40 @@ WebSockets.open("wss://echo.websocket.org") do ws end ``` +## Server-side WebSockets + +For a dedicated WebSocket server, use `WebSockets.listen` or `WebSockets.listen!`: + +```julia +using HTTP +using HTTP.WebSockets + +server = WebSockets.listen!("127.0.0.1", 8080) do ws + for msg in ws + send(ws, msg) + end +end +``` + +To mix WebSockets and normal HTTP on the same port, use a stream handler with `HTTP.listen!` and +upgrade the connection when requested: + +```julia +using HTTP +using HTTP.WebSockets + +server = HTTP.listen!("127.0.0.1", 8080) do stream + if WebSockets.isupgrade(stream) + WebSockets.upgrade(stream) do ws + send(ws, "hello") + end + else + HTTP.setstatus(stream, 200) + write(stream, "ok") + end +end +``` + ## Connection Lifecycle and Error Handling You can check whether a WebSocket is open using `WebSockets.isclosed(ws)` and close it with `close(ws)`. The API is designed to raise exceptions for connection issues or protocol errors, allowing you to handle errors using try‑catch blocks. diff --git a/src/HTTP.jl b/src/HTTP.jl index 8f8c50a0d..9d9f4e0dd 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -1,123 +1,100 @@ module HTTP -using CodecZlib, URIs, Mmap, Base64, Dates, Sockets -using LibAwsCommon, LibAwsIO, LibAwsHTTPFork -import LibAwsCommon: Future, FieldRef +const _HTTP_TRIM_MODE = get(ENV, "HTTP_TRIM", "0") == "1" -export @logfmt_str, common_logfmt, combined_logfmt -export WebSockets +@static if _HTTP_TRIM_MODE + # --------------------------------------------------------------------- + # Trim mode: minimal, synchronous HTTP/1.1 client surface. + # --------------------------------------------------------------------- + using Reseau + include("trim/trim.jl") +else + using CodecZlib, URIs, Mmap, Base64, Dates + using Reseau, AwsHTTP -include("utils.jl") -include("access_log.jl") -include("sniff.jl"); using .Sniff -include("forms.jl"); using .Forms -include("requestresponse.jl") -include("cookies.jl"); using .Cookies -include("client/redirects.jl") -include("client/client.jl") -include("client/retry.jl") -include("client/connection.jl") -include("client/request.jl") -include("client/stream.jl") -include("client/makerequest.jl") -include("websockets.jl"); using .WebSockets -include("server.jl") -include("handlers.jl"); using .Handlers -include("statuses.jl") + export HTTPVersion + export startwrite, startread, closewrite, closeread + export @logfmt_str, common_logfmt, combined_logfmt + export WebSockets + export Pool, default_connection_limit, set_default_connection_limit!, closeall -struct StatusError <: Exception - request_method::String - request_uri::aws_uri - response::Response -end + const nobody = UInt8[] -function Base.showerror(io::IO, e::StatusError) - println(io, "HTTP.StatusError:") - println(io, " Request method: $(e.request_method)") - println(io, " Request URI: $(makeuri(e.request_uri))") - println(io, " response:") - print_response(io, e.response) - return -end + Base.@deprecate escape escapeuri -# backwards compatibility -function Base.getproperty(e::StatusError, s::Symbol) - if s == :status - return e.response.status - elseif s == :method - return e.request_method - elseif s == :target - return makeuri(e.request_uri) - else - return getfield(e, s) + include("utils.jl") + include("statistics.jl") + include("access_log.jl") + include("sniff.jl"); using .Sniff + include("forms.jl"); using .Forms + include("requestresponse.jl") + include("exceptions.jl"); using .Exceptions + struct StatusError <: HTTPError + request_method::String + request_uri::URI + response::Response end -end -#NOTE: this is global process logging in the aws-crt libraries; not appropriate for request-level -# logging, but more for debugging the library itself -mutable struct AwsLogger - ptr::Ptr{aws_logger} - file_ref::Libc.FILE - options::aws_logger_standard_options - function AwsLogger(level::Integer, allocator::Ptr{aws_allocator}) - fr = Libc.FILE(Libc.RawFD(1), "w") - opts = aws_logger_standard_options(aws_log_level(0), C_NULL, Ptr{Libc.FILE}(fr.ptr)) - x = new(Ptr{aws_logger}(aws_mem_acquire(allocator, 64)), fr, opts) - aws_logger_init_standard(x.ptr, allocator, FieldRef(x, :options)) != 0 && aws_throw_error() - aws_logger_set(x.ptr) - return finalizer(x) do x - aws_logger_clean_up(x.ptr) - aws_mem_release(allocator, x.ptr) - end + function Base.showerror(io::IO, e::StatusError) + println(io, "HTTP.StatusError:") + println(io, " Request method: $(e.request_method)") + println(io, " Request URI: $(e.request_uri)") + println(io, " response:") + print_response(io, e.response) + return end -end -const LOGGER = Ref{AwsLogger}() + # backwards compatibility + function Base.getproperty(e::StatusError, s::Symbol) + if s == :status + return e.response.status + elseif s == :method + return e.request_method + elseif s == :target + return e.request_uri + else + return getfield(e, s) + end + end + include("cookies.jl"); using .Cookies + include("client/redirects.jl") + include("client/client.jl") + include("client/retry.jl") + include("client/connection.jl") + include("client/request.jl") + include("client/stream.jl") + include("client/makerequest.jl") + include("client/open.jl") + include("download.jl") + include("websockets.jl"); using .WebSockets + include("server.jl") + include("handlers.jl"); using .Handlers + include("statuses.jl") -function set_log_level!(level::Integer, allocator::Ptr{aws_allocator}=default_aws_allocator()) - @assert 0 <= level <= 7 "log level must be between 0 and 7" - LOGGER[] = AwsLogger(level, allocator) - @assert aws_logger_set_log_level(LOGGER[].ptr, aws_log_level(level)) == 0 - return -end + #NOTE: this is process-level logging; not appropriate for request-level + # logging, but more for debugging the library itself + function set_log_level!(level::Integer) + @assert 0 <= level <= 7 "log level must be between 0 and 7" + Reseau.set_log_level!(Reseau.logger_get(), Reseau.LogLevel.T(level)) + return + end -function __init__() - allocator = default_aws_allocator() - LibAwsHTTPFork.init(allocator) - # intialize c functions - on_acquired[] = @cfunction(c_on_acquired, Cvoid, (Ptr{Cvoid}, Cint, Ptr{aws_retry_token}, Ptr{Cvoid})) - # on_shutdown[] = @cfunction(c_on_shutdown, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Cvoid})) - on_setup[] = @cfunction(c_on_setup, Cvoid, (Ptr{aws_http_connection}, Cint, Ptr{Cvoid})) - on_stream_write_on_complete[] = @cfunction(c_on_stream_write_on_complete, Cvoid, (Ptr{aws_http_stream}, Cint, Ptr{Cvoid})) - on_response_headers[] = @cfunction(c_on_response_headers, Cint, (Ptr{Cvoid}, Cint, Ptr{aws_http_header}, Csize_t, Ptr{Cvoid})) - on_response_header_block_done[] = @cfunction(c_on_response_header_block_done, Cint, (Ptr{Cvoid}, Cint, Ptr{Cvoid})) - on_response_body[] = @cfunction(c_on_response_body, Cint, (Ptr{Cvoid}, Ptr{aws_byte_cursor}, Ptr{Cvoid})) - on_metrics[] = @cfunction(c_on_metrics, Cvoid, (Ptr{Cvoid}, Ptr{aws_http_stream_metrics}, Ptr{Cvoid})) - on_complete[] = @cfunction(c_on_complete, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Cvoid})) - on_destroy[] = @cfunction(c_on_destroy, Cvoid, (Ptr{Cvoid},)) - retry_ready[] = @cfunction(c_retry_ready, Cvoid, (Ptr{aws_retry_token}, Cint, Ptr{Cvoid})) - on_incoming_connection[] = @cfunction(c_on_incoming_connection, Cvoid, (Ptr{Cvoid}, Ptr{aws_http_connection}, Cint, Ptr{Cvoid})) - on_connection_shutdown[] = @cfunction(c_on_connection_shutdown, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Cvoid})) - on_incoming_request[] = @cfunction(c_on_incoming_request, Ptr{aws_http_stream}, (Ptr{aws_http_connection}, Ptr{Cvoid})) - on_request_headers[] = @cfunction(c_on_request_headers, Cint, (Ptr{aws_http_stream}, Ptr{aws_http_header_block}, Ptr{aws_http_header}, Csize_t, Ptr{Cvoid})) - on_request_header_block_done[] = @cfunction(c_on_request_header_block_done, Cint, (Ptr{aws_http_stream}, Ptr{aws_http_header_block}, Ptr{Cvoid})) - on_request_body[] = @cfunction(c_on_request_body, Cint, (Ptr{aws_http_stream}, Ptr{aws_byte_cursor}, Ptr{Cvoid})) - on_request_done[] = @cfunction(c_on_request_done, Cint, (Ptr{aws_http_stream}, Ptr{Cvoid})) - on_server_stream_complete[] = @cfunction(c_on_server_stream_complete, Cint, (Ptr{aws_http_connection}, Cint, Ptr{Cvoid})) - on_destroy_complete[] = @cfunction(c_on_destroy_complete, Cvoid, (Ptr{Cvoid},)) - return -end + function __init__() + AwsHTTP.http_library_init() + return + end -# only run if precompiling -if VERSION >= v"1.9.0-0" && ccall(:jl_generating_output, Cint, ()) == 1 - do_precompile = true - try - Sockets.getalladdrinfo("localhost") - catch ex - @debug "Skipping precompilation workload because localhost cannot be resolved. Check firewall settings" exception=(ex,catch_backtrace()) - do_precompile = false + # only run if precompiling + if VERSION >= v"1.9.0-0" && ccall(:jl_generating_output, Cint, ()) == 1 + do_precompile = true + try + isempty(Reseau.getalladdrinfo("localhost")) && error("localhost cannot be resolved") + catch ex + @debug "Skipping precompilation workload because localhost cannot be resolved. Check firewall settings" exception=(ex,catch_backtrace()) + do_precompile = false + end + # do_precompile && include("precompile.jl") end - # do_precompile && include("precompile.jl") end end diff --git a/src/access_log.jl b/src/access_log.jl index bbef3b9c5..6f6e4415d 100644 --- a/src/access_log.jl +++ b/src/access_log.jl @@ -68,7 +68,7 @@ function symbol_mapping(s::Symbol) elseif s === :remote_port :(HTTP.remote_port(http.connection)) elseif s === :remote_user - :("-") # TODO: find from Basic auth... + :(HTTP._remote_user(http)) elseif s === :time_iso8601 if !Sys.iswindows() :(Libc.strftime("%FT%T%z", time())) @@ -105,6 +105,30 @@ function symbol_mapping(s::Symbol) end end +function _remote_user(http) + auth = HTTP.getheader(http.request.headers, "authorization") + auth === nothing && return "-" + parts = split(auth; limit=2) + length(parts) < 2 && return "-" + ascii_lc_isequal(parts[1], "basic") || return "-" + token = parts[2] + isempty(token) && return "-" + decoded = try + Base64.base64decode(token) + catch + return "-" + end + userpass = try + String(decoded) + catch + return "-" + end + colon = findfirst(==(':'), userpass) + user = colon === nothing ? userpass : userpass[1:prevind(userpass, colon)] + isempty(user) && return "-" + return user +end + """ common_logfmt(io::IO, http::HTTP.Stream) @@ -118,4 +142,3 @@ const common_logfmt = logfmt"$remote_addr - $remote_user [$time_local] \"$reques Format a log message in the Combined Log Format and write to `io`. """ const combined_logfmt = logfmt"$remote_addr - $remote_user [$time_local] \"$request\" $status $body_bytes_sent \"$http_referer\" \"$http_user_agent\"" - diff --git a/src/client/client.jl b/src/client/client.jl index 51d3ef16b..65400f324 100644 --- a/src/client/client.jl +++ b/src/client/client.jl @@ -1,13 +1,82 @@ const DEFAULT_CONNECT_TIMEOUT = 3000 -const DEFAULT_MAX_RETRIES = 10 +const DEFAULT_MAX_RETRIES = 0 +const default_connection_limit = Ref{Int}(max(16, Threads.nthreads() * 4)) + +# ─── Shared infrastructure ─── +# Lazily initialized resources that replace the old C library globals +# (default_aws_allocator, default_aws_event_loop_group, default_aws_client_bootstrap). +const _RESOURCES_LOCK = ReentrantLock() +const _EVENT_LOOP_GROUP = Ref{Any}(nothing) +const _HOST_RESOLVER = Ref{Any}(nothing) +const _CLIENT_BOOTSTRAP = Ref{Any}(nothing) + +function _ensure_resources!() + _CLIENT_BOOTSTRAP[] !== nothing && return + Base.@lock _RESOURCES_LOCK begin + _CLIENT_BOOTSTRAP[] !== nothing && return + elg_opts = Reseau.EventLoops.EventLoopGroupOptions() + elg = Reseau.EventLoops.EventLoopGroup(elg_opts) + _EVENT_LOOP_GROUP[] = elg + resolver = Reseau.Sockets.HostResolver(elg) + _HOST_RESOLVER[] = resolver + bootstrap = Reseau.Sockets.ClientBootstrap(Reseau.Sockets.ClientBootstrapOptions( + event_loop_group=elg, + host_resolver=resolver, + )) + _CLIENT_BOOTSTRAP[] = bootstrap + end +end + +function _task_sleep_s(seconds::Real)::Nothing + seconds <= 0 && return nothing + elg = _EVENT_LOOP_GROUP[] + if elg === nothing + Reseau.thread_sleep_s(seconds) + return nothing + end + el = Reseau.EventLoops.event_loop_group_get_next_loop(elg) + if el === nothing + Reseau.thread_sleep_s(seconds) + return nothing + end + Reseau.EventLoops.task_sleep_s(el, seconds) + return nothing +end + +# ─── TLS helper ─── + +function _make_tls_options(host::String; ssl_cert, ssl_key, ssl_capath, ssl_cacert, ssl_insecure, ssl_alpn_list) + alpn_list = _normalize_alpn_list(ssl_alpn_list) + if ssl_cert !== nothing && ssl_key !== nothing + # Mutual TLS: client certificate + key (file paths) + opts = Reseau.Sockets.tls_ctx_options_init_client_mtls_from_path(ssl_cert, ssl_key) + Reseau.Sockets.tls_ctx_options_set_verify_peer!(opts, !ssl_insecure) + if alpn_list !== nothing && !isempty(alpn_list) + Reseau.Sockets.tls_ctx_options_set_alpn_list!(opts, alpn_list) + end + if ssl_cacert !== nothing || ssl_capath !== nothing + Reseau.Sockets.tls_ctx_options_override_default_trust_store_from_path!(opts; + ca_path=ssl_capath, ca_file=ssl_cacert) + end + ctx = Reseau.Sockets.tls_context_new(opts) + else + # Standard client TLS (no client cert) + ctx = Reseau.Sockets.tls_context_new_client(; + verify_peer=!ssl_insecure, + ca_file=ssl_cacert, + ca_path=ssl_capath, + alpn_list=alpn_list, + ) + end + return Reseau.Sockets.TlsConnectionOptions(ctx; server_name=host) +end + +# ─── Settings ─── Base.@kwdef struct ClientSettings scheme::String host::String port::UInt32 - allocator::Ptr{aws_allocator} = default_aws_allocator() - bootstrap::Ptr{aws_client_bootstrap} = default_aws_client_bootstrap() - event_loop_group::Ptr{aws_event_loop_group} = default_aws_event_loop_group() socket_domain::Symbol = :ipv4 connect_timeout_ms::Int = DEFAULT_CONNECT_TIMEOUT keep_alive_interval_sec::Int = 0 @@ -31,11 +100,14 @@ Base.@kwdef struct ClientSettings proxy_ssl_cacert::Union{Nothing, String} = nothing proxy_ssl_insecure::Bool = false proxy_ssl_alpn_list::String = "h2;http/1.1" + proxy_auth::Union{Nothing, Symbol} = nothing + proxy_username::Union{Nothing, String} = nothing + proxy_password::Union{Nothing, String} = nothing retry_partition::Union{Nothing, String} = nothing max_retries::Int = DEFAULT_MAX_RETRIES backoff_scale_factor_ms::Int = 25 max_backoff_secs::Int = 20 - jitter_mode::aws_exponential_backoff_jitter_mode = AWS_EXPONENTIAL_BACKOFF_JITTER_DEFAULT + jitter_mode::Symbol = :default retry_timeout_ms::Int = 60000 initial_bucket_capacity::Int = 500 max_connections::Int = 512 @@ -43,7 +115,20 @@ Base.@kwdef struct ClientSettings connection_acquisition_timeout_ms::Int = 0 max_pending_connection_acquisitions::Int = 0 enable_read_back_pressure::Bool = false + monitoring_minimum_throughput_bytes_per_second::UInt64 = 0 + monitoring_allowable_throughput_failure_interval_seconds::UInt32 = 0 + monitoring_statistics_observer::Union{Nothing, Function} = nothing http2_prior_knowledge::Bool = false + http2_stream_manager::Bool = false + http2_close_connection_on_server_error::Bool = false + http2_connection_manual_window_management::Bool = false + http2_connection_ping_period_ms::Int = 0 + http2_connection_ping_timeout_ms::Int = 0 + http2_ideal_concurrent_streams_per_connection::Int = 0 + http2_max_concurrent_streams_per_connection::Int = 0 + http2_max_closed_streams::Int = 0 + http2_initial_window_size::Int = HTTP2_DEFAULT_WINDOW_SIZE + http2_initial_settings::Union{Nothing, AbstractVector} = nothing end ClientSettings( @@ -58,15 +143,35 @@ ClientSettings( max_retries::Integer=DEFAULT_MAX_RETRIES, require_ssl_verification::Bool=true, ssl_insecure::Bool=false, - kw...) = + kw...) = begin + kw_nt = (; kw...) + connection_limit = Base.get(() -> nothing, kw_nt, :connection_limit) + if connection_limit !== nothing + connection_limit_warning(connection_limit) + kw_nt = Base.structdiff(kw_nt, (; connection_limit=nothing)) + end + max_connections = Base.get(() -> default_connection_limit[], kw_nt, :max_connections) + if haskey(kw_nt, :max_connections) + kw_nt = Base.structdiff(kw_nt, (; max_connections=nothing)) + end + http2_initial_settings = Base.get(() -> nothing, kw_nt, :http2_initial_settings) + if http2_initial_settings !== nothing && !(http2_initial_settings isa AbstractVector) + throw(ArgumentError("http2_initial_settings must be a vector of pairs or AwsHTTP.Http2Setting")) + end + if haskey(kw_nt, :http2_initial_settings) + kw_nt = Base.structdiff(kw_nt, (; http2_initial_settings=nothing)) + end ClientSettings(; scheme=String(scheme), host=String(host), port=port, connect_timeout_ms=(connect_timeout !== nothing ? connect_timeout * 1000 : connect_timeout_ms), max_retries=(retry ? (retries != DEFAULT_MAX_RETRIES ? retries : max_retries) : 0), + max_connections=max_connections, ssl_insecure=(!require_ssl_verification || ssl_insecure), - kw...) + http2_initial_settings=http2_initial_settings, + kw_nt...) +end # make a new ClientSettings object from an existing one w/ just different url values @generated function ClientSettings(cs::ClientSettings, scheme::AbstractString, host::AbstractString, port::Integer) @@ -77,17 +182,49 @@ ClientSettings( return ex end +# ─── Compat option mirrors for tests ─── + +struct ConnManagerOptsCompat + http2_conn_manual_window_management::Bool + max_closed_streams::Csize_t + initial_window_size::Csize_t + num_initial_settings::Csize_t + initial_settings_array::Ptr{AwsHTTP.Http2Setting} + _initial_settings_storage::Vector{AwsHTTP.Http2Setting} +end + +struct Http2StreamManagerOptsCompat + close_connection_on_server_error::Bool + conn_manual_window_management::Bool + connection_ping_period_ms::Csize_t + connection_ping_timeout_ms::Csize_t + ideal_concurrent_streams_per_connection::Csize_t + max_concurrent_streams_per_connection::Csize_t + initial_window_size::Csize_t + max_closed_streams::Csize_t + num_initial_settings::Csize_t + initial_settings_array::Ptr{AwsHTTP.Http2Setting} + _initial_settings_storage::Vector{AwsHTTP.Http2Setting} +end + +# ─── Client ─── + mutable struct Client settings::ClientSettings - socket_options::aws_socket_options - tls_options::Union{Nothing, aws_tls_connection_options} + socket_options::Reseau.Sockets.SocketOptions + tls_options::Union{Nothing, Reseau.Sockets.TlsConnectionOptions} # only 1 of proxy_options or proxy_env_settings is set - proxy_options::Union{Nothing, aws_http_proxy_options} - proxy_env_settings::Union{Nothing, proxy_env_var_settings} - retry_options::aws_standard_retry_options - retry_strategy::Ptr{aws_retry_strategy} - conn_manager_opts::aws_http_connection_manager_options - connection_manager::Ptr{aws_http_connection_manager} + proxy_options::Union{Nothing, AwsHTTP.HttpProxyOptions} + proxy_env_settings::Union{Nothing, AwsHTTP.ProxyEnvVarSettings} + proxy_strategy::Union{Nothing, AwsHTTP.HttpProxyStrategy} + monitoring_options::Union{Nothing, AwsHTTP.HttpConnectionMonitoringOptions} + monitoring_observer::Union{Nothing, Function} + retry_strategy::Reseau.Sockets.StandardRetryStrategy + connection_manager::AwsHTTP.HttpConnectionManager + http2_stream_manager::Union{Nothing, AwsHTTP.Http2StreamManager} + http2_initial_settings::Union{Nothing, Vector{AwsHTTP.Http2Setting}} + conn_manager_opts::ConnManagerOptsCompat + http2_stream_manager_opts::Union{Nothing, Http2StreamManagerOptsCompat} Client() = new() end @@ -95,156 +232,292 @@ end Client(scheme::AbstractString, host::AbstractString, port::Integer; kw...) = Client(ClientSettings(scheme, host, port % UInt32; kw...)) function Client(cs::ClientSettings) + _ensure_resources!() client = Client() client.settings = cs + if cs.http2_initial_window_size < 0 || cs.http2_initial_window_size > HTTP2_MAX_WINDOW_SIZE + throw(ArgumentError("http2_initial_window_size must be between 0 and $(HTTP2_MAX_WINDOW_SIZE)")) + end # socket options - client.socket_options = aws_socket_options( - AWS_SOCKET_STREAM, # socket type - cs.socket_domain == :ipv4 ? AWS_SOCKET_IPV4 : AWS_SOCKET_IPV6, # socket domain - AWS_SOCKET_IMPL_PLATFORM_DEFAULT, # aws_socket_impl_type - cs.connect_timeout_ms, - cs.keep_alive_interval_sec, - cs.keep_alive_timeout_sec, - cs.keep_alive_max_failed_probes, - cs.keepalive, - ntuple(x -> Cchar(0), 16) # network_interface_name + client.socket_options = Reseau.Sockets.SocketOptions(; + type=Reseau.Sockets.SocketType.STREAM, + domain=cs.socket_domain == :ipv4 ? Reseau.Sockets.SocketDomain.IPV4 : Reseau.Sockets.SocketDomain.IPV6, + connect_timeout_ms=cs.connect_timeout_ms, + keepalive=cs.keepalive, + keep_alive_interval_sec=cs.keep_alive_interval_sec, + keep_alive_timeout_sec=cs.keep_alive_timeout_sec, + keep_alive_max_failed_probes=cs.keep_alive_max_failed_probes, ) # tls options if cs.scheme == "https" || cs.scheme == "wss" - client.tls_options = LibAwsIO.tlsoptions(cs.host; - cs.ssl_cert, - cs.ssl_key, - cs.ssl_capath, - cs.ssl_cacert, - cs.ssl_insecure, - cs.ssl_alpn_list - ) + client.tls_options = _make_tls_options(cs.host; + cs.ssl_cert, cs.ssl_key, cs.ssl_capath, cs.ssl_cacert, + cs.ssl_insecure, cs.ssl_alpn_list) else client.tls_options = nothing end # proxy options + client.proxy_options = nothing + client.proxy_env_settings = nothing + client.proxy_strategy = nothing + proxy_connection_type = cs.proxy_connection_type == :forward ? + AwsHTTP.HttpProxyConnectionType.HTTP_FORWARD : + AwsHTTP.HttpProxyConnectionType.HTTP_TUNNEL if cs.proxy_host !== nothing && cs.proxy_port !== nothing - client.proxy_options = aws_http_proxy_options( - cs.proxy_connection_type == :forward ? AWS_HPCT_HTTP_FORWARD : AWS_HPCT_HTTP_TUNNEL, - aws_byte_cursor_from_c_str(cs.proxy_host), - cs.proxy_port % UInt32, - cs.proxy_ssl_cert === nothing ? C_NULL : LibAwsIO.tlsoptions(cs.proxy_host; - cs.proxy_ssl_cert, - cs.proxy_ssl_key, - cs.proxy_ssl_capath, - cs.proxy_ssl_cacert, - cs.proxy_ssl_insecure, - cs.proxy_ssl_alpn_list - ), - #TODO: support proxy_strategy - C_NULL, # proxy_strategy::Ptr{aws_http_proxy_strategy} - 0, # auth_type::aws_http_proxy_authentication_type - aws_byte_cursor_from_c_str(""), # auth_username::aws_byte_cursor - aws_byte_cursor_from_c_str(""), # auth_password::aws_byte_cursor + proxy_auth = cs.proxy_auth + if proxy_auth === nothing && (cs.proxy_username !== nothing || cs.proxy_password !== nothing) + proxy_auth = :basic + end + if proxy_auth !== nothing + proxy_auth == :basic || throw(ArgumentError("unsupported proxy_auth: $proxy_auth")) + cs.proxy_username === nothing && throw(ArgumentError("proxy_username required for basic proxy auth")) + cs.proxy_password === nothing && throw(ArgumentError("proxy_password required for basic proxy auth")) + client.proxy_strategy = AwsHTTP.http_proxy_strategy_new_basic_auth( + AwsHTTP.HttpProxyStrategyBasicAuthOptions( + proxy_connection_type, + cs.proxy_username, + cs.proxy_password, + ) + ) + end + client.proxy_options = AwsHTTP.HttpProxyOptions(; + connection_type=proxy_connection_type, + host=cs.proxy_host, + port=cs.proxy_port % UInt32, + proxy_strategy=client.proxy_strategy, ) elseif cs.proxy_allow_env_var - client.proxy_env_settings = proxy_env_var_settings( - AWS_HPEV_ENABLE, - cs.proxy_connection_type == :forward ? AWS_HPCT_HTTP_FORWARD : AWS_HPCT_HTTP_TUNNEL, - cs.proxy_ssl_cert === nothing ? C_NULL : LibAwsIO.tlsoptions(cs.proxy_host; - cs.proxy_ssl_cert, - cs.proxy_ssl_key, - cs.proxy_ssl_capath, - cs.proxy_ssl_cacert, - cs.proxy_ssl_insecure, - cs.proxy_ssl_alpn_list - ) + if cs.proxy_auth !== nothing || cs.proxy_username !== nothing || cs.proxy_password !== nothing + throw(ArgumentError("proxy auth requires explicit proxy_host/proxy_port")) + end + client.proxy_env_settings = AwsHTTP.ProxyEnvVarSettings(; + env_var_type=AwsHTTP.HttpProxyEnvVarType.ENABLE, + connection_type=proxy_connection_type, + ) + end + # connection monitoring options + if cs.monitoring_minimum_throughput_bytes_per_second != 0 || + cs.monitoring_allowable_throughput_failure_interval_seconds != 0 + client.monitoring_options = AwsHTTP.HttpConnectionMonitoringOptions( + UInt64(cs.monitoring_minimum_throughput_bytes_per_second), + UInt32(cs.monitoring_allowable_throughput_failure_interval_seconds), ) else - client.proxy_options = nothing + client.monitoring_options = nothing end + client.monitoring_observer = cs.monitoring_statistics_observer # retry strategy - exp_back_opts = aws_exponential_backoff_retry_options( - cs.event_loop_group, - cs.max_retries, - cs.backoff_scale_factor_ms, - cs.max_backoff_secs, - cs.jitter_mode, - C_NULL, # generate_random - C_NULL, # generate_random_impl - C_NULL, # generate_random_user_data - C_NULL, # shutdown_options::Ptr{aws_shutdown_callback_options} + backoff_config = Reseau.Sockets.ExponentialBackoffConfig(; + backoff_scale_factor_ms=cs.backoff_scale_factor_ms, + max_backoff_secs=cs.max_backoff_secs, + max_retries=cs.max_retries, + jitter_mode=cs.jitter_mode, ) - client.retry_options = aws_standard_retry_options( - exp_back_opts, - cs.initial_bucket_capacity + retry_config = Reseau.Sockets.StandardRetryConfig(; + initial_bucket_capacity=cs.initial_bucket_capacity, + backoff_config=backoff_config, ) - client.retry_strategy = aws_retry_strategy_new_standard(cs.allocator, FieldRef(client, :retry_options)) - client.retry_strategy == C_NULL && aws_throw_error() - client.conn_manager_opts = aws_http_connection_manager_options( - cs.bootstrap, - typemax(Csize_t), # initial_window_size::Csize_t - pointer(FieldRef(client, :socket_options)), - cs.response_first_byte_timeout_ms, - (cs.scheme == "https" || cs.scheme == "wss") ? pointer(FieldRef(client, :tls_options)) : C_NULL, - cs.http2_prior_knowledge, - C_NULL, # monitoring_options::Ptr{aws_http_connection_monitoring_options} - aws_byte_cursor_from_c_str(cs.host), - cs.port % UInt32, - C_NULL, # initial_settings_array::Ptr{aws_http2_setting} - 0, # num_initial_settings::Csize_t - 0, # max_closed_streams::Csize_t - false, # http2_conn_manual_window_management::Bool - client.proxy_options === nothing ? C_NULL : pointer(FieldRef(client, :proxy_options)), # proxy_options::Ptr{aws_http_proxy_options} - client.proxy_env_settings === nothing ? C_NULL : pointer(FieldRef(client, :proxy_env_settings)), # proxy_env_settings::Ptr{proxy_env_var_settings} - cs.max_connections, # max_connections::Csize_t, 512 - C_NULL, # shutdown_complete_user_data::Ptr{Cvoid} - C_NULL, # shutdown_complete_callback::Ptr{aws_http_connection_manager_shutdown_complete_fn} - cs.enable_read_back_pressure, # enable_read_back_pressure::Bool - cs.max_connection_idle_in_milliseconds, - cs.connection_acquisition_timeout_ms, - cs.max_pending_connection_acquisitions, - C_NULL, # network_interface_names_array - 0, # num_network_interface_names - ) - client.connection_manager = aws_http_connection_manager_new(cs.allocator, FieldRef(client, :conn_manager_opts)) - client.connection_manager == C_NULL && aws_throw_error() - - finalizer(client) do x - if x.connection_manager != C_NULL - aws_http_connection_manager_release(x.connection_manager) - x.connection_manager = C_NULL - end - if x.retry_strategy != C_NULL - aws_retry_strategy_release(x.retry_strategy) - x.retry_strategy = C_NULL + strategy = Reseau.Sockets.StandardRetryStrategy(_EVENT_LOOP_GROUP[], retry_config) + client.retry_strategy = strategy + # http2 initial settings + settings_input = cs.http2_initial_settings + if settings_input === nothing + client.http2_initial_settings = nothing + elseif settings_input isa AbstractVector{AwsHTTP.Http2Setting} + client.http2_initial_settings = collect(settings_input) + elseif settings_input isa AbstractVector{<:Pair} + client.http2_initial_settings = _settings_from_pairs(settings_input) + else + throw(ArgumentError("http2_initial_settings must be a vector of pairs or AwsHTTP.Http2Setting")) + end + settings_storage = client.http2_initial_settings === nothing ? AwsHTTP.Http2Setting[] : client.http2_initial_settings + settings_ptr = isempty(settings_storage) ? Ptr{AwsHTTP.Http2Setting}(C_NULL) : pointer(settings_storage) + settings_count = Csize_t(length(settings_storage)) + # connection factory: creates connections for the pool managers. + # Calls AwsHTTP.http_client_connect (async) and blocks until setup completes. + conn_factory = let socket_opts=client.socket_options, tls_opts=client.tls_options, + host=cs.host, port=cs.port, + prior_knowledge=cs.http2_prior_knowledge, + manual_wm=cs.http2_connection_manual_window_management, + initial_ws=cs.http2_initial_window_size, + rfbt_ms=cs.response_first_byte_timeout_ms + function(_manager_opts) + result_ch = Base.Channel{Any}(1) + AwsHTTP.http_client_connect(AwsHTTP.HttpClientConnectionOptions( + bootstrap=_CLIENT_BOOTSTRAP[], + host_name=host, + port=port, + socket_options=socket_opts, + tls_connection_options=tls_opts, + prior_knowledge_http2=prior_knowledge, + manual_window_management=manual_wm, + initial_window_size=Csize_t(initial_ws), + response_first_byte_timeout_ms=UInt64(rfbt_ms), + on_setup=(conn, err, ud) -> put!(result_ch, err == Reseau.OP_SUCCESS ? conn : nothing), + )) + return take!(result_ch) end end + # connection manager + client.connection_manager = AwsHTTP.http_connection_manager_new( + AwsHTTP.HttpConnectionManagerOptions(; + host=cs.host, + port=cs.port, + max_connections=cs.max_connections, + initial_window_size=Csize_t(cs.http2_initial_window_size), + manual_window_management=cs.http2_connection_manual_window_management, + http2_prior_knowledge=cs.http2_prior_knowledge, + enable_read_back_pressure=cs.enable_read_back_pressure, + max_connection_idle_in_milliseconds=UInt64(cs.max_connection_idle_in_milliseconds), + connection_acquisition_timeout_ms=UInt64(cs.connection_acquisition_timeout_ms), + max_pending_connection_acquisitions=cs.max_pending_connection_acquisitions, + response_first_byte_timeout_ms=UInt64(cs.response_first_byte_timeout_ms), + max_closed_streams=cs.http2_max_closed_streams, + http2_conn_manual_window_management=cs.http2_connection_manual_window_management, + on_connection_setup=conn_factory, + ) + ) + client.conn_manager_opts = ConnManagerOptsCompat( + cs.http2_connection_manual_window_management, + Csize_t(cs.http2_max_closed_streams), + Csize_t(cs.http2_initial_window_size), + settings_count, + settings_ptr, + settings_storage, + ) + # http2 stream manager (optional) + client.http2_stream_manager = nothing + client.http2_stream_manager_opts = nothing + if cs.http2_stream_manager + client.http2_stream_manager = AwsHTTP.http2_stream_manager_new( + AwsHTTP.Http2StreamManagerOptions(; + host=cs.host, + port=cs.port, + max_connections=cs.max_connections, + ideal_concurrent_streams_per_connection=cs.http2_ideal_concurrent_streams_per_connection, + max_concurrent_streams_per_connection=cs.http2_max_concurrent_streams_per_connection, + close_connection_on_server_error=cs.http2_close_connection_on_server_error, + connection_ping_period_ms=UInt64(cs.http2_connection_ping_period_ms), + connection_ping_timeout_ms=UInt64(cs.http2_connection_ping_timeout_ms), + initial_window_size=Csize_t(cs.http2_initial_window_size), + manual_window_management=cs.http2_connection_manual_window_management, + http2_prior_knowledge=cs.http2_prior_knowledge, + enable_read_back_pressure=cs.enable_read_back_pressure, + max_closed_streams=cs.http2_max_closed_streams, + on_connection_setup=conn_factory, + ) + ) + client.http2_stream_manager_opts = Http2StreamManagerOptsCompat( + cs.http2_close_connection_on_server_error, + cs.http2_connection_manual_window_management, + Csize_t(cs.http2_connection_ping_period_ms), + Csize_t(cs.http2_connection_ping_timeout_ms), + Csize_t(cs.http2_ideal_concurrent_streams_per_connection), + Csize_t(cs.http2_max_concurrent_streams_per_connection), + Csize_t(cs.http2_initial_window_size), + Csize_t(cs.http2_max_closed_streams), + settings_count, + settings_ptr, + settings_storage, + ) + end return client end -#TODO: this should probably be a LRU cache to not grow indefinitely +# ─── Client cache ─── + +const _CLIENT_CACHE_MAX = let val = get(ENV, "HTTP_CLIENT_CACHE_MAX", "64") + parsed = tryparse(Int, val) + parsed === nothing || parsed < 1 ? 64 : parsed +end + struct Clients lock::ReentrantLock clients::Dict{ClientSettings, Client} + order::Vector{ClientSettings} + max_clients::Int +end + +Clients(max_clients::Int=_CLIENT_CACHE_MAX) = + Clients(ReentrantLock(), Dict{ClientSettings, Client}(), ClientSettings[], max_clients) + +struct Pool + clients::Clients + max_connections::Union{Nothing, Int} end -Clients() = Clients(ReentrantLock(), Dict{ClientSettings, Client}()) +Pool() = Pool(default_connection_limit[]) +Pool(max_connections::Union{Int, Nothing}) = Pool(Clients(), max_connections) + +function Base.getproperty(pool::Pool, name::Symbol) + if name === :max + return getfield(pool, :max_connections) + end + return getfield(pool, name) +end const CLIENTS = Clients() function getclient(key::ClientSettings, clients::Clients=CLIENTS) Base.@lock clients.lock begin if haskey(clients.clients, key) + idx = findfirst(==(key), clients.order) + idx !== nothing && deleteat!(clients.order, idx) + push!(clients.order, key) return clients.clients[key] - else - client = Client(key) - clients.clients[key] = client - return client end + client = Client(key) + clients.clients[key] = client + push!(clients.order, key) + if length(clients.order) > clients.max_clients + evict = popfirst!(clients.order) + delete!(clients.clients, evict) + end + return client end end +function manager_metrics(client::Client) + if client.http2_stream_manager !== nothing + return AwsHTTP.http2_stream_manager_fetch_metrics(client.http2_stream_manager) + else + return AwsHTTP.http_connection_manager_fetch_metrics(client.connection_manager) + end +end + +getclient(key::ClientSettings, pool::Pool) = getclient(key, pool.clients) + function close_all_clients!(clients::Clients=CLIENTS) Base.@lock clients.lock begin for client in values(clients.clients) - finalize(client) + close(client.connection_manager) + if client.http2_stream_manager !== nothing + close(client.http2_stream_manager) + end end empty!(clients.clients) + empty!(clients.order) end end + +close_all_clients!(pool::Pool) = close_all_clients!(pool.clients) + +function set_default_connection_limit!(n::Integer) + default_connection_limit[] = Int(n) + return +end + +function closeall(pool::Union{Nothing, Pool}=nothing) + if pool === nothing + close_all_clients!(CLIENTS) + else + close_all_clients!(pool.clients) + end + return +end + +@noinline function connection_limit_warning(cl) + cl === nothing && return + @warn "connection_limit no longer supported as a keyword argument; use `HTTP.set_default_connection_limit!($cl)` before any requests are made or construct a shared pool via `POOL = HTTP.Pool($cl)` and pass to each request like `pool=POOL` instead." + return +end diff --git a/src/client/connection.jl b/src/client/connection.jl index 17071a053..a326df05f 100644 --- a/src/client/connection.jl +++ b/src/client/connection.jl @@ -1,26 +1,163 @@ -const on_setup = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_setup(conn, error_code, fut_ptr) - fut = unsafe_pointer_to_objref(fut_ptr) - if error_code == AWS_IO_DNS_INVALID_NAME# || error_code == AWS_IO_TLS_ERROR_NEGOTIATION_FAILURE - notify(fut, DontRetry(CapturedException(aws_error(error_code), Base.backtrace()))) - elseif error_code != 0 - notify(fut, CapturedException(aws_error(error_code), Base.backtrace())) - else - notify(fut, conn) - end - return +function _client_url(client::Client) + host = client.settings.host + port = client.settings.port + return string(client.settings.scheme, "://", host, ":", port) end -function with_connection(f::Function, client::Client) - fut = Future{Ptr{aws_http_connection}}() - GC.@preserve fut begin - aws_http_connection_manager_acquire_connection(client.connection_manager, on_setup[], pointer_from_objref(fut)) - connection = wait(fut) +function with_connection(f::Function, client::Client; context=nothing) + start_time = context !== nothing ? time() : 0.0 + ch = Channel{Any}(1) + AwsHTTP.http_connection_manager_acquire_connection( + client.connection_manager; + callback = (conn, error_code, _) -> begin + if error_code != AwsHTTP.OP_SUCCESS + ec = Reseau.last_error() + put!(ch, CapturedException(aws_error(ec), Base.backtrace())) + else + put!(ch, conn) + end + end + ) + result = take!(ch) + connection = if result isa Exception + throw(ConnectError(_client_url(client), result)) + else + result end try return f(connection) finally - aws_http_connection_manager_release_connection(client.connection_manager, connection) + AwsHTTP.http_connection_manager_release_connection(client.connection_manager, connection) + context !== nothing && _record_layer!(context, :connectionlayer, start_time) + end +end + +function _ensure_http2_connection(conn) + conn === nothing && throw(ArgumentError("HTTP/2 connection is null")) + AwsHTTP.http_connection_get_version(conn) == AwsHTTP.HttpVersion.HTTP_2 || throw(ArgumentError("HTTP/2 connection required")) + return conn +end + +function _with_http2_connection(f::Function, client::Client) + return with_connection(client) do conn + f(_ensure_http2_connection(conn)) + end +end + +function http2_ping(conn; data=nothing) + _ensure_http2_connection(conn) + fut = Future{UInt64}() + opaque_data = if data !== nothing + bytes = data isa AbstractString ? Vector{UInt8}(codeunits(data)) : Vector{UInt8}(data) + length(bytes) == AwsHTTP.H2_PING_DATA_SIZE || throw(ArgumentError("PING data must be $(AwsHTTP.H2_PING_DATA_SIZE) bytes")) + bytes + else + zeros(UInt8, AwsHTTP.H2_PING_DATA_SIZE) + end + AwsHTTP.h2_connection_send_ping!(conn, opaque_data; + on_completed = (rtt_ns, error_code, _) -> begin + if error_code != 0 + notify(fut, CapturedException(aws_error(error_code), Base.backtrace())) + else + notify(fut, rtt_ns) + end + end + ) + return wait(fut) +end + +http2_ping(client::Client; data=nothing) = _with_http2_connection(conn -> http2_ping(conn; data=data), client) + +function _settings_from_pairs(settings::AbstractVector{<:Pair}) + out = Vector{AwsHTTP.Http2Setting}(undef, length(settings)) + for (i, (k, v)) in enumerate(settings) + id = k isa AwsHTTP.Http2SettingsId.T ? k : AwsHTTP.Http2SettingsId.T(k) + out[i] = AwsHTTP.Http2Setting(id, UInt32(v)) + end + return out +end + +function http2_change_settings(conn, settings::Vector{AwsHTTP.Http2Setting}) + _ensure_http2_connection(conn) + fut = Future{Nothing}() + AwsHTTP.h2_connection_change_settings!(conn, settings; + on_completed = (error_code, _) -> begin + if error_code != 0 + notify(fut, CapturedException(aws_error(error_code), Base.backtrace())) + else + notify(fut, nothing) + end + end + ) + wait(fut) + return +end + +http2_change_settings(conn, settings::AbstractVector{<:Pair}) = + http2_change_settings(conn, _settings_from_pairs(settings)) + +http2_change_settings(client::Client, settings) = + _with_http2_connection(conn -> http2_change_settings(conn, settings), client) + + +function http2_local_settings(conn) + _ensure_http2_connection(conn) + return AwsHTTP.h2_connection_get_local_settings(conn) +end + +http2_local_settings(client::Client) = _with_http2_connection(http2_local_settings, client) + +function http2_remote_settings(conn) + _ensure_http2_connection(conn) + return AwsHTTP.h2_connection_get_remote_settings(conn) +end + +http2_remote_settings(client::Client) = _with_http2_connection(http2_remote_settings, client) + +function http2_send_goaway(conn, http2_error::Integer; allow_more_streams::Bool=true, debug_data=nothing) + _ensure_http2_connection(conn) + dd = if debug_data !== nothing + bytes = debug_data isa AbstractString ? Vector{UInt8}(codeunits(debug_data)) : Vector{UInt8}(debug_data) + length(bytes) <= 16 * 1024 || throw(ArgumentError("debug_data must be <= 16KB")) + bytes + else + UInt8[] end + AwsHTTP.h2_connection_send_goaway!(conn; + allow_more_streams=allow_more_streams, + error_code=UInt32(http2_error), + debug_data=dd + ) + return end + +http2_send_goaway(client::Client, http2_error::Integer; allow_more_streams::Bool=true, debug_data=nothing) = + _with_http2_connection(conn -> http2_send_goaway(conn, http2_error; allow_more_streams=allow_more_streams, debug_data=debug_data), client) + +function http2_update_window(conn, increment::Integer) + _ensure_http2_connection(conn) + increment < 0 && throw(ArgumentError("increment must be >= 0")) + increment > HTTP2_MAX_WINDOW_SIZE && throw(ArgumentError("increment must be <= $(HTTP2_MAX_WINDOW_SIZE)")) + AwsHTTP.h2_connection_update_window!(conn, UInt32(increment)) + return +end + +http2_update_window(client::Client, increment::Integer) = + _with_http2_connection(conn -> http2_update_window(conn, increment), client) + + +function _get_goaway(get_fn, conn) + _ensure_http2_connection(conn) + sent_or_received, last_stream_id, error_code = get_fn(conn) + if sent_or_received + return (http2_error=error_code, last_stream_id=last_stream_id) + else + return nothing + end +end + +http2_get_sent_goaway(conn) = _get_goaway(AwsHTTP.h2_connection_get_sent_goaway, conn) +http2_get_received_goaway(conn) = _get_goaway(AwsHTTP.h2_connection_get_received_goaway, conn) + +http2_get_sent_goaway(client::Client) = _with_http2_connection(http2_get_sent_goaway, client) +http2_get_received_goaway(client::Client) = _with_http2_connection(http2_get_received_goaway, client) diff --git a/src/client/makerequest.jl b/src/client/makerequest.jl index 298363568..79a982155 100644 --- a/src/client/makerequest.jl +++ b/src/client/makerequest.jl @@ -7,21 +7,56 @@ head(a...; kw...) = request("HEAD", a...; kw...) options(a...; kw...) = request("OPTIONS", a...; kw...) const COOKIEJAR = CookieJar() +const DEFAULT_PROXY = Symbol("__HTTP_DEFAULT_PROXY__") _something(x, y) = x === nothing ? y : x +# proxy keyword handling +function proxy_kwargs(proxy, req_scheme) + if proxy === DEFAULT_PROXY + return NamedTuple() + elseif proxy === nothing || proxy === false + return (proxy_allow_env_var=false,) + elseif proxy isa AbstractString || proxy isa URI + p = proxy isa URI ? proxy : URI(String(proxy)) + isempty(p.host) && throw(ArgumentError("proxy URL must include a host")) + port = isempty(p.port) ? (p.scheme == "https" ? 443 : 80) : parse(Int, p.port) + conn_type = req_scheme in ("https", "wss") ? :tunnel : :forward + if isempty(p.userinfo) + return (proxy_allow_env_var=false, proxy_host=p.host, proxy_port=UInt32(port), proxy_connection_type=conn_type) + end + parts = split(p.userinfo, ":"; limit=2) + proxy_user = unescapeuri(parts[1]) + proxy_pass = length(parts) == 2 ? unescapeuri(parts[2]) : "" + return (proxy_allow_env_var=false, proxy_host=p.host, proxy_port=UInt32(port), proxy_connection_type=conn_type, + proxy_auth=:basic, proxy_username=proxy_user, proxy_password=proxy_pass) + else + throw(ArgumentError("proxy must be a URL String, URI, nothing, or false")) + end +end + # main entrypoint for making an HTTP request # can provide method, url, headers, body, along with various keyword arguments -function request(method, url, h=Header[], b::RequestBodyTypes=nothing; - allocator=default_aws_allocator(), +function request(method, url, h=Header[], b=nothing; headers=h, - body::RequestBodyTypes=b, + body=b, chunkedbody=nothing, + copyheaders::Bool=true, + canonicalize_headers::Bool=false, + detect_content_type::Bool=false, username=nothing, password=nothing, bearer=nothing, query=nothing, client::Union{Nothing, Client}=nothing, + basicauth::Bool=true, + proxy=DEFAULT_PROXY, + pool=nothing, + logerrors::Bool=false, + logtag=nothing, + observelayers::Bool=false, + retry_check=nothing, + retry_delays=nothing, # redirect options redirect=true, redirect_limit=3, @@ -35,39 +70,109 @@ function request(method, url, h=Header[], b::RequestBodyTypes=nothing; response_body=response_stream, decompress::Union{Nothing, Bool}=nothing, status_exception::Bool=true, - readtimeout::Int=0, # only supported for HTTP 1.1, not HTTP 2 (current aws limitation) + readtimeout::Int=0, retry_non_idempotent::Bool=false, modifier=nothing, verbose=0, # only client keywords in catch-all kw...) - uri = parseuri(url, query, allocator) - return with_redirect(allocator, method, uri, headers, body, redirect, redirect_limit, redirect_method, forwardheaders) do method, uri, headers, body - reqclient = @something(client, getclient(ClientSettings(scheme(uri), host(uri), getport(uri); allocator=allocator, kw...)))::Client - with_retry_token(reqclient) do - resp = with_connection(reqclient) do conn - http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 - path = resource(uri) - with_request(reqclient, method, path, headers, body, chunkedbody, decompress, (username !== nothing && password !== nothing) ? "$username:$password" : userinfo(uri), bearer, modifier, http2, cookies, cookiejar, verbose) do req - if response_body isa AbstractVector{UInt8} - ref = Ref(1) - GC.@preserve ref begin - on_stream_response_body = BufferOnResponseBody(response_body, Base.unsafe_convert(Ptr{Int}, ref)) - with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator) + context = observelayers ? Dict{Symbol, Any}() : nothing + context === nothing || _init_observations!(context) + if chunkedbody === nothing && body isa IO && !(body isa IOStream) && !(body isa Form) + chunkedbody = IOChunkedBody(body) + body = nothing + end + if chunkedbody === nothing && body !== nothing && !(body isa RequestBodyTypes) && Base.isiterable(typeof(body)) + chunkedbody = body + body = nothing + end + retryable_body = chunkedbody === nothing && ( + body === nothing || + body isa AbstractString || + body isa AbstractVector{UInt8} || + body isa AbstractDict || + body isa NamedTuple || + body isa Form + ) + headers = mkreqheaders(headers, copyheaders) + uri = parseuri(url, query) + proxy_kw = proxy_kwargs(proxy, scheme(uri)) + client_kw = (; kw...) + if pool isa Pool && pool.max_connections !== nothing && !haskey(client_kw, :max_connections) + client_kw = merge(client_kw, (; max_connections=pool.max_connections)) + end + if !isempty(proxy_kw) + client_kw = merge(client_kw, proxy_kw) + end + authinfo = (username !== nothing && password !== nothing) ? "$username:$password" : userinfo(uri) + apply_basicauth = (username !== nothing && password !== nothing) ? true : basicauth + # `client_kw`/`chunkedbody` are reassigned above; freeze them in a fresh binding so + # the request pipeline closures capture a concrete value rather than a `Core.Box`. + let client_kw = client_kw, chunkedbody = chunkedbody + return with_redirect(method, uri, headers, body, redirect, redirect_limit, redirect_method, forwardheaders; context=context) do method, uri, headers, body + reqclient = @something( + client, + pool === nothing ? + getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...)) : + getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...), pool) + )::Client + req_ref = Ref{Any}(nothing) + with_retry_token(reqclient; logerrors=logerrors, logtag=logtag, method=method, uri=uri, + retry_check=retry_check, retry_delays=retry_delays, + retry_non_idempotent=retry_non_idempotent, retryable_body=retryable_body, req_ref=req_ref, context=context) do + resp = if reqclient.http2_stream_manager !== nothing + with_request(reqclient, method, resource(uri), headers, body, chunkedbody, decompress, authinfo, bearer, modifier, true, cookies, cookiejar, verbose; + copyheaders=false, + canonicalize_headers=canonicalize_headers, + detect_content_type=detect_content_type, + basicauth=apply_basicauth, + observelayers=observelayers, + context=context, + ) do req + req_ref[] = req + if response_body isa AbstractVector{UInt8} + on_stream_response_body = BufferOnResponseBody(response_body, Ref(1)) + with_stream_manager(reqclient, req, chunkedbody, on_stream_response_body, decompress, readtimeout; context=context) + elseif response_body isa IO + on_stream_response_body = IOOnResponseBody(response_body) + with_stream_manager(reqclient, req, chunkedbody, on_stream_response_body, decompress, readtimeout; context=context) + else + with_stream_manager(reqclient, req, chunkedbody, response_body, decompress, readtimeout; context=context) + end end - elseif response_body isa IO - on_stream_response_body = IOOnResponseBody(response_body) - with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator) else - with_stream(conn, req, chunkedbody, response_body, decompress, http2, readtimeout, allocator) + with_connection(reqclient; context=context) do conn + http2 = AwsHTTP.http_connection_get_version(conn) == AwsHTTP.HttpVersion.HTTP_2 + with_request(reqclient, method, resource(uri), headers, body, chunkedbody, decompress, authinfo, bearer, modifier, http2, cookies, cookiejar, verbose; + copyheaders=false, + canonicalize_headers=canonicalize_headers, + detect_content_type=detect_content_type, + basicauth=apply_basicauth, + observelayers=observelayers, + context=context, + ) do req + req_ref[] = req + if response_body isa AbstractVector{UInt8} + on_stream_response_body = BufferOnResponseBody(response_body, Ref(1)) + with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout; context=context) + elseif response_body isa IO + on_stream_response_body = IOOnResponseBody(response_body) + with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout; context=context) + else + with_stream(conn, req, chunkedbody, response_body, decompress, http2, readtimeout; context=context) + end + end end end + # status error check + if status_exception && iserror(resp) + if logerrors + @error "HTTP StatusError" method=method url=makeuri(uri) status=resp.status logtag=logtag + end + throw(StatusError(method, uri, resp)) + end + return resp end - # status error check - if status_exception && iserror(resp) - throw(StatusError(method, uri, resp)) - end - return resp end end end diff --git a/src/client/open.jl b/src/client/open.jl new file mode 100644 index 000000000..8a47a6d73 --- /dev/null +++ b/src/client/open.jl @@ -0,0 +1,157 @@ +function _open_stream(conn, req::Request, decompress, readtimeout) + http2 = AwsHTTP.http_connection_get_version(conn) == AwsHTTP.HttpVersion.HTTP_2 + stream = Stream{typeof(conn)}(decompress, http2, false) + stream.readtimeout = readtimeout + stream.bufferstream = Base.BufferStream() + stream.connection = conn + stream.request = req + stream.response = resp = Response(0, nothing, nothing, http2) + resp.request = req + opts = _make_request_options(stream, req; readtimeout=readtimeout) + aws_stream = AwsHTTP.http_connection_make_request(conn, opts) + aws_stream === nothing && aws_throw_error() + stream.aws_stream = aws_stream + _activate_stream!(stream) + return stream +end + +function open(f::Function, method::Union{String, Symbol}, url, h=Header[]; + headers=h, + copyheaders::Bool=true, + canonicalize_headers::Bool=false, + detect_content_type::Bool=false, + username=nothing, + password=nothing, + bearer=nothing, + query=nothing, + client::Union{Nothing, Client}=nothing, + basicauth::Bool=true, + proxy=DEFAULT_PROXY, + pool=nothing, + logerrors::Bool=false, + logtag=nothing, + observelayers::Bool=false, + retry_check=nothing, + retry_delays=nothing, + # redirect options + redirect=true, + redirect_limit=3, + redirect_method=nothing, + forwardheaders=true, + # cookie options + cookies=true, + cookiejar::CookieJar=COOKIEJAR, + # response options + decompress::Union{Nothing, Bool}=nothing, + status_exception::Bool=true, + readtimeout::Int=0, + modifier=nothing, + verbose=0, + # only client keywords in catch-all + kw...) + method_str = string(method) + headers = mkreqheaders(headers, copyheaders) + uri = parseuri(url, query) + context = observelayers ? Dict{Symbol, Any}() : nothing + context === nothing || _init_observations!(context) + count = 0 + while true + redirect_start = context === nothing ? 0.0 : time() + redirect_url = nothing + resp = nothing + proxy_kw = proxy_kwargs(proxy, scheme(uri)) + client_kw = (; kw...) + if pool isa Pool && pool.max_connections !== nothing && !haskey(client_kw, :max_connections) + client_kw = merge(client_kw, (; max_connections=pool.max_connections)) + end + if !isempty(proxy_kw) + client_kw = merge(client_kw, proxy_kw) + end + authinfo = (username !== nothing && password !== nothing) ? "$username:$password" : userinfo(uri) + apply_basicauth = (username !== nothing && password !== nothing) ? true : basicauth + reqclient = @something( + client, + pool === nothing ? + getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...)) : + getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...), pool) + )::Client + resp = with_connection(reqclient; context=context) do conn + http2 = AwsHTTP.http_connection_get_version(conn) == AwsHTTP.HttpVersion.HTTP_2 + path = resource(uri) + with_request(reqclient, method_str, path, headers, nothing, nothing, decompress, authinfo, bearer, modifier, http2, cookies, cookiejar, verbose; + copyheaders=false, + canonicalize_headers=canonicalize_headers, + detect_content_type=detect_content_type, + basicauth=apply_basicauth, + observelayers=observelayers, + context=context, + ) do req + if !http2 && + method_str in ("POST", "PUT", "PATCH") && + !hasheader(req.headers, "content-length") && + !hasheader(req.headers, "transfer-encoding") && + !hasheader(req.headers, "upgrade") + setheader(req.headers, "transfer-encoding", "chunked") + end + stream = _open_stream(conn, req, decompress, readtimeout) + stream_start = context === nothing ? 0.0 : time() + try + if redirect && issafe(method_str) + resp = startread(stream) + if (count < redirect_limit && isredirect(resp) && (location = getheader(resp.headers, "Location")) != "") + redirect_url = location + closeread(stream) + return resp + end + end + err = nothing + try + f(stream) + catch e + err = e + finally + closewrite(stream) + end + resp = closeread(stream) + err === nothing || throw(err) + return resp + finally + context === nothing || _record_layer!(context, :streamlayer, stream_start) + end + end + end + context === nothing || _record_layer!(context, :redirectlayer, redirect_start) + if redirect_url === nothing + if status_exception && iserror(resp) + if logerrors + @error "HTTP StatusError" method=method_str url=makeuri(uri) status=resp.status logtag=logtag + end + throw(StatusError(method_str, uri, resp)) + end + return resp + end + if count == redirect_limit + return resp + end + olduri = uri + newuri = resolvereference(makeuri(uri), redirect_url) + uri = parseuri(newuri, nothing) + method_str = newmethod(method_str, resp.status, redirect_method) + if forwardheaders + headers = filter(headers) do (header, _) + if headereq(String(header), "host") + return false + elseif any(x -> headereq(x, String(header)), SENSITIVE_HEADERS) && !isdomainorsubdomain(host(uri), host(olduri)) + return false + elseif method_str == "GET" && (headereq(String(header), "content-type") || headereq(String(header), "content-length")) + return false + else + return true + end + end + else + headers = Header[] + end + count += 1 + end +end diff --git a/src/client/redirects.jl b/src/client/redirects.jl index e260939ec..00015d225 100644 --- a/src/client/redirects.jl +++ b/src/client/redirects.jl @@ -38,14 +38,23 @@ function newmethod(request_method, response_status, redirect_method) return "GET" end -function with_redirect(f, allocator, method, uri, headers=nothing, body=nothing, redirect::Bool=true, redirect_limit::Int=3, redirect_method=nothing, forwardheaders::Bool=true) +function with_redirect(f, method, uri, headers=nothing, body=nothing, redirect::Bool=true, redirect_limit::Int=3, redirect_method=nothing, forwardheaders::Bool=true; context=nothing) if !redirect || redirect_limit == 0 # no redirecting return f(method, uri, headers, body) end count = 0 while true - ret = f(method, uri, headers, body) + if context === nothing + ret = f(method, uri, headers, body) + else + start_time = time() + try + ret = f(method, uri, headers, body) + finally + _record_layer!(context, :redirectlayer, start_time) + end + end resp = getresponse(ret) if (count == redirect_limit || !isredirect(resp) || (location = getheader(resp.headers, "Location")) == "") return ret @@ -54,7 +63,7 @@ function with_redirect(f, allocator, method, uri, headers=nothing, body=nothing, # follow redirect olduri = uri newuri = resolvereference(makeuri(uri), location) - uri = parseuri(newuri, nothing, allocator) + uri = parseuri(newuri, nothing) method = newmethod(method, resp.status, redirect_method) body = method == "GET" ? nothing : body if forwardheaders @@ -76,4 +85,7 @@ function with_redirect(f, allocator, method, uri, headers=nothing, body=nothing, count += 1 end @assert false "Unreachable!" -end \ No newline at end of file +end + +# compatibility: old callers that pass allocator as first arg +with_redirect(f, _allocator, method, uri, args...; kw...) = with_redirect(f, method, uri, args...; kw...) diff --git a/src/client/request.jl b/src/client/request.jl index 1fc573f94..9abfa7633 100644 --- a/src/client/request.jl +++ b/src/client/request.jl @@ -1,4 +1,7 @@ -const USER_AGENT = Ref{Union{String, Nothing}}("HTTP.jl/$VERSION") +const DEFAULT_USER_AGENT = let v = try Base.pkgversion(@__MODULE__) catch; nothing end + v === nothing ? "HTTP.jl/dev" : "HTTP.jl/$(v)" +end +const USER_AGENT = Ref{Union{String, Nothing}}(DEFAULT_USER_AGENT) """ setuseragent!(x::Union{String, Nothing}) @@ -12,23 +15,74 @@ function setuseragent!(x::Union{String, Nothing}) return end -function with_request(f::Function, client::Client, method, path, headers=nothing, body=nothing, chunkedbody=nothing, decompress::Union{Nothing, Bool}=nothing, userinfo=nothing, bearer=nothing, modifier=nothing, http2::Bool=false, cookies=true, cookiejar=COOKIEJAR, verbose=false) +function _default_host_header(settings::ClientSettings) + scheme = lowercase(settings.scheme) + default_port = (scheme == "https" || scheme == "wss") ? UInt32(443) : UInt32(80) + if settings.port == default_port + return settings.host + end + return string(settings.host, ":", settings.port) +end + +function with_request( + f::Function, + client::Client, + method, + path, + headers=nothing, + body=nothing, + chunkedbody=nothing, + decompress::Union{Nothing, Bool}=nothing, + userinfo=nothing, + bearer=nothing, + modifier=nothing, + http2::Bool=false, + cookies=true, + cookiejar=COOKIEJAR, + verbose=false; + copyheaders::Bool=true, + canonicalize_headers::Bool=false, + detect_content_type::Bool=false, + basicauth::Bool=true, + observelayers::Bool=false, + context=nothing, +) + if chunkedbody === nothing && body isa IO && !(body isa IOStream) && !(body isa Form) + chunkedbody = IOChunkedBody(body) + body = nothing + end + if chunkedbody === nothing && body !== nothing && !(body isa RequestBodyTypes) && Base.isiterable(typeof(body)) + chunkedbody = body + body = nothing + end # create request - req = Request(method, path, headers, nothing, http2, client.settings.allocator) + mutable_headers = (headers isa AbstractVector{<:Pair} && !copyheaders) ? headers : nothing + req_headers = mkreqheaders(headers, copyheaders) + req = Request(method, path, req_headers, nothing, http2; context=context) # add headers to request h = req.headers if http2 + authority = AwsHTTP.http_headers_get(h.hdrs, ":authority") + if authority === nothing + authority = AwsHTTP.http_headers_get(h.hdrs, "host") + end + if authority === nothing + authority = _default_host_header(client.settings) + end + AwsHTTP.http_headers_has(h.hdrs, "host") || setheader(h, "host", authority) setscheme(h, client.settings.scheme) - setauthority(h, client.settings.host) + setauthority(h, authority) else - setheader(h, "host", client.settings.host) + setheaderifabsent(h, "host", _default_host_header(client.settings)) end setheaderifabsent(h, "accept", "*/*") - setheaderifabsent(h, "user-agent", something(USER_AGENT[], "-")) + if USER_AGENT[] !== nothing + setheaderifabsent(h, "user-agent", USER_AGENT[]) + end if decompress === nothing || decompress setheaderifabsent(h, "accept-encoding", "gzip") end - if userinfo !== nothing + if basicauth && userinfo !== nothing && !isempty(userinfo) setheaderifabsent(h, "authorization", "Basic $(base64encode(unescapeuri(userinfo)))") end if bearer !== nothing @@ -37,10 +91,8 @@ function with_request(f::Function, client::Client, method, path, headers=nothing if !http2 && chunkedbody !== nothing setheaderifabsent(h, "transfer-encoding", "chunked") end - if headers !== nothing - for (k, v) in headers - addheader(h, k, v) - end + if detect_content_type && !hasheader(h, "content-type") && !(body isa Form) && isbytes(body) + setheader(h, "content-type", sniff(body)) end if cookies === true || (cookies isa AbstractDict && !isempty(cookies)) cookiestosend = Cookies.getcookies!(cookiejar, client.settings.scheme, client.settings.host, req.path) @@ -62,17 +114,29 @@ function with_request(f::Function, client::Client, method, path, headers=nothing setinputstream!(req, body) end elseif body !== nothing - try - setinputstream!(req, body) - catch e - @error "Failed to set input stream" exception=(e, catch_backtrace()) - end + setinputstream!(req, body) + end + if canonicalize_headers && !http2 + canonicalizeheaders!(h) + end + if mutable_headers !== nothing + sync_headers!(mutable_headers, h) end # call user function verbose > 0 && print_request(stdout, req) - ret = f(req) - resp = getresponse(ret) - verbose > 0 && print_response(stdout, resp) - cookies === false || Cookies.setcookies!(cookiejar, client.settings.scheme, client.settings.host, req.path, resp.headers) - return ret + start_time = time() + ret = nothing + try + ret = f(req) + resp = getresponse(ret) + if canonicalize_headers + canonicalizeheaders!(resp.headers) + end + verbose > 0 && print_response(stdout, resp) + cookies === false || Cookies.setcookies!(cookiejar, client.settings.scheme, client.settings.host, req.path, resp.headers) + return ret + finally + req.context[:total_request_duration_ms] = (time() - start_time) * 1000 + observelayers && _record_layer!(req.context, :messagelayer, start_time) + end end diff --git a/src/client/retry.jl b/src/client/retry.jl index 94b89e7bd..ba4efad97 100644 --- a/src/client/retry.jl +++ b/src/client/retry.jl @@ -1,5 +1,5 @@ -struct DontRetry{T} <: Exception - error::T +struct DontRetry <: Exception + error::Exception end Base.showerror(io::IO, e::DontRetry) = print(io, e.error) @@ -11,79 +11,106 @@ end Base.showerror(io::IO, e::StreamError) = print(io, e.error) -const on_acquired = Ref{Ptr{Cvoid}}(C_NULL) +retryable_status(status::Integer) = status in (403, 408, 409, 429, 500, 502, 503, 504, 599) -function c_on_acquired(retry_strategy, error_code, retry_token, fut_ptr) - fut = unsafe_pointer_to_objref(fut_ptr) - if error_code != 0 - notify(fut, DontRetry(CapturedException(aws_error(error_code), Base.backtrace()))) - else - notify(fut, retry_token) +function isrecoverable(ex::Exception)::Bool + if ex isa StatusError + return retryable_status(ex.status) + elseif ex isa ConnectError + return isrecoverable(ex.error) + elseif ex isa TimeoutError + return true + elseif ex isa RequestError + return isrecoverable(ex.error) + elseif ex isa Base.EOFError || ex isa Base.IOError + return true + elseif ex isa ArgumentError + return ex.msg == "stream is closed or unusable" + elseif ex isa CompositeException + for child in ex.exceptions + child isa Exception || return false + isrecoverable(child) || return false + end + return true + elseif ex isa AWSError + return true end - return + return false end -const retry_ready = Ref{Ptr{Cvoid}}(C_NULL) +@inline function _default_retryable(method, err::Exception, retryable_body::Bool, retry_non_idempotent::Bool)::Bool + retryable_body || return false + method === nothing && return false + method_str = string(method) + if !(isidempotent(method_str) || retry_non_idempotent) + return false + end + if err isa StatusError + return retryable_status(err.status) + end + return isrecoverable(err) +end -function c_retry_ready(token, error_code::Cint, fut_ptr) - fut = unsafe_pointer_to_objref(fut_ptr) - if error_code != 0 - notify(fut, DontRetry(CapturedException(aws_error(error_code), Base.backtrace()))) +function _normalize_retry_delays(retry_delays, max_retries::Int) + if retry_delays === nothing + return Base.ExponentialBackOff(n=max_retries, factor=3.0) + elseif retry_delays isa Number + return Iterators.repeated(retry_delays, max_retries) else - notify(fut, token) + return retry_delays end - return end -function with_retry_token(f::Function, client::Client) - # If max_retries is 0, we don't need to bother with any retrying - client.settings.max_retries == 0 && return f() - retry_partition = client.settings.retry_partition === nothing ? C_NULL : aws_byte_cursor_from_c_str(client.settings.retry_partition) - fut = Future{Ptr{aws_retry_token}}() - GC.@preserve fut begin - if aws_retry_strategy_acquire_retry_token(client.retry_strategy, retry_partition, on_acquired[], pointer_from_objref(fut), client.settings.retry_timeout_ms) != 0 - aws_throw_error() +function _retry_error_type(err) + if err isa StatusError + status = err.status + if status == 429 + return Reseau.Sockets.RetryErrorType.THROTTLING + elseif 500 <= status < 600 + return Reseau.Sockets.RetryErrorType.SERVER_ERROR + elseif 400 <= status < 500 + return Reseau.Sockets.RetryErrorType.CLIENT_ERROR + else + return Reseau.Sockets.RetryErrorType.TRANSIENT end - token = wait(fut) end + return Reseau.Sockets.RetryErrorType.TRANSIENT +end + +function with_retry_token( + f, + client::Client; + logerrors::Bool=false, + logtag=nothing, + method=nothing, + uri=nothing, + retry_check=nothing, + retry_delays=nothing, + retry_non_idempotent::Bool=false, + retryable_body::Bool=true, + req_ref=nothing, + context=nothing, +) + start_time = context !== nothing ? time() : 0.0 try - while true - try - ret = f() - aws_retry_token_record_success(token) - return ret - catch e - stream = nothing - if e isa StreamError - stream = e.stream - e = e.error - end - if e isa DontRetry - if stream !== nothing && iserror(stream.response.status) && stream.bufferstream !== nothing - # for error responses, we need to commit the temporary body buffer - stream.response.body = readavailable(stream.bufferstream) - end - throw(e.error) - end - # note we assume any error that wasn't wrapped in DontRetry is retryable - retryReady = Future{Ptr{aws_retry_token}}() - GC.@preserve retryReady begin - if aws_retry_strategy_schedule_retry( - token, - #TODO: use different error types? - AWS_RETRY_ERROR_TYPE_TRANSIENT, - retry_ready[], - pointer_from_objref(retryReady) - ) != 0 - #TODO: do we need to commit a previous error body to the response here? - aws_throw_error() - end - #TODO: should we wrap this in try-catch to commit a previous stream bufferstream to the response body? - token = wait(retryReady) + return f() + catch e + err = e isa StreamError ? (e::StreamError).error : e + if err isa DontRetry + if e isa StreamError + stream = (e::StreamError).stream::Stream + if iserror(stream.response.status) && stream.bufferstream !== nothing + # For error responses, we need to commit the temporary body buffer. + stream.response.body = readavailable(stream.bufferstream) end end + throw(err.error) + end + if logerrors + @error "HTTP request error" exception=(err, catch_backtrace()) method=method url=uri logtag=logtag end + rethrow() finally - aws_retry_token_release(token) + context !== nothing && _record_layer!(context, :retrylayer, start_time) end -end \ No newline at end of file +end diff --git a/src/client/stream.jl b/src/client/stream.jl index 90dfe5e96..857136bf9 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -1,222 +1,927 @@ -const on_response_headers = Ref{Ptr{Cvoid}}(C_NULL) +export Stream, closebody, isaborted, readall!, setstatus -function c_on_response_headers(aws_stream_ptr, header_block, header_array::Ptr{aws_http_header}, num_headers, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - headers = stream.response.headers - addheaders(headers, header_array, num_headers) - return Cint(0) +writebuf(body, maxsize=length(body) == 0 ? typemax(Int64) : length(body)) = Base.GenericIOBuffer{AbstractVector{UInt8}}(body, true, true, true, false, maxsize) + +mutable struct Stream{T} <: IO + decompress::Union{Nothing, Bool} + http2::Bool + server_side::Bool + status::Int + fut::Future{Nothing} + chunk::Union{Nothing, InputStream} + final_chunk_written::Bool + bufferstream::Union{Nothing, Base.BufferStream} + gzipstream::Union{Nothing, CodecZlib.GzipDecompressorStream} + responsebuf::Union{Nothing, IOBuffer} + headers_ready::Threads.Event + activated::Bool + write_started::Bool + read_started::Bool + response_started::Bool + handler_started::Bool + ignore_writes::Bool + readtimeout::Int + on_complete::Union{Nothing, Function} + released::Bool + # remaining fields are initially undefined + aws_stream::Any # H1Stream or H2Stream from AwsHTTP + connection::T + response::Response + request::Request + Stream{T}(decompress, http2, server_side::Bool=false) where {T} = new{T}( + decompress, + http2, + server_side, + 0, + Future{Nothing}(), + nothing, + false, + nothing, + nothing, + nothing, + Threads.Event(), + false, + false, + false, + false, + false, + false, + 0, + nothing, + false, + ) end -writebuf(body, maxsize=length(body) == 0 ? typemax(Int64) : length(body)) = Base.GenericIOBuffer{AbstractVector{UInt8}}(body, true, true, true, false, maxsize) +# compatibility: 4-arg version for callers that still pass allocator +Stream{T}(allocator, decompress, http2, server_side::Bool=false) where {T} = + Stream{T}(decompress, http2, server_side) + +Base.hash(s::Stream, h::UInt) = hash(objectid(s), h) + +getrequest(s::Stream) = s.request + +function _with_http2_connection(f::Function, stream::Stream) + !isdefined(stream, :aws_stream) && throw(ArgumentError("HTTP stream is not initialized")) + conn = stream.aws_stream.owning_connection + return f(conn) +end -const on_response_header_block_done = Ref{Ptr{Cvoid}}(C_NULL) +http2_ping(stream::Stream; data=nothing) = _with_http2_connection(conn -> http2_ping(conn; data=data), stream) +http2_change_settings(stream::Stream, settings) = _with_http2_connection(conn -> http2_change_settings(conn, settings), stream) +http2_local_settings(stream::Stream) = _with_http2_connection(http2_local_settings, stream) +http2_remote_settings(stream::Stream) = _with_http2_connection(http2_remote_settings, stream) +http2_send_goaway(stream::Stream, http2_error::Integer; allow_more_streams::Bool=true, debug_data=nothing) = + _with_http2_connection(conn -> http2_send_goaway(conn, http2_error; allow_more_streams=allow_more_streams, debug_data=debug_data), stream) +http2_get_sent_goaway(stream::Stream) = _with_http2_connection(http2_get_sent_goaway, stream) +http2_get_received_goaway(stream::Stream) = _with_http2_connection(http2_get_received_goaway, stream) +http2_update_window(stream::Stream, increment::Integer) = + _with_http2_connection(conn -> http2_update_window(conn, increment), stream) -function c_on_response_header_block_done(aws_stream_ptr, header_block, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - if aws_http_stream_get_incoming_response_status(aws_stream_ptr, FieldRef(stream, :status)) != 0 - return aws_raise_error(aws_last_error()) +function update_window(stream::Stream, increment::Integer) + !isdefined(stream, :aws_stream) && throw(ArgumentError("HTTP stream is not initialized")) + increment < 0 && throw(ArgumentError("increment must be >= 0")) + if stream.http2 + increment > HTTP2_MAX_WINDOW_SIZE && throw(ArgumentError("increment must be <= $(HTTP2_MAX_WINDOW_SIZE)")) + AwsHTTP.h2_stream_update_window!(stream.aws_stream, UInt32(increment)) + else + increment > typemax(UInt64) && throw(ArgumentError("increment too large")) + AwsHTTP.http_stream_update_window(stream.aws_stream, UInt64(increment)) end - stream.response.status = stream.status - # if this is the end of the main header block, prepare our response body to be written to, otherwise return - if header_block != AWS_HTTP_HEADER_BLOCK_MAIN - return Cint(0) + return +end + +function writechunk(s::Stream, chunk::RequestBodyTypes) + if !s.server_side && !(chunk isa AbstractString && isempty(chunk)) + @assert (isdefined(s, :response) && + isdefined(s.response, :request) && + s.response.request.method in ("POST", "PUT", "PATCH")) "write is only allowed for POST, PUT, and PATCH requests" end - if stream.decompress !== false - val = getheader(stream.response.headers, "content-encoding") - stream.decompress = val !== nothing && val == "gzip" + s.chunk = InputStream() + is = s.chunk + if chunk isa AbstractVector{UInt8} + is.bodyref = chunk + is.bodylen = length(chunk) + elseif chunk isa AbstractString + is.bodyref = chunk + is.bodylen = sizeof(chunk) + else + is.bodyref = chunk + is.bodylen = nbytes(chunk) === nothing ? 0 : nbytes(chunk) end - return Cint(0) + fut = Future{Nothing}() + if s.http2 + data = if chunk isa AbstractString + Vector{UInt8}(codeunits(chunk)) + elseif chunk isa AbstractVector{UInt8} + chunk + else + UInt8[] + end + is_final = isempty(data) + AwsHTTP.h2_stream_write_data!(s.aws_stream, data; + end_stream=is_final, + on_complete=(err, ud) -> begin + if err != 0 + notify(fut, CapturedException(aws_error(err), Base.backtrace())) + else + notify(fut, nothing) + end + end, + ) != 0 && aws_throw_error() + else + data = if chunk isa AbstractString + IOBuffer(codeunits(chunk)) + elseif chunk isa AbstractVector{UInt8} + IOBuffer(chunk) + else + IOBuffer(UInt8[]) + end + h1chunk = AwsHTTP.h1_chunk_new(data, is.bodylen; + on_complete=(stream, err, ud) -> begin + if err != 0 + notify(fut, CapturedException(aws_error(err), Base.backtrace())) + else + notify(fut, nothing) + end + end, + ) + AwsHTTP.h1_stream_write_chunk!(s.aws_stream, h1chunk) != 0 && aws_throw_error() + _h1_flush_outgoing!(s) + end + wait(fut) + if isdefined(s, :response) && s.response !== nothing + if s.server_side + s.response.metrics.response_body_length += is.bodylen + else + s.response.metrics.request_body_length += is.bodylen + end + end + return is.bodylen end -const on_response_body = Ref{Ptr{Cvoid}}(C_NULL) +function _ensure_response!(s::Stream) + if !isdefined(s, :response) || s.response === nothing + s.response = Response(200, nothing, nothing, s.http2) + end + return s.response +end -function c_on_response_body(aws_stream_ptr, data::Ptr{aws_byte_cursor}, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - bc = unsafe_load(data) - stream.response.metrics.response_body_length += bc.len - if stream.decompress - if stream.gzipstream === nothing - stream.bufferstream = b = Base.BufferStream() - stream.gzipstream = g = CodecZlib.GzipDecompressorStream(b) - unsafe_write(g, bc.ptr, bc.len) - else - unsafe_write(stream.gzipstream, bc.ptr, bc.len) +# Drive H1 outgoing encoder and send encoded bytes through the channel pipeline. +# Must be called after operations that produce outgoing data (send_response, write_chunk, activate). +# For H1 only; H2 encoding is handled differently. +function _h1_flush_outgoing!(s::Stream) + !isdefined(s, :aws_stream) && return + h1conn = s.aws_stream.owning_connection + slot = h1conn.slot + slot === nothing && return + channel = slot.channel + channel === nothing && return + if !Reseau.Sockets.channel_thread_is_callers_thread(channel) + fut = Future{Nothing}() + task = Reseau.Sockets.ChannelTask((task, ctx, status) -> begin + status == Reseau.TaskStatus.RUN_READY || return notify(fut, nothing) + try + _h1_flush_outgoing!(s) + notify(fut, nothing) + catch e + notify(fut, CapturedException(e, catch_backtrace())) + end + return nothing + end, nothing, "http_h1_flush_outgoing") + Reseau.Sockets.channel_schedule_task_now!(channel, task) + wait(fut) + return + end + while true + status, encoded = AwsHTTP.h1_connection_encode_outgoing!(h1conn) + status != AwsHTTP.OP_SUCCESS && throw(AWSError("H1 encoding failed")) + isempty(encoded) && break + msg = Reseau.Sockets.IoMessage(length(encoded)) + buf = msg.message_data + @inbounds for i in 1:length(encoded) + buf.mem[i] = encoded[i] end + buf.len = Csize_t(length(encoded)) + try + Reseau.Sockets.channel_slot_send_message(slot, msg, Reseau.Sockets.ChannelDirection.WRITE) + catch e + e isa Reseau.ReseauError || rethrow() + throw(AWSError("channel slot send failed")) + end + end + return +end + +function _send_response!(s::Stream) + if s.response_started + return s.response + end + resp = _ensure_response!(s) + msg = getfield(resp, :msg) + if s.http2 && AwsHTTP.http_message_get_protocol_version(msg) != AwsHTTP.HttpVersion.HTTP_2 + converted = AwsHTTP.http2_message_new_from_http1(msg) + converted === nothing && aws_throw_error() + setfield!(resp, :msg, converted) + msg = converted + end + if s.http2 + # H2 sends response via H2Stream API + conn = s.aws_stream.owning_connection + AwsHTTP.h2_stream_send_response!(s.aws_stream, conn, msg) != 0 && aws_throw_error() else - if stream.bufferstream === nothing - stream.bufferstream = b = Base.BufferStream() - unsafe_write(b, bc.ptr, bc.len) - else - unsafe_write(stream.bufferstream, bc.ptr, bc.len) + AwsHTTP.h1_stream_send_response!(s.aws_stream, msg) != 0 && aws_throw_error() + _h1_flush_outgoing!(s) + end + s.response_started = true + return resp +end + +function _server_startwrite(s::Stream) + if s.write_started + return + end + resp = _ensure_response!(s) + if s.request.method == "HEAD" + s.ignore_writes = true + _head_response!(resp) + end + if s.http2 + if !s.ignore_writes && resp.inputstream === nothing && hasheader(resp.headers, "content-length") + removeheader(resp.headers, "content-length") end + if !s.response_started + _send_response!(s) + end + s.write_started = true + return + end + if !s.ignore_writes && + !hasheader(resp.headers, "transfer-encoding") && + !hasheader(resp.headers, "upgrade") + hasheader(resp.headers, "content-length") && removeheader(resp.headers, "content-length") + setheader(resp.headers, "transfer-encoding", "chunked") end - return Cint(0) + _send_response!(s) + s.write_started = true + return end -const on_metrics = Ref{Ptr{Cvoid}}(C_NULL) +function _server_closewrite(s::Stream) + if s.final_chunk_written + return + end + resp = _ensure_response!(s) + if s.http2 + if !s.response_started + if s.ignore_writes + setinputstream!(resp, nothing) + end + _send_response!(s) + end + if s.ignore_writes + s.final_chunk_written = true + return + end + if resp.trailers !== nothing + AwsHTTP.h2_stream_add_trailing_headers!(s.aws_stream, resp.trailers.hdrs) != 0 && aws_throw_error() + end + writechunk(s, "") + s.final_chunk_written = true + return + end + if !s.response_started + if !s.ignore_writes && + !hasheader(resp.headers, "transfer-encoding") && + !hasheader(resp.headers, "upgrade") + hasheader(resp.headers, "content-length") && removeheader(resp.headers, "content-length") + setheader(resp.headers, "transfer-encoding", "chunked") + end + _send_response!(s) + end + if s.ignore_writes + s.final_chunk_written = true + return + end + if hasheader(resp.headers, "upgrade") + s.final_chunk_written = true + return + end + if resp.trailers !== nothing + AwsHTTP.h1_stream_add_chunked_trailer!(s.aws_stream, resp.trailers.hdrs) != 0 && aws_throw_error() + end + writechunk(s, "") + s.final_chunk_written = true + return +end -function c_on_metrics(aws_stream_ptr, metrics::Ptr{aws_http_stream_metrics}, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - m = unsafe_load(metrics) - if m.send_start_timestamp_ns != -1 - stream.response.metrics.stream_metrics = m +function _activate_stream!(s::Stream) + if s.server_side + s.activated = true + return + end + if !s.activated + if s.http2 + conn = s.aws_stream.owning_connection + status, _ = AwsHTTP.h2_stream_activate!(s.aws_stream, conn) + status != 0 && aws_throw_error() + s.activated = true + AwsHTTP._h2_connection_flush_outgoing!(conn) + else + AwsHTTP.h1_stream_activate!(s.aws_stream) != 0 && aws_throw_error() + s.activated = true + _h1_flush_outgoing!(s) + end end return end -const on_complete = Ref{Ptr{Cvoid}}(C_NULL) +function startwrite(s::Stream) + if s.server_side + return _server_startwrite(s) + end + if s.write_started + return + end + if !s.http2 && + !hasheader(s.request.headers, "content-length") && + !hasheader(s.request.headers, "transfer-encoding") && + !hasheader(s.request.headers, "upgrade") + setheader(s.request.headers, "transfer-encoding", "chunked") + end + _activate_stream!(s) + s.write_started = true + return +end -function c_on_complete(aws_stream_ptr, error_code, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - if stream.gzipstream !== nothing - close(stream.gzipstream) +function closewrite(s::Stream) + if s.server_side + return _server_closewrite(s) + end + if s.final_chunk_written + return end - if stream.bufferstream !== nothing - close(stream.bufferstream) + if s.http2 + _activate_stream!(s) + writechunk(s, "") + s.final_chunk_written = true + return end - if error_code != 0 - notify(stream.fut, CapturedException(aws_error(error_code), Base.backtrace())) + if s.write_started + writechunk(s, "") + s.final_chunk_written = true + elseif hasheader(s.request.headers, "transfer-encoding") + _activate_stream!(s) + writechunk(s, "") + s.final_chunk_written = true else - notify(stream.fut, nothing) + _activate_stream!(s) end return end -const on_destroy = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_destroy(stream) +function closebody(s::Stream) + closewrite(s) return end -if !@isdefined aws_websocket_server_upgrade_options - const aws_websocket_server_upgrade_options = Ptr{Cvoid} +function readall!(s::Stream, buf::Base.GenericIOBuffer=PipeBuffer()) + total = 0 + while !eof(s) + bytes = readavailable(s) + total += length(bytes) + write(buf, bytes) + end + return total end -mutable struct Stream{T} - allocator::Ptr{aws_allocator} - decompress::Union{Nothing, Bool} - http2::Bool - status::Cint # used as a ref - fut::Future{Nothing} - chunk::Union{Nothing, InputStream} - final_chunk_written::Bool - bufferstream::Union{Nothing, Base.BufferStream} - gzipstream::Union{Nothing, CodecZlib.GzipDecompressorStream} - # remaining fields are initially undefined - ptr::Ptr{aws_http_stream} - connection::T # Connection{F, S} (in servers.jl) - request_options::aws_http_make_request_options - response::Response - method::aws_byte_cursor - path::aws_byte_cursor - request_handler_options::aws_http_request_handler_options - request::Request - http2_stream_write_data_options::aws_http2_stream_write_data_options - chunk_options::aws_http1_chunk_options - websocket_options::aws_websocket_server_upgrade_options - Stream{T}(allocator, decompress, http2) where {T} = new{T}(allocator, decompress, http2, 0, Future{Nothing}(), nothing, false, nothing, nothing) +function isaborted(s::Stream) + s.server_side && return false + if !isdefined(s, :response) || s.response === nothing + return false + end + resp = s.response + return iserror(resp) && hasheader(resp, "Connection", "close") +end + +function startread(s::Stream) + if s.server_side + if s.read_started + return s.request + end + wait(s.headers_ready) + s.read_started = true + return s.request + end + if s.read_started + return s.response + end + _activate_stream!(s) + s.http2 && !s.final_chunk_written && closewrite(s) + wait(s.headers_ready) + s.read_started = true + return s.response +end + +function Base.readavailable(s::Stream, n::Int=typemax(Int)) + startread(s) + if s.bufferstream === nothing + return UInt8[] + end + return _readavailable(s.bufferstream) end -Base.hash(s::Stream, h::UInt) = hash(s.ptr, h) +function Base.read(s::Stream, n::Integer) + startread(s) + s.bufferstream === nothing && return UInt8[] + return read(s.bufferstream, n) +end -const on_stream_write_on_complete = Ref{Ptr{Cvoid}}(C_NULL) +function Base.read(s::Stream) + startread(s) + s.bufferstream === nothing && return UInt8[] + return read(s.bufferstream) +end -function c_on_stream_write_on_complete(aws_stream_ptr, error_code, fut_ptr) - fut = unsafe_pointer_to_objref(fut_ptr) - if error_code != 0 - notify(fut, CapturedException(aws_error(error_code), Base.backtrace())) - else - notify(fut, nothing) +function Base.read(s::Stream, ::Type{UInt8}) + data = Base.read(s, 1) + isempty(data) && throw(EOFError()) + return data[1] +end + +function Base.eof(s::Stream) + startread(s) + s.bufferstream === nothing && return true + return eof(s.bufferstream) +end + +function Base.unsafe_write(s::Stream, p::Ptr{UInt8}, n::UInt) + n == 0 && return 0 + buf = Vector{UInt8}(undef, n) + GC.@preserve buf unsafe_copyto!(pointer(buf), p, n) + Base.write(s, buf) + return n +end + +function Base.write(s::Stream, data::AbstractVector{UInt8}) + startwrite(s) + if s.server_side + if s.ignore_writes + return length(data) + elseif s.http2 + writechunk(s, data) + return length(data) + end + end + writechunk(s, data) + return length(data) +end + +function Base.write(s::Stream, data::StridedVector{UInt8}) + startwrite(s) + if s.server_side + if s.ignore_writes + return length(data) + elseif s.http2 + writechunk(s, data) + return length(data) + end + end + writechunk(s, data) + return length(data) +end + +function Base.write(s::Stream, data::Union{String, SubString{String}}) + startwrite(s) + if s.server_side + if s.ignore_writes + return sizeof(data) + elseif s.http2 + writechunk(s, data) + return sizeof(data) + end + end + writechunk(s, data) + return sizeof(data) +end + +function Base.write(s::Stream, data::AbstractString) + return Base.write(s, String(data)) +end + +function Base.write(s::Stream, b::UInt8) + startwrite(s) + if s.server_side + if s.ignore_writes + return 1 + elseif s.http2 + writechunk(s, UInt8[b]) + return 1 + end + end + writechunk(s, UInt8[b]) + return 1 +end + +function closeread(s::Stream) + startread(s) + try + try + wait(s.fut) + catch e + e isa HTTPError && rethrow() + if !s.server_side && isdefined(s, :request) && s.request !== nothing + throw(RequestError(s.request, e)) + end + rethrow() + end + finally + s.released = true + end + return s.response +end + +function Base.close(s::Stream) + try + closewrite(s) + finally + closeread(s) end return end -function writechunk(s::Stream, chunk::RequestBodyTypes) - @assert (isdefined(s, :response) && - isdefined(s.response, :request) && - s.response.request.method in ("POST", "PUT", "PATCH")) "write is only allowed for POST, PUT, and PATCH requests" - s.chunk = InputStream(s.allocator, chunk) - fut = Future{Nothing}() +function setstatus(s::Stream, status::Integer) + s.server_side || error("setstatus is only supported for server streams") + s.response_started && error("response already started") + resp = _ensure_response!(s) + resp.status = status + return +end + +function setheader(s::Stream, v) + s.server_side || error("setheader is only supported for server streams") + s.response_started && error("response already started") + resp = _ensure_response!(s) + setheader(resp, v) + return +end + +function setheader(s::Stream, k, v) + return setheader(s, k => v) +end + +function setheaderifabsent(s::Stream, k, v) + s.server_side || error("setheaderifabsent is only supported for server streams") + s.response_started && error("response already started") + resp = _ensure_response!(s) + setheaderifabsent(resp.headers, k, v) + return +end + +function addtrailer(s::Stream, headers::Headers) + !isdefined(s, :aws_stream) && error("stream is not initialized") + if s.server_side + resp = _ensure_response!(s) + if resp.trailers === nothing + resp.trailers = headers + elseif resp.trailers !== headers + for h in headers + addheader(resp.trailers, h) + end + end + return + end if s.http2 - s.http2_stream_write_data_options = aws_http2_stream_write_data_options( - s.chunk.ptr, - chunk == "", - on_stream_write_on_complete[], - pointer_from_objref(fut) - ) - aws_http2_stream_write_data(s.ptr, FieldRef(s, :http2_stream_write_data_options)) != 0 && aws_throw_error() + AwsHTTP.h2_stream_add_trailing_headers!(s.aws_stream, headers.hdrs) != 0 && aws_throw_error() else - s.chunk_options = aws_http1_chunk_options( - s.chunk.ptr, - s.chunk.bodylen, - C_NULL, - 0, - on_stream_write_on_complete[], - pointer_from_objref(fut) - ) - aws_http1_stream_write_chunk(s.ptr, FieldRef(s, :chunk_options)) != 0 && aws_throw_error() + AwsHTTP.h1_stream_add_chunked_trailer!(s.aws_stream, headers.hdrs) != 0 && aws_throw_error() end - wait(fut) - return s.chunk.bodylen + return +end + +function addtrailer(s::Stream, h::Pair) + trailers = Headers() + addheader(trailers, String(h.first), String(h.second)) + return addtrailer(s, trailers) +end + +function addtrailer(s::Stream, h::AbstractVector{<:Pair}) + trailers = Headers() + for (k, v) in h + addheader(trailers, String(k), String(v)) + end + return addtrailer(s, trailers) +end + +# ─── Callback builders ─── +# These create the closure callbacks for AwsHTTP stream options. +# Each closure captures the HTTP.Stream and manipulates it directly +# when the AwsHTTP library fires the callback. + +function _on_response_headers(stream::Stream) + return (aws_stream, header_block, headers_vec, user_data) -> begin + if header_block == AwsHTTP.HttpHeaderBlock.TRAILING + trailers = stream.response.trailers + if trailers === nothing + trailers = Headers() + stream.response.trailers = trailers + end + for h in headers_vec + addheader(trailers, h) + end + else + hdrs = stream.response.headers + for h in headers_vec + addheader(hdrs, h) + end + end + return AwsHTTP.OP_SUCCESS + end +end + +function _on_response_header_block_done(stream::Stream) + return (aws_stream, header_block, user_data) -> begin + stream.status = aws_stream.response_status + stream.response.status = stream.status + if header_block != AwsHTTP.HttpHeaderBlock.MAIN + return AwsHTTP.OP_SUCCESS + end + if stream.decompress !== false + val = getheader(stream.response.headers, "content-encoding") + stream.decompress = val !== nothing && val == "gzip" + end + notify(stream.headers_ready) + return AwsHTTP.OP_SUCCESS + end +end + +function _on_response_body(stream::Stream) + return (aws_stream, data::AbstractVector{UInt8}, user_data) -> begin + stream.response.metrics.response_body_length += length(data) + if stream.decompress + if stream.gzipstream === nothing + stream.bufferstream = b = Base.BufferStream() + stream.gzipstream = g = CodecZlib.GzipDecompressorStream(b) + write(g, data) + else + write(stream.gzipstream, data) + end + else + if stream.bufferstream === nothing + stream.bufferstream = b = Base.BufferStream() + write(b, data) + else + write(stream.bufferstream, data) + end + end + return AwsHTTP.OP_SUCCESS + end +end + +function _on_metrics(stream::Stream) + return (aws_stream, metrics, user_data) -> begin + if metrics.send_start_timestamp_ns != -1 + stream.response.metrics.stream_metrics = metrics + end + return nothing + end +end + +function _on_complete(stream::Stream) + return (aws_stream, error_code, user_data) -> begin + if stream.gzipstream !== nothing + close(stream.gzipstream) + end + if stream.bufferstream !== nothing + close(stream.bufferstream) + end + if error_code != 0 && !stream.http2 && isdefined(stream, :connection) && stream.connection !== nothing + AwsHTTP.http_connection_close(stream.connection) + end + if error_code != 0 + if error_code == AwsHTTP.ERROR_HTTP_RESPONSE_FIRST_BYTE_TIMEOUT && stream.readtimeout > 0 + notify(stream.fut, TimeoutError(stream.readtimeout)) + else + notify(stream.fut, CapturedException(aws_error(error_code), Base.backtrace())) + end + else + notify(stream.fut, nothing) + end + notify(stream.headers_ready) + stream.released = true + return nothing + end +end + +function _make_request_options(stream::Stream, req::Request; chunkedbody=nothing, readtimeout=0) + msg = getfield(req, :msg) + return AwsHTTP.HttpMakeRequestOptions(; + request=msg, + on_response_headers=_on_response_headers(stream), + on_response_header_block_done=_on_response_header_block_done(stream), + on_response_body=_on_response_body(stream), + on_metrics=_on_metrics(stream), + on_complete=_on_complete(stream), + http2_use_manual_data_writes=(chunkedbody !== nothing), + response_first_byte_timeout_ms=UInt64(readtimeout * 1000), + ) end -function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator) - stream = Stream{Nothing}(allocator, decompress, http2) +# ─── with_stream_manager: H2 stream manager path ─── + +function with_stream_manager(client::Client, req::Request, chunkedbody, on_stream_response_body, decompress, readtimeout; context=nothing) + start_time = context !== nothing ? time() : 0.0 + stream = Stream{Nothing}(decompress, true, false) + stream.readtimeout = readtimeout if on_stream_response_body !== nothing stream.bufferstream = Base.BufferStream() end - GC.@preserve stream begin - stream.request_options = aws_http_make_request_options( - 1, - req.ptr, - pointer_from_objref(stream), - on_response_headers[], - on_response_header_block_done[], - on_response_body[], - on_metrics[], - on_complete[], - on_destroy[], - http2 && chunkedbody !== nothing, # http2_use_manual_data_writes - readtimeout * 1000 # response_first_byte_timeout_ms - ) - stream_ptr = aws_http_connection_make_request(conn, FieldRef(stream, :request_options)) - stream_ptr == C_NULL && aws_throw_error() - stream.ptr = stream_ptr - http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 - stream.response = resp = Response(0, nothing, nothing, http2, allocator) - resp.metrics = RequestMetrics() - resp.request = req - try - aws_http_stream_activate(stream_ptr) != 0 && aws_throw_error() - # write chunked body if provided - if chunkedbody !== nothing - foreach(chunk -> writechunk(stream, chunk), chunkedbody) - # write final chunk - writechunk(stream, "") - end - if on_stream_response_body !== nothing + stream.response = resp = Response(0, nothing, nothing, true) + resp.metrics = RequestMetrics() + resp.request = req + resp.metrics.request_body_length = bodylen(req) + request_options = _make_request_options(stream, req; chunkedbody=chunkedbody, readtimeout=readtimeout) + + # Acquire a connection from the H2 stream manager + acquire_ch = Base.Channel{Any}(1) + AwsHTTP.http2_stream_manager_acquire_stream(client.http2_stream_manager; + callback=(conn_or_nothing, error_code, ud) -> begin + if error_code != 0 || conn_or_nothing === nothing + put!(acquire_ch, error_code != 0 ? error_code : AwsHTTP.ERROR_HTTP_CONNECTION_CLOSED) + else + put!(acquire_ch, conn_or_nothing) + end + end, + ) + acquired = take!(acquire_ch) + if acquired isa Integer + throw(CapturedException(aws_error(acquired), Base.backtrace())) + end + connection = acquired + + # Create stream on the acquired connection + aws_stream = AwsHTTP.http_connection_make_request(connection, request_options) + aws_stream === nothing && aws_throw_error() + stream.aws_stream = aws_stream + timeout_task = nothing + if readtimeout > 0 + timeout_task = errormonitor(Threads.@spawn begin + _task_sleep_s(readtimeout) + (@atomic stream.fut.set) != 0 && return + notify(stream.fut, TimeoutError(readtimeout)) + if isdefined(stream, :aws_stream) + AwsHTTP.h2_stream_cancel!(stream.aws_stream) + end + end) + end + + # Activate stream + _activate_stream!(stream) + + try + # Write chunked body if provided + if chunkedbody !== nothing + foreach(chunk -> writechunk(stream, chunk), chunkedbody) + writechunk(stream, "") + end + if on_stream_response_body !== nothing + try + while !eof(stream.bufferstream) + on_stream_response_body(resp, _readavailable(stream.bufferstream)) + end try - while !eof(stream.bufferstream) - on_stream_response_body(resp, _readavailable(stream.bufferstream)) - end + wait(stream.fut) catch e - rethrow(DontRetry(e)) + e isa HTTPError && rethrow() + throw(RequestError(req, e)) end - else + catch e + rethrow(DontRetry(e)) + end + else + try wait(stream.fut) - if stream.bufferstream !== nothing - resp.body = _readavailable(stream.bufferstream) + catch e + e isa HTTPError && rethrow() + throw(RequestError(req, e)) + end + if stream.bufferstream !== nothing + resp.body = _readavailable(stream.bufferstream) + else + resp.body = UInt8[] + end + end + return resp + finally + timeout_task = nothing + stream.released = true + AwsHTTP.http2_stream_manager_release_stream(client.http2_stream_manager, connection) + if context !== nothing + _record_layer!(context, :streamlayer, start_time) + end + end +end + +# compatibility: 8-arg version for callers that still pass allocator +with_stream_manager(client, req, chunkedbody, on_stream_response_body, decompress, readtimeout, _allocator; context=nothing) = + with_stream_manager(client, req, chunkedbody, on_stream_response_body, decompress, readtimeout; context=context) + +# ─── with_stream: connection manager path ─── + +function with_stream(conn, req::Request, chunkedbody, on_stream_response_body, decompress, http2, readtimeout; context=nothing) + start_time = context !== nothing ? time() : 0.0 + stream = Stream{typeof(conn)}(decompress, http2, false) + stream.readtimeout = readtimeout + if on_stream_response_body !== nothing + stream.bufferstream = Base.BufferStream() + end + stream.connection = conn + + request_options = _make_request_options(stream, req; + chunkedbody=(http2 ? chunkedbody : nothing), + readtimeout=readtimeout) + + aws_stream = AwsHTTP.http_connection_make_request(conn, request_options) + aws_stream === nothing && aws_throw_error() + stream.aws_stream = aws_stream + # Check actual connection version (may have been upgraded) + actual_http2 = AwsHTTP.http_connection_get_version(conn) == AwsHTTP.HttpVersion.HTTP_2 + stream.http2 = actual_http2 + stream.response = resp = Response(0, nothing, nothing, actual_http2) + resp.metrics = RequestMetrics() + resp.request = req + resp.metrics.request_body_length = bodylen(req) + timeout_task = nothing + if readtimeout > 0 + timeout_task = errormonitor(Threads.@spawn begin + _task_sleep_s(readtimeout) + (@atomic stream.fut.set) != 0 && return + if !stream.http2 && isdefined(stream, :connection) + conn = stream.connection + if conn !== nothing && conn.slot !== nothing && conn.slot.channel !== nothing + Reseau.Sockets.channel_shutdown!(conn.slot.channel, AwsHTTP.ERROR_HTTP_RESPONSE_FIRST_BYTE_TIMEOUT; shutdown_immediately=true) + elseif conn !== nothing + AwsHTTP.http_connection_close(conn) + end + end + notify(stream.fut, TimeoutError(readtimeout)) + if isdefined(stream, :aws_stream) + if stream.http2 + AwsHTTP.h2_stream_cancel!(stream.aws_stream) else - resp.body = UInt8[] + AwsHTTP.http_stream_cancel(stream.aws_stream) + end + end + end) + end + + try + _activate_stream!(stream) + # Write chunked body if provided + if chunkedbody !== nothing + foreach(chunk -> writechunk(stream, chunk), chunkedbody) + writechunk(stream, "") + end + if on_stream_response_body !== nothing + try + while !eof(stream.bufferstream) + on_stream_response_body(resp, _readavailable(stream.bufferstream)) end + try + wait(stream.fut) + catch e + e isa HTTPError && rethrow() + throw(RequestError(req, e)) + end + catch e + rethrow(DontRetry(e)) + end + else + try + wait(stream.fut) + catch e + e isa HTTPError && rethrow() + throw(RequestError(req, e)) + end + if stream.bufferstream !== nothing + resp.body = _readavailable(stream.bufferstream) + else + resp.body = UInt8[] end - return resp - finally - aws_http_stream_release(stream_ptr) end - end # GC.@preserve + return resp + finally + timeout_task = nothing + stream.released = true + if context !== nothing + _record_layer!(context, :streamlayer, start_time) + end + end end +# compatibility: 9-arg version for callers that still pass allocator +with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, _allocator; context=nothing) = + with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout; context=context) + # can be removed once https://github.com/JuliaLang/julia/pull/57211 is fully released function _readavailable(this::Base.BufferStream) bytes = lock(this.cond) do @@ -226,4 +931,4 @@ function _readavailable(this::Base.BufferStream) take!(buf) end return bytes -end \ No newline at end of file +end diff --git a/src/cookies.jl b/src/cookies.jl index 7cc09e6cf..c2fc01f73 100644 --- a/src/cookies.jl +++ b/src/cookies.jl @@ -33,8 +33,8 @@ module Cookies export Cookie, CookieJar, cookies, stringify, getcookies!, setcookies!, addcookie! import Base: == -using Dates, Sockets -import ..addheader, ..headereq, ..Headers, ..Request, ..Response +using Dates +import ..addheader, ..headereq, ..Headers, ..Request, ..Response, .._header_name, .._header_value @enum SameSite SameSiteDefaultMode=1 SameSiteLaxMode SameSiteStrictMode SameSiteNoneMode @@ -211,8 +211,9 @@ const RFC1123GMTFormat = gmtformat(Dates.RFC1123Format) function readsetcookies(headers::Headers) result = Cookie[] for h in headers - headereq(h.name, "Set-Cookie") || continue - line = h.value + name = _header_name(h) + headereq(name, "Set-Cookie") || continue + line = _header_value(h) if length(line) == 0 continue end @@ -300,14 +301,18 @@ function readsetcookies(headers::Headers) return result end -function isIP(host) - try - Base.parse(IPAddr, host) - return true - catch e - isa(e, ArgumentError) && return false - rethrow(e) +function isIP(host::AbstractString)::Bool + # Minimal IPv4 literal check (avoid `Sockets` stdlib dependency). + parts = split(host, '.'; keepempty = false) + length(parts) == 4 || return false + for p in parts + isempty(p) && return false + all(isdigit, p) || return false + v = tryparse(Int, p) + v === nothing && return false + (0 <= v <= 255) || return false end + return true end """ @@ -326,8 +331,9 @@ cookies(r::Request) = readcookies(r.headers, "") function readcookies(headers::Headers, filter::String="") result = Cookie[] for h in headers - headereq(h.name, "Cookie") || continue - line = h.value + name = _header_name(h) + headereq(name, "Cookie") || continue + line = _header_value(h) for part in split(strip(line), ';'; keepempty=false) part = strip(part) length(part) <= 1 && continue diff --git a/src/download.jl b/src/download.jl new file mode 100644 index 000000000..841dc38c0 --- /dev/null +++ b/src/download.jl @@ -0,0 +1,122 @@ +function safer_joinpath(basepart, parts...) + explain = "Possible directory traversal attack detected." + for part in parts + occursin("..", part) && throw(DomainError(part, "contains \"..\". $explain")) + startswith(part, '/') && throw(DomainError(part, "begins with \"/\". $explain")) + end + return joinpath(basepart, parts...) +end + +function try_get_filename_from_headers(hdrs) + for content_disp in hdrs + filename_part = match(r"filename\\s*=\\s*(.*)", content_disp) + if filename_part !== nothing + filename = filename_part[1] + quoted_filename = match(r"\\\"(.*)\\\"", filename) + if quoted_filename !== nothing + filename = unescape_string(quoted_filename[1]) + end + return filename == "" ? nothing : filename + end + end + return nothing +end + +function try_get_filename_from_request(req::Request) + function file_from_target(t) + (t == "" || t == "/") && return nothing + f = basename(URI(t).path) + return f == "" ? file_from_target(dirname(t)) : f + end + return file_from_target(req.path) +end + +determine_file(::Nothing, resp, hdrs) = determine_file(tempdir(), resp, hdrs) + +function determine_file(path, resp, hdrs) + if isdir(path) + filename = something( + try_get_filename_from_headers(hdrs), + resp.request === nothing ? nothing : try_get_filename_from_request(resp.request), + basename(tempname(; cleanup = false)) + ) + return safer_joinpath(path, filename) + end + return path +end + +""" + download(url, [local_path], [headers]; update_period=1, kw...) + +Download a URL to a local file, returning the filename. If `local_path` is not +provided, the file is saved in a temporary directory. If `local_path` is a +directory, the filename is determined from response headers or request target. + +`update_period` controls progress reporting in seconds (set to `Inf` to disable). +Additional keyword arguments are forwarded to `HTTP.open`. +""" +function download(url::AbstractString, local_path=nothing, headers=Header[]; update_period=1, kw...) + format_progress(x) = round(x, digits=4) + format_bytes(x) = !isfinite(x) ? "∞ B" : Base.format_bytes(round(Int, max(x, 0))) + format_seconds(x) = "$(round(x; digits=2)) s" + format_bytes_per_second(x) = format_bytes(x) * "/s" + + @debug "downloading $url" + local file + hdrs = String[] + HTTP.open("GET", url, headers; kw...) do stream + resp = startread(stream) + content_disp = header(resp, "Content-Disposition") + !isempty(content_disp) && push!(hdrs, content_disp) + eof(stream) && return + + file = determine_file(local_path, resp, hdrs) + total_bytes = parse(Float64, header(resp, "Content-Length", "NaN")) + downloaded_bytes = 0 + start_time = now() + prev_time = now() + + if header(resp, "Content-Encoding") == "gzip" + total_bytes = NaN + end + + function report_callback() + prev_time = now() + taken_time = (prev_time - start_time).value / 1000 + average_speed = taken_time > 0 ? downloaded_bytes / taken_time : NaN + remaining_bytes = total_bytes - downloaded_bytes + remaining_bytes = isfinite(remaining_bytes) && remaining_bytes < 0 ? 0 : remaining_bytes + remaining_time = average_speed > 0 ? remaining_bytes / average_speed : NaN + completion_progress = isfinite(total_bytes) && total_bytes > 0 ? downloaded_bytes / total_bytes : NaN + completion_progress = isfinite(completion_progress) ? clamp(completion_progress, 0, 1) : completion_progress + @info("Downloading", + source=url, + dest=file, + progress=completion_progress |> format_progress, + time_taken=taken_time |> format_seconds, + time_remaining=remaining_time |> format_seconds, + average_speed=average_speed |> format_bytes_per_second, + downloaded=downloaded_bytes |> format_bytes, + remaining=remaining_bytes |> format_bytes, + total=total_bytes |> format_bytes, + ) + end + + Base.open(file, "w") do io + while !eof(stream) + buf = readavailable(stream) + wrote = write(io, buf) + downloaded_bytes += wrote + if !isinf(update_period) + if now() - prev_time > Millisecond(round(1000update_period)) + report_callback() + end + end + end + end + if !isinf(update_period) + report_callback() + end + end + return file +end diff --git a/src/exceptions.jl b/src/exceptions.jl new file mode 100644 index 000000000..b11d82618 --- /dev/null +++ b/src/exceptions.jl @@ -0,0 +1,91 @@ +module Exceptions + +export @try, HTTPError, ConnectError, TimeoutError, RequestError, current_exceptions_to_string + +# Pull parent-module bindings needed for concrete field types. +import ..Request + +@eval begin +""" + @try PermittedErrorTypes expr + +Convenience macro for wrapping an expression in a try/catch block where thrown +exceptions are ignored if they match one of the permitted types. +""" +macro $(:try)(exes...) + errs = Any[exes...] + ex = pop!(errs) + isempty(errs) && error("no permitted errors") + quote + try $(esc(ex)) + catch e + e isa InterruptException && rethrow(e) + |($([:(e isa $(esc(err))) for err in errs]...)) || rethrow(e) + end + end +end +end + +abstract type HTTPError <: Exception end + +""" + HTTP.ConnectError + +Raised when an error occurs while trying to establish a request connection to +the remote server. The underlying error is stored in `error`. +""" +struct ConnectError <: HTTPError + url::String + error::Exception +end + +function Base.showerror(io::IO, e::ConnectError) + print(io, "HTTP.ConnectError for url = `$(e.url)`: ") + Base.showerror(io, e.error) +end + +""" + HTTP.TimeoutError + +Raised when a request times out according to `readtimeout` keyword argument provided. +""" +struct TimeoutError <: HTTPError + readtimeout::Int +end + +Base.showerror(io::IO, e::TimeoutError) = + print(io, "TimeoutError: Connection closed after $(e.readtimeout) seconds") + +""" + HTTP.RequestError + +Raised when an error occurs while physically sending a request to the remote server +or reading the response back. The underlying error is stored in `error`. +""" +struct RequestError <: HTTPError + request::Request + error::Exception +end + +function Base.showerror(io::IO, e::RequestError) + println(io, "HTTP.RequestError:") + println(io, "HTTP.Request:") + Base.show(io, e.request) + println(io, "Underlying error:") + Base.showerror(io, e.error) +end + +function current_exceptions_to_string() + buf = IOBuffer() + println(buf) + println(buf, "\n===========================\nHTTP Error message:\n") + exc = @static if VERSION >= v"1.8.0-" + Base.current_exceptions() + else + Base.catch_stack() + end + Base.display_error(buf, exc) + return String(take!(buf)) +end + +end # module Exceptions diff --git a/src/handlers.jl b/src/handlers.jl index 9c06b20f2..86f7367f6 100644 --- a/src/handlers.jl +++ b/src/handlers.jl @@ -1,8 +1,9 @@ module Handlers -export Handler, Middleware, serve, serve!, Router, register!, getroute, getparams, getparam, getcookies +export Handler, Middleware, serve, serve!, Router, register!, getroute, getparams, getparam, getcookies, streamhandler -import ..Request, ..Cookies +import ..Request, ..Response, ..Stream, ..Cookies, ..getbody, .._header_name, .._header_value +import ..startread, ..setstatus, ..setheader, ..addtrailer, ..closewrite, ..closeread """ Handler @@ -39,6 +40,34 @@ then an input to the `auth_middlware`, which further enhances/modifies the handl """ abstract type Middleware end +""" + streamhandler(request_handler) -> stream handler + +Middleware that takes a request handler and returns a stream handler. +""" +function streamhandler(handler) + return function(stream::Stream) + req = startread(stream) + req.body = read(stream) + resp = Base.invokelatest(handler, req)::Response + resp.request = req + setstatus(stream, resp.status) + for h in resp.headers + setheader(stream, _header_name(h), _header_value(h)) + end + body = getbody(resp) + if body isa IO + write(stream, read(body)) + elseif body !== nothing + write(stream, body) + end + resp.trailers === nothing || addtrailer(stream, resp.trailers) + closewrite(stream) + closeread(stream) + return + end +end + # tree-based router handler mutable struct Variable name::String @@ -396,4 +425,4 @@ middleware. """ getcookies(req) = req.cookies === nothing ? Cookies.Cookie[] : req.cookies::Vector{Cookies.Cookie} -end # module Handlers \ No newline at end of file +end # module Handlers diff --git a/src/requestresponse.jl b/src/requestresponse.jl index 8b97edefe..a725c5195 100644 --- a/src/requestresponse.jl +++ b/src/requestresponse.jl @@ -1,17 +1,22 @@ +export Header, Headers, Message, Request, Response, + header, headers, hasheader, headercontains, + setheader, setheaderifabsent, setheaders!, defaultheader!, appendheader, removeheader, + canonicalizeheaders, canonicalizeheaders!, mkheaders + # working with headers -headereq(a::String, b::String) = GC.@preserve a b aws_http_header_name_eq(aws_byte_cursor_from_c_str(a), aws_byte_cursor_from_c_str(b)) +headereq(a::String, b::String) = ascii_lc_isequal(a, b) -mutable struct Header - header::aws_http_header - Header() = new() - Header(header::aws_http_header) = new(header) +struct Header + header::AwsHTTP.HttpHeader end +Header() = Header(AwsHTTP.HttpHeader("", "")) +Header(name::AbstractString, value::AbstractString) = Header(AwsHTTP.HttpHeader(String(name), String(value))) function Base.getproperty(x::Header, s::Symbol) if s == :name - return str(getfield(x, :header).name) + return getfield(x, :header).name elseif s == :value - return str(getfield(x, :header).value) + return getfield(x, :header).value else return getfield(x, s) end @@ -19,204 +24,475 @@ end Base.show(io::IO, h::Header) = print_header(io, h) -mutable struct Headers <: AbstractVector{Header} - const ptr::Ptr{aws_http_headers} - function Headers(allocator=default_aws_allocator()) - x = new(aws_http_headers_new(allocator)) - x.ptr == C_NULL && aws_throw_error() - return finalizer(_ -> aws_http_headers_release(x.ptr), x) +@inline _header_name(h::Header) = h.name +@inline _header_name(h::Pair) = String(first(h)) +@inline _header_name(h::AwsHTTP.HttpHeader) = h.name +@inline _header_value(h::Header) = h.value +@inline _header_value(h::Pair) = String(last(h)) +@inline _header_value(h::AwsHTTP.HttpHeader) = h.value +@inline _header_pair(h) = _header_name(h) => _header_value(h) + +Base.first(h::Header) = h.name +Base.last(h::Header) = h.value + +mutable struct Headers <: AbstractVector{Pair{String, String}} + const hdrs::AwsHTTP.HttpHeaders + function Headers() + return new(AwsHTTP.http_headers_new()) + end + Headers(hdrs::AwsHTTP.HttpHeaders) = new(hdrs) + function Headers(h::AbstractVector{<:Pair}) + hdrs = AwsHTTP.http_headers_new() + for (k, v) in h + AwsHTTP.http_headers_add(hdrs, String(k), String(v)) != 0 && aws_throw_error() + end + return new(hdrs) end - # no finalizer in this constructor because whoever called aws_http_headers_new needs to do that - Headers(ptr::Ptr{aws_http_headers}) = new(ptr) end -Base.size(h::Headers) = (Int(aws_http_headers_count(h.ptr)),) +abstract type Message end + +Base.eltype(::Type{Headers}) = Pair{String, String} +Base.IndexStyle(::Type{Headers}) = IndexLinear() +Base.size(h::Headers) = (AwsHTTP.http_headers_count(h.hdrs),) +Base.length(h::Headers) = AwsHTTP.http_headers_count(h.hdrs) function Base.getindex(h::Headers, i::Int) - header = Header() - aws_http_headers_get_index(h.ptr, i - 1, FieldRef(header, :header)) != 0 && aws_throw_error() - return header + hdr = AwsHTTP.http_headers_get_index(h.hdrs, i - 1) + hdr === nothing && throw(BoundsError(h, i)) + return _header_pair(hdr) end -Base.Dict(h::Headers) = Dict(((h.name, h.value) for h in h)) +function Base.setindex!(h::Headers, v::Pair, i::Int) + len = length(h) + (i < 1 || i > len) && throw(BoundsError(h, i)) + items = collect(h) + items[i] = String(v.first) => String(v.second) + empty!(h) + addheaders(h, items) + return h +end + +function Base.insert!(h::Headers, i::Int, v::Pair) + len = length(h) + (i < 1 || i > len + 1) && throw(BoundsError(h, i)) + items = collect(h) + insert!(items, i, String(v.first) => String(v.second)) + empty!(h) + addheaders(h, items) + return h +end -addheader(headers::Headers, h::Header) = aws_http_headers_add_header(headers.ptr, FieldRef(h, :header)) != 0 && aws_throw_error() -addheader(headers::Headers, k, v) = GC.@preserve k v aws_http_headers_add(headers.ptr, aws_byte_cursor_from_c_str(k), aws_byte_cursor_from_c_str(v)) != 0 && aws_throw_error() -addheaders(headers::Headers, h::Vector{aws_http_header}) = GC.@preserve h aws_http_headers_add_array(headers.ptr, pointer(h), length(h)) != 0 && aws_throw_error() -addheaders(headers::Headers, h::Ptr{aws_http_header}, count::Integer) = aws_http_headers_add_array(headers.ptr, h, count) != 0 && aws_throw_error() +function Base.push!(h::Headers, v::Pair) + addheader(h, v) + return h +end -function addheaders(headers::Headers, h::Vector{Pair{String, String}}) +function Base.push!(h::Headers, v::Header) + addheader(h, v) + return h +end + +Base.Dict(h::Headers) = Dict((_header_pair(h2) for h2 in h)) +Base.copy(h::Headers) = mkheaders(h) +Base.convert(::Type{Vector{Pair{String, String}}}, h::Headers) = mkheaders(h) + +addheader(headers::Headers, h::Header) = AwsHTTP.http_headers_add_header(headers.hdrs, h.header) != 0 && aws_throw_error() +addheader(headers::Headers, h::AwsHTTP.HttpHeader) = AwsHTTP.http_headers_add_header(headers.hdrs, h) != 0 && aws_throw_error() +addheader(headers::Headers, h::Pair) = AwsHTTP.http_headers_add(headers.hdrs, String(h.first), String(h.second)) != 0 && aws_throw_error() +addheader(headers::Headers, k, v) = AwsHTTP.http_headers_add(headers.hdrs, String(k), String(v)) != 0 && aws_throw_error() +addheaders(headers::Headers, h::Vector{AwsHTTP.HttpHeader}) = AwsHTTP.http_headers_add_array(headers.hdrs, h) != 0 && aws_throw_error() + +function addheaders(headers::Headers, h::AbstractVector{<:Pair}) for (k, v) in h addheader(headers, k, v) end end -setheader(headers::Headers, k, v) = GC.@preserve k v aws_http_headers_set(headers.ptr, aws_byte_cursor_from_c_str(k), aws_byte_cursor_from_c_str(v)) != 0 && aws_throw_error() -setscheme(headers::Headers, scheme) = GC.@preserve scheme aws_http2_headers_set_request_scheme(headers.ptr, aws_byte_cursor_from_c_str(scheme)) != 0 && aws_throw_error() -setauthority(headers::Headers, authority) = GC.@preserve authority aws_http2_headers_set_request_authority(headers.ptr, aws_byte_cursor_from_c_str(authority)) != 0 && aws_throw_error() +setheader(headers::Headers, k, v) = AwsHTTP.http_headers_set(headers.hdrs, String(k), String(v)) != 0 && aws_throw_error() +setscheme(headers::Headers, scheme) = AwsHTTP.http2_headers_set_request_scheme(headers.hdrs, String(scheme)) != 0 && aws_throw_error() +setauthority(headers::Headers, authority) = AwsHTTP.http2_headers_set_request_authority(headers.hdrs, String(authority)) != 0 && aws_throw_error() -#TODO: struct aws_string *aws_http_headers_get_all(const struct aws_http_headers *headers, struct aws_byte_cursor name); function getheader(headers::Headers, k) - out = Ref{aws_byte_cursor}() - GC.@preserve k out begin - aws_http_headers_get(headers.ptr, aws_byte_cursor_from_c_str(k), out) != 0 && return nothing - return str(out[]) + name = String(k) + val = AwsHTTP.http_headers_get(headers.hdrs, name) + if val === nothing && field_name_isequal(name, "host") + val = AwsHTTP.http_headers_get(headers.hdrs, ":authority") end + return val end -hasheader(headers::Headers, k) = - GC.@preserve k aws_http_headers_has(headers.ptr, aws_byte_cursor_from_c_str(k)) +function hasheader(headers::Headers, k) + name = String(k) + has = AwsHTTP.http_headers_has(headers.hdrs, name) + if !has && field_name_isequal(name, "host") + has = AwsHTTP.http_headers_has(headers.hdrs, ":authority") + end + return has +end -removeheader(headers::Headers, k) = - GC.@preserve k aws_http_headers_erase(headers.ptr, aws_byte_cursor_from_c_str(k)) != 0 && aws_throw_error() +removeheader(headers::Headers, k) = AwsHTTP.http_headers_erase(headers.hdrs, String(k)) != 0 && aws_throw_error() +removeheader(headers::Headers, k, v) = AwsHTTP.http_headers_erase_value(headers.hdrs, String(k), String(v)) != 0 && aws_throw_error() -removeheader(headers::Headers, k, v) = - GC.@preserve k v aws_http_headers_erase_value(headers.ptr, aws_byte_cursor_from_c_str(k), aws_byte_cursor_from_c_str(v)) != 0 && aws_throw_error() +function Base.deleteat!(h::Headers, i::Int) + AwsHTTP.http_headers_erase_index(h.hdrs, i - 1) != 0 && aws_throw_error() + return h +end -Base.deleteat!(h::Headers, i::Int) = aws_http_headers_erase_index(h.ptr, i - 1) != 0 && aws_throw_error() -Base.empty!(h::Headers) = aws_http_headers_clear(h.ptr) != 0 && aws_throw_error() +function Base.empty!(h::Headers) + AwsHTTP.http_headers_clear(h.hdrs) + return h +end setheaderifabsent(headers, k, v) = !hasheader(headers, k) && setheader(headers, k, v) +setheaderifabsent(m::Message, k, v) = setheaderifabsent(m.headers, k, v) + +function setheaders!(headers::Headers, newheaders) + newheaders === headers && return headers + if newheaders === nothing + empty!(headers) + return headers + end + if newheaders isa Headers + newheaders.hdrs === headers.hdrs && return headers + items = collect(newheaders) + elseif newheaders isa AwsHTTP.HttpHeaders + items = collect(Headers(newheaders)) + else + items = mkheaders(newheaders) + end + empty!(headers) + addheaders(headers, items) + return headers +end + +setheaders!(m::Message, newheaders) = setheaders!(m.headers, newheaders) + +field_name_isequal(a, b) = headereq(String(a), String(b)) + +Base.getindex(m::Message, k) = header(m, k) + +""" + HTTP.header(::Message, key [, default=""]) -> String + +Get header value for `key` (case-insensitive). +""" +header(m::Message, k, d="") = header(m.headers, k, d) +header(h::Headers, k, d="") = (v = getheader(h, String(k)); v === nothing ? d : v) +header(h::AbstractVector{<:Pair}, k, d="") = begin + for (name, value) in h + if field_name_isequal(name, k) + return String(value) + end + end + return d +end + +""" + HTTP.headers(m::Message, key) -> Vector{String} + +Get all headers with key `k` or empty if none. +""" +function headers(h::Headers, k) + vals = [_header_value(h2) for h2 in h if field_name_isequal(_header_name(h2), k)] + if isempty(vals) && field_name_isequal(k, "host") + authority = AwsHTTP.http_headers_get(h.hdrs, ":authority") + authority === nothing && return String[] + return [authority] + end + return vals +end +headers(h::AbstractVector{<:Pair}, k) = [String(v) for (name, v) in h if field_name_isequal(name, k)] +headers(m::Message, k) = headers(m.headers, k) + +""" + HTTP.hasheader(::Message, key) -> Bool + +Does header value for `key` exist (case-insensitive)? +""" +hasheader(m::Message, k) = header(m, k) != "" +hasheader(m::Message, k, v) = field_name_isequal(header(m, k), v) + +""" + HTTP.headercontains(::Message, key, value) -> Bool + +Does the header for `key` (interpreted as comma-separated list) contain `value` (case-insensitive)? +""" +headercontains(m::Message, k, v) = any(field_name_isequal.(strip.(split(header(m, k), ",")), v)) +headercontains(h::Headers, k, v) = any(field_name_isequal.(strip.(split(header(h, k), ",")), v)) +headercontains(h::AbstractVector{<:Pair}, k, v) = any(field_name_isequal.(strip.(split(header(h, k), ",")), v)) + +""" + HTTP.setheader(::Message, key => value) + +Set header `value` for `key` (case-insensitive). +""" +setheader(m::Message, v) = setheader(m.headers, v) +setheader(h::Headers, v::Header) = setheader(h, v.name, v.value) +setheader(h::Headers, v::Pair) = setheader(h, String(v.first), String(v.second)) +function setheader(h::AbstractVector{<:Pair}, v::Pair) + key = String(v.first) + value = String(v.second) + for i in eachindex(h) + if field_name_isequal(h[i].first, key) + h[i] = key => value + return h + end + end + push!(h, key => value) + return h +end + +appendheader(m::Message, v) = appendheader(m.headers, v) +appendheader(h::Headers, v::Header) = addheader(h, v) +appendheader(h::Headers, v::Pair) = addheader(h, String(v.first), String(v.second)) +function appendheader(h::AbstractVector{<:Pair}, v::Pair) + push!(h, String(v.first) => String(v.second)) + return h +end + +removeheader(m::Message, k) = removeheader(m.headers, k) +removeheader(m::Message, k, v) = removeheader(m.headers, k, v) +function removeheader(h::AbstractVector{<:Pair}, k) + key = String(k) + filter!(kv -> !field_name_isequal(kv.first, key), h) + return h +end +function removeheader(h::AbstractVector{<:Pair}, k, v) + key = String(k) + val = String(v) + filter!(kv -> !(field_name_isequal(kv.first, key) && field_name_isequal(kv.second, val)), h) + return h +end + +""" + defaultheader!(::Message, key => value) + +Set header `value` in message for `key` if it is not already set. +""" +function defaultheader!(m, v::Pair) + if header(m, first(v), nothing) === nothing + setheader(m, v) + end + return +end + +function canonicalizeheaders!(h::Headers) + items = [(_header_name(h2), _header_value(h2)) for h2 in h] + for i in length(h):-1:1 + deleteat!(h, i) + end + for (k, v) in items + addheader(h, tocameldash(k), v) + end + return h +end + +canonicalizeheaders(h::AbstractVector{<:Pair}) = + [tocameldash(String(k)) => String(v) for (k, v) in h] + +mkheaders(::Nothing) = Pair{String, String}[] +mkheaders(h::Headers) = [_header_pair(h2) for h2 in h] +mkheaders(h::AbstractVector{Header}) = begin + headers = Pair{String, String}[] + for head in h + push!(headers, _header_pair(head)) + end + return headers +end +mkheaders(h::AbstractVector{<:Pair}) = [String(k) => String(v) for (k, v) in h] +function mkheaders(h) + headers = Pair{String, String}[] + for (k, v) in h + push!(headers, String(k) => String(v)) + end + return headers +end + +function mkreqheaders(h, copyheaders::Bool) + if h === nothing + return Pair{String, String}[] + elseif h isa AbstractVector{<:Pair} && !copyheaders + return h + else + return mkheaders(h) + end +end + +function sync_headers!(dest::AbstractVector{<:Pair}, src::Headers) + empty!(dest) + for h in src + push!(dest, _header_pair(h)) + end + return dest +end # request/response -abstract type Message end mutable struct InputStream - ptr::Ptr{aws_input_stream} bodyref::Any bodylen::Int64 - bodycursor::aws_byte_cursor - InputStream() = new() + InputStream() = new(nothing, 0) end -ischunked(is::InputStream) = is.ptr == C_NULL && is.bodyref !== nothing +ischunked(is::InputStream) = is.bodylen < 0 && is.bodyref !== nothing -const RequestBodyTypes = Union{AbstractString, AbstractVector{UInt8}, IO, AbstractDict, NamedTuple, Nothing} +const RequestBodyTypes = Union{AbstractString, AbstractVector{UInt8}, IO, AbstractDict, NamedTuple, Form, Nothing} +const DEFAULT_IO_CHUNK_SIZE = 64 * 1024 -function InputStream(allocator::Ptr{aws_allocator}, body::RequestBodyTypes) - is = InputStream() - if body !== nothing - if body isa RequestBodyTypes - if (body isa AbstractVector{UInt8}) || (body isa AbstractString) - is.bodyref = body - is.bodycursor = aws_byte_cursor(sizeof(body), pointer(body)) - is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) - elseif body isa Union{AbstractDict, NamedTuple} - # hold a reference to the request body in order to gc-preserve it - is.bodyref = URIs.escapeuri(body) - is.bodycursor = aws_byte_cursor_from_c_str(is.bodyref) - is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) - elseif body isa IOStream - is.bodyref = body - is.ptr = aws_input_stream_new_from_open_file(allocator, Libc.FILE(body)) - elseif body isa Form - # we set the request.body to the Form bytes in order to gc-preserve them - is.bodyref = read(body) - is.bodycursor = aws_byte_cursor(sizeof(is.bodyref), pointer(is.bodyref)) - is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) - elseif body isa IO - # we set the request.body to the IO bytes in order to gc-preserve them - bytes = readavailable(body) - while !eof(body) - append!(bytes, readavailable(body)) - end - is.bodyref = bytes - is.bodycursor = aws_byte_cursor(sizeof(is.bodyref), pointer(is.bodyref)) - is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) - else - throw(ArgumentError("request body must be a string, vector of UInt8, NamedTuple, AbstractDict, HTTP.Form, or IO")) - end - aws_input_stream_get_length(is.ptr, FieldRef(is, :bodylen)) != 0 && aws_throw_error() - if !(is.bodylen > 0) - aws_input_stream_release(is.ptr) - is.ptr = C_NULL - end - else - # assume a chunked request body; any kind of iterable where elements are RequestBodyTypes - @assert Base.isiterable(typeof(body)) "chunked request body must be an iterable" - is.bodyref = body - end +struct IOChunkedBody{T<:IO} + io::T + chunk_size::Int + buf::Vector{UInt8} +end + +IOChunkedBody(io::IO; chunk_size::Int=DEFAULT_IO_CHUNK_SIZE) = + IOChunkedBody{typeof(io)}(io, chunk_size, Vector{UInt8}(undef, chunk_size)) + +function Base.iterate(it::IOChunkedBody, state=nothing) + eof(it.io) && return nothing + n = readbytes!(it.io, it.buf, it.chunk_size) + n == 0 && return nothing + return view(it.buf, 1:n), nothing +end + +const OBSERVELAYER_NAMES = (:messagelayer, :redirectlayer, :retrylayer, :connectionlayer, :streamlayer) + +function _init_observations!(context::Dict{Symbol, Any}) + for name in OBSERVELAYER_NAMES + context[Symbol(name, "_count")] = 0 + context[Symbol(name, "_duration_ms")] = 0.0 end - return finalizer(x -> aws_input_stream_release(x.ptr), is) + return context end -function setinputstream!(msg::Message, body) - aws_http_message_set_body_stream(msg.ptr, C_NULL) - msg.inputstream = nothing +function _record_layer!(context::Dict{Symbol, Any}, name::Symbol, started::Float64) + cntkey = Symbol(name, "_count") + durkey = Symbol(name, "_duration_ms") + context[cntkey] = Base.get(() -> 0, context, cntkey) + 1 + context[durkey] = Base.get(() -> 0.0, context, durkey) + (time() - started) * 1000 + return +end + +function setinputstream!(m::Message, body) + AwsHTTP.http_message_set_body_stream(getfield(m, :msg), nothing) + m.inputstream = nothing body === nothing && return - input_stream = InputStream(msg.allocator, body) - setfield!(msg, :inputstream, input_stream) - if input_stream.ptr != C_NULL - aws_http_message_set_body_stream(msg.ptr, input_stream.ptr) + is = InputStream() + if (body isa AbstractVector{UInt8}) || (body isa AbstractString) + is.bodyref = body + is.bodylen = sizeof(body) + elseif body isa Union{AbstractDict, NamedTuple} + is.bodyref = URIs.escapeuri(body) + is.bodylen = sizeof(is.bodyref) + elseif body isa IOStream + isopen(body) || throw(ArgumentError("request body IOStream is closed")) + is.bodyref = read(body) + is.bodylen = sizeof(is.bodyref) + elseif body isa Form + is.bodyref = read(body) + is.bodylen = sizeof(is.bodyref) + elseif body isa IO + bytes = readavailable(body) + while !eof(body) + append!(bytes, readavailable(body)) + end + is.bodyref = bytes + is.bodylen = sizeof(is.bodyref) + elseif Base.isiterable(typeof(body)) + # chunked request body; any kind of iterable where elements are RequestBodyTypes + is.bodyref = body + is.bodylen = -1 + else + throw(ArgumentError("request body must be a string, vector of UInt8, NamedTuple, AbstractDict, HTTP.Form, IO, or an iterable of those")) + end + setfield!(m, :inputstream, is) + if is.bodylen > 0 + # Wrap body in IOBuffer so the H1 encoder's readbytes! works + AwsHTTP.http_message_set_body_stream(getfield(m, :msg), IOBuffer(is.bodyref)) if body isa Union{AbstractDict, NamedTuple} - setheaderifabsent(msg.headers, "content-type", "application/x-www-form-urlencoded") + setheaderifabsent(m.headers, "content-type", "application/x-www-form-urlencoded") elseif body isa Form - setheaderifabsent(msg.headers, "content-type", content_type(body)) + setheaderifabsent(m.headers, "content-type", content_type(body)) end - setheader(msg.headers, "content-length", string(input_stream.bodylen)) + setheader(m.headers, "content-length", string(is.bodylen)) end return end + mutable struct Request <: Message - allocator::Ptr{aws_allocator} - ptr::Ptr{aws_http_message} + msg::AwsHTTP.HttpMessage inputstream::Union{Nothing, InputStream} # used for outgoing request body # only set in server-side request handlers body::Union{Nothing, Vector{UInt8}} + trailers::Union{Nothing, Headers} + context::Dict{Symbol, Any} route::Union{Nothing, String} params::Union{Nothing, Dict{String, String}} cookies::Any # actually Union{Nothing, Vector{Cookie}} - function Request(method, path, headers=nothing, body=nothing, http2::Bool=false, allocator=default_aws_allocator()) - ptr = http2 ? - aws_http2_message_new_request(allocator) : - aws_http_message_new_request(allocator) - ptr == C_NULL && aws_throw_error() - try - GC.@preserve method aws_http_message_set_request_method(ptr, aws_byte_cursor_from_c_str(method)) != 0 && aws_throw_error() - GC.@preserve path aws_http_message_set_request_path(ptr, aws_byte_cursor_from_c_str(path)) != 0 && aws_throw_error() - request_headers = Headers(aws_http_message_get_headers(ptr)) - if headers !== nothing - for (k, v) in headers - addheader(request_headers, k, v) - end + function Request(method, path, headers=nothing, body=nothing, http2::Bool=false; context=nothing) + msg = http2 ? + AwsHTTP.http2_message_new_request() : + AwsHTTP.http_message_new_request() + AwsHTTP.http_message_set_request_method(msg, String(method)) != 0 && aws_throw_error() + AwsHTTP.http_message_set_request_path(msg, String(path)) != 0 && aws_throw_error() + msg_headers = AwsHTTP.http_message_get_headers(msg) + if headers !== nothing + src_headers = headers isa AbstractVector{<:Pair} ? headers : mkheaders(headers) + for (k, v) in src_headers + AwsHTTP.http_headers_add(msg_headers, String(k), String(v)) != 0 && aws_throw_error() end - req = new(allocator, ptr) - req.body = nothing - req.inputstream = nothing - req.route = nothing - req.params = nothing - req.cookies = nothing - body !== nothing && setinputstream!(req, body) - return finalizer(_ -> aws_http_message_release(ptr), req) - catch - aws_http_message_release(ptr) - rethrow() end + req = new(msg) + req.body = nothing + req.inputstream = nothing + req.trailers = nothing + req.context = context === nothing ? Dict{Symbol, Any}() : context + req.route = nothing + req.params = nothing + req.cookies = nothing + body !== nothing && setinputstream!(req, body) + return req end end -ptr(x) = getfield(x, :ptr) +# compatibility: 6-arg version for callers that still pass allocator +Request(method, path, headers, body, http2::Bool, _allocator; context=nothing) = + Request(method, path, headers, body, http2; context=context) + +getrequest(req::Request) = req + +function observelayer(f) + function observation(req_or_stream; kw...) + req = getrequest(req_or_stream) + nm = nameof(f) + start_time = time() + ctx = req.context + ctx[Symbol(nm, "_count")] = Base.get(() -> 0, ctx, Symbol(nm, "_count")) + 1 + try + return f(req_or_stream; kw...) + finally + ctx[Symbol(nm, "_duration_ms")] = + Base.get(() -> 0.0, ctx, Symbol(nm, "_duration_ms")) + (time() - start_time) * 1000 + end + end +end function Base.getproperty(x::Request, s::Symbol) if s == :method - out = Ref{aws_byte_cursor}() - GC.@preserve out begin - aws_http_message_get_request_method(ptr(x), out) != 0 && return nothing - return str(out[]) - end - elseif s == :path || s == :target || s == :uri - out = Ref{aws_byte_cursor}() - GC.@preserve out begin - aws_http_message_get_request_path(ptr(x), out) != 0 && return nothing - path = str(out[]) - return s == :uri ? URI(path) : path - end + return AwsHTTP.http_message_get_request_method(getfield(x, :msg)) + elseif s == :path || s == :target + return AwsHTTP.http_message_get_request_path(getfield(x, :msg)) + elseif s == :uri + path = AwsHTTP.http_message_get_request_path(getfield(x, :msg)) + return path === nothing ? URI("/") : URI(path) elseif s == :headers - return Headers(aws_http_message_get_headers(ptr(x))) + return Headers(AwsHTTP.http_message_get_headers(getfield(x, :msg))) elseif s == :version - return aws_http_message_get_protocol_version(ptr(x)) == AWS_HTTP_VERSION_2 ? "2" : "1.1" + v = AwsHTTP.http_message_get_protocol_version(getfield(x, :msg)) + return v == AwsHTTP.HttpVersion.HTTP_2 ? HTTPVersion(2, 0) : HTTPVersion(1, 1) else return getfield(x, s) end @@ -224,19 +500,19 @@ end function Base.setproperty!(x::Request, s::Symbol, v) if s == :method - GC.@preserve v aws_http_message_set_request_method(x.ptr, aws_byte_cursor_from_c_str(v)) != 0 && aws_throw_error() + AwsHTTP.http_message_set_request_method(getfield(x, :msg), String(v)) != 0 && aws_throw_error() elseif s == :path - GC.@preserve v aws_http_message_set_request_path(x.ptr, aws_byte_cursor_from_c_str(v)) != 0 && aws_throw_error() + AwsHTTP.http_message_set_request_path(getfield(x, :msg), String(v)) != 0 && aws_throw_error() elseif s == :headers - addheaders(x.headers, v) + setheaders!(x, v) else setfield!(x, s, v) end end function print_header(io, h) - key = h.name - val = h.value + key = _header_name(h) + val = _header_value(h) if headereq(key, "authorization") write(io, string(key, ": ", "******", "\r\n")) return @@ -255,7 +531,13 @@ end function print_request(io, method, version, path, headers, body) write(io, "\"\"\"\n") - write(io, string(method, " ", path, " HTTP/$version\r\n")) + write(io, string(method, " ", path, " ")) + if version isa HTTPVersion + write(io, version) + else + write(io, "HTTP/", string(version)) + end + write(io, "\r\n") for h in headers print_header(io, h) end @@ -265,7 +547,12 @@ function print_request(io, method, version, path, headers, body) return end -getbody(r::Message) = isdefined(r, :inputstream) ? r.inputstream.bodyref : r.body +function getbody(r::Message) + if isdefined(r, :inputstream) && r.inputstream !== nothing + return r.inputstream.bodyref + end + return r.body +end print_request(io::IO, r::Request) = print_request(io, r.method, r.version, r.path, r.headers, getbody(r)) @@ -279,69 +566,86 @@ target(r::Request) = r.path headers(r::Request) = r.headers body(r::Request) = r.body -resource(uri::URI) = string(isempty(uri.path) ? "/" : uri.path, - !isempty(uri.query) ? "?" : "", uri.query, - !isempty(uri.fragment) ? "#" : "", uri.fragment) - mutable struct RequestMetrics request_body_length::Int response_body_length::Int nretries::Int - stream_metrics::Union{Nothing, aws_http_stream_metrics} + stream_metrics::Union{Nothing, AwsHTTP.HttpStreamMetrics} end RequestMetrics() = RequestMetrics(0, 0, 0, nothing) mutable struct Response <: Message - allocator::Ptr{aws_allocator} - ptr::Ptr{aws_http_message} + msg::AwsHTTP.HttpMessage inputstream::Union{Nothing, InputStream} body::Union{Nothing, Vector{UInt8}} # only set for client-side response body when no user-provided response_body + trailers::Union{Nothing, Headers} metrics::RequestMetrics - request::Request - - function Response(status::Integer, headers, body, http2::Bool=false, allocator=default_aws_allocator()) - ptr = http2 ? - aws_http2_message_new_response(allocator) : - aws_http_message_new_response(allocator) - ptr == C_NULL && aws_throw_error() - try - GC.@preserve status aws_http_message_set_response_status(ptr, status) != 0 && aws_throw_error() - response_headers = Headers(aws_http_message_get_headers(ptr)) - if headers !== nothing - for (k, v) in headers - addheader(response_headers, k, v) - end + request::Union{Request, Nothing} + + function Response(status::Integer, headers, body, http2::Bool=false) + msg = http2 ? + AwsHTTP.http2_message_new_response() : + AwsHTTP.http_message_new_response() + AwsHTTP.http_message_set_response_status(msg, Int(status)) != 0 && aws_throw_error() + msg_headers = AwsHTTP.http_message_get_headers(msg) + if headers !== nothing + src_headers = headers isa AbstractVector{<:Pair} ? headers : mkheaders(headers) + for (k, v) in src_headers + AwsHTTP.http_headers_add(msg_headers, String(k), String(v)) != 0 && aws_throw_error() end - resp = new(allocator, ptr) - resp.body = nothing - resp.inputstream = nothing - body !== nothing && setinputstream!(resp, body) - return finalizer(_ -> aws_http_message_release(ptr), resp) - catch - aws_http_message_release(ptr) - rethrow() end + resp = new(msg) + resp.body = nothing + resp.inputstream = nothing + resp.trailers = nothing + resp.metrics = RequestMetrics() + resp.request = nothing + if body !== nothing + setinputstream!(resp, body) + else + if !hasheader(resp.headers, "content-length") && !hasheader(resp.headers, "transfer-encoding") + setheader(resp.headers, "content-length" => "0") + end + end + return resp end - Response() = new(C_NULL, C_NULL, nothing, nothing) + Response() = new(AwsHTTP.http_message_new_response(), nothing, nothing, nothing, RequestMetrics(), nothing) end +# compatibility: 5-arg version for callers that still pass allocator +Response(status::Integer, headers, body, http2::Bool, _allocator) = + Response(status, headers, body, http2) + Response(status::Integer, body) = Response(status, nothing, Vector{UInt8}(string(body))) Response(status::Integer) = Response(status, nothing, nothing) getresponse(r::Response) = r +function _head_response!(resp::Response) + setinputstream!(resp, nothing) + hasheader(resp.headers, "transfer-encoding") && removeheader(resp.headers, "transfer-encoding") + setheader(resp.headers, "content-length" => "0") + return +end + bodylen(m::Message) = isdefined(m, :inputstream) && m.inputstream !== nothing ? m.inputstream.bodylen : 0 +function bodylen(r::Response) + if isdefined(r, :inputstream) && r.inputstream !== nothing + return r.inputstream.bodylen + end + return r.metrics.response_body_length +end + function Base.getproperty(x::Response, s::Symbol) if s == :status - ref = Ref{Cint}() - aws_http_message_get_response_status(x.ptr, ref) != 0 && return nothing - return Int(ref[]) + return AwsHTTP.http_message_get_response_status(getfield(x, :msg)) elseif s == :headers - return Headers(aws_http_message_get_headers(x.ptr)) + return Headers(AwsHTTP.http_message_get_headers(getfield(x, :msg))) elseif s == :version - return aws_http_message_get_protocol_version(x.ptr) == AWS_HTTP_VERSION_2 ? "2" : "1.1" + v = AwsHTTP.http_message_get_protocol_version(getfield(x, :msg)) + return v == AwsHTTP.HttpVersion.HTTP_2 ? HTTPVersion(2, 0) : HTTPVersion(1, 1) else return getfield(x, s) end @@ -349,9 +653,9 @@ end function Base.setproperty!(x::Response, s::Symbol, v) if s == :status - GC.@preserve v aws_http_message_set_response_status(x.ptr, v) != 0 && aws_throw_error() + AwsHTTP.http_message_set_response_status(getfield(x, :msg), Int(v)) != 0 && aws_throw_error() elseif s == :headers - addheaders(x.headers, v) + setheaders!(x, v) else setfield!(x, s, v) end @@ -359,7 +663,12 @@ end function print_response(io, status, version, headers, body) write(io, "\"\"\"\n") - write(io, string("HTTP/$version ", status, "\r\n")) + if version isa HTTPVersion + write(io, version) + else + write(io, "HTTP/", string(version)) + end + write(io, " ", string(status), "\r\n") for h in headers print_header(io, h) end @@ -413,4 +722,4 @@ Does this `Response` have a redirect status? isredirect(r::Response) = isredirect(r.status) isredirect(status::Integer) = status in (301, 302, 303, 307, 308) -Forms.parse_multipart_form(m::Message) = parse_multipart_form(getheader(m.headers, "content-type"), m.body) \ No newline at end of file +Forms.parse_multipart_form(m::Message) = parse_multipart_form(getheader(m.headers, "content-type"), m.body) diff --git a/src/server.jl b/src/server.jl index 98e1a1318..e2d947487 100644 --- a/src/server.jl +++ b/src/server.jl @@ -1,97 +1,344 @@ -socket_endpoint(host, port) = aws_socket_endpoint( - ntuple(i -> i > sizeof(host) ? 0x00 : codeunit(host, i), Base._counttuple(fieldtype(aws_socket_endpoint, :address))), - port % UInt32 -) +function server_tlsoptions(; + ssl_cert=nothing, + ssl_key=nothing, + ssl_capath=nothing, + ssl_cacert=nothing, + ssl_insecure=false, + ssl_alpn_list="h2;http/1.1", + ) + alpn_list = _normalize_alpn_list(ssl_alpn_list) + if ssl_cert !== nothing && ssl_key !== nothing + ctx_opts = Reseau.Sockets.tls_ctx_options_init_default_server_from_path(ssl_cert, ssl_key; alpn_list=alpn_list) + elseif Sys.iswindows() && ssl_cert !== nothing && ssl_key === nothing + ctx_opts = Reseau.Sockets.tls_ctx_options_init_default_server_from_system_path(ssl_cert) + else + throw(ArgumentError("ssl_cert and ssl_key are required for TLS server")) + end + if ssl_capath !== nothing || ssl_cacert !== nothing + Reseau.Sockets.tls_ctx_options_override_default_trust_store_from_path!(ctx_opts; + ca_path=ssl_capath, ca_file=ssl_cacert) + end + if ssl_insecure + Reseau.Sockets.tls_ctx_options_set_verify_peer!(ctx_opts, false) + end + ctx = Reseau.Sockets.tls_server_ctx_new(ctx_opts) + return Reseau.Sockets.TlsConnectionOptions(ctx; alpn_list=alpn_list) +end + +const _BACKLOG_DEFAULT = 511 mutable struct Connection{S} const server::S # Server{F, C} - const allocator::Ptr{aws_allocator} - const connection::Ptr{aws_http_connection} + const h1conn::Any # AwsHTTP.H1Connection or AwsHTTP.H2Connection + const channel::Any # Reseau.Channel const streams_lock::ReentrantLock const streams::Set{Stream} - connection_options::aws_http_server_connection_options + const remote_addr::String + const remote_port_num::Int - Connection( - server::S, - allocator::Ptr{aws_allocator}, - connection::Ptr{aws_http_connection}, - ) where {S} = new{S}(server, allocator, connection, ReentrantLock(), Set{Stream}()) + Connection(server::S, h1conn, channel, remote_addr::String, remote_port_num::Int) where {S} = + new{S}(server, h1conn, channel, ReentrantLock(), Set{Stream}(), remote_addr, remote_port_num) end -Base.hash(c::Connection, h::UInt) = hash(c.connection, h) - -function remote_address(c::Connection) - socket_ptr = aws_http_connection_get_remote_endpoint(c.connection) - addr = unsafe_load(socket_ptr).address - bytes = Vector{UInt8}(undef, length(addr)) - nul_i = 0 - for i in eachindex(bytes) - b = addr[i] - @inbounds bytes[i] = b - if b == 0x00 - nul_i = i - break - end - end - resize!(bytes, nul_i == 0 ? length(addr) : nul_i - 1) - return String(bytes) -end -remote_port(c::Connection) = Int(unsafe_load(aws_http_connection_get_remote_endpoint(c.connection)).port) +Base.hash(c::Connection, h::UInt) = hash(objectid(c), h) + +remote_address(c::Connection) = c.remote_addr +remote_port(c::Connection) = c.remote_port_num function http_version(c::Connection) - v = aws_http_connection_get_version(c.connection) - return v == AWS_HTTP_VERSION_2 ? "HTTP/2" : "HTTP/1.1" + v = AwsHTTP.http_connection_get_version(c.h1conn) + return v == AwsHTTP.HttpVersion.HTTP_2 ? "HTTP/2" : "HTTP/1.1" end -getinet(host::String, port::Integer) = Sockets.InetAddr(parse(IPAddr, host), port) -getinet(host::IPAddr, port::Integer) = Sockets.InetAddr(host, port) - mutable struct Server{F, C} const f::F const on_stream_complete::C + const on_shutdown::Any const fut::Future{Symbol} - const allocator::Ptr{aws_allocator} - const endpoint::aws_socket_endpoint - const socket_options::aws_socket_options - const tls_options::Union{aws_tls_connection_options, Nothing} const connections_lock::ReentrantLock const connections::Set{Connection} const closed::Threads.Event const access_log::Union{Nothing, Function} + const stream::Bool const logstate::Base.CoreLogging.LogState @atomic state::Symbol # :initializing, :running, :closed - server::Ptr{aws_http_server} - server_options::aws_http_server_options + bootstrap::Any # Reseau.ServerBootstrap + bound_port::Int Server{F, C}( f::F, on_stream_complete::C, + on_shutdown::Any, fut::Future{Symbol}, - allocator::Ptr{aws_allocator}, - endpoint::aws_socket_endpoint, - socket_options::aws_socket_options, - tls_options::Union{aws_tls_connection_options, Nothing}, connections_lock::ReentrantLock, connections::Set{Connection}, closed::Threads.Event, access_log::Union{Nothing, Function}, + stream::Bool, logstate::Base.CoreLogging.LogState, state::Symbol, - ) where {F, C} = new{F, C}(f, on_stream_complete, fut, allocator, endpoint, socket_options, tls_options, connections_lock, connections, closed, access_log, logstate, state) + ) where {F, C} = new{F, C}(f, on_stream_complete, on_shutdown, fut, connections_lock, connections, closed, access_log, stream, logstate, state) end Base.wait(s::Server) = wait(s.closed) ftype(::Server{F}) where {F} = F -port(s::Server) = Int(s.endpoint.port) +port(s::Server) = s.bound_port + +shutdown(fns::Vector{<:Function}) = foreach(shutdown, fns) +shutdown(::Nothing) = nothing +function shutdown(fn::Function) + try + fn() + catch e + @error "shutdown function failed" exception=(e, catch_backtrace()) + end + return +end + +function _future_done(f::Future) + return (@atomic f.set) != 0 +end + +function _should_log_stream_error(error_code::Integer)::Bool + error_code == 0 && return false + error_code == AwsHTTP.ERROR_HTTP_CONNECTION_CLOSED && return false + error_code == AwsHTTP.ERROR_HTTP_STREAM_CANCELLED && return false + error_code == AwsHTTP.ERROR_HTTP_SERVER_CLOSED && return false + error_code == AwsHTTP.ERROR_HTTP_SWITCHED_PROTOCOLS && return false + error_code == AwsHTTP.ERROR_HTTP_GOAWAY_RECEIVED && return false + error_code == AwsHTTP.ERROR_HTTP_RST_STREAM_RECEIVED && return false + error_code == Reseau.EventLoops.ERROR_IO_SOCKET_CLOSED && return false + error_code == Reseau.EventLoops.ERROR_IO_BROKEN_PIPE && return false + error_code == Reseau.EventLoops.ERROR_IO_OPERATION_CANCELLED && return false + return true +end + +function _should_log_channel_shutdown_error(error_code::Integer)::Bool + error_code == 0 && return false + error_code == AwsHTTP.ERROR_HTTP_CONNECTION_CLOSED && return false + error_code == AwsHTTP.ERROR_HTTP_SERVER_CLOSED && return false + error_code == Reseau.EventLoops.ERROR_IO_SOCKET_CLOSED && return false + error_code == Reseau.EventLoops.ERROR_IO_BROKEN_PIPE && return false + error_code == Reseau.EventLoops.ERROR_IO_OPERATION_CANCELLED && return false + return true +end + +function _create_request_handler!(conn::Connection, aws_conn; http2::Bool=false) + server = conn.server + http_conn = aws_conn + stream = Stream{typeof(conn)}(nothing, http2, true) + stream.connection = conn + stream.request = Request("", "", nothing, nothing, http2) + + on_request_headers = (aws_stream, header_block, headers_vec, user_data) -> begin + if header_block == AwsHTTP.HttpHeaderBlock.TRAILING + trailers = stream.request.trailers + if trailers === nothing + trailers = Headers() + stream.request.trailers = trailers + end + for h in headers_vec + addheader(trailers, h.name, h.value) + end + else + hdrs = stream.request.headers + for h in headers_vec + if stream.http2 && !isempty(h.name) && h.name[1] == ':' + if h.name == ":scheme" || h.name == ":authority" || h.name == ":protocol" + addheader(hdrs, h.name, h.value) + if h.name == ":authority" && !hasheader(hdrs, "host") + addheader(hdrs, "host", h.value) + end + end + else + addheader(hdrs, h.name, h.value) + end + end + end + return AwsHTTP.OP_SUCCESS + end + + on_request_header_block_done = (aws_stream, header_block, user_data) -> begin + if header_block != AwsHTTP.HttpHeaderBlock.MAIN + return AwsHTTP.OP_SUCCESS + end + method = AwsHTTP.http_stream_get_incoming_request_method(aws_stream) + path = AwsHTTP.http_stream_get_incoming_request_uri(aws_stream) + method === nothing && (method = "") + path === nothing && (path = "") + stream.request.method = method + stream.request.path = path + notify(stream.headers_ready) + if server.stream && !stream.handler_started + stream.handler_started = true + stream.bufferstream === nothing && (stream.bufferstream = Base.BufferStream()) + Threads.@spawn begin + Base.CoreLogging.with_logstate(server.logstate) do + try + Base.invokelatest(server.f, stream) + catch e + @error "Request handler error; sending 500" exception=(e, catch_backtrace()) + if !stream.response_started + try setstatus(stream, 500) catch; end + end + finally + try closewrite(stream) catch; end + end + end + end + end + return AwsHTTP.OP_SUCCESS + end + + on_request_body = (aws_stream, data, user_data) -> begin + if server.stream + stream.bufferstream === nothing && (stream.bufferstream = Base.BufferStream()) + write(stream.bufferstream, data) + return AwsHTTP.OP_SUCCESS + end + body = stream.request.body + if body === nothing + stream.request.body = copy(data) + else + append!(body, data) + end + return AwsHTTP.OP_SUCCESS + end + + on_request_done = (aws_stream, user_data) -> begin + if server.stream + Base.CoreLogging.with_logstate(server.logstate) do + stream.bufferstream !== nothing && close(stream.bufferstream) + end + return + end + errormonitor(Threads.@spawn begin + Base.CoreLogging.with_logstate(server.logstate) do + try + stream.response = Base.invokelatest(server.f, stream.request)::Response + if stream.request.method == "HEAD" + _head_response!(stream.response) + end + catch e + @error "Request handler error; sending 500" exception=(e, catch_backtrace()) + stream.response = Response(500) + end + _send_response!(stream) + end + end) + return + end + + on_complete = (aws_stream, error_code, user_data) -> begin + stream.released && return + stream.released = true + Base.CoreLogging.with_logstate(server.logstate) do + if _should_log_stream_error(error_code) + @error "server stream complete error" error_code + end + if server.on_stream_complete !== nothing + try + Base.invokelatest(server.on_stream_complete, stream) + catch e + @error "on_stream_complete error" exception=(e, catch_backtrace()) + end + end + if stream.on_complete !== nothing + try + Base.invokelatest(stream.on_complete, stream) + catch e + @error "stream on_complete error" exception=(e, catch_backtrace()) + end + stream.on_complete = nothing + end + if server.access_log !== nothing + try + if isdefined(stream, :request) && isdefined(stream, :response) + @info sprint(server.access_log, stream) _group=:access + end + catch e + @error "access log error" exception=(e, catch_backtrace()) + end + end + shutdown_channel = false + @lock conn.streams_lock begin + delete!(conn.streams, stream) + if @atomic(server.state) == :closing && isempty(conn.streams) + shutdown_channel = true + end + end + if shutdown_channel + Reseau.Sockets.channel_shutdown!(conn.channel; shutdown_immediately=true) + @lock server.connections_lock begin + delete!(server.connections, conn) + end + end + # HTTP pipelining: create next request handler if connection allows + if !stream.http2 && AwsHTTP.http_connection_new_requests_allowed(http_conn) + _create_request_handler!(conn, http_conn; http2=false) + end + end + return + end + + on_destroy = (user_data) -> nothing + + opts = AwsHTTP.HttpRequestHandlerOptions( + http_conn, + nothing, + on_request_headers, + on_request_header_block_done, + on_request_body, + on_request_done, + on_complete, + on_destroy, + ) + if http2 + h2stream = AwsHTTP.h2_stream_new_request_handler(http_conn, opts; manual_write=server.stream) + stream.aws_stream = h2stream + @lock conn.streams_lock begin + push!(conn.streams, stream) + end + return h2stream + end + h1stream = AwsHTTP.http_connection_new_request_handler(http_conn, opts) + if h1stream === nothing + @error "failed to create request handler stream" + return + end + stream.aws_stream = h1stream + AwsHTTP.h1_stream_activate!(h1stream) + @lock conn.streams_lock begin + push!(conn.streams, stream) + end + return +end + +function _warn_unsupported_server_options(; reuseaddr::Bool, backlog::Integer) + reuseaddr && @warn "reuseaddr is not supported by the Reseau server; ignoring" + backlog != _BACKLOG_DEFAULT && @warn "backlog is not supported by the Reseau server; ignoring" + return +end + +function _stop_new_requests!(conn::Connection) + AwsHTTP.http_connection_stop_new_requests(conn.h1conn) + if AwsHTTP.http_connection_get_version(conn.h1conn) == AwsHTTP.HttpVersion.HTTP_2 + try + AwsHTTP.h2_connection_send_goaway!(conn.h1conn; allow_more_streams=false) + catch + end + end + return +end function serve!(f, host="127.0.0.1", port=8080; - allocator=default_aws_allocator(), - bootstrap::Ptr{aws_server_bootstrap}=default_aws_server_bootstrap(), - endpoint=nothing, - listenany::Bool=false, on_stream_complete=nothing, + on_shutdown=nothing, access_log::Union{Nothing, Function}=nothing, + stream::Bool=false, + listenany::Bool=false, + reuseaddr::Bool=false, + backlog::Integer=_BACKLOG_DEFAULT, # socket options - socket_options=nothing, socket_domain=:ipv4, connect_timeout_ms::Integer=3000, keep_alive_interval_sec::Integer=0, @@ -108,260 +355,338 @@ function serve!(f, host="127.0.0.1", port=8080; ssl_alpn_list="h2;http/1.1", initial_window_size=typemax(UInt64), ) - addr = getinet(host, port) - if listenany - port, sock = Sockets.listenany(addr.host, addr.port) - close(sock) + _ensure_resources!() + _warn_unsupported_server_options(; reuseaddr=reuseaddr, backlog=backlog) + host_str = string(host) + # `listenany=true` should pick an ephemeral port (port=0), avoiding collisions + # with any existing process bound to the default `port` (e.g. 8080). + port_int = listenany ? 0 : Int(port) + tls_conn_opts = if tls_options !== nothing + tls_options + elseif any(x -> x !== nothing, (ssl_cert, ssl_key, ssl_capath, ssl_cacert)) + server_tlsoptions(; + ssl_cert, ssl_key, ssl_capath, ssl_cacert, ssl_insecure, ssl_alpn_list + ) + else + nothing end server = Server{typeof(f), typeof(on_stream_complete)}( - f, # RequestHandler + f, on_stream_complete, + on_shutdown, Future{Symbol}(), - allocator, - endpoint !== nothing ? endpoint : socket_endpoint(host, port), - socket_options !== nothing ? socket_options : aws_socket_options( - AWS_SOCKET_STREAM, # socket type - socket_domain == :ipv4 ? AWS_SOCKET_IPV4 : AWS_SOCKET_IPV6, # socket domain - AWS_SOCKET_IMPL_PLATFORM_DEFAULT, # aws_socket_impl_type - connect_timeout_ms, - keep_alive_interval_sec, - keep_alive_timeout_sec, - keep_alive_max_failed_probes, - keepalive, - ntuple(x -> Cchar(0), 16) # network_interface_name - ), - tls_options !== nothing ? tls_options : - any(x -> x !== nothing, (ssl_cert, ssl_key, ssl_capath, ssl_cacert)) ? LibAwsIO.tlsoptions(host; - ssl_cert, - ssl_key, - ssl_capath, - ssl_cacert, - ssl_insecure, - ssl_alpn_list - ) : nothing, - ReentrantLock(), # connections_lock - Set{Connection}(), # connections - Threads.Event(), # closed + ReentrantLock(), + Set{Connection}(), + Threads.Event(), access_log, + stream, Base.CoreLogging.current_logstate(), - :initializing, # state + :initializing, ) - server.server_options = aws_http_server_options( - 1, - allocator, - bootstrap, - pointer(FieldRef(server, :endpoint)), - pointer(FieldRef(server, :socket_options)), - server.tls_options === nothing ? C_NULL : pointer(FieldRef(server, :tls_options)), - initial_window_size, - pointer_from_objref(server), - on_incoming_connection[], - on_destroy_complete[], - false # manual_window_management + server.bound_port = port_int + listener_ready = Threads.Event() + socket_opts = Reseau.Sockets.SocketOptions(; + domain = socket_domain == :ipv4 ? Reseau.Sockets.SocketDomain.IPV4 : Reseau.Sockets.SocketDomain.IPV6, + connect_timeout_ms = connect_timeout_ms, + keep_alive_interval_sec = keep_alive_interval_sec, + keep_alive_timeout_sec = keep_alive_timeout_sec, + keep_alive_max_failed_probes = keep_alive_max_failed_probes, + keepalive = keepalive, ) - server.server = aws_http_server_new(FieldRef(server, :server_options)) - @assert server.server != C_NULL "failed to create server" - @atomic server.state = :running - return server -end - -const on_incoming_connection = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_incoming_connection(aws_server, aws_conn, error_code, server_ptr) - server = unsafe_pointer_to_objref(server_ptr) - Base.CoreLogging.with_logstate(server.logstate) do - if error_code != 0 - @error "incoming connection error" exception=(aws_error(error_code), Base.backtrace()) - return - end - conn = Connection( - server, - server.allocator, - aws_conn, - ) - conn.connection_options = aws_http_server_connection_options( - 1, - pointer_from_objref(conn), - on_incoming_request[], - on_connection_shutdown[] - ) - if aws_http_connection_configure_server( - aws_conn, - FieldRef(conn, :connection_options) - ) != 0 - @error "failed to configure connection" exception=(aws_error(), Base.backtrace()) - return - end - @lock server.connections_lock begin - push!(server.connections, conn) + alpn_list = _tls_alpn_list(tls_conn_opts) + initial_window = Csize_t(min(UInt64(initial_window_size), UInt64(typemax(Csize_t)))) + on_incoming_channel_setup = (bootstrap, error_code, channel, user_data) -> begin + Base.CoreLogging.with_logstate(server.logstate) do + if error_code != 0 + @error "incoming channel setup error" error_code + return + end + st = @atomic(server.state) + if st == :closing || st == :closed + Reseau.Sockets.channel_shutdown!(channel; shutdown_immediately=true) + return + end + slot = Reseau.Sockets.channel_slot_new!(channel) + Reseau.Sockets.channel_slot_insert_end!(channel, slot) + version = AwsHTTP.HttpVersion.HTTP_1_1 + if tls_conn_opts !== nothing + tls_slot = slot.adj_left + if tls_slot === nothing || tls_slot.handler === nothing || !(tls_slot.handler isa Reseau.Sockets.TlsChannelHandler) + @error "incoming channel setup error" error_code=Reseau.ERROR_INVALID_STATE + Reseau.Sockets.channel_shutdown!(channel, Reseau.ERROR_INVALID_STATE) + return + end + protocol = Reseau.Sockets.tls_handler_protocol(tls_slot.handler) + if protocol.len > 0 + protocol_str = Reseau.byte_buffer_as_string(protocol) + if protocol_str == "h2" + version = AwsHTTP.HttpVersion.HTTP_2 + elseif protocol_str == "http/1.1" + version = AwsHTTP.HttpVersion.HTTP_1_1 + end + end + end + http_conn = AwsHTTP.http_connection_new_channel_handler(; + is_server=true, + version=version, + initial_window_size=initial_window, + ) + http_conn === nothing && return + Reseau.Sockets.channel_slot_set_handler!(slot, http_conn) + http_conn.slot = slot + # Extract remote endpoint from the socket handler (first slot in pipeline) + remote_addr = "0.0.0.0" + remote_port_num = 0 + try + socket_handler = channel.first.handler + ep = socket_handler.socket.remote_endpoint + remote_addr = Reseau.Sockets.get_address(ep) + remote_port_num = Int(ep.port) + catch + end + http_conn.remote_endpoint = "$remote_addr:$remote_port_num" + conn = Connection(server, http_conn, channel, remote_addr, remote_port_num) + @lock server.connections_lock begin + push!(server.connections, conn) + end + if AwsHTTP.http_connection_get_version(http_conn) == AwsHTTP.HttpVersion.HTTP_2 + opts = AwsHTTP.HttpServerConnectionOptions( + connection_user_data = conn, + on_incoming_request = (h2conn, ud) -> begin + try + return _create_request_handler!(ud, h2conn; http2=true) + catch e + @error "failed to create HTTP/2 request handler" exception=(e, catch_backtrace()) + return nothing + end + end, + on_shutdown = (h2conn, err, ud) -> nothing, + ) + status = AwsHTTP.http_connection_configure_server(http_conn, opts) + if status != AwsHTTP.OP_SUCCESS + @error "failed to configure HTTP/2 server connection" error_code=status + return + end + else + _create_request_handler!(conn, http_conn; http2=false) + end + if Reseau.Sockets.channel_thread_is_callers_thread(channel) + Reseau.Sockets.channel_trigger_read(channel) + else + task = Reseau.Sockets.ChannelTask((task, ctx, status) -> begin + status == Reseau.TaskStatus.RUN_READY || return nothing + Reseau.Sockets.channel_trigger_read(ctx.channel) + return nothing + end, (channel = channel,), "http_server_trigger_read") + Reseau.Sockets.channel_schedule_task_now!(channel, task) + end end return end -end - -const on_connection_shutdown = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_connection_shutdown(aws_conn, error_code, conn_ptr) - conn = unsafe_pointer_to_objref(conn_ptr) - Base.CoreLogging.with_logstate(conn.server.logstate) do - if error_code != 0 - @error "connection shutdown error" exception=(aws_error(error_code), Base.backtrace()) - end - @lock conn.server.connections_lock begin - delete!(conn.server.connections, conn) + on_incoming_channel_shutdown = (bootstrap, error_code, channel, user_data) -> begin + Base.CoreLogging.with_logstate(server.logstate) do + if _should_log_channel_shutdown_error(error_code) + @error "incoming channel shutdown error" error_code + end + @lock server.connections_lock begin + filter!(c -> c.channel !== channel, server.connections) + end end return end -end - -const on_incoming_request = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_incoming_request(aws_conn, conn_ptr) - conn = unsafe_pointer_to_objref(conn_ptr) - Base.CoreLogging.with_logstate(conn.server.logstate) do - stream = Stream{typeof(conn)}( - conn.allocator, - false, # decompress - aws_http_connection_get_version(aws_conn) == AWS_HTTP_VERSION_2 # http2 - ) - stream.connection = conn - stream.request_handler_options = aws_http_request_handler_options( - 1, - aws_conn, - pointer_from_objref(stream), - on_request_headers[], - on_request_header_block_done[], - on_request_body[], - on_request_done[], - on_server_stream_complete[], - on_destroy[] - ) - stream.request = Request("", "") - stream.ptr = aws_http_stream_new_server_request_handler( - FieldRef(stream, :request_handler_options) - ) - if stream.ptr == C_NULL - @error "failed to create stream" exception=(aws_error(), Base.backtrace()) - else - @lock conn.streams_lock begin - push!(conn.streams, stream) - end - end - return stream.ptr + on_listener_destroy = (bootstrap, user_data) -> begin + notify(server.fut, :destroyed) + return end + bootstrap_opts = Reseau.Sockets.ServerBootstrapOptions(; + event_loop_group = _EVENT_LOOP_GROUP[], + socket_options = socket_opts, + host = host_str, + port = UInt32(port_int), + tls_connection_options = tls_conn_opts, + on_protocol_negotiated = nothing, + on_listener_setup = (bootstrap, error_code, user_data) -> begin + if error_code == 0 && bootstrap.listener_socket !== nothing + server.bound_port = try + ep = Reseau.Sockets.socket_get_bound_address(bootstrap.listener_socket) + Int(ep.port) + catch + port_int + end + else + server.bound_port = port_int + end + notify(listener_ready) + return nothing + end, + on_incoming_channel_setup = on_incoming_channel_setup, + on_incoming_channel_shutdown = on_incoming_channel_shutdown, + on_listener_destroy = on_listener_destroy, + user_data = server, + enable_read_back_pressure = false, + ) + bs = Reseau.Sockets.ServerBootstrap(bootstrap_opts) + server.bootstrap = bs + # Wait until the listener is ready so `port(server)` is accurate immediately. + wait(listener_ready) + @atomic server.state = :running + return server end -const on_request_headers = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_request_headers(aws_stream_ptr, header_block, header_array::Ptr{aws_http_header}, num_headers, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - headers = stream.request.headers - addheaders(headers, header_array, num_headers) - return Cint(0) +function serve(f, host="127.0.0.1", port=8080; stream::Bool=false, kw...) + server = serve!(f, host, port; stream=stream, kw...) + wait(server) + return server end -const on_request_header_block_done = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_request_header_block_done(aws_stream_ptr, header_block, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - ret = aws_http_stream_get_incoming_request_method(aws_stream_ptr, FieldRef(stream, :method)) - ret != 0 && return ret - aws_http_message_set_request_method(stream.request.ptr, stream.method) - ret = aws_http_stream_get_incoming_request_uri(aws_stream_ptr, FieldRef(stream, :path)) - ret != 0 && return ret - aws_http_message_set_request_path(stream.request.ptr, stream.path) - return Cint(0) -end +listen!(f, host="127.0.0.1", port=8080; kw...) = serve!(f, host, port; stream=true, kw...) +listen(f, host="127.0.0.1", port=8080; kw...) = serve(f, host, port; stream=true, kw...) -const on_request_body = Ref{Ptr{Cvoid}}(C_NULL) - -#TODO: how could we allow for streaming request bodies? -function c_on_request_body(aws_stream_ptr, data::Ptr{aws_byte_cursor}, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - bc = unsafe_load(data) - body = stream.request.body - if body === nothing - body = Vector{UInt8}(undef, bc.len) - GC.@preserve body unsafe_copyto!(pointer(body), bc.ptr, bc.len) - stream.request.body = body - else - newlen = length(body) + bc.len - resize!(body, newlen) - GC.@preserve body unsafe_copyto!(pointer(body, length(body) - bc.len + 1), bc.ptr, bc.len) +function _push_promise_headers!(req::Request, parent::Stream; scheme=nothing, authority=nothing) + if !hasheader(req.headers, ":scheme") + scheme_val = scheme === nothing ? header(parent.request, ":scheme", "") : String(scheme) + isempty(scheme_val) && throw(ArgumentError("push promise requires :scheme")) + addheader(req.headers, ":scheme", scheme_val) + end + if !hasheader(req.headers, ":authority") + authority_val = authority === nothing ? header(parent.request, ":authority", header(parent.request, "host", "")) : String(authority) + isempty(authority_val) && throw(ArgumentError("push promise requires :authority")) + addheader(req.headers, ":authority", authority_val) end - return Cint(0) + return end -const on_request_done = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_request_done(aws_stream_ptr, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - Base.CoreLogging.with_logstate(stream.connection.server.logstate) do - try - stream.response = Base.invokelatest(stream.connection.server.f, stream.request)::Response - if stream.request.method == "HEAD" - setinputstream!(stream.response, nothing) +function push_promise(parent::Stream, req::Request; pad_length::Integer=0, scheme=nothing, authority=nothing) + parent.server_side || error("push_promise is only supported for server streams") + parent.http2 || throw(ArgumentError("HTTP/2 stream required for push promise")) + pad_length < 0 && throw(ArgumentError("pad_length must be >= 0")) + pad_length > typemax(UInt8) && throw(ArgumentError("pad_length must be <= $(typemax(UInt8))")) + isdefined(parent, :aws_stream) || throw(ArgumentError("HTTP stream is not initialized")) + _push_promise_headers!(req, parent; scheme=scheme, authority=authority) + msg = getfield(req, :msg) + if AwsHTTP.http_message_get_protocol_version(msg) != AwsHTTP.HttpVersion.HTTP_2 + converted = AwsHTTP.http2_message_new_from_http1(msg) + converted === nothing && throw(AWSError("Failed to convert push promise request to HTTP/2")) + setfield!(req, :msg, converted) + msg = converted + end + h2conn = parent.aws_stream.owning_connection + h2conn === nothing && throw(ArgumentError("HTTP/2 connection is not initialized")) + promised_id = h2conn.next_stream_id + promised_id > AwsHTTP.H2_STREAM_ID_MAX && throw(AWSError("HTTP/2 stream IDs exhausted")) + h2conn.next_stream_id += UInt32(2) + h2stream = _create_request_handler!(parent.connection, h2conn; http2=true) + h2stream === nothing && throw(AWSError("Failed to create push promise stream")) + push_stream = nothing + @lock parent.connection.streams_lock begin + for s in parent.connection.streams + if s.aws_stream === h2stream + push_stream = s + break end - #TODO: is it possible to stream the response body? - #TODO: support transfer-encoding: gzip - catch e - @error "Request handler error; sending 500" exception=(e, catch_backtrace()) - stream.response = Response(500) end - ret = aws_http_stream_send_response(aws_stream_ptr, stream.response.ptr) - if ret != 0 - @error "failed to send response" exception=(aws_error(ret), Base.backtrace()) - return Cint(AWS_ERROR_HTTP_UNKNOWN) + end + push_stream === nothing && throw(AWSError("Failed to locate push promise stream")) + push_stream.request = req + notify(push_stream.headers_ready) + method_val = req.method + path_val = req.path + method_val === nothing && (method_val = "") + path_val === nothing && (path_val = "") + h2stream.id = promised_id + AwsHTTP.h2_stream_init_window_sizes!(h2stream, h2conn) + h2stream.metrics = AwsHTTP.HttpStreamMetrics( + h2stream.metrics.send_start_timestamp_ns, + h2stream.metrics.send_end_timestamp_ns, + h2stream.metrics.sending_duration_ns, + h2stream.metrics.receive_start_timestamp_ns, + h2stream.metrics.receive_end_timestamp_ns, + h2stream.metrics.receiving_duration_ns, + promised_id, + ) + h2stream.state = AwsHTTP.H2StreamState.RESERVED_LOCAL + h2stream.request_method = AwsHTTP.http_str_to_method(String(method_val)) + h2stream.request_method_str = String(method_val) + h2stream.request_path = String(path_val) + h2conn.active_streams[promised_id] = h2stream + headers = AwsHTTP.http_message_get_headers(msg) + status = AwsHTTP.h2_stream_send_push_promise!(parent.aws_stream, h2conn, promised_id, headers; + pad_length=UInt8(pad_length)) + if status != AwsHTTP.OP_SUCCESS + delete!(h2conn.active_streams, promised_id) + @lock parent.connection.streams_lock begin + delete!(parent.connection.streams, push_stream) end - return Cint(0) + throw(AWSError("Failed to send push promise")) end + return push_stream end -const on_server_stream_complete = Ref{Ptr{Cvoid}}(C_NULL) +function push_promise(parent::Stream, method::Union{String, Symbol}, path; headers=Header[], pad_length::Integer=0, scheme=nothing, authority=nothing) + return push_promise(parent, Request(String(method), String(path), headers, nothing, true); pad_length=pad_length, scheme=scheme, authority=authority) +end -function c_on_server_stream_complete(aws_stream_ptr, error_code, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - Base.CoreLogging.with_logstate(stream.connection.server.logstate) do - if error_code != 0 - @error "server complete error" exception=(aws_error(error_code), Base.backtrace()) - end - if stream.connection.server.on_stream_complete !== nothing - try - Base.invokelatest(stream.connection.server.on_stream_complete, stream) - catch e - @error "on_stream_complete error" exception=(e, catch_backtrace()) +function _forceclose!(server::Server; skip_shutdown::Bool=false) + skip_shutdown || shutdown(server.on_shutdown) + Reseau.Sockets.server_bootstrap_shutdown!(server.bootstrap) + conns = Connection[] + @lock server.connections_lock begin + append!(conns, server.connections) + end + for conn in conns + Reseau.Sockets.channel_shutdown!(conn.channel; shutdown_immediately=true) + end + @atomic server.state = :closed + notify(server.closed) + return +end + +function Base.close(server::Server) + state = @atomicswap server.state = :closing + if state == :closed + return + elseif state == :closing + wait(server.closed) + return + end + shutdown(server.on_shutdown) + Reseau.Sockets.server_bootstrap_shutdown!(server.bootstrap) + conns = Connection[] + @lock server.connections_lock begin + append!(conns, server.connections) + end + for conn in conns + _stop_new_requests!(conn) + @lock conn.streams_lock begin + if isempty(conn.streams) + Reseau.Sockets.channel_shutdown!(conn.channel; shutdown_immediately=true) + @lock server.connections_lock begin + delete!(server.connections, conn) + end end end - if stream.connection.server.access_log !== nothing - try - @info sprint(stream.connection.server.access_log, stream) _group=:access - catch e - @error "access log error" exception=(e, catch_backtrace()) - end + end + deadline = time() + 0.5 + while time() < deadline + empty = @lock server.connections_lock begin + isempty(server.connections) end - @lock stream.connection.streams_lock begin - delete!(stream.connection.streams, stream) + if empty + @atomic server.state = :closed + notify(server.closed) + return end - return Cint(0) + _task_sleep_s(0.05) end -end - -const on_destroy_complete = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_destroy_complete(server_ptr) - server = unsafe_pointer_to_objref(server_ptr) - notify(server.fut, :destroyed) + _forceclose!(server; skip_shutdown=true) return end -function Base.close(server::Server) +function forceclose(server::Server) state = @atomicswap server.state = :closed - if state == :running - aws_http_server_release(server.server) - @assert wait(server.fut) == :destroyed - notify(server.closed) - end + state == :closed && return + _forceclose!(server; skip_shutdown = state == :closing) return end -Base.isopen(server::Server) = @atomic(server.state) != :closed +Base.isopen(server::Server) = @atomic(server.state) == :running diff --git a/src/statistics.jl b/src/statistics.jl new file mode 100644 index 000000000..5513ed955 --- /dev/null +++ b/src/statistics.jl @@ -0,0 +1,74 @@ +export AWSCRT_STAT_CAT_HTTP1_CHANNEL, AWSCRT_STAT_CAT_HTTP2_CHANNEL +export aws_crt_statistics_http1_channel, aws_crt_statistics_http2_channel +export _decode_statistics, _call_statistics_observer + +const AWSCRT_STAT_CAT_HTTP1_CHANNEL = :http1_channel +const AWSCRT_STAT_CAT_HTTP2_CHANNEL = :http2_channel + +struct aws_crt_statistics_http1_channel + category::Symbol + pending_outgoing_stream_ms::UInt64 + pending_incoming_stream_ms::UInt64 + current_outgoing_stream_id::UInt32 + current_incoming_stream_id::UInt32 +end + +struct aws_crt_statistics_http2_channel + category::Symbol + pending_outgoing_stream_ms::UInt64 + pending_incoming_stream_ms::UInt64 + was_inactive::Bool +end + +function _normalize_stat_category(category) + if category === AWSCRT_STAT_CAT_HTTP1_CHANNEL || category === AWSCRT_STAT_CAT_HTTP2_CHANNEL + return category + end + category isa Symbol && return category + return Symbol(category) +end + +function _normalize_stat(stat::aws_crt_statistics_http1_channel) + cat = _normalize_stat_category(stat.category) + cat === stat.category && return stat + return aws_crt_statistics_http1_channel( + cat, + stat.pending_outgoing_stream_ms, + stat.pending_incoming_stream_ms, + stat.current_outgoing_stream_id, + stat.current_incoming_stream_id, + ) +end + +function _normalize_stat(stat::aws_crt_statistics_http2_channel) + cat = _normalize_stat_category(stat.category) + cat === stat.category && return stat + return aws_crt_statistics_http2_channel( + cat, + stat.pending_outgoing_stream_ms, + stat.pending_incoming_stream_ms, + stat.was_inactive, + ) +end + +_normalize_stat(stat) = stat + +function _decode_statistics(stats_list) + list = stats_list isa Base.RefValue ? stats_list[] : stats_list + out = Any[] + if list isa AbstractVector + for item in list + item = item isa Base.RefValue ? item[] : item + push!(out, _normalize_stat(item)) + end + return out + end + throw(ArgumentError("stats_list must be an AbstractVector")) +end + +function _call_statistics_observer(observer, nonce, stats_list) + observer === nothing && return nothing + stats = _decode_statistics(stats_list) + Base.invokelatest(observer, nonce, stats) + return nothing +end diff --git a/src/utils.jl b/src/utils.jl index 125ad21d8..d282d18e0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,127 @@ +export bytes, isbytes, nbytes, nobytes, + escapehtml, tocameldash, iso8859_1_to_utf8, ascii_lc_isequal + +const HTTP2_DEFAULT_WINDOW_SIZE = 65535 +const HTTP2_MAX_WINDOW_SIZE = 0x7fffffff +const AWS_HTTP2_SETTINGS_MAX_CONCURRENT_STREAMS = AwsHTTP.Http2SettingsId.MAX_CONCURRENT_STREAMS +const AWS_HTTP2_SETTINGS_INITIAL_WINDOW_SIZE = AwsHTTP.Http2SettingsId.INITIAL_WINDOW_SIZE +const AWS_HTTP2_SETTINGS_COUNT = Int(AwsHTTP.HTTP2_SETTINGS_END_RANGE - AwsHTTP.HTTP2_SETTINGS_BEGIN_RANGE) +const _H2_CHANNEL_SUPPORTED = AwsHTTP.H2Connection <: Reseau.Sockets.AbstractChannelHandler + +function _normalize_alpn_list(alpn_list::Union{String, Nothing}) + alpn_list === nothing && return nothing + isempty(alpn_list) && return alpn_list + _H2_CHANNEL_SUPPORTED && return alpn_list + parts = split(alpn_list, ';'; keepempty = false) + filtered = [p for p in parts if lowercase(p) != "h2"] + isempty(filtered) && return "http/1.1" + return join(filtered, ';') +end + +@inline function _alpn_includes_h2(alpn_list::Union{String, Nothing})::Bool + alpn_list === nothing && return false + for part in split(alpn_list, ';'; keepempty = false) + lowercase(part) == "h2" && return true + end + return false +end + +function _should_use_nw_tls(alpn_list::Union{String, Nothing})::Bool + @static if Sys.isapple() + return _use_nw_sockets() && _alpn_includes_h2(alpn_list) + else + return false + end +end + +function _use_nw_sockets()::Bool + @static if Sys.isapple() + Reseau.Sockets._tls_set_use_secitem_from_env() + return Reseau.Sockets.is_using_secitem() + else + return false + end +end + +function _tls_alpn_list(tls_opts) + tls_opts === nothing && return nothing + if hasproperty(tls_opts, :alpn_list) && tls_opts.alpn_list !== nothing + return tls_opts.alpn_list + end + if hasproperty(tls_opts, :ctx) && tls_opts.ctx !== nothing + return tls_opts.ctx.options.alpn_list + end + return nothing +end + +""" + HTTPVersion(major, minor) + +The HTTP version number consists of two digits separated by a ".". The first +digit (`major`) indicates the HTTP messaging syntax, whereas the second digit +(`minor`) indicates the highest minor version within that major version to which +the sender is conformant and able to understand for future communication. +""" +struct HTTPVersion + major::UInt8 + minor::UInt8 +end + +HTTPVersion(major::Integer) = HTTPVersion(major, 0x00) +HTTPVersion(v::AbstractString) = parse(HTTPVersion, v) +HTTPVersion(v::VersionNumber) = convert(HTTPVersion, v) +Base.convert(::Type{HTTPVersion}, v::VersionNumber) = HTTPVersion(v.major, v.minor) +Base.VersionNumber(v::HTTPVersion) = VersionNumber(v.major, v.minor) + +Base.show(io::IO, v::HTTPVersion) = print(io, "HTTPVersion(\"", string(v.major), ".", string(v.minor), "\")") +Base.write(io::IO, v::HTTPVersion) = write(io, "HTTP/", string(v.major), ".", string(v.minor)) + +Base.:(==)(va::VersionNumber, vb::HTTPVersion) = va == VersionNumber(vb) +Base.:(==)(va::HTTPVersion, vb::VersionNumber) = VersionNumber(va) == vb +Base.isless(va::VersionNumber, vb::HTTPVersion) = isless(va, VersionNumber(vb)) +Base.isless(va::HTTPVersion, vb::VersionNumber) = isless(VersionNumber(va), vb) +function Base.isless(va::HTTPVersion, vb::HTTPVersion) + va.major < vb.major && return true + va.major > vb.major && return false + va.minor < vb.minor && return true + return false +end + +function Base.parse(::Type{HTTPVersion}, v::AbstractString) + ver = tryparse(HTTPVersion, v) + ver === nothing && throw(ArgumentError("invalid HTTP version string: $(repr(v))")) + return ver +end + +# We only support single-digits for major and minor versions. +function Base.tryparse(::Type{HTTPVersion}, v::AbstractString) + isempty(v) && return nothing + len = ncodeunits(v) + + i = firstindex(v) + d1 = v[i] + if isdigit(d1) + major = parse(UInt8, d1) + else + return nothing + end + + i = nextind(v, i) + i > len && return HTTPVersion(major) + dot = v[i] + dot == '.' || return nothing + + i = nextind(v, i) + i > len && return HTTPVersion(major) + d2 = v[i] + if isdigit(d2) + minor = parse(UInt8, d2) + else + return nothing + end + return HTTPVersion(major, minor) +end + """ escapehtml(i::String) @@ -58,8 +182,24 @@ tocameldash(s::AbstractString) = tocameldash(String(s)) @inline isupper(b::UInt8) = UInt8('A') <= b <= UInt8('Z') @inline lower(c::UInt8) = c | 0x20 -function parseuri(url, query, allocator) - uri_ref = Ref{aws_uri}() +""" + ascii_lc_isequal(a, b) + +Case insensitive ASCII string comparison. +""" +function ascii_lc_isequal(a, b) + acu = codeunits(a) + bcu = codeunits(b) + len = length(acu) + len != length(bcu) && return false + for i = 1:len + @inbounds (acu[i] in UInt8('A'):UInt8('Z') ? acu[i] + 0x20 : acu[i]) == + (bcu[i] in UInt8('A'):UInt8('Z') ? bcu[i] + 0x20 : bcu[i]) || return false + end + return true +end + +function parseuri(url, query) if url isa AbstractString url_str = String(url) * (query === nothing ? "" : ("?" * URIs.escapeuri(query))) elseif url isa URI @@ -67,83 +207,157 @@ function parseuri(url, query, allocator) else throw(ArgumentError("url must be an AbstractString or URI")) end - GC.@preserve url_str begin - url_ref = Ref(aws_byte_cursor(sizeof(url_str), pointer(url_str))) - aws_uri_init_parse(uri_ref, allocator, url_ref) - end - return uri_ref[] + return URIs.URI(url_str) end +# compatibility: 3-arg version for callers that still pass allocator +parseuri(url, query, _allocator) = parseuri(url, query) + +""" + bytes(x) + +If `x` is "castable" to an `AbstractVector{UInt8}`, then an +`AbstractVector{UInt8}` is returned; otherwise `x` is returned. +""" +function bytes end +bytes(s::AbstractVector{UInt8}) = s +bytes(s::AbstractString) = codeunits(s) +bytes(x) = x + isbytes(x) = x isa AbstractVector{UInt8} || x isa AbstractString -str(bc::aws_byte_cursor) = bc.ptr == C_NULL || bc.len == 0 ? "" : unsafe_string(bc.ptr, bc.len) +""" + nbytes(x) -> Int -function print_uri(io, uri::aws_uri) - print(io, "scheme: ", str(uri.scheme), "\n") - print(io, "userinfo: ", str(uri.userinfo), "\n") - print(io, "host_name: ", str(uri.host_name), "\n") - print(io, "port: ", Int(uri.port), "\n") - print(io, "path: ", str(uri.path), "\n") - print(io, "query: ", str(uri.query_string), "\n") - return -end +Length in bytes of `x` if `x` is `isbytes(x)`. +""" +function nbytes end +nbytes(x) = nothing +nbytes(x::AbstractVector{UInt8}) = length(x) +nbytes(x::AbstractString) = sizeof(x) +nbytes(x::Vector{T}) where T <: AbstractString = sum(sizeof, x) +nbytes(x::Vector{T}) where T <: AbstractVector{UInt8} = sum(length, x) +nbytes(x::IOBuffer) = bytesavailable(x) +nbytes(x::Vector{IOBuffer}) = sum(bytesavailable, x) -scheme(uri::aws_uri) = str(uri.scheme) -userinfo(uri::aws_uri) = str(uri.userinfo) -host(uri::aws_uri) = str(uri.host_name) -port(uri::aws_uri) = uri.port -path(uri::aws_uri) = str(uri.path) -query(uri::aws_uri) = str(uri.query_string) +const nobytes = view(UInt8[], 1:0) -function resource(uri::aws_uri) - ref = Ref(uri) - GC.@preserve ref begin - bc = aws_uri_path_and_query(ref) - path = str(unsafe_load(bc)) - return isempty(path) ? "/" : path +# URI accessor helpers that work on URIs.URI +scheme(uri::URI) = uri.scheme +userinfo(uri::URI) = uri.userinfo +host(uri::URI) = uri.host +function port(uri::URI) + p = uri.port + if p === nothing || isempty(p) + return UInt32(0) end + return UInt32(parse(Int, p)) +end +path(uri::URI) = uri.path +query(uri::URI) = uri.query + +function resource(uri::URI) + p = uri.path + q = uri.query + r = isempty(p) ? "/" : p + return isempty(q) ? r : string(r, "?", q) end const URI_SCHEME_HTTPS = "https" const URI_SCHEME_WSS = "wss" -ishttps(sch) = aws_byte_cursor_eq_c_str_ignore_case(sch, URI_SCHEME_HTTPS) -iswss(sch) = aws_byte_cursor_eq_c_str_ignore_case(sch, URI_SCHEME_WSS) -function getport(uri::aws_uri) - sch = Ref(uri.scheme) - GC.@preserve sch begin - return UInt32(uri.port != 0 ? uri.port : (ishttps(sch) || iswss(sch)) ? 443 : 80) - end +ishttps(sch::AbstractString) = lowercase(sch) == URI_SCHEME_HTTPS +iswss(sch::AbstractString) = lowercase(sch) == URI_SCHEME_WSS +function getport(uri::URI) + p = port(uri) + return p != 0 ? p : (ishttps(scheme(uri)) || iswss(scheme(uri))) ? UInt32(443) : UInt32(80) end -function makeuri(u::aws_uri) - return URIs.URI( - scheme=str(u.scheme), - userinfo=isempty(str(u.userinfo)) ? URIs.absent : str(u.userinfo), - host=str(u.host_name), - port=u.port == 0 ? URIs.absent : u.port, - path=isempty(str(u.path)) ? URIs.absent : str(u.path), - query=isempty(str(u.query_string)) ? URIs.absent : str(u.query_string), - ) -end +makeuri(u::URI) = u struct AWSError <: Exception msg::String end -aws_error() = AWSError(unsafe_string(aws_error_debug_str(aws_last_error()))) -aws_error(error_code) = AWSError(unsafe_string(aws_error_str(error_code))) +function _resolve_error_str(error_code::Integer) + ec = Int(error_code) + # AwsHTTP has its own String-based error table for HTTP-range codes + if ec >= AwsHTTP.ERROR_HTTP_UNKNOWN && ec <= AwsHTTP.ERROR_HTTP_END_RANGE + return AwsHTTP.http_error_str(ec) + end + return Reseau.error_str(ec) +end + +aws_error() = AWSError(_resolve_error_str(Reseau.last_error())) +aws_error(error_code) = AWSError(_resolve_error_str(error_code)) aws_throw_error() = throw(aws_error()) +# Simple Future type for async callback coordination. +# Replaces LibAwsCommon.Future. Supports notify/wait pattern: +# notify(f, value::T) -> success +# notify(f, err::Exception) -> error +# wait(f) -> returns T or throws Exception +mutable struct Future{T} + const notify_cond::Threads.Condition + @atomic set::Int8 # 0=pending, 1=success, 2=error + result::Union{Exception, T} + Future{T}() where {T} = new{T}(Threads.Condition(), 0) +end + +Future() = Future{Nothing}() + +function Base.wait(f::Future{T}) where {T} + set = @atomic f.set + set == 1 && return f.result::T + set == 2 && throw(f.result::Exception) + lock(f.notify_cond) + try + set = f.set + set == 1 && return f.result::T + set == 2 && throw(f.result::Exception) + wait(f.notify_cond) + finally + unlock(f.notify_cond) + end + f.set == 1 && return f.result::T + throw(f.result::Exception) +end + +function Base.notify(f::Future{T}, result::T) where {T} + lock(f.notify_cond) + try + f.set != 0 && return + f.result = result + @atomic f.set = 1 + notify(f.notify_cond) + finally + unlock(f.notify_cond) + end + return +end + +function Base.notify(f::Future, err::Exception) + lock(f.notify_cond) + try + f.set != 0 && return + f.result = err + @atomic f.set = 2 + notify(f.notify_cond) + finally + unlock(f.notify_cond) + end + return +end + struct BufferOnResponseBody{T <: AbstractVector{UInt8}} buffer::T - pos::Ptr{Int} + pos::Ref{Int} end function (f::BufferOnResponseBody)(resp, buf) len = length(buf) - pos = unsafe_load(f.pos) + pos = f.pos[] copyto!(f.buffer, pos, buf, 1, len) - unsafe_store!(f.pos, pos + len) + f.pos[] = pos + len return len end diff --git a/src/websockets.jl b/src/websockets.jl index 82e58ea25..1aca9eb4d 100644 --- a/src/websockets.jl +++ b/src/websockets.jl @@ -1,279 +1,456 @@ module WebSockets -using Base64, Random, LibAwsHTTPFork, LibAwsCommon, LibAwsIO - -import ..FieldRef, ..iswss, ..getport, ..makeuri, ..aws_throw_error, ..resource, ..Headers, ..Header, ..str, ..aws_error, ..aws_throw_error, ..Future, ..parseuri, ..with_redirect, ..with_request, ..getclient, ..ClientSettings, ..scheme, ..host, ..getport, ..userinfo, ..Client, ..Request, ..Response, ..setinputstream!, ..getresponse, ..CookieJar, ..COOKIEJAR, ..addheaders, ..Stream, ..HTTP, ..getheader +using Base64, Random, AwsHTTP, Reseau + +import ..Headers, ..Header, ..Request, ..Response, ..Message, ..Stream +import ..setinputstream!, ..getresponse, ..getheader, ..hasheader, ..header +import ..addheader, ..setheader, ..removeheader +import ..Future, ..parseuri, ..with_redirect, ..with_request, ..getclient +import ..ClientSettings, ..scheme, ..host, ..getport, ..userinfo, ..resource +import ..Client, ..CookieJar, ..COOKIEJAR, ..with_connection, .._open_stream +import ..aws_throw_error, ..aws_error, ..AWSError +import .._h1_flush_outgoing!, ..iswss, ..makeuri, ..HTTP, ..mkreqheaders +import ..startread, ..closeread, ..startwrite, ..closewrite, .._task_sleep_s export WebSocket, send, receive, ping, pong +# ─── Types ─── + @enum OpCode::UInt8 CONTINUATION=0x00 TEXT=0x01 BINARY=0x02 CLOSE=0x08 PING=0x09 PONG=0x0A -mutable struct WebSocket - id::String - host::String - path::String - not::Future{Nothing} - readchannel::Channel{Union{String, Vector{UInt8}}} - writebuffer::Vector{UInt8} - writepos::Int - writeclosed::Bool - closelock::ReentrantLock - handshake_request::Request - options::aws_websocket_client_connection_options - websocket_pointer::Ptr{aws_websocket} - handshake_response::Response - websocket_send_frame_options::aws_websocket_send_frame_options +const DEFAULT_MAX_FRAG = 1024 - WebSocket(host::AbstractString, path::AbstractString) = new(string(rand(UInt32); base=58), String(host), String(path), Future{Nothing}(), Channel{Union{String, Vector{UInt8}}}(Inf), UInt8[], 0, false, ReentrantLock()) +struct CloseFrameBody + code::Int + reason::String end -getresponse(ws::WebSocket) = ws.handshake_response +struct WebSocketError <: Exception + message::CloseFrameBody +end -const on_connection_setup = Ref{Ptr{Cvoid}}(C_NULL) +isok(e::WebSocketError) = e.message.code in (1000, 1001, 1005) +isok(::Any) = false -function c_on_connection_setup(connection_setup_data::Ptr{aws_websocket_on_connection_setup_data}, ws_ptr) - ws = unsafe_pointer_to_objref(ws_ptr) - data = unsafe_load(connection_setup_data) - try - if data.error_code != 0 - notify(ws.not, CapturedException(aws_error(data.error_code), Base.backtrace())) - else - ws.websocket_pointer = data.websocket - ws.handshake_response.status = unsafe_load(data.handshake_response_status) - addheaders(ws.handshake_response.headers, data.handshake_response_header_array, data.num_handshake_response_headers) - if data.handshake_response_body != C_NULL - handshake_response_body = unsafe_load(data.handshake_response_body) - response_body = str(handshake_response_body) - else - response_body = nothing +function isupgrade(r::Message) + ((r isa Request && r.method == "GET") || + (r isa Response && r.status == 101)) && + (hasheader(r, "Connection", "upgrade") || + hasheader(r, "Connection", "keep-alive, upgrade")) && + hasheader(r, "Upgrade", "websocket") +end + +isupgrade(s::Stream) = isupgrade(s.request) + +Base.@deprecate is_upgrade isupgrade + +# ─── WsChannelHandler ─── +# Bridges the Reseau channel pipeline with the AwsHTTP WebSocket codec. +# Installed into the H1Connection's channel slot after HTTP 101 upgrade. + +mutable struct WsChannelHandler <: Reseau.Sockets.AbstractChannelHandler + slot::Union{Reseau.Sockets.ChannelSlot, Nothing} + aws_ws::Any # AwsHTTP.WebSocket + wslock::ReentrantLock # protects outgoing_frames access + ws::Any +end + +WsChannelHandler(aws_ws, ws) = WsChannelHandler(nothing, aws_ws, ReentrantLock(), ws) + +function Reseau.Sockets.setchannelslot!(handler::WsChannelHandler, slot::Reseau.Sockets.ChannelSlot)::Nothing + handler.slot = slot + return nothing +end + +function Reseau.Sockets.handler_process_read_message(handler::WsChannelHandler, slot::Reseau.Sockets.ChannelSlot, message::Reseau.Sockets.IoMessage)::Nothing + data = Reseau.byte_buffer_as_vector(message.message_data) + isempty(data) && return nothing + @lock handler.wslock begin + status, _ = AwsHTTP.ws_on_incoming_data!(handler.aws_ws, data) + if status != AwsHTTP.OP_SUCCESS + ws = handler.ws + if ws !== nothing && !ws.readclosed + close_body = status == AwsHTTP.ERROR_HTTP_WEBSOCKET_PROTOCOL_ERROR ? + CloseFrameBody(1002, "WebSocket protocol error") : + CloseFrameBody(1011, "WebSocket error") + _queue_close!(ws, close_body) + errormonitor(Threads.@spawn close(ws, close_body)) end - setinputstream!(ws.handshake_response, response_body) - notify(ws.not, nothing) + Reseau.throw_error(status) end - catch e - notify(ws.not, CapturedException(e, Base.backtrace())) + # Flush auto-responses (PONG, CLOSE echo) generated by ws_on_incoming_data! + _ws_channel_flush!(handler) end - return + return nothing end -const on_connection_shutdown = Ref{Ptr{Cvoid}}(C_NULL) +function Reseau.Sockets.handler_process_write_message(handler::WsChannelHandler, slot::Reseau.Sockets.ChannelSlot, message::Reseau.Sockets.IoMessage)::Nothing + # Pass through to lower pipeline (socket) + Reseau.Sockets.channel_slot_send_message(slot, message, Reseau.Sockets.ChannelDirection.WRITE) + return nothing +end -function c_on_connection_shutdown(websocket::Ptr{aws_websocket}, error_code::Cint, ws_ptr) - ws = unsafe_pointer_to_objref(ws_ptr) - if error_code != 0 - @error "$(ws.id): connection shutdown error" exception=(aws_error(error_code), Base.backtrace()) +function Reseau.Sockets.handler_increment_read_window(handler::WsChannelHandler, slot::Reseau.Sockets.ChannelSlot, size::Csize_t)::Nothing + Reseau.Sockets.channel_slot_increment_read_window!(slot, size) + return nothing +end + +function Reseau.Sockets.handler_shutdown( + handler::WsChannelHandler, + slot::Reseau.Sockets.ChannelSlot, + direction::Reseau.Sockets.ChannelDirection.T, + error_code::Int, + free_scarce_resources_immediately::Bool, + )::Nothing + ws = handler.ws + if ws !== nothing && !ws.readclosed + _queue_close!(ws, CloseFrameBody(1006, "")) + end + Reseau.Sockets.channel_slot_on_handler_shutdown_complete!(slot, direction, error_code, free_scarce_resources_immediately) + return nothing +end + +Reseau.Sockets.handler_initial_window_size(::WsChannelHandler)::Csize_t = Csize_t(typemax(UInt64)) +Reseau.Sockets.handler_message_overhead(::WsChannelHandler)::Csize_t = Csize_t(0) + +# Collect outgoing frames from the AwsHTTP WebSocket codec and send them +# through the channel pipeline. Caller MUST hold handler.wslock. +function _ws_channel_flush!(handler::WsChannelHandler) + outdata = AwsHTTP.ws_get_outgoing_data!(handler.aws_ws) + isempty(outdata) && return + slot = handler.slot + slot === nothing && return + channel = slot.channel + channel === nothing && return + msg = Reseau.Sockets.IoMessage(length(outdata)) + buf = msg.message_data + @inbounds for i in 1:length(outdata) + buf.mem[i] = outdata[i] end - close(ws) + buf.len = Csize_t(length(outdata)) + if Reseau.Sockets.channel_thread_is_callers_thread(channel) + Reseau.Sockets.channel_slot_send_message(slot, msg, Reseau.Sockets.ChannelDirection.WRITE) + return + end + task = Reseau.Sockets.ChannelTask((task, ctx, status) -> begin + status == Reseau.TaskStatus.RUN_READY || return nothing + Reseau.Sockets.channel_slot_send_message(ctx.slot, ctx.msg, Reseau.Sockets.ChannelDirection.WRITE) + return nothing + end, (slot=slot, msg=msg), "http_ws_flush") + Reseau.Sockets.channel_schedule_task_now!(channel, task) return end -const on_incoming_frame_begin = Ref{Ptr{Cvoid}}(C_NULL) +# ─── WebSocket struct ─── -function c_on_incoming_frame_begin(websocket::Ptr{aws_websocket}, frame::Ptr{aws_websocket_incoming_frame}, ws_ptr) - # ws = unsafe_pointer_to_objref(ws_ptr) - # fr = unsafe_load(frame) - return true +mutable struct WebSocket + id::String + host::String + path::String + maxframesize::Int + maxfragmentation::Int + is_client::Bool + readchannel::Channel{Union{String, Vector{UInt8}, WebSocketError}} + readclosed::Bool + writeclosed::Bool + closelock::ReentrantLock + sendlock::ReentrantLock + handshake_request::Union{Nothing, Request} + handshake_response::Union{Nothing, Response} + # AwsHTTP WebSocket codec + channel handler + aws_ws::Any # AwsHTTP.WebSocket + handler::Any # WsChannelHandler + # Fragment tracking + incoming_opcode::UInt8 + incoming_fin::Bool + incoming_payload::Vector{UInt8} + fragment_opcode::Union{Nothing, UInt8} + fragment_payload::Vector{UInt8} + fragment_count::Int + drop_incoming::Bool + closebody::Union{Nothing, CloseFrameBody} + + WebSocket(host::AbstractString, path::AbstractString; maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, is_client::Bool=true) = new( + string(rand(UInt32); base=58), + String(host), + String(path), + Int(maxframesize), + Int(maxfragmentation), + is_client, + Channel{Union{String, Vector{UInt8}, WebSocketError}}(Inf), + false, + false, + ReentrantLock(), + ReentrantLock(), + nothing, + nothing, + nothing, # aws_ws + nothing, # handler + 0x00, + false, + UInt8[], + nothing, + UInt8[], + 0, + false, + nothing, + ) end -const on_incoming_frame_payload = Ref{Ptr{Cvoid}}(C_NULL) +getresponse(ws::WebSocket) = ws.handshake_response -function c_on_incoming_frame_payload(websocket::Ptr{aws_websocket}, frame::Ptr{aws_websocket_incoming_frame}, data::aws_byte_cursor, ws_ptr) - ws = unsafe_pointer_to_objref(ws_ptr) - fr = unsafe_load(frame) - try - if fr.opcode == UInt8(TEXT) - put!(ws.readchannel, unsafe_string(data.ptr, data.len)) - else - rec = Vector{UInt8}(undef, data.len) - Base.unsafe_copyto!(pointer(rec), data.ptr, data.len) - put!(ws.readchannel, rec) +# ─── Internal helpers ─── + +function _queue_close!(ws::WebSocket, body::CloseFrameBody) + ws.closebody = body + ws.readclosed = true + if isopen(ws.readchannel) + try + put!(ws.readchannel, WebSocketError(body)) + catch end - catch e - @error "$(ws.id): incoming frame payload error" exception=(e, catch_backtrace()) + Base.close(ws.readchannel) end - return true + return end -const on_incoming_frame_complete = Ref{Ptr{Cvoid}}(C_NULL) +function _close_channel!(ws::WebSocket) + isopen(ws.readchannel) && Base.close(ws.readchannel) + return +end -function c_on_incoming_frame_complete(websocket::Ptr{aws_websocket}, frame::Ptr{aws_websocket_incoming_frame}, error_code::Cint, ws_ptr) - ws = unsafe_pointer_to_objref(ws_ptr) - fr = unsafe_load(frame) - if error_code != 0 - @error "$(ws.id): incoming frame complete error" exception=(aws_error(error_code), Base.backtrace()) - end - return true +function _shutdown_ws_channel!(handler::WsChannelHandler) + slot = handler.slot + slot === nothing && return + channel = slot.channel + channel === nothing && return + Reseau.Sockets.channel_shutdown!(channel; shutdown_immediately=true) + return end -function open(f::Function, url; - headers=[], - allocator::Ptr{aws_allocator}=default_aws_allocator(), - username=nothing, - password=nothing, - bearer=nothing, - query=nothing, - client::Union{Nothing, Client}=nothing, - # redirect options - redirect=true, - redirect_limit=3, - redirect_method=nothing, - forwardheaders=true, - # cookie options - cookies=true, - cookiejar::CookieJar=COOKIEJAR, - modifier=nothing, - verbose=0, - # client keywords - kw... - ) - key = base64encode(rand(Random.RandomDevice(), UInt8, 16)) - uri = parseuri(url, query, allocator) - # add required websocket headers - append!(headers, [ - "upgrade" => "websocket", - "connection" => "upgrade", - "sec-websocket-key" => key, - "sec-websocket-version" => "13" - ]) - ws = with_redirect(allocator, "GET", uri, headers, nothing, redirect, redirect_limit, redirect_method, forwardheaders) do method, uri, headers, body - reqclient = @something(client, getclient(ClientSettings(scheme(uri), host(uri), getport(uri); allocator=allocator, ssl_alpn_list="http/1.1", kw...)))::Client - path = resource(uri) - with_request(reqclient, method, path, headers, body, nothing, false, (username !== nothing && password !== nothing) ? "$username:$password" : userinfo(uri), bearer, modifier, false, cookies, cookiejar, verbose) do req - host = str(uri.host_name) - ws = WebSocket(host, path) - ws.handshake_request = req - ws.handshake_response = Response(0, nothing, nothing, false, allocator) - ws.options = aws_websocket_client_connection_options( - allocator, - reqclient.settings.bootstrap, - pointer(FieldRef(reqclient, :socket_options)), - reqclient.tls_options === nothing ? C_NULL : pointer(FieldRef(reqclient, :tls_options)), - reqclient.proxy_options === nothing ? C_NULL : pointer(FieldRef(reqclient, :proxy_options)), - uri.host_name, - uri.port, - ws.handshake_request.ptr, - 0, # initial_window_size - Ptr{Cvoid}(pointer_from_objref(ws)), # user_data - on_connection_setup[], - on_connection_shutdown[], - on_incoming_frame_begin[], - on_incoming_frame_payload[], - on_incoming_frame_complete[], - false, # manual_window_management - C_NULL, # requested_event_loop - C_NULL, # host_resolution_config - ) - if aws_websocket_client_connect(FieldRef(ws, :options)) != 0 - aws_throw_error() - end - # wait until connected - wait(ws.not) - return ws +function _enqueue_message!(ws::WebSocket, msg) + if isopen(ws.readchannel) + try + put!(ws.readchannel, msg) + catch end end - verbose > 0 && @info "$(ws.id): WebSocket opened" - try - f(ws) - catch e - # if !isok(e) - # suppress_close_error || @error "$(ws.id): error" (e, catch_backtrace()) - # end - # if !isclosed(ws) - # if e isa WebSocketError && e.message isa CloseFrameBody - # close(ws, e.message) - # else - # close(ws, CloseFrameBody(1008, "Unexpected client websocket error")) - # end - # end - # if !isok(e) - rethrow() - # end - finally - # if !isclosed(ws) - close(ws) - # end - end + return end -function Base.close(ws::WebSocket) - @lock ws.closelock begin - if ws.websocket_pointer != C_NULL - aws_websocket_close(ws.websocket_pointer, false) - ws.websocket_pointer = C_NULL - ws.writeclosed = true - end - end +function _valid_close_status(code::Int)::Bool + code < 0 && return false + code > typemax(UInt16) && return false + return AwsHTTP.ws_is_valid_close_status(UInt16(code)) +end + +function _queue_protocol_error!(ws::WebSocket, reason::String) + close_body = CloseFrameBody(1002, reason) + _queue_close!(ws, close_body) + errormonitor(Threads.@spawn close(ws, close_body)) return end -""" - WebSockets.isclosed(ws) -> Bool +function _queue_invalid_payload!(ws::WebSocket, reason::String) + close_body = CloseFrameBody(1007, reason) + _queue_close!(ws, close_body) + errormonitor(Threads.@spawn close(ws, close_body)) + return +end -Check whether a `WebSocket` has sent and received CLOSE frames -""" -isclosed(ws::WebSocket) = !isopen(ws.readchannel) && ws.writeclosed +function close_payload(body::CloseFrameBody) + reason_bytes = collect(codeunits(body.reason)) + payload = Vector{UInt8}(undef, 2 + length(reason_bytes)) + payload[1] = UInt8((body.code >> 8) & 0xff) + payload[2] = UInt8(body.code & 0xff) + if !isempty(reason_bytes) + copyto!(payload, 3, reason_bytes, 1, length(reason_bytes)) + end + return payload +end isbinary(x) = x isa AbstractVector{UInt8} istext(x) = x isa AbstractString opcode(x) = isbinary(x) ? BINARY : TEXT -function payload(ws, x) - pload = isbinary(x) ? x : codeunits(string(x)) - len = length(pload) - resize!(ws.writebuffer, len) - copyto!(ws.writebuffer, pload) - ws.writepos = 1 - return ws.writebuffer +_to_bytes(x::AbstractVector{UInt8}) = x +_to_bytes(x) = Vector{UInt8}(codeunits(string(x))) + +# ─── AwsHTTP WebSocket callback builders ─── +# These create the closure callbacks passed to AwsHTTP.ws_new(). +# Each closure captures the HTTP.WebSocket and manipulates it directly. + +function _on_incoming_frame_begin(ws::WebSocket) + return (aws_ws, frame_info, user_data) -> begin + ws.incoming_opcode = frame_info.opcode + ws.incoming_fin = frame_info.fin + empty!(ws.incoming_payload) + ws.drop_incoming = false + if frame_info.payload_length > ws.maxframesize + close_body = CloseFrameBody(1009, "frame too large") + _queue_close!(ws, close_body) + errormonitor(Threads.@spawn close(ws, close_body)) + ws.drop_incoming = true + end + frame_info.payload_length > 0 && sizehint!(ws.incoming_payload, Int(frame_info.payload_length)) + return true + end end -const stream_outgoing_payload = Ref{Ptr{Cvoid}}(C_NULL) - -function c_stream_outgoing_payload(websocket::Ptr{aws_websocket}, out_buf::Ptr{aws_byte_buf}, ws_ptr::Ptr{Cvoid}) - ws = unsafe_pointer_to_objref(ws_ptr) - out = unsafe_load(out_buf) - try - space_available = out.capacity - out.len - amount_to_send = min(space_available, sizeof(ws.writebuffer) - ws.writepos + 1) - cursor = aws_byte_cursor(amount_to_send, pointer(ws.writebuffer, ws.writepos)) - @assert aws_byte_buf_write_from_whole_cursor(out_buf, cursor) - ws.writepos += amount_to_send - catch e - @error "$(ws.id): error" (e, catch_backtrace()) - return false +function _on_incoming_frame_payload(ws::WebSocket) + return (aws_ws, frame_info, data, user_data) -> begin + ws.drop_incoming && return true + try + n = length(data) + n == 0 && return true + append!(ws.incoming_payload, data) + catch e + @error "$(ws.id): incoming frame payload error" exception=(e, catch_backtrace()) + end + return true end - return true end -const on_complete = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_complete(websocket::Ptr{aws_websocket}, error_code::Cint, ws_ptr::Ptr{Cvoid}) - ws = unsafe_pointer_to_objref(ws_ptr) - if error_code != 0 - notify(ws.not, CapturedException(aws_error(error_code), Base.backtrace())) +function _on_incoming_frame_complete(ws::WebSocket) + return (aws_ws, frame_info, error_code, user_data) -> begin + if error_code != 0 + @error "$(ws.id): incoming frame complete error" error_code + close_body = CloseFrameBody(1006, "") + _queue_close!(ws, close_body) + errormonitor(Threads.@spawn close(ws, close_body)) + return true + end + if ws.drop_incoming + ws.drop_incoming = false + return true + end + op = frame_info.opcode + fin = frame_info.fin + payload = ws.incoming_payload + # PING/PONG: AwsHTTP handles auto-PONG, nothing to do here + if op == UInt8(PING) || op == UInt8(PONG) + return true + end + # CLOSE: AwsHTTP auto-echoes, we just track state + if op == UInt8(CLOSE) + body = payload + if length(body) == 1 + _queue_protocol_error!(ws, "invalid close payload length") + return true + end + close_body = if length(body) >= 2 + code = (Int(body[1]) << 8) | Int(body[2]) + _valid_close_status(code) || (_queue_protocol_error!(ws, "invalid close status code"); return true) + reason_bytes = length(body) > 2 ? body[3:end] : UInt8[] + if !isempty(reason_bytes) && !isvalid(String, reason_bytes) + _queue_invalid_payload!(ws, "invalid close reason") + return true + end + reason = length(body) > 2 ? String(copy(reason_bytes)) : "" + CloseFrameBody(code, reason) + else + CloseFrameBody(1005, "") + end + # AwsHTTP will echo the CLOSE and set close_sent=true, is_open=false + ws.writeclosed = true + _queue_close!(ws, close_body) + return true + end + # Data frames: TEXT, BINARY, CONTINUATION + if op == UInt8(CONTINUATION) + if ws.fragment_opcode === nothing + _queue_protocol_error!(ws, "unexpected continuation") + return true + end + ws.fragment_count += 1 + if ws.fragment_count > ws.maxfragmentation + close_body = CloseFrameBody(1009, "message too large") + _queue_close!(ws, close_body) + errormonitor(Threads.@spawn close(ws, close_body)) + return true + end + append!(ws.fragment_payload, payload) + if fin + msg_opcode = ws.fragment_opcode + data = ws.fragment_payload + ws.fragment_opcode = nothing + ws.fragment_payload = UInt8[] + ws.fragment_count = 0 + if msg_opcode == UInt8(TEXT) + if !isvalid(String, data) + _queue_invalid_payload!(ws, "invalid UTF-8") + return true + end + _enqueue_message!(ws, String(copy(data))) + else + _enqueue_message!(ws, copy(data)) + end + end + return true + end + if op == UInt8(TEXT) || op == UInt8(BINARY) + if ws.fragment_opcode !== nothing + _queue_protocol_error!(ws, "unexpected new data frame") + return true + end + if fin + if op == UInt8(TEXT) + if !isvalid(String, payload) + _queue_invalid_payload!(ws, "invalid UTF-8") + return true + end + _enqueue_message!(ws, String(copy(payload))) + else + _enqueue_message!(ws, copy(payload)) + end + ws.fragment_count = 0 + else + ws.fragment_opcode = op + ws.fragment_payload = copy(payload) + ws.fragment_count = 1 + if ws.fragment_count > ws.maxfragmentation + close_body = CloseFrameBody(1009, "message too large") + _queue_close!(ws, close_body) + errormonitor(Threads.@spawn close(ws, close_body)) + return true + end + end + end + return true end - notify(ws.not, nothing) - return end -function writeframe(ws::WebSocket, fin::Bool, opcode::OpCode, payload) - n = sizeof(payload) - ws.websocket_send_frame_options = aws_websocket_send_frame_options( - n % UInt64, - Ptr{Cvoid}(pointer_from_objref(ws)), # user_data - stream_outgoing_payload[], - on_complete[], - UInt8(opcode), - fin +# Create an AwsHTTP WebSocket and WsChannelHandler, then install +# the handler into the H1Connection's channel slot. +function _create_ws_handler!(ws::WebSocket, slot::Reseau.Sockets.ChannelSlot, is_client::Bool) + aws_ws = AwsHTTP.ws_new(; + is_client=is_client, + on_incoming_frame_begin=_on_incoming_frame_begin(ws), + on_incoming_frame_payload=_on_incoming_frame_payload(ws), + on_incoming_frame_complete=_on_incoming_frame_complete(ws), ) - opts = pointer(FieldRef(ws, :websocket_send_frame_options)) - if aws_websocket_send_frame(ws.websocket_pointer, opts) != 0 - aws_throw_error() + handler = WsChannelHandler(aws_ws, ws) + ws.aws_ws = aws_ws + ws.handler = handler + Reseau.Sockets.channel_slot_set_handler!(slot, handler) + return +end + +# ─── writeframe ─── + +function writeframe(ws::WebSocket, fin::Bool, opcode::OpCode, payload::AbstractVector{UInt8}) + handler = ws.handler + handler === nothing && throw(WebSocketError(CloseFrameBody(1006, "WebSocket not connected"))) + @lock handler.wslock begin + ret = AwsHTTP.ws_send_frame!(handler.aws_ws, UInt8(opcode), payload; fin=fin) + ret != AwsHTTP.OP_SUCCESS && throw(AWSError("ws_send_frame! failed")) + _ws_channel_flush!(handler) end - # wait until frame sent - wait(ws.not) - return n + return length(payload) end +# ─── Public API ─── + """ send(ws::WebSocket, msg) @@ -288,38 +465,34 @@ or `close(ws[, body::WebSockets.CloseFrameBody])`. Calling `close` will initiate the close sequence and close the underlying connection. """ function send(ws::WebSocket, x) - @assert !ws.writeclosed "WebSocket is closed" - if !isbinary(x) && !istext(x) - # if x is not single binary or text, then assume it's an iterable of binary or text - # and we'll send fragmented message - first = true - n = 0 - state = iterate(x) - if state === nothing - # x was not binary or text, but is an empty iterable, send single empty frame - x = "" - @goto write_single_frame - end - @debug "$(ws.id): Writing fragmented message" - item, st = state - # we prefetch next state so we know if we're on the last item or not - # so we can appropriately set the FIN bit for the last fragmented frame - nextstate = iterate(x, st) - while true - n += writeframe(ws, nextstate === nothing, first ? opcode(item) : CONTINUATION, payload(ws, item)) - first = false - nextstate === nothing && break - item, st = nextstate + @lock ws.sendlock begin + @assert !ws.writeclosed "WebSocket is closed" + if !isbinary(x) && !istext(x) + # iterable of binary or text → fragmented message + first = true + n = 0 + state = iterate(x) + if state === nothing + x = "" + @goto write_single_frame + end + item, st = state nextstate = iterate(x, st) - end - else - # single binary or text frame for message + while true + n += writeframe(ws, nextstate === nothing, first ? opcode(item) : CONTINUATION, _to_bytes(item)) + first = false + nextstate === nothing && break + item, st = nextstate + nextstate = iterate(x, st) + end + return n + else @label write_single_frame - return writeframe(ws, true, opcode(x), payload(ws, x)) + return writeframe(ws, true, opcode(x), _to_bytes(x)) + end end end -# control frames """ ping(ws, data=[]) @@ -328,8 +501,10 @@ body to send with the message. PONG messages are automatically responded to when a PING message is received by a websocket connection. """ function ping(ws::WebSocket, data=UInt8[]) - @assert !ws.writeclosed "WebSocket is closed" - return writeframe(ws.io, true, PING, payload(ws, data)) + @lock ws.sendlock begin + @assert !ws.writeclosed "WebSocket is closed" + return writeframe(ws, true, PING, _to_bytes(data)) + end end """ @@ -342,8 +517,10 @@ PONG message, but in certain cases, a unidirectional PONG message can be used as a one-way heartbeat. """ function pong(ws::WebSocket, data=UInt8[]) - @assert !ws.writeclosed "WebSocket is closed" - return writeframe(ws.io, true, PONG, payload(ws, data)) + @lock ws.sendlock begin + @assert !ws.writeclosed "WebSocket is closed" + return writeframe(ws, true, PONG, _to_bytes(data)) + end end """ @@ -363,8 +540,23 @@ returned by `receive`. Note that `WebSocket` objects can be iterated, where each iteration yields a message until the connection is closed. """ function receive(ws::WebSocket) - @assert isopen(ws.readchannel) "WebSocket is closed" - return take!(ws.readchannel) + # If a CLOSE arrives after application data, `_queue_close!` marks `readclosed=true` + # and closes `readchannel`. We still need to deliver any already-queued messages + # before throwing the close error. + if isready(ws.readchannel) + msg = take!(ws.readchannel) + msg isa WebSocketError && throw(msg) + return msg + end + + if ws.readclosed || !isopen(ws.readchannel) + close_body = ws.closebody === nothing ? CloseFrameBody(1006, "") : ws.closebody + throw(WebSocketError(close_body)) + end + + msg = take!(ws.readchannel) + msg isa WebSocketError && throw(msg) + return msg end """ @@ -389,58 +581,288 @@ function Base.iterate(ws::WebSocket, st=nothing) end end -# given a WebSocket request, return the 101 response +""" + WebSockets.isclosed(ws) -> Bool + +Check whether a `WebSocket` has sent and received CLOSE frames +""" +isclosed(ws::WebSocket) = ws.readclosed && ws.writeclosed + +function Base.close(ws::WebSocket, body::Union{Nothing, CloseFrameBody}=nothing) + handler = ws.handler + @lock ws.closelock begin + if ws.writeclosed + _close_channel!(ws) + return + end + ws.writeclosed = true + if handler !== nothing + close_body = body === nothing ? CloseFrameBody(1000, "") : body + code = UInt16(close_body.code) + reason = Vector{UInt8}(codeunits(close_body.reason)) + @lock handler.wslock begin + try + AwsHTTP.ws_close!(handler.aws_ws; status_code=code, reason=reason) + _ws_channel_flush!(handler) + catch + # ignore errors while closing + end + end + end + end + if !ws.readclosed + deadline = time() + 5.0 + while time() < deadline + ws.readclosed && break + _task_sleep_s(0.05) + end + ws.readclosed = true + end + if !ws.is_client && handler !== nothing + _shutdown_ws_channel!(handler) + end + _close_channel!(ws) + return +end + +@noinline handshakeerror() = throw(WebSocketError(CloseFrameBody(1002, "Websocket handshake failed"))) + +# ─── Client-side open ─── + +function open(f::Function, url; + suppress_close_error::Bool=false, + headers=[], + maxframesize::Integer=typemax(Int), + maxfragmentation::Integer=DEFAULT_MAX_FRAG, + username=nothing, + password=nothing, + bearer=nothing, + query=nothing, + client::Union{Nothing, Client}=nothing, + # redirect options + redirect=true, + redirect_limit=3, + redirect_method=nothing, + forwardheaders=true, + # cookie options + cookies=true, + cookiejar::CookieJar=COOKIEJAR, + modifier=nothing, + verbose=0, + # client keywords + kw... + ) + key = base64encode(rand(Random.RandomDevice(), UInt8, 16)) + expected_accept = AwsHTTP.ws_compute_accept_key(key) + uri = parseuri(url, query) + # add required websocket headers + headers = collect(headers) + append!(headers, [ + "upgrade" => "websocket", + "connection" => "upgrade", + "sec-websocket-key" => key, + "sec-websocket-version" => "13" + ]) + ws = with_redirect("GET", uri, headers, nothing, redirect, redirect_limit, redirect_method, forwardheaders) do method, uri, headers, body + reqclient = @something(client, getclient(ClientSettings(scheme(uri), host(uri), getport(uri); ssl_alpn_list="http/1.1", kw...)))::Client + path = resource(uri) + with_request(reqclient, method, path, headers, body, nothing, false, (username !== nothing && password !== nothing) ? "$username:$password" : userinfo(uri), bearer, modifier, false, cookies, cookiejar, verbose) do req + ws_host = header(req, "host", String(host(uri))) + ws = WebSocket(ws_host, path; maxframesize=maxframesize, maxfragmentation=maxfragmentation, is_client=true) + ws.handshake_request = req + with_connection(reqclient) do conn + stream = _open_stream(conn, req, false, 0) + resp = startread(stream) + ws.handshake_response = resp + if resp.status != 101 + # Not a WebSocket upgrade; finish the stream and return ws + # so with_redirect can follow 3xx redirects via getresponse(ws) + try; closeread(stream); catch; end + return ws + end + if !isupgrade(resp) + try; closeread(stream); catch; end + AwsHTTP.http_connection_close(conn) + handshakeerror() + end + accept = getheader(resp.headers, "sec-websocket-accept") + if accept === nothing || accept != expected_accept + try; closeread(stream); catch; end + AwsHTTP.http_connection_close(conn) + handshakeerror() + end + # Wait for the H1 stream to complete (fires with ERROR_HTTP_SWITCHED_PROTOCOLS) + # + # Note: servers are allowed to start sending websocket frames immediately after the 101 + # response. Depending on timing, the HTTP layer may have already read some post-upgrade + # bytes before we swap in the websocket handler. After the stream completes, drain any + # such buffered bytes and feed them to the websocket decoder so they aren't lost. + try + wait(stream.fut) + catch + # Expected: ERROR_HTTP_SWITCHED_PROTOCOLS + end + buffered = UInt8[] + if stream.bufferstream !== nothing + # BufferStream is closed in the stream on_complete callback, so read() should not block here. + buffered = try + read(stream.bufferstream) + catch + UInt8[] + end + end + # Swap the H1Connection handler with our WsChannelHandler + h1conn = stream.aws_stream.owning_connection + slot = h1conn.slot + _create_ws_handler!(ws, slot, true) + if !isempty(buffered) + handler = ws.handler + handler !== nothing || throw(WebSocketError(CloseFrameBody(1006, "WebSocket not connected"))) + @lock handler.wslock begin + status, _ = AwsHTTP.ws_on_incoming_data!(handler.aws_ws, buffered) + status != AwsHTTP.OP_SUCCESS && throw(AWSError("ws_on_incoming_data! failed")) + _ws_channel_flush!(handler) + end + end + verbose > 0 && @info "$(ws.id): WebSocket opened" + # Run WebSocket session + try + f(ws) + catch e + if !isok(e) + suppress_close_error || @error "$(ws.id): error" exception=(e, catch_backtrace()) + end + if !isclosed(ws) + if e isa WebSocketError && e.message isa CloseFrameBody + close(ws, e.message) + else + close(ws, CloseFrameBody(1008, "Unexpected client websocket error")) + end + end + if !isok(e) + rethrow() + end + finally + if !isclosed(ws) + close(ws, CloseFrameBody(1000, "")) + end + AwsHTTP.http_connection_close(conn) + end + end + return ws + end + end + # After redirect loop: verify we actually got a 101 upgrade + resp = ws.handshake_response + if resp === nothing || resp.status != 101 + throw(WebSocketError(CloseFrameBody(1002, "Websocket handshake failed: status $(resp === nothing ? 0 : resp.status)"))) + end + return ws +end + +# ─── Server-side upgrade ─── + +# Given a WebSocket request, return the 101 response function websocket_upgrade_handler(req::Request) + if !isupgrade(req) + return Response(400, ["content-type" => "text/plain"], "websocket upgrade required") + end + if !hasheader(req, "Sec-WebSocket-Version", "13") + return Response(400, ["content-type" => "text/plain"], "unsupported websocket version") + end key = getheader(req.headers, "sec-websocket-key") - resp_ptr = aws_http_message_new_websocket_handshake_response(req.allocator, aws_byte_cursor_from_c_str(key)) - resp_ptr == C_NULL && aws_throw_error() - resp = Response() - resp.allocator = req.allocator - resp.ptr = resp_ptr - resp.request = req + if key === nothing || isempty(key) + return Response(400, ["content-type" => "text/plain"], "missing websocket key") + end + accept_key = AwsHTTP.ws_compute_accept_key(key) + resp = Response(101, [ + "upgrade" => "websocket", + "connection" => "upgrade", + "sec-websocket-accept" => accept_key, + ], nothing) return resp end -function websocket_upgrade_function(f) - #TODO: return WebSocketUpgradeArgs - # then schedule a task to do the actual upgrade +function websocket_upgrade_function(f; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, done=nothing) function websocket_upgrade(stream::Stream) - #TODO: get host/path from stream? - ws = WebSocket("", "") - stream.websocket_options = aws_websocket_server_upgrade_options( - 0, - Ptr{Cvoid}(pointer_from_objref(ws)), - on_incoming_frame_begin[], - on_incoming_frame_payload[], - on_incoming_frame_complete[], - false # manual_window_management - ) - ws_ptr = aws_websocket_upgrade(stream.allocator, stream.ptr, FieldRef(stream, :websocket_options)) - ws_ptr == C_NULL && aws_throw_error() - ws.websocket_pointer = ws_ptr + resp = isdefined(stream, :response) ? stream.response : nothing + if resp === nothing || resp.status != 101 + done !== nothing && notify(done, CapturedException(ArgumentError("websocket upgrade not accepted"), Base.backtrace())) + return + end + req = stream.request + ws = WebSocket(header(req, "host", ""), req.path; maxframesize=maxframesize, maxfragmentation=maxfragmentation, is_client=false) + ws.handshake_request = req + ws.handshake_response = resp + # Get the H1Connection's channel slot and swap in our WebSocket handler + h1conn = stream.aws_stream.owning_connection + slot = h1conn.slot + _create_ws_handler!(ws, slot, false) errormonitor(Threads.@spawn begin + err = nothing try f(ws) + catch e + if !isok(e) + err = e + suppress_close_error || @error "$(ws.id): error" exception=(e, catch_backtrace()) + end + if !isclosed(ws) + if e isa WebSocketError && e.message isa CloseFrameBody + close(ws, e.message) + elseif isok(e) + close(ws, CloseFrameBody(1000, "")) + else + close(ws, CloseFrameBody(1011, "Unexpected server websocket error")) + end + end + if err !== nothing + done !== nothing && notify(done, CapturedException(e, catch_backtrace())) + end + if !isok(e) + rethrow() + end finally - aws_websocket_release(ws_ptr) + if err === nothing + if !isclosed(ws) + close(ws, CloseFrameBody(1000, "")) + end + done !== nothing && notify(done, nothing) + end + AwsHTTP.http_connection_close(h1conn) end end) return end end -serve!(f, host="127.0.0.1", port=8080; kw...) = - HTTP.serve!(websocket_upgrade_handler, host, port; on_stream_complete=websocket_upgrade_function(f), kw...) - -function __init__() - on_connection_setup[] = @cfunction(c_on_connection_setup, Cvoid, (Ptr{aws_websocket_on_connection_setup_data}, Ptr{Cvoid})) - on_connection_shutdown[] = @cfunction(c_on_connection_shutdown, Cvoid, (Ptr{aws_websocket}, Cint, Ptr{Cvoid})) - on_incoming_frame_begin[] = @cfunction(c_on_incoming_frame_begin, Bool, (Ptr{aws_websocket}, Ptr{aws_websocket_incoming_frame}, Ptr{Cvoid})) - on_incoming_frame_payload[] = @cfunction(c_on_incoming_frame_payload, Bool, (Ptr{aws_websocket}, Ptr{aws_websocket_incoming_frame}, aws_byte_cursor, Ptr{Cvoid})) - on_incoming_frame_complete[] = @cfunction(c_on_incoming_frame_complete, Bool, (Ptr{aws_websocket}, Ptr{aws_websocket_incoming_frame}, Cint, Ptr{Cvoid})) - stream_outgoing_payload[] = @cfunction(c_stream_outgoing_payload, Bool, (Ptr{aws_websocket}, Ptr{aws_byte_buf}, Ptr{Cvoid})) - on_complete[] = @cfunction(c_on_complete, Cvoid, (Ptr{aws_websocket}, Cint, Ptr{Cvoid})) +function _upgrade(f::Function, stream::Stream; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG) + isupgrade(stream) || handshakeerror() + hasheader(stream.request, "Sec-WebSocket-Version", "13") || handshakeerror() + key = getheader(stream.request.headers, "sec-websocket-key") + (key === nothing || isempty(key)) && handshakeerror() + stream.response_started && error("response already started") + done = Future{Nothing}() + stream.on_complete = websocket_upgrade_function(f; suppress_close_error=suppress_close_error, maxframesize=maxframesize, maxfragmentation=maxfragmentation, done=done) + stream.response = websocket_upgrade_handler(stream.request) + startwrite(stream) + closewrite(stream) + wait(done) return end -end # module \ No newline at end of file +upgrade(f::Function, stream::Stream; kw...) = _upgrade(f, stream; kw...) +upgrade(stream::Stream, f::Function; kw...) = _upgrade(f, stream; kw...) + +serve!(f, host="127.0.0.1", port=8080; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, kw...) = + HTTP.serve!(websocket_upgrade_handler, host, port; on_stream_complete=websocket_upgrade_function(f; suppress_close_error=suppress_close_error, maxframesize=maxframesize, maxfragmentation=maxfragmentation), kw...) + +listen!(f, host="127.0.0.1", port=8080; kw...) = serve!(f, host, port; kw...) + +function listen(f, host="127.0.0.1", port=8080; kw...) + server = listen!(f, host, port; kw...) + wait(server) + return server +end + +end # module diff --git a/test/client.jl b/test/client.jl index d230212e8..dd76840cc 100644 --- a/test/client.jl +++ b/test/client.jl @@ -1,4 +1,22 @@ @testset "Client.jl" begin + @testset "Connection pool compatibility" begin + original_limit = HTTP.default_connection_limit[] + HTTP.set_default_connection_limit!(13) + pool = HTTP.Pool() + @test pool.max_connections == 13 + @test pool.max == 13 + cs = HTTP.ClientSettings("http", "example.com", UInt32(80)) + @test cs.max_connections == 13 + @test_logs (:warn, r"connection_limit no longer supported") begin + cs_warn = HTTP.ClientSettings("http", "example.com", UInt32(80); connection_limit=7) + @test cs_warn.max_connections == 13 + end + HTTP.set_default_connection_limit!(original_limit) + pool_default = HTTP.Pool() + @test pool_default.max_connections == original_limit + @test pool_default.max == original_limit + end + if HAVE_HTTPBIN @testset "GET, HEAD, POST, PUT, DELETE, PATCH: $scheme" for scheme in ["http", "https"] @test isok(HTTP.get("$scheme://$httpbin/ip")) @test isok(HTTP.head("$scheme://$httpbin/ip")) @@ -134,6 +152,9 @@ @test isok(resp) # x = JSONBase.materialize(resp.body) # @test x["form"] == Dict("name" => ["value with spaces"]) + resp = HTTP.post("https://$httpbin/post"; body=["hey", " there ", "sailor"]) + @test isok(resp) + @test occursin("\"data\":\"hey there sailor\"", String(resp.body)) end @testset "ASync Client Request Body" begin @@ -155,7 +176,8 @@ @test HTTP.request(read_method, "https://$httpbin/redirect/6", status_exception=false).status == 302 #over max number of redirects @test isok(HTTP.request(read_method, "https://$httpbin/relative-redirect/1")) @test isok(HTTP.request(read_method, "https://$httpbin/absolute-redirect/1")) - @test isok(HTTP.request(read_method, "https://$httpbin/redirect-to?url=http%3A%2F%2Fgoogle.com")) + redirect_target = URIs.escapeuri("http://$httpbin/get") + @test isok(HTTP.request(read_method, "https://$httpbin/redirect-to?url=$redirect_target")) end @testset "Client Basic Auth" begin @@ -189,41 +211,592 @@ @test r.request.method == "GET" @test length(r.body) > 0 end + else + @info "Skipping HTTPBin-dependent client tests" + end + + @testset "Header insertion" begin + server = HTTP.serve!(req -> begin + accept_count = length(HTTP.headers(req, "accept")) + host_count = length(HTTP.headers(req, "host")) + return HTTP.Response(200, "$accept_count,$host_count") + end; listenany=true) + try + port = HTTP.port(server) + resp = HTTP.get("http://127.0.0.1:$port"; headers=["Accept" => "*/*", "Host" => "example.com"]) + @test String(resp.body) == "1,1" + resp = HTTP.get("http://127.0.0.1:$port") + @test String(resp.body) == "1,1" + finally + close(server) + end + end + + @testset "Host header includes port for non-default" begin + server = HTTP.serve!(req -> HTTP.Response(200, HTTP.header(req, "host")); listenany=true) + try + port = HTTP.port(server) + resp = HTTP.get("http://127.0.0.1:$port") + @test String(resp.body) == "127.0.0.1:$port" + finally + close(server) + end + end @testset "readtimeout" begin - @test_throws CapturedException HTTP.get("http://$httpbin/delay/5"; readtimeout=1, max_retries=0) - @test isok(HTTP.get("http://$httpbin/delay/1"; readtimeout=2, max_retries=0)) - end - - @testset "Public entry point of HTTP.request and friends (e.g. issue #463)" begin - headers = Dict("User-Agent" => "HTTP.jl") - query = Dict("hello" => "world") - body = UInt8[1, 2, 3] - for uri in ("https://$httpbin/anything", HTTP.URI("https://$httpbin/anything")) - # HTTP.request - @test isok(HTTP.request("GET", uri; headers=headers, body=body, query=query)) - @test isok(HTTP.request("GET", uri, headers; body=body, query=query)) - @test isok(HTTP.request("GET", uri, headers, body; query=query)) - # HTTP.get - @test isok(HTTP.get(uri; headers=headers, body=body, query=query)) - @test isok(HTTP.get(uri, headers; body=body, query=query)) - @test isok(HTTP.get(uri, headers, body; query=query)) - # HTTP.put - @test isok(HTTP.put(uri; headers=headers, body=body, query=query)) - @test isok(HTTP.put(uri, headers; body=body, query=query)) - @test isok(HTTP.put(uri, headers, body; query=query)) - # HTTP.post - @test isok(HTTP.post(uri; headers=headers, body=body, query=query)) - @test isok(HTTP.post(uri, headers; body=body, query=query)) - @test isok(HTTP.post(uri, headers, body; query=query)) - # HTTP.patch - @test isok(HTTP.patch(uri; headers=headers, body=body, query=query)) - @test isok(HTTP.patch(uri, headers; body=body, query=query)) - @test isok(HTTP.patch(uri, headers, body; query=query)) - # HTTP.delete - @test isok(HTTP.delete(uri; headers=headers, body=body, query=query)) - @test isok(HTTP.delete(uri, headers; body=body, query=query)) - @test isok(HTTP.delete(uri, headers, body; query=query)) - end - end -end \ No newline at end of file + server = HTTP.serve!("127.0.0.1", 0; listenany=true) do req + if req.target == "/delay/5" + sleep(5) + elseif req.target == "/delay/1" + sleep(1) + end + return HTTP.Response(200, "ok") + end + try + port = HTTP.port(server) + @test_throws HTTP.TimeoutError HTTP.get("http://127.0.0.1:$port/delay/5"; readtimeout=1, max_retries=0) + @test isok(HTTP.get("http://127.0.0.1:$port/delay/1"; readtimeout=5, max_retries=0)) + finally + close(server) + end + end + + @testset "Retry semantics" begin + attempts = Ref(0) + failures = Ref(1) + attempt_lock = ReentrantLock() + next_attempt() = Base.@lock attempt_lock begin + attempts[] += 1 + return attempts[] + end + reset_attempts!(nfail) = Base.@lock attempt_lock begin + attempts[] = 0 + failures[] = nfail + return + end + + server = HTTP.serve!("127.0.0.1", 0; listenany=true) do req + n = next_attempt() + if n <= failures[] + return HTTP.Response(503, "fail") + end + return HTTP.Response(200, "ok") + end + port = HTTP.port(server) + try + reset_attempts!(1) + resp = HTTP.get("http://127.0.0.1:$port/"; retries=1, retry_delays=[0.0]) + @test resp.status == 200 + @test resp.metrics.nretries == 1 + @test attempts[] == 2 + + reset_attempts!(1) + err = nothing + try + HTTP.post("http://127.0.0.1:$port/"; body="x", retries=1, retry_delays=[0.0]) + catch e + err = e + end + @test err isa HTTP.StatusError + @test err.response.metrics.nretries == 0 + @test attempts[] == 1 + + reset_attempts!(1) + resp = HTTP.post("http://127.0.0.1:$port/"; body="x", retries=1, + retry_non_idempotent=true, retry_delays=[0.0]) + @test resp.status == 200 + @test resp.metrics.nretries == 1 + @test attempts[] == 2 + + reset_attempts!(1) + resp = HTTP.post("http://127.0.0.1:$port/"; body="x", retries=1, + retry_check=(s, ex, req, resp, resp_body) -> true, retry_delays=[0.0]) + @test resp.status == 200 + @test resp.metrics.nretries == 1 + @test attempts[] == 2 + + reset_attempts!(2) + err = nothing + try + HTTP.get("http://127.0.0.1:$port/"; retries=3, retry_delays=[0.0]) + catch e + err = e + end + @test err isa HTTP.StatusError + @test attempts[] == 2 + + reset_attempts!(1) + resp = HTTP.get("http://127.0.0.1:$port/"; retries=1, retry_partition="test") + @test resp.status == 200 + @test resp.metrics.nretries == 1 + @test attempts[] == 2 + finally + close(server) + end + end + + @testset "Request metrics" begin + server = HTTP.listen!("127.0.0.1", 0; listenany=true) do http + body = read(http) + HTTP.setstatus(http, 200) + HTTP.startwrite(http) + write(http, body) + end + try + port = HTTP.port(server) + resp = HTTP.post("http://127.0.0.1:$port/"; body="hello") + @test resp.metrics.request_body_length == 5 + @test resp.metrics.response_body_length == 5 + + resp = HTTP.post("http://127.0.0.1:$port/"; body=IOBuffer("chunked")) + @test resp.metrics.request_body_length == 7 + @test resp.metrics.response_body_length == 7 + finally + close(server) + end + end + + if HAVE_HTTPBIN + @testset "Request Options Parity" begin + headers = ["X-Test" => "1"] + HTTP.get("https://$httpbin/headers"; headers=headers, copyheaders=true) + @test headers == ["X-Test" => "1"] + + headers2 = ["X-Test" => "1"] + HTTP.get("https://$httpbin/headers"; headers=headers2, copyheaders=false) + @test any(h -> lowercase(String(h.first)) == "accept", headers2) + @test any(h -> lowercase(String(h.first)) == "x-test", headers2) + + resp = HTTP.get("https://user:pwd@$httpbin/headers"; basicauth=false) + @test HTTP.getheader(resp.request.headers, "authorization") === nothing + + resp = HTTP.get("https://user:pwd@$httpbin/headers"; basicauth=true) + auth = HTTP.getheader(resp.request.headers, "authorization") + @test auth !== nothing && startswith(auth, "Basic ") + + resp = HTTP.post("https://$httpbin/anything"; body="hello", detect_content_type=true) + @test HTTP.getheader(resp.request.headers, "content-type") == "text/plain; charset=utf-8" + + orig_agent = HTTP.USER_AGENT[] + try + HTTP.setuseragent!(nothing) + resp = HTTP.get("https://$httpbin/headers") + @test HTTP.getheader(resp.request.headers, "user-agent") === nothing + finally + HTTP.setuseragent!(orig_agent) + end + + pool = HTTP.Pool(1) + @test isempty(pool.clients.clients) + HTTP.get("https://$httpbin/ip"; pool=pool) + @test !isempty(pool.clients.clients) + end + else + @info "Skipping HTTPBin-dependent Request Options Parity tests" + end + + @testset "observelayers" begin + server = HTTP.serve!(req -> begin + if req.target == "/redirect" + return HTTP.Response(302, ["Location" => "/ok"], nothing) + end + return HTTP.Response(200, "ok") + end; listenany=true) + try + port = HTTP.port(server) + resp = HTTP.get("http://127.0.0.1:$port/redirect"; observelayers=true, retries=0) + ctx = resp.request.context + @test ctx[:messagelayer_count] >= 1 + @test ctx[:redirectlayer_count] >= 1 + @test ctx[:retrylayer_count] >= 1 + @test ctx[:connectionlayer_count] >= 1 + @test ctx[:streamlayer_count] >= 1 + @test ctx[:total_request_duration_ms] > 0 + finally + close(server) + end + end + + @testset "IO request body streaming" begin + mutable struct ChunkedTestIO <: IO + chunks::Vector{Vector{UInt8}} + readbytes_calls::Int + readavailable_calls::Int + end + ChunkedTestIO(chunks) = ChunkedTestIO(chunks, 0, 0) + Base.eof(io::ChunkedTestIO) = isempty(io.chunks) + function Base.readbytes!(io::ChunkedTestIO, buf::Vector{UInt8}, n::Integer) + io.readbytes_calls += 1 + isempty(io.chunks) && return 0 + chunk = popfirst!(io.chunks) + ncopy = min(n, length(chunk)) + copyto!(buf, 1, chunk, 1, ncopy) + return ncopy + end + function Base.readavailable(io::ChunkedTestIO) + io.readavailable_calls += 1 + error("readavailable should not be used for chunked IO") + end + + server = HTTP.listen!("127.0.0.1", 0; listenany=true) do http + body = String(read(http)) + HTTP.setstatus(http, 200) + HTTP.setheader(http, "Content-Type" => "text/plain") + HTTP.startwrite(http) + write(http, body) + end + try + port = HTTP.port(server) + io = ChunkedTestIO([Vector{UInt8}("hello"), Vector{UInt8}(" "), Vector{UInt8}("world")]) + resp = HTTP.post("http://127.0.0.1:$port/"; body=io) + @test String(resp.body) == "hello world" + @test io.readbytes_calls == 3 + @test io.readavailable_calls == 0 + finally + close(server) + end + end + + @testset "closed IOStream body errors" begin + path = tempname() + io = open(path, "w") + close(io) + @test_throws ArgumentError HTTP.request("POST", "http://example.com"; body=io, retry=false, status_exception=false) + end + + @testset "Iterable request body streaming" begin + mutable struct ChunkedIterable + chunks::Vector{Vector{UInt8}} + iter_calls::Int + end + ChunkedIterable(chunks) = ChunkedIterable(chunks, 0) + function Base.iterate(it::ChunkedIterable, state::Int=1) + state > length(it.chunks) && return nothing + it.iter_calls += 1 + return (it.chunks[state], state + 1) + end + + server = HTTP.listen!("127.0.0.1", 0; listenany=true) do http + body = String(read(http)) + HTTP.setstatus(http, 200) + HTTP.setheader(http, "Content-Type" => "text/plain") + HTTP.startwrite(http) + write(http, body) + end + try + port = HTTP.port(server) + chunks = ChunkedIterable([Vector{UInt8}("hello"), Vector{UInt8}(" "), Vector{UInt8}("world")]) + resp = HTTP.post("http://127.0.0.1:$port/"; body=chunks) + @test String(resp.body) == "hello world" + @test chunks.iter_calls == 3 + finally + close(server) + end + end + + @testset "stream helpers" begin + server = HTTP.listen!("127.0.0.1", 0; listenany=true) do http + body = String(read(http)) + if http.request.method == "POST" + HTTP.setstatus(http, 200) + HTTP.startwrite(http) + write(http, body) + else + HTTP.setstatus(http, 500) + HTTP.setheader(http, "Connection" => "close") + HTTP.startwrite(http) + write(http, "error") + end + end + try + port = HTTP.port(server) + resp = HTTP.open("GET", "http://127.0.0.1:$port"; status_exception=false) do io + r = HTTP.startread(io) + @test r.status == 500 + @test HTTP.isaborted(io) + buf = IOBuffer() + n = HTTP.readall!(io, buf) + @test n > 0 + @test String(take!(buf)) == "error" + end + @test resp.status == 500 + + resp = HTTP.open("POST", "http://127.0.0.1:$port") do io + write(io, "hello") + HTTP.closebody(io) + r = HTTP.startread(io) + @test r.status == 200 + @test String(read(io)) == "hello" + end + @test resp.status == 200 + finally + close(server) + end + end + + if HAVE_HTTPBIN + @testset "HTTP.open streaming" begin + resp = HTTP.open("GET", "https://$httpbin/stream/5") do io + r = HTTP.startread(io) + @test r.status == 200 + data = String(read(io)) + @test length(split(chomp(data), '\n')) == 5 + end + @test resp.status == 200 + + resp = HTTP.open("POST", "https://$httpbin/anything") do io + write(io, "hello") + HTTP.closewrite(io) + r = HTTP.startread(io) + data = String(read(io)) + @test occursin("\"data\":\"hello\"", data) + end + @test resp.status == 200 + end + else + @info "Skipping HTTPBin-dependent HTTP.open streaming tests" + end + + @testset "HTTP/2 stream manager smoke" begin + cs = HTTP.ClientSettings("https", "example.com", UInt32(443); http2_stream_manager=true) + client = HTTP.Client(cs) + @test client.http2_stream_manager != C_NULL + finalize(client) + end + + @testset "HTTP/2 stream manager options" begin + cs = HTTP.ClientSettings("https", "example.com", UInt32(443); + http2_stream_manager=true, + http2_close_connection_on_server_error=true, + http2_connection_manual_window_management=true, + http2_connection_ping_period_ms=1234, + http2_connection_ping_timeout_ms=2345, + http2_ideal_concurrent_streams_per_connection=7, + http2_max_concurrent_streams_per_connection=9, + http2_initial_window_size=12345, + ) + client = HTTP.Client(cs) + opts = client.http2_stream_manager_opts + @test opts !== nothing + @test opts.close_connection_on_server_error == true + @test opts.conn_manual_window_management == true + @test opts.connection_ping_period_ms == Csize_t(1234) + @test opts.connection_ping_timeout_ms == Csize_t(2345) + @test opts.ideal_concurrent_streams_per_connection == Csize_t(7) + @test opts.max_concurrent_streams_per_connection == Csize_t(9) + @test opts.initial_window_size == Csize_t(12345) + @test client.conn_manager_opts.http2_conn_manual_window_management == true + finalize(client) + end + + @testset "HTTP/2 max closed streams option" begin + cs = HTTP.ClientSettings("https", "example.com", UInt32(443); + http2_max_closed_streams=7, + http2_initial_window_size=54321, + ) + client = HTTP.Client(cs) + @test client.conn_manager_opts.max_closed_streams == Csize_t(7) + @test client.conn_manager_opts.initial_window_size == Csize_t(54321) + finalize(client) + + cs = HTTP.ClientSettings("https", "example.com", UInt32(443); + http2_stream_manager=true, + http2_max_closed_streams=9, + http2_initial_window_size=65432, + ) + client = HTTP.Client(cs) + opts = client.http2_stream_manager_opts + @test opts !== nothing + @test opts.max_closed_streams == Csize_t(9) + @test opts.initial_window_size == Csize_t(65432) + finalize(client) + end + + @testset "HTTP/2 initial settings options" begin + settings = [ + HTTP.AWS_HTTP2_SETTINGS_MAX_CONCURRENT_STREAMS => 10, + HTTP.AWS_HTTP2_SETTINGS_INITIAL_WINDOW_SIZE => 65535, + ] + cs = HTTP.ClientSettings("https", "example.com", UInt32(443); http2_initial_settings=settings) + client = HTTP.Client(cs) + @test client.http2_initial_settings !== nothing + @test length(client.http2_initial_settings) == 2 + @test client.conn_manager_opts.num_initial_settings == Csize_t(2) + @test client.conn_manager_opts.initial_settings_array != C_NULL + finalize(client) + + cs = HTTP.ClientSettings("https", "example.com", UInt32(443); + http2_stream_manager=true, + http2_initial_settings=settings, + ) + client = HTTP.Client(cs) + opts = client.http2_stream_manager_opts + @test opts !== nothing + @test opts.num_initial_settings == Csize_t(2) + @test opts.initial_settings_array != C_NULL + finalize(client) + + @test_throws ArgumentError HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443); + http2_initial_settings=1, + )) + end + + @testset "HTTP/2 window size validation" begin + @test_throws ArgumentError HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443); + http2_initial_window_size=-1, + )) + @test_throws ArgumentError HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443); + http2_initial_window_size=HTTP.HTTP2_MAX_WINDOW_SIZE + 1, + )) + end + + @testset "HTTP manager metrics" begin + client = HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443))) + metrics = HTTP.manager_metrics(client) + @test metrics.available_concurrency >= 0 + @test metrics.pending_concurrency_acquires >= 0 + @test metrics.leased_concurrency >= 0 + finalize(client) + + client = HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443); http2_stream_manager=true)) + metrics = HTTP.manager_metrics(client) + @test metrics.available_concurrency >= 0 + @test metrics.pending_concurrency_acquires >= 0 + @test metrics.leased_concurrency >= 0 + finalize(client) + end + + @testset "HTTP connection monitoring stats" begin + list = HTTP.aws_crt_statistics_http1_channel[] + stat1 = HTTP.aws_crt_statistics_http1_channel( + HTTP.AWSCRT_STAT_CAT_HTTP1_CHANNEL, + UInt64(10), + UInt64(20), + UInt32(1), + UInt32(2), + ) + push!(list, stat1) + decoded = HTTP._decode_statistics(list) + @test length(decoded) == 1 + @test decoded[1].category == :http1_channel + @test decoded[1].pending_outgoing_stream_ms == 10 + @test decoded[1].pending_incoming_stream_ms == 20 + @test decoded[1].current_outgoing_stream_id == 1 + @test decoded[1].current_incoming_stream_id == 2 + + list = HTTP.aws_crt_statistics_http2_channel[] + stat2 = HTTP.aws_crt_statistics_http2_channel( + HTTP.AWSCRT_STAT_CAT_HTTP2_CHANNEL, + UInt64(5), + UInt64(6), + true, + ) + push!(list, stat2) + decoded = HTTP._decode_statistics(list) + @test length(decoded) == 1 + @test decoded[1].category == :http2_channel + @test decoded[1].pending_outgoing_stream_ms == 5 + @test decoded[1].pending_incoming_stream_ms == 6 + @test decoded[1].was_inactive == true + + called = Ref(false) + cb = (nonce, stats) -> (called[] = true) + client = HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443); monitoring_statistics_observer=cb)) + list = HTTP.aws_crt_statistics_http1_channel[] + stat3 = HTTP.aws_crt_statistics_http1_channel( + HTTP.AWSCRT_STAT_CAT_HTTP1_CHANNEL, + UInt64(1), + UInt64(1), + UInt32(1), + UInt32(1), + ) + push!(list, stat3) + HTTP._call_statistics_observer(client.monitoring_observer, Csize_t(0), list) + @test called[] + finalize(client) + end + + @testset "Proxy basic auth strategy" begin + opts = HTTP.proxy_kwargs("http://user:pass@proxy.local:3128", "http") + @test opts.proxy_auth == :basic + @test opts.proxy_username == "user" + @test opts.proxy_password == "pass" + + cs = HTTP.ClientSettings("https", "example.com", UInt32(443); + proxy_host="proxy.local", + proxy_port=UInt32(3128), + proxy_connection_type=:forward, + proxy_auth=:basic, + proxy_username="user", + proxy_password="pass", + ) + client = HTTP.Client(cs) + @test client.proxy_options !== nothing + @test client.proxy_strategy != C_NULL + @test client.proxy_options.proxy_strategy == client.proxy_strategy + finalize(client) + + @test_throws ArgumentError HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443); + proxy_host="proxy.local", + proxy_port=UInt32(3128), + proxy_auth=:basic, + proxy_username="user", + )) + end + + if HAVE_HTTPBIN + @testset "HTTP/2 control APIs" begin + resp = HTTP.get("https://$httpbin/ip") + if resp.version == HTTP.HTTPVersion(2, 0) + HTTP.open("GET", "https://$httpbin/ip") do io + r = HTTP.startread(io) + @test r.status == 200 + rtt = HTTP.http2_ping(io) + @test rtt isa UInt64 + HTTP.http2_change_settings(io, Pair{Int, Int}[]) + @test length(HTTP.http2_local_settings(io)) == HTTP.AWS_HTTP2_SETTINGS_COUNT + @test HTTP.http2_get_sent_goaway(io) === nothing + @test HTTP.http2_get_received_goaway(io) === nothing + @test_nowarn HTTP.http2_update_window(io, 1024) + @test_nowarn HTTP.update_window(io, 1024) + end + else + @info "HTTP/2 not available for $httpbin" + end + end + + @testset "Public entry point of HTTP.request and friends (e.g. issue #463)" begin + headers = Dict("User-Agent" => "HTTP.jl") + query = Dict("hello" => "world") + body = UInt8[1, 2, 3] + for uri in ("https://$httpbin/anything", HTTP.URI("https://$httpbin/anything")) + # HTTP.request + @test isok(HTTP.request("GET", uri; headers=headers, body=body, query=query)) + @test isok(HTTP.request("GET", uri, headers; body=body, query=query)) + @test isok(HTTP.request("GET", uri, headers, body; query=query)) + # HTTP.get + @test isok(HTTP.get(uri; headers=headers, body=body, query=query)) + @test isok(HTTP.get(uri, headers; body=body, query=query)) + @test isok(HTTP.get(uri, headers, body; query=query)) + # HTTP.put + @test isok(HTTP.put(uri; headers=headers, body=body, query=query)) + @test isok(HTTP.put(uri, headers; body=body, query=query)) + @test isok(HTTP.put(uri, headers, body; query=query)) + # HTTP.post + @test isok(HTTP.post(uri; headers=headers, body=body, query=query)) + @test isok(HTTP.post(uri, headers; body=body, query=query)) + @test isok(HTTP.post(uri, headers, body; query=query)) + # HTTP.patch + @test isok(HTTP.patch(uri; headers=headers, body=body, query=query)) + @test isok(HTTP.patch(uri, headers; body=body, query=query)) + @test isok(HTTP.patch(uri, headers, body; query=query)) + # HTTP.delete + @test isok(HTTP.delete(uri; headers=headers, body=body, query=query)) + @test isok(HTTP.delete(uri, headers; body=body, query=query)) + @test isok(HTTP.delete(uri, headers, body; query=query)) + end + end + else + @info "Skipping HTTPBin-dependent HTTP/2 control and request entry point tests" + end +end diff --git a/test/fixtures/http2.crt b/test/fixtures/http2.crt new file mode 100644 index 000000000..0e60882b5 --- /dev/null +++ b/test/fixtures/http2.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDCTCCAfGgAwIBAgIUS6toYHFUbN6xEG51OuiEhDeHiZwwDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDEyNTA3MjgxMFoXDTI3MDEy +NTA3MjgxMFowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF +AAOCAQ8AMIIBCgKCAQEAiZAY+NfcrqOSKAusYaFd9DSvYF3sRUTcTtyx35cawH1Y +196hTh6vyRIS7pHnX2kKwzc8G1C8aQXs4tBHuF0o+D06z2QU8fyAfJHrb1AhGpFm +y76HzeuT7S9Mmr58Kw4Yks1NuCt4LyBeC0AFEo6FKknlo+GiQvqZLKPVCcabca5r +mh28zvAhWCe00CX5HeXw+BJFDeD618QAiFF4Mr6imoL8TLPyLYz1t+mav6cLvxQK +vrr0FNSMVdmQ+VYF8tBOftCpblmyiswjeCTzKr95RPh6lQdfdn8Oq9Y29SCwvbKr +qlwElrU/RTCukjITk8Gzo/kxJtuRJDduyckN0s0P/QIDAQABo1MwUTAdBgNVHQ4E +FgQUGV6pqqYSC7wPdjKU30QOZl6/AP8wHwYDVR0jBBgwFoAUGV6pqqYSC7wPdjKU +30QOZl6/AP8wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAV/i6 +DjhMOWupCv+X6Re0uR2lasXHRfimFuzanWtNj+A226PHvk+XsYg+CNq+vTPPtMj1 +yQuz8bH+zHcPaXEfXlUzZHZUqMXb0yCxJLPP64nsvYkeoJtorgYzoa20eY9PRdZx +xo/Bvi5IVRaVh3UxaqlUMRmbNc8/OS7Q+CQXXJtfuKnCvwE68b2/M9szPobl1zgY +L3NflRuPHBdtX8/OhDNHygexujyLtI/e/3DzhokkchXjEiPiBpgzrybp6L52XwMb +MBv3houfL6+/Bd+pmnT2NPy5Qqe7ifGZAlz/kNN/6fNGp/FRTa83rt3+rYXzca4A +uOyQ8so7PeuSkuwEEw== +-----END CERTIFICATE----- diff --git a/test/fixtures/http2.key b/test/fixtures/http2.key new file mode 100644 index 000000000..59cc30315 --- /dev/null +++ b/test/fixtures/http2.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCJkBj419yuo5Io +C6xhoV30NK9gXexFRNxO3LHflxrAfVjX3qFOHq/JEhLukedfaQrDNzwbULxpBezi +0Ee4XSj4PTrPZBTx/IB8ketvUCEakWbLvofN65PtL0yavnwrDhiSzU24K3gvIF4L +QAUSjoUqSeWj4aJC+pkso9UJxptxrmuaHbzO8CFYJ7TQJfkd5fD4EkUN4PrXxACI +UXgyvqKagvxMs/ItjPW36Zq/pwu/FAq+uvQU1IxV2ZD5VgXy0E5+0KluWbKKzCN4 +JPMqv3lE+HqVB192fw6r1jb1ILC9squqXASWtT9FMK6SMhOTwbOj+TEm25EkN27J +yQ3SzQ/9AgMBAAECggEABJ9fO3NzgrnT5j0YLpKv184ya3xUfXJmLc88Oe13tGqx +4tUkGf4tfYX6NWKZQgtDVZKEVk2k+yl8D5Ycpt0YjJjwIjp0ero3rhfwL6YjaqNi +ryuIoDqPlruNaTDH9uFrIXm9KBhr2jeN+XZOBVTdNDHWAecJ3xLRNV8PAFxYal47 +uD2UhiFp4GMbQDcjLix4FCAUQTw5Y11a6wmvJu/PmB1u9VLf1HOcLjBtS9YYaIh+ +RuxmqQG/02lVYf6TQKzuT+njJ2Tl4DDH1lvdSyYo2nlW8ybpvux23RyOefSl6PZT +MDutLFYvnic8xIxnwKtchWNWp6Em2ryKHpgwBFVQQQKBgQDCWPt3t+AoUYG6vFtT +BZvWz85E5kNSGT/cC9h4e6xCt2C9qzn8nQ8dsCAhWmWZTMvY/yG0feQ9CLt0rNGF +fsHaGKaau1TpWPOB/pM9m2XMgJo9ZDxWF3qoN5vDQ1zZlcm1dYFoHqcgrmrh5nXM +JXr9pJw0g8/fNaeV4Q9pO92FwQKBgQC1M51XorzbYK6QHKOhuoVMrwh32a0a+TEG +ElRNNQ5rK+3H21tPlk+QKAhOYiihIf9Ut/4yfmR8njV4eM/MtbaiGfFSd2ENiU1d +xnEOSD42NgEFBCvLDcu5/wE4YQwweieohOd3lrsntCucKd616aghhiLjeT/GSabo +2RzYSXpxPQKBgQCLzokvySXGu0OQurkTk0BVGm5vIBojsChBOoBBw+3anKJKLyfq +sm1SVQX4GFhoHFe0RWzQs5OB2ItJVpzu5I29P+hx/PsLVkLuK91t/yEPKSBLs5S3 +9fH1mvNBV28u01MkZ2BtL0fY+b/HvArXjcrZNhZsrLnX/3gMGLgGYttrwQKBgASg +3uIALCbGX28a7CsTYphE2EiHbN6FgvUOvsyCEG44Xwh91+U+h6W9AAlQhI0pGyaE +1J9hjxuHxwHexCAMfC/DzeA3YGlCGpHModKlkcE8u+Xu51d2cL+9fcB86hzK4fxx ++J+bYAhxl7OTdjbbUwoYLQf2buSXuQW1lgEIT3JZAoGAWtsY3mftmxg8hchjsK8V +SZdJcU6Lx+gPUBHi3YPPXrVJJjO2UIBI/w09xSz4MCslHl+HXfiA0uCW3tF8DVGy +OXB/6EVo/wQmNoHZYO5IVqf16Q/cBCiPQ9HR/UJj/D5QfyqseUJ28V/n08oz/Wka +VQXUh0zfszlDyPndAMuXwd8= +-----END PRIVATE KEY----- diff --git a/test/headers.jl b/test/headers.jl new file mode 100644 index 000000000..9263de22c --- /dev/null +++ b/test/headers.jl @@ -0,0 +1,36 @@ +@testset "Headers helpers" begin + h = HTTP.Headers() + HTTP.addheader(h, "x-test-header", "abc") + HTTP.addheader(h, "content-type", "text/plain") + + @test HTTP.header(h, "X-Test-Header") == "abc" + @test HTTP.header(h, "missing", "default") == "default" + @test HTTP.hasheader(h, "x-test-header") + @test HTTP.headercontains(h, "x-test-header", "abc") + + HTTP.canonicalizeheaders!(h) + @test any(x -> first(x) == "X-Test-Header", h) + @test any(x -> first(x) == "Content-Type", h) +end + +@testset "Headers vector compatibility" begin + h = HTTP.Headers() + push!(h, "a" => "1") + push!(h, "b" => "2") + @test h[1] == ("a" => "1") + h[2] = "b" => "3" + @test h[2] == ("b" => "3") + insert!(h, 2, "c" => "4") + @test h[2] == ("c" => "4") + + req = HTTP.Request("GET", "/") + req.headers = ["x" => "1", "y" => "2"] + @test length(req.headers) == 2 + req.headers = ["z" => "3"] + @test length(req.headers) == 1 + @test HTTP.header(req, "z") == "3" + + hdrs = req.headers + push!(hdrs, "w" => "4") + @test HTTP.header(req, "w") == "4" +end diff --git a/test/multipart.jl b/test/multipart.jl index 1defa2c09..e1c704462 100644 --- a/test/multipart.jl +++ b/test/multipart.jl @@ -11,11 +11,15 @@ end headers = Dict("User-Agent" => "HTTP.jl") body = HTTP.Form(Dict()) mark(body) - @testset "Setting of Content-Type" begin - test_multipart(HTTP.request("POST", "https://$httpbin/post", headers, body), body) - test_multipart(HTTP.post("https://$httpbin/post", headers, body), body) - test_multipart(HTTP.request("PUT", "https://$httpbin/put", headers, body), body) - test_multipart(HTTP.put("https://$httpbin/put", headers, body), body) + if HAVE_HTTPBIN + @testset "Setting of Content-Type" begin + test_multipart(HTTP.request("POST", "https://$httpbin/post", headers, body), body) + test_multipart(HTTP.post("https://$httpbin/post", headers, body), body) + test_multipart(HTTP.request("PUT", "https://$httpbin/put", headers, body), body) + test_multipart(HTTP.put("https://$httpbin/put", headers, body), body) + end + else + @info "Skipping HTTPBin-dependent multipart tests" end @testset "HTTP.Multipart ensure show() works correctly" begin # testing that there is no error in printing when nothing is set for filename diff --git a/test/runtests.jl b/test/runtests.jl index 6e0adf105..0ca425e5d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,11 +1,22 @@ -using Test, HTTP, URIs, JSON +using Test, HTTP, URIs, JSON, Reseau const httpbin = get(ENV, "JULIA_TEST_HTTPBINGO_SERVER", "httpbingo.julialang.org") isok(r) = r.status == 200 +const HAVE_HTTPBIN = let + try + resp = HTTP.get("http://$httpbin/ip"; readtimeout=2, max_retries=0) + isok(resp) + catch + false + end +end include("utils.jl") +include("headers.jl") +include("httpversion.jl") include("sniff.jl") include("multipart.jl") include("client.jl") include("handlers.jl") include("server.jl") +include("websockets_basic.jl") diff --git a/test/server.jl b/test/server.jl index 535c314bd..abcf71818 100644 --- a/test/server.jl +++ b/test/server.jl @@ -1,10 +1,265 @@ -using Test, HTTP, Logging +using Test, HTTP, Logging, Base64, Reseau +import Sockets + +# Find a subsequence in a byte vector (naive; good enough for tests). +function _find_subseq(hay::Vector{UInt8}, needle::Vector{UInt8}, from::Int = 1)::Int + n = length(needle) + m = length(hay) + (n == 0 || from > m) && return 0 + last_i = m - n + 1 + last_i < from && return 0 + @inbounds for i in from:last_i + if hay[i] == needle[1] + ok = true + for j in 2:n + if hay[i + j - 1] != needle[j] + ok = false + break + end + end + ok && return i + end + end + return 0 +end + +const _HTTP1_CHUNKED_TERMINATOR = Vector{UInt8}(codeunits("\r\n0\r\n")) +const _HTTP1_HEADERS_TERMINATOR = Vector{UInt8}(codeunits("\r\n\r\n")) + +function _http1_chunked_with_trailers_done(buf::Vector{UInt8})::Bool + idx0 = _find_subseq(buf, _HTTP1_CHUNKED_TERMINATOR) + idx0 == 0 && return false + idx_end = _find_subseq(buf, _HTTP1_HEADERS_TERMINATOR, idx0 + length(_HTTP1_CHUNKED_TERMINATOR)) + return idx_end != 0 +end + +# Minimal raw TCP client for tests. Avoids the `Sockets` stdlib. +@static if Sys.iswindows() + const _WS2_32 = "ws2_32" + const _AF_INET = Cint(2) + const _SOCK_STREAM = Cint(1) + const _IPPROTO_TCP = Cint(6) + const _INVALID_SOCKET = typemax(UInt) + const _SOCKET_ERROR = Cint(-1) + const _POLLIN = Int16(0x0001) + const _RAW_READ_TIMEOUT_MS = Cint(5_000) + + struct _sockaddr_in + sin_family::UInt16 + sin_port::UInt16 + sin_addr::UInt32 + sin_zero::NTuple{8, UInt8} + end + + struct _WSAPOLLFD + fd::UInt + events::Int16 + revents::Int16 + end + + function _raw_tcp_connect_readall( + host::AbstractString, + port::Integer, + request::AbstractString, + ; + stop_pred = nothing, + )::Vector{UInt8} + wsadata = Vector{UInt8}(undef, 512) + ret = ccall((:WSAStartup, _WS2_32), Cint, (UInt16, Ptr{UInt8}), UInt16(0x0202), wsadata) + ret == 0 || error("WSAStartup() failed: $ret") + sock = ccall((:socket, _WS2_32), UInt, (Cint, Cint, Cint), _AF_INET, _SOCK_STREAM, _IPPROTO_TCP) + if sock == _INVALID_SOCKET + err = ccall((:WSAGetLastError, _WS2_32), Cint, ()) + _ = ccall((:WSACleanup, _WS2_32), Cint, ()) + error("socket() failed: $err") + end + + try + addr = ccall((:inet_addr, _WS2_32), UInt32, (Cstring,), host) + addr == 0xffffffff && error("inet_addr() failed for host=$host") + port_be = ccall((:htons, _WS2_32), UInt16, (UInt16,), UInt16(port)) + sin = _sockaddr_in(UInt16(_AF_INET), port_be, addr, ntuple(_ -> UInt8(0), 8)) + + cres = ccall((:connect, _WS2_32), Cint, (UInt, Ref{_sockaddr_in}, Cint), sock, sin, Cint(sizeof(_sockaddr_in))) + cres == 0 || error("connect() failed: $(ccall((:WSAGetLastError, _WS2_32), Cint, ()))") + + bytes = Vector{UInt8}(codeunits(request)) + sent = GC.@preserve bytes ccall( + (:send, _WS2_32), + Cint, + (UInt, Ptr{UInt8}, Cint, Cint), + sock, + pointer(bytes), + Cint(length(bytes)), + Cint(0), + ) + sent == length(bytes) || error("send() failed: $(ccall((:WSAGetLastError, _WS2_32), Cint, ()))") + # Some server stacks treat EOF as end-of-request; half-close our send side. + _ = ccall((:shutdown, _WS2_32), Cint, (UInt, Cint), sock, Cint(1)) # SD_SEND = 1 + + buf = UInt8[] + tmp = Vector{UInt8}(undef, 4096) + while true + pollfd = Ref(_WSAPOLLFD(sock, _POLLIN, Int16(0))) + pres = ccall( + (:WSAPoll, _WS2_32), + Cint, + (Ptr{_WSAPOLLFD}, UInt32, Cint), + pollfd, + UInt32(1), + _RAW_READ_TIMEOUT_MS, + ) + pres == 0 && error("WSAPoll() timeout") + pres == _SOCKET_ERROR && error("WSAPoll() failed: $(ccall((:WSAGetLastError, _WS2_32), Cint, ()))") + + r = GC.@preserve tmp ccall( + (:recv, _WS2_32), + Cint, + (UInt, Ptr{UInt8}, Cint, Cint), + sock, + pointer(tmp), + Cint(length(tmp)), + Cint(0), + ) + r == _SOCKET_ERROR && error("recv() failed: $(ccall((:WSAGetLastError, _WS2_32), Cint, ()))") + r == 0 && break + append!(buf, view(tmp, 1:r)) + stop_pred !== nothing && stop_pred(buf) && return buf + end + return buf + finally + _ = ccall((:closesocket, _WS2_32), Cint, (UInt,), sock) + _ = ccall((:WSACleanup, _WS2_32), Cint, ()) + end + end +else + const _AF_INET = Cint(2) + const _SOCK_STREAM = Cint(1) + const _POLLIN = Int16(0x0001) + const _RAW_READ_TIMEOUT_MS = Cint(5_000) + + @static if Sys.isapple() || Sys.isbsd() + struct _sockaddr_in + sin_len::UInt8 + sin_family::UInt8 + sin_port::UInt16 + sin_addr::UInt32 + sin_zero::NTuple{8, UInt8} + end + else + struct _sockaddr_in + sin_family::UInt16 + sin_port::UInt16 + sin_addr::UInt32 + sin_zero::NTuple{8, UInt8} + end + end + + struct _pollfd + fd::Cint + events::Int16 + revents::Int16 + end + + const _poll_nfds_t = @static (Sys.isapple() || Sys.isbsd()) ? Cuint : Culong + + function _raw_tcp_connect_readall( + host::AbstractString, + port::Integer, + request::AbstractString, + ; + stop_pred = nothing, + )::Vector{UInt8} + fd = ccall(:socket, Cint, (Cint, Cint, Cint), _AF_INET, _SOCK_STREAM, 0) + fd < 0 && error("socket() failed: errno=$(Libc.errno())") + try + addr = Ref{UInt32}(0) + ok = ccall(:inet_pton, Cint, (Cint, Cstring, Ptr{Cvoid}), _AF_INET, host, addr) + ok == 1 || error("inet_pton() failed for host=$host") + + port16 = UInt16(port) + port_be = ccall(:htons, UInt16, (UInt16,), port16) + + sin = @static if Sys.isapple() || Sys.isbsd() + _sockaddr_in( + UInt8(sizeof(_sockaddr_in)), + UInt8(_AF_INET), + port_be, + addr[], + ntuple(_ -> UInt8(0), 8), + ) + else + _sockaddr_in( + UInt16(_AF_INET), + port_be, + addr[], + ntuple(_ -> UInt8(0), 8), + ) + end + + ret = ccall(:connect, Cint, (Cint, Ref{_sockaddr_in}, Cuint), fd, sin, Cuint(sizeof(_sockaddr_in))) + ret == 0 || error("connect() failed: errno=$(Libc.errno())") + + bytes = Vector{UInt8}(codeunits(request)) + n = GC.@preserve bytes ccall( + :write, + Cssize_t, + (Cint, Ptr{UInt8}, Csize_t), + fd, + pointer(bytes), + Csize_t(length(bytes)), + ) + n == length(bytes) || error("write() failed: errno=$(Libc.errno())") + # Some server stacks treat EOF as end-of-request; half-close our send side. + _ = ccall(:shutdown, Cint, (Cint, Cint), fd, Cint(1)) # SHUT_WR = 1 + + buf = UInt8[] + tmp = Vector{UInt8}(undef, 4096) + while true + pollfd = Ref(_pollfd(fd, _POLLIN, Int16(0))) + pres = ccall( + :poll, + Cint, + (Ptr{_pollfd}, _poll_nfds_t, Cint), + pollfd, + _poll_nfds_t(1), + _RAW_READ_TIMEOUT_MS, + ) + if pres == 0 + error("poll() timeout") + end + if pres < 0 + err = Libc.errno() + err == Libc.EINTR && continue + error("poll() failed: errno=$err") + end + + r = GC.@preserve tmp ccall( + :read, + Cssize_t, + (Cint, Ptr{UInt8}, Csize_t), + fd, + pointer(tmp), + Csize_t(length(tmp)), + ) + r < 0 && error("read() failed: errno=$(Libc.errno())") + r == 0 && break + append!(buf, view(tmp, 1:Int(r))) + stop_pred !== nothing && stop_pred(buf) && return buf + end + return buf + finally + _ = ccall(:close, Cint, (Cint,), fd) + end + end +end @testset "HTTP.serve" begin server = HTTP.serve!(req -> HTTP.Response(200, "Hello, World!"); listenany=true) try @test server.state == :running - resp = HTTP.get("http://127.0.0.1:8080") + port = HTTP.port(server) + resp = HTTP.get("http://127.0.0.1:$port") @test resp.status == 200 @test String(resp.body) == "Hello, World!" finally @@ -12,6 +267,273 @@ using Test, HTTP, Logging end end +@testset "server shutdown hooks" begin + closed = Threads.Atomic{Int}(0) + server = HTTP.serve!(req -> HTTP.Response(200, "ok"); listenany=true, on_shutdown=() -> (closed[] += 1)) + try + port = HTTP.port(server) + HTTP.get("http://127.0.0.1:$port") + finally + close(server) + end + @test closed[] == 1 + + forced = Threads.Atomic{Int}(0) + server2 = HTTP.serve!(req -> HTTP.Response(200, "ok"); listenany=true, + on_shutdown=[() -> (forced[] += 1), () -> (forced[] += 1)]) + try + port2 = HTTP.port(server2) + HTTP.get("http://127.0.0.1:$port2") + finally + HTTP.forceclose(server2) + end + @test forced[] == 2 +end + +@testset "access logging stream handler" begin + logger = Test.TestLogger() + with_logger(logger) do + server = HTTP.listen!("127.0.0.1", 0; listenany=true, access_log=common_logfmt) do http + read(http) + HTTP.setstatus(http, 200) + HTTP.startwrite(http) + write(http, "hello") + end + port = HTTP.port(server) + try + HTTP.post("http://127.0.0.1:$port"; body="x") + sleep(1) + finally + close(server) + end + end + logs = filter!(x -> x.group == :access, logger.logs) + @test length(logs) == 1 + @test occursin(r" 200 5$", logs[1].message) +end + +@testset "HTTP.listen stream handler" begin + server = HTTP.listen!("127.0.0.1", 0; listenany=true) do http + body = String(read(http)) + HTTP.setstatus(http, 200) + HTTP.setheader(http, "Content-Type" => "text/plain") + HTTP.startwrite(http) + write(http, isempty(body) ? "ping" : body) + end + try + port = HTTP.port(server) + resp = HTTP.get("http://127.0.0.1:$port") + @test resp.status == 200 + @test String(resp.body) == "ping" + + resp = HTTP.post("http://127.0.0.1:$port"; body="echo") + @test resp.status == 200 + @test String(resp.body) == "echo" + finally + close(server) + end +end + +@testset "HTTP.streamhandler" begin + handler = req -> begin + body = req.body === nothing ? UInt8[] : req.body + if isempty(body) + return HTTP.Response(200, ["Content-Type" => "text/plain"], "ping") + end + return HTTP.Response(200, ["Content-Type" => "text/plain"], String(body)) + end + server = HTTP.listen!(HTTP.streamhandler(handler), "127.0.0.1", 0; listenany=true) + try + port = HTTP.port(server) + resp = HTTP.get("http://127.0.0.1:$port") + @test resp.status == 200 + @test String(resp.body) == "ping" + + resp = HTTP.post("http://127.0.0.1:$port"; body="echo") + @test resp.status == 200 + @test String(resp.body) == "echo" + finally + close(server) + end +end + +@testset "HTTP response trailers" begin + server = HTTP.listen!("127.0.0.1", 0; listenany=true) do http + read(http) + HTTP.setstatus(http, 200) + HTTP.startwrite(http) + write(http, "hello") + HTTP.addtrailer(http, "X-Trailer" => "ok") + HTTP.closewrite(http) + end + try + port = HTTP.port(server) + sock = Sockets.connect("127.0.0.1", port) + write(sock, "GET / HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\n\r\n") + flush(sock) + raw = String(read(sock)) + close(sock) + lower_raw = lowercase(raw) + @test occursin("transfer-encoding: chunked", lower_raw) + @test occursin("hello", raw) + @test occursin("\r\n0\r\nx-trailer: ok\r\n\r\n", lower_raw) + finally + close(server) + end +end + +@testset "HTTP/2 TLS support" begin + if !Reseau.Sockets.tls_is_alpn_available() + @info "Skipping HTTP/2 TLS tests; ALPN not available" + @test true + else + @testset "HTTP/2 stream handler writes" begin + cert = joinpath(@__DIR__, "fixtures", "http2.crt") + key = joinpath(@__DIR__, "fixtures", "http2.key") + saw_http2 = Threads.Atomic{Bool}(false) + buffered = Threads.Atomic{Bool}(false) + server = HTTP.serve!("127.0.0.1", 0; listenany=true, stream=true, ssl_cert=cert, ssl_key=key, ssl_alpn_list="h2") do stream + HTTP.startread(stream) + stream.http2 && (saw_http2[] = true) + HTTP.setstatus(stream, 200) + HTTP.startwrite(stream) + write(stream, "hello") + if stream.http2 && stream.responsebuf !== nothing + buffered[] = true + end + HTTP.closewrite(stream) + end + try + port = HTTP.port(server) + resp = HTTP.get("https://127.0.0.1:$(port)"; ssl_insecure=true, ssl_alpn_list="h2") + if resp.version == HTTP.HTTPVersion(2, 0) + @test saw_http2[] + @test !buffered[] + @test String(resp.body) == "hello" + else + @info "HTTP/2 not negotiated for stream handler test" + @test true + end + finally + close(server) + end + end + @testset "HTTP/2 host and authority mapping" begin + cert = joinpath(@__DIR__, "fixtures", "http2.crt") + key = joinpath(@__DIR__, "fixtures", "http2.key") + host_ref = Ref{String}("") + authority_ref = Ref{String}("") + server = HTTP.serve!("127.0.0.1", 0; listenany=true, stream=true, ssl_cert=cert, ssl_key=key, ssl_alpn_list="h2") do stream + HTTP.startread(stream) + if stream.http2 + host_ref[] = HTTP.header(stream.request, "host") + authority_ref[] = HTTP.header(stream.request, ":authority") + end + HTTP.setstatus(stream, 200) + HTTP.startwrite(stream) + write(stream, "ok") + HTTP.closewrite(stream) + end + try + port = HTTP.port(server) + resp = HTTP.get("https://127.0.0.1:$(port)"; ssl_insecure=true, ssl_alpn_list="h2", headers=["Host" => "example.com"]) + if resp.version == HTTP.HTTPVersion(2, 0) + @test host_ref[] == "example.com" + @test authority_ref[] == "example.com" + else + @info "HTTP/2 not negotiated for host/authority mapping test" + @test true + end + finally + close(server) + end + end + @testset "HTTP/2 server push promise" begin + cert = joinpath(@__DIR__, "fixtures", "http2.crt") + key = joinpath(@__DIR__, "fixtures", "http2.key") + port_ref = Ref{Int}(0) + push_called = Threads.Atomic{Bool}(false) + push_http2 = Threads.Atomic{Bool}(false) + push_server_side = Threads.Atomic{Bool}(false) + server = HTTP.serve!("127.0.0.1", 0; listenany=true, stream=true, ssl_cert=cert, ssl_key=key, ssl_alpn_list="h2") do stream + HTTP.startread(stream) + if stream.http2 + authority = "127.0.0.1:$(port_ref[])" + push = HTTP.push_promise(stream, "GET", "/pushed"; scheme="https", authority=authority) + push_called[] = true + push_http2[] = push.http2 + push_server_side[] = push.server_side + HTTP.setstatus(push, 200) + HTTP.setheader(push, "Content-Type" => "text/plain") + write(push, "pushed") + HTTP.closewrite(push) + end + HTTP.setstatus(stream, 200) + write(stream, "ok") + end + try + port_ref[] = HTTP.port(server) + resp = HTTP.get("https://127.0.0.1:$(port_ref[])"; ssl_insecure=true, ssl_alpn_list="h2") + if resp.version == HTTP.HTTPVersion(2, 0) + @test push_called[] + @test push_http2[] + @test push_server_side[] + @test String(resp.body) == "ok" + else + @info "HTTP/2 not negotiated for push promise test" + @test true + end + finally + close(server) + end + end + @testset "HTTP/2 readtimeout keeps connection open" begin + cert = joinpath(@__DIR__, "fixtures", "http2.crt") + key = joinpath(@__DIR__, "fixtures", "http2.key") + seen_lock = ReentrantLock() + seen_conns = Set{UInt}() + server = HTTP.serve!("127.0.0.1", 0; listenany=true, stream=true, ssl_cert=cert, ssl_key=key, ssl_alpn_list="h2") do stream + HTTP.startread(stream) + @lock seen_lock push!(seen_conns, objectid(stream.connection)) + if stream.request.path == "/slow" + sleep(2) + else + sleep(1.5) + end + try + HTTP.setstatus(stream, 200) + HTTP.startwrite(stream) + write(stream, stream.request.path == "/slow" ? "slow" : "fast") + HTTP.closewrite(stream) + catch + nothing + end + end + try + port = HTTP.port(server) + cs = HTTP.ClientSettings("https", "127.0.0.1", UInt32(port); ssl_insecure=true, ssl_alpn_list="h2", max_connections=1) + client = HTTP.Client(cs) + slow_err = try + HTTP.get("https://127.0.0.1:$(port)/slow"; client=client, readtimeout=1, retry=false) + catch e + e + end + fast_resp = HTTP.get("https://127.0.0.1:$(port)/fast"; client=client, retry=false) + if fast_resp.version == HTTP.HTTPVersion(2, 0) + @test slow_err isa HTTP.TimeoutError + @test String(fast_resp.body) == "fast" + @test length(seen_conns) == 1 + else + @info "HTTP/2 not negotiated for readtimeout connection test" + @test true + end + finally + close(server) + end + end + end +end + @testset "access logging" begin local handler = (req) -> begin if req.target == "/internal-error" @@ -82,11 +604,14 @@ end HTTP.get("http://127.0.0.1:$port/index.html") HTTP.get("http://127.0.0.1:$port/index.html?a=b") HTTP.head("http://127.0.0.1:$port") + auth = Base64.base64encode("alice:secret") + HTTP.get("http://127.0.0.1:$port/auth", ["Authorization" => "Basic $auth"]) end - @test length(logs) == 4 + @test length(logs) == 5 @test all(x -> x.group === :access, logs) @test occursin(r"^application/json text/plain GET / HTTP/1\.1 GET / 127\.0\.0\.1 \d+ - HTTP/1\.1 \d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.* \d+/.*/\d{4}:\d{2}:\d{2}:\d{2}.* 200 12$", logs[1].message) @test occursin(r"^\*/\* text/plain GET /index\.html HTTP/1\.1 GET /index\.html 127\.0\.0\.1 \d+ - HTTP/1\.1 \d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.* \d+/.*/\d{4}:\d{2}:\d{2}:\d{2}.* 200 12$", logs[2].message) @test occursin(r"^\*/\* text/plain GET /index\.html\?a=b HTTP/1\.1 GET /index\.html\?a=b 127\.0\.0\.1 \d+ - HTTP/1\.1 \d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.* \d+/.*/\d{4}:\d{2}:\d{2}:\d{2}.* 200 12$", logs[3].message) @test occursin(r"^\*/\* text/plain HEAD / HTTP/1\.1 HEAD / 127\.0\.0\.1 \d+ - HTTP/1\.1 \d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.* \d+/.*/\d{4}:\d{2}:\d{2}:\d{2}.* 200 0$", logs[4].message) -end \ No newline at end of file + @test occursin(r"^\*/\* text/plain GET /auth HTTP/1\.1 GET /auth 127\.0\.0\.1 \d+ alice HTTP/1\.1 \d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.* \d+/.*/\d{4}:\d{2}:\d{2}:\d{2}.* 200 12$", logs[5].message) +end diff --git a/test/utils.jl b/test/utils.jl index e02a73a33..d4422a9cc 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -29,4 +29,96 @@ ) @test HTTP.iso8859_1_to_utf8(bytes) == utf8 end -end # testset \ No newline at end of file + + buf = UInt8[0x01, 0x02] + @test HTTP.bytes(buf) === buf + @test collect(HTTP.bytes("hi")) == collect(codeunits("hi")) + @test HTTP.nbytes("hi") == 2 + @test HTTP.nbytes(buf) == 2 + @test HTTP.nbytes([buf, UInt8[0x03]]) == 3 + @test HTTP.nbytes(["a", "bc"]) == 3 + @test HTTP.nbytes(IOBuffer("abc")) == 3 + @test HTTP.nobytes isa AbstractVector{UInt8} + @test isempty(HTTP.nobytes) + @test HTTP.ascii_lc_isequal("AbC", "aBc") + @test !HTTP.ascii_lc_isequal("abc", "abd") + + err = HTTP.ConnectError("http://example.com", ErrorException("boom")) + @test err isa HTTP.ConnectError + @test err.url == "http://example.com" + @test err.error isa ErrorException + @test HTTP.TimeoutError(5).readtimeout == 5 + req = HTTP.Request("GET", "/") + req_err = HTTP.RequestError(req, ErrorException("nope")) + @test req_err.request === req + @test req_err.error isa ErrorException + msg = "" + try + error("boom") + catch + msg = HTTP.current_exceptions_to_string() + end + @test occursin("boom", msg) + got = false + HTTP.@try ArgumentError begin + got = true + throw(ArgumentError("boom")) + end + @test got + + # parseuri now delegates to URIs.jl (no C-level allocator needed) + @test HTTP.parseuri("http://example.com/path", nothing) isa URIs.URI + + exported = names(HTTP, all=false) + @test :startwrite in exported + @test :startread in exported + @test :closewrite in exported + @test :closeread in exported + @test :Stream in exported + @test :Request in exported + @test :Response in exported + @test :Message in exported + @test :Header in exported + @test :Headers in exported + @test :header in exported + @test :headers in exported + @test :hasheader in exported + @test :setheader in exported + @test :appendheader in exported + @test :removeheader in exported + @test :bytes in exported + @test :nbytes in exported + @test :nobytes in exported + @test :escapehtml in exported + @test :tocameldash in exported + @test :iso8859_1_to_utf8 in exported + @test :ascii_lc_isequal in exported + + req = HTTP.Request("GET", "/") + HTTP.setheader(req, "X-Test" => "a") + @test HTTP.header(req, "X-Test") == "a" + HTTP.appendheader(req, "X-Test" => "b") + @test HTTP.headers(req, "X-Test") == ["a", "b"] + HTTP.removeheader(req, "X-Test") + @test HTTP.header(req, "X-Test") == "" + @test HTTP.nobody isa Vector{UInt8} + @test isempty(HTTP.nobody) + @test isdefined(HTTP, :streamhandler) + + @test_deprecated HTTP.escape("a b") == "a%20b" + + @testset "download" begin + server = HTTP.serve!(req -> HTTP.Response(200, ["Content-Disposition" => "attachment; filename=\"hello.txt\""], "hello"); listenany=true) + try + port = HTTP.port(server) + mktempdir() do dir + file = HTTP.download("http://127.0.0.1:$port/hello.txt", dir) + @test isfile(file) + @test basename(file) == "hello.txt" + @test String(read(file)) == "hello" + end + finally + close(server) + end + end +end # testset diff --git a/test/websockets/autobahn.jl b/test/websockets/autobahn.jl index 5b373c7ca..7a62cba32 100644 --- a/test/websockets/autobahn.jl +++ b/test/websockets/autobahn.jl @@ -1,4 +1,4 @@ -using Test, Sockets, HTTP, HTTP.WebSockets, JSON +using Test, HTTP, HTTP.WebSockets, JSON const DIR = abspath(joinpath(dirname(pathof(HTTP)), "../test/websockets")) @@ -96,4 +96,4 @@ end end # @testset "WebSockets" -end # 64-bit only \ No newline at end of file +end # 64-bit only diff --git a/test/websockets/deno_client/server.jl b/test/websockets/deno_client/server.jl index 00ef3c98d..d5ee5ec35 100644 --- a/test/websockets/deno_client/server.jl +++ b/test/websockets/deno_client/server.jl @@ -1,4 +1,4 @@ -using Test, Sockets, Deno_jll, HTTP +using Test, Deno_jll, HTTP # Not all architectures have a Deno_jll hasproperty(Deno_jll, :deno) && @testset "WebSocket server" begin @@ -25,4 +25,4 @@ hasproperty(Deno_jll, :deno) && @testset "WebSocket server" begin finally close(server) end -end \ No newline at end of file +end diff --git a/test/websockets_basic.jl b/test/websockets_basic.jl new file mode 100644 index 000000000..3dcb0a1ac --- /dev/null +++ b/test/websockets_basic.jl @@ -0,0 +1,199 @@ +using Test +using HTTP +using HTTP.WebSockets + +@testset "WebSockets ping/pong" begin + server = HTTP.WebSockets.serve!("127.0.0.1", 0; listenany=true) do ws + msg = receive(ws) + send(ws, msg) + end + port = HTTP.port(server) + try + WebSockets.open("ws://127.0.0.1:$port") do ws + @test_nowarn WebSockets.ping(ws) + @test_nowarn WebSockets.pong(ws) + send(ws, "ok") + @test receive(ws) == "ok" + end + finally + close(server) + end +end + +@testset "WebSockets fragmentation and close" begin + server = HTTP.WebSockets.serve!("127.0.0.1", 0; listenany=true) do ws + msg = receive(ws) + send(ws, msg) + WebSockets.close(ws, WebSockets.CloseFrameBody(1000, "bye")) + end + port = HTTP.port(server) + try + WebSockets.open("ws://127.0.0.1:$port") do ws + send(ws, ["hel", "lo"]) + @test receive(ws) == "hello" + try + receive(ws) + @test false + catch e + @test e isa WebSockets.WebSocketError + @test WebSockets.isok(e) + @test e.message.code == 1000 + end + end + finally + close(server) + end +end + +@testset "WebSockets listen!" begin + server = WebSockets.listen!("127.0.0.1", 0; listenany=true) do ws + send(ws, "hi") + WebSockets.close(ws, WebSockets.CloseFrameBody(1000, "bye")) + end + port = HTTP.port(server) + try + WebSockets.open("ws://127.0.0.1:$port") do ws + @test receive(ws) == "hi" + try + receive(ws) + @test false + catch e + @test e isa WebSockets.WebSocketError + @test WebSockets.isok(e) + end + end + finally + close(server) + end +end + +@testset "WebSockets upgrade via HTTP.listen!" begin + server = HTTP.listen!("127.0.0.1", 0; listenany=true) do stream + if WebSockets.isupgrade(stream) + WebSockets.upgrade(stream) do ws + send(ws, "pong") + WebSockets.close(ws, WebSockets.CloseFrameBody(1000, "bye")) + end + else + HTTP.setstatus(stream, 404) + HTTP.startwrite(stream) + write(stream, "nope") + end + end + port = HTTP.port(server) + try + WebSockets.open("ws://127.0.0.1:$port") do ws + @test receive(ws) == "pong" + try + receive(ws) + @test false + catch e + @test e isa WebSockets.WebSocketError + @test WebSockets.isok(e) + @test e.message.code == 1000 + end + end + finally + close(server) + end +end + +@testset "WebSockets max frame size" begin + server = WebSockets.listen!("127.0.0.1", 0; listenany=true) do ws + send(ws, "0123456789") + end + port = HTTP.port(server) + err = nothing + try + WebSockets.open("ws://127.0.0.1:$port"; maxframesize=5, suppress_close_error=true) do ws + receive(ws) + end + catch e + err = e + finally + close(server) + end + @test err isa WebSockets.WebSocketError + @test err.message.code == 1009 +end + +@testset "WebSockets max fragmentation" begin + server = WebSockets.listen!("127.0.0.1", 0; listenany=true) do ws + send(ws, ["a", "b", "c"]) + end + port = HTTP.port(server) + err = nothing + try + WebSockets.open("ws://127.0.0.1:$port"; maxfragmentation=2, suppress_close_error=true) do ws + receive(ws) + end + catch e + err = e + finally + close(server) + end + @test err isa WebSockets.WebSocketError + @test err.message.code == 1009 +end + +@testset "WebSockets handshake accept validation" begin + server = HTTP.listen!("127.0.0.1", 0; listenany=true) do stream + HTTP.startread(stream) + HTTP.setstatus(stream, 101) + HTTP.setheader(stream, "Upgrade" => "websocket") + HTTP.setheader(stream, "Connection" => "Upgrade") + HTTP.setheader(stream, "Sec-WebSocket-Accept" => "invalid") + HTTP.startwrite(stream) + HTTP.closewrite(stream) + end + port = HTTP.port(server) + err = nothing + try + WebSockets.open("ws://127.0.0.1:$port"; suppress_close_error=true) do ws + end + catch e + err = e + finally + close(server) + end + @test err isa WebSockets.WebSocketError + @test err.message.code == 1002 +end + +@testset "WebSockets invalid close status" begin + server = WebSockets.listen!("127.0.0.1", 0; listenany=true) do ws + WebSockets.close(ws, WebSockets.CloseFrameBody(1005, "bad")) + end + port = HTTP.port(server) + err = nothing + try + WebSockets.open("ws://127.0.0.1:$port"; suppress_close_error=true) do ws + receive(ws) + end + catch e + err = e + finally + close(server) + end + @test err isa WebSockets.WebSocketError + @test err.message.code == 1002 +end + +@testset "WebSockets invalid UTF-8 text" begin + server = WebSockets.listen!("127.0.0.1", 0; listenany=true) do ws + WebSockets.writeframe(ws, true, WebSockets.TEXT, UInt8[0xff]) + end + port = HTTP.port(server) + err = nothing + try + WebSockets.open("ws://127.0.0.1:$port"; suppress_close_error=true) do ws + receive(ws) + end + catch e + err = e + finally + close(server) + end + @test err isa WebSockets.WebSocketError + @test err.message.code == 1007 +end