From 1ed5655cd450bc0f9a044b76a3db1573594ca131 Mon Sep 17 00:00:00 2001 From: leafx54 Date: Thu, 28 Aug 2025 01:17:22 -0400 Subject: [PATCH] Enhance local model handling: add provider detection, fetch local models, and improve JSON extraction in responses --- crates/agentic-core/src/cloud.rs | 19 +- crates/agentic-core/src/lib.rs | 42 ++++ crates/agentic-core/src/models.rs | 312 ++++++++++++++++++++++-- crates/agentic-core/src/orchestrator.rs | 21 +- crates/agentic-tui/src/ui/app.rs | 8 +- 5 files changed, 372 insertions(+), 30 deletions(-) diff --git a/crates/agentic-core/src/cloud.rs b/crates/agentic-core/src/cloud.rs index 1e605b2..22b3a77 100644 --- a/crates/agentic-core/src/cloud.rs +++ b/crates/agentic-core/src/cloud.rs @@ -125,8 +125,25 @@ pub async fn call_cloud_model( .map(|choice| &choice.message.content) .ok_or(CloudError::ParseError)?; + // Extract JSON from markdown code blocks if present + let clean_content = if message_content.contains("```json") { + // Extract content between ```json and ``` + if let Some(json_start) = message_content.find("```json") { + let after_start = &message_content[json_start + 7..]; // Skip "```json" + if let Some(json_end) = after_start.find("```") { + after_start[..json_end].trim() + } else { + message_content + } + } else { + message_content + } + } else { + message_content + }; + let atomic_note: AtomicNote = - serde_json::from_str(message_content).map_err(|_| CloudError::ParseError)?; + serde_json::from_str(clean_content).map_err(|_| CloudError::ParseError)?; Ok(atomic_note) } diff --git a/crates/agentic-core/src/lib.rs b/crates/agentic-core/src/lib.rs index 44f712a..ce242fa 100644 --- a/crates/agentic-core/src/lib.rs +++ b/crates/agentic-core/src/lib.rs @@ -295,6 +295,48 @@ mod tests { } } + #[tokio::test] + async fn test_provider_detection() { + use crate::models::{LocalProvider, ModelValidator}; + + println!("🧪 Testing provider detection system..."); + + let validator = ModelValidator::new(); + + // Test provider detection for common endpoints + let ollama_provider = validator.detect_provider_type("localhost:11434").await; + let lmstudio_provider = validator.detect_provider_type("localhost:1234").await; + let custom_provider = validator.detect_provider_type("localhost:8080").await; + + println!("Provider detection results:"); + println!(" Ollama (11434): {:?}", ollama_provider); + println!(" LM Studio (1234): {:?}", lmstudio_provider); + println!(" Custom (8080): {:?}", custom_provider); + + // Test fetching local models if Ollama is available + if ollama_provider == LocalProvider::Ollama { + println!("Testing local model fetching..."); + match validator.fetch_local_models("localhost:11434").await { + Ok(models) => { + println!("✅ Successfully fetched {} local models", models.len()); + for model in models.iter().take(3) { + println!(" - {} ({:?}, {})", model.name, model.provider, model.size); + } + } + Err(e) => { + println!( + "⚠️ Could not fetch local models (Ollama not running): {}", + e + ); + } + } + } else { + println!("⚠️ Ollama not detected, skipping local model test"); + } + + println!("✅ Provider detection test completed!"); + } + #[test] fn test_api_key_truncation() { // Test the API key display formatting (simulating the settings modal function) diff --git a/crates/agentic-core/src/models.rs b/crates/agentic-core/src/models.rs index 319b957..f433e11 100644 --- a/crates/agentic-core/src/models.rs +++ b/crates/agentic-core/src/models.rs @@ -17,6 +17,22 @@ pub struct OllamaModel { pub modified: String, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LocalModel { + pub name: String, + pub id: String, + pub provider: LocalProvider, + pub size: String, + pub modified: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum LocalProvider { + Ollama, + LMStudio, + OpenAI, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OpenRouterModel { pub id: String, @@ -64,6 +80,20 @@ struct ModelPricingRaw { completion: String, } +#[derive(Debug, Serialize, Deserialize)] +struct OpenAIListResponse { + data: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct OpenAIModelRaw { + id: String, + #[serde(default)] + name: Option, + #[serde(default)] + created: Option, +} + pub struct ModelValidator { client: Client, } @@ -78,6 +108,83 @@ impl ModelValidator { Self { client } } + pub async fn detect_provider_type(&self, endpoint: &str) -> LocalProvider { + // Try OpenAI/LM Studio API first for port 1234 + if endpoint.to_lowercase().contains("1234") + && self.test_openai_endpoint(endpoint).await.is_ok() + { + return LocalProvider::LMStudio; + } + + // Try Ollama API + if self.test_ollama_endpoint(endpoint).await.is_ok() { + return LocalProvider::Ollama; + } + + // Try generic OpenAI API + if self.test_openai_endpoint(endpoint).await.is_ok() { + return LocalProvider::OpenAI; + } + + // Default to Ollama if all detection fails + LocalProvider::Ollama + } + + async fn test_ollama_endpoint(&self, endpoint: &str) -> Result<()> { + let url = if endpoint.starts_with("http") { + format!("{}/api/tags", endpoint) + } else { + format!("http://{}/api/tags", endpoint) + }; + + let response = self.client.get(&url).send().await?; + if response.status().is_success() { + Ok(()) + } else { + Err(anyhow::anyhow!("Ollama endpoint not accessible")) + } + } + + async fn test_openai_endpoint(&self, endpoint: &str) -> Result<()> { + let normalized_endpoint = endpoint.to_lowercase(); + let url = if normalized_endpoint.starts_with("http") { + format!("{}/v1/models", normalized_endpoint) + } else { + format!("http://{}/v1/models", normalized_endpoint) + }; + + let response = self.client.get(&url).send().await?; + if response.status().is_success() { + Ok(()) + } else { + Err(anyhow::anyhow!("OpenAI endpoint not accessible")) + } + } + + pub async fn fetch_local_models(&self, endpoint: &str) -> Result> { + let provider = self.detect_provider_type(endpoint).await; + + match provider { + LocalProvider::Ollama => { + let ollama_models = self.fetch_ollama_models(endpoint).await?; + let local_models = ollama_models + .into_iter() + .map(|model| LocalModel { + name: model.name.clone(), + id: model.name, + provider: LocalProvider::Ollama, + size: model.size, + modified: model.modified, + }) + .collect(); + Ok(local_models) + } + LocalProvider::LMStudio | LocalProvider::OpenAI => { + self.fetch_openai_models(endpoint).await + } + } + } + pub async fn fetch_ollama_models(&self, endpoint: &str) -> Result> { let url = if endpoint.starts_with("http") { format!("{}/api/tags", endpoint) @@ -156,41 +263,116 @@ impl ModelValidator { Ok(models) } - pub async fn validate_local_endpoint(&self, endpoint: &str, model: &str) -> Result<()> { + pub async fn fetch_openai_models(&self, endpoint: &str) -> Result> { let url = if endpoint.starts_with("http") { - format!("{}/api/tags", endpoint) + format!("{}/v1/models", endpoint) } else { - format!("http://{}/api/tags", endpoint) + format!("http://{}/v1/models", endpoint) }; let response = self.client.get(&url).send().await?; if !response.status().is_success() { - return Err(anyhow::anyhow!("Local endpoint not accessible")); + return Err(anyhow::anyhow!("OpenAI/LM Studio endpoint not accessible")); } - let models: Value = response.json().await?; + let openai_response: OpenAIListResponse = response.json().await?; - if let Some(models_array) = models.get("models").and_then(|m| m.as_array()) { - let model_exists = models_array.iter().any(|m| { - m.get("name") - .and_then(|name| name.as_str()) - .map(|name| name == model) - .unwrap_or(false) - }); + let provider = if endpoint.contains("1234") { + LocalProvider::LMStudio + } else { + LocalProvider::OpenAI + }; - if model_exists { - Ok(()) - } else { - Err(anyhow::anyhow!( - "Model '{}' not found on local endpoint", - model - )) + let models = openai_response + .data + .into_iter() + .map(|raw| LocalModel { + name: raw.name.unwrap_or_else(|| raw.id.clone()), + id: raw.id, + provider: provider.clone(), + size: "Unknown".to_string(), + modified: "recently".to_string(), + }) + .collect(); + + Ok(models) + } + + pub async fn validate_local_endpoint(&self, endpoint: &str, model: &str) -> Result<()> { + let provider = self.detect_provider_type(endpoint).await; + + match provider { + LocalProvider::Ollama => { + let url = if endpoint.starts_with("http") { + format!("{}/api/tags", endpoint) + } else { + format!("http://{}/api/tags", endpoint) + }; + + let response = self.client.get(&url).send().await?; + if !response.status().is_success() { + return Err(anyhow::anyhow!("Local endpoint not accessible")); + } + + let models: Value = response.json().await?; + if let Some(models_array) = models.get("models").and_then(|m| m.as_array()) { + let model_exists = models_array.iter().any(|m| { + m.get("name") + .and_then(|name| name.as_str()) + .map(|name| name == model) + .unwrap_or(false) + }); + + if model_exists { + Ok(()) + } else { + Err(anyhow::anyhow!( + "Model '{}' not found on local endpoint", + model + )) + } + } else { + Err(anyhow::anyhow!( + "Invalid response format from local endpoint" + )) + } + } + LocalProvider::LMStudio | LocalProvider::OpenAI => { + let url = if endpoint.starts_with("http") { + format!("{}/v1/models", endpoint) + } else { + format!("http://{}/v1/models", endpoint) + }; + + let response = self.client.get(&url).send().await?; + if !response.status().is_success() { + return Err(anyhow::anyhow!("Local endpoint not accessible")); + } + + let models: Value = response.json().await?; + if let Some(models_array) = models.get("data").and_then(|m| m.as_array()) { + let model_exists = models_array.iter().any(|m| { + m.get("id") + .and_then(|id| id.as_str()) + .map(|id| id == model) + .unwrap_or(false) + }); + + if model_exists { + Ok(()) + } else { + Err(anyhow::anyhow!( + "Model '{}' not found on local endpoint", + model + )) + } + } else { + Err(anyhow::anyhow!( + "Invalid response format from local endpoint" + )) + } } - } else { - Err(anyhow::anyhow!( - "Invalid response format from local endpoint" - )) } } @@ -301,6 +483,22 @@ pub async fn call_local_model( endpoint: &str, model: &str, prompt: &str, +) -> Result { + let validator = ModelValidator::new(); + let provider = validator.detect_provider_type(endpoint).await; + + match provider { + LocalProvider::Ollama => call_ollama_model(endpoint, model, prompt).await, + LocalProvider::LMStudio | LocalProvider::OpenAI => { + call_openai_model(endpoint, model, prompt).await + } + } +} + +pub async fn call_ollama_model( + endpoint: &str, + model: &str, + prompt: &str, ) -> Result { let client = Client::new(); let url = if endpoint.starts_with("http") { @@ -328,6 +526,74 @@ pub async fn call_local_model( } } +#[derive(Serialize)] +struct OpenAIGenerationRequest<'a> { + model: &'a str, + messages: Vec, + max_tokens: u32, + temperature: f32, +} + +#[derive(Deserialize)] +struct OpenAIGenerationResponse { + choices: Vec, +} + +#[derive(Deserialize)] +struct OpenAIChoice { + message: OpenAIMessage, +} + +#[derive(Deserialize)] +struct OpenAIMessage { + content: String, +} + +pub async fn call_openai_model( + endpoint: &str, + model: &str, + prompt: &str, +) -> Result { + let client = Client::new(); + let url = if endpoint.starts_with("http") { + format!("{}/v1/chat/completions", endpoint) + } else { + format!("http://{}/v1/chat/completions", endpoint) + }; + + let payload = OpenAIGenerationRequest { + model, + messages: vec![serde_json::json!({ + "role": "user", + "content": prompt + })], + max_tokens: 2000, + temperature: 0.7, + }; + + let response = client.post(&url).json(&payload).send().await?; + + if response.status().is_success() { + let gen_response: OpenAIGenerationResponse = response.json().await?; + if let Some(choice) = gen_response.choices.first() { + Ok(choice.message.content.clone()) + } else { + Err(anyhow::anyhow!("No response choices from OpenAI model")) + } + } else { + let status = response.status(); + let error_text = response + .text() + .await + .unwrap_or_else(|_| "Unknown error".to_string()); + Err(anyhow::anyhow!( + "Failed to get response from OpenAI model. Status: {}. Error: {}", + status, + error_text + )) + } +} + impl Default for ModelValidator { fn default() -> Self { Self::new() diff --git a/crates/agentic-core/src/orchestrator.rs b/crates/agentic-core/src/orchestrator.rs index b34cca3..ef2e5a4 100644 --- a/crates/agentic-core/src/orchestrator.rs +++ b/crates/agentic-core/src/orchestrator.rs @@ -61,9 +61,26 @@ pub async fn generate_proposals( // Debug: Write the response to a file so we can see what came back std::fs::write("/tmp/debug_response.txt", &response_str).ok(); + // Extract JSON from markdown code blocks if present + let clean_response = if response_str.contains("```json") { + // Extract content between ```json and ``` + if let Some(json_start) = response_str.find("```json") { + let after_start = &response_str[json_start + 7..]; // Skip "```json" + if let Some(json_end) = after_start.find("```") { + after_start[..json_end].trim() + } else { + &response_str + } + } else { + &response_str + } + } else { + &response_str + }; + // Attempt to find the start of the JSON object - if let Some(json_start) = response_str.find("{") { - let json_str = &response_str[json_start..]; + if let Some(json_start) = clean_response.find("{") { + let json_str = &clean_response[json_start..]; match serde_json::from_str::(json_str) { Ok(response) => { let proposals = response diff --git a/crates/agentic-tui/src/ui/app.rs b/crates/agentic-tui/src/ui/app.rs index a351a64..de9af0f 100644 --- a/crates/agentic-tui/src/ui/app.rs +++ b/crates/agentic-tui/src/ui/app.rs @@ -7,7 +7,7 @@ use super::{ }; use agentic_core::{ cloud::{self, CloudError}, - models::{AtomicNote, ModelValidator, OllamaModel, OpenRouterModel}, + models::{AtomicNote, LocalModel, ModelValidator, OpenRouterModel}, orchestrator, settings::{Settings, ValidationError}, theme::{Element, Theme}, @@ -56,7 +56,7 @@ pub enum AgentStatus { pub enum ValidationMessage { LocalValidationComplete(Result<(), ValidationError>), CloudValidationComplete(Result<(), ValidationError>), - LocalModelsLoaded(Result, anyhow::Error>), + LocalModelsLoaded(Result, anyhow::Error>), CloudModelsLoaded(Result, anyhow::Error>), } @@ -149,7 +149,7 @@ pub struct App { agent_rx: mpsc::UnboundedReceiver, agent_tx: mpsc::UnboundedSender, edit_buffer: String, - available_local_models: Vec, + available_local_models: Vec, available_cloud_models: Vec, selected_model_index: usize, current_page: usize, @@ -1523,7 +1523,7 @@ impl App { let endpoint = self.settings.endpoint.clone(); tokio::spawn(async move { let validator = ModelValidator::new(); - let result = validator.fetch_ollama_models(&endpoint).await; + let result = validator.fetch_local_models(&endpoint).await; let _ = tx.send(ValidationMessage::LocalModelsLoaded(result)); });