diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index a68f99425cb..608ab49b968 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2032,11 +2032,13 @@ name = "codex-lmstudio" version = "0.0.0" dependencies = [ "codex-core", + "codex-protocol", + "pretty_assertions", "reqwest", + "serde", "serde_json", "tokio", "tracing", - "which", "wiremock", ] diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs index e8ab6f8af74..781e2079143 100644 --- a/codex-rs/exec/src/lib.rs +++ b/codex-rs/exec/src/lib.rs @@ -47,6 +47,7 @@ use codex_protocol::protocol::SubAgentSource; use codex_protocol::user_input::UserInput; use codex_utils_absolute_path::AbsolutePathBuf; use codex_utils_oss::ensure_oss_provider_ready; +use codex_utils_oss::fetch_oss_model_catalog; use codex_utils_oss::get_default_model_for_oss_provider; use event_processor_with_human_output::EventProcessorWithHumanOutput; use event_processor_with_jsonl_output::EventProcessorWithJsonOutput; @@ -289,7 +290,7 @@ pub async fn run_main(cli: Cli, arg0_paths: Arg0DispatchPaths) -> anyhow::Result additional_writable_roots: add_dir, }; - let config = ConfigBuilder::default() + let mut config = ConfigBuilder::default() .cli_overrides(cli_kv_overrides) .harness_overrides(overrides) .cloud_requirements(cloud_requirements) @@ -373,6 +374,20 @@ pub async fn run_main(cli: Cli, arg0_paths: Arg0DispatchPaths) -> anyhow::Result .map_err(|e| anyhow::anyhow!("OSS setup failed: {e}"))?; } + if config.model_catalog.is_none() { + let provider_id = config.model_provider_id.as_str(); + match fetch_oss_model_catalog(provider_id, &config).await { + Ok(Some(catalog)) => { + config.model_catalog = Some(catalog); + } + Ok(None) => {} + Err(err) => { + warn!("Failed to fetch OSS model catalog for {provider_id}: {err}"); + } + } + } + let config = config; + let default_cwd = config.cwd.to_path_buf(); let default_approval_policy = config.permissions.approval_policy.value(); let default_sandbox_policy = config.permissions.sandbox_policy.get(); diff --git a/codex-rs/lmstudio/Cargo.toml b/codex-rs/lmstudio/Cargo.toml index 5f4849638ae..705be1870ef 100644 --- a/codex-rs/lmstudio/Cargo.toml +++ b/codex-rs/lmstudio/Cargo.toml @@ -11,13 +11,15 @@ path = "src/lib.rs" [dependencies] codex-core = { path = "../core" } +codex-protocol = { path = "../protocol" } reqwest = { version = "0.12", features = ["json", "stream"] } +serde = { workspace = true, features = ["derive"] } serde_json = "1" -tokio = { version = "1", features = ["rt"] } +tokio = { version = "1", features = ["rt", "time"] } tracing = { version = "0.1.44", features = ["log"] } -which = "8.0" [dev-dependencies] +pretty_assertions = { workspace = true } wiremock = "0.6" tokio = { version = "1", features = ["full"] } diff --git a/codex-rs/lmstudio/src/client.rs b/codex-rs/lmstudio/src/client.rs index a2a8ee03bff..6ec9c16d2ec 100644 --- a/codex-rs/lmstudio/src/client.rs +++ b/codex-rs/lmstudio/src/client.rs @@ -1,7 +1,20 @@ use codex_core::LMSTUDIO_OSS_PROVIDER_ID; use codex_core::config::Config; +use codex_core::models_manager::model_info::BASE_INSTRUCTIONS; +use codex_protocol::config_types::ReasoningSummary; +use codex_protocol::openai_models::ApplyPatchToolType; +use codex_protocol::openai_models::ConfigShellToolType; +use codex_protocol::openai_models::InputModality; +use codex_protocol::openai_models::ModelInfo; +use codex_protocol::openai_models::ModelVisibility; +use codex_protocol::openai_models::ReasoningEffort; +use codex_protocol::openai_models::ReasoningEffortPreset; +use codex_protocol::openai_models::TruncationPolicyConfig; +use serde::Deserialize; use std::io; -use std::path::Path; +use std::io::Write; +use std::time::Duration; +use std::time::Instant; #[derive(Clone)] pub struct LMStudioClient { @@ -10,9 +23,10 @@ pub struct LMStudioClient { } const LMSTUDIO_CONNECTION_ERROR: &str = "LM Studio is not responding. Install from https://lmstudio.ai/download and run 'lms server start'."; +const DEFAULT_CONTEXT_LENGTH: i64 = 64000; impl LMStudioClient { - pub async fn try_from_provider(config: &Config) -> std::io::Result { + pub async fn try_from_provider(config: &Config) -> io::Result { let provider = config .model_providers .get(LMSTUDIO_OSS_PROVIDER_ID) @@ -30,7 +44,7 @@ impl LMStudioClient { })?; let client = reqwest::Client::builder() - .connect_timeout(std::time::Duration::from_secs(5)) + .connect_timeout(Duration::from_secs(5)) .build() .unwrap_or_else(|_| reqwest::Client::new()); @@ -43,8 +57,13 @@ impl LMStudioClient { Ok(client) } + fn host_root(&self) -> String { + let base_url = self.base_url.trim_end_matches('/'); + base_url.strip_suffix("/v1").unwrap_or(base_url).to_string() + } + async fn check_server(&self) -> io::Result<()> { - let url = format!("{}/models", self.base_url.trim_end_matches('/')); + let url = format!("{}/api/v1/models", self.host_root()); let response = self.client.get(&url).send().await; if let Ok(resp) = response { @@ -61,14 +80,54 @@ impl LMStudioClient { } } + async fn query_model_loaded(&self, model: &str) -> io::Result { + let models_url = format!("{}/api/v1/models", self.host_root()); + let response = self + .client + .get(&models_url) + .send() + .await + .map_err(|e| io::Error::other(format!("Request failed: {e}")))?; + if !response.status().is_success() { + return Err(io::Error::other(format!( + "Failed to fetch models: {}", + response.status() + ))); + } + + let json = response.json::().await.map_err(|e| { + io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}")) + })?; + let models = json.get("models").and_then(|value| value.as_array()); + Ok(models.is_some_and(|entries| { + entries.iter().any(|entry| { + let is_requested_model = + entry.get("key").and_then(|value| value.as_str()) == Some(model); + let has_loaded_instances = entry + .get("loaded_instances") + .and_then(|value| value.as_array()) + .is_some_and(|instances| !instances.is_empty()); + is_requested_model && has_loaded_instances + }) + })) + } + + // Check if a model is already loaded with the same key + async fn is_model_loaded(&self, model: &str) -> bool { + self.query_model_loaded(model).await.unwrap_or(false) + } + // Load a model by sending an empty request with max_tokens 1 pub async fn load_model(&self, model: &str) -> io::Result<()> { - let url = format!("{}/responses", self.base_url.trim_end_matches('/')); + if self.is_model_loaded(model).await { + tracing::info!("Model '{model}' already loaded; reusing existing instance"); + return Ok(()); + } + let url = format!("{}/api/v1/models/load", self.host_root()); let request_body = serde_json::json!({ "model": model, - "input": "", - "max_output_tokens": 1 + "context_length": DEFAULT_CONTEXT_LENGTH }); let response = self @@ -91,9 +150,8 @@ impl LMStudioClient { } } - // Return the list of models available on the LM Studio server. - pub async fn fetch_models(&self) -> io::Result> { - let url = format!("{}/models", self.base_url.trim_end_matches('/')); + async fn fetch_models_response(&self) -> io::Result { + let url = format!("{}/api/v1/models", self.host_root()); let response = self .client .get(&url) @@ -101,99 +159,271 @@ impl LMStudioClient { .await .map_err(|e| io::Error::other(format!("Request failed: {e}")))?; - if response.status().is_success() { - let json: serde_json::Value = response.json().await.map_err(|e| { - io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}")) - })?; - let models = json["data"] - .as_array() - .ok_or_else(|| { - io::Error::new(io::ErrorKind::InvalidData, "No 'data' array in response") - })? - .iter() - .filter_map(|model| model["id"].as_str()) - .map(std::string::ToString::to_string) - .collect(); - Ok(models) - } else { - Err(io::Error::other(format!( + if !response.status().is_success() { + return Err(io::Error::other(format!( "Failed to fetch models: {}", response.status() - ))) + ))); } + + response + .json::() + .await + .map_err(|e| { + io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}")) + }) } - // Find lms, checking fallback paths if not in PATH - fn find_lms() -> std::io::Result { - Self::find_lms_with_home_dir(None) + // Return the list of models available on the LM Studio server. + pub async fn fetch_models(&self) -> io::Result> { + let response = self.fetch_models_response().await?; + Ok(response.models.into_iter().map(|m| m.key).collect()) } - fn find_lms_with_home_dir(home_dir: Option<&str>) -> std::io::Result { - // First try 'lms' in PATH - if which::which("lms").is_ok() { - return Ok("lms".to_string()); - } + /// Return model metadata from the LM Studio server. + pub async fn fetch_model_metadata(&self) -> io::Result> { + let json = self.fetch_models_response().await?; - // Platform-specific fallback paths - let home = match home_dir { - Some(dir) => dir.to_string(), - None => { - #[cfg(unix)] - { - std::env::var("HOME").unwrap_or_default() - } - #[cfg(windows)] - { - std::env::var("USERPROFILE").unwrap_or_default() - } - } - }; - - #[cfg(unix)] - let fallback_path = format!("{home}/.lmstudio/bin/lms"); + let models = json + .models + .into_iter() + .filter(|model| matches!(model.model_type.as_deref(), None | Some("llm"))) + .enumerate() + .map(|(index, model)| { + let context_window = model + .loaded_instances + .as_ref() + .and_then(|instances| { + instances + .iter() + .filter_map(|instance| { + instance + .config + .as_ref() + .and_then(|config| config.context_length) + }) + .max() + }) + .or(model.max_context_length); + // LM Studio reports `capabilities.vision` for models that accept image input, so + // missing or false values are treated as text-only here. + let supports_vision = model + .capabilities + .as_ref() + .and_then(|capabilities| capabilities.vision) + .unwrap_or(false); + let trained_for_tool_use = model + .capabilities + .as_ref() + .and_then(|capabilities| capabilities.trained_for_tool_use) + .unwrap_or(false); + let input_modalities = if supports_vision { + vec![InputModality::Text, InputModality::Image] + } else { + vec![InputModality::Text] + }; + let (default_reasoning_level, supported_reasoning_levels) = + parse_reasoning_capability( + model + .capabilities + .as_ref() + .and_then(|capabilities| capabilities.reasoning.as_ref()), + ); + let supports_reasoning = + default_reasoning_level.is_some() || !supported_reasoning_levels.is_empty(); - #[cfg(windows)] - let fallback_path = format!("{home}/.lmstudio/bin/lms.exe"); + ModelInfo { + slug: model.key.clone(), + display_name: model + .display_name + .clone() + .unwrap_or_else(|| model.key.clone()), + description: model.description, + default_reasoning_level, + supported_reasoning_levels, + shell_type: ConfigShellToolType::Default, + visibility: ModelVisibility::List, + supported_in_api: true, + priority: i32::try_from(index).unwrap_or(i32::MAX), + availability_nux: None, + upgrade: None, + base_instructions: BASE_INSTRUCTIONS.to_string(), + model_messages: None, + supports_reasoning_summaries: supports_reasoning, + default_reasoning_summary: ReasoningSummary::None, + support_verbosity: false, + default_verbosity: None, + apply_patch_tool_type: trained_for_tool_use + .then_some(ApplyPatchToolType::Function), + truncation_policy: TruncationPolicyConfig::bytes(10_000), + supports_parallel_tool_calls: false, + context_window, + auto_compact_token_limit: None, + effective_context_window_percent: 95, + experimental_supported_tools: Vec::new(), + input_modalities, + prefer_websockets: false, + used_fallback_model_metadata: false, + } + }) + .collect(); - if Path::new(&fallback_path).exists() { - Ok(fallback_path) - } else { - Err(std::io::Error::new( - std::io::ErrorKind::NotFound, - "LM Studio not found. Please install LM Studio from https://lmstudio.ai/", - )) - } + Ok(models) } - pub async fn download_model(&self, model: &str) -> std::io::Result<()> { - let lms = Self::find_lms()?; - eprintln!("Downloading model: {model}"); + pub async fn download_model(&self, model: &str) -> io::Result<()> { + let url = format!("{}/api/v1/models/download", self.host_root()); - let status = std::process::Command::new(&lms) - .args(["get", "--yes", model]) - .stdout(std::process::Stdio::inherit()) - .stderr(std::process::Stdio::null()) - .status() - .map_err(|e| { - std::io::Error::other(format!("Failed to execute '{lms} get --yes {model}': {e}")) - })?; + let request_body = serde_json::json!({ + "model": model + }); + + let response = self + .client + .post(&url) + .header("Content-Type", "application/json") + .json(&request_body) + .send() + .await + .map_err(|e| io::Error::other(format!("Request failed: {e}")))?; - if !status.success() { - return Err(std::io::Error::other(format!( - "Model download failed with exit code: {}", - status.code().unwrap_or(-1) + if !response.status().is_success() { + return Err(io::Error::other(format!( + "Failed to download model: {}", + response.status() ))); } - tracing::info!("Successfully downloaded model '{model}'"); - Ok(()) + let download_status = response.json::().await.map_err(|e| { + io::Error::new(io::ErrorKind::InvalidData, format!("JSON parse error: {e}")) + })?; + + let initial = parse_download_status(&download_status)?; + let status = initial.status; + let job_id = initial.job_id; + + match status.as_str() { + "already_downloaded" | "completed" => { + tracing::info!("Model '{model}' is ready"); + Ok(()) + } + "failed" => Err(io::Error::other(format!( + "Model download failed for '{model}'" + ))), + "paused" => Err(io::Error::other(format!( + "Model download paused for '{model}'" + ))), + "downloading" => { + let job_id = job_id.as_deref().ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "Download status missing job_id") + })?; + + let mut last_logged = Instant::now() - Duration::from_secs(10); + let mut attempts = 0u32; + + loop { + tokio::time::sleep(DOWNLOAD_POLL_INTERVAL).await; + attempts += 1; + if attempts > MAX_DOWNLOAD_POLL_ATTEMPTS { + eprintln!(); + return Err(io::Error::other(format!( + "Timed out waiting for model '{model}' to download" + ))); + } + let status_url = format!( + "{}/api/v1/models/download/status/{job_id}", + self.host_root() + ); + + let status_response = self + .client + .get(&status_url) + .send() + .await + .map_err(|e| io::Error::other(format!("Request failed: {e}")))?; + + if !status_response.status().is_success() { + return Err(io::Error::other(format!( + "Failed to fetch download status: {}", + status_response.status() + ))); + } + + let status = + status_response + .json::() + .await + .map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("JSON parse error: {e}"), + ) + })?; + let poll = parse_download_status(&status)?; + let status_value = poll.status; + let downloaded_bytes = poll.downloaded_bytes; + let total_size_bytes = poll.total_size_bytes; + + match status_value.as_str() { + "completed" => { + eprintln!(); + tracing::info!("Successfully downloaded model '{model}'"); + return Ok(()); + } + "failed" => { + eprintln!(); + return Err(io::Error::other(format!( + "Model download failed for '{model}'" + ))); + } + "paused" => { + eprintln!(); + return Err(io::Error::other(format!( + "Model download paused for '{model}'" + ))); + } + "downloading" => { + if let Some(downloaded) = downloaded_bytes { + let now = Instant::now(); + if now.duration_since(last_logged) >= Duration::from_millis(500) { + if let Some(total) = total_size_bytes { + let percent = (downloaded as f64 / total as f64) * 100.0; + eprint!( + "\rDownloading '{model}': {} / {} ({percent:.1}%)", + format_bytes(downloaded), + format_bytes(total) + ); + } else { + eprint!( + "\rDownloading '{model}': {}", + format_bytes(downloaded) + ); + } + let _ = io::stderr().flush(); + last_logged = now; + } + } + } + status_value => { + eprintln!(); + return Err(io::Error::other(format!( + "Unknown download status '{status_value}' for '{model}'" + ))); + } + } + } + } + status_value => Err(io::Error::other(format!( + "Unknown download status '{status_value}' for '{model}'" + ))), + } } /// Low-level constructor given a raw host root, e.g. "http://localhost:1234". #[cfg(test)] fn from_host_root(host_root: impl Into) -> Self { let client = reqwest::Client::builder() - .connect_timeout(std::time::Duration::from_secs(5)) + .connect_timeout(Duration::from_secs(5)) .build() .unwrap_or_else(|_| reqwest::Client::new()); Self { @@ -203,30 +433,209 @@ impl LMStudioClient { } } +#[derive(Debug, Deserialize)] +struct LMStudioModelsResponse { + models: Vec, +} + +#[derive(Debug, Deserialize)] +struct LMStudioModel { + key: String, + display_name: Option, + description: Option, + #[serde(rename = "type")] + model_type: Option, + max_context_length: Option, + capabilities: Option, + loaded_instances: Option>, +} + +#[derive(Debug, Deserialize)] +struct LMStudioCapabilities { + vision: Option, + trained_for_tool_use: Option, + reasoning: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum LMStudioReasoningCapability { + Enabled(bool), + Options(LMStudioReasoningOptions), +} + +#[derive(Debug, Deserialize)] +struct LMStudioReasoningOptions { + allowed_options: Option>, + #[serde(rename = "default")] + default_option: Option, +} + +#[derive(Debug, Deserialize)] +struct LMStudioLoadedInstance { + config: Option, +} + +#[derive(Debug, Deserialize)] +struct LMStudioInstanceConfig { + context_length: Option, +} + +// Poll every 2 seconds in production; use a short interval in tests to avoid slowness. +#[cfg(not(test))] +const DOWNLOAD_POLL_INTERVAL: Duration = Duration::from_secs(2); +#[cfg(test)] +const DOWNLOAD_POLL_INTERVAL: Duration = Duration::from_millis(10); + +// Allow ~2 hours of polling (at 2s intervals) before giving up. +const MAX_DOWNLOAD_POLL_ATTEMPTS: u32 = 3600; + +struct DownloadStatusResponse { + status: String, + job_id: Option, + downloaded_bytes: Option, + total_size_bytes: Option, +} + +fn parse_download_status(json: &serde_json::Value) -> io::Result { + let status = json["status"] + .as_str() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Missing status"))? + .to_string(); + let job_id = json["job_id"].as_str().map(String::from); + let downloaded_bytes = json["downloaded_bytes"].as_u64(); + let total_size_bytes = json["total_size_bytes"].as_u64(); + Ok(DownloadStatusResponse { + status, + job_id, + downloaded_bytes, + total_size_bytes, + }) +} + +fn format_bytes(bytes: u64) -> String { + const UNITS: &[&str] = &["B", "KiB", "MiB", "GiB", "TiB"]; + let mut size = bytes as f64; + let mut i = 0; + while size >= 1024.0 && i < UNITS.len() - 1 { + size /= 1024.0; + i += 1; + } + if i == 0 { + format!("{size} B") + } else { + format!("{size:.2} {}", UNITS[i]) + } +} + +fn parse_reasoning_effort(option: &str) -> Option { + let normalized = option.trim().to_ascii_lowercase(); + match normalized.as_str() { + "off" | "none" => Some(ReasoningEffort::None), + "on" => Some(ReasoningEffort::Medium), + "minimal" => Some(ReasoningEffort::Minimal), + "low" => Some(ReasoningEffort::Low), + "medium" => Some(ReasoningEffort::Medium), + "high" => Some(ReasoningEffort::High), + "xhigh" => Some(ReasoningEffort::XHigh), + _ => None, + } +} + +fn parse_reasoning_capability( + capability: Option<&LMStudioReasoningCapability>, +) -> (Option, Vec) { + let medium_only = ( + Some(ReasoningEffort::Medium), + vec![ReasoningEffortPreset { + effort: ReasoningEffort::Medium, + description: "medium".to_string(), + }], + ); + let fallback_presets = vec![ + ReasoningEffort::Low, + ReasoningEffort::Medium, + ReasoningEffort::High, + ] + .into_iter() + .map(|effort| ReasoningEffortPreset { + effort, + description: format!("{effort}"), + }) + .collect::>(); + let fallback = (Some(ReasoningEffort::Medium), fallback_presets); + + let Some(capability) = capability else { + return (None, Vec::new()); + }; + + match capability { + LMStudioReasoningCapability::Enabled(true) => medium_only, + LMStudioReasoningCapability::Enabled(false) => (None, Vec::new()), + LMStudioReasoningCapability::Options(options) => { + let mut efforts = Vec::new(); + if let Some(allowed_options) = options.allowed_options.as_ref() { + for option in allowed_options { + if let Some(effort) = parse_reasoning_effort(option) + && !efforts.contains(&effort) + { + efforts.push(effort); + } + } + } + + if efforts.is_empty() { + return fallback; + } + + let default_reasoning_level = options + .default_option + .as_deref() + .and_then(parse_reasoning_effort) + .or_else(|| efforts.first().copied()); + let supported_reasoning_levels = efforts + .into_iter() + .map(|effort| ReasoningEffortPreset { + effort, + description: format!("{effort}"), + }) + .collect(); + (default_reasoning_level, supported_reasoning_levels) + } + } +} + #[cfg(test)] mod tests { #![allow(clippy::expect_used, clippy::unwrap_used)] use super::*; + use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR; + use pretty_assertions::assert_eq; + use wiremock::Mock; + use wiremock::MockServer; + use wiremock::ResponseTemplate; + use wiremock::matchers::method; + use wiremock::matchers::path; #[tokio::test] async fn test_fetch_models_happy_path() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_fetch_models_happy_path", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/models")) + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) .respond_with( - wiremock::ResponseTemplate::new(200).set_body_raw( + ResponseTemplate::new(200).set_body_raw( serde_json::json!({ - "data": [ - {"id": "openai/gpt-oss-20b"}, - ] + "models": [ + {"key": "openai/gpt-oss-20b"}, + ], }) .to_string(), "application/json", @@ -242,19 +651,19 @@ mod tests { #[tokio::test] async fn test_fetch_models_no_data_array() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_fetch_models_no_data_array", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/models")) + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) .respond_with( - wiremock::ResponseTemplate::new(200) + ResponseTemplate::new(200) .set_body_raw(serde_json::json!({}).to_string(), "application/json"), ) .mount(&server) @@ -263,28 +672,23 @@ mod tests { let client = LMStudioClient::from_host_root(server.uri()); let result = client.fetch_models().await; assert!(result.is_err()); - assert!( - result - .unwrap_err() - .to_string() - .contains("No 'data' array in response") - ); + assert!(result.unwrap_err().to_string().contains("JSON parse error")); } #[tokio::test] async fn test_fetch_models_server_error() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_fetch_models_server_error", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/models")) - .respond_with(wiremock::ResponseTemplate::new(500)) + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with(ResponseTemplate::new(500)) .mount(&server) .await; @@ -299,20 +703,290 @@ mod tests { ); } + #[tokio::test] + async fn test_fetch_model_metadata_filters_and_maps_fields() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_fetch_model_metadata_filters_and_maps_fields", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with( + ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "models": [ + { + "key": "openai/gpt-oss-20b", + "display_name": "GPT-OSS-20B (LM Studio)", + "description": "OSS model", + "type": "llm", + "max_context_length": 100_000, + "capabilities": { + "vision": true, + "trained_for_tool_use": true, + "reasoning": { + "allowed_options": ["low", "medium", "high"], + "default": "low" + } + }, + "loaded_instances": [ + { + "config": { + "context_length": 90_000 + } + } + ] + }, + { + "key": "embed/text", + "display_name": "Embedding", + "type": "embedding", + "max_context_length": 1024 + }, + { + "key": "llm/second", + "type": "llm", + "max_context_length": 4096 + } + ] + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + let models = client.fetch_model_metadata().await.expect("fetch metadata"); + + let expected = vec![ + ModelInfo { + slug: "openai/gpt-oss-20b".to_string(), + display_name: "GPT-OSS-20B (LM Studio)".to_string(), + description: Some("OSS model".to_string()), + default_reasoning_level: Some(ReasoningEffort::Low), + supported_reasoning_levels: vec![ + ReasoningEffortPreset { + effort: ReasoningEffort::Low, + description: "low".to_string(), + }, + ReasoningEffortPreset { + effort: ReasoningEffort::Medium, + description: "medium".to_string(), + }, + ReasoningEffortPreset { + effort: ReasoningEffort::High, + description: "high".to_string(), + }, + ], + shell_type: ConfigShellToolType::Default, + visibility: ModelVisibility::List, + supported_in_api: true, + priority: 0, + availability_nux: None, + upgrade: None, + base_instructions: BASE_INSTRUCTIONS.to_string(), + model_messages: None, + supports_reasoning_summaries: true, + default_reasoning_summary: ReasoningSummary::None, + support_verbosity: false, + default_verbosity: None, + apply_patch_tool_type: Some(ApplyPatchToolType::Function), + truncation_policy: TruncationPolicyConfig::bytes(10_000), + supports_parallel_tool_calls: false, + context_window: Some(90_000), + auto_compact_token_limit: None, + effective_context_window_percent: 95, + experimental_supported_tools: Vec::new(), + input_modalities: vec![InputModality::Text, InputModality::Image], + prefer_websockets: false, + used_fallback_model_metadata: false, + }, + ModelInfo { + slug: "llm/second".to_string(), + display_name: "llm/second".to_string(), + description: None, + default_reasoning_level: None, + supported_reasoning_levels: Vec::new(), + shell_type: ConfigShellToolType::Default, + visibility: ModelVisibility::List, + supported_in_api: true, + priority: 1, + availability_nux: None, + upgrade: None, + base_instructions: BASE_INSTRUCTIONS.to_string(), + model_messages: None, + supports_reasoning_summaries: false, + default_reasoning_summary: ReasoningSummary::None, + support_verbosity: false, + default_verbosity: None, + apply_patch_tool_type: None, + truncation_policy: TruncationPolicyConfig::bytes(10_000), + supports_parallel_tool_calls: false, + context_window: Some(4096), + auto_compact_token_limit: None, + effective_context_window_percent: 95, + experimental_supported_tools: Vec::new(), + input_modalities: vec![InputModality::Text], + prefer_websockets: false, + used_fallback_model_metadata: false, + }, + ]; + + assert_eq!(models, expected); + } + + #[tokio::test] + async fn test_fetch_model_metadata_reasoning_variants() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_fetch_model_metadata_reasoning_variants", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with( + ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "models": [ + { + "key": "lmstudio-community/qwen3-0.6b", + "type": "llm", + "capabilities": { + "vision": false, + "trained_for_tool_use": true + } + }, + { + "key": "qwen/qwen3-0.6b", + "type": "llm", + "capabilities": { + "vision": false, + "trained_for_tool_use": true, + "reasoning": { + "allowed_options": ["off", "on"], + "default": "on" + } + } + }, + { + "key": "microsoft/phi-4-mini-reasoning", + "type": "llm", + "capabilities": { + "vision": false, + "trained_for_tool_use": false, + "reasoning": true + } + }, + { + "key": "nvidia/nemotron-3-super", + "type": "llm", + "capabilities": { + "vision": false, + "trained_for_tool_use": true, + "reasoning": { + "allowed_options": ["off", "low", "on"], + "default": "on" + } + } + }, + { + "key": "test/missing-default", + "type": "llm", + "capabilities": { + "vision": false, + "trained_for_tool_use": true, + "reasoning": { + "allowed_options": ["off", "on"] + } + } + } + ] + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + let models = client.fetch_model_metadata().await.expect("fetch metadata"); + + let summary = models + .iter() + .map(|model| { + ( + model.slug.clone(), + model.default_reasoning_level, + model + .supported_reasoning_levels + .iter() + .map(|preset| preset.effort) + .collect::>(), + ) + }) + .collect::>(); + + let expected = vec![ + ( + "lmstudio-community/qwen3-0.6b".to_string(), + None, + Vec::::new(), + ), + ( + "qwen/qwen3-0.6b".to_string(), + Some(ReasoningEffort::Medium), + vec![ReasoningEffort::None, ReasoningEffort::Medium], + ), + ( + "microsoft/phi-4-mini-reasoning".to_string(), + Some(ReasoningEffort::Medium), + vec![ReasoningEffort::Medium], + ), + ( + "nvidia/nemotron-3-super".to_string(), + Some(ReasoningEffort::Medium), + vec![ + ReasoningEffort::None, + ReasoningEffort::Low, + ReasoningEffort::Medium, + ], + ), + ( + "test/missing-default".to_string(), + Some(ReasoningEffort::None), + vec![ReasoningEffort::None, ReasoningEffort::Medium], + ), + ]; + + assert_eq!(summary, expected); + } + #[tokio::test] async fn test_check_server_happy_path() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_check_server_happy_path", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/models")) - .respond_with(wiremock::ResponseTemplate::new(200)) + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with(ResponseTemplate::new(200)) .mount(&server) .await; @@ -325,18 +999,18 @@ mod tests { #[tokio::test] async fn test_check_server_error() { - if std::env::var(codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { tracing::info!( "{} is set; skipping test_check_server_error", - codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR ); return; } - let server = wiremock::MockServer::start().await; - wiremock::Mock::given(wiremock::matchers::method("GET")) - .and(wiremock::matchers::path("/models")) - .respond_with(wiremock::ResponseTemplate::new(404)) + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with(ResponseTemplate::new(404)) .mount(&server) .await; @@ -351,39 +1025,362 @@ mod tests { ); } - #[test] - fn test_find_lms() { - let result = LMStudioClient::find_lms(); + #[tokio::test] + async fn test_load_model_happy_path() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_load_model_happy_path", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } - match result { - Ok(_) => { - // lms was found in PATH - that's fine - } - Err(e) => { - // Expected error when LM Studio not installed - assert!(e.to_string().contains("LM Studio not found")); - } + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with(ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ "models": [] }).to_string(), + "application/json", + )) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/api/v1/models/load")) + .respond_with(ResponseTemplate::new(200)) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(format!("{uri}/v1", uri = server.uri())); + client + .load_model("openai/gpt-oss-20b") + .await + .expect("load model"); + } + + #[tokio::test] + async fn test_load_model_reuses_loaded_instance() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_load_model_reuses_loaded_instance", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; } + + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with( + ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "models": [ + { + "key": "test/test-model", + "loaded_instances": [ + { + "id": "instance-abc123", + "config": { + "context_length": 7000 + } + } + ] + } + ] + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/api/v1/models/load")) + .respond_with(ResponseTemplate::new(500)) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + let result = client.load_model("test/test-model").await; + assert!(result.is_ok()); } - #[test] - fn test_find_lms_with_mock_home() { - // Test fallback path construction without touching env vars - #[cfg(unix)] - { - let result = LMStudioClient::find_lms_with_home_dir(Some("/test/home")); - if let Err(e) = result { - assert!(e.to_string().contains("LM Studio not found")); - } + #[tokio::test] + async fn test_is_model_loaded_true() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_is_model_loaded_true", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; } - #[cfg(windows)] - { - let result = LMStudioClient::find_lms_with_home_dir(Some("C:\\test\\home")); - if let Err(e) = result { - assert!(e.to_string().contains("LM Studio not found")); - } + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with( + ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "models": [ + { + "key": "test/test-model", + "loaded_instances": [ + { + "id": "instance-abc123", + "config": { + "context_length": 7000 + } + } + ] + } + ] + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + assert!(client.is_model_loaded("test/test-model").await); + } + + #[tokio::test] + async fn test_is_model_loaded_false() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_is_model_loaded_false", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; } + + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with( + ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "models": [ + { + "key": "test/test-model", + "loaded_instances": [] + } + ] + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + assert!(!client.is_model_loaded("test/test-model").await); + } + + #[tokio::test] + async fn test_load_model_error() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_load_model_error", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/api/v1/models")) + .respond_with(ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ "models": [] }).to_string(), + "application/json", + )) + .mount(&server) + .await; + Mock::given(method("POST")) + .and(path("/api/v1/models/load")) + .respond_with(ResponseTemplate::new(500)) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(format!("{uri}/v1", uri = server.uri())); + let result = client.load_model("openai/gpt-oss-20b").await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Failed to load model: 500") + ); + } + + #[tokio::test] + async fn test_download_model_happy_path() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_download_model_happy_path", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v1/models/download")) + .respond_with( + ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "job_id": "job-1", + "status": "downloading" + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/api/v1/models/download/status/job-1")) + .respond_with( + ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "job_id": "job-1", + "status": "completed" + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(format!("{uri}/v1", uri = server.uri())); + client + .download_model("openai/gpt-oss-20b") + .await + .expect("download model"); + } + + #[tokio::test] + async fn test_download_model_error() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_download_model_error", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v1/models/download")) + .respond_with( + ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "job_id": "job-1", + "status": "downloading" + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + Mock::given(method("GET")) + .and(path("/api/v1/models/download/status/job-1")) + .respond_with( + ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "job_id": "job-1", + "status": "failed" + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(format!("{uri}/v1", uri = server.uri())); + let result = client.download_model("openai/gpt-oss-20b").await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Model download failed") + ); + } + + #[tokio::test] + async fn test_download_model_already_downloaded() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_download_model_already_downloaded", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v1/models/download")) + .respond_with( + ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "status": "already_downloaded" + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + client + .download_model("openai/gpt-oss-20b") + .await + .expect("already_downloaded should succeed"); + } + + #[tokio::test] + async fn test_download_model_paused() { + if std::env::var(CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR).is_ok() { + tracing::info!( + "{} is set; skipping test_download_model_paused", + CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR + ); + return; + } + + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/api/v1/models/download")) + .respond_with( + ResponseTemplate::new(200).set_body_raw( + serde_json::json!({ + "status": "paused" + }) + .to_string(), + "application/json", + ), + ) + .mount(&server) + .await; + + let client = LMStudioClient::from_host_root(server.uri()); + let result = client.download_model("openai/gpt-oss-20b").await; + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .to_string() + .contains("Model download paused") + ); } #[test] diff --git a/codex-rs/lmstudio/src/lib.rs b/codex-rs/lmstudio/src/lib.rs index fd4f82a728a..34803522b8d 100644 --- a/codex-rs/lmstudio/src/lib.rs +++ b/codex-rs/lmstudio/src/lib.rs @@ -2,6 +2,7 @@ mod client; pub use client::LMStudioClient; use codex_core::config::Config; +pub use codex_protocol::openai_models::ModelsResponse; /// Default OSS model to use when `--oss` is passed without an explicit `-m`. pub const DEFAULT_OSS_MODEL: &str = "openai/gpt-oss-20b"; @@ -31,16 +32,9 @@ pub async fn ensure_oss_ready(config: &Config) -> std::io::Result<()> { } } - // Load the model in the background - tokio::spawn({ - let client = lmstudio_client.clone(); - let model = model.to_string(); - async move { - if let Err(e) = client.load_model(&model).await { - tracing::warn!("Failed to load model {}: {}", model, e); - } - } - }); + if let Err(err) = lmstudio_client.load_model(model).await { + tracing::warn!("Failed to load model {model}: {err}"); + } Ok(()) } diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index 8dfdbcd8717..219174df565 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -44,6 +44,7 @@ use codex_protocol::protocol::RolloutLine; use codex_state::log_db; use codex_utils_absolute_path::AbsolutePathBuf; use codex_utils_oss::ensure_oss_provider_ready; +use codex_utils_oss::fetch_oss_model_catalog; use codex_utils_oss::get_default_model_for_oss_provider; use cwd_prompt::CwdPromptAction; use cwd_prompt::CwdPromptOutcome; @@ -591,7 +592,7 @@ async fn run_ratatui_app( should_show_onboarding(login_status, &initial_config, should_show_trust_screen_flag); let mut trust_decision_was_made = false; - let config = if should_show_onboarding { + let mut config = if should_show_onboarding { let show_login_screen = should_show_login_screen(login_status, &initial_config); let onboarding_result = run_onboarding_app( OnboardingScreenArgs { @@ -644,6 +645,19 @@ async fn run_ratatui_app( initial_config }; + if config.model_catalog.is_none() { + let provider_id = config.model_provider_id.as_str(); + match fetch_oss_model_catalog(provider_id, &config).await { + Ok(Some(catalog)) => { + config.model_catalog = Some(catalog); + } + Ok(None) => {} + Err(err) => { + tracing::warn!("Failed to fetch OSS model catalog for {provider_id}: {err}"); + } + } + } + let mut missing_session_exit = |id_str: &str, action: &str| { error!("Error finding conversation path: {id_str}"); restore(); diff --git a/codex-rs/utils/oss/src/lib.rs b/codex-rs/utils/oss/src/lib.rs index a44a6a7d326..32b4028fe46 100644 --- a/codex-rs/utils/oss/src/lib.rs +++ b/codex-rs/utils/oss/src/lib.rs @@ -3,6 +3,7 @@ use codex_core::LMSTUDIO_OSS_PROVIDER_ID; use codex_core::OLLAMA_OSS_PROVIDER_ID; use codex_core::config::Config; +use codex_lmstudio::ModelsResponse; /// Returns the default model for a given OSS provider. pub fn get_default_model_for_oss_provider(provider_id: &str) -> Option<&'static str> { @@ -37,6 +38,26 @@ pub async fn ensure_oss_provider_ready( Ok(()) } +/// Fetch a provider-specific model catalog, if supported. +pub async fn fetch_oss_model_catalog( + provider_id: &str, + config: &Config, +) -> Result, std::io::Error> { + match provider_id { + LMSTUDIO_OSS_PROVIDER_ID => { + let client = codex_lmstudio::LMStudioClient::try_from_provider(config).await?; + let models = client.fetch_model_metadata().await?; + if models.is_empty() { + Ok(None) + } else { + Ok(Some(ModelsResponse { models })) + } + } + OLLAMA_OSS_PROVIDER_ID => Ok(None), + _ => Ok(None), + } +} + #[cfg(test)] mod tests { use super::*;