diff --git a/cli/crates/agent-tui-app/src/app/daemon/ws_server.rs b/cli/crates/agent-tui-app/src/app/daemon/ws_server.rs index 2b7b872..f32614b 100644 --- a/cli/crates/agent-tui-app/src/app/daemon/ws_server.rs +++ b/cli/crates/agent-tui-app/src/app/daemon/ws_server.rs @@ -6,6 +6,8 @@ use axum::extract::ws::Message; use axum::extract::ws::WebSocket; use axum::extract::ws::WebSocketUpgrade; use axum::extract::ws::close_code; +use axum::http::HeaderMap; +use axum::http::header::COOKIE; use axum::response::Html; use axum::response::IntoResponse; use axum::response::Redirect; @@ -162,7 +164,6 @@ struct WsState { ws_queue_capacity: usize, shutdown_rx: watch::Receiver, auth_token: String, - ws_url: String, } pub(crate) fn start_ws_server( @@ -176,10 +177,12 @@ pub(crate) fn start_ws_server( let (listener, local_addr) = bind_listener(&config)?; let auth_token = generate_ws_auth_token(); - let ws_url = format_ws_url(&local_addr, &auth_token); + let ws_url = format_ws_url(&local_addr); + let ws_url_with_auth = format_ws_url_with_auth(&ws_url, &auth_token); let ui_url = format_ui_url(&local_addr, &ws_url); let listen_addr = local_addr.to_string(); - if let Err(err) = write_state_file(&config.state_path, &ws_url, &ui_url, &listen_addr) { + if let Err(err) = write_state_file(&config.state_path, &ws_url_with_auth, &ui_url, &listen_addr) + { warn!(error = %err, "Failed to write WS state file"); } @@ -191,7 +194,6 @@ pub(crate) fn start_ws_server( ws_queue_capacity: config.ws_queue_capacity, shutdown_rx: shutdown_rx.clone(), auth_token, - ws_url: ws_url.clone(), }); let state_path = config.state_path.clone(); @@ -293,12 +295,19 @@ fn build_router(state: Arc) -> axum::Router { } async fn ui_root_handler(State(state): State>) -> Response { - let ws = encode_url_query_value(&state.ws_url); - Redirect::temporary(&format!("/ui?ws={ws}")).into_response() + ( + [("set-cookie", ws_auth_cookie(&state.auth_token))], + Redirect::temporary("/ui"), + ) + .into_response() } -async fn ui_index_handler() -> Response { - Html(UI_INDEX_HTML).into_response() +async fn ui_index_handler(State(state): State>) -> Response { + ( + [("set-cookie", ws_auth_cookie(&state.auth_token))], + Html(UI_INDEX_HTML), + ) + .into_response() } async fn ui_app_js_handler() -> Response { @@ -325,9 +334,14 @@ struct WsAuthQuery { async fn ws_handler( State(state): State>, Query(query): Query, + headers: HeaderMap, ws: WebSocketUpgrade, ) -> Response { - if query.token.as_deref() != Some(state.auth_token.as_str()) { + let query_token_matches = query.token.as_deref() == Some(state.auth_token.as_str()); + let cookie_token_matches = read_cookie_value(&headers, "agent_tui_ws_token") + .as_deref() + .is_some_and(|token| token == state.auth_token.as_str()); + if !query_token_matches && !cookie_token_matches { let response = RpcResponse::error(0, -32001, "unauthorized"); return ( axum::http::StatusCode::UNAUTHORIZED, @@ -596,12 +610,16 @@ fn bind_listener(config: &WsConfig) -> Result<(std::net::TcpListener, SocketAddr Ok((listener, local_addr)) } -fn format_ws_url(addr: &SocketAddr, auth_token: &str) -> String { +fn format_ws_url(addr: &SocketAddr) -> String { let host = match addr.ip() { std::net::IpAddr::V4(ip) => ip.to_string(), std::net::IpAddr::V6(ip) => format!("[{ip}]"), }; - format!("ws://{}:{}/ws?token={}", host, addr.port(), auth_token) + format!("ws://{}:{}/ws", host, addr.port()) +} + +fn format_ws_url_with_auth(ws_url: &str, auth_token: &str) -> String { + format!("{ws_url}?token={auth_token}") } fn format_ui_url(addr: &SocketAddr, ws_url: &str) -> String { @@ -617,6 +635,20 @@ fn encode_url_query_value(value: &str) -> String { url::form_urlencoded::byte_serialize(value.as_bytes()).collect() } +fn ws_auth_cookie(auth_token: &str) -> String { + format!("agent_tui_ws_token={auth_token}; HttpOnly; SameSite=Strict; Path=/ws") +} + +fn read_cookie_value(headers: &HeaderMap, key: &str) -> Option { + let cookie_header = headers.get(COOKIE)?.to_str().ok()?; + cookie_header.split(';').find_map(|cookie| { + let mut parts = cookie.trim().splitn(2, '='); + let cookie_key = parts.next()?.trim(); + let cookie_value = parts.next()?.trim(); + (cookie_key == key).then(|| cookie_value.to_string()) + }) +} + fn generate_ws_auth_token() -> String { Uuid::new_v4().simple().to_string() } @@ -772,23 +804,42 @@ mod tests { } #[test] - fn format_ws_url_embeds_auth_token() { + fn format_ws_url_does_not_embed_auth_token() { let addr: std::net::SocketAddr = "127.0.0.1:12345".parse().expect("valid addr"); - let url = super::format_ws_url(&addr, "secret-token"); + let url = super::format_ws_url(&addr); + assert_eq!(url, "ws://127.0.0.1:12345/ws"); + } + + #[test] + fn format_ws_url_with_auth_appends_token() { + let url = super::format_ws_url_with_auth("ws://127.0.0.1:12345/ws", "secret-token"); assert_eq!(url, "ws://127.0.0.1:12345/ws?token=secret-token"); } #[test] fn format_ui_url_embeds_encoded_ws_url() { let addr: std::net::SocketAddr = "127.0.0.1:12345".parse().expect("valid addr"); - let ws_url = "ws://127.0.0.1:12345/ws?token=secret-token"; + let ws_url = "ws://127.0.0.1:12345/ws"; let ui_url = super::format_ui_url(&addr, ws_url); assert!( - ui_url.contains("ws=ws%3A%2F%2F127.0.0.1%3A12345%2Fws%3Ftoken%3Dsecret-token"), + ui_url.contains("ws=ws%3A%2F%2F127.0.0.1%3A12345%2Fws"), "{ui_url}" ); } + #[test] + fn read_cookie_value_extracts_cookie_from_header() { + let mut headers = axum::http::HeaderMap::new(); + headers.insert( + axum::http::header::COOKIE, + "foo=bar; agent_tui_ws_token=secret-token" + .parse() + .expect("valid header"), + ); + let token = super::read_cookie_value(&headers, "agent_tui_ws_token"); + assert_eq!(token.as_deref(), Some("secret-token")); + } + #[tokio::test(flavor = "current_thread")] async fn cancel_stream_task_drops_receiver_to_release_backpressure() { let (tx, rx) = mpsc::channel::(1);