diff --git a/src/relay/mod.rs b/src/relay/mod.rs index 71e9f2c..9d41ef1 100644 --- a/src/relay/mod.rs +++ b/src/relay/mod.rs @@ -69,7 +69,7 @@ pub struct SubdomainPick { // ── Relay server ──────────────────────────────────────────────────────── type PendingRequests = Arc>>>; -type TunnelSender = tokio::sync::mpsc::UnboundedSender; +type TunnelSender = tokio::sync::mpsc::Sender; struct TunnelConnection { sender: TunnelSender, @@ -327,7 +327,7 @@ async fn handle_tunnel_ws(socket: WebSocket, state: Arc) { } // Create channel for sending requests to this tunnel client - let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let (tx, mut rx) = tokio::sync::mpsc::channel::(64); let pending: PendingRequests = Arc::new(RwLock::new(HashMap::new())); let conn = Arc::new(TunnelConnection { @@ -483,19 +483,34 @@ async fn handle_tunnel_request( let (resp_tx, resp_rx) = oneshot::channel(); conn.pending.write().await.insert(req_id.clone(), resp_tx); - // Send request to tunnel client + // Send request to tunnel client (bounded channel — immediate backpressure) let msg = serde_json::to_string(&tunnel_req).unwrap(); - if conn.sender.send(msg).is_err() { - conn.pending.write().await.remove(&req_id); - relay_log( - &subdomain, - method.as_str(), - &path_str, - 502, - 0, - std::time::Duration::ZERO, - ); - return (StatusCode::BAD_GATEWAY, "tunnel disconnected").into_response(); + match conn.sender.try_send(msg) { + Ok(()) => {} + Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => { + conn.pending.write().await.remove(&req_id); + relay_log( + &subdomain, + method.as_str(), + &path_str, + 503, + 0, + std::time::Duration::ZERO, + ); + return (StatusCode::SERVICE_UNAVAILABLE, "tunnel overloaded").into_response(); + } + Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => { + conn.pending.write().await.remove(&req_id); + relay_log( + &subdomain, + method.as_str(), + &path_str, + 502, + 0, + std::time::Duration::ZERO, + ); + return (StatusCode::BAD_GATEWAY, "tunnel disconnected").into_response(); + } } // Wait for response with timeout