Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 66 additions & 15 deletions cli/crates/agent-tui-app/src/app/daemon/ws_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use axum::extract::ws::Message;
use axum::extract::ws::WebSocket;
use axum::extract::ws::WebSocketUpgrade;
use axum::extract::ws::close_code;
use axum::http::HeaderMap;
use axum::http::header::COOKIE;
use axum::response::Html;
use axum::response::IntoResponse;
use axum::response::Redirect;
Expand Down Expand Up @@ -162,7 +164,6 @@ struct WsState {
ws_queue_capacity: usize,
shutdown_rx: watch::Receiver<bool>,
auth_token: String,
ws_url: String,
}

pub(crate) fn start_ws_server(
Expand All @@ -176,10 +177,12 @@ pub(crate) fn start_ws_server(

let (listener, local_addr) = bind_listener(&config)?;
let auth_token = generate_ws_auth_token();
let ws_url = format_ws_url(&local_addr, &auth_token);
let ws_url = format_ws_url(&local_addr);
let ws_url_with_auth = format_ws_url_with_auth(&ws_url, &auth_token);
let ui_url = format_ui_url(&local_addr, &ws_url);
let listen_addr = local_addr.to_string();
if let Err(err) = write_state_file(&config.state_path, &ws_url, &ui_url, &listen_addr) {
if let Err(err) = write_state_file(&config.state_path, &ws_url_with_auth, &ui_url, &listen_addr)
{
warn!(error = %err, "Failed to write WS state file");
}

Expand All @@ -191,7 +194,6 @@ pub(crate) fn start_ws_server(
ws_queue_capacity: config.ws_queue_capacity,
shutdown_rx: shutdown_rx.clone(),
auth_token,
ws_url: ws_url.clone(),
});

let state_path = config.state_path.clone();
Expand Down Expand Up @@ -293,12 +295,19 @@ fn build_router(state: Arc<WsState>) -> axum::Router {
}

async fn ui_root_handler(State(state): State<Arc<WsState>>) -> Response {
let ws = encode_url_query_value(&state.ws_url);
Redirect::temporary(&format!("/ui?ws={ws}")).into_response()
(
[("set-cookie", ws_auth_cookie(&state.auth_token))],
Redirect::temporary("/ui"),
)
.into_response()
}

async fn ui_index_handler() -> Response {
Html(UI_INDEX_HTML).into_response()
async fn ui_index_handler(State(state): State<Arc<WsState>>) -> Response {
(
[("set-cookie", ws_auth_cookie(&state.auth_token))],
Html(UI_INDEX_HTML),
)
.into_response()
}

async fn ui_app_js_handler() -> Response {
Expand All @@ -325,9 +334,14 @@ struct WsAuthQuery {
async fn ws_handler(
State(state): State<Arc<WsState>>,
Query(query): Query<WsAuthQuery>,
headers: HeaderMap,
ws: WebSocketUpgrade,
) -> Response {
if query.token.as_deref() != Some(state.auth_token.as_str()) {
let query_token_matches = query.token.as_deref() == Some(state.auth_token.as_str());
let cookie_token_matches = read_cookie_value(&headers, "agent_tui_ws_token")
.as_deref()
.is_some_and(|token| token == state.auth_token.as_str());
if !query_token_matches && !cookie_token_matches {
let response = RpcResponse::error(0, -32001, "unauthorized");
return (
axum::http::StatusCode::UNAUTHORIZED,
Expand Down Expand Up @@ -596,12 +610,16 @@ fn bind_listener(config: &WsConfig) -> Result<(std::net::TcpListener, SocketAddr
Ok((listener, local_addr))
}

fn format_ws_url(addr: &SocketAddr, auth_token: &str) -> String {
fn format_ws_url(addr: &SocketAddr) -> String {
let host = match addr.ip() {
std::net::IpAddr::V4(ip) => ip.to_string(),
std::net::IpAddr::V6(ip) => format!("[{ip}]"),
};
format!("ws://{}:{}/ws?token={}", host, addr.port(), auth_token)
format!("ws://{}:{}/ws", host, addr.port())
}

fn format_ws_url_with_auth(ws_url: &str, auth_token: &str) -> String {
format!("{ws_url}?token={auth_token}")
}

fn format_ui_url(addr: &SocketAddr, ws_url: &str) -> String {
Expand All @@ -617,6 +635,20 @@ fn encode_url_query_value(value: &str) -> String {
url::form_urlencoded::byte_serialize(value.as_bytes()).collect()
}

fn ws_auth_cookie(auth_token: &str) -> String {
format!("agent_tui_ws_token={auth_token}; HttpOnly; SameSite=Strict; Path=/ws")
}

fn read_cookie_value(headers: &HeaderMap, key: &str) -> Option<String> {
let cookie_header = headers.get(COOKIE)?.to_str().ok()?;
cookie_header.split(';').find_map(|cookie| {
let mut parts = cookie.trim().splitn(2, '=');
let cookie_key = parts.next()?.trim();
let cookie_value = parts.next()?.trim();
(cookie_key == key).then(|| cookie_value.to_string())
})
}

fn generate_ws_auth_token() -> String {
Uuid::new_v4().simple().to_string()
}
Expand Down Expand Up @@ -772,23 +804,42 @@ mod tests {
}

#[test]
fn format_ws_url_embeds_auth_token() {
fn format_ws_url_does_not_embed_auth_token() {
let addr: std::net::SocketAddr = "127.0.0.1:12345".parse().expect("valid addr");
let url = super::format_ws_url(&addr, "secret-token");
let url = super::format_ws_url(&addr);
assert_eq!(url, "ws://127.0.0.1:12345/ws");
}

#[test]
fn format_ws_url_with_auth_appends_token() {
let url = super::format_ws_url_with_auth("ws://127.0.0.1:12345/ws", "secret-token");
assert_eq!(url, "ws://127.0.0.1:12345/ws?token=secret-token");
}

#[test]
fn format_ui_url_embeds_encoded_ws_url() {
let addr: std::net::SocketAddr = "127.0.0.1:12345".parse().expect("valid addr");
let ws_url = "ws://127.0.0.1:12345/ws?token=secret-token";
let ws_url = "ws://127.0.0.1:12345/ws";
let ui_url = super::format_ui_url(&addr, ws_url);
assert!(
ui_url.contains("ws=ws%3A%2F%2F127.0.0.1%3A12345%2Fws%3Ftoken%3Dsecret-token"),
ui_url.contains("ws=ws%3A%2F%2F127.0.0.1%3A12345%2Fws"),
"{ui_url}"
);
}

#[test]
fn read_cookie_value_extracts_cookie_from_header() {
let mut headers = axum::http::HeaderMap::new();
headers.insert(
axum::http::header::COOKIE,
"foo=bar; agent_tui_ws_token=secret-token"
.parse()
.expect("valid header"),
);
let token = super::read_cookie_value(&headers, "agent_tui_ws_token");
assert_eq!(token.as_deref(), Some("secret-token"));
}

#[tokio::test(flavor = "current_thread")]
async fn cancel_stream_task_drops_receiver_to_release_backpressure() {
let (tx, rx) = mpsc::channel::<String>(1);
Expand Down
Loading