diff --git a/src/acp/pool.rs b/src/acp/pool.rs index e1d27bf..08b55a2 100644 --- a/src/acp/pool.rs +++ b/src/acp/pool.rs @@ -3,6 +3,7 @@ use crate::acp::protocol::ConfigOption; use crate::config::AgentConfig; use anyhow::{anyhow, Result}; use std::collections::HashMap; +use std::path::{Path, PathBuf}; use std::sync::Arc; use tokio::sync::{Mutex, RwLock}; use tokio::time::Instant; @@ -28,6 +29,7 @@ pub struct SessionPool { state: RwLock, config: AgentConfig, max_sessions: usize, + mapping_path: PathBuf, } type EvictionCandidate = ( @@ -63,15 +65,42 @@ fn get_or_insert_gate( impl SessionPool { pub fn new(config: AgentConfig, max_sessions: usize) -> Self { + let mapping_path = PathBuf::from(&config.working_dir).join("thread_map.json"); + let suspended = Self::load_mapping(&mapping_path); Self { state: RwLock::new(PoolState { active: HashMap::new(), cancel_handles: HashMap::new(), - suspended: HashMap::new(), + suspended, creating: HashMap::new(), }), config, max_sessions, + mapping_path, + } + } + + fn load_mapping(path: &Path) -> HashMap { + match std::fs::read_to_string(path) { + Ok(data) => serde_json::from_str(&data).unwrap_or_else(|e| { + warn!(path = %path.display(), error = %e, "corrupt thread_map.json, starting fresh"); + HashMap::new() + }), + Err(_) => HashMap::new(), + } + } + + fn save_mapping(&self, suspended: &HashMap) { + let data = match serde_json::to_string_pretty(suspended) { + Ok(d) => d, + Err(e) => { + warn!(error = %e, "failed to serialize thread mapping"); + return; + } + }; + let tmp = self.mapping_path.with_extension("json.tmp"); + if let Err(e) = std::fs::write(&tmp, &data).and_then(|_| std::fs::rename(&tmp, &self.mapping_path)) { + warn!(path = %self.mapping_path.display(), error = %e, "failed to persist thread mapping"); } } @@ -213,6 +242,7 @@ impl SessionPool { state.suspended.remove(thread_id); state.active.insert(thread_id.to_string(), new_conn); + self.save_mapping(&state.suspended); if !cancel_session_id.is_empty() { state.cancel_handles.insert(thread_id.to_string(), (cancel_handle, cancel_session_id)); } @@ -330,12 +360,26 @@ impl SessionPool { } } } + self.save_mapping(&state.suspended); } pub async fn shutdown(&self) { let mut state = self.state.write().await; + // Collect handles before borrowing suspended mutably. + let handles: Vec<(String, Arc>)> = state + .active + .iter() + .map(|(k, v)| (k.clone(), Arc::clone(v))) + .collect(); + for (key, conn) in handles { + let conn = conn.lock().await; + if let Some(sid) = conn.acp_session_id.clone() { + state.suspended.insert(key, sid); + } + } + self.save_mapping(&state.suspended); let count = state.active.len(); - state.active.clear(); // Drop impl kills process groups + state.active.clear(); info!(count, "pool shutdown complete"); } }