diff --git a/crates/forge_repo/proto/forge.proto b/crates/forge_repo/proto/forge.proto index 5ea339a85d..44aee800e6 100644 --- a/crates/forge_repo/proto/forge.proto +++ b/crates/forge_repo/proto/forge.proto @@ -47,6 +47,9 @@ service ForgeService { // Searches for needle in haystack using fuzzy search rpc FuzzySearch(FuzzySearchRequest) returns (FuzzySearchResponse); + + // Lists all available LLM providers and their configurations + rpc ListProviders(ListProvidersRequest) returns (ListProvidersResponse); } // Node types @@ -360,3 +363,67 @@ message SearchMatch { uint32 start_line = 1; uint32 end_line = 2; } + +// Provider-related messages + +message ListProvidersRequest {} + +message ListProvidersResponse { + repeated Provider providers = 1; +} + +message Provider { + string id = 1; + optional string api_key_vars = 2; + repeated string url_param_vars = 3; + optional string response_type = 4; + string url = 5; + ProviderModels models = 6; + repeated AuthMethod auth_methods = 7; + map custom_headers = 8; + optional string provider_type = 9; +} + +message ProviderModels { + oneof kind { + string url = 1; + ModelList model_list = 2; + } +} + +message ModelList { + repeated Model models = 1; +} + +message Model { + string id = 1; + optional string name = 2; + optional string description = 3; + optional uint64 context_length = 4; + optional bool tools_supported = 5; + optional bool supports_parallel_tool_calls = 6; + optional bool supports_reasoning = 7; + repeated string input_modalities = 8; +} + +message AuthMethod { + oneof method { + string api_key = 1; + OAuthConfig oauth_device = 2; + OAuthConfig oauth_code = 3; + string google_adc = 4; + OAuthConfig codex_device = 5; + } +} + +message OAuthConfig { + string auth_url = 1; + string token_url = 2; + string client_id = 3; + repeated string scopes = 4; + optional string redirect_uri = 5; + bool use_pkce = 6; + optional string token_refresh_url = 7; + map custom_headers = 8; + map extra_auth_params = 9; +} diff --git a/crates/forge_repo/src/forge_repo.rs b/crates/forge_repo/src/forge_repo.rs index dde20149e0..29b44b8ec6 100644 --- a/crates/forge_repo/src/forge_repo.rs +++ b/crates/forge_repo/src/forge_repo.rs @@ -171,7 +171,7 @@ impl +impl ProviderRepository for ForgeRepo { async fn get_all_providers(&self) -> anyhow::Result> { diff --git a/crates/forge_repo/src/provider/provider_repo.rs b/crates/forge_repo/src/provider/provider_repo.rs index 57863a0ba4..4e7dfe6624 100644 --- a/crates/forge_repo/src/provider/provider_repo.rs +++ b/crates/forge_repo/src/provider/provider_repo.rs @@ -1,8 +1,9 @@ -use std::sync::{Arc, LazyLock}; +use std::sync::Arc; +use anyhow::Context; use bytes::Bytes; use forge_app::domain::{ProviderId, ProviderResponse}; -use forge_app::{EnvironmentInfra, FileReaderInfra, FileWriterInfra, HttpInfra}; +use forge_app::{EnvironmentInfra, FileReaderInfra, FileWriterInfra, GrpcInfra, HttpInfra}; use forge_domain::{ AnyProvider, ApiKey, AuthCredential, AuthDetails, Error, MigrationResult, Provider, ProviderRepository, ProviderType, URLParam, URLParamValue, @@ -10,6 +11,9 @@ use forge_domain::{ use merge::Merge; use serde::Deserialize; +use crate::proto_generated::ListProvidersRequest; +use crate::proto_generated::forge_service_client::ForgeServiceClient; + /// Represents the source of models for a provider #[derive(Debug, Clone, Deserialize)] #[serde(untagged)] @@ -101,15 +105,120 @@ impl From<&ProviderConfig> for forge_domain::ProviderTemplate { } } -static PROVIDER_CONFIGS: LazyLock> = LazyLock::new(|| { - let json_str = include_str!("provider.json"); - serde_json::from_str(json_str) - .map_err(|e| anyhow::anyhow!("Failed to parse embedded provider configs: {e}")) - .unwrap() -}); +impl TryFrom for forge_domain::OAuthConfig { + type Error = anyhow::Error; + + fn try_from(proto: crate::proto_generated::OAuthConfig) -> Result { + Ok(Self { + auth_url: proto.auth_url.parse().context("Invalid auth_url")?, + token_url: proto.token_url.parse().context("Invalid token_url")?, + client_id: forge_domain::ClientId::from(proto.client_id), + scopes: proto.scopes, + redirect_uri: proto.redirect_uri, + use_pkce: proto.use_pkce, + token_refresh_url: proto + .token_refresh_url + .map(|u| u.parse()) + .transpose() + .context("Invalid token_refresh_url")?, + custom_headers: if proto.custom_headers.is_empty() { + None + } else { + Some(proto.custom_headers) + }, + extra_auth_params: if proto.extra_auth_params.is_empty() { + None + } else { + Some(proto.extra_auth_params) + }, + }) + } +} -fn get_provider_configs() -> &'static Vec { - &PROVIDER_CONFIGS +impl TryFrom for forge_domain::AuthMethod { + type Error = anyhow::Error; + + fn try_from(proto: crate::proto_generated::AuthMethod) -> Result { + use crate::proto_generated::auth_method::Method; + let method = proto.method.context("AuthMethod has no method set")?; + match method { + Method::ApiKey(_) => Ok(Self::ApiKey), + Method::OauthDevice(cfg) => Ok(Self::OAuthDevice(cfg.try_into()?)), + Method::OauthCode(cfg) => Ok(Self::OAuthCode(cfg.try_into()?)), + Method::GoogleAdc(_) => Ok(Self::GoogleAdc), + Method::CodexDevice(cfg) => Ok(Self::CodexDevice(cfg.try_into()?)), + } + } +} + +impl TryFrom for ProviderConfig { + type Error = anyhow::Error; + + fn try_from(proto: crate::proto_generated::Provider) -> Result { + use crate::proto_generated::provider_models; + + let models = proto.models.and_then(|m| { + m.kind.map(|kind| match kind { + provider_models::Kind::Url(url) => Models::Url(url), + provider_models::Kind::ModelList(list) => { + let domain_models = list + .models + .into_iter() + .map(|m| forge_app::domain::Model { + id: forge_app::domain::ModelId::new(m.id), + name: m.name, + description: m.description, + context_length: m.context_length, + tools_supported: m.tools_supported, + supports_parallel_tool_calls: m.supports_parallel_tool_calls, + supports_reasoning: m.supports_reasoning, + input_modalities: m + .input_modalities + .into_iter() + .filter_map(|s| s.parse::().ok()) + .collect(), + }) + .collect(); + Models::Hardcoded(domain_models) + } + }) + }); + + let response_type = proto + .response_type + .map(|s| serde_json::from_value(serde_json::Value::String(s))) + .transpose() + .context("Invalid response_type")?; + + let provider_type = proto + .provider_type + .map(|s| s.parse::()) + .transpose() + .map_err(|e| anyhow::anyhow!("Invalid provider_type: {e}"))? + .unwrap_or_default(); + + let auth_methods = proto + .auth_methods + .into_iter() + .map(TryFrom::try_from) + .collect::, _>>()?; + + Ok(ProviderConfig { + id: ProviderId::from(proto.id), + provider_type, + api_key_vars: proto.api_key_vars, + url_param_vars: proto.url_param_vars, + response_type, + url: proto.url, + models, + auth_methods, + custom_headers: if proto.custom_headers.is_empty() { + None + } else { + Some(proto.custom_headers) + }, + }) + } } pub struct ForgeProviderRepository { @@ -122,9 +231,25 @@ impl ForgeProviderRepository { } } -impl +impl ForgeProviderRepository { + /// Fetches provider configurations from the remote gRPC server. + async fn get_provider_configs(&self) -> anyhow::Result> { + let channel = self.infra.channel(); + let mut client = ForgeServiceClient::new(channel); + let response: crate::proto_generated::ListProvidersResponse = client + .list_providers(tonic::Request::new(ListProvidersRequest {})) + .await + .context("Failed to call list_providers gRPC")? + .into_inner(); + response + .providers + .into_iter() + .map(TryFrom::try_from) + .collect() + } + async fn get_custom_provider_configs(&self) -> anyhow::Result> { let environment = self.infra.get_environment(); let provider_json_path = environment.base_path.join("provider.json"); @@ -134,8 +259,8 @@ impl Ok(configs) } - async fn get_providers(&self) -> Vec { - let configs = self.get_merged_configs().await; + async fn get_providers(&self) -> anyhow::Result> { + let configs = self.get_merged_configs().await?; let mut providers: Vec = Vec::new(); for config in configs { @@ -162,7 +287,7 @@ impl // ordering providers.sort_by_key(|a| a.id()); - providers + Ok(providers) } /// Migrates environment variable-based credentials to file-based @@ -178,7 +303,7 @@ impl let mut credentials = Vec::new(); let mut migrated_providers = Vec::new(); - let configs = self.get_merged_configs().await; + let configs = self.get_merged_configs().await?; let has_openai_url = self.infra.get_env_var("OPENAI_URL").is_some(); let has_anthropic_url = self.infra.get_env_var("ANTHROPIC_URL").is_some(); @@ -374,7 +499,7 @@ impl // Look up provider from cached providers - return configured template providers self.get_providers() - .await + .await? .iter() .find_map(|p| match p { AnyProvider::Template(tp) if tp.id == id && tp.credential.is_some() => { @@ -385,15 +510,15 @@ impl .ok_or_else(|| Error::provider_not_available(id).into()) } - /// Returns merged provider configs (embedded + custom) - async fn get_merged_configs(&self) -> Vec { - let mut configs = ProviderConfigs(get_provider_configs().clone()); - // Merge custom configs into embedded configs + /// Returns merged provider configs (gRPC + custom) + async fn get_merged_configs(&self) -> anyhow::Result> { + let mut configs = ProviderConfigs(self.get_provider_configs().await?); + // Merge custom configs into base configs configs.merge(ProviderConfigs( self.get_custom_provider_configs().await.unwrap_or_default(), )); - configs.0 + Ok(configs.0) } async fn read_credentials(&self) -> Vec { @@ -416,11 +541,11 @@ impl } #[async_trait::async_trait] -impl ProviderRepository - for ForgeProviderRepository +impl + ProviderRepository for ForgeProviderRepository { async fn get_all_providers(&self) -> anyhow::Result> { - Ok(self.get_providers().await) + self.get_providers().await } async fn get_provider(&self, id: ProviderId) -> anyhow::Result { @@ -460,168 +585,6 @@ impl } } -#[cfg(test)] -mod tests { - use forge_app::domain::{AuthMethod, ProviderResponse}; - use pretty_assertions::assert_eq; - - use super::*; - - #[test] - fn test_load_provider_configs() { - let configs = get_provider_configs(); - assert!(!configs.is_empty()); - - // Test that OpenRouter config is loaded correctly - let openrouter_config = configs - .iter() - .find(|c| c.id == ProviderId::OPEN_ROUTER) - .unwrap(); - assert_eq!( - openrouter_config.api_key_vars, - Some("OPENROUTER_API_KEY".to_string()) - ); - assert_eq!(openrouter_config.url_param_vars, Vec::::new()); - assert_eq!( - openrouter_config.response_type, - Some(ProviderResponse::OpenAI) - ); - assert_eq!( - openrouter_config.url.as_str(), - "https://openrouter.ai/api/v1/chat/completions" - ); - } - - #[test] - fn test_vertex_ai_config() { - let configs = get_provider_configs(); - let config = configs - .iter() - .find(|c| c.id == ProviderId::VERTEX_AI) - .unwrap(); - assert_eq!(config.id, ProviderId::VERTEX_AI); - assert_eq!( - config.api_key_vars, - Some("VERTEX_AI_AUTH_TOKEN".to_string()) - ); - assert_eq!( - config.url_param_vars, - vec!["PROJECT_ID".to_string(), "LOCATION".to_string()] - ); - assert_eq!(config.response_type, Some(ProviderResponse::Google)); - assert!(&config.url.contains("{{")); - assert!(&config.url.contains("}}")); - - // Verify both auth methods are supported - assert!(config.auth_methods.contains(&AuthMethod::ApiKey)); - assert!(config.auth_methods.contains(&AuthMethod::GoogleAdc)); - } - - #[test] - fn test_azure_config() { - let configs = get_provider_configs(); - let config = configs.iter().find(|c| c.id == ProviderId::AZURE).unwrap(); - assert_eq!(config.id, ProviderId::AZURE); - assert_eq!(config.api_key_vars, Some("AZURE_API_KEY".to_string())); - assert_eq!( - config.url_param_vars, - vec![ - "AZURE_RESOURCE_NAME".to_string(), - "AZURE_DEPLOYMENT_NAME".to_string(), - "AZURE_API_VERSION".to_string() - ] - ); - assert_eq!(config.response_type, Some(ProviderResponse::OpenAI)); - - // Check URL (now contains full chat completion URL) - let url = &config.url; - assert!(url.contains("{{")); - assert!(url.contains("}}")); - assert!(url.contains("openai.azure.com")); - assert!(url.contains("api-version")); - assert!(url.contains("deployments")); - assert!(url.contains("chat/completions")); - - // Check models exists and contains expected elements - match config.models.as_ref().unwrap() { - Models::Url(model_url) => { - assert!(model_url.contains("api-version")); - assert!(model_url.contains("/models")); - } - Models::Hardcoded(_) => panic!("Expected Models::Url variant"), - } - } - - #[test] - fn test_openai_compatible_config() { - let configs = get_provider_configs(); - let config = configs - .iter() - .find(|c| c.id == ProviderId::OPENAI_COMPATIBLE) - .unwrap(); - assert_eq!(config.id, ProviderId::OPENAI_COMPATIBLE); - assert_eq!(config.api_key_vars, Some("OPENAI_API_KEY".to_string())); - assert_eq!(config.url_param_vars, vec!["OPENAI_URL".to_string()]); - assert_eq!(config.response_type, Some(ProviderResponse::OpenAI)); - assert!(&config.url.contains("{{OPENAI_URL}}")); - } - - #[test] - fn test_openai_responses_compatible_config() { - let configs = get_provider_configs(); - let config = configs - .iter() - .find(|c| c.id == ProviderId::OPENAI_RESPONSES_COMPATIBLE) - .unwrap(); - assert_eq!(config.id, ProviderId::OPENAI_RESPONSES_COMPATIBLE); - assert_eq!(config.api_key_vars, Some("OPENAI_API_KEY".to_string())); - assert_eq!(config.url_param_vars, vec!["OPENAI_URL".to_string()]); - assert_eq!( - config.response_type, - Some(ProviderResponse::OpenAIResponses) - ); - assert_eq!(config.url, "{{OPENAI_URL}}/responses"); - match config.models.as_ref().unwrap() { - Models::Url(model_url) => assert_eq!(model_url, "{{OPENAI_URL}}/models"), - Models::Hardcoded(_) => panic!("Expected Models::Url variant"), - } - } - - #[test] - fn test_anthropic_compatible_config() { - let configs = get_provider_configs(); - let config = configs - .iter() - .find(|c| c.id == ProviderId::ANTHROPIC_COMPATIBLE) - .unwrap(); - assert_eq!(config.id, ProviderId::ANTHROPIC_COMPATIBLE); - assert_eq!(config.api_key_vars, Some("ANTHROPIC_API_KEY".to_string())); - assert_eq!(config.url_param_vars, vec!["ANTHROPIC_URL".to_string()]); - assert_eq!(config.response_type, Some(ProviderResponse::Anthropic)); - assert!(config.url.contains("{{ANTHROPIC_URL}}")); - } - - #[test] - fn test_io_intelligence_config() { - let configs = get_provider_configs(); - let config = configs - .iter() - .find(|c| c.id == ProviderId::IO_INTELLIGENCE) - .unwrap(); - assert_eq!(config.id, ProviderId::IO_INTELLIGENCE); - assert_eq!( - config.api_key_vars, - Some("IO_INTELLIGENCE_API_KEY".to_string()) - ); - assert_eq!(config.url_param_vars, Vec::::new()); - assert_eq!(config.response_type, Some(ProviderResponse::OpenAI)); - assert_eq!( - config.url.as_str(), - "https://api.intelligence.io.solutions/api/v1/chat/completions" - ); - } -} - #[cfg(test)] mod env_tests { use std::collections::{BTreeMap, HashMap}; @@ -770,6 +733,14 @@ mod env_tests { } } + impl GrpcInfra for MockInfra { + fn channel(&self) -> tonic::transport::Channel { + tonic::transport::Channel::from_static("http://[::1]:50051").connect_lazy() + } + + fn hydrate(&self) {} + } + #[async_trait::async_trait] impl ChatRepository for MockInfra { async fn chat( @@ -1021,7 +992,7 @@ mod env_tests { registry.migrate_env_to_file().await.unwrap(); // Get Azure config from embedded configs - let configs = get_provider_configs(); + let configs = registry.get_provider_configs().await.unwrap(); let azure_config = configs .iter() .find(|c| c.id == ProviderId::AZURE) @@ -1243,6 +1214,14 @@ mod env_tests { } } + impl GrpcInfra for CustomMockInfra { + fn channel(&self) -> tonic::transport::Channel { + tonic::transport::Channel::from_static("http://[::1]:50051").connect_lazy() + } + + fn hydrate(&self) {} + } + #[async_trait::async_trait] impl ChatRepository for CustomMockInfra { async fn chat( @@ -1298,7 +1277,7 @@ mod env_tests { let registry = ForgeProviderRepository::new(infra); // Get merged configs - let merged_configs = registry.get_merged_configs().await; + let merged_configs = registry.get_merged_configs().await.unwrap(); // Verify OpenAI config was overridden let openai_config = merged_configs