diff --git a/crates/openshell-sandbox/src/l7/provider.rs b/crates/openshell-sandbox/src/l7/provider.rs index a9bf8bf5..9c6b0fce 100644 --- a/crates/openshell-sandbox/src/l7/provider.rs +++ b/crates/openshell-sandbox/src/l7/provider.rs @@ -13,6 +13,21 @@ use miette::Result; use std::future::Future; use tokio::io::{AsyncRead, AsyncWrite}; +/// Outcome of relaying a single HTTP request/response pair. +#[derive(Debug)] +pub enum RelayOutcome { + /// Connection is reusable for further HTTP requests (keep-alive). + Reusable, + /// Connection was consumed (e.g. read-until-EOF or `Connection: close`). + Consumed, + /// Server responded with 101 Switching Protocols. + /// The connection has been upgraded (e.g. to WebSocket) and must be + /// relayed as raw bidirectional TCP from this point forward. + /// Contains any overflow bytes read from upstream after the 101 headers + /// that must be forwarded to the client before switching to copy mode. + Upgraded { overflow: Vec }, +} + /// Body framing for HTTP requests/responses. #[derive(Debug, Clone, Copy)] pub enum BodyLength { @@ -54,14 +69,15 @@ pub trait L7Provider: Send + Sync { /// Forward an allowed request to upstream and relay the response back. /// - /// Returns `true` if the upstream connection is reusable (keep-alive), - /// `false` if it was consumed (e.g. read-until-EOF or `Connection: close`). + /// Returns a [`RelayOutcome`] indicating whether the connection is + /// reusable (keep-alive), consumed, or has been upgraded (101 Switching + /// Protocols) and must be relayed as raw bidirectional TCP. fn relay( &self, req: &L7Request, client: &mut C, upstream: &mut U, - ) -> impl Future> + Send + ) -> impl Future> + Send where C: AsyncRead + AsyncWrite + Unpin + Send, U: AsyncRead + AsyncWrite + Unpin + Send; diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index 940e7f94..50857d3f 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -7,12 +7,12 @@ //! Parses each request within the tunnel, evaluates it against OPA policy, //! and either forwards or denies the request. -use crate::l7::provider::L7Provider; +use crate::l7::provider::{L7Provider, RelayOutcome}; use crate::l7::{EnforcementMode, L7EndpointConfig, L7Protocol, L7RequestInfo}; use crate::secrets::SecretResolver; use miette::{IntoDiagnostic, Result, miette}; use std::sync::{Arc, Mutex}; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, info, warn}; /// Context for L7 request policy evaluation. @@ -134,20 +134,42 @@ where if allowed || config.enforcement == EnforcementMode::Audit { // Forward request to upstream and relay response - let reusable = crate::l7::rest::relay_http_request_with_resolver( + let outcome = crate::l7::rest::relay_http_request_with_resolver( &req, client, upstream, ctx.secret_resolver.as_deref(), ) .await?; - if !reusable { - debug!( - host = %ctx.host, - port = ctx.port, - "Upstream connection not reusable, closing L7 relay" - ); - return Ok(()); + match outcome { + RelayOutcome::Reusable => {} // continue loop + RelayOutcome::Consumed => { + debug!( + host = %ctx.host, + port = ctx.port, + "Upstream connection not reusable, closing L7 relay" + ); + return Ok(()); + } + RelayOutcome::Upgraded { overflow } => { + info!( + host = %ctx.host, + port = ctx.port, + overflow_bytes = overflow.len(), + "101 Switching Protocols — switching to raw bidirectional relay" + ); + // Forward any overflow bytes from the upgrade response + if !overflow.is_empty() { + client.write_all(&overflow).await.into_diagnostic()?; + client.flush().await.into_diagnostic()?; + } + // Switch to raw bidirectional TCP copy for the upgraded + // protocol (WebSocket, HTTP/2, etc.) + tokio::io::copy_bidirectional(client, upstream) + .await + .into_diagnostic()?; + return Ok(()); + } } } else { // Enforce mode: deny with 403 and close connection @@ -278,12 +300,29 @@ where // Forward request with credential rewriting and relay the response. // relay_http_request_with_resolver handles both directions: it sends // the request upstream and reads the response back to the client. - let reusable = + let outcome = crate::l7::rest::relay_http_request_with_resolver(&req, client, upstream, resolver) .await?; - if !reusable { - break; + match outcome { + RelayOutcome::Reusable => {} // continue loop + RelayOutcome::Consumed => break, + RelayOutcome::Upgraded { overflow } => { + info!( + host = %ctx.host, + port = ctx.port, + overflow_bytes = overflow.len(), + "101 Switching Protocols — switching to raw bidirectional relay" + ); + if !overflow.is_empty() { + client.write_all(&overflow).await.into_diagnostic()?; + client.flush().await.into_diagnostic()?; + } + tokio::io::copy_bidirectional(client, upstream) + .await + .into_diagnostic()?; + return Ok(()); + } } } diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index ebb34957..ca3e5d9b 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -7,7 +7,7 @@ //! policy, and relays allowed requests to upstream. Handles Content-Length //! and chunked transfer encoding for body framing. -use crate::l7::provider::{BodyLength, L7Provider, L7Request}; +use crate::l7::provider::{BodyLength, L7Provider, L7Request, RelayOutcome}; use crate::secrets::rewrite_http_header_block; use miette::{IntoDiagnostic, Result, miette}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -31,7 +31,12 @@ impl L7Provider for RestProvider { parse_http_request(client).await } - async fn relay(&self, req: &L7Request, client: &mut C, upstream: &mut U) -> Result + async fn relay( + &self, + req: &L7Request, + client: &mut C, + upstream: &mut U, + ) -> Result where C: AsyncRead + AsyncWrite + Unpin + Send, U: AsyncRead + AsyncWrite + Unpin + Send, @@ -140,8 +145,13 @@ async fn parse_http_request(client: &mut C) -> Result(req: &L7Request, client: &mut C, upstream: &mut U) -> Result +/// Returns the relay outcome indicating whether the connection is reusable, +/// consumed, or has been upgraded (e.g. WebSocket via 101 Switching Protocols). +async fn relay_http_request( + req: &L7Request, + client: &mut C, + upstream: &mut U, +) -> Result where C: AsyncRead + AsyncWrite + Unpin, U: AsyncRead + AsyncWrite + Unpin, @@ -154,7 +164,7 @@ pub(crate) async fn relay_http_request_with_resolver( client: &mut C, upstream: &mut U, resolver: Option<&crate::secrets::SecretResolver>, -) -> Result +) -> Result where C: AsyncRead + AsyncWrite + Unpin, U: AsyncRead + AsyncWrite + Unpin, @@ -191,8 +201,19 @@ where BodyLength::None => {} } upstream.flush().await.into_diagnostic()?; - let (reusable, _) = relay_response(&req.action, upstream, client).await?; - Ok(reusable) + let (reusable, status_code, resp_overflow) = + relay_response(&req.action, upstream, client).await?; + + if status_code == 101 { + return Ok(RelayOutcome::Upgraded { + overflow: resp_overflow, + }); + } + if reusable { + Ok(RelayOutcome::Reusable) + } else { + Ok(RelayOutcome::Consumed) + } } /// Send a 403 Forbidden JSON deny response. @@ -429,7 +450,7 @@ where U: AsyncRead + Unpin, C: AsyncWrite + Unpin, { - let (reusable, _status) = relay_response(request_method, upstream, client).await?; + let (reusable, _status, _overflow) = relay_response(request_method, upstream, client).await?; Ok(reusable) } @@ -437,7 +458,7 @@ async fn relay_response( request_method: &str, upstream: &mut U, client: &mut C, -) -> Result<(bool, u16)> +) -> Result<(bool, u16, Vec)> where U: AsyncRead + Unpin, C: AsyncWrite + Unpin, @@ -458,7 +479,7 @@ where if !buf.is_empty() { client.write_all(&buf).await.into_diagnostic()?; } - return Ok((false, 0)); + return Ok((false, 0, Vec::new())); } buf.extend_from_slice(&tmp[..n]); @@ -484,6 +505,26 @@ where "relay_response framing" ); + // 101 Switching Protocols: the connection has been upgraded (e.g. to + // WebSocket). Forward the 101 headers to the client and signal the + // caller to switch to raw bidirectional TCP relay. Any bytes read + // from upstream beyond the headers are overflow that belong to the + // upgraded protocol and must be forwarded before switching. + if status_code == 101 { + client + .write_all(&buf[..header_end]) + .await + .into_diagnostic()?; + client.flush().await.into_diagnostic()?; + let overflow = buf[header_end..].to_vec(); + debug!( + request_method, + overflow_bytes = overflow.len(), + "101 Switching Protocols — signaling protocol upgrade" + ); + return Ok((false, status_code, overflow)); + } + // Bodiless responses (HEAD, 1xx, 204, 304): forward headers only, skip body if is_bodiless_response(request_method, status_code) { client @@ -491,7 +532,7 @@ where .await .into_diagnostic()?; client.flush().await.into_diagnostic()?; - return Ok((!server_wants_close, status_code)); + return Ok((!server_wants_close, status_code, Vec::new())); } // No explicit framing (no Content-Length, no Transfer-Encoding). @@ -511,7 +552,7 @@ where } relay_until_eof(upstream, client).await?; client.flush().await.into_diagnostic()?; - return Ok((false, status_code)); + return Ok((false, status_code, Vec::new())); } // No Connection: close — an HTTP/1.1 keep-alive server that omits // framing headers has an empty body. Forward headers and continue @@ -522,7 +563,7 @@ where .await .into_diagnostic()?; client.flush().await.into_diagnostic()?; - return Ok((true, status_code)); + return Ok((true, status_code, Vec::new())); } // Forward response headers + any overflow body bytes @@ -555,7 +596,7 @@ where // loop will exit via the normal error path. Exiting early here would // tear down the CONNECT tunnel before the client can detect the close, // causing ~30 s retry delays in clients like `gh`. - Ok((true, status_code)) + Ok((true, status_code, Vec::new())) } /// Parse the HTTP status code from a response status line. @@ -858,7 +899,7 @@ mod tests { .await .expect("relay_response should not deadlock"); - let (reusable, _status) = result.expect("relay_response should succeed"); + let (reusable, _status, _overflow) = result.expect("relay_response should succeed"); assert!(!reusable, "connection consumed by read-until-EOF"); client_write.shutdown().await.unwrap(); @@ -896,7 +937,7 @@ mod tests { .await .expect("must not block when no Connection: close"); - let (reusable, _status) = result.expect("relay_response should succeed"); + let (reusable, _status, _overflow) = result.expect("relay_response should succeed"); assert!(reusable, "keep-alive implied, connection reusable"); client_write.shutdown().await.unwrap(); @@ -929,7 +970,7 @@ mod tests { .await .expect("HEAD relay must not deadlock waiting for body"); - let (reusable, _status) = result.expect("relay_response should succeed"); + let (reusable, _status, _overflow) = result.expect("relay_response should succeed"); assert!(reusable, "HEAD response should be reusable"); client_write.shutdown().await.unwrap(); @@ -959,7 +1000,7 @@ mod tests { .await .expect("204 relay must not deadlock"); - let (reusable, _status) = result.expect("relay_response should succeed"); + let (reusable, _status, _overflow) = result.expect("relay_response should succeed"); assert!(reusable, "204 response should be reusable"); client_write.shutdown().await.unwrap(); @@ -991,7 +1032,7 @@ mod tests { .await .expect("must not block when chunked body is complete in overflow"); - let (reusable, _status) = result.expect("relay_response should succeed"); + let (reusable, _status, _overflow) = result.expect("relay_response should succeed"); assert!(reusable, "connection should be reusable"); client_write.shutdown().await.unwrap(); @@ -1027,7 +1068,7 @@ mod tests { .await .expect("must not block when chunked response has trailers"); - let (reusable, _status) = result.expect("relay_response should succeed"); + let (reusable, _status, _overflow) = result.expect("relay_response should succeed"); assert!(reusable, "chunked response should be reusable"); client_write.shutdown().await.unwrap(); @@ -1062,7 +1103,7 @@ mod tests { .await .expect("normal relay must not deadlock"); - let (reusable, _status) = result.expect("relay_response should succeed"); + let (reusable, _status, _overflow) = result.expect("relay_response should succeed"); assert!(reusable, "Content-Length response should be reusable"); client_write.shutdown().await.unwrap(); @@ -1090,7 +1131,7 @@ mod tests { .await .expect("relay must not deadlock"); - let (reusable, _status) = result.expect("relay_response should succeed"); + let (reusable, _status, _overflow) = result.expect("relay_response should succeed"); // With explicit framing, Connection: close is still reported as reusable // so the relay loop continues. The *next* upstream write will fail and // exit the loop via the normal error path. @@ -1105,6 +1146,51 @@ mod tests { assert!(String::from_utf8_lossy(&received).contains("hello")); } + #[tokio::test] + async fn relay_response_101_switching_protocols_returns_overflow() { + // Build a 101 response followed by WebSocket frame data (overflow). + let mut response = Vec::new(); + response.extend_from_slice(b"HTTP/1.1 101 Switching Protocols\r\n"); + response.extend_from_slice(b"Upgrade: websocket\r\n"); + response.extend_from_slice(b"Connection: Upgrade\r\n"); + response.extend_from_slice(b"\r\n"); + response.extend_from_slice(b"\x81\x05hello"); // WebSocket frame + + let (upstream_read, mut upstream_write) = tokio::io::duplex(4096); + let (mut client_read, client_write) = tokio::io::duplex(4096); + + upstream_write.write_all(&response).await.unwrap(); + drop(upstream_write); + + let mut upstream_read = upstream_read; + let mut client_write = client_write; + + let result = tokio::time::timeout( + std::time::Duration::from_secs(2), + relay_response("GET", &mut upstream_read, &mut client_write), + ) + .await + .expect("relay_response should not deadlock"); + + let (reusable, status, overflow) = result.expect("relay_response should succeed"); + assert!(!reusable, "101 should signal non-reusable"); + assert_eq!(status, 101); + assert_eq!( + &overflow, + b"\x81\x05hello", + "overflow should contain WebSocket frame data" + ); + + client_write.shutdown().await.unwrap(); + let mut received = Vec::new(); + client_read.read_to_end(&mut received).await.unwrap(); + let received_str = String::from_utf8_lossy(&received); + assert!( + received_str.contains("101 Switching Protocols"), + "client should receive the 101 response headers" + ); + } + #[test] fn rewrite_header_block_resolves_placeholder_auth_headers() { let (_, resolver) = SecretResolver::from_provider_env(