From 066fd5360c9b356a115925cddc0df7cc49c8b7d6 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sat, 24 Jan 2026 21:18:37 -0700 Subject: [PATCH 01/56] feat(core): restore client parity baseline Add HTTPVersion and header helper APIs. Implement request options: copyheaders, basicauth, canonicalize_headers, detect_content_type, proxy/pool, logerrors/logtag. Fix request/body handling, stream error propagation, websocket control frames, and header duplication. Add tests for headers, httpversion, client options, and websocket ping/pong. --- src/HTTP.jl | 1 + src/client/client.jl | 11 ++ src/client/makerequest.jl | 67 ++++++++++- src/client/request.jl | 58 +++++++--- src/client/retry.jl | 30 ++++- src/client/stream.jl | 3 +- src/requestresponse.jl | 230 ++++++++++++++++++++++++++++++-------- src/utils.jl | 70 +++++++++++- src/websockets.jl | 61 +++++----- test/client.jl | 40 ++++++- test/headers.jl | 14 +++ test/runtests.jl | 3 + test/utils.jl | 4 +- test/websockets_basic.jl | 21 ++++ 14 files changed, 517 insertions(+), 96 deletions(-) create mode 100644 test/headers.jl create mode 100644 test/websockets_basic.jl diff --git a/src/HTTP.jl b/src/HTTP.jl index 8f8c50a0..57244f57 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -4,6 +4,7 @@ using CodecZlib, URIs, Mmap, Base64, Dates, Sockets using LibAwsCommon, LibAwsIO, LibAwsHTTPFork import LibAwsCommon: Future, FieldRef +export HTTPVersion export @logfmt_str, common_logfmt, combined_logfmt export WebSockets diff --git a/src/client/client.jl b/src/client/client.jl index 51d3ef16..257002fd 100644 --- a/src/client/client.jl +++ b/src/client/client.jl @@ -226,6 +226,13 @@ end Clients() = Clients(ReentrantLock(), Dict{ClientSettings, Client}()) +struct Pool + clients::Clients + max_connections::Union{Nothing, Int} +end + +Pool(max_connections::Union{Int, Nothing}=nothing) = Pool(Clients(), max_connections) + const CLIENTS = Clients() function getclient(key::ClientSettings, clients::Clients=CLIENTS) @@ -240,6 +247,8 @@ function getclient(key::ClientSettings, clients::Clients=CLIENTS) end end +getclient(key::ClientSettings, pool::Pool) = getclient(key, pool.clients) + function close_all_clients!(clients::Clients=CLIENTS) Base.@lock clients.lock begin for client in values(clients.clients) @@ -248,3 +257,5 @@ function close_all_clients!(clients::Clients=CLIENTS) empty!(clients.clients) end end + +close_all_clients!(pool::Pool) = close_all_clients!(pool.clients) diff --git a/src/client/makerequest.jl b/src/client/makerequest.jl index 29836356..e1ce21ef 100644 --- a/src/client/makerequest.jl +++ b/src/client/makerequest.jl @@ -7,21 +7,50 @@ head(a...; kw...) = request("HEAD", a...; kw...) options(a...; kw...) = request("OPTIONS", a...; kw...) const COOKIEJAR = CookieJar() +const DEFAULT_PROXY = Symbol("__HTTP_DEFAULT_PROXY__") _something(x, y) = x === nothing ? y : x +# proxy keyword handling +function proxy_kwargs(proxy, req_scheme) + if proxy === DEFAULT_PROXY + return NamedTuple() + elseif proxy === nothing || proxy === false + return (proxy_allow_env_var=false,) + elseif proxy isa AbstractString || proxy isa URI + p = proxy isa URI ? proxy : URI(String(proxy)) + isempty(p.host) && throw(ArgumentError("proxy URL must include a host")) + port = isempty(p.port) ? (p.scheme == "https" ? 443 : 80) : parse(Int, p.port) + conn_type = req_scheme in ("https", "wss") ? :tunnel : :forward + return (proxy_allow_env_var=false, proxy_host=p.host, proxy_port=UInt32(port), proxy_connection_type=conn_type) + else + throw(ArgumentError("proxy must be a URL String, URI, nothing, or false")) + end +end + # main entrypoint for making an HTTP request # can provide method, url, headers, body, along with various keyword arguments -function request(method, url, h=Header[], b::RequestBodyTypes=nothing; +function request(method, url, h=Header[], b=nothing; allocator=default_aws_allocator(), headers=h, - body::RequestBodyTypes=b, + body=b, chunkedbody=nothing, + copyheaders::Bool=true, + canonicalize_headers::Bool=false, + detect_content_type::Bool=false, username=nothing, password=nothing, bearer=nothing, query=nothing, client::Union{Nothing, Client}=nothing, + basicauth::Bool=true, + proxy=DEFAULT_PROXY, + pool=nothing, + logerrors::Bool=false, + logtag=nothing, + observelayers::Bool=false, + retry_check=nothing, + retry_delays=nothing, # redirect options redirect=true, redirect_limit=3, @@ -41,14 +70,39 @@ function request(method, url, h=Header[], b::RequestBodyTypes=nothing; verbose=0, # only client keywords in catch-all kw...) + if chunkedbody === nothing && body !== nothing && !(body isa RequestBodyTypes) && Base.isiterable(typeof(body)) + chunkedbody = body + body = nothing + end + headers = mkreqheaders(headers, copyheaders) uri = parseuri(url, query, allocator) + proxy_kw = proxy_kwargs(proxy, scheme(uri)) + client_kw = (; allocator=allocator, kw...) + if pool isa Pool && pool.max_connections !== nothing && !haskey(client_kw, :max_connections) + client_kw = merge(client_kw, (; max_connections=pool.max_connections)) + end + if !isempty(proxy_kw) + client_kw = merge(client_kw, proxy_kw) + end + authinfo = (username !== nothing && password !== nothing) ? "$username:$password" : userinfo(uri) + apply_basicauth = (username !== nothing && password !== nothing) ? true : basicauth return with_redirect(allocator, method, uri, headers, body, redirect, redirect_limit, redirect_method, forwardheaders) do method, uri, headers, body - reqclient = @something(client, getclient(ClientSettings(scheme(uri), host(uri), getport(uri); allocator=allocator, kw...)))::Client - with_retry_token(reqclient) do + reqclient = @something( + client, + pool === nothing ? + getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...)) : + getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...), pool) + )::Client + with_retry_token(reqclient; logerrors=logerrors, logtag=logtag, method=method, uri=uri, retry_check=retry_check, retry_delays=retry_delays) do resp = with_connection(reqclient) do conn http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 path = resource(uri) - with_request(reqclient, method, path, headers, body, chunkedbody, decompress, (username !== nothing && password !== nothing) ? "$username:$password" : userinfo(uri), bearer, modifier, http2, cookies, cookiejar, verbose) do req + with_request(reqclient, method, path, headers, body, chunkedbody, decompress, authinfo, bearer, modifier, http2, cookies, cookiejar, verbose; + copyheaders=false, + canonicalize_headers=canonicalize_headers, + detect_content_type=detect_content_type, + basicauth=apply_basicauth, + ) do req if response_body isa AbstractVector{UInt8} ref = Ref(1) GC.@preserve ref begin @@ -65,6 +119,9 @@ function request(method, url, h=Header[], b::RequestBodyTypes=nothing; end # status error check if status_exception && iserror(resp) + if logerrors + @error "HTTP StatusError" method=method url=makeuri(uri) status=resp.status logtag=logtag + end throw(StatusError(method, uri, resp)) end return resp diff --git a/src/client/request.jl b/src/client/request.jl index 1fc573f9..1a97f224 100644 --- a/src/client/request.jl +++ b/src/client/request.jl @@ -1,4 +1,7 @@ -const USER_AGENT = Ref{Union{String, Nothing}}("HTTP.jl/$VERSION") +const DEFAULT_USER_AGENT = let v = try Base.pkgversion(@__MODULE__) catch; nothing end + v === nothing ? "HTTP.jl/dev" : "HTTP.jl/$(v)" +end +const USER_AGENT = Ref{Union{String, Nothing}}(DEFAULT_USER_AGENT) """ setuseragent!(x::Union{String, Nothing}) @@ -12,9 +15,31 @@ function setuseragent!(x::Union{String, Nothing}) return end -function with_request(f::Function, client::Client, method, path, headers=nothing, body=nothing, chunkedbody=nothing, decompress::Union{Nothing, Bool}=nothing, userinfo=nothing, bearer=nothing, modifier=nothing, http2::Bool=false, cookies=true, cookiejar=COOKIEJAR, verbose=false) +function with_request( + f::Function, + client::Client, + method, + path, + headers=nothing, + body=nothing, + chunkedbody=nothing, + decompress::Union{Nothing, Bool}=nothing, + userinfo=nothing, + bearer=nothing, + modifier=nothing, + http2::Bool=false, + cookies=true, + cookiejar=COOKIEJAR, + verbose=false; + copyheaders::Bool=true, + canonicalize_headers::Bool=false, + detect_content_type::Bool=false, + basicauth::Bool=true, +) # create request - req = Request(method, path, headers, nothing, http2, client.settings.allocator) + mutable_headers = (headers isa AbstractVector{<:Pair} && !copyheaders) ? headers : nothing + req_headers = mkreqheaders(headers, copyheaders) + req = Request(method, path, req_headers, nothing, http2, client.settings.allocator) # add headers to request h = req.headers if http2 @@ -24,11 +49,13 @@ function with_request(f::Function, client::Client, method, path, headers=nothing setheader(h, "host", client.settings.host) end setheaderifabsent(h, "accept", "*/*") - setheaderifabsent(h, "user-agent", something(USER_AGENT[], "-")) + if USER_AGENT[] !== nothing + setheaderifabsent(h, "user-agent", USER_AGENT[]) + end if decompress === nothing || decompress setheaderifabsent(h, "accept-encoding", "gzip") end - if userinfo !== nothing + if basicauth && userinfo !== nothing && !isempty(userinfo) setheaderifabsent(h, "authorization", "Basic $(base64encode(unescapeuri(userinfo)))") end if bearer !== nothing @@ -37,10 +64,8 @@ function with_request(f::Function, client::Client, method, path, headers=nothing if !http2 && chunkedbody !== nothing setheaderifabsent(h, "transfer-encoding", "chunked") end - if headers !== nothing - for (k, v) in headers - addheader(h, k, v) - end + if detect_content_type && !hasheader(h, "content-type") && !(body isa Form) && isbytes(body) + setheader(h, "content-type", sniff(body)) end if cookies === true || (cookies isa AbstractDict && !isempty(cookies)) cookiestosend = Cookies.getcookies!(cookiejar, client.settings.scheme, client.settings.host, req.path) @@ -62,16 +87,21 @@ function with_request(f::Function, client::Client, method, path, headers=nothing setinputstream!(req, body) end elseif body !== nothing - try - setinputstream!(req, body) - catch e - @error "Failed to set input stream" exception=(e, catch_backtrace()) - end + setinputstream!(req, body) + end + if canonicalize_headers && !http2 + canonicalizeheaders!(h) + end + if mutable_headers !== nothing + sync_headers!(mutable_headers, h) end # call user function verbose > 0 && print_request(stdout, req) ret = f(req) resp = getresponse(ret) + if canonicalize_headers + canonicalizeheaders!(resp.headers) + end verbose > 0 && print_response(stdout, resp) cookies === false || Cookies.setcookies!(cookiejar, client.settings.scheme, client.settings.host, req.path, resp.headers) return ret diff --git a/src/client/retry.jl b/src/client/retry.jl index 94b89e7b..93292f2a 100644 --- a/src/client/retry.jl +++ b/src/client/retry.jl @@ -35,9 +35,28 @@ function c_retry_ready(token, error_code::Cint, fut_ptr) return end -function with_retry_token(f::Function, client::Client) +function with_retry_token( + f::Function, + client::Client; + logerrors::Bool=false, + logtag=nothing, + method=nothing, + uri=nothing, + retry_check=nothing, + retry_delays=nothing, +) # If max_retries is 0, we don't need to bother with any retrying - client.settings.max_retries == 0 && return f() + if client.settings.max_retries == 0 + try + return f() + catch e + if logerrors + url = uri === nothing ? nothing : (uri isa aws_uri ? makeuri(uri) : uri) + @error "HTTP request error" exception=(e, catch_backtrace()) method=method url=url logtag=logtag + end + rethrow() + end + end retry_partition = client.settings.retry_partition === nothing ? C_NULL : aws_byte_cursor_from_c_str(client.settings.retry_partition) fut = Future{Ptr{aws_retry_token}}() GC.@preserve fut begin @@ -58,6 +77,11 @@ function with_retry_token(f::Function, client::Client) stream = e.stream e = e.error end + if logerrors + log_err = e isa DontRetry ? e.error : e + url = uri === nothing ? nothing : (uri isa aws_uri ? makeuri(uri) : uri) + @error "HTTP request error" exception=(log_err, catch_backtrace()) method=method url=url logtag=logtag + end if e isa DontRetry if stream !== nothing && iserror(stream.response.status) && stream.bufferstream !== nothing # for error responses, we need to commit the temporary body buffer @@ -86,4 +110,4 @@ function with_retry_token(f::Function, client::Client) finally aws_retry_token_release(token) end -end \ No newline at end of file +end diff --git a/src/client/stream.jl b/src/client/stream.jl index 90dfe5e9..24219953 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -199,6 +199,7 @@ function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, while !eof(stream.bufferstream) on_stream_response_body(resp, _readavailable(stream.bufferstream)) end + wait(stream.fut) catch e rethrow(DontRetry(e)) end @@ -226,4 +227,4 @@ function _readavailable(this::Base.BufferStream) take!(buf) end return bytes -end \ No newline at end of file +end diff --git a/src/requestresponse.jl b/src/requestresponse.jl index 8b97edef..3a29be96 100644 --- a/src/requestresponse.jl +++ b/src/requestresponse.jl @@ -30,6 +30,8 @@ mutable struct Headers <: AbstractVector{Header} Headers(ptr::Ptr{aws_http_headers}) = new(ptr) end +abstract type Message end + Base.size(h::Headers) = (Int(aws_http_headers_count(h.ptr)),) function Base.getindex(h::Headers, i::Int) @@ -78,8 +80,136 @@ Base.empty!(h::Headers) = aws_http_headers_clear(h.ptr) != 0 && aws_throw_error( setheaderifabsent(headers, k, v) = !hasheader(headers, k) && setheader(headers, k, v) +field_name_isequal(a, b) = headereq(String(a), String(b)) + +Base.getindex(m::Message, k) = header(m, k) + +""" + HTTP.header(::Message, key [, default=""]) -> String + +Get header value for `key` (case-insensitive). +""" +header(m::Message, k, d="") = header(m.headers, k, d) +header(h::Headers, k, d="") = (v = getheader(h, String(k)); v === nothing ? d : v) +header(h::AbstractVector{<:Pair}, k, d="") = begin + for (name, value) in h + if field_name_isequal(name, k) + return String(value) + end + end + return d +end + +""" + HTTP.headers(m::Message, key) -> Vector{String} + +Get all headers with key `k` or empty if none. +""" +headers(h::Headers, k) = [h2.value for h2 in h if field_name_isequal(h2.name, k)] +headers(h::AbstractVector{<:Pair}, k) = [String(v) for (name, v) in h if field_name_isequal(name, k)] +headers(m::Message, k) = headers(m.headers, k) + +""" + HTTP.hasheader(::Message, key) -> Bool + +Does header value for `key` exist (case-insensitive)? +""" +hasheader(m::Message, k) = header(m, k) != "" +hasheader(m::Message, k, v) = field_name_isequal(header(m, k), v) + +""" + HTTP.headercontains(::Message, key, value) -> Bool + +Does the header for `key` (interpreted as comma-separated list) contain `value` (case-insensitive)? +""" +headercontains(m::Message, k, v) = any(field_name_isequal.(strip.(split(header(m, k), ",")), v)) +headercontains(h::Headers, k, v) = any(field_name_isequal.(strip.(split(header(h, k), ",")), v)) +headercontains(h::AbstractVector{<:Pair}, k, v) = any(field_name_isequal.(strip.(split(header(h, k), ",")), v)) + +""" + HTTP.setheader(::Message, key => value) + +Set header `value` for `key` (case-insensitive). +""" +setheader(m::Message, v) = setheader(m.headers, v) +setheader(h::Headers, v::Header) = setheader(h, v.name, v.value) +setheader(h::Headers, v::Pair) = setheader(h, String(v.first), String(v.second)) +function setheader(h::AbstractVector{<:Pair}, v::Pair) + key = String(v.first) + value = String(v.second) + for i in eachindex(h) + if field_name_isequal(h[i].first, key) + h[i] = key => value + return h + end + end + push!(h, key => value) + return h +end + +""" + defaultheader!(::Message, key => value) + +Set header `value` in message for `key` if it is not already set. +""" +function defaultheader!(m, v::Pair) + if header(m, first(v), nothing) === nothing + setheader(m, v) + end + return +end + +function canonicalizeheaders!(h::Headers) + items = [(h2.name, h2.value) for h2 in h] + for i in length(h):-1:1 + deleteat!(h, i) + end + for (k, v) in items + addheader(h, tocameldash(k), v) + end + return h +end + +canonicalizeheaders(h::AbstractVector{<:Pair}) = + [tocameldash(String(k)) => String(v) for (k, v) in h] + +mkheaders(::Nothing) = Pair{String, String}[] +mkheaders(h::Headers) = [h2.name => h2.value for h2 in h] +mkheaders(h::AbstractVector{Header}) = begin + headers = Pair{String, String}[] + for head in h + push!(headers, String(head.name) => String(head.value)) + end + return headers +end +mkheaders(h::AbstractVector{<:Pair}) = [String(k) => String(v) for (k, v) in h] +function mkheaders(h) + headers = Pair{String, String}[] + for (k, v) in h + push!(headers, String(k) => String(v)) + end + return headers +end + +function mkreqheaders(h, copyheaders::Bool) + if h === nothing + return Pair{String, String}[] + elseif h isa AbstractVector{<:Pair} && !copyheaders + return h + else + return mkheaders(h) + end +end + +function sync_headers!(dest::AbstractVector{<:Pair}, src::Headers) + empty!(dest) + for h in src + push!(dest, String(h.name) => String(h.value)) + end + return dest +end + # request/response -abstract type Message end mutable struct InputStream ptr::Ptr{aws_input_stream} @@ -91,53 +221,52 @@ end ischunked(is::InputStream) = is.ptr == C_NULL && is.bodyref !== nothing -const RequestBodyTypes = Union{AbstractString, AbstractVector{UInt8}, IO, AbstractDict, NamedTuple, Nothing} +const RequestBodyTypes = Union{AbstractString, AbstractVector{UInt8}, IO, AbstractDict, NamedTuple, Form, Nothing} -function InputStream(allocator::Ptr{aws_allocator}, body::RequestBodyTypes) +function InputStream(allocator::Ptr{aws_allocator}, body) is = InputStream() if body !== nothing - if body isa RequestBodyTypes - if (body isa AbstractVector{UInt8}) || (body isa AbstractString) - is.bodyref = body - is.bodycursor = aws_byte_cursor(sizeof(body), pointer(body)) - is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) - elseif body isa Union{AbstractDict, NamedTuple} - # hold a reference to the request body in order to gc-preserve it - is.bodyref = URIs.escapeuri(body) - is.bodycursor = aws_byte_cursor_from_c_str(is.bodyref) - is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) - elseif body isa IOStream - is.bodyref = body - is.ptr = aws_input_stream_new_from_open_file(allocator, Libc.FILE(body)) - elseif body isa Form - # we set the request.body to the Form bytes in order to gc-preserve them - is.bodyref = read(body) - is.bodycursor = aws_byte_cursor(sizeof(is.bodyref), pointer(is.bodyref)) - is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) - elseif body isa IO - # we set the request.body to the IO bytes in order to gc-preserve them - bytes = readavailable(body) - while !eof(body) - append!(bytes, readavailable(body)) - end - is.bodyref = bytes - is.bodycursor = aws_byte_cursor(sizeof(is.bodyref), pointer(is.bodyref)) - is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) - else - throw(ArgumentError("request body must be a string, vector of UInt8, NamedTuple, AbstractDict, HTTP.Form, or IO")) + if (body isa AbstractVector{UInt8}) || (body isa AbstractString) + is.bodyref = body + is.bodycursor = aws_byte_cursor(sizeof(body), pointer(body)) + is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) + elseif body isa Union{AbstractDict, NamedTuple} + # hold a reference to the request body in order to gc-preserve it + is.bodyref = URIs.escapeuri(body) + is.bodycursor = aws_byte_cursor_from_c_str(is.bodyref) + is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) + elseif body isa IOStream + is.bodyref = body + is.ptr = aws_input_stream_new_from_open_file(allocator, Libc.FILE(body)) + elseif body isa Form + # we set the request.body to the Form bytes in order to gc-preserve them + is.bodyref = read(body) + is.bodycursor = aws_byte_cursor(sizeof(is.bodyref), pointer(is.bodyref)) + is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) + elseif body isa IO + # we set the request.body to the IO bytes in order to gc-preserve them + bytes = readavailable(body) + while !eof(body) + append!(bytes, readavailable(body)) end + is.bodyref = bytes + is.bodycursor = aws_byte_cursor(sizeof(is.bodyref), pointer(is.bodyref)) + is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) + elseif Base.isiterable(typeof(body)) + # assume a chunked request body; any kind of iterable where elements are RequestBodyTypes + is.bodyref = body + else + throw(ArgumentError("request body must be a string, vector of UInt8, NamedTuple, AbstractDict, HTTP.Form, IO, or an iterable of those")) + end + if is.ptr != C_NULL aws_input_stream_get_length(is.ptr, FieldRef(is, :bodylen)) != 0 && aws_throw_error() if !(is.bodylen > 0) aws_input_stream_release(is.ptr) is.ptr = C_NULL end - else - # assume a chunked request body; any kind of iterable where elements are RequestBodyTypes - @assert Base.isiterable(typeof(body)) "chunked request body must be an iterable" - is.bodyref = body end end - return finalizer(x -> aws_input_stream_release(x.ptr), is) + return finalizer(x -> x.ptr != C_NULL && aws_input_stream_release(x.ptr), is) end function setinputstream!(msg::Message, body) @@ -216,7 +345,7 @@ function Base.getproperty(x::Request, s::Symbol) elseif s == :headers return Headers(aws_http_message_get_headers(ptr(x))) elseif s == :version - return aws_http_message_get_protocol_version(ptr(x)) == AWS_HTTP_VERSION_2 ? "2" : "1.1" + return aws_http_message_get_protocol_version(ptr(x)) == AWS_HTTP_VERSION_2 ? HTTPVersion(2, 0) : HTTPVersion(1, 1) else return getfield(x, s) end @@ -255,7 +384,13 @@ end function print_request(io, method, version, path, headers, body) write(io, "\"\"\"\n") - write(io, string(method, " ", path, " HTTP/$version\r\n")) + write(io, string(method, " ", path, " ")) + if version isa HTTPVersion + write(io, version) + else + write(io, "HTTP/", string(version)) + end + write(io, "\r\n") for h in headers print_header(io, h) end @@ -298,7 +433,7 @@ mutable struct Response <: Message inputstream::Union{Nothing, InputStream} body::Union{Nothing, Vector{UInt8}} # only set for client-side response body when no user-provided response_body metrics::RequestMetrics - request::Request + request::Union{Request, Nothing} function Response(status::Integer, headers, body, http2::Bool=false, allocator=default_aws_allocator()) ptr = http2 ? @@ -316,6 +451,8 @@ mutable struct Response <: Message resp = new(allocator, ptr) resp.body = nothing resp.inputstream = nothing + resp.metrics = RequestMetrics() + resp.request = nothing body !== nothing && setinputstream!(resp, body) return finalizer(_ -> aws_http_message_release(ptr), resp) catch @@ -323,7 +460,7 @@ mutable struct Response <: Message rethrow() end end - Response() = new(C_NULL, C_NULL, nothing, nothing) + Response() = new(C_NULL, C_NULL, nothing, nothing, RequestMetrics(), nothing) end Response(status::Integer, body) = Response(status, nothing, Vector{UInt8}(string(body))) @@ -341,7 +478,7 @@ function Base.getproperty(x::Response, s::Symbol) elseif s == :headers return Headers(aws_http_message_get_headers(x.ptr)) elseif s == :version - return aws_http_message_get_protocol_version(x.ptr) == AWS_HTTP_VERSION_2 ? "2" : "1.1" + return aws_http_message_get_protocol_version(x.ptr) == AWS_HTTP_VERSION_2 ? HTTPVersion(2, 0) : HTTPVersion(1, 1) else return getfield(x, s) end @@ -359,7 +496,12 @@ end function print_response(io, status, version, headers, body) write(io, "\"\"\"\n") - write(io, string("HTTP/$version ", status, "\r\n")) + if version isa HTTPVersion + write(io, version) + else + write(io, "HTTP/", string(version)) + end + write(io, " ", string(status), "\r\n") for h in headers print_header(io, h) end @@ -413,4 +555,4 @@ Does this `Response` have a redirect status? isredirect(r::Response) = isredirect(r.status) isredirect(status::Integer) = status in (301, 302, 303, 307, 308) -Forms.parse_multipart_form(m::Message) = parse_multipart_form(getheader(m.headers, "content-type"), m.body) \ No newline at end of file +Forms.parse_multipart_form(m::Message) = parse_multipart_form(getheader(m.headers, "content-type"), m.body) diff --git a/src/utils.jl b/src/utils.jl index 125ad21d..22e94a23 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,71 @@ +""" + HTTPVersion(major, minor) + +The HTTP version number consists of two digits separated by a ".". The first +digit (`major`) indicates the HTTP messaging syntax, whereas the second digit +(`minor`) indicates the highest minor version within that major version to which +the sender is conformant and able to understand for future communication. +""" +struct HTTPVersion + major::UInt8 + minor::UInt8 +end + +HTTPVersion(major::Integer) = HTTPVersion(major, 0x00) +HTTPVersion(v::AbstractString) = parse(HTTPVersion, v) +HTTPVersion(v::VersionNumber) = convert(HTTPVersion, v) +Base.convert(::Type{HTTPVersion}, v::VersionNumber) = HTTPVersion(v.major, v.minor) +Base.VersionNumber(v::HTTPVersion) = VersionNumber(v.major, v.minor) + +Base.show(io::IO, v::HTTPVersion) = print(io, "HTTPVersion(\"", string(v.major), ".", string(v.minor), "\")") +Base.write(io::IO, v::HTTPVersion) = write(io, "HTTP/", string(v.major), ".", string(v.minor)) + +Base.:(==)(va::VersionNumber, vb::HTTPVersion) = va == VersionNumber(vb) +Base.:(==)(va::HTTPVersion, vb::VersionNumber) = VersionNumber(va) == vb +Base.isless(va::VersionNumber, vb::HTTPVersion) = isless(va, VersionNumber(vb)) +Base.isless(va::HTTPVersion, vb::VersionNumber) = isless(VersionNumber(va), vb) +function Base.isless(va::HTTPVersion, vb::HTTPVersion) + va.major < vb.major && return true + va.major > vb.major && return false + va.minor < vb.minor && return true + return false +end + +function Base.parse(::Type{HTTPVersion}, v::AbstractString) + ver = tryparse(HTTPVersion, v) + ver === nothing && throw(ArgumentError("invalid HTTP version string: $(repr(v))")) + return ver +end + +# We only support single-digits for major and minor versions. +function Base.tryparse(::Type{HTTPVersion}, v::AbstractString) + isempty(v) && return nothing + len = ncodeunits(v) + + i = firstindex(v) + d1 = v[i] + if isdigit(d1) + major = parse(UInt8, d1) + else + return nothing + end + + i = nextind(v, i) + i > len && return HTTPVersion(major) + dot = v[i] + dot == '.' || return nothing + + i = nextind(v, i) + i > len && return HTTPVersion(major) + d2 = v[i] + if isdigit(d2) + minor = parse(UInt8, d2) + else + return nothing + end + return HTTPVersion(major, minor) +end + """ escapehtml(i::String) @@ -69,7 +137,7 @@ function parseuri(url, query, allocator) end GC.@preserve url_str begin url_ref = Ref(aws_byte_cursor(sizeof(url_str), pointer(url_str))) - aws_uri_init_parse(uri_ref, allocator, url_ref) + aws_uri_init_parse(uri_ref, allocator, url_ref) != 0 && aws_throw_error() end return uri_ref[] end diff --git a/src/websockets.jl b/src/websockets.jl index 82e58ea2..ce1323a9 100644 --- a/src/websockets.jl +++ b/src/websockets.jl @@ -12,7 +12,7 @@ mutable struct WebSocket id::String host::String path::String - not::Future{Nothing} + connect_fut::Future{Nothing} readchannel::Channel{Union{String, Vector{UInt8}}} writebuffer::Vector{UInt8} writepos::Int @@ -29,6 +29,11 @@ end getresponse(ws::WebSocket) = ws.handshake_response +mutable struct SendState + ws::WebSocket + fut::Future{Nothing} +end + const on_connection_setup = Ref{Ptr{Cvoid}}(C_NULL) function c_on_connection_setup(connection_setup_data::Ptr{aws_websocket_on_connection_setup_data}, ws_ptr) @@ -36,7 +41,7 @@ function c_on_connection_setup(connection_setup_data::Ptr{aws_websocket_on_conne data = unsafe_load(connection_setup_data) try if data.error_code != 0 - notify(ws.not, CapturedException(aws_error(data.error_code), Base.backtrace())) + notify(ws.connect_fut, CapturedException(aws_error(data.error_code), Base.backtrace())) else ws.websocket_pointer = data.websocket ws.handshake_response.status = unsafe_load(data.handshake_response_status) @@ -48,10 +53,10 @@ function c_on_connection_setup(connection_setup_data::Ptr{aws_websocket_on_conne response_body = nothing end setinputstream!(ws.handshake_response, response_body) - notify(ws.not, nothing) + notify(ws.connect_fut, nothing) end catch e - notify(ws.not, CapturedException(e, Base.backtrace())) + notify(ws.connect_fut, CapturedException(e, Base.backtrace())) end return end @@ -167,7 +172,7 @@ function open(f::Function, url; aws_throw_error() end # wait until connected - wait(ws.not) + wait(ws.connect_fut) return ws end end @@ -229,7 +234,8 @@ end const stream_outgoing_payload = Ref{Ptr{Cvoid}}(C_NULL) function c_stream_outgoing_payload(websocket::Ptr{aws_websocket}, out_buf::Ptr{aws_byte_buf}, ws_ptr::Ptr{Cvoid}) - ws = unsafe_pointer_to_objref(ws_ptr) + state = unsafe_pointer_to_objref(ws_ptr) + ws = state.ws out = unsafe_load(out_buf) try space_available = out.capacity - out.len @@ -247,31 +253,34 @@ end const on_complete = Ref{Ptr{Cvoid}}(C_NULL) function c_on_complete(websocket::Ptr{aws_websocket}, error_code::Cint, ws_ptr::Ptr{Cvoid}) - ws = unsafe_pointer_to_objref(ws_ptr) + state = unsafe_pointer_to_objref(ws_ptr) if error_code != 0 - notify(ws.not, CapturedException(aws_error(error_code), Base.backtrace())) + notify(state.fut, CapturedException(aws_error(error_code), Base.backtrace())) end - notify(ws.not, nothing) + notify(state.fut, nothing) return end function writeframe(ws::WebSocket, fin::Bool, opcode::OpCode, payload) n = sizeof(payload) - ws.websocket_send_frame_options = aws_websocket_send_frame_options( - n % UInt64, - Ptr{Cvoid}(pointer_from_objref(ws)), # user_data - stream_outgoing_payload[], - on_complete[], - UInt8(opcode), - fin - ) - opts = pointer(FieldRef(ws, :websocket_send_frame_options)) - if aws_websocket_send_frame(ws.websocket_pointer, opts) != 0 - aws_throw_error() + state = SendState(ws, Future{Nothing}()) + GC.@preserve state begin + ws.websocket_send_frame_options = aws_websocket_send_frame_options( + n % UInt64, + Ptr{Cvoid}(pointer_from_objref(state)), # user_data + stream_outgoing_payload[], + on_complete[], + UInt8(opcode), + fin + ) + opts = pointer(FieldRef(ws, :websocket_send_frame_options)) + if aws_websocket_send_frame(ws.websocket_pointer, opts) != 0 + aws_throw_error() + end + # wait until frame sent + wait(state.fut) + return n end - # wait until frame sent - wait(ws.not) - return n end """ @@ -329,7 +338,7 @@ to when a PING message is received by a websocket connection. """ function ping(ws::WebSocket, data=UInt8[]) @assert !ws.writeclosed "WebSocket is closed" - return writeframe(ws.io, true, PING, payload(ws, data)) + return writeframe(ws, true, PING, payload(ws, data)) end """ @@ -343,7 +352,7 @@ used as a one-way heartbeat. """ function pong(ws::WebSocket, data=UInt8[]) @assert !ws.writeclosed "WebSocket is closed" - return writeframe(ws.io, true, PONG, payload(ws, data)) + return writeframe(ws, true, PONG, payload(ws, data)) end """ @@ -443,4 +452,4 @@ function __init__() return end -end # module \ No newline at end of file +end # module diff --git a/test/client.jl b/test/client.jl index d230212e..9f42fed3 100644 --- a/test/client.jl +++ b/test/client.jl @@ -134,6 +134,9 @@ @test isok(resp) # x = JSONBase.materialize(resp.body) # @test x["form"] == Dict("name" => ["value with spaces"]) + resp = HTTP.post("https://$httpbin/post"; body=["hey", " there ", "sailor"]) + @test isok(resp) + @test occursin("\"data\":\"hey there sailor\"", String(resp.body)) end @testset "ASync Client Request Body" begin @@ -195,6 +198,41 @@ @test isok(HTTP.get("http://$httpbin/delay/1"; readtimeout=2, max_retries=0)) end + @testset "Request Options Parity" begin + headers = ["X-Test" => "1"] + HTTP.get("https://$httpbin/headers"; headers=headers, copyheaders=true) + @test headers == ["X-Test" => "1"] + + headers2 = ["X-Test" => "1"] + HTTP.get("https://$httpbin/headers"; headers=headers2, copyheaders=false) + @test any(h -> lowercase(String(h.first)) == "accept", headers2) + @test any(h -> lowercase(String(h.first)) == "x-test", headers2) + + resp = HTTP.get("https://user:pwd@$httpbin/headers"; basicauth=false) + @test HTTP.getheader(resp.request.headers, "authorization") === nothing + + resp = HTTP.get("https://user:pwd@$httpbin/headers"; basicauth=true) + auth = HTTP.getheader(resp.request.headers, "authorization") + @test auth !== nothing && startswith(auth, "Basic ") + + resp = HTTP.post("https://$httpbin/anything"; body="hello", detect_content_type=true) + @test HTTP.getheader(resp.request.headers, "content-type") == "text/plain; charset=utf-8" + + orig_agent = HTTP.USER_AGENT[] + try + HTTP.setuseragent!(nothing) + resp = HTTP.get("https://$httpbin/headers") + @test HTTP.getheader(resp.request.headers, "user-agent") === nothing + finally + HTTP.setuseragent!(orig_agent) + end + + pool = HTTP.Pool(1) + @test isempty(pool.clients.clients) + HTTP.get("https://$httpbin/ip"; pool=pool) + @test !isempty(pool.clients.clients) + end + @testset "Public entry point of HTTP.request and friends (e.g. issue #463)" begin headers = Dict("User-Agent" => "HTTP.jl") query = Dict("hello" => "world") @@ -226,4 +264,4 @@ @test isok(HTTP.delete(uri, headers, body; query=query)) end end -end \ No newline at end of file +end diff --git a/test/headers.jl b/test/headers.jl new file mode 100644 index 00000000..c9f47653 --- /dev/null +++ b/test/headers.jl @@ -0,0 +1,14 @@ +@testset "Headers helpers" begin + h = HTTP.Headers() + HTTP.addheader(h, "x-test-header", "abc") + HTTP.addheader(h, "content-type", "text/plain") + + @test HTTP.header(h, "X-Test-Header") == "abc" + @test HTTP.header(h, "missing", "default") == "default" + @test HTTP.hasheader(h, "x-test-header") + @test HTTP.headercontains(h, "x-test-header", "abc") + + HTTP.canonicalizeheaders!(h) + @test any(x -> x.name == "X-Test-Header", h) + @test any(x -> x.name == "Content-Type", h) +end diff --git a/test/runtests.jl b/test/runtests.jl index 6e0adf10..97d34c38 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,8 +4,11 @@ const httpbin = get(ENV, "JULIA_TEST_HTTPBINGO_SERVER", "httpbingo.julialang.org isok(r) = r.status == 200 include("utils.jl") +include("headers.jl") +include("httpversion.jl") include("sniff.jl") include("multipart.jl") include("client.jl") include("handlers.jl") include("server.jl") +include("websockets_basic.jl") diff --git a/test/utils.jl b/test/utils.jl index e02a73a3..5fd4c47d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -29,4 +29,6 @@ ) @test HTTP.iso8859_1_to_utf8(bytes) == utf8 end -end # testset \ No newline at end of file + + @test_throws HTTP.AWSError HTTP.parseuri("http://example.com:abc", nothing, HTTP.default_aws_allocator()) +end # testset diff --git a/test/websockets_basic.jl b/test/websockets_basic.jl new file mode 100644 index 00000000..0ff355f1 --- /dev/null +++ b/test/websockets_basic.jl @@ -0,0 +1,21 @@ +using Test +using HTTP +using HTTP.WebSockets + +@testset "WebSockets ping/pong" begin + server = HTTP.WebSockets.serve!("127.0.0.1", 0; listenany=true) do ws + msg = receive(ws) + send(ws, msg) + end + port = HTTP.port(server) + try + WebSockets.open("ws://127.0.0.1:$port") do ws + @test_nowarn WebSockets.ping(ws) + @test_nowarn WebSockets.pong(ws) + send(ws, "ok") + @test receive(ws) == "ok" + end + finally + close(server) + end +end From 2b3e11eca9d8b0099c09d8a609ea27cf265b76fe Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sat, 24 Jan 2026 21:55:46 -0700 Subject: [PATCH 02/56] feat(client): add HTTP.open streaming support Implement HTTP.open with streaming IO for requests/responses. - Add stream IO methods, lifecycle control, and GC rooting - Prepare HTTP/1 chunked requests before stream creation - Add client streaming tests for GET and POST --- src/HTTP.jl | 1 + src/client/open.jl | 160 +++++++++++++++++++++++++++++++++ src/client/stream.jl | 207 +++++++++++++++++++++++++++++++++++++++++-- test/client.jl | 19 ++++ 4 files changed, 382 insertions(+), 5 deletions(-) create mode 100644 src/client/open.jl diff --git a/src/HTTP.jl b/src/HTTP.jl index 57244f57..ebcd5563 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -21,6 +21,7 @@ include("client/connection.jl") include("client/request.jl") include("client/stream.jl") include("client/makerequest.jl") +include("client/open.jl") include("websockets.jl"); using .WebSockets include("server.jl") include("handlers.jl"); using .Handlers diff --git a/src/client/open.jl b/src/client/open.jl new file mode 100644 index 00000000..02399971 --- /dev/null +++ b/src/client/open.jl @@ -0,0 +1,160 @@ +function _open_stream(conn::Ptr{aws_http_connection}, req::Request, decompress, readtimeout, allocator) + http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 + stream = Stream{Ptr{aws_http_connection}}(allocator, decompress, http2) + stream.bufferstream = Base.BufferStream() + stream.connection = conn + stream.request = req + stream.response = resp = Response(0, nothing, nothing, http2, allocator) + resp.request = req + GC.@preserve stream begin + stream.request_options = aws_http_make_request_options( + 1, + req.ptr, + pointer_from_objref(stream), + on_response_headers[], + on_response_header_block_done[], + on_response_body[], + on_metrics[], + on_complete[], + on_destroy[], + http2, # http2_use_manual_data_writes + readtimeout * 1000 # response_first_byte_timeout_ms + ) + stream_ptr = aws_http_connection_make_request(conn, FieldRef(stream, :request_options)) + stream_ptr == C_NULL && aws_throw_error() + stream.ptr = stream_ptr + end + retain_stream!(stream) + return stream +end + +function open(f::Function, method::Union{String, Symbol}, url, h=Header[]; + allocator=default_aws_allocator(), + headers=h, + copyheaders::Bool=true, + canonicalize_headers::Bool=false, + detect_content_type::Bool=false, + username=nothing, + password=nothing, + bearer=nothing, + query=nothing, + client::Union{Nothing, Client}=nothing, + basicauth::Bool=true, + proxy=DEFAULT_PROXY, + pool=nothing, + logerrors::Bool=false, + logtag=nothing, + observelayers::Bool=false, + retry_check=nothing, + retry_delays=nothing, + # redirect options + redirect=true, + redirect_limit=3, + redirect_method=nothing, + forwardheaders=true, + # cookie options + cookies=true, + cookiejar::CookieJar=COOKIEJAR, + # response options + decompress::Union{Nothing, Bool}=nothing, + status_exception::Bool=true, + readtimeout::Int=0, + modifier=nothing, + verbose=0, + # only client keywords in catch-all + kw...) + method_str = string(method) + headers = mkreqheaders(headers, copyheaders) + uri = parseuri(url, query, allocator) + count = 0 + while true + redirect_url = nothing + resp = nothing + proxy_kw = proxy_kwargs(proxy, scheme(uri)) + client_kw = (; allocator=allocator, kw...) + if pool isa Pool && pool.max_connections !== nothing && !haskey(client_kw, :max_connections) + client_kw = merge(client_kw, (; max_connections=pool.max_connections)) + end + if !isempty(proxy_kw) + client_kw = merge(client_kw, proxy_kw) + end + authinfo = (username !== nothing && password !== nothing) ? "$username:$password" : userinfo(uri) + apply_basicauth = (username !== nothing && password !== nothing) ? true : basicauth + reqclient = @something( + client, + pool === nothing ? + getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...)) : + getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...), pool) + )::Client + resp = with_connection(reqclient) do conn + http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 + path = resource(uri) + with_request(reqclient, method_str, path, headers, nothing, nothing, decompress, authinfo, bearer, modifier, http2, cookies, cookiejar, verbose; + copyheaders=false, + canonicalize_headers=canonicalize_headers, + detect_content_type=detect_content_type, + basicauth=apply_basicauth, + ) do req + if !http2 && + method_str in ("POST", "PUT", "PATCH") && + !hasheader(req.headers, "content-length") && + !hasheader(req.headers, "transfer-encoding") && + !hasheader(req.headers, "upgrade") + setheader(req.headers, "transfer-encoding", "chunked") + end + stream = _open_stream(conn, req, decompress, readtimeout, allocator) + if redirect && issafe(method_str) + resp = startread(stream) + if (count < redirect_limit && isredirect(resp) && (location = getheader(resp.headers, "Location")) != "") + redirect_url = location + closeread(stream) + return resp + end + end + err = nothing + try + f(stream) + catch e + err = e + finally + closewrite(stream) + end + resp = closeread(stream) + err === nothing || throw(err) + return resp + end + end + if redirect_url === nothing + if status_exception && iserror(resp) + if logerrors + @error "HTTP StatusError" method=method_str url=makeuri(uri) status=resp.status logtag=logtag + end + throw(StatusError(method_str, uri, resp)) + end + return resp + end + if count == redirect_limit + return resp + end + olduri = uri + newuri = resolvereference(makeuri(uri), redirect_url) + uri = parseuri(newuri, nothing, allocator) + method_str = newmethod(method_str, resp.status, redirect_method) + if forwardheaders + headers = filter(headers) do (header, _) + if headereq(String(header), "host") + return false + elseif any(x -> headereq(x, String(header)), SENSITIVE_HEADERS) && !isdomainorsubdomain(host(uri), host(olduri)) + return false + elseif method_str == "GET" && (headereq(String(header), "content-type") || headereq(String(header), "content-length")) + return false + else + return true + end + end + else + headers = Header[] + end + count += 1 + end +end diff --git a/src/client/stream.jl b/src/client/stream.jl index 24219953..940d80c9 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -25,6 +25,7 @@ function c_on_response_header_block_done(aws_stream_ptr, header_block, stream_pt val = getheader(stream.response.headers, "content-encoding") stream.decompress = val !== nothing && val == "gzip" end + notify(stream.headers_ready) return Cint(0) end @@ -79,6 +80,8 @@ function c_on_complete(aws_stream_ptr, error_code, stream_ptr) else notify(stream.fut, nothing) end + notify(stream.headers_ready) + release_stream!(stream) return end @@ -92,7 +95,7 @@ if !@isdefined aws_websocket_server_upgrade_options const aws_websocket_server_upgrade_options = Ptr{Cvoid} end -mutable struct Stream{T} +mutable struct Stream{T} <: IO allocator::Ptr{aws_allocator} decompress::Union{Nothing, Bool} http2::Bool @@ -102,6 +105,11 @@ mutable struct Stream{T} final_chunk_written::Bool bufferstream::Union{Nothing, Base.BufferStream} gzipstream::Union{Nothing, CodecZlib.GzipDecompressorStream} + headers_ready::Threads.Event + activated::Bool + write_started::Bool + read_started::Bool + released::Bool # remaining fields are initially undefined ptr::Ptr{aws_http_stream} connection::T # Connection{F, S} (in servers.jl) @@ -114,11 +122,58 @@ mutable struct Stream{T} http2_stream_write_data_options::aws_http2_stream_write_data_options chunk_options::aws_http1_chunk_options websocket_options::aws_websocket_server_upgrade_options - Stream{T}(allocator, decompress, http2) where {T} = new{T}(allocator, decompress, http2, 0, Future{Nothing}(), nothing, false, nothing, nothing) + Stream{T}(allocator, decompress, http2) where {T} = new{T}( + allocator, + decompress, + http2, + 0, + Future{Nothing}(), + nothing, + false, + nothing, + nothing, + Threads.Event(), + false, + false, + false, + false, + ) end Base.hash(s::Stream, h::UInt) = hash(s.ptr, h) +const ACTIVE_STREAMS_LOCK = ReentrantLock() +const ACTIVE_STREAMS = IdDict{Stream, Bool}() + +function retain_stream!(s::Stream) + lock(ACTIVE_STREAMS_LOCK) + try + ACTIVE_STREAMS[s] = true + finally + unlock(ACTIVE_STREAMS_LOCK) + end + return +end + +function release_stream!(s::Stream) + lock(ACTIVE_STREAMS_LOCK) + try + pop!(ACTIVE_STREAMS, s, nothing) + finally + unlock(ACTIVE_STREAMS_LOCK) + end + return +end + +function release_stream_ptr!(s::Stream) + if isdefined(s, :ptr) && s.ptr != C_NULL && !s.released + aws_http_stream_release(s.ptr) + s.released = true + s.ptr = Ptr{aws_http_stream}(C_NULL) + end + return +end + const on_stream_write_on_complete = Ref{Ptr{Cvoid}}(C_NULL) function c_on_stream_write_on_complete(aws_stream_ptr, error_code, fut_ptr) @@ -132,9 +187,11 @@ function c_on_stream_write_on_complete(aws_stream_ptr, error_code, fut_ptr) end function writechunk(s::Stream, chunk::RequestBodyTypes) - @assert (isdefined(s, :response) && - isdefined(s.response, :request) && - s.response.request.method in ("POST", "PUT", "PATCH")) "write is only allowed for POST, PUT, and PATCH requests" + if !(chunk isa AbstractString && isempty(chunk)) + @assert (isdefined(s, :response) && + isdefined(s.response, :request) && + s.response.request.method in ("POST", "PUT", "PATCH")) "write is only allowed for POST, PUT, and PATCH requests" + end s.chunk = InputStream(s.allocator, chunk) fut = Future{Nothing}() if s.http2 @@ -160,6 +217,144 @@ function writechunk(s::Stream, chunk::RequestBodyTypes) return s.chunk.bodylen end +function _activate_stream!(s::Stream) + if !s.activated + aws_http_stream_activate(s.ptr) != 0 && aws_throw_error() + s.activated = true + end + return +end + +function startwrite(s::Stream) + if s.write_started + return + end + if !s.http2 && + !hasheader(s.request.headers, "content-length") && + !hasheader(s.request.headers, "transfer-encoding") && + !hasheader(s.request.headers, "upgrade") + setheader(s.request.headers, "transfer-encoding", "chunked") + end + _activate_stream!(s) + s.write_started = true + return +end + +function closewrite(s::Stream) + if s.final_chunk_written + return + end + if s.http2 + _activate_stream!(s) + writechunk(s, "") + s.final_chunk_written = true + return + end + if s.write_started + writechunk(s, "") + s.final_chunk_written = true + elseif hasheader(s.request.headers, "transfer-encoding") + _activate_stream!(s) + writechunk(s, "") + s.final_chunk_written = true + else + _activate_stream!(s) + end + return +end + +function startread(s::Stream) + if s.read_started + return s.response + end + _activate_stream!(s) + s.http2 && !s.final_chunk_written && closewrite(s) + wait(s.headers_ready) + s.read_started = true + return s.response +end + +function Base.readavailable(s::Stream, n::Int=typemax(Int)) + startread(s) + if s.bufferstream === nothing + return UInt8[] + end + return _readavailable(s.bufferstream) +end + +function Base.read(s::Stream, n::Integer) + startread(s) + s.bufferstream === nothing && return UInt8[] + return read(s.bufferstream, n) +end + +function Base.read(s::Stream) + startread(s) + s.bufferstream === nothing && return UInt8[] + return read(s.bufferstream) +end + +function Base.read(s::Stream, ::Type{UInt8}) + data = Base.read(s, 1) + isempty(data) && throw(EOFError()) + return data[1] +end + +function Base.eof(s::Stream) + startread(s) + s.bufferstream === nothing && return true + return eof(s.bufferstream) +end + +function Base.unsafe_write(s::Stream, p::Ptr{UInt8}, n::UInt) + n == 0 && return 0 + buf = Vector{UInt8}(undef, n) + GC.@preserve buf unsafe_copyto!(pointer(buf), p, n) + Base.write(s, buf) + return n +end + +function Base.write(s::Stream, data::AbstractVector{UInt8}) + startwrite(s) + writechunk(s, data) + return length(data) +end + +function Base.write(s::Stream, data::Union{String, SubString{String}}) + startwrite(s) + writechunk(s, data) + return sizeof(data) +end + +function Base.write(s::Stream, data::AbstractString) + return Base.write(s, String(data)) +end + +function Base.write(s::Stream, b::UInt8) + startwrite(s) + writechunk(s, UInt8[b]) + return 1 +end + +function closeread(s::Stream) + startread(s) + try + wait(s.fut) + finally + release_stream_ptr!(s) + end + return s.response +end + +function Base.close(s::Stream) + try + closewrite(s) + finally + closeread(s) + end + return +end + function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator) stream = Stream{Nothing}(allocator, decompress, http2) if on_stream_response_body !== nothing @@ -214,6 +409,8 @@ function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, return resp finally aws_http_stream_release(stream_ptr) + stream.released = true + stream.ptr = Ptr{aws_http_stream}(C_NULL) end end # GC.@preserve end diff --git a/test/client.jl b/test/client.jl index 9f42fed3..0867b909 100644 --- a/test/client.jl +++ b/test/client.jl @@ -233,6 +233,25 @@ @test !isempty(pool.clients.clients) end + @testset "HTTP.open streaming" begin + resp = HTTP.open("GET", "https://$httpbin/stream/5") do io + r = HTTP.startread(io) + @test r.status == 200 + data = String(read(io)) + @test length(split(chomp(data), '\n')) == 5 + end + @test resp.status == 200 + + resp = HTTP.open("POST", "https://$httpbin/anything") do io + write(io, "hello") + HTTP.closewrite(io) + r = HTTP.startread(io) + data = String(read(io)) + @test occursin("\"data\":\"hello\"", data) + end + @test resp.status == 200 + end + @testset "Public entry point of HTTP.request and friends (e.g. issue #463)" begin headers = Dict("User-Agent" => "HTTP.jl") query = Dict("hello" => "world") From b7464e980e9b8803a1f9ae9c723d7abcbbb8b5c1 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sat, 24 Jan 2026 22:18:26 -0700 Subject: [PATCH 03/56] feat(server): add listen and stream handling Add listen/listen! wrappers and blocking serve. - Enable stream handlers with request-body streaming and chunked responses - Add server stream helpers (setstatus/setheader) and tests --- src/client/open.jl | 2 +- src/client/stream.jl | 166 ++++++++++++++++++++++++++++++++++++++++++- src/server.jl | 55 +++++++++++++- test/server.jl | 24 ++++++- 4 files changed, 240 insertions(+), 7 deletions(-) diff --git a/src/client/open.jl b/src/client/open.jl index 02399971..d26194c7 100644 --- a/src/client/open.jl +++ b/src/client/open.jl @@ -1,6 +1,6 @@ function _open_stream(conn::Ptr{aws_http_connection}, req::Request, decompress, readtimeout, allocator) http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 - stream = Stream{Ptr{aws_http_connection}}(allocator, decompress, http2) + stream = Stream{Ptr{aws_http_connection}}(allocator, decompress, http2, false) stream.bufferstream = Base.BufferStream() stream.connection = conn stream.request = req diff --git a/src/client/stream.jl b/src/client/stream.jl index 940d80c9..2e69c6ce 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -99,16 +99,21 @@ mutable struct Stream{T} <: IO allocator::Ptr{aws_allocator} decompress::Union{Nothing, Bool} http2::Bool + server_side::Bool status::Cint # used as a ref fut::Future{Nothing} chunk::Union{Nothing, InputStream} final_chunk_written::Bool bufferstream::Union{Nothing, Base.BufferStream} gzipstream::Union{Nothing, CodecZlib.GzipDecompressorStream} + responsebuf::Union{Nothing, IOBuffer} headers_ready::Threads.Event activated::Bool write_started::Bool read_started::Bool + response_started::Bool + handler_started::Bool + ignore_writes::Bool released::Bool # remaining fields are initially undefined ptr::Ptr{aws_http_stream} @@ -122,21 +127,26 @@ mutable struct Stream{T} <: IO http2_stream_write_data_options::aws_http2_stream_write_data_options chunk_options::aws_http1_chunk_options websocket_options::aws_websocket_server_upgrade_options - Stream{T}(allocator, decompress, http2) where {T} = new{T}( + Stream{T}(allocator, decompress, http2, server_side::Bool=false) where {T} = new{T}( allocator, decompress, http2, + server_side, 0, Future{Nothing}(), nothing, false, nothing, nothing, + nothing, Threads.Event(), false, false, false, false, + false, + false, + false, ) end @@ -187,7 +197,7 @@ function c_on_stream_write_on_complete(aws_stream_ptr, error_code, fut_ptr) end function writechunk(s::Stream, chunk::RequestBodyTypes) - if !(chunk isa AbstractString && isempty(chunk)) + if !s.server_side && !(chunk isa AbstractString && isempty(chunk)) @assert (isdefined(s, :response) && isdefined(s.response, :request) && s.response.request.method in ("POST", "PUT", "PATCH")) "write is only allowed for POST, PUT, and PATCH requests" @@ -217,7 +227,88 @@ function writechunk(s::Stream, chunk::RequestBodyTypes) return s.chunk.bodylen end +function _ensure_response!(s::Stream) + if !isdefined(s, :response) || s.response === nothing + s.response = Response(200, nothing, nothing, s.http2, s.allocator) + end + return s.response +end + +function _send_response!(s::Stream) + if s.response_started + return s.response + end + resp = _ensure_response!(s) + aws_http_stream_send_response(s.ptr, resp.ptr) != 0 && aws_throw_error() + s.response_started = true + return resp +end + +function _server_startwrite(s::Stream) + if s.write_started + return + end + resp = _ensure_response!(s) + if s.request.method == "HEAD" + s.ignore_writes = true + setinputstream!(resp, nothing) + end + if s.http2 + s.write_started = true + return + end + if !s.ignore_writes && + !hasheader(resp.headers, "transfer-encoding") && + !hasheader(resp.headers, "upgrade") + hasheader(resp.headers, "content-length") && removeheader(resp.headers, "content-length") + setheader(resp.headers, "transfer-encoding", "chunked") + end + _send_response!(s) + s.write_started = true + return +end + +function _server_closewrite(s::Stream) + if s.final_chunk_written + return + end + resp = _ensure_response!(s) + if s.http2 + if !s.response_started + if s.ignore_writes + setinputstream!(resp, nothing) + else + body = s.responsebuf === nothing ? UInt8[] : take!(s.responsebuf) + setinputstream!(resp, body) + end + _send_response!(s) + end + s.final_chunk_written = true + return + end + if !s.response_started + if !s.ignore_writes && + !hasheader(resp.headers, "transfer-encoding") && + !hasheader(resp.headers, "upgrade") + hasheader(resp.headers, "content-length") && removeheader(resp.headers, "content-length") + setheader(resp.headers, "transfer-encoding", "chunked") + end + _send_response!(s) + end + if s.ignore_writes + s.final_chunk_written = true + return + end + writechunk(s, "") + s.final_chunk_written = true + return +end + function _activate_stream!(s::Stream) + if s.server_side + s.activated = true + return + end if !s.activated aws_http_stream_activate(s.ptr) != 0 && aws_throw_error() s.activated = true @@ -226,6 +317,9 @@ function _activate_stream!(s::Stream) end function startwrite(s::Stream) + if s.server_side + return _server_startwrite(s) + end if s.write_started return end @@ -241,6 +335,9 @@ function startwrite(s::Stream) end function closewrite(s::Stream) + if s.server_side + return _server_closewrite(s) + end if s.final_chunk_written return end @@ -264,6 +361,14 @@ function closewrite(s::Stream) end function startread(s::Stream) + if s.server_side + if s.read_started + return s.request + end + wait(s.headers_ready) + s.read_started = true + return s.request + end if s.read_started return s.response end @@ -316,12 +421,30 @@ end function Base.write(s::Stream, data::AbstractVector{UInt8}) startwrite(s) + if s.server_side + if s.ignore_writes + return length(data) + elseif s.http2 + s.responsebuf === nothing && (s.responsebuf = IOBuffer()) + write(s.responsebuf, data) + return length(data) + end + end writechunk(s, data) return length(data) end function Base.write(s::Stream, data::Union{String, SubString{String}}) startwrite(s) + if s.server_side + if s.ignore_writes + return sizeof(data) + elseif s.http2 + s.responsebuf === nothing && (s.responsebuf = IOBuffer()) + write(s.responsebuf, data) + return sizeof(data) + end + end writechunk(s, data) return sizeof(data) end @@ -332,6 +455,15 @@ end function Base.write(s::Stream, b::UInt8) startwrite(s) + if s.server_side + if s.ignore_writes + return 1 + elseif s.http2 + s.responsebuf === nothing && (s.responsebuf = IOBuffer()) + write(s.responsebuf, b) + return 1 + end + end writechunk(s, UInt8[b]) return 1 end @@ -355,8 +487,36 @@ function Base.close(s::Stream) return end +function setstatus(s::Stream, status::Integer) + s.server_side || error("setstatus is only supported for server streams") + s.response_started && error("response already started") + resp = _ensure_response!(s) + resp.status = status + return +end + +function setheader(s::Stream, v) + s.server_side || error("setheader is only supported for server streams") + s.response_started && error("response already started") + resp = _ensure_response!(s) + setheader(resp, v) + return +end + +function setheader(s::Stream, k, v) + return setheader(s, k => v) +end + +function setheaderifabsent(s::Stream, k, v) + s.server_side || error("setheaderifabsent is only supported for server streams") + s.response_started && error("response already started") + resp = _ensure_response!(s) + setheaderifabsent(resp.headers, k, v) + return +end + function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator) - stream = Stream{Nothing}(allocator, decompress, http2) + stream = Stream{Nothing}(allocator, decompress, http2, false) if on_stream_response_body !== nothing stream.bufferstream = Base.BufferStream() end diff --git a/src/server.jl b/src/server.jl index 98e1a131..e3027fbb 100644 --- a/src/server.jl +++ b/src/server.jl @@ -57,6 +57,7 @@ mutable struct Server{F, C} const connections::Set{Connection} const closed::Threads.Event const access_log::Union{Nothing, Function} + const stream::Bool const logstate::Base.CoreLogging.LogState @atomic state::Symbol # :initializing, :running, :closed server::Ptr{aws_http_server} @@ -74,9 +75,10 @@ mutable struct Server{F, C} connections::Set{Connection}, closed::Threads.Event, access_log::Union{Nothing, Function}, + stream::Bool, logstate::Base.CoreLogging.LogState, state::Symbol, - ) where {F, C} = new{F, C}(f, on_stream_complete, fut, allocator, endpoint, socket_options, tls_options, connections_lock, connections, closed, access_log, logstate, state) + ) where {F, C} = new{F, C}(f, on_stream_complete, fut, allocator, endpoint, socket_options, tls_options, connections_lock, connections, closed, access_log, stream, logstate, state) end Base.wait(s::Server) = wait(s.closed) @@ -90,6 +92,7 @@ function serve!(f, host="127.0.0.1", port=8080; listenany::Bool=false, on_stream_complete=nothing, access_log::Union{Nothing, Function}=nothing, + stream::Bool=false, # socket options socket_options=nothing, socket_domain=:ipv4, @@ -143,6 +146,7 @@ function serve!(f, host="127.0.0.1", port=8080; Set{Connection}(), # connections Threads.Event(), # closed access_log, + stream, Base.CoreLogging.current_logstate(), :initializing, # state ) @@ -165,6 +169,15 @@ function serve!(f, host="127.0.0.1", port=8080; return server end +function serve(f, host="127.0.0.1", port=8080; stream::Bool=false, kw...) + server = serve!(f, host, port; stream=stream, kw...) + wait(server) + return server +end + +listen!(f, host="127.0.0.1", port=8080; kw...) = serve!(f, host, port; stream=true, kw...) +listen(f, host="127.0.0.1", port=8080; kw...) = serve(f, host, port; stream=true, kw...) + const on_incoming_connection = Ref{Ptr{Cvoid}}(C_NULL) function c_on_incoming_connection(aws_server, aws_conn, error_code, server_ptr) @@ -222,7 +235,8 @@ function c_on_incoming_request(aws_conn, conn_ptr) stream = Stream{typeof(conn)}( conn.allocator, false, # decompress - aws_http_connection_get_version(aws_conn) == AWS_HTTP_VERSION_2 # http2 + aws_http_connection_get_version(aws_conn) == AWS_HTTP_VERSION_2, # http2 + true, ) stream.connection = conn stream.request_handler_options = aws_http_request_handler_options( @@ -270,6 +284,33 @@ function c_on_request_header_block_done(aws_stream_ptr, header_block, stream_ptr ret = aws_http_stream_get_incoming_request_uri(aws_stream_ptr, FieldRef(stream, :path)) ret != 0 && return ret aws_http_message_set_request_path(stream.request.ptr, stream.path) + notify(stream.headers_ready) + if stream.connection.server.stream && !stream.handler_started + stream.handler_started = true + stream.bufferstream === nothing && (stream.bufferstream = Base.BufferStream()) + Threads.@spawn begin + Base.CoreLogging.with_logstate(stream.connection.server.logstate) do + try + Base.invokelatest(stream.connection.server.f, stream) + catch e + @error "Request handler error; sending 500" exception=(e, catch_backtrace()) + if !stream.response_started + try + setstatus(stream, 500) + catch err + @error "failed to set 500 status" exception=(err, catch_backtrace()) + end + end + finally + try + closewrite(stream) + catch err + @error "failed to close response stream" exception=(err, catch_backtrace()) + end + end + end + end + end return Cint(0) end @@ -279,6 +320,11 @@ const on_request_body = Ref{Ptr{Cvoid}}(C_NULL) function c_on_request_body(aws_stream_ptr, data::Ptr{aws_byte_cursor}, stream_ptr) stream = unsafe_pointer_to_objref(stream_ptr) bc = unsafe_load(data) + if stream.connection.server.stream + stream.bufferstream === nothing && (stream.bufferstream = Base.BufferStream()) + unsafe_write(stream.bufferstream, bc.ptr, bc.len) + return Cint(0) + end body = stream.request.body if body === nothing body = Vector{UInt8}(undef, bc.len) @@ -297,6 +343,10 @@ const on_request_done = Ref{Ptr{Cvoid}}(C_NULL) function c_on_request_done(aws_stream_ptr, stream_ptr) stream = unsafe_pointer_to_objref(stream_ptr) Base.CoreLogging.with_logstate(stream.connection.server.logstate) do + if stream.connection.server.stream + stream.bufferstream !== nothing && close(stream.bufferstream) + return Cint(0) + end try stream.response = Base.invokelatest(stream.connection.server.f, stream.request)::Response if stream.request.method == "HEAD" @@ -342,6 +392,7 @@ function c_on_server_stream_complete(aws_stream_ptr, error_code, stream_ptr) @lock stream.connection.streams_lock begin delete!(stream.connection.streams, stream) end + release_stream_ptr!(stream) return Cint(0) end end diff --git a/test/server.jl b/test/server.jl index 535c314b..b022a76c 100644 --- a/test/server.jl +++ b/test/server.jl @@ -12,6 +12,28 @@ using Test, HTTP, Logging end end +@testset "HTTP.listen stream handler" begin + server = HTTP.listen!("127.0.0.1", 0; listenany=true) do http + body = String(read(http)) + HTTP.setstatus(http, 200) + HTTP.setheader(http, "Content-Type" => "text/plain") + HTTP.startwrite(http) + write(http, isempty(body) ? "ping" : body) + end + try + port = HTTP.port(server) + resp = HTTP.get("http://127.0.0.1:$port") + @test resp.status == 200 + @test String(resp.body) == "ping" + + resp = HTTP.post("http://127.0.0.1:$port"; body="echo") + @test resp.status == 200 + @test String(resp.body) == "echo" + finally + close(server) + end +end + @testset "access logging" begin local handler = (req) -> begin if req.target == "/internal-error" @@ -89,4 +111,4 @@ end @test occursin(r"^\*/\* text/plain GET /index\.html HTTP/1\.1 GET /index\.html 127\.0\.0\.1 \d+ - HTTP/1\.1 \d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.* \d+/.*/\d{4}:\d{2}:\d{2}:\d{2}.* 200 12$", logs[2].message) @test occursin(r"^\*/\* text/plain GET /index\.html\?a=b HTTP/1\.1 GET /index\.html\?a=b 127\.0\.0\.1 \d+ - HTTP/1\.1 \d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.* \d+/.*/\d{4}:\d{2}:\d{2}:\d{2}.* 200 12$", logs[3].message) @test occursin(r"^\*/\* text/plain HEAD / HTTP/1\.1 HEAD / 127\.0\.0\.1 \d+ - HTTP/1\.1 \d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.* \d+/.*/\d{4}:\d{2}:\d{2}:\d{2}.* 200 0$", logs[4].message) -end \ No newline at end of file +end From 42bd2f2136764914ec4e7f0613aedc6cf84d6dd0 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sat, 24 Jan 2026 22:44:48 -0700 Subject: [PATCH 04/56] feat(websockets): handle close and fragmentation - track control frames and fragmented messages - add close handling with error propagation - cover fragmentation/close in tests --- src/websockets.jl | 327 ++++++++++++++++++++++++++++++--------- test/websockets_basic.jl | 25 +++ 2 files changed, 282 insertions(+), 70 deletions(-) diff --git a/src/websockets.jl b/src/websockets.jl index ce1323a9..8a4292c0 100644 --- a/src/websockets.jl +++ b/src/websockets.jl @@ -8,27 +8,101 @@ export WebSocket, send, receive, ping, pong @enum OpCode::UInt8 CONTINUATION=0x00 TEXT=0x01 BINARY=0x02 CLOSE=0x08 PING=0x09 PONG=0x0A +struct CloseFrameBody + code::Int + reason::String +end + +struct WebSocketError <: Exception + message::CloseFrameBody +end + +isok(e::WebSocketError) = e.message.code in (1000, 1001, 1005) + mutable struct WebSocket id::String host::String path::String connect_fut::Future{Nothing} - readchannel::Channel{Union{String, Vector{UInt8}}} + readchannel::Channel{Union{String, Vector{UInt8}, WebSocketError}} writebuffer::Vector{UInt8} writepos::Int writeclosed::Bool closelock::ReentrantLock - handshake_request::Request - options::aws_websocket_client_connection_options + sendlock::ReentrantLock + handshake_request::Union{Nothing, Request} websocket_pointer::Ptr{aws_websocket} - handshake_response::Response - websocket_send_frame_options::aws_websocket_send_frame_options - - WebSocket(host::AbstractString, path::AbstractString) = new(string(rand(UInt32); base=58), String(host), String(path), Future{Nothing}(), Channel{Union{String, Vector{UInt8}}}(Inf), UInt8[], 0, false, ReentrantLock()) + handshake_response::Union{Nothing, Response} + incoming_opcode::UInt8 + incoming_fin::Bool + incoming_payload::Vector{UInt8} + fragment_opcode::Union{Nothing, UInt8} + fragment_payload::Vector{UInt8} + closebody::Union{Nothing, CloseFrameBody} + + WebSocket(host::AbstractString, path::AbstractString) = new( + string(rand(UInt32); base=58), + String(host), + String(path), + Future{Nothing}(), + Channel{Union{String, Vector{UInt8}, WebSocketError}}(Inf), + UInt8[], + 0, + false, + ReentrantLock(), + ReentrantLock(), + nothing, + C_NULL, + nothing, + 0x00, + false, + UInt8[], + nothing, + UInt8[], + nothing, + ) end getresponse(ws::WebSocket) = ws.handshake_response +function _queue_close!(ws::WebSocket, body::CloseFrameBody) + ws.closebody = body + if isopen(ws.readchannel) + try + put!(ws.readchannel, WebSocketError(body)) + catch + end + Base.close(ws.readchannel) + end + return +end + +function _close_channel!(ws::WebSocket) + isopen(ws.readchannel) && Base.close(ws.readchannel) + return +end + +function _enqueue_message!(ws::WebSocket, msg) + if isopen(ws.readchannel) + try + put!(ws.readchannel, msg) + catch + end + end + return +end + +function close_payload(body::CloseFrameBody) + reason_bytes = collect(codeunits(body.reason)) + payload = Vector{UInt8}(undef, 2 + length(reason_bytes)) + payload[1] = UInt8((body.code >> 8) & 0xff) + payload[2] = UInt8(body.code & 0xff) + if !isempty(reason_bytes) + copyto!(payload, 3, reason_bytes, 1, length(reason_bytes)) + end + return payload +end + mutable struct SendState ws::WebSocket fut::Future{Nothing} @@ -44,15 +118,17 @@ function c_on_connection_setup(connection_setup_data::Ptr{aws_websocket_on_conne notify(ws.connect_fut, CapturedException(aws_error(data.error_code), Base.backtrace())) else ws.websocket_pointer = data.websocket - ws.handshake_response.status = unsafe_load(data.handshake_response_status) - addheaders(ws.handshake_response.headers, data.handshake_response_header_array, data.num_handshake_response_headers) + resp = ws.handshake_response + @assert resp !== nothing + resp.status = unsafe_load(data.handshake_response_status) + addheaders(resp.headers, data.handshake_response_header_array, data.num_handshake_response_headers) if data.handshake_response_body != C_NULL handshake_response_body = unsafe_load(data.handshake_response_body) response_body = str(handshake_response_body) else response_body = nothing end - setinputstream!(ws.handshake_response, response_body) + setinputstream!(resp, response_body) notify(ws.connect_fut, nothing) end catch e @@ -67,16 +143,26 @@ function c_on_connection_shutdown(websocket::Ptr{aws_websocket}, error_code::Cin ws = unsafe_pointer_to_objref(ws_ptr) if error_code != 0 @error "$(ws.id): connection shutdown error" exception=(aws_error(error_code), Base.backtrace()) + if ws.closebody === nothing + _queue_close!(ws, CloseFrameBody(1006, "")) + end + else + _close_channel!(ws) end - close(ws) + ws.websocket_pointer = C_NULL + ws.writeclosed = true return end const on_incoming_frame_begin = Ref{Ptr{Cvoid}}(C_NULL) function c_on_incoming_frame_begin(websocket::Ptr{aws_websocket}, frame::Ptr{aws_websocket_incoming_frame}, ws_ptr) - # ws = unsafe_pointer_to_objref(ws_ptr) - # fr = unsafe_load(frame) + ws = unsafe_pointer_to_objref(ws_ptr) + fr = unsafe_load(frame) + ws.incoming_opcode = fr.opcode + ws.incoming_fin = fr.fin + empty!(ws.incoming_payload) + fr.payload_length > 0 && sizehint!(ws.incoming_payload, Int(fr.payload_length)) return true end @@ -84,15 +170,13 @@ const on_incoming_frame_payload = Ref{Ptr{Cvoid}}(C_NULL) function c_on_incoming_frame_payload(websocket::Ptr{aws_websocket}, frame::Ptr{aws_websocket_incoming_frame}, data::aws_byte_cursor, ws_ptr) ws = unsafe_pointer_to_objref(ws_ptr) - fr = unsafe_load(frame) try - if fr.opcode == UInt8(TEXT) - put!(ws.readchannel, unsafe_string(data.ptr, data.len)) - else - rec = Vector{UInt8}(undef, data.len) - Base.unsafe_copyto!(pointer(rec), data.ptr, data.len) - put!(ws.readchannel, rec) - end + n = Int(data.len) + n == 0 && return true + payload = ws.incoming_payload + start = length(payload) + 1 + resize!(payload, length(payload) + n) + Base.unsafe_copyto!(pointer(payload, start), data.ptr, n) catch e @error "$(ws.id): incoming frame payload error" exception=(e, catch_backtrace()) end @@ -103,9 +187,86 @@ const on_incoming_frame_complete = Ref{Ptr{Cvoid}}(C_NULL) function c_on_incoming_frame_complete(websocket::Ptr{aws_websocket}, frame::Ptr{aws_websocket_incoming_frame}, error_code::Cint, ws_ptr) ws = unsafe_pointer_to_objref(ws_ptr) - fr = unsafe_load(frame) if error_code != 0 @error "$(ws.id): incoming frame complete error" exception=(aws_error(error_code), Base.backtrace()) + close_body = CloseFrameBody(1006, "") + _queue_close!(ws, close_body) + Threads.@spawn close(ws, close_body) + return true + end + fr = unsafe_load(frame) + opcode = fr.opcode + fin = fr.fin + payload = ws.incoming_payload + if opcode == UInt8(PING) + payload_copy = copy(payload) + Threads.@spawn begin + try + pong(ws, payload_copy) + catch e + @error "$(ws.id): failed to send pong" exception=(e, catch_backtrace()) + end + end + return true + elseif opcode == UInt8(PONG) + return true + elseif opcode == UInt8(CLOSE) + body = payload + close_body = if length(body) >= 2 + code = (Int(body[1]) << 8) | Int(body[2]) + reason = length(body) > 2 ? String(copy(body[3:end])) : "" + CloseFrameBody(code, reason) + else + CloseFrameBody(1005, "") + end + Threads.@spawn begin + try + ws.writeclosed || close(ws, close_body) + catch e + @error "$(ws.id): failed to close websocket" exception=(e, catch_backtrace()) + end + end + _queue_close!(ws, close_body) + return true + end + if opcode == UInt8(CONTINUATION) + if ws.fragment_opcode === nothing + close_body = CloseFrameBody(1002, "unexpected continuation") + _queue_close!(ws, close_body) + Threads.@spawn close(ws, close_body) + return true + end + append!(ws.fragment_payload, payload) + if fin + msg_opcode = ws.fragment_opcode + data = ws.fragment_payload + ws.fragment_opcode = nothing + ws.fragment_payload = UInt8[] + if msg_opcode == UInt8(TEXT) + _enqueue_message!(ws, String(copy(data))) + else + _enqueue_message!(ws, copy(data)) + end + end + return true + end + if opcode == UInt8(TEXT) || opcode == UInt8(BINARY) + if ws.fragment_opcode !== nothing + close_body = CloseFrameBody(1002, "unexpected new data frame") + _queue_close!(ws, close_body) + Threads.@spawn close(ws, close_body) + return true + end + if fin + if opcode == UInt8(TEXT) + _enqueue_message!(ws, String(copy(payload))) + else + _enqueue_message!(ws, copy(payload)) + end + else + ws.fragment_opcode = opcode + ws.fragment_payload = copy(payload) + end end return true end @@ -148,7 +309,7 @@ function open(f::Function, url; ws = WebSocket(host, path) ws.handshake_request = req ws.handshake_response = Response(0, nothing, nothing, false, allocator) - ws.options = aws_websocket_client_connection_options( + options = aws_websocket_client_connection_options( allocator, reqclient.settings.bootstrap, pointer(FieldRef(reqclient, :socket_options)), @@ -168,7 +329,7 @@ function open(f::Function, url; C_NULL, # requested_event_loop C_NULL, # host_resolution_config ) - if aws_websocket_client_connect(FieldRef(ws, :options)) != 0 + if aws_websocket_client_connect(Ref(options)) != 0 aws_throw_error() end # wait until connected @@ -195,18 +356,33 @@ function open(f::Function, url; # end finally # if !isclosed(ws) - close(ws) + ws.closebody === nothing && close(ws) # end end end -function Base.close(ws::WebSocket) +function Base.close(ws::WebSocket, body::Union{Nothing, CloseFrameBody}=nothing) @lock ws.closelock begin + if ws.writeclosed + _close_channel!(ws) + return + end + ws.writeclosed = true if ws.websocket_pointer != C_NULL + if body !== nothing + payload_bytes = close_payload(body) + @lock ws.sendlock begin + try + writeframe(ws, true, CLOSE, payload(ws, payload_bytes)) + catch + # ignore errors while closing + end + end + end aws_websocket_close(ws.websocket_pointer, false) ws.websocket_pointer = C_NULL - ws.writeclosed = true end + _close_channel!(ws) end return end @@ -256,25 +432,25 @@ function c_on_complete(websocket::Ptr{aws_websocket}, error_code::Cint, ws_ptr:: state = unsafe_pointer_to_objref(ws_ptr) if error_code != 0 notify(state.fut, CapturedException(aws_error(error_code), Base.backtrace())) + else + notify(state.fut, nothing) end - notify(state.fut, nothing) return end function writeframe(ws::WebSocket, fin::Bool, opcode::OpCode, payload) n = sizeof(payload) state = SendState(ws, Future{Nothing}()) - GC.@preserve state begin - ws.websocket_send_frame_options = aws_websocket_send_frame_options( - n % UInt64, - Ptr{Cvoid}(pointer_from_objref(state)), # user_data - stream_outgoing_payload[], - on_complete[], - UInt8(opcode), - fin - ) - opts = pointer(FieldRef(ws, :websocket_send_frame_options)) - if aws_websocket_send_frame(ws.websocket_pointer, opts) != 0 + opts = aws_websocket_send_frame_options( + n % UInt64, + Ptr{Cvoid}(pointer_from_objref(state)), # user_data + stream_outgoing_payload[], + on_complete[], + UInt8(opcode), + fin + ) + GC.@preserve state opts begin + if aws_websocket_send_frame(ws.websocket_pointer, Ref(opts)) != 0 aws_throw_error() end # wait until frame sent @@ -297,34 +473,37 @@ or `close(ws[, body::WebSockets.CloseFrameBody])`. Calling `close` will initiate the close sequence and close the underlying connection. """ function send(ws::WebSocket, x) - @assert !ws.writeclosed "WebSocket is closed" - if !isbinary(x) && !istext(x) - # if x is not single binary or text, then assume it's an iterable of binary or text - # and we'll send fragmented message - first = true - n = 0 - state = iterate(x) - if state === nothing - # x was not binary or text, but is an empty iterable, send single empty frame - x = "" - @goto write_single_frame - end - @debug "$(ws.id): Writing fragmented message" - item, st = state - # we prefetch next state so we know if we're on the last item or not - # so we can appropriately set the FIN bit for the last fragmented frame - nextstate = iterate(x, st) - while true - n += writeframe(ws, nextstate === nothing, first ? opcode(item) : CONTINUATION, payload(ws, item)) - first = false - nextstate === nothing && break - item, st = nextstate + @lock ws.sendlock begin + @assert !ws.writeclosed "WebSocket is closed" + if !isbinary(x) && !istext(x) + # if x is not single binary or text, then assume it's an iterable of binary or text + # and we'll send fragmented message + first = true + n = 0 + state = iterate(x) + if state === nothing + # x was not binary or text, but is an empty iterable, send single empty frame + x = "" + @goto write_single_frame + end + @debug "$(ws.id): Writing fragmented message" + item, st = state + # we prefetch next state so we know if we're on the last item or not + # so we can appropriately set the FIN bit for the last fragmented frame nextstate = iterate(x, st) - end - else - # single binary or text frame for message + while true + n += writeframe(ws, nextstate === nothing, first ? opcode(item) : CONTINUATION, payload(ws, item)) + first = false + nextstate === nothing && break + item, st = nextstate + nextstate = iterate(x, st) + end + return n + else + # single binary or text frame for message @label write_single_frame - return writeframe(ws, true, opcode(x), payload(ws, x)) + return writeframe(ws, true, opcode(x), payload(ws, x)) + end end end @@ -337,8 +516,10 @@ body to send with the message. PONG messages are automatically responded to when a PING message is received by a websocket connection. """ function ping(ws::WebSocket, data=UInt8[]) - @assert !ws.writeclosed "WebSocket is closed" - return writeframe(ws, true, PING, payload(ws, data)) + @lock ws.sendlock begin + @assert !ws.writeclosed "WebSocket is closed" + return writeframe(ws, true, PING, payload(ws, data)) + end end """ @@ -351,8 +532,10 @@ PONG message, but in certain cases, a unidirectional PONG message can be used as a one-way heartbeat. """ function pong(ws::WebSocket, data=UInt8[]) - @assert !ws.writeclosed "WebSocket is closed" - return writeframe(ws, true, PONG, payload(ws, data)) + @lock ws.sendlock begin + @assert !ws.writeclosed "WebSocket is closed" + return writeframe(ws, true, PONG, payload(ws, data)) + end end """ @@ -373,7 +556,11 @@ where each iteration yields a message until the connection is closed. """ function receive(ws::WebSocket) @assert isopen(ws.readchannel) "WebSocket is closed" - return take!(ws.readchannel) + msg = take!(ws.readchannel) + if msg isa WebSocketError + throw(msg) + end + return msg end """ diff --git a/test/websockets_basic.jl b/test/websockets_basic.jl index 0ff355f1..d3484c87 100644 --- a/test/websockets_basic.jl +++ b/test/websockets_basic.jl @@ -19,3 +19,28 @@ using HTTP.WebSockets close(server) end end + +@testset "WebSockets fragmentation and close" begin + server = HTTP.WebSockets.serve!("127.0.0.1", 0; listenany=true) do ws + msg = receive(ws) + send(ws, msg) + WebSockets.close(ws, WebSockets.CloseFrameBody(1000, "bye")) + end + port = HTTP.port(server) + try + WebSockets.open("ws://127.0.0.1:$port") do ws + send(ws, ["hel", "lo"]) + @test receive(ws) == "hello" + try + receive(ws) + @test false + catch e + @test e isa WebSockets.WebSocketError + @test WebSockets.isok(e) + @test e.message.code == 1000 + end + end + finally + close(server) + end +end From 97ce1bbce36f1ef0c3a9e8c2cadfb5fc97355963 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sat, 24 Jan 2026 22:59:31 -0700 Subject: [PATCH 05/56] feat(retry): align retry semantics with v1 - gate retries by idempotency and body retryability - honor retry_delays and retry_check overrides - track retry counts in response metrics - add retry behavior tests --- src/HTTP.jl | 27 ++++--- src/client/client.jl | 2 +- src/client/makerequest.jl | 14 +++- src/client/retry.jl | 144 +++++++++++++++++++++++++------------- test/client.jl | 68 ++++++++++++++++++ 5 files changed, 189 insertions(+), 66 deletions(-) diff --git a/src/HTTP.jl b/src/HTTP.jl index ebcd5563..fc5bcc4e 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -13,20 +13,6 @@ include("access_log.jl") include("sniff.jl"); using .Sniff include("forms.jl"); using .Forms include("requestresponse.jl") -include("cookies.jl"); using .Cookies -include("client/redirects.jl") -include("client/client.jl") -include("client/retry.jl") -include("client/connection.jl") -include("client/request.jl") -include("client/stream.jl") -include("client/makerequest.jl") -include("client/open.jl") -include("websockets.jl"); using .WebSockets -include("server.jl") -include("handlers.jl"); using .Handlers -include("statuses.jl") - struct StatusError <: Exception request_method::String request_uri::aws_uri @@ -54,6 +40,19 @@ function Base.getproperty(e::StatusError, s::Symbol) return getfield(e, s) end end +include("cookies.jl"); using .Cookies +include("client/redirects.jl") +include("client/client.jl") +include("client/retry.jl") +include("client/connection.jl") +include("client/request.jl") +include("client/stream.jl") +include("client/makerequest.jl") +include("client/open.jl") +include("websockets.jl"); using .WebSockets +include("server.jl") +include("handlers.jl"); using .Handlers +include("statuses.jl") #NOTE: this is global process logging in the aws-crt libraries; not appropriate for request-level # logging, but more for debugging the library itself diff --git a/src/client/client.jl b/src/client/client.jl index 257002fd..e3540715 100644 --- a/src/client/client.jl +++ b/src/client/client.jl @@ -1,5 +1,5 @@ const DEFAULT_CONNECT_TIMEOUT = 3000 -const DEFAULT_MAX_RETRIES = 10 +const DEFAULT_MAX_RETRIES = 4 Base.@kwdef struct ClientSettings scheme::String diff --git a/src/client/makerequest.jl b/src/client/makerequest.jl index e1ce21ef..2a1421c0 100644 --- a/src/client/makerequest.jl +++ b/src/client/makerequest.jl @@ -74,6 +74,14 @@ function request(method, url, h=Header[], b=nothing; chunkedbody = body body = nothing end + retryable_body = chunkedbody === nothing && ( + body === nothing || + body isa AbstractString || + body isa AbstractVector{UInt8} || + body isa AbstractDict || + body isa NamedTuple || + body isa Form + ) headers = mkreqheaders(headers, copyheaders) uri = parseuri(url, query, allocator) proxy_kw = proxy_kwargs(proxy, scheme(uri)) @@ -93,7 +101,10 @@ function request(method, url, h=Header[], b=nothing; getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...)) : getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...), pool) )::Client - with_retry_token(reqclient; logerrors=logerrors, logtag=logtag, method=method, uri=uri, retry_check=retry_check, retry_delays=retry_delays) do + req_ref = Ref{Any}(nothing) + with_retry_token(reqclient; logerrors=logerrors, logtag=logtag, method=method, uri=uri, + retry_check=retry_check, retry_delays=retry_delays, + retry_non_idempotent=retry_non_idempotent, retryable_body=retryable_body, req_ref=req_ref) do resp = with_connection(reqclient) do conn http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 path = resource(uri) @@ -103,6 +114,7 @@ function request(method, url, h=Header[], b=nothing; detect_content_type=detect_content_type, basicauth=apply_basicauth, ) do req + req_ref[] = req if response_body isa AbstractVector{UInt8} ref = Ref(1) GC.@preserve ref begin diff --git a/src/client/retry.jl b/src/client/retry.jl index 93292f2a..f5c6e7fc 100644 --- a/src/client/retry.jl +++ b/src/client/retry.jl @@ -11,6 +11,16 @@ end Base.showerror(io::IO, e::StreamError) = print(io, e.error) +retryable_status(status::Integer) = status in (403, 408, 409, 429, 500, 502, 503, 504, 599) + +isrecoverable(ex::StatusError) = retryable_status(ex.status) +isrecoverable(::Union{Base.EOFError, Base.IOError}) = true +isrecoverable(ex::ArgumentError) = ex.msg == "stream is closed or unusable" +isrecoverable(ex::CompositeException) = all(isrecoverable, ex.exceptions) +isrecoverable(ex::Sockets.DNSError) = (ex.code == Base.UV_EAI_AGAIN) +isrecoverable(::AWSError) = true +isrecoverable(::Exception) = false + const on_acquired = Ref{Ptr{Cvoid}}(C_NULL) function c_on_acquired(retry_strategy, error_code, retry_token, fut_ptr) @@ -35,6 +45,40 @@ function c_retry_ready(token, error_code::Cint, fut_ptr) return end +function _default_retryable(method, err, retryable_body::Bool, retry_non_idempotent::Bool) + retryable_body || return false + method === nothing && return false + method_str = string(method) + if !(isidempotent(method_str) || retry_non_idempotent) + return false + end + if err isa StatusError + return retryable_status(err.status) + end + return isrecoverable(err) +end + +function _normalize_retry_delays(retry_delays, max_retries::Int) + if retry_delays === nothing + return Base.ExponentialBackOff(n=max_retries, factor=3.0) + elseif retry_delays isa Number + return Iterators.repeated(retry_delays, max_retries) + else + return retry_delays + end +end + +function _set_nretries!(x, nretries::Int) + if x isa Response + x.metrics.nretries = nretries + elseif x isa StatusError + x.response.metrics.nretries = nretries + elseif x isa StreamError && x.stream !== nothing + x.stream.response !== nothing && (x.stream.response.metrics.nretries = nretries) + end + return +end + function with_retry_token( f::Function, client::Client; @@ -44,9 +88,13 @@ function with_retry_token( uri=nothing, retry_check=nothing, retry_delays=nothing, + retry_non_idempotent::Bool=false, + retryable_body::Bool=true, + req_ref=nothing, ) # If max_retries is 0, we don't need to bother with any retrying - if client.settings.max_retries == 0 + max_retries = client.settings.max_retries + if max_retries == 0 try return f() catch e @@ -57,57 +105,53 @@ function with_retry_token( rethrow() end end - retry_partition = client.settings.retry_partition === nothing ? C_NULL : aws_byte_cursor_from_c_str(client.settings.retry_partition) - fut = Future{Ptr{aws_retry_token}}() - GC.@preserve fut begin - if aws_retry_strategy_acquire_retry_token(client.retry_strategy, retry_partition, on_acquired[], pointer_from_objref(fut), client.settings.retry_timeout_ms) != 0 - aws_throw_error() - end - token = wait(fut) - end - try - while true - try - ret = f() - aws_retry_token_record_success(token) - return ret - catch e - stream = nothing - if e isa StreamError - stream = e.stream - e = e.error - end - if logerrors - log_err = e isa DontRetry ? e.error : e - url = uri === nothing ? nothing : (uri isa aws_uri ? makeuri(uri) : uri) - @error "HTTP request error" exception=(log_err, catch_backtrace()) method=method url=url logtag=logtag - end - if e isa DontRetry - if stream !== nothing && iserror(stream.response.status) && stream.bufferstream !== nothing - # for error responses, we need to commit the temporary body buffer - stream.response.body = readavailable(stream.bufferstream) - end - throw(e.error) - end - # note we assume any error that wasn't wrapped in DontRetry is retryable - retryReady = Future{Ptr{aws_retry_token}}() - GC.@preserve retryReady begin - if aws_retry_strategy_schedule_retry( - token, - #TODO: use different error types? - AWS_RETRY_ERROR_TYPE_TRANSIENT, - retry_ready[], - pointer_from_objref(retryReady) - ) != 0 - #TODO: do we need to commit a previous error body to the response here? - aws_throw_error() - end - #TODO: should we wrap this in try-catch to commit a previous stream bufferstream to the response body? - token = wait(retryReady) + retry_check_fn = retry_check === nothing ? nothing : retry_check + delays = _normalize_retry_delays(retry_delays, max_retries) + delay_state = nothing + nretries = 0 + while true + try + ret = f() + _set_nretries!(ret, nretries) + return ret + catch e + stream = nothing + err = e + if err isa StreamError + stream = err.stream + err = err.error + end + if logerrors + log_err = err isa DontRetry ? err.error : err + url = uri === nothing ? nothing : (uri isa aws_uri ? makeuri(uri) : uri) + @error "HTTP request error" exception=(log_err, catch_backtrace()) method=method url=url logtag=logtag + end + if err isa DontRetry + if stream !== nothing && iserror(stream.response.status) && stream.bufferstream !== nothing + # for error responses, we need to commit the temporary body buffer + stream.response.body = readavailable(stream.bufferstream) end + err = err.error + _set_nretries!(err, nretries) + throw(err) + end + nretries >= max_retries && (_set_nretries!(err, nretries); throw(err)) + delay_iter = delay_state === nothing ? iterate(delays) : iterate(delays, delay_state) + delay_iter === nothing && (_set_nretries!(err, nretries); throw(err)) + delay, delay_state = delay_iter + req = req_ref === nothing ? nothing : req_ref[] + resp = err isa StatusError ? err.response : nothing + resp_body = resp === nothing ? nothing : resp.body + retry = _default_retryable(method, err, retryable_body, retry_non_idempotent) + if !retry && retry_check_fn !== nothing && retryable_body + retry = retry_check_fn(delay, err, req, resp, resp_body) + end + if !retry + _set_nretries!(err, nretries) + throw(err) end + nretries += 1 + sleep(delay) end - finally - aws_retry_token_release(token) end end diff --git a/test/client.jl b/test/client.jl index 0867b909..50fb7d1a 100644 --- a/test/client.jl +++ b/test/client.jl @@ -198,6 +198,74 @@ @test isok(HTTP.get("http://$httpbin/delay/1"; readtimeout=2, max_retries=0)) end + @testset "Retry semantics" begin + attempts = Ref(0) + failures = Ref(1) + attempt_lock = ReentrantLock() + next_attempt() = Base.@lock attempt_lock begin + attempts[] += 1 + return attempts[] + end + reset_attempts!(nfail) = Base.@lock attempt_lock begin + attempts[] = 0 + failures[] = nfail + return + end + + server = HTTP.serve!("127.0.0.1", 0; listenany=true) do req + n = next_attempt() + if n <= failures[] + return HTTP.Response(503, "fail") + end + return HTTP.Response(200, "ok") + end + port = HTTP.port(server) + try + reset_attempts!(1) + resp = HTTP.get("http://127.0.0.1:$port/"; retries=1, retry_delays=[0.0]) + @test resp.status == 200 + @test resp.metrics.nretries == 1 + @test attempts[] == 2 + + reset_attempts!(1) + err = nothing + try + HTTP.post("http://127.0.0.1:$port/"; body="x", retries=1, retry_delays=[0.0]) + catch e + err = e + end + @test err isa HTTP.StatusError + @test err.response.metrics.nretries == 0 + @test attempts[] == 1 + + reset_attempts!(1) + resp = HTTP.post("http://127.0.0.1:$port/"; body="x", retries=1, + retry_non_idempotent=true, retry_delays=[0.0]) + @test resp.status == 200 + @test resp.metrics.nretries == 1 + @test attempts[] == 2 + + reset_attempts!(1) + resp = HTTP.post("http://127.0.0.1:$port/"; body="x", retries=1, + retry_check=(s, ex, req, resp, resp_body) -> true, retry_delays=[0.0]) + @test resp.status == 200 + @test resp.metrics.nretries == 1 + @test attempts[] == 2 + + reset_attempts!(2) + err = nothing + try + HTTP.get("http://127.0.0.1:$port/"; retries=3, retry_delays=[0.0]) + catch e + err = e + end + @test err isa HTTP.StatusError + @test attempts[] == 2 + finally + close(server) + end + end + @testset "Request Options Parity" begin headers = ["X-Test" => "1"] HTTP.get("https://$httpbin/headers"; headers=headers, copyheaders=true) From bd3be6f749a565626f48836daba4ee08d106cad4 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sat, 24 Jan 2026 23:35:43 -0700 Subject: [PATCH 06/56] feat(trailers): support trailing headers --- src/HTTP.jl | 4 ++-- src/client/stream.jl | 42 ++++++++++++++++++++++++++++++++++++++++-- src/requestresponse.jl | 6 +++++- src/server.jl | 16 ++++++++++++++-- test/server.jl | 25 +++++++++++++++++++++++++ 5 files changed, 86 insertions(+), 7 deletions(-) diff --git a/src/HTTP.jl b/src/HTTP.jl index fc5bcc4e..fcd90f03 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -100,8 +100,8 @@ function __init__() on_incoming_connection[] = @cfunction(c_on_incoming_connection, Cvoid, (Ptr{Cvoid}, Ptr{aws_http_connection}, Cint, Ptr{Cvoid})) on_connection_shutdown[] = @cfunction(c_on_connection_shutdown, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Cvoid})) on_incoming_request[] = @cfunction(c_on_incoming_request, Ptr{aws_http_stream}, (Ptr{aws_http_connection}, Ptr{Cvoid})) - on_request_headers[] = @cfunction(c_on_request_headers, Cint, (Ptr{aws_http_stream}, Ptr{aws_http_header_block}, Ptr{aws_http_header}, Csize_t, Ptr{Cvoid})) - on_request_header_block_done[] = @cfunction(c_on_request_header_block_done, Cint, (Ptr{aws_http_stream}, Ptr{aws_http_header_block}, Ptr{Cvoid})) + on_request_headers[] = @cfunction(c_on_request_headers, Cint, (Ptr{aws_http_stream}, aws_http_header_block, Ptr{aws_http_header}, Csize_t, Ptr{Cvoid})) + on_request_header_block_done[] = @cfunction(c_on_request_header_block_done, Cint, (Ptr{aws_http_stream}, aws_http_header_block, Ptr{Cvoid})) on_request_body[] = @cfunction(c_on_request_body, Cint, (Ptr{aws_http_stream}, Ptr{aws_byte_cursor}, Ptr{Cvoid})) on_request_done[] = @cfunction(c_on_request_done, Cint, (Ptr{aws_http_stream}, Ptr{Cvoid})) on_server_stream_complete[] = @cfunction(c_on_server_stream_complete, Cint, (Ptr{aws_http_connection}, Cint, Ptr{Cvoid})) diff --git a/src/client/stream.jl b/src/client/stream.jl index 2e69c6ce..b38687d2 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -2,13 +2,27 @@ const on_response_headers = Ref{Ptr{Cvoid}}(C_NULL) function c_on_response_headers(aws_stream_ptr, header_block, header_array::Ptr{aws_http_header}, num_headers, stream_ptr) stream = unsafe_pointer_to_objref(stream_ptr) - headers = stream.response.headers - addheaders(headers, header_array, num_headers) + if header_block == AWS_HTTP_HEADER_BLOCK_TRAILING + trailers = stream.response.trailers + if trailers === nothing + trailers = Headers(stream.response.allocator) + stream.response.trailers = trailers + end + addheaders(trailers, header_array, num_headers) + else + headers = stream.response.headers + addheaders(headers, header_array, num_headers) + end return Cint(0) end writebuf(body, maxsize=length(body) == 0 ? typemax(Int64) : length(body)) = Base.GenericIOBuffer{AbstractVector{UInt8}}(body, true, true, true, false, maxsize) +function aws_http2_stream_add_trailing_headers(http2_stream::Ptr{aws_http_stream}, trailing_headers::Ptr{aws_http_headers}) + return ccall((:aws_http2_stream_add_trailing_headers, LibAwsHTTPFork.libaws_c_http_jq), + Cint, (Ptr{aws_http_stream}, Ptr{aws_http_headers}), http2_stream, trailing_headers) +end + const on_response_header_block_done = Ref{Ptr{Cvoid}}(C_NULL) function c_on_response_header_block_done(aws_stream_ptr, header_block, stream_ptr) @@ -515,6 +529,30 @@ function setheaderifabsent(s::Stream, k, v) return end +function addtrailer(s::Stream, headers::Headers) + s.ptr == C_NULL && error("stream is not initialized") + if s.http2 + aws_http2_stream_add_trailing_headers(s.ptr, headers.ptr) != 0 && aws_throw_error() + else + aws_http1_stream_add_chunked_trailer(s.ptr, headers.ptr) != 0 && aws_throw_error() + end + return +end + +function addtrailer(s::Stream, h::Pair) + trailers = Headers(s.allocator) + addheader(trailers, String(h.first), String(h.second)) + return addtrailer(s, trailers) +end + +function addtrailer(s::Stream, h::AbstractVector{<:Pair}) + trailers = Headers(s.allocator) + for (k, v) in h + addheader(trailers, String(k), String(v)) + end + return addtrailer(s, trailers) +end + function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator) stream = Stream{Nothing}(allocator, decompress, http2, false) if on_stream_response_body !== nothing diff --git a/src/requestresponse.jl b/src/requestresponse.jl index 3a29be96..f333ae20 100644 --- a/src/requestresponse.jl +++ b/src/requestresponse.jl @@ -293,6 +293,7 @@ mutable struct Request <: Message inputstream::Union{Nothing, InputStream} # used for outgoing request body # only set in server-side request handlers body::Union{Nothing, Vector{UInt8}} + trailers::Union{Nothing, Headers} route::Union{Nothing, String} params::Union{Nothing, Dict{String, String}} cookies::Any # actually Union{Nothing, Vector{Cookie}} @@ -314,6 +315,7 @@ mutable struct Request <: Message req = new(allocator, ptr) req.body = nothing req.inputstream = nothing + req.trailers = nothing req.route = nothing req.params = nothing req.cookies = nothing @@ -432,6 +434,7 @@ mutable struct Response <: Message ptr::Ptr{aws_http_message} inputstream::Union{Nothing, InputStream} body::Union{Nothing, Vector{UInt8}} # only set for client-side response body when no user-provided response_body + trailers::Union{Nothing, Headers} metrics::RequestMetrics request::Union{Request, Nothing} @@ -451,6 +454,7 @@ mutable struct Response <: Message resp = new(allocator, ptr) resp.body = nothing resp.inputstream = nothing + resp.trailers = nothing resp.metrics = RequestMetrics() resp.request = nothing body !== nothing && setinputstream!(resp, body) @@ -460,7 +464,7 @@ mutable struct Response <: Message rethrow() end end - Response() = new(C_NULL, C_NULL, nothing, nothing, RequestMetrics(), nothing) + Response() = new(C_NULL, C_NULL, nothing, nothing, nothing, RequestMetrics(), nothing) end Response(status::Integer, body) = Response(status, nothing, Vector{UInt8}(string(body))) diff --git a/src/server.jl b/src/server.jl index e3027fbb..6fcba5ac 100644 --- a/src/server.jl +++ b/src/server.jl @@ -269,8 +269,17 @@ const on_request_headers = Ref{Ptr{Cvoid}}(C_NULL) function c_on_request_headers(aws_stream_ptr, header_block, header_array::Ptr{aws_http_header}, num_headers, stream_ptr) stream = unsafe_pointer_to_objref(stream_ptr) - headers = stream.request.headers - addheaders(headers, header_array, num_headers) + if header_block == AWS_HTTP_HEADER_BLOCK_TRAILING + trailers = stream.request.trailers + if trailers === nothing + trailers = Headers(stream.request.allocator) + stream.request.trailers = trailers + end + addheaders(trailers, header_array, num_headers) + else + headers = stream.request.headers + addheaders(headers, header_array, num_headers) + end return Cint(0) end @@ -278,6 +287,9 @@ const on_request_header_block_done = Ref{Ptr{Cvoid}}(C_NULL) function c_on_request_header_block_done(aws_stream_ptr, header_block, stream_ptr) stream = unsafe_pointer_to_objref(stream_ptr) + if header_block != AWS_HTTP_HEADER_BLOCK_MAIN + return Cint(0) + end ret = aws_http_stream_get_incoming_request_method(aws_stream_ptr, FieldRef(stream, :method)) ret != 0 && return ret aws_http_message_set_request_method(stream.request.ptr, stream.method) diff --git a/test/server.jl b/test/server.jl index b022a76c..c824883a 100644 --- a/test/server.jl +++ b/test/server.jl @@ -1,4 +1,5 @@ using Test, HTTP, Logging +import Sockets @testset "HTTP.serve" begin server = HTTP.serve!(req -> HTTP.Response(200, "Hello, World!"); listenany=true) @@ -34,6 +35,30 @@ end end end +@testset "HTTP response trailers" begin + server = HTTP.listen!("127.0.0.1", 0; listenany=true) do http + read(http) + HTTP.setstatus(http, 200) + HTTP.startwrite(http) + write(http, "hello") + HTTP.addtrailer(http, "X-Trailer" => "ok") + end + try + port = HTTP.port(server) + sock = Sockets.connect("127.0.0.1", port) + write(sock, "GET / HTTP/1.1\r\nHost: 127.0.0.1\r\nConnection: close\r\n\r\n") + flush(sock) + raw = String(read(sock)) + close(sock) + lower_raw = lowercase(raw) + @test occursin("transfer-encoding: chunked", lower_raw) + @test occursin("hello", raw) + @test occursin("\r\n0\r\nx-trailer: ok\r\n\r\n", lower_raw) + finally + close(server) + end +end + @testset "access logging" begin local handler = (req) -> begin if req.target == "/internal-error" From eb3578366c29f787d959233d8c0d5faf293fa47a Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sat, 24 Jan 2026 23:46:55 -0700 Subject: [PATCH 07/56] feat(http2): add stream manager support --- src/HTTP.jl | 1 + src/client/client.jl | 39 ++++++++++++++++++++ src/client/makerequest.jl | 36 +++++++++++++++---- src/client/stream.jl | 75 +++++++++++++++++++++++++++++++++++++++ test/client.jl | 7 ++++ 5 files changed, 152 insertions(+), 6 deletions(-) diff --git a/src/HTTP.jl b/src/HTTP.jl index fcd90f03..3aa9631e 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -96,6 +96,7 @@ function __init__() on_metrics[] = @cfunction(c_on_metrics, Cvoid, (Ptr{Cvoid}, Ptr{aws_http_stream_metrics}, Ptr{Cvoid})) on_complete[] = @cfunction(c_on_complete, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Cvoid})) on_destroy[] = @cfunction(c_on_destroy, Cvoid, (Ptr{Cvoid},)) + on_stream_acquired[] = @cfunction(c_on_stream_acquired, Cvoid, (Ptr{aws_http_stream}, Cint, Ptr{Cvoid})) retry_ready[] = @cfunction(c_retry_ready, Cvoid, (Ptr{aws_retry_token}, Cint, Ptr{Cvoid})) on_incoming_connection[] = @cfunction(c_on_incoming_connection, Cvoid, (Ptr{Cvoid}, Ptr{aws_http_connection}, Cint, Ptr{Cvoid})) on_connection_shutdown[] = @cfunction(c_on_connection_shutdown, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Cvoid})) diff --git a/src/client/client.jl b/src/client/client.jl index e3540715..f7735c20 100644 --- a/src/client/client.jl +++ b/src/client/client.jl @@ -44,6 +44,7 @@ Base.@kwdef struct ClientSettings max_pending_connection_acquisitions::Int = 0 enable_read_back_pressure::Bool = false http2_prior_knowledge::Bool = false + http2_stream_manager::Bool = false end ClientSettings( @@ -88,6 +89,8 @@ mutable struct Client retry_strategy::Ptr{aws_retry_strategy} conn_manager_opts::aws_http_connection_manager_options connection_manager::Ptr{aws_http_connection_manager} + http2_stream_manager_opts::Union{Nothing, aws_http2_stream_manager_options} + http2_stream_manager::Ptr{aws_http2_stream_manager} Client() = new() end @@ -204,12 +207,48 @@ function Client(cs::ClientSettings) ) client.connection_manager = aws_http_connection_manager_new(cs.allocator, FieldRef(client, :conn_manager_opts)) client.connection_manager == C_NULL && aws_throw_error() + client.http2_stream_manager_opts = nothing + client.http2_stream_manager = C_NULL + if cs.http2_stream_manager + opts = aws_http2_stream_manager_options( + cs.bootstrap, + pointer(FieldRef(client, :socket_options)), + (cs.scheme == "https" || cs.scheme == "wss") ? pointer(FieldRef(client, :tls_options)) : C_NULL, + cs.http2_prior_knowledge, + aws_byte_cursor_from_c_str(cs.host), + cs.port % UInt32, + C_NULL, # initial_settings_array + 0, # num_initial_settings + 0, # max_closed_streams + false, # conn_manual_window_management + cs.enable_read_back_pressure, + typemax(Csize_t), # initial_window_size + C_NULL, # monitoring_options + client.proxy_options === nothing ? C_NULL : pointer(FieldRef(client, :proxy_options)), + client.proxy_env_settings === nothing ? C_NULL : pointer(FieldRef(client, :proxy_env_settings)), + C_NULL, # shutdown_complete_user_data + C_NULL, # shutdown_complete_callback + false, # close_connection_on_server_error + 0, # connection_ping_period_ms + 0, # connection_ping_timeout_ms + 0, # ideal_concurrent_streams_per_connection + 0, # max_concurrent_streams_per_connection + cs.max_connections, + ) + client.http2_stream_manager_opts = opts + client.http2_stream_manager = aws_http2_stream_manager_new(cs.allocator, Ref(opts)) + client.http2_stream_manager == C_NULL && aws_throw_error() + end finalizer(client) do x if x.connection_manager != C_NULL aws_http_connection_manager_release(x.connection_manager) x.connection_manager = C_NULL end + if x.http2_stream_manager != C_NULL + aws_http2_stream_manager_release(x.http2_stream_manager) + x.http2_stream_manager = C_NULL + end if x.retry_strategy != C_NULL aws_retry_strategy_release(x.retry_strategy) x.retry_strategy = C_NULL diff --git a/src/client/makerequest.jl b/src/client/makerequest.jl index 2a1421c0..5885f865 100644 --- a/src/client/makerequest.jl +++ b/src/client/makerequest.jl @@ -105,10 +105,9 @@ function request(method, url, h=Header[], b=nothing; with_retry_token(reqclient; logerrors=logerrors, logtag=logtag, method=method, uri=uri, retry_check=retry_check, retry_delays=retry_delays, retry_non_idempotent=retry_non_idempotent, retryable_body=retryable_body, req_ref=req_ref) do - resp = with_connection(reqclient) do conn - http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 + resp = if reqclient.http2_stream_manager != C_NULL path = resource(uri) - with_request(reqclient, method, path, headers, body, chunkedbody, decompress, authinfo, bearer, modifier, http2, cookies, cookiejar, verbose; + with_request(reqclient, method, path, headers, body, chunkedbody, decompress, authinfo, bearer, modifier, true, cookies, cookiejar, verbose; copyheaders=false, canonicalize_headers=canonicalize_headers, detect_content_type=detect_content_type, @@ -119,13 +118,38 @@ function request(method, url, h=Header[], b=nothing; ref = Ref(1) GC.@preserve ref begin on_stream_response_body = BufferOnResponseBody(response_body, Base.unsafe_convert(Ptr{Int}, ref)) - with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator) + with_stream_manager(reqclient, req, chunkedbody, on_stream_response_body, decompress, readtimeout, allocator) end elseif response_body isa IO on_stream_response_body = IOOnResponseBody(response_body) - with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator) + with_stream_manager(reqclient, req, chunkedbody, on_stream_response_body, decompress, readtimeout, allocator) else - with_stream(conn, req, chunkedbody, response_body, decompress, http2, readtimeout, allocator) + with_stream_manager(reqclient, req, chunkedbody, response_body, decompress, readtimeout, allocator) + end + end + else + with_connection(reqclient) do conn + http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 + path = resource(uri) + with_request(reqclient, method, path, headers, body, chunkedbody, decompress, authinfo, bearer, modifier, http2, cookies, cookiejar, verbose; + copyheaders=false, + canonicalize_headers=canonicalize_headers, + detect_content_type=detect_content_type, + basicauth=apply_basicauth, + ) do req + req_ref[] = req + if response_body isa AbstractVector{UInt8} + ref = Ref(1) + GC.@preserve ref begin + on_stream_response_body = BufferOnResponseBody(response_body, Base.unsafe_convert(Ptr{Int}, ref)) + with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator) + end + elseif response_body isa IO + on_stream_response_body = IOOnResponseBody(response_body) + with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator) + else + with_stream(conn, req, chunkedbody, response_body, decompress, http2, readtimeout, allocator) + end end end end diff --git a/src/client/stream.jl b/src/client/stream.jl index b38687d2..93ec1994 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -105,6 +105,18 @@ function c_on_destroy(stream) return end +const on_stream_acquired = Ref{Ptr{Cvoid}}(C_NULL) + +function c_on_stream_acquired(aws_stream_ptr, error_code, fut_ptr) + fut = unsafe_pointer_to_objref(fut_ptr) + if error_code != 0 + notify(fut, CapturedException(aws_error(error_code), Base.backtrace())) + else + notify(fut, aws_stream_ptr) + end + return +end + if !@isdefined aws_websocket_server_upgrade_options const aws_websocket_server_upgrade_options = Ptr{Cvoid} end @@ -553,6 +565,69 @@ function addtrailer(s::Stream, h::AbstractVector{<:Pair}) return addtrailer(s, trailers) end +function with_stream_manager(client::Client, req::Request, chunkedbody, on_stream_response_body, decompress, readtimeout, allocator) + stream = Stream{Nothing}(allocator, decompress, true, false) + if on_stream_response_body !== nothing + stream.bufferstream = Base.BufferStream() + end + acquire_fut = Future{Ptr{aws_http_stream}}() + GC.@preserve stream acquire_fut begin + stream.request_options = aws_http_make_request_options( + 1, + req.ptr, + pointer_from_objref(stream), + on_response_headers[], + on_response_header_block_done[], + on_response_body[], + on_metrics[], + on_complete[], + on_destroy[], + chunkedbody !== nothing, # http2_use_manual_data_writes + readtimeout * 1000 # response_first_byte_timeout_ms + ) + stream.response = resp = Response(0, nothing, nothing, true, allocator) + resp.metrics = RequestMetrics() + resp.request = req + acquire_opts = aws_http2_stream_manager_acquire_stream_options( + on_stream_acquired[], + pointer_from_objref(acquire_fut), + FieldRef(stream, :request_options), + ) + aws_http2_stream_manager_acquire_stream(client.http2_stream_manager, Ref(acquire_opts)) + stream_ptr = wait(acquire_fut) + stream.ptr = stream_ptr + stream.activated = true + try + if chunkedbody !== nothing + foreach(chunk -> writechunk(stream, chunk), chunkedbody) + writechunk(stream, "") + end + if on_stream_response_body !== nothing + try + while !eof(stream.bufferstream) + on_stream_response_body(resp, _readavailable(stream.bufferstream)) + end + wait(stream.fut) + catch e + rethrow(DontRetry(e)) + end + else + wait(stream.fut) + if stream.bufferstream !== nothing + resp.body = _readavailable(stream.bufferstream) + else + resp.body = UInt8[] + end + end + return resp + finally + aws_http_stream_release(stream_ptr) + stream.released = true + stream.ptr = Ptr{aws_http_stream}(C_NULL) + end + end +end + function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator) stream = Stream{Nothing}(allocator, decompress, http2, false) if on_stream_response_body !== nothing diff --git a/test/client.jl b/test/client.jl index 50fb7d1a..03e27f2a 100644 --- a/test/client.jl +++ b/test/client.jl @@ -320,6 +320,13 @@ @test resp.status == 200 end + @testset "HTTP/2 stream manager smoke" begin + cs = HTTP.ClientSettings("https", "example.com", UInt32(443); http2_stream_manager=true) + client = HTTP.Client(cs) + @test client.http2_stream_manager != C_NULL + finalize(client) + end + @testset "Public entry point of HTTP.request and friends (e.g. issue #463)" begin headers = Dict("User-Agent" => "HTTP.jl") query = Dict("hello" => "world") From b10590ca46682a728cb4aa6fa499ddaabcec681c Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sat, 24 Jan 2026 23:56:25 -0700 Subject: [PATCH 08/56] feat(http2): add settings/ping/goaway helpers --- src/HTTP.jl | 2 + src/client/connection.jl | 144 +++++++++++++++++++++++++++++++++++++++ src/client/stream.jl | 15 ++++ test/client.jl | 18 +++++ 4 files changed, 179 insertions(+) diff --git a/src/HTTP.jl b/src/HTTP.jl index 3aa9631e..ec05160f 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -89,6 +89,8 @@ function __init__() on_acquired[] = @cfunction(c_on_acquired, Cvoid, (Ptr{Cvoid}, Cint, Ptr{aws_retry_token}, Ptr{Cvoid})) # on_shutdown[] = @cfunction(c_on_shutdown, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Cvoid})) on_setup[] = @cfunction(c_on_setup, Cvoid, (Ptr{aws_http_connection}, Cint, Ptr{Cvoid})) + on_change_settings_complete[] = @cfunction(c_on_change_settings_complete, Cvoid, (Ptr{aws_http_connection}, Cint, Ptr{Cvoid})) + on_ping_complete[] = @cfunction(c_on_ping_complete, Cvoid, (Ptr{aws_http_connection}, UInt64, Cint, Ptr{Cvoid})) on_stream_write_on_complete[] = @cfunction(c_on_stream_write_on_complete, Cvoid, (Ptr{aws_http_stream}, Cint, Ptr{Cvoid})) on_response_headers[] = @cfunction(c_on_response_headers, Cint, (Ptr{Cvoid}, Cint, Ptr{aws_http_header}, Csize_t, Ptr{Cvoid})) on_response_header_block_done[] = @cfunction(c_on_response_header_block_done, Cint, (Ptr{Cvoid}, Cint, Ptr{Cvoid})) diff --git a/src/client/connection.jl b/src/client/connection.jl index 17071a05..b93530de 100644 --- a/src/client/connection.jl +++ b/src/client/connection.jl @@ -1,4 +1,6 @@ const on_setup = Ref{Ptr{Cvoid}}(C_NULL) +const on_change_settings_complete = Ref{Ptr{Cvoid}}(C_NULL) +const on_ping_complete = Ref{Ptr{Cvoid}}(C_NULL) function c_on_setup(conn, error_code, fut_ptr) fut = unsafe_pointer_to_objref(fut_ptr) @@ -12,6 +14,26 @@ function c_on_setup(conn, error_code, fut_ptr) return end +function c_on_change_settings_complete(conn, error_code, fut_ptr) + fut = unsafe_pointer_to_objref(fut_ptr) + if error_code != 0 + notify(fut, CapturedException(aws_error(error_code), Base.backtrace())) + else + notify(fut, nothing) + end + return +end + +function c_on_ping_complete(conn, round_trip_time_ns, error_code, fut_ptr) + fut = unsafe_pointer_to_objref(fut_ptr) + if error_code != 0 + notify(fut, CapturedException(aws_error(error_code), Base.backtrace())) + else + notify(fut, round_trip_time_ns) + end + return +end + function with_connection(f::Function, client::Client) fut = Future{Ptr{aws_http_connection}}() GC.@preserve fut begin @@ -24,3 +46,125 @@ function with_connection(f::Function, client::Client) aws_http_connection_manager_release_connection(client.connection_manager, connection) end end + +function _ensure_http2_connection(conn::Ptr{aws_http_connection}) + conn == C_NULL && throw(ArgumentError("HTTP/2 connection is null")) + aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 || throw(ArgumentError("HTTP/2 connection required")) + return conn +end + +function _with_http2_connection(f::Function, client::Client) + return with_connection(client) do conn + f(_ensure_http2_connection(conn)) + end +end + +function http2_ping(conn::Ptr{aws_http_connection}; data=nothing) + _ensure_http2_connection(conn) + fut = Future{UInt64}() + cursor_ref = Ref{aws_byte_cursor}() + cursor_ptr = C_NULL + bytes = nothing + if data !== nothing + bytes = data isa AbstractString ? Vector{UInt8}(codeunits(data)) : Vector{UInt8}(data) + length(bytes) == AWS_HTTP2_PING_DATA_SIZE || throw(ArgumentError("PING data must be $(AWS_HTTP2_PING_DATA_SIZE) bytes")) + GC.@preserve bytes begin + cursor_ref[] = aws_byte_cursor_from_array(pointer(bytes), length(bytes)) + end + cursor_ptr = cursor_ref + end + GC.@preserve fut cursor_ref bytes begin + aws_http2_connection_ping(conn, cursor_ptr, on_ping_complete[], pointer_from_objref(fut)) != 0 && aws_throw_error() + return wait(fut) + end +end + +http2_ping(client::Client; data=nothing) = _with_http2_connection(conn -> http2_ping(conn; data=data), client) + +function _settings_from_pairs(settings::AbstractVector{<:Pair}) + out = Vector{aws_http2_setting}(undef, length(settings)) + for (i, (k, v)) in enumerate(settings) + out[i] = aws_http2_setting(aws_http2_settings_id(k), UInt32(v)) + end + return out +end + +function http2_change_settings(conn::Ptr{aws_http_connection}, settings::AbstractVector{aws_http2_setting}) + _ensure_http2_connection(conn) + fut = Future{Nothing}() + settings_ptr = isempty(settings) ? C_NULL : pointer(settings) + GC.@preserve settings fut begin + aws_http2_connection_change_settings(conn, settings_ptr, length(settings), on_change_settings_complete[], pointer_from_objref(fut)) != 0 && aws_throw_error() + wait(fut) + end + return +end + +http2_change_settings(conn::Ptr{aws_http_connection}, settings::AbstractVector{<:Pair}) = + http2_change_settings(conn, _settings_from_pairs(settings)) + +http2_change_settings(client::Client, settings) = + _with_http2_connection(conn -> http2_change_settings(conn, settings), client) + + +function http2_local_settings(conn::Ptr{aws_http_connection}) + _ensure_http2_connection(conn) + settings = Vector{aws_http2_setting}(undef, AWS_HTTP2_SETTINGS_COUNT) + aws_http2_connection_get_local_settings(conn, pointer(settings)) + return settings +end + +http2_local_settings(client::Client) = _with_http2_connection(http2_local_settings, client) + +function http2_remote_settings(conn::Ptr{aws_http_connection}) + _ensure_http2_connection(conn) + settings = Vector{aws_http2_setting}(undef, AWS_HTTP2_SETTINGS_COUNT) + aws_http2_connection_get_remote_settings(conn, pointer(settings)) + return settings +end + +http2_remote_settings(client::Client) = _with_http2_connection(http2_remote_settings, client) + +function http2_send_goaway(conn::Ptr{aws_http_connection}, http2_error::Integer; allow_more_streams::Bool=true, debug_data=nothing) + _ensure_http2_connection(conn) + cursor_ref = Ref{aws_byte_cursor}() + cursor_ptr = C_NULL + bytes = nothing + if debug_data !== nothing + bytes = debug_data isa AbstractString ? Vector{UInt8}(codeunits(debug_data)) : Vector{UInt8}(debug_data) + length(bytes) <= 16 * 1024 || throw(ArgumentError("debug_data must be <= 16KB")) + GC.@preserve bytes begin + cursor_ref[] = aws_byte_cursor_from_array(pointer(bytes), length(bytes)) + end + cursor_ptr = cursor_ref + end + GC.@preserve bytes cursor_ref begin + aws_http2_connection_send_goaway(conn, UInt32(http2_error), allow_more_streams, cursor_ptr) + end + return +end + +http2_send_goaway(client::Client, http2_error::Integer; allow_more_streams::Bool=true, debug_data=nothing) = + _with_http2_connection(conn -> http2_send_goaway(conn, http2_error; allow_more_streams=allow_more_streams, debug_data=debug_data), client) + + +function _get_goaway(get_fn, conn::Ptr{aws_http_connection}) + _ensure_http2_connection(conn) + http2_error = Ref{UInt32}() + last_stream_id = Ref{UInt32}() + ret = get_fn(conn, http2_error, last_stream_id) + if ret == 0 + return (http2_error=http2_error[], last_stream_id=last_stream_id[]) + elseif ret == AWS_ERROR_HTTP_DATA_NOT_AVAILABLE + return nothing + else + aws_throw_error() + end +end + +http2_get_sent_goaway(conn::Ptr{aws_http_connection}) = _get_goaway(aws_http2_connection_get_sent_goaway, conn) +http2_get_received_goaway(conn::Ptr{aws_http_connection}) = _get_goaway(aws_http2_connection_get_received_goaway, conn) + +http2_get_sent_goaway(client::Client) = _with_http2_connection(http2_get_sent_goaway, client) + +http2_get_received_goaway(client::Client) = _with_http2_connection(http2_get_received_goaway, client) diff --git a/src/client/stream.jl b/src/client/stream.jl index 93ec1994..631432ba 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -210,6 +210,21 @@ function release_stream_ptr!(s::Stream) return end +function _with_http2_connection(f::Function, stream::Stream) + stream.ptr == C_NULL && throw(ArgumentError("HTTP stream is not initialized")) + conn = aws_http_stream_get_connection(stream.ptr) + return f(_ensure_http2_connection(conn)) +end + +http2_ping(stream::Stream; data=nothing) = _with_http2_connection(conn -> http2_ping(conn; data=data), stream) +http2_change_settings(stream::Stream, settings) = _with_http2_connection(conn -> http2_change_settings(conn, settings), stream) +http2_local_settings(stream::Stream) = _with_http2_connection(http2_local_settings, stream) +http2_remote_settings(stream::Stream) = _with_http2_connection(http2_remote_settings, stream) +http2_send_goaway(stream::Stream, http2_error::Integer; allow_more_streams::Bool=true, debug_data=nothing) = + _with_http2_connection(conn -> http2_send_goaway(conn, http2_error; allow_more_streams=allow_more_streams, debug_data=debug_data), stream) +http2_get_sent_goaway(stream::Stream) = _with_http2_connection(http2_get_sent_goaway, stream) +http2_get_received_goaway(stream::Stream) = _with_http2_connection(http2_get_received_goaway, stream) + const on_stream_write_on_complete = Ref{Ptr{Cvoid}}(C_NULL) function c_on_stream_write_on_complete(aws_stream_ptr, error_code, fut_ptr) diff --git a/test/client.jl b/test/client.jl index 03e27f2a..ce91e8ed 100644 --- a/test/client.jl +++ b/test/client.jl @@ -327,6 +327,24 @@ finalize(client) end + @testset "HTTP/2 control APIs" begin + resp = HTTP.get("https://$httpbin/ip") + if resp.version == HTTP.HTTPVersion(2, 0) + HTTP.open("GET", "https://$httpbin/ip") do io + r = HTTP.startread(io) + @test r.status == 200 + rtt = HTTP.http2_ping(io) + @test rtt isa UInt64 + HTTP.http2_change_settings(io, Pair{Int, Int}[]) + @test length(HTTP.http2_local_settings(io)) == HTTP.AWS_HTTP2_SETTINGS_COUNT + @test HTTP.http2_get_sent_goaway(io) === nothing + @test HTTP.http2_get_received_goaway(io) === nothing + end + else + @test_skip "HTTP/2 not available for $httpbin" + end + end + @testset "Public entry point of HTTP.request and friends (e.g. issue #463)" begin headers = Dict("User-Agent" => "HTTP.jl") query = Dict("hello" => "world") From 89fff7688b2a389edc11c58973d100b21b1349fd Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 00:02:25 -0700 Subject: [PATCH 09/56] feat(metrics): expose http manager metrics --- src/client/client.jl | 10 ++++++++++ test/client.jl | 16 ++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/src/client/client.jl b/src/client/client.jl index f7735c20..c97a3183 100644 --- a/src/client/client.jl +++ b/src/client/client.jl @@ -286,6 +286,16 @@ function getclient(key::ClientSettings, clients::Clients=CLIENTS) end end +function manager_metrics(client::Client) + metrics = Ref{aws_http_manager_metrics}() + if client.http2_stream_manager != C_NULL + aws_http2_stream_manager_fetch_metrics(client.http2_stream_manager, metrics) + else + aws_http_connection_manager_fetch_metrics(client.connection_manager, metrics) + end + return metrics[] +end + getclient(key::ClientSettings, pool::Pool) = getclient(key, pool.clients) function close_all_clients!(clients::Clients=CLIENTS) diff --git a/test/client.jl b/test/client.jl index ce91e8ed..e2bc8489 100644 --- a/test/client.jl +++ b/test/client.jl @@ -327,6 +327,22 @@ finalize(client) end + @testset "HTTP manager metrics" begin + client = HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443))) + metrics = HTTP.manager_metrics(client) + @test metrics.available_concurrency >= 0 + @test metrics.pending_concurrency_acquires >= 0 + @test metrics.leased_concurrency >= 0 + finalize(client) + + client = HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443); http2_stream_manager=true)) + metrics = HTTP.manager_metrics(client) + @test metrics.available_concurrency >= 0 + @test metrics.pending_concurrency_acquires >= 0 + @test metrics.leased_concurrency >= 0 + finalize(client) + end + @testset "HTTP/2 control APIs" begin resp = HTTP.get("https://$httpbin/ip") if resp.version == HTTP.HTTPVersion(2, 0) From 4f1a32f01fc7290c58091418a5ccc669c908cd29 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 00:02:38 -0700 Subject: [PATCH 10/56] fix(websockets): return close error on closed receive --- src/websockets.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/websockets.jl b/src/websockets.jl index 8a4292c0..f9f04bdd 100644 --- a/src/websockets.jl +++ b/src/websockets.jl @@ -555,7 +555,10 @@ returned by `receive`. Note that `WebSocket` objects can be iterated, where each iteration yields a message until the connection is closed. """ function receive(ws::WebSocket) - @assert isopen(ws.readchannel) "WebSocket is closed" + if !isopen(ws.readchannel) + close_body = ws.closebody === nothing ? CloseFrameBody(1006, "") : ws.closebody + throw(WebSocketError(close_body)) + end msg = take!(ws.readchannel) if msg isa WebSocketError throw(msg) From f752564d0d210101f1d1717a65e0a9f531715be7 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 00:17:28 -0700 Subject: [PATCH 11/56] docs: document http2, trailers, and metrics --- docs/src/manual/client.md | 71 ++++++++++++++++++++++++++++++++++++-- docs/src/manual/migrate.md | 6 +++- docs/src/manual/server.md | 22 ++++++++++++ 3 files changed, 95 insertions(+), 4 deletions(-) diff --git a/docs/src/manual/client.md b/docs/src/manual/client.md index a29b0404..77c1e62f 100644 --- a/docs/src/manual/client.md +++ b/docs/src/manual/client.md @@ -35,10 +35,14 @@ The following keyword arguments (which correspond to the non-`scheme`/`host`/`po - **chunkedbody**: To send a request body in chunks, an iterable must be provided where each element is one of the valid types of request bodies mentioned above. - **modifier**: A function of the form `f(request, body) -> newbody`, i.e. that takes the HTTP request object and proposed request body, and can optionally return a new request body. If the modifer only modifies the request object, it should return `nothing`, which will ensure the original request body is sent unmodified. -- Response options: - - **response_body**: By default, response bodies are returned as `Vector{UInt8}`. Alternatively, a preallocated `AbstractVector{UInt8}` or any `IO` object can be provided for the response body to be written into. + - **response_body**: By default, response bodies are returned as `Vector{UInt8}`. Alternatively, a preallocated `AbstractVector{UInt8}` or any `IO` object can be provided for the response body to be written into. `response_stream` is a compatible alias. - **decompress**: If `true`, the response body will be decompressed if it is compressed. By default, response bodies with the `Content-Encoding: gzip` header are decompressed. - **status_exception**: Default `true`. If `true`, an exception will be thrown if the response status code is not in the 200-299 range. - **readtimeout**: The maximum time in seconds to wait for a response from the server. Only valid for HTTP/1.1 connections. +-- Retry options (per request): + - **retry_non_idempotent**: Default `false`. If `true`, non-idempotent requests may be retried. + - **retry_check**: Optional function to override retry decisions. It is called as `retry_check(delay, err, req, resp, resp_body)` when a retry is being considered. + - **retry_delays**: Custom retry delays. Provide a number (seconds) or any iterator of delays; defaults to exponential backoff. -- Redirect options: - **redirect**: Default `true`. If `true`, the client will follow redirects. - **redirect_limit**: The maximum number of redirects to follow. Default is 3. @@ -78,7 +82,7 @@ The following keyword arguments (which correspond to the non-`scheme`/`host`/`po - **proxy_ssl_insecure**: Default `false`. If `true`, SSL certificate verification will be disabled for the proxy. - **proxy_ssl_alpn_list**: A list of ALPN protocols to use for the proxy connection. Default is `"h2;http/1.1"`. -- Retry options: - - **max_retries**: The maximum number of times to retry a request. Default is 10. + - **max_retries**: The maximum number of times to retry a request. Default is 4. - **retry_partition**: Requests utilizing the same retry partition (an arbitrary string) will coordinate retries against each other to not overwhelm a temporarily unresponsive server. - **backoff_scale_factor_ms**: The factor by which to scale the backoff time between retries. Default is 25. - **max_backoff_secs**: The maximum time in seconds to wait between retries. Default is 20. @@ -88,6 +92,13 @@ The following keyword arguments (which correspond to the non-`scheme`/`host`/`po -- Connection pool options: - **max_connections**: The maximum number of connections to keep open in the connection pool. Default is 512. - **max_connection_idle_in_milliseconds**: The maximum time in milliseconds to keep a connection open in the pool. Default is 60000. + - **connection_acquisition_timeout_ms**: The maximum time in milliseconds to wait for a connection from the pool. Default is 0 (no timeout). + - **max_pending_connection_acquisitions**: Maximum number of pending connection acquisitions. Default is 0 (no limit). + - **enable_read_back_pressure**: If `true`, enable back pressure on reads to limit buffered data. Default `false`. + - **response_first_byte_timeout_ms**: Maximum time in milliseconds to wait for the first response byte. Default is 0 (disabled). For per-request control, use `readtimeout`. + -- HTTP/2 options: + - **http2_prior_knowledge**: Default `false`. If `true`, assume HTTP/2 without ALPN negotiation. + - **http2_stream_manager**: Default `false`. If `true`, enable the HTTP/2 stream manager for multiplexed requests. -- AWS runtime options: - **allocator**: The allocator to use for AWS-allocated memory during the request. - **bootstrap**: The AWS client bootstrap to use for the request. @@ -158,6 +169,61 @@ response = HTTP.get("https://api.example.com/data"; client = custom_client) println(String(response.body)) \`\`\` +## HTTP/2 Features (Advanced) + +### Stream Manager + +For high-concurrency HTTP/2 workloads, you can enable the HTTP/2 stream manager. This allows multiple in-flight +streams to share pooled HTTP/2 connections. + +\`\`\`julia +using HTTP + +client = HTTP.Client("https", "example.com", 443; http2_stream_manager=true) +resp = HTTP.get("https://example.com/resource"; client=client) +\`\`\` + +### Connection Control Helpers + +When a connection negotiates HTTP/2, you can use the following helpers: + +- `HTTP.http2_ping(client; data=nothing)` -> returns round-trip time in nanoseconds. +- `HTTP.http2_change_settings(client, settings)` where `settings` is a vector of pairs or `aws_http2_setting`. +- `HTTP.http2_local_settings(client)` / `HTTP.http2_remote_settings(client)` -> returns current settings. +- `HTTP.http2_send_goaway(client, error_code; allow_more_streams=true, debug_data=nothing)` +- `HTTP.http2_get_sent_goaway(client)` / `HTTP.http2_get_received_goaway(client)` -> returns `nothing` if no GOAWAY. + +These helpers require an HTTP/2 connection and will throw an `ArgumentError` if the connection is HTTP/1.1. + +## Trailing Headers + +Trailing headers are available on responses as `resp.trailers` after the response completes. For streaming requests, +you can attach trailers before closing the write side of the stream. + +\`\`\`julia +using HTTP + +resp = HTTP.open("POST", "https://example.com/upload") do stream + write(stream, "chunk-1") + write(stream, "chunk-2") + HTTP.addtrailer(stream, "x-checksum" => "abc123") +end + +resp.trailers === nothing || println(HTTP.header(resp.trailers, "x-server-checksum")) +\`\`\` + +## Metrics and Observability + +Each response includes a `metrics` field: + +- `response.metrics.request_body_length` +- `response.metrics.response_body_length` +- `response.metrics.nretries` +- `response.metrics.stream_metrics` (AWS CRT `aws_http_stream_metrics`) + +For connection-level metrics, use `HTTP.manager_metrics(client)`, which returns `aws_http_manager_metrics` with +`available_concurrency`, `pending_concurrency_acquires`, and `leased_concurrency`. + ## Under the Hood (Advanced) When you call `HTTP.request`, the following advanced steps occur: @@ -179,4 +245,3 @@ When you call `HTTP.request`, the following advanced steps occur: 6. **Response Processing:** The response is parsed, and if errors occur (as dictated by your settings), an exception is raised. - diff --git a/docs/src/manual/migrate.md b/docs/src/manual/migrate.md index f2bcfc82..c4ae321c 100644 --- a/docs/src/manual/migrate.md +++ b/docs/src/manual/migrate.md @@ -53,6 +53,7 @@ While the basic request syntax remains similar, there are some changes to keywor - The `response_stream` keyword argument is still supported, but HTTP.jl no longer automatically closes this stream when done - you need to handle this yourself - `retry` behavior has been overhauled with more consistent rules for what is retryable +- The default `max_retries` is now 4 (was 10 in v1.x) - Some connection-related options have new defaults (e.g., TLS is now OpenSSL-based by default rather than MbedTLS) Example: @@ -335,6 +336,9 @@ cookies = HTTP.getcookies(jar, "example.com") - **TLS Implementation**: OpenSSL is now the default TLS provider instead of MbedTLS - **Multithreading**: Improved thread safety throughout the codebase - **Performance**: Significant performance improvements, especially for high-throughput servers +- **Trailing Headers**: Trailing headers are now captured on `Request.trailers` and `Response.trailers`, and can be sent with `HTTP.addtrailer` when streaming. +- **HTTP/2 Controls**: HTTP/2 helpers (ping, settings, GOAWAY) are available for advanced connection management. +- **Metrics**: Responses include `response.metrics` and clients expose `HTTP.manager_metrics` for connection manager stats. ## Transitioning Tips @@ -350,4 +354,4 @@ cookies = HTTP.getcookies(jar, "example.com") - Custom client-side layers from v1.x are not compatible with v2.0 and will need to be reimplemented - WebSocket handling code may need adjustments even though the API is similar -For more detailed information on specific topics, consult the full HTTP.jl v2.0.0 documentation. \ No newline at end of file +For more detailed information on specific topics, consult the full HTTP.jl v2.0.0 documentation. diff --git a/docs/src/manual/server.md b/docs/src/manual/server.md index 9dfc75b1..e61cde90 100644 --- a/docs/src/manual/server.md +++ b/docs/src/manual/server.md @@ -117,6 +117,28 @@ server = HTTP.serve!(handle_request, "127.0.0.1", 8080; println("Server started on http://127.0.0.1:8080") ``` +## Stream Handlers (Advanced) + +If you need direct access to the request and response streams, pass `stream=true` to `serve!` (or `serve`) and use +an `HTTP.Stream` handler. Reads will automatically start when you read from the stream, and writes will automatically +start when you write. + +```julia +using HTTP + +server = HTTP.serve!("127.0.0.1", 8080; stream=true) do stream::HTTP.Stream + req = HTTP.startread(stream) + data = read(stream) + + HTTP.setstatus(stream, 200) + HTTP.setheader(stream, "Content-Type" => "text/plain") + write(stream, "received $(length(data)) bytes") + HTTP.addtrailer(stream, "x-request-id" => HTTP.header(req, "x-request-id", "unknown")) +end +``` + +Trailing headers sent by the client are available after the request completes via `stream.request.trailers`. + ## Handlers and Middleware ### Handler Functions From db9b9f6c6ef4535042c5fab77fef86ff8d8f844f Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 00:33:10 -0700 Subject: [PATCH 12/56] feat(http2): add server push promise support --- .gitignore | 2 ++ docs/src/manual/server.md | 25 ++++++++++++++ src/server.jl | 69 +++++++++++++++++++++++++++++++++++++++ test/fixtures/http2.crt | 19 +++++++++++ test/fixtures/http2.key | 28 ++++++++++++++++ test/server.jl | 40 +++++++++++++++++++++++ 6 files changed, 183 insertions(+) create mode 100644 test/fixtures/http2.crt create mode 100644 test/fixtures/http2.key diff --git a/.gitignore b/.gitignore index 9ac1073d..d2c41218 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,5 @@ /test/websockets/reports/* .idea/* .vscode +!test/fixtures/http2.key +!test/fixtures/http2.crt diff --git a/docs/src/manual/server.md b/docs/src/manual/server.md index e61cde90..b9885091 100644 --- a/docs/src/manual/server.md +++ b/docs/src/manual/server.md @@ -139,6 +139,31 @@ end Trailing headers sent by the client are available after the request completes via `stream.request.trailers`. +## HTTP/2 Server Push (Advanced) + +When handling an HTTP/2 request, you can send a push promise to the client and stream a pushed response. +The `HTTP.push_promise` function returns a new `HTTP.Stream` to write the pushed response. + +```julia +using HTTP + +HTTP.serve!("127.0.0.1", 8443; stream=true, ssl_cert="server.crt", ssl_key="server.key", ssl_alpn_list="h2") do stream + req = HTTP.startread(stream) + if stream.http2 + push = HTTP.push_promise(stream, "GET", "/assets/app.js"; scheme="https", authority="127.0.0.1:8443") + HTTP.setstatus(push, 200) + HTTP.setheader(push, "Content-Type" => "application/javascript") + write(push, "console.log(\"pushed\")") + HTTP.closewrite(push) + end + HTTP.setstatus(stream, 200) + write(stream, "ok") +end +``` + +The push request must include `:scheme` and `:authority`. If these are not present on the original request, +pass them explicitly via the keyword arguments. Clients that do not accept server push will reject the promise. + ## Handlers and Middleware ### Handler Functions diff --git a/src/server.jl b/src/server.jl index 6fcba5ac..a351c88e 100644 --- a/src/server.jl +++ b/src/server.jl @@ -20,6 +20,30 @@ end Base.hash(c::Connection, h::UInt) = hash(c.connection, h) +if !@isdefined aws_http2_send_push_promise_options + struct aws_http2_send_push_promise_options + self_size::Csize_t + user_data::Ptr{Cvoid} + on_complete::Ptr{Cvoid} + on_destroy::Ptr{Cvoid} + pad_length::UInt8 + end +end + +function aws_http2_stream_send_push_promise( + parent_stream::Ptr{aws_http_stream}, + request::Ptr{aws_http_message}, + options::Ref{aws_http2_send_push_promise_options}, +) + return ccall((:aws_http2_stream_send_push_promise, LibAwsHTTPFork.libaws_c_http_jq), + Ptr{aws_http_stream}, + (Ptr{aws_http_stream}, Ptr{aws_http_message}, Ptr{aws_http2_send_push_promise_options}), + parent_stream, + request, + options, + ) +end + function remote_address(c::Connection) socket_ptr = aws_http_connection_get_remote_endpoint(c.connection) addr = unsafe_load(socket_ptr).address @@ -178,6 +202,51 @@ end listen!(f, host="127.0.0.1", port=8080; kw...) = serve!(f, host, port; stream=true, kw...) listen(f, host="127.0.0.1", port=8080; kw...) = serve(f, host, port; stream=true, kw...) +function _push_promise_headers!(req::Request, parent::Stream; scheme=nothing, authority=nothing) + if !hasheader(req.headers, ":scheme") + scheme_val = scheme === nothing ? header(parent.request, ":scheme", "") : String(scheme) + isempty(scheme_val) && throw(ArgumentError("push promise requires :scheme")) + addheader(req.headers, ":scheme", scheme_val) + end + if !hasheader(req.headers, ":authority") + authority_val = authority === nothing ? header(parent.request, ":authority", header(parent.request, "host", "")) : String(authority) + isempty(authority_val) && throw(ArgumentError("push promise requires :authority")) + addheader(req.headers, ":authority", authority_val) + end + return +end + +function push_promise(parent::Stream, req::Request; pad_length::Integer=0, scheme=nothing, authority=nothing) + parent.server_side || error("push_promise is only supported for server streams") + parent.http2 || throw(ArgumentError("HTTP/2 stream required for push promise")) + req.version == HTTPVersion(2, 0) || throw(ArgumentError("push promise request must be HTTP/2")) + 0 <= pad_length <= 0xff || throw(ArgumentError("pad_length must be between 0 and 255")) + _push_promise_headers!(req, parent; scheme=scheme, authority=authority) + push_stream = Stream{typeof(parent.connection)}(parent.allocator, false, true, true) + push_stream.connection = parent.connection + push_stream.request = req + opts = aws_http2_send_push_promise_options( + sizeof(aws_http2_send_push_promise_options), + pointer_from_objref(push_stream), + on_server_stream_complete[], + on_destroy[], + UInt8(pad_length), + ) + stream_ptr = aws_http2_stream_send_push_promise(parent.ptr, req.ptr, Ref(opts)) + stream_ptr == C_NULL && aws_throw_error() + push_stream.ptr = stream_ptr + @lock parent.connection.streams_lock begin + push!(parent.connection.streams, push_stream) + end + retain_stream!(push_stream) + return push_stream +end + +function push_promise(parent::Stream, method::Union{String, Symbol}, path; headers=Header[], pad_length::Integer=0, scheme=nothing, authority=nothing) + req = Request(String(method), String(path), headers, nothing, true, parent.allocator) + return push_promise(parent, req; pad_length=pad_length, scheme=scheme, authority=authority) +end + const on_incoming_connection = Ref{Ptr{Cvoid}}(C_NULL) function c_on_incoming_connection(aws_server, aws_conn, error_code, server_ptr) diff --git a/test/fixtures/http2.crt b/test/fixtures/http2.crt new file mode 100644 index 00000000..0e60882b --- /dev/null +++ b/test/fixtures/http2.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDCTCCAfGgAwIBAgIUS6toYHFUbN6xEG51OuiEhDeHiZwwDQYJKoZIhvcNAQEL +BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDEyNTA3MjgxMFoXDTI3MDEy +NTA3MjgxMFowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF +AAOCAQ8AMIIBCgKCAQEAiZAY+NfcrqOSKAusYaFd9DSvYF3sRUTcTtyx35cawH1Y +196hTh6vyRIS7pHnX2kKwzc8G1C8aQXs4tBHuF0o+D06z2QU8fyAfJHrb1AhGpFm +y76HzeuT7S9Mmr58Kw4Yks1NuCt4LyBeC0AFEo6FKknlo+GiQvqZLKPVCcabca5r +mh28zvAhWCe00CX5HeXw+BJFDeD618QAiFF4Mr6imoL8TLPyLYz1t+mav6cLvxQK +vrr0FNSMVdmQ+VYF8tBOftCpblmyiswjeCTzKr95RPh6lQdfdn8Oq9Y29SCwvbKr +qlwElrU/RTCukjITk8Gzo/kxJtuRJDduyckN0s0P/QIDAQABo1MwUTAdBgNVHQ4E +FgQUGV6pqqYSC7wPdjKU30QOZl6/AP8wHwYDVR0jBBgwFoAUGV6pqqYSC7wPdjKU +30QOZl6/AP8wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAV/i6 +DjhMOWupCv+X6Re0uR2lasXHRfimFuzanWtNj+A226PHvk+XsYg+CNq+vTPPtMj1 +yQuz8bH+zHcPaXEfXlUzZHZUqMXb0yCxJLPP64nsvYkeoJtorgYzoa20eY9PRdZx +xo/Bvi5IVRaVh3UxaqlUMRmbNc8/OS7Q+CQXXJtfuKnCvwE68b2/M9szPobl1zgY +L3NflRuPHBdtX8/OhDNHygexujyLtI/e/3DzhokkchXjEiPiBpgzrybp6L52XwMb +MBv3houfL6+/Bd+pmnT2NPy5Qqe7ifGZAlz/kNN/6fNGp/FRTa83rt3+rYXzca4A +uOyQ8so7PeuSkuwEEw== +-----END CERTIFICATE----- diff --git a/test/fixtures/http2.key b/test/fixtures/http2.key new file mode 100644 index 00000000..59cc3031 --- /dev/null +++ b/test/fixtures/http2.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCJkBj419yuo5Io +C6xhoV30NK9gXexFRNxO3LHflxrAfVjX3qFOHq/JEhLukedfaQrDNzwbULxpBezi +0Ee4XSj4PTrPZBTx/IB8ketvUCEakWbLvofN65PtL0yavnwrDhiSzU24K3gvIF4L +QAUSjoUqSeWj4aJC+pkso9UJxptxrmuaHbzO8CFYJ7TQJfkd5fD4EkUN4PrXxACI +UXgyvqKagvxMs/ItjPW36Zq/pwu/FAq+uvQU1IxV2ZD5VgXy0E5+0KluWbKKzCN4 +JPMqv3lE+HqVB192fw6r1jb1ILC9squqXASWtT9FMK6SMhOTwbOj+TEm25EkN27J +yQ3SzQ/9AgMBAAECggEABJ9fO3NzgrnT5j0YLpKv184ya3xUfXJmLc88Oe13tGqx +4tUkGf4tfYX6NWKZQgtDVZKEVk2k+yl8D5Ycpt0YjJjwIjp0ero3rhfwL6YjaqNi +ryuIoDqPlruNaTDH9uFrIXm9KBhr2jeN+XZOBVTdNDHWAecJ3xLRNV8PAFxYal47 +uD2UhiFp4GMbQDcjLix4FCAUQTw5Y11a6wmvJu/PmB1u9VLf1HOcLjBtS9YYaIh+ +RuxmqQG/02lVYf6TQKzuT+njJ2Tl4DDH1lvdSyYo2nlW8ybpvux23RyOefSl6PZT +MDutLFYvnic8xIxnwKtchWNWp6Em2ryKHpgwBFVQQQKBgQDCWPt3t+AoUYG6vFtT +BZvWz85E5kNSGT/cC9h4e6xCt2C9qzn8nQ8dsCAhWmWZTMvY/yG0feQ9CLt0rNGF +fsHaGKaau1TpWPOB/pM9m2XMgJo9ZDxWF3qoN5vDQ1zZlcm1dYFoHqcgrmrh5nXM +JXr9pJw0g8/fNaeV4Q9pO92FwQKBgQC1M51XorzbYK6QHKOhuoVMrwh32a0a+TEG +ElRNNQ5rK+3H21tPlk+QKAhOYiihIf9Ut/4yfmR8njV4eM/MtbaiGfFSd2ENiU1d +xnEOSD42NgEFBCvLDcu5/wE4YQwweieohOd3lrsntCucKd616aghhiLjeT/GSabo +2RzYSXpxPQKBgQCLzokvySXGu0OQurkTk0BVGm5vIBojsChBOoBBw+3anKJKLyfq +sm1SVQX4GFhoHFe0RWzQs5OB2ItJVpzu5I29P+hx/PsLVkLuK91t/yEPKSBLs5S3 +9fH1mvNBV28u01MkZ2BtL0fY+b/HvArXjcrZNhZsrLnX/3gMGLgGYttrwQKBgASg +3uIALCbGX28a7CsTYphE2EiHbN6FgvUOvsyCEG44Xwh91+U+h6W9AAlQhI0pGyaE +1J9hjxuHxwHexCAMfC/DzeA3YGlCGpHModKlkcE8u+Xu51d2cL+9fcB86hzK4fxx ++J+bYAhxl7OTdjbbUwoYLQf2buSXuQW1lgEIT3JZAoGAWtsY3mftmxg8hchjsK8V +SZdJcU6Lx+gPUBHi3YPPXrVJJjO2UIBI/w09xSz4MCslHl+HXfiA0uCW3tF8DVGy +OXB/6EVo/wQmNoHZYO5IVqf16Q/cBCiPQ9HR/UJj/D5QfyqseUJ28V/n08oz/Wka +VQXUh0zfszlDyPndAMuXwd8= +-----END PRIVATE KEY----- diff --git a/test/server.jl b/test/server.jl index c824883a..defed070 100644 --- a/test/server.jl +++ b/test/server.jl @@ -59,6 +59,46 @@ end end end +@testset "HTTP/2 server push promise" begin + cert = joinpath(@__DIR__, "fixtures", "http2.crt") + key = joinpath(@__DIR__, "fixtures", "http2.key") + port_ref = Ref{Int}(0) + push_called = Threads.Atomic{Bool}(false) + push_http2 = Threads.Atomic{Bool}(false) + push_server_side = Threads.Atomic{Bool}(false) + server = HTTP.serve!("127.0.0.1", 0; listenany=true, stream=true, ssl_cert=cert, ssl_key=key, ssl_alpn_list="h2") do stream + HTTP.startread(stream) + if stream.http2 + authority = "127.0.0.1:$(port_ref[])" + push = HTTP.push_promise(stream, "GET", "/pushed"; scheme="https", authority=authority) + push_called[] = true + push_http2[] = push.http2 + push_server_side[] = push.server_side + HTTP.setstatus(push, 200) + HTTP.setheader(push, "Content-Type" => "text/plain") + write(push, "pushed") + HTTP.closewrite(push) + end + HTTP.setstatus(stream, 200) + write(stream, "ok") + end + try + port_ref[] = HTTP.port(server) + resp = HTTP.get("https://127.0.0.1:$(port_ref[])"; ssl_insecure=true, ssl_alpn_list="h2") + if resp.version == HTTP.HTTPVersion(2, 0) + @test push_called[] + @test push_http2[] + @test push_server_side[] + @test String(resp.body) == "ok" + else + @info "HTTP/2 not negotiated for push promise test" + @test true + end + finally + close(server) + end +end + @testset "access logging" begin local handler = (req) -> begin if req.target == "/internal-error" From ff03d018e757c88c43730d74ae786b920fd5e550 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 00:42:44 -0700 Subject: [PATCH 13/56] feat(metrics): add connection monitoring hooks --- docs/src/manual/client.md | 8 ++++ docs/src/manual/migrate.md | 1 + src/HTTP.jl | 1 + src/client/client.jl | 87 ++++++++++++++++++++++++++++++++++++-- test/client.jl | 39 +++++++++++++++++ 5 files changed, 133 insertions(+), 3 deletions(-) diff --git a/docs/src/manual/client.md b/docs/src/manual/client.md index 77c1e62f..9d6db70a 100644 --- a/docs/src/manual/client.md +++ b/docs/src/manual/client.md @@ -96,6 +96,10 @@ The following keyword arguments (which correspond to the non-`scheme`/`host`/`po - **max_pending_connection_acquisitions**: Maximum number of pending connection acquisitions. Default is 0 (no limit). - **enable_read_back_pressure**: If `true`, enable back pressure on reads to limit buffered data. Default `false`. - **response_first_byte_timeout_ms**: Maximum time in milliseconds to wait for the first response byte. Default is 0 (disabled). For per-request control, use `readtimeout`. + -- Monitoring options: + - **monitoring_minimum_throughput_bytes_per_second**: Minimum throughput to consider the connection healthy. Default is 0 (disabled). + - **monitoring_allowable_throughput_failure_interval_seconds**: Seconds of below-minimum throughput before the connection is closed. Default is 0 (disabled). + - **monitoring_statistics_observer**: Optional callback `(connection_nonce, stats) -> nothing` invoked with connection stats samples. -- HTTP/2 options: - **http2_prior_knowledge**: Default `false`. If `true`, assume HTTP/2 without ALPN negotiation. - **http2_stream_manager**: Default `false`. If `true`, enable the HTTP/2 stream manager for multiplexed requests. @@ -224,6 +228,10 @@ Each response includes a `metrics` field: For connection-level metrics, use `HTTP.manager_metrics(client)`, which returns `aws_http_manager_metrics` with `available_concurrency`, `pending_concurrency_acquires`, and `leased_concurrency`. +To enable periodic connection statistics callbacks, pass `monitoring_statistics_observer` in `ClientSettings`. The +callback receives a `connection_nonce` and a vector of stats entries. Each entry has a `category` field +(`:http1_channel` or `:http2_channel`) and category-specific fields like `pending_outgoing_stream_ms`. + ## Under the Hood (Advanced) When you call `HTTP.request`, the following advanced steps occur: diff --git a/docs/src/manual/migrate.md b/docs/src/manual/migrate.md index c4ae321c..e0f500b0 100644 --- a/docs/src/manual/migrate.md +++ b/docs/src/manual/migrate.md @@ -339,6 +339,7 @@ cookies = HTTP.getcookies(jar, "example.com") - **Trailing Headers**: Trailing headers are now captured on `Request.trailers` and `Response.trailers`, and can be sent with `HTTP.addtrailer` when streaming. - **HTTP/2 Controls**: HTTP/2 helpers (ping, settings, GOAWAY) are available for advanced connection management. - **Metrics**: Responses include `response.metrics` and clients expose `HTTP.manager_metrics` for connection manager stats. +- **Monitoring**: Optional connection monitoring callbacks are available via `monitoring_statistics_observer`. ## Transitioning Tips diff --git a/src/HTTP.jl b/src/HTTP.jl index ec05160f..510db3ff 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -102,6 +102,7 @@ function __init__() retry_ready[] = @cfunction(c_retry_ready, Cvoid, (Ptr{aws_retry_token}, Cint, Ptr{Cvoid})) on_incoming_connection[] = @cfunction(c_on_incoming_connection, Cvoid, (Ptr{Cvoid}, Ptr{aws_http_connection}, Cint, Ptr{Cvoid})) on_connection_shutdown[] = @cfunction(c_on_connection_shutdown, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Cvoid})) + on_statistics_observer[] = @cfunction(c_on_statistics_observer, Cvoid, (Csize_t, Ptr{aws_array_list}, Ptr{Cvoid})) on_incoming_request[] = @cfunction(c_on_incoming_request, Ptr{aws_http_stream}, (Ptr{aws_http_connection}, Ptr{Cvoid})) on_request_headers[] = @cfunction(c_on_request_headers, Cint, (Ptr{aws_http_stream}, aws_http_header_block, Ptr{aws_http_header}, Csize_t, Ptr{Cvoid})) on_request_header_block_done[] = @cfunction(c_on_request_header_block_done, Cint, (Ptr{aws_http_stream}, aws_http_header_block, Ptr{Cvoid})) diff --git a/src/client/client.jl b/src/client/client.jl index c97a3183..9c784d1e 100644 --- a/src/client/client.jl +++ b/src/client/client.jl @@ -1,6 +1,64 @@ const DEFAULT_CONNECT_TIMEOUT = 3000 const DEFAULT_MAX_RETRIES = 4 +const on_statistics_observer = Ref{Ptr{Cvoid}}(C_NULL) + +mutable struct StatisticsObserver + cb::Function +end + +function _decode_statistics(stats_list_ptr::Ptr{aws_array_list}) + stats_list_ptr == C_NULL && return Any[] + stats_list = unsafe_load(stats_list_ptr) + len = Int(stats_list.length) + len == 0 && return Any[] + item_size = Int(stats_list.item_size) + data_ptr = Ptr{UInt8}(stats_list.data) + data_ptr == C_NULL && return Any[] + stats = Vector{Any}(undef, len) + for i in 1:len + item_ptr = data_ptr + (i - 1) * item_size + category = unsafe_load(Ptr{UInt32}(item_ptr)) + if category == UInt32(AWSCRT_STAT_CAT_HTTP1_CHANNEL) + entry = unsafe_load(Ptr{aws_crt_statistics_http1_channel}(item_ptr)) + stats[i] = ( + category = :http1_channel, + pending_outgoing_stream_ms = entry.pending_outgoing_stream_ms, + pending_incoming_stream_ms = entry.pending_incoming_stream_ms, + current_outgoing_stream_id = entry.current_outgoing_stream_id, + current_incoming_stream_id = entry.current_incoming_stream_id, + ) + elseif category == UInt32(AWSCRT_STAT_CAT_HTTP2_CHANNEL) + entry = unsafe_load(Ptr{aws_crt_statistics_http2_channel}(item_ptr)) + stats[i] = ( + category = :http2_channel, + pending_outgoing_stream_ms = entry.pending_outgoing_stream_ms, + pending_incoming_stream_ms = entry.pending_incoming_stream_ms, + was_inactive = entry.was_inactive, + ) + else + raw = Vector{UInt8}(undef, item_size) + GC.@preserve raw unsafe_copyto!(pointer(raw), item_ptr, item_size) + stats[i] = (category = :unknown, raw = raw) + end + end + return stats +end + +_decode_statistics(stats_list::Ref{aws_array_list}) = + _decode_statistics(Base.unsafe_convert(Ptr{aws_array_list}, stats_list)) + +function c_on_statistics_observer(connection_nonce::Csize_t, stats_list::Ptr{aws_array_list}, observer_ptr::Ptr{Cvoid}) + observer = unsafe_pointer_to_objref(observer_ptr)::StatisticsObserver + stats = _decode_statistics(stats_list) + try + Base.invokelatest(observer.cb, connection_nonce, stats) + catch e + @error "statistics observer error" exception=(e, catch_backtrace()) + end + return +end + Base.@kwdef struct ClientSettings scheme::String host::String @@ -43,6 +101,9 @@ Base.@kwdef struct ClientSettings connection_acquisition_timeout_ms::Int = 0 max_pending_connection_acquisitions::Int = 0 enable_read_back_pressure::Bool = false + monitoring_minimum_throughput_bytes_per_second::UInt64 = 0 + monitoring_allowable_throughput_failure_interval_seconds::UInt32 = 0 + monitoring_statistics_observer::Union{Nothing, Function} = nothing http2_prior_knowledge::Bool = false http2_stream_manager::Bool = false end @@ -85,6 +146,8 @@ mutable struct Client # only 1 of proxy_options or proxy_env_settings is set proxy_options::Union{Nothing, aws_http_proxy_options} proxy_env_settings::Union{Nothing, proxy_env_var_settings} + monitoring_options::Union{Nothing, aws_http_connection_monitoring_options} + monitoring_observer::Union{Nothing, Any} retry_options::aws_standard_retry_options retry_strategy::Ptr{aws_retry_strategy} conn_manager_opts::aws_http_connection_manager_options @@ -126,6 +189,8 @@ function Client(cs::ClientSettings) client.tls_options = nothing end # proxy options + client.proxy_options = nothing + client.proxy_env_settings = nothing if cs.proxy_host !== nothing && cs.proxy_port !== nothing client.proxy_options = aws_http_proxy_options( cs.proxy_connection_type == :forward ? AWS_HPCT_HTTP_FORWARD : AWS_HPCT_HTTP_TUNNEL, @@ -158,8 +223,24 @@ function Client(cs::ClientSettings) cs.proxy_ssl_alpn_list ) ) + end + # connection monitoring options + monitoring_ptr = C_NULL + if cs.monitoring_statistics_observer !== nothing || + cs.monitoring_minimum_throughput_bytes_per_second != 0 || + cs.monitoring_allowable_throughput_failure_interval_seconds != 0 + observer = cs.monitoring_statistics_observer === nothing ? nothing : StatisticsObserver(cs.monitoring_statistics_observer) + client.monitoring_observer = observer + client.monitoring_options = aws_http_connection_monitoring_options( + UInt64(cs.monitoring_minimum_throughput_bytes_per_second), + UInt32(cs.monitoring_allowable_throughput_failure_interval_seconds), + observer === nothing ? C_NULL : on_statistics_observer[], + observer === nothing ? C_NULL : pointer_from_objref(observer) + ) + monitoring_ptr = pointer(FieldRef(client, :monitoring_options)) else - client.proxy_options = nothing + client.monitoring_options = nothing + client.monitoring_observer = nothing end # retry strategy exp_back_opts = aws_exponential_backoff_retry_options( @@ -186,7 +267,7 @@ function Client(cs::ClientSettings) cs.response_first_byte_timeout_ms, (cs.scheme == "https" || cs.scheme == "wss") ? pointer(FieldRef(client, :tls_options)) : C_NULL, cs.http2_prior_knowledge, - C_NULL, # monitoring_options::Ptr{aws_http_connection_monitoring_options} + monitoring_ptr, # monitoring_options::Ptr{aws_http_connection_monitoring_options} aws_byte_cursor_from_c_str(cs.host), cs.port % UInt32, C_NULL, # initial_settings_array::Ptr{aws_http2_setting} @@ -223,7 +304,7 @@ function Client(cs::ClientSettings) false, # conn_manual_window_management cs.enable_read_back_pressure, typemax(Csize_t), # initial_window_size - C_NULL, # monitoring_options + monitoring_ptr, # monitoring_options client.proxy_options === nothing ? C_NULL : pointer(FieldRef(client, :proxy_options)), client.proxy_env_settings === nothing ? C_NULL : pointer(FieldRef(client, :proxy_env_settings)), C_NULL, # shutdown_complete_user_data diff --git a/test/client.jl b/test/client.jl index e2bc8489..ac1ac7bd 100644 --- a/test/client.jl +++ b/test/client.jl @@ -343,6 +343,45 @@ finalize(client) end + @testset "HTTP connection monitoring stats" begin + list = Ref{HTTP.aws_array_list}() + HTTP.aws_array_list_init_dynamic(list, HTTP.default_aws_allocator(), 1, sizeof(HTTP.aws_crt_statistics_http1_channel)) + stat1 = HTTP.aws_crt_statistics_http1_channel(HTTP.AWSCRT_STAT_CAT_HTTP1_CHANNEL, 10, 20, 1, 2) + HTTP.aws_array_list_push_back(list, Ref(stat1)) + decoded = HTTP._decode_statistics(list) + @test length(decoded) == 1 + @test decoded[1].category == :http1_channel + @test decoded[1].pending_outgoing_stream_ms == 10 + @test decoded[1].pending_incoming_stream_ms == 20 + @test decoded[1].current_outgoing_stream_id == 1 + @test decoded[1].current_incoming_stream_id == 2 + HTTP.aws_array_list_clean_up(list) + + list = Ref{HTTP.aws_array_list}() + HTTP.aws_array_list_init_dynamic(list, HTTP.default_aws_allocator(), 1, sizeof(HTTP.aws_crt_statistics_http2_channel)) + stat2 = HTTP.aws_crt_statistics_http2_channel(HTTP.AWSCRT_STAT_CAT_HTTP2_CHANNEL, 5, 6, true) + HTTP.aws_array_list_push_back(list, Ref(stat2)) + decoded = HTTP._decode_statistics(list) + @test length(decoded) == 1 + @test decoded[1].category == :http2_channel + @test decoded[1].pending_outgoing_stream_ms == 5 + @test decoded[1].pending_incoming_stream_ms == 6 + @test decoded[1].was_inactive == true + HTTP.aws_array_list_clean_up(list) + + called = Ref(false) + cb = (nonce, stats) -> (called[] = true) + client = HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443); monitoring_statistics_observer=cb)) + list = Ref{HTTP.aws_array_list}() + HTTP.aws_array_list_init_dynamic(list, HTTP.default_aws_allocator(), 1, sizeof(HTTP.aws_crt_statistics_http1_channel)) + stat3 = HTTP.aws_crt_statistics_http1_channel(HTTP.AWSCRT_STAT_CAT_HTTP1_CHANNEL, 1, 1, 1, 1) + HTTP.aws_array_list_push_back(list, Ref(stat3)) + HTTP.c_on_statistics_observer(Csize_t(0), Base.unsafe_convert(Ptr{HTTP.aws_array_list}, list), pointer_from_objref(client.monitoring_observer)) + @test called[] + HTTP.aws_array_list_clean_up(list) + finalize(client) + end + @testset "HTTP/2 control APIs" begin resp = HTTP.get("https://$httpbin/ip") if resp.version == HTTP.HTTPVersion(2, 0) From 9b8075a81546c105f371c45fdaac4a828f9d1983 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 00:49:59 -0700 Subject: [PATCH 14/56] fix(server): use server TLS options --- src/server.jl | 44 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/src/server.jl b/src/server.jl index a351c88e..1b210351 100644 --- a/src/server.jl +++ b/src/server.jl @@ -3,6 +3,48 @@ socket_endpoint(host, port) = aws_socket_endpoint( port % UInt32 ) +function server_tlsoptions(host::String; + allocator=default_aws_allocator(), + ssl_cert=nothing, + ssl_key=nothing, + ssl_capath=nothing, + ssl_cacert=nothing, + ssl_insecure=false, + ssl_alpn_list=nothing, +) + tls_options = aws_tls_connection_options(C_NULL, C_NULL, C_NULL, C_NULL, C_NULL, C_NULL, C_NULL, false, UInt32(0)) + tls_ctx_options = Ptr{aws_tls_ctx_options}(aws_mem_acquire(allocator, sizeof(aws_tls_ctx_options))) + tls_ctx = C_NULL + try + if ssl_cert !== nothing && ssl_key !== nothing + LibAwsIO.aws_tls_ctx_options_init_default_server_from_path(tls_ctx_options, allocator, ssl_cert, ssl_key) != 0 && sockerr("aws_tls_ctx_options_init_default_server_from_path failed") + elseif Sys.iswindows() && ssl_cert !== nothing && ssl_key === nothing + LibAwsIO.aws_tls_ctx_options_init_default_server_from_system_path(tls_ctx_options, allocator, ssl_cert) != 0 && sockerr("aws_tls_ctx_options_init_default_server_from_system_path failed") + else + throw(ArgumentError("ssl_cert and ssl_key are required for TLS server")) + end + if ssl_capath !== nothing && ssl_cacert !== nothing + LibAwsIO.aws_tls_ctx_options_override_default_trust_store_from_path(tls_ctx_options, ssl_capath, ssl_cacert) != 0 && sockerr("aws_tls_ctx_options_override_default_trust_store_from_path failed") + end + if ssl_insecure + LibAwsIO.aws_tls_ctx_options_set_verify_peer(tls_ctx_options, false) + end + if ssl_alpn_list !== nothing + LibAwsIO.aws_tls_ctx_options_set_alpn_list(tls_ctx_options, ssl_alpn_list) != 0 && sockerr("aws_tls_ctx_options_set_alpn_list failed") + end + tls_ctx = LibAwsIO.aws_tls_server_ctx_new(allocator, tls_ctx_options) + tls_ctx == C_NULL && sockerr("") + ref = Ref(tls_options) + LibAwsIO.aws_tls_connection_options_init_from_ctx(ref, tls_ctx) + tls_options = ref[] + finally + LibAwsIO.aws_tls_ctx_options_clean_up(tls_ctx_options) + LibAwsIO.aws_tls_ctx_release(tls_ctx) + aws_mem_release(allocator, tls_ctx_options) + end + return tls_options +end + mutable struct Connection{S} const server::S # Server{F, C} const allocator::Ptr{aws_allocator} @@ -158,7 +200,7 @@ function serve!(f, host="127.0.0.1", port=8080; ntuple(x -> Cchar(0), 16) # network_interface_name ), tls_options !== nothing ? tls_options : - any(x -> x !== nothing, (ssl_cert, ssl_key, ssl_capath, ssl_cacert)) ? LibAwsIO.tlsoptions(host; + any(x -> x !== nothing, (ssl_cert, ssl_key, ssl_capath, ssl_cacert)) ? server_tlsoptions(host; ssl_cert, ssl_key, ssl_capath, From 69a374a87b8b33d16f4014853cbf6e1b11a1ae67 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 01:06:56 -0700 Subject: [PATCH 15/56] feat(client): stream IO request bodies --- src/client/makerequest.jl | 4 ++++ src/requestresponse.jl | 17 +++++++++++++++++ test/client.jl | 40 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+) diff --git a/src/client/makerequest.jl b/src/client/makerequest.jl index 5885f865..ae146fc2 100644 --- a/src/client/makerequest.jl +++ b/src/client/makerequest.jl @@ -70,6 +70,10 @@ function request(method, url, h=Header[], b=nothing; verbose=0, # only client keywords in catch-all kw...) + if chunkedbody === nothing && body isa IO && !(body isa IOStream) && !(body isa Form) + chunkedbody = IOChunkedBody(body) + body = nothing + end if chunkedbody === nothing && body !== nothing && !(body isa RequestBodyTypes) && Base.isiterable(typeof(body)) chunkedbody = body body = nothing diff --git a/src/requestresponse.jl b/src/requestresponse.jl index f333ae20..1633a3e9 100644 --- a/src/requestresponse.jl +++ b/src/requestresponse.jl @@ -222,6 +222,23 @@ end ischunked(is::InputStream) = is.ptr == C_NULL && is.bodyref !== nothing const RequestBodyTypes = Union{AbstractString, AbstractVector{UInt8}, IO, AbstractDict, NamedTuple, Form, Nothing} +const DEFAULT_IO_CHUNK_SIZE = 64 * 1024 + +struct IOChunkedBody{T<:IO} + io::T + chunk_size::Int + buf::Vector{UInt8} +end + +IOChunkedBody(io::IO; chunk_size::Int=DEFAULT_IO_CHUNK_SIZE) = + IOChunkedBody{typeof(io)}(io, chunk_size, Vector{UInt8}(undef, chunk_size)) + +function Base.iterate(it::IOChunkedBody, state=nothing) + eof(it.io) && return nothing + n = readbytes!(it.io, it.buf, it.chunk_size) + n == 0 && return nothing + return view(it.buf, 1:n), nothing +end function InputStream(allocator::Ptr{aws_allocator}, body) is = InputStream() diff --git a/test/client.jl b/test/client.jl index ac1ac7bd..68b063f5 100644 --- a/test/client.jl +++ b/test/client.jl @@ -301,6 +301,46 @@ @test !isempty(pool.clients.clients) end + @testset "IO request body streaming" begin + mutable struct ChunkedTestIO <: IO + chunks::Vector{Vector{UInt8}} + readbytes_calls::Int + readavailable_calls::Int + end + ChunkedTestIO(chunks) = ChunkedTestIO(chunks, 0, 0) + Base.eof(io::ChunkedTestIO) = isempty(io.chunks) + function Base.readbytes!(io::ChunkedTestIO, buf::Vector{UInt8}, n::Integer) + io.readbytes_calls += 1 + isempty(io.chunks) && return 0 + chunk = popfirst!(io.chunks) + ncopy = min(n, length(chunk)) + copyto!(buf, 1, chunk, 1, ncopy) + return ncopy + end + function Base.readavailable(io::ChunkedTestIO) + io.readavailable_calls += 1 + error("readavailable should not be used for chunked IO") + end + + server = HTTP.listen!("127.0.0.1", 0; listenany=true) do http + body = String(read(http)) + HTTP.setstatus(http, 200) + HTTP.setheader(http, "Content-Type" => "text/plain") + HTTP.startwrite(http) + write(http, body) + end + try + port = HTTP.port(server) + io = ChunkedTestIO([Vector{UInt8}("hello"), Vector{UInt8}(" "), Vector{UInt8}("world")]) + resp = HTTP.post("http://127.0.0.1:$port/"; body=io) + @test String(resp.body) == "hello world" + @test io.readbytes_calls == 3 + @test io.readavailable_calls == 0 + finally + close(server) + end + end + @testset "HTTP.open streaming" begin resp = HTTP.open("GET", "https://$httpbin/stream/5") do io r = HTTP.startread(io) From 1906ec0f106fd6c32d8520e8c747587aec02c6fb Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 01:18:45 -0700 Subject: [PATCH 16/56] feat(client): add observelayers context --- src/client/connection.jl | 16 ++++- src/client/makerequest.jl | 24 ++++--- src/client/open.jl | 45 ++++++++----- src/client/redirects.jl | 15 ++++- src/client/request.jl | 25 ++++--- src/client/retry.jl | 18 ++++++ src/client/stream.jl | 133 +++++++++++++++++++++++++++++++++++++- src/requestresponse.jl | 40 +++++++++++- test/client.jl | 22 +++++++ 9 files changed, 297 insertions(+), 41 deletions(-) diff --git a/src/client/connection.jl b/src/client/connection.jl index b93530de..65dac2f6 100644 --- a/src/client/connection.jl +++ b/src/client/connection.jl @@ -34,7 +34,20 @@ function c_on_ping_complete(conn, round_trip_time_ns, error_code, fut_ptr) return end -function with_connection(f::Function, client::Client) +function with_connection(f::Function, client::Client; context=nothing) + if context === nothing + fut = Future{Ptr{aws_http_connection}}() + GC.@preserve fut begin + aws_http_connection_manager_acquire_connection(client.connection_manager, on_setup[], pointer_from_objref(fut)) + connection = wait(fut) + end + try + return f(connection) + finally + aws_http_connection_manager_release_connection(client.connection_manager, connection) + end + end + start_time = time() fut = Future{Ptr{aws_http_connection}}() GC.@preserve fut begin aws_http_connection_manager_acquire_connection(client.connection_manager, on_setup[], pointer_from_objref(fut)) @@ -44,6 +57,7 @@ function with_connection(f::Function, client::Client) return f(connection) finally aws_http_connection_manager_release_connection(client.connection_manager, connection) + _record_layer!(context, :connectionlayer, start_time) end end diff --git a/src/client/makerequest.jl b/src/client/makerequest.jl index ae146fc2..2480ba62 100644 --- a/src/client/makerequest.jl +++ b/src/client/makerequest.jl @@ -70,6 +70,8 @@ function request(method, url, h=Header[], b=nothing; verbose=0, # only client keywords in catch-all kw...) + context = observelayers ? Dict{Symbol, Any}() : nothing + context === nothing || _init_observations!(context) if chunkedbody === nothing && body isa IO && !(body isa IOStream) && !(body isa Form) chunkedbody = IOChunkedBody(body) body = nothing @@ -98,7 +100,7 @@ function request(method, url, h=Header[], b=nothing; end authinfo = (username !== nothing && password !== nothing) ? "$username:$password" : userinfo(uri) apply_basicauth = (username !== nothing && password !== nothing) ? true : basicauth - return with_redirect(allocator, method, uri, headers, body, redirect, redirect_limit, redirect_method, forwardheaders) do method, uri, headers, body + return with_redirect(allocator, method, uri, headers, body, redirect, redirect_limit, redirect_method, forwardheaders; context=context) do method, uri, headers, body reqclient = @something( client, pool === nothing ? @@ -108,7 +110,7 @@ function request(method, url, h=Header[], b=nothing; req_ref = Ref{Any}(nothing) with_retry_token(reqclient; logerrors=logerrors, logtag=logtag, method=method, uri=uri, retry_check=retry_check, retry_delays=retry_delays, - retry_non_idempotent=retry_non_idempotent, retryable_body=retryable_body, req_ref=req_ref) do + retry_non_idempotent=retry_non_idempotent, retryable_body=retryable_body, req_ref=req_ref, context=context) do resp = if reqclient.http2_stream_manager != C_NULL path = resource(uri) with_request(reqclient, method, path, headers, body, chunkedbody, decompress, authinfo, bearer, modifier, true, cookies, cookiejar, verbose; @@ -116,23 +118,25 @@ function request(method, url, h=Header[], b=nothing; canonicalize_headers=canonicalize_headers, detect_content_type=detect_content_type, basicauth=apply_basicauth, + observelayers=observelayers, + context=context, ) do req req_ref[] = req if response_body isa AbstractVector{UInt8} ref = Ref(1) GC.@preserve ref begin on_stream_response_body = BufferOnResponseBody(response_body, Base.unsafe_convert(Ptr{Int}, ref)) - with_stream_manager(reqclient, req, chunkedbody, on_stream_response_body, decompress, readtimeout, allocator) + with_stream_manager(reqclient, req, chunkedbody, on_stream_response_body, decompress, readtimeout, allocator; context=context) end elseif response_body isa IO on_stream_response_body = IOOnResponseBody(response_body) - with_stream_manager(reqclient, req, chunkedbody, on_stream_response_body, decompress, readtimeout, allocator) + with_stream_manager(reqclient, req, chunkedbody, on_stream_response_body, decompress, readtimeout, allocator; context=context) else - with_stream_manager(reqclient, req, chunkedbody, response_body, decompress, readtimeout, allocator) + with_stream_manager(reqclient, req, chunkedbody, response_body, decompress, readtimeout, allocator; context=context) end end else - with_connection(reqclient) do conn + with_connection(reqclient; context=context) do conn http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 path = resource(uri) with_request(reqclient, method, path, headers, body, chunkedbody, decompress, authinfo, bearer, modifier, http2, cookies, cookiejar, verbose; @@ -140,19 +144,21 @@ function request(method, url, h=Header[], b=nothing; canonicalize_headers=canonicalize_headers, detect_content_type=detect_content_type, basicauth=apply_basicauth, + observelayers=observelayers, + context=context, ) do req req_ref[] = req if response_body isa AbstractVector{UInt8} ref = Ref(1) GC.@preserve ref begin on_stream_response_body = BufferOnResponseBody(response_body, Base.unsafe_convert(Ptr{Int}, ref)) - with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator) + with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator; context=context) end elseif response_body isa IO on_stream_response_body = IOOnResponseBody(response_body) - with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator) + with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator; context=context) else - with_stream(conn, req, chunkedbody, response_body, decompress, http2, readtimeout, allocator) + with_stream(conn, req, chunkedbody, response_body, decompress, http2, readtimeout, allocator; context=context) end end end diff --git a/src/client/open.jl b/src/client/open.jl index d26194c7..e35c2caf 100644 --- a/src/client/open.jl +++ b/src/client/open.jl @@ -66,8 +66,11 @@ function open(f::Function, method::Union{String, Symbol}, url, h=Header[]; method_str = string(method) headers = mkreqheaders(headers, copyheaders) uri = parseuri(url, query, allocator) + context = observelayers ? Dict{Symbol, Any}() : nothing + context === nothing || _init_observations!(context) count = 0 while true + redirect_start = context === nothing ? 0.0 : time() redirect_url = nothing resp = nothing proxy_kw = proxy_kwargs(proxy, scheme(uri)) @@ -86,7 +89,7 @@ function open(f::Function, method::Union{String, Symbol}, url, h=Header[]; getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...)) : getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...), pool) )::Client - resp = with_connection(reqclient) do conn + resp = with_connection(reqclient; context=context) do conn http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 path = resource(uri) with_request(reqclient, method_str, path, headers, nothing, nothing, decompress, authinfo, bearer, modifier, http2, cookies, cookiejar, verbose; @@ -94,6 +97,8 @@ function open(f::Function, method::Union{String, Symbol}, url, h=Header[]; canonicalize_headers=canonicalize_headers, detect_content_type=detect_content_type, basicauth=apply_basicauth, + observelayers=observelayers, + context=context, ) do req if !http2 && method_str in ("POST", "PUT", "PATCH") && @@ -103,27 +108,33 @@ function open(f::Function, method::Union{String, Symbol}, url, h=Header[]; setheader(req.headers, "transfer-encoding", "chunked") end stream = _open_stream(conn, req, decompress, readtimeout, allocator) - if redirect && issafe(method_str) - resp = startread(stream) - if (count < redirect_limit && isredirect(resp) && (location = getheader(resp.headers, "Location")) != "") - redirect_url = location - closeread(stream) - return resp - end - end - err = nothing + stream_start = context === nothing ? 0.0 : time() try - f(stream) - catch e - err = e + if redirect && issafe(method_str) + resp = startread(stream) + if (count < redirect_limit && isredirect(resp) && (location = getheader(resp.headers, "Location")) != "") + redirect_url = location + closeread(stream) + return resp + end + end + err = nothing + try + f(stream) + catch e + err = e + finally + closewrite(stream) + end + resp = closeread(stream) + err === nothing || throw(err) + return resp finally - closewrite(stream) + context === nothing || _record_layer!(context, :streamlayer, stream_start) end - resp = closeread(stream) - err === nothing || throw(err) - return resp end end + context === nothing || _record_layer!(context, :redirectlayer, redirect_start) if redirect_url === nothing if status_exception && iserror(resp) if logerrors diff --git a/src/client/redirects.jl b/src/client/redirects.jl index e260939e..0cc441c2 100644 --- a/src/client/redirects.jl +++ b/src/client/redirects.jl @@ -38,14 +38,23 @@ function newmethod(request_method, response_status, redirect_method) return "GET" end -function with_redirect(f, allocator, method, uri, headers=nothing, body=nothing, redirect::Bool=true, redirect_limit::Int=3, redirect_method=nothing, forwardheaders::Bool=true) +function with_redirect(f, allocator, method, uri, headers=nothing, body=nothing, redirect::Bool=true, redirect_limit::Int=3, redirect_method=nothing, forwardheaders::Bool=true; context=nothing) if !redirect || redirect_limit == 0 # no redirecting return f(method, uri, headers, body) end count = 0 while true - ret = f(method, uri, headers, body) + if context === nothing + ret = f(method, uri, headers, body) + else + start_time = time() + try + ret = f(method, uri, headers, body) + finally + _record_layer!(context, :redirectlayer, start_time) + end + end resp = getresponse(ret) if (count == redirect_limit || !isredirect(resp) || (location = getheader(resp.headers, "Location")) == "") return ret @@ -76,4 +85,4 @@ function with_redirect(f, allocator, method, uri, headers=nothing, body=nothing, count += 1 end @assert false "Unreachable!" -end \ No newline at end of file +end diff --git a/src/client/request.jl b/src/client/request.jl index 1a97f224..45adaccb 100644 --- a/src/client/request.jl +++ b/src/client/request.jl @@ -35,11 +35,13 @@ function with_request( canonicalize_headers::Bool=false, detect_content_type::Bool=false, basicauth::Bool=true, + observelayers::Bool=false, + context=nothing, ) # create request mutable_headers = (headers isa AbstractVector{<:Pair} && !copyheaders) ? headers : nothing req_headers = mkreqheaders(headers, copyheaders) - req = Request(method, path, req_headers, nothing, http2, client.settings.allocator) + req = Request(method, path, req_headers, nothing, http2, client.settings.allocator; context=context) # add headers to request h = req.headers if http2 @@ -97,12 +99,19 @@ function with_request( end # call user function verbose > 0 && print_request(stdout, req) - ret = f(req) - resp = getresponse(ret) - if canonicalize_headers - canonicalizeheaders!(resp.headers) + start_time = time() + ret = nothing + try + ret = f(req) + resp = getresponse(ret) + if canonicalize_headers + canonicalizeheaders!(resp.headers) + end + verbose > 0 && print_response(stdout, resp) + cookies === false || Cookies.setcookies!(cookiejar, client.settings.scheme, client.settings.host, req.path, resp.headers) + return ret + finally + req.context[:total_request_duration_ms] = (time() - start_time) * 1000 + observelayers && _record_layer!(req.context, :messagelayer, start_time) end - verbose > 0 && print_response(stdout, resp) - cookies === false || Cookies.setcookies!(cookiejar, client.settings.scheme, client.settings.host, req.path, resp.headers) - return ret end diff --git a/src/client/retry.jl b/src/client/retry.jl index f5c6e7fc..80773fd6 100644 --- a/src/client/retry.jl +++ b/src/client/retry.jl @@ -91,10 +91,23 @@ function with_retry_token( retry_non_idempotent::Bool=false, retryable_body::Bool=true, req_ref=nothing, + context=nothing, ) # If max_retries is 0, we don't need to bother with any retrying max_retries = client.settings.max_retries if max_retries == 0 + if context === nothing + try + return f() + catch e + if logerrors + url = uri === nothing ? nothing : (uri isa aws_uri ? makeuri(uri) : uri) + @error "HTTP request error" exception=(e, catch_backtrace()) method=method url=url logtag=logtag + end + rethrow() + end + end + start_time = time() try return f() catch e @@ -103,6 +116,8 @@ function with_retry_token( @error "HTTP request error" exception=(e, catch_backtrace()) method=method url=url logtag=logtag end rethrow() + finally + _record_layer!(context, :retrylayer, start_time) end end retry_check_fn = retry_check === nothing ? nothing : retry_check @@ -110,11 +125,14 @@ function with_retry_token( delay_state = nothing nretries = 0 while true + attempt_start = context === nothing ? 0.0 : time() try ret = f() + context === nothing || _record_layer!(context, :retrylayer, attempt_start) _set_nretries!(ret, nretries) return ret catch e + context === nothing || _record_layer!(context, :retrylayer, attempt_start) stream = nothing err = e if err isa StreamError diff --git a/src/client/stream.jl b/src/client/stream.jl index 631432ba..4cd21dfb 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -178,6 +178,8 @@ end Base.hash(s::Stream, h::UInt) = hash(s.ptr, h) +getrequest(s::Stream) = s.request + const ACTIVE_STREAMS_LOCK = ReentrantLock() const ACTIVE_STREAMS = IdDict{Stream, Bool}() @@ -580,7 +582,71 @@ function addtrailer(s::Stream, h::AbstractVector{<:Pair}) return addtrailer(s, trailers) end -function with_stream_manager(client::Client, req::Request, chunkedbody, on_stream_response_body, decompress, readtimeout, allocator) +function with_stream_manager(client::Client, req::Request, chunkedbody, on_stream_response_body, decompress, readtimeout, allocator; context=nothing) + if context === nothing + stream = Stream{Nothing}(allocator, decompress, true, false) + if on_stream_response_body !== nothing + stream.bufferstream = Base.BufferStream() + end + acquire_fut = Future{Ptr{aws_http_stream}}() + GC.@preserve stream acquire_fut begin + stream.request_options = aws_http_make_request_options( + 1, + req.ptr, + pointer_from_objref(stream), + on_response_headers[], + on_response_header_block_done[], + on_response_body[], + on_metrics[], + on_complete[], + on_destroy[], + chunkedbody !== nothing, # http2_use_manual_data_writes + readtimeout * 1000 # response_first_byte_timeout_ms + ) + stream.response = resp = Response(0, nothing, nothing, true, allocator) + resp.metrics = RequestMetrics() + resp.request = req + acquire_opts = aws_http2_stream_manager_acquire_stream_options( + on_stream_acquired[], + pointer_from_objref(acquire_fut), + FieldRef(stream, :request_options), + ) + aws_http2_stream_manager_acquire_stream(client.http2_stream_manager, Ref(acquire_opts)) + stream_ptr = wait(acquire_fut) + stream.ptr = stream_ptr + stream.activated = true + try + if chunkedbody !== nothing + foreach(chunk -> writechunk(stream, chunk), chunkedbody) + writechunk(stream, "") + end + if on_stream_response_body !== nothing + try + while !eof(stream.bufferstream) + on_stream_response_body(resp, _readavailable(stream.bufferstream)) + end + wait(stream.fut) + catch e + rethrow(DontRetry(e)) + end + else + wait(stream.fut) + if stream.bufferstream !== nothing + resp.body = _readavailable(stream.bufferstream) + else + resp.body = UInt8[] + end + end + return resp + finally + aws_http_stream_release(stream_ptr) + stream.released = true + stream.ptr = Ptr{aws_http_stream}(C_NULL) + end + end + end + + start_time = time() stream = Stream{Nothing}(allocator, decompress, true, false) if on_stream_response_body !== nothing stream.bufferstream = Base.BufferStream() @@ -639,11 +705,73 @@ function with_stream_manager(client::Client, req::Request, chunkedbody, on_strea aws_http_stream_release(stream_ptr) stream.released = true stream.ptr = Ptr{aws_http_stream}(C_NULL) + _record_layer!(context, :streamlayer, start_time) end end end -function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator) +function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator; context=nothing) + if context === nothing + stream = Stream{Nothing}(allocator, decompress, http2, false) + if on_stream_response_body !== nothing + stream.bufferstream = Base.BufferStream() + end + GC.@preserve stream begin + stream.request_options = aws_http_make_request_options( + 1, + req.ptr, + pointer_from_objref(stream), + on_response_headers[], + on_response_header_block_done[], + on_response_body[], + on_metrics[], + on_complete[], + on_destroy[], + http2 && chunkedbody !== nothing, # http2_use_manual_data_writes + readtimeout * 1000 # response_first_byte_timeout_ms + ) + stream_ptr = aws_http_connection_make_request(conn, FieldRef(stream, :request_options)) + stream_ptr == C_NULL && aws_throw_error() + stream.ptr = stream_ptr + http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 + stream.response = resp = Response(0, nothing, nothing, http2, allocator) + resp.metrics = RequestMetrics() + resp.request = req + try + aws_http_stream_activate(stream_ptr) != 0 && aws_throw_error() + # write chunked body if provided + if chunkedbody !== nothing + foreach(chunk -> writechunk(stream, chunk), chunkedbody) + # write final chunk + writechunk(stream, "") + end + if on_stream_response_body !== nothing + try + while !eof(stream.bufferstream) + on_stream_response_body(resp, _readavailable(stream.bufferstream)) + end + wait(stream.fut) + catch e + rethrow(DontRetry(e)) + end + else + wait(stream.fut) + if stream.bufferstream !== nothing + resp.body = _readavailable(stream.bufferstream) + else + resp.body = UInt8[] + end + end + return resp + finally + aws_http_stream_release(stream_ptr) + stream.released = true + stream.ptr = Ptr{aws_http_stream}(C_NULL) + end + end # GC.@preserve + end + + start_time = time() stream = Stream{Nothing}(allocator, decompress, http2, false) if on_stream_response_body !== nothing stream.bufferstream = Base.BufferStream() @@ -699,6 +827,7 @@ function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, aws_http_stream_release(stream_ptr) stream.released = true stream.ptr = Ptr{aws_http_stream}(C_NULL) + _record_layer!(context, :streamlayer, start_time) end end # GC.@preserve end diff --git a/src/requestresponse.jl b/src/requestresponse.jl index 1633a3e9..3bf80299 100644 --- a/src/requestresponse.jl +++ b/src/requestresponse.jl @@ -240,6 +240,24 @@ function Base.iterate(it::IOChunkedBody, state=nothing) return view(it.buf, 1:n), nothing end +const OBSERVELAYER_NAMES = (:messagelayer, :redirectlayer, :retrylayer, :connectionlayer, :streamlayer) + +function _init_observations!(context::Dict{Symbol, Any}) + for name in OBSERVELAYER_NAMES + context[Symbol(name, "_count")] = 0 + context[Symbol(name, "_duration_ms")] = 0.0 + end + return context +end + +function _record_layer!(context::Dict{Symbol, Any}, name::Symbol, started::Float64) + cntkey = Symbol(name, "_count") + durkey = Symbol(name, "_duration_ms") + context[cntkey] = Base.get(() -> 0, context, cntkey) + 1 + context[durkey] = Base.get(() -> 0.0, context, durkey) + (time() - started) * 1000 + return +end + function InputStream(allocator::Ptr{aws_allocator}, body) is = InputStream() if body !== nothing @@ -311,11 +329,12 @@ mutable struct Request <: Message # only set in server-side request handlers body::Union{Nothing, Vector{UInt8}} trailers::Union{Nothing, Headers} + context::Dict{Symbol, Any} route::Union{Nothing, String} params::Union{Nothing, Dict{String, String}} cookies::Any # actually Union{Nothing, Vector{Cookie}} - function Request(method, path, headers=nothing, body=nothing, http2::Bool=false, allocator=default_aws_allocator()) + function Request(method, path, headers=nothing, body=nothing, http2::Bool=false, allocator=default_aws_allocator(); context=nothing) ptr = http2 ? aws_http2_message_new_request(allocator) : aws_http_message_new_request(allocator) @@ -333,6 +352,7 @@ mutable struct Request <: Message req.body = nothing req.inputstream = nothing req.trailers = nothing + req.context = context === nothing ? Dict{Symbol, Any}() : context req.route = nothing req.params = nothing req.cookies = nothing @@ -345,6 +365,24 @@ mutable struct Request <: Message end end +getrequest(req::Request) = req + +function observelayer(f) + function observation(req_or_stream; kw...) + req = getrequest(req_or_stream) + nm = nameof(f) + start_time = time() + ctx = req.context + ctx[Symbol(nm, "_count")] = Base.get(() -> 0, ctx, Symbol(nm, "_count")) + 1 + try + return f(req_or_stream; kw...) + finally + ctx[Symbol(nm, "_duration_ms")] = + Base.get(() -> 0.0, ctx, Symbol(nm, "_duration_ms")) + (time() - start_time) * 1000 + end + end +end + ptr(x) = getfield(x, :ptr) function Base.getproperty(x::Request, s::Symbol) diff --git a/test/client.jl b/test/client.jl index 68b063f5..ff20fb19 100644 --- a/test/client.jl +++ b/test/client.jl @@ -301,6 +301,28 @@ @test !isempty(pool.clients.clients) end + @testset "observelayers" begin + server = HTTP.serve!(req -> begin + if req.target == "/redirect" + return HTTP.Response(302, ["Location" => "/ok"], nothing) + end + return HTTP.Response(200, "ok") + end; listenany=true) + try + port = HTTP.port(server) + resp = HTTP.get("http://127.0.0.1:$port/redirect"; observelayers=true, retries=0) + ctx = resp.request.context + @test ctx[:messagelayer_count] >= 1 + @test ctx[:redirectlayer_count] >= 1 + @test ctx[:retrylayer_count] >= 1 + @test ctx[:connectionlayer_count] >= 1 + @test ctx[:streamlayer_count] >= 1 + @test ctx[:total_request_duration_ms] > 0 + finally + close(server) + end + end + @testset "IO request body streaming" begin mutable struct ChunkedTestIO <: IO chunks::Vector{Vector{UInt8}} From a515f4eb3ca53eedb5ddd861dae663b12a09144f Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 01:21:00 -0700 Subject: [PATCH 17/56] feat(api): export stream helpers --- src/HTTP.jl | 1 + test/utils.jl | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/src/HTTP.jl b/src/HTTP.jl index 510db3ff..c934bf5a 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -5,6 +5,7 @@ using LibAwsCommon, LibAwsIO, LibAwsHTTPFork import LibAwsCommon: Future, FieldRef export HTTPVersion +export startwrite, startread, closewrite, closeread export @logfmt_str, common_logfmt, combined_logfmt export WebSockets diff --git a/test/utils.jl b/test/utils.jl index 5fd4c47d..84f52c46 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -31,4 +31,10 @@ end @test_throws HTTP.AWSError HTTP.parseuri("http://example.com:abc", nothing, HTTP.default_aws_allocator()) + + exported = names(HTTP, all=false) + @test :startwrite in exported + @test :startread in exported + @test :closewrite in exported + @test :closeread in exported end # testset From 76b298bdaae9879f0a9081476b86f3a5c687af15 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 02:02:11 -0700 Subject: [PATCH 18/56] feat(stream): add readall, closebody, isaborted --- src/client/stream.jl | 24 ++++++++++++++++++++++++ test/client.jl | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/src/client/stream.jl b/src/client/stream.jl index 4cd21dfb..29ab9927 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -403,6 +403,30 @@ function closewrite(s::Stream) return end +function closebody(s::Stream) + closewrite(s) + return +end + +function readall!(s::Stream, buf::Base.GenericIOBuffer=PipeBuffer()) + total = 0 + while !eof(s) + bytes = readavailable(s) + total += length(bytes) + write(buf, bytes) + end + return total +end + +function isaborted(s::Stream) + s.server_side && return false + if !isdefined(s, :response) || s.response === nothing + return false + end + resp = s.response + return iserror(resp) && hasheader(resp, "Connection", "close") +end + function startread(s::Stream) if s.server_side if s.read_started diff --git a/test/client.jl b/test/client.jl index ff20fb19..e7604ecf 100644 --- a/test/client.jl +++ b/test/client.jl @@ -363,6 +363,46 @@ end end + @testset "stream helpers" begin + server = HTTP.listen!("127.0.0.1", 0; listenany=true) do http + body = String(read(http)) + if http.request.method == "POST" + HTTP.setstatus(http, 200) + HTTP.startwrite(http) + write(http, body) + else + HTTP.setstatus(http, 500) + HTTP.setheader(http, "Connection" => "close") + HTTP.startwrite(http) + write(http, "error") + end + end + try + port = HTTP.port(server) + resp = HTTP.open("GET", "http://127.0.0.1:$port"; status_exception=false) do io + r = HTTP.startread(io) + @test r.status == 500 + @test HTTP.isaborted(io) + buf = IOBuffer() + n = HTTP.readall!(io, buf) + @test n > 0 + @test String(take!(buf)) == "error" + end + @test resp.status == 500 + + resp = HTTP.open("POST", "http://127.0.0.1:$port") do io + write(io, "hello") + HTTP.closebody(io) + r = HTTP.startread(io) + @test r.status == 200 + @test String(read(io)) == "hello" + end + @test resp.status == 200 + finally + close(server) + end + end + @testset "HTTP.open streaming" begin resp = HTTP.open("GET", "https://$httpbin/stream/5") do io r = HTTP.startread(io) From ae1d6119773f6720c95d2cf0b79329883c6b14de Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 02:06:05 -0700 Subject: [PATCH 19/56] feat(util): add download helper --- src/HTTP.jl | 1 + src/download.jl | 118 ++++++++++++++++++++++++++++++++++++++++++++++++ test/utils.jl | 15 ++++++ 3 files changed, 134 insertions(+) create mode 100644 src/download.jl diff --git a/src/HTTP.jl b/src/HTTP.jl index c934bf5a..d20918a9 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -50,6 +50,7 @@ include("client/request.jl") include("client/stream.jl") include("client/makerequest.jl") include("client/open.jl") +include("download.jl") include("websockets.jl"); using .WebSockets include("server.jl") include("handlers.jl"); using .Handlers diff --git a/src/download.jl b/src/download.jl new file mode 100644 index 00000000..bdd85d8c --- /dev/null +++ b/src/download.jl @@ -0,0 +1,118 @@ +function safer_joinpath(basepart, parts...) + explain = "Possible directory traversal attack detected." + for part in parts + occursin("..", part) && throw(DomainError(part, "contains \"..\". $explain")) + startswith(part, '/') && throw(DomainError(part, "begins with \"/\". $explain")) + end + return joinpath(basepart, parts...) +end + +function try_get_filename_from_headers(hdrs) + for content_disp in hdrs + filename_part = match(r"filename\\s*=\\s*(.*)", content_disp) + if filename_part !== nothing + filename = filename_part[1] + quoted_filename = match(r"\\\"(.*)\\\"", filename) + if quoted_filename !== nothing + filename = unescape_string(quoted_filename[1]) + end + return filename == "" ? nothing : filename + end + end + return nothing +end + +function try_get_filename_from_request(req::Request) + function file_from_target(t) + (t == "" || t == "/") && return nothing + f = basename(URI(t).path) + return f == "" ? file_from_target(dirname(t)) : f + end + return file_from_target(req.path) +end + +determine_file(::Nothing, resp, hdrs) = determine_file(tempdir(), resp, hdrs) + +function determine_file(path, resp, hdrs) + if isdir(path) + filename = something( + try_get_filename_from_headers(hdrs), + resp.request === nothing ? nothing : try_get_filename_from_request(resp.request), + basename(tempname()) + ) + return safer_joinpath(path, filename) + end + return path +end + +""" + download(url, [local_path], [headers]; update_period=1, kw...) + +Download a URL to a local file, returning the filename. If `local_path` is not +provided, the file is saved in a temporary directory. If `local_path` is a +directory, the filename is determined from response headers or request target. + +`update_period` controls progress reporting in seconds (set to `Inf` to disable). +Additional keyword arguments are forwarded to `HTTP.open`. +""" +function download(url::AbstractString, local_path=nothing, headers=Header[]; update_period=1, kw...) + format_progress(x) = round(x, digits=4) + format_bytes(x) = !isfinite(x) ? "∞ B" : Base.format_bytes(round(Int, x)) + format_seconds(x) = "$(round(x; digits=2)) s" + format_bytes_per_second(x) = format_bytes(x) * "/s" + + @debug "downloading $url" + local file + hdrs = String[] + HTTP.open("GET", url, headers; kw...) do stream + resp = startread(stream) + content_disp = header(resp, "Content-Disposition") + !isempty(content_disp) && push!(hdrs, content_disp) + eof(stream) && return + + file = determine_file(local_path, resp, hdrs) + total_bytes = parse(Float64, header(resp, "Content-Length", "NaN")) + downloaded_bytes = 0 + start_time = now() + prev_time = now() + + if header(resp, "Content-Encoding") == "gzip" + total_bytes = NaN + end + + function report_callback() + prev_time = now() + taken_time = (prev_time - start_time).value / 1000 + average_speed = downloaded_bytes / taken_time + remaining_bytes = total_bytes - downloaded_bytes + remaining_time = remaining_bytes / average_speed + completion_progress = downloaded_bytes / total_bytes + @info("Downloading", + source=url, + dest=file, + progress=completion_progress |> format_progress, + time_taken=taken_time |> format_seconds, + time_remaining=remaining_time |> format_seconds, + average_speed=average_speed |> format_bytes_per_second, + downloaded=downloaded_bytes |> format_bytes, + remaining=remaining_bytes |> format_bytes, + total=total_bytes |> format_bytes, + ) + end + + Base.open(file, "w") do fh + while !eof(stream) + downloaded_bytes += write(fh, readavailable(stream)) + if !isinf(update_period) + if now() - prev_time > Millisecond(round(1000update_period)) + report_callback() + end + end + end + end + if !isinf(update_period) + report_callback() + end + end + return file +end diff --git a/test/utils.jl b/test/utils.jl index 84f52c46..7793f84e 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -37,4 +37,19 @@ @test :startread in exported @test :closewrite in exported @test :closeread in exported + + @testset "download" begin + server = HTTP.serve!(req -> HTTP.Response(200, ["Content-Disposition" => "attachment; filename=\"hello.txt\""], "hello"); listenany=true) + try + port = HTTP.port(server) + mktempdir() do dir + file = HTTP.download("http://127.0.0.1:$port/hello.txt", dir) + @test isfile(file) + @test basename(file) == "hello.txt" + @test String(read(file)) == "hello" + end + finally + close(server) + end + end end # testset From 90da92787a46460c06283728e036b0cc8cb30d0e Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 02:08:56 -0700 Subject: [PATCH 20/56] feat(api): add nobody constant --- src/HTTP.jl | 2 ++ test/utils.jl | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/HTTP.jl b/src/HTTP.jl index d20918a9..d3230d95 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -9,6 +9,8 @@ export startwrite, startread, closewrite, closeread export @logfmt_str, common_logfmt, combined_logfmt export WebSockets +const nobody = UInt8[] + include("utils.jl") include("access_log.jl") include("sniff.jl"); using .Sniff diff --git a/test/utils.jl b/test/utils.jl index 7793f84e..a9df0265 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -37,6 +37,8 @@ @test :startread in exported @test :closewrite in exported @test :closeread in exported + @test HTTP.nobody isa Vector{UInt8} + @test isempty(HTTP.nobody) @testset "download" begin server = HTTP.serve!(req -> HTTP.Response(200, ["Content-Disposition" => "attachment; filename=\"hello.txt\""], "hello"); listenany=true) From 27735c1db8a4c8f785d51a500a69b5f4c648e78b Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 02:27:30 -0700 Subject: [PATCH 21/56] fix(api): add escape deprecation alias --- src/HTTP.jl | 2 ++ test/utils.jl | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/HTTP.jl b/src/HTTP.jl index d3230d95..9f36dc61 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -11,6 +11,8 @@ export WebSockets const nobody = UInt8[] +Base.@deprecate escape escapeuri + include("utils.jl") include("access_log.jl") include("sniff.jl"); using .Sniff diff --git a/test/utils.jl b/test/utils.jl index a9df0265..85d3dd70 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -40,6 +40,8 @@ @test HTTP.nobody isa Vector{UInt8} @test isempty(HTTP.nobody) + @test_deprecated HTTP.escape("a b") == "a%20b" + @testset "download" begin server = HTTP.serve!(req -> HTTP.Response(200, ["Content-Disposition" => "attachment; filename=\"hello.txt\""], "hello"); listenany=true) try From 88a6b8c78912c9b497a6a0b2054931ab428955d1 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 02:37:59 -0700 Subject: [PATCH 22/56] feat(websockets): add upgrade/listen parity --- src/client/stream.jl | 6 ++ src/server.jl | 8 +++ src/websockets.jl | 139 ++++++++++++++++++++++++++++++++------- test/websockets_basic.jl | 53 +++++++++++++++ 4 files changed, 183 insertions(+), 23 deletions(-) diff --git a/src/client/stream.jl b/src/client/stream.jl index 29ab9927..cf27e050 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -140,6 +140,7 @@ mutable struct Stream{T} <: IO response_started::Bool handler_started::Bool ignore_writes::Bool + on_complete::Union{Nothing, Function} released::Bool # remaining fields are initially undefined ptr::Ptr{aws_http_stream} @@ -172,6 +173,7 @@ mutable struct Stream{T} <: IO false, false, false, + nothing, false, ) end @@ -342,6 +344,10 @@ function _server_closewrite(s::Stream) s.final_chunk_written = true return end + if hasheader(resp.headers, "upgrade") + s.final_chunk_written = true + return + end writechunk(s, "") s.final_chunk_written = true return diff --git a/src/server.jl b/src/server.jl index 1b210351..5df45312 100644 --- a/src/server.jl +++ b/src/server.jl @@ -505,6 +505,14 @@ function c_on_server_stream_complete(aws_stream_ptr, error_code, stream_ptr) @error "on_stream_complete error" exception=(e, catch_backtrace()) end end + if stream.on_complete !== nothing + try + Base.invokelatest(stream.on_complete, stream) + catch e + @error "stream on_complete error" exception=(e, catch_backtrace()) + end + stream.on_complete = nothing + end if stream.connection.server.access_log !== nothing try @info sprint(stream.connection.server.access_log, stream) _group=:access diff --git a/src/websockets.jl b/src/websockets.jl index f9f04bdd..d9a28f6d 100644 --- a/src/websockets.jl +++ b/src/websockets.jl @@ -2,12 +2,14 @@ module WebSockets using Base64, Random, LibAwsHTTPFork, LibAwsCommon, LibAwsIO -import ..FieldRef, ..iswss, ..getport, ..makeuri, ..aws_throw_error, ..resource, ..Headers, ..Header, ..str, ..aws_error, ..aws_throw_error, ..Future, ..parseuri, ..with_redirect, ..with_request, ..getclient, ..ClientSettings, ..scheme, ..host, ..getport, ..userinfo, ..Client, ..Request, ..Response, ..setinputstream!, ..getresponse, ..CookieJar, ..COOKIEJAR, ..addheaders, ..Stream, ..HTTP, ..getheader +import ..FieldRef, ..iswss, ..getport, ..makeuri, ..aws_throw_error, ..resource, ..Headers, ..Header, ..str, ..aws_error, ..aws_throw_error, ..Future, ..parseuri, ..with_redirect, ..with_request, ..getclient, ..ClientSettings, ..scheme, ..host, ..getport, ..userinfo, ..Client, ..Request, ..Response, ..Message, ..setinputstream!, ..getresponse, ..CookieJar, ..COOKIEJAR, ..addheaders, ..Stream, ..HTTP, ..getheader, ..hasheader, ..header export WebSocket, send, receive, ping, pong @enum OpCode::UInt8 CONTINUATION=0x00 TEXT=0x01 BINARY=0x02 CLOSE=0x08 PING=0x09 PONG=0x0A +const DEFAULT_MAX_FRAG = 1024 + struct CloseFrameBody code::Int reason::String @@ -18,11 +20,26 @@ struct WebSocketError <: Exception end isok(e::WebSocketError) = e.message.code in (1000, 1001, 1005) +isok(::Any) = false + +function isupgrade(r::Message) + ((r isa Request && r.method == "GET") || + (r isa Response && r.status == 101)) && + (hasheader(r, "Connection", "upgrade") || + hasheader(r, "Connection", "keep-alive, upgrade")) && + hasheader(r, "Upgrade", "websocket") +end + +isupgrade(s::Stream) = isupgrade(s.request) + +Base.@deprecate is_upgrade isupgrade mutable struct WebSocket id::String host::String path::String + maxframesize::Int + maxfragmentation::Int connect_fut::Future{Nothing} readchannel::Channel{Union{String, Vector{UInt8}, WebSocketError}} writebuffer::Vector{UInt8} @@ -40,10 +57,12 @@ mutable struct WebSocket fragment_payload::Vector{UInt8} closebody::Union{Nothing, CloseFrameBody} - WebSocket(host::AbstractString, path::AbstractString) = new( + WebSocket(host::AbstractString, path::AbstractString; maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG) = new( string(rand(UInt32); base=58), String(host), String(path), + Int(maxframesize), + Int(maxfragmentation), Future{Nothing}(), Channel{Union{String, Vector{UInt8}, WebSocketError}}(Inf), UInt8[], @@ -272,7 +291,10 @@ function c_on_incoming_frame_complete(websocket::Ptr{aws_websocket}, frame::Ptr{ end function open(f::Function, url; + suppress_close_error::Bool=false, headers=[], + maxframesize::Integer=typemax(Int), + maxfragmentation::Integer=DEFAULT_MAX_FRAG, allocator::Ptr{aws_allocator}=default_aws_allocator(), username=nothing, password=nothing, @@ -306,7 +328,7 @@ function open(f::Function, url; path = resource(uri) with_request(reqclient, method, path, headers, body, nothing, false, (username !== nothing && password !== nothing) ? "$username:$password" : userinfo(uri), bearer, modifier, false, cookies, cookiejar, verbose) do req host = str(uri.host_name) - ws = WebSocket(host, path) + ws = WebSocket(host, path; maxframesize=maxframesize, maxfragmentation=maxfragmentation) ws.handshake_request = req ws.handshake_response = Response(0, nothing, nothing, false, allocator) options = aws_websocket_client_connection_options( @@ -341,23 +363,23 @@ function open(f::Function, url; try f(ws) catch e - # if !isok(e) - # suppress_close_error || @error "$(ws.id): error" (e, catch_backtrace()) - # end - # if !isclosed(ws) - # if e isa WebSocketError && e.message isa CloseFrameBody - # close(ws, e.message) - # else - # close(ws, CloseFrameBody(1008, "Unexpected client websocket error")) - # end - # end - # if !isok(e) + if !isok(e) + suppress_close_error || @error "$(ws.id): error" exception=(e, catch_backtrace()) + end + if !isclosed(ws) + if e isa WebSocketError && e.message isa CloseFrameBody + close(ws, e.message) + else + close(ws, CloseFrameBody(1008, "Unexpected client websocket error")) + end + end + if !isok(e) rethrow() - # end + end finally - # if !isclosed(ws) - ws.closebody === nothing && close(ws) - # end + if !isclosed(ws) + close(ws, CloseFrameBody(1000, "")) + end end end @@ -588,9 +610,20 @@ function Base.iterate(ws::WebSocket, st=nothing) end end +@noinline handshakeerror() = throw(WebSocketError(CloseFrameBody(1002, "Websocket handshake failed"))) + # given a WebSocket request, return the 101 response function websocket_upgrade_handler(req::Request) + if !isupgrade(req) + return Response(400, ["content-type" => "text/plain"], "websocket upgrade required", false, req.allocator) + end + if !hasheader(req, "Sec-WebSocket-Version", "13") + return Response(400, ["content-type" => "text/plain"], "unsupported websocket version", false, req.allocator) + end key = getheader(req.headers, "sec-websocket-key") + if key === nothing || isempty(key) + return Response(400, ["content-type" => "text/plain"], "missing websocket key", false, req.allocator) + end resp_ptr = aws_http_message_new_websocket_handshake_response(req.allocator, aws_byte_cursor_from_c_str(key)) resp_ptr == C_NULL && aws_throw_error() resp = Response() @@ -600,12 +633,19 @@ function websocket_upgrade_handler(req::Request) return resp end -function websocket_upgrade_function(f) +function websocket_upgrade_function(f; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, done=nothing) #TODO: return WebSocketUpgradeArgs # then schedule a task to do the actual upgrade function websocket_upgrade(stream::Stream) - #TODO: get host/path from stream? - ws = WebSocket("", "") + resp = isdefined(stream, :response) ? stream.response : nothing + if resp === nothing || resp.status != 101 + done !== nothing && notify(done, CapturedException(ArgumentError("websocket upgrade not accepted"), Base.backtrace())) + return + end + req = stream.request + ws = WebSocket(header(req, "host", ""), req.path; maxframesize=maxframesize, maxfragmentation=maxfragmentation) + ws.handshake_request = req + ws.handshake_response = resp stream.websocket_options = aws_websocket_server_upgrade_options( 0, Ptr{Cvoid}(pointer_from_objref(ws)), @@ -618,9 +658,36 @@ function websocket_upgrade_function(f) ws_ptr == C_NULL && aws_throw_error() ws.websocket_pointer = ws_ptr errormonitor(Threads.@spawn begin + err = nothing try f(ws) + catch e + if !isok(e) + err = e + suppress_close_error || @error "$(ws.id): error" exception=(e, catch_backtrace()) + end + if !isclosed(ws) + if e isa WebSocketError && e.message isa CloseFrameBody + close(ws, e.message) + elseif isok(e) + close(ws, CloseFrameBody(1000, "")) + else + close(ws, CloseFrameBody(1011, "Unexpected server websocket error")) + end + end + if err !== nothing + done !== nothing && notify(done, CapturedException(e, catch_backtrace())) + end + if !isok(e) + rethrow() + end finally + if err === nothing + if !isclosed(ws) + close(ws, CloseFrameBody(1000, "")) + end + done !== nothing && notify(done, nothing) + end aws_websocket_release(ws_ptr) end end) @@ -628,8 +695,34 @@ function websocket_upgrade_function(f) end end -serve!(f, host="127.0.0.1", port=8080; kw...) = - HTTP.serve!(websocket_upgrade_handler, host, port; on_stream_complete=websocket_upgrade_function(f), kw...) +function _upgrade(f::Function, stream::Stream; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG) + isupgrade(stream) || handshakeerror() + hasheader(stream.request, "Sec-WebSocket-Version", "13") || handshakeerror() + key = getheader(stream.request.headers, "sec-websocket-key") + (key === nothing || isempty(key)) && handshakeerror() + stream.response_started && error("response already started") + done = Future{Nothing}() + stream.on_complete = websocket_upgrade_function(f; suppress_close_error=suppress_close_error, maxframesize=maxframesize, maxfragmentation=maxfragmentation, done=done) + stream.response = websocket_upgrade_handler(stream.request) + HTTP.startwrite(stream) + HTTP.closewrite(stream) + wait(done) + return +end + +upgrade(f::Function, stream::Stream; kw...) = _upgrade(f, stream; kw...) +upgrade(stream::Stream, f::Function; kw...) = _upgrade(f, stream; kw...) + +serve!(f, host="127.0.0.1", port=8080; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, kw...) = + HTTP.serve!(websocket_upgrade_handler, host, port; on_stream_complete=websocket_upgrade_function(f; suppress_close_error=suppress_close_error, maxframesize=maxframesize, maxfragmentation=maxfragmentation), kw...) + +listen!(f, host="127.0.0.1", port=8080; kw...) = serve!(f, host, port; kw...) + +function listen(f, host="127.0.0.1", port=8080; kw...) + server = listen!(f, host, port; kw...) + wait(server) + return server +end function __init__() on_connection_setup[] = @cfunction(c_on_connection_setup, Cvoid, (Ptr{aws_websocket_on_connection_setup_data}, Ptr{Cvoid})) diff --git a/test/websockets_basic.jl b/test/websockets_basic.jl index d3484c87..3ac81ee6 100644 --- a/test/websockets_basic.jl +++ b/test/websockets_basic.jl @@ -44,3 +44,56 @@ end close(server) end end + +@testset "WebSockets listen!" begin + server = WebSockets.listen!("127.0.0.1", 0; listenany=true) do ws + send(ws, "hi") + WebSockets.close(ws, WebSockets.CloseFrameBody(1000, "bye")) + end + port = HTTP.port(server) + try + WebSockets.open("ws://127.0.0.1:$port") do ws + @test receive(ws) == "hi" + try + receive(ws) + @test false + catch e + @test e isa WebSockets.WebSocketError + @test WebSockets.isok(e) + end + end + finally + close(server) + end +end + +@testset "WebSockets upgrade via HTTP.listen!" begin + server = HTTP.listen!("127.0.0.1", 0; listenany=true) do stream + if WebSockets.isupgrade(stream) + WebSockets.upgrade(stream) do ws + send(ws, "pong") + WebSockets.close(ws, WebSockets.CloseFrameBody(1000, "bye")) + end + else + HTTP.setstatus(stream, 404) + HTTP.startwrite(stream) + write(stream, "nope") + end + end + port = HTTP.port(server) + try + WebSockets.open("ws://127.0.0.1:$port") do ws + @test receive(ws) == "pong" + try + receive(ws) + @test false + catch e + @test e isa WebSockets.WebSocketError + @test WebSockets.isok(e) + @test e.message.code == 1000 + end + end + finally + close(server) + end +end From 24d30841351688658cfb7df5b7f23ca8f5cdebdf Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 02:42:16 -0700 Subject: [PATCH 23/56] docs: update websockets and migration guides --- docs/src/manual/migrate.md | 14 ++++++++------ docs/src/manual/websockets.md | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/docs/src/manual/migrate.md b/docs/src/manual/migrate.md index e0f500b0..ef9008e4 100644 --- a/docs/src/manual/migrate.md +++ b/docs/src/manual/migrate.md @@ -37,11 +37,12 @@ r = HTTP.request("GET", "http://example.com") # Accessing fields status = r.status body_text = String(r.body) # Similar to v1.x -header_value = HTTP.header(r.headers, "Content-Type") +header_value = HTTP.header(r, "Content-Type") +# or HTTP.header(r.headers, "Content-Type") ``` Key differences: -- Headers access has changed slightly to operate on the `headers` field +- Headers access works directly on `Request`/`Response`, or on the `headers` field if you already have it - The `.body` field can now be any type, not just `Vector{UInt8}` - Context dictionary access is now through `.context` rather than request-specific fields @@ -85,7 +86,7 @@ end # After (v2.0) HTTP.open("GET", "http://example.com") do http - # Start reading must be explicitly called + # Optional: call startread to access headers before reading the body startread(http) while !eof(http) data = readavailable(http) @@ -134,7 +135,7 @@ end ``` Key differences: -- Most server functionality has been standardized around `serve`/`serve!` rather than `listen`/`listen!` +- `serve`/`serve!` are the primary entry points; `listen`/`listen!` remain available for stream handlers and WebSockets - The handler typically works with `Request`/`Response` objects rather than `Stream` objects - The lifecycle management for servers has improved with clearer semantics for `isopen`, `close`, and `wait` @@ -167,6 +168,7 @@ end ``` Note the addition of the `stream=true` keyword argument to indicate you want to work with a stream handler. +You can also use `HTTP.listen`/`HTTP.listen!` as shorthand for `stream=true`. ### Router and Middleware @@ -271,7 +273,7 @@ close(server) using HTTP.WebSockets # Non-blocking server -server = WebSockets.serve!("127.0.0.1", 8081) do ws +server = WebSockets.listen!("127.0.0.1", 8081) do ws for msg in ws # Echo back any received message send(ws, msg) @@ -282,7 +284,7 @@ end close(server) ``` -Note the change from `listen!` to `serve!` to maintain consistency with the HTTP server API. +`listen!` and `serve!` are both supported in v2.0; `serve!` is an alias that matches the HTTP server naming. ## Error Handling diff --git a/docs/src/manual/websockets.md b/docs/src/manual/websockets.md index 259650e0..33c08cdc 100644 --- a/docs/src/manual/websockets.md +++ b/docs/src/manual/websockets.md @@ -82,6 +82,40 @@ WebSockets.open("wss://echo.websocket.org") do ws end ``` +## Server-side WebSockets + +For a dedicated WebSocket server, use `WebSockets.listen` or `WebSockets.listen!`: + +```julia +using HTTP +using HTTP.WebSockets + +server = WebSockets.listen!("127.0.0.1", 8080) do ws + for msg in ws + send(ws, msg) + end +end +``` + +To mix WebSockets and normal HTTP on the same port, use a stream handler with `HTTP.listen!` and +upgrade the connection when requested: + +```julia +using HTTP +using HTTP.WebSockets + +server = HTTP.listen!("127.0.0.1", 8080) do stream + if WebSockets.isupgrade(stream) + WebSockets.upgrade(stream) do ws + send(ws, "hello") + end + else + HTTP.setstatus(stream, 200) + write(stream, "ok") + end +end +``` + ## Connection Lifecycle and Error Handling You can check whether a WebSocket is open using `WebSockets.isclosed(ws)` and close it with `close(ws)`. The API is designed to raise exceptions for connection issues or protocol errors, allowing you to handle errors using try‑catch blocks. From 23b43fec5ada1b46d75eb8c9b2888ffb6e7538e0 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 02:56:18 -0700 Subject: [PATCH 24/56] feat(api): add utility parity helpers --- src/client/stream.jl | 2 ++ src/handlers.jl | 35 ++++++++++++++++++++++++++++--- src/requestresponse.jl | 2 ++ src/utils.jl | 47 ++++++++++++++++++++++++++++++++++++++++++ test/server.jl | 23 +++++++++++++++++++++ test/utils.jl | 27 ++++++++++++++++++++++++ 6 files changed, 133 insertions(+), 3 deletions(-) diff --git a/src/client/stream.jl b/src/client/stream.jl index cf27e050..ebc51413 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -1,3 +1,5 @@ +export Stream, closebody, isaborted, readall!, setstatus + const on_response_headers = Ref{Ptr{Cvoid}}(C_NULL) function c_on_response_headers(aws_stream_ptr, header_block, header_array::Ptr{aws_http_header}, num_headers, stream_ptr) diff --git a/src/handlers.jl b/src/handlers.jl index 9c06b20f..22ae9768 100644 --- a/src/handlers.jl +++ b/src/handlers.jl @@ -1,8 +1,9 @@ module Handlers -export Handler, Middleware, serve, serve!, Router, register!, getroute, getparams, getparam, getcookies +export Handler, Middleware, serve, serve!, Router, register!, getroute, getparams, getparam, getcookies, streamhandler -import ..Request, ..Cookies +import ..Request, ..Response, ..Stream, ..Cookies, ..getbody +import ..startread, ..setstatus, ..setheader, ..addtrailer, ..closewrite, ..closeread """ Handler @@ -39,6 +40,34 @@ then an input to the `auth_middlware`, which further enhances/modifies the handl """ abstract type Middleware end +""" + streamhandler(request_handler) -> stream handler + +Middleware that takes a request handler and returns a stream handler. +""" +function streamhandler(handler) + return function(stream::Stream) + req = startread(stream) + req.body = read(stream) + resp = Base.invokelatest(handler, req)::Response + resp.request = req + setstatus(stream, resp.status) + for h in resp.headers + setheader(stream, h.name, h.value) + end + body = getbody(resp) + if body isa IO + write(stream, read(body)) + elseif body !== nothing + write(stream, body) + end + resp.trailers === nothing || addtrailer(stream, resp.trailers) + closewrite(stream) + closeread(stream) + return + end +end + # tree-based router handler mutable struct Variable name::String @@ -396,4 +425,4 @@ middleware. """ getcookies(req) = req.cookies === nothing ? Cookies.Cookie[] : req.cookies::Vector{Cookies.Cookie} -end # module Handlers \ No newline at end of file +end # module Handlers diff --git a/src/requestresponse.jl b/src/requestresponse.jl index 3bf80299..7d59d3d7 100644 --- a/src/requestresponse.jl +++ b/src/requestresponse.jl @@ -1,3 +1,5 @@ +export Header, Headers, Message, Request, Response + # working with headers headereq(a::String, b::String) = GC.@preserve a b aws_http_header_name_eq(aws_byte_cursor_from_c_str(a), aws_byte_cursor_from_c_str(b)) diff --git a/src/utils.jl b/src/utils.jl index 22e94a23..564d6dae 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,6 @@ +export bytes, isbytes, nbytes, nobytes, + escapehtml, tocameldash, iso8859_1_to_utf8, ascii_lc_isequal + """ HTTPVersion(major, minor) @@ -126,6 +129,23 @@ tocameldash(s::AbstractString) = tocameldash(String(s)) @inline isupper(b::UInt8) = UInt8('A') <= b <= UInt8('Z') @inline lower(c::UInt8) = c | 0x20 +""" + ascii_lc_isequal(a, b) + +Case insensitive ASCII string comparison. +""" +function ascii_lc_isequal(a, b) + acu = codeunits(a) + bcu = codeunits(b) + len = length(acu) + len != length(bcu) && return false + for i = 1:len + @inbounds (acu[i] in UInt8('A'):UInt8('Z') ? acu[i] + 0x20 : acu[i]) == + (bcu[i] in UInt8('A'):UInt8('Z') ? bcu[i] + 0x20 : bcu[i]) || return false + end + return true +end + function parseuri(url, query, allocator) uri_ref = Ref{aws_uri}() if url isa AbstractString @@ -142,8 +162,35 @@ function parseuri(url, query, allocator) return uri_ref[] end +""" + bytes(x) + +If `x` is "castable" to an `AbstractVector{UInt8}`, then an +`AbstractVector{UInt8}` is returned; otherwise `x` is returned. +""" +function bytes end +bytes(s::AbstractVector{UInt8}) = s +bytes(s::AbstractString) = codeunits(s) +bytes(x) = x + isbytes(x) = x isa AbstractVector{UInt8} || x isa AbstractString +""" + nbytes(x) -> Int + +Length in bytes of `x` if `x` is `isbytes(x)`. +""" +function nbytes end +nbytes(x) = nothing +nbytes(x::AbstractVector{UInt8}) = length(x) +nbytes(x::AbstractString) = sizeof(x) +nbytes(x::Vector{T}) where T <: AbstractString = sum(sizeof, x) +nbytes(x::Vector{T}) where T <: AbstractVector{UInt8} = sum(length, x) +nbytes(x::IOBuffer) = bytesavailable(x) +nbytes(x::Vector{IOBuffer}) = sum(bytesavailable, x) + +const nobytes = view(UInt8[], 1:0) + str(bc::aws_byte_cursor) = bc.ptr == C_NULL || bc.len == 0 ? "" : unsafe_string(bc.ptr, bc.len) function print_uri(io, uri::aws_uri) diff --git a/test/server.jl b/test/server.jl index defed070..7dc6143e 100644 --- a/test/server.jl +++ b/test/server.jl @@ -35,6 +35,29 @@ end end end +@testset "HTTP.streamhandler" begin + handler = req -> begin + body = req.body === nothing ? UInt8[] : req.body + if isempty(body) + return HTTP.Response(200, ["Content-Type" => "text/plain"], "ping") + end + return HTTP.Response(200, ["Content-Type" => "text/plain"], String(body)) + end + server = HTTP.listen!(HTTP.streamhandler(handler), "127.0.0.1", 0; listenany=true) + try + port = HTTP.port(server) + resp = HTTP.get("http://127.0.0.1:$port") + @test resp.status == 200 + @test String(resp.body) == "ping" + + resp = HTTP.post("http://127.0.0.1:$port"; body="echo") + @test resp.status == 200 + @test String(resp.body) == "echo" + finally + close(server) + end +end + @testset "HTTP response trailers" begin server = HTTP.listen!("127.0.0.1", 0; listenany=true) do http read(http) diff --git a/test/utils.jl b/test/utils.jl index 85d3dd70..1212840d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -30,6 +30,19 @@ @test HTTP.iso8859_1_to_utf8(bytes) == utf8 end + buf = UInt8[0x01, 0x02] + @test HTTP.bytes(buf) === buf + @test collect(HTTP.bytes("hi")) == collect(codeunits("hi")) + @test HTTP.nbytes("hi") == 2 + @test HTTP.nbytes(buf) == 2 + @test HTTP.nbytes([buf, UInt8[0x03]]) == 3 + @test HTTP.nbytes(["a", "bc"]) == 3 + @test HTTP.nbytes(IOBuffer("abc")) == 3 + @test HTTP.nobytes isa AbstractVector{UInt8} + @test isempty(HTTP.nobytes) + @test HTTP.ascii_lc_isequal("AbC", "aBc") + @test !HTTP.ascii_lc_isequal("abc", "abd") + @test_throws HTTP.AWSError HTTP.parseuri("http://example.com:abc", nothing, HTTP.default_aws_allocator()) exported = names(HTTP, all=false) @@ -37,8 +50,22 @@ @test :startread in exported @test :closewrite in exported @test :closeread in exported + @test :Stream in exported + @test :Request in exported + @test :Response in exported + @test :Message in exported + @test :Header in exported + @test :Headers in exported + @test :bytes in exported + @test :nbytes in exported + @test :nobytes in exported + @test :escapehtml in exported + @test :tocameldash in exported + @test :iso8859_1_to_utf8 in exported + @test :ascii_lc_isequal in exported @test HTTP.nobody isa Vector{UInt8} @test isempty(HTTP.nobody) + @test isdefined(HTTP, :streamhandler) @test_deprecated HTTP.escape("a b") == "a%20b" From 00023025e0ff9da013741f2e7dbac6467fadc182 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 03:23:25 -0700 Subject: [PATCH 25/56] feat(api): add HTTP error types and trailers --- src/HTTP.jl | 3 +- src/client/connection.jl | 18 +++++++- src/client/retry.jl | 5 +++ src/client/stream.jl | 92 +++++++++++++++++++++++++++++++++++----- src/exceptions.jl | 88 ++++++++++++++++++++++++++++++++++++++ src/requestresponse.jl | 7 ++- test/client.jl | 2 +- test/utils.jl | 23 ++++++++++ 8 files changed, 223 insertions(+), 15 deletions(-) create mode 100644 src/exceptions.jl diff --git a/src/HTTP.jl b/src/HTTP.jl index 9f36dc61..bc59b11d 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -18,7 +18,8 @@ include("access_log.jl") include("sniff.jl"); using .Sniff include("forms.jl"); using .Forms include("requestresponse.jl") -struct StatusError <: Exception +include("exceptions.jl"); using .Exceptions +struct StatusError <: HTTPError request_method::String request_uri::aws_uri response::Response diff --git a/src/client/connection.jl b/src/client/connection.jl index 65dac2f6..00063ae5 100644 --- a/src/client/connection.jl +++ b/src/client/connection.jl @@ -34,12 +34,22 @@ function c_on_ping_complete(conn, round_trip_time_ns, error_code, fut_ptr) return end +function _client_url(client::Client) + host = client.settings.host + port = client.settings.port + return string(client.settings.scheme, "://", host, ":", port) +end + function with_connection(f::Function, client::Client; context=nothing) if context === nothing fut = Future{Ptr{aws_http_connection}}() GC.@preserve fut begin aws_http_connection_manager_acquire_connection(client.connection_manager, on_setup[], pointer_from_objref(fut)) - connection = wait(fut) + connection = try + wait(fut) + catch e + throw(ConnectError(_client_url(client), e)) + end end try return f(connection) @@ -51,7 +61,11 @@ function with_connection(f::Function, client::Client; context=nothing) fut = Future{Ptr{aws_http_connection}}() GC.@preserve fut begin aws_http_connection_manager_acquire_connection(client.connection_manager, on_setup[], pointer_from_objref(fut)) - connection = wait(fut) + connection = try + wait(fut) + catch e + throw(ConnectError(_client_url(client), e)) + end end try return f(connection) diff --git a/src/client/retry.jl b/src/client/retry.jl index 80773fd6..6743d4ea 100644 --- a/src/client/retry.jl +++ b/src/client/retry.jl @@ -14,6 +14,9 @@ Base.showerror(io::IO, e::StreamError) = print(io, e.error) retryable_status(status::Integer) = status in (403, 408, 409, 429, 500, 502, 503, 504, 599) isrecoverable(ex::StatusError) = retryable_status(ex.status) +isrecoverable(ex::ConnectError) = isrecoverable(ex.error) +isrecoverable(ex::TimeoutError) = true +isrecoverable(ex::RequestError) = isrecoverable(ex.error) isrecoverable(::Union{Base.EOFError, Base.IOError}) = true isrecoverable(ex::ArgumentError) = ex.msg == "stream is closed or unusable" isrecoverable(ex::CompositeException) = all(isrecoverable, ex.exceptions) @@ -73,6 +76,8 @@ function _set_nretries!(x, nretries::Int) x.metrics.nretries = nretries elseif x isa StatusError x.response.metrics.nretries = nretries + elseif x isa RequestError + _set_nretries!(x.error, nretries) elseif x isa StreamError && x.stream !== nothing x.stream.response !== nothing && (x.stream.response.metrics.nretries = nretries) end diff --git a/src/client/stream.jl b/src/client/stream.jl index ebc51413..abb68fca 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -92,7 +92,11 @@ function c_on_complete(aws_stream_ptr, error_code, stream_ptr) close(stream.bufferstream) end if error_code != 0 - notify(stream.fut, CapturedException(aws_error(error_code), Base.backtrace())) + if error_code == AWS_ERROR_HTTP_RESPONSE_FIRST_BYTE_TIMEOUT && stream.readtimeout > 0 + notify(stream.fut, TimeoutError(stream.readtimeout)) + else + notify(stream.fut, CapturedException(aws_error(error_code), Base.backtrace())) + end else notify(stream.fut, nothing) end @@ -142,6 +146,7 @@ mutable struct Stream{T} <: IO response_started::Bool handler_started::Bool ignore_writes::Bool + readtimeout::Int on_complete::Union{Nothing, Function} released::Bool # remaining fields are initially undefined @@ -175,6 +180,7 @@ mutable struct Stream{T} <: IO false, false, false, + 0, nothing, false, ) @@ -350,6 +356,9 @@ function _server_closewrite(s::Stream) s.final_chunk_written = true return end + if resp.trailers !== nothing + aws_http1_stream_add_chunked_trailer(s.ptr, resp.trailers.ptr) != 0 && aws_throw_error() + end writechunk(s, "") s.final_chunk_written = true return @@ -546,7 +555,15 @@ end function closeread(s::Stream) startread(s) try - wait(s.fut) + try + wait(s.fut) + catch e + e isa HTTPError && rethrow() + if !s.server_side && isdefined(s, :request) && s.request !== nothing + throw(RequestError(s.request, e)) + end + rethrow() + end finally release_stream_ptr!(s) end @@ -592,6 +609,17 @@ end function addtrailer(s::Stream, headers::Headers) s.ptr == C_NULL && error("stream is not initialized") + if s.server_side + resp = _ensure_response!(s) + if resp.trailers === nothing + resp.trailers = headers + elseif resp.trailers !== headers + for h in headers + addheader(resp.trailers, h) + end + end + return + end if s.http2 aws_http2_stream_add_trailing_headers(s.ptr, headers.ptr) != 0 && aws_throw_error() else @@ -617,6 +645,7 @@ end function with_stream_manager(client::Client, req::Request, chunkedbody, on_stream_response_body, decompress, readtimeout, allocator; context=nothing) if context === nothing stream = Stream{Nothing}(allocator, decompress, true, false) + stream.readtimeout = readtimeout if on_stream_response_body !== nothing stream.bufferstream = Base.BufferStream() end @@ -657,12 +686,22 @@ function with_stream_manager(client::Client, req::Request, chunkedbody, on_strea while !eof(stream.bufferstream) on_stream_response_body(resp, _readavailable(stream.bufferstream)) end - wait(stream.fut) + try + wait(stream.fut) + catch e + e isa HTTPError && rethrow() + throw(RequestError(req, e)) + end catch e rethrow(DontRetry(e)) end else - wait(stream.fut) + try + wait(stream.fut) + catch e + e isa HTTPError && rethrow() + throw(RequestError(req, e)) + end if stream.bufferstream !== nothing resp.body = _readavailable(stream.bufferstream) else @@ -680,6 +719,7 @@ function with_stream_manager(client::Client, req::Request, chunkedbody, on_strea start_time = time() stream = Stream{Nothing}(allocator, decompress, true, false) + stream.readtimeout = readtimeout if on_stream_response_body !== nothing stream.bufferstream = Base.BufferStream() end @@ -720,12 +760,22 @@ function with_stream_manager(client::Client, req::Request, chunkedbody, on_strea while !eof(stream.bufferstream) on_stream_response_body(resp, _readavailable(stream.bufferstream)) end - wait(stream.fut) + try + wait(stream.fut) + catch e + e isa HTTPError && rethrow() + throw(RequestError(req, e)) + end catch e rethrow(DontRetry(e)) end else - wait(stream.fut) + try + wait(stream.fut) + catch e + e isa HTTPError && rethrow() + throw(RequestError(req, e)) + end if stream.bufferstream !== nothing resp.body = _readavailable(stream.bufferstream) else @@ -745,6 +795,7 @@ end function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator; context=nothing) if context === nothing stream = Stream{Nothing}(allocator, decompress, http2, false) + stream.readtimeout = readtimeout if on_stream_response_body !== nothing stream.bufferstream = Base.BufferStream() end @@ -782,12 +833,22 @@ function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, while !eof(stream.bufferstream) on_stream_response_body(resp, _readavailable(stream.bufferstream)) end - wait(stream.fut) + try + wait(stream.fut) + catch e + e isa HTTPError && rethrow() + throw(RequestError(req, e)) + end catch e rethrow(DontRetry(e)) end else - wait(stream.fut) + try + wait(stream.fut) + catch e + e isa HTTPError && rethrow() + throw(RequestError(req, e)) + end if stream.bufferstream !== nothing resp.body = _readavailable(stream.bufferstream) else @@ -805,6 +866,7 @@ function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, start_time = time() stream = Stream{Nothing}(allocator, decompress, http2, false) + stream.readtimeout = readtimeout if on_stream_response_body !== nothing stream.bufferstream = Base.BufferStream() end @@ -842,12 +904,22 @@ function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, while !eof(stream.bufferstream) on_stream_response_body(resp, _readavailable(stream.bufferstream)) end - wait(stream.fut) + try + wait(stream.fut) + catch e + e isa HTTPError && rethrow() + throw(RequestError(req, e)) + end catch e rethrow(DontRetry(e)) end else - wait(stream.fut) + try + wait(stream.fut) + catch e + e isa HTTPError && rethrow() + throw(RequestError(req, e)) + end if stream.bufferstream !== nothing resp.body = _readavailable(stream.bufferstream) else diff --git a/src/exceptions.jl b/src/exceptions.jl new file mode 100644 index 00000000..c5e2e84d --- /dev/null +++ b/src/exceptions.jl @@ -0,0 +1,88 @@ +module Exceptions + +export @try, HTTPError, ConnectError, TimeoutError, RequestError, current_exceptions_to_string + +@eval begin +""" + @try PermittedErrorTypes expr + +Convenience macro for wrapping an expression in a try/catch block where thrown +exceptions are ignored if they match one of the permitted types. +""" +macro $(:try)(exes...) + errs = Any[exes...] + ex = pop!(errs) + isempty(errs) && error("no permitted errors") + quote + try $(esc(ex)) + catch e + e isa InterruptException && rethrow(e) + |($([:(e isa $(esc(err))) for err in errs]...)) || rethrow(e) + end + end +end +end + +abstract type HTTPError <: Exception end + +""" + HTTP.ConnectError + +Raised when an error occurs while trying to establish a request connection to +the remote server. The underlying error is stored in `error`. +""" +struct ConnectError <: HTTPError + url::String + error::Any +end + +function Base.showerror(io::IO, e::ConnectError) + print(io, "HTTP.ConnectError for url = `$(e.url)`: ") + Base.showerror(io, e.error) +end + +""" + HTTP.TimeoutError + +Raised when a request times out according to `readtimeout` keyword argument provided. +""" +struct TimeoutError <: HTTPError + readtimeout::Int +end + +Base.showerror(io::IO, e::TimeoutError) = + print(io, "TimeoutError: Connection closed after $(e.readtimeout) seconds") + +""" + HTTP.RequestError + +Raised when an error occurs while physically sending a request to the remote server +or reading the response back. The underlying error is stored in `error`. +""" +struct RequestError <: HTTPError + request::Any + error::Any +end + +function Base.showerror(io::IO, e::RequestError) + println(io, "HTTP.RequestError:") + println(io, "HTTP.Request:") + Base.show(io, e.request) + println(io, "Underlying error:") + Base.showerror(io, e.error) +end + +function current_exceptions_to_string() + buf = IOBuffer() + println(buf) + println(buf, "\n===========================\nHTTP Error message:\n") + exc = @static if VERSION >= v"1.8.0-" + Base.current_exceptions() + else + Base.catch_stack() + end + Base.display_error(buf, exc) + return String(take!(buf)) +end + +end # module Exceptions diff --git a/src/requestresponse.jl b/src/requestresponse.jl index 7d59d3d7..f38f2fea 100644 --- a/src/requestresponse.jl +++ b/src/requestresponse.jl @@ -459,7 +459,12 @@ function print_request(io, method, version, path, headers, body) return end -getbody(r::Message) = isdefined(r, :inputstream) ? r.inputstream.bodyref : r.body +function getbody(r::Message) + if isdefined(r, :inputstream) && r.inputstream !== nothing + return r.inputstream.bodyref + end + return r.body +end print_request(io::IO, r::Request) = print_request(io, r.method, r.version, r.path, r.headers, getbody(r)) diff --git a/test/client.jl b/test/client.jl index e7604ecf..db3f2ccc 100644 --- a/test/client.jl +++ b/test/client.jl @@ -194,7 +194,7 @@ end @testset "readtimeout" begin - @test_throws CapturedException HTTP.get("http://$httpbin/delay/5"; readtimeout=1, max_retries=0) + @test_throws HTTP.TimeoutError HTTP.get("http://$httpbin/delay/5"; readtimeout=1, max_retries=0) @test isok(HTTP.get("http://$httpbin/delay/1"; readtimeout=2, max_retries=0)) end diff --git a/test/utils.jl b/test/utils.jl index 1212840d..b381afb2 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -43,6 +43,29 @@ @test HTTP.ascii_lc_isequal("AbC", "aBc") @test !HTTP.ascii_lc_isequal("abc", "abd") + err = HTTP.ConnectError("http://example.com", ErrorException("boom")) + @test err isa HTTP.ConnectError + @test err.url == "http://example.com" + @test err.error isa ErrorException + @test HTTP.TimeoutError(5).readtimeout == 5 + req = HTTP.Request("GET", "/") + req_err = HTTP.RequestError(req, ErrorException("nope")) + @test req_err.request === req + @test req_err.error isa ErrorException + msg = "" + try + error("boom") + catch + msg = HTTP.current_exceptions_to_string() + end + @test occursin("boom", msg) + got = false + HTTP.@try ArgumentError begin + got = true + throw(ArgumentError("boom")) + end + @test got + @test_throws HTTP.AWSError HTTP.parseuri("http://example.com:abc", nothing, HTTP.default_aws_allocator()) exported = names(HTTP, all=false) From ae2ee9d8273b1e23bd7473f734142b84037f58c1 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 03:30:15 -0700 Subject: [PATCH 26/56] feat(client): handle iterable chunked bodies --- src/client/request.jl | 8 ++++++++ test/client.jl | 30 ++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/src/client/request.jl b/src/client/request.jl index 45adaccb..1c1aa7a9 100644 --- a/src/client/request.jl +++ b/src/client/request.jl @@ -38,6 +38,14 @@ function with_request( observelayers::Bool=false, context=nothing, ) + if chunkedbody === nothing && body isa IO && !(body isa IOStream) && !(body isa Form) + chunkedbody = IOChunkedBody(body) + body = nothing + end + if chunkedbody === nothing && body !== nothing && !(body isa RequestBodyTypes) && Base.isiterable(typeof(body)) + chunkedbody = body + body = nothing + end # create request mutable_headers = (headers isa AbstractVector{<:Pair} && !copyheaders) ? headers : nothing req_headers = mkreqheaders(headers, copyheaders) diff --git a/test/client.jl b/test/client.jl index db3f2ccc..032fa8a3 100644 --- a/test/client.jl +++ b/test/client.jl @@ -363,6 +363,36 @@ end end + @testset "Iterable request body streaming" begin + mutable struct ChunkedIterable + chunks::Vector{Vector{UInt8}} + iter_calls::Int + end + ChunkedIterable(chunks) = ChunkedIterable(chunks, 0) + function Base.iterate(it::ChunkedIterable, state::Int=1) + state > length(it.chunks) && return nothing + it.iter_calls += 1 + return (it.chunks[state], state + 1) + end + + server = HTTP.listen!("127.0.0.1", 0; listenany=true) do http + body = String(read(http)) + HTTP.setstatus(http, 200) + HTTP.setheader(http, "Content-Type" => "text/plain") + HTTP.startwrite(http) + write(http, body) + end + try + port = HTTP.port(server) + chunks = ChunkedIterable([Vector{UInt8}("hello"), Vector{UInt8}(" "), Vector{UInt8}("world")]) + resp = HTTP.post("http://127.0.0.1:$port/"; body=chunks) + @test String(resp.body) == "hello world" + @test chunks.iter_calls == 3 + finally + close(server) + end + end + @testset "stream helpers" begin server = HTTP.listen!("127.0.0.1", 0; listenany=true) do http body = String(read(http)) From 0bf5d04efcbf5958286f25c555f89b6c084c59c8 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 03:36:04 -0700 Subject: [PATCH 27/56] fix(client): fail on invalid body streams --- src/requestresponse.jl | 6 ++++++ test/client.jl | 7 +++++++ 2 files changed, 13 insertions(+) diff --git a/src/requestresponse.jl b/src/requestresponse.jl index f38f2fea..8e8482b7 100644 --- a/src/requestresponse.jl +++ b/src/requestresponse.jl @@ -267,19 +267,24 @@ function InputStream(allocator::Ptr{aws_allocator}, body) is.bodyref = body is.bodycursor = aws_byte_cursor(sizeof(body), pointer(body)) is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) + is.ptr == C_NULL && aws_throw_error() elseif body isa Union{AbstractDict, NamedTuple} # hold a reference to the request body in order to gc-preserve it is.bodyref = URIs.escapeuri(body) is.bodycursor = aws_byte_cursor_from_c_str(is.bodyref) is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) + is.ptr == C_NULL && aws_throw_error() elseif body isa IOStream + isopen(body) || throw(ArgumentError("request body IOStream is closed")) is.bodyref = body is.ptr = aws_input_stream_new_from_open_file(allocator, Libc.FILE(body)) + is.ptr == C_NULL && aws_throw_error() elseif body isa Form # we set the request.body to the Form bytes in order to gc-preserve them is.bodyref = read(body) is.bodycursor = aws_byte_cursor(sizeof(is.bodyref), pointer(is.bodyref)) is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) + is.ptr == C_NULL && aws_throw_error() elseif body isa IO # we set the request.body to the IO bytes in order to gc-preserve them bytes = readavailable(body) @@ -289,6 +294,7 @@ function InputStream(allocator::Ptr{aws_allocator}, body) is.bodyref = bytes is.bodycursor = aws_byte_cursor(sizeof(is.bodyref), pointer(is.bodyref)) is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) + is.ptr == C_NULL && aws_throw_error() elseif Base.isiterable(typeof(body)) # assume a chunked request body; any kind of iterable where elements are RequestBodyTypes is.bodyref = body diff --git a/test/client.jl b/test/client.jl index 032fa8a3..6ff18663 100644 --- a/test/client.jl +++ b/test/client.jl @@ -363,6 +363,13 @@ end end + @testset "closed IOStream body errors" begin + path = tempname() + io = open(path, "w") + close(io) + @test_throws ArgumentError HTTP.request("POST", "http://example.com"; body=io, retry=false, status_exception=false) + end + @testset "Iterable request body streaming" begin mutable struct ChunkedIterable chunks::Vector{Vector{UInt8}} From bb4436252cdb8d4339ce76bf4ff6a08681b9ef1c Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 03:46:58 -0700 Subject: [PATCH 28/56] feat(metrics): track request body length --- src/client/stream.jl | 22 ++++++++++++++++++++++ test/client.jl | 21 +++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/src/client/stream.jl b/src/client/stream.jl index abb68fca..024a62ec 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -277,6 +277,9 @@ function writechunk(s::Stream, chunk::RequestBodyTypes) aws_http1_stream_write_chunk(s.ptr, FieldRef(s, :chunk_options)) != 0 && aws_throw_error() end wait(fut) + if !s.server_side && isdefined(s, :response) && s.response !== nothing + s.response.metrics.request_body_length += s.chunk.bodylen + end return s.chunk.bodylen end @@ -518,6 +521,21 @@ function Base.write(s::Stream, data::AbstractVector{UInt8}) return length(data) end +function Base.write(s::Stream, data::StridedVector{UInt8}) + startwrite(s) + if s.server_side + if s.ignore_writes + return length(data) + elseif s.http2 + s.responsebuf === nothing && (s.responsebuf = IOBuffer()) + write(s.responsebuf, data) + return length(data) + end + end + writechunk(s, data) + return length(data) +end + function Base.write(s::Stream, data::Union{String, SubString{String}}) startwrite(s) if s.server_side @@ -667,6 +685,7 @@ function with_stream_manager(client::Client, req::Request, chunkedbody, on_strea stream.response = resp = Response(0, nothing, nothing, true, allocator) resp.metrics = RequestMetrics() resp.request = req + resp.metrics.request_body_length = bodylen(req) acquire_opts = aws_http2_stream_manager_acquire_stream_options( on_stream_acquired[], pointer_from_objref(acquire_fut), @@ -741,6 +760,7 @@ function with_stream_manager(client::Client, req::Request, chunkedbody, on_strea stream.response = resp = Response(0, nothing, nothing, true, allocator) resp.metrics = RequestMetrics() resp.request = req + resp.metrics.request_body_length = bodylen(req) acquire_opts = aws_http2_stream_manager_acquire_stream_options( on_stream_acquired[], pointer_from_objref(acquire_fut), @@ -820,6 +840,7 @@ function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, stream.response = resp = Response(0, nothing, nothing, http2, allocator) resp.metrics = RequestMetrics() resp.request = req + resp.metrics.request_body_length = bodylen(req) try aws_http_stream_activate(stream_ptr) != 0 && aws_throw_error() # write chunked body if provided @@ -891,6 +912,7 @@ function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, stream.response = resp = Response(0, nothing, nothing, http2, allocator) resp.metrics = RequestMetrics() resp.request = req + resp.metrics.request_body_length = bodylen(req) try aws_http_stream_activate(stream_ptr) != 0 && aws_throw_error() # write chunked body if provided diff --git a/test/client.jl b/test/client.jl index 6ff18663..82f9e747 100644 --- a/test/client.jl +++ b/test/client.jl @@ -266,6 +266,27 @@ end end + @testset "Request metrics" begin + server = HTTP.listen!("127.0.0.1", 0; listenany=true) do http + body = read(http) + HTTP.setstatus(http, 200) + HTTP.startwrite(http) + write(http, body) + end + try + port = HTTP.port(server) + resp = HTTP.post("http://127.0.0.1:$port/"; body="hello") + @test resp.metrics.request_body_length == 5 + @test resp.metrics.response_body_length == 5 + + resp = HTTP.post("http://127.0.0.1:$port/"; body=IOBuffer("chunked")) + @test resp.metrics.request_body_length == 7 + @test resp.metrics.response_body_length == 7 + finally + close(server) + end + end + @testset "Request Options Parity" begin headers = ["X-Test" => "1"] HTTP.get("https://$httpbin/headers"; headers=headers, copyheaders=true) From 39159237ceedc1a301ca574b73d419e089146939 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 03:54:49 -0700 Subject: [PATCH 29/56] feat(api): export header helpers --- src/requestresponse.jl | 28 +++++++++++++++++++++++++++- test/utils.jl | 14 ++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/requestresponse.jl b/src/requestresponse.jl index 8e8482b7..47217c6d 100644 --- a/src/requestresponse.jl +++ b/src/requestresponse.jl @@ -1,4 +1,7 @@ -export Header, Headers, Message, Request, Response +export Header, Headers, Message, Request, Response, + header, headers, hasheader, headercontains, + setheader, setheaderifabsent, defaultheader!, appendheader, removeheader, + canonicalizeheaders, canonicalizeheaders!, mkheaders # working with headers headereq(a::String, b::String) = GC.@preserve a b aws_http_header_name_eq(aws_byte_cursor_from_c_str(a), aws_byte_cursor_from_c_str(b)) @@ -81,6 +84,7 @@ Base.deleteat!(h::Headers, i::Int) = aws_http_headers_erase_index(h.ptr, i - 1) Base.empty!(h::Headers) = aws_http_headers_clear(h.ptr) != 0 && aws_throw_error() setheaderifabsent(headers, k, v) = !hasheader(headers, k) && setheader(headers, k, v) +setheaderifabsent(m::Message, k, v) = setheaderifabsent(m.headers, k, v) field_name_isequal(a, b) = headereq(String(a), String(b)) @@ -149,6 +153,28 @@ function setheader(h::AbstractVector{<:Pair}, v::Pair) return h end +appendheader(m::Message, v) = appendheader(m.headers, v) +appendheader(h::Headers, v::Header) = addheader(h, v) +appendheader(h::Headers, v::Pair) = addheader(h, String(v.first), String(v.second)) +function appendheader(h::AbstractVector{<:Pair}, v::Pair) + push!(h, String(v.first) => String(v.second)) + return h +end + +removeheader(m::Message, k) = removeheader(m.headers, k) +removeheader(m::Message, k, v) = removeheader(m.headers, k, v) +function removeheader(h::AbstractVector{<:Pair}, k) + key = String(k) + filter!(kv -> !field_name_isequal(kv.first, key), h) + return h +end +function removeheader(h::AbstractVector{<:Pair}, k, v) + key = String(k) + val = String(v) + filter!(kv -> !(field_name_isequal(kv.first, key) && field_name_isequal(kv.second, val)), h) + return h +end + """ defaultheader!(::Message, key => value) diff --git a/test/utils.jl b/test/utils.jl index b381afb2..70395fb3 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -79,6 +79,12 @@ @test :Message in exported @test :Header in exported @test :Headers in exported + @test :header in exported + @test :headers in exported + @test :hasheader in exported + @test :setheader in exported + @test :appendheader in exported + @test :removeheader in exported @test :bytes in exported @test :nbytes in exported @test :nobytes in exported @@ -86,6 +92,14 @@ @test :tocameldash in exported @test :iso8859_1_to_utf8 in exported @test :ascii_lc_isequal in exported + + req = HTTP.Request("GET", "/") + HTTP.setheader(req, "X-Test" => "a") + @test HTTP.header(req, "X-Test") == "a" + HTTP.appendheader(req, "X-Test" => "b") + @test HTTP.headers(req, "X-Test") == ["a", "b"] + HTTP.removeheader(req, "X-Test") + @test HTTP.header(req, "X-Test") == "" @test HTTP.nobody isa Vector{UInt8} @test isempty(HTTP.nobody) @test isdefined(HTTP, :streamhandler) From 932dfa821859dc6f3c1f9a3417154d0257c9d218 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 03:58:39 -0700 Subject: [PATCH 30/56] test(client): avoid duplicate headers --- test/client.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/client.jl b/test/client.jl index 82f9e747..901e2df3 100644 --- a/test/client.jl +++ b/test/client.jl @@ -193,6 +193,23 @@ @test length(r.body) > 0 end + @testset "Header insertion" begin + server = HTTP.serve!(req -> begin + accept_count = length(HTTP.headers(req, "accept")) + host_count = length(HTTP.headers(req, "host")) + return HTTP.Response(200, "$accept_count,$host_count") + end; listenany=true) + try + port = HTTP.port(server) + resp = HTTP.get("http://127.0.0.1:$port"; headers=["Accept" => "*/*", "Host" => "example.com"]) + @test String(resp.body) == "1,1" + resp = HTTP.get("http://127.0.0.1:$port") + @test String(resp.body) == "1,1" + finally + close(server) + end + end + @testset "readtimeout" begin @test_throws HTTP.TimeoutError HTTP.get("http://$httpbin/delay/5"; readtimeout=1, max_retries=0) @test isok(HTTP.get("http://$httpbin/delay/1"; readtimeout=2, max_retries=0)) From 24eb4d5a7bdfee642ca6e327eb85319496291104 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 04:11:50 -0700 Subject: [PATCH 31/56] feat(client): add http2 stream manager options --- docs/src/manual/client.md | 5 +++++ src/client/client.jl | 15 ++++++++++----- test/client.jl | 20 ++++++++++++++++++++ 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/docs/src/manual/client.md b/docs/src/manual/client.md index 9d6db70a..9e649988 100644 --- a/docs/src/manual/client.md +++ b/docs/src/manual/client.md @@ -103,6 +103,11 @@ The following keyword arguments (which correspond to the non-`scheme`/`host`/`po -- HTTP/2 options: - **http2_prior_knowledge**: Default `false`. If `true`, assume HTTP/2 without ALPN negotiation. - **http2_stream_manager**: Default `false`. If `true`, enable the HTTP/2 stream manager for multiplexed requests. + - **http2_close_connection_on_server_error**: Default `false`. If `true`, close HTTP/2 connections when a 5xx response is received. + - **http2_connection_ping_period_ms**: Default `0` (disabled). Period in milliseconds for sending HTTP/2 PING frames. + - **http2_connection_ping_timeout_ms**: Default `0` (AWS default). Timeout in milliseconds for PING responses. + - **http2_ideal_concurrent_streams_per_connection**: Default `0` (AWS default). Target streams per connection before opening new connections. + - **http2_max_concurrent_streams_per_connection**: Default `0` (no explicit limit). Upper bound for streams per connection. -- AWS runtime options: - **allocator**: The allocator to use for AWS-allocated memory during the request. - **bootstrap**: The AWS client bootstrap to use for the request. diff --git a/src/client/client.jl b/src/client/client.jl index 9c784d1e..a656073b 100644 --- a/src/client/client.jl +++ b/src/client/client.jl @@ -106,6 +106,11 @@ Base.@kwdef struct ClientSettings monitoring_statistics_observer::Union{Nothing, Function} = nothing http2_prior_knowledge::Bool = false http2_stream_manager::Bool = false + http2_close_connection_on_server_error::Bool = false + http2_connection_ping_period_ms::Int = 0 + http2_connection_ping_timeout_ms::Int = 0 + http2_ideal_concurrent_streams_per_connection::Int = 0 + http2_max_concurrent_streams_per_connection::Int = 0 end ClientSettings( @@ -309,11 +314,11 @@ function Client(cs::ClientSettings) client.proxy_env_settings === nothing ? C_NULL : pointer(FieldRef(client, :proxy_env_settings)), C_NULL, # shutdown_complete_user_data C_NULL, # shutdown_complete_callback - false, # close_connection_on_server_error - 0, # connection_ping_period_ms - 0, # connection_ping_timeout_ms - 0, # ideal_concurrent_streams_per_connection - 0, # max_concurrent_streams_per_connection + cs.http2_close_connection_on_server_error, # close_connection_on_server_error + cs.http2_connection_ping_period_ms, # connection_ping_period_ms + cs.http2_connection_ping_timeout_ms, # connection_ping_timeout_ms + cs.http2_ideal_concurrent_streams_per_connection, # ideal_concurrent_streams_per_connection + cs.http2_max_concurrent_streams_per_connection, # max_concurrent_streams_per_connection cs.max_connections, ) client.http2_stream_manager_opts = opts diff --git a/test/client.jl b/test/client.jl index 901e2df3..63758766 100644 --- a/test/client.jl +++ b/test/client.jl @@ -504,6 +504,26 @@ finalize(client) end + @testset "HTTP/2 stream manager options" begin + cs = HTTP.ClientSettings("https", "example.com", UInt32(443); + http2_stream_manager=true, + http2_close_connection_on_server_error=true, + http2_connection_ping_period_ms=1234, + http2_connection_ping_timeout_ms=2345, + http2_ideal_concurrent_streams_per_connection=7, + http2_max_concurrent_streams_per_connection=9, + ) + client = HTTP.Client(cs) + opts = client.http2_stream_manager_opts + @test opts !== nothing + @test opts.close_connection_on_server_error == true + @test opts.connection_ping_period_ms == Csize_t(1234) + @test opts.connection_ping_timeout_ms == Csize_t(2345) + @test opts.ideal_concurrent_streams_per_connection == Csize_t(7) + @test opts.max_concurrent_streams_per_connection == Csize_t(9) + finalize(client) + end + @testset "HTTP manager metrics" begin client = HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443))) metrics = HTTP.manager_metrics(client) From b2085d20e9288a864e81bd905479cf01904a3f67 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 04:14:09 -0700 Subject: [PATCH 32/56] docs(migrate): update v2 migration notes --- docs/src/manual/migrate.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/src/manual/migrate.md b/docs/src/manual/migrate.md index ef9008e4..9a8c3e0f 100644 --- a/docs/src/manual/migrate.md +++ b/docs/src/manual/migrate.md @@ -43,7 +43,8 @@ header_value = HTTP.header(r, "Content-Type") Key differences: - Headers access works directly on `Request`/`Response`, or on the `headers` field if you already have it -- The `.body` field can now be any type, not just `Vector{UInt8}` +- Request/response *body arguments* accept strings, byte vectors, Dict/NamedTuple (form-encoded), `HTTP.Form`, or IO +- `Response.body` remains a `Vector{UInt8}` for buffered responses (or `nothing` when streaming into `response_body`) - Context dictionary access is now through `.context` rather than request-specific fields ### Making Requests @@ -338,6 +339,7 @@ cookies = HTTP.getcookies(jar, "example.com") - **TLS Implementation**: OpenSSL is now the default TLS provider instead of MbedTLS - **Multithreading**: Improved thread safety throughout the codebase - **Performance**: Significant performance improvements, especially for high-throughput servers +- **Parser APIs**: Low-level parser APIs from v1.x have been removed in v2.0 - **Trailing Headers**: Trailing headers are now captured on `Request.trailers` and `Response.trailers`, and can be sent with `HTTP.addtrailer` when streaming. - **HTTP/2 Controls**: HTTP/2 helpers (ping, settings, GOAWAY) are available for advanced connection management. - **Metrics**: Responses include `response.metrics` and clients expose `HTTP.manager_metrics` for connection manager stats. From 200d94dfbfa67f50a61045e74367a7795e6751c1 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 04:25:37 -0700 Subject: [PATCH 33/56] feat(server): stream http2 responses --- src/client/stream.jl | 26 +++++++++++++++----------- test/server.jl | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/src/client/stream.jl b/src/client/stream.jl index 024a62ec..ee3f09d9 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -310,6 +310,9 @@ function _server_startwrite(s::Stream) setinputstream!(resp, nothing) end if s.http2 + if !s.response_started + _send_response!(s) + end s.write_started = true return end @@ -333,12 +336,17 @@ function _server_closewrite(s::Stream) if !s.response_started if s.ignore_writes setinputstream!(resp, nothing) - else - body = s.responsebuf === nothing ? UInt8[] : take!(s.responsebuf) - setinputstream!(resp, body) end _send_response!(s) end + if s.ignore_writes + s.final_chunk_written = true + return + end + if resp.trailers !== nothing + aws_http2_stream_add_trailing_headers(s.ptr, resp.trailers.ptr) != 0 && aws_throw_error() + end + writechunk(s, "") s.final_chunk_written = true return end @@ -512,8 +520,7 @@ function Base.write(s::Stream, data::AbstractVector{UInt8}) if s.ignore_writes return length(data) elseif s.http2 - s.responsebuf === nothing && (s.responsebuf = IOBuffer()) - write(s.responsebuf, data) + writechunk(s, data) return length(data) end end @@ -527,8 +534,7 @@ function Base.write(s::Stream, data::StridedVector{UInt8}) if s.ignore_writes return length(data) elseif s.http2 - s.responsebuf === nothing && (s.responsebuf = IOBuffer()) - write(s.responsebuf, data) + writechunk(s, data) return length(data) end end @@ -542,8 +548,7 @@ function Base.write(s::Stream, data::Union{String, SubString{String}}) if s.ignore_writes return sizeof(data) elseif s.http2 - s.responsebuf === nothing && (s.responsebuf = IOBuffer()) - write(s.responsebuf, data) + writechunk(s, data) return sizeof(data) end end @@ -561,8 +566,7 @@ function Base.write(s::Stream, b::UInt8) if s.ignore_writes return 1 elseif s.http2 - s.responsebuf === nothing && (s.responsebuf = IOBuffer()) - write(s.responsebuf, b) + writechunk(s, UInt8[b]) return 1 end end diff --git a/test/server.jl b/test/server.jl index 7dc6143e..73aa7284 100644 --- a/test/server.jl +++ b/test/server.jl @@ -82,6 +82,38 @@ end end end +@testset "HTTP/2 stream handler writes" begin + cert = joinpath(@__DIR__, "fixtures", "http2.crt") + key = joinpath(@__DIR__, "fixtures", "http2.key") + saw_http2 = Threads.Atomic{Bool}(false) + buffered = Threads.Atomic{Bool}(false) + server = HTTP.serve!("127.0.0.1", 0; listenany=true, stream=true, ssl_cert=cert, ssl_key=key, ssl_alpn_list="h2") do stream + HTTP.startread(stream) + stream.http2 && (saw_http2[] = true) + HTTP.setstatus(stream, 200) + HTTP.startwrite(stream) + write(stream, "hello") + if stream.http2 && stream.responsebuf !== nothing + buffered[] = true + end + HTTP.closewrite(stream) + end + try + port = HTTP.port(server) + resp = HTTP.get("https://127.0.0.1:$(port)"; ssl_insecure=true, ssl_alpn_list="h2") + if resp.version == HTTP.HTTPVersion(2, 0) + @test saw_http2[] + @test !buffered[] + @test String(resp.body) == "hello" + else + @info "HTTP/2 not negotiated for stream handler test" + @test true + end + finally + close(server) + end +end + @testset "HTTP/2 server push promise" begin cert = joinpath(@__DIR__, "fixtures", "http2.crt") key = joinpath(@__DIR__, "fixtures", "http2.key") From b445bdd2b865ce3b0723d1e979eabbe6e8477668 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 08:19:03 -0700 Subject: [PATCH 34/56] feat(client): add proxy basic auth --- docs/src/manual/client.md | 3 +++ src/client/client.jl | 40 ++++++++++++++++++++++++++++++++++----- src/client/makerequest.jl | 9 ++++++++- test/client.jl | 28 +++++++++++++++++++++++++++ 4 files changed, 74 insertions(+), 6 deletions(-) diff --git a/docs/src/manual/client.md b/docs/src/manual/client.md index 9e649988..db2ba8c8 100644 --- a/docs/src/manual/client.md +++ b/docs/src/manual/client.md @@ -81,6 +81,9 @@ The following keyword arguments (which correspond to the non-`scheme`/`host`/`po - **proxy_ssl_cacert**: The path to the CA certificate file for the proxy. - **proxy_ssl_insecure**: Default `false`. If `true`, SSL certificate verification will be disabled for the proxy. - **proxy_ssl_alpn_list**: A list of ALPN protocols to use for the proxy connection. Default is `"h2;http/1.1"`. + - **proxy_auth**: Optional. Set to `:basic` to enable basic auth on proxy requests. + - **proxy_username**: Username for proxy basic auth (requires explicit `proxy_host`/`proxy_port`). + - **proxy_password**: Password for proxy basic auth (requires explicit `proxy_host`/`proxy_port`). -- Retry options: - **max_retries**: The maximum number of times to retry a request. Default is 4. - **retry_partition**: Requests utilizing the same retry partition (an arbitrary string) will coordinate retries against each other to not overwhelm a temporarily unresponsive server. diff --git a/src/client/client.jl b/src/client/client.jl index a656073b..18d1b9ea 100644 --- a/src/client/client.jl +++ b/src/client/client.jl @@ -89,6 +89,9 @@ Base.@kwdef struct ClientSettings proxy_ssl_cacert::Union{Nothing, String} = nothing proxy_ssl_insecure::Bool = false proxy_ssl_alpn_list::String = "h2;http/1.1" + proxy_auth::Union{Nothing, Symbol} = nothing + proxy_username::Union{Nothing, String} = nothing + proxy_password::Union{Nothing, String} = nothing retry_partition::Union{Nothing, String} = nothing max_retries::Int = DEFAULT_MAX_RETRIES backoff_scale_factor_ms::Int = 25 @@ -151,6 +154,7 @@ mutable struct Client # only 1 of proxy_options or proxy_env_settings is set proxy_options::Union{Nothing, aws_http_proxy_options} proxy_env_settings::Union{Nothing, proxy_env_var_settings} + proxy_strategy::Ptr{aws_http_proxy_strategy} monitoring_options::Union{Nothing, aws_http_connection_monitoring_options} monitoring_observer::Union{Nothing, Any} retry_options::aws_standard_retry_options @@ -196,9 +200,29 @@ function Client(cs::ClientSettings) # proxy options client.proxy_options = nothing client.proxy_env_settings = nothing + client.proxy_strategy = C_NULL + proxy_connection_type = cs.proxy_connection_type == :forward ? AWS_HPCT_HTTP_FORWARD : AWS_HPCT_HTTP_TUNNEL if cs.proxy_host !== nothing && cs.proxy_port !== nothing + proxy_auth = cs.proxy_auth + if proxy_auth === nothing && (cs.proxy_username !== nothing || cs.proxy_password !== nothing) + proxy_auth = :basic + end + if proxy_auth !== nothing + proxy_auth == :basic || throw(ArgumentError("unsupported proxy_auth: $proxy_auth")) + cs.proxy_username === nothing && throw(ArgumentError("proxy_username required for basic proxy auth")) + cs.proxy_password === nothing && throw(ArgumentError("proxy_password required for basic proxy auth")) + auth_opts = aws_http_proxy_strategy_basic_auth_options( + proxy_connection_type, + aws_byte_cursor_from_c_str(cs.proxy_username), + aws_byte_cursor_from_c_str(cs.proxy_password), + ) + GC.@preserve cs begin + client.proxy_strategy = aws_http_proxy_strategy_new_basic_auth(cs.allocator, Ref(auth_opts)) + end + client.proxy_strategy == C_NULL && aws_throw_error() + end client.proxy_options = aws_http_proxy_options( - cs.proxy_connection_type == :forward ? AWS_HPCT_HTTP_FORWARD : AWS_HPCT_HTTP_TUNNEL, + proxy_connection_type, aws_byte_cursor_from_c_str(cs.proxy_host), cs.proxy_port % UInt32, cs.proxy_ssl_cert === nothing ? C_NULL : LibAwsIO.tlsoptions(cs.proxy_host; @@ -209,16 +233,18 @@ function Client(cs::ClientSettings) cs.proxy_ssl_insecure, cs.proxy_ssl_alpn_list ), - #TODO: support proxy_strategy - C_NULL, # proxy_strategy::Ptr{aws_http_proxy_strategy} - 0, # auth_type::aws_http_proxy_authentication_type + client.proxy_strategy, # proxy_strategy::Ptr{aws_http_proxy_strategy} + AWS_HPAT_NONE, # auth_type::aws_http_proxy_authentication_type aws_byte_cursor_from_c_str(""), # auth_username::aws_byte_cursor aws_byte_cursor_from_c_str(""), # auth_password::aws_byte_cursor ) elseif cs.proxy_allow_env_var + if cs.proxy_auth !== nothing || cs.proxy_username !== nothing || cs.proxy_password !== nothing + throw(ArgumentError("proxy auth requires explicit proxy_host/proxy_port")) + end client.proxy_env_settings = proxy_env_var_settings( AWS_HPEV_ENABLE, - cs.proxy_connection_type == :forward ? AWS_HPCT_HTTP_FORWARD : AWS_HPCT_HTTP_TUNNEL, + proxy_connection_type, cs.proxy_ssl_cert === nothing ? C_NULL : LibAwsIO.tlsoptions(cs.proxy_host; cs.proxy_ssl_cert, cs.proxy_ssl_key, @@ -335,6 +361,10 @@ function Client(cs::ClientSettings) aws_http2_stream_manager_release(x.http2_stream_manager) x.http2_stream_manager = C_NULL end + if x.proxy_strategy != C_NULL + aws_http_proxy_strategy_release(x.proxy_strategy) + x.proxy_strategy = C_NULL + end if x.retry_strategy != C_NULL aws_retry_strategy_release(x.retry_strategy) x.retry_strategy = C_NULL diff --git a/src/client/makerequest.jl b/src/client/makerequest.jl index 2480ba62..37a2fcec 100644 --- a/src/client/makerequest.jl +++ b/src/client/makerequest.jl @@ -22,7 +22,14 @@ function proxy_kwargs(proxy, req_scheme) isempty(p.host) && throw(ArgumentError("proxy URL must include a host")) port = isempty(p.port) ? (p.scheme == "https" ? 443 : 80) : parse(Int, p.port) conn_type = req_scheme in ("https", "wss") ? :tunnel : :forward - return (proxy_allow_env_var=false, proxy_host=p.host, proxy_port=UInt32(port), proxy_connection_type=conn_type) + if isempty(p.userinfo) + return (proxy_allow_env_var=false, proxy_host=p.host, proxy_port=UInt32(port), proxy_connection_type=conn_type) + end + parts = split(p.userinfo, ":"; limit=2) + proxy_user = unescapeuri(parts[1]) + proxy_pass = length(parts) == 2 ? unescapeuri(parts[2]) : "" + return (proxy_allow_env_var=false, proxy_host=p.host, proxy_port=UInt32(port), proxy_connection_type=conn_type, + proxy_auth=:basic, proxy_username=proxy_user, proxy_password=proxy_pass) else throw(ArgumentError("proxy must be a URL String, URI, nothing, or false")) end diff --git a/test/client.jl b/test/client.jl index 63758766..f0efcb69 100644 --- a/test/client.jl +++ b/test/client.jl @@ -579,6 +579,34 @@ finalize(client) end + @testset "Proxy basic auth strategy" begin + opts = HTTP.proxy_kwargs("http://user:pass@proxy.local:3128", "http") + @test opts.proxy_auth == :basic + @test opts.proxy_username == "user" + @test opts.proxy_password == "pass" + + cs = HTTP.ClientSettings("https", "example.com", UInt32(443); + proxy_host="proxy.local", + proxy_port=UInt32(3128), + proxy_connection_type=:forward, + proxy_auth=:basic, + proxy_username="user", + proxy_password="pass", + ) + client = HTTP.Client(cs) + @test client.proxy_options !== nothing + @test client.proxy_strategy != C_NULL + @test client.proxy_options.proxy_strategy == client.proxy_strategy + finalize(client) + + @test_throws ArgumentError HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443); + proxy_host="proxy.local", + proxy_port=UInt32(3128), + proxy_auth=:basic, + proxy_username="user", + )) + end + @testset "HTTP/2 control APIs" begin resp = HTTP.get("https://$httpbin/ip") if resp.version == HTTP.HTTPVersion(2, 0) From 72afcb9ab51fddea20ed5ceaeed95454881a8ea6 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 08:36:56 -0700 Subject: [PATCH 35/56] feat(access-log): parse basic auth remote_user --- src/access_log.jl | 27 +++++++++++++++++++++++++-- test/server.jl | 7 +++++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/access_log.jl b/src/access_log.jl index bbef3b9c..6f6e4415 100644 --- a/src/access_log.jl +++ b/src/access_log.jl @@ -68,7 +68,7 @@ function symbol_mapping(s::Symbol) elseif s === :remote_port :(HTTP.remote_port(http.connection)) elseif s === :remote_user - :("-") # TODO: find from Basic auth... + :(HTTP._remote_user(http)) elseif s === :time_iso8601 if !Sys.iswindows() :(Libc.strftime("%FT%T%z", time())) @@ -105,6 +105,30 @@ function symbol_mapping(s::Symbol) end end +function _remote_user(http) + auth = HTTP.getheader(http.request.headers, "authorization") + auth === nothing && return "-" + parts = split(auth; limit=2) + length(parts) < 2 && return "-" + ascii_lc_isequal(parts[1], "basic") || return "-" + token = parts[2] + isempty(token) && return "-" + decoded = try + Base64.base64decode(token) + catch + return "-" + end + userpass = try + String(decoded) + catch + return "-" + end + colon = findfirst(==(':'), userpass) + user = colon === nothing ? userpass : userpass[1:prevind(userpass, colon)] + isempty(user) && return "-" + return user +end + """ common_logfmt(io::IO, http::HTTP.Stream) @@ -118,4 +142,3 @@ const common_logfmt = logfmt"$remote_addr - $remote_user [$time_local] \"$reques Format a log message in the Combined Log Format and write to `io`. """ const combined_logfmt = logfmt"$remote_addr - $remote_user [$time_local] \"$request\" $status $body_bytes_sent \"$http_referer\" \"$http_user_agent\"" - diff --git a/test/server.jl b/test/server.jl index 73aa7284..e11c3a9c 100644 --- a/test/server.jl +++ b/test/server.jl @@ -1,4 +1,4 @@ -using Test, HTTP, Logging +using Test, HTTP, Logging, Base64 import Sockets @testset "HTTP.serve" begin @@ -224,11 +224,14 @@ end HTTP.get("http://127.0.0.1:$port/index.html") HTTP.get("http://127.0.0.1:$port/index.html?a=b") HTTP.head("http://127.0.0.1:$port") + auth = Base64.base64encode("alice:secret") + HTTP.get("http://127.0.0.1:$port/auth", ["Authorization" => "Basic $auth"]) end - @test length(logs) == 4 + @test length(logs) == 5 @test all(x -> x.group === :access, logs) @test occursin(r"^application/json text/plain GET / HTTP/1\.1 GET / 127\.0\.0\.1 \d+ - HTTP/1\.1 \d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.* \d+/.*/\d{4}:\d{2}:\d{2}:\d{2}.* 200 12$", logs[1].message) @test occursin(r"^\*/\* text/plain GET /index\.html HTTP/1\.1 GET /index\.html 127\.0\.0\.1 \d+ - HTTP/1\.1 \d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.* \d+/.*/\d{4}:\d{2}:\d{2}:\d{2}.* 200 12$", logs[2].message) @test occursin(r"^\*/\* text/plain GET /index\.html\?a=b HTTP/1\.1 GET /index\.html\?a=b 127\.0\.0\.1 \d+ - HTTP/1\.1 \d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.* \d+/.*/\d{4}:\d{2}:\d{2}:\d{2}.* 200 12$", logs[3].message) @test occursin(r"^\*/\* text/plain HEAD / HTTP/1\.1 HEAD / 127\.0\.0\.1 \d+ - HTTP/1\.1 \d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.* \d+/.*/\d{4}:\d{2}:\d{2}:\d{2}.* 200 0$", logs[4].message) + @test occursin(r"^\*/\* text/plain GET /auth HTTP/1\.1 GET /auth 127\.0\.0\.1 \d+ alice HTTP/1\.1 \d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.* \d+/.*/\d{4}:\d{2}:\d{2}:\d{2}.* 200 12$", logs[5].message) end From 9617630d21af6199a321605b9996c3447a161b08 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 10:25:35 -0700 Subject: [PATCH 36/56] feat(websockets): enforce frame and fragment limits --- src/websockets.jl | 33 +++++++++++++++++++++++++++++++++ test/websockets_basic.jl | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/src/websockets.jl b/src/websockets.jl index d9a28f6d..5a95eef5 100644 --- a/src/websockets.jl +++ b/src/websockets.jl @@ -55,6 +55,8 @@ mutable struct WebSocket incoming_payload::Vector{UInt8} fragment_opcode::Union{Nothing, UInt8} fragment_payload::Vector{UInt8} + fragment_count::Int + drop_incoming::Bool closebody::Union{Nothing, CloseFrameBody} WebSocket(host::AbstractString, path::AbstractString; maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG) = new( @@ -78,6 +80,8 @@ mutable struct WebSocket UInt8[], nothing, UInt8[], + 0, + false, nothing, ) end @@ -181,6 +185,14 @@ function c_on_incoming_frame_begin(websocket::Ptr{aws_websocket}, frame::Ptr{aws ws.incoming_opcode = fr.opcode ws.incoming_fin = fr.fin empty!(ws.incoming_payload) + ws.drop_incoming = false + if fr.payload_length > ws.maxframesize + close_body = CloseFrameBody(1009, "frame too large") + _queue_close!(ws, close_body) + Threads.@spawn close(ws, close_body) + ws.drop_incoming = true + return true + end fr.payload_length > 0 && sizehint!(ws.incoming_payload, Int(fr.payload_length)) return true end @@ -189,6 +201,7 @@ const on_incoming_frame_payload = Ref{Ptr{Cvoid}}(C_NULL) function c_on_incoming_frame_payload(websocket::Ptr{aws_websocket}, frame::Ptr{aws_websocket_incoming_frame}, data::aws_byte_cursor, ws_ptr) ws = unsafe_pointer_to_objref(ws_ptr) + ws.drop_incoming && return true try n = Int(data.len) n == 0 && return true @@ -213,6 +226,10 @@ function c_on_incoming_frame_complete(websocket::Ptr{aws_websocket}, frame::Ptr{ Threads.@spawn close(ws, close_body) return true end + if ws.drop_incoming + ws.drop_incoming = false + return true + end fr = unsafe_load(frame) opcode = fr.opcode fin = fr.fin @@ -255,12 +272,20 @@ function c_on_incoming_frame_complete(websocket::Ptr{aws_websocket}, frame::Ptr{ Threads.@spawn close(ws, close_body) return true end + ws.fragment_count += 1 + if ws.fragment_count > ws.maxfragmentation + close_body = CloseFrameBody(1009, "message too large") + _queue_close!(ws, close_body) + Threads.@spawn close(ws, close_body) + return true + end append!(ws.fragment_payload, payload) if fin msg_opcode = ws.fragment_opcode data = ws.fragment_payload ws.fragment_opcode = nothing ws.fragment_payload = UInt8[] + ws.fragment_count = 0 if msg_opcode == UInt8(TEXT) _enqueue_message!(ws, String(copy(data))) else @@ -282,9 +307,17 @@ function c_on_incoming_frame_complete(websocket::Ptr{aws_websocket}, frame::Ptr{ else _enqueue_message!(ws, copy(payload)) end + ws.fragment_count = 0 else ws.fragment_opcode = opcode ws.fragment_payload = copy(payload) + ws.fragment_count = 1 + if ws.fragment_count > ws.maxfragmentation + close_body = CloseFrameBody(1009, "message too large") + _queue_close!(ws, close_body) + Threads.@spawn close(ws, close_body) + return true + end end end return true diff --git a/test/websockets_basic.jl b/test/websockets_basic.jl index 3ac81ee6..e7db8057 100644 --- a/test/websockets_basic.jl +++ b/test/websockets_basic.jl @@ -97,3 +97,41 @@ end close(server) end end + +@testset "WebSockets max frame size" begin + server = WebSockets.listen!("127.0.0.1", 0; listenany=true) do ws + send(ws, "0123456789") + end + port = HTTP.port(server) + err = nothing + try + WebSockets.open("ws://127.0.0.1:$port"; maxframesize=5, suppress_close_error=true) do ws + receive(ws) + end + catch e + err = e + finally + close(server) + end + @test err isa WebSockets.WebSocketError + @test err.message.code == 1009 +end + +@testset "WebSockets max fragmentation" begin + server = WebSockets.listen!("127.0.0.1", 0; listenany=true) do ws + send(ws, ["a", "b", "c"]) + end + port = HTTP.port(server) + err = nothing + try + WebSockets.open("ws://127.0.0.1:$port"; maxfragmentation=2, suppress_close_error=true) do ws + receive(ws) + end + catch e + err = e + finally + close(server) + end + @test err isa WebSockets.WebSocketError + @test err.message.code == 1009 +end From 161c9e047088f0542d345f0527b08cd725ffc9d2 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 10:34:00 -0700 Subject: [PATCH 37/56] feat(retry): integrate aws retry strategy partition --- src/client/retry.jl | 86 ++++++++++++++++++++++++++++++++++++++++++--- test/client.jl | 6 ++++ 2 files changed, 88 insertions(+), 4 deletions(-) diff --git a/src/client/retry.jl b/src/client/retry.jl index 6743d4ea..e5f312b4 100644 --- a/src/client/retry.jl +++ b/src/client/retry.jl @@ -71,6 +71,22 @@ function _normalize_retry_delays(retry_delays, max_retries::Int) end end +function _retry_error_type(err) + if err isa StatusError + status = err.status + if status == 429 + return AWS_RETRY_ERROR_TYPE_THROTTLING + elseif 500 <= status < 600 + return AWS_RETRY_ERROR_TYPE_SERVER_ERROR + elseif 400 <= status < 500 + return AWS_RETRY_ERROR_TYPE_CLIENT_ERROR + else + return AWS_RETRY_ERROR_TYPE_TRANSIENT + end + end + return AWS_RETRY_ERROR_TYPE_TRANSIENT +end + function _set_nretries!(x, nretries::Int) if x isa Response x.metrics.nretries = nretries @@ -98,6 +114,17 @@ function with_retry_token( req_ref=nothing, context=nothing, ) + retry_token = Ptr{aws_retry_token}(C_NULL) + partition = client.settings.retry_partition + partition_ref = Ref{aws_byte_cursor}() + partition_ptr = C_NULL + if partition !== nothing + GC.@preserve partition begin + partition_ref[] = aws_byte_cursor_from_c_str(partition) + end + partition_ptr = partition_ref + end + use_retry_strategy = retry_delays === nothing && client.retry_strategy != C_NULL # If max_retries is 0, we don't need to bother with any retrying max_retries = client.settings.max_retries if max_retries == 0 @@ -135,6 +162,11 @@ function with_retry_token( ret = f() context === nothing || _record_layer!(context, :retrylayer, attempt_start) _set_nretries!(ret, nretries) + if retry_token != C_NULL + aws_retry_token_record_success(retry_token) != 0 && aws_throw_error() + aws_retry_token_release(retry_token) + retry_token = C_NULL + end return ret catch e context === nothing || _record_layer!(context, :retrylayer, attempt_start) @@ -156,12 +188,26 @@ function with_retry_token( end err = err.error _set_nretries!(err, nretries) + if retry_token != C_NULL + aws_retry_token_release(retry_token) + retry_token = C_NULL + end throw(err) end - nretries >= max_retries && (_set_nretries!(err, nretries); throw(err)) - delay_iter = delay_state === nothing ? iterate(delays) : iterate(delays, delay_state) - delay_iter === nothing && (_set_nretries!(err, nretries); throw(err)) - delay, delay_state = delay_iter + if nretries >= max_retries + _set_nretries!(err, nretries) + if retry_token != C_NULL + aws_retry_token_release(retry_token) + retry_token = C_NULL + end + throw(err) + end + delay = 0.0 + if !use_retry_strategy + delay_iter = delay_state === nothing ? iterate(delays) : iterate(delays, delay_state) + delay_iter === nothing && (_set_nretries!(err, nretries); throw(err)) + delay, delay_state = delay_iter + end req = req_ref === nothing ? nothing : req_ref[] resp = err isa StatusError ? err.response : nothing resp_body = resp === nothing ? nothing : resp.body @@ -171,8 +217,40 @@ function with_retry_token( end if !retry _set_nretries!(err, nretries) + if retry_token != C_NULL + aws_retry_token_release(retry_token) + retry_token = C_NULL + end throw(err) end + if use_retry_strategy + try + if retry_token == C_NULL + fut = Future{Ptr{aws_retry_token}}() + GC.@preserve fut begin + rc = aws_retry_strategy_acquire_retry_token(client.retry_strategy, partition_ptr, on_acquired[], pointer_from_objref(fut), UInt64(client.settings.retry_timeout_ms)) + rc != 0 && aws_throw_error() + retry_token = wait(fut) + end + end + fut = Future{Ptr{aws_retry_token}}() + error_type = _retry_error_type(err) + GC.@preserve fut begin + rc = aws_retry_strategy_schedule_retry(retry_token, error_type, retry_ready[], pointer_from_objref(fut)) + rc != 0 && aws_throw_error() + retry_token = wait(fut) + end + catch + if retry_token != C_NULL + aws_retry_token_release(retry_token) + retry_token = C_NULL + end + _set_nretries!(err, nretries) + throw(err) + end + nretries += 1 + continue + end nretries += 1 sleep(delay) end diff --git a/test/client.jl b/test/client.jl index f0efcb69..a1de58a7 100644 --- a/test/client.jl +++ b/test/client.jl @@ -278,6 +278,12 @@ end @test err isa HTTP.StatusError @test attempts[] == 2 + + reset_attempts!(1) + resp = HTTP.get("http://127.0.0.1:$port/"; retries=1, retry_delays=[0.0], retry_partition="test") + @test resp.status == 200 + @test resp.metrics.nretries == 1 + @test attempts[] == 2 finally close(server) end From 57598b521929f270e0e3fe2c886720a028ddbac6 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 10:40:07 -0700 Subject: [PATCH 38/56] feat(access-log): track streamed response bytes --- src/client/stream.jl | 8 ++++++-- src/requestresponse.jl | 7 +++++++ test/server.jl | 22 ++++++++++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/src/client/stream.jl b/src/client/stream.jl index ee3f09d9..f2d2ab76 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -277,8 +277,12 @@ function writechunk(s::Stream, chunk::RequestBodyTypes) aws_http1_stream_write_chunk(s.ptr, FieldRef(s, :chunk_options)) != 0 && aws_throw_error() end wait(fut) - if !s.server_side && isdefined(s, :response) && s.response !== nothing - s.response.metrics.request_body_length += s.chunk.bodylen + if isdefined(s, :response) && s.response !== nothing + if s.server_side + s.response.metrics.response_body_length += s.chunk.bodylen + else + s.response.metrics.request_body_length += s.chunk.bodylen + end end return s.chunk.bodylen end diff --git a/src/requestresponse.jl b/src/requestresponse.jl index 47217c6d..ffbd2c8f 100644 --- a/src/requestresponse.jl +++ b/src/requestresponse.jl @@ -568,6 +568,13 @@ getresponse(r::Response) = r bodylen(m::Message) = isdefined(m, :inputstream) && m.inputstream !== nothing ? m.inputstream.bodylen : 0 +function bodylen(r::Response) + if isdefined(r, :inputstream) && r.inputstream !== nothing + return r.inputstream.bodylen + end + return r.metrics.response_body_length +end + function Base.getproperty(x::Response, s::Symbol) if s == :status ref = Ref{Cint}() diff --git a/test/server.jl b/test/server.jl index e11c3a9c..8dbd6022 100644 --- a/test/server.jl +++ b/test/server.jl @@ -13,6 +13,28 @@ import Sockets end end +@testset "access logging stream handler" begin + logger = Test.TestLogger() + with_logger(logger) do + server = HTTP.listen!("127.0.0.1", 0; listenany=true, access_log=common_logfmt) do http + read(http) + HTTP.setstatus(http, 200) + HTTP.startwrite(http) + write(http, "hello") + end + port = HTTP.port(server) + try + HTTP.post("http://127.0.0.1:$port"; body="x") + sleep(1) + finally + close(server) + end + end + logs = filter!(x -> x.group == :access, logger.logs) + @test length(logs) == 1 + @test occursin(r" 200 5$", logs[1].message) +end + @testset "HTTP.listen stream handler" begin server = HTTP.listen!("127.0.0.1", 0; listenany=true) do http body = String(read(http)) From 6726a44dff05a9e8d94c4b0daee15bc9b58be780 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 16:14:20 -0700 Subject: [PATCH 39/56] fix(retry): use aws strategy only with partition --- src/client/retry.jl | 2 +- test/client.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/client/retry.jl b/src/client/retry.jl index e5f312b4..83df096a 100644 --- a/src/client/retry.jl +++ b/src/client/retry.jl @@ -124,7 +124,7 @@ function with_retry_token( end partition_ptr = partition_ref end - use_retry_strategy = retry_delays === nothing && client.retry_strategy != C_NULL + use_retry_strategy = retry_delays === nothing && partition !== nothing && client.retry_strategy != C_NULL # If max_retries is 0, we don't need to bother with any retrying max_retries = client.settings.max_retries if max_retries == 0 diff --git a/test/client.jl b/test/client.jl index a1de58a7..1ac37c31 100644 --- a/test/client.jl +++ b/test/client.jl @@ -280,7 +280,7 @@ @test attempts[] == 2 reset_attempts!(1) - resp = HTTP.get("http://127.0.0.1:$port/"; retries=1, retry_delays=[0.0], retry_partition="test") + resp = HTTP.get("http://127.0.0.1:$port/"; retries=1, retry_partition="test") @test resp.status == 200 @test resp.metrics.nretries == 1 @test attempts[] == 2 @@ -627,7 +627,7 @@ @test HTTP.http2_get_received_goaway(io) === nothing end else - @test_skip "HTTP/2 not available for $httpbin" + @info "HTTP/2 not available for $httpbin" end end From 10d9b861ee168e97e2702fe4e437a68f7cd08245 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 16:32:00 -0700 Subject: [PATCH 40/56] feat(http2): add initial settings options --- src/client/client.jl | 34 +++++++++++++++++++++++++++++----- src/client/connection.jl | 3 ++- test/client.jl | 29 +++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 6 deletions(-) diff --git a/src/client/client.jl b/src/client/client.jl index 18d1b9ea..5b775161 100644 --- a/src/client/client.jl +++ b/src/client/client.jl @@ -114,6 +114,7 @@ Base.@kwdef struct ClientSettings http2_connection_ping_timeout_ms::Int = 0 http2_ideal_concurrent_streams_per_connection::Int = 0 http2_max_concurrent_streams_per_connection::Int = 0 + http2_initial_settings::Union{Nothing, AbstractVector} = nothing end ClientSettings( @@ -128,7 +129,14 @@ ClientSettings( max_retries::Integer=DEFAULT_MAX_RETRIES, require_ssl_verification::Bool=true, ssl_insecure::Bool=false, - kw...) = + kw...) = begin + http2_initial_settings = Base.get(() -> nothing, kw, :http2_initial_settings) + if http2_initial_settings !== nothing && !(http2_initial_settings isa AbstractVector) + throw(ArgumentError("http2_initial_settings must be a vector of pairs or aws_http2_setting")) + end + if haskey(kw, :http2_initial_settings) + kw = Base.structdiff((; kw...), (; http2_initial_settings=nothing)) + end ClientSettings(; scheme=String(scheme), host=String(host), @@ -136,7 +144,9 @@ ClientSettings( connect_timeout_ms=(connect_timeout !== nothing ? connect_timeout * 1000 : connect_timeout_ms), max_retries=(retry ? (retries != DEFAULT_MAX_RETRIES ? retries : max_retries) : 0), ssl_insecure=(!require_ssl_verification || ssl_insecure), + http2_initial_settings=http2_initial_settings, kw...) +end # make a new ClientSettings object from an existing one w/ just different url values @generated function ClientSettings(cs::ClientSettings, scheme::AbstractString, host::AbstractString, port::Integer) @@ -163,6 +173,7 @@ mutable struct Client connection_manager::Ptr{aws_http_connection_manager} http2_stream_manager_opts::Union{Nothing, aws_http2_stream_manager_options} http2_stream_manager::Ptr{aws_http2_stream_manager} + http2_initial_settings::Union{Nothing, Vector{aws_http2_setting}} Client() = new() end @@ -291,6 +302,19 @@ function Client(cs::ClientSettings) ) client.retry_strategy = aws_retry_strategy_new_standard(cs.allocator, FieldRef(client, :retry_options)) client.retry_strategy == C_NULL && aws_throw_error() + settings_input = cs.http2_initial_settings + if settings_input === nothing + client.http2_initial_settings = nothing + elseif settings_input isa AbstractVector{aws_http2_setting} + client.http2_initial_settings = collect(settings_input) + elseif settings_input isa AbstractVector{<:Pair} + client.http2_initial_settings = _settings_from_pairs(settings_input) + else + throw(ArgumentError("http2_initial_settings must be a vector of pairs or aws_http2_setting")) + end + settings_ptr = client.http2_initial_settings === nothing ? C_NULL : pointer(client.http2_initial_settings) + settings_len = client.http2_initial_settings === nothing ? 0 : length(client.http2_initial_settings) + client.conn_manager_opts = aws_http_connection_manager_options( cs.bootstrap, typemax(Csize_t), # initial_window_size::Csize_t @@ -301,8 +325,8 @@ function Client(cs::ClientSettings) monitoring_ptr, # monitoring_options::Ptr{aws_http_connection_monitoring_options} aws_byte_cursor_from_c_str(cs.host), cs.port % UInt32, - C_NULL, # initial_settings_array::Ptr{aws_http2_setting} - 0, # num_initial_settings::Csize_t + settings_ptr, # initial_settings_array::Ptr{aws_http2_setting} + settings_len, # num_initial_settings::Csize_t 0, # max_closed_streams::Csize_t false, # http2_conn_manual_window_management::Bool client.proxy_options === nothing ? C_NULL : pointer(FieldRef(client, :proxy_options)), # proxy_options::Ptr{aws_http_proxy_options} @@ -329,8 +353,8 @@ function Client(cs::ClientSettings) cs.http2_prior_knowledge, aws_byte_cursor_from_c_str(cs.host), cs.port % UInt32, - C_NULL, # initial_settings_array - 0, # num_initial_settings + settings_ptr, # initial_settings_array + settings_len, # num_initial_settings 0, # max_closed_streams false, # conn_manual_window_management cs.enable_read_back_pressure, diff --git a/src/client/connection.jl b/src/client/connection.jl index 00063ae5..b0c82183 100644 --- a/src/client/connection.jl +++ b/src/client/connection.jl @@ -112,7 +112,8 @@ http2_ping(client::Client; data=nothing) = _with_http2_connection(conn -> http2_ function _settings_from_pairs(settings::AbstractVector{<:Pair}) out = Vector{aws_http2_setting}(undef, length(settings)) for (i, (k, v)) in enumerate(settings) - out[i] = aws_http2_setting(aws_http2_settings_id(k), UInt32(v)) + id = k isa aws_http2_settings_id ? k : aws_http2_settings_id(k) + out[i] = aws_http2_setting(id, UInt32(v)) end return out end diff --git a/test/client.jl b/test/client.jl index 1ac37c31..06781661 100644 --- a/test/client.jl +++ b/test/client.jl @@ -530,6 +530,35 @@ finalize(client) end + @testset "HTTP/2 initial settings options" begin + settings = [ + HTTP.AWS_HTTP2_SETTINGS_MAX_CONCURRENT_STREAMS => 10, + HTTP.AWS_HTTP2_SETTINGS_INITIAL_WINDOW_SIZE => 65535, + ] + cs = HTTP.ClientSettings("https", "example.com", UInt32(443); http2_initial_settings=settings) + client = HTTP.Client(cs) + @test client.http2_initial_settings !== nothing + @test length(client.http2_initial_settings) == 2 + @test client.conn_manager_opts.num_initial_settings == Csize_t(2) + @test client.conn_manager_opts.initial_settings_array != C_NULL + finalize(client) + + cs = HTTP.ClientSettings("https", "example.com", UInt32(443); + http2_stream_manager=true, + http2_initial_settings=settings, + ) + client = HTTP.Client(cs) + opts = client.http2_stream_manager_opts + @test opts !== nothing + @test opts.num_initial_settings == Csize_t(2) + @test opts.initial_settings_array != C_NULL + finalize(client) + + @test_throws ArgumentError HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443); + http2_initial_settings=1, + )) + end + @testset "HTTP manager metrics" begin client = HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443))) metrics = HTTP.manager_metrics(client) From d9a0b8a1bc39ce41a4d674d952ba20e3e162615e Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 18:52:27 -0700 Subject: [PATCH 41/56] docs(migrate): align with current v2 behavior --- docs/src/manual/migrate.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/src/manual/migrate.md b/docs/src/manual/migrate.md index 9a8c3e0f..acffd9bd 100644 --- a/docs/src/manual/migrate.md +++ b/docs/src/manual/migrate.md @@ -53,7 +53,8 @@ While the basic request syntax remains similar, there are some changes to keywor #### Changes to Keyword Arguments -- The `response_stream` keyword argument is still supported, but HTTP.jl no longer automatically closes this stream when done - you need to handle this yourself +- The preferred keyword for streaming responses is `response_body`. `response_stream` is still supported for compatibility +- Response streams are not automatically closed; you need to handle this yourself - `retry` behavior has been overhauled with more consistent rules for what is retryable - The default `max_retries` is now 4 (was 10 in v1.x) - Some connection-related options have new defaults (e.g., TLS is now OpenSSL-based by default rather than MbedTLS) @@ -285,7 +286,8 @@ end close(server) ``` -`listen!` and `serve!` are both supported in v2.0; `serve!` is an alias that matches the HTTP server naming. +`serve`/`serve!` are the primary request/response handlers. `listen`/`listen!` are stream-handler shorthands +equivalent to `serve(...; stream=true)`. ## Error Handling @@ -306,7 +308,7 @@ catch e if e isa HTTP.ConnectError println("Connection failed: $(e.error)") elseif e isa HTTP.TimeoutError - println("Request timed out after $(e.timeout) seconds") + println("Request timed out after $(e.readtimeout) seconds") elseif e isa HTTP.StatusError println("Server returned error status: $(e.status)") elseif e isa HTTP.RequestError @@ -329,14 +331,14 @@ jar = HTTP.CookieJar() response = HTTP.get("https://example.com", cookiejar=jar) # Checking cookies -cookies = HTTP.getcookies(jar, "example.com") +cookies = HTTP.Cookies.getcookies!(jar, "https", "example.com", "/") ``` ## Other Notable Changes - **URI Handling**: URIs are now handled by the separate URIs.jl package (this change actually occurred in v1.0) - **Default Headers**: Headers like `Accept: */*` are now included by default in requests -- **TLS Implementation**: OpenSSL is now the default TLS provider instead of MbedTLS +- **TLS Implementation**: TLS is handled by AWS CRT (s2n-tls) instead of MbedTLS - **Multithreading**: Improved thread safety throughout the codebase - **Performance**: Significant performance improvements, especially for high-throughput servers - **Parser APIs**: Low-level parser APIs from v1.x have been removed in v2.0 From 353364b0fd27937bc433eec6001b1a3874146764 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 21:07:21 -0700 Subject: [PATCH 42/56] feat(http2): add max closed streams option --- docs/src/manual/client.md | 1 + src/client/client.jl | 5 +- test/client.jl | 224 +++++++++++++++++++++++--------------- test/multipart.jl | 14 ++- test/runtests.jl | 8 ++ 5 files changed, 156 insertions(+), 96 deletions(-) diff --git a/docs/src/manual/client.md b/docs/src/manual/client.md index db2ba8c8..7f4a4b02 100644 --- a/docs/src/manual/client.md +++ b/docs/src/manual/client.md @@ -111,6 +111,7 @@ The following keyword arguments (which correspond to the non-`scheme`/`host`/`po - **http2_connection_ping_timeout_ms**: Default `0` (AWS default). Timeout in milliseconds for PING responses. - **http2_ideal_concurrent_streams_per_connection**: Default `0` (AWS default). Target streams per connection before opening new connections. - **http2_max_concurrent_streams_per_connection**: Default `0` (no explicit limit). Upper bound for streams per connection. + - **http2_max_closed_streams**: Default `0` (AWS default). Max closed streams to remember before ignoring late frames. -- AWS runtime options: - **allocator**: The allocator to use for AWS-allocated memory during the request. - **bootstrap**: The AWS client bootstrap to use for the request. diff --git a/src/client/client.jl b/src/client/client.jl index 5b775161..d58210f3 100644 --- a/src/client/client.jl +++ b/src/client/client.jl @@ -114,6 +114,7 @@ Base.@kwdef struct ClientSettings http2_connection_ping_timeout_ms::Int = 0 http2_ideal_concurrent_streams_per_connection::Int = 0 http2_max_concurrent_streams_per_connection::Int = 0 + http2_max_closed_streams::Int = 0 http2_initial_settings::Union{Nothing, AbstractVector} = nothing end @@ -327,7 +328,7 @@ function Client(cs::ClientSettings) cs.port % UInt32, settings_ptr, # initial_settings_array::Ptr{aws_http2_setting} settings_len, # num_initial_settings::Csize_t - 0, # max_closed_streams::Csize_t + cs.http2_max_closed_streams, # max_closed_streams::Csize_t false, # http2_conn_manual_window_management::Bool client.proxy_options === nothing ? C_NULL : pointer(FieldRef(client, :proxy_options)), # proxy_options::Ptr{aws_http_proxy_options} client.proxy_env_settings === nothing ? C_NULL : pointer(FieldRef(client, :proxy_env_settings)), # proxy_env_settings::Ptr{proxy_env_var_settings} @@ -355,7 +356,7 @@ function Client(cs::ClientSettings) cs.port % UInt32, settings_ptr, # initial_settings_array settings_len, # num_initial_settings - 0, # max_closed_streams + cs.http2_max_closed_streams, # max_closed_streams false, # conn_manual_window_management cs.enable_read_back_pressure, typemax(Csize_t), # initial_window_size diff --git a/test/client.jl b/test/client.jl index 06781661..d75f41e8 100644 --- a/test/client.jl +++ b/test/client.jl @@ -1,4 +1,5 @@ @testset "Client.jl" begin + if HAVE_HTTPBIN @testset "GET, HEAD, POST, PUT, DELETE, PATCH: $scheme" for scheme in ["http", "https"] @test isok(HTTP.get("$scheme://$httpbin/ip")) @test isok(HTTP.head("$scheme://$httpbin/ip")) @@ -192,6 +193,9 @@ @test r.request.method == "GET" @test length(r.body) > 0 end + else + @info "Skipping HTTPBin-dependent client tests" + end @testset "Header insertion" begin server = HTTP.serve!(req -> begin @@ -211,8 +215,21 @@ end @testset "readtimeout" begin - @test_throws HTTP.TimeoutError HTTP.get("http://$httpbin/delay/5"; readtimeout=1, max_retries=0) - @test isok(HTTP.get("http://$httpbin/delay/1"; readtimeout=2, max_retries=0)) + server = HTTP.serve!("127.0.0.1", 0; listenany=true) do req + if req.target == "/delay/5" + sleep(5) + elseif req.target == "/delay/1" + sleep(1) + end + return HTTP.Response(200, "ok") + end + try + port = HTTP.port(server) + @test_throws HTTP.TimeoutError HTTP.get("http://127.0.0.1:$port/delay/5"; readtimeout=1, max_retries=0) + @test isok(HTTP.get("http://127.0.0.1:$port/delay/1"; readtimeout=2, max_retries=0)) + finally + close(server) + end end @testset "Retry semantics" begin @@ -310,39 +327,43 @@ end end - @testset "Request Options Parity" begin - headers = ["X-Test" => "1"] - HTTP.get("https://$httpbin/headers"; headers=headers, copyheaders=true) - @test headers == ["X-Test" => "1"] + if HAVE_HTTPBIN + @testset "Request Options Parity" begin + headers = ["X-Test" => "1"] + HTTP.get("https://$httpbin/headers"; headers=headers, copyheaders=true) + @test headers == ["X-Test" => "1"] - headers2 = ["X-Test" => "1"] - HTTP.get("https://$httpbin/headers"; headers=headers2, copyheaders=false) - @test any(h -> lowercase(String(h.first)) == "accept", headers2) - @test any(h -> lowercase(String(h.first)) == "x-test", headers2) + headers2 = ["X-Test" => "1"] + HTTP.get("https://$httpbin/headers"; headers=headers2, copyheaders=false) + @test any(h -> lowercase(String(h.first)) == "accept", headers2) + @test any(h -> lowercase(String(h.first)) == "x-test", headers2) - resp = HTTP.get("https://user:pwd@$httpbin/headers"; basicauth=false) - @test HTTP.getheader(resp.request.headers, "authorization") === nothing + resp = HTTP.get("https://user:pwd@$httpbin/headers"; basicauth=false) + @test HTTP.getheader(resp.request.headers, "authorization") === nothing - resp = HTTP.get("https://user:pwd@$httpbin/headers"; basicauth=true) - auth = HTTP.getheader(resp.request.headers, "authorization") - @test auth !== nothing && startswith(auth, "Basic ") + resp = HTTP.get("https://user:pwd@$httpbin/headers"; basicauth=true) + auth = HTTP.getheader(resp.request.headers, "authorization") + @test auth !== nothing && startswith(auth, "Basic ") - resp = HTTP.post("https://$httpbin/anything"; body="hello", detect_content_type=true) - @test HTTP.getheader(resp.request.headers, "content-type") == "text/plain; charset=utf-8" + resp = HTTP.post("https://$httpbin/anything"; body="hello", detect_content_type=true) + @test HTTP.getheader(resp.request.headers, "content-type") == "text/plain; charset=utf-8" - orig_agent = HTTP.USER_AGENT[] - try - HTTP.setuseragent!(nothing) - resp = HTTP.get("https://$httpbin/headers") - @test HTTP.getheader(resp.request.headers, "user-agent") === nothing - finally - HTTP.setuseragent!(orig_agent) - end + orig_agent = HTTP.USER_AGENT[] + try + HTTP.setuseragent!(nothing) + resp = HTTP.get("https://$httpbin/headers") + @test HTTP.getheader(resp.request.headers, "user-agent") === nothing + finally + HTTP.setuseragent!(orig_agent) + end - pool = HTTP.Pool(1) - @test isempty(pool.clients.clients) - HTTP.get("https://$httpbin/ip"; pool=pool) - @test !isempty(pool.clients.clients) + pool = HTTP.Pool(1) + @test isempty(pool.clients.clients) + HTTP.get("https://$httpbin/ip"; pool=pool) + @test !isempty(pool.clients.clients) + end + else + @info "Skipping HTTPBin-dependent Request Options Parity tests" end @testset "observelayers" begin @@ -484,23 +505,27 @@ end end - @testset "HTTP.open streaming" begin - resp = HTTP.open("GET", "https://$httpbin/stream/5") do io - r = HTTP.startread(io) - @test r.status == 200 - data = String(read(io)) - @test length(split(chomp(data), '\n')) == 5 - end - @test resp.status == 200 - - resp = HTTP.open("POST", "https://$httpbin/anything") do io - write(io, "hello") - HTTP.closewrite(io) - r = HTTP.startread(io) - data = String(read(io)) - @test occursin("\"data\":\"hello\"", data) + if HAVE_HTTPBIN + @testset "HTTP.open streaming" begin + resp = HTTP.open("GET", "https://$httpbin/stream/5") do io + r = HTTP.startread(io) + @test r.status == 200 + data = String(read(io)) + @test length(split(chomp(data), '\n')) == 5 + end + @test resp.status == 200 + + resp = HTTP.open("POST", "https://$httpbin/anything") do io + write(io, "hello") + HTTP.closewrite(io) + r = HTTP.startread(io) + data = String(read(io)) + @test occursin("\"data\":\"hello\"", data) + end + @test resp.status == 200 end - @test resp.status == 200 + else + @info "Skipping HTTPBin-dependent HTTP.open streaming tests" end @testset "HTTP/2 stream manager smoke" begin @@ -530,6 +555,23 @@ finalize(client) end + @testset "HTTP/2 max closed streams option" begin + cs = HTTP.ClientSettings("https", "example.com", UInt32(443); http2_max_closed_streams=7) + client = HTTP.Client(cs) + @test client.conn_manager_opts.max_closed_streams == Csize_t(7) + finalize(client) + + cs = HTTP.ClientSettings("https", "example.com", UInt32(443); + http2_stream_manager=true, + http2_max_closed_streams=9, + ) + client = HTTP.Client(cs) + opts = client.http2_stream_manager_opts + @test opts !== nothing + @test opts.max_closed_streams == Csize_t(9) + finalize(client) + end + @testset "HTTP/2 initial settings options" begin settings = [ HTTP.AWS_HTTP2_SETTINGS_MAX_CONCURRENT_STREAMS => 10, @@ -642,53 +684,57 @@ )) end - @testset "HTTP/2 control APIs" begin - resp = HTTP.get("https://$httpbin/ip") - if resp.version == HTTP.HTTPVersion(2, 0) - HTTP.open("GET", "https://$httpbin/ip") do io - r = HTTP.startread(io) - @test r.status == 200 - rtt = HTTP.http2_ping(io) - @test rtt isa UInt64 - HTTP.http2_change_settings(io, Pair{Int, Int}[]) - @test length(HTTP.http2_local_settings(io)) == HTTP.AWS_HTTP2_SETTINGS_COUNT - @test HTTP.http2_get_sent_goaway(io) === nothing - @test HTTP.http2_get_received_goaway(io) === nothing + if HAVE_HTTPBIN + @testset "HTTP/2 control APIs" begin + resp = HTTP.get("https://$httpbin/ip") + if resp.version == HTTP.HTTPVersion(2, 0) + HTTP.open("GET", "https://$httpbin/ip") do io + r = HTTP.startread(io) + @test r.status == 200 + rtt = HTTP.http2_ping(io) + @test rtt isa UInt64 + HTTP.http2_change_settings(io, Pair{Int, Int}[]) + @test length(HTTP.http2_local_settings(io)) == HTTP.AWS_HTTP2_SETTINGS_COUNT + @test HTTP.http2_get_sent_goaway(io) === nothing + @test HTTP.http2_get_received_goaway(io) === nothing + end + else + @info "HTTP/2 not available for $httpbin" end - else - @info "HTTP/2 not available for $httpbin" end - end - @testset "Public entry point of HTTP.request and friends (e.g. issue #463)" begin - headers = Dict("User-Agent" => "HTTP.jl") - query = Dict("hello" => "world") - body = UInt8[1, 2, 3] - for uri in ("https://$httpbin/anything", HTTP.URI("https://$httpbin/anything")) - # HTTP.request - @test isok(HTTP.request("GET", uri; headers=headers, body=body, query=query)) - @test isok(HTTP.request("GET", uri, headers; body=body, query=query)) - @test isok(HTTP.request("GET", uri, headers, body; query=query)) - # HTTP.get - @test isok(HTTP.get(uri; headers=headers, body=body, query=query)) - @test isok(HTTP.get(uri, headers; body=body, query=query)) - @test isok(HTTP.get(uri, headers, body; query=query)) - # HTTP.put - @test isok(HTTP.put(uri; headers=headers, body=body, query=query)) - @test isok(HTTP.put(uri, headers; body=body, query=query)) - @test isok(HTTP.put(uri, headers, body; query=query)) - # HTTP.post - @test isok(HTTP.post(uri; headers=headers, body=body, query=query)) - @test isok(HTTP.post(uri, headers; body=body, query=query)) - @test isok(HTTP.post(uri, headers, body; query=query)) - # HTTP.patch - @test isok(HTTP.patch(uri; headers=headers, body=body, query=query)) - @test isok(HTTP.patch(uri, headers; body=body, query=query)) - @test isok(HTTP.patch(uri, headers, body; query=query)) - # HTTP.delete - @test isok(HTTP.delete(uri; headers=headers, body=body, query=query)) - @test isok(HTTP.delete(uri, headers; body=body, query=query)) - @test isok(HTTP.delete(uri, headers, body; query=query)) + @testset "Public entry point of HTTP.request and friends (e.g. issue #463)" begin + headers = Dict("User-Agent" => "HTTP.jl") + query = Dict("hello" => "world") + body = UInt8[1, 2, 3] + for uri in ("https://$httpbin/anything", HTTP.URI("https://$httpbin/anything")) + # HTTP.request + @test isok(HTTP.request("GET", uri; headers=headers, body=body, query=query)) + @test isok(HTTP.request("GET", uri, headers; body=body, query=query)) + @test isok(HTTP.request("GET", uri, headers, body; query=query)) + # HTTP.get + @test isok(HTTP.get(uri; headers=headers, body=body, query=query)) + @test isok(HTTP.get(uri, headers; body=body, query=query)) + @test isok(HTTP.get(uri, headers, body; query=query)) + # HTTP.put + @test isok(HTTP.put(uri; headers=headers, body=body, query=query)) + @test isok(HTTP.put(uri, headers; body=body, query=query)) + @test isok(HTTP.put(uri, headers, body; query=query)) + # HTTP.post + @test isok(HTTP.post(uri; headers=headers, body=body, query=query)) + @test isok(HTTP.post(uri, headers; body=body, query=query)) + @test isok(HTTP.post(uri, headers, body; query=query)) + # HTTP.patch + @test isok(HTTP.patch(uri; headers=headers, body=body, query=query)) + @test isok(HTTP.patch(uri, headers; body=body, query=query)) + @test isok(HTTP.patch(uri, headers, body; query=query)) + # HTTP.delete + @test isok(HTTP.delete(uri; headers=headers, body=body, query=query)) + @test isok(HTTP.delete(uri, headers; body=body, query=query)) + @test isok(HTTP.delete(uri, headers, body; query=query)) + end end + else + @info "Skipping HTTPBin-dependent HTTP/2 control and request entry point tests" end end diff --git a/test/multipart.jl b/test/multipart.jl index 1defa2c0..e1c70446 100644 --- a/test/multipart.jl +++ b/test/multipart.jl @@ -11,11 +11,15 @@ end headers = Dict("User-Agent" => "HTTP.jl") body = HTTP.Form(Dict()) mark(body) - @testset "Setting of Content-Type" begin - test_multipart(HTTP.request("POST", "https://$httpbin/post", headers, body), body) - test_multipart(HTTP.post("https://$httpbin/post", headers, body), body) - test_multipart(HTTP.request("PUT", "https://$httpbin/put", headers, body), body) - test_multipart(HTTP.put("https://$httpbin/put", headers, body), body) + if HAVE_HTTPBIN + @testset "Setting of Content-Type" begin + test_multipart(HTTP.request("POST", "https://$httpbin/post", headers, body), body) + test_multipart(HTTP.post("https://$httpbin/post", headers, body), body) + test_multipart(HTTP.request("PUT", "https://$httpbin/put", headers, body), body) + test_multipart(HTTP.put("https://$httpbin/put", headers, body), body) + end + else + @info "Skipping HTTPBin-dependent multipart tests" end @testset "HTTP.Multipart ensure show() works correctly" begin # testing that there is no error in printing when nothing is set for filename diff --git a/test/runtests.jl b/test/runtests.jl index 97d34c38..eded5cdf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,14 @@ using Test, HTTP, URIs, JSON const httpbin = get(ENV, "JULIA_TEST_HTTPBINGO_SERVER", "httpbingo.julialang.org") isok(r) = r.status == 200 +const HAVE_HTTPBIN = let + try + resp = HTTP.get("http://$httpbin/ip"; readtimeout=2, max_retries=0) + isok(resp) + catch + false + end +end include("utils.jl") include("headers.jl") From 518ecc8ed83e1faf00c0fe251ea8a1840a0f2cbf Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 21:14:14 -0700 Subject: [PATCH 43/56] feat(http2): add initial window size option --- docs/src/manual/client.md | 1 + src/client/client.jl | 5 +++-- test/client.jl | 12 ++++++++++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/docs/src/manual/client.md b/docs/src/manual/client.md index 7f4a4b02..2a13ecca 100644 --- a/docs/src/manual/client.md +++ b/docs/src/manual/client.md @@ -112,6 +112,7 @@ The following keyword arguments (which correspond to the non-`scheme`/`host`/`po - **http2_ideal_concurrent_streams_per_connection**: Default `0` (AWS default). Target streams per connection before opening new connections. - **http2_max_concurrent_streams_per_connection**: Default `0` (no explicit limit). Upper bound for streams per connection. - **http2_max_closed_streams**: Default `0` (AWS default). Max closed streams to remember before ignoring late frames. + - **http2_initial_window_size**: Default `typemax(Int)` (AWS default). Initial flow-control window size for HTTP/2. -- AWS runtime options: - **allocator**: The allocator to use for AWS-allocated memory during the request. - **bootstrap**: The AWS client bootstrap to use for the request. diff --git a/src/client/client.jl b/src/client/client.jl index d58210f3..b2cfea79 100644 --- a/src/client/client.jl +++ b/src/client/client.jl @@ -115,6 +115,7 @@ Base.@kwdef struct ClientSettings http2_ideal_concurrent_streams_per_connection::Int = 0 http2_max_concurrent_streams_per_connection::Int = 0 http2_max_closed_streams::Int = 0 + http2_initial_window_size::Int = typemax(Int) http2_initial_settings::Union{Nothing, AbstractVector} = nothing end @@ -318,7 +319,7 @@ function Client(cs::ClientSettings) client.conn_manager_opts = aws_http_connection_manager_options( cs.bootstrap, - typemax(Csize_t), # initial_window_size::Csize_t + cs.http2_initial_window_size, # initial_window_size::Csize_t pointer(FieldRef(client, :socket_options)), cs.response_first_byte_timeout_ms, (cs.scheme == "https" || cs.scheme == "wss") ? pointer(FieldRef(client, :tls_options)) : C_NULL, @@ -359,7 +360,7 @@ function Client(cs::ClientSettings) cs.http2_max_closed_streams, # max_closed_streams false, # conn_manual_window_management cs.enable_read_back_pressure, - typemax(Csize_t), # initial_window_size + cs.http2_initial_window_size, # initial_window_size monitoring_ptr, # monitoring_options client.proxy_options === nothing ? C_NULL : pointer(FieldRef(client, :proxy_options)), client.proxy_env_settings === nothing ? C_NULL : pointer(FieldRef(client, :proxy_env_settings)), diff --git a/test/client.jl b/test/client.jl index d75f41e8..14f2cf60 100644 --- a/test/client.jl +++ b/test/client.jl @@ -226,7 +226,7 @@ try port = HTTP.port(server) @test_throws HTTP.TimeoutError HTTP.get("http://127.0.0.1:$port/delay/5"; readtimeout=1, max_retries=0) - @test isok(HTTP.get("http://127.0.0.1:$port/delay/1"; readtimeout=2, max_retries=0)) + @test isok(HTTP.get("http://127.0.0.1:$port/delay/1"; readtimeout=5, max_retries=0)) finally close(server) end @@ -543,6 +543,7 @@ http2_connection_ping_timeout_ms=2345, http2_ideal_concurrent_streams_per_connection=7, http2_max_concurrent_streams_per_connection=9, + http2_initial_window_size=12345, ) client = HTTP.Client(cs) opts = client.http2_stream_manager_opts @@ -552,23 +553,30 @@ @test opts.connection_ping_timeout_ms == Csize_t(2345) @test opts.ideal_concurrent_streams_per_connection == Csize_t(7) @test opts.max_concurrent_streams_per_connection == Csize_t(9) + @test opts.initial_window_size == Csize_t(12345) finalize(client) end @testset "HTTP/2 max closed streams option" begin - cs = HTTP.ClientSettings("https", "example.com", UInt32(443); http2_max_closed_streams=7) + cs = HTTP.ClientSettings("https", "example.com", UInt32(443); + http2_max_closed_streams=7, + http2_initial_window_size=54321, + ) client = HTTP.Client(cs) @test client.conn_manager_opts.max_closed_streams == Csize_t(7) + @test client.conn_manager_opts.initial_window_size == Csize_t(54321) finalize(client) cs = HTTP.ClientSettings("https", "example.com", UInt32(443); http2_stream_manager=true, http2_max_closed_streams=9, + http2_initial_window_size=65432, ) client = HTTP.Client(cs) opts = client.http2_stream_manager_opts @test opts !== nothing @test opts.max_closed_streams == Csize_t(9) + @test opts.initial_window_size == Csize_t(65432) finalize(client) end From a9d780e12ecd80d8b0cfe5a5c476feae26b456e5 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 21:29:26 -0700 Subject: [PATCH 44/56] feat(http2): add manual window controls --- docs/src/manual/client.md | 5 ++++- src/client/client.jl | 5 +++-- src/client/connection.jl | 11 +++++++++++ src/client/stream.jl | 10 ++++++++++ test/client.jl | 5 +++++ 5 files changed, 33 insertions(+), 3 deletions(-) diff --git a/docs/src/manual/client.md b/docs/src/manual/client.md index 2a13ecca..b16ce346 100644 --- a/docs/src/manual/client.md +++ b/docs/src/manual/client.md @@ -107,6 +107,7 @@ The following keyword arguments (which correspond to the non-`scheme`/`host`/`po - **http2_prior_knowledge**: Default `false`. If `true`, assume HTTP/2 without ALPN negotiation. - **http2_stream_manager**: Default `false`. If `true`, enable the HTTP/2 stream manager for multiplexed requests. - **http2_close_connection_on_server_error**: Default `false`. If `true`, close HTTP/2 connections when a 5xx response is received. + - **http2_connection_manual_window_management**: Default `false`. If `true`, use manual connection-level flow control (call `HTTP.http2_update_window`). - **http2_connection_ping_period_ms**: Default `0` (disabled). Period in milliseconds for sending HTTP/2 PING frames. - **http2_connection_ping_timeout_ms**: Default `0` (AWS default). Timeout in milliseconds for PING responses. - **http2_ideal_concurrent_streams_per_connection**: Default `0` (AWS default). Target streams per connection before opening new connections. @@ -206,8 +207,10 @@ When a connection negotiates HTTP/2, you can use the following helpers: - `HTTP.http2_local_settings(client)` / `HTTP.http2_remote_settings(client)` -> returns current settings. - `HTTP.http2_send_goaway(client, error_code; allow_more_streams=true, debug_data=nothing)` - `HTTP.http2_get_sent_goaway(client)` / `HTTP.http2_get_received_goaway(client)` -> returns `nothing` if no GOAWAY. +- `HTTP.http2_update_window(client_or_stream, increment)` -> increases the HTTP/2 connection flow-control window. +- `HTTP.update_window(stream, increment)` -> increases the stream flow-control window (useful with `enable_read_back_pressure=true`). -These helpers require an HTTP/2 connection and will throw an `ArgumentError` if the connection is HTTP/1.1. +HTTP/2-specific helpers require an HTTP/2 connection and will throw an `ArgumentError` if the connection is HTTP/1.1. ## Trailing Headers diff --git a/src/client/client.jl b/src/client/client.jl index b2cfea79..110106ba 100644 --- a/src/client/client.jl +++ b/src/client/client.jl @@ -110,6 +110,7 @@ Base.@kwdef struct ClientSettings http2_prior_knowledge::Bool = false http2_stream_manager::Bool = false http2_close_connection_on_server_error::Bool = false + http2_connection_manual_window_management::Bool = false http2_connection_ping_period_ms::Int = 0 http2_connection_ping_timeout_ms::Int = 0 http2_ideal_concurrent_streams_per_connection::Int = 0 @@ -330,7 +331,7 @@ function Client(cs::ClientSettings) settings_ptr, # initial_settings_array::Ptr{aws_http2_setting} settings_len, # num_initial_settings::Csize_t cs.http2_max_closed_streams, # max_closed_streams::Csize_t - false, # http2_conn_manual_window_management::Bool + cs.http2_connection_manual_window_management, # http2_conn_manual_window_management::Bool client.proxy_options === nothing ? C_NULL : pointer(FieldRef(client, :proxy_options)), # proxy_options::Ptr{aws_http_proxy_options} client.proxy_env_settings === nothing ? C_NULL : pointer(FieldRef(client, :proxy_env_settings)), # proxy_env_settings::Ptr{proxy_env_var_settings} cs.max_connections, # max_connections::Csize_t, 512 @@ -358,7 +359,7 @@ function Client(cs::ClientSettings) settings_ptr, # initial_settings_array settings_len, # num_initial_settings cs.http2_max_closed_streams, # max_closed_streams - false, # conn_manual_window_management + cs.http2_connection_manual_window_management, # conn_manual_window_management cs.enable_read_back_pressure, cs.http2_initial_window_size, # initial_window_size monitoring_ptr, # monitoring_options diff --git a/src/client/connection.jl b/src/client/connection.jl index b0c82183..d58f7c82 100644 --- a/src/client/connection.jl +++ b/src/client/connection.jl @@ -176,6 +176,17 @@ end http2_send_goaway(client::Client, http2_error::Integer; allow_more_streams::Bool=true, debug_data=nothing) = _with_http2_connection(conn -> http2_send_goaway(conn, http2_error; allow_more_streams=allow_more_streams, debug_data=debug_data), client) +function http2_update_window(conn::Ptr{aws_http_connection}, increment::Integer) + _ensure_http2_connection(conn) + increment < 0 && throw(ArgumentError("increment must be >= 0")) + increment > typemax(UInt32) && throw(ArgumentError("increment must be <= $(typemax(UInt32))")) + aws_http2_connection_update_window(conn, UInt32(increment)) + return +end + +http2_update_window(client::Client, increment::Integer) = + _with_http2_connection(conn -> http2_update_window(conn, increment), client) + function _get_goaway(get_fn, conn::Ptr{aws_http_connection}) _ensure_http2_connection(conn) diff --git a/src/client/stream.jl b/src/client/stream.jl index f2d2ab76..937e94ed 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -236,6 +236,16 @@ http2_send_goaway(stream::Stream, http2_error::Integer; allow_more_streams::Bool _with_http2_connection(conn -> http2_send_goaway(conn, http2_error; allow_more_streams=allow_more_streams, debug_data=debug_data), stream) http2_get_sent_goaway(stream::Stream) = _with_http2_connection(http2_get_sent_goaway, stream) http2_get_received_goaway(stream::Stream) = _with_http2_connection(http2_get_received_goaway, stream) +http2_update_window(stream::Stream, increment::Integer) = + _with_http2_connection(conn -> http2_update_window(conn, increment), stream) + +function update_window(stream::Stream, increment::Integer) + stream.ptr == C_NULL && throw(ArgumentError("HTTP stream is not initialized")) + increment < 0 && throw(ArgumentError("increment must be >= 0")) + increment > typemax(Csize_t) && throw(ArgumentError("increment must be <= $(typemax(Csize_t))")) + aws_http_stream_update_window(stream.ptr, Csize_t(increment)) + return +end const on_stream_write_on_complete = Ref{Ptr{Cvoid}}(C_NULL) diff --git a/test/client.jl b/test/client.jl index 14f2cf60..1ae8098f 100644 --- a/test/client.jl +++ b/test/client.jl @@ -539,6 +539,7 @@ cs = HTTP.ClientSettings("https", "example.com", UInt32(443); http2_stream_manager=true, http2_close_connection_on_server_error=true, + http2_connection_manual_window_management=true, http2_connection_ping_period_ms=1234, http2_connection_ping_timeout_ms=2345, http2_ideal_concurrent_streams_per_connection=7, @@ -549,11 +550,13 @@ opts = client.http2_stream_manager_opts @test opts !== nothing @test opts.close_connection_on_server_error == true + @test opts.conn_manual_window_management == true @test opts.connection_ping_period_ms == Csize_t(1234) @test opts.connection_ping_timeout_ms == Csize_t(2345) @test opts.ideal_concurrent_streams_per_connection == Csize_t(7) @test opts.max_concurrent_streams_per_connection == Csize_t(9) @test opts.initial_window_size == Csize_t(12345) + @test client.conn_manager_opts.http2_conn_manual_window_management == true finalize(client) end @@ -705,6 +708,8 @@ @test length(HTTP.http2_local_settings(io)) == HTTP.AWS_HTTP2_SETTINGS_COUNT @test HTTP.http2_get_sent_goaway(io) === nothing @test HTTP.http2_get_received_goaway(io) === nothing + @test_nowarn HTTP.http2_update_window(io, 1024) + @test_nowarn HTTP.update_window(io, 1024) end else @info "HTTP/2 not available for $httpbin" From eddc1b26f573a21020acef9a0c113284e6385e08 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Sun, 25 Jan 2026 21:37:56 -0700 Subject: [PATCH 45/56] fix(http2): validate window sizes --- docs/src/manual/client.md | 2 +- src/client/client.jl | 5 ++++- src/client/connection.jl | 2 +- src/client/stream.jl | 6 +++++- src/utils.jl | 3 +++ test/client.jl | 9 +++++++++ 6 files changed, 23 insertions(+), 4 deletions(-) diff --git a/docs/src/manual/client.md b/docs/src/manual/client.md index b16ce346..62a05dd0 100644 --- a/docs/src/manual/client.md +++ b/docs/src/manual/client.md @@ -113,7 +113,7 @@ The following keyword arguments (which correspond to the non-`scheme`/`host`/`po - **http2_ideal_concurrent_streams_per_connection**: Default `0` (AWS default). Target streams per connection before opening new connections. - **http2_max_concurrent_streams_per_connection**: Default `0` (no explicit limit). Upper bound for streams per connection. - **http2_max_closed_streams**: Default `0` (AWS default). Max closed streams to remember before ignoring late frames. - - **http2_initial_window_size**: Default `typemax(Int)` (AWS default). Initial flow-control window size for HTTP/2. + - **http2_initial_window_size**: Default `65535`. Initial flow-control window size for HTTP/2 (must be <= 2^31-1). -- AWS runtime options: - **allocator**: The allocator to use for AWS-allocated memory during the request. - **bootstrap**: The AWS client bootstrap to use for the request. diff --git a/src/client/client.jl b/src/client/client.jl index 110106ba..b9dfca36 100644 --- a/src/client/client.jl +++ b/src/client/client.jl @@ -116,7 +116,7 @@ Base.@kwdef struct ClientSettings http2_ideal_concurrent_streams_per_connection::Int = 0 http2_max_concurrent_streams_per_connection::Int = 0 http2_max_closed_streams::Int = 0 - http2_initial_window_size::Int = typemax(Int) + http2_initial_window_size::Int = HTTP2_DEFAULT_WINDOW_SIZE http2_initial_settings::Union{Nothing, AbstractVector} = nothing end @@ -186,6 +186,9 @@ Client(scheme::AbstractString, host::AbstractString, port::Integer; kw...) = Cli function Client(cs::ClientSettings) client = Client() client.settings = cs + if cs.http2_initial_window_size < 0 || cs.http2_initial_window_size > HTTP2_MAX_WINDOW_SIZE + throw(ArgumentError("http2_initial_window_size must be between 0 and $(HTTP2_MAX_WINDOW_SIZE)")) + end # socket options client.socket_options = aws_socket_options( AWS_SOCKET_STREAM, # socket type diff --git a/src/client/connection.jl b/src/client/connection.jl index d58f7c82..9059a51a 100644 --- a/src/client/connection.jl +++ b/src/client/connection.jl @@ -179,7 +179,7 @@ http2_send_goaway(client::Client, http2_error::Integer; allow_more_streams::Bool function http2_update_window(conn::Ptr{aws_http_connection}, increment::Integer) _ensure_http2_connection(conn) increment < 0 && throw(ArgumentError("increment must be >= 0")) - increment > typemax(UInt32) && throw(ArgumentError("increment must be <= $(typemax(UInt32))")) + increment > HTTP2_MAX_WINDOW_SIZE && throw(ArgumentError("increment must be <= $(HTTP2_MAX_WINDOW_SIZE)")) aws_http2_connection_update_window(conn, UInt32(increment)) return end diff --git a/src/client/stream.jl b/src/client/stream.jl index 937e94ed..71492dd1 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -242,7 +242,11 @@ http2_update_window(stream::Stream, increment::Integer) = function update_window(stream::Stream, increment::Integer) stream.ptr == C_NULL && throw(ArgumentError("HTTP stream is not initialized")) increment < 0 && throw(ArgumentError("increment must be >= 0")) - increment > typemax(Csize_t) && throw(ArgumentError("increment must be <= $(typemax(Csize_t))")) + if stream.http2 + increment > HTTP2_MAX_WINDOW_SIZE && throw(ArgumentError("increment must be <= $(HTTP2_MAX_WINDOW_SIZE)")) + else + increment > typemax(Csize_t) && throw(ArgumentError("increment must be <= $(typemax(Csize_t))")) + end aws_http_stream_update_window(stream.ptr, Csize_t(increment)) return end diff --git a/src/utils.jl b/src/utils.jl index 564d6dae..205ff379 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,9 @@ export bytes, isbytes, nbytes, nobytes, escapehtml, tocameldash, iso8859_1_to_utf8, ascii_lc_isequal +const HTTP2_DEFAULT_WINDOW_SIZE = 65535 +const HTTP2_MAX_WINDOW_SIZE = 0x7fffffff + """ HTTPVersion(major, minor) diff --git a/test/client.jl b/test/client.jl index 1ae8098f..62db60f2 100644 --- a/test/client.jl +++ b/test/client.jl @@ -612,6 +612,15 @@ )) end + @testset "HTTP/2 window size validation" begin + @test_throws ArgumentError HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443); + http2_initial_window_size=-1, + )) + @test_throws ArgumentError HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443); + http2_initial_window_size=HTTP.HTTP2_MAX_WINDOW_SIZE + 1, + )) + end + @testset "HTTP manager metrics" begin client = HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443))) metrics = HTTP.manager_metrics(client) From 6260637a89b09027442d73d61e2f5857c7697052 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Mon, 2 Feb 2026 06:58:44 -0700 Subject: [PATCH 46/56] Switch to pure Julia AwsIO/AwsHTTP packages --- Project.toml | 10 +- src/HTTP.jl | 64 +-- src/client/client.jl | 501 ++++++++++----------- src/client/connection.jl | 197 ++++----- src/client/makerequest.jl | 33 +- src/client/open.jl | 44 +- src/client/redirects.jl | 7 +- src/client/request.jl | 2 +- src/client/retry.jl | 142 +++--- src/client/stream.jl | 886 ++++++++++++++++++------------------- src/download.jl | 10 +- src/requestresponse.jl | 339 +++++++-------- src/server.jl | 888 ++++++++++++++++++++------------------ src/statistics.jl | 82 ++++ src/utils.jl | 205 ++++++--- src/websockets.jl | 785 ++++++++++++++++----------------- test/client.jl | 43 +- test/runtests.jl | 2 +- test/server.jl | 145 ++++--- test/utils.jl | 3 +- 20 files changed, 2197 insertions(+), 2191 deletions(-) create mode 100644 src/statistics.jl diff --git a/Project.toml b/Project.toml index 6d1b5d97..3f1d5760 100644 --- a/Project.toml +++ b/Project.toml @@ -8,9 +8,8 @@ Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -LibAwsCommon = "c6e421ba-b5f8-4792-a1c4-42948de3ed9d" -LibAwsHTTPFork = "d3f1d20b-921e-4930-8491-471e0be3121a" -LibAwsIO = "a5388770-19df-4151-b103-3d71de896ddf" +AwsHTTP = "d4eb1443-154a-48c0-b55a-2f1d1087a5c5" +AwsIO = "4047365c-aa37-44ec-b1fa-4c0d5495ccf1" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Mmap = "a63ad114-7e13-5084-954f-fe012c677804" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -21,9 +20,8 @@ URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" [compat] CodecZlib = "0.7" JSON = "0.21.4, 1" -LibAwsCommon = "1.3" -LibAwsHTTPFork = "1.0.2" -LibAwsIO = "1.2.0" +AwsHTTP = "0.1" +AwsIO = "1.1" PrecompileTools = "1.2.1" URIs = "1" julia = "1.10" diff --git a/src/HTTP.jl b/src/HTTP.jl index bc59b11d..62942e5c 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -1,8 +1,7 @@ module HTTP using CodecZlib, URIs, Mmap, Base64, Dates, Sockets -using LibAwsCommon, LibAwsIO, LibAwsHTTPFork -import LibAwsCommon: Future, FieldRef +using AwsIO, AwsHTTP export HTTPVersion export startwrite, startread, closewrite, closeread @@ -14,6 +13,7 @@ const nobody = UInt8[] Base.@deprecate escape escapeuri include("utils.jl") +include("statistics.jl") include("access_log.jl") include("sniff.jl"); using .Sniff include("forms.jl"); using .Forms @@ -21,14 +21,14 @@ include("requestresponse.jl") include("exceptions.jl"); using .Exceptions struct StatusError <: HTTPError request_method::String - request_uri::aws_uri + request_uri::URI response::Response end function Base.showerror(io::IO, e::StatusError) println(io, "HTTP.StatusError:") println(io, " Request method: $(e.request_method)") - println(io, " Request URI: $(makeuri(e.request_uri))") + println(io, " Request URI: $(e.request_uri)") println(io, " response:") print_response(io, e.response) return @@ -41,7 +41,7 @@ function Base.getproperty(e::StatusError, s::Symbol) elseif s == :method return e.request_method elseif s == :target - return makeuri(e.request_uri) + return e.request_uri else return getfield(e, s) end @@ -61,62 +61,16 @@ include("server.jl") include("handlers.jl"); using .Handlers include("statuses.jl") -#NOTE: this is global process logging in the aws-crt libraries; not appropriate for request-level +#NOTE: this is process-level logging; not appropriate for request-level # logging, but more for debugging the library itself -mutable struct AwsLogger - ptr::Ptr{aws_logger} - file_ref::Libc.FILE - options::aws_logger_standard_options - function AwsLogger(level::Integer, allocator::Ptr{aws_allocator}) - fr = Libc.FILE(Libc.RawFD(1), "w") - opts = aws_logger_standard_options(aws_log_level(0), C_NULL, Ptr{Libc.FILE}(fr.ptr)) - x = new(Ptr{aws_logger}(aws_mem_acquire(allocator, 64)), fr, opts) - aws_logger_init_standard(x.ptr, allocator, FieldRef(x, :options)) != 0 && aws_throw_error() - aws_logger_set(x.ptr) - return finalizer(x) do x - aws_logger_clean_up(x.ptr) - aws_mem_release(allocator, x.ptr) - end - end -end - -const LOGGER = Ref{AwsLogger}() - -function set_log_level!(level::Integer, allocator::Ptr{aws_allocator}=default_aws_allocator()) +function set_log_level!(level::Integer) @assert 0 <= level <= 7 "log level must be between 0 and 7" - LOGGER[] = AwsLogger(level, allocator) - @assert aws_logger_set_log_level(LOGGER[].ptr, aws_log_level(level)) == 0 + AwsIO.set_log_level!(AwsIO.logger_get(), AwsIO.LogLevel.T(level)) return end function __init__() - allocator = default_aws_allocator() - LibAwsHTTPFork.init(allocator) - # intialize c functions - on_acquired[] = @cfunction(c_on_acquired, Cvoid, (Ptr{Cvoid}, Cint, Ptr{aws_retry_token}, Ptr{Cvoid})) - # on_shutdown[] = @cfunction(c_on_shutdown, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Cvoid})) - on_setup[] = @cfunction(c_on_setup, Cvoid, (Ptr{aws_http_connection}, Cint, Ptr{Cvoid})) - on_change_settings_complete[] = @cfunction(c_on_change_settings_complete, Cvoid, (Ptr{aws_http_connection}, Cint, Ptr{Cvoid})) - on_ping_complete[] = @cfunction(c_on_ping_complete, Cvoid, (Ptr{aws_http_connection}, UInt64, Cint, Ptr{Cvoid})) - on_stream_write_on_complete[] = @cfunction(c_on_stream_write_on_complete, Cvoid, (Ptr{aws_http_stream}, Cint, Ptr{Cvoid})) - on_response_headers[] = @cfunction(c_on_response_headers, Cint, (Ptr{Cvoid}, Cint, Ptr{aws_http_header}, Csize_t, Ptr{Cvoid})) - on_response_header_block_done[] = @cfunction(c_on_response_header_block_done, Cint, (Ptr{Cvoid}, Cint, Ptr{Cvoid})) - on_response_body[] = @cfunction(c_on_response_body, Cint, (Ptr{Cvoid}, Ptr{aws_byte_cursor}, Ptr{Cvoid})) - on_metrics[] = @cfunction(c_on_metrics, Cvoid, (Ptr{Cvoid}, Ptr{aws_http_stream_metrics}, Ptr{Cvoid})) - on_complete[] = @cfunction(c_on_complete, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Cvoid})) - on_destroy[] = @cfunction(c_on_destroy, Cvoid, (Ptr{Cvoid},)) - on_stream_acquired[] = @cfunction(c_on_stream_acquired, Cvoid, (Ptr{aws_http_stream}, Cint, Ptr{Cvoid})) - retry_ready[] = @cfunction(c_retry_ready, Cvoid, (Ptr{aws_retry_token}, Cint, Ptr{Cvoid})) - on_incoming_connection[] = @cfunction(c_on_incoming_connection, Cvoid, (Ptr{Cvoid}, Ptr{aws_http_connection}, Cint, Ptr{Cvoid})) - on_connection_shutdown[] = @cfunction(c_on_connection_shutdown, Cvoid, (Ptr{Cvoid}, Cint, Ptr{Cvoid})) - on_statistics_observer[] = @cfunction(c_on_statistics_observer, Cvoid, (Csize_t, Ptr{aws_array_list}, Ptr{Cvoid})) - on_incoming_request[] = @cfunction(c_on_incoming_request, Ptr{aws_http_stream}, (Ptr{aws_http_connection}, Ptr{Cvoid})) - on_request_headers[] = @cfunction(c_on_request_headers, Cint, (Ptr{aws_http_stream}, aws_http_header_block, Ptr{aws_http_header}, Csize_t, Ptr{Cvoid})) - on_request_header_block_done[] = @cfunction(c_on_request_header_block_done, Cint, (Ptr{aws_http_stream}, aws_http_header_block, Ptr{Cvoid})) - on_request_body[] = @cfunction(c_on_request_body, Cint, (Ptr{aws_http_stream}, Ptr{aws_byte_cursor}, Ptr{Cvoid})) - on_request_done[] = @cfunction(c_on_request_done, Cint, (Ptr{aws_http_stream}, Ptr{Cvoid})) - on_server_stream_complete[] = @cfunction(c_on_server_stream_complete, Cint, (Ptr{aws_http_connection}, Cint, Ptr{Cvoid})) - on_destroy_complete[] = @cfunction(c_on_destroy_complete, Cvoid, (Ptr{Cvoid},)) + AwsHTTP.http_library_init() return end diff --git a/src/client/client.jl b/src/client/client.jl index b9dfca36..d454556a 100644 --- a/src/client/client.jl +++ b/src/client/client.jl @@ -1,71 +1,72 @@ const DEFAULT_CONNECT_TIMEOUT = 3000 const DEFAULT_MAX_RETRIES = 4 -const on_statistics_observer = Ref{Ptr{Cvoid}}(C_NULL) +# ─── Shared infrastructure ─── +# Lazily initialized resources that replace the old C library globals +# (default_aws_allocator, default_aws_event_loop_group, default_aws_client_bootstrap). +const _RESOURCES_LOCK = ReentrantLock() +const _EVENT_LOOP_GROUP = Ref{Any}(nothing) +const _HOST_RESOLVER = Ref{Any}(nothing) +const _CLIENT_BOOTSTRAP = Ref{Any}(nothing) -mutable struct StatisticsObserver - cb::Function -end - -function _decode_statistics(stats_list_ptr::Ptr{aws_array_list}) - stats_list_ptr == C_NULL && return Any[] - stats_list = unsafe_load(stats_list_ptr) - len = Int(stats_list.length) - len == 0 && return Any[] - item_size = Int(stats_list.item_size) - data_ptr = Ptr{UInt8}(stats_list.data) - data_ptr == C_NULL && return Any[] - stats = Vector{Any}(undef, len) - for i in 1:len - item_ptr = data_ptr + (i - 1) * item_size - category = unsafe_load(Ptr{UInt32}(item_ptr)) - if category == UInt32(AWSCRT_STAT_CAT_HTTP1_CHANNEL) - entry = unsafe_load(Ptr{aws_crt_statistics_http1_channel}(item_ptr)) - stats[i] = ( - category = :http1_channel, - pending_outgoing_stream_ms = entry.pending_outgoing_stream_ms, - pending_incoming_stream_ms = entry.pending_incoming_stream_ms, - current_outgoing_stream_id = entry.current_outgoing_stream_id, - current_incoming_stream_id = entry.current_incoming_stream_id, - ) - elseif category == UInt32(AWSCRT_STAT_CAT_HTTP2_CHANNEL) - entry = unsafe_load(Ptr{aws_crt_statistics_http2_channel}(item_ptr)) - stats[i] = ( - category = :http2_channel, - pending_outgoing_stream_ms = entry.pending_outgoing_stream_ms, - pending_incoming_stream_ms = entry.pending_incoming_stream_ms, - was_inactive = entry.was_inactive, - ) - else - raw = Vector{UInt8}(undef, item_size) - GC.@preserve raw unsafe_copyto!(pointer(raw), item_ptr, item_size) - stats[i] = (category = :unknown, raw = raw) - end +function _ensure_resources!() + _CLIENT_BOOTSTRAP[] !== nothing && return + Base.@lock _RESOURCES_LOCK begin + _CLIENT_BOOTSTRAP[] !== nothing && return + elg_opts = AwsIO.EventLoopGroupOptions(; + type = _use_nw_sockets() ? AwsIO.EventLoopType.DISPATCH_QUEUE : AwsIO.EventLoopType.PLATFORM_DEFAULT, + ) + elg = AwsIO.event_loop_group_new(elg_opts) + elg isa AwsIO.ErrorResult && throw(AWSError("Failed to create event loop group; ensure sufficient interactive threads")) + _EVENT_LOOP_GROUP[] = elg + resolver = AwsIO.DefaultHostResolver(elg) + _HOST_RESOLVER[] = resolver + bootstrap = AwsIO.ClientBootstrap(AwsIO.ClientBootstrapOptions( + event_loop_group=elg, + host_resolver=resolver, + )) + _CLIENT_BOOTSTRAP[] = bootstrap end - return stats end -_decode_statistics(stats_list::Ref{aws_array_list}) = - _decode_statistics(Base.unsafe_convert(Ptr{aws_array_list}, stats_list)) +# ─── TLS helper ─── -function c_on_statistics_observer(connection_nonce::Csize_t, stats_list::Ptr{aws_array_list}, observer_ptr::Ptr{Cvoid}) - observer = unsafe_pointer_to_objref(observer_ptr)::StatisticsObserver - stats = _decode_statistics(stats_list) - try - Base.invokelatest(observer.cb, connection_nonce, stats) - catch e - @error "statistics observer error" exception=(e, catch_backtrace()) +function _make_tls_options(host::String; ssl_cert, ssl_key, ssl_capath, ssl_cacert, ssl_insecure, ssl_alpn_list) + alpn_list = _normalize_alpn_list(ssl_alpn_list) + if ssl_cert !== nothing && ssl_key !== nothing + # Mutual TLS: client certificate + key (file paths) + opts = AwsIO.tls_ctx_options_init_client_mtls_from_path(ssl_cert, ssl_key) + opts isa AwsIO.ErrorResult && throw(AWSError("Failed to create mTLS options")) + AwsIO.tls_ctx_options_set_verify_peer!(opts, !ssl_insecure) + if alpn_list !== nothing && !isempty(alpn_list) + AwsIO.tls_ctx_options_set_alpn_list!(opts, alpn_list) + end + if ssl_cacert !== nothing || ssl_capath !== nothing + res = AwsIO.tls_ctx_options_override_default_trust_store_from_path!(opts; + ca_path=ssl_capath, ca_file=ssl_cacert) + res isa AwsIO.ErrorResult && throw(AWSError("Failed to set trust store")) + end + ctx = AwsIO.tls_context_new(opts) + ctx isa AwsIO.ErrorResult && throw(AWSError("Failed to create TLS context")) + else + # Standard client TLS (no client cert) + ctx = AwsIO.tls_context_new_client(; + verify_peer=!ssl_insecure, + ca_file=ssl_cacert, + ca_path=ssl_capath, + alpn_list=alpn_list, + ) + ctx isa AwsIO.ErrorResult && throw(AWSError("Failed to create TLS context")) end - return + return AwsIO.TlsConnectionOptions(ctx; server_name=host) end +# ─── Settings ─── + Base.@kwdef struct ClientSettings scheme::String host::String port::UInt32 - allocator::Ptr{aws_allocator} = default_aws_allocator() - bootstrap::Ptr{aws_client_bootstrap} = default_aws_client_bootstrap() - event_loop_group::Ptr{aws_event_loop_group} = default_aws_event_loop_group() socket_domain::Symbol = :ipv4 connect_timeout_ms::Int = DEFAULT_CONNECT_TIMEOUT keep_alive_interval_sec::Int = 0 @@ -96,7 +97,7 @@ Base.@kwdef struct ClientSettings max_retries::Int = DEFAULT_MAX_RETRIES backoff_scale_factor_ms::Int = 25 max_backoff_secs::Int = 20 - jitter_mode::aws_exponential_backoff_jitter_mode = AWS_EXPONENTIAL_BACKOFF_JITTER_DEFAULT + jitter_mode::Symbol = :default retry_timeout_ms::Int = 60000 initial_bucket_capacity::Int = 500 max_connections::Int = 512 @@ -135,7 +136,7 @@ ClientSettings( kw...) = begin http2_initial_settings = Base.get(() -> nothing, kw, :http2_initial_settings) if http2_initial_settings !== nothing && !(http2_initial_settings isa AbstractVector) - throw(ArgumentError("http2_initial_settings must be a vector of pairs or aws_http2_setting")) + throw(ArgumentError("http2_initial_settings must be a vector of pairs or AwsHTTP.Http2Setting")) end if haskey(kw, :http2_initial_settings) kw = Base.structdiff((; kw...), (; http2_initial_settings=nothing)) @@ -160,23 +161,49 @@ end return ex end +# ─── Compat option mirrors for tests ─── + +struct ConnManagerOptsCompat + http2_conn_manual_window_management::Bool + max_closed_streams::Csize_t + initial_window_size::Csize_t + num_initial_settings::Csize_t + initial_settings_array::Ptr{AwsHTTP.Http2Setting} + _initial_settings_storage::Vector{AwsHTTP.Http2Setting} +end + +struct Http2StreamManagerOptsCompat + close_connection_on_server_error::Bool + conn_manual_window_management::Bool + connection_ping_period_ms::Csize_t + connection_ping_timeout_ms::Csize_t + ideal_concurrent_streams_per_connection::Csize_t + max_concurrent_streams_per_connection::Csize_t + initial_window_size::Csize_t + max_closed_streams::Csize_t + num_initial_settings::Csize_t + initial_settings_array::Ptr{AwsHTTP.Http2Setting} + _initial_settings_storage::Vector{AwsHTTP.Http2Setting} +end + +# ─── Client ─── + mutable struct Client settings::ClientSettings - socket_options::aws_socket_options - tls_options::Union{Nothing, aws_tls_connection_options} + socket_options::AwsIO.SocketOptions + tls_options::Union{Nothing, AwsIO.TlsConnectionOptions} # only 1 of proxy_options or proxy_env_settings is set - proxy_options::Union{Nothing, aws_http_proxy_options} - proxy_env_settings::Union{Nothing, proxy_env_var_settings} - proxy_strategy::Ptr{aws_http_proxy_strategy} - monitoring_options::Union{Nothing, aws_http_connection_monitoring_options} - monitoring_observer::Union{Nothing, Any} - retry_options::aws_standard_retry_options - retry_strategy::Ptr{aws_retry_strategy} - conn_manager_opts::aws_http_connection_manager_options - connection_manager::Ptr{aws_http_connection_manager} - http2_stream_manager_opts::Union{Nothing, aws_http2_stream_manager_options} - http2_stream_manager::Ptr{aws_http2_stream_manager} - http2_initial_settings::Union{Nothing, Vector{aws_http2_setting}} + proxy_options::Union{Nothing, AwsHTTP.HttpProxyOptions} + proxy_env_settings::Union{Nothing, AwsHTTP.ProxyEnvVarSettings} + proxy_strategy::Union{Nothing, AwsHTTP.HttpProxyStrategy} + monitoring_options::Union{Nothing, AwsHTTP.HttpConnectionMonitoringOptions} + monitoring_observer::Union{Nothing, Function} + retry_strategy::AwsIO.StandardRetryStrategy + connection_manager::AwsHTTP.HttpConnectionManager + http2_stream_manager::Union{Nothing, AwsHTTP.Http2StreamManager} + http2_initial_settings::Union{Nothing, Vector{AwsHTTP.Http2Setting}} + conn_manager_opts::ConnManagerOptsCompat + http2_stream_manager_opts::Union{Nothing, Http2StreamManagerOptsCompat} Client() = new() end @@ -184,41 +211,40 @@ end Client(scheme::AbstractString, host::AbstractString, port::Integer; kw...) = Client(ClientSettings(scheme, host, port % UInt32; kw...)) function Client(cs::ClientSettings) + _ensure_resources!() client = Client() client.settings = cs if cs.http2_initial_window_size < 0 || cs.http2_initial_window_size > HTTP2_MAX_WINDOW_SIZE throw(ArgumentError("http2_initial_window_size must be between 0 and $(HTTP2_MAX_WINDOW_SIZE)")) end # socket options - client.socket_options = aws_socket_options( - AWS_SOCKET_STREAM, # socket type - cs.socket_domain == :ipv4 ? AWS_SOCKET_IPV4 : AWS_SOCKET_IPV6, # socket domain - AWS_SOCKET_IMPL_PLATFORM_DEFAULT, # aws_socket_impl_type - cs.connect_timeout_ms, - cs.keep_alive_interval_sec, - cs.keep_alive_timeout_sec, - cs.keep_alive_max_failed_probes, - cs.keepalive, - ntuple(x -> Cchar(0), 16) # network_interface_name + client.socket_options = AwsIO.SocketOptions(; + type=AwsIO.SocketType.STREAM, + domain=cs.socket_domain == :ipv4 ? AwsIO.SocketDomain.IPV4 : AwsIO.SocketDomain.IPV6, + connect_timeout_ms=cs.connect_timeout_ms, + keepalive=cs.keepalive, + keep_alive_interval_sec=cs.keep_alive_interval_sec, + keep_alive_timeout_sec=cs.keep_alive_timeout_sec, + keep_alive_max_failed_probes=cs.keep_alive_max_failed_probes, ) + if _use_nw_sockets() + client.socket_options.impl_type = AwsIO.SocketImplType.APPLE_NETWORK_FRAMEWORK + end # tls options if cs.scheme == "https" || cs.scheme == "wss" - client.tls_options = LibAwsIO.tlsoptions(cs.host; - cs.ssl_cert, - cs.ssl_key, - cs.ssl_capath, - cs.ssl_cacert, - cs.ssl_insecure, - cs.ssl_alpn_list - ) + client.tls_options = _make_tls_options(cs.host; + cs.ssl_cert, cs.ssl_key, cs.ssl_capath, cs.ssl_cacert, + cs.ssl_insecure, cs.ssl_alpn_list) else client.tls_options = nothing end # proxy options client.proxy_options = nothing client.proxy_env_settings = nothing - client.proxy_strategy = C_NULL - proxy_connection_type = cs.proxy_connection_type == :forward ? AWS_HPCT_HTTP_FORWARD : AWS_HPCT_HTTP_TUNNEL + client.proxy_strategy = nothing + proxy_connection_type = cs.proxy_connection_type == :forward ? + AwsHTTP.HttpProxyConnectionType.HTTP_FORWARD : + AwsHTTP.HttpProxyConnectionType.HTTP_TUNNEL if cs.proxy_host !== nothing && cs.proxy_port !== nothing proxy_auth = cs.proxy_auth if proxy_auth === nothing && (cs.proxy_username !== nothing || cs.proxy_password !== nothing) @@ -228,188 +254,175 @@ function Client(cs::ClientSettings) proxy_auth == :basic || throw(ArgumentError("unsupported proxy_auth: $proxy_auth")) cs.proxy_username === nothing && throw(ArgumentError("proxy_username required for basic proxy auth")) cs.proxy_password === nothing && throw(ArgumentError("proxy_password required for basic proxy auth")) - auth_opts = aws_http_proxy_strategy_basic_auth_options( - proxy_connection_type, - aws_byte_cursor_from_c_str(cs.proxy_username), - aws_byte_cursor_from_c_str(cs.proxy_password), + client.proxy_strategy = AwsHTTP.http_proxy_strategy_new_basic_auth( + AwsHTTP.HttpProxyStrategyBasicAuthOptions( + proxy_connection_type, + cs.proxy_username, + cs.proxy_password, + ) ) - GC.@preserve cs begin - client.proxy_strategy = aws_http_proxy_strategy_new_basic_auth(cs.allocator, Ref(auth_opts)) - end - client.proxy_strategy == C_NULL && aws_throw_error() end - client.proxy_options = aws_http_proxy_options( - proxy_connection_type, - aws_byte_cursor_from_c_str(cs.proxy_host), - cs.proxy_port % UInt32, - cs.proxy_ssl_cert === nothing ? C_NULL : LibAwsIO.tlsoptions(cs.proxy_host; - cs.proxy_ssl_cert, - cs.proxy_ssl_key, - cs.proxy_ssl_capath, - cs.proxy_ssl_cacert, - cs.proxy_ssl_insecure, - cs.proxy_ssl_alpn_list - ), - client.proxy_strategy, # proxy_strategy::Ptr{aws_http_proxy_strategy} - AWS_HPAT_NONE, # auth_type::aws_http_proxy_authentication_type - aws_byte_cursor_from_c_str(""), # auth_username::aws_byte_cursor - aws_byte_cursor_from_c_str(""), # auth_password::aws_byte_cursor + client.proxy_options = AwsHTTP.HttpProxyOptions(; + connection_type=proxy_connection_type, + host=cs.proxy_host, + port=cs.proxy_port % UInt32, + proxy_strategy=client.proxy_strategy, ) elseif cs.proxy_allow_env_var if cs.proxy_auth !== nothing || cs.proxy_username !== nothing || cs.proxy_password !== nothing throw(ArgumentError("proxy auth requires explicit proxy_host/proxy_port")) end - client.proxy_env_settings = proxy_env_var_settings( - AWS_HPEV_ENABLE, - proxy_connection_type, - cs.proxy_ssl_cert === nothing ? C_NULL : LibAwsIO.tlsoptions(cs.proxy_host; - cs.proxy_ssl_cert, - cs.proxy_ssl_key, - cs.proxy_ssl_capath, - cs.proxy_ssl_cacert, - cs.proxy_ssl_insecure, - cs.proxy_ssl_alpn_list - ) + client.proxy_env_settings = AwsHTTP.ProxyEnvVarSettings(; + env_var_type=AwsHTTP.HttpProxyEnvVarType.ENABLE, + connection_type=proxy_connection_type, ) end # connection monitoring options - monitoring_ptr = C_NULL - if cs.monitoring_statistics_observer !== nothing || - cs.monitoring_minimum_throughput_bytes_per_second != 0 || + if cs.monitoring_minimum_throughput_bytes_per_second != 0 || cs.monitoring_allowable_throughput_failure_interval_seconds != 0 - observer = cs.monitoring_statistics_observer === nothing ? nothing : StatisticsObserver(cs.monitoring_statistics_observer) - client.monitoring_observer = observer - client.monitoring_options = aws_http_connection_monitoring_options( + client.monitoring_options = AwsHTTP.HttpConnectionMonitoringOptions( UInt64(cs.monitoring_minimum_throughput_bytes_per_second), UInt32(cs.monitoring_allowable_throughput_failure_interval_seconds), - observer === nothing ? C_NULL : on_statistics_observer[], - observer === nothing ? C_NULL : pointer_from_objref(observer) ) - monitoring_ptr = pointer(FieldRef(client, :monitoring_options)) else client.monitoring_options = nothing - client.monitoring_observer = nothing end + client.monitoring_observer = cs.monitoring_statistics_observer # retry strategy - exp_back_opts = aws_exponential_backoff_retry_options( - cs.event_loop_group, - cs.max_retries, - cs.backoff_scale_factor_ms, - cs.max_backoff_secs, - cs.jitter_mode, - C_NULL, # generate_random - C_NULL, # generate_random_impl - C_NULL, # generate_random_user_data - C_NULL, # shutdown_options::Ptr{aws_shutdown_callback_options} + backoff_config = AwsIO.ExponentialBackoffConfig(; + backoff_scale_factor_ms=cs.backoff_scale_factor_ms, + max_backoff_secs=cs.max_backoff_secs, + max_retries=cs.max_retries, + jitter_mode=cs.jitter_mode, ) - client.retry_options = aws_standard_retry_options( - exp_back_opts, - cs.initial_bucket_capacity + retry_config = AwsIO.StandardRetryConfig(; + initial_bucket_capacity=cs.initial_bucket_capacity, + backoff_config=backoff_config, ) - client.retry_strategy = aws_retry_strategy_new_standard(cs.allocator, FieldRef(client, :retry_options)) - client.retry_strategy == C_NULL && aws_throw_error() + strategy = AwsIO.StandardRetryStrategy(_EVENT_LOOP_GROUP[], retry_config) + strategy isa AwsIO.ErrorResult && throw(AWSError("Failed to create retry strategy")) + client.retry_strategy = strategy + # http2 initial settings settings_input = cs.http2_initial_settings if settings_input === nothing client.http2_initial_settings = nothing - elseif settings_input isa AbstractVector{aws_http2_setting} + elseif settings_input isa AbstractVector{AwsHTTP.Http2Setting} client.http2_initial_settings = collect(settings_input) elseif settings_input isa AbstractVector{<:Pair} client.http2_initial_settings = _settings_from_pairs(settings_input) else - throw(ArgumentError("http2_initial_settings must be a vector of pairs or aws_http2_setting")) + throw(ArgumentError("http2_initial_settings must be a vector of pairs or AwsHTTP.Http2Setting")) end - settings_ptr = client.http2_initial_settings === nothing ? C_NULL : pointer(client.http2_initial_settings) - settings_len = client.http2_initial_settings === nothing ? 0 : length(client.http2_initial_settings) - - client.conn_manager_opts = aws_http_connection_manager_options( - cs.bootstrap, - cs.http2_initial_window_size, # initial_window_size::Csize_t - pointer(FieldRef(client, :socket_options)), - cs.response_first_byte_timeout_ms, - (cs.scheme == "https" || cs.scheme == "wss") ? pointer(FieldRef(client, :tls_options)) : C_NULL, - cs.http2_prior_knowledge, - monitoring_ptr, # monitoring_options::Ptr{aws_http_connection_monitoring_options} - aws_byte_cursor_from_c_str(cs.host), - cs.port % UInt32, - settings_ptr, # initial_settings_array::Ptr{aws_http2_setting} - settings_len, # num_initial_settings::Csize_t - cs.http2_max_closed_streams, # max_closed_streams::Csize_t - cs.http2_connection_manual_window_management, # http2_conn_manual_window_management::Bool - client.proxy_options === nothing ? C_NULL : pointer(FieldRef(client, :proxy_options)), # proxy_options::Ptr{aws_http_proxy_options} - client.proxy_env_settings === nothing ? C_NULL : pointer(FieldRef(client, :proxy_env_settings)), # proxy_env_settings::Ptr{proxy_env_var_settings} - cs.max_connections, # max_connections::Csize_t, 512 - C_NULL, # shutdown_complete_user_data::Ptr{Cvoid} - C_NULL, # shutdown_complete_callback::Ptr{aws_http_connection_manager_shutdown_complete_fn} - cs.enable_read_back_pressure, # enable_read_back_pressure::Bool - cs.max_connection_idle_in_milliseconds, - cs.connection_acquisition_timeout_ms, - cs.max_pending_connection_acquisitions, - C_NULL, # network_interface_names_array - 0, # num_network_interface_names + settings_storage = client.http2_initial_settings === nothing ? AwsHTTP.Http2Setting[] : client.http2_initial_settings + settings_ptr = isempty(settings_storage) ? Ptr{AwsHTTP.Http2Setting}(C_NULL) : pointer(settings_storage) + settings_count = Csize_t(length(settings_storage)) + # connection factory: creates connections for the pool managers. + # Calls AwsHTTP.http_client_connect (async) and blocks until setup completes. + conn_factory = let socket_opts=client.socket_options, tls_opts=client.tls_options, + host=cs.host, port=cs.port, + prior_knowledge=cs.http2_prior_knowledge, + manual_wm=cs.http2_connection_manual_window_management, + initial_ws=cs.http2_initial_window_size, + rfbt_ms=cs.response_first_byte_timeout_ms + function(_manager_opts) + result_ch = Base.Channel{Any}(1) + AwsHTTP.http_client_connect(AwsHTTP.HttpClientConnectionOptions( + bootstrap=_CLIENT_BOOTSTRAP[], + host_name=host, + port=port, + socket_options=socket_opts, + tls_connection_options=tls_opts, + prior_knowledge_http2=prior_knowledge, + manual_window_management=manual_wm, + initial_window_size=Csize_t(initial_ws), + response_first_byte_timeout_ms=UInt64(rfbt_ms), + on_setup=(conn, err, ud) -> put!(result_ch, err == AwsIO.OP_SUCCESS ? conn : nothing), + )) + return take!(result_ch) + end + end + # connection manager + client.connection_manager = AwsHTTP.http_connection_manager_new( + AwsHTTP.HttpConnectionManagerOptions(; + host=cs.host, + port=cs.port, + max_connections=cs.max_connections, + initial_window_size=Csize_t(cs.http2_initial_window_size), + manual_window_management=cs.http2_connection_manual_window_management, + http2_prior_knowledge=cs.http2_prior_knowledge, + enable_read_back_pressure=cs.enable_read_back_pressure, + max_connection_idle_in_milliseconds=UInt64(cs.max_connection_idle_in_milliseconds), + connection_acquisition_timeout_ms=UInt64(cs.connection_acquisition_timeout_ms), + max_pending_connection_acquisitions=cs.max_pending_connection_acquisitions, + response_first_byte_timeout_ms=UInt64(cs.response_first_byte_timeout_ms), + max_closed_streams=cs.http2_max_closed_streams, + http2_conn_manual_window_management=cs.http2_connection_manual_window_management, + on_connection_setup=conn_factory, + ) ) - client.connection_manager = aws_http_connection_manager_new(cs.allocator, FieldRef(client, :conn_manager_opts)) - client.connection_manager == C_NULL && aws_throw_error() + client.conn_manager_opts = ConnManagerOptsCompat( + cs.http2_connection_manual_window_management, + Csize_t(cs.http2_max_closed_streams), + Csize_t(cs.http2_initial_window_size), + settings_count, + settings_ptr, + settings_storage, + ) + # http2 stream manager (optional) + client.http2_stream_manager = nothing client.http2_stream_manager_opts = nothing - client.http2_stream_manager = C_NULL if cs.http2_stream_manager - opts = aws_http2_stream_manager_options( - cs.bootstrap, - pointer(FieldRef(client, :socket_options)), - (cs.scheme == "https" || cs.scheme == "wss") ? pointer(FieldRef(client, :tls_options)) : C_NULL, - cs.http2_prior_knowledge, - aws_byte_cursor_from_c_str(cs.host), - cs.port % UInt32, - settings_ptr, # initial_settings_array - settings_len, # num_initial_settings - cs.http2_max_closed_streams, # max_closed_streams - cs.http2_connection_manual_window_management, # conn_manual_window_management - cs.enable_read_back_pressure, - cs.http2_initial_window_size, # initial_window_size - monitoring_ptr, # monitoring_options - client.proxy_options === nothing ? C_NULL : pointer(FieldRef(client, :proxy_options)), - client.proxy_env_settings === nothing ? C_NULL : pointer(FieldRef(client, :proxy_env_settings)), - C_NULL, # shutdown_complete_user_data - C_NULL, # shutdown_complete_callback - cs.http2_close_connection_on_server_error, # close_connection_on_server_error - cs.http2_connection_ping_period_ms, # connection_ping_period_ms - cs.http2_connection_ping_timeout_ms, # connection_ping_timeout_ms - cs.http2_ideal_concurrent_streams_per_connection, # ideal_concurrent_streams_per_connection - cs.http2_max_concurrent_streams_per_connection, # max_concurrent_streams_per_connection - cs.max_connections, + client.http2_stream_manager = AwsHTTP.http2_stream_manager_new( + AwsHTTP.Http2StreamManagerOptions(; + host=cs.host, + port=cs.port, + max_connections=cs.max_connections, + ideal_concurrent_streams_per_connection=cs.http2_ideal_concurrent_streams_per_connection, + max_concurrent_streams_per_connection=cs.http2_max_concurrent_streams_per_connection, + close_connection_on_server_error=cs.http2_close_connection_on_server_error, + connection_ping_period_ms=UInt64(cs.http2_connection_ping_period_ms), + connection_ping_timeout_ms=UInt64(cs.http2_connection_ping_timeout_ms), + initial_window_size=Csize_t(cs.http2_initial_window_size), + manual_window_management=cs.http2_connection_manual_window_management, + http2_prior_knowledge=cs.http2_prior_knowledge, + enable_read_back_pressure=cs.enable_read_back_pressure, + max_closed_streams=cs.http2_max_closed_streams, + on_connection_setup=conn_factory, + ) + ) + client.http2_stream_manager_opts = Http2StreamManagerOptsCompat( + cs.http2_close_connection_on_server_error, + cs.http2_connection_manual_window_management, + Csize_t(cs.http2_connection_ping_period_ms), + Csize_t(cs.http2_connection_ping_timeout_ms), + Csize_t(cs.http2_ideal_concurrent_streams_per_connection), + Csize_t(cs.http2_max_concurrent_streams_per_connection), + Csize_t(cs.http2_initial_window_size), + Csize_t(cs.http2_max_closed_streams), + settings_count, + settings_ptr, + settings_storage, ) - client.http2_stream_manager_opts = opts - client.http2_stream_manager = aws_http2_stream_manager_new(cs.allocator, Ref(opts)) - client.http2_stream_manager == C_NULL && aws_throw_error() - end - - finalizer(client) do x - if x.connection_manager != C_NULL - aws_http_connection_manager_release(x.connection_manager) - x.connection_manager = C_NULL - end - if x.http2_stream_manager != C_NULL - aws_http2_stream_manager_release(x.http2_stream_manager) - x.http2_stream_manager = C_NULL - end - if x.proxy_strategy != C_NULL - aws_http_proxy_strategy_release(x.proxy_strategy) - x.proxy_strategy = C_NULL - end - if x.retry_strategy != C_NULL - aws_retry_strategy_release(x.retry_strategy) - x.retry_strategy = C_NULL - end end return client end -#TODO: this should probably be a LRU cache to not grow indefinitely +# ─── Client cache ─── + +const _CLIENT_CACHE_MAX = let val = get(ENV, "HTTP_CLIENT_CACHE_MAX", "64") + parsed = tryparse(Int, val) + parsed === nothing || parsed < 1 ? 64 : parsed +end + struct Clients lock::ReentrantLock clients::Dict{ClientSettings, Client} + order::Vector{ClientSettings} + max_clients::Int end -Clients() = Clients(ReentrantLock(), Dict{ClientSettings, Client}()) +Clients(max_clients::Int=_CLIENT_CACHE_MAX) = + Clients(ReentrantLock(), Dict{ClientSettings, Client}(), ClientSettings[], max_clients) struct Pool clients::Clients @@ -423,23 +436,28 @@ const CLIENTS = Clients() function getclient(key::ClientSettings, clients::Clients=CLIENTS) Base.@lock clients.lock begin if haskey(clients.clients, key) + idx = findfirst(==(key), clients.order) + idx !== nothing && deleteat!(clients.order, idx) + push!(clients.order, key) return clients.clients[key] - else - client = Client(key) - clients.clients[key] = client - return client end + client = Client(key) + clients.clients[key] = client + push!(clients.order, key) + if length(clients.order) > clients.max_clients + evict = popfirst!(clients.order) + delete!(clients.clients, evict) + end + return client end end function manager_metrics(client::Client) - metrics = Ref{aws_http_manager_metrics}() - if client.http2_stream_manager != C_NULL - aws_http2_stream_manager_fetch_metrics(client.http2_stream_manager, metrics) + if client.http2_stream_manager !== nothing + return AwsHTTP.http2_stream_manager_fetch_metrics(client.http2_stream_manager) else - aws_http_connection_manager_fetch_metrics(client.connection_manager, metrics) + return AwsHTTP.http_connection_manager_fetch_metrics(client.connection_manager) end - return metrics[] end getclient(key::ClientSettings, pool::Pool) = getclient(key, pool.clients) @@ -447,7 +465,10 @@ getclient(key::ClientSettings, pool::Pool) = getclient(key, pool.clients) function close_all_clients!(clients::Clients=CLIENTS) Base.@lock clients.lock begin for client in values(clients.clients) - finalize(client) + close(client.connection_manager) + if client.http2_stream_manager !== nothing + close(client.http2_stream_manager) + end end empty!(clients.clients) end diff --git a/src/client/connection.jl b/src/client/connection.jl index 9059a51a..0daa6659 100644 --- a/src/client/connection.jl +++ b/src/client/connection.jl @@ -1,39 +1,3 @@ -const on_setup = Ref{Ptr{Cvoid}}(C_NULL) -const on_change_settings_complete = Ref{Ptr{Cvoid}}(C_NULL) -const on_ping_complete = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_setup(conn, error_code, fut_ptr) - fut = unsafe_pointer_to_objref(fut_ptr) - if error_code == AWS_IO_DNS_INVALID_NAME# || error_code == AWS_IO_TLS_ERROR_NEGOTIATION_FAILURE - notify(fut, DontRetry(CapturedException(aws_error(error_code), Base.backtrace()))) - elseif error_code != 0 - notify(fut, CapturedException(aws_error(error_code), Base.backtrace())) - else - notify(fut, conn) - end - return -end - -function c_on_change_settings_complete(conn, error_code, fut_ptr) - fut = unsafe_pointer_to_objref(fut_ptr) - if error_code != 0 - notify(fut, CapturedException(aws_error(error_code), Base.backtrace())) - else - notify(fut, nothing) - end - return -end - -function c_on_ping_complete(conn, round_trip_time_ns, error_code, fut_ptr) - fut = unsafe_pointer_to_objref(fut_ptr) - if error_code != 0 - notify(fut, CapturedException(aws_error(error_code), Base.backtrace())) - else - notify(fut, round_trip_time_ns) - end - return -end - function _client_url(client::Client) host = client.settings.host port = client.settings.port @@ -41,43 +5,36 @@ function _client_url(client::Client) end function with_connection(f::Function, client::Client; context=nothing) - if context === nothing - fut = Future{Ptr{aws_http_connection}}() - GC.@preserve fut begin - aws_http_connection_manager_acquire_connection(client.connection_manager, on_setup[], pointer_from_objref(fut)) - connection = try - wait(fut) - catch e - throw(ConnectError(_client_url(client), e)) + start_time = context !== nothing ? time() : 0.0 + ch = Channel{Any}(1) + AwsHTTP.http_connection_manager_acquire_connection( + client.connection_manager; + callback = (conn, error_code, _) -> begin + if error_code != AwsHTTP.OP_SUCCESS + ec = AwsIO.last_error() + put!(ch, CapturedException(aws_error(ec), Base.backtrace())) + else + put!(ch, conn) end end - try - return f(connection) - finally - aws_http_connection_manager_release_connection(client.connection_manager, connection) - end - end - start_time = time() - fut = Future{Ptr{aws_http_connection}}() - GC.@preserve fut begin - aws_http_connection_manager_acquire_connection(client.connection_manager, on_setup[], pointer_from_objref(fut)) - connection = try - wait(fut) - catch e - throw(ConnectError(_client_url(client), e)) - end + ) + result = take!(ch) + connection = if result isa Exception + throw(ConnectError(_client_url(client), result)) + else + result end try return f(connection) finally - aws_http_connection_manager_release_connection(client.connection_manager, connection) - _record_layer!(context, :connectionlayer, start_time) + AwsHTTP.http_connection_manager_release_connection(client.connection_manager, connection) + context !== nothing && _record_layer!(context, :connectionlayer, start_time) end end -function _ensure_http2_connection(conn::Ptr{aws_http_connection}) - conn == C_NULL && throw(ArgumentError("HTTP/2 connection is null")) - aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 || throw(ArgumentError("HTTP/2 connection required")) +function _ensure_http2_connection(conn) + conn === nothing && throw(ArgumentError("HTTP/2 connection is null")) + AwsHTTP.http_connection_get_version(conn) == AwsHTTP.HttpVersion.HTTP_2 || throw(ArgumentError("HTTP/2 connection required")) return conn end @@ -87,100 +44,101 @@ function _with_http2_connection(f::Function, client::Client) end end -function http2_ping(conn::Ptr{aws_http_connection}; data=nothing) +function http2_ping(conn; data=nothing) _ensure_http2_connection(conn) fut = Future{UInt64}() - cursor_ref = Ref{aws_byte_cursor}() - cursor_ptr = C_NULL - bytes = nothing - if data !== nothing + opaque_data = if data !== nothing bytes = data isa AbstractString ? Vector{UInt8}(codeunits(data)) : Vector{UInt8}(data) - length(bytes) == AWS_HTTP2_PING_DATA_SIZE || throw(ArgumentError("PING data must be $(AWS_HTTP2_PING_DATA_SIZE) bytes")) - GC.@preserve bytes begin - cursor_ref[] = aws_byte_cursor_from_array(pointer(bytes), length(bytes)) - end - cursor_ptr = cursor_ref - end - GC.@preserve fut cursor_ref bytes begin - aws_http2_connection_ping(conn, cursor_ptr, on_ping_complete[], pointer_from_objref(fut)) != 0 && aws_throw_error() - return wait(fut) + length(bytes) == AwsHTTP.H2_PING_DATA_SIZE || throw(ArgumentError("PING data must be $(AwsHTTP.H2_PING_DATA_SIZE) bytes")) + bytes + else + zeros(UInt8, AwsHTTP.H2_PING_DATA_SIZE) end + AwsHTTP.h2_connection_send_ping!(conn, opaque_data; + on_completed = (rtt_ns, error_code, _) -> begin + if error_code != 0 + notify(fut, CapturedException(aws_error(error_code), Base.backtrace())) + else + notify(fut, rtt_ns) + end + end + ) + return wait(fut) end http2_ping(client::Client; data=nothing) = _with_http2_connection(conn -> http2_ping(conn; data=data), client) function _settings_from_pairs(settings::AbstractVector{<:Pair}) - out = Vector{aws_http2_setting}(undef, length(settings)) + out = Vector{AwsHTTP.Http2Setting}(undef, length(settings)) for (i, (k, v)) in enumerate(settings) - id = k isa aws_http2_settings_id ? k : aws_http2_settings_id(k) - out[i] = aws_http2_setting(id, UInt32(v)) + id = k isa AwsHTTP.Http2SettingsId.T ? k : AwsHTTP.Http2SettingsId.T(k) + out[i] = AwsHTTP.Http2Setting(id, UInt32(v)) end return out end -function http2_change_settings(conn::Ptr{aws_http_connection}, settings::AbstractVector{aws_http2_setting}) +function http2_change_settings(conn, settings::Vector{AwsHTTP.Http2Setting}) _ensure_http2_connection(conn) fut = Future{Nothing}() - settings_ptr = isempty(settings) ? C_NULL : pointer(settings) - GC.@preserve settings fut begin - aws_http2_connection_change_settings(conn, settings_ptr, length(settings), on_change_settings_complete[], pointer_from_objref(fut)) != 0 && aws_throw_error() - wait(fut) - end + AwsHTTP.h2_connection_change_settings!(conn, settings; + on_completed = (error_code, _) -> begin + if error_code != 0 + notify(fut, CapturedException(aws_error(error_code), Base.backtrace())) + else + notify(fut, nothing) + end + end + ) + wait(fut) return end -http2_change_settings(conn::Ptr{aws_http_connection}, settings::AbstractVector{<:Pair}) = +http2_change_settings(conn, settings::AbstractVector{<:Pair}) = http2_change_settings(conn, _settings_from_pairs(settings)) http2_change_settings(client::Client, settings) = _with_http2_connection(conn -> http2_change_settings(conn, settings), client) -function http2_local_settings(conn::Ptr{aws_http_connection}) +function http2_local_settings(conn) _ensure_http2_connection(conn) - settings = Vector{aws_http2_setting}(undef, AWS_HTTP2_SETTINGS_COUNT) - aws_http2_connection_get_local_settings(conn, pointer(settings)) - return settings + return AwsHTTP.h2_connection_get_local_settings(conn) end http2_local_settings(client::Client) = _with_http2_connection(http2_local_settings, client) -function http2_remote_settings(conn::Ptr{aws_http_connection}) +function http2_remote_settings(conn) _ensure_http2_connection(conn) - settings = Vector{aws_http2_setting}(undef, AWS_HTTP2_SETTINGS_COUNT) - aws_http2_connection_get_remote_settings(conn, pointer(settings)) - return settings + return AwsHTTP.h2_connection_get_remote_settings(conn) end http2_remote_settings(client::Client) = _with_http2_connection(http2_remote_settings, client) -function http2_send_goaway(conn::Ptr{aws_http_connection}, http2_error::Integer; allow_more_streams::Bool=true, debug_data=nothing) +function http2_send_goaway(conn, http2_error::Integer; allow_more_streams::Bool=true, debug_data=nothing) _ensure_http2_connection(conn) - cursor_ref = Ref{aws_byte_cursor}() - cursor_ptr = C_NULL - bytes = nothing - if debug_data !== nothing + dd = if debug_data !== nothing bytes = debug_data isa AbstractString ? Vector{UInt8}(codeunits(debug_data)) : Vector{UInt8}(debug_data) length(bytes) <= 16 * 1024 || throw(ArgumentError("debug_data must be <= 16KB")) - GC.@preserve bytes begin - cursor_ref[] = aws_byte_cursor_from_array(pointer(bytes), length(bytes)) - end - cursor_ptr = cursor_ref - end - GC.@preserve bytes cursor_ref begin - aws_http2_connection_send_goaway(conn, UInt32(http2_error), allow_more_streams, cursor_ptr) + bytes + else + UInt8[] end + AwsHTTP.h2_connection_send_goaway!(conn; + allow_more_streams=allow_more_streams, + error_code=UInt32(http2_error), + debug_data=dd + ) return end http2_send_goaway(client::Client, http2_error::Integer; allow_more_streams::Bool=true, debug_data=nothing) = _with_http2_connection(conn -> http2_send_goaway(conn, http2_error; allow_more_streams=allow_more_streams, debug_data=debug_data), client) -function http2_update_window(conn::Ptr{aws_http_connection}, increment::Integer) +function http2_update_window(conn, increment::Integer) _ensure_http2_connection(conn) increment < 0 && throw(ArgumentError("increment must be >= 0")) increment > HTTP2_MAX_WINDOW_SIZE && throw(ArgumentError("increment must be <= $(HTTP2_MAX_WINDOW_SIZE)")) - aws_http2_connection_update_window(conn, UInt32(increment)) + AwsHTTP.h2_connection_update_window!(conn, UInt32(increment)) return end @@ -188,23 +146,18 @@ http2_update_window(client::Client, increment::Integer) = _with_http2_connection(conn -> http2_update_window(conn, increment), client) -function _get_goaway(get_fn, conn::Ptr{aws_http_connection}) +function _get_goaway(get_fn, conn) _ensure_http2_connection(conn) - http2_error = Ref{UInt32}() - last_stream_id = Ref{UInt32}() - ret = get_fn(conn, http2_error, last_stream_id) - if ret == 0 - return (http2_error=http2_error[], last_stream_id=last_stream_id[]) - elseif ret == AWS_ERROR_HTTP_DATA_NOT_AVAILABLE - return nothing + sent_or_received, last_stream_id, error_code = get_fn(conn) + if sent_or_received + return (http2_error=error_code, last_stream_id=last_stream_id) else - aws_throw_error() + return nothing end end -http2_get_sent_goaway(conn::Ptr{aws_http_connection}) = _get_goaway(aws_http2_connection_get_sent_goaway, conn) -http2_get_received_goaway(conn::Ptr{aws_http_connection}) = _get_goaway(aws_http2_connection_get_received_goaway, conn) +http2_get_sent_goaway(conn) = _get_goaway(AwsHTTP.h2_connection_get_sent_goaway, conn) +http2_get_received_goaway(conn) = _get_goaway(AwsHTTP.h2_connection_get_received_goaway, conn) http2_get_sent_goaway(client::Client) = _with_http2_connection(http2_get_sent_goaway, client) - http2_get_received_goaway(client::Client) = _with_http2_connection(http2_get_received_goaway, client) diff --git a/src/client/makerequest.jl b/src/client/makerequest.jl index 37a2fcec..e8125feb 100644 --- a/src/client/makerequest.jl +++ b/src/client/makerequest.jl @@ -38,7 +38,6 @@ end # main entrypoint for making an HTTP request # can provide method, url, headers, body, along with various keyword arguments function request(method, url, h=Header[], b=nothing; - allocator=default_aws_allocator(), headers=h, body=b, chunkedbody=nothing, @@ -96,9 +95,9 @@ function request(method, url, h=Header[], b=nothing; body isa Form ) headers = mkreqheaders(headers, copyheaders) - uri = parseuri(url, query, allocator) + uri = parseuri(url, query) proxy_kw = proxy_kwargs(proxy, scheme(uri)) - client_kw = (; allocator=allocator, kw...) + client_kw = (; kw...) if pool isa Pool && pool.max_connections !== nothing && !haskey(client_kw, :max_connections) client_kw = merge(client_kw, (; max_connections=pool.max_connections)) end @@ -107,7 +106,7 @@ function request(method, url, h=Header[], b=nothing; end authinfo = (username !== nothing && password !== nothing) ? "$username:$password" : userinfo(uri) apply_basicauth = (username !== nothing && password !== nothing) ? true : basicauth - return with_redirect(allocator, method, uri, headers, body, redirect, redirect_limit, redirect_method, forwardheaders; context=context) do method, uri, headers, body + return with_redirect(method, uri, headers, body, redirect, redirect_limit, redirect_method, forwardheaders; context=context) do method, uri, headers, body reqclient = @something( client, pool === nothing ? @@ -118,7 +117,7 @@ function request(method, url, h=Header[], b=nothing; with_retry_token(reqclient; logerrors=logerrors, logtag=logtag, method=method, uri=uri, retry_check=retry_check, retry_delays=retry_delays, retry_non_idempotent=retry_non_idempotent, retryable_body=retryable_body, req_ref=req_ref, context=context) do - resp = if reqclient.http2_stream_manager != C_NULL + resp = if reqclient.http2_stream_manager !== nothing path = resource(uri) with_request(reqclient, method, path, headers, body, chunkedbody, decompress, authinfo, bearer, modifier, true, cookies, cookiejar, verbose; copyheaders=false, @@ -130,21 +129,18 @@ function request(method, url, h=Header[], b=nothing; ) do req req_ref[] = req if response_body isa AbstractVector{UInt8} - ref = Ref(1) - GC.@preserve ref begin - on_stream_response_body = BufferOnResponseBody(response_body, Base.unsafe_convert(Ptr{Int}, ref)) - with_stream_manager(reqclient, req, chunkedbody, on_stream_response_body, decompress, readtimeout, allocator; context=context) - end + on_stream_response_body = BufferOnResponseBody(response_body, Ref(1)) + with_stream_manager(reqclient, req, chunkedbody, on_stream_response_body, decompress, readtimeout; context=context) elseif response_body isa IO on_stream_response_body = IOOnResponseBody(response_body) - with_stream_manager(reqclient, req, chunkedbody, on_stream_response_body, decompress, readtimeout, allocator; context=context) + with_stream_manager(reqclient, req, chunkedbody, on_stream_response_body, decompress, readtimeout; context=context) else - with_stream_manager(reqclient, req, chunkedbody, response_body, decompress, readtimeout, allocator; context=context) + with_stream_manager(reqclient, req, chunkedbody, response_body, decompress, readtimeout; context=context) end end else with_connection(reqclient; context=context) do conn - http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 + http2 = AwsHTTP.http_connection_get_version(conn) == AwsHTTP.HttpVersion.HTTP_2 path = resource(uri) with_request(reqclient, method, path, headers, body, chunkedbody, decompress, authinfo, bearer, modifier, http2, cookies, cookiejar, verbose; copyheaders=false, @@ -156,16 +152,13 @@ function request(method, url, h=Header[], b=nothing; ) do req req_ref[] = req if response_body isa AbstractVector{UInt8} - ref = Ref(1) - GC.@preserve ref begin - on_stream_response_body = BufferOnResponseBody(response_body, Base.unsafe_convert(Ptr{Int}, ref)) - with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator; context=context) - end + on_stream_response_body = BufferOnResponseBody(response_body, Ref(1)) + with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout; context=context) elseif response_body isa IO on_stream_response_body = IOOnResponseBody(response_body) - with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator; context=context) + with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout; context=context) else - with_stream(conn, req, chunkedbody, response_body, decompress, http2, readtimeout, allocator; context=context) + with_stream(conn, req, chunkedbody, response_body, decompress, http2, readtimeout; context=context) end end end diff --git a/src/client/open.jl b/src/client/open.jl index e35c2caf..8a47a6d7 100644 --- a/src/client/open.jl +++ b/src/client/open.jl @@ -1,35 +1,21 @@ -function _open_stream(conn::Ptr{aws_http_connection}, req::Request, decompress, readtimeout, allocator) - http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 - stream = Stream{Ptr{aws_http_connection}}(allocator, decompress, http2, false) +function _open_stream(conn, req::Request, decompress, readtimeout) + http2 = AwsHTTP.http_connection_get_version(conn) == AwsHTTP.HttpVersion.HTTP_2 + stream = Stream{typeof(conn)}(decompress, http2, false) + stream.readtimeout = readtimeout stream.bufferstream = Base.BufferStream() stream.connection = conn stream.request = req - stream.response = resp = Response(0, nothing, nothing, http2, allocator) + stream.response = resp = Response(0, nothing, nothing, http2) resp.request = req - GC.@preserve stream begin - stream.request_options = aws_http_make_request_options( - 1, - req.ptr, - pointer_from_objref(stream), - on_response_headers[], - on_response_header_block_done[], - on_response_body[], - on_metrics[], - on_complete[], - on_destroy[], - http2, # http2_use_manual_data_writes - readtimeout * 1000 # response_first_byte_timeout_ms - ) - stream_ptr = aws_http_connection_make_request(conn, FieldRef(stream, :request_options)) - stream_ptr == C_NULL && aws_throw_error() - stream.ptr = stream_ptr - end - retain_stream!(stream) + opts = _make_request_options(stream, req; readtimeout=readtimeout) + aws_stream = AwsHTTP.http_connection_make_request(conn, opts) + aws_stream === nothing && aws_throw_error() + stream.aws_stream = aws_stream + _activate_stream!(stream) return stream end function open(f::Function, method::Union{String, Symbol}, url, h=Header[]; - allocator=default_aws_allocator(), headers=h, copyheaders::Bool=true, canonicalize_headers::Bool=false, @@ -65,7 +51,7 @@ function open(f::Function, method::Union{String, Symbol}, url, h=Header[]; kw...) method_str = string(method) headers = mkreqheaders(headers, copyheaders) - uri = parseuri(url, query, allocator) + uri = parseuri(url, query) context = observelayers ? Dict{Symbol, Any}() : nothing context === nothing || _init_observations!(context) count = 0 @@ -74,7 +60,7 @@ function open(f::Function, method::Union{String, Symbol}, url, h=Header[]; redirect_url = nothing resp = nothing proxy_kw = proxy_kwargs(proxy, scheme(uri)) - client_kw = (; allocator=allocator, kw...) + client_kw = (; kw...) if pool isa Pool && pool.max_connections !== nothing && !haskey(client_kw, :max_connections) client_kw = merge(client_kw, (; max_connections=pool.max_connections)) end @@ -90,7 +76,7 @@ function open(f::Function, method::Union{String, Symbol}, url, h=Header[]; getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...), pool) )::Client resp = with_connection(reqclient; context=context) do conn - http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 + http2 = AwsHTTP.http_connection_get_version(conn) == AwsHTTP.HttpVersion.HTTP_2 path = resource(uri) with_request(reqclient, method_str, path, headers, nothing, nothing, decompress, authinfo, bearer, modifier, http2, cookies, cookiejar, verbose; copyheaders=false, @@ -107,7 +93,7 @@ function open(f::Function, method::Union{String, Symbol}, url, h=Header[]; !hasheader(req.headers, "upgrade") setheader(req.headers, "transfer-encoding", "chunked") end - stream = _open_stream(conn, req, decompress, readtimeout, allocator) + stream = _open_stream(conn, req, decompress, readtimeout) stream_start = context === nothing ? 0.0 : time() try if redirect && issafe(method_str) @@ -149,7 +135,7 @@ function open(f::Function, method::Union{String, Symbol}, url, h=Header[]; end olduri = uri newuri = resolvereference(makeuri(uri), redirect_url) - uri = parseuri(newuri, nothing, allocator) + uri = parseuri(newuri, nothing) method_str = newmethod(method_str, resp.status, redirect_method) if forwardheaders headers = filter(headers) do (header, _) diff --git a/src/client/redirects.jl b/src/client/redirects.jl index 0cc441c2..00015d22 100644 --- a/src/client/redirects.jl +++ b/src/client/redirects.jl @@ -38,7 +38,7 @@ function newmethod(request_method, response_status, redirect_method) return "GET" end -function with_redirect(f, allocator, method, uri, headers=nothing, body=nothing, redirect::Bool=true, redirect_limit::Int=3, redirect_method=nothing, forwardheaders::Bool=true; context=nothing) +function with_redirect(f, method, uri, headers=nothing, body=nothing, redirect::Bool=true, redirect_limit::Int=3, redirect_method=nothing, forwardheaders::Bool=true; context=nothing) if !redirect || redirect_limit == 0 # no redirecting return f(method, uri, headers, body) @@ -63,7 +63,7 @@ function with_redirect(f, allocator, method, uri, headers=nothing, body=nothing, # follow redirect olduri = uri newuri = resolvereference(makeuri(uri), location) - uri = parseuri(newuri, nothing, allocator) + uri = parseuri(newuri, nothing) method = newmethod(method, resp.status, redirect_method) body = method == "GET" ? nothing : body if forwardheaders @@ -86,3 +86,6 @@ function with_redirect(f, allocator, method, uri, headers=nothing, body=nothing, end @assert false "Unreachable!" end + +# compatibility: old callers that pass allocator as first arg +with_redirect(f, _allocator, method, uri, args...; kw...) = with_redirect(f, method, uri, args...; kw...) diff --git a/src/client/request.jl b/src/client/request.jl index 1c1aa7a9..334e16ca 100644 --- a/src/client/request.jl +++ b/src/client/request.jl @@ -49,7 +49,7 @@ function with_request( # create request mutable_headers = (headers isa AbstractVector{<:Pair} && !copyheaders) ? headers : nothing req_headers = mkreqheaders(headers, copyheaders) - req = Request(method, path, req_headers, nothing, http2, client.settings.allocator; context=context) + req = Request(method, path, req_headers, nothing, http2; context=context) # add headers to request h = req.headers if http2 diff --git a/src/client/retry.jl b/src/client/retry.jl index 83df096a..1fea7958 100644 --- a/src/client/retry.jl +++ b/src/client/retry.jl @@ -24,30 +24,6 @@ isrecoverable(ex::Sockets.DNSError) = (ex.code == Base.UV_EAI_AGAIN) isrecoverable(::AWSError) = true isrecoverable(::Exception) = false -const on_acquired = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_acquired(retry_strategy, error_code, retry_token, fut_ptr) - fut = unsafe_pointer_to_objref(fut_ptr) - if error_code != 0 - notify(fut, DontRetry(CapturedException(aws_error(error_code), Base.backtrace()))) - else - notify(fut, retry_token) - end - return -end - -const retry_ready = Ref{Ptr{Cvoid}}(C_NULL) - -function c_retry_ready(token, error_code::Cint, fut_ptr) - fut = unsafe_pointer_to_objref(fut_ptr) - if error_code != 0 - notify(fut, DontRetry(CapturedException(aws_error(error_code), Base.backtrace()))) - else - notify(fut, token) - end - return -end - function _default_retryable(method, err, retryable_body::Bool, retry_non_idempotent::Bool) retryable_body || return false method === nothing && return false @@ -75,16 +51,16 @@ function _retry_error_type(err) if err isa StatusError status = err.status if status == 429 - return AWS_RETRY_ERROR_TYPE_THROTTLING + return AwsIO.RetryErrorType.THROTTLING elseif 500 <= status < 600 - return AWS_RETRY_ERROR_TYPE_SERVER_ERROR + return AwsIO.RetryErrorType.SERVER_ERROR elseif 400 <= status < 500 - return AWS_RETRY_ERROR_TYPE_CLIENT_ERROR + return AwsIO.RetryErrorType.CLIENT_ERROR else - return AWS_RETRY_ERROR_TYPE_TRANSIENT + return AwsIO.RetryErrorType.TRANSIENT end end - return AWS_RETRY_ERROR_TYPE_TRANSIENT + return AwsIO.RetryErrorType.TRANSIENT end function _set_nretries!(x, nretries::Int) @@ -114,42 +90,22 @@ function with_retry_token( req_ref=nothing, context=nothing, ) - retry_token = Ptr{aws_retry_token}(C_NULL) + retry_token = nothing partition = client.settings.retry_partition - partition_ref = Ref{aws_byte_cursor}() - partition_ptr = C_NULL - if partition !== nothing - GC.@preserve partition begin - partition_ref[] = aws_byte_cursor_from_c_str(partition) - end - partition_ptr = partition_ref - end - use_retry_strategy = retry_delays === nothing && partition !== nothing && client.retry_strategy != C_NULL + use_retry_strategy = retry_delays === nothing && partition !== nothing && client.retry_strategy !== nothing # If max_retries is 0, we don't need to bother with any retrying max_retries = client.settings.max_retries if max_retries == 0 - if context === nothing - try - return f() - catch e - if logerrors - url = uri === nothing ? nothing : (uri isa aws_uri ? makeuri(uri) : uri) - @error "HTTP request error" exception=(e, catch_backtrace()) method=method url=url logtag=logtag - end - rethrow() - end - end - start_time = time() + start_time = context !== nothing ? time() : 0.0 try return f() catch e if logerrors - url = uri === nothing ? nothing : (uri isa aws_uri ? makeuri(uri) : uri) - @error "HTTP request error" exception=(e, catch_backtrace()) method=method url=url logtag=logtag + @error "HTTP request error" exception=(e, catch_backtrace()) method=method url=uri logtag=logtag end rethrow() finally - _record_layer!(context, :retrylayer, start_time) + context !== nothing && _record_layer!(context, :retrylayer, start_time) end end retry_check_fn = retry_check === nothing ? nothing : retry_check @@ -162,10 +118,10 @@ function with_retry_token( ret = f() context === nothing || _record_layer!(context, :retrylayer, attempt_start) _set_nretries!(ret, nretries) - if retry_token != C_NULL - aws_retry_token_record_success(retry_token) != 0 && aws_throw_error() - aws_retry_token_release(retry_token) - retry_token = C_NULL + if retry_token !== nothing + AwsIO.retry_token_record_success(retry_token) + AwsIO.retry_token_release!(retry_token) + retry_token = nothing end return ret catch e @@ -178,8 +134,7 @@ function with_retry_token( end if logerrors log_err = err isa DontRetry ? err.error : err - url = uri === nothing ? nothing : (uri isa aws_uri ? makeuri(uri) : uri) - @error "HTTP request error" exception=(log_err, catch_backtrace()) method=method url=url logtag=logtag + @error "HTTP request error" exception=(log_err, catch_backtrace()) method=method url=uri logtag=logtag end if err isa DontRetry if stream !== nothing && iserror(stream.response.status) && stream.bufferstream !== nothing @@ -188,17 +143,17 @@ function with_retry_token( end err = err.error _set_nretries!(err, nretries) - if retry_token != C_NULL - aws_retry_token_release(retry_token) - retry_token = C_NULL + if retry_token !== nothing + AwsIO.retry_token_release!(retry_token) + retry_token = nothing end throw(err) end if nretries >= max_retries _set_nretries!(err, nretries) - if retry_token != C_NULL - aws_retry_token_release(retry_token) - retry_token = C_NULL + if retry_token !== nothing + AwsIO.retry_token_release!(retry_token) + retry_token = nothing end throw(err) end @@ -217,33 +172,50 @@ function with_retry_token( end if !retry _set_nretries!(err, nretries) - if retry_token != C_NULL - aws_retry_token_release(retry_token) - retry_token = C_NULL + if retry_token !== nothing + AwsIO.retry_token_release!(retry_token) + retry_token = nothing end throw(err) end if use_retry_strategy try - if retry_token == C_NULL - fut = Future{Ptr{aws_retry_token}}() - GC.@preserve fut begin - rc = aws_retry_strategy_acquire_retry_token(client.retry_strategy, partition_ptr, on_acquired[], pointer_from_objref(fut), UInt64(client.settings.retry_timeout_ms)) - rc != 0 && aws_throw_error() - retry_token = wait(fut) - end - end - fut = Future{Ptr{aws_retry_token}}() - error_type = _retry_error_type(err) - GC.@preserve fut begin - rc = aws_retry_strategy_schedule_retry(retry_token, error_type, retry_ready[], pointer_from_objref(fut)) - rc != 0 && aws_throw_error() + if retry_token === nothing + fut = Future{Any}() + AwsIO.retry_strategy_acquire_token!( + client.retry_strategy, + partition, + (token, error_code, _) -> begin + if error_code != 0 + notify(fut, DontRetry(CapturedException(aws_error(error_code), Base.backtrace()))) + else + notify(fut, token) + end + end, + nothing, + UInt64(client.settings.retry_timeout_ms) + ) retry_token = wait(fut) end + fut = Future{Any}() + error_type = _retry_error_type(err) + AwsIO.retry_token_schedule_retry( + retry_token, + error_type, + (token, error_code, _) -> begin + if error_code != 0 + notify(fut, DontRetry(CapturedException(aws_error(error_code), Base.backtrace()))) + else + notify(fut, token) + end + end, + nothing + ) + retry_token = wait(fut) catch - if retry_token != C_NULL - aws_retry_token_release(retry_token) - retry_token = C_NULL + if retry_token !== nothing + AwsIO.retry_token_release!(retry_token) + retry_token = nothing end _set_nretries!(err, nretries) throw(err) diff --git a/src/client/stream.jl b/src/client/stream.jl index 71492dd1..994780a6 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -1,138 +1,12 @@ export Stream, closebody, isaborted, readall!, setstatus -const on_response_headers = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_response_headers(aws_stream_ptr, header_block, header_array::Ptr{aws_http_header}, num_headers, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - if header_block == AWS_HTTP_HEADER_BLOCK_TRAILING - trailers = stream.response.trailers - if trailers === nothing - trailers = Headers(stream.response.allocator) - stream.response.trailers = trailers - end - addheaders(trailers, header_array, num_headers) - else - headers = stream.response.headers - addheaders(headers, header_array, num_headers) - end - return Cint(0) -end - writebuf(body, maxsize=length(body) == 0 ? typemax(Int64) : length(body)) = Base.GenericIOBuffer{AbstractVector{UInt8}}(body, true, true, true, false, maxsize) -function aws_http2_stream_add_trailing_headers(http2_stream::Ptr{aws_http_stream}, trailing_headers::Ptr{aws_http_headers}) - return ccall((:aws_http2_stream_add_trailing_headers, LibAwsHTTPFork.libaws_c_http_jq), - Cint, (Ptr{aws_http_stream}, Ptr{aws_http_headers}), http2_stream, trailing_headers) -end - -const on_response_header_block_done = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_response_header_block_done(aws_stream_ptr, header_block, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - if aws_http_stream_get_incoming_response_status(aws_stream_ptr, FieldRef(stream, :status)) != 0 - return aws_raise_error(aws_last_error()) - end - stream.response.status = stream.status - # if this is the end of the main header block, prepare our response body to be written to, otherwise return - if header_block != AWS_HTTP_HEADER_BLOCK_MAIN - return Cint(0) - end - if stream.decompress !== false - val = getheader(stream.response.headers, "content-encoding") - stream.decompress = val !== nothing && val == "gzip" - end - notify(stream.headers_ready) - return Cint(0) -end - -const on_response_body = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_response_body(aws_stream_ptr, data::Ptr{aws_byte_cursor}, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - bc = unsafe_load(data) - stream.response.metrics.response_body_length += bc.len - if stream.decompress - if stream.gzipstream === nothing - stream.bufferstream = b = Base.BufferStream() - stream.gzipstream = g = CodecZlib.GzipDecompressorStream(b) - unsafe_write(g, bc.ptr, bc.len) - else - unsafe_write(stream.gzipstream, bc.ptr, bc.len) - end - else - if stream.bufferstream === nothing - stream.bufferstream = b = Base.BufferStream() - unsafe_write(b, bc.ptr, bc.len) - else - unsafe_write(stream.bufferstream, bc.ptr, bc.len) - end - end - return Cint(0) -end - -const on_metrics = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_metrics(aws_stream_ptr, metrics::Ptr{aws_http_stream_metrics}, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - m = unsafe_load(metrics) - if m.send_start_timestamp_ns != -1 - stream.response.metrics.stream_metrics = m - end - return -end - -const on_complete = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_complete(aws_stream_ptr, error_code, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - if stream.gzipstream !== nothing - close(stream.gzipstream) - end - if stream.bufferstream !== nothing - close(stream.bufferstream) - end - if error_code != 0 - if error_code == AWS_ERROR_HTTP_RESPONSE_FIRST_BYTE_TIMEOUT && stream.readtimeout > 0 - notify(stream.fut, TimeoutError(stream.readtimeout)) - else - notify(stream.fut, CapturedException(aws_error(error_code), Base.backtrace())) - end - else - notify(stream.fut, nothing) - end - notify(stream.headers_ready) - release_stream!(stream) - return -end - -const on_destroy = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_destroy(stream) - return -end - -const on_stream_acquired = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_stream_acquired(aws_stream_ptr, error_code, fut_ptr) - fut = unsafe_pointer_to_objref(fut_ptr) - if error_code != 0 - notify(fut, CapturedException(aws_error(error_code), Base.backtrace())) - else - notify(fut, aws_stream_ptr) - end - return -end - -if !@isdefined aws_websocket_server_upgrade_options - const aws_websocket_server_upgrade_options = Ptr{Cvoid} -end - mutable struct Stream{T} <: IO - allocator::Ptr{aws_allocator} decompress::Union{Nothing, Bool} http2::Bool server_side::Bool - status::Cint # used as a ref + status::Int fut::Future{Nothing} chunk::Union{Nothing, InputStream} final_chunk_written::Bool @@ -150,19 +24,11 @@ mutable struct Stream{T} <: IO on_complete::Union{Nothing, Function} released::Bool # remaining fields are initially undefined - ptr::Ptr{aws_http_stream} - connection::T # Connection{F, S} (in servers.jl) - request_options::aws_http_make_request_options + aws_stream::Any # H1Stream or H2Stream from AwsHTTP + connection::T response::Response - method::aws_byte_cursor - path::aws_byte_cursor - request_handler_options::aws_http_request_handler_options request::Request - http2_stream_write_data_options::aws_http2_stream_write_data_options - chunk_options::aws_http1_chunk_options - websocket_options::aws_websocket_server_upgrade_options - Stream{T}(allocator, decompress, http2, server_side::Bool=false) where {T} = new{T}( - allocator, + Stream{T}(decompress, http2, server_side::Bool=false) where {T} = new{T}( decompress, http2, server_side, @@ -186,46 +52,18 @@ mutable struct Stream{T} <: IO ) end -Base.hash(s::Stream, h::UInt) = hash(s.ptr, h) - -getrequest(s::Stream) = s.request - -const ACTIVE_STREAMS_LOCK = ReentrantLock() -const ACTIVE_STREAMS = IdDict{Stream, Bool}() - -function retain_stream!(s::Stream) - lock(ACTIVE_STREAMS_LOCK) - try - ACTIVE_STREAMS[s] = true - finally - unlock(ACTIVE_STREAMS_LOCK) - end - return -end +# compatibility: 4-arg version for callers that still pass allocator +Stream{T}(allocator, decompress, http2, server_side::Bool=false) where {T} = + Stream{T}(decompress, http2, server_side) -function release_stream!(s::Stream) - lock(ACTIVE_STREAMS_LOCK) - try - pop!(ACTIVE_STREAMS, s, nothing) - finally - unlock(ACTIVE_STREAMS_LOCK) - end - return -end +Base.hash(s::Stream, h::UInt) = hash(objectid(s), h) -function release_stream_ptr!(s::Stream) - if isdefined(s, :ptr) && s.ptr != C_NULL && !s.released - aws_http_stream_release(s.ptr) - s.released = true - s.ptr = Ptr{aws_http_stream}(C_NULL) - end - return -end +getrequest(s::Stream) = s.request function _with_http2_connection(f::Function, stream::Stream) - stream.ptr == C_NULL && throw(ArgumentError("HTTP stream is not initialized")) - conn = aws_http_stream_get_connection(stream.ptr) - return f(_ensure_http2_connection(conn)) + !isdefined(stream, :aws_stream) && throw(ArgumentError("HTTP stream is not initialized")) + conn = stream.aws_stream.owning_connection + return f(conn) end http2_ping(stream::Stream; data=nothing) = _with_http2_connection(conn -> http2_ping(conn; data=data), stream) @@ -240,25 +78,14 @@ http2_update_window(stream::Stream, increment::Integer) = _with_http2_connection(conn -> http2_update_window(conn, increment), stream) function update_window(stream::Stream, increment::Integer) - stream.ptr == C_NULL && throw(ArgumentError("HTTP stream is not initialized")) + !isdefined(stream, :aws_stream) && throw(ArgumentError("HTTP stream is not initialized")) increment < 0 && throw(ArgumentError("increment must be >= 0")) if stream.http2 increment > HTTP2_MAX_WINDOW_SIZE && throw(ArgumentError("increment must be <= $(HTTP2_MAX_WINDOW_SIZE)")) + AwsHTTP.h2_stream_update_window!(stream.aws_stream, UInt32(increment)) else - increment > typemax(Csize_t) && throw(ArgumentError("increment must be <= $(typemax(Csize_t))")) - end - aws_http_stream_update_window(stream.ptr, Csize_t(increment)) - return -end - -const on_stream_write_on_complete = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_stream_write_on_complete(aws_stream_ptr, error_code, fut_ptr) - fut = unsafe_pointer_to_objref(fut_ptr) - if error_code != 0 - notify(fut, CapturedException(aws_error(error_code), Base.backtrace())) - else - notify(fut, nothing) + increment > typemax(UInt64) && throw(ArgumentError("increment too large")) + AwsHTTP.http_stream_update_window(stream.aws_stream, UInt64(increment)) end return end @@ -269,51 +96,138 @@ function writechunk(s::Stream, chunk::RequestBodyTypes) isdefined(s.response, :request) && s.response.request.method in ("POST", "PUT", "PATCH")) "write is only allowed for POST, PUT, and PATCH requests" end - s.chunk = InputStream(s.allocator, chunk) + s.chunk = InputStream() + is = s.chunk + if chunk isa AbstractVector{UInt8} + is.bodyref = chunk + is.bodylen = length(chunk) + elseif chunk isa AbstractString + is.bodyref = chunk + is.bodylen = sizeof(chunk) + else + is.bodyref = chunk + is.bodylen = nbytes(chunk) === nothing ? 0 : nbytes(chunk) + end fut = Future{Nothing}() if s.http2 - s.http2_stream_write_data_options = aws_http2_stream_write_data_options( - s.chunk.ptr, - chunk == "", - on_stream_write_on_complete[], - pointer_from_objref(fut) - ) - aws_http2_stream_write_data(s.ptr, FieldRef(s, :http2_stream_write_data_options)) != 0 && aws_throw_error() + data = if chunk isa AbstractString + Vector{UInt8}(codeunits(chunk)) + elseif chunk isa AbstractVector{UInt8} + chunk + else + UInt8[] + end + is_final = isempty(data) + AwsHTTP.h2_stream_write_data!(s.aws_stream, data; + end_stream=is_final, + on_complete=(err, ud) -> begin + if err != 0 + notify(fut, CapturedException(aws_error(err), Base.backtrace())) + else + notify(fut, nothing) + end + end, + ) != 0 && aws_throw_error() else - s.chunk_options = aws_http1_chunk_options( - s.chunk.ptr, - s.chunk.bodylen, - C_NULL, - 0, - on_stream_write_on_complete[], - pointer_from_objref(fut) + data = if chunk isa AbstractString + IOBuffer(codeunits(chunk)) + elseif chunk isa AbstractVector{UInt8} + IOBuffer(chunk) + else + IOBuffer(UInt8[]) + end + h1chunk = AwsHTTP.h1_chunk_new(data, is.bodylen; + on_complete=(stream, err, ud) -> begin + if err != 0 + notify(fut, CapturedException(aws_error(err), Base.backtrace())) + else + notify(fut, nothing) + end + end, ) - aws_http1_stream_write_chunk(s.ptr, FieldRef(s, :chunk_options)) != 0 && aws_throw_error() + AwsHTTP.h1_stream_write_chunk!(s.aws_stream, h1chunk) != 0 && aws_throw_error() + _h1_flush_outgoing!(s) end wait(fut) if isdefined(s, :response) && s.response !== nothing if s.server_side - s.response.metrics.response_body_length += s.chunk.bodylen + s.response.metrics.response_body_length += is.bodylen else - s.response.metrics.request_body_length += s.chunk.bodylen + s.response.metrics.request_body_length += is.bodylen end end - return s.chunk.bodylen + return is.bodylen end function _ensure_response!(s::Stream) if !isdefined(s, :response) || s.response === nothing - s.response = Response(200, nothing, nothing, s.http2, s.allocator) + s.response = Response(200, nothing, nothing, s.http2) end return s.response end +# Drive H1 outgoing encoder and send encoded bytes through the channel pipeline. +# Must be called after operations that produce outgoing data (send_response, write_chunk, activate). +# For H1 only; H2 encoding is handled differently. +function _h1_flush_outgoing!(s::Stream) + !isdefined(s, :aws_stream) && return + h1conn = s.aws_stream.owning_connection + slot = h1conn.slot + slot === nothing && return + channel = slot.channel + channel === nothing && return + if !AwsIO.channel_thread_is_callers_thread(channel) + fut = Future{Nothing}() + task = AwsIO.ChannelTask((task, ctx, status) -> begin + status == AwsIO.TaskStatus.RUN_READY || return notify(fut, nothing) + try + _h1_flush_outgoing!(s) + notify(fut, nothing) + catch e + notify(fut, CapturedException(e, catch_backtrace())) + end + return nothing + end, nothing, "http_h1_flush_outgoing") + AwsIO.channel_schedule_task_now!(channel, task) + wait(fut) + return + end + while true + status, encoded = AwsHTTP.h1_connection_encode_outgoing!(h1conn) + status != AwsHTTP.OP_SUCCESS && throw(AWSError("H1 encoding failed")) + isempty(encoded) && break + msg = AwsIO.IoMessage(length(encoded)) + buf = msg.message_data + @inbounds for i in 1:length(encoded) + buf.mem[i] = encoded[i] + end + buf.len = Csize_t(length(encoded)) + result = AwsIO.channel_slot_send_message(slot, msg, AwsIO.ChannelDirection.WRITE) + result isa AwsIO.ErrorResult && throw(AWSError("channel slot send failed")) + end + return +end + function _send_response!(s::Stream) if s.response_started return s.response end resp = _ensure_response!(s) - aws_http_stream_send_response(s.ptr, resp.ptr) != 0 && aws_throw_error() + msg = getfield(resp, :msg) + if s.http2 && AwsHTTP.http_message_get_protocol_version(msg) != AwsHTTP.HttpVersion.HTTP_2 + converted = AwsHTTP.http2_message_new_from_http1(msg) + converted === nothing && aws_throw_error() + setfield!(resp, :msg, converted) + msg = converted + end + if s.http2 + # H2 sends response via H2Stream API + conn = s.aws_stream.owning_connection + AwsHTTP.h2_stream_send_response!(s.aws_stream, conn, msg) != 0 && aws_throw_error() + else + AwsHTTP.h1_stream_send_response!(s.aws_stream, msg) != 0 && aws_throw_error() + _h1_flush_outgoing!(s) + end s.response_started = true return resp end @@ -325,9 +239,12 @@ function _server_startwrite(s::Stream) resp = _ensure_response!(s) if s.request.method == "HEAD" s.ignore_writes = true - setinputstream!(resp, nothing) + _head_response!(resp) end if s.http2 + if !s.ignore_writes && resp.inputstream === nothing && hasheader(resp.headers, "content-length") + removeheader(resp.headers, "content-length") + end if !s.response_started _send_response!(s) end @@ -362,7 +279,7 @@ function _server_closewrite(s::Stream) return end if resp.trailers !== nothing - aws_http2_stream_add_trailing_headers(s.ptr, resp.trailers.ptr) != 0 && aws_throw_error() + AwsHTTP.h2_stream_add_trailing_headers!(s.aws_stream, resp.trailers.hdrs) != 0 && aws_throw_error() end writechunk(s, "") s.final_chunk_written = true @@ -386,7 +303,7 @@ function _server_closewrite(s::Stream) return end if resp.trailers !== nothing - aws_http1_stream_add_chunked_trailer(s.ptr, resp.trailers.ptr) != 0 && aws_throw_error() + AwsHTTP.h1_stream_add_chunked_trailer!(s.aws_stream, resp.trailers.hdrs) != 0 && aws_throw_error() end writechunk(s, "") s.final_chunk_written = true @@ -399,8 +316,17 @@ function _activate_stream!(s::Stream) return end if !s.activated - aws_http_stream_activate(s.ptr) != 0 && aws_throw_error() - s.activated = true + if s.http2 + conn = s.aws_stream.owning_connection + status, _ = AwsHTTP.h2_stream_activate!(s.aws_stream, conn) + status != 0 && aws_throw_error() + s.activated = true + AwsHTTP._h2_connection_flush_outgoing!(conn) + else + AwsHTTP.h1_stream_activate!(s.aws_stream) != 0 && aws_throw_error() + s.activated = true + _h1_flush_outgoing!(s) + end end return end @@ -605,7 +531,7 @@ function closeread(s::Stream) rethrow() end finally - release_stream_ptr!(s) + s.released = true end return s.response end @@ -648,7 +574,7 @@ function setheaderifabsent(s::Stream, k, v) end function addtrailer(s::Stream, headers::Headers) - s.ptr == C_NULL && error("stream is not initialized") + !isdefined(s, :aws_stream) && error("stream is not initialized") if s.server_side resp = _ensure_response!(s) if resp.trailers === nothing @@ -661,325 +587,329 @@ function addtrailer(s::Stream, headers::Headers) return end if s.http2 - aws_http2_stream_add_trailing_headers(s.ptr, headers.ptr) != 0 && aws_throw_error() + AwsHTTP.h2_stream_add_trailing_headers!(s.aws_stream, headers.hdrs) != 0 && aws_throw_error() else - aws_http1_stream_add_chunked_trailer(s.ptr, headers.ptr) != 0 && aws_throw_error() + AwsHTTP.h1_stream_add_chunked_trailer!(s.aws_stream, headers.hdrs) != 0 && aws_throw_error() end return end function addtrailer(s::Stream, h::Pair) - trailers = Headers(s.allocator) + trailers = Headers() addheader(trailers, String(h.first), String(h.second)) return addtrailer(s, trailers) end function addtrailer(s::Stream, h::AbstractVector{<:Pair}) - trailers = Headers(s.allocator) + trailers = Headers() for (k, v) in h addheader(trailers, String(k), String(v)) end return addtrailer(s, trailers) end -function with_stream_manager(client::Client, req::Request, chunkedbody, on_stream_response_body, decompress, readtimeout, allocator; context=nothing) - if context === nothing - stream = Stream{Nothing}(allocator, decompress, true, false) - stream.readtimeout = readtimeout - if on_stream_response_body !== nothing - stream.bufferstream = Base.BufferStream() - end - acquire_fut = Future{Ptr{aws_http_stream}}() - GC.@preserve stream acquire_fut begin - stream.request_options = aws_http_make_request_options( - 1, - req.ptr, - pointer_from_objref(stream), - on_response_headers[], - on_response_header_block_done[], - on_response_body[], - on_metrics[], - on_complete[], - on_destroy[], - chunkedbody !== nothing, # http2_use_manual_data_writes - readtimeout * 1000 # response_first_byte_timeout_ms - ) - stream.response = resp = Response(0, nothing, nothing, true, allocator) - resp.metrics = RequestMetrics() - resp.request = req - resp.metrics.request_body_length = bodylen(req) - acquire_opts = aws_http2_stream_manager_acquire_stream_options( - on_stream_acquired[], - pointer_from_objref(acquire_fut), - FieldRef(stream, :request_options), - ) - aws_http2_stream_manager_acquire_stream(client.http2_stream_manager, Ref(acquire_opts)) - stream_ptr = wait(acquire_fut) - stream.ptr = stream_ptr - stream.activated = true - try - if chunkedbody !== nothing - foreach(chunk -> writechunk(stream, chunk), chunkedbody) - writechunk(stream, "") - end - if on_stream_response_body !== nothing - try - while !eof(stream.bufferstream) - on_stream_response_body(resp, _readavailable(stream.bufferstream)) - end - try - wait(stream.fut) - catch e - e isa HTTPError && rethrow() - throw(RequestError(req, e)) - end - catch e - rethrow(DontRetry(e)) - end - else - try - wait(stream.fut) - catch e - e isa HTTPError && rethrow() - throw(RequestError(req, e)) - end - if stream.bufferstream !== nothing - resp.body = _readavailable(stream.bufferstream) - else - resp.body = UInt8[] - end - end - return resp - finally - aws_http_stream_release(stream_ptr) - stream.released = true - stream.ptr = Ptr{aws_http_stream}(C_NULL) +# ─── Callback builders ─── +# These create the closure callbacks for AwsHTTP stream options. +# Each closure captures the HTTP.Stream and manipulates it directly +# when the AwsHTTP library fires the callback. + +function _on_response_headers(stream::Stream) + return (aws_stream, header_block, headers_vec, user_data) -> begin + if header_block == AwsHTTP.HttpHeaderBlock.TRAILING + trailers = stream.response.trailers + if trailers === nothing + trailers = Headers() + stream.response.trailers = trailers end + for h in headers_vec + addheader(trailers, h) + end + else + hdrs = stream.response.headers + for h in headers_vec + addheader(hdrs, h) + end + end + return AwsHTTP.OP_SUCCESS + end +end + +function _on_response_header_block_done(stream::Stream) + return (aws_stream, header_block, user_data) -> begin + stream.status = aws_stream.response_status + stream.response.status = stream.status + if header_block != AwsHTTP.HttpHeaderBlock.MAIN + return AwsHTTP.OP_SUCCESS + end + if stream.decompress !== false + val = getheader(stream.response.headers, "content-encoding") + stream.decompress = val !== nothing && val == "gzip" + end + notify(stream.headers_ready) + return AwsHTTP.OP_SUCCESS + end +end + +function _on_response_body(stream::Stream) + return (aws_stream, data::AbstractVector{UInt8}, user_data) -> begin + stream.response.metrics.response_body_length += length(data) + if stream.decompress + if stream.gzipstream === nothing + stream.bufferstream = b = Base.BufferStream() + stream.gzipstream = g = CodecZlib.GzipDecompressorStream(b) + write(g, data) + else + write(stream.gzipstream, data) + end + else + if stream.bufferstream === nothing + stream.bufferstream = b = Base.BufferStream() + write(b, data) + else + write(stream.bufferstream, data) + end + end + return AwsHTTP.OP_SUCCESS + end +end + +function _on_metrics(stream::Stream) + return (aws_stream, metrics, user_data) -> begin + if metrics.send_start_timestamp_ns != -1 + stream.response.metrics.stream_metrics = metrics end + return nothing end +end + +function _on_complete(stream::Stream) + return (aws_stream, error_code, user_data) -> begin + if stream.gzipstream !== nothing + close(stream.gzipstream) + end + if stream.bufferstream !== nothing + close(stream.bufferstream) + end + if error_code != 0 && !stream.http2 && isdefined(stream, :connection) && stream.connection !== nothing + AwsHTTP.http_connection_close(stream.connection) + end + if error_code != 0 + if error_code == AwsHTTP.ERROR_HTTP_RESPONSE_FIRST_BYTE_TIMEOUT && stream.readtimeout > 0 + notify(stream.fut, TimeoutError(stream.readtimeout)) + else + notify(stream.fut, CapturedException(aws_error(error_code), Base.backtrace())) + end + else + notify(stream.fut, nothing) + end + notify(stream.headers_ready) + stream.released = true + return nothing + end +end + +function _make_request_options(stream::Stream, req::Request; chunkedbody=nothing, readtimeout=0) + msg = getfield(req, :msg) + return AwsHTTP.HttpMakeRequestOptions(; + request=msg, + on_response_headers=_on_response_headers(stream), + on_response_header_block_done=_on_response_header_block_done(stream), + on_response_body=_on_response_body(stream), + on_metrics=_on_metrics(stream), + on_complete=_on_complete(stream), + http2_use_manual_data_writes=(chunkedbody !== nothing), + response_first_byte_timeout_ms=UInt64(readtimeout * 1000), + ) +end + +# ─── with_stream_manager: H2 stream manager path ─── - start_time = time() - stream = Stream{Nothing}(allocator, decompress, true, false) +function with_stream_manager(client::Client, req::Request, chunkedbody, on_stream_response_body, decompress, readtimeout; context=nothing) + start_time = context !== nothing ? time() : 0.0 + stream = Stream{Nothing}(decompress, true, false) stream.readtimeout = readtimeout if on_stream_response_body !== nothing stream.bufferstream = Base.BufferStream() end - acquire_fut = Future{Ptr{aws_http_stream}}() - GC.@preserve stream acquire_fut begin - stream.request_options = aws_http_make_request_options( - 1, - req.ptr, - pointer_from_objref(stream), - on_response_headers[], - on_response_header_block_done[], - on_response_body[], - on_metrics[], - on_complete[], - on_destroy[], - chunkedbody !== nothing, # http2_use_manual_data_writes - readtimeout * 1000 # response_first_byte_timeout_ms - ) - stream.response = resp = Response(0, nothing, nothing, true, allocator) - resp.metrics = RequestMetrics() - resp.request = req - resp.metrics.request_body_length = bodylen(req) - acquire_opts = aws_http2_stream_manager_acquire_stream_options( - on_stream_acquired[], - pointer_from_objref(acquire_fut), - FieldRef(stream, :request_options), - ) - aws_http2_stream_manager_acquire_stream(client.http2_stream_manager, Ref(acquire_opts)) - stream_ptr = wait(acquire_fut) - stream.ptr = stream_ptr - stream.activated = true - try - if chunkedbody !== nothing - foreach(chunk -> writechunk(stream, chunk), chunkedbody) - writechunk(stream, "") + stream.response = resp = Response(0, nothing, nothing, true) + resp.metrics = RequestMetrics() + resp.request = req + resp.metrics.request_body_length = bodylen(req) + request_options = _make_request_options(stream, req; chunkedbody=chunkedbody, readtimeout=readtimeout) + + # Acquire a connection from the H2 stream manager + acquire_ch = Base.Channel{Any}(1) + AwsHTTP.http2_stream_manager_acquire_stream(client.http2_stream_manager; + callback=(conn_or_nothing, error_code, ud) -> begin + if error_code != 0 || conn_or_nothing === nothing + put!(acquire_ch, error_code != 0 ? error_code : AwsHTTP.ERROR_HTTP_CONNECTION_CLOSED) + else + put!(acquire_ch, conn_or_nothing) end - if on_stream_response_body !== nothing - try - while !eof(stream.bufferstream) - on_stream_response_body(resp, _readavailable(stream.bufferstream)) - end - try - wait(stream.fut) - catch e - e isa HTTPError && rethrow() - throw(RequestError(req, e)) - end - catch e - rethrow(DontRetry(e)) + end, + ) + acquired = take!(acquire_ch) + if acquired isa Integer + throw(CapturedException(aws_error(acquired), Base.backtrace())) + end + connection = acquired + + # Create stream on the acquired connection + aws_stream = AwsHTTP.http_connection_make_request(connection, request_options) + aws_stream === nothing && aws_throw_error() + stream.aws_stream = aws_stream + + # Activate stream + _activate_stream!(stream) + + try + # Write chunked body if provided + if chunkedbody !== nothing + foreach(chunk -> writechunk(stream, chunk), chunkedbody) + writechunk(stream, "") + end + if on_stream_response_body !== nothing + try + while !eof(stream.bufferstream) + on_stream_response_body(resp, _readavailable(stream.bufferstream)) end - else try wait(stream.fut) catch e e isa HTTPError && rethrow() throw(RequestError(req, e)) end - if stream.bufferstream !== nothing - resp.body = _readavailable(stream.bufferstream) - else - resp.body = UInt8[] - end + catch e + rethrow(DontRetry(e)) + end + else + try + wait(stream.fut) + catch e + e isa HTTPError && rethrow() + throw(RequestError(req, e)) + end + if stream.bufferstream !== nothing + resp.body = _readavailable(stream.bufferstream) + else + resp.body = UInt8[] end - return resp - finally - aws_http_stream_release(stream_ptr) - stream.released = true - stream.ptr = Ptr{aws_http_stream}(C_NULL) + end + return resp + finally + stream.released = true + AwsHTTP.http2_stream_manager_release_stream(client.http2_stream_manager, connection) + if context !== nothing _record_layer!(context, :streamlayer, start_time) end end end -function with_stream(conn::Ptr{aws_http_connection}, req::Request, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, allocator; context=nothing) - if context === nothing - stream = Stream{Nothing}(allocator, decompress, http2, false) - stream.readtimeout = readtimeout - if on_stream_response_body !== nothing - stream.bufferstream = Base.BufferStream() - end - GC.@preserve stream begin - stream.request_options = aws_http_make_request_options( - 1, - req.ptr, - pointer_from_objref(stream), - on_response_headers[], - on_response_header_block_done[], - on_response_body[], - on_metrics[], - on_complete[], - on_destroy[], - http2 && chunkedbody !== nothing, # http2_use_manual_data_writes - readtimeout * 1000 # response_first_byte_timeout_ms - ) - stream_ptr = aws_http_connection_make_request(conn, FieldRef(stream, :request_options)) - stream_ptr == C_NULL && aws_throw_error() - stream.ptr = stream_ptr - http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 - stream.response = resp = Response(0, nothing, nothing, http2, allocator) - resp.metrics = RequestMetrics() - resp.request = req - resp.metrics.request_body_length = bodylen(req) - try - aws_http_stream_activate(stream_ptr) != 0 && aws_throw_error() - # write chunked body if provided - if chunkedbody !== nothing - foreach(chunk -> writechunk(stream, chunk), chunkedbody) - # write final chunk - writechunk(stream, "") - end - if on_stream_response_body !== nothing - try - while !eof(stream.bufferstream) - on_stream_response_body(resp, _readavailable(stream.bufferstream)) - end - try - wait(stream.fut) - catch e - e isa HTTPError && rethrow() - throw(RequestError(req, e)) - end - catch e - rethrow(DontRetry(e)) - end - else - try - wait(stream.fut) - catch e - e isa HTTPError && rethrow() - throw(RequestError(req, e)) - end - if stream.bufferstream !== nothing - resp.body = _readavailable(stream.bufferstream) - else - resp.body = UInt8[] - end - end - return resp - finally - aws_http_stream_release(stream_ptr) - stream.released = true - stream.ptr = Ptr{aws_http_stream}(C_NULL) - end - end # GC.@preserve - end +# compatibility: 8-arg version for callers that still pass allocator +with_stream_manager(client, req, chunkedbody, on_stream_response_body, decompress, readtimeout, _allocator; context=nothing) = + with_stream_manager(client, req, chunkedbody, on_stream_response_body, decompress, readtimeout; context=context) + +# ─── with_stream: connection manager path ─── - start_time = time() - stream = Stream{Nothing}(allocator, decompress, http2, false) +function with_stream(conn, req::Request, chunkedbody, on_stream_response_body, decompress, http2, readtimeout; context=nothing) + start_time = context !== nothing ? time() : 0.0 + stream = Stream{typeof(conn)}(decompress, http2, false) stream.readtimeout = readtimeout if on_stream_response_body !== nothing stream.bufferstream = Base.BufferStream() end - GC.@preserve stream begin - stream.request_options = aws_http_make_request_options( - 1, - req.ptr, - pointer_from_objref(stream), - on_response_headers[], - on_response_header_block_done[], - on_response_body[], - on_metrics[], - on_complete[], - on_destroy[], - http2 && chunkedbody !== nothing, # http2_use_manual_data_writes - readtimeout * 1000 # response_first_byte_timeout_ms - ) - stream_ptr = aws_http_connection_make_request(conn, FieldRef(stream, :request_options)) - stream_ptr == C_NULL && aws_throw_error() - stream.ptr = stream_ptr - http2 = aws_http_connection_get_version(conn) == AWS_HTTP_VERSION_2 - stream.response = resp = Response(0, nothing, nothing, http2, allocator) - resp.metrics = RequestMetrics() - resp.request = req - resp.metrics.request_body_length = bodylen(req) - try - aws_http_stream_activate(stream_ptr) != 0 && aws_throw_error() - # write chunked body if provided - if chunkedbody !== nothing - foreach(chunk -> writechunk(stream, chunk), chunkedbody) - # write final chunk - writechunk(stream, "") + stream.connection = conn + + request_options = _make_request_options(stream, req; + chunkedbody=(http2 ? chunkedbody : nothing), + readtimeout=readtimeout) + + aws_stream = AwsHTTP.http_connection_make_request(conn, request_options) + aws_stream === nothing && aws_throw_error() + stream.aws_stream = aws_stream + # Check actual connection version (may have been upgraded) + actual_http2 = AwsHTTP.http_connection_get_version(conn) == AwsHTTP.HttpVersion.HTTP_2 + stream.http2 = actual_http2 + stream.response = resp = Response(0, nothing, nothing, actual_http2) + resp.metrics = RequestMetrics() + resp.request = req + resp.metrics.request_body_length = bodylen(req) + timeout_task = nothing + if readtimeout > 0 + timeout_task = errormonitor(Threads.@spawn begin + sleep(readtimeout) + (@atomic stream.fut.set) != 0 && return + if isdefined(stream, :connection) + conn = stream.connection + conn !== nothing && AwsHTTP.http_connection_close(conn) end - if on_stream_response_body !== nothing - try - while !eof(stream.bufferstream) - on_stream_response_body(resp, _readavailable(stream.bufferstream)) - end - try - wait(stream.fut) - catch e - e isa HTTPError && rethrow() - throw(RequestError(req, e)) + notify(stream.fut, TimeoutError(readtimeout)) + if isdefined(stream, :aws_stream) + if stream.http2 + AwsHTTP.h2_stream_cancel!(stream.aws_stream) + else + AwsHTTP.http_stream_cancel(stream.aws_stream) + if isdefined(stream, :connection) + conn = stream.connection + if conn !== nothing && conn.slot !== nothing && conn.slot.channel !== nothing + AwsIO.channel_shutdown!(conn.slot.channel, AwsHTTP.ERROR_HTTP_RESPONSE_FIRST_BYTE_TIMEOUT; shutdown_immediately=true) + elseif conn !== nothing + AwsHTTP.http_connection_close(conn) + end end - catch e - rethrow(DontRetry(e)) end - else + end + end) + end + + try + _activate_stream!(stream) + # Write chunked body if provided + if chunkedbody !== nothing + foreach(chunk -> writechunk(stream, chunk), chunkedbody) + writechunk(stream, "") + end + if on_stream_response_body !== nothing + try + while !eof(stream.bufferstream) + on_stream_response_body(resp, _readavailable(stream.bufferstream)) + end try wait(stream.fut) catch e e isa HTTPError && rethrow() throw(RequestError(req, e)) end - if stream.bufferstream !== nothing - resp.body = _readavailable(stream.bufferstream) - else - resp.body = UInt8[] - end + catch e + rethrow(DontRetry(e)) + end + else + try + wait(stream.fut) + catch e + e isa HTTPError && rethrow() + throw(RequestError(req, e)) end - return resp - finally - aws_http_stream_release(stream_ptr) - stream.released = true - stream.ptr = Ptr{aws_http_stream}(C_NULL) + if stream.bufferstream !== nothing + resp.body = _readavailable(stream.bufferstream) + else + resp.body = UInt8[] + end + end + return resp + finally + timeout_task = nothing + stream.released = true + if context !== nothing _record_layer!(context, :streamlayer, start_time) end - end # GC.@preserve + end end +# compatibility: 9-arg version for callers that still pass allocator +with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout, _allocator; context=nothing) = + with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout; context=context) + # can be removed once https://github.com/JuliaLang/julia/pull/57211 is fully released function _readavailable(this::Base.BufferStream) bytes = lock(this.cond) do diff --git a/src/download.jl b/src/download.jl index bdd85d8c..4e3a40cb 100644 --- a/src/download.jl +++ b/src/download.jl @@ -57,7 +57,7 @@ Additional keyword arguments are forwarded to `HTTP.open`. """ function download(url::AbstractString, local_path=nothing, headers=Header[]; update_period=1, kw...) format_progress(x) = round(x, digits=4) - format_bytes(x) = !isfinite(x) ? "∞ B" : Base.format_bytes(round(Int, x)) + format_bytes(x) = !isfinite(x) ? "∞ B" : Base.format_bytes(round(Int, max(x, 0))) format_seconds(x) = "$(round(x; digits=2)) s" format_bytes_per_second(x) = format_bytes(x) * "/s" @@ -83,10 +83,12 @@ function download(url::AbstractString, local_path=nothing, headers=Header[]; upd function report_callback() prev_time = now() taken_time = (prev_time - start_time).value / 1000 - average_speed = downloaded_bytes / taken_time + average_speed = taken_time > 0 ? downloaded_bytes / taken_time : NaN remaining_bytes = total_bytes - downloaded_bytes - remaining_time = remaining_bytes / average_speed - completion_progress = downloaded_bytes / total_bytes + remaining_bytes = isfinite(remaining_bytes) && remaining_bytes < 0 ? 0 : remaining_bytes + remaining_time = average_speed > 0 ? remaining_bytes / average_speed : NaN + completion_progress = isfinite(total_bytes) && total_bytes > 0 ? downloaded_bytes / total_bytes : NaN + completion_progress = isfinite(completion_progress) ? clamp(completion_progress, 0, 1) : completion_progress @info("Downloading", source=url, dest=file, diff --git a/src/requestresponse.jl b/src/requestresponse.jl index ffbd2c8f..667f3fab 100644 --- a/src/requestresponse.jl +++ b/src/requestresponse.jl @@ -4,19 +4,19 @@ export Header, Headers, Message, Request, Response, canonicalizeheaders, canonicalizeheaders!, mkheaders # working with headers -headereq(a::String, b::String) = GC.@preserve a b aws_http_header_name_eq(aws_byte_cursor_from_c_str(a), aws_byte_cursor_from_c_str(b)) +headereq(a::String, b::String) = ascii_lc_isequal(a, b) -mutable struct Header - header::aws_http_header - Header() = new() - Header(header::aws_http_header) = new(header) +struct Header + header::AwsHTTP.HttpHeader end +Header() = Header(AwsHTTP.HttpHeader("", "")) +Header(name::AbstractString, value::AbstractString) = Header(AwsHTTP.HttpHeader(String(name), String(value))) function Base.getproperty(x::Header, s::Symbol) if s == :name - return str(getfield(x, :header).name) + return getfield(x, :header).name elseif s == :value - return str(getfield(x, :header).value) + return getfield(x, :header).value else return getfield(x, s) end @@ -25,32 +25,29 @@ end Base.show(io::IO, h::Header) = print_header(io, h) mutable struct Headers <: AbstractVector{Header} - const ptr::Ptr{aws_http_headers} - function Headers(allocator=default_aws_allocator()) - x = new(aws_http_headers_new(allocator)) - x.ptr == C_NULL && aws_throw_error() - return finalizer(_ -> aws_http_headers_release(x.ptr), x) + const hdrs::AwsHTTP.HttpHeaders + function Headers() + return new(AwsHTTP.http_headers_new()) end - # no finalizer in this constructor because whoever called aws_http_headers_new needs to do that - Headers(ptr::Ptr{aws_http_headers}) = new(ptr) + Headers(hdrs::AwsHTTP.HttpHeaders) = new(hdrs) end abstract type Message end -Base.size(h::Headers) = (Int(aws_http_headers_count(h.ptr)),) +Base.size(h::Headers) = (AwsHTTP.http_headers_count(h.hdrs),) function Base.getindex(h::Headers, i::Int) - header = Header() - aws_http_headers_get_index(h.ptr, i - 1, FieldRef(header, :header)) != 0 && aws_throw_error() - return header + hdr = AwsHTTP.http_headers_get_index(h.hdrs, i - 1) + hdr === nothing && throw(BoundsError(h, i)) + return Header(hdr) end -Base.Dict(h::Headers) = Dict(((h.name, h.value) for h in h)) +Base.Dict(h::Headers) = Dict(((h2.name, h2.value) for h2 in h)) -addheader(headers::Headers, h::Header) = aws_http_headers_add_header(headers.ptr, FieldRef(h, :header)) != 0 && aws_throw_error() -addheader(headers::Headers, k, v) = GC.@preserve k v aws_http_headers_add(headers.ptr, aws_byte_cursor_from_c_str(k), aws_byte_cursor_from_c_str(v)) != 0 && aws_throw_error() -addheaders(headers::Headers, h::Vector{aws_http_header}) = GC.@preserve h aws_http_headers_add_array(headers.ptr, pointer(h), length(h)) != 0 && aws_throw_error() -addheaders(headers::Headers, h::Ptr{aws_http_header}, count::Integer) = aws_http_headers_add_array(headers.ptr, h, count) != 0 && aws_throw_error() +addheader(headers::Headers, h::Header) = AwsHTTP.http_headers_add_header(headers.hdrs, h.header) != 0 && aws_throw_error() +addheader(headers::Headers, h::AwsHTTP.HttpHeader) = AwsHTTP.http_headers_add_header(headers.hdrs, h) != 0 && aws_throw_error() +addheader(headers::Headers, k, v) = AwsHTTP.http_headers_add(headers.hdrs, String(k), String(v)) != 0 && aws_throw_error() +addheaders(headers::Headers, h::Vector{AwsHTTP.HttpHeader}) = AwsHTTP.http_headers_add_array(headers.hdrs, h) != 0 && aws_throw_error() function addheaders(headers::Headers, h::Vector{Pair{String, String}}) for (k, v) in h @@ -58,30 +55,28 @@ function addheaders(headers::Headers, h::Vector{Pair{String, String}}) end end -setheader(headers::Headers, k, v) = GC.@preserve k v aws_http_headers_set(headers.ptr, aws_byte_cursor_from_c_str(k), aws_byte_cursor_from_c_str(v)) != 0 && aws_throw_error() -setscheme(headers::Headers, scheme) = GC.@preserve scheme aws_http2_headers_set_request_scheme(headers.ptr, aws_byte_cursor_from_c_str(scheme)) != 0 && aws_throw_error() -setauthority(headers::Headers, authority) = GC.@preserve authority aws_http2_headers_set_request_authority(headers.ptr, aws_byte_cursor_from_c_str(authority)) != 0 && aws_throw_error() +setheader(headers::Headers, k, v) = AwsHTTP.http_headers_set(headers.hdrs, String(k), String(v)) != 0 && aws_throw_error() +setscheme(headers::Headers, scheme) = AwsHTTP.http2_headers_set_request_scheme(headers.hdrs, String(scheme)) != 0 && aws_throw_error() +setauthority(headers::Headers, authority) = AwsHTTP.http2_headers_set_request_authority(headers.hdrs, String(authority)) != 0 && aws_throw_error() -#TODO: struct aws_string *aws_http_headers_get_all(const struct aws_http_headers *headers, struct aws_byte_cursor name); function getheader(headers::Headers, k) - out = Ref{aws_byte_cursor}() - GC.@preserve k out begin - aws_http_headers_get(headers.ptr, aws_byte_cursor_from_c_str(k), out) != 0 && return nothing - return str(out[]) - end + return AwsHTTP.http_headers_get(headers.hdrs, String(k)) end -hasheader(headers::Headers, k) = - GC.@preserve k aws_http_headers_has(headers.ptr, aws_byte_cursor_from_c_str(k)) +hasheader(headers::Headers, k) = AwsHTTP.http_headers_has(headers.hdrs, String(k)) -removeheader(headers::Headers, k) = - GC.@preserve k aws_http_headers_erase(headers.ptr, aws_byte_cursor_from_c_str(k)) != 0 && aws_throw_error() +removeheader(headers::Headers, k) = AwsHTTP.http_headers_erase(headers.hdrs, String(k)) != 0 && aws_throw_error() +removeheader(headers::Headers, k, v) = AwsHTTP.http_headers_erase_value(headers.hdrs, String(k), String(v)) != 0 && aws_throw_error() -removeheader(headers::Headers, k, v) = - GC.@preserve k v aws_http_headers_erase_value(headers.ptr, aws_byte_cursor_from_c_str(k), aws_byte_cursor_from_c_str(v)) != 0 && aws_throw_error() +function Base.deleteat!(h::Headers, i::Int) + AwsHTTP.http_headers_erase_index(h.hdrs, i - 1) != 0 && aws_throw_error() + return h +end -Base.deleteat!(h::Headers, i::Int) = aws_http_headers_erase_index(h.ptr, i - 1) != 0 && aws_throw_error() -Base.empty!(h::Headers) = aws_http_headers_clear(h.ptr) != 0 && aws_throw_error() +function Base.empty!(h::Headers) + AwsHTTP.http_headers_clear(h.hdrs) + return h +end setheaderifabsent(headers, k, v) = !hasheader(headers, k) && setheader(headers, k, v) setheaderifabsent(m::Message, k, v) = setheaderifabsent(m.headers, k, v) @@ -240,14 +235,12 @@ end # request/response mutable struct InputStream - ptr::Ptr{aws_input_stream} bodyref::Any bodylen::Int64 - bodycursor::aws_byte_cursor - InputStream() = new() + InputStream() = new(nothing, 0) end -ischunked(is::InputStream) = is.ptr == C_NULL && is.bodyref !== nothing +ischunked(is::InputStream) = is.bodylen < 0 && is.bodyref !== nothing const RequestBodyTypes = Union{AbstractString, AbstractVector{UInt8}, IO, AbstractDict, NamedTuple, Form, Nothing} const DEFAULT_IO_CHUNK_SIZE = 64 * 1024 @@ -286,79 +279,55 @@ function _record_layer!(context::Dict{Symbol, Any}, name::Symbol, started::Float return end -function InputStream(allocator::Ptr{aws_allocator}, body) +function setinputstream!(m::Message, body) + AwsHTTP.http_message_set_body_stream(getfield(m, :msg), nothing) + m.inputstream = nothing + body === nothing && return is = InputStream() - if body !== nothing - if (body isa AbstractVector{UInt8}) || (body isa AbstractString) - is.bodyref = body - is.bodycursor = aws_byte_cursor(sizeof(body), pointer(body)) - is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) - is.ptr == C_NULL && aws_throw_error() - elseif body isa Union{AbstractDict, NamedTuple} - # hold a reference to the request body in order to gc-preserve it - is.bodyref = URIs.escapeuri(body) - is.bodycursor = aws_byte_cursor_from_c_str(is.bodyref) - is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) - is.ptr == C_NULL && aws_throw_error() - elseif body isa IOStream - isopen(body) || throw(ArgumentError("request body IOStream is closed")) - is.bodyref = body - is.ptr = aws_input_stream_new_from_open_file(allocator, Libc.FILE(body)) - is.ptr == C_NULL && aws_throw_error() - elseif body isa Form - # we set the request.body to the Form bytes in order to gc-preserve them - is.bodyref = read(body) - is.bodycursor = aws_byte_cursor(sizeof(is.bodyref), pointer(is.bodyref)) - is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) - is.ptr == C_NULL && aws_throw_error() - elseif body isa IO - # we set the request.body to the IO bytes in order to gc-preserve them - bytes = readavailable(body) - while !eof(body) - append!(bytes, readavailable(body)) - end - is.bodyref = bytes - is.bodycursor = aws_byte_cursor(sizeof(is.bodyref), pointer(is.bodyref)) - is.ptr = aws_input_stream_new_from_cursor(allocator, FieldRef(is, :bodycursor)) - is.ptr == C_NULL && aws_throw_error() - elseif Base.isiterable(typeof(body)) - # assume a chunked request body; any kind of iterable where elements are RequestBodyTypes - is.bodyref = body - else - throw(ArgumentError("request body must be a string, vector of UInt8, NamedTuple, AbstractDict, HTTP.Form, IO, or an iterable of those")) - end - if is.ptr != C_NULL - aws_input_stream_get_length(is.ptr, FieldRef(is, :bodylen)) != 0 && aws_throw_error() - if !(is.bodylen > 0) - aws_input_stream_release(is.ptr) - is.ptr = C_NULL - end + if (body isa AbstractVector{UInt8}) || (body isa AbstractString) + is.bodyref = body + is.bodylen = sizeof(body) + elseif body isa Union{AbstractDict, NamedTuple} + is.bodyref = URIs.escapeuri(body) + is.bodylen = sizeof(is.bodyref) + elseif body isa IOStream + isopen(body) || throw(ArgumentError("request body IOStream is closed")) + is.bodyref = read(body) + is.bodylen = sizeof(is.bodyref) + elseif body isa Form + is.bodyref = read(body) + is.bodylen = sizeof(is.bodyref) + elseif body isa IO + bytes = readavailable(body) + while !eof(body) + append!(bytes, readavailable(body)) end + is.bodyref = bytes + is.bodylen = sizeof(is.bodyref) + elseif Base.isiterable(typeof(body)) + # chunked request body; any kind of iterable where elements are RequestBodyTypes + is.bodyref = body + is.bodylen = -1 + else + throw(ArgumentError("request body must be a string, vector of UInt8, NamedTuple, AbstractDict, HTTP.Form, IO, or an iterable of those")) end - return finalizer(x -> x.ptr != C_NULL && aws_input_stream_release(x.ptr), is) -end - -function setinputstream!(msg::Message, body) - aws_http_message_set_body_stream(msg.ptr, C_NULL) - msg.inputstream = nothing - body === nothing && return - input_stream = InputStream(msg.allocator, body) - setfield!(msg, :inputstream, input_stream) - if input_stream.ptr != C_NULL - aws_http_message_set_body_stream(msg.ptr, input_stream.ptr) + setfield!(m, :inputstream, is) + if is.bodylen > 0 + # Wrap body in IOBuffer so the H1 encoder's readbytes! works + AwsHTTP.http_message_set_body_stream(getfield(m, :msg), IOBuffer(is.bodyref)) if body isa Union{AbstractDict, NamedTuple} - setheaderifabsent(msg.headers, "content-type", "application/x-www-form-urlencoded") + setheaderifabsent(m.headers, "content-type", "application/x-www-form-urlencoded") elseif body isa Form - setheaderifabsent(msg.headers, "content-type", content_type(body)) + setheaderifabsent(m.headers, "content-type", content_type(body)) end - setheader(msg.headers, "content-length", string(input_stream.bodylen)) + setheader(m.headers, "content-length", string(is.bodylen)) end return end + mutable struct Request <: Message - allocator::Ptr{aws_allocator} - ptr::Ptr{aws_http_message} + msg::AwsHTTP.HttpMessage inputstream::Union{Nothing, InputStream} # used for outgoing request body # only set in server-side request handlers body::Union{Nothing, Vector{UInt8}} @@ -368,37 +337,35 @@ mutable struct Request <: Message params::Union{Nothing, Dict{String, String}} cookies::Any # actually Union{Nothing, Vector{Cookie}} - function Request(method, path, headers=nothing, body=nothing, http2::Bool=false, allocator=default_aws_allocator(); context=nothing) - ptr = http2 ? - aws_http2_message_new_request(allocator) : - aws_http_message_new_request(allocator) - ptr == C_NULL && aws_throw_error() - try - GC.@preserve method aws_http_message_set_request_method(ptr, aws_byte_cursor_from_c_str(method)) != 0 && aws_throw_error() - GC.@preserve path aws_http_message_set_request_path(ptr, aws_byte_cursor_from_c_str(path)) != 0 && aws_throw_error() - request_headers = Headers(aws_http_message_get_headers(ptr)) - if headers !== nothing - for (k, v) in headers - addheader(request_headers, k, v) - end + function Request(method, path, headers=nothing, body=nothing, http2::Bool=false; context=nothing) + msg = http2 ? + AwsHTTP.http2_message_new_request() : + AwsHTTP.http_message_new_request() + AwsHTTP.http_message_set_request_method(msg, String(method)) != 0 && aws_throw_error() + AwsHTTP.http_message_set_request_path(msg, String(path)) != 0 && aws_throw_error() + msg_headers = AwsHTTP.http_message_get_headers(msg) + if headers !== nothing + for (k, v) in headers + AwsHTTP.http_headers_add(msg_headers, String(k), String(v)) != 0 && aws_throw_error() end - req = new(allocator, ptr) - req.body = nothing - req.inputstream = nothing - req.trailers = nothing - req.context = context === nothing ? Dict{Symbol, Any}() : context - req.route = nothing - req.params = nothing - req.cookies = nothing - body !== nothing && setinputstream!(req, body) - return finalizer(_ -> aws_http_message_release(ptr), req) - catch - aws_http_message_release(ptr) - rethrow() end + req = new(msg) + req.body = nothing + req.inputstream = nothing + req.trailers = nothing + req.context = context === nothing ? Dict{Symbol, Any}() : context + req.route = nothing + req.params = nothing + req.cookies = nothing + body !== nothing && setinputstream!(req, body) + return req end end +# compatibility: 6-arg version for callers that still pass allocator +Request(method, path, headers, body, http2::Bool, _allocator; context=nothing) = + Request(method, path, headers, body, http2; context=context) + getrequest(req::Request) = req function observelayer(f) @@ -417,26 +384,19 @@ function observelayer(f) end end -ptr(x) = getfield(x, :ptr) - function Base.getproperty(x::Request, s::Symbol) if s == :method - out = Ref{aws_byte_cursor}() - GC.@preserve out begin - aws_http_message_get_request_method(ptr(x), out) != 0 && return nothing - return str(out[]) - end - elseif s == :path || s == :target || s == :uri - out = Ref{aws_byte_cursor}() - GC.@preserve out begin - aws_http_message_get_request_path(ptr(x), out) != 0 && return nothing - path = str(out[]) - return s == :uri ? URI(path) : path - end + return AwsHTTP.http_message_get_request_method(getfield(x, :msg)) + elseif s == :path || s == :target + return AwsHTTP.http_message_get_request_path(getfield(x, :msg)) + elseif s == :uri + path = AwsHTTP.http_message_get_request_path(getfield(x, :msg)) + return path === nothing ? URI("/") : URI(path) elseif s == :headers - return Headers(aws_http_message_get_headers(ptr(x))) + return Headers(AwsHTTP.http_message_get_headers(getfield(x, :msg))) elseif s == :version - return aws_http_message_get_protocol_version(ptr(x)) == AWS_HTTP_VERSION_2 ? HTTPVersion(2, 0) : HTTPVersion(1, 1) + v = AwsHTTP.http_message_get_protocol_version(getfield(x, :msg)) + return v == AwsHTTP.HttpVersion.HTTP_2 ? HTTPVersion(2, 0) : HTTPVersion(1, 1) else return getfield(x, s) end @@ -444,9 +404,9 @@ end function Base.setproperty!(x::Request, s::Symbol, v) if s == :method - GC.@preserve v aws_http_message_set_request_method(x.ptr, aws_byte_cursor_from_c_str(v)) != 0 && aws_throw_error() + AwsHTTP.http_message_set_request_method(getfield(x, :msg), String(v)) != 0 && aws_throw_error() elseif s == :path - GC.@preserve v aws_http_message_set_request_path(x.ptr, aws_byte_cursor_from_c_str(v)) != 0 && aws_throw_error() + AwsHTTP.http_message_set_request_path(getfield(x, :msg), String(v)) != 0 && aws_throw_error() elseif s == :headers addheaders(x.headers, v) else @@ -510,62 +470,68 @@ target(r::Request) = r.path headers(r::Request) = r.headers body(r::Request) = r.body -resource(uri::URI) = string(isempty(uri.path) ? "/" : uri.path, - !isempty(uri.query) ? "?" : "", uri.query, - !isempty(uri.fragment) ? "#" : "", uri.fragment) - mutable struct RequestMetrics request_body_length::Int response_body_length::Int nretries::Int - stream_metrics::Union{Nothing, aws_http_stream_metrics} + stream_metrics::Union{Nothing, AwsHTTP.HttpStreamMetrics} end RequestMetrics() = RequestMetrics(0, 0, 0, nothing) mutable struct Response <: Message - allocator::Ptr{aws_allocator} - ptr::Ptr{aws_http_message} + msg::AwsHTTP.HttpMessage inputstream::Union{Nothing, InputStream} body::Union{Nothing, Vector{UInt8}} # only set for client-side response body when no user-provided response_body trailers::Union{Nothing, Headers} metrics::RequestMetrics request::Union{Request, Nothing} - function Response(status::Integer, headers, body, http2::Bool=false, allocator=default_aws_allocator()) - ptr = http2 ? - aws_http2_message_new_response(allocator) : - aws_http_message_new_response(allocator) - ptr == C_NULL && aws_throw_error() - try - GC.@preserve status aws_http_message_set_response_status(ptr, status) != 0 && aws_throw_error() - response_headers = Headers(aws_http_message_get_headers(ptr)) - if headers !== nothing - for (k, v) in headers - addheader(response_headers, k, v) - end + function Response(status::Integer, headers, body, http2::Bool=false) + msg = http2 ? + AwsHTTP.http2_message_new_response() : + AwsHTTP.http_message_new_response() + AwsHTTP.http_message_set_response_status(msg, Int(status)) != 0 && aws_throw_error() + msg_headers = AwsHTTP.http_message_get_headers(msg) + if headers !== nothing + for (k, v) in headers + AwsHTTP.http_headers_add(msg_headers, String(k), String(v)) != 0 && aws_throw_error() end - resp = new(allocator, ptr) - resp.body = nothing - resp.inputstream = nothing - resp.trailers = nothing - resp.metrics = RequestMetrics() - resp.request = nothing - body !== nothing && setinputstream!(resp, body) - return finalizer(_ -> aws_http_message_release(ptr), resp) - catch - aws_http_message_release(ptr) - rethrow() end + resp = new(msg) + resp.body = nothing + resp.inputstream = nothing + resp.trailers = nothing + resp.metrics = RequestMetrics() + resp.request = nothing + if body !== nothing + setinputstream!(resp, body) + else + if !hasheader(resp.headers, "content-length") && !hasheader(resp.headers, "transfer-encoding") + setheader(resp.headers, "content-length" => "0") + end + end + return resp end - Response() = new(C_NULL, C_NULL, nothing, nothing, nothing, RequestMetrics(), nothing) + Response() = new(AwsHTTP.http_message_new_response(), nothing, nothing, nothing, RequestMetrics(), nothing) end +# compatibility: 5-arg version for callers that still pass allocator +Response(status::Integer, headers, body, http2::Bool, _allocator) = + Response(status, headers, body, http2) + Response(status::Integer, body) = Response(status, nothing, Vector{UInt8}(string(body))) Response(status::Integer) = Response(status, nothing, nothing) getresponse(r::Response) = r +function _head_response!(resp::Response) + setinputstream!(resp, nothing) + hasheader(resp.headers, "transfer-encoding") && removeheader(resp.headers, "transfer-encoding") + setheader(resp.headers, "content-length" => "0") + return +end + bodylen(m::Message) = isdefined(m, :inputstream) && m.inputstream !== nothing ? m.inputstream.bodylen : 0 function bodylen(r::Response) @@ -577,13 +543,12 @@ end function Base.getproperty(x::Response, s::Symbol) if s == :status - ref = Ref{Cint}() - aws_http_message_get_response_status(x.ptr, ref) != 0 && return nothing - return Int(ref[]) + return AwsHTTP.http_message_get_response_status(getfield(x, :msg)) elseif s == :headers - return Headers(aws_http_message_get_headers(x.ptr)) + return Headers(AwsHTTP.http_message_get_headers(getfield(x, :msg))) elseif s == :version - return aws_http_message_get_protocol_version(x.ptr) == AWS_HTTP_VERSION_2 ? HTTPVersion(2, 0) : HTTPVersion(1, 1) + v = AwsHTTP.http_message_get_protocol_version(getfield(x, :msg)) + return v == AwsHTTP.HttpVersion.HTTP_2 ? HTTPVersion(2, 0) : HTTPVersion(1, 1) else return getfield(x, s) end @@ -591,7 +556,7 @@ end function Base.setproperty!(x::Response, s::Symbol, v) if s == :status - GC.@preserve v aws_http_message_set_response_status(x.ptr, v) != 0 && aws_throw_error() + AwsHTTP.http_message_set_response_status(getfield(x, :msg), Int(v)) != 0 && aws_throw_error() elseif s == :headers addheaders(x.headers, v) else diff --git a/src/server.jl b/src/server.jl index 5df45312..91328ebd 100644 --- a/src/server.jl +++ b/src/server.jl @@ -1,124 +1,60 @@ -socket_endpoint(host, port) = aws_socket_endpoint( - ntuple(i -> i > sizeof(host) ? 0x00 : codeunit(host, i), Base._counttuple(fieldtype(aws_socket_endpoint, :address))), - port % UInt32 -) - -function server_tlsoptions(host::String; - allocator=default_aws_allocator(), +function server_tlsoptions(; ssl_cert=nothing, ssl_key=nothing, ssl_capath=nothing, ssl_cacert=nothing, ssl_insecure=false, - ssl_alpn_list=nothing, + ssl_alpn_list="h2;http/1.1", ) - tls_options = aws_tls_connection_options(C_NULL, C_NULL, C_NULL, C_NULL, C_NULL, C_NULL, C_NULL, false, UInt32(0)) - tls_ctx_options = Ptr{aws_tls_ctx_options}(aws_mem_acquire(allocator, sizeof(aws_tls_ctx_options))) - tls_ctx = C_NULL - try - if ssl_cert !== nothing && ssl_key !== nothing - LibAwsIO.aws_tls_ctx_options_init_default_server_from_path(tls_ctx_options, allocator, ssl_cert, ssl_key) != 0 && sockerr("aws_tls_ctx_options_init_default_server_from_path failed") - elseif Sys.iswindows() && ssl_cert !== nothing && ssl_key === nothing - LibAwsIO.aws_tls_ctx_options_init_default_server_from_system_path(tls_ctx_options, allocator, ssl_cert) != 0 && sockerr("aws_tls_ctx_options_init_default_server_from_system_path failed") - else - throw(ArgumentError("ssl_cert and ssl_key are required for TLS server")) - end - if ssl_capath !== nothing && ssl_cacert !== nothing - LibAwsIO.aws_tls_ctx_options_override_default_trust_store_from_path(tls_ctx_options, ssl_capath, ssl_cacert) != 0 && sockerr("aws_tls_ctx_options_override_default_trust_store_from_path failed") - end - if ssl_insecure - LibAwsIO.aws_tls_ctx_options_set_verify_peer(tls_ctx_options, false) - end - if ssl_alpn_list !== nothing - LibAwsIO.aws_tls_ctx_options_set_alpn_list(tls_ctx_options, ssl_alpn_list) != 0 && sockerr("aws_tls_ctx_options_set_alpn_list failed") - end - tls_ctx = LibAwsIO.aws_tls_server_ctx_new(allocator, tls_ctx_options) - tls_ctx == C_NULL && sockerr("") - ref = Ref(tls_options) - LibAwsIO.aws_tls_connection_options_init_from_ctx(ref, tls_ctx) - tls_options = ref[] - finally - LibAwsIO.aws_tls_ctx_options_clean_up(tls_ctx_options) - LibAwsIO.aws_tls_ctx_release(tls_ctx) - aws_mem_release(allocator, tls_ctx_options) + alpn_list = _normalize_alpn_list(ssl_alpn_list) + if ssl_cert !== nothing && ssl_key !== nothing + ctx_opts = AwsIO.tls_ctx_options_init_default_server_from_path(ssl_cert, ssl_key; alpn_list=alpn_list) + ctx_opts isa AwsIO.ErrorResult && throw(AWSError("Failed to create server TLS options")) + elseif Sys.iswindows() && ssl_cert !== nothing && ssl_key === nothing + ctx_opts = AwsIO.tls_ctx_options_init_default_server_from_system_path(ssl_cert) + ctx_opts isa AwsIO.ErrorResult && throw(AWSError("Failed to create server TLS options from system path")) + else + throw(ArgumentError("ssl_cert and ssl_key are required for TLS server")) + end + if ssl_capath !== nothing || ssl_cacert !== nothing + res = AwsIO.tls_ctx_options_override_default_trust_store_from_path!(ctx_opts; + ca_path=ssl_capath, ca_file=ssl_cacert) + res isa AwsIO.ErrorResult && throw(AWSError("Failed to set trust store")) + end + if ssl_insecure + AwsIO.tls_ctx_options_set_verify_peer!(ctx_opts, false) end - return tls_options + ctx = AwsIO.tls_server_ctx_new(ctx_opts) + ctx isa AwsIO.ErrorResult && throw(AWSError("Failed to create server TLS context")) + return AwsIO.TlsConnectionOptions(ctx; alpn_list=alpn_list) end mutable struct Connection{S} const server::S # Server{F, C} - const allocator::Ptr{aws_allocator} - const connection::Ptr{aws_http_connection} + const h1conn::Any # AwsHTTP.H1Connection or AwsHTTP.H2Connection + const channel::Any # AwsIO.Channel const streams_lock::ReentrantLock const streams::Set{Stream} - connection_options::aws_http_server_connection_options + const remote_addr::String + const remote_port_num::Int - Connection( - server::S, - allocator::Ptr{aws_allocator}, - connection::Ptr{aws_http_connection}, - ) where {S} = new{S}(server, allocator, connection, ReentrantLock(), Set{Stream}()) + Connection(server::S, h1conn, channel, remote_addr::String, remote_port_num::Int) where {S} = + new{S}(server, h1conn, channel, ReentrantLock(), Set{Stream}(), remote_addr, remote_port_num) end -Base.hash(c::Connection, h::UInt) = hash(c.connection, h) - -if !@isdefined aws_http2_send_push_promise_options - struct aws_http2_send_push_promise_options - self_size::Csize_t - user_data::Ptr{Cvoid} - on_complete::Ptr{Cvoid} - on_destroy::Ptr{Cvoid} - pad_length::UInt8 - end -end +Base.hash(c::Connection, h::UInt) = hash(objectid(c), h) -function aws_http2_stream_send_push_promise( - parent_stream::Ptr{aws_http_stream}, - request::Ptr{aws_http_message}, - options::Ref{aws_http2_send_push_promise_options}, -) - return ccall((:aws_http2_stream_send_push_promise, LibAwsHTTPFork.libaws_c_http_jq), - Ptr{aws_http_stream}, - (Ptr{aws_http_stream}, Ptr{aws_http_message}, Ptr{aws_http2_send_push_promise_options}), - parent_stream, - request, - options, - ) -end - -function remote_address(c::Connection) - socket_ptr = aws_http_connection_get_remote_endpoint(c.connection) - addr = unsafe_load(socket_ptr).address - bytes = Vector{UInt8}(undef, length(addr)) - nul_i = 0 - for i in eachindex(bytes) - b = addr[i] - @inbounds bytes[i] = b - if b == 0x00 - nul_i = i - break - end - end - resize!(bytes, nul_i == 0 ? length(addr) : nul_i - 1) - return String(bytes) -end -remote_port(c::Connection) = Int(unsafe_load(aws_http_connection_get_remote_endpoint(c.connection)).port) +remote_address(c::Connection) = c.remote_addr +remote_port(c::Connection) = c.remote_port_num function http_version(c::Connection) - v = aws_http_connection_get_version(c.connection) - return v == AWS_HTTP_VERSION_2 ? "HTTP/2" : "HTTP/1.1" + v = AwsHTTP.http_connection_get_version(c.h1conn) + return v == AwsHTTP.HttpVersion.HTTP_2 ? "HTTP/2" : "HTTP/1.1" end -getinet(host::String, port::Integer) = Sockets.InetAddr(parse(IPAddr, host), port) -getinet(host::IPAddr, port::Integer) = Sockets.InetAddr(host, port) - mutable struct Server{F, C} const f::F const on_stream_complete::C const fut::Future{Symbol} - const allocator::Ptr{aws_allocator} - const endpoint::aws_socket_endpoint - const socket_options::aws_socket_options - const tls_options::Union{aws_tls_connection_options, Nothing} const connections_lock::ReentrantLock const connections::Set{Connection} const closed::Threads.Event @@ -126,17 +62,13 @@ mutable struct Server{F, C} const stream::Bool const logstate::Base.CoreLogging.LogState @atomic state::Symbol # :initializing, :running, :closed - server::Ptr{aws_http_server} - server_options::aws_http_server_options + bootstrap::Any # AwsIO.ServerBootstrap + bound_port::Int Server{F, C}( f::F, on_stream_complete::C, fut::Future{Symbol}, - allocator::Ptr{aws_allocator}, - endpoint::aws_socket_endpoint, - socket_options::aws_socket_options, - tls_options::Union{aws_tls_connection_options, Nothing}, connections_lock::ReentrantLock, connections::Set{Connection}, closed::Threads.Event, @@ -144,23 +76,224 @@ mutable struct Server{F, C} stream::Bool, logstate::Base.CoreLogging.LogState, state::Symbol, - ) where {F, C} = new{F, C}(f, on_stream_complete, fut, allocator, endpoint, socket_options, tls_options, connections_lock, connections, closed, access_log, stream, logstate, state) + ) where {F, C} = new{F, C}(f, on_stream_complete, fut, connections_lock, connections, closed, access_log, stream, logstate, state) end Base.wait(s::Server) = wait(s.closed) ftype(::Server{F}) where {F} = F -port(s::Server) = Int(s.endpoint.port) +port(s::Server) = s.bound_port + +function _should_log_stream_error(error_code::Integer)::Bool + error_code == 0 && return false + error_code == AwsHTTP.ERROR_HTTP_CONNECTION_CLOSED && return false + error_code == AwsHTTP.ERROR_HTTP_STREAM_CANCELLED && return false + error_code == AwsHTTP.ERROR_HTTP_SERVER_CLOSED && return false + error_code == AwsHTTP.ERROR_HTTP_SWITCHED_PROTOCOLS && return false + error_code == AwsHTTP.ERROR_HTTP_GOAWAY_RECEIVED && return false + error_code == AwsHTTP.ERROR_HTTP_RST_STREAM_RECEIVED && return false + error_code == AwsIO.ERROR_IO_SOCKET_CLOSED && return false + error_code == AwsIO.ERROR_IO_BROKEN_PIPE && return false + error_code == AwsIO.ERROR_IO_OPERATION_CANCELLED && return false + return true +end + +function _should_log_channel_shutdown_error(error_code::Integer)::Bool + error_code == 0 && return false + error_code == AwsHTTP.ERROR_HTTP_CONNECTION_CLOSED && return false + error_code == AwsHTTP.ERROR_HTTP_SERVER_CLOSED && return false + error_code == AwsIO.ERROR_IO_SOCKET_CLOSED && return false + error_code == AwsIO.ERROR_IO_BROKEN_PIPE && return false + error_code == AwsIO.ERROR_IO_OPERATION_CANCELLED && return false + return true +end + +function _create_request_handler!(conn::Connection, aws_conn; http2::Bool=false) + server = conn.server + http_conn = aws_conn + stream = Stream{typeof(conn)}(nothing, http2, true) + stream.connection = conn + stream.request = Request("", "", nothing, nothing, http2) + + on_request_headers = (aws_stream, header_block, headers_vec, user_data) -> begin + if header_block == AwsHTTP.HttpHeaderBlock.TRAILING + trailers = stream.request.trailers + if trailers === nothing + trailers = Headers() + stream.request.trailers = trailers + end + for h in headers_vec + addheader(trailers, h.name, h.value) + end + else + hdrs = stream.request.headers + for h in headers_vec + if stream.http2 && !isempty(h.name) && h.name[1] == ':' + if h.name == ":scheme" || h.name == ":authority" || h.name == ":protocol" + addheader(hdrs, h.name, h.value) + if h.name == ":authority" && !hasheader(hdrs, "host") + addheader(hdrs, "host", h.value) + end + end + else + addheader(hdrs, h.name, h.value) + end + end + end + return AwsHTTP.OP_SUCCESS + end + + on_request_header_block_done = (aws_stream, header_block, user_data) -> begin + if header_block != AwsHTTP.HttpHeaderBlock.MAIN + return AwsHTTP.OP_SUCCESS + end + method = AwsHTTP.http_stream_get_incoming_request_method(aws_stream) + path = AwsHTTP.http_stream_get_incoming_request_uri(aws_stream) + method === nothing && (method = "") + path === nothing && (path = "") + stream.request.method = method + stream.request.path = path + notify(stream.headers_ready) + if server.stream && !stream.handler_started + stream.handler_started = true + stream.bufferstream === nothing && (stream.bufferstream = Base.BufferStream()) + Threads.@spawn begin + Base.CoreLogging.with_logstate(server.logstate) do + try + Base.invokelatest(server.f, stream) + catch e + @error "Request handler error; sending 500" exception=(e, catch_backtrace()) + if !stream.response_started + try setstatus(stream, 500) catch; end + end + finally + try closewrite(stream) catch; end + end + end + end + end + return AwsHTTP.OP_SUCCESS + end + + on_request_body = (aws_stream, data, user_data) -> begin + if server.stream + stream.bufferstream === nothing && (stream.bufferstream = Base.BufferStream()) + write(stream.bufferstream, data) + return AwsHTTP.OP_SUCCESS + end + body = stream.request.body + if body === nothing + stream.request.body = copy(data) + else + append!(body, data) + end + return AwsHTTP.OP_SUCCESS + end + + on_request_done = (aws_stream, user_data) -> begin + if server.stream + Base.CoreLogging.with_logstate(server.logstate) do + stream.bufferstream !== nothing && close(stream.bufferstream) + end + return + end + errormonitor(Threads.@spawn begin + Base.CoreLogging.with_logstate(server.logstate) do + try + stream.response = Base.invokelatest(server.f, stream.request)::Response + if stream.request.method == "HEAD" + _head_response!(stream.response) + end + catch e + @error "Request handler error; sending 500" exception=(e, catch_backtrace()) + stream.response = Response(500) + end + _send_response!(stream) + end + end) + return + end + + on_complete = (aws_stream, error_code, user_data) -> begin + stream.released && return + stream.released = true + Base.CoreLogging.with_logstate(server.logstate) do + if _should_log_stream_error(error_code) + @error "server stream complete error" error_code + end + if server.on_stream_complete !== nothing + try + Base.invokelatest(server.on_stream_complete, stream) + catch e + @error "on_stream_complete error" exception=(e, catch_backtrace()) + end + end + if stream.on_complete !== nothing + try + Base.invokelatest(stream.on_complete, stream) + catch e + @error "stream on_complete error" exception=(e, catch_backtrace()) + end + stream.on_complete = nothing + end + if server.access_log !== nothing + try + if isdefined(stream, :request) && isdefined(stream, :response) + @info sprint(server.access_log, stream) _group=:access + end + catch e + @error "access log error" exception=(e, catch_backtrace()) + end + end + @lock conn.streams_lock begin + delete!(conn.streams, stream) + end + # HTTP pipelining: create next request handler if connection allows + if !stream.http2 && AwsHTTP.http_connection_new_requests_allowed(http_conn) + _create_request_handler!(conn, http_conn; http2=false) + end + end + return + end + + on_destroy = (user_data) -> nothing + + opts = AwsHTTP.HttpRequestHandlerOptions( + http_conn, + nothing, + on_request_headers, + on_request_header_block_done, + on_request_body, + on_request_done, + on_complete, + on_destroy, + ) + if http2 + h2stream = AwsHTTP.h2_stream_new_request_handler(http_conn, opts; manual_write=server.stream) + stream.aws_stream = h2stream + @lock conn.streams_lock begin + push!(conn.streams, stream) + end + return h2stream + end + h1stream = AwsHTTP.http_connection_new_request_handler(http_conn, opts) + if h1stream === nothing + @error "failed to create request handler stream" + return + end + stream.aws_stream = h1stream + AwsHTTP.h1_stream_activate!(h1stream) + @lock conn.streams_lock begin + push!(conn.streams, stream) + end + return +end function serve!(f, host="127.0.0.1", port=8080; - allocator=default_aws_allocator(), - bootstrap::Ptr{aws_server_bootstrap}=default_aws_server_bootstrap(), - endpoint=nothing, - listenany::Bool=false, on_stream_complete=nothing, access_log::Union{Nothing, Function}=nothing, stream::Bool=false, + listenany::Bool=false, # socket options - socket_options=nothing, socket_domain=:ipv4, connect_timeout_ms::Integer=3000, keep_alive_interval_sec::Integer=0, @@ -177,60 +310,180 @@ function serve!(f, host="127.0.0.1", port=8080; ssl_alpn_list="h2;http/1.1", initial_window_size=typemax(UInt64), ) - addr = getinet(host, port) + _ensure_resources!() + host_str = string(host) + port_int = Int(port) if listenany - port, sock = Sockets.listenany(addr.host, addr.port) + addr = Sockets.InetAddr(parse(IPAddr, host_str), port_int) + port_int, sock = Sockets.listenany(addr.host, addr.port) close(sock) end + tls_conn_opts = if tls_options !== nothing + tls_options + elseif any(x -> x !== nothing, (ssl_cert, ssl_key, ssl_capath, ssl_cacert)) + server_tlsoptions(; + ssl_cert, ssl_key, ssl_capath, ssl_cacert, ssl_insecure, ssl_alpn_list + ) + else + nothing + end server = Server{typeof(f), typeof(on_stream_complete)}( - f, # RequestHandler + f, on_stream_complete, Future{Symbol}(), - allocator, - endpoint !== nothing ? endpoint : socket_endpoint(host, port), - socket_options !== nothing ? socket_options : aws_socket_options( - AWS_SOCKET_STREAM, # socket type - socket_domain == :ipv4 ? AWS_SOCKET_IPV4 : AWS_SOCKET_IPV6, # socket domain - AWS_SOCKET_IMPL_PLATFORM_DEFAULT, # aws_socket_impl_type - connect_timeout_ms, - keep_alive_interval_sec, - keep_alive_timeout_sec, - keep_alive_max_failed_probes, - keepalive, - ntuple(x -> Cchar(0), 16) # network_interface_name - ), - tls_options !== nothing ? tls_options : - any(x -> x !== nothing, (ssl_cert, ssl_key, ssl_capath, ssl_cacert)) ? server_tlsoptions(host; - ssl_cert, - ssl_key, - ssl_capath, - ssl_cacert, - ssl_insecure, - ssl_alpn_list - ) : nothing, - ReentrantLock(), # connections_lock - Set{Connection}(), # connections - Threads.Event(), # closed + ReentrantLock(), + Set{Connection}(), + Threads.Event(), access_log, stream, Base.CoreLogging.current_logstate(), - :initializing, # state + :initializing, ) - server.server_options = aws_http_server_options( - 1, - allocator, - bootstrap, - pointer(FieldRef(server, :endpoint)), - pointer(FieldRef(server, :socket_options)), - server.tls_options === nothing ? C_NULL : pointer(FieldRef(server, :tls_options)), - initial_window_size, - pointer_from_objref(server), - on_incoming_connection[], - on_destroy_complete[], - false # manual_window_management + socket_opts = AwsIO.SocketOptions(; + domain = socket_domain == :ipv4 ? AwsIO.SocketDomain.IPV4 : AwsIO.SocketDomain.IPV6, + connect_timeout_ms = connect_timeout_ms, + keep_alive_interval_sec = keep_alive_interval_sec, + keep_alive_timeout_sec = keep_alive_timeout_sec, + keep_alive_max_failed_probes = keep_alive_max_failed_probes, + keepalive = keepalive, ) - server.server = aws_http_server_new(FieldRef(server, :server_options)) - @assert server.server != C_NULL "failed to create server" + alpn_list = _tls_alpn_list(tls_conn_opts) + if _use_nw_sockets() + socket_opts.impl_type = AwsIO.SocketImplType.APPLE_NETWORK_FRAMEWORK + end + initial_window = Csize_t(min(UInt64(initial_window_size), UInt64(typemax(Csize_t)))) + on_protocol_negotiated = tls_conn_opts !== nothing ? + (new_slot, protocol, user_data) -> begin + protocol_str = AwsIO.byte_buffer_as_string(protocol) + version = protocol_str == "h2" ? AwsHTTP.HttpVersion.HTTP_2 : AwsHTTP.HttpVersion.HTTP_1_1 + return AwsHTTP.http_connection_new_channel_handler(; + is_server=true, + version, + initial_window_size=initial_window, + ) + end : nothing + on_incoming_channel_setup = (bootstrap, error_code, channel, user_data) -> begin + Base.CoreLogging.with_logstate(server.logstate) do + if error_code != 0 + @error "incoming channel setup error" error_code + return + end + # For TLS, ALPN may have installed the H1/H2 handler. Otherwise install H1 manually. + http_conn = if tls_conn_opts !== nothing + handler = channel.last.handler + if handler isa AwsHTTP.H1Connection || handler isa AwsHTTP.H2Connection + handler + else + h = AwsHTTP.http_connection_new_channel_handler(; + is_server=true, + version=AwsHTTP.HttpVersion.HTTP_1_1, + initial_window_size=initial_window, + ) + slot = AwsIO.channel_slot_new!(channel) + AwsIO.channel_slot_insert_end!(channel, slot) + AwsIO.channel_slot_set_handler!(slot, h) + h.slot = slot + h + end + else + h = AwsHTTP.http_connection_new_channel_handler(; + is_server=true, + version=AwsHTTP.HttpVersion.HTTP_1_1, + initial_window_size=initial_window, + ) + slot = AwsIO.channel_slot_new!(channel) + AwsIO.channel_slot_insert_end!(channel, slot) + AwsIO.channel_slot_set_handler!(slot, h) + h.slot = slot + h + end + # Extract remote endpoint from the socket handler (first slot in pipeline) + remote_addr = "0.0.0.0" + remote_port_num = 0 + try + socket_handler = channel.first.handler + ep = socket_handler.socket.remote_endpoint + remote_addr = AwsIO.get_address(ep) + remote_port_num = Int(ep.port) + catch + end + http_conn.remote_endpoint = "$remote_addr:$remote_port_num" + conn = Connection(server, http_conn, channel, remote_addr, remote_port_num) + @lock server.connections_lock begin + push!(server.connections, conn) + end + if AwsHTTP.http_connection_get_version(http_conn) == AwsHTTP.HttpVersion.HTTP_2 + opts = AwsHTTP.HttpServerConnectionOptions( + connection_user_data = conn, + on_incoming_request = (h2conn, ud) -> begin + try + return _create_request_handler!(ud, h2conn; http2=true) + catch e + @error "failed to create HTTP/2 request handler" exception=(e, catch_backtrace()) + return nothing + end + end, + on_shutdown = (h2conn, err, ud) -> nothing, + ) + status = AwsHTTP.http_connection_configure_server(http_conn, opts) + if status != AwsHTTP.OP_SUCCESS + @error "failed to configure HTTP/2 server connection" error_code=status + return + end + else + _create_request_handler!(conn, http_conn; http2=false) + end + if AwsIO.channel_thread_is_callers_thread(channel) + AwsIO.channel_trigger_read(channel) + else + task = AwsIO.ChannelTask((task, ctx, status) -> begin + status == AwsIO.TaskStatus.RUN_READY || return nothing + AwsIO.channel_trigger_read(ctx.channel) + return nothing + end, (channel = channel,), "http_server_trigger_read") + AwsIO.channel_schedule_task_now!(channel, task) + end + end + return + end + on_incoming_channel_shutdown = (bootstrap, error_code, channel, user_data) -> begin + Base.CoreLogging.with_logstate(server.logstate) do + if _should_log_channel_shutdown_error(error_code) + @error "incoming channel shutdown error" error_code + end + @lock server.connections_lock begin + filter!(c -> c.channel !== channel, server.connections) + end + end + return + end + on_listener_destroy = (bootstrap, user_data) -> begin + notify(server.fut, :destroyed) + return + end + bootstrap_opts = AwsIO.ServerBootstrapOptions(; + event_loop_group = _EVENT_LOOP_GROUP[], + socket_options = socket_opts, + host = host_str, + port = UInt32(port_int), + tls_connection_options = tls_conn_opts, + on_protocol_negotiated = on_protocol_negotiated, + on_incoming_channel_setup = on_incoming_channel_setup, + on_incoming_channel_shutdown = on_incoming_channel_shutdown, + on_listener_destroy = on_listener_destroy, + user_data = server, + enable_read_back_pressure = false, + ) + bs = AwsIO.ServerBootstrap(bootstrap_opts) + bs isa AwsIO.ErrorResult && throw(AWSError("Failed to create server bootstrap")) + server.bootstrap = bs + # Retrieve the actual bound port (useful when port=0 or listenany) + if bs.listener_socket !== nothing + ep = AwsIO.socket_get_bound_address(bs.listener_socket) + server.bound_port = ep isa AwsIO.ErrorResult ? port_int : Int(ep.port) + else + server.bound_port = port_int + end @atomic server.state = :running return server end @@ -261,285 +514,84 @@ end function push_promise(parent::Stream, req::Request; pad_length::Integer=0, scheme=nothing, authority=nothing) parent.server_side || error("push_promise is only supported for server streams") parent.http2 || throw(ArgumentError("HTTP/2 stream required for push promise")) - req.version == HTTPVersion(2, 0) || throw(ArgumentError("push promise request must be HTTP/2")) - 0 <= pad_length <= 0xff || throw(ArgumentError("pad_length must be between 0 and 255")) + pad_length < 0 && throw(ArgumentError("pad_length must be >= 0")) + pad_length > typemax(UInt8) && throw(ArgumentError("pad_length must be <= $(typemax(UInt8))")) + isdefined(parent, :aws_stream) || throw(ArgumentError("HTTP stream is not initialized")) _push_promise_headers!(req, parent; scheme=scheme, authority=authority) - push_stream = Stream{typeof(parent.connection)}(parent.allocator, false, true, true) - push_stream.connection = parent.connection - push_stream.request = req - opts = aws_http2_send_push_promise_options( - sizeof(aws_http2_send_push_promise_options), - pointer_from_objref(push_stream), - on_server_stream_complete[], - on_destroy[], - UInt8(pad_length), - ) - stream_ptr = aws_http2_stream_send_push_promise(parent.ptr, req.ptr, Ref(opts)) - stream_ptr == C_NULL && aws_throw_error() - push_stream.ptr = stream_ptr - @lock parent.connection.streams_lock begin - push!(parent.connection.streams, push_stream) - end - retain_stream!(push_stream) - return push_stream -end - -function push_promise(parent::Stream, method::Union{String, Symbol}, path; headers=Header[], pad_length::Integer=0, scheme=nothing, authority=nothing) - req = Request(String(method), String(path), headers, nothing, true, parent.allocator) - return push_promise(parent, req; pad_length=pad_length, scheme=scheme, authority=authority) -end - -const on_incoming_connection = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_incoming_connection(aws_server, aws_conn, error_code, server_ptr) - server = unsafe_pointer_to_objref(server_ptr) - Base.CoreLogging.with_logstate(server.logstate) do - if error_code != 0 - @error "incoming connection error" exception=(aws_error(error_code), Base.backtrace()) - return - end - conn = Connection( - server, - server.allocator, - aws_conn, - ) - conn.connection_options = aws_http_server_connection_options( - 1, - pointer_from_objref(conn), - on_incoming_request[], - on_connection_shutdown[] - ) - if aws_http_connection_configure_server( - aws_conn, - FieldRef(conn, :connection_options) - ) != 0 - @error "failed to configure connection" exception=(aws_error(), Base.backtrace()) - return - end - @lock server.connections_lock begin - push!(server.connections, conn) - end - return - end -end - -const on_connection_shutdown = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_connection_shutdown(aws_conn, error_code, conn_ptr) - conn = unsafe_pointer_to_objref(conn_ptr) - Base.CoreLogging.with_logstate(conn.server.logstate) do - if error_code != 0 - @error "connection shutdown error" exception=(aws_error(error_code), Base.backtrace()) - end - @lock conn.server.connections_lock begin - delete!(conn.server.connections, conn) - end - return + msg = getfield(req, :msg) + if AwsHTTP.http_message_get_protocol_version(msg) != AwsHTTP.HttpVersion.HTTP_2 + converted = AwsHTTP.http2_message_new_from_http1(msg) + converted === nothing && throw(AWSError("Failed to convert push promise request to HTTP/2")) + setfield!(req, :msg, converted) + msg = converted end -end - -const on_incoming_request = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_incoming_request(aws_conn, conn_ptr) - conn = unsafe_pointer_to_objref(conn_ptr) - Base.CoreLogging.with_logstate(conn.server.logstate) do - stream = Stream{typeof(conn)}( - conn.allocator, - false, # decompress - aws_http_connection_get_version(aws_conn) == AWS_HTTP_VERSION_2, # http2 - true, - ) - stream.connection = conn - stream.request_handler_options = aws_http_request_handler_options( - 1, - aws_conn, - pointer_from_objref(stream), - on_request_headers[], - on_request_header_block_done[], - on_request_body[], - on_request_done[], - on_server_stream_complete[], - on_destroy[] - ) - stream.request = Request("", "") - stream.ptr = aws_http_stream_new_server_request_handler( - FieldRef(stream, :request_handler_options) - ) - if stream.ptr == C_NULL - @error "failed to create stream" exception=(aws_error(), Base.backtrace()) - else - @lock conn.streams_lock begin - push!(conn.streams, stream) - end - end - return stream.ptr - end -end - -const on_request_headers = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_request_headers(aws_stream_ptr, header_block, header_array::Ptr{aws_http_header}, num_headers, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - if header_block == AWS_HTTP_HEADER_BLOCK_TRAILING - trailers = stream.request.trailers - if trailers === nothing - trailers = Headers(stream.request.allocator) - stream.request.trailers = trailers - end - addheaders(trailers, header_array, num_headers) - else - headers = stream.request.headers - addheaders(headers, header_array, num_headers) - end - return Cint(0) -end - -const on_request_header_block_done = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_request_header_block_done(aws_stream_ptr, header_block, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - if header_block != AWS_HTTP_HEADER_BLOCK_MAIN - return Cint(0) - end - ret = aws_http_stream_get_incoming_request_method(aws_stream_ptr, FieldRef(stream, :method)) - ret != 0 && return ret - aws_http_message_set_request_method(stream.request.ptr, stream.method) - ret = aws_http_stream_get_incoming_request_uri(aws_stream_ptr, FieldRef(stream, :path)) - ret != 0 && return ret - aws_http_message_set_request_path(stream.request.ptr, stream.path) - notify(stream.headers_ready) - if stream.connection.server.stream && !stream.handler_started - stream.handler_started = true - stream.bufferstream === nothing && (stream.bufferstream = Base.BufferStream()) - Threads.@spawn begin - Base.CoreLogging.with_logstate(stream.connection.server.logstate) do - try - Base.invokelatest(stream.connection.server.f, stream) - catch e - @error "Request handler error; sending 500" exception=(e, catch_backtrace()) - if !stream.response_started - try - setstatus(stream, 500) - catch err - @error "failed to set 500 status" exception=(err, catch_backtrace()) - end - end - finally - try - closewrite(stream) - catch err - @error "failed to close response stream" exception=(err, catch_backtrace()) - end - end - end - end - end - return Cint(0) -end - -const on_request_body = Ref{Ptr{Cvoid}}(C_NULL) - -#TODO: how could we allow for streaming request bodies? -function c_on_request_body(aws_stream_ptr, data::Ptr{aws_byte_cursor}, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - bc = unsafe_load(data) - if stream.connection.server.stream - stream.bufferstream === nothing && (stream.bufferstream = Base.BufferStream()) - unsafe_write(stream.bufferstream, bc.ptr, bc.len) - return Cint(0) - end - body = stream.request.body - if body === nothing - body = Vector{UInt8}(undef, bc.len) - GC.@preserve body unsafe_copyto!(pointer(body), bc.ptr, bc.len) - stream.request.body = body - else - newlen = length(body) + bc.len - resize!(body, newlen) - GC.@preserve body unsafe_copyto!(pointer(body, length(body) - bc.len + 1), bc.ptr, bc.len) - end - return Cint(0) -end - -const on_request_done = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_request_done(aws_stream_ptr, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - Base.CoreLogging.with_logstate(stream.connection.server.logstate) do - if stream.connection.server.stream - stream.bufferstream !== nothing && close(stream.bufferstream) - return Cint(0) - end - try - stream.response = Base.invokelatest(stream.connection.server.f, stream.request)::Response - if stream.request.method == "HEAD" - setinputstream!(stream.response, nothing) + h2conn = parent.aws_stream.owning_connection + h2conn === nothing && throw(ArgumentError("HTTP/2 connection is not initialized")) + promised_id = h2conn.next_stream_id + promised_id > AwsHTTP.H2_STREAM_ID_MAX && throw(AWSError("HTTP/2 stream IDs exhausted")) + h2conn.next_stream_id += UInt32(2) + h2stream = _create_request_handler!(parent.connection, h2conn; http2=true) + h2stream === nothing && throw(AWSError("Failed to create push promise stream")) + push_stream = nothing + @lock parent.connection.streams_lock begin + for s in parent.connection.streams + if s.aws_stream === h2stream + push_stream = s + break end - #TODO: is it possible to stream the response body? - #TODO: support transfer-encoding: gzip - catch e - @error "Request handler error; sending 500" exception=(e, catch_backtrace()) - stream.response = Response(500) - end - ret = aws_http_stream_send_response(aws_stream_ptr, stream.response.ptr) - if ret != 0 - @error "failed to send response" exception=(aws_error(ret), Base.backtrace()) - return Cint(AWS_ERROR_HTTP_UNKNOWN) end - return Cint(0) end -end - -const on_server_stream_complete = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_server_stream_complete(aws_stream_ptr, error_code, stream_ptr) - stream = unsafe_pointer_to_objref(stream_ptr) - Base.CoreLogging.with_logstate(stream.connection.server.logstate) do - if error_code != 0 - @error "server complete error" exception=(aws_error(error_code), Base.backtrace()) - end - if stream.connection.server.on_stream_complete !== nothing - try - Base.invokelatest(stream.connection.server.on_stream_complete, stream) - catch e - @error "on_stream_complete error" exception=(e, catch_backtrace()) - end - end - if stream.on_complete !== nothing - try - Base.invokelatest(stream.on_complete, stream) - catch e - @error "stream on_complete error" exception=(e, catch_backtrace()) - end - stream.on_complete = nothing - end - if stream.connection.server.access_log !== nothing - try - @info sprint(stream.connection.server.access_log, stream) _group=:access - catch e - @error "access log error" exception=(e, catch_backtrace()) - end - end - @lock stream.connection.streams_lock begin - delete!(stream.connection.streams, stream) + push_stream === nothing && throw(AWSError("Failed to locate push promise stream")) + push_stream.request = req + notify(push_stream.headers_ready) + method_val = req.method + path_val = req.path + method_val === nothing && (method_val = "") + path_val === nothing && (path_val = "") + h2stream.id = promised_id + AwsHTTP.h2_stream_init_window_sizes!(h2stream, h2conn) + h2stream.metrics = AwsHTTP.HttpStreamMetrics( + h2stream.metrics.send_start_timestamp_ns, + h2stream.metrics.send_end_timestamp_ns, + h2stream.metrics.sending_duration_ns, + h2stream.metrics.receive_start_timestamp_ns, + h2stream.metrics.receive_end_timestamp_ns, + h2stream.metrics.receiving_duration_ns, + promised_id, + ) + h2stream.state = AwsHTTP.H2StreamState.RESERVED_LOCAL + h2stream.request_method = AwsHTTP.http_str_to_method(String(method_val)) + h2stream.request_method_str = String(method_val) + h2stream.request_path = String(path_val) + h2conn.active_streams[promised_id] = h2stream + headers = AwsHTTP.http_message_get_headers(msg) + status = AwsHTTP.h2_stream_send_push_promise!(parent.aws_stream, h2conn, promised_id, headers; + pad_length=UInt8(pad_length)) + if status != AwsHTTP.OP_SUCCESS + delete!(h2conn.active_streams, promised_id) + @lock parent.connection.streams_lock begin + delete!(parent.connection.streams, push_stream) end - release_stream_ptr!(stream) - return Cint(0) + throw(AWSError("Failed to send push promise")) end + return push_stream end -const on_destroy_complete = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_destroy_complete(server_ptr) - server = unsafe_pointer_to_objref(server_ptr) - notify(server.fut, :destroyed) - return +function push_promise(parent::Stream, method::Union{String, Symbol}, path; headers=Header[], pad_length::Integer=0, scheme=nothing, authority=nothing) + return push_promise(parent, Request(String(method), String(path), headers, nothing, true); pad_length=pad_length, scheme=scheme, authority=authority) end function Base.close(server::Server) state = @atomicswap server.state = :closed if state == :running - aws_http_server_release(server.server) + AwsIO.server_bootstrap_shutdown!(server.bootstrap) + conns = Connection[] + @lock server.connections_lock begin + append!(conns, server.connections) + end + for conn in conns + AwsIO.channel_shutdown!(conn.channel; shutdown_immediately=true) + end @assert wait(server.fut) == :destroyed notify(server.closed) end diff --git a/src/statistics.jl b/src/statistics.jl new file mode 100644 index 00000000..6d5ef1f1 --- /dev/null +++ b/src/statistics.jl @@ -0,0 +1,82 @@ +export AWSCRT_STAT_CAT_HTTP1_CHANNEL, AWSCRT_STAT_CAT_HTTP2_CHANNEL +export aws_crt_statistics_http1_channel, aws_crt_statistics_http2_channel +export _decode_statistics, _call_statistics_observer + +const AWSCRT_STAT_CAT_HTTP1_CHANNEL = :http1_channel +const AWSCRT_STAT_CAT_HTTP2_CHANNEL = :http2_channel + +struct aws_crt_statistics_http1_channel + category::Symbol + pending_outgoing_stream_ms::UInt64 + pending_incoming_stream_ms::UInt64 + current_outgoing_stream_id::UInt32 + current_incoming_stream_id::UInt32 +end + +struct aws_crt_statistics_http2_channel + category::Symbol + pending_outgoing_stream_ms::UInt64 + pending_incoming_stream_ms::UInt64 + was_inactive::Bool +end + +function _normalize_stat_category(category) + if category === AWSCRT_STAT_CAT_HTTP1_CHANNEL || category === AWSCRT_STAT_CAT_HTTP2_CHANNEL + return category + end + category isa Symbol && return category + return Symbol(category) +end + +function _normalize_stat(stat::aws_crt_statistics_http1_channel) + cat = _normalize_stat_category(stat.category) + cat === stat.category && return stat + return aws_crt_statistics_http1_channel( + cat, + stat.pending_outgoing_stream_ms, + stat.pending_incoming_stream_ms, + stat.current_outgoing_stream_id, + stat.current_incoming_stream_id, + ) +end + +function _normalize_stat(stat::aws_crt_statistics_http2_channel) + cat = _normalize_stat_category(stat.category) + cat === stat.category && return stat + return aws_crt_statistics_http2_channel( + cat, + stat.pending_outgoing_stream_ms, + stat.pending_incoming_stream_ms, + stat.was_inactive, + ) +end + +_normalize_stat(stat) = stat + +function _decode_statistics(stats_list) + list = stats_list isa Base.RefValue ? stats_list[] : stats_list + out = Any[] + if list isa AwsIO.ArrayList + for i in 1:length(list) + item = list[i] + item = item isa Base.RefValue ? item[] : item + push!(out, _normalize_stat(item)) + end + return out + end + if list isa AbstractVector + for item in list + item = item isa Base.RefValue ? item[] : item + push!(out, _normalize_stat(item)) + end + return out + end + throw(ArgumentError("stats_list must be ArrayList or AbstractVector")) +end + +function _call_statistics_observer(observer, nonce, stats_list) + observer === nothing && return nothing + stats = _decode_statistics(stats_list) + Base.invokelatest(observer, nonce, stats) + return nothing +end diff --git a/src/utils.jl b/src/utils.jl index 205ff379..7d503d03 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -3,6 +3,56 @@ export bytes, isbytes, nbytes, nobytes, const HTTP2_DEFAULT_WINDOW_SIZE = 65535 const HTTP2_MAX_WINDOW_SIZE = 0x7fffffff +const AWS_HTTP2_SETTINGS_MAX_CONCURRENT_STREAMS = AwsHTTP.Http2SettingsId.MAX_CONCURRENT_STREAMS +const AWS_HTTP2_SETTINGS_INITIAL_WINDOW_SIZE = AwsHTTP.Http2SettingsId.INITIAL_WINDOW_SIZE +const AWS_HTTP2_SETTINGS_COUNT = Int(AwsHTTP.HTTP2_SETTINGS_END_RANGE - AwsHTTP.HTTP2_SETTINGS_BEGIN_RANGE) +const _H2_CHANNEL_SUPPORTED = AwsHTTP.H2Connection <: AwsIO.AbstractChannelHandler + +function _normalize_alpn_list(alpn_list::Union{String, Nothing}) + alpn_list === nothing && return nothing + isempty(alpn_list) && return alpn_list + _H2_CHANNEL_SUPPORTED && return alpn_list + parts = split(alpn_list, ';'; keepempty = false) + filtered = [p for p in parts if lowercase(p) != "h2"] + isempty(filtered) && return "http/1.1" + return join(filtered, ';') +end + +@inline function _alpn_includes_h2(alpn_list::Union{String, Nothing})::Bool + alpn_list === nothing && return false + for part in split(alpn_list, ';'; keepempty = false) + lowercase(part) == "h2" && return true + end + return false +end + +function _should_use_nw_tls(alpn_list::Union{String, Nothing})::Bool + @static if Sys.isapple() + return _use_nw_sockets() && _alpn_includes_h2(alpn_list) + else + return false + end +end + +function _use_nw_sockets()::Bool + @static if Sys.isapple() + AwsIO._tls_set_use_secitem_from_env() + return AwsIO._NW_SHIM_LIB != "" && AwsIO.is_using_secitem() + else + return false + end +end + +function _tls_alpn_list(tls_opts) + tls_opts === nothing && return nothing + if hasproperty(tls_opts, :alpn_list) && tls_opts.alpn_list !== nothing + return tls_opts.alpn_list + end + if hasproperty(tls_opts, :ctx) && tls_opts.ctx !== nothing + return tls_opts.ctx.options.alpn_list + end + return nothing +end """ HTTPVersion(major, minor) @@ -149,8 +199,7 @@ function ascii_lc_isequal(a, b) return true end -function parseuri(url, query, allocator) - uri_ref = Ref{aws_uri}() +function parseuri(url, query) if url isa AbstractString url_str = String(url) * (query === nothing ? "" : ("?" * URIs.escapeuri(query))) elseif url isa URI @@ -158,13 +207,12 @@ function parseuri(url, query, allocator) else throw(ArgumentError("url must be an AbstractString or URI")) end - GC.@preserve url_str begin - url_ref = Ref(aws_byte_cursor(sizeof(url_str), pointer(url_str))) - aws_uri_init_parse(uri_ref, allocator, url_ref) != 0 && aws_throw_error() - end - return uri_ref[] + return URIs.URI(url_str) end +# compatibility: 3-arg version for callers that still pass allocator +parseuri(url, query, _allocator) = parseuri(url, query) + """ bytes(x) @@ -194,74 +242,123 @@ nbytes(x::Vector{IOBuffer}) = sum(bytesavailable, x) const nobytes = view(UInt8[], 1:0) -str(bc::aws_byte_cursor) = bc.ptr == C_NULL || bc.len == 0 ? "" : unsafe_string(bc.ptr, bc.len) - -function print_uri(io, uri::aws_uri) - print(io, "scheme: ", str(uri.scheme), "\n") - print(io, "userinfo: ", str(uri.userinfo), "\n") - print(io, "host_name: ", str(uri.host_name), "\n") - print(io, "port: ", Int(uri.port), "\n") - print(io, "path: ", str(uri.path), "\n") - print(io, "query: ", str(uri.query_string), "\n") - return -end - -scheme(uri::aws_uri) = str(uri.scheme) -userinfo(uri::aws_uri) = str(uri.userinfo) -host(uri::aws_uri) = str(uri.host_name) -port(uri::aws_uri) = uri.port -path(uri::aws_uri) = str(uri.path) -query(uri::aws_uri) = str(uri.query_string) - -function resource(uri::aws_uri) - ref = Ref(uri) - GC.@preserve ref begin - bc = aws_uri_path_and_query(ref) - path = str(unsafe_load(bc)) - return isempty(path) ? "/" : path +# URI accessor helpers that work on URIs.URI +scheme(uri::URI) = uri.scheme +userinfo(uri::URI) = uri.userinfo +host(uri::URI) = uri.host +function port(uri::URI) + p = uri.port + if p === nothing || isempty(p) + return UInt32(0) end + return UInt32(parse(Int, p)) +end +path(uri::URI) = uri.path +query(uri::URI) = uri.query + +function resource(uri::URI) + p = uri.path + q = uri.query + r = isempty(p) ? "/" : p + return isempty(q) ? r : string(r, "?", q) end const URI_SCHEME_HTTPS = "https" const URI_SCHEME_WSS = "wss" -ishttps(sch) = aws_byte_cursor_eq_c_str_ignore_case(sch, URI_SCHEME_HTTPS) -iswss(sch) = aws_byte_cursor_eq_c_str_ignore_case(sch, URI_SCHEME_WSS) -function getport(uri::aws_uri) - sch = Ref(uri.scheme) - GC.@preserve sch begin - return UInt32(uri.port != 0 ? uri.port : (ishttps(sch) || iswss(sch)) ? 443 : 80) - end +ishttps(sch::AbstractString) = lowercase(sch) == URI_SCHEME_HTTPS +iswss(sch::AbstractString) = lowercase(sch) == URI_SCHEME_WSS +function getport(uri::URI) + p = port(uri) + return p != 0 ? p : (ishttps(scheme(uri)) || iswss(scheme(uri))) ? UInt32(443) : UInt32(80) end -function makeuri(u::aws_uri) - return URIs.URI( - scheme=str(u.scheme), - userinfo=isempty(str(u.userinfo)) ? URIs.absent : str(u.userinfo), - host=str(u.host_name), - port=u.port == 0 ? URIs.absent : u.port, - path=isempty(str(u.path)) ? URIs.absent : str(u.path), - query=isempty(str(u.query_string)) ? URIs.absent : str(u.query_string), - ) -end +makeuri(u::URI) = u struct AWSError <: Exception msg::String end -aws_error() = AWSError(unsafe_string(aws_error_debug_str(aws_last_error()))) -aws_error(error_code) = AWSError(unsafe_string(aws_error_str(error_code))) +function _resolve_error_str(error_code::Integer) + ec = Int(error_code) + # AwsHTTP has its own String-based error table for HTTP-range codes + if ec >= AwsHTTP.ERROR_HTTP_UNKNOWN && ec <= AwsHTTP.ERROR_HTTP_END_RANGE + return AwsHTTP.http_error_str(ec) + end + # AwsIO.error_str returns Ptr{UInt8}; convert to String + return unsafe_string(AwsIO.error_str(ec)) +end + +aws_error() = AWSError(_resolve_error_str(AwsIO.last_error())) +aws_error(error_code) = AWSError(_resolve_error_str(error_code)) aws_throw_error() = throw(aws_error()) +# Simple Future type for async callback coordination. +# Replaces LibAwsCommon.Future. Supports notify/wait pattern: +# notify(f, value::T) -> success +# notify(f, err::Exception) -> error +# wait(f) -> returns T or throws Exception +mutable struct Future{T} + const notify_cond::Threads.Condition + @atomic set::Int8 # 0=pending, 1=success, 2=error + result::Union{Exception, T} + Future{T}() where {T} = new{T}(Threads.Condition(), 0) +end + +Future() = Future{Nothing}() + +function Base.wait(f::Future{T}) where {T} + set = @atomic f.set + set == 1 && return f.result::T + set == 2 && throw(f.result::Exception) + lock(f.notify_cond) + try + set = f.set + set == 1 && return f.result::T + set == 2 && throw(f.result::Exception) + wait(f.notify_cond) + finally + unlock(f.notify_cond) + end + f.set == 1 && return f.result::T + throw(f.result::Exception) +end + +function Base.notify(f::Future{T}, result::T) where {T} + lock(f.notify_cond) + try + f.set != 0 && return + f.result = result + @atomic f.set = 1 + notify(f.notify_cond) + finally + unlock(f.notify_cond) + end + return +end + +function Base.notify(f::Future, err::Exception) + lock(f.notify_cond) + try + f.set != 0 && return + f.result = err + @atomic f.set = 2 + notify(f.notify_cond) + finally + unlock(f.notify_cond) + end + return +end + struct BufferOnResponseBody{T <: AbstractVector{UInt8}} buffer::T - pos::Ptr{Int} + pos::Ref{Int} end function (f::BufferOnResponseBody)(resp, buf) len = length(buf) - pos = unsafe_load(f.pos) + pos = f.pos[] copyto!(f.buffer, pos, buf, 1, len) - unsafe_store!(f.pos, pos + len) + f.pos[] = pos + len return len end diff --git a/src/websockets.jl b/src/websockets.jl index 5a95eef5..b3c773da 100644 --- a/src/websockets.jl +++ b/src/websockets.jl @@ -1,11 +1,21 @@ module WebSockets -using Base64, Random, LibAwsHTTPFork, LibAwsCommon, LibAwsIO - -import ..FieldRef, ..iswss, ..getport, ..makeuri, ..aws_throw_error, ..resource, ..Headers, ..Header, ..str, ..aws_error, ..aws_throw_error, ..Future, ..parseuri, ..with_redirect, ..with_request, ..getclient, ..ClientSettings, ..scheme, ..host, ..getport, ..userinfo, ..Client, ..Request, ..Response, ..Message, ..setinputstream!, ..getresponse, ..CookieJar, ..COOKIEJAR, ..addheaders, ..Stream, ..HTTP, ..getheader, ..hasheader, ..header +using Base64, Random, AwsHTTP, AwsIO + +import ..Headers, ..Header, ..Request, ..Response, ..Message, ..Stream +import ..setinputstream!, ..getresponse, ..getheader, ..hasheader, ..header +import ..addheader, ..setheader, ..removeheader +import ..Future, ..parseuri, ..with_redirect, ..with_request, ..getclient +import ..ClientSettings, ..scheme, ..host, ..getport, ..userinfo, ..resource +import ..Client, ..CookieJar, ..COOKIEJAR, ..with_connection, .._open_stream +import ..aws_throw_error, ..aws_error, ..AWSError +import .._h1_flush_outgoing!, ..iswss, ..makeuri, ..HTTP, ..mkreqheaders +import ..startread, ..closeread, ..startwrite, ..closewrite export WebSocket, send, receive, ping, pong +# ─── Types ─── + @enum OpCode::UInt8 CONTINUATION=0x00 TEXT=0x01 BINARY=0x02 CLOSE=0x08 PING=0x09 PONG=0x0A const DEFAULT_MAX_FRAG = 1024 @@ -34,22 +44,106 @@ isupgrade(s::Stream) = isupgrade(s.request) Base.@deprecate is_upgrade isupgrade +# ─── WsChannelHandler ─── +# Bridges the AwsIO channel pipeline with the AwsHTTP WebSocket codec. +# Installed into the H1Connection's channel slot after HTTP 101 upgrade. + +mutable struct WsChannelHandler <: AwsIO.AbstractChannelHandler + slot::Union{AwsIO.ChannelSlot, Nothing} + aws_ws::Any # AwsHTTP.WebSocket + wslock::ReentrantLock # protects outgoing_frames access +end + +WsChannelHandler(aws_ws) = WsChannelHandler(nothing, aws_ws, ReentrantLock()) + +function AwsIO.setchannelslot!(handler::WsChannelHandler, slot::AwsIO.ChannelSlot)::Nothing + handler.slot = slot + return nothing +end + +function AwsIO.handler_process_read_message(handler::WsChannelHandler, slot::AwsIO.ChannelSlot, message::AwsIO.IoMessage)::Union{Nothing, AwsIO.ErrorResult} + data = AwsIO.byte_buffer_as_vector(message.message_data) + isempty(data) && return nothing + @lock handler.wslock begin + status, _ = AwsHTTP.ws_on_incoming_data!(handler.aws_ws, data) + if status != AwsHTTP.OP_SUCCESS + return AwsIO.ErrorResult(status) + end + # Flush auto-responses (PONG, CLOSE echo) generated by ws_on_incoming_data! + _ws_channel_flush!(handler) + end + return nothing +end + +function AwsIO.handler_process_write_message(handler::WsChannelHandler, slot::AwsIO.ChannelSlot, message::AwsIO.IoMessage)::Union{Nothing, AwsIO.ErrorResult} + # Pass through to lower pipeline (socket) + return AwsIO.channel_slot_send_message(slot, message, AwsIO.ChannelDirection.WRITE) +end + +function AwsIO.handler_increment_read_window(handler::WsChannelHandler, slot::AwsIO.ChannelSlot, size::Csize_t)::Union{Nothing, AwsIO.ErrorResult} + return AwsIO.channel_slot_increment_read_window!(slot, size) +end + +function AwsIO.handler_shutdown( + handler::WsChannelHandler, + slot::AwsIO.ChannelSlot, + direction::AwsIO.ChannelDirection.T, + error_code::Int, + free_scarce_resources_immediately::Bool, + )::Union{Nothing, AwsIO.ErrorResult} + AwsIO.channel_slot_on_handler_shutdown_complete!(slot, direction, error_code, free_scarce_resources_immediately) + return nothing +end + +AwsIO.handler_initial_window_size(::WsChannelHandler)::Csize_t = Csize_t(typemax(UInt64)) +AwsIO.handler_message_overhead(::WsChannelHandler)::Csize_t = Csize_t(0) + +# Collect outgoing frames from the AwsHTTP WebSocket codec and send them +# through the channel pipeline. Caller MUST hold handler.wslock. +function _ws_channel_flush!(handler::WsChannelHandler) + outdata = AwsHTTP.ws_get_outgoing_data!(handler.aws_ws) + isempty(outdata) && return + slot = handler.slot + slot === nothing && return + channel = slot.channel + channel === nothing && return + msg = AwsIO.IoMessage(length(outdata)) + buf = msg.message_data + @inbounds for i in 1:length(outdata) + buf.mem[i] = outdata[i] + end + buf.len = Csize_t(length(outdata)) + if AwsIO.channel_thread_is_callers_thread(channel) + AwsIO.channel_slot_send_message(slot, msg, AwsIO.ChannelDirection.WRITE) + return + end + task = AwsIO.ChannelTask((task, ctx, status) -> begin + status == AwsIO.TaskStatus.RUN_READY || return nothing + AwsIO.channel_slot_send_message(ctx.slot, ctx.msg, AwsIO.ChannelDirection.WRITE) + return nothing + end, (slot=slot, msg=msg), "http_ws_flush") + AwsIO.channel_schedule_task_now!(channel, task) + return +end + +# ─── WebSocket struct ─── + mutable struct WebSocket id::String host::String path::String maxframesize::Int maxfragmentation::Int - connect_fut::Future{Nothing} readchannel::Channel{Union{String, Vector{UInt8}, WebSocketError}} - writebuffer::Vector{UInt8} - writepos::Int writeclosed::Bool closelock::ReentrantLock sendlock::ReentrantLock handshake_request::Union{Nothing, Request} - websocket_pointer::Ptr{aws_websocket} handshake_response::Union{Nothing, Response} + # AwsHTTP WebSocket codec + channel handler + aws_ws::Any # AwsHTTP.WebSocket + handler::Any # WsChannelHandler + # Fragment tracking incoming_opcode::UInt8 incoming_fin::Bool incoming_payload::Vector{UInt8} @@ -65,16 +159,14 @@ mutable struct WebSocket String(path), Int(maxframesize), Int(maxfragmentation), - Future{Nothing}(), Channel{Union{String, Vector{UInt8}, WebSocketError}}(Inf), - UInt8[], - 0, false, ReentrantLock(), ReentrantLock(), nothing, - C_NULL, nothing, + nothing, # aws_ws + nothing, # handler 0x00, false, UInt8[], @@ -88,6 +180,8 @@ end getresponse(ws::WebSocket) = ws.handshake_response +# ─── Internal helpers ─── + function _queue_close!(ws::WebSocket, body::CloseFrameBody) ws.closebody = body if isopen(ws.readchannel) @@ -126,393 +220,173 @@ function close_payload(body::CloseFrameBody) return payload end -mutable struct SendState - ws::WebSocket - fut::Future{Nothing} -end - -const on_connection_setup = Ref{Ptr{Cvoid}}(C_NULL) +isbinary(x) = x isa AbstractVector{UInt8} +istext(x) = x isa AbstractString +opcode(x) = isbinary(x) ? BINARY : TEXT -function c_on_connection_setup(connection_setup_data::Ptr{aws_websocket_on_connection_setup_data}, ws_ptr) - ws = unsafe_pointer_to_objref(ws_ptr) - data = unsafe_load(connection_setup_data) - try - if data.error_code != 0 - notify(ws.connect_fut, CapturedException(aws_error(data.error_code), Base.backtrace())) - else - ws.websocket_pointer = data.websocket - resp = ws.handshake_response - @assert resp !== nothing - resp.status = unsafe_load(data.handshake_response_status) - addheaders(resp.headers, data.handshake_response_header_array, data.num_handshake_response_headers) - if data.handshake_response_body != C_NULL - handshake_response_body = unsafe_load(data.handshake_response_body) - response_body = str(handshake_response_body) - else - response_body = nothing - end - setinputstream!(resp, response_body) - notify(ws.connect_fut, nothing) - end - catch e - notify(ws.connect_fut, CapturedException(e, Base.backtrace())) - end - return -end +_to_bytes(x::AbstractVector{UInt8}) = x +_to_bytes(x) = Vector{UInt8}(codeunits(string(x))) -const on_connection_shutdown = Ref{Ptr{Cvoid}}(C_NULL) +# ─── AwsHTTP WebSocket callback builders ─── +# These create the closure callbacks passed to AwsHTTP.ws_new(). +# Each closure captures the HTTP.WebSocket and manipulates it directly. -function c_on_connection_shutdown(websocket::Ptr{aws_websocket}, error_code::Cint, ws_ptr) - ws = unsafe_pointer_to_objref(ws_ptr) - if error_code != 0 - @error "$(ws.id): connection shutdown error" exception=(aws_error(error_code), Base.backtrace()) - if ws.closebody === nothing - _queue_close!(ws, CloseFrameBody(1006, "")) +function _on_incoming_frame_begin(ws::WebSocket) + return (aws_ws, frame_info, user_data) -> begin + ws.incoming_opcode = frame_info.opcode + ws.incoming_fin = frame_info.fin + empty!(ws.incoming_payload) + ws.drop_incoming = false + if frame_info.payload_length > ws.maxframesize + close_body = CloseFrameBody(1009, "frame too large") + _queue_close!(ws, close_body) + Threads.@spawn close(ws, close_body) + ws.drop_incoming = true end - else - _close_channel!(ws) - end - ws.websocket_pointer = C_NULL - ws.writeclosed = true - return -end - -const on_incoming_frame_begin = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_incoming_frame_begin(websocket::Ptr{aws_websocket}, frame::Ptr{aws_websocket_incoming_frame}, ws_ptr) - ws = unsafe_pointer_to_objref(ws_ptr) - fr = unsafe_load(frame) - ws.incoming_opcode = fr.opcode - ws.incoming_fin = fr.fin - empty!(ws.incoming_payload) - ws.drop_incoming = false - if fr.payload_length > ws.maxframesize - close_body = CloseFrameBody(1009, "frame too large") - _queue_close!(ws, close_body) - Threads.@spawn close(ws, close_body) - ws.drop_incoming = true + frame_info.payload_length > 0 && sizehint!(ws.incoming_payload, Int(frame_info.payload_length)) return true end - fr.payload_length > 0 && sizehint!(ws.incoming_payload, Int(fr.payload_length)) - return true -end - -const on_incoming_frame_payload = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_incoming_frame_payload(websocket::Ptr{aws_websocket}, frame::Ptr{aws_websocket_incoming_frame}, data::aws_byte_cursor, ws_ptr) - ws = unsafe_pointer_to_objref(ws_ptr) - ws.drop_incoming && return true - try - n = Int(data.len) - n == 0 && return true - payload = ws.incoming_payload - start = length(payload) + 1 - resize!(payload, length(payload) + n) - Base.unsafe_copyto!(pointer(payload, start), data.ptr, n) - catch e - @error "$(ws.id): incoming frame payload error" exception=(e, catch_backtrace()) - end - return true end -const on_incoming_frame_complete = Ref{Ptr{Cvoid}}(C_NULL) - -function c_on_incoming_frame_complete(websocket::Ptr{aws_websocket}, frame::Ptr{aws_websocket_incoming_frame}, error_code::Cint, ws_ptr) - ws = unsafe_pointer_to_objref(ws_ptr) - if error_code != 0 - @error "$(ws.id): incoming frame complete error" exception=(aws_error(error_code), Base.backtrace()) - close_body = CloseFrameBody(1006, "") - _queue_close!(ws, close_body) - Threads.@spawn close(ws, close_body) - return true - end - if ws.drop_incoming - ws.drop_incoming = false - return true - end - fr = unsafe_load(frame) - opcode = fr.opcode - fin = fr.fin - payload = ws.incoming_payload - if opcode == UInt8(PING) - payload_copy = copy(payload) - Threads.@spawn begin - try - pong(ws, payload_copy) - catch e - @error "$(ws.id): failed to send pong" exception=(e, catch_backtrace()) - end - end - return true - elseif opcode == UInt8(PONG) - return true - elseif opcode == UInt8(CLOSE) - body = payload - close_body = if length(body) >= 2 - code = (Int(body[1]) << 8) | Int(body[2]) - reason = length(body) > 2 ? String(copy(body[3:end])) : "" - CloseFrameBody(code, reason) - else - CloseFrameBody(1005, "") - end - Threads.@spawn begin - try - ws.writeclosed || close(ws, close_body) - catch e - @error "$(ws.id): failed to close websocket" exception=(e, catch_backtrace()) - end +function _on_incoming_frame_payload(ws::WebSocket) + return (aws_ws, frame_info, data, user_data) -> begin + ws.drop_incoming && return true + try + n = length(data) + n == 0 && return true + append!(ws.incoming_payload, data) + catch e + @error "$(ws.id): incoming frame payload error" exception=(e, catch_backtrace()) end - _queue_close!(ws, close_body) return true end - if opcode == UInt8(CONTINUATION) - if ws.fragment_opcode === nothing - close_body = CloseFrameBody(1002, "unexpected continuation") +end + +function _on_incoming_frame_complete(ws::WebSocket) + return (aws_ws, frame_info, error_code, user_data) -> begin + if error_code != 0 + @error "$(ws.id): incoming frame complete error" error_code + close_body = CloseFrameBody(1006, "") _queue_close!(ws, close_body) Threads.@spawn close(ws, close_body) return true end - ws.fragment_count += 1 - if ws.fragment_count > ws.maxfragmentation - close_body = CloseFrameBody(1009, "message too large") - _queue_close!(ws, close_body) - Threads.@spawn close(ws, close_body) + if ws.drop_incoming + ws.drop_incoming = false + return true + end + op = frame_info.opcode + fin = frame_info.fin + payload = ws.incoming_payload + # PING/PONG: AwsHTTP handles auto-PONG, nothing to do here + if op == UInt8(PING) || op == UInt8(PONG) return true end - append!(ws.fragment_payload, payload) - if fin - msg_opcode = ws.fragment_opcode - data = ws.fragment_payload - ws.fragment_opcode = nothing - ws.fragment_payload = UInt8[] - ws.fragment_count = 0 - if msg_opcode == UInt8(TEXT) - _enqueue_message!(ws, String(copy(data))) + # CLOSE: AwsHTTP auto-echoes, we just track state + if op == UInt8(CLOSE) + body = payload + close_body = if length(body) >= 2 + code = (Int(body[1]) << 8) | Int(body[2]) + reason = length(body) > 2 ? String(copy(body[3:end])) : "" + CloseFrameBody(code, reason) else - _enqueue_message!(ws, copy(data)) + CloseFrameBody(1005, "") end - end - return true - end - if opcode == UInt8(TEXT) || opcode == UInt8(BINARY) - if ws.fragment_opcode !== nothing - close_body = CloseFrameBody(1002, "unexpected new data frame") + # AwsHTTP will echo the CLOSE and set close_sent=true, is_open=false + ws.writeclosed = true _queue_close!(ws, close_body) - Threads.@spawn close(ws, close_body) return true end - if fin - if opcode == UInt8(TEXT) - _enqueue_message!(ws, String(copy(payload))) - else - _enqueue_message!(ws, copy(payload)) + # Data frames: TEXT, BINARY, CONTINUATION + if op == UInt8(CONTINUATION) + if ws.fragment_opcode === nothing + close_body = CloseFrameBody(1002, "unexpected continuation") + _queue_close!(ws, close_body) + Threads.@spawn close(ws, close_body) + return true end - ws.fragment_count = 0 - else - ws.fragment_opcode = opcode - ws.fragment_payload = copy(payload) - ws.fragment_count = 1 + ws.fragment_count += 1 if ws.fragment_count > ws.maxfragmentation close_body = CloseFrameBody(1009, "message too large") _queue_close!(ws, close_body) Threads.@spawn close(ws, close_body) return true end - end - end - return true -end - -function open(f::Function, url; - suppress_close_error::Bool=false, - headers=[], - maxframesize::Integer=typemax(Int), - maxfragmentation::Integer=DEFAULT_MAX_FRAG, - allocator::Ptr{aws_allocator}=default_aws_allocator(), - username=nothing, - password=nothing, - bearer=nothing, - query=nothing, - client::Union{Nothing, Client}=nothing, - # redirect options - redirect=true, - redirect_limit=3, - redirect_method=nothing, - forwardheaders=true, - # cookie options - cookies=true, - cookiejar::CookieJar=COOKIEJAR, - modifier=nothing, - verbose=0, - # client keywords - kw... - ) - key = base64encode(rand(Random.RandomDevice(), UInt8, 16)) - uri = parseuri(url, query, allocator) - # add required websocket headers - append!(headers, [ - "upgrade" => "websocket", - "connection" => "upgrade", - "sec-websocket-key" => key, - "sec-websocket-version" => "13" - ]) - ws = with_redirect(allocator, "GET", uri, headers, nothing, redirect, redirect_limit, redirect_method, forwardheaders) do method, uri, headers, body - reqclient = @something(client, getclient(ClientSettings(scheme(uri), host(uri), getport(uri); allocator=allocator, ssl_alpn_list="http/1.1", kw...)))::Client - path = resource(uri) - with_request(reqclient, method, path, headers, body, nothing, false, (username !== nothing && password !== nothing) ? "$username:$password" : userinfo(uri), bearer, modifier, false, cookies, cookiejar, verbose) do req - host = str(uri.host_name) - ws = WebSocket(host, path; maxframesize=maxframesize, maxfragmentation=maxfragmentation) - ws.handshake_request = req - ws.handshake_response = Response(0, nothing, nothing, false, allocator) - options = aws_websocket_client_connection_options( - allocator, - reqclient.settings.bootstrap, - pointer(FieldRef(reqclient, :socket_options)), - reqclient.tls_options === nothing ? C_NULL : pointer(FieldRef(reqclient, :tls_options)), - reqclient.proxy_options === nothing ? C_NULL : pointer(FieldRef(reqclient, :proxy_options)), - uri.host_name, - uri.port, - ws.handshake_request.ptr, - 0, # initial_window_size - Ptr{Cvoid}(pointer_from_objref(ws)), # user_data - on_connection_setup[], - on_connection_shutdown[], - on_incoming_frame_begin[], - on_incoming_frame_payload[], - on_incoming_frame_complete[], - false, # manual_window_management - C_NULL, # requested_event_loop - C_NULL, # host_resolution_config - ) - if aws_websocket_client_connect(Ref(options)) != 0 - aws_throw_error() + append!(ws.fragment_payload, payload) + if fin + msg_opcode = ws.fragment_opcode + data = ws.fragment_payload + ws.fragment_opcode = nothing + ws.fragment_payload = UInt8[] + ws.fragment_count = 0 + if msg_opcode == UInt8(TEXT) + _enqueue_message!(ws, String(copy(data))) + else + _enqueue_message!(ws, copy(data)) + end end - # wait until connected - wait(ws.connect_fut) - return ws - end - end - verbose > 0 && @info "$(ws.id): WebSocket opened" - try - f(ws) - catch e - if !isok(e) - suppress_close_error || @error "$(ws.id): error" exception=(e, catch_backtrace()) + return true end - if !isclosed(ws) - if e isa WebSocketError && e.message isa CloseFrameBody - close(ws, e.message) - else - close(ws, CloseFrameBody(1008, "Unexpected client websocket error")) + if op == UInt8(TEXT) || op == UInt8(BINARY) + if ws.fragment_opcode !== nothing + close_body = CloseFrameBody(1002, "unexpected new data frame") + _queue_close!(ws, close_body) + Threads.@spawn close(ws, close_body) + return true end - end - if !isok(e) - rethrow() - end - finally - if !isclosed(ws) - close(ws, CloseFrameBody(1000, "")) - end - end -end - -function Base.close(ws::WebSocket, body::Union{Nothing, CloseFrameBody}=nothing) - @lock ws.closelock begin - if ws.writeclosed - _close_channel!(ws) - return - end - ws.writeclosed = true - if ws.websocket_pointer != C_NULL - if body !== nothing - payload_bytes = close_payload(body) - @lock ws.sendlock begin - try - writeframe(ws, true, CLOSE, payload(ws, payload_bytes)) - catch - # ignore errors while closing - end + if fin + if op == UInt8(TEXT) + _enqueue_message!(ws, String(copy(payload))) + else + _enqueue_message!(ws, copy(payload)) + end + ws.fragment_count = 0 + else + ws.fragment_opcode = op + ws.fragment_payload = copy(payload) + ws.fragment_count = 1 + if ws.fragment_count > ws.maxfragmentation + close_body = CloseFrameBody(1009, "message too large") + _queue_close!(ws, close_body) + Threads.@spawn close(ws, close_body) + return true end end - aws_websocket_close(ws.websocket_pointer, false) - ws.websocket_pointer = C_NULL end - _close_channel!(ws) + return true end - return -end - -""" - WebSockets.isclosed(ws) -> Bool - -Check whether a `WebSocket` has sent and received CLOSE frames -""" -isclosed(ws::WebSocket) = !isopen(ws.readchannel) && ws.writeclosed - -isbinary(x) = x isa AbstractVector{UInt8} -istext(x) = x isa AbstractString -opcode(x) = isbinary(x) ? BINARY : TEXT - -function payload(ws, x) - pload = isbinary(x) ? x : codeunits(string(x)) - len = length(pload) - resize!(ws.writebuffer, len) - copyto!(ws.writebuffer, pload) - ws.writepos = 1 - return ws.writebuffer end -const stream_outgoing_payload = Ref{Ptr{Cvoid}}(C_NULL) - -function c_stream_outgoing_payload(websocket::Ptr{aws_websocket}, out_buf::Ptr{aws_byte_buf}, ws_ptr::Ptr{Cvoid}) - state = unsafe_pointer_to_objref(ws_ptr) - ws = state.ws - out = unsafe_load(out_buf) - try - space_available = out.capacity - out.len - amount_to_send = min(space_available, sizeof(ws.writebuffer) - ws.writepos + 1) - cursor = aws_byte_cursor(amount_to_send, pointer(ws.writebuffer, ws.writepos)) - @assert aws_byte_buf_write_from_whole_cursor(out_buf, cursor) - ws.writepos += amount_to_send - catch e - @error "$(ws.id): error" (e, catch_backtrace()) - return false - end - return true +# Create an AwsHTTP WebSocket and WsChannelHandler, then install +# the handler into the H1Connection's channel slot. +function _create_ws_handler!(ws::WebSocket, slot::AwsIO.ChannelSlot, is_client::Bool) + aws_ws = AwsHTTP.ws_new(; + is_client=is_client, + on_incoming_frame_begin=_on_incoming_frame_begin(ws), + on_incoming_frame_payload=_on_incoming_frame_payload(ws), + on_incoming_frame_complete=_on_incoming_frame_complete(ws), + ) + handler = WsChannelHandler(aws_ws) + ws.aws_ws = aws_ws + ws.handler = handler + AwsIO.channel_slot_set_handler!(slot, handler) + return end -const on_complete = Ref{Ptr{Cvoid}}(C_NULL) +# ─── writeframe ─── -function c_on_complete(websocket::Ptr{aws_websocket}, error_code::Cint, ws_ptr::Ptr{Cvoid}) - state = unsafe_pointer_to_objref(ws_ptr) - if error_code != 0 - notify(state.fut, CapturedException(aws_error(error_code), Base.backtrace())) - else - notify(state.fut, nothing) +function writeframe(ws::WebSocket, fin::Bool, opcode::OpCode, payload::AbstractVector{UInt8}) + handler = ws.handler + handler === nothing && throw(WebSocketError(CloseFrameBody(1006, "WebSocket not connected"))) + @lock handler.wslock begin + ret = AwsHTTP.ws_send_frame!(handler.aws_ws, UInt8(opcode), payload; fin=fin) + ret != AwsHTTP.OP_SUCCESS && throw(AWSError("ws_send_frame! failed")) + _ws_channel_flush!(handler) end - return + return length(payload) end -function writeframe(ws::WebSocket, fin::Bool, opcode::OpCode, payload) - n = sizeof(payload) - state = SendState(ws, Future{Nothing}()) - opts = aws_websocket_send_frame_options( - n % UInt64, - Ptr{Cvoid}(pointer_from_objref(state)), # user_data - stream_outgoing_payload[], - on_complete[], - UInt8(opcode), - fin - ) - GC.@preserve state opts begin - if aws_websocket_send_frame(ws.websocket_pointer, Ref(opts)) != 0 - aws_throw_error() - end - # wait until frame sent - wait(state.fut) - return n - end -end +# ─── Public API ─── """ send(ws::WebSocket, msg) @@ -531,23 +405,18 @@ function send(ws::WebSocket, x) @lock ws.sendlock begin @assert !ws.writeclosed "WebSocket is closed" if !isbinary(x) && !istext(x) - # if x is not single binary or text, then assume it's an iterable of binary or text - # and we'll send fragmented message + # iterable of binary or text → fragmented message first = true n = 0 state = iterate(x) if state === nothing - # x was not binary or text, but is an empty iterable, send single empty frame x = "" @goto write_single_frame end - @debug "$(ws.id): Writing fragmented message" item, st = state - # we prefetch next state so we know if we're on the last item or not - # so we can appropriately set the FIN bit for the last fragmented frame nextstate = iterate(x, st) while true - n += writeframe(ws, nextstate === nothing, first ? opcode(item) : CONTINUATION, payload(ws, item)) + n += writeframe(ws, nextstate === nothing, first ? opcode(item) : CONTINUATION, _to_bytes(item)) first = false nextstate === nothing && break item, st = nextstate @@ -555,14 +424,12 @@ function send(ws::WebSocket, x) end return n else - # single binary or text frame for message @label write_single_frame - return writeframe(ws, true, opcode(x), payload(ws, x)) + return writeframe(ws, true, opcode(x), _to_bytes(x)) end end end -# control frames """ ping(ws, data=[]) @@ -573,7 +440,7 @@ to when a PING message is received by a websocket connection. function ping(ws::WebSocket, data=UInt8[]) @lock ws.sendlock begin @assert !ws.writeclosed "WebSocket is closed" - return writeframe(ws, true, PING, payload(ws, data)) + return writeframe(ws, true, PING, _to_bytes(data)) end end @@ -589,7 +456,7 @@ used as a one-way heartbeat. function pong(ws::WebSocket, data=UInt8[]) @lock ws.sendlock begin @assert !ws.writeclosed "WebSocket is closed" - return writeframe(ws, true, PONG, payload(ws, data)) + return writeframe(ws, true, PONG, _to_bytes(data)) end end @@ -643,32 +510,163 @@ function Base.iterate(ws::WebSocket, st=nothing) end end +""" + WebSockets.isclosed(ws) -> Bool + +Check whether a `WebSocket` has sent and received CLOSE frames +""" +isclosed(ws::WebSocket) = !isopen(ws.readchannel) && ws.writeclosed + +function Base.close(ws::WebSocket, body::Union{Nothing, CloseFrameBody}=nothing) + @lock ws.closelock begin + if ws.writeclosed + _close_channel!(ws) + return + end + ws.writeclosed = true + handler = ws.handler + if handler !== nothing + if body !== nothing + code = UInt16(body.code) + reason = Vector{UInt8}(codeunits(body.reason)) + @lock handler.wslock begin + try + AwsHTTP.ws_close!(handler.aws_ws; status_code=code, reason=reason) + _ws_channel_flush!(handler) + catch + # ignore errors while closing + end + end + end + end + _close_channel!(ws) + end + return +end + @noinline handshakeerror() = throw(WebSocketError(CloseFrameBody(1002, "Websocket handshake failed"))) -# given a WebSocket request, return the 101 response +# ─── Client-side open ─── + +function open(f::Function, url; + suppress_close_error::Bool=false, + headers=[], + maxframesize::Integer=typemax(Int), + maxfragmentation::Integer=DEFAULT_MAX_FRAG, + username=nothing, + password=nothing, + bearer=nothing, + query=nothing, + client::Union{Nothing, Client}=nothing, + # redirect options + redirect=true, + redirect_limit=3, + redirect_method=nothing, + forwardheaders=true, + # cookie options + cookies=true, + cookiejar::CookieJar=COOKIEJAR, + modifier=nothing, + verbose=0, + # client keywords + kw... + ) + key = base64encode(rand(Random.RandomDevice(), UInt8, 16)) + uri = parseuri(url, query) + # add required websocket headers + headers = collect(headers) + append!(headers, [ + "upgrade" => "websocket", + "connection" => "upgrade", + "sec-websocket-key" => key, + "sec-websocket-version" => "13" + ]) + ws = with_redirect("GET", uri, headers, nothing, redirect, redirect_limit, redirect_method, forwardheaders) do method, uri, headers, body + reqclient = @something(client, getclient(ClientSettings(scheme(uri), host(uri), getport(uri); ssl_alpn_list="http/1.1", kw...)))::Client + path = resource(uri) + with_request(reqclient, method, path, headers, body, nothing, false, (username !== nothing && password !== nothing) ? "$username:$password" : userinfo(uri), bearer, modifier, false, cookies, cookiejar, verbose) do req + ws_host = header(req, "host", String(host(uri))) + ws = WebSocket(ws_host, path; maxframesize=maxframesize, maxfragmentation=maxfragmentation) + ws.handshake_request = req + with_connection(reqclient) do conn + stream = _open_stream(conn, req, false, 0) + resp = startread(stream) + ws.handshake_response = resp + if resp.status != 101 + # Not a WebSocket upgrade; finish the stream and return ws + # so with_redirect can follow 3xx redirects via getresponse(ws) + try; closeread(stream); catch; end + return ws + end + # Wait for the H1 stream to complete (fires with ERROR_HTTP_SWITCHED_PROTOCOLS) + try + wait(stream.fut) + catch + # Expected: ERROR_HTTP_SWITCHED_PROTOCOLS + end + # Swap the H1Connection handler with our WsChannelHandler + h1conn = stream.aws_stream.owning_connection + slot = h1conn.slot + _create_ws_handler!(ws, slot, true) + verbose > 0 && @info "$(ws.id): WebSocket opened" + # Run WebSocket session + try + f(ws) + catch e + if !isok(e) + suppress_close_error || @error "$(ws.id): error" exception=(e, catch_backtrace()) + end + if !isclosed(ws) + if e isa WebSocketError && e.message isa CloseFrameBody + close(ws, e.message) + else + close(ws, CloseFrameBody(1008, "Unexpected client websocket error")) + end + end + if !isok(e) + rethrow() + end + finally + if !isclosed(ws) + close(ws, CloseFrameBody(1000, "")) + end + end + end + return ws + end + end + # After redirect loop: verify we actually got a 101 upgrade + resp = ws.handshake_response + if resp === nothing || resp.status != 101 + throw(WebSocketError(CloseFrameBody(1002, "Websocket handshake failed: status $(resp === nothing ? 0 : resp.status)"))) + end + return ws +end + +# ─── Server-side upgrade ─── + +# Given a WebSocket request, return the 101 response function websocket_upgrade_handler(req::Request) if !isupgrade(req) - return Response(400, ["content-type" => "text/plain"], "websocket upgrade required", false, req.allocator) + return Response(400, ["content-type" => "text/plain"], "websocket upgrade required") end if !hasheader(req, "Sec-WebSocket-Version", "13") - return Response(400, ["content-type" => "text/plain"], "unsupported websocket version", false, req.allocator) + return Response(400, ["content-type" => "text/plain"], "unsupported websocket version") end key = getheader(req.headers, "sec-websocket-key") if key === nothing || isempty(key) - return Response(400, ["content-type" => "text/plain"], "missing websocket key", false, req.allocator) + return Response(400, ["content-type" => "text/plain"], "missing websocket key") end - resp_ptr = aws_http_message_new_websocket_handshake_response(req.allocator, aws_byte_cursor_from_c_str(key)) - resp_ptr == C_NULL && aws_throw_error() - resp = Response() - resp.allocator = req.allocator - resp.ptr = resp_ptr - resp.request = req + accept_key = AwsHTTP.ws_compute_accept_key(key) + resp = Response(101, [ + "upgrade" => "websocket", + "connection" => "upgrade", + "sec-websocket-accept" => accept_key, + ], nothing) return resp end function websocket_upgrade_function(f; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, done=nothing) - #TODO: return WebSocketUpgradeArgs - # then schedule a task to do the actual upgrade function websocket_upgrade(stream::Stream) resp = isdefined(stream, :response) ? stream.response : nothing if resp === nothing || resp.status != 101 @@ -679,17 +677,10 @@ function websocket_upgrade_function(f; suppress_close_error::Bool=false, maxfram ws = WebSocket(header(req, "host", ""), req.path; maxframesize=maxframesize, maxfragmentation=maxfragmentation) ws.handshake_request = req ws.handshake_response = resp - stream.websocket_options = aws_websocket_server_upgrade_options( - 0, - Ptr{Cvoid}(pointer_from_objref(ws)), - on_incoming_frame_begin[], - on_incoming_frame_payload[], - on_incoming_frame_complete[], - false # manual_window_management - ) - ws_ptr = aws_websocket_upgrade(stream.allocator, stream.ptr, FieldRef(stream, :websocket_options)) - ws_ptr == C_NULL && aws_throw_error() - ws.websocket_pointer = ws_ptr + # Get the H1Connection's channel slot and swap in our WebSocket handler + h1conn = stream.aws_stream.owning_connection + slot = h1conn.slot + _create_ws_handler!(ws, slot, false) errormonitor(Threads.@spawn begin err = nothing try @@ -721,7 +712,6 @@ function websocket_upgrade_function(f; suppress_close_error::Bool=false, maxfram end done !== nothing && notify(done, nothing) end - aws_websocket_release(ws_ptr) end end) return @@ -737,8 +727,8 @@ function _upgrade(f::Function, stream::Stream; suppress_close_error::Bool=false, done = Future{Nothing}() stream.on_complete = websocket_upgrade_function(f; suppress_close_error=suppress_close_error, maxframesize=maxframesize, maxfragmentation=maxfragmentation, done=done) stream.response = websocket_upgrade_handler(stream.request) - HTTP.startwrite(stream) - HTTP.closewrite(stream) + startwrite(stream) + closewrite(stream) wait(done) return end @@ -757,15 +747,4 @@ function listen(f, host="127.0.0.1", port=8080; kw...) return server end -function __init__() - on_connection_setup[] = @cfunction(c_on_connection_setup, Cvoid, (Ptr{aws_websocket_on_connection_setup_data}, Ptr{Cvoid})) - on_connection_shutdown[] = @cfunction(c_on_connection_shutdown, Cvoid, (Ptr{aws_websocket}, Cint, Ptr{Cvoid})) - on_incoming_frame_begin[] = @cfunction(c_on_incoming_frame_begin, Bool, (Ptr{aws_websocket}, Ptr{aws_websocket_incoming_frame}, Ptr{Cvoid})) - on_incoming_frame_payload[] = @cfunction(c_on_incoming_frame_payload, Bool, (Ptr{aws_websocket}, Ptr{aws_websocket_incoming_frame}, aws_byte_cursor, Ptr{Cvoid})) - on_incoming_frame_complete[] = @cfunction(c_on_incoming_frame_complete, Bool, (Ptr{aws_websocket}, Ptr{aws_websocket_incoming_frame}, Cint, Ptr{Cvoid})) - stream_outgoing_payload[] = @cfunction(c_stream_outgoing_payload, Bool, (Ptr{aws_websocket}, Ptr{aws_byte_buf}, Ptr{Cvoid})) - on_complete[] = @cfunction(c_on_complete, Cvoid, (Ptr{aws_websocket}, Cint, Ptr{Cvoid})) - return -end - end # module diff --git a/test/client.jl b/test/client.jl index 62db60f2..1cfa5171 100644 --- a/test/client.jl +++ b/test/client.jl @@ -638,10 +638,15 @@ end @testset "HTTP connection monitoring stats" begin - list = Ref{HTTP.aws_array_list}() - HTTP.aws_array_list_init_dynamic(list, HTTP.default_aws_allocator(), 1, sizeof(HTTP.aws_crt_statistics_http1_channel)) - stat1 = HTTP.aws_crt_statistics_http1_channel(HTTP.AWSCRT_STAT_CAT_HTTP1_CHANNEL, 10, 20, 1, 2) - HTTP.aws_array_list_push_back(list, Ref(stat1)) + list = AwsIO.ArrayList{HTTP.aws_crt_statistics_http1_channel}() + stat1 = HTTP.aws_crt_statistics_http1_channel( + HTTP.AWSCRT_STAT_CAT_HTTP1_CHANNEL, + UInt64(10), + UInt64(20), + UInt32(1), + UInt32(2), + ) + AwsIO.push_back!(list, stat1) decoded = HTTP._decode_statistics(list) @test length(decoded) == 1 @test decoded[1].category == :http1_channel @@ -649,30 +654,36 @@ @test decoded[1].pending_incoming_stream_ms == 20 @test decoded[1].current_outgoing_stream_id == 1 @test decoded[1].current_incoming_stream_id == 2 - HTTP.aws_array_list_clean_up(list) - list = Ref{HTTP.aws_array_list}() - HTTP.aws_array_list_init_dynamic(list, HTTP.default_aws_allocator(), 1, sizeof(HTTP.aws_crt_statistics_http2_channel)) - stat2 = HTTP.aws_crt_statistics_http2_channel(HTTP.AWSCRT_STAT_CAT_HTTP2_CHANNEL, 5, 6, true) - HTTP.aws_array_list_push_back(list, Ref(stat2)) + list = AwsIO.ArrayList{HTTP.aws_crt_statistics_http2_channel}() + stat2 = HTTP.aws_crt_statistics_http2_channel( + HTTP.AWSCRT_STAT_CAT_HTTP2_CHANNEL, + UInt64(5), + UInt64(6), + true, + ) + AwsIO.push_back!(list, stat2) decoded = HTTP._decode_statistics(list) @test length(decoded) == 1 @test decoded[1].category == :http2_channel @test decoded[1].pending_outgoing_stream_ms == 5 @test decoded[1].pending_incoming_stream_ms == 6 @test decoded[1].was_inactive == true - HTTP.aws_array_list_clean_up(list) called = Ref(false) cb = (nonce, stats) -> (called[] = true) client = HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443); monitoring_statistics_observer=cb)) - list = Ref{HTTP.aws_array_list}() - HTTP.aws_array_list_init_dynamic(list, HTTP.default_aws_allocator(), 1, sizeof(HTTP.aws_crt_statistics_http1_channel)) - stat3 = HTTP.aws_crt_statistics_http1_channel(HTTP.AWSCRT_STAT_CAT_HTTP1_CHANNEL, 1, 1, 1, 1) - HTTP.aws_array_list_push_back(list, Ref(stat3)) - HTTP.c_on_statistics_observer(Csize_t(0), Base.unsafe_convert(Ptr{HTTP.aws_array_list}, list), pointer_from_objref(client.monitoring_observer)) + list = AwsIO.ArrayList{HTTP.aws_crt_statistics_http1_channel}() + stat3 = HTTP.aws_crt_statistics_http1_channel( + HTTP.AWSCRT_STAT_CAT_HTTP1_CHANNEL, + UInt64(1), + UInt64(1), + UInt32(1), + UInt32(1), + ) + AwsIO.push_back!(list, stat3) + HTTP._call_statistics_observer(client.monitoring_observer, Csize_t(0), list) @test called[] - HTTP.aws_array_list_clean_up(list) finalize(client) end diff --git a/test/runtests.jl b/test/runtests.jl index eded5cdf..edf374e8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -using Test, HTTP, URIs, JSON +using Test, HTTP, URIs, JSON, AwsIO const httpbin = get(ENV, "JULIA_TEST_HTTPBINGO_SERVER", "httpbingo.julialang.org") isok(r) = r.status == 200 diff --git a/test/server.jl b/test/server.jl index 8dbd6022..028d0395 100644 --- a/test/server.jl +++ b/test/server.jl @@ -1,11 +1,12 @@ -using Test, HTTP, Logging, Base64 +using Test, HTTP, Logging, Base64, AwsIO import Sockets @testset "HTTP.serve" begin server = HTTP.serve!(req -> HTTP.Response(200, "Hello, World!"); listenany=true) try @test server.state == :running - resp = HTTP.get("http://127.0.0.1:8080") + port = HTTP.port(server) + resp = HTTP.get("http://127.0.0.1:$port") @test resp.status == 200 @test String(resp.body) == "Hello, World!" finally @@ -104,75 +105,81 @@ end end end -@testset "HTTP/2 stream handler writes" begin - cert = joinpath(@__DIR__, "fixtures", "http2.crt") - key = joinpath(@__DIR__, "fixtures", "http2.key") - saw_http2 = Threads.Atomic{Bool}(false) - buffered = Threads.Atomic{Bool}(false) - server = HTTP.serve!("127.0.0.1", 0; listenany=true, stream=true, ssl_cert=cert, ssl_key=key, ssl_alpn_list="h2") do stream - HTTP.startread(stream) - stream.http2 && (saw_http2[] = true) - HTTP.setstatus(stream, 200) - HTTP.startwrite(stream) - write(stream, "hello") - if stream.http2 && stream.responsebuf !== nothing - buffered[] = true - end - HTTP.closewrite(stream) - end - try - port = HTTP.port(server) - resp = HTTP.get("https://127.0.0.1:$(port)"; ssl_insecure=true, ssl_alpn_list="h2") - if resp.version == HTTP.HTTPVersion(2, 0) - @test saw_http2[] - @test !buffered[] - @test String(resp.body) == "hello" - else - @info "HTTP/2 not negotiated for stream handler test" - @test true - end - finally - close(server) - end -end - -@testset "HTTP/2 server push promise" begin - cert = joinpath(@__DIR__, "fixtures", "http2.crt") - key = joinpath(@__DIR__, "fixtures", "http2.key") - port_ref = Ref{Int}(0) - push_called = Threads.Atomic{Bool}(false) - push_http2 = Threads.Atomic{Bool}(false) - push_server_side = Threads.Atomic{Bool}(false) - server = HTTP.serve!("127.0.0.1", 0; listenany=true, stream=true, ssl_cert=cert, ssl_key=key, ssl_alpn_list="h2") do stream - HTTP.startread(stream) - if stream.http2 - authority = "127.0.0.1:$(port_ref[])" - push = HTTP.push_promise(stream, "GET", "/pushed"; scheme="https", authority=authority) - push_called[] = true - push_http2[] = push.http2 - push_server_side[] = push.server_side - HTTP.setstatus(push, 200) - HTTP.setheader(push, "Content-Type" => "text/plain") - write(push, "pushed") - HTTP.closewrite(push) +@testset "HTTP/2 TLS support" begin + if !AwsIO.tls_is_alpn_available() + @info "Skipping HTTP/2 TLS tests; ALPN not available" + @test true + else + @testset "HTTP/2 stream handler writes" begin + cert = joinpath(@__DIR__, "fixtures", "http2.crt") + key = joinpath(@__DIR__, "fixtures", "http2.key") + saw_http2 = Threads.Atomic{Bool}(false) + buffered = Threads.Atomic{Bool}(false) + server = HTTP.serve!("127.0.0.1", 0; listenany=true, stream=true, ssl_cert=cert, ssl_key=key, ssl_alpn_list="h2") do stream + HTTP.startread(stream) + stream.http2 && (saw_http2[] = true) + HTTP.setstatus(stream, 200) + HTTP.startwrite(stream) + write(stream, "hello") + if stream.http2 && stream.responsebuf !== nothing + buffered[] = true + end + HTTP.closewrite(stream) + end + try + port = HTTP.port(server) + resp = HTTP.get("https://127.0.0.1:$(port)"; ssl_insecure=true, ssl_alpn_list="h2") + if resp.version == HTTP.HTTPVersion(2, 0) + @test saw_http2[] + @test !buffered[] + @test String(resp.body) == "hello" + else + @info "HTTP/2 not negotiated for stream handler test" + @test true + end + finally + close(server) + end end - HTTP.setstatus(stream, 200) - write(stream, "ok") - end - try - port_ref[] = HTTP.port(server) - resp = HTTP.get("https://127.0.0.1:$(port_ref[])"; ssl_insecure=true, ssl_alpn_list="h2") - if resp.version == HTTP.HTTPVersion(2, 0) - @test push_called[] - @test push_http2[] - @test push_server_side[] - @test String(resp.body) == "ok" - else - @info "HTTP/2 not negotiated for push promise test" - @test true + @testset "HTTP/2 server push promise" begin + cert = joinpath(@__DIR__, "fixtures", "http2.crt") + key = joinpath(@__DIR__, "fixtures", "http2.key") + port_ref = Ref{Int}(0) + push_called = Threads.Atomic{Bool}(false) + push_http2 = Threads.Atomic{Bool}(false) + push_server_side = Threads.Atomic{Bool}(false) + server = HTTP.serve!("127.0.0.1", 0; listenany=true, stream=true, ssl_cert=cert, ssl_key=key, ssl_alpn_list="h2") do stream + HTTP.startread(stream) + if stream.http2 + authority = "127.0.0.1:$(port_ref[])" + push = HTTP.push_promise(stream, "GET", "/pushed"; scheme="https", authority=authority) + push_called[] = true + push_http2[] = push.http2 + push_server_side[] = push.server_side + HTTP.setstatus(push, 200) + HTTP.setheader(push, "Content-Type" => "text/plain") + write(push, "pushed") + HTTP.closewrite(push) + end + HTTP.setstatus(stream, 200) + write(stream, "ok") + end + try + port_ref[] = HTTP.port(server) + resp = HTTP.get("https://127.0.0.1:$(port_ref[])"; ssl_insecure=true, ssl_alpn_list="h2") + if resp.version == HTTP.HTTPVersion(2, 0) + @test push_called[] + @test push_http2[] + @test push_server_side[] + @test String(resp.body) == "ok" + else + @info "HTTP/2 not negotiated for push promise test" + @test true + end + finally + close(server) + end end - finally - close(server) end end diff --git a/test/utils.jl b/test/utils.jl index 70395fb3..d4422a9c 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -66,7 +66,8 @@ end @test got - @test_throws HTTP.AWSError HTTP.parseuri("http://example.com:abc", nothing, HTTP.default_aws_allocator()) + # parseuri now delegates to URIs.jl (no C-level allocator needed) + @test HTTP.parseuri("http://example.com/path", nothing) isa URIs.URI exported = names(HTTP, all=false) @test :startwrite in exported From bbec902555353ef75cddc7edaa93787aa1b6f764 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Mon, 2 Feb 2026 07:57:40 -0700 Subject: [PATCH 47/56] Avoid external redirect target in client tests --- test/client.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/client.jl b/test/client.jl index 1cfa5171..c27698b9 100644 --- a/test/client.jl +++ b/test/client.jl @@ -159,7 +159,8 @@ @test HTTP.request(read_method, "https://$httpbin/redirect/6", status_exception=false).status == 302 #over max number of redirects @test isok(HTTP.request(read_method, "https://$httpbin/relative-redirect/1")) @test isok(HTTP.request(read_method, "https://$httpbin/absolute-redirect/1")) - @test isok(HTTP.request(read_method, "https://$httpbin/redirect-to?url=http%3A%2F%2Fgoogle.com")) + redirect_target = URIs.escapeuri("http://$httpbin/get") + @test isok(HTTP.request(read_method, "https://$httpbin/redirect-to?url=$redirect_target")) end @testset "Client Basic Auth" begin From 4ccfcd78f53688413b886d9d0ce5cb8f87d0950c Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Mon, 2 Feb 2026 16:23:30 -0700 Subject: [PATCH 48/56] Select server handler from TLS protocol --- src/server.jl | 65 +++++++++++++++++++++------------------------------ 1 file changed, 27 insertions(+), 38 deletions(-) diff --git a/src/server.jl b/src/server.jl index 91328ebd..8c085f6f 100644 --- a/src/server.jl +++ b/src/server.jl @@ -352,51 +352,40 @@ function serve!(f, host="127.0.0.1", port=8080; socket_opts.impl_type = AwsIO.SocketImplType.APPLE_NETWORK_FRAMEWORK end initial_window = Csize_t(min(UInt64(initial_window_size), UInt64(typemax(Csize_t)))) - on_protocol_negotiated = tls_conn_opts !== nothing ? - (new_slot, protocol, user_data) -> begin - protocol_str = AwsIO.byte_buffer_as_string(protocol) - version = protocol_str == "h2" ? AwsHTTP.HttpVersion.HTTP_2 : AwsHTTP.HttpVersion.HTTP_1_1 - return AwsHTTP.http_connection_new_channel_handler(; - is_server=true, - version, - initial_window_size=initial_window, - ) - end : nothing on_incoming_channel_setup = (bootstrap, error_code, channel, user_data) -> begin Base.CoreLogging.with_logstate(server.logstate) do if error_code != 0 @error "incoming channel setup error" error_code return end - # For TLS, ALPN may have installed the H1/H2 handler. Otherwise install H1 manually. - http_conn = if tls_conn_opts !== nothing - handler = channel.last.handler - if handler isa AwsHTTP.H1Connection || handler isa AwsHTTP.H2Connection - handler - else - h = AwsHTTP.http_connection_new_channel_handler(; - is_server=true, - version=AwsHTTP.HttpVersion.HTTP_1_1, - initial_window_size=initial_window, - ) - slot = AwsIO.channel_slot_new!(channel) - AwsIO.channel_slot_insert_end!(channel, slot) - AwsIO.channel_slot_set_handler!(slot, h) - h.slot = slot - h + slot = AwsIO.channel_slot_new!(channel) + AwsIO.channel_slot_insert_end!(channel, slot) + version = AwsHTTP.HttpVersion.HTTP_1_1 + if tls_conn_opts !== nothing + tls_slot = slot.adj_left + if tls_slot === nothing || tls_slot.handler === nothing || !(tls_slot.handler isa AwsIO.TlsChannelHandler) + @error "incoming channel setup error" error_code=AwsIO.ERROR_INVALID_STATE + AwsIO.channel_shutdown!(channel, AwsIO.ERROR_INVALID_STATE) + return + end + protocol = AwsIO.tls_handler_protocol(tls_slot.handler) + if protocol.len > 0 + protocol_str = AwsIO.byte_buffer_as_string(protocol) + if protocol_str == "h2" + version = AwsHTTP.HttpVersion.HTTP_2 + elseif protocol_str == "http/1.1" + version = AwsHTTP.HttpVersion.HTTP_1_1 + end end - else - h = AwsHTTP.http_connection_new_channel_handler(; - is_server=true, - version=AwsHTTP.HttpVersion.HTTP_1_1, - initial_window_size=initial_window, - ) - slot = AwsIO.channel_slot_new!(channel) - AwsIO.channel_slot_insert_end!(channel, slot) - AwsIO.channel_slot_set_handler!(slot, h) - h.slot = slot - h end + http_conn = AwsHTTP.http_connection_new_channel_handler(; + is_server=true, + version=version, + initial_window_size=initial_window, + ) + http_conn === nothing && return + AwsIO.channel_slot_set_handler!(slot, http_conn) + http_conn.slot = slot # Extract remote endpoint from the socket handler (first slot in pipeline) remote_addr = "0.0.0.0" remote_port_num = 0 @@ -467,7 +456,7 @@ function serve!(f, host="127.0.0.1", port=8080; host = host_str, port = UInt32(port_int), tls_connection_options = tls_conn_opts, - on_protocol_negotiated = on_protocol_negotiated, + on_protocol_negotiated = nothing, on_incoming_channel_setup = on_incoming_channel_setup, on_incoming_channel_shutdown = on_incoming_channel_shutdown, on_listener_destroy = on_listener_destroy, From 215b77a2885de916dcfc77cf7b97262d42ee4e59 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Mon, 2 Feb 2026 19:24:58 -0700 Subject: [PATCH 49/56] Restore headers vector semantics --- src/cookies.jl | 12 +++-- src/handlers.jl | 4 +- src/requestresponse.jl | 109 +++++++++++++++++++++++++++++++++++------ test/headers.jl | 26 +++++++++- 4 files changed, 126 insertions(+), 25 deletions(-) diff --git a/src/cookies.jl b/src/cookies.jl index 7cc09e6c..6a6181c6 100644 --- a/src/cookies.jl +++ b/src/cookies.jl @@ -34,7 +34,7 @@ export Cookie, CookieJar, cookies, stringify, getcookies!, setcookies!, addcooki import Base: == using Dates, Sockets -import ..addheader, ..headereq, ..Headers, ..Request, ..Response +import ..addheader, ..headereq, ..Headers, ..Request, ..Response, .._header_name, .._header_value @enum SameSite SameSiteDefaultMode=1 SameSiteLaxMode SameSiteStrictMode SameSiteNoneMode @@ -211,8 +211,9 @@ const RFC1123GMTFormat = gmtformat(Dates.RFC1123Format) function readsetcookies(headers::Headers) result = Cookie[] for h in headers - headereq(h.name, "Set-Cookie") || continue - line = h.value + name = _header_name(h) + headereq(name, "Set-Cookie") || continue + line = _header_value(h) if length(line) == 0 continue end @@ -326,8 +327,9 @@ cookies(r::Request) = readcookies(r.headers, "") function readcookies(headers::Headers, filter::String="") result = Cookie[] for h in headers - headereq(h.name, "Cookie") || continue - line = h.value + name = _header_name(h) + headereq(name, "Cookie") || continue + line = _header_value(h) for part in split(strip(line), ';'; keepempty=false) part = strip(part) length(part) <= 1 && continue diff --git a/src/handlers.jl b/src/handlers.jl index 22ae9768..86f7367f 100644 --- a/src/handlers.jl +++ b/src/handlers.jl @@ -2,7 +2,7 @@ module Handlers export Handler, Middleware, serve, serve!, Router, register!, getroute, getparams, getparam, getcookies, streamhandler -import ..Request, ..Response, ..Stream, ..Cookies, ..getbody +import ..Request, ..Response, ..Stream, ..Cookies, ..getbody, .._header_name, .._header_value import ..startread, ..setstatus, ..setheader, ..addtrailer, ..closewrite, ..closeread """ @@ -53,7 +53,7 @@ function streamhandler(handler) resp.request = req setstatus(stream, resp.status) for h in resp.headers - setheader(stream, h.name, h.value) + setheader(stream, _header_name(h), _header_value(h)) end body = getbody(resp) if body isa IO diff --git a/src/requestresponse.jl b/src/requestresponse.jl index 667f3fab..a61d31ea 100644 --- a/src/requestresponse.jl +++ b/src/requestresponse.jl @@ -1,6 +1,6 @@ export Header, Headers, Message, Request, Response, header, headers, hasheader, headercontains, - setheader, setheaderifabsent, defaultheader!, appendheader, removeheader, + setheader, setheaderifabsent, setheaders!, defaultheader!, appendheader, removeheader, canonicalizeheaders, canonicalizeheaders!, mkheaders # working with headers @@ -24,32 +24,86 @@ end Base.show(io::IO, h::Header) = print_header(io, h) -mutable struct Headers <: AbstractVector{Header} +@inline _header_name(h::Header) = h.name +@inline _header_name(h::Pair) = String(first(h)) +@inline _header_name(h::AwsHTTP.HttpHeader) = h.name +@inline _header_value(h::Header) = h.value +@inline _header_value(h::Pair) = String(last(h)) +@inline _header_value(h::AwsHTTP.HttpHeader) = h.value +@inline _header_pair(h) = _header_name(h) => _header_value(h) + +Base.first(h::Header) = h.name +Base.last(h::Header) = h.value + +mutable struct Headers <: AbstractVector{Pair{String, String}} const hdrs::AwsHTTP.HttpHeaders function Headers() return new(AwsHTTP.http_headers_new()) end Headers(hdrs::AwsHTTP.HttpHeaders) = new(hdrs) + function Headers(h::AbstractVector{<:Pair}) + hdrs = AwsHTTP.http_headers_new() + for (k, v) in h + AwsHTTP.http_headers_add(hdrs, String(k), String(v)) != 0 && aws_throw_error() + end + return new(hdrs) + end end abstract type Message end +Base.eltype(::Type{Headers}) = Pair{String, String} +Base.IndexStyle(::Type{Headers}) = IndexLinear() Base.size(h::Headers) = (AwsHTTP.http_headers_count(h.hdrs),) +Base.length(h::Headers) = AwsHTTP.http_headers_count(h.hdrs) function Base.getindex(h::Headers, i::Int) hdr = AwsHTTP.http_headers_get_index(h.hdrs, i - 1) hdr === nothing && throw(BoundsError(h, i)) - return Header(hdr) + return _header_pair(hdr) +end + +function Base.setindex!(h::Headers, v::Pair, i::Int) + len = length(h) + (i < 1 || i > len) && throw(BoundsError(h, i)) + items = collect(h) + items[i] = String(v.first) => String(v.second) + empty!(h) + addheaders(h, items) + return h +end + +function Base.insert!(h::Headers, i::Int, v::Pair) + len = length(h) + (i < 1 || i > len + 1) && throw(BoundsError(h, i)) + items = collect(h) + insert!(items, i, String(v.first) => String(v.second)) + empty!(h) + addheaders(h, items) + return h end -Base.Dict(h::Headers) = Dict(((h2.name, h2.value) for h2 in h)) +function Base.push!(h::Headers, v::Pair) + addheader(h, v) + return h +end + +function Base.push!(h::Headers, v::Header) + addheader(h, v) + return h +end + +Base.Dict(h::Headers) = Dict((_header_pair(h2) for h2 in h)) +Base.copy(h::Headers) = mkheaders(h) +Base.convert(::Type{Vector{Pair{String, String}}}, h::Headers) = mkheaders(h) addheader(headers::Headers, h::Header) = AwsHTTP.http_headers_add_header(headers.hdrs, h.header) != 0 && aws_throw_error() addheader(headers::Headers, h::AwsHTTP.HttpHeader) = AwsHTTP.http_headers_add_header(headers.hdrs, h) != 0 && aws_throw_error() +addheader(headers::Headers, h::Pair) = AwsHTTP.http_headers_add(headers.hdrs, String(h.first), String(h.second)) != 0 && aws_throw_error() addheader(headers::Headers, k, v) = AwsHTTP.http_headers_add(headers.hdrs, String(k), String(v)) != 0 && aws_throw_error() addheaders(headers::Headers, h::Vector{AwsHTTP.HttpHeader}) = AwsHTTP.http_headers_add_array(headers.hdrs, h) != 0 && aws_throw_error() -function addheaders(headers::Headers, h::Vector{Pair{String, String}}) +function addheaders(headers::Headers, h::AbstractVector{<:Pair}) for (k, v) in h addheader(headers, k, v) end @@ -81,6 +135,27 @@ end setheaderifabsent(headers, k, v) = !hasheader(headers, k) && setheader(headers, k, v) setheaderifabsent(m::Message, k, v) = setheaderifabsent(m.headers, k, v) +function setheaders!(headers::Headers, newheaders) + newheaders === headers && return headers + if newheaders === nothing + empty!(headers) + return headers + end + if newheaders isa Headers + newheaders.hdrs === headers.hdrs && return headers + items = collect(newheaders) + elseif newheaders isa AwsHTTP.HttpHeaders + items = collect(Headers(newheaders)) + else + items = mkheaders(newheaders) + end + empty!(headers) + addheaders(headers, items) + return headers +end + +setheaders!(m::Message, newheaders) = setheaders!(m.headers, newheaders) + field_name_isequal(a, b) = headereq(String(a), String(b)) Base.getindex(m::Message, k) = header(m, k) @@ -106,7 +181,7 @@ end Get all headers with key `k` or empty if none. """ -headers(h::Headers, k) = [h2.value for h2 in h if field_name_isequal(h2.name, k)] +headers(h::Headers, k) = [_header_value(h2) for h2 in h if field_name_isequal(_header_name(h2), k)] headers(h::AbstractVector{<:Pair}, k) = [String(v) for (name, v) in h if field_name_isequal(name, k)] headers(m::Message, k) = headers(m.headers, k) @@ -183,7 +258,7 @@ function defaultheader!(m, v::Pair) end function canonicalizeheaders!(h::Headers) - items = [(h2.name, h2.value) for h2 in h] + items = [(_header_name(h2), _header_value(h2)) for h2 in h] for i in length(h):-1:1 deleteat!(h, i) end @@ -197,11 +272,11 @@ canonicalizeheaders(h::AbstractVector{<:Pair}) = [tocameldash(String(k)) => String(v) for (k, v) in h] mkheaders(::Nothing) = Pair{String, String}[] -mkheaders(h::Headers) = [h2.name => h2.value for h2 in h] +mkheaders(h::Headers) = [_header_pair(h2) for h2 in h] mkheaders(h::AbstractVector{Header}) = begin headers = Pair{String, String}[] for head in h - push!(headers, String(head.name) => String(head.value)) + push!(headers, _header_pair(head)) end return headers end @@ -227,7 +302,7 @@ end function sync_headers!(dest::AbstractVector{<:Pair}, src::Headers) empty!(dest) for h in src - push!(dest, String(h.name) => String(h.value)) + push!(dest, _header_pair(h)) end return dest end @@ -345,7 +420,8 @@ mutable struct Request <: Message AwsHTTP.http_message_set_request_path(msg, String(path)) != 0 && aws_throw_error() msg_headers = AwsHTTP.http_message_get_headers(msg) if headers !== nothing - for (k, v) in headers + src_headers = headers isa AbstractVector{<:Pair} ? headers : mkheaders(headers) + for (k, v) in src_headers AwsHTTP.http_headers_add(msg_headers, String(k), String(v)) != 0 && aws_throw_error() end end @@ -408,15 +484,15 @@ function Base.setproperty!(x::Request, s::Symbol, v) elseif s == :path AwsHTTP.http_message_set_request_path(getfield(x, :msg), String(v)) != 0 && aws_throw_error() elseif s == :headers - addheaders(x.headers, v) + setheaders!(x, v) else setfield!(x, s, v) end end function print_header(io, h) - key = h.name - val = h.value + key = _header_name(h) + val = _header_value(h) if headereq(key, "authorization") write(io, string(key, ": ", "******", "\r\n")) return @@ -494,7 +570,8 @@ mutable struct Response <: Message AwsHTTP.http_message_set_response_status(msg, Int(status)) != 0 && aws_throw_error() msg_headers = AwsHTTP.http_message_get_headers(msg) if headers !== nothing - for (k, v) in headers + src_headers = headers isa AbstractVector{<:Pair} ? headers : mkheaders(headers) + for (k, v) in src_headers AwsHTTP.http_headers_add(msg_headers, String(k), String(v)) != 0 && aws_throw_error() end end @@ -558,7 +635,7 @@ function Base.setproperty!(x::Response, s::Symbol, v) if s == :status AwsHTTP.http_message_set_response_status(getfield(x, :msg), Int(v)) != 0 && aws_throw_error() elseif s == :headers - addheaders(x.headers, v) + setheaders!(x, v) else setfield!(x, s, v) end diff --git a/test/headers.jl b/test/headers.jl index c9f47653..9263de22 100644 --- a/test/headers.jl +++ b/test/headers.jl @@ -9,6 +9,28 @@ @test HTTP.headercontains(h, "x-test-header", "abc") HTTP.canonicalizeheaders!(h) - @test any(x -> x.name == "X-Test-Header", h) - @test any(x -> x.name == "Content-Type", h) + @test any(x -> first(x) == "X-Test-Header", h) + @test any(x -> first(x) == "Content-Type", h) +end + +@testset "Headers vector compatibility" begin + h = HTTP.Headers() + push!(h, "a" => "1") + push!(h, "b" => "2") + @test h[1] == ("a" => "1") + h[2] = "b" => "3" + @test h[2] == ("b" => "3") + insert!(h, 2, "c" => "4") + @test h[2] == ("c" => "4") + + req = HTTP.Request("GET", "/") + req.headers = ["x" => "1", "y" => "2"] + @test length(req.headers) == 2 + req.headers = ["z" => "3"] + @test length(req.headers) == 1 + @test HTTP.header(req, "z") == "3" + + hdrs = req.headers + push!(hdrs, "w" => "4") + @test HTTP.header(req, "w") == "4" end From 6636cec36331a700bc8b16127a5f8a7c647844df Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Mon, 2 Feb 2026 19:46:37 -0700 Subject: [PATCH 50/56] Restore pool compatibility --- src/HTTP.jl | 1 + src/client/client.jl | 51 +++++++++++++++++++++++++++++++++++++++----- test/client.jl | 17 +++++++++++++++ 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/src/HTTP.jl b/src/HTTP.jl index 62942e5c..3b75297a 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -7,6 +7,7 @@ export HTTPVersion export startwrite, startread, closewrite, closeread export @logfmt_str, common_logfmt, combined_logfmt export WebSockets +export Pool, default_connection_limit, set_default_connection_limit!, closeall const nobody = UInt8[] diff --git a/src/client/client.jl b/src/client/client.jl index d454556a..3ad8d456 100644 --- a/src/client/client.jl +++ b/src/client/client.jl @@ -1,5 +1,6 @@ const DEFAULT_CONNECT_TIMEOUT = 3000 const DEFAULT_MAX_RETRIES = 4 +const default_connection_limit = Ref{Int}(max(16, Threads.nthreads() * 4)) # ─── Shared infrastructure ─── # Lazily initialized resources that replace the old C library globals @@ -134,12 +135,22 @@ ClientSettings( require_ssl_verification::Bool=true, ssl_insecure::Bool=false, kw...) = begin - http2_initial_settings = Base.get(() -> nothing, kw, :http2_initial_settings) + kw_nt = (; kw...) + connection_limit = Base.get(() -> nothing, kw_nt, :connection_limit) + if connection_limit !== nothing + connection_limit_warning(connection_limit) + kw_nt = Base.structdiff(kw_nt, (; connection_limit=nothing)) + end + max_connections = Base.get(() -> default_connection_limit[], kw_nt, :max_connections) + if haskey(kw_nt, :max_connections) + kw_nt = Base.structdiff(kw_nt, (; max_connections=nothing)) + end + http2_initial_settings = Base.get(() -> nothing, kw_nt, :http2_initial_settings) if http2_initial_settings !== nothing && !(http2_initial_settings isa AbstractVector) throw(ArgumentError("http2_initial_settings must be a vector of pairs or AwsHTTP.Http2Setting")) end - if haskey(kw, :http2_initial_settings) - kw = Base.structdiff((; kw...), (; http2_initial_settings=nothing)) + if haskey(kw_nt, :http2_initial_settings) + kw_nt = Base.structdiff(kw_nt, (; http2_initial_settings=nothing)) end ClientSettings(; scheme=String(scheme), @@ -147,9 +158,10 @@ ClientSettings( port=port, connect_timeout_ms=(connect_timeout !== nothing ? connect_timeout * 1000 : connect_timeout_ms), max_retries=(retry ? (retries != DEFAULT_MAX_RETRIES ? retries : max_retries) : 0), + max_connections=max_connections, ssl_insecure=(!require_ssl_verification || ssl_insecure), http2_initial_settings=http2_initial_settings, - kw...) + kw_nt...) end # make a new ClientSettings object from an existing one w/ just different url values @@ -429,7 +441,15 @@ struct Pool max_connections::Union{Nothing, Int} end -Pool(max_connections::Union{Int, Nothing}=nothing) = Pool(Clients(), max_connections) +Pool() = Pool(default_connection_limit[]) +Pool(max_connections::Union{Int, Nothing}) = Pool(Clients(), max_connections) + +function Base.getproperty(pool::Pool, name::Symbol) + if name === :max + return getfield(pool, :max_connections) + end + return getfield(pool, name) +end const CLIENTS = Clients() @@ -471,7 +491,28 @@ function close_all_clients!(clients::Clients=CLIENTS) end end empty!(clients.clients) + empty!(clients.order) end end close_all_clients!(pool::Pool) = close_all_clients!(pool.clients) + +function set_default_connection_limit!(n::Integer) + default_connection_limit[] = Int(n) + return +end + +function closeall(pool::Union{Nothing, Pool}=nothing) + if pool === nothing + close_all_clients!(CLIENTS) + else + close_all_clients!(pool.clients) + end + return +end + +@noinline function connection_limit_warning(cl) + cl === nothing && return + @warn "connection_limit no longer supported as a keyword argument; use `HTTP.set_default_connection_limit!($cl)` before any requests are made or construct a shared pool via `POOL = HTTP.Pool($cl)` and pass to each request like `pool=POOL` instead." + return +end diff --git a/test/client.jl b/test/client.jl index c27698b9..f14bbab1 100644 --- a/test/client.jl +++ b/test/client.jl @@ -1,4 +1,21 @@ @testset "Client.jl" begin + @testset "Connection pool compatibility" begin + original_limit = HTTP.default_connection_limit[] + HTTP.set_default_connection_limit!(13) + pool = HTTP.Pool() + @test pool.max_connections == 13 + @test pool.max == 13 + cs = HTTP.ClientSettings("http", "example.com", UInt32(80)) + @test cs.max_connections == 13 + @test_logs (:warn, r"connection_limit no longer supported") begin + cs_warn = HTTP.ClientSettings("http", "example.com", UInt32(80); connection_limit=7) + @test cs_warn.max_connections == 13 + end + HTTP.set_default_connection_limit!(original_limit) + pool_default = HTTP.Pool() + @test pool_default.max_connections == original_limit + @test pool_default.max == original_limit + end if HAVE_HTTPBIN @testset "GET, HEAD, POST, PUT, DELETE, PATCH: $scheme" for scheme in ["http", "https"] @test isok(HTTP.get("$scheme://$httpbin/ip")) From 6fe2fa428b609a6cd1da129b6255327d9498c216 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Mon, 2 Feb 2026 20:21:29 -0700 Subject: [PATCH 51/56] Restore server shutdown hooks --- src/server.jl | 124 ++++++++++++++++++++++++++++++++++++++++++++----- test/server.jl | 23 +++++++++ 2 files changed, 135 insertions(+), 12 deletions(-) diff --git a/src/server.jl b/src/server.jl index 8c085f6f..76226c22 100644 --- a/src/server.jl +++ b/src/server.jl @@ -54,6 +54,7 @@ end mutable struct Server{F, C} const f::F const on_stream_complete::C + const on_shutdown::Any const fut::Future{Symbol} const connections_lock::ReentrantLock const connections::Set{Connection} @@ -68,6 +69,7 @@ mutable struct Server{F, C} Server{F, C}( f::F, on_stream_complete::C, + on_shutdown::Any, fut::Future{Symbol}, connections_lock::ReentrantLock, connections::Set{Connection}, @@ -76,13 +78,28 @@ mutable struct Server{F, C} stream::Bool, logstate::Base.CoreLogging.LogState, state::Symbol, - ) where {F, C} = new{F, C}(f, on_stream_complete, fut, connections_lock, connections, closed, access_log, stream, logstate, state) + ) where {F, C} = new{F, C}(f, on_stream_complete, on_shutdown, fut, connections_lock, connections, closed, access_log, stream, logstate, state) end Base.wait(s::Server) = wait(s.closed) ftype(::Server{F}) where {F} = F port(s::Server) = s.bound_port +shutdown(fns::Vector{<:Function}) = foreach(shutdown, fns) +shutdown(::Nothing) = nothing +function shutdown(fn::Function) + try + fn() + catch e + @error "shutdown function failed" exception=(e, catch_backtrace()) + end + return +end + +function _future_done(f::Future) + return (@atomic f.set) != 0 +end + function _should_log_stream_error(error_code::Integer)::Bool error_code == 0 && return false error_code == AwsHTTP.ERROR_HTTP_CONNECTION_CLOSED && return false @@ -244,8 +261,18 @@ function _create_request_handler!(conn::Connection, aws_conn; http2::Bool=false) @error "access log error" exception=(e, catch_backtrace()) end end + shutdown_channel = false @lock conn.streams_lock begin delete!(conn.streams, stream) + if @atomic(server.state) == :closing && isempty(conn.streams) + shutdown_channel = true + end + end + if shutdown_channel + AwsIO.channel_shutdown!(conn.channel; shutdown_immediately=true) + @lock server.connections_lock begin + delete!(server.connections, conn) + end end # HTTP pipelining: create next request handler if connection allows if !stream.http2 && AwsHTTP.http_connection_new_requests_allowed(http_conn) @@ -288,11 +315,31 @@ function _create_request_handler!(conn::Connection, aws_conn; http2::Bool=false) return end +function _warn_unsupported_server_options(; reuseaddr::Bool, backlog::Integer) + reuseaddr && @warn "reuseaddr is not supported by the AwsIO server; ignoring" + backlog != Sockets.BACKLOG_DEFAULT && @warn "backlog is not supported by the AwsIO server; ignoring" + return +end + +function _stop_new_requests!(conn::Connection) + AwsHTTP.http_connection_stop_new_requests(conn.h1conn) + if AwsHTTP.http_connection_get_version(conn.h1conn) == AwsHTTP.HttpVersion.HTTP_2 + try + AwsHTTP.h2_connection_send_goaway!(conn.h1conn; allow_more_streams=false) + catch + end + end + return +end + function serve!(f, host="127.0.0.1", port=8080; on_stream_complete=nothing, + on_shutdown=nothing, access_log::Union{Nothing, Function}=nothing, stream::Bool=false, listenany::Bool=false, + reuseaddr::Bool=false, + backlog::Integer=Sockets.BACKLOG_DEFAULT, # socket options socket_domain=:ipv4, connect_timeout_ms::Integer=3000, @@ -311,6 +358,7 @@ function serve!(f, host="127.0.0.1", port=8080; initial_window_size=typemax(UInt64), ) _ensure_resources!() + _warn_unsupported_server_options(; reuseaddr=reuseaddr, backlog=backlog) host_str = string(host) port_int = Int(port) if listenany @@ -330,6 +378,7 @@ function serve!(f, host="127.0.0.1", port=8080; server = Server{typeof(f), typeof(on_stream_complete)}( f, on_stream_complete, + on_shutdown, Future{Symbol}(), ReentrantLock(), Set{Connection}(), @@ -358,6 +407,11 @@ function serve!(f, host="127.0.0.1", port=8080; @error "incoming channel setup error" error_code return end + st = @atomic(server.state) + if st == :closing || st == :closed + AwsIO.channel_shutdown!(channel; shutdown_immediately=true) + return + end slot = AwsIO.channel_slot_new!(channel) AwsIO.channel_slot_insert_end!(channel, slot) version = AwsHTTP.HttpVersion.HTTP_1_1 @@ -570,21 +624,67 @@ function push_promise(parent::Stream, method::Union{String, Symbol}, path; heade return push_promise(parent, Request(String(method), String(path), headers, nothing, true); pad_length=pad_length, scheme=scheme, authority=authority) end +function _forceclose!(server::Server; skip_shutdown::Bool=false) + skip_shutdown || shutdown(server.on_shutdown) + AwsIO.server_bootstrap_shutdown!(server.bootstrap) + conns = Connection[] + @lock server.connections_lock begin + append!(conns, server.connections) + end + for conn in conns + AwsIO.channel_shutdown!(conn.channel; shutdown_immediately=true) + end + @atomic server.state = :closed + notify(server.closed) + return +end + function Base.close(server::Server) - state = @atomicswap server.state = :closed - if state == :running - AwsIO.server_bootstrap_shutdown!(server.bootstrap) - conns = Connection[] - @lock server.connections_lock begin - append!(conns, server.connections) + state = @atomicswap server.state = :closing + if state == :closed + return + elseif state == :closing + wait(server.closed) + return + end + shutdown(server.on_shutdown) + AwsIO.server_bootstrap_shutdown!(server.bootstrap) + conns = Connection[] + @lock server.connections_lock begin + append!(conns, server.connections) + end + for conn in conns + _stop_new_requests!(conn) + @lock conn.streams_lock begin + if isempty(conn.streams) + AwsIO.channel_shutdown!(conn.channel; shutdown_immediately=true) + @lock server.connections_lock begin + delete!(server.connections, conn) + end + end end - for conn in conns - AwsIO.channel_shutdown!(conn.channel; shutdown_immediately=true) + end + deadline = time() + 0.5 + while time() < deadline + empty = @lock server.connections_lock begin + isempty(server.connections) + end + if empty + @atomic server.state = :closed + notify(server.closed) + return end - @assert wait(server.fut) == :destroyed - notify(server.closed) + sleep(0.05) end + _forceclose!(server; skip_shutdown=true) + return +end + +function forceclose(server::Server) + state = @atomicswap server.state = :closed + state == :closed && return + _forceclose!(server; skip_shutdown = state == :closing) return end -Base.isopen(server::Server) = @atomic(server.state) != :closed +Base.isopen(server::Server) = @atomic(server.state) == :running diff --git a/test/server.jl b/test/server.jl index 028d0395..9401ca2b 100644 --- a/test/server.jl +++ b/test/server.jl @@ -14,6 +14,29 @@ import Sockets end end +@testset "server shutdown hooks" begin + closed = Threads.Atomic{Int}(0) + server = HTTP.serve!(req -> HTTP.Response(200, "ok"); listenany=true, on_shutdown=() -> (closed[] += 1)) + try + port = HTTP.port(server) + HTTP.get("http://127.0.0.1:$port") + finally + close(server) + end + @test closed[] == 1 + + forced = Threads.Atomic{Int}(0) + server2 = HTTP.serve!(req -> HTTP.Response(200, "ok"); listenany=true, + on_shutdown=[() -> (forced[] += 1), () -> (forced[] += 1)]) + try + port2 = HTTP.port(server2) + HTTP.get("http://127.0.0.1:$port2") + finally + HTTP.forceclose(server2) + end + @test forced[] == 2 +end + @testset "access logging stream handler" begin logger = Test.TestLogger() with_logger(logger) do From 3b23a2cebe42663a2e820e327073859f07245cc2 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Mon, 2 Feb 2026 20:44:06 -0700 Subject: [PATCH 52/56] Avoid closing shared connections on readtimeout --- src/client/stream.jl | 16 ++++++---------- test/server.jl | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 10 deletions(-) diff --git a/src/client/stream.jl b/src/client/stream.jl index 994780a6..cf25685e 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -839,9 +839,13 @@ function with_stream(conn, req::Request, chunkedbody, on_stream_response_body, d timeout_task = errormonitor(Threads.@spawn begin sleep(readtimeout) (@atomic stream.fut.set) != 0 && return - if isdefined(stream, :connection) + if !stream.http2 && isdefined(stream, :connection) conn = stream.connection - conn !== nothing && AwsHTTP.http_connection_close(conn) + if conn !== nothing && conn.slot !== nothing && conn.slot.channel !== nothing + AwsIO.channel_shutdown!(conn.slot.channel, AwsHTTP.ERROR_HTTP_RESPONSE_FIRST_BYTE_TIMEOUT; shutdown_immediately=true) + elseif conn !== nothing + AwsHTTP.http_connection_close(conn) + end end notify(stream.fut, TimeoutError(readtimeout)) if isdefined(stream, :aws_stream) @@ -849,14 +853,6 @@ function with_stream(conn, req::Request, chunkedbody, on_stream_response_body, d AwsHTTP.h2_stream_cancel!(stream.aws_stream) else AwsHTTP.http_stream_cancel(stream.aws_stream) - if isdefined(stream, :connection) - conn = stream.connection - if conn !== nothing && conn.slot !== nothing && conn.slot.channel !== nothing - AwsIO.channel_shutdown!(conn.slot.channel, AwsHTTP.ERROR_HTTP_RESPONSE_FIRST_BYTE_TIMEOUT; shutdown_immediately=true) - elseif conn !== nothing - AwsHTTP.http_connection_close(conn) - end - end end end end) diff --git a/test/server.jl b/test/server.jl index 9401ca2b..27db69ad 100644 --- a/test/server.jl +++ b/test/server.jl @@ -203,6 +203,50 @@ end close(server) end end + @testset "HTTP/2 readtimeout keeps connection open" begin + cert = joinpath(@__DIR__, "fixtures", "http2.crt") + key = joinpath(@__DIR__, "fixtures", "http2.key") + seen_lock = ReentrantLock() + seen_conns = Set{UInt}() + server = HTTP.serve!("127.0.0.1", 0; listenany=true, stream=true, ssl_cert=cert, ssl_key=key, ssl_alpn_list="h2") do stream + HTTP.startread(stream) + @lock seen_lock push!(seen_conns, objectid(stream.connection)) + if stream.request.path == "/slow" + sleep(2) + else + sleep(1.5) + end + try + HTTP.setstatus(stream, 200) + HTTP.startwrite(stream) + write(stream, stream.request.path == "/slow" ? "slow" : "fast") + HTTP.closewrite(stream) + catch + nothing + end + end + try + port = HTTP.port(server) + cs = HTTP.ClientSettings("https", "127.0.0.1", UInt32(port); ssl_insecure=true, ssl_alpn_list="h2", max_connections=1) + client = HTTP.Client(cs) + slow_err = try + HTTP.get("https://127.0.0.1:$(port)/slow"; client=client, readtimeout=1, retry=false) + catch e + e + end + fast_resp = HTTP.get("https://127.0.0.1:$(port)/fast"; client=client, retry=false) + if fast_resp.version == HTTP.HTTPVersion(2, 0) + @test slow_err isa HTTP.TimeoutError + @test String(fast_resp.body) == "fast" + @test length(seen_conns) == 1 + else + @info "HTTP/2 not negotiated for readtimeout connection test" + @test true + end + finally + close(server) + end + end end end From c9bbad93668022051c5228291640deff01a10849 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Mon, 2 Feb 2026 21:04:20 -0700 Subject: [PATCH 53/56] Align WebSocket upgrade and protocol handling --- src/websockets.jl | 145 +++++++++++++++++++++++++++++++-------- test/websockets_basic.jl | 62 +++++++++++++++++ 2 files changed, 177 insertions(+), 30 deletions(-) diff --git a/src/websockets.jl b/src/websockets.jl index b3c773da..ede36296 100644 --- a/src/websockets.jl +++ b/src/websockets.jl @@ -52,9 +52,10 @@ mutable struct WsChannelHandler <: AwsIO.AbstractChannelHandler slot::Union{AwsIO.ChannelSlot, Nothing} aws_ws::Any # AwsHTTP.WebSocket wslock::ReentrantLock # protects outgoing_frames access + ws::Any end -WsChannelHandler(aws_ws) = WsChannelHandler(nothing, aws_ws, ReentrantLock()) +WsChannelHandler(aws_ws, ws) = WsChannelHandler(nothing, aws_ws, ReentrantLock(), ws) function AwsIO.setchannelslot!(handler::WsChannelHandler, slot::AwsIO.ChannelSlot)::Nothing handler.slot = slot @@ -67,6 +68,14 @@ function AwsIO.handler_process_read_message(handler::WsChannelHandler, slot::Aws @lock handler.wslock begin status, _ = AwsHTTP.ws_on_incoming_data!(handler.aws_ws, data) if status != AwsHTTP.OP_SUCCESS + ws = handler.ws + if ws !== nothing && !ws.readclosed + close_body = status == AwsHTTP.ERROR_HTTP_WEBSOCKET_PROTOCOL_ERROR ? + CloseFrameBody(1002, "WebSocket protocol error") : + CloseFrameBody(1011, "WebSocket error") + _queue_close!(ws, close_body) + errormonitor(Threads.@spawn close(ws, close_body)) + end return AwsIO.ErrorResult(status) end # Flush auto-responses (PONG, CLOSE echo) generated by ws_on_incoming_data! @@ -91,6 +100,10 @@ function AwsIO.handler_shutdown( error_code::Int, free_scarce_resources_immediately::Bool, )::Union{Nothing, AwsIO.ErrorResult} + ws = handler.ws + if ws !== nothing && !ws.readclosed + _queue_close!(ws, CloseFrameBody(1006, "")) + end AwsIO.channel_slot_on_handler_shutdown_complete!(slot, direction, error_code, free_scarce_resources_immediately) return nothing end @@ -134,7 +147,9 @@ mutable struct WebSocket path::String maxframesize::Int maxfragmentation::Int + is_client::Bool readchannel::Channel{Union{String, Vector{UInt8}, WebSocketError}} + readclosed::Bool writeclosed::Bool closelock::ReentrantLock sendlock::ReentrantLock @@ -153,14 +168,16 @@ mutable struct WebSocket drop_incoming::Bool closebody::Union{Nothing, CloseFrameBody} - WebSocket(host::AbstractString, path::AbstractString; maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG) = new( + WebSocket(host::AbstractString, path::AbstractString; maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, is_client::Bool=true) = new( string(rand(UInt32); base=58), String(host), String(path), Int(maxframesize), Int(maxfragmentation), + is_client, Channel{Union{String, Vector{UInt8}, WebSocketError}}(Inf), false, + false, ReentrantLock(), ReentrantLock(), nothing, @@ -184,6 +201,7 @@ getresponse(ws::WebSocket) = ws.handshake_response function _queue_close!(ws::WebSocket, body::CloseFrameBody) ws.closebody = body + ws.readclosed = true if isopen(ws.readchannel) try put!(ws.readchannel, WebSocketError(body)) @@ -199,6 +217,15 @@ function _close_channel!(ws::WebSocket) return end +function _shutdown_ws_channel!(handler::WsChannelHandler) + slot = handler.slot + slot === nothing && return + channel = slot.channel + channel === nothing && return + AwsIO.channel_shutdown!(channel; shutdown_immediately=true) + return +end + function _enqueue_message!(ws::WebSocket, msg) if isopen(ws.readchannel) try @@ -209,6 +236,26 @@ function _enqueue_message!(ws::WebSocket, msg) return end +function _valid_close_status(code::Int)::Bool + code < 0 && return false + code > typemax(UInt16) && return false + return AwsHTTP.ws_is_valid_close_status(UInt16(code)) +end + +function _queue_protocol_error!(ws::WebSocket, reason::String) + close_body = CloseFrameBody(1002, reason) + _queue_close!(ws, close_body) + errormonitor(Threads.@spawn close(ws, close_body)) + return +end + +function _queue_invalid_payload!(ws::WebSocket, reason::String) + close_body = CloseFrameBody(1007, reason) + _queue_close!(ws, close_body) + errormonitor(Threads.@spawn close(ws, close_body)) + return +end + function close_payload(body::CloseFrameBody) reason_bytes = collect(codeunits(body.reason)) payload = Vector{UInt8}(undef, 2 + length(reason_bytes)) @@ -240,7 +287,7 @@ function _on_incoming_frame_begin(ws::WebSocket) if frame_info.payload_length > ws.maxframesize close_body = CloseFrameBody(1009, "frame too large") _queue_close!(ws, close_body) - Threads.@spawn close(ws, close_body) + errormonitor(Threads.@spawn close(ws, close_body)) ws.drop_incoming = true end frame_info.payload_length > 0 && sizehint!(ws.incoming_payload, Int(frame_info.payload_length)) @@ -268,7 +315,7 @@ function _on_incoming_frame_complete(ws::WebSocket) @error "$(ws.id): incoming frame complete error" error_code close_body = CloseFrameBody(1006, "") _queue_close!(ws, close_body) - Threads.@spawn close(ws, close_body) + errormonitor(Threads.@spawn close(ws, close_body)) return true end if ws.drop_incoming @@ -285,9 +332,19 @@ function _on_incoming_frame_complete(ws::WebSocket) # CLOSE: AwsHTTP auto-echoes, we just track state if op == UInt8(CLOSE) body = payload + if length(body) == 1 + _queue_protocol_error!(ws, "invalid close payload length") + return true + end close_body = if length(body) >= 2 code = (Int(body[1]) << 8) | Int(body[2]) - reason = length(body) > 2 ? String(copy(body[3:end])) : "" + _valid_close_status(code) || (_queue_protocol_error!(ws, "invalid close status code"); return true) + reason_bytes = length(body) > 2 ? body[3:end] : UInt8[] + if !isempty(reason_bytes) && !isvalid(String, reason_bytes) + _queue_invalid_payload!(ws, "invalid close reason") + return true + end + reason = length(body) > 2 ? String(copy(reason_bytes)) : "" CloseFrameBody(code, reason) else CloseFrameBody(1005, "") @@ -300,16 +357,14 @@ function _on_incoming_frame_complete(ws::WebSocket) # Data frames: TEXT, BINARY, CONTINUATION if op == UInt8(CONTINUATION) if ws.fragment_opcode === nothing - close_body = CloseFrameBody(1002, "unexpected continuation") - _queue_close!(ws, close_body) - Threads.@spawn close(ws, close_body) + _queue_protocol_error!(ws, "unexpected continuation") return true end ws.fragment_count += 1 if ws.fragment_count > ws.maxfragmentation close_body = CloseFrameBody(1009, "message too large") _queue_close!(ws, close_body) - Threads.@spawn close(ws, close_body) + errormonitor(Threads.@spawn close(ws, close_body)) return true end append!(ws.fragment_payload, payload) @@ -320,6 +375,10 @@ function _on_incoming_frame_complete(ws::WebSocket) ws.fragment_payload = UInt8[] ws.fragment_count = 0 if msg_opcode == UInt8(TEXT) + if !isvalid(String, data) + _queue_invalid_payload!(ws, "invalid UTF-8") + return true + end _enqueue_message!(ws, String(copy(data))) else _enqueue_message!(ws, copy(data)) @@ -329,13 +388,15 @@ function _on_incoming_frame_complete(ws::WebSocket) end if op == UInt8(TEXT) || op == UInt8(BINARY) if ws.fragment_opcode !== nothing - close_body = CloseFrameBody(1002, "unexpected new data frame") - _queue_close!(ws, close_body) - Threads.@spawn close(ws, close_body) + _queue_protocol_error!(ws, "unexpected new data frame") return true end if fin if op == UInt8(TEXT) + if !isvalid(String, payload) + _queue_invalid_payload!(ws, "invalid UTF-8") + return true + end _enqueue_message!(ws, String(copy(payload))) else _enqueue_message!(ws, copy(payload)) @@ -348,7 +409,7 @@ function _on_incoming_frame_complete(ws::WebSocket) if ws.fragment_count > ws.maxfragmentation close_body = CloseFrameBody(1009, "message too large") _queue_close!(ws, close_body) - Threads.@spawn close(ws, close_body) + errormonitor(Threads.@spawn close(ws, close_body)) return true end end @@ -366,7 +427,7 @@ function _create_ws_handler!(ws::WebSocket, slot::AwsIO.ChannelSlot, is_client:: on_incoming_frame_payload=_on_incoming_frame_payload(ws), on_incoming_frame_complete=_on_incoming_frame_complete(ws), ) - handler = WsChannelHandler(aws_ws) + handler = WsChannelHandler(aws_ws, ws) ws.aws_ws = aws_ws ws.handler = handler AwsIO.channel_slot_set_handler!(slot, handler) @@ -477,7 +538,7 @@ returned by `receive`. Note that `WebSocket` objects can be iterated, where each iteration yields a message until the connection is closed. """ function receive(ws::WebSocket) - if !isopen(ws.readchannel) + if ws.readclosed || !isopen(ws.readchannel) close_body = ws.closebody === nothing ? CloseFrameBody(1006, "") : ws.closebody throw(WebSocketError(close_body)) end @@ -515,32 +576,42 @@ end Check whether a `WebSocket` has sent and received CLOSE frames """ -isclosed(ws::WebSocket) = !isopen(ws.readchannel) && ws.writeclosed +isclosed(ws::WebSocket) = ws.readclosed && ws.writeclosed function Base.close(ws::WebSocket, body::Union{Nothing, CloseFrameBody}=nothing) + handler = ws.handler @lock ws.closelock begin if ws.writeclosed _close_channel!(ws) return end ws.writeclosed = true - handler = ws.handler if handler !== nothing - if body !== nothing - code = UInt16(body.code) - reason = Vector{UInt8}(codeunits(body.reason)) - @lock handler.wslock begin - try - AwsHTTP.ws_close!(handler.aws_ws; status_code=code, reason=reason) - _ws_channel_flush!(handler) - catch - # ignore errors while closing - end + close_body = body === nothing ? CloseFrameBody(1000, "") : body + code = UInt16(close_body.code) + reason = Vector{UInt8}(codeunits(close_body.reason)) + @lock handler.wslock begin + try + AwsHTTP.ws_close!(handler.aws_ws; status_code=code, reason=reason) + _ws_channel_flush!(handler) + catch + # ignore errors while closing end end end - _close_channel!(ws) end + if !ws.readclosed + deadline = time() + 5.0 + while time() < deadline + ws.readclosed && break + sleep(0.05) + end + ws.readclosed = true + end + if !ws.is_client && handler !== nothing + _shutdown_ws_channel!(handler) + end + _close_channel!(ws) return end @@ -572,6 +643,7 @@ function open(f::Function, url; kw... ) key = base64encode(rand(Random.RandomDevice(), UInt8, 16)) + expected_accept = AwsHTTP.ws_compute_accept_key(key) uri = parseuri(url, query) # add required websocket headers headers = collect(headers) @@ -586,7 +658,7 @@ function open(f::Function, url; path = resource(uri) with_request(reqclient, method, path, headers, body, nothing, false, (username !== nothing && password !== nothing) ? "$username:$password" : userinfo(uri), bearer, modifier, false, cookies, cookiejar, verbose) do req ws_host = header(req, "host", String(host(uri))) - ws = WebSocket(ws_host, path; maxframesize=maxframesize, maxfragmentation=maxfragmentation) + ws = WebSocket(ws_host, path; maxframesize=maxframesize, maxfragmentation=maxfragmentation, is_client=true) ws.handshake_request = req with_connection(reqclient) do conn stream = _open_stream(conn, req, false, 0) @@ -598,6 +670,17 @@ function open(f::Function, url; try; closeread(stream); catch; end return ws end + if !isupgrade(resp) + try; closeread(stream); catch; end + AwsHTTP.http_connection_close(conn) + handshakeerror() + end + accept = getheader(resp.headers, "sec-websocket-accept") + if accept === nothing || accept != expected_accept + try; closeread(stream); catch; end + AwsHTTP.http_connection_close(conn) + handshakeerror() + end # Wait for the H1 stream to complete (fires with ERROR_HTTP_SWITCHED_PROTOCOLS) try wait(stream.fut) @@ -630,6 +713,7 @@ function open(f::Function, url; if !isclosed(ws) close(ws, CloseFrameBody(1000, "")) end + AwsHTTP.http_connection_close(conn) end end return ws @@ -674,7 +758,7 @@ function websocket_upgrade_function(f; suppress_close_error::Bool=false, maxfram return end req = stream.request - ws = WebSocket(header(req, "host", ""), req.path; maxframesize=maxframesize, maxfragmentation=maxfragmentation) + ws = WebSocket(header(req, "host", ""), req.path; maxframesize=maxframesize, maxfragmentation=maxfragmentation, is_client=false) ws.handshake_request = req ws.handshake_response = resp # Get the H1Connection's channel slot and swap in our WebSocket handler @@ -712,6 +796,7 @@ function websocket_upgrade_function(f; suppress_close_error::Bool=false, maxfram end done !== nothing && notify(done, nothing) end + AwsHTTP.http_connection_close(h1conn) end end) return diff --git a/test/websockets_basic.jl b/test/websockets_basic.jl index e7db8057..3dcb0a1a 100644 --- a/test/websockets_basic.jl +++ b/test/websockets_basic.jl @@ -135,3 +135,65 @@ end @test err isa WebSockets.WebSocketError @test err.message.code == 1009 end + +@testset "WebSockets handshake accept validation" begin + server = HTTP.listen!("127.0.0.1", 0; listenany=true) do stream + HTTP.startread(stream) + HTTP.setstatus(stream, 101) + HTTP.setheader(stream, "Upgrade" => "websocket") + HTTP.setheader(stream, "Connection" => "Upgrade") + HTTP.setheader(stream, "Sec-WebSocket-Accept" => "invalid") + HTTP.startwrite(stream) + HTTP.closewrite(stream) + end + port = HTTP.port(server) + err = nothing + try + WebSockets.open("ws://127.0.0.1:$port"; suppress_close_error=true) do ws + end + catch e + err = e + finally + close(server) + end + @test err isa WebSockets.WebSocketError + @test err.message.code == 1002 +end + +@testset "WebSockets invalid close status" begin + server = WebSockets.listen!("127.0.0.1", 0; listenany=true) do ws + WebSockets.close(ws, WebSockets.CloseFrameBody(1005, "bad")) + end + port = HTTP.port(server) + err = nothing + try + WebSockets.open("ws://127.0.0.1:$port"; suppress_close_error=true) do ws + receive(ws) + end + catch e + err = e + finally + close(server) + end + @test err isa WebSockets.WebSocketError + @test err.message.code == 1002 +end + +@testset "WebSockets invalid UTF-8 text" begin + server = WebSockets.listen!("127.0.0.1", 0; listenany=true) do ws + WebSockets.writeframe(ws, true, WebSockets.TEXT, UInt8[0xff]) + end + port = HTTP.port(server) + err = nothing + try + WebSockets.open("ws://127.0.0.1:$port"; suppress_close_error=true) do ws + receive(ws) + end + catch e + err = e + finally + close(server) + end + @test err isa WebSockets.WebSocketError + @test err.message.code == 1007 +end From d15180ae1cb59ed3391d1e616594a4661911e009 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Mon, 2 Feb 2026 21:15:37 -0700 Subject: [PATCH 54/56] Align HTTP/2 host and authority handling --- src/client/request.jl | 21 +++++++++++++++++++-- src/requestresponse.jl | 26 +++++++++++++++++++++++--- test/client.jl | 11 +++++++++++ test/server.jl | 30 ++++++++++++++++++++++++++++++ 4 files changed, 83 insertions(+), 5 deletions(-) diff --git a/src/client/request.jl b/src/client/request.jl index 334e16ca..9abfa763 100644 --- a/src/client/request.jl +++ b/src/client/request.jl @@ -15,6 +15,15 @@ function setuseragent!(x::Union{String, Nothing}) return end +function _default_host_header(settings::ClientSettings) + scheme = lowercase(settings.scheme) + default_port = (scheme == "https" || scheme == "wss") ? UInt32(443) : UInt32(80) + if settings.port == default_port + return settings.host + end + return string(settings.host, ":", settings.port) +end + function with_request( f::Function, client::Client, @@ -53,10 +62,18 @@ function with_request( # add headers to request h = req.headers if http2 + authority = AwsHTTP.http_headers_get(h.hdrs, ":authority") + if authority === nothing + authority = AwsHTTP.http_headers_get(h.hdrs, "host") + end + if authority === nothing + authority = _default_host_header(client.settings) + end + AwsHTTP.http_headers_has(h.hdrs, "host") || setheader(h, "host", authority) setscheme(h, client.settings.scheme) - setauthority(h, client.settings.host) + setauthority(h, authority) else - setheader(h, "host", client.settings.host) + setheaderifabsent(h, "host", _default_host_header(client.settings)) end setheaderifabsent(h, "accept", "*/*") if USER_AGENT[] !== nothing diff --git a/src/requestresponse.jl b/src/requestresponse.jl index a61d31ea..a725c519 100644 --- a/src/requestresponse.jl +++ b/src/requestresponse.jl @@ -114,10 +114,22 @@ setscheme(headers::Headers, scheme) = AwsHTTP.http2_headers_set_request_scheme(h setauthority(headers::Headers, authority) = AwsHTTP.http2_headers_set_request_authority(headers.hdrs, String(authority)) != 0 && aws_throw_error() function getheader(headers::Headers, k) - return AwsHTTP.http_headers_get(headers.hdrs, String(k)) + name = String(k) + val = AwsHTTP.http_headers_get(headers.hdrs, name) + if val === nothing && field_name_isequal(name, "host") + val = AwsHTTP.http_headers_get(headers.hdrs, ":authority") + end + return val end -hasheader(headers::Headers, k) = AwsHTTP.http_headers_has(headers.hdrs, String(k)) +function hasheader(headers::Headers, k) + name = String(k) + has = AwsHTTP.http_headers_has(headers.hdrs, name) + if !has && field_name_isequal(name, "host") + has = AwsHTTP.http_headers_has(headers.hdrs, ":authority") + end + return has +end removeheader(headers::Headers, k) = AwsHTTP.http_headers_erase(headers.hdrs, String(k)) != 0 && aws_throw_error() removeheader(headers::Headers, k, v) = AwsHTTP.http_headers_erase_value(headers.hdrs, String(k), String(v)) != 0 && aws_throw_error() @@ -181,7 +193,15 @@ end Get all headers with key `k` or empty if none. """ -headers(h::Headers, k) = [_header_value(h2) for h2 in h if field_name_isequal(_header_name(h2), k)] +function headers(h::Headers, k) + vals = [_header_value(h2) for h2 in h if field_name_isequal(_header_name(h2), k)] + if isempty(vals) && field_name_isequal(k, "host") + authority = AwsHTTP.http_headers_get(h.hdrs, ":authority") + authority === nothing && return String[] + return [authority] + end + return vals +end headers(h::AbstractVector{<:Pair}, k) = [String(v) for (name, v) in h if field_name_isequal(name, k)] headers(m::Message, k) = headers(m.headers, k) diff --git a/test/client.jl b/test/client.jl index f14bbab1..12ee82e9 100644 --- a/test/client.jl +++ b/test/client.jl @@ -232,6 +232,17 @@ end end + @testset "Host header includes port for non-default" begin + server = HTTP.serve!(req -> HTTP.Response(200, HTTP.header(req, "host")); listenany=true) + try + port = HTTP.port(server) + resp = HTTP.get("http://127.0.0.1:$port") + @test String(resp.body) == "127.0.0.1:$port" + finally + close(server) + end + end + @testset "readtimeout" begin server = HTTP.serve!("127.0.0.1", 0; listenany=true) do req if req.target == "/delay/5" diff --git a/test/server.jl b/test/server.jl index 27db69ad..f943609d 100644 --- a/test/server.jl +++ b/test/server.jl @@ -164,6 +164,36 @@ end close(server) end end + @testset "HTTP/2 host and authority mapping" begin + cert = joinpath(@__DIR__, "fixtures", "http2.crt") + key = joinpath(@__DIR__, "fixtures", "http2.key") + host_ref = Ref{String}("") + authority_ref = Ref{String}("") + server = HTTP.serve!("127.0.0.1", 0; listenany=true, stream=true, ssl_cert=cert, ssl_key=key, ssl_alpn_list="h2") do stream + HTTP.startread(stream) + if stream.http2 + host_ref[] = HTTP.header(stream.request, "host") + authority_ref[] = HTTP.header(stream.request, ":authority") + end + HTTP.setstatus(stream, 200) + HTTP.startwrite(stream) + write(stream, "ok") + HTTP.closewrite(stream) + end + try + port = HTTP.port(server) + resp = HTTP.get("https://127.0.0.1:$(port)"; ssl_insecure=true, ssl_alpn_list="h2", headers=["Host" => "example.com"]) + if resp.version == HTTP.HTTPVersion(2, 0) + @test host_ref[] == "example.com" + @test authority_ref[] == "example.com" + else + @info "HTTP/2 not negotiated for host/authority mapping test" + @test true + end + finally + close(server) + end + end @testset "HTTP/2 server push promise" begin cert = joinpath(@__DIR__, "fixtures", "http2.crt") key = joinpath(@__DIR__, "fixtures", "http2.key") From eda7c60f05444657218be0cba0208d8472ee45fb Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Mon, 2 Feb 2026 21:22:21 -0700 Subject: [PATCH 55/56] Add HTTP/2 readtimeout handling --- src/client/makerequest.jl | 2 +- src/client/stream.jl | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/client/makerequest.jl b/src/client/makerequest.jl index e8125feb..fc172215 100644 --- a/src/client/makerequest.jl +++ b/src/client/makerequest.jl @@ -70,7 +70,7 @@ function request(method, url, h=Header[], b=nothing; response_body=response_stream, decompress::Union{Nothing, Bool}=nothing, status_exception::Bool=true, - readtimeout::Int=0, # only supported for HTTP 1.1, not HTTP 2 (current aws limitation) + readtimeout::Int=0, retry_non_idempotent::Bool=false, modifier=nothing, verbose=0, diff --git a/src/client/stream.jl b/src/client/stream.jl index cf25685e..c8a06947 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -758,6 +758,17 @@ function with_stream_manager(client::Client, req::Request, chunkedbody, on_strea aws_stream = AwsHTTP.http_connection_make_request(connection, request_options) aws_stream === nothing && aws_throw_error() stream.aws_stream = aws_stream + timeout_task = nothing + if readtimeout > 0 + timeout_task = errormonitor(Threads.@spawn begin + sleep(readtimeout) + (@atomic stream.fut.set) != 0 && return + notify(stream.fut, TimeoutError(readtimeout)) + if isdefined(stream, :aws_stream) + AwsHTTP.h2_stream_cancel!(stream.aws_stream) + end + end) + end # Activate stream _activate_stream!(stream) @@ -797,6 +808,7 @@ function with_stream_manager(client::Client, req::Request, chunkedbody, on_strea end return resp finally + timeout_task = nothing stream.released = true AwsHTTP.http2_stream_manager_release_stream(client.http2_stream_manager, connection) if context !== nothing From 5fa2255573fbdc8f25d29f0db6a7eea7446c3421 Mon Sep 17 00:00:00 2001 From: Jacob Quinn Date: Tue, 10 Feb 2026 23:21:43 -0700 Subject: [PATCH 56/56] Remove ErrorResult, adopt exception-based error model Adapt to Reseau's new exception-based error model: - Replace AwsIO module references with Reseau.Sockets - Fix channel_slot_send_message result check (now returns nothing) - Change WsChannelHandler vtable methods from ::Int to ::Nothing - Fix unsafe_string call on Reseau.error_str (returns String, not Ptr) - Update tests for new error patterns Co-Authored-By: Claude Opus 4.6 --- Project.toml | 8 +- src/HTTP.jl | 158 ++++++++-------- src/client/client.jl | 71 +++---- src/client/connection.jl | 2 +- src/client/makerequest.jl | 104 ++++++----- src/client/retry.jl | 214 +++++---------------- src/client/stream.jl | 24 ++- src/cookies.jl | 20 +- src/download.jl | 8 +- src/exceptions.jl | 9 +- src/server.jl | 132 ++++++------- src/statistics.jl | 10 +- src/utils.jl | 11 +- src/websockets.jl | 101 ++++++---- test/client.jl | 12 +- test/runtests.jl | 2 +- test/server.jl | 258 +++++++++++++++++++++++++- test/websockets/autobahn.jl | 4 +- test/websockets/deno_client/server.jl | 4 +- 19 files changed, 674 insertions(+), 478 deletions(-) diff --git a/Project.toml b/Project.toml index 3f1d5760..70f876cd 100644 --- a/Project.toml +++ b/Project.toml @@ -9,26 +9,26 @@ CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" AwsHTTP = "d4eb1443-154a-48c0-b55a-2f1d1087a5c5" -AwsIO = "4047365c-aa37-44ec-b1fa-4c0d5495ccf1" +Reseau = "802f3686-a58f-41ce-bb0c-3c43c75bba36" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Mmap = "a63ad114-7e13-5084-954f-fe012c677804" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" [compat] CodecZlib = "0.7" JSON = "0.21.4, 1" AwsHTTP = "0.1" -AwsIO = "1.1" +Reseau = "1.1" PrecompileTools = "1.2.1" URIs = "1" julia = "1.10" [extras] JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["JSON", "Test"] +test = ["JSON", "Sockets", "Test"] diff --git a/src/HTTP.jl b/src/HTTP.jl index 3b75297a..9d9f4e0d 100644 --- a/src/HTTP.jl +++ b/src/HTTP.jl @@ -1,90 +1,100 @@ module HTTP -using CodecZlib, URIs, Mmap, Base64, Dates, Sockets -using AwsIO, AwsHTTP +const _HTTP_TRIM_MODE = get(ENV, "HTTP_TRIM", "0") == "1" -export HTTPVersion -export startwrite, startread, closewrite, closeread -export @logfmt_str, common_logfmt, combined_logfmt -export WebSockets -export Pool, default_connection_limit, set_default_connection_limit!, closeall +@static if _HTTP_TRIM_MODE + # --------------------------------------------------------------------- + # Trim mode: minimal, synchronous HTTP/1.1 client surface. + # --------------------------------------------------------------------- + using Reseau + include("trim/trim.jl") +else + using CodecZlib, URIs, Mmap, Base64, Dates + using Reseau, AwsHTTP -const nobody = UInt8[] + export HTTPVersion + export startwrite, startread, closewrite, closeread + export @logfmt_str, common_logfmt, combined_logfmt + export WebSockets + export Pool, default_connection_limit, set_default_connection_limit!, closeall -Base.@deprecate escape escapeuri + const nobody = UInt8[] -include("utils.jl") -include("statistics.jl") -include("access_log.jl") -include("sniff.jl"); using .Sniff -include("forms.jl"); using .Forms -include("requestresponse.jl") -include("exceptions.jl"); using .Exceptions -struct StatusError <: HTTPError - request_method::String - request_uri::URI - response::Response -end + Base.@deprecate escape escapeuri -function Base.showerror(io::IO, e::StatusError) - println(io, "HTTP.StatusError:") - println(io, " Request method: $(e.request_method)") - println(io, " Request URI: $(e.request_uri)") - println(io, " response:") - print_response(io, e.response) - return -end + include("utils.jl") + include("statistics.jl") + include("access_log.jl") + include("sniff.jl"); using .Sniff + include("forms.jl"); using .Forms + include("requestresponse.jl") + include("exceptions.jl"); using .Exceptions + struct StatusError <: HTTPError + request_method::String + request_uri::URI + response::Response + end -# backwards compatibility -function Base.getproperty(e::StatusError, s::Symbol) - if s == :status - return e.response.status - elseif s == :method - return e.request_method - elseif s == :target - return e.request_uri - else - return getfield(e, s) + function Base.showerror(io::IO, e::StatusError) + println(io, "HTTP.StatusError:") + println(io, " Request method: $(e.request_method)") + println(io, " Request URI: $(e.request_uri)") + println(io, " response:") + print_response(io, e.response) + return end -end -include("cookies.jl"); using .Cookies -include("client/redirects.jl") -include("client/client.jl") -include("client/retry.jl") -include("client/connection.jl") -include("client/request.jl") -include("client/stream.jl") -include("client/makerequest.jl") -include("client/open.jl") -include("download.jl") -include("websockets.jl"); using .WebSockets -include("server.jl") -include("handlers.jl"); using .Handlers -include("statuses.jl") -#NOTE: this is process-level logging; not appropriate for request-level -# logging, but more for debugging the library itself -function set_log_level!(level::Integer) - @assert 0 <= level <= 7 "log level must be between 0 and 7" - AwsIO.set_log_level!(AwsIO.logger_get(), AwsIO.LogLevel.T(level)) - return -end + # backwards compatibility + function Base.getproperty(e::StatusError, s::Symbol) + if s == :status + return e.response.status + elseif s == :method + return e.request_method + elseif s == :target + return e.request_uri + else + return getfield(e, s) + end + end + include("cookies.jl"); using .Cookies + include("client/redirects.jl") + include("client/client.jl") + include("client/retry.jl") + include("client/connection.jl") + include("client/request.jl") + include("client/stream.jl") + include("client/makerequest.jl") + include("client/open.jl") + include("download.jl") + include("websockets.jl"); using .WebSockets + include("server.jl") + include("handlers.jl"); using .Handlers + include("statuses.jl") -function __init__() - AwsHTTP.http_library_init() - return -end + #NOTE: this is process-level logging; not appropriate for request-level + # logging, but more for debugging the library itself + function set_log_level!(level::Integer) + @assert 0 <= level <= 7 "log level must be between 0 and 7" + Reseau.set_log_level!(Reseau.logger_get(), Reseau.LogLevel.T(level)) + return + end + + function __init__() + AwsHTTP.http_library_init() + return + end -# only run if precompiling -if VERSION >= v"1.9.0-0" && ccall(:jl_generating_output, Cint, ()) == 1 - do_precompile = true - try - Sockets.getalladdrinfo("localhost") - catch ex - @debug "Skipping precompilation workload because localhost cannot be resolved. Check firewall settings" exception=(ex,catch_backtrace()) - do_precompile = false + # only run if precompiling + if VERSION >= v"1.9.0-0" && ccall(:jl_generating_output, Cint, ()) == 1 + do_precompile = true + try + isempty(Reseau.getalladdrinfo("localhost")) && error("localhost cannot be resolved") + catch ex + @debug "Skipping precompilation workload because localhost cannot be resolved. Check firewall settings" exception=(ex,catch_backtrace()) + do_precompile = false + end + # do_precompile && include("precompile.jl") end - # do_precompile && include("precompile.jl") end end diff --git a/src/client/client.jl b/src/client/client.jl index 3ad8d456..65400f32 100644 --- a/src/client/client.jl +++ b/src/client/client.jl @@ -1,5 +1,5 @@ const DEFAULT_CONNECT_TIMEOUT = 3000 -const DEFAULT_MAX_RETRIES = 4 +const DEFAULT_MAX_RETRIES = 0 const default_connection_limit = Ref{Int}(max(16, Threads.nthreads() * 4)) # ─── Shared infrastructure ─── @@ -14,15 +14,12 @@ function _ensure_resources!() _CLIENT_BOOTSTRAP[] !== nothing && return Base.@lock _RESOURCES_LOCK begin _CLIENT_BOOTSTRAP[] !== nothing && return - elg_opts = AwsIO.EventLoopGroupOptions(; - type = _use_nw_sockets() ? AwsIO.EventLoopType.DISPATCH_QUEUE : AwsIO.EventLoopType.PLATFORM_DEFAULT, - ) - elg = AwsIO.event_loop_group_new(elg_opts) - elg isa AwsIO.ErrorResult && throw(AWSError("Failed to create event loop group; ensure sufficient interactive threads")) + elg_opts = Reseau.EventLoops.EventLoopGroupOptions() + elg = Reseau.EventLoops.EventLoopGroup(elg_opts) _EVENT_LOOP_GROUP[] = elg - resolver = AwsIO.DefaultHostResolver(elg) + resolver = Reseau.Sockets.HostResolver(elg) _HOST_RESOLVER[] = resolver - bootstrap = AwsIO.ClientBootstrap(AwsIO.ClientBootstrapOptions( + bootstrap = Reseau.Sockets.ClientBootstrap(Reseau.Sockets.ClientBootstrapOptions( event_loop_group=elg, host_resolver=resolver, )) @@ -30,36 +27,48 @@ function _ensure_resources!() end end +function _task_sleep_s(seconds::Real)::Nothing + seconds <= 0 && return nothing + elg = _EVENT_LOOP_GROUP[] + if elg === nothing + Reseau.thread_sleep_s(seconds) + return nothing + end + el = Reseau.EventLoops.event_loop_group_get_next_loop(elg) + if el === nothing + Reseau.thread_sleep_s(seconds) + return nothing + end + Reseau.EventLoops.task_sleep_s(el, seconds) + return nothing +end + # ─── TLS helper ─── function _make_tls_options(host::String; ssl_cert, ssl_key, ssl_capath, ssl_cacert, ssl_insecure, ssl_alpn_list) alpn_list = _normalize_alpn_list(ssl_alpn_list) if ssl_cert !== nothing && ssl_key !== nothing # Mutual TLS: client certificate + key (file paths) - opts = AwsIO.tls_ctx_options_init_client_mtls_from_path(ssl_cert, ssl_key) - opts isa AwsIO.ErrorResult && throw(AWSError("Failed to create mTLS options")) - AwsIO.tls_ctx_options_set_verify_peer!(opts, !ssl_insecure) + opts = Reseau.Sockets.tls_ctx_options_init_client_mtls_from_path(ssl_cert, ssl_key) + Reseau.Sockets.tls_ctx_options_set_verify_peer!(opts, !ssl_insecure) if alpn_list !== nothing && !isempty(alpn_list) - AwsIO.tls_ctx_options_set_alpn_list!(opts, alpn_list) + Reseau.Sockets.tls_ctx_options_set_alpn_list!(opts, alpn_list) end if ssl_cacert !== nothing || ssl_capath !== nothing - res = AwsIO.tls_ctx_options_override_default_trust_store_from_path!(opts; + Reseau.Sockets.tls_ctx_options_override_default_trust_store_from_path!(opts; ca_path=ssl_capath, ca_file=ssl_cacert) - res isa AwsIO.ErrorResult && throw(AWSError("Failed to set trust store")) end - ctx = AwsIO.tls_context_new(opts) - ctx isa AwsIO.ErrorResult && throw(AWSError("Failed to create TLS context")) + ctx = Reseau.Sockets.tls_context_new(opts) else # Standard client TLS (no client cert) - ctx = AwsIO.tls_context_new_client(; + ctx = Reseau.Sockets.tls_context_new_client(; verify_peer=!ssl_insecure, ca_file=ssl_cacert, ca_path=ssl_capath, alpn_list=alpn_list, ) - ctx isa AwsIO.ErrorResult && throw(AWSError("Failed to create TLS context")) end - return AwsIO.TlsConnectionOptions(ctx; server_name=host) + return Reseau.Sockets.TlsConnectionOptions(ctx; server_name=host) end # ─── Settings ─── @@ -202,15 +211,15 @@ end mutable struct Client settings::ClientSettings - socket_options::AwsIO.SocketOptions - tls_options::Union{Nothing, AwsIO.TlsConnectionOptions} + socket_options::Reseau.Sockets.SocketOptions + tls_options::Union{Nothing, Reseau.Sockets.TlsConnectionOptions} # only 1 of proxy_options or proxy_env_settings is set proxy_options::Union{Nothing, AwsHTTP.HttpProxyOptions} proxy_env_settings::Union{Nothing, AwsHTTP.ProxyEnvVarSettings} proxy_strategy::Union{Nothing, AwsHTTP.HttpProxyStrategy} monitoring_options::Union{Nothing, AwsHTTP.HttpConnectionMonitoringOptions} monitoring_observer::Union{Nothing, Function} - retry_strategy::AwsIO.StandardRetryStrategy + retry_strategy::Reseau.Sockets.StandardRetryStrategy connection_manager::AwsHTTP.HttpConnectionManager http2_stream_manager::Union{Nothing, AwsHTTP.Http2StreamManager} http2_initial_settings::Union{Nothing, Vector{AwsHTTP.Http2Setting}} @@ -230,18 +239,15 @@ function Client(cs::ClientSettings) throw(ArgumentError("http2_initial_window_size must be between 0 and $(HTTP2_MAX_WINDOW_SIZE)")) end # socket options - client.socket_options = AwsIO.SocketOptions(; - type=AwsIO.SocketType.STREAM, - domain=cs.socket_domain == :ipv4 ? AwsIO.SocketDomain.IPV4 : AwsIO.SocketDomain.IPV6, + client.socket_options = Reseau.Sockets.SocketOptions(; + type=Reseau.Sockets.SocketType.STREAM, + domain=cs.socket_domain == :ipv4 ? Reseau.Sockets.SocketDomain.IPV4 : Reseau.Sockets.SocketDomain.IPV6, connect_timeout_ms=cs.connect_timeout_ms, keepalive=cs.keepalive, keep_alive_interval_sec=cs.keep_alive_interval_sec, keep_alive_timeout_sec=cs.keep_alive_timeout_sec, keep_alive_max_failed_probes=cs.keep_alive_max_failed_probes, ) - if _use_nw_sockets() - client.socket_options.impl_type = AwsIO.SocketImplType.APPLE_NETWORK_FRAMEWORK - end # tls options if cs.scheme == "https" || cs.scheme == "wss" client.tls_options = _make_tls_options(cs.host; @@ -301,18 +307,17 @@ function Client(cs::ClientSettings) end client.monitoring_observer = cs.monitoring_statistics_observer # retry strategy - backoff_config = AwsIO.ExponentialBackoffConfig(; + backoff_config = Reseau.Sockets.ExponentialBackoffConfig(; backoff_scale_factor_ms=cs.backoff_scale_factor_ms, max_backoff_secs=cs.max_backoff_secs, max_retries=cs.max_retries, jitter_mode=cs.jitter_mode, ) - retry_config = AwsIO.StandardRetryConfig(; + retry_config = Reseau.Sockets.StandardRetryConfig(; initial_bucket_capacity=cs.initial_bucket_capacity, backoff_config=backoff_config, ) - strategy = AwsIO.StandardRetryStrategy(_EVENT_LOOP_GROUP[], retry_config) - strategy isa AwsIO.ErrorResult && throw(AWSError("Failed to create retry strategy")) + strategy = Reseau.Sockets.StandardRetryStrategy(_EVENT_LOOP_GROUP[], retry_config) client.retry_strategy = strategy # http2 initial settings settings_input = cs.http2_initial_settings @@ -348,7 +353,7 @@ function Client(cs::ClientSettings) manual_window_management=manual_wm, initial_window_size=Csize_t(initial_ws), response_first_byte_timeout_ms=UInt64(rfbt_ms), - on_setup=(conn, err, ud) -> put!(result_ch, err == AwsIO.OP_SUCCESS ? conn : nothing), + on_setup=(conn, err, ud) -> put!(result_ch, err == Reseau.OP_SUCCESS ? conn : nothing), )) return take!(result_ch) end diff --git a/src/client/connection.jl b/src/client/connection.jl index 0daa6659..a326df05 100644 --- a/src/client/connection.jl +++ b/src/client/connection.jl @@ -11,7 +11,7 @@ function with_connection(f::Function, client::Client; context=nothing) client.connection_manager; callback = (conn, error_code, _) -> begin if error_code != AwsHTTP.OP_SUCCESS - ec = AwsIO.last_error() + ec = Reseau.last_error() put!(ch, CapturedException(aws_error(ec), Base.backtrace())) else put!(ch, conn) diff --git a/src/client/makerequest.jl b/src/client/makerequest.jl index fc172215..79a98215 100644 --- a/src/client/makerequest.jl +++ b/src/client/makerequest.jl @@ -106,71 +106,73 @@ function request(method, url, h=Header[], b=nothing; end authinfo = (username !== nothing && password !== nothing) ? "$username:$password" : userinfo(uri) apply_basicauth = (username !== nothing && password !== nothing) ? true : basicauth - return with_redirect(method, uri, headers, body, redirect, redirect_limit, redirect_method, forwardheaders; context=context) do method, uri, headers, body - reqclient = @something( - client, - pool === nothing ? - getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...)) : - getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...), pool) - )::Client - req_ref = Ref{Any}(nothing) - with_retry_token(reqclient; logerrors=logerrors, logtag=logtag, method=method, uri=uri, - retry_check=retry_check, retry_delays=retry_delays, - retry_non_idempotent=retry_non_idempotent, retryable_body=retryable_body, req_ref=req_ref, context=context) do - resp = if reqclient.http2_stream_manager !== nothing - path = resource(uri) - with_request(reqclient, method, path, headers, body, chunkedbody, decompress, authinfo, bearer, modifier, true, cookies, cookiejar, verbose; - copyheaders=false, - canonicalize_headers=canonicalize_headers, - detect_content_type=detect_content_type, - basicauth=apply_basicauth, - observelayers=observelayers, - context=context, - ) do req - req_ref[] = req - if response_body isa AbstractVector{UInt8} - on_stream_response_body = BufferOnResponseBody(response_body, Ref(1)) - with_stream_manager(reqclient, req, chunkedbody, on_stream_response_body, decompress, readtimeout; context=context) - elseif response_body isa IO - on_stream_response_body = IOOnResponseBody(response_body) - with_stream_manager(reqclient, req, chunkedbody, on_stream_response_body, decompress, readtimeout; context=context) - else - with_stream_manager(reqclient, req, chunkedbody, response_body, decompress, readtimeout; context=context) - end - end - else - with_connection(reqclient; context=context) do conn - http2 = AwsHTTP.http_connection_get_version(conn) == AwsHTTP.HttpVersion.HTTP_2 - path = resource(uri) - with_request(reqclient, method, path, headers, body, chunkedbody, decompress, authinfo, bearer, modifier, http2, cookies, cookiejar, verbose; - copyheaders=false, - canonicalize_headers=canonicalize_headers, - detect_content_type=detect_content_type, - basicauth=apply_basicauth, + # `client_kw`/`chunkedbody` are reassigned above; freeze them in a fresh binding so + # the request pipeline closures capture a concrete value rather than a `Core.Box`. + let client_kw = client_kw, chunkedbody = chunkedbody + return with_redirect(method, uri, headers, body, redirect, redirect_limit, redirect_method, forwardheaders; context=context) do method, uri, headers, body + reqclient = @something( + client, + pool === nothing ? + getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...)) : + getclient(ClientSettings(scheme(uri), host(uri), getport(uri); client_kw...), pool) + )::Client + req_ref = Ref{Any}(nothing) + with_retry_token(reqclient; logerrors=logerrors, logtag=logtag, method=method, uri=uri, + retry_check=retry_check, retry_delays=retry_delays, + retry_non_idempotent=retry_non_idempotent, retryable_body=retryable_body, req_ref=req_ref, context=context) do + resp = if reqclient.http2_stream_manager !== nothing + with_request(reqclient, method, resource(uri), headers, body, chunkedbody, decompress, authinfo, bearer, modifier, true, cookies, cookiejar, verbose; + copyheaders=false, + canonicalize_headers=canonicalize_headers, + detect_content_type=detect_content_type, + basicauth=apply_basicauth, observelayers=observelayers, context=context, ) do req req_ref[] = req if response_body isa AbstractVector{UInt8} on_stream_response_body = BufferOnResponseBody(response_body, Ref(1)) - with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout; context=context) + with_stream_manager(reqclient, req, chunkedbody, on_stream_response_body, decompress, readtimeout; context=context) elseif response_body isa IO on_stream_response_body = IOOnResponseBody(response_body) - with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout; context=context) + with_stream_manager(reqclient, req, chunkedbody, on_stream_response_body, decompress, readtimeout; context=context) else - with_stream(conn, req, chunkedbody, response_body, decompress, http2, readtimeout; context=context) + with_stream_manager(reqclient, req, chunkedbody, response_body, decompress, readtimeout; context=context) + end + end + else + with_connection(reqclient; context=context) do conn + http2 = AwsHTTP.http_connection_get_version(conn) == AwsHTTP.HttpVersion.HTTP_2 + with_request(reqclient, method, resource(uri), headers, body, chunkedbody, decompress, authinfo, bearer, modifier, http2, cookies, cookiejar, verbose; + copyheaders=false, + canonicalize_headers=canonicalize_headers, + detect_content_type=detect_content_type, + basicauth=apply_basicauth, + observelayers=observelayers, + context=context, + ) do req + req_ref[] = req + if response_body isa AbstractVector{UInt8} + on_stream_response_body = BufferOnResponseBody(response_body, Ref(1)) + with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout; context=context) + elseif response_body isa IO + on_stream_response_body = IOOnResponseBody(response_body) + with_stream(conn, req, chunkedbody, on_stream_response_body, decompress, http2, readtimeout; context=context) + else + with_stream(conn, req, chunkedbody, response_body, decompress, http2, readtimeout; context=context) + end end end end - end - # status error check - if status_exception && iserror(resp) - if logerrors - @error "HTTP StatusError" method=method url=makeuri(uri) status=resp.status logtag=logtag + # status error check + if status_exception && iserror(resp) + if logerrors + @error "HTTP StatusError" method=method url=makeuri(uri) status=resp.status logtag=logtag + end + throw(StatusError(method, uri, resp)) end - throw(StatusError(method, uri, resp)) + return resp end - return resp end end end diff --git a/src/client/retry.jl b/src/client/retry.jl index 1fea7958..ba4efad9 100644 --- a/src/client/retry.jl +++ b/src/client/retry.jl @@ -1,5 +1,5 @@ -struct DontRetry{T} <: Exception - error::T +struct DontRetry <: Exception + error::Exception end Base.showerror(io::IO, e::DontRetry) = print(io, e.error) @@ -13,18 +13,32 @@ Base.showerror(io::IO, e::StreamError) = print(io, e.error) retryable_status(status::Integer) = status in (403, 408, 409, 429, 500, 502, 503, 504, 599) -isrecoverable(ex::StatusError) = retryable_status(ex.status) -isrecoverable(ex::ConnectError) = isrecoverable(ex.error) -isrecoverable(ex::TimeoutError) = true -isrecoverable(ex::RequestError) = isrecoverable(ex.error) -isrecoverable(::Union{Base.EOFError, Base.IOError}) = true -isrecoverable(ex::ArgumentError) = ex.msg == "stream is closed or unusable" -isrecoverable(ex::CompositeException) = all(isrecoverable, ex.exceptions) -isrecoverable(ex::Sockets.DNSError) = (ex.code == Base.UV_EAI_AGAIN) -isrecoverable(::AWSError) = true -isrecoverable(::Exception) = false +function isrecoverable(ex::Exception)::Bool + if ex isa StatusError + return retryable_status(ex.status) + elseif ex isa ConnectError + return isrecoverable(ex.error) + elseif ex isa TimeoutError + return true + elseif ex isa RequestError + return isrecoverable(ex.error) + elseif ex isa Base.EOFError || ex isa Base.IOError + return true + elseif ex isa ArgumentError + return ex.msg == "stream is closed or unusable" + elseif ex isa CompositeException + for child in ex.exceptions + child isa Exception || return false + isrecoverable(child) || return false + end + return true + elseif ex isa AWSError + return true + end + return false +end -function _default_retryable(method, err, retryable_body::Bool, retry_non_idempotent::Bool) +@inline function _default_retryable(method, err::Exception, retryable_body::Bool, retry_non_idempotent::Bool)::Bool retryable_body || return false method === nothing && return false method_str = string(method) @@ -51,33 +65,20 @@ function _retry_error_type(err) if err isa StatusError status = err.status if status == 429 - return AwsIO.RetryErrorType.THROTTLING + return Reseau.Sockets.RetryErrorType.THROTTLING elseif 500 <= status < 600 - return AwsIO.RetryErrorType.SERVER_ERROR + return Reseau.Sockets.RetryErrorType.SERVER_ERROR elseif 400 <= status < 500 - return AwsIO.RetryErrorType.CLIENT_ERROR + return Reseau.Sockets.RetryErrorType.CLIENT_ERROR else - return AwsIO.RetryErrorType.TRANSIENT + return Reseau.Sockets.RetryErrorType.TRANSIENT end end - return AwsIO.RetryErrorType.TRANSIENT -end - -function _set_nretries!(x, nretries::Int) - if x isa Response - x.metrics.nretries = nretries - elseif x isa StatusError - x.response.metrics.nretries = nretries - elseif x isa RequestError - _set_nretries!(x.error, nretries) - elseif x isa StreamError && x.stream !== nothing - x.stream.response !== nothing && (x.stream.response.metrics.nretries = nretries) - end - return + return Reseau.Sockets.RetryErrorType.TRANSIENT end function with_retry_token( - f::Function, + f, client::Client; logerrors::Bool=false, logtag=nothing, @@ -90,141 +91,26 @@ function with_retry_token( req_ref=nothing, context=nothing, ) - retry_token = nothing - partition = client.settings.retry_partition - use_retry_strategy = retry_delays === nothing && partition !== nothing && client.retry_strategy !== nothing - # If max_retries is 0, we don't need to bother with any retrying - max_retries = client.settings.max_retries - if max_retries == 0 - start_time = context !== nothing ? time() : 0.0 - try - return f() - catch e - if logerrors - @error "HTTP request error" exception=(e, catch_backtrace()) method=method url=uri logtag=logtag - end - rethrow() - finally - context !== nothing && _record_layer!(context, :retrylayer, start_time) - end - end - retry_check_fn = retry_check === nothing ? nothing : retry_check - delays = _normalize_retry_delays(retry_delays, max_retries) - delay_state = nothing - nretries = 0 - while true - attempt_start = context === nothing ? 0.0 : time() - try - ret = f() - context === nothing || _record_layer!(context, :retrylayer, attempt_start) - _set_nretries!(ret, nretries) - if retry_token !== nothing - AwsIO.retry_token_record_success(retry_token) - AwsIO.retry_token_release!(retry_token) - retry_token = nothing - end - return ret - catch e - context === nothing || _record_layer!(context, :retrylayer, attempt_start) - stream = nothing - err = e - if err isa StreamError - stream = err.stream - err = err.error - end - if logerrors - log_err = err isa DontRetry ? err.error : err - @error "HTTP request error" exception=(log_err, catch_backtrace()) method=method url=uri logtag=logtag - end - if err isa DontRetry - if stream !== nothing && iserror(stream.response.status) && stream.bufferstream !== nothing - # for error responses, we need to commit the temporary body buffer + start_time = context !== nothing ? time() : 0.0 + try + return f() + catch e + err = e isa StreamError ? (e::StreamError).error : e + if err isa DontRetry + if e isa StreamError + stream = (e::StreamError).stream::Stream + if iserror(stream.response.status) && stream.bufferstream !== nothing + # For error responses, we need to commit the temporary body buffer. stream.response.body = readavailable(stream.bufferstream) end - err = err.error - _set_nretries!(err, nretries) - if retry_token !== nothing - AwsIO.retry_token_release!(retry_token) - retry_token = nothing - end - throw(err) - end - if nretries >= max_retries - _set_nretries!(err, nretries) - if retry_token !== nothing - AwsIO.retry_token_release!(retry_token) - retry_token = nothing - end - throw(err) - end - delay = 0.0 - if !use_retry_strategy - delay_iter = delay_state === nothing ? iterate(delays) : iterate(delays, delay_state) - delay_iter === nothing && (_set_nretries!(err, nretries); throw(err)) - delay, delay_state = delay_iter - end - req = req_ref === nothing ? nothing : req_ref[] - resp = err isa StatusError ? err.response : nothing - resp_body = resp === nothing ? nothing : resp.body - retry = _default_retryable(method, err, retryable_body, retry_non_idempotent) - if !retry && retry_check_fn !== nothing && retryable_body - retry = retry_check_fn(delay, err, req, resp, resp_body) end - if !retry - _set_nretries!(err, nretries) - if retry_token !== nothing - AwsIO.retry_token_release!(retry_token) - retry_token = nothing - end - throw(err) - end - if use_retry_strategy - try - if retry_token === nothing - fut = Future{Any}() - AwsIO.retry_strategy_acquire_token!( - client.retry_strategy, - partition, - (token, error_code, _) -> begin - if error_code != 0 - notify(fut, DontRetry(CapturedException(aws_error(error_code), Base.backtrace()))) - else - notify(fut, token) - end - end, - nothing, - UInt64(client.settings.retry_timeout_ms) - ) - retry_token = wait(fut) - end - fut = Future{Any}() - error_type = _retry_error_type(err) - AwsIO.retry_token_schedule_retry( - retry_token, - error_type, - (token, error_code, _) -> begin - if error_code != 0 - notify(fut, DontRetry(CapturedException(aws_error(error_code), Base.backtrace()))) - else - notify(fut, token) - end - end, - nothing - ) - retry_token = wait(fut) - catch - if retry_token !== nothing - AwsIO.retry_token_release!(retry_token) - retry_token = nothing - end - _set_nretries!(err, nretries) - throw(err) - end - nretries += 1 - continue - end - nretries += 1 - sleep(delay) + throw(err.error) + end + if logerrors + @error "HTTP request error" exception=(err, catch_backtrace()) method=method url=uri logtag=logtag end + rethrow() + finally + context !== nothing && _record_layer!(context, :retrylayer, start_time) end end diff --git a/src/client/stream.jl b/src/client/stream.jl index c8a06947..857136bf 100644 --- a/src/client/stream.jl +++ b/src/client/stream.jl @@ -176,10 +176,10 @@ function _h1_flush_outgoing!(s::Stream) slot === nothing && return channel = slot.channel channel === nothing && return - if !AwsIO.channel_thread_is_callers_thread(channel) + if !Reseau.Sockets.channel_thread_is_callers_thread(channel) fut = Future{Nothing}() - task = AwsIO.ChannelTask((task, ctx, status) -> begin - status == AwsIO.TaskStatus.RUN_READY || return notify(fut, nothing) + task = Reseau.Sockets.ChannelTask((task, ctx, status) -> begin + status == Reseau.TaskStatus.RUN_READY || return notify(fut, nothing) try _h1_flush_outgoing!(s) notify(fut, nothing) @@ -188,7 +188,7 @@ function _h1_flush_outgoing!(s::Stream) end return nothing end, nothing, "http_h1_flush_outgoing") - AwsIO.channel_schedule_task_now!(channel, task) + Reseau.Sockets.channel_schedule_task_now!(channel, task) wait(fut) return end @@ -196,14 +196,18 @@ function _h1_flush_outgoing!(s::Stream) status, encoded = AwsHTTP.h1_connection_encode_outgoing!(h1conn) status != AwsHTTP.OP_SUCCESS && throw(AWSError("H1 encoding failed")) isempty(encoded) && break - msg = AwsIO.IoMessage(length(encoded)) + msg = Reseau.Sockets.IoMessage(length(encoded)) buf = msg.message_data @inbounds for i in 1:length(encoded) buf.mem[i] = encoded[i] end buf.len = Csize_t(length(encoded)) - result = AwsIO.channel_slot_send_message(slot, msg, AwsIO.ChannelDirection.WRITE) - result isa AwsIO.ErrorResult && throw(AWSError("channel slot send failed")) + try + Reseau.Sockets.channel_slot_send_message(slot, msg, Reseau.Sockets.ChannelDirection.WRITE) + catch e + e isa Reseau.ReseauError || rethrow() + throw(AWSError("channel slot send failed")) + end end return end @@ -761,7 +765,7 @@ function with_stream_manager(client::Client, req::Request, chunkedbody, on_strea timeout_task = nothing if readtimeout > 0 timeout_task = errormonitor(Threads.@spawn begin - sleep(readtimeout) + _task_sleep_s(readtimeout) (@atomic stream.fut.set) != 0 && return notify(stream.fut, TimeoutError(readtimeout)) if isdefined(stream, :aws_stream) @@ -849,12 +853,12 @@ function with_stream(conn, req::Request, chunkedbody, on_stream_response_body, d timeout_task = nothing if readtimeout > 0 timeout_task = errormonitor(Threads.@spawn begin - sleep(readtimeout) + _task_sleep_s(readtimeout) (@atomic stream.fut.set) != 0 && return if !stream.http2 && isdefined(stream, :connection) conn = stream.connection if conn !== nothing && conn.slot !== nothing && conn.slot.channel !== nothing - AwsIO.channel_shutdown!(conn.slot.channel, AwsHTTP.ERROR_HTTP_RESPONSE_FIRST_BYTE_TIMEOUT; shutdown_immediately=true) + Reseau.Sockets.channel_shutdown!(conn.slot.channel, AwsHTTP.ERROR_HTTP_RESPONSE_FIRST_BYTE_TIMEOUT; shutdown_immediately=true) elseif conn !== nothing AwsHTTP.http_connection_close(conn) end diff --git a/src/cookies.jl b/src/cookies.jl index 6a6181c6..c2fc01f7 100644 --- a/src/cookies.jl +++ b/src/cookies.jl @@ -33,7 +33,7 @@ module Cookies export Cookie, CookieJar, cookies, stringify, getcookies!, setcookies!, addcookie! import Base: == -using Dates, Sockets +using Dates import ..addheader, ..headereq, ..Headers, ..Request, ..Response, .._header_name, .._header_value @enum SameSite SameSiteDefaultMode=1 SameSiteLaxMode SameSiteStrictMode SameSiteNoneMode @@ -301,14 +301,18 @@ function readsetcookies(headers::Headers) return result end -function isIP(host) - try - Base.parse(IPAddr, host) - return true - catch e - isa(e, ArgumentError) && return false - rethrow(e) +function isIP(host::AbstractString)::Bool + # Minimal IPv4 literal check (avoid `Sockets` stdlib dependency). + parts = split(host, '.'; keepempty = false) + length(parts) == 4 || return false + for p in parts + isempty(p) && return false + all(isdigit, p) || return false + v = tryparse(Int, p) + v === nothing && return false + (0 <= v <= 255) || return false end + return true end """ diff --git a/src/download.jl b/src/download.jl index 4e3a40cb..841dc38c 100644 --- a/src/download.jl +++ b/src/download.jl @@ -38,7 +38,7 @@ function determine_file(path, resp, hdrs) filename = something( try_get_filename_from_headers(hdrs), resp.request === nothing ? nothing : try_get_filename_from_request(resp.request), - basename(tempname()) + basename(tempname(; cleanup = false)) ) return safer_joinpath(path, filename) end @@ -102,9 +102,11 @@ function download(url::AbstractString, local_path=nothing, headers=Header[]; upd ) end - Base.open(file, "w") do fh + Base.open(file, "w") do io while !eof(stream) - downloaded_bytes += write(fh, readavailable(stream)) + buf = readavailable(stream) + wrote = write(io, buf) + downloaded_bytes += wrote if !isinf(update_period) if now() - prev_time > Millisecond(round(1000update_period)) report_callback() diff --git a/src/exceptions.jl b/src/exceptions.jl index c5e2e84d..b11d8261 100644 --- a/src/exceptions.jl +++ b/src/exceptions.jl @@ -2,6 +2,9 @@ module Exceptions export @try, HTTPError, ConnectError, TimeoutError, RequestError, current_exceptions_to_string +# Pull parent-module bindings needed for concrete field types. +import ..Request + @eval begin """ @try PermittedErrorTypes expr @@ -33,7 +36,7 @@ the remote server. The underlying error is stored in `error`. """ struct ConnectError <: HTTPError url::String - error::Any + error::Exception end function Base.showerror(io::IO, e::ConnectError) @@ -60,8 +63,8 @@ Raised when an error occurs while physically sending a request to the remote ser or reading the response back. The underlying error is stored in `error`. """ struct RequestError <: HTTPError - request::Any - error::Any + request::Request + error::Exception end function Base.showerror(io::IO, e::RequestError) diff --git a/src/server.jl b/src/server.jl index 76226c22..e2d94748 100644 --- a/src/server.jl +++ b/src/server.jl @@ -5,34 +5,32 @@ function server_tlsoptions(; ssl_cacert=nothing, ssl_insecure=false, ssl_alpn_list="h2;http/1.1", -) + ) alpn_list = _normalize_alpn_list(ssl_alpn_list) if ssl_cert !== nothing && ssl_key !== nothing - ctx_opts = AwsIO.tls_ctx_options_init_default_server_from_path(ssl_cert, ssl_key; alpn_list=alpn_list) - ctx_opts isa AwsIO.ErrorResult && throw(AWSError("Failed to create server TLS options")) + ctx_opts = Reseau.Sockets.tls_ctx_options_init_default_server_from_path(ssl_cert, ssl_key; alpn_list=alpn_list) elseif Sys.iswindows() && ssl_cert !== nothing && ssl_key === nothing - ctx_opts = AwsIO.tls_ctx_options_init_default_server_from_system_path(ssl_cert) - ctx_opts isa AwsIO.ErrorResult && throw(AWSError("Failed to create server TLS options from system path")) + ctx_opts = Reseau.Sockets.tls_ctx_options_init_default_server_from_system_path(ssl_cert) else throw(ArgumentError("ssl_cert and ssl_key are required for TLS server")) end if ssl_capath !== nothing || ssl_cacert !== nothing - res = AwsIO.tls_ctx_options_override_default_trust_store_from_path!(ctx_opts; + Reseau.Sockets.tls_ctx_options_override_default_trust_store_from_path!(ctx_opts; ca_path=ssl_capath, ca_file=ssl_cacert) - res isa AwsIO.ErrorResult && throw(AWSError("Failed to set trust store")) end if ssl_insecure - AwsIO.tls_ctx_options_set_verify_peer!(ctx_opts, false) + Reseau.Sockets.tls_ctx_options_set_verify_peer!(ctx_opts, false) end - ctx = AwsIO.tls_server_ctx_new(ctx_opts) - ctx isa AwsIO.ErrorResult && throw(AWSError("Failed to create server TLS context")) - return AwsIO.TlsConnectionOptions(ctx; alpn_list=alpn_list) + ctx = Reseau.Sockets.tls_server_ctx_new(ctx_opts) + return Reseau.Sockets.TlsConnectionOptions(ctx; alpn_list=alpn_list) end +const _BACKLOG_DEFAULT = 511 + mutable struct Connection{S} const server::S # Server{F, C} const h1conn::Any # AwsHTTP.H1Connection or AwsHTTP.H2Connection - const channel::Any # AwsIO.Channel + const channel::Any # Reseau.Channel const streams_lock::ReentrantLock const streams::Set{Stream} const remote_addr::String @@ -63,7 +61,7 @@ mutable struct Server{F, C} const stream::Bool const logstate::Base.CoreLogging.LogState @atomic state::Symbol # :initializing, :running, :closed - bootstrap::Any # AwsIO.ServerBootstrap + bootstrap::Any # Reseau.ServerBootstrap bound_port::Int Server{F, C}( @@ -108,9 +106,9 @@ function _should_log_stream_error(error_code::Integer)::Bool error_code == AwsHTTP.ERROR_HTTP_SWITCHED_PROTOCOLS && return false error_code == AwsHTTP.ERROR_HTTP_GOAWAY_RECEIVED && return false error_code == AwsHTTP.ERROR_HTTP_RST_STREAM_RECEIVED && return false - error_code == AwsIO.ERROR_IO_SOCKET_CLOSED && return false - error_code == AwsIO.ERROR_IO_BROKEN_PIPE && return false - error_code == AwsIO.ERROR_IO_OPERATION_CANCELLED && return false + error_code == Reseau.EventLoops.ERROR_IO_SOCKET_CLOSED && return false + error_code == Reseau.EventLoops.ERROR_IO_BROKEN_PIPE && return false + error_code == Reseau.EventLoops.ERROR_IO_OPERATION_CANCELLED && return false return true end @@ -118,9 +116,9 @@ function _should_log_channel_shutdown_error(error_code::Integer)::Bool error_code == 0 && return false error_code == AwsHTTP.ERROR_HTTP_CONNECTION_CLOSED && return false error_code == AwsHTTP.ERROR_HTTP_SERVER_CLOSED && return false - error_code == AwsIO.ERROR_IO_SOCKET_CLOSED && return false - error_code == AwsIO.ERROR_IO_BROKEN_PIPE && return false - error_code == AwsIO.ERROR_IO_OPERATION_CANCELLED && return false + error_code == Reseau.EventLoops.ERROR_IO_SOCKET_CLOSED && return false + error_code == Reseau.EventLoops.ERROR_IO_BROKEN_PIPE && return false + error_code == Reseau.EventLoops.ERROR_IO_OPERATION_CANCELLED && return false return true end @@ -269,7 +267,7 @@ function _create_request_handler!(conn::Connection, aws_conn; http2::Bool=false) end end if shutdown_channel - AwsIO.channel_shutdown!(conn.channel; shutdown_immediately=true) + Reseau.Sockets.channel_shutdown!(conn.channel; shutdown_immediately=true) @lock server.connections_lock begin delete!(server.connections, conn) end @@ -316,8 +314,8 @@ function _create_request_handler!(conn::Connection, aws_conn; http2::Bool=false) end function _warn_unsupported_server_options(; reuseaddr::Bool, backlog::Integer) - reuseaddr && @warn "reuseaddr is not supported by the AwsIO server; ignoring" - backlog != Sockets.BACKLOG_DEFAULT && @warn "backlog is not supported by the AwsIO server; ignoring" + reuseaddr && @warn "reuseaddr is not supported by the Reseau server; ignoring" + backlog != _BACKLOG_DEFAULT && @warn "backlog is not supported by the Reseau server; ignoring" return end @@ -339,7 +337,7 @@ function serve!(f, host="127.0.0.1", port=8080; stream::Bool=false, listenany::Bool=false, reuseaddr::Bool=false, - backlog::Integer=Sockets.BACKLOG_DEFAULT, + backlog::Integer=_BACKLOG_DEFAULT, # socket options socket_domain=:ipv4, connect_timeout_ms::Integer=3000, @@ -360,12 +358,9 @@ function serve!(f, host="127.0.0.1", port=8080; _ensure_resources!() _warn_unsupported_server_options(; reuseaddr=reuseaddr, backlog=backlog) host_str = string(host) - port_int = Int(port) - if listenany - addr = Sockets.InetAddr(parse(IPAddr, host_str), port_int) - port_int, sock = Sockets.listenany(addr.host, addr.port) - close(sock) - end + # `listenany=true` should pick an ephemeral port (port=0), avoiding collisions + # with any existing process bound to the default `port` (e.g. 8080). + port_int = listenany ? 0 : Int(port) tls_conn_opts = if tls_options !== nothing tls_options elseif any(x -> x !== nothing, (ssl_cert, ssl_key, ssl_capath, ssl_cacert)) @@ -388,8 +383,10 @@ function serve!(f, host="127.0.0.1", port=8080; Base.CoreLogging.current_logstate(), :initializing, ) - socket_opts = AwsIO.SocketOptions(; - domain = socket_domain == :ipv4 ? AwsIO.SocketDomain.IPV4 : AwsIO.SocketDomain.IPV6, + server.bound_port = port_int + listener_ready = Threads.Event() + socket_opts = Reseau.Sockets.SocketOptions(; + domain = socket_domain == :ipv4 ? Reseau.Sockets.SocketDomain.IPV4 : Reseau.Sockets.SocketDomain.IPV6, connect_timeout_ms = connect_timeout_ms, keep_alive_interval_sec = keep_alive_interval_sec, keep_alive_timeout_sec = keep_alive_timeout_sec, @@ -397,9 +394,6 @@ function serve!(f, host="127.0.0.1", port=8080; keepalive = keepalive, ) alpn_list = _tls_alpn_list(tls_conn_opts) - if _use_nw_sockets() - socket_opts.impl_type = AwsIO.SocketImplType.APPLE_NETWORK_FRAMEWORK - end initial_window = Csize_t(min(UInt64(initial_window_size), UInt64(typemax(Csize_t)))) on_incoming_channel_setup = (bootstrap, error_code, channel, user_data) -> begin Base.CoreLogging.with_logstate(server.logstate) do @@ -409,22 +403,22 @@ function serve!(f, host="127.0.0.1", port=8080; end st = @atomic(server.state) if st == :closing || st == :closed - AwsIO.channel_shutdown!(channel; shutdown_immediately=true) + Reseau.Sockets.channel_shutdown!(channel; shutdown_immediately=true) return end - slot = AwsIO.channel_slot_new!(channel) - AwsIO.channel_slot_insert_end!(channel, slot) + slot = Reseau.Sockets.channel_slot_new!(channel) + Reseau.Sockets.channel_slot_insert_end!(channel, slot) version = AwsHTTP.HttpVersion.HTTP_1_1 if tls_conn_opts !== nothing tls_slot = slot.adj_left - if tls_slot === nothing || tls_slot.handler === nothing || !(tls_slot.handler isa AwsIO.TlsChannelHandler) - @error "incoming channel setup error" error_code=AwsIO.ERROR_INVALID_STATE - AwsIO.channel_shutdown!(channel, AwsIO.ERROR_INVALID_STATE) + if tls_slot === nothing || tls_slot.handler === nothing || !(tls_slot.handler isa Reseau.Sockets.TlsChannelHandler) + @error "incoming channel setup error" error_code=Reseau.ERROR_INVALID_STATE + Reseau.Sockets.channel_shutdown!(channel, Reseau.ERROR_INVALID_STATE) return end - protocol = AwsIO.tls_handler_protocol(tls_slot.handler) + protocol = Reseau.Sockets.tls_handler_protocol(tls_slot.handler) if protocol.len > 0 - protocol_str = AwsIO.byte_buffer_as_string(protocol) + protocol_str = Reseau.byte_buffer_as_string(protocol) if protocol_str == "h2" version = AwsHTTP.HttpVersion.HTTP_2 elseif protocol_str == "http/1.1" @@ -438,7 +432,7 @@ function serve!(f, host="127.0.0.1", port=8080; initial_window_size=initial_window, ) http_conn === nothing && return - AwsIO.channel_slot_set_handler!(slot, http_conn) + Reseau.Sockets.channel_slot_set_handler!(slot, http_conn) http_conn.slot = slot # Extract remote endpoint from the socket handler (first slot in pipeline) remote_addr = "0.0.0.0" @@ -446,7 +440,7 @@ function serve!(f, host="127.0.0.1", port=8080; try socket_handler = channel.first.handler ep = socket_handler.socket.remote_endpoint - remote_addr = AwsIO.get_address(ep) + remote_addr = Reseau.Sockets.get_address(ep) remote_port_num = Int(ep.port) catch end @@ -476,15 +470,15 @@ function serve!(f, host="127.0.0.1", port=8080; else _create_request_handler!(conn, http_conn; http2=false) end - if AwsIO.channel_thread_is_callers_thread(channel) - AwsIO.channel_trigger_read(channel) + if Reseau.Sockets.channel_thread_is_callers_thread(channel) + Reseau.Sockets.channel_trigger_read(channel) else - task = AwsIO.ChannelTask((task, ctx, status) -> begin - status == AwsIO.TaskStatus.RUN_READY || return nothing - AwsIO.channel_trigger_read(ctx.channel) + task = Reseau.Sockets.ChannelTask((task, ctx, status) -> begin + status == Reseau.TaskStatus.RUN_READY || return nothing + Reseau.Sockets.channel_trigger_read(ctx.channel) return nothing end, (channel = channel,), "http_server_trigger_read") - AwsIO.channel_schedule_task_now!(channel, task) + Reseau.Sockets.channel_schedule_task_now!(channel, task) end end return @@ -504,29 +498,37 @@ function serve!(f, host="127.0.0.1", port=8080; notify(server.fut, :destroyed) return end - bootstrap_opts = AwsIO.ServerBootstrapOptions(; + bootstrap_opts = Reseau.Sockets.ServerBootstrapOptions(; event_loop_group = _EVENT_LOOP_GROUP[], socket_options = socket_opts, host = host_str, port = UInt32(port_int), tls_connection_options = tls_conn_opts, on_protocol_negotiated = nothing, + on_listener_setup = (bootstrap, error_code, user_data) -> begin + if error_code == 0 && bootstrap.listener_socket !== nothing + server.bound_port = try + ep = Reseau.Sockets.socket_get_bound_address(bootstrap.listener_socket) + Int(ep.port) + catch + port_int + end + else + server.bound_port = port_int + end + notify(listener_ready) + return nothing + end, on_incoming_channel_setup = on_incoming_channel_setup, on_incoming_channel_shutdown = on_incoming_channel_shutdown, on_listener_destroy = on_listener_destroy, user_data = server, enable_read_back_pressure = false, ) - bs = AwsIO.ServerBootstrap(bootstrap_opts) - bs isa AwsIO.ErrorResult && throw(AWSError("Failed to create server bootstrap")) + bs = Reseau.Sockets.ServerBootstrap(bootstrap_opts) server.bootstrap = bs - # Retrieve the actual bound port (useful when port=0 or listenany) - if bs.listener_socket !== nothing - ep = AwsIO.socket_get_bound_address(bs.listener_socket) - server.bound_port = ep isa AwsIO.ErrorResult ? port_int : Int(ep.port) - else - server.bound_port = port_int - end + # Wait until the listener is ready so `port(server)` is accurate immediately. + wait(listener_ready) @atomic server.state = :running return server end @@ -626,13 +628,13 @@ end function _forceclose!(server::Server; skip_shutdown::Bool=false) skip_shutdown || shutdown(server.on_shutdown) - AwsIO.server_bootstrap_shutdown!(server.bootstrap) + Reseau.Sockets.server_bootstrap_shutdown!(server.bootstrap) conns = Connection[] @lock server.connections_lock begin append!(conns, server.connections) end for conn in conns - AwsIO.channel_shutdown!(conn.channel; shutdown_immediately=true) + Reseau.Sockets.channel_shutdown!(conn.channel; shutdown_immediately=true) end @atomic server.state = :closed notify(server.closed) @@ -648,7 +650,7 @@ function Base.close(server::Server) return end shutdown(server.on_shutdown) - AwsIO.server_bootstrap_shutdown!(server.bootstrap) + Reseau.Sockets.server_bootstrap_shutdown!(server.bootstrap) conns = Connection[] @lock server.connections_lock begin append!(conns, server.connections) @@ -657,7 +659,7 @@ function Base.close(server::Server) _stop_new_requests!(conn) @lock conn.streams_lock begin if isempty(conn.streams) - AwsIO.channel_shutdown!(conn.channel; shutdown_immediately=true) + Reseau.Sockets.channel_shutdown!(conn.channel; shutdown_immediately=true) @lock server.connections_lock begin delete!(server.connections, conn) end @@ -674,7 +676,7 @@ function Base.close(server::Server) notify(server.closed) return end - sleep(0.05) + _task_sleep_s(0.05) end _forceclose!(server; skip_shutdown=true) return diff --git a/src/statistics.jl b/src/statistics.jl index 6d5ef1f1..5513ed95 100644 --- a/src/statistics.jl +++ b/src/statistics.jl @@ -56,14 +56,6 @@ _normalize_stat(stat) = stat function _decode_statistics(stats_list) list = stats_list isa Base.RefValue ? stats_list[] : stats_list out = Any[] - if list isa AwsIO.ArrayList - for i in 1:length(list) - item = list[i] - item = item isa Base.RefValue ? item[] : item - push!(out, _normalize_stat(item)) - end - return out - end if list isa AbstractVector for item in list item = item isa Base.RefValue ? item[] : item @@ -71,7 +63,7 @@ function _decode_statistics(stats_list) end return out end - throw(ArgumentError("stats_list must be ArrayList or AbstractVector")) + throw(ArgumentError("stats_list must be an AbstractVector")) end function _call_statistics_observer(observer, nonce, stats_list) diff --git a/src/utils.jl b/src/utils.jl index 7d503d03..d282d18e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,7 +6,7 @@ const HTTP2_MAX_WINDOW_SIZE = 0x7fffffff const AWS_HTTP2_SETTINGS_MAX_CONCURRENT_STREAMS = AwsHTTP.Http2SettingsId.MAX_CONCURRENT_STREAMS const AWS_HTTP2_SETTINGS_INITIAL_WINDOW_SIZE = AwsHTTP.Http2SettingsId.INITIAL_WINDOW_SIZE const AWS_HTTP2_SETTINGS_COUNT = Int(AwsHTTP.HTTP2_SETTINGS_END_RANGE - AwsHTTP.HTTP2_SETTINGS_BEGIN_RANGE) -const _H2_CHANNEL_SUPPORTED = AwsHTTP.H2Connection <: AwsIO.AbstractChannelHandler +const _H2_CHANNEL_SUPPORTED = AwsHTTP.H2Connection <: Reseau.Sockets.AbstractChannelHandler function _normalize_alpn_list(alpn_list::Union{String, Nothing}) alpn_list === nothing && return nothing @@ -36,8 +36,8 @@ end function _use_nw_sockets()::Bool @static if Sys.isapple() - AwsIO._tls_set_use_secitem_from_env() - return AwsIO._NW_SHIM_LIB != "" && AwsIO.is_using_secitem() + Reseau.Sockets._tls_set_use_secitem_from_env() + return Reseau.Sockets.is_using_secitem() else return false end @@ -284,11 +284,10 @@ function _resolve_error_str(error_code::Integer) if ec >= AwsHTTP.ERROR_HTTP_UNKNOWN && ec <= AwsHTTP.ERROR_HTTP_END_RANGE return AwsHTTP.http_error_str(ec) end - # AwsIO.error_str returns Ptr{UInt8}; convert to String - return unsafe_string(AwsIO.error_str(ec)) + return Reseau.error_str(ec) end -aws_error() = AWSError(_resolve_error_str(AwsIO.last_error())) +aws_error() = AWSError(_resolve_error_str(Reseau.last_error())) aws_error(error_code) = AWSError(_resolve_error_str(error_code)) aws_throw_error() = throw(aws_error()) diff --git a/src/websockets.jl b/src/websockets.jl index ede36296..1aca9eb4 100644 --- a/src/websockets.jl +++ b/src/websockets.jl @@ -1,6 +1,6 @@ module WebSockets -using Base64, Random, AwsHTTP, AwsIO +using Base64, Random, AwsHTTP, Reseau import ..Headers, ..Header, ..Request, ..Response, ..Message, ..Stream import ..setinputstream!, ..getresponse, ..getheader, ..hasheader, ..header @@ -10,7 +10,7 @@ import ..ClientSettings, ..scheme, ..host, ..getport, ..userinfo, ..resource import ..Client, ..CookieJar, ..COOKIEJAR, ..with_connection, .._open_stream import ..aws_throw_error, ..aws_error, ..AWSError import .._h1_flush_outgoing!, ..iswss, ..makeuri, ..HTTP, ..mkreqheaders -import ..startread, ..closeread, ..startwrite, ..closewrite +import ..startread, ..closeread, ..startwrite, ..closewrite, .._task_sleep_s export WebSocket, send, receive, ping, pong @@ -45,11 +45,11 @@ isupgrade(s::Stream) = isupgrade(s.request) Base.@deprecate is_upgrade isupgrade # ─── WsChannelHandler ─── -# Bridges the AwsIO channel pipeline with the AwsHTTP WebSocket codec. +# Bridges the Reseau channel pipeline with the AwsHTTP WebSocket codec. # Installed into the H1Connection's channel slot after HTTP 101 upgrade. -mutable struct WsChannelHandler <: AwsIO.AbstractChannelHandler - slot::Union{AwsIO.ChannelSlot, Nothing} +mutable struct WsChannelHandler <: Reseau.Sockets.AbstractChannelHandler + slot::Union{Reseau.Sockets.ChannelSlot, Nothing} aws_ws::Any # AwsHTTP.WebSocket wslock::ReentrantLock # protects outgoing_frames access ws::Any @@ -57,13 +57,13 @@ end WsChannelHandler(aws_ws, ws) = WsChannelHandler(nothing, aws_ws, ReentrantLock(), ws) -function AwsIO.setchannelslot!(handler::WsChannelHandler, slot::AwsIO.ChannelSlot)::Nothing +function Reseau.Sockets.setchannelslot!(handler::WsChannelHandler, slot::Reseau.Sockets.ChannelSlot)::Nothing handler.slot = slot return nothing end -function AwsIO.handler_process_read_message(handler::WsChannelHandler, slot::AwsIO.ChannelSlot, message::AwsIO.IoMessage)::Union{Nothing, AwsIO.ErrorResult} - data = AwsIO.byte_buffer_as_vector(message.message_data) +function Reseau.Sockets.handler_process_read_message(handler::WsChannelHandler, slot::Reseau.Sockets.ChannelSlot, message::Reseau.Sockets.IoMessage)::Nothing + data = Reseau.byte_buffer_as_vector(message.message_data) isempty(data) && return nothing @lock handler.wslock begin status, _ = AwsHTTP.ws_on_incoming_data!(handler.aws_ws, data) @@ -76,7 +76,7 @@ function AwsIO.handler_process_read_message(handler::WsChannelHandler, slot::Aws _queue_close!(ws, close_body) errormonitor(Threads.@spawn close(ws, close_body)) end - return AwsIO.ErrorResult(status) + Reseau.throw_error(status) end # Flush auto-responses (PONG, CLOSE echo) generated by ws_on_incoming_data! _ws_channel_flush!(handler) @@ -84,32 +84,34 @@ function AwsIO.handler_process_read_message(handler::WsChannelHandler, slot::Aws return nothing end -function AwsIO.handler_process_write_message(handler::WsChannelHandler, slot::AwsIO.ChannelSlot, message::AwsIO.IoMessage)::Union{Nothing, AwsIO.ErrorResult} +function Reseau.Sockets.handler_process_write_message(handler::WsChannelHandler, slot::Reseau.Sockets.ChannelSlot, message::Reseau.Sockets.IoMessage)::Nothing # Pass through to lower pipeline (socket) - return AwsIO.channel_slot_send_message(slot, message, AwsIO.ChannelDirection.WRITE) + Reseau.Sockets.channel_slot_send_message(slot, message, Reseau.Sockets.ChannelDirection.WRITE) + return nothing end -function AwsIO.handler_increment_read_window(handler::WsChannelHandler, slot::AwsIO.ChannelSlot, size::Csize_t)::Union{Nothing, AwsIO.ErrorResult} - return AwsIO.channel_slot_increment_read_window!(slot, size) +function Reseau.Sockets.handler_increment_read_window(handler::WsChannelHandler, slot::Reseau.Sockets.ChannelSlot, size::Csize_t)::Nothing + Reseau.Sockets.channel_slot_increment_read_window!(slot, size) + return nothing end -function AwsIO.handler_shutdown( +function Reseau.Sockets.handler_shutdown( handler::WsChannelHandler, - slot::AwsIO.ChannelSlot, - direction::AwsIO.ChannelDirection.T, + slot::Reseau.Sockets.ChannelSlot, + direction::Reseau.Sockets.ChannelDirection.T, error_code::Int, free_scarce_resources_immediately::Bool, - )::Union{Nothing, AwsIO.ErrorResult} + )::Nothing ws = handler.ws if ws !== nothing && !ws.readclosed _queue_close!(ws, CloseFrameBody(1006, "")) end - AwsIO.channel_slot_on_handler_shutdown_complete!(slot, direction, error_code, free_scarce_resources_immediately) + Reseau.Sockets.channel_slot_on_handler_shutdown_complete!(slot, direction, error_code, free_scarce_resources_immediately) return nothing end -AwsIO.handler_initial_window_size(::WsChannelHandler)::Csize_t = Csize_t(typemax(UInt64)) -AwsIO.handler_message_overhead(::WsChannelHandler)::Csize_t = Csize_t(0) +Reseau.Sockets.handler_initial_window_size(::WsChannelHandler)::Csize_t = Csize_t(typemax(UInt64)) +Reseau.Sockets.handler_message_overhead(::WsChannelHandler)::Csize_t = Csize_t(0) # Collect outgoing frames from the AwsHTTP WebSocket codec and send them # through the channel pipeline. Caller MUST hold handler.wslock. @@ -120,22 +122,22 @@ function _ws_channel_flush!(handler::WsChannelHandler) slot === nothing && return channel = slot.channel channel === nothing && return - msg = AwsIO.IoMessage(length(outdata)) + msg = Reseau.Sockets.IoMessage(length(outdata)) buf = msg.message_data @inbounds for i in 1:length(outdata) buf.mem[i] = outdata[i] end buf.len = Csize_t(length(outdata)) - if AwsIO.channel_thread_is_callers_thread(channel) - AwsIO.channel_slot_send_message(slot, msg, AwsIO.ChannelDirection.WRITE) + if Reseau.Sockets.channel_thread_is_callers_thread(channel) + Reseau.Sockets.channel_slot_send_message(slot, msg, Reseau.Sockets.ChannelDirection.WRITE) return end - task = AwsIO.ChannelTask((task, ctx, status) -> begin - status == AwsIO.TaskStatus.RUN_READY || return nothing - AwsIO.channel_slot_send_message(ctx.slot, ctx.msg, AwsIO.ChannelDirection.WRITE) + task = Reseau.Sockets.ChannelTask((task, ctx, status) -> begin + status == Reseau.TaskStatus.RUN_READY || return nothing + Reseau.Sockets.channel_slot_send_message(ctx.slot, ctx.msg, Reseau.Sockets.ChannelDirection.WRITE) return nothing end, (slot=slot, msg=msg), "http_ws_flush") - AwsIO.channel_schedule_task_now!(channel, task) + Reseau.Sockets.channel_schedule_task_now!(channel, task) return end @@ -222,7 +224,7 @@ function _shutdown_ws_channel!(handler::WsChannelHandler) slot === nothing && return channel = slot.channel channel === nothing && return - AwsIO.channel_shutdown!(channel; shutdown_immediately=true) + Reseau.Sockets.channel_shutdown!(channel; shutdown_immediately=true) return end @@ -420,7 +422,7 @@ end # Create an AwsHTTP WebSocket and WsChannelHandler, then install # the handler into the H1Connection's channel slot. -function _create_ws_handler!(ws::WebSocket, slot::AwsIO.ChannelSlot, is_client::Bool) +function _create_ws_handler!(ws::WebSocket, slot::Reseau.Sockets.ChannelSlot, is_client::Bool) aws_ws = AwsHTTP.ws_new(; is_client=is_client, on_incoming_frame_begin=_on_incoming_frame_begin(ws), @@ -430,7 +432,7 @@ function _create_ws_handler!(ws::WebSocket, slot::AwsIO.ChannelSlot, is_client:: handler = WsChannelHandler(aws_ws, ws) ws.aws_ws = aws_ws ws.handler = handler - AwsIO.channel_slot_set_handler!(slot, handler) + Reseau.Sockets.channel_slot_set_handler!(slot, handler) return end @@ -538,14 +540,22 @@ returned by `receive`. Note that `WebSocket` objects can be iterated, where each iteration yields a message until the connection is closed. """ function receive(ws::WebSocket) + # If a CLOSE arrives after application data, `_queue_close!` marks `readclosed=true` + # and closes `readchannel`. We still need to deliver any already-queued messages + # before throwing the close error. + if isready(ws.readchannel) + msg = take!(ws.readchannel) + msg isa WebSocketError && throw(msg) + return msg + end + if ws.readclosed || !isopen(ws.readchannel) close_body = ws.closebody === nothing ? CloseFrameBody(1006, "") : ws.closebody throw(WebSocketError(close_body)) end + msg = take!(ws.readchannel) - if msg isa WebSocketError - throw(msg) - end + msg isa WebSocketError && throw(msg) return msg end @@ -604,7 +614,7 @@ function Base.close(ws::WebSocket, body::Union{Nothing, CloseFrameBody}=nothing) deadline = time() + 5.0 while time() < deadline ws.readclosed && break - sleep(0.05) + _task_sleep_s(0.05) end ws.readclosed = true end @@ -682,15 +692,38 @@ function open(f::Function, url; handshakeerror() end # Wait for the H1 stream to complete (fires with ERROR_HTTP_SWITCHED_PROTOCOLS) + # + # Note: servers are allowed to start sending websocket frames immediately after the 101 + # response. Depending on timing, the HTTP layer may have already read some post-upgrade + # bytes before we swap in the websocket handler. After the stream completes, drain any + # such buffered bytes and feed them to the websocket decoder so they aren't lost. try wait(stream.fut) catch # Expected: ERROR_HTTP_SWITCHED_PROTOCOLS end + buffered = UInt8[] + if stream.bufferstream !== nothing + # BufferStream is closed in the stream on_complete callback, so read() should not block here. + buffered = try + read(stream.bufferstream) + catch + UInt8[] + end + end # Swap the H1Connection handler with our WsChannelHandler h1conn = stream.aws_stream.owning_connection slot = h1conn.slot _create_ws_handler!(ws, slot, true) + if !isempty(buffered) + handler = ws.handler + handler !== nothing || throw(WebSocketError(CloseFrameBody(1006, "WebSocket not connected"))) + @lock handler.wslock begin + status, _ = AwsHTTP.ws_on_incoming_data!(handler.aws_ws, buffered) + status != AwsHTTP.OP_SUCCESS && throw(AWSError("ws_on_incoming_data! failed")) + _ws_channel_flush!(handler) + end + end verbose > 0 && @info "$(ws.id): WebSocket opened" # Run WebSocket session try diff --git a/test/client.jl b/test/client.jl index 12ee82e9..dd76840c 100644 --- a/test/client.jl +++ b/test/client.jl @@ -667,7 +667,7 @@ end @testset "HTTP connection monitoring stats" begin - list = AwsIO.ArrayList{HTTP.aws_crt_statistics_http1_channel}() + list = HTTP.aws_crt_statistics_http1_channel[] stat1 = HTTP.aws_crt_statistics_http1_channel( HTTP.AWSCRT_STAT_CAT_HTTP1_CHANNEL, UInt64(10), @@ -675,7 +675,7 @@ UInt32(1), UInt32(2), ) - AwsIO.push_back!(list, stat1) + push!(list, stat1) decoded = HTTP._decode_statistics(list) @test length(decoded) == 1 @test decoded[1].category == :http1_channel @@ -684,14 +684,14 @@ @test decoded[1].current_outgoing_stream_id == 1 @test decoded[1].current_incoming_stream_id == 2 - list = AwsIO.ArrayList{HTTP.aws_crt_statistics_http2_channel}() + list = HTTP.aws_crt_statistics_http2_channel[] stat2 = HTTP.aws_crt_statistics_http2_channel( HTTP.AWSCRT_STAT_CAT_HTTP2_CHANNEL, UInt64(5), UInt64(6), true, ) - AwsIO.push_back!(list, stat2) + push!(list, stat2) decoded = HTTP._decode_statistics(list) @test length(decoded) == 1 @test decoded[1].category == :http2_channel @@ -702,7 +702,7 @@ called = Ref(false) cb = (nonce, stats) -> (called[] = true) client = HTTP.Client(HTTP.ClientSettings("https", "example.com", UInt32(443); monitoring_statistics_observer=cb)) - list = AwsIO.ArrayList{HTTP.aws_crt_statistics_http1_channel}() + list = HTTP.aws_crt_statistics_http1_channel[] stat3 = HTTP.aws_crt_statistics_http1_channel( HTTP.AWSCRT_STAT_CAT_HTTP1_CHANNEL, UInt64(1), @@ -710,7 +710,7 @@ UInt32(1), UInt32(1), ) - AwsIO.push_back!(list, stat3) + push!(list, stat3) HTTP._call_statistics_observer(client.monitoring_observer, Csize_t(0), list) @test called[] finalize(client) diff --git a/test/runtests.jl b/test/runtests.jl index edf374e8..0ca425e5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -using Test, HTTP, URIs, JSON, AwsIO +using Test, HTTP, URIs, JSON, Reseau const httpbin = get(ENV, "JULIA_TEST_HTTPBINGO_SERVER", "httpbingo.julialang.org") isok(r) = r.status == 200 diff --git a/test/server.jl b/test/server.jl index f943609d..abcf7181 100644 --- a/test/server.jl +++ b/test/server.jl @@ -1,6 +1,259 @@ -using Test, HTTP, Logging, Base64, AwsIO +using Test, HTTP, Logging, Base64, Reseau import Sockets +# Find a subsequence in a byte vector (naive; good enough for tests). +function _find_subseq(hay::Vector{UInt8}, needle::Vector{UInt8}, from::Int = 1)::Int + n = length(needle) + m = length(hay) + (n == 0 || from > m) && return 0 + last_i = m - n + 1 + last_i < from && return 0 + @inbounds for i in from:last_i + if hay[i] == needle[1] + ok = true + for j in 2:n + if hay[i + j - 1] != needle[j] + ok = false + break + end + end + ok && return i + end + end + return 0 +end + +const _HTTP1_CHUNKED_TERMINATOR = Vector{UInt8}(codeunits("\r\n0\r\n")) +const _HTTP1_HEADERS_TERMINATOR = Vector{UInt8}(codeunits("\r\n\r\n")) + +function _http1_chunked_with_trailers_done(buf::Vector{UInt8})::Bool + idx0 = _find_subseq(buf, _HTTP1_CHUNKED_TERMINATOR) + idx0 == 0 && return false + idx_end = _find_subseq(buf, _HTTP1_HEADERS_TERMINATOR, idx0 + length(_HTTP1_CHUNKED_TERMINATOR)) + return idx_end != 0 +end + +# Minimal raw TCP client for tests. Avoids the `Sockets` stdlib. +@static if Sys.iswindows() + const _WS2_32 = "ws2_32" + const _AF_INET = Cint(2) + const _SOCK_STREAM = Cint(1) + const _IPPROTO_TCP = Cint(6) + const _INVALID_SOCKET = typemax(UInt) + const _SOCKET_ERROR = Cint(-1) + const _POLLIN = Int16(0x0001) + const _RAW_READ_TIMEOUT_MS = Cint(5_000) + + struct _sockaddr_in + sin_family::UInt16 + sin_port::UInt16 + sin_addr::UInt32 + sin_zero::NTuple{8, UInt8} + end + + struct _WSAPOLLFD + fd::UInt + events::Int16 + revents::Int16 + end + + function _raw_tcp_connect_readall( + host::AbstractString, + port::Integer, + request::AbstractString, + ; + stop_pred = nothing, + )::Vector{UInt8} + wsadata = Vector{UInt8}(undef, 512) + ret = ccall((:WSAStartup, _WS2_32), Cint, (UInt16, Ptr{UInt8}), UInt16(0x0202), wsadata) + ret == 0 || error("WSAStartup() failed: $ret") + sock = ccall((:socket, _WS2_32), UInt, (Cint, Cint, Cint), _AF_INET, _SOCK_STREAM, _IPPROTO_TCP) + if sock == _INVALID_SOCKET + err = ccall((:WSAGetLastError, _WS2_32), Cint, ()) + _ = ccall((:WSACleanup, _WS2_32), Cint, ()) + error("socket() failed: $err") + end + + try + addr = ccall((:inet_addr, _WS2_32), UInt32, (Cstring,), host) + addr == 0xffffffff && error("inet_addr() failed for host=$host") + port_be = ccall((:htons, _WS2_32), UInt16, (UInt16,), UInt16(port)) + sin = _sockaddr_in(UInt16(_AF_INET), port_be, addr, ntuple(_ -> UInt8(0), 8)) + + cres = ccall((:connect, _WS2_32), Cint, (UInt, Ref{_sockaddr_in}, Cint), sock, sin, Cint(sizeof(_sockaddr_in))) + cres == 0 || error("connect() failed: $(ccall((:WSAGetLastError, _WS2_32), Cint, ()))") + + bytes = Vector{UInt8}(codeunits(request)) + sent = GC.@preserve bytes ccall( + (:send, _WS2_32), + Cint, + (UInt, Ptr{UInt8}, Cint, Cint), + sock, + pointer(bytes), + Cint(length(bytes)), + Cint(0), + ) + sent == length(bytes) || error("send() failed: $(ccall((:WSAGetLastError, _WS2_32), Cint, ()))") + # Some server stacks treat EOF as end-of-request; half-close our send side. + _ = ccall((:shutdown, _WS2_32), Cint, (UInt, Cint), sock, Cint(1)) # SD_SEND = 1 + + buf = UInt8[] + tmp = Vector{UInt8}(undef, 4096) + while true + pollfd = Ref(_WSAPOLLFD(sock, _POLLIN, Int16(0))) + pres = ccall( + (:WSAPoll, _WS2_32), + Cint, + (Ptr{_WSAPOLLFD}, UInt32, Cint), + pollfd, + UInt32(1), + _RAW_READ_TIMEOUT_MS, + ) + pres == 0 && error("WSAPoll() timeout") + pres == _SOCKET_ERROR && error("WSAPoll() failed: $(ccall((:WSAGetLastError, _WS2_32), Cint, ()))") + + r = GC.@preserve tmp ccall( + (:recv, _WS2_32), + Cint, + (UInt, Ptr{UInt8}, Cint, Cint), + sock, + pointer(tmp), + Cint(length(tmp)), + Cint(0), + ) + r == _SOCKET_ERROR && error("recv() failed: $(ccall((:WSAGetLastError, _WS2_32), Cint, ()))") + r == 0 && break + append!(buf, view(tmp, 1:r)) + stop_pred !== nothing && stop_pred(buf) && return buf + end + return buf + finally + _ = ccall((:closesocket, _WS2_32), Cint, (UInt,), sock) + _ = ccall((:WSACleanup, _WS2_32), Cint, ()) + end + end +else + const _AF_INET = Cint(2) + const _SOCK_STREAM = Cint(1) + const _POLLIN = Int16(0x0001) + const _RAW_READ_TIMEOUT_MS = Cint(5_000) + + @static if Sys.isapple() || Sys.isbsd() + struct _sockaddr_in + sin_len::UInt8 + sin_family::UInt8 + sin_port::UInt16 + sin_addr::UInt32 + sin_zero::NTuple{8, UInt8} + end + else + struct _sockaddr_in + sin_family::UInt16 + sin_port::UInt16 + sin_addr::UInt32 + sin_zero::NTuple{8, UInt8} + end + end + + struct _pollfd + fd::Cint + events::Int16 + revents::Int16 + end + + const _poll_nfds_t = @static (Sys.isapple() || Sys.isbsd()) ? Cuint : Culong + + function _raw_tcp_connect_readall( + host::AbstractString, + port::Integer, + request::AbstractString, + ; + stop_pred = nothing, + )::Vector{UInt8} + fd = ccall(:socket, Cint, (Cint, Cint, Cint), _AF_INET, _SOCK_STREAM, 0) + fd < 0 && error("socket() failed: errno=$(Libc.errno())") + try + addr = Ref{UInt32}(0) + ok = ccall(:inet_pton, Cint, (Cint, Cstring, Ptr{Cvoid}), _AF_INET, host, addr) + ok == 1 || error("inet_pton() failed for host=$host") + + port16 = UInt16(port) + port_be = ccall(:htons, UInt16, (UInt16,), port16) + + sin = @static if Sys.isapple() || Sys.isbsd() + _sockaddr_in( + UInt8(sizeof(_sockaddr_in)), + UInt8(_AF_INET), + port_be, + addr[], + ntuple(_ -> UInt8(0), 8), + ) + else + _sockaddr_in( + UInt16(_AF_INET), + port_be, + addr[], + ntuple(_ -> UInt8(0), 8), + ) + end + + ret = ccall(:connect, Cint, (Cint, Ref{_sockaddr_in}, Cuint), fd, sin, Cuint(sizeof(_sockaddr_in))) + ret == 0 || error("connect() failed: errno=$(Libc.errno())") + + bytes = Vector{UInt8}(codeunits(request)) + n = GC.@preserve bytes ccall( + :write, + Cssize_t, + (Cint, Ptr{UInt8}, Csize_t), + fd, + pointer(bytes), + Csize_t(length(bytes)), + ) + n == length(bytes) || error("write() failed: errno=$(Libc.errno())") + # Some server stacks treat EOF as end-of-request; half-close our send side. + _ = ccall(:shutdown, Cint, (Cint, Cint), fd, Cint(1)) # SHUT_WR = 1 + + buf = UInt8[] + tmp = Vector{UInt8}(undef, 4096) + while true + pollfd = Ref(_pollfd(fd, _POLLIN, Int16(0))) + pres = ccall( + :poll, + Cint, + (Ptr{_pollfd}, _poll_nfds_t, Cint), + pollfd, + _poll_nfds_t(1), + _RAW_READ_TIMEOUT_MS, + ) + if pres == 0 + error("poll() timeout") + end + if pres < 0 + err = Libc.errno() + err == Libc.EINTR && continue + error("poll() failed: errno=$err") + end + + r = GC.@preserve tmp ccall( + :read, + Cssize_t, + (Cint, Ptr{UInt8}, Csize_t), + fd, + pointer(tmp), + Csize_t(length(tmp)), + ) + r < 0 && error("read() failed: errno=$(Libc.errno())") + r == 0 && break + append!(buf, view(tmp, 1:Int(r))) + stop_pred !== nothing && stop_pred(buf) && return buf + end + return buf + finally + _ = ccall(:close, Cint, (Cint,), fd) + end + end +end + @testset "HTTP.serve" begin server = HTTP.serve!(req -> HTTP.Response(200, "Hello, World!"); listenany=true) try @@ -111,6 +364,7 @@ end HTTP.startwrite(http) write(http, "hello") HTTP.addtrailer(http, "X-Trailer" => "ok") + HTTP.closewrite(http) end try port = HTTP.port(server) @@ -129,7 +383,7 @@ end end @testset "HTTP/2 TLS support" begin - if !AwsIO.tls_is_alpn_available() + if !Reseau.Sockets.tls_is_alpn_available() @info "Skipping HTTP/2 TLS tests; ALPN not available" @test true else diff --git a/test/websockets/autobahn.jl b/test/websockets/autobahn.jl index 5b373c7c..7a62cba3 100644 --- a/test/websockets/autobahn.jl +++ b/test/websockets/autobahn.jl @@ -1,4 +1,4 @@ -using Test, Sockets, HTTP, HTTP.WebSockets, JSON +using Test, HTTP, HTTP.WebSockets, JSON const DIR = abspath(joinpath(dirname(pathof(HTTP)), "../test/websockets")) @@ -96,4 +96,4 @@ end end # @testset "WebSockets" -end # 64-bit only \ No newline at end of file +end # 64-bit only diff --git a/test/websockets/deno_client/server.jl b/test/websockets/deno_client/server.jl index 00ef3c98..d5ee5ec3 100644 --- a/test/websockets/deno_client/server.jl +++ b/test/websockets/deno_client/server.jl @@ -1,4 +1,4 @@ -using Test, Sockets, Deno_jll, HTTP +using Test, Deno_jll, HTTP # Not all architectures have a Deno_jll hasproperty(Deno_jll, :deno) && @testset "WebSocket server" begin @@ -25,4 +25,4 @@ hasproperty(Deno_jll, :deno) && @testset "WebSocket server" begin finally close(server) end -end \ No newline at end of file +end