From dac5d2dd4240babfbb5b6d85d2820032e4ad4802 Mon Sep 17 00:00:00 2001 From: Rugved Somwanshi Date: Mon, 23 Feb 2026 15:44:30 -0500 Subject: [PATCH 01/11] Update lms client --- codex-rs/lmstudio/src/client.rs | 415 ++++++++++++++++++++++++-------- 1 file changed, 319 insertions(+), 96 deletions(-) diff --git a/codex-rs/lmstudio/src/client.rs b/codex-rs/lmstudio/src/client.rs index a2a8ee03bff..3862e61d226 100644 --- a/codex-rs/lmstudio/src/client.rs +++ b/codex-rs/lmstudio/src/client.rs @@ -1,7 +1,7 @@ use codex_core::LMSTUDIO_OSS_PROVIDER_ID; use codex_core::config::Config; use std::io; -use std::path::Path; +use std::io::Write; #[derive(Clone)] pub struct LMStudioClient { @@ -9,7 +9,7 @@ pub struct LMStudioClient { base_url: String, } -const LMSTUDIO_CONNECTION_ERROR: &str = "LM Studio is not responding. Install from https://lmstudio.ai/download and run 'lms server start'."; +const LMSTUDIO_CONNECTION_ERROR: &str = "LM Studio is not responding. Install from https://lmstudio.ai/download and start the LM Studio server."; impl LMStudioClient { pub async fn try_from_provider(config: &Config) -> std::io::Result { @@ -43,8 +43,20 @@ impl LMStudioClient { Ok(client) } + fn api_base_url(&self) -> String { + let base_url = self.base_url.trim_end_matches('/'); + let base_url = base_url + .strip_suffix("/api/v1") + .or_else(|| base_url.strip_suffix("/v1")) + .unwrap_or(base_url); + base_url.to_string() + } + async fn check_server(&self) -> io::Result<()> { - let url = format!("{}/models", self.base_url.trim_end_matches('/')); + let url = format!( + "{base_url}/models", + base_url = self.base_url.trim_end_matches('/') + ); let response = self.client.get(&url).send().await; if let Ok(resp) = response { @@ -52,8 +64,8 @@ impl LMStudioClient { Ok(()) } else { Err(io::Error::other(format!( - "Server returned error: {} {LMSTUDIO_CONNECTION_ERROR}", - resp.status() + "Server returned error: {status} {LMSTUDIO_CONNECTION_ERROR}", + status = resp.status() ))) } } else { @@ -63,12 +75,11 @@ impl LMStudioClient { // Load a model by sending an empty request with max_tokens 1 pub async fn load_model(&self, model: &str) -> io::Result<()> { - let url = format!("{}/responses", self.base_url.trim_end_matches('/')); + let api_base_url = self.api_base_url(); + let url = format!("{api_base_url}/api/v1/models/load"); let request_body = serde_json::json!({ - "model": model, - "input": "", - "max_output_tokens": 1 + "model": model }); let response = self @@ -85,15 +96,18 @@ impl LMStudioClient { Ok(()) } else { Err(io::Error::other(format!( - "Failed to load model: {}", - response.status() + "Failed to load model: {status}", + status = response.status() ))) } } // Return the list of models available on the LM Studio server. pub async fn fetch_models(&self) -> io::Result> { - let url = format!("{}/models", self.base_url.trim_end_matches('/')); + let url = format!( + "{base_url}/models", + base_url = self.base_url.trim_end_matches('/') + ); let response = self .client .get(&url) @@ -117,76 +131,164 @@ impl LMStudioClient { Ok(models) } else { Err(io::Error::other(format!( - "Failed to fetch models: {}", - response.status() + "Failed to fetch models: {status}", + status = response.status() ))) } } - // Find lms, checking fallback paths if not in PATH - fn find_lms() -> std::io::Result { - Self::find_lms_with_home_dir(None) - } + pub async fn download_model(&self, model: &str) -> std::io::Result<()> { + let api_base_url = self.api_base_url(); + let url = format!("{api_base_url}/api/v1/models/download"); - fn find_lms_with_home_dir(home_dir: Option<&str>) -> std::io::Result { - // First try 'lms' in PATH - if which::which("lms").is_ok() { - return Ok("lms".to_string()); - } + let request_body = serde_json::json!({ + "model": model + }); - // Platform-specific fallback paths - let home = match home_dir { - Some(dir) => dir.to_string(), - None => { - #[cfg(unix)] - { - std::env::var("HOME").unwrap_or_default() - } - #[cfg(windows)] - { - std::env::var("USERPROFILE").unwrap_or_default() - } - } - }; + let response = self + .client + .post(&url) + .header("Content-Type", "application/json") + .json(&request_body) + .send() + .await + .map_err(|e| io::Error::other(format!("Request failed: {e}")))?; - #[cfg(unix)] - let fallback_path = format!("{home}/.lmstudio/bin/lms"); + if !response.status().is_success() { + return Err(io::Error::other(format!( + "Failed to download model: {status}", + status = response.status() + ))); + } - #[cfg(windows)] - let fallback_path = format!("{home}/.lmstudio/bin/lms.exe"); + let download_status = response.json::().await.map_err(|e| { + io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}")) + })?; - if Path::new(&fallback_path).exists() { - Ok(fallback_path) - } else { - Err(std::io::Error::new( - std::io::ErrorKind::NotFound, - "LM Studio not found. Please install LM Studio from https://lmstudio.ai/", - )) - } - } + let parse_status = |json: &serde_json::Value| -> io::Result<( + String, + Option, + Option, + Option, + )> { + let status = json["status"] + .as_str() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Missing status"))? + .to_string(); + let job_id = json["job_id"].as_str().map(std::string::ToString::to_string); + let downloaded_bytes = json["downloaded_bytes"].as_u64(); + let total_size_bytes = json["total_size_bytes"].as_u64(); + Ok((status, job_id, downloaded_bytes, total_size_bytes)) + }; - pub async fn download_model(&self, model: &str) -> std::io::Result<()> { - let lms = Self::find_lms()?; - eprintln!("Downloading model: {model}"); - - let status = std::process::Command::new(&lms) - .args(["get", "--yes", model]) - .stdout(std::process::Stdio::inherit()) - .stderr(std::process::Stdio::null()) - .status() - .map_err(|e| { - std::io::Error::other(format!("Failed to execute '{lms} get --yes {model}': {e}")) - })?; + let (status, job_id, _, _) = parse_status(&download_status)?; - if !status.success() { - return Err(std::io::Error::other(format!( - "Model download failed with exit code: {}", - status.code().unwrap_or(-1) - ))); + match status.as_str() { + "already_downloaded" | "completed" => { + tracing::info!("Model '{model}' is ready"); + Ok(()) + } + "failed" => Err(io::Error::other(format!( + "Model download failed for '{model}'" + ))), + "paused" => Err(io::Error::other(format!( + "Model download paused for '{model}'" + ))), + "downloading" => { + let job_id = job_id.as_deref().ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "Download status missing job_id") + })?; + + let mut last_logged = + std::time::Instant::now() - std::time::Duration::from_secs(10); + + loop { + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + let status_url = format!( + "{api_base_url}/api/v1/models/download/status/{job_id}", + api_base_url = api_base_url + ); + + let status_response = self + .client + .get(&status_url) + .send() + .await + .map_err(|e| io::Error::other(format!("Request failed: {e}")))?; + + if !status_response.status().is_success() { + return Err(io::Error::other(format!( + "Failed to fetch download status: {status}", + status = status_response.status() + ))); + } + + let status = + status_response + .json::() + .await + .map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("JSON parse error: {e}"), + ) + })?; + let (status_value, _, downloaded_bytes, total_size_bytes) = + parse_status(&status)?; + + match status_value.as_str() { + "completed" => { + eprintln!(); + tracing::info!("Successfully downloaded model '{model}'"); + return Ok(()); + } + "failed" => { + eprintln!(); + return Err(io::Error::other(format!( + "Model download failed for '{model}'" + ))); + } + "paused" => { + eprintln!(); + return Err(io::Error::other(format!( + "Model download paused for '{model}'" + ))); + } + "downloading" => { + if let (Some(downloaded), Some(total)) = + (downloaded_bytes, total_size_bytes) + { + let now = std::time::Instant::now(); + if now.duration_since(last_logged) + >= std::time::Duration::from_millis(500) + { + let percent = (downloaded as f64 / total as f64) * 100.0; + let downloaded_mb = downloaded as f64 / (1024.0 * 1024.0); + let total_gb = total as f64 / (1024.0 * 1024.0 * 1024.0); + eprint!( + "\rDownloading '{model}': {downloaded_mb:.2} MB / {total_gb:.2} GB ({percent:.1}%)", + downloaded_mb = downloaded_mb, + total_gb = total_gb, + percent = percent + ); + let _ = std::io::stderr().flush(); + last_logged = now; + } + } + } + status_value => { + eprintln!(); + return Err(io::Error::other(format!( + "Unknown download status '{status_value}' for '{model}'" + ))); + } + } + } + } + status_value => Err(io::Error::other(format!( + "Unknown download status '{status_value}' for '{model}'" + ))), } - - tracing::info!("Successfully downloaded model '{model}'"); - Ok(()) } /// Low-level constructor given a raw host root, e.g. "http://localhost:1234". @@ -351,39 +453,160 @@ mod tests { ); } - #[test] - fn test_find_lms() { - let result = LMStudioClient::find_lms(); + #[tokio::test] + async fn test_load_model_happy_path() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_load_model_happy_path", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } - match result { - Ok(_) => { - // lms was found in PATH - that's fine - } - Err(e) => { - // Expected error when LM Studio not installed - assert!(e.to_string().contains("LM Studio not found")); - } + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path("/api/v1/models/load")) + .respond_with(wiremock::ResponseTemplate::new(200)) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(format!("{uri}/v1", uri = server.uri())); + client + .load_model("openai/gpt-oss-20b") + .await + .expect("load model"); + } + + #[tokio::test] + async fn test_load_model_error() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_load_model_error", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; } + + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path("/api/v1/models/load")) + .respond_with(wiremock::ResponseTemplate::new(500)) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(format!("{uri}/v1", uri = server.uri())); + let result = client.load_model("openai/gpt-oss-20b").await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Failed to load model: 500") + ); } - #[test] - fn test_find_lms_with_mock_home() { - // Test fallback path construction without touching env vars - #[cfg(unix)] - { - let result = LMStudioClient::find_lms_with_home_dir(Some("/test/home")); - if let Err(e) = result { - assert!(e.to_string().contains("LM Studio not found")); - } + #[tokio::test] + async fn test_download_model_happy_path() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_download_model_happy_path", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; } - #[cfg(windows)] - { - let result = LMStudioClient::find_lms_with_home_dir(Some("C:\\test\\home")); - if let Err(e) = result { - assert!(e.to_string().contains("LM Studio not found")); - } + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path("/api/v1/models/download")) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "job_id": "job-1", + "status": "downloading" + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path( + "/api/v1/models/download/status/job-1", + )) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "job_id": "job-1", + "status": "completed" + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(format!("{uri}/v1", uri = server.uri())); + client + .download_model("openai/gpt-oss-20b") + .await + .expect("download model"); + } + + #[tokio::test] + async fn test_download_model_error() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_download_model_error", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; } + + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path("/api/v1/models/download")) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "job_id": "job-1", + "status": "downloading" + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path( + "/api/v1/models/download/status/job-1", + )) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "job_id": "job-1", + "status": "failed" + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(format!("{uri}/v1", uri = server.uri())); + let result = client.download_model("openai/gpt-oss-20b").await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Model download failed") + ); } #[test] From 5bc500bde48789bffad0e68eaafc8df17e973ccc Mon Sep 17 00:00:00 2001 From: Rugved Somwanshi Date: Mon, 23 Feb 2026 17:23:12 -0500 Subject: [PATCH 02/11] Add is model loaded --- codex-rs/lmstudio/src/client.rs | 135 +++++++++++++++++++++++++++++--- 1 file changed, 126 insertions(+), 9 deletions(-) diff --git a/codex-rs/lmstudio/src/client.rs b/codex-rs/lmstudio/src/client.rs index 3862e61d226..0222dee8bd8 100644 --- a/codex-rs/lmstudio/src/client.rs +++ b/codex-rs/lmstudio/src/client.rs @@ -73,13 +73,49 @@ impl LMStudioClient { } } + async fn is_model_loaded(&self, model: &str) -> bool { + let api_base_url = self.api_base_url(); + let models_url = format!("{api_base_url}/api/v1/models"); + let Ok(response) = self.client.get(&models_url).send().await else { + return false; + }; + if !response.status().is_success() { + return false; + } + let Ok(json) = response.json::().await else { + return false; + }; + let models = json + .get("models") + .or_else(|| json.get("data")) + .and_then(|value| value.as_array()); + models.is_some_and(|entries| { + entries.iter().any(|entry| { + let key = entry + .get("key") + .and_then(|value| value.as_str()) + .or_else(|| entry.get("id").and_then(|value| value.as_str())); + let loaded_instances = entry + .get("loaded_instances") + .and_then(|value| value.as_array()); + key == Some(model) + && loaded_instances.is_some_and(|instances| !instances.is_empty()) + }) + }) + } + // Load a model by sending an empty request with max_tokens 1 pub async fn load_model(&self, model: &str) -> io::Result<()> { let api_base_url = self.api_base_url(); + if self.is_model_loaded(model).await { + tracing::info!("Model '{model}' already loaded; reusing existing instance"); + return Ok(()); + } let url = format!("{api_base_url}/api/v1/models/load"); let request_body = serde_json::json!({ - "model": model + "model": model, + "context_length": 7000 }); let response = self @@ -204,10 +240,8 @@ impl LMStudioClient { loop { tokio::time::sleep(std::time::Duration::from_secs(2)).await; - let status_url = format!( - "{api_base_url}/api/v1/models/download/status/{job_id}", - api_base_url = api_base_url - ); + let status_url = + format!("{api_base_url}/api/v1/models/download/status/{job_id}"); let status_response = self .client @@ -266,10 +300,7 @@ impl LMStudioClient { let downloaded_mb = downloaded as f64 / (1024.0 * 1024.0); let total_gb = total as f64 / (1024.0 * 1024.0 * 1024.0); eprint!( - "\rDownloading '{model}': {downloaded_mb:.2} MB / {total_gb:.2} GB ({percent:.1}%)", - downloaded_mb = downloaded_mb, - total_gb = total_gb, - percent = percent + "\rDownloading '{model}': {downloaded_mb:.2} MB / {total_gb:.2} GB ({percent:.1}%)" ); let _ = std::io::stderr().flush(); last_logged = now; @@ -477,6 +508,92 @@ mod tests { .expect("load model"); } + #[tokio::test] + async fn test_load_model_reuses_loaded_instance() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_load_model_reuses_loaded_instance", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/api/v1/models")) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "models": [ + { + "key": "openai/gpt-oss-20b", + "loaded_instances": [ + { + "config": { + "context_length": 7000 + } + } + ] + } + ] + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + wiremock::Mock::given(wiremock::matchers::method("POST")) + .and(wiremock::matchers::path("/api/v1/models/load")) + .respond_with(wiremock::ResponseTemplate::new(500)) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + let result = client.load_model("openai/gpt-oss-20b").await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_is_model_loaded_true() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_is_model_loaded_true", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/api/v1/models")) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "models": [ + { + "key": "openai/gpt-oss-20b", + "loaded_instances": [ + { + "config": { + "context_length": 7000 + } + } + ] + } + ] + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + assert!(client.is_model_loaded("openai/gpt-oss-20b").await); + } + #[tokio::test] async fn test_load_model_error() { if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { From 432e071727f546046c7e5ca380cb2255013d80e6 Mon Sep 17 00:00:00 2001 From: Rugved Somwanshi Date: Mon, 23 Feb 2026 17:49:37 -0500 Subject: [PATCH 03/11] Cleanup cleanup 2 tweak test default context length cleanup 3 --- codex-rs/lmstudio/src/client.rs | 124 +++++++++++++++++++------------- 1 file changed, 75 insertions(+), 49 deletions(-) diff --git a/codex-rs/lmstudio/src/client.rs b/codex-rs/lmstudio/src/client.rs index 0222dee8bd8..3c96c546e84 100644 --- a/codex-rs/lmstudio/src/client.rs +++ b/codex-rs/lmstudio/src/client.rs @@ -9,7 +9,9 @@ pub struct LMStudioClient { base_url: String, } -const LMSTUDIO_CONNECTION_ERROR: &str = "LM Studio is not responding. Install from https://lmstudio.ai/download and start the LM Studio server."; +const LMSTUDIO_CONNECTION_ERROR: &str = "LM Studio is not responding. Install from https://lmstudio.ai/download and run 'lms server start'."; +// 8192 tokens provides headroom above the observed ~6000 token initial input. +const DEFAULT_CONTEXT_LENGTH: u32 = 8192; impl LMStudioClient { pub async fn try_from_provider(config: &Config) -> std::io::Result { @@ -43,20 +45,13 @@ impl LMStudioClient { Ok(client) } - fn api_base_url(&self) -> String { + fn host_root(&self) -> String { let base_url = self.base_url.trim_end_matches('/'); - let base_url = base_url - .strip_suffix("/api/v1") - .or_else(|| base_url.strip_suffix("/v1")) - .unwrap_or(base_url); - base_url.to_string() + base_url.strip_suffix("/v1").unwrap_or(base_url).to_string() } async fn check_server(&self) -> io::Result<()> { - let url = format!( - "{base_url}/models", - base_url = self.base_url.trim_end_matches('/') - ); + let url = format!("{}/v1/models", self.host_root()); let response = self.client.get(&url).send().await; if let Ok(resp) = response { @@ -73,9 +68,9 @@ impl LMStudioClient { } } + // Check if a model is already loaded with the same key async fn is_model_loaded(&self, model: &str) -> bool { - let api_base_url = self.api_base_url(); - let models_url = format!("{api_base_url}/api/v1/models"); + let models_url = format!("{}/api/v1/models", self.host_root()); let Ok(response) = self.client.get(&models_url).send().await else { return false; }; @@ -85,37 +80,34 @@ impl LMStudioClient { let Ok(json) = response.json::().await else { return false; }; - let models = json - .get("models") - .or_else(|| json.get("data")) - .and_then(|value| value.as_array()); + let models = json.get("models").and_then(|value| value.as_array()); models.is_some_and(|entries| { entries.iter().any(|entry| { - let key = entry - .get("key") - .and_then(|value| value.as_str()) - .or_else(|| entry.get("id").and_then(|value| value.as_str())); let loaded_instances = entry .get("loaded_instances") .and_then(|value| value.as_array()); - key == Some(model) - && loaded_instances.is_some_and(|instances| !instances.is_empty()) + // A model is considered loaded if any of its loaded_instances + // shares the same id as the requested model. + loaded_instances.is_some_and(|instances| { + instances.iter().any(|instance| { + instance.get("id").and_then(|value| value.as_str()) == Some(model) + }) + }) }) }) } // Load a model by sending an empty request with max_tokens 1 pub async fn load_model(&self, model: &str) -> io::Result<()> { - let api_base_url = self.api_base_url(); if self.is_model_loaded(model).await { tracing::info!("Model '{model}' already loaded; reusing existing instance"); return Ok(()); } - let url = format!("{api_base_url}/api/v1/models/load"); + let url = format!("{}/api/v1/models/load", self.host_root()); let request_body = serde_json::json!({ "model": model, - "context_length": 7000 + "context_length": DEFAULT_CONTEXT_LENGTH }); let response = self @@ -140,10 +132,7 @@ impl LMStudioClient { // Return the list of models available on the LM Studio server. pub async fn fetch_models(&self) -> io::Result> { - let url = format!( - "{base_url}/models", - base_url = self.base_url.trim_end_matches('/') - ); + let url = format!("{}/v1/models", self.host_root()); let response = self .client .get(&url) @@ -174,8 +163,7 @@ impl LMStudioClient { } pub async fn download_model(&self, model: &str) -> std::io::Result<()> { - let api_base_url = self.api_base_url(); - let url = format!("{api_base_url}/api/v1/models/download"); + let url = format!("{}/api/v1/models/download", self.host_root()); let request_body = serde_json::json!({ "model": model @@ -240,8 +228,10 @@ impl LMStudioClient { loop { tokio::time::sleep(std::time::Duration::from_secs(2)).await; - let status_url = - format!("{api_base_url}/api/v1/models/download/status/{job_id}"); + let status_url = format!( + "{}/api/v1/models/download/status/{job_id}", + self.host_root() + ); let status_response = self .client @@ -353,12 +343,12 @@ mod tests { let server = wiremock::MockServer::start().await; wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/models")) + .and(wiremock::matchers::path("/v1/models")) .respond_with( wiremock::ResponseTemplate::new(200).set_body_raw( serde_json::json!({ "data": [ - {"id": "openai/gpt-oss-20b"}, + {"id": "test/test-model"}, ] }) .to_string(), @@ -370,7 +360,7 @@ mod tests { let client = LMStudioClient::from_host_root(server.uri()); let models = client.fetch_models().await.expect("fetch models"); - assert!(models.contains(&"openai/gpt-oss-20b".to_string())); + assert!(models.contains(&"test/test-model".to_string())); } #[tokio::test] @@ -385,7 +375,7 @@ mod tests { let server = wiremock::MockServer::start().await; wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/models")) + .and(wiremock::matchers::path("/v1/models")) .respond_with( wiremock::ResponseTemplate::new(200) .set_body_raw(serde_json::json!({}).to_string(), "application/json"), @@ -416,7 +406,7 @@ mod tests { let server = wiremock::MockServer::start().await; wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/models")) + .and(wiremock::matchers::path("/v1/models")) .respond_with(wiremock::ResponseTemplate::new(500)) .mount(&server) .await; @@ -444,7 +434,7 @@ mod tests { let server = wiremock::MockServer::start().await; wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/models")) + .and(wiremock::matchers::path("/v1/models")) .respond_with(wiremock::ResponseTemplate::new(200)) .mount(&server) .await; @@ -468,7 +458,7 @@ mod tests { let server = wiremock::MockServer::start().await; wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/models")) + .and(wiremock::matchers::path("/v1/models")) .respond_with(wiremock::ResponseTemplate::new(404)) .mount(&server) .await; @@ -503,7 +493,7 @@ mod tests { let client = LMStudioClient::from_host_root(format!("{uri}/v1", uri = server.uri())); client - .load_model("openai/gpt-oss-20b") + .load_model("test/test-model") .await .expect("load model"); } @@ -526,9 +516,10 @@ mod tests { serde_json::json!({ "models": [ { - "key": "openai/gpt-oss-20b", + "key": "test/test-model", "loaded_instances": [ { + "id": "test/test-model", "config": { "context_length": 7000 } @@ -550,7 +541,7 @@ mod tests { .await; let client = LMStudioClient::from_host_root(server.uri()); - let result = client.load_model("openai/gpt-oss-20b").await; + let result = client.load_model("test/test-model").await; assert!(result.is_ok()); } @@ -572,9 +563,10 @@ mod tests { serde_json::json!({ "models": [ { - "key": "openai/gpt-oss-20b", + "key": "test/test-model", "loaded_instances": [ { + "id": "test/test-model", "config": { "context_length": 7000 } @@ -591,7 +583,41 @@ mod tests { .await; let client = LMStudioClient::from_host_root(server.uri()); - assert!(client.is_model_loaded("openai/gpt-oss-20b").await); + assert!(client.is_model_loaded("test/test-model").await); + } + + #[tokio::test] + async fn test_is_model_loaded_false() { + if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_is_model_loaded_false", + codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = wiremock::MockServer::start().await; + wiremock::Mock::given(wiremock::matchers::method("GET")) + .and(wiremock::matchers::path("/api/v1/models")) + .respond_with( + wiremock::ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "models": [ + { + "key": "test/test-model", + "loaded_instances": [] + } + ] + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + assert!(!client.is_model_loaded("test/test-model").await); } #[tokio::test] @@ -612,7 +638,7 @@ mod tests { .await; let client = LMStudioClient::from_host_root(format!("{uri}/v1", uri = server.uri())); - let result = client.load_model("openai/gpt-oss-20b").await; + let result = client.load_model("test/test-model").await; assert!(result.is_err()); assert!( result @@ -667,7 +693,7 @@ mod tests { let client = LMStudioClient::from_host_root(format!("{uri}/v1", uri = server.uri())); client - .download_model("openai/gpt-oss-20b") + .download_model("test/test-model") .await .expect("download model"); } @@ -716,7 +742,7 @@ mod tests { .await; let client = LMStudioClient::from_host_root(format!("{uri}/v1", uri = server.uri())); - let result = client.download_model("openai/gpt-oss-20b").await; + let result = client.download_model("test/test-model").await; assert!(result.is_err()); assert!( result From bd9d37c6fdf5bcc3494667168d2d028a6c21772e Mon Sep 17 00:00:00 2001 From: Rugved Somwanshi Date: Mon, 23 Feb 2026 18:30:32 -0500 Subject: [PATCH 04/11] Remove unnecesarry diffs --- codex-rs/lmstudio/src/client.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/codex-rs/lmstudio/src/client.rs b/codex-rs/lmstudio/src/client.rs index 3c96c546e84..17b06865c93 100644 --- a/codex-rs/lmstudio/src/client.rs +++ b/codex-rs/lmstudio/src/client.rs @@ -59,8 +59,8 @@ impl LMStudioClient { Ok(()) } else { Err(io::Error::other(format!( - "Server returned error: {status} {LMSTUDIO_CONNECTION_ERROR}", - status = resp.status() + "Server returned error: {} {LMSTUDIO_CONNECTION_ERROR}", + resp.status() ))) } } else { @@ -348,7 +348,7 @@ mod tests { wiremock::ResponseTemplate::new(200).set_body_raw( serde_json::json!({ "data": [ - {"id": "test/test-model"}, + {"id": "openai/gpt-oss-20b"}, ] }) .to_string(), @@ -360,7 +360,7 @@ mod tests { let client = LMStudioClient::from_host_root(server.uri()); let models = client.fetch_models().await.expect("fetch models"); - assert!(models.contains(&"test/test-model".to_string())); + assert!(models.contains(&"openai/gpt-oss-20b".to_string())); } #[tokio::test] @@ -493,7 +493,7 @@ mod tests { let client = LMStudioClient::from_host_root(format!("{uri}/v1", uri = server.uri())); client - .load_model("test/test-model") + .load_model("openai/gpt-oss-20b") .await .expect("load model"); } @@ -638,7 +638,7 @@ mod tests { .await; let client = LMStudioClient::from_host_root(format!("{uri}/v1", uri = server.uri())); - let result = client.load_model("test/test-model").await; + let result = client.load_model("openai/gpt-oss-20b").await; assert!(result.is_err()); assert!( result @@ -693,7 +693,7 @@ mod tests { let client = LMStudioClient::from_host_root(format!("{uri}/v1", uri = server.uri())); client - .download_model("test/test-model") + .download_model("openai/gpt-oss-20b") .await .expect("download model"); } @@ -742,7 +742,7 @@ mod tests { .await; let client = LMStudioClient::from_host_root(format!("{uri}/v1", uri = server.uri())); - let result = client.download_model("test/test-model").await; + let result = client.download_model("openai/gpt-oss-20b").await; assert!(result.is_err()); assert!( result From 8ad7c35e3e78af48cb041da642f4ce34642c6663 Mon Sep 17 00:00:00 2001 From: Rugved Somwanshi Date: Mon, 23 Feb 2026 18:34:38 -0500 Subject: [PATCH 05/11] Fix warning tewt --- codex-rs/lmstudio/src/client.rs | 42 ++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/codex-rs/lmstudio/src/client.rs b/codex-rs/lmstudio/src/client.rs index 17b06865c93..00e340d895b 100644 --- a/codex-rs/lmstudio/src/client.rs +++ b/codex-rs/lmstudio/src/client.rs @@ -124,8 +124,8 @@ impl LMStudioClient { Ok(()) } else { Err(io::Error::other(format!( - "Failed to load model: {status}", - status = response.status() + "Failed to load model: {}", + response.status() ))) } } @@ -189,17 +189,16 @@ impl LMStudioClient { io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}")) })?; - let parse_status = |json: &serde_json::Value| -> io::Result<( - String, - Option, - Option, - Option, - )> { + type DownloadStatus = (String, Option, Option, Option); + + let parse_status = |json: &serde_json::Value| -> io::Result { let status = json["status"] .as_str() .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Missing status"))? .to_string(); - let job_id = json["job_id"].as_str().map(std::string::ToString::to_string); + let job_id = json["job_id"] + .as_str() + .map(std::string::ToString::to_string); let downloaded_bytes = json["downloaded_bytes"].as_u64(); let total_size_bytes = json["total_size_bytes"].as_u64(); Ok((status, job_id, downloaded_bytes, total_size_bytes)) @@ -242,8 +241,8 @@ impl LMStudioClient { if !status_response.status().is_success() { return Err(io::Error::other(format!( - "Failed to fetch download status: {status}", - status = status_response.status() + "Failed to fetch download status: {}", + status_response.status() ))); } @@ -287,10 +286,10 @@ impl LMStudioClient { >= std::time::Duration::from_millis(500) { let percent = (downloaded as f64 / total as f64) * 100.0; - let downloaded_mb = downloaded as f64 / (1024.0 * 1024.0); - let total_gb = total as f64 / (1024.0 * 1024.0 * 1024.0); eprint!( - "\rDownloading '{model}': {downloaded_mb:.2} MB / {total_gb:.2} GB ({percent:.1}%)" + "\rDownloading '{model}': {} / {} ({percent:.1}%)", + format_bytes(downloaded), + format_bytes(total) ); let _ = std::io::stderr().flush(); last_logged = now; @@ -326,6 +325,21 @@ impl LMStudioClient { } } +fn format_bytes(bytes: u64) -> String { + const UNITS: &[&str] = &["B", "KiB", "MiB", "GiB", "TiB"]; + let mut size = bytes as f64; + let mut i = 0; + while size >= 1024.0 && i < UNITS.len() - 1 { + size /= 1024.0; + i += 1; + } + if i == 0 { + format!("{size} B") + } else { + format!("{size:.2} {}", UNITS[i]) + } +} + #[cfg(test)] mod tests { #![allow(clippy::expect_used, clippy::unwrap_used)] From a7120081998885bcb748cd833a658cfcaa005175 Mon Sep 17 00:00:00 2001 From: Rugved Somwanshi Date: Mon, 23 Feb 2026 18:52:49 -0500 Subject: [PATCH 06/11] Reduce diff --- codex-rs/lmstudio/src/client.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/codex-rs/lmstudio/src/client.rs b/codex-rs/lmstudio/src/client.rs index 00e340d895b..0ad2ae80cbc 100644 --- a/codex-rs/lmstudio/src/client.rs +++ b/codex-rs/lmstudio/src/client.rs @@ -156,8 +156,8 @@ impl LMStudioClient { Ok(models) } else { Err(io::Error::other(format!( - "Failed to fetch models: {status}", - status = response.status() + "Failed to fetch models: {}", + response.status() ))) } } @@ -180,8 +180,8 @@ impl LMStudioClient { if !response.status().is_success() { return Err(io::Error::other(format!( - "Failed to download model: {status}", - status = response.status() + "Failed to download model: {}", + response.status() ))); } From 57470993a65d75fdb08ea31a0b58a6a4e6825461 Mon Sep 17 00:00:00 2001 From: Rugved Somwanshi Date: Mon, 23 Feb 2026 19:05:16 -0500 Subject: [PATCH 07/11] Switch to api/v1/models --- codex-rs/Cargo.lock | 4 +- codex-rs/exec/src/lib.rs | 16 +- codex-rs/lmstudio/Cargo.toml | 6 +- codex-rs/lmstudio/src/client.rs | 639 +++++++++++++++++++++++++------- codex-rs/lmstudio/src/lib.rs | 1 + codex-rs/tui/src/lib.rs | 16 +- codex-rs/utils/oss/src/lib.rs | 17 + 7 files changed, 560 insertions(+), 139 deletions(-) diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index a68f99425cb..608ab49b968 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2032,11 +2032,13 @@ name = "codex-lmstudio" version = "0.0.0" dependencies = [ "codex-core", + "codex-protocol", + "pretty_assertions", "reqwest", + "serde", "serde_json", "tokio", "tracing", - "which", "wiremock", ] diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index e8ab6f8af74..11b602e6db8 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -47,6 +47,7 @@ use codex_protocol::protocol::SubAgentSource; use codex_protocol::user_input::UserInput; use codex_utils_absolute_path::AbsolutePathBuf; use codex_utils_oss::ensure_oss_provider_ready; +use codex_utils_oss::fetch_oss_model_catalog; use codex_utils_oss::get_default_model_for_oss_provider; use event_processor_with_human_output::EventProcessorWithHumanOutput; use event_processor_with_jsonl_output::EventProcessorWithJsonOutput; @@ -289,7 +290,7 @@ pub async fn run_main(cli: Cli, arg0_paths: Arg0DispatchPaths) -> anyhow::Result additional_writable_roots: add_dir, }; - let config = ConfigBuilder::default() + let mut config = ConfigBuilder::default() .cli_overrides(cli_kv_overrides) .harness_overrides(overrides) .cloud_requirements(cloud_requirements) @@ -373,6 +374,19 @@ pub async fn run_main(cli: Cli, arg0_paths: Arg0DispatchPaths) -> anyhow::Result .map_err(|e| anyhow::anyhow!("OSS setup failed: {e}"))?; } + if config.model_catalog.is_none() { + let provider_id = config.model_provider_id.as_str(); + match fetch_oss_model_catalog(provider_id, &config).await { + Ok(Some(catalog)) => { + config.model_catalog = Some(catalog); + } + Ok(None) => {} + Err(err) => { + warn!("Failed to fetch OSS model catalog for {provider_id}: {err}"); + } + } + } + let default_cwd = config.cwd.to_path_buf(); let default_approval_policy = config.permissions.approval_policy.value(); let default_sandbox_policy = config.permissions.sandbox_policy.get(); diff --git a/codex-rs/lmstudio/Cargo.toml b/codex-rs/lmstudio/Cargo.toml index 5f4849638ae..705be1870ef 100644 --- a/codex-rs/lmstudio/Cargo.toml +++ b/codex-rs/lmstudio/Cargo.toml @@ -11,13 +11,15 @@ path = "src/lib.rs" [dependencies] codex-core = { path = "../core" } +codex-protocol = { path = "../protocol" } reqwest = { version = "0.12", features = ["json", "stream"] } +serde = { workspace = true, features = ["derive"] } serde_json = "1" -tokio = { version = "1", features = ["rt"] } +tokio = { version = "1", features = ["rt", "time"] } tracing = { version = "0.1.44", features = ["log"] } -which = "8.0" [dev-dependencies] +pretty_assertions = { workspace = true } wiremock = "0.6" tokio = { version = "1", features = ["full"] } diff --git a/codex-rs/lmstudio/src/client.rs b/codex-rs/lmstudio/src/client.rs index 0ad2ae80cbc..5bd989325b0 100644 --- a/codex-rs/lmstudio/src/client.rs +++ b/codex-rs/lmstudio/src/client.rs @@ -1,7 +1,18 @@ use codex_core::LMSTUDIO_OSS_PROVIDER_ID; use codex_core::config::Config; +use codex_core::models_manager::model_info::BASE_INSTRUCTIONS; +use codex_protocol::config_types::ReasoningSummary; +use codex_protocol::openai_models::ApplyPatchToolType; +use codex_protocol::openai_models::ConfigShellToolType; +use codex_protocol::openai_models::InputModality; +use codex_protocol::openai_models::ModelInfo; +use codex_protocol::openai_models::ModelVisibility; +use codex_protocol::openai_models::TruncationPolicyConfig; +use serde::Deserialize; use std::io; use std::io::Write; +use std::time::Duration; +use std::time::Instant; #[derive(Clone)] pub struct LMStudioClient { @@ -14,7 +25,7 @@ const LMSTUDIO_CONNECTION_ERROR: &str = "LM Studio is not responding. Install fr const DEFAULT_CONTEXT_LENGTH: u32 = 8192; impl LMStudioClient { - pub async fn try_from_provider(config: &Config) -> std::io::Result { + pub async fn try_from_provider(config: &Config) -> io::Result { let provider = config .model_providers .get(LMSTUDIO_OSS_PROVIDER_ID) @@ -32,7 +43,7 @@ impl LMStudioClient { })?; let client = reqwest::Client::builder() - .connect_timeout(std::time::Duration::from_secs(5)) + .connect_timeout(Duration::from_secs(5)) .build() .unwrap_or_else(|_| reqwest::Client::new()); @@ -51,7 +62,7 @@ impl LMStudioClient { } async fn check_server(&self) -> io::Result<()> { - let url = format!("{}/v1/models", self.host_root()); + let url = format!("{}/api/v1/models", self.host_root()); let response = self.client.get(&url).send().await; if let Ok(resp) = response { @@ -132,7 +143,7 @@ impl LMStudioClient { // Return the list of models available on the LM Studio server. pub async fn fetch_models(&self) -> io::Result> { - let url = format!("{}/v1/models", self.host_root()); + let url = format!("{}/api/v1/models", self.host_root()); let response = self .client .get(&url) @@ -144,14 +155,14 @@ impl LMStudioClient { let json: serde_json::Value = response.json().await.map_err(|e| { io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}")) })?; - let models = json["data"] + let models = json["models"] .as_array() .ok_or_else(|| { - io::Error::new(io::ErrorKind::InvalidData, "No 'data' array in response") + io::Error::new(io::ErrorKind::InvalidData, "No 'models' array in response") })? .iter() - .filter_map(|model| model["id"].as_str()) - .map(std::string::ToString::to_string) + .filter_map(|model| model["key"].as_str()) + .map(String::from) .collect(); Ok(models) } else { @@ -162,7 +173,104 @@ impl LMStudioClient { } } - pub async fn download_model(&self, model: &str) -> std::io::Result<()> { + /// Return model metadata from the LM Studio server. + pub async fn fetch_model_metadata(&self) -> io::Result> { + let url = format!("{}/api/v1/models", self.host_root()); + let response = self + .client + .get(&url) + .send() + .await + .map_err(|e| io::Error::other(format!("Request failed: {e}")))?; + + if !response.status().is_success() { + return Err(io::Error::other(format!( + "Failed to fetch models: {}", + response.status() + ))); + } + + let json: LMStudioModelsResponse = response.json().await.map_err(|e| { + io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}")) + })?; + + let models = json + .models + .into_iter() + .filter(|model| matches!(model.model_type.as_deref(), None | Some("llm"))) + .enumerate() + .map(|(index, model)| { + let context_window = model + .loaded_instances + .as_ref() + .and_then(|instances| { + instances + .iter() + .filter_map(|instance| { + instance + .config + .as_ref() + .and_then(|config| config.context_length) + }) + .max() + }) + .or(model.max_context_length); + let supports_vision = model + .capabilities + .as_ref() + .and_then(|capabilities| capabilities.vision) + .unwrap_or(false); + let trained_for_tool_use = model + .capabilities + .as_ref() + .and_then(|capabilities| capabilities.trained_for_tool_use) + .unwrap_or(false); + let input_modalities = if supports_vision { + vec![InputModality::Text, InputModality::Image] + } else { + vec![InputModality::Text] + }; + + ModelInfo { + slug: model.key.clone(), + display_name: model + .display_name + .clone() + .unwrap_or_else(|| model.key.clone()), + description: model.description, + default_reasoning_level: None, // TODO: Update after we surface it in the API + supported_reasoning_levels: Vec::new(), // TODO: Update after we surface it in the API + shell_type: ConfigShellToolType::Default, + visibility: ModelVisibility::List, + supported_in_api: true, + priority: i32::try_from(index).unwrap_or(i32::MAX), + availability_nux: None, + upgrade: None, + base_instructions: BASE_INSTRUCTIONS.to_string(), + model_messages: None, + supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::None, // Intentional as we don't support summary + support_verbosity: false, + default_verbosity: None, + apply_patch_tool_type: trained_for_tool_use + .then_some(ApplyPatchToolType::Function), + truncation_policy: TruncationPolicyConfig::bytes(10_000), + supports_parallel_tool_calls: false, + context_window, + auto_compact_token_limit: None, + effective_context_window_percent: 95, + experimental_supported_tools: Vec::new(), + input_modalities, + prefer_websockets: false, + used_fallback_model_metadata: false, + } + }) + .collect(); + + Ok(models) + } + + pub async fn download_model(&self, model: &str) -> io::Result<()> { let url = format!("{}/api/v1/models/download", self.host_root()); let request_body = serde_json::json!({ @@ -189,22 +297,9 @@ impl LMStudioClient { io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}")) })?; - type DownloadStatus = (String, Option, Option, Option); - - let parse_status = |json: &serde_json::Value| -> io::Result { - let status = json["status"] - .as_str() - .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Missing status"))? - .to_string(); - let job_id = json["job_id"] - .as_str() - .map(std::string::ToString::to_string); - let downloaded_bytes = json["downloaded_bytes"].as_u64(); - let total_size_bytes = json["total_size_bytes"].as_u64(); - Ok((status, job_id, downloaded_bytes, total_size_bytes)) - }; - - let (status, job_id, _, _) = parse_status(&download_status)?; + let initial = parse_download_status(&download_status)?; + let status = initial.status; + let job_id = initial.job_id; match status.as_str() { "already_downloaded" | "completed" => { @@ -222,11 +317,18 @@ impl LMStudioClient { io::Error::new(io::ErrorKind::InvalidData, "Download status missing job_id") })?; - let mut last_logged = - std::time::Instant::now() - std::time::Duration::from_secs(10); + let mut last_logged = Instant::now() - Duration::from_secs(10); + let mut attempts = 0u32; loop { - tokio::time::sleep(std::time::Duration::from_secs(2)).await; + tokio::time::sleep(DOWNLOAD_POLL_INTERVAL).await; + attempts += 1; + if attempts > MAX_DOWNLOAD_POLL_ATTEMPTS { + eprintln!(); + return Err(io::Error::other(format!( + "Timed out waiting for model '{model}' to download" + ))); + } let status_url = format!( "{}/api/v1/models/download/status/{job_id}", self.host_root() @@ -256,8 +358,10 @@ impl LMStudioClient { format!("JSON parse error: {e}"), ) })?; - let (status_value, _, downloaded_bytes, total_size_bytes) = - parse_status(&status)?; + let poll = parse_download_status(&status)?; + let status_value = poll.status; + let downloaded_bytes = poll.downloaded_bytes; + let total_size_bytes = poll.total_size_bytes; match status_value.as_str() { "completed" => { @@ -278,20 +382,23 @@ impl LMStudioClient { ))); } "downloading" => { - if let (Some(downloaded), Some(total)) = - (downloaded_bytes, total_size_bytes) - { - let now = std::time::Instant::now(); - if now.duration_since(last_logged) - >= std::time::Duration::from_millis(500) - { - let percent = (downloaded as f64 / total as f64) * 100.0; - eprint!( - "\rDownloading '{model}': {} / {} ({percent:.1}%)", - format_bytes(downloaded), - format_bytes(total) - ); - let _ = std::io::stderr().flush(); + if let Some(downloaded) = downloaded_bytes { + let now = Instant::now(); + if now.duration_since(last_logged) >= Duration::from_millis(500) { + if let Some(total) = total_size_bytes { + let percent = (downloaded as f64 / total as f64) * 100.0; + eprint!( + "\rDownloading '{model}': {} / {} ({percent:.1}%)", + format_bytes(downloaded), + format_bytes(total) + ); + } else { + eprint!( + "\rDownloading '{model}': {}", + format_bytes(downloaded) + ); + } + let _ = io::stderr().flush(); last_logged = now; } } @@ -315,7 +422,7 @@ impl LMStudioClient { #[cfg(test)] fn from_host_root(host_root: impl Into) -> Self { let client = reqwest::Client::builder() - .connect_timeout(std::time::Duration::from_secs(5)) + .connect_timeout(Duration::from_secs(5)) .build() .unwrap_or_else(|_| reqwest::Client::new()); Self { @@ -325,6 +432,72 @@ impl LMStudioClient { } } +#[derive(Debug, Deserialize)] +struct LMStudioModelsResponse { + models: Vec, +} + +#[derive(Debug, Deserialize)] +struct LMStudioModel { + key: String, + display_name: Option, + description: Option, + #[serde(rename = "type")] + model_type: Option, + max_context_length: Option, + capabilities: Option, + loaded_instances: Option>, +} + +#[derive(Debug, Deserialize)] +struct LMStudioCapabilities { + vision: Option, + trained_for_tool_use: Option, +} + +#[derive(Debug, Deserialize)] +struct LMStudioLoadedInstance { + id: Option, + config: Option, +} + +#[derive(Debug, Deserialize)] +struct LMStudioInstanceConfig { + context_length: Option, +} + +// Poll every 2 seconds in production; use a short interval in tests to avoid slowness. +#[cfg(not(test))] +const DOWNLOAD_POLL_INTERVAL: Duration = Duration::from_secs(2); +#[cfg(test)] +const DOWNLOAD_POLL_INTERVAL: Duration = Duration::from_millis(10); + +// Allow ~2 hours of polling (at 2s intervals) before giving up. +const MAX_DOWNLOAD_POLL_ATTEMPTS: u32 = 3600; + +struct DownloadStatusResponse { + status: String, + job_id: Option, + downloaded_bytes: Option, + total_size_bytes: Option, +} + +fn parse_download_status(json: &serde_json::Value) -> io::Result { + let status = json["status"] + .as_str() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Missing status"))? + .to_string(); + let job_id = json["job_id"].as_str().map(String::from); + let downloaded_bytes = json["downloaded_bytes"].as_u64(); + let total_size_bytes = json["total_size_bytes"].as_u64(); + Ok(DownloadStatusResponse { + status, + job_id, + downloaded_bytes, + total_size_bytes, + }) +} + fn format_bytes(bytes: u64) -> String { const UNITS: &[&str] = &["B", "KiB", "MiB", "GiB", "TiB"]; let mut size = bytes as f64; @@ -344,26 +517,30 @@ fn format_bytes(bytes: u64) -> String { mod tests { #![allow(clippy::expect_used, clippy::unwrap_used)] use super::*; + use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; + use pretty_assertions::assert_eq; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; #[tokio::test] async fn test_fetch_models_happy_path() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_fetch_models_happy_path", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/v1/models")) + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) .respond_with( - wiremock::ResponseTemplate::new(200).set_body_raw( + ResponseTemplate::new(200).set_body_raw( serde_json::json!({ - "data": [ - {"id": "openai/gpt-oss-20b"}, - ] + "models": [ + {"key": "openai/gpt-oss-20b"}, + ], }) .to_string(), "application/json", @@ -379,19 +556,19 @@ mod tests { #[tokio::test] async fn test_fetch_models_no_data_array() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_fetch_models_no_data_array", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/v1/models")) + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) .respond_with( - wiremock::ResponseTemplate::new(200) + ResponseTemplate::new(200) .set_body_raw(serde_json::json!({}).to_string(), "application/json"), ) .mount(&server) @@ -404,24 +581,24 @@ mod tests { result .unwrap_err() .to_string() - .contains("No 'data' array in response") + .contains("No 'models' array in response") ); } #[tokio::test] async fn test_fetch_models_server_error() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_fetch_models_server_error", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/v1/models")) - .respond_with(wiremock::ResponseTemplate::new(500)) + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with(ResponseTemplate::new(500)) .mount(&server) .await; @@ -436,20 +613,142 @@ mod tests { ); } + #[tokio::test] + async fn test_fetch_model_metadata_filters_and_maps_fields() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_fetch_model_metadata_filters_and_maps_fields", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with( + ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "models": [ + { + "key": "openai/gpt-oss-20b", + "display_name": "GPT-OSS-20B (LM Studio)", + "description": "OSS model", + "type": "llm", + "max_context_length": 100_000, + "capabilities": { + "vision": true, + "trained_for_tool_use": true + }, + "loaded_instances": [ + { + "config": { + "context_length": 90_000 + } + } + ] + }, + { + "key": "embed/text", + "display_name": "Embedding", + "type": "embedding", + "max_context_length": 1024 + }, + { + "key": "llm/second", + "type": "llm", + "max_context_length": 4096 + } + ] + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + let models = client.fetch_model_metadata().await.expect("fetch metadata"); + + let expected = vec![ + ModelInfo { + slug: "openai/gpt-oss-20b".to_string(), + display_name: "GPT-OSS-20B (LM Studio)".to_string(), + description: Some("OSS model".to_string()), + default_reasoning_level: None, + supported_reasoning_levels: Vec::new(), + shell_type: ConfigShellToolType::Default, + visibility: ModelVisibility::List, + supported_in_api: true, + priority: 0, + availability_nux: None, + upgrade: None, + base_instructions: BASE_INSTRUCTIONS.to_string(), + model_messages: None, + supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::None, + support_verbosity: false, + default_verbosity: None, + apply_patch_tool_type: Some(ApplyPatchToolType::Function), + truncation_policy: TruncationPolicyConfig::bytes(10_000), + supports_parallel_tool_calls: false, + context_window: Some(90_000), + auto_compact_token_limit: None, + effective_context_window_percent: 95, + experimental_supported_tools: Vec::new(), + input_modalities: vec![InputModality::Text, InputModality::Image], + prefer_websockets: false, + used_fallback_model_metadata: false, + }, + ModelInfo { + slug: "llm/second".to_string(), + display_name: "llm/second".to_string(), + description: None, + default_reasoning_level: None, + supported_reasoning_levels: Vec::new(), + shell_type: ConfigShellToolType::Default, + visibility: ModelVisibility::List, + supported_in_api: true, + priority: 1, + availability_nux: None, + upgrade: None, + base_instructions: BASE_INSTRUCTIONS.to_string(), + model_messages: None, + supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::None, + support_verbosity: false, + default_verbosity: None, + apply_patch_tool_type: None, + truncation_policy: TruncationPolicyConfig::bytes(10_000), + supports_parallel_tool_calls: false, + context_window: Some(4096), + auto_compact_token_limit: None, + effective_context_window_percent: 95, + experimental_supported_tools: Vec::new(), + input_modalities: vec![InputModality::Text], + prefer_websockets: false, + used_fallback_model_metadata: false, + }, + ]; + + assert_eq!(models, expected); + } + #[tokio::test] async fn test_check_server_happy_path() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_check_server_happy_path", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/v1/models")) - .respond_with(wiremock::ResponseTemplate::new(200)) + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with(ResponseTemplate::new(200)) .mount(&server) .await; @@ -462,18 +761,18 @@ mod tests { #[tokio::test] async fn test_check_server_error() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_check_server_error", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/v1/models")) - .respond_with(wiremock::ResponseTemplate::new(404)) + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with(ResponseTemplate::new(404)) .mount(&server) .await; @@ -490,18 +789,26 @@ mod tests { #[tokio::test] async fn test_load_model_happy_path() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_load_model_happy_path", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("POST")) - .and(wiremock::matchers::path("/api/v1/models/load")) - .respond_with(wiremock::ResponseTemplate::new(200)) + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with(ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ "models": [] }).to_string(), + "application/json", + )) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/api/v1/models/load")) + .respond_with(ResponseTemplate::new(200)) .mount(&server) .await; @@ -514,19 +821,19 @@ mod tests { #[tokio::test] async fn test_load_model_reuses_loaded_instance() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_load_model_reuses_loaded_instance", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/api/v1/models")) + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) .respond_with( - wiremock::ResponseTemplate::new(200).set_body_raw( + ResponseTemplate::new(200).set_body_raw( serde_json::json!({ "models": [ { @@ -548,9 +855,9 @@ mod tests { ) .mount(&server) .await; - wiremock::Mock::given(wiremock::matchers::method("POST")) - .and(wiremock::matchers::path("/api/v1/models/load")) - .respond_with(wiremock::ResponseTemplate::new(500)) + Mock::given(method("POST")) + .and(path("/api/v1/models/load")) + .respond_with(ResponseTemplate::new(500)) .mount(&server) .await; @@ -561,19 +868,19 @@ mod tests { #[tokio::test] async fn test_is_model_loaded_true() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_is_model_loaded_true", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/api/v1/models")) + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) .respond_with( - wiremock::ResponseTemplate::new(200).set_body_raw( + ResponseTemplate::new(200).set_body_raw( serde_json::json!({ "models": [ { @@ -602,19 +909,19 @@ mod tests { #[tokio::test] async fn test_is_model_loaded_false() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_is_model_loaded_false", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/api/v1/models")) + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) .respond_with( - wiremock::ResponseTemplate::new(200).set_body_raw( + ResponseTemplate::new(200).set_body_raw( serde_json::json!({ "models": [ { @@ -636,18 +943,18 @@ mod tests { #[tokio::test] async fn test_load_model_error() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_load_model_error", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("POST")) - .and(wiremock::matchers::path("/api/v1/models/load")) - .respond_with(wiremock::ResponseTemplate::new(500)) + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v1/models/load")) + .respond_with(ResponseTemplate::new(500)) .mount(&server) .await; @@ -664,19 +971,19 @@ mod tests { #[tokio::test] async fn test_download_model_happy_path() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_download_model_happy_path", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("POST")) - .and(wiremock::matchers::path("/api/v1/models/download")) + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v1/models/download")) .respond_with( - wiremock::ResponseTemplate::new(200).set_body_raw( + ResponseTemplate::new(200).set_body_raw( serde_json::json!({ "job_id": "job-1", "status": "downloading" @@ -688,12 +995,10 @@ mod tests { .mount(&server) .await; - wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path( - "/api/v1/models/download/status/job-1", - )) + Mock::given(method("GET")) + .and(path("/api/v1/models/download/status/job-1")) .respond_with( - wiremock::ResponseTemplate::new(200).set_body_raw( + ResponseTemplate::new(200).set_body_raw( serde_json::json!({ "job_id": "job-1", "status": "completed" @@ -714,19 +1019,19 @@ mod tests { #[tokio::test] async fn test_download_model_error() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_download_model_error", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("POST")) - .and(wiremock::matchers::path("/api/v1/models/download")) + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v1/models/download")) .respond_with( - wiremock::ResponseTemplate::new(200).set_body_raw( + ResponseTemplate::new(200).set_body_raw( serde_json::json!({ "job_id": "job-1", "status": "downloading" @@ -738,12 +1043,10 @@ mod tests { .mount(&server) .await; - wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path( - "/api/v1/models/download/status/job-1", - )) + Mock::given(method("GET")) + .and(path("/api/v1/models/download/status/job-1")) .respond_with( - wiremock::ResponseTemplate::new(200).set_body_raw( + ResponseTemplate::new(200).set_body_raw( serde_json::json!({ "job_id": "job-1", "status": "failed" @@ -766,6 +1069,74 @@ mod tests { ); } + #[tokio::test] + async fn test_download_model_already_downloaded() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_download_model_already_downloaded", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v1/models/download")) + .respond_with( + ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "status": "already_downloaded" + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + client + .download_model("openai/gpt-oss-20b") + .await + .expect("already_downloaded should succeed"); + } + + #[tokio::test] + async fn test_download_model_paused() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_download_model_paused", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v1/models/download")) + .respond_with( + ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "status": "paused" + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + let result = client.download_model("openai/gpt-oss-20b").await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Model download paused") + ); + } + #[test] fn test_from_host_root() { let client = LMStudioClient::from_host_root("http://localhost:1234"); diff --git a/codex-rs/lmstudio/src/lib.rs b/codex-rs/lmstudio/src/lib.rs index fd4f82a728a..3df0ece2bc9 100644 --- a/codex-rs/lmstudio/src/lib.rs +++ b/codex-rs/lmstudio/src/lib.rs @@ -2,6 +2,7 @@ mod client; pub use client::LMStudioClient; use codex_core::config::Config; +pub use codex_protocol::openai_models::ModelsResponse; /// Default OSS model to use when `--oss` is passed without an explicit `-m`. pub const DEFAULT_OSS_MODEL: &str = "openai/gpt-oss-20b"; diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index 8dfdbcd8717..219174df565 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -44,6 +44,7 @@ use codex_protocol::protocol::RolloutLine; use codex_state::log_db; use codex_utils_absolute_path::AbsolutePathBuf; use codex_utils_oss::ensure_oss_provider_ready; +use codex_utils_oss::fetch_oss_model_catalog; use codex_utils_oss::get_default_model_for_oss_provider; use cwd_prompt::CwdPromptAction; use cwd_prompt::CwdPromptOutcome; @@ -591,7 +592,7 @@ async fn run_ratatui_app( should_show_onboarding(login_status, &initial_config, should_show_trust_screen_flag); let mut trust_decision_was_made = false; - let config = if should_show_onboarding { + let mut config = if should_show_onboarding { let show_login_screen = should_show_login_screen(login_status, &initial_config); let onboarding_result = run_onboarding_app( OnboardingScreenArgs { @@ -644,6 +645,19 @@ async fn run_ratatui_app( initial_config }; + if config.model_catalog.is_none() { + let provider_id = config.model_provider_id.as_str(); + match fetch_oss_model_catalog(provider_id, &config).await { + Ok(Some(catalog)) => { + config.model_catalog = Some(catalog); + } + Ok(None) => {} + Err(err) => { + tracing::warn!("Failed to fetch OSS model catalog for {provider_id}: {err}"); + } + } + } + let mut missing_session_exit = |id_str: &str, action: &str| { error!("Error finding conversation path: {id_str}"); restore(); diff --git a/codex-rs/utils/oss/src/lib.rs b/codex-rs/utils/oss/src/lib.rs index a44a6a7d326..8c513d8873f 100644 --- a/codex-rs/utils/oss/src/lib.rs +++ b/codex-rs/utils/oss/src/lib.rs @@ -3,6 +3,7 @@ use codex_core::LMSTUDIO_OSS_PROVIDER_ID; use codex_core::OLLAMA_OSS_PROVIDER_ID; use codex_core::config::Config; +use codex_lmstudio::ModelsResponse; /// Returns the default model for a given OSS provider. pub fn get_default_model_for_oss_provider(provider_id: &str) -> Option<&'static str> { @@ -37,6 +38,22 @@ pub async fn ensure_oss_provider_ready( Ok(()) } +/// Fetch a provider-specific model catalog, if supported. +pub async fn fetch_oss_model_catalog( + provider_id: &str, + config: &Config, +) -> Result, std::io::Error> { + match provider_id { + LMSTUDIO_OSS_PROVIDER_ID => { + let client = codex_lmstudio::LMStudioClient::try_from_provider(config).await?; + let models = client.fetch_model_metadata().await?; + Ok(Some(ModelsResponse { models })) + } + OLLAMA_OSS_PROVIDER_ID => Ok(None), + _ => Ok(None), + } +} + #[cfg(test)] mod tests { use super::*; From a3c3ca823614308b99be94f6b9fe5e2840f9a81a Mon Sep 17 00:00:00 2001 From: Rugved Somwanshi Date: Thu, 19 Mar 2026 17:50:43 -0400 Subject: [PATCH 08/11] Add support for reasoning effort --- codex-rs/lmstudio/src/client.rs | 278 ++++++++++++++++++++++++++++++-- 1 file changed, 265 insertions(+), 13 deletions(-) diff --git a/codex-rs/lmstudio/src/client.rs b/codex-rs/lmstudio/src/client.rs index 5bd989325b0..b01a5dbe41a 100644 --- a/codex-rs/lmstudio/src/client.rs +++ b/codex-rs/lmstudio/src/client.rs @@ -7,6 +7,8 @@ use codex_protocol::openai_models::ConfigShellToolType; use codex_protocol::openai_models::InputModality; use codex_protocol::openai_models::ModelInfo; use codex_protocol::openai_models::ModelVisibility; +use codex_protocol::openai_models::ReasoningEffort; +use codex_protocol::openai_models::ReasoningEffortPreset; use codex_protocol::openai_models::TruncationPolicyConfig; use serde::Deserialize; use std::io; @@ -21,8 +23,7 @@ pub struct LMStudioClient { } const LMSTUDIO_CONNECTION_ERROR: &str = "LM Studio is not responding. Install from https://lmstudio.ai/download and run 'lms server start'."; -// 8192 tokens provides headroom above the observed ~6000 token initial input. -const DEFAULT_CONTEXT_LENGTH: u32 = 8192; +const DEFAULT_CONTEXT_LENGTH: u32 = 64000; impl LMStudioClient { pub async fn try_from_provider(config: &Config) -> io::Result { @@ -230,6 +231,15 @@ impl LMStudioClient { } else { vec![InputModality::Text] }; + let (default_reasoning_level, supported_reasoning_levels) = + parse_reasoning_capability( + model + .capabilities + .as_ref() + .and_then(|capabilities| capabilities.reasoning.as_ref()), + ); + let supports_reasoning = + default_reasoning_level.is_some() || !supported_reasoning_levels.is_empty(); ModelInfo { slug: model.key.clone(), @@ -238,8 +248,8 @@ impl LMStudioClient { .clone() .unwrap_or_else(|| model.key.clone()), description: model.description, - default_reasoning_level: None, // TODO: Update after we surface it in the API - supported_reasoning_levels: Vec::new(), // TODO: Update after we surface it in the API + default_reasoning_level, + supported_reasoning_levels, shell_type: ConfigShellToolType::Default, visibility: ModelVisibility::List, supported_in_api: true, @@ -248,7 +258,7 @@ impl LMStudioClient { upgrade: None, base_instructions: BASE_INSTRUCTIONS.to_string(), model_messages: None, - supports_reasoning_summaries: false, + supports_reasoning_summaries: supports_reasoning, default_reasoning_summary: ReasoningSummary::None, // Intentional as we don't support summary support_verbosity: false, default_verbosity: None, @@ -453,11 +463,25 @@ struct LMStudioModel { struct LMStudioCapabilities { vision: Option, trained_for_tool_use: Option, + reasoning: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum LMStudioReasoningCapability { + Enabled(bool), + Options(LMStudioReasoningOptions), +} + +#[derive(Debug, Deserialize)] +struct LMStudioReasoningOptions { + allowed_options: Option>, + #[serde(rename = "default")] + default_option: Option, } #[derive(Debug, Deserialize)] struct LMStudioLoadedInstance { - id: Option, config: Option, } @@ -513,14 +537,94 @@ fn format_bytes(bytes: u64) -> String { } } +fn parse_reasoning_effort(option: &str) -> Option { + let normalized = option.trim().to_ascii_lowercase(); + match normalized.as_str() { + "off" | "none" => Some(ReasoningEffort::None), + "on" => Some(ReasoningEffort::Medium), + "minimal" => Some(ReasoningEffort::Minimal), + "low" => Some(ReasoningEffort::Low), + "medium" => Some(ReasoningEffort::Medium), + "high" => Some(ReasoningEffort::High), + "xhigh" => Some(ReasoningEffort::XHigh), + _ => None, + } +} + +fn parse_reasoning_capability( + capability: Option<&LMStudioReasoningCapability>, +) -> (Option, Vec) { + let medium_only = ( + Some(ReasoningEffort::Medium), + vec![ReasoningEffortPreset { + effort: ReasoningEffort::Medium, + description: "medium".to_string(), + }], + ); + let fallback_presets = vec![ + ReasoningEffort::Low, + ReasoningEffort::Medium, + ReasoningEffort::High, + ] + .into_iter() + .map(|effort| ReasoningEffortPreset { + effort, + description: format!("{effort}"), + }) + .collect::>(); + let fallback = (Some(ReasoningEffort::Medium), fallback_presets); + + let Some(capability) = capability else { + return (None, Vec::new()); + }; + + match capability { + LMStudioReasoningCapability::Enabled(true) => medium_only, + LMStudioReasoningCapability::Enabled(false) => (None, Vec::new()), + LMStudioReasoningCapability::Options(options) => { + let mut efforts = Vec::new(); + if let Some(allowed_options) = options.allowed_options.as_ref() { + for option in allowed_options { + if let Some(effort) = parse_reasoning_effort(option) + && !efforts.contains(&effort) + { + efforts.push(effort); + } + } + } + + if efforts.is_empty() { + return fallback; + } + + let default_reasoning_level = options + .default_option + .as_deref() + .and_then(parse_reasoning_effort) + .or_else(|| efforts.first().copied()); + let supported_reasoning_levels = efforts + .into_iter() + .map(|effort| ReasoningEffortPreset { + effort, + description: format!("{effort}"), + }) + .collect(); + (default_reasoning_level, supported_reasoning_levels) + } + } +} + #[cfg(test)] mod tests { #![allow(clippy::expect_used, clippy::unwrap_used)] use super::*; use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; use pretty_assertions::assert_eq; - use wiremock::matchers::{method, path}; - use wiremock::{Mock, MockServer, ResponseTemplate}; + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::ResponseTemplate; + use wiremock::matchers::method; + use wiremock::matchers::path; #[tokio::test] async fn test_fetch_models_happy_path() { @@ -638,7 +742,11 @@ mod tests { "max_context_length": 100_000, "capabilities": { "vision": true, - "trained_for_tool_use": true + "trained_for_tool_use": true, + "reasoning": { + "allowed_options": ["low", "medium", "high"], + "default": "low" + } }, "loaded_instances": [ { @@ -676,8 +784,21 @@ mod tests { slug: "openai/gpt-oss-20b".to_string(), display_name: "GPT-OSS-20B (LM Studio)".to_string(), description: Some("OSS model".to_string()), - default_reasoning_level: None, - supported_reasoning_levels: Vec::new(), + default_reasoning_level: Some(ReasoningEffort::Low), + supported_reasoning_levels: vec![ + ReasoningEffortPreset { + effort: ReasoningEffort::Low, + description: "low".to_string(), + }, + ReasoningEffortPreset { + effort: ReasoningEffort::Medium, + description: "medium".to_string(), + }, + ReasoningEffortPreset { + effort: ReasoningEffort::High, + description: "high".to_string(), + }, + ], shell_type: ConfigShellToolType::Default, visibility: ModelVisibility::List, supported_in_api: true, @@ -686,7 +807,7 @@ mod tests { upgrade: None, base_instructions: BASE_INSTRUCTIONS.to_string(), model_messages: None, - supports_reasoning_summaries: false, + supports_reasoning_summaries: true, default_reasoning_summary: ReasoningSummary::None, support_verbosity: false, default_verbosity: None, @@ -715,7 +836,7 @@ mod tests { upgrade: None, base_instructions: BASE_INSTRUCTIONS.to_string(), model_messages: None, - supports_reasoning_summaries: false, + supports_reasoning_summaries: true, default_reasoning_summary: ReasoningSummary::None, support_verbosity: false, default_verbosity: None, @@ -735,6 +856,137 @@ mod tests { assert_eq!(models, expected); } + #[tokio::test] + async fn test_fetch_model_metadata_reasoning_variants() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_fetch_model_metadata_reasoning_variants", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with( + ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "models": [ + { + "key": "lmstudio-community/qwen3-0.6b", + "type": "llm", + "capabilities": { + "vision": false, + "trained_for_tool_use": true + } + }, + { + "key": "qwen/qwen3-0.6b", + "type": "llm", + "capabilities": { + "vision": false, + "trained_for_tool_use": true, + "reasoning": { + "allowed_options": ["off", "on"], + "default": "on" + } + } + }, + { + "key": "microsoft/phi-4-mini-reasoning", + "type": "llm", + "capabilities": { + "vision": false, + "trained_for_tool_use": false, + "reasoning": true + } + }, + { + "key": "nvidia/nemotron-3-super", + "type": "llm", + "capabilities": { + "vision": false, + "trained_for_tool_use": true, + "reasoning": { + "allowed_options": ["off", "low", "on"], + "default": "on" + } + } + }, + { + "key": "test/missing-default", + "type": "llm", + "capabilities": { + "vision": false, + "trained_for_tool_use": true, + "reasoning": { + "allowed_options": ["off", "on"] + } + } + } + ] + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + let models = client.fetch_model_metadata().await.expect("fetch metadata"); + + let summary = models + .iter() + .map(|model| { + ( + model.slug.clone(), + model.default_reasoning_level, + model + .supported_reasoning_levels + .iter() + .map(|preset| preset.effort) + .collect::>(), + ) + }) + .collect::>(); + + let expected = vec![ + ( + "lmstudio-community/qwen3-0.6b".to_string(), + None, + Vec::::new(), + ), + ( + "qwen/qwen3-0.6b".to_string(), + Some(ReasoningEffort::Medium), + vec![ReasoningEffort::None, ReasoningEffort::Medium], + ), + ( + "microsoft/phi-4-mini-reasoning".to_string(), + Some(ReasoningEffort::Medium), + vec![ReasoningEffort::Medium], + ), + ( + "nvidia/nemotron-3-super".to_string(), + Some(ReasoningEffort::Medium), + vec![ + ReasoningEffort::None, + ReasoningEffort::Low, + ReasoningEffort::Medium, + ], + ), + ( + "test/missing-default".to_string(), + Some(ReasoningEffort::None), + vec![ReasoningEffort::None, ReasoningEffort::Medium], + ), + ]; + + assert_eq!(summary, expected); + } + #[tokio::test] async fn test_check_server_happy_path() { if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { From f351e3f762301ef77736a4ac9d11ed6039274526 Mon Sep 17 00:00:00 2001 From: Rugved Somwanshi Date: Thu, 19 Mar 2026 18:13:21 -0400 Subject: [PATCH 09/11] Fix test and loaded logic --- codex-rs/lmstudio/src/client.rs | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/codex-rs/lmstudio/src/client.rs b/codex-rs/lmstudio/src/client.rs index b01a5dbe41a..23fd4b13454 100644 --- a/codex-rs/lmstudio/src/client.rs +++ b/codex-rs/lmstudio/src/client.rs @@ -95,16 +95,13 @@ impl LMStudioClient { let models = json.get("models").and_then(|value| value.as_array()); models.is_some_and(|entries| { entries.iter().any(|entry| { - let loaded_instances = entry + let is_requested_model = + entry.get("key").and_then(|value| value.as_str()) == Some(model); + let has_loaded_instances = entry .get("loaded_instances") - .and_then(|value| value.as_array()); - // A model is considered loaded if any of its loaded_instances - // shares the same id as the requested model. - loaded_instances.is_some_and(|instances| { - instances.iter().any(|instance| { - instance.get("id").and_then(|value| value.as_str()) == Some(model) - }) - }) + .and_then(|value| value.as_array()) + .is_some_and(|instances| !instances.is_empty()); + is_requested_model && has_loaded_instances }) }) } @@ -836,7 +833,7 @@ mod tests { upgrade: None, base_instructions: BASE_INSTRUCTIONS.to_string(), model_messages: None, - supports_reasoning_summaries: true, + supports_reasoning_summaries: false, default_reasoning_summary: ReasoningSummary::None, support_verbosity: false, default_verbosity: None, @@ -1092,7 +1089,7 @@ mod tests { "key": "test/test-model", "loaded_instances": [ { - "id": "test/test-model", + "id": "instance-abc123", "config": { "context_length": 7000 } @@ -1139,7 +1136,7 @@ mod tests { "key": "test/test-model", "loaded_instances": [ { - "id": "test/test-model", + "id": "instance-abc123", "config": { "context_length": 7000 } From 46815e80f71f16f9624a5e0b7be8ca3c1372ba97 Mon Sep 17 00:00:00 2001 From: Rugved Somwanshi Date: Fri, 20 Mar 2026 12:18:04 -0400 Subject: [PATCH 10/11] Handle empty catalog and refactor --- codex-rs/exec/src/lib.rs | 1 + codex-rs/lmstudio/src/client.rs | 221 ++++++++++++++++++++++++-------- codex-rs/lmstudio/src/lib.rs | 16 +-- codex-rs/utils/oss/src/lib.rs | 6 +- 4 files changed, 177 insertions(+), 67 deletions(-) diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index 11b602e6db8..781e2079143 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -386,6 +386,7 @@ pub async fn run_main(cli: Cli, arg0_paths: Arg0DispatchPaths) -> anyhow::Result } } } + let config = config; let default_cwd = config.cwd.to_path_buf(); let default_approval_policy = config.permissions.approval_policy.value(); diff --git a/codex-rs/lmstudio/src/client.rs b/codex-rs/lmstudio/src/client.rs index 23fd4b13454..ebe5167962c 100644 --- a/codex-rs/lmstudio/src/client.rs +++ b/codex-rs/lmstudio/src/client.rs @@ -23,7 +23,7 @@ pub struct LMStudioClient { } const LMSTUDIO_CONNECTION_ERROR: &str = "LM Studio is not responding. Install from https://lmstudio.ai/download and run 'lms server start'."; -const DEFAULT_CONTEXT_LENGTH: u32 = 64000; +const DEFAULT_CONTEXT_LENGTH: i64 = 64000; impl LMStudioClient { pub async fn try_from_provider(config: &Config) -> io::Result { @@ -80,20 +80,26 @@ impl LMStudioClient { } } - // Check if a model is already loaded with the same key - async fn is_model_loaded(&self, model: &str) -> bool { + async fn query_model_loaded(&self, model: &str) -> io::Result { let models_url = format!("{}/api/v1/models", self.host_root()); - let Ok(response) = self.client.get(&models_url).send().await else { - return false; - }; + let response = self + .client + .get(&models_url) + .send() + .await + .map_err(|e| io::Error::other(format!("Request failed: {e}")))?; if !response.status().is_success() { - return false; + return Err(io::Error::other(format!( + "Failed to fetch models: {}", + response.status() + ))); } - let Ok(json) = response.json::().await else { - return false; - }; + + let json = response.json::().await.map_err(|e| { + io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}")) + })?; let models = json.get("models").and_then(|value| value.as_array()); - models.is_some_and(|entries| { + Ok(models.is_some_and(|entries| { entries.iter().any(|entry| { let is_requested_model = entry.get("key").and_then(|value| value.as_str()) == Some(model); @@ -103,7 +109,28 @@ impl LMStudioClient { .is_some_and(|instances| !instances.is_empty()); is_requested_model && has_loaded_instances }) - }) + })) + } + + // Check if a model is already loaded with the same key + async fn is_model_loaded(&self, model: &str) -> bool { + self.query_model_loaded(model).await.unwrap_or(false) + } + + pub(crate) async fn wait_until_model_loaded(&self, model: &str) -> io::Result<()> { + for attempt in 0..MAX_MODEL_LOAD_POLL_ATTEMPTS { + if self.query_model_loaded(model).await? { + return Ok(()); + } + if attempt + 1 == MAX_MODEL_LOAD_POLL_ATTEMPTS { + break; + } + tokio::time::sleep(MODEL_LOAD_POLL_INTERVAL).await; + } + + Err(io::Error::other(format!( + "Timed out waiting for model '{model}' to finish loading" + ))) } // Load a model by sending an empty request with max_tokens 1 @@ -139,40 +166,7 @@ impl LMStudioClient { } } - // Return the list of models available on the LM Studio server. - pub async fn fetch_models(&self) -> io::Result> { - let url = format!("{}/api/v1/models", self.host_root()); - let response = self - .client - .get(&url) - .send() - .await - .map_err(|e| io::Error::other(format!("Request failed: {e}")))?; - - if response.status().is_success() { - let json: serde_json::Value = response.json().await.map_err(|e| { - io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}")) - })?; - let models = json["models"] - .as_array() - .ok_or_else(|| { - io::Error::new(io::ErrorKind::InvalidData, "No 'models' array in response") - })? - .iter() - .filter_map(|model| model["key"].as_str()) - .map(String::from) - .collect(); - Ok(models) - } else { - Err(io::Error::other(format!( - "Failed to fetch models: {}", - response.status() - ))) - } - } - - /// Return model metadata from the LM Studio server. - pub async fn fetch_model_metadata(&self) -> io::Result> { + async fn fetch_models_response(&self) -> io::Result { let url = format!("{}/api/v1/models", self.host_root()); let response = self .client @@ -188,9 +182,23 @@ impl LMStudioClient { ))); } - let json: LMStudioModelsResponse = response.json().await.map_err(|e| { - io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}")) - })?; + response + .json::() + .await + .map_err(|e| { + io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}")) + }) + } + + // Return the list of models available on the LM Studio server. + pub async fn fetch_models(&self) -> io::Result> { + let response = self.fetch_models_response().await?; + Ok(response.models.into_iter().map(|m| m.key).collect()) + } + + /// Return model metadata from the LM Studio server. + pub async fn fetch_model_metadata(&self) -> io::Result> { + let json = self.fetch_models_response().await?; let models = json .models @@ -213,6 +221,8 @@ impl LMStudioClient { .max() }) .or(model.max_context_length); + // LM Studio reports `capabilities.vision` for models that accept image input, so + // missing or false values are treated as text-only here. let supports_vision = model .capabilities .as_ref() @@ -256,7 +266,7 @@ impl LMStudioClient { base_instructions: BASE_INSTRUCTIONS.to_string(), model_messages: None, supports_reasoning_summaries: supports_reasoning, - default_reasoning_summary: ReasoningSummary::None, // Intentional as we don't support summary + default_reasoning_summary: ReasoningSummary::None, support_verbosity: false, default_verbosity: None, apply_patch_tool_type: trained_for_tool_use @@ -493,8 +503,17 @@ const DOWNLOAD_POLL_INTERVAL: Duration = Duration::from_secs(2); #[cfg(test)] const DOWNLOAD_POLL_INTERVAL: Duration = Duration::from_millis(10); +#[cfg(not(test))] +const MODEL_LOAD_POLL_INTERVAL: Duration = Duration::from_secs(1); +#[cfg(test)] +const MODEL_LOAD_POLL_INTERVAL: Duration = Duration::from_millis(10); + // Allow ~2 hours of polling (at 2s intervals) before giving up. const MAX_DOWNLOAD_POLL_ATTEMPTS: u32 = 3600; +#[cfg(not(test))] +const MAX_MODEL_LOAD_POLL_ATTEMPTS: u32 = 120; +#[cfg(test)] +const MAX_MODEL_LOAD_POLL_ATTEMPTS: u32 = 20; struct DownloadStatusResponse { status: String, @@ -617,8 +636,12 @@ mod tests { use super::*; use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; use pretty_assertions::assert_eq; + use std::sync::Arc; + use std::sync::atomic::AtomicUsize; + use std::sync::atomic::Ordering; use wiremock::Mock; use wiremock::MockServer; + use wiremock::Request; use wiremock::ResponseTemplate; use wiremock::matchers::method; use wiremock::matchers::path; @@ -678,12 +701,7 @@ mod tests { let client = LMStudioClient::from_host_root(server.uri()); let result = client.fetch_models().await; assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("No 'models' array in response") - ); + assert!(result.unwrap_err().to_string().contains("JSON parse error")); } #[tokio::test] @@ -1190,6 +1208,89 @@ mod tests { assert!(!client.is_model_loaded("test/test-model").await); } + #[tokio::test] + async fn test_wait_until_model_loaded_polls_until_loaded() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_wait_until_model_loaded_polls_until_loaded", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = MockServer::start().await; + let counter = Arc::new(AtomicUsize::new(0)); + let request_counter = Arc::clone(&counter); + + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with(move |_: &Request| { + let loaded_instances = if request_counter.fetch_add(1, Ordering::SeqCst) == 0 { + Vec::new() + } else { + vec![serde_json::json!({ + "id": "instance-abc123", + "config": { + "context_length": 7000 + } + })] + }; + ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "models": [ + { + "key": "test/test-model", + "loaded_instances": loaded_instances + } + ] + })) + }) + .expect(2) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + client + .wait_until_model_loaded("test/test-model") + .await + .expect("wait for model load"); + } + + #[tokio::test] + async fn test_wait_until_model_loaded_times_out() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_wait_until_model_loaded_times_out", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "models": [ + { + "key": "test/test-model", + "loaded_instances": [] + } + ] + }))) + .expect(MAX_MODEL_LOAD_POLL_ATTEMPTS as u64) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + let result = client.wait_until_model_loaded("test/test-model").await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Timed out waiting for model 'test/test-model' to finish loading") + ); + } + #[tokio::test] async fn test_load_model_error() { if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { @@ -1201,6 +1302,14 @@ mod tests { } let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with(ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ "models": [] }).to_string(), + "application/json", + )) + .mount(&server) + .await; Mock::given(method("POST")) .and(path("/api/v1/models/load")) .respond_with(ResponseTemplate::new(500)) diff --git a/codex-rs/lmstudio/src/lib.rs b/codex-rs/lmstudio/src/lib.rs index 3df0ece2bc9..89ee69e247f 100644 --- a/codex-rs/lmstudio/src/lib.rs +++ b/codex-rs/lmstudio/src/lib.rs @@ -11,6 +11,7 @@ pub const DEFAULT_OSS_MODEL: &str = "openai/gpt-oss-20b"; /// /// - Ensures a local LM Studio server is reachable. /// - Checks if the model exists locally and downloads it if missing. +/// - Waits for the selected model preload to settle before returning. pub async fn ensure_oss_ready(config: &Config) -> std::io::Result<()> { let model = match config.model.as_ref() { Some(model) => model, @@ -32,16 +33,11 @@ pub async fn ensure_oss_ready(config: &Config) -> std::io::Result<()> { } } - // Load the model in the background - tokio::spawn({ - let client = lmstudio_client.clone(); - let model = model.to_string(); - async move { - if let Err(e) = client.load_model(&model).await { - tracing::warn!("Failed to load model {}: {}", model, e); - } - } - }); + if let Err(err) = lmstudio_client.load_model(model).await { + tracing::warn!("Failed to load model {model}: {err}"); + } else if let Err(err) = lmstudio_client.wait_until_model_loaded(model).await { + tracing::warn!("Failed waiting for model {model} to finish loading: {err}"); + } Ok(()) } diff --git a/codex-rs/utils/oss/src/lib.rs b/codex-rs/utils/oss/src/lib.rs index 8c513d8873f..32b4028fe46 100644 --- a/codex-rs/utils/oss/src/lib.rs +++ b/codex-rs/utils/oss/src/lib.rs @@ -47,7 +47,11 @@ pub async fn fetch_oss_model_catalog( LMSTUDIO_OSS_PROVIDER_ID => { let client = codex_lmstudio::LMStudioClient::try_from_provider(config).await?; let models = client.fetch_model_metadata().await?; - Ok(Some(ModelsResponse { models })) + if models.is_empty() { + Ok(None) + } else { + Ok(Some(ModelsResponse { models })) + } } OLLAMA_OSS_PROVIDER_ID => Ok(None), _ => Ok(None), From 5f9de17bc9d2c36405a8af9df6ed6875613f1a0d Mon Sep 17 00:00:00 2001 From: Rugved Somwanshi Date: Fri, 20 Mar 2026 12:23:38 -0400 Subject: [PATCH 11/11] Drop waiting logic for load --- codex-rs/lmstudio/src/client.rs | 112 -------------------------------- codex-rs/lmstudio/src/lib.rs | 3 - 2 files changed, 115 deletions(-) diff --git a/codex-rs/lmstudio/src/client.rs b/codex-rs/lmstudio/src/client.rs index ebe5167962c..6ec9c16d2ec 100644 --- a/codex-rs/lmstudio/src/client.rs +++ b/codex-rs/lmstudio/src/client.rs @@ -117,22 +117,6 @@ impl LMStudioClient { self.query_model_loaded(model).await.unwrap_or(false) } - pub(crate) async fn wait_until_model_loaded(&self, model: &str) -> io::Result<()> { - for attempt in 0..MAX_MODEL_LOAD_POLL_ATTEMPTS { - if self.query_model_loaded(model).await? { - return Ok(()); - } - if attempt + 1 == MAX_MODEL_LOAD_POLL_ATTEMPTS { - break; - } - tokio::time::sleep(MODEL_LOAD_POLL_INTERVAL).await; - } - - Err(io::Error::other(format!( - "Timed out waiting for model '{model}' to finish loading" - ))) - } - // Load a model by sending an empty request with max_tokens 1 pub async fn load_model(&self, model: &str) -> io::Result<()> { if self.is_model_loaded(model).await { @@ -503,17 +487,8 @@ const DOWNLOAD_POLL_INTERVAL: Duration = Duration::from_secs(2); #[cfg(test)] const DOWNLOAD_POLL_INTERVAL: Duration = Duration::from_millis(10); -#[cfg(not(test))] -const MODEL_LOAD_POLL_INTERVAL: Duration = Duration::from_secs(1); -#[cfg(test)] -const MODEL_LOAD_POLL_INTERVAL: Duration = Duration::from_millis(10); - // Allow ~2 hours of polling (at 2s intervals) before giving up. const MAX_DOWNLOAD_POLL_ATTEMPTS: u32 = 3600; -#[cfg(not(test))] -const MAX_MODEL_LOAD_POLL_ATTEMPTS: u32 = 120; -#[cfg(test)] -const MAX_MODEL_LOAD_POLL_ATTEMPTS: u32 = 20; struct DownloadStatusResponse { status: String, @@ -636,12 +611,8 @@ mod tests { use super::*; use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; use pretty_assertions::assert_eq; - use std::sync::Arc; - use std::sync::atomic::AtomicUsize; - use std::sync::atomic::Ordering; use wiremock::Mock; use wiremock::MockServer; - use wiremock::Request; use wiremock::ResponseTemplate; use wiremock::matchers::method; use wiremock::matchers::path; @@ -1208,89 +1179,6 @@ mod tests { assert!(!client.is_model_loaded("test/test-model").await); } - #[tokio::test] - async fn test_wait_until_model_loaded_polls_until_loaded() { - if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { - tracing::info!( - "{} is set; skipping test_wait_until_model_loaded_polls_until_loaded", - CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR - ); - return; - } - - let server = MockServer::start().await; - let counter = Arc::new(AtomicUsize::new(0)); - let request_counter = Arc::clone(&counter); - - Mock::given(method("GET")) - .and(path("/api/v1/models")) - .respond_with(move |_: &Request| { - let loaded_instances = if request_counter.fetch_add(1, Ordering::SeqCst) == 0 { - Vec::new() - } else { - vec![serde_json::json!({ - "id": "instance-abc123", - "config": { - "context_length": 7000 - } - })] - }; - ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "models": [ - { - "key": "test/test-model", - "loaded_instances": loaded_instances - } - ] - })) - }) - .expect(2) - .mount(&server) - .await; - - let client = LMStudioClient::from_host_root(server.uri()); - client - .wait_until_model_loaded("test/test-model") - .await - .expect("wait for model load"); - } - - #[tokio::test] - async fn test_wait_until_model_loaded_times_out() { - if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { - tracing::info!( - "{} is set; skipping test_wait_until_model_loaded_times_out", - CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR - ); - return; - } - - let server = MockServer::start().await; - Mock::given(method("GET")) - .and(path("/api/v1/models")) - .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ - "models": [ - { - "key": "test/test-model", - "loaded_instances": [] - } - ] - }))) - .expect(MAX_MODEL_LOAD_POLL_ATTEMPTS as u64) - .mount(&server) - .await; - - let client = LMStudioClient::from_host_root(server.uri()); - let result = client.wait_until_model_loaded("test/test-model").await; - assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("Timed out waiting for model 'test/test-model' to finish loading") - ); - } - #[tokio::test] async fn test_load_model_error() { if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { diff --git a/codex-rs/lmstudio/src/lib.rs b/codex-rs/lmstudio/src/lib.rs index 89ee69e247f..34803522b8d 100644 --- a/codex-rs/lmstudio/src/lib.rs +++ b/codex-rs/lmstudio/src/lib.rs @@ -11,7 +11,6 @@ pub const DEFAULT_OSS_MODEL: &str = "openai/gpt-oss-20b"; /// /// - Ensures a local LM Studio server is reachable. /// - Checks if the model exists locally and downloads it if missing. -/// - Waits for the selected model preload to settle before returning. pub async fn ensure_oss_ready(config: &Config) -> std::io::Result<()> { let model = match config.model.as_ref() { Some(model) => model, @@ -35,8 +34,6 @@ pub async fn ensure_oss_ready(config: &Config) -> std::io::Result<()> { if let Err(err) = lmstudio_client.load_model(model).await { tracing::warn!("Failed to load model {model}: {err}"); - } else if let Err(err) = lmstudio_client.wait_until_model_loaded(model).await { - tracing::warn!("Failed waiting for model {model} to finish loading: {err}"); } Ok(())