From 053108b96ca2fa41c7372022d284cd1b7a422956 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 12 Mar 2026 12:27:38 +0000 Subject: [PATCH] add native Gemini provider support via hermesllm transforms --- .../brightstaff/src/handlers/router_chat.rs | 4 +- crates/hermesllm/src/apis/gemini.rs | 744 ++++++++++++++++++ crates/hermesllm/src/apis/mod.rs | 2 + crates/hermesllm/src/clients/endpoints.rs | 106 ++- crates/hermesllm/src/lib.rs | 1 + crates/hermesllm/src/providers/id.rs | 65 +- crates/hermesllm/src/providers/request.rs | 276 ++++++- crates/hermesllm/src/providers/response.rs | 148 ++++ .../src/providers/streaming_response.rs | 5 + crates/hermesllm/src/transforms/mod.rs | 1 + .../src/transforms/request/from_gemini.rs | 327 ++++++++ .../hermesllm/src/transforms/request/mod.rs | 2 + .../src/transforms/request/to_gemini.rs | 323 ++++++++ .../src/transforms/response/from_gemini.rs | 417 ++++++++++ .../hermesllm/src/transforms/response/mod.rs | 1 + crates/llm_gateway/src/stream_context.rs | 4 +- 16 files changed, 2416 insertions(+), 10 deletions(-) create mode 100644 crates/hermesllm/src/apis/gemini.rs create mode 100644 crates/hermesllm/src/transforms/request/from_gemini.rs create mode 100644 crates/hermesllm/src/transforms/request/to_gemini.rs create mode 100644 crates/hermesllm/src/transforms/response/from_gemini.rs diff --git a/crates/brightstaff/src/handlers/router_chat.rs b/crates/brightstaff/src/handlers/router_chat.rs index 910e5408e..30adb1f6f 100644 --- a/crates/brightstaff/src/handlers/router_chat.rs +++ b/crates/brightstaff/src/handlers/router_chat.rs @@ -53,7 +53,9 @@ pub async fn router_chat_get_upstream_model( ProviderRequestType::MessagesRequest(_) | ProviderRequestType::BedrockConverse(_) | ProviderRequestType::BedrockConverseStream(_) - | ProviderRequestType::ResponsesAPIRequest(_), + | ProviderRequestType::ResponsesAPIRequest(_) + | ProviderRequestType::GeminiGenerateContent(_) + | ProviderRequestType::GeminiStreamGenerateContent(_), ) => { warn!("unexpected: got non-ChatCompletions request after converting to OpenAI format"); return Err(RoutingError::internal_error( diff --git a/crates/hermesllm/src/apis/gemini.rs b/crates/hermesllm/src/apis/gemini.rs new file mode 100644 index 000000000..39b4c81f3 --- /dev/null +++ b/crates/hermesllm/src/apis/gemini.rs @@ -0,0 +1,744 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use serde_with::skip_serializing_none; +use std::collections::HashMap; + +use super::ApiDefinition; +use crate::providers::request::{ProviderRequest, ProviderRequestError}; +use crate::providers::response::TokenUsage; +use crate::providers::streaming_response::ProviderStreamResponse; +use crate::transforms::lib::ExtractText; +use crate::GENERATE_CONTENT_PATH_SUFFIX; + +// ============================================================================ +// GEMINI GENERATE CONTENT API ENUMERATION +// ============================================================================ + +/// Enum for all supported Gemini GenerateContent APIs +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum GeminiApi { + GenerateContent, + StreamGenerateContent, +} + +impl ApiDefinition for GeminiApi { + fn endpoint(&self) -> &'static str { + match self { + GeminiApi::GenerateContent => ":generateContent", + GeminiApi::StreamGenerateContent => ":streamGenerateContent", + } + } + + fn from_endpoint(endpoint: &str) -> Option { + if endpoint.ends_with(":streamGenerateContent") { + Some(GeminiApi::StreamGenerateContent) + } else if endpoint.ends_with(GENERATE_CONTENT_PATH_SUFFIX) { + Some(GeminiApi::GenerateContent) + } else { + None + } + } + + fn supports_streaming(&self) -> bool { + match self { + GeminiApi::GenerateContent => false, + GeminiApi::StreamGenerateContent => true, + } + } + + fn supports_tools(&self) -> bool { + true + } + + fn supports_vision(&self) -> bool { + true + } + + fn all_variants() -> Vec { + vec![GeminiApi::GenerateContent, GeminiApi::StreamGenerateContent] + } +} + +// ============================================================================ +// REQUEST TYPES +// ============================================================================ + +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct GenerateContentRequest { + /// Internal model field — not part of Gemini wire format (model is in the URL). + /// Populated during parsing and used for routing. + #[serde(skip_serializing, default)] + pub model: String, + + pub contents: Vec, + pub generation_config: Option, + pub tools: Option>, + pub tool_config: Option, + pub safety_settings: Option>, + pub system_instruction: Option, + pub cached_content: Option, + + #[serde(skip_serializing)] + pub metadata: Option>, +} + +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Content { + pub role: Option, + pub parts: Vec, +} + +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Part { + pub text: Option, + pub inline_data: Option, + pub function_call: Option, + pub function_response: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InlineData { + pub mime_type: String, + pub data: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FunctionCall { + pub name: String, + pub args: Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FunctionResponse { + pub name: String, + pub response: Value, +} + +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct GenerationConfig { + pub temperature: Option, + pub top_p: Option, + pub top_k: Option, + pub max_output_tokens: Option, + pub stop_sequences: Option>, + pub response_mime_type: Option, + pub candidate_count: Option, + pub presence_penalty: Option, + pub frequency_penalty: Option, +} + +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Tool { + pub function_declarations: Option>, + pub code_execution: Option, +} + +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FunctionDeclaration { + pub name: String, + pub description: Option, + pub parameters: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolConfig { + pub function_calling_config: FunctionCallingConfig, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FunctionCallingConfig { + pub mode: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SafetySetting { + pub category: String, + pub threshold: String, +} + +// ============================================================================ +// RESPONSE TYPES +// ============================================================================ + +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct GenerateContentResponse { + pub candidates: Option>, + pub usage_metadata: Option, + pub prompt_feedback: Option, + pub model_version: Option, +} + +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Candidate { + pub content: Option, + pub finish_reason: Option, + pub safety_ratings: Option>, +} + +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct UsageMetadata { + pub prompt_token_count: Option, + pub candidates_token_count: Option, + pub total_token_count: Option, +} + +impl TokenUsage for UsageMetadata { + fn completion_tokens(&self) -> usize { + self.candidates_token_count.unwrap_or(0) as usize + } + + fn prompt_tokens(&self) -> usize { + self.prompt_token_count.unwrap_or(0) as usize + } + + fn total_tokens(&self) -> usize { + self.total_token_count.unwrap_or(0) as usize + } +} + +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptFeedback { + pub block_reason: Option, + pub safety_ratings: Option>, +} + +#[skip_serializing_none] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SafetyRating { + pub category: String, + pub probability: String, + pub blocked: Option, +} + +// ============================================================================ +// PROVIDER REQUEST TRAIT IMPLEMENTATION +// ============================================================================ + +impl ProviderRequest for GenerateContentRequest { + fn model(&self) -> &str { + &self.model + } + + fn set_model(&mut self, model: String) { + self.model = model; + } + + fn is_streaming(&self) -> bool { + // Gemini uses URL-based streaming, not a field in the request body + false + } + + fn extract_messages_text(&self) -> String { + let mut parts_text = Vec::new(); + for content in &self.contents { + for part in &content.parts { + if let Some(text) = &part.text { + parts_text.push(text.clone()); + } + } + } + if let Some(system) = &self.system_instruction { + for part in &system.parts { + if let Some(text) = &part.text { + parts_text.push(text.clone()); + } + } + } + parts_text.join(" ") + } + + fn get_recent_user_message(&self) -> Option { + self.contents + .iter() + .rev() + .find(|c| c.role.as_deref() == Some("user")) + .and_then(|c| { + c.parts + .iter() + .filter_map(|p| p.text.clone()) + .collect::>() + .first() + .cloned() + }) + } + + fn get_tool_names(&self) -> Option> { + self.tools.as_ref().map(|tools| { + tools + .iter() + .filter_map(|t| t.function_declarations.as_ref()) + .flatten() + .map(|f| f.name.clone()) + .collect() + }) + } + + fn to_bytes(&self) -> Result, ProviderRequestError> { + serde_json::to_vec(self).map_err(|e| ProviderRequestError { + message: format!("Failed to serialize GenerateContentRequest: {}", e), + source: Some(Box::new(e)), + }) + } + + fn metadata(&self) -> &Option> { + &self.metadata + } + + fn remove_metadata_key(&mut self, key: &str) -> bool { + if let Some(ref mut metadata) = self.metadata { + metadata.remove(key).is_some() + } else { + false + } + } + + fn get_temperature(&self) -> Option { + self.generation_config + .as_ref() + .and_then(|gc| gc.temperature) + } + + fn get_messages(&self) -> Vec { + use crate::apis::openai::{Message, MessageContent, Role}; + + let mut messages = Vec::new(); + + // Convert system instruction + if let Some(system) = &self.system_instruction { + let text = system + .parts + .iter() + .filter_map(|p| p.text.clone()) + .collect::>() + .join(""); + if !text.is_empty() { + messages.push(Message { + role: Role::System, + content: Some(MessageContent::Text(text)), + name: None, + tool_calls: None, + tool_call_id: None, + }); + } + } + + // Convert contents + for content in &self.contents { + let role = match content.role.as_deref() { + Some("model") => Role::Assistant, + _ => Role::User, + }; + + let text = content + .parts + .iter() + .filter_map(|p| p.text.clone()) + .collect::>() + .join(""); + + messages.push(Message { + role, + content: Some(MessageContent::Text(text)), + name: None, + tool_calls: None, + tool_call_id: None, + }); + } + + messages + } + + fn set_messages(&mut self, messages: &[crate::apis::openai::Message]) { + use crate::apis::openai::Role; + + self.contents.clear(); + self.system_instruction = None; + + for msg in messages { + let text = msg.content.extract_text(); + match msg.role { + Role::System => { + self.system_instruction = Some(Content { + role: Some("user".to_string()), + parts: vec![Part { + text: Some(text), + inline_data: None, + function_call: None, + function_response: None, + }], + }); + } + Role::User => { + self.contents.push(Content { + role: Some("user".to_string()), + parts: vec![Part { + text: Some(text), + inline_data: None, + function_call: None, + function_response: None, + }], + }); + } + Role::Assistant => { + self.contents.push(Content { + role: Some("model".to_string()), + parts: vec![Part { + text: Some(text), + inline_data: None, + function_call: None, + function_response: None, + }], + }); + } + Role::Tool => { + self.contents.push(Content { + role: Some("user".to_string()), + parts: vec![Part { + text: Some(text), + inline_data: None, + function_call: None, + function_response: None, + }], + }); + } + } + } + } +} + +// ============================================================================ +// PROVIDER STREAM RESPONSE TRAIT IMPLEMENTATION +// ============================================================================ + +impl ProviderStreamResponse for GenerateContentResponse { + fn content_delta(&self) -> Option<&str> { + self.candidates + .as_ref() + .and_then(|candidates| candidates.first()) + .and_then(|candidate| candidate.content.as_ref()) + .and_then(|content| content.parts.first()) + .and_then(|part| part.text.as_deref()) + } + + fn is_final(&self) -> bool { + self.candidates + .as_ref() + .and_then(|candidates| candidates.first()) + .and_then(|candidate| candidate.finish_reason.as_deref()) + .map(|reason| reason == "STOP" || reason == "MAX_TOKENS" || reason == "SAFETY") + .unwrap_or(false) + } + + fn role(&self) -> Option<&str> { + self.candidates + .as_ref() + .and_then(|candidates| candidates.first()) + .and_then(|candidate| candidate.content.as_ref()) + .and_then(|content| content.role.as_deref()) + } + + fn event_type(&self) -> Option<&str> { + None // Gemini doesn't use SSE event types + } +} + +// ============================================================================ +// SERDE PARSING +// ============================================================================ + +impl TryFrom<&[u8]> for GenerateContentRequest { + type Error = serde_json::Error; + + fn try_from(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes) + } +} + +impl TryFrom<&[u8]> for GenerateContentResponse { + type Error = serde_json::Error; + + fn try_from(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes) + } +} + +// ============================================================================ +// TESTS +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_gemini_api_from_endpoint() { + assert_eq!( + GeminiApi::from_endpoint("/v1beta/models/gemini-pro:generateContent"), + Some(GeminiApi::GenerateContent) + ); + assert_eq!( + GeminiApi::from_endpoint("/v1beta/models/gemini-pro:streamGenerateContent"), + Some(GeminiApi::StreamGenerateContent) + ); + assert_eq!(GeminiApi::from_endpoint("/v1/chat/completions"), None); + } + + #[test] + fn test_generate_content_request_serde() { + let json_str = json!({ + "contents": [{ + "role": "user", + "parts": [{"text": "Hello"}] + }], + "generationConfig": { + "temperature": 0.7, + "maxOutputTokens": 1024 + } + }); + + let req: GenerateContentRequest = serde_json::from_value(json_str).unwrap(); + assert_eq!(req.contents.len(), 1); + assert_eq!(req.contents[0].role, Some("user".to_string())); + assert_eq!( + req.generation_config.as_ref().unwrap().temperature, + Some(0.7) + ); + assert_eq!( + req.generation_config.as_ref().unwrap().max_output_tokens, + Some(1024) + ); + + // Roundtrip + let bytes = serde_json::to_vec(&req).unwrap(); + let req2: GenerateContentRequest = serde_json::from_slice(&bytes).unwrap(); + assert_eq!(req2.contents.len(), 1); + } + + #[test] + fn test_generate_content_response_serde() { + let json_str = json!({ + "candidates": [{ + "content": { + "role": "model", + "parts": [{"text": "Hello! How can I help?"}] + }, + "finishReason": "STOP" + }], + "usageMetadata": { + "promptTokenCount": 5, + "candidatesTokenCount": 7, + "totalTokenCount": 12 + } + }); + + let resp: GenerateContentResponse = serde_json::from_value(json_str).unwrap(); + assert!(resp.candidates.is_some()); + let candidates = resp.candidates.as_ref().unwrap(); + assert_eq!(candidates.len(), 1); + assert_eq!(candidates[0].finish_reason.as_deref(), Some("STOP")); + assert_eq!( + resp.usage_metadata.as_ref().unwrap().prompt_token_count, + Some(5) + ); + } + + #[test] + fn test_generate_content_request_with_tools() { + let json_str = json!({ + "contents": [{ + "role": "user", + "parts": [{"text": "What's the weather?"}] + }], + "tools": [{ + "functionDeclarations": [{ + "name": "get_weather", + "description": "Get weather info", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"} + } + } + }] + }], + "toolConfig": { + "functionCallingConfig": { + "mode": "AUTO" + } + } + }); + + let req: GenerateContentRequest = serde_json::from_value(json_str).unwrap(); + assert!(req.tools.is_some()); + let tools = req.tools.as_ref().unwrap(); + assert_eq!(tools.len(), 1); + let decls = tools[0].function_declarations.as_ref().unwrap(); + assert_eq!(decls[0].name, "get_weather"); + assert_eq!( + req.tool_config + .as_ref() + .unwrap() + .function_calling_config + .mode, + "AUTO" + ); + } + + #[test] + fn test_generate_content_response_with_function_call() { + let json_str = json!({ + "candidates": [{ + "content": { + "role": "model", + "parts": [{ + "functionCall": { + "name": "get_weather", + "args": {"location": "NYC"} + } + }] + }, + "finishReason": "STOP" + }] + }); + + let resp: GenerateContentResponse = serde_json::from_value(json_str).unwrap(); + let candidates = resp.candidates.as_ref().unwrap(); + let parts = &candidates[0].content.as_ref().unwrap().parts; + assert!(parts[0].function_call.is_some()); + assert_eq!(parts[0].function_call.as_ref().unwrap().name, "get_weather"); + } + + #[test] + fn test_stream_response_content_delta() { + let resp = GenerateContentResponse { + candidates: Some(vec![Candidate { + content: Some(Content { + role: Some("model".to_string()), + parts: vec![Part { + text: Some("Hello".to_string()), + inline_data: None, + function_call: None, + function_response: None, + }], + }), + finish_reason: None, + safety_ratings: None, + }]), + usage_metadata: None, + prompt_feedback: None, + model_version: None, + }; + + assert_eq!(resp.content_delta(), Some("Hello")); + assert!(!resp.is_final()); + } + + #[test] + fn test_stream_response_is_final() { + let resp = GenerateContentResponse { + candidates: Some(vec![Candidate { + content: Some(Content { + role: Some("model".to_string()), + parts: vec![Part { + text: Some("Done".to_string()), + inline_data: None, + function_call: None, + function_response: None, + }], + }), + finish_reason: Some("STOP".to_string()), + safety_ratings: None, + }]), + usage_metadata: None, + prompt_feedback: None, + model_version: None, + }; + + assert!(resp.is_final()); + } + + #[test] + fn test_provider_request_extract_text() { + let req = GenerateContentRequest { + model: "gemini-pro".to_string(), + contents: vec![Content { + role: Some("user".to_string()), + parts: vec![Part { + text: Some("Hello world".to_string()), + inline_data: None, + function_call: None, + function_response: None, + }], + }], + system_instruction: Some(Content { + role: Some("user".to_string()), + parts: vec![Part { + text: Some("Be helpful".to_string()), + inline_data: None, + function_call: None, + function_response: None, + }], + }), + ..Default::default() + }; + + let text = req.extract_messages_text(); + assert!(text.contains("Hello world")); + assert!(text.contains("Be helpful")); + } + + #[test] + fn test_provider_request_get_tool_names() { + let req = GenerateContentRequest { + model: "gemini-pro".to_string(), + contents: vec![], + tools: Some(vec![Tool { + function_declarations: Some(vec![ + FunctionDeclaration { + name: "func_a".to_string(), + description: None, + parameters: None, + }, + FunctionDeclaration { + name: "func_b".to_string(), + description: None, + parameters: None, + }, + ]), + code_execution: None, + }]), + ..Default::default() + }; + + let names = req.get_tool_names().unwrap(); + assert_eq!(names, vec!["func_a", "func_b"]); + } +} diff --git a/crates/hermesllm/src/apis/mod.rs b/crates/hermesllm/src/apis/mod.rs index ea0563926..02ed36e21 100644 --- a/crates/hermesllm/src/apis/mod.rs +++ b/crates/hermesllm/src/apis/mod.rs @@ -1,5 +1,6 @@ pub mod amazon_bedrock; pub mod anthropic; +pub mod gemini; pub mod openai; pub mod openai_responses; pub mod streaming_shapes; @@ -10,6 +11,7 @@ pub use amazon_bedrock::{ Message as BedrockMessage, Tool as BedrockTool, ToolChoice as BedrockToolChoice, }; pub use anthropic::{AnthropicApi, MessagesRequest, MessagesResponse, MessagesStreamEvent}; +pub use gemini::{GeminiApi, GenerateContentRequest, GenerateContentResponse}; pub use openai::{ ChatCompletionsRequest, ChatCompletionsResponse, ChatCompletionsStreamResponse, OpenAIApi, }; diff --git a/crates/hermesllm/src/clients/endpoints.rs b/crates/hermesllm/src/clients/endpoints.rs index 23e146047..3984ea1a7 100644 --- a/crates/hermesllm/src/clients/endpoints.rs +++ b/crates/hermesllm/src/clients/endpoints.rs @@ -1,4 +1,4 @@ -use crate::apis::{AmazonBedrockApi, AnthropicApi, ApiDefinition, OpenAIApi}; +use crate::apis::{AmazonBedrockApi, AnthropicApi, ApiDefinition, GeminiApi, OpenAIApi}; use crate::ProviderId; use std::fmt; @@ -8,6 +8,7 @@ pub enum SupportedAPIsFromClient { OpenAIChatCompletions(OpenAIApi), AnthropicMessagesAPI(AnthropicApi), OpenAIResponsesAPI(OpenAIApi), + GeminiGenerateContentAPI(GeminiApi), } #[derive(Debug, Clone, PartialEq)] @@ -17,6 +18,8 @@ pub enum SupportedUpstreamAPIs { AmazonBedrockConverse(AmazonBedrockApi), AmazonBedrockConverseStream(AmazonBedrockApi), OpenAIResponsesAPI(OpenAIApi), + GeminiGenerateContent(GeminiApi), + GeminiStreamGenerateContent(GeminiApi), } impl fmt::Display for SupportedAPIsFromClient { @@ -31,6 +34,9 @@ impl fmt::Display for SupportedAPIsFromClient { SupportedAPIsFromClient::OpenAIResponsesAPI(api) => { write!(f, "OpenAI Responses ({})", api.endpoint()) } + SupportedAPIsFromClient::GeminiGenerateContentAPI(api) => { + write!(f, "Gemini ({})", api.endpoint()) + } } } } @@ -53,6 +59,12 @@ impl fmt::Display for SupportedUpstreamAPIs { SupportedUpstreamAPIs::OpenAIResponsesAPI(api) => { write!(f, "OpenAI Responses ({})", api.endpoint()) } + SupportedUpstreamAPIs::GeminiGenerateContent(api) => { + write!(f, "Gemini ({})", api.endpoint()) + } + SupportedUpstreamAPIs::GeminiStreamGenerateContent(api) => { + write!(f, "Gemini Stream ({})", api.endpoint()) + } } } } @@ -60,6 +72,13 @@ impl fmt::Display for SupportedUpstreamAPIs { impl SupportedAPIsFromClient { /// Create a SupportedApi from an endpoint path pub fn from_endpoint(endpoint: &str) -> Option { + // Check Gemini first since it uses suffix matching (`:generateContent`) + if let Some(gemini_api) = GeminiApi::from_endpoint(endpoint) { + return Some(SupportedAPIsFromClient::GeminiGenerateContentAPI( + gemini_api, + )); + } + if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) { // Check if this is the Responses API endpoint if openai_api == OpenAIApi::Responses { @@ -82,6 +101,7 @@ impl SupportedAPIsFromClient { SupportedAPIsFromClient::OpenAIChatCompletions(api) => api.endpoint(), SupportedAPIsFromClient::AnthropicMessagesAPI(api) => api.endpoint(), SupportedAPIsFromClient::OpenAIResponsesAPI(api) => api.endpoint(), + SupportedAPIsFromClient::GeminiGenerateContentAPI(api) => api.endpoint(), } } @@ -145,7 +165,18 @@ impl SupportedAPIsFromClient { } ProviderId::Gemini => { if request_path.starts_with("/v1/") { - build_endpoint("/v1beta/openai", endpoint_suffix) + // Use native Gemini endpoint + if !is_streaming { + build_endpoint( + "/v1beta", + &format!("/models/{}:generateContent", model_id), + ) + } else { + build_endpoint( + "/v1beta", + &format!("/models/{}:streamGenerateContent?alt=sse", model_id), + ) + } } else { build_endpoint("/v1", endpoint_suffix) } @@ -178,6 +209,20 @@ impl SupportedAPIsFromClient { build_endpoint("/v1", "/chat/completions") } } + ProviderId::Gemini => { + // Translate Anthropic → Gemini native + if !is_streaming { + build_endpoint( + "/v1beta", + &format!("/models/{}:generateContent", model_id), + ) + } else { + build_endpoint( + "/v1beta", + &format!("/models/{}:streamGenerateContent?alt=sse", model_id), + ) + } + } _ => build_endpoint("/v1", "/chat/completions"), } } @@ -186,6 +231,20 @@ impl SupportedAPIsFromClient { match provider_id { // Providers that support /v1/responses natively ProviderId::OpenAI | ProviderId::XAI => route_by_provider("/responses"), + ProviderId::Gemini => { + // Translate Responses → Gemini native + if !is_streaming { + build_endpoint( + "/v1beta", + &format!("/models/{}:generateContent", model_id), + ) + } else { + build_endpoint( + "/v1beta", + &format!("/models/{}:streamGenerateContent?alt=sse", model_id), + ) + } + } // All other providers: translate to /chat/completions _ => route_by_provider("/chat/completions"), } @@ -194,6 +253,33 @@ impl SupportedAPIsFromClient { // For Chat Completions API, use the standard chat/completions path route_by_provider("/chat/completions") } + SupportedAPIsFromClient::GeminiGenerateContentAPI(_) => { + match provider_id { + ProviderId::Gemini => { + // Native Gemini endpoint + if !is_streaming { + build_endpoint( + "/v1beta", + &format!("/models/{}:generateContent", model_id), + ) + } else { + build_endpoint( + "/v1beta", + &format!("/models/{}:streamGenerateContent?alt=sse", model_id), + ) + } + } + ProviderId::Anthropic => build_endpoint("/v1", "/messages"), + ProviderId::AmazonBedrock => { + if !is_streaming { + build_endpoint("", &format!("/model/{}/converse", model_id)) + } else { + build_endpoint("", &format!("/model/{}/converse-stream", model_id)) + } + } + _ => build_endpoint("/v1", "/chat/completions"), + } + } } } } @@ -201,6 +287,18 @@ impl SupportedAPIsFromClient { impl SupportedUpstreamAPIs { /// Create a SupportedUpstreamApi from an endpoint path pub fn from_endpoint(endpoint: &str) -> Option { + // Check Gemini first since it uses suffix matching + if let Some(gemini_api) = GeminiApi::from_endpoint(endpoint) { + return match gemini_api { + GeminiApi::GenerateContent => { + Some(SupportedUpstreamAPIs::GeminiGenerateContent(gemini_api)) + } + GeminiApi::StreamGenerateContent => Some( + SupportedUpstreamAPIs::GeminiStreamGenerateContent(gemini_api), + ), + }; + } + if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) { // Check if this is the Responses API endpoint if openai_api == OpenAIApi::Responses { @@ -396,7 +494,7 @@ mod tests { "/openai/deployments/gpt-4/chat/completions?api-version=2025-01-01-preview" ); - // Test Gemini provider + // Test Gemini provider (uses native Gemini API with transforms) assert_eq!( api.target_endpoint_for_provider( &ProviderId::Gemini, @@ -405,7 +503,7 @@ mod tests { false, None ), - "/v1beta/openai/chat/completions" + "/v1beta/models/gemini-pro:generateContent" ); } diff --git a/crates/hermesllm/src/lib.rs b/crates/hermesllm/src/lib.rs index 997fc72a4..3b8d69784 100644 --- a/crates/hermesllm/src/lib.rs +++ b/crates/hermesllm/src/lib.rs @@ -20,6 +20,7 @@ pub use providers::streaming_response::{ProviderStreamResponse, ProviderStreamRe pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; pub const OPENAI_RESPONSES_API_PATH: &str = "/v1/responses"; pub const MESSAGES_PATH: &str = "/v1/messages"; +pub const GENERATE_CONTENT_PATH_SUFFIX: &str = ":generateContent"; #[cfg(test)] mod tests { diff --git a/crates/hermesllm/src/providers/id.rs b/crates/hermesllm/src/providers/id.rs index 110087112..67b7643cd 100644 --- a/crates/hermesllm/src/providers/id.rs +++ b/crates/hermesllm/src/providers/id.rs @@ -1,4 +1,4 @@ -use crate::apis::{AmazonBedrockApi, AnthropicApi, OpenAIApi}; +use crate::apis::{AmazonBedrockApi, AnthropicApi, GeminiApi, OpenAIApi}; use crate::clients::endpoints::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; use serde::Deserialize; use std::collections::HashMap; @@ -116,7 +116,68 @@ impl ProviderId { is_streaming: bool, ) -> SupportedUpstreamAPIs { match (self, client_api) { + // ============================================================================ + // Gemini provider — use native Gemini APIs + // ============================================================================ + (ProviderId::Gemini, SupportedAPIsFromClient::GeminiGenerateContentAPI(_)) => { + if is_streaming { + SupportedUpstreamAPIs::GeminiStreamGenerateContent( + GeminiApi::StreamGenerateContent, + ) + } else { + SupportedUpstreamAPIs::GeminiGenerateContent(GeminiApi::GenerateContent) + } + } + (ProviderId::Gemini, SupportedAPIsFromClient::OpenAIChatCompletions(_)) => { + if is_streaming { + SupportedUpstreamAPIs::GeminiStreamGenerateContent( + GeminiApi::StreamGenerateContent, + ) + } else { + SupportedUpstreamAPIs::GeminiGenerateContent(GeminiApi::GenerateContent) + } + } + (ProviderId::Gemini, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => { + if is_streaming { + SupportedUpstreamAPIs::GeminiStreamGenerateContent( + GeminiApi::StreamGenerateContent, + ) + } else { + SupportedUpstreamAPIs::GeminiGenerateContent(GeminiApi::GenerateContent) + } + } + (ProviderId::Gemini, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => { + if is_streaming { + SupportedUpstreamAPIs::GeminiStreamGenerateContent( + GeminiApi::StreamGenerateContent, + ) + } else { + SupportedUpstreamAPIs::GeminiGenerateContent(GeminiApi::GenerateContent) + } + } + + // ============================================================================ + // Non-Gemini providers receiving Gemini-format requests + // ============================================================================ + (ProviderId::Anthropic, SupportedAPIsFromClient::GeminiGenerateContentAPI(_)) => { + SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) + } + (ProviderId::AmazonBedrock, SupportedAPIsFromClient::GeminiGenerateContentAPI(_)) => { + if is_streaming { + SupportedUpstreamAPIs::AmazonBedrockConverseStream( + AmazonBedrockApi::ConverseStream, + ) + } else { + SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse) + } + } + (_, SupportedAPIsFromClient::GeminiGenerateContentAPI(_)) => { + SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions) + } + + // ============================================================================ // Claude/Anthropic providers natively support Anthropic APIs + // ============================================================================ (ProviderId::Anthropic, SupportedAPIsFromClient::AnthropicMessagesAPI(_)) => { SupportedUpstreamAPIs::AnthropicMessagesAPI(AnthropicApi::Messages) } @@ -136,7 +197,6 @@ impl ProviderId { | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch - | ProviderId::Gemini | ProviderId::GitHub | ProviderId::AzureOpenAI | ProviderId::XAI @@ -154,7 +214,6 @@ impl ProviderId { | ProviderId::Mistral | ProviderId::Deepseek | ProviderId::Arch - | ProviderId::Gemini | ProviderId::GitHub | ProviderId::AzureOpenAI | ProviderId::XAI diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index 92688133c..598c4fb65 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -1,5 +1,7 @@ use crate::apis::anthropic::MessagesRequest; +use crate::apis::gemini::GenerateContentRequest; use crate::apis::openai::ChatCompletionsRequest; +use crate::apis::ApiDefinition; use crate::apis::amazon_bedrock::{ConverseRequest, ConverseStreamRequest}; use crate::apis::openai_responses::ResponsesAPIRequest; @@ -19,7 +21,8 @@ pub enum ProviderRequestType { BedrockConverse(ConverseRequest), BedrockConverseStream(ConverseStreamRequest), ResponsesAPIRequest(ResponsesAPIRequest), - //add more request types here + GeminiGenerateContent(GenerateContentRequest), + GeminiStreamGenerateContent(GenerateContentRequest), } pub trait ProviderRequest: Send + Sync { /// Extract the model name from the request @@ -69,6 +72,9 @@ impl ProviderRequestType { Self::BedrockConverse(r) => r.set_messages(messages), Self::BedrockConverseStream(r) => r.set_messages(messages), Self::ResponsesAPIRequest(r) => r.set_messages(messages), + Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => { + r.set_messages(messages) + } } } @@ -100,6 +106,7 @@ impl ProviderRequest for ProviderRequestType { Self::BedrockConverse(r) => r.model(), Self::BedrockConverseStream(r) => r.model(), Self::ResponsesAPIRequest(r) => r.model(), + Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => r.model(), } } @@ -110,6 +117,9 @@ impl ProviderRequest for ProviderRequestType { Self::BedrockConverse(r) => r.set_model(model), Self::BedrockConverseStream(r) => r.set_model(model), Self::ResponsesAPIRequest(r) => r.set_model(model), + Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => { + r.set_model(model) + } } } @@ -120,6 +130,8 @@ impl ProviderRequest for ProviderRequestType { Self::BedrockConverse(_) => false, Self::BedrockConverseStream(_) => true, Self::ResponsesAPIRequest(r) => r.is_streaming(), + Self::GeminiGenerateContent(_) => false, + Self::GeminiStreamGenerateContent(_) => true, } } @@ -130,6 +142,9 @@ impl ProviderRequest for ProviderRequestType { Self::BedrockConverse(r) => r.extract_messages_text(), Self::BedrockConverseStream(r) => r.extract_messages_text(), Self::ResponsesAPIRequest(r) => r.extract_messages_text(), + Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => { + r.extract_messages_text() + } } } @@ -140,6 +155,9 @@ impl ProviderRequest for ProviderRequestType { Self::BedrockConverse(r) => r.get_recent_user_message(), Self::BedrockConverseStream(r) => r.get_recent_user_message(), Self::ResponsesAPIRequest(r) => r.get_recent_user_message(), + Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => { + r.get_recent_user_message() + } } } @@ -150,6 +168,9 @@ impl ProviderRequest for ProviderRequestType { Self::BedrockConverse(r) => r.get_tool_names(), Self::BedrockConverseStream(r) => r.get_tool_names(), Self::ResponsesAPIRequest(r) => r.get_tool_names(), + Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => { + r.get_tool_names() + } } } @@ -160,6 +181,7 @@ impl ProviderRequest for ProviderRequestType { Self::BedrockConverse(r) => r.to_bytes(), Self::BedrockConverseStream(r) => r.to_bytes(), Self::ResponsesAPIRequest(r) => r.to_bytes(), + Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => r.to_bytes(), } } @@ -170,6 +192,7 @@ impl ProviderRequest for ProviderRequestType { Self::BedrockConverse(r) => r.metadata(), Self::BedrockConverseStream(r) => r.metadata(), Self::ResponsesAPIRequest(r) => r.metadata(), + Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => r.metadata(), } } @@ -180,6 +203,9 @@ impl ProviderRequest for ProviderRequestType { Self::BedrockConverse(r) => r.remove_metadata_key(key), Self::BedrockConverseStream(r) => r.remove_metadata_key(key), Self::ResponsesAPIRequest(r) => r.remove_metadata_key(key), + Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => { + r.remove_metadata_key(key) + } } } @@ -190,6 +216,9 @@ impl ProviderRequest for ProviderRequestType { Self::BedrockConverse(r) => r.get_temperature(), Self::BedrockConverseStream(r) => r.get_temperature(), Self::ResponsesAPIRequest(r) => r.get_temperature(), + Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => { + r.get_temperature() + } } } @@ -200,6 +229,9 @@ impl ProviderRequest for ProviderRequestType { Self::BedrockConverse(r) => r.get_messages(), Self::BedrockConverseStream(r) => r.get_messages(), Self::ResponsesAPIRequest(r) => r.get_messages(), + Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => { + r.get_messages() + } } } @@ -210,6 +242,9 @@ impl ProviderRequest for ProviderRequestType { Self::BedrockConverse(r) => r.set_messages(messages), Self::BedrockConverseStream(r) => r.set_messages(messages), Self::ResponsesAPIRequest(r) => r.set_messages(messages), + Self::GeminiGenerateContent(r) | Self::GeminiStreamGenerateContent(r) => { + r.set_messages(messages) + } } } } @@ -245,6 +280,18 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient)> for ProviderRequestType { responses_apirequest, )) } + SupportedAPIsFromClient::GeminiGenerateContentAPI(gemini_api) => { + let gemini_request: GenerateContentRequest = + GenerateContentRequest::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + if gemini_api.supports_streaming() { + Ok(ProviderRequestType::GeminiStreamGenerateContent( + gemini_request, + )) + } else { + Ok(ProviderRequestType::GeminiGenerateContent(gemini_request)) + } + } } } } @@ -309,6 +356,37 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT source: None, }) } + // ChatCompletions -> Gemini + ( + ProviderRequestType::ChatCompletionsRequest(chat_req), + SupportedUpstreamAPIs::GeminiGenerateContent(_), + ) => { + let gemini_req = GenerateContentRequest::try_from(chat_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert ChatCompletionsRequest to GenerateContentRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + Ok(ProviderRequestType::GeminiGenerateContent(gemini_req)) + } + ( + ProviderRequestType::ChatCompletionsRequest(chat_req), + SupportedUpstreamAPIs::GeminiStreamGenerateContent(_), + ) => { + let gemini_req = GenerateContentRequest::try_from(chat_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert ChatCompletionsRequest to GenerateContentRequest (stream): {}", + e + ), + source: Some(Box::new(e)), + } + })?; + Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req)) + } // ============================================================================ // MessagesRequest conversions @@ -370,6 +448,37 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT source: None, }) } + // Messages -> Gemini (chain: Anthropic -> OpenAI -> Gemini) + ( + ProviderRequestType::MessagesRequest(messages_req), + SupportedUpstreamAPIs::GeminiGenerateContent(_), + ) => { + let gemini_req = GenerateContentRequest::try_from(messages_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert MessagesRequest to GenerateContentRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + Ok(ProviderRequestType::GeminiGenerateContent(gemini_req)) + } + ( + ProviderRequestType::MessagesRequest(messages_req), + SupportedUpstreamAPIs::GeminiStreamGenerateContent(_), + ) => { + let gemini_req = GenerateContentRequest::try_from(messages_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert MessagesRequest to GenerateContentRequest (stream): {}", + e + ), + source: Some(Box::new(e)), + } + })?; + Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req)) + } // ============================================================================ // ResponsesAPIRequest conversions (only converts TO other formats) @@ -480,6 +589,171 @@ impl TryFrom<(ProviderRequestType, &SupportedUpstreamAPIs)> for ProviderRequestT Ok(ProviderRequestType::BedrockConverseStream(bedrock_req)) } + // ResponsesAPI -> Gemini (via ChatCompletions) + ( + ProviderRequestType::ResponsesAPIRequest(responses_req), + SupportedUpstreamAPIs::GeminiGenerateContent(_), + ) => { + let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + let gemini_req = GenerateContentRequest::try_from(chat_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert ChatCompletionsRequest to GenerateContentRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + Ok(ProviderRequestType::GeminiGenerateContent(gemini_req)) + } + ( + ProviderRequestType::ResponsesAPIRequest(responses_req), + SupportedUpstreamAPIs::GeminiStreamGenerateContent(_), + ) => { + let chat_req = ChatCompletionsRequest::try_from(responses_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert ResponsesAPIRequest to ChatCompletionsRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + let gemini_req = GenerateContentRequest::try_from(chat_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert ChatCompletionsRequest to GenerateContentRequest (stream): {}", + e + ), + source: Some(Box::new(e)), + } + })?; + Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req)) + } + + // ============================================================================ + // GeminiGenerateContent conversions (client sends Gemini format) + // ============================================================================ + ( + ProviderRequestType::GeminiGenerateContent(gemini_req), + SupportedUpstreamAPIs::GeminiGenerateContent(_), + ) => Ok(ProviderRequestType::GeminiGenerateContent(gemini_req)), + ( + ProviderRequestType::GeminiStreamGenerateContent(gemini_req), + SupportedUpstreamAPIs::GeminiStreamGenerateContent(_), + ) => Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req)), + // Cross-streaming mode: non-streaming -> streaming and vice versa + ( + ProviderRequestType::GeminiGenerateContent(gemini_req), + SupportedUpstreamAPIs::GeminiStreamGenerateContent(_), + ) => Ok(ProviderRequestType::GeminiStreamGenerateContent(gemini_req)), + ( + ProviderRequestType::GeminiStreamGenerateContent(gemini_req), + SupportedUpstreamAPIs::GeminiGenerateContent(_), + ) => Ok(ProviderRequestType::GeminiGenerateContent(gemini_req)), + ( + ProviderRequestType::GeminiGenerateContent(gemini_req) + | ProviderRequestType::GeminiStreamGenerateContent(gemini_req), + SupportedUpstreamAPIs::OpenAIChatCompletions(_), + ) => { + let chat_req = ChatCompletionsRequest::try_from(gemini_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert GenerateContentRequest to ChatCompletionsRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)) + } + ( + ProviderRequestType::GeminiGenerateContent(gemini_req) + | ProviderRequestType::GeminiStreamGenerateContent(gemini_req), + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + ) => { + let messages_req = MessagesRequest::try_from(gemini_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert GenerateContentRequest to MessagesRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + Ok(ProviderRequestType::MessagesRequest(messages_req)) + } + ( + ProviderRequestType::GeminiGenerateContent(gemini_req) + | ProviderRequestType::GeminiStreamGenerateContent(gemini_req), + SupportedUpstreamAPIs::AmazonBedrockConverse(_), + ) => { + // Chain: Gemini -> OpenAI -> Bedrock + let chat_req = ChatCompletionsRequest::try_from(gemini_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert GenerateContentRequest to ChatCompletionsRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + let bedrock_req = ConverseRequest::try_from(chat_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert ChatCompletionsRequest to ConverseRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + Ok(ProviderRequestType::BedrockConverse(bedrock_req)) + } + ( + ProviderRequestType::GeminiGenerateContent(gemini_req) + | ProviderRequestType::GeminiStreamGenerateContent(gemini_req), + SupportedUpstreamAPIs::AmazonBedrockConverseStream(_), + ) => { + // Chain: Gemini -> OpenAI -> Bedrock Stream + let chat_req = ChatCompletionsRequest::try_from(gemini_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert GenerateContentRequest to ChatCompletionsRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + let bedrock_req = ConverseStreamRequest::try_from(chat_req).map_err(|e| { + ProviderRequestError { + message: format!( + "Failed to convert ChatCompletionsRequest to ConverseStreamRequest: {}", + e + ), + source: Some(Box::new(e)), + } + })?; + Ok(ProviderRequestType::BedrockConverseStream(bedrock_req)) + } + ( + ProviderRequestType::GeminiGenerateContent(_) + | ProviderRequestType::GeminiStreamGenerateContent(_), + SupportedUpstreamAPIs::OpenAIResponsesAPI(_), + ) => { + Err(ProviderRequestError { + message: "Conversion from GenerateContentRequest to ResponsesAPIRequest is not supported.".to_string(), + source: None, + }) + } + // ============================================================================ // Amazon Bedrock conversions (not supported as client API) // ============================================================================ diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index 5f46f97bd..a2d89dbf9 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -1,5 +1,6 @@ use crate::apis::amazon_bedrock::ConverseResponse; use crate::apis::anthropic::MessagesResponse; +use crate::apis::gemini::GenerateContentResponse; use crate::apis::openai::ChatCompletionsResponse; use crate::apis::openai_responses::ResponsesAPIResponse; use crate::clients::endpoints::SupportedAPIsFromClient; @@ -16,6 +17,7 @@ pub enum ProviderResponseType { ChatCompletionsResponse(ChatCompletionsResponse), MessagesResponse(MessagesResponse), ResponsesAPIResponse(Box), + GenerateContentResponse(GenerateContentResponse), } /// Trait for token usage information @@ -44,6 +46,9 @@ impl ProviderResponse for ProviderResponseType { ProviderResponseType::ResponsesAPIResponse(resp) => { resp.usage.as_ref().map(|u| u as &dyn TokenUsage) } + ProviderResponseType::GenerateContentResponse(resp) => { + resp.usage_metadata.as_ref().map(|u| u as &dyn TokenUsage) + } } } @@ -58,6 +63,15 @@ impl ProviderResponse for ProviderResponseType { u.total_tokens as usize, ) }), + ProviderResponseType::GenerateContentResponse(resp) => { + resp.usage_metadata.as_ref().map(|u| { + ( + u.prompt_token_count.unwrap_or(0) as usize, + u.candidates_token_count.unwrap_or(0) as usize, + u.total_token_count.unwrap_or(0) as usize, + ) + }) + } } } } @@ -238,6 +252,140 @@ impl TryFrom<(&[u8], &SupportedAPIsFromClient, &ProviderId)> for ProviderRespons response_api, ))) } + // ============================================================================ + // Gemini upstream transformations + // ============================================================================ + ( + SupportedUpstreamAPIs::GeminiGenerateContent(_), + SupportedAPIsFromClient::GeminiGenerateContentAPI(_), + ) => { + // Passthrough: Gemini upstream -> Gemini client + let resp: GenerateContentResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderResponseType::GenerateContentResponse(resp)) + } + ( + SupportedUpstreamAPIs::GeminiGenerateContent(_), + SupportedAPIsFromClient::OpenAIChatCompletions(_), + ) => { + // Gemini upstream -> OpenAI client + let gemini_resp: GenerateContentResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + let chat_resp: ChatCompletionsResponse = gemini_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp)) + } + ( + SupportedUpstreamAPIs::GeminiGenerateContent(_), + SupportedAPIsFromClient::AnthropicMessagesAPI(_), + ) => { + // Chain: Gemini -> OpenAI -> Anthropic + let gemini_resp: GenerateContentResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + let chat_resp: ChatCompletionsResponse = gemini_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + let messages_resp: MessagesResponse = chat_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + Ok(ProviderResponseType::MessagesResponse(messages_resp)) + } + ( + SupportedUpstreamAPIs::GeminiGenerateContent(_), + SupportedAPIsFromClient::OpenAIResponsesAPI(_), + ) => { + // Chain: Gemini -> OpenAI -> ResponsesAPI + let gemini_resp: GenerateContentResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + let chat_resp: ChatCompletionsResponse = gemini_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + let responses_resp: ResponsesAPIResponse = chat_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + Ok(ProviderResponseType::ResponsesAPIResponse(Box::new( + responses_resp, + ))) + } + + // ============================================================================ + // Non-Gemini upstream -> Gemini client + // ============================================================================ + ( + SupportedUpstreamAPIs::OpenAIChatCompletions(_), + SupportedAPIsFromClient::GeminiGenerateContentAPI(_), + ) => { + // OpenAI upstream -> Gemini client + let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + let gemini_resp: GenerateContentResponse = openai_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + Ok(ProviderResponseType::GenerateContentResponse(gemini_resp)) + } + ( + SupportedUpstreamAPIs::AnthropicMessagesAPI(_), + SupportedAPIsFromClient::GeminiGenerateContentAPI(_), + ) => { + // Chain: Anthropic -> OpenAI -> Gemini + let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + let chat_resp: ChatCompletionsResponse = + anthropic_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + let gemini_resp: GenerateContentResponse = chat_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + Ok(ProviderResponseType::GenerateContentResponse(gemini_resp)) + } + ( + SupportedUpstreamAPIs::AmazonBedrockConverse(_), + SupportedAPIsFromClient::GeminiGenerateContentAPI(_), + ) => { + // Chain: Bedrock -> OpenAI -> Gemini + let bedrock_resp: ConverseResponse = serde_json::from_slice(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + let chat_resp: ChatCompletionsResponse = bedrock_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + let gemini_resp: GenerateContentResponse = chat_resp.try_into().map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Transformation error: {}", e), + ) + })?; + Ok(ProviderResponseType::GenerateContentResponse(gemini_resp)) + } + _ => Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "Unsupported API combination for response transformation", diff --git a/crates/hermesllm/src/providers/streaming_response.rs b/crates/hermesllm/src/providers/streaming_response.rs index 66ccc7354..d34c0a715 100644 --- a/crates/hermesllm/src/providers/streaming_response.rs +++ b/crates/hermesllm/src/providers/streaming_response.rs @@ -83,6 +83,11 @@ impl TryFrom<(&SupportedAPIsFromClient, &SupportedUpstreamAPIs)> for SseStreamBu SupportedAPIsFromClient::OpenAIResponsesAPI(_) => { Ok(SseStreamBuffer::OpenAIResponses(Box::default())) } + SupportedAPIsFromClient::GeminiGenerateContentAPI(_) => { + // Gemini client with a different upstream - use passthrough + // since Gemini streaming uses SSE and doesn't need special buffering + Ok(SseStreamBuffer::Passthrough(PassthroughStreamBuffer::new())) + } } } } diff --git a/crates/hermesllm/src/transforms/mod.rs b/crates/hermesllm/src/transforms/mod.rs index ebb4bf203..8a6bc2b2d 100644 --- a/crates/hermesllm/src/transforms/mod.rs +++ b/crates/hermesllm/src/transforms/mod.rs @@ -15,6 +15,7 @@ pub mod response_streaming; // Re-export commonly used items for convenience pub use lib::*; +#[allow(ambiguous_glob_reexports)] pub use request::*; pub use response::*; pub use response_streaming::*; diff --git a/crates/hermesllm/src/transforms/request/from_gemini.rs b/crates/hermesllm/src/transforms/request/from_gemini.rs new file mode 100644 index 000000000..9db8db5e5 --- /dev/null +++ b/crates/hermesllm/src/transforms/request/from_gemini.rs @@ -0,0 +1,327 @@ +use crate::apis::gemini::GenerateContentRequest; +use crate::apis::openai::{ + ChatCompletionsRequest, Function, FunctionCall as OpenAIFunctionCall, Message, MessageContent, + Role, Tool, ToolCall as OpenAIToolCall, ToolChoice, ToolChoiceType, +}; + +use crate::apis::anthropic::MessagesRequest; +use crate::clients::TransformError; + +// ============================================================================ +// Gemini GenerateContent -> OpenAI ChatCompletions +// ============================================================================ + +impl TryFrom for ChatCompletionsRequest { + type Error = TransformError; + + fn try_from(req: GenerateContentRequest) -> Result { + let mut messages: Vec = Vec::new(); + + // Convert system instruction + if let Some(system) = &req.system_instruction { + let text = system + .parts + .iter() + .filter_map(|p| p.text.clone()) + .collect::>() + .join(""); + if !text.is_empty() { + messages.push(Message { + role: Role::System, + content: Some(MessageContent::Text(text)), + name: None, + tool_calls: None, + tool_call_id: None, + }); + } + } + + // Convert contents + for content in &req.contents { + let role = match content.role.as_deref() { + Some("model") => Role::Assistant, + _ => Role::User, + }; + + // Check if this content has function_call parts (assistant with tool calls) + let has_function_calls = content.parts.iter().any(|p| p.function_call.is_some()); + let has_function_responses = + content.parts.iter().any(|p| p.function_response.is_some()); + + if has_function_calls { + // Convert to assistant message with tool_calls + let mut tool_calls = Vec::new(); + let mut text_parts = Vec::new(); + + for (i, part) in content.parts.iter().enumerate() { + if let Some(fc) = &part.function_call { + tool_calls.push(OpenAIToolCall { + id: format!("call_{}", i), + call_type: "function".to_string(), + function: OpenAIFunctionCall { + name: fc.name.clone(), + arguments: serde_json::to_string(&fc.args).unwrap_or_default(), + }, + }); + } else if let Some(text) = &part.text { + text_parts.push(text.clone()); + } + } + + let content_text = if text_parts.is_empty() { + None + } else { + Some(MessageContent::Text(text_parts.join(""))) + }; + + messages.push(Message { + role: Role::Assistant, + content: content_text, + name: None, + tool_calls: if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }, + tool_call_id: None, + }); + } else if has_function_responses { + // Convert each function_response to a tool message + for part in &content.parts { + if let Some(fr) = &part.function_response { + let result_text = serde_json::to_string(&fr.response).unwrap_or_default(); + messages.push(Message { + role: Role::Tool, + content: Some(MessageContent::Text(result_text)), + name: None, + tool_calls: None, + tool_call_id: Some(fr.name.clone()), + }); + } + } + } else { + // Regular text message + let text = content + .parts + .iter() + .filter_map(|p| p.text.clone()) + .collect::>() + .join(""); + messages.push(Message { + role, + content: Some(MessageContent::Text(text)), + name: None, + tool_calls: None, + tool_call_id: None, + }); + } + } + + // Convert generation config + let (temperature, top_p, max_tokens, stop, presence_penalty, frequency_penalty) = + if let Some(gc) = &req.generation_config { + ( + gc.temperature, + gc.top_p, + gc.max_output_tokens, + gc.stop_sequences.clone(), + gc.presence_penalty, + gc.frequency_penalty, + ) + } else { + (None, None, None, None, None, None) + }; + + // Convert tools + let tools = req.tools.and_then(|gemini_tools| { + let openai_tools: Vec = gemini_tools + .iter() + .filter_map(|t| t.function_declarations.as_ref()) + .flatten() + .map(|fd| Tool { + tool_type: "function".to_string(), + function: Function { + name: fd.name.clone(), + description: fd.description.clone(), + parameters: fd.parameters.clone().unwrap_or_default(), + strict: None, + }, + }) + .collect(); + if openai_tools.is_empty() { + None + } else { + Some(openai_tools) + } + }); + + // Convert tool_config + let tool_choice = + req.tool_config + .and_then(|tc| match tc.function_calling_config.mode.as_str() { + "AUTO" => Some(ToolChoice::Type(ToolChoiceType::Auto)), + "NONE" => Some(ToolChoice::Type(ToolChoiceType::None)), + "ANY" => Some(ToolChoice::Type(ToolChoiceType::Required)), + _ => None, + }); + + Ok(ChatCompletionsRequest { + model: req.model, + messages, + temperature, + top_p, + max_completion_tokens: max_tokens, + stop, + tools, + tool_choice, + presence_penalty, + frequency_penalty, + metadata: req.metadata, + ..Default::default() + }) + } +} + +// ============================================================================ +// Gemini GenerateContent -> Anthropic Messages (via OpenAI) +// ============================================================================ + +impl TryFrom for MessagesRequest { + type Error = TransformError; + + fn try_from(req: GenerateContentRequest) -> Result { + // Chain: Gemini -> OpenAI -> Anthropic + let chat_req = ChatCompletionsRequest::try_from(req)?; + MessagesRequest::try_from(chat_req) + } +} + +// ============================================================================ +// TESTS +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + use crate::apis::gemini::{Content, FunctionCall, Part}; + use serde_json::json; + + #[test] + fn test_gemini_to_openai_basic() { + let req = GenerateContentRequest { + model: "gemini-pro".to_string(), + contents: vec![ + Content { + role: Some("user".to_string()), + parts: vec![Part { + text: Some("Hello".to_string()), + inline_data: None, + function_call: None, + function_response: None, + }], + }, + Content { + role: Some("model".to_string()), + parts: vec![Part { + text: Some("Hi there!".to_string()), + inline_data: None, + function_call: None, + function_response: None, + }], + }, + ], + system_instruction: Some(Content { + role: Some("user".to_string()), + parts: vec![Part { + text: Some("Be helpful".to_string()), + inline_data: None, + function_call: None, + function_response: None, + }], + }), + generation_config: Some(crate::apis::gemini::GenerationConfig { + temperature: Some(0.5), + max_output_tokens: Some(512), + ..Default::default() + }), + ..Default::default() + }; + + let openai_req = ChatCompletionsRequest::try_from(req).unwrap(); + + // System + user + assistant = 3 messages + assert_eq!(openai_req.messages.len(), 3); + assert_eq!(openai_req.messages[0].role, Role::System); + assert_eq!(openai_req.messages[1].role, Role::User); + assert_eq!(openai_req.messages[2].role, Role::Assistant); + + assert_eq!(openai_req.temperature, Some(0.5)); + assert_eq!(openai_req.max_completion_tokens, Some(512)); + } + + #[test] + fn test_gemini_to_openai_with_function_calls() { + let req = GenerateContentRequest { + model: "gemini-pro".to_string(), + contents: vec![ + Content { + role: Some("user".to_string()), + parts: vec![Part { + text: Some("Weather?".to_string()), + inline_data: None, + function_call: None, + function_response: None, + }], + }, + Content { + role: Some("model".to_string()), + parts: vec![Part { + text: None, + inline_data: None, + function_call: Some(FunctionCall { + name: "get_weather".to_string(), + args: json!({"location": "NYC"}), + }), + function_response: None, + }], + }, + ], + ..Default::default() + }; + + let openai_req = ChatCompletionsRequest::try_from(req).unwrap(); + assert_eq!(openai_req.messages.len(), 2); + assert!(openai_req.messages[1].tool_calls.is_some()); + let tc = openai_req.messages[1].tool_calls.as_ref().unwrap(); + assert_eq!(tc[0].function.name, "get_weather"); + } + + #[test] + fn test_gemini_to_openai_tool_config() { + let req = GenerateContentRequest { + model: "gemini-pro".to_string(), + contents: vec![Content { + role: Some("user".to_string()), + parts: vec![Part { + text: Some("test".to_string()), + inline_data: None, + function_call: None, + function_response: None, + }], + }], + tool_config: Some(crate::apis::gemini::ToolConfig { + function_calling_config: crate::apis::gemini::FunctionCallingConfig { + mode: "ANY".to_string(), + }, + }), + ..Default::default() + }; + + let openai_req = ChatCompletionsRequest::try_from(req).unwrap(); + assert!(openai_req.tool_choice.is_some()); + assert_eq!( + openai_req.tool_choice.as_ref().unwrap(), + &ToolChoice::Type(ToolChoiceType::Required) + ); + } +} diff --git a/crates/hermesllm/src/transforms/request/mod.rs b/crates/hermesllm/src/transforms/request/mod.rs index 5fbdf0b17..843693b4a 100644 --- a/crates/hermesllm/src/transforms/request/mod.rs +++ b/crates/hermesllm/src/transforms/request/mod.rs @@ -1,4 +1,6 @@ //! Request transformation modules pub mod from_anthropic; +pub mod from_gemini; pub mod from_openai; +pub mod to_gemini; diff --git a/crates/hermesllm/src/transforms/request/to_gemini.rs b/crates/hermesllm/src/transforms/request/to_gemini.rs new file mode 100644 index 000000000..9815f6e87 --- /dev/null +++ b/crates/hermesllm/src/transforms/request/to_gemini.rs @@ -0,0 +1,323 @@ +use crate::apis::gemini::{ + Content, FunctionCall, FunctionCallingConfig, FunctionDeclaration, FunctionResponse, + GenerateContentRequest, GenerationConfig, Part, Tool, ToolConfig, +}; +use crate::apis::openai::{ChatCompletionsRequest, Role, ToolChoice, ToolChoiceType}; + +use crate::apis::anthropic::MessagesRequest; +use crate::clients::TransformError; +use crate::transforms::lib::ExtractText; + +// ============================================================================ +// OpenAI ChatCompletions -> Gemini GenerateContent +// ============================================================================ + +impl TryFrom for GenerateContentRequest { + type Error = TransformError; + + fn try_from(req: ChatCompletionsRequest) -> Result { + let mut contents: Vec = Vec::new(); + let mut system_instruction: Option = None; + + for msg in &req.messages { + match msg.role { + Role::System => { + let text = msg.content.extract_text(); + system_instruction = Some(Content { + role: Some("user".to_string()), + parts: vec![Part { + text: Some(text), + inline_data: None, + function_call: None, + function_response: None, + }], + }); + } + Role::User => { + let text = msg.content.extract_text(); + contents.push(Content { + role: Some("user".to_string()), + parts: vec![Part { + text: Some(text), + inline_data: None, + function_call: None, + function_response: None, + }], + }); + } + Role::Assistant => { + let mut parts = Vec::new(); + + // Check for tool calls + if let Some(tool_calls) = &msg.tool_calls { + for tc in tool_calls { + let args: serde_json::Value = + serde_json::from_str(&tc.function.arguments).unwrap_or_default(); + parts.push(Part { + text: None, + inline_data: None, + function_call: Some(FunctionCall { + name: tc.function.name.clone(), + args, + }), + function_response: None, + }); + } + } + + // Also include text content if present + let text = msg.content.extract_text(); + if !text.is_empty() { + parts.push(Part { + text: Some(text), + inline_data: None, + function_call: None, + function_response: None, + }); + } + + if !parts.is_empty() { + contents.push(Content { + role: Some("model".to_string()), + parts, + }); + } + } + Role::Tool => { + let text = msg.content.extract_text(); + let tool_call_id = msg.tool_call_id.clone().unwrap_or_default(); + let response_value = serde_json::from_str(&text) + .unwrap_or_else(|_| serde_json::json!({"result": text})); + + contents.push(Content { + role: Some("user".to_string()), + parts: vec![Part { + text: None, + inline_data: None, + function_call: None, + function_response: Some(FunctionResponse { + name: tool_call_id, + response: response_value, + }), + }], + }); + } + } + } + + // Convert generation config + let generation_config = { + let gc = GenerationConfig { + temperature: req.temperature, + top_p: req.top_p, + top_k: None, + max_output_tokens: req.max_completion_tokens.or(req.max_tokens), + stop_sequences: req.stop, + response_mime_type: None, + candidate_count: None, + presence_penalty: req.presence_penalty, + frequency_penalty: req.frequency_penalty, + }; + // Only include if any field is set + if gc.temperature.is_some() + || gc.top_p.is_some() + || gc.max_output_tokens.is_some() + || gc.stop_sequences.is_some() + || gc.presence_penalty.is_some() + || gc.frequency_penalty.is_some() + { + Some(gc) + } else { + None + } + }; + + // Convert tools + let tools = req.tools.map(|openai_tools| { + let declarations: Vec = openai_tools + .iter() + .map(|t| FunctionDeclaration { + name: t.function.name.clone(), + description: t.function.description.clone(), + parameters: Some(t.function.parameters.clone()), + }) + .collect(); + vec![Tool { + function_declarations: Some(declarations), + code_execution: None, + }] + }); + + // Convert tool_choice + let tool_config = req.tool_choice.and_then(|tc| { + let mode = match tc { + ToolChoice::Type(t) => match t { + ToolChoiceType::Auto => Some("AUTO".to_string()), + ToolChoiceType::None => Some("NONE".to_string()), + ToolChoiceType::Required => Some("ANY".to_string()), + }, + ToolChoice::Function { .. } => Some("AUTO".to_string()), + }; + mode.map(|m| ToolConfig { + function_calling_config: FunctionCallingConfig { mode: m }, + }) + }); + + Ok(GenerateContentRequest { + model: req.model, + contents, + generation_config, + tools, + tool_config, + safety_settings: None, + system_instruction, + cached_content: None, + metadata: req.metadata, + }) + } +} + +// ============================================================================ +// Anthropic Messages -> Gemini GenerateContent (via OpenAI) +// ============================================================================ + +impl TryFrom for GenerateContentRequest { + type Error = TransformError; + + fn try_from(req: MessagesRequest) -> Result { + // Chain: Anthropic -> OpenAI -> Gemini + let chat_req = ChatCompletionsRequest::try_from(req)?; + GenerateContentRequest::try_from(chat_req) + } +} + +// ============================================================================ +// TESTS +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_openai_to_gemini_basic() { + let req: ChatCompletionsRequest = serde_json::from_value(json!({ + "model": "gemini-pro", + "messages": [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"} + ], + "temperature": 0.7, + "max_tokens": 1024 + })) + .unwrap(); + + let gemini_req = GenerateContentRequest::try_from(req).unwrap(); + + // System should be in system_instruction + assert!(gemini_req.system_instruction.is_some()); + let sys = gemini_req.system_instruction.as_ref().unwrap(); + assert_eq!(sys.parts[0].text.as_deref(), Some("You are helpful")); + + // 3 content messages (user, model, user) + assert_eq!(gemini_req.contents.len(), 3); + assert_eq!(gemini_req.contents[0].role.as_deref(), Some("user")); + assert_eq!(gemini_req.contents[1].role.as_deref(), Some("model")); + assert_eq!(gemini_req.contents[2].role.as_deref(), Some("user")); + + // Generation config + assert_eq!( + gemini_req.generation_config.as_ref().unwrap().temperature, + Some(0.7) + ); + assert_eq!( + gemini_req + .generation_config + .as_ref() + .unwrap() + .max_output_tokens, + Some(1024) + ); + } + + #[test] + fn test_openai_to_gemini_with_tools() { + let req: ChatCompletionsRequest = serde_json::from_value(json!({ + "model": "gemini-pro", + "messages": [ + {"role": "user", "content": "What's the weather?"} + ], + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": {"type": "object", "properties": {"location": {"type": "string"}}} + } + }], + "tool_choice": "auto" + })) + .unwrap(); + + let gemini_req = GenerateContentRequest::try_from(req).unwrap(); + assert!(gemini_req.tools.is_some()); + let tools = gemini_req.tools.as_ref().unwrap(); + assert_eq!(tools.len(), 1); + let decls = tools[0].function_declarations.as_ref().unwrap(); + assert_eq!(decls[0].name, "get_weather"); + + assert!(gemini_req.tool_config.is_some()); + assert_eq!( + gemini_req + .tool_config + .as_ref() + .unwrap() + .function_calling_config + .mode, + "AUTO" + ); + } + + #[test] + fn test_openai_to_gemini_with_tool_calls() { + let req: ChatCompletionsRequest = serde_json::from_value(json!({ + "model": "gemini-pro", + "messages": [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "tool_calls": [{ + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\": \"NYC\"}" + } + }] + }, + { + "role": "tool", + "tool_call_id": "call_123", + "content": "Sunny, 72F" + } + ] + })) + .unwrap(); + + let gemini_req = GenerateContentRequest::try_from(req).unwrap(); + assert_eq!(gemini_req.contents.len(), 3); + + // Assistant with function_call + let model_content = &gemini_req.contents[1]; + assert_eq!(model_content.role.as_deref(), Some("model")); + assert!(model_content.parts[0].function_call.is_some()); + + // Tool response + let tool_content = &gemini_req.contents[2]; + assert_eq!(tool_content.role.as_deref(), Some("user")); + assert!(tool_content.parts[0].function_response.is_some()); + } +} diff --git a/crates/hermesllm/src/transforms/response/from_gemini.rs b/crates/hermesllm/src/transforms/response/from_gemini.rs new file mode 100644 index 000000000..94635bcb1 --- /dev/null +++ b/crates/hermesllm/src/transforms/response/from_gemini.rs @@ -0,0 +1,417 @@ +use crate::apis::anthropic::MessagesResponse; +use crate::apis::gemini::GenerateContentResponse; +use crate::apis::openai::{ + ChatCompletionsResponse, ChatCompletionsStreamResponse, Choice, FinishReason, + FunctionCall as OpenAIFunctionCall, MessageDelta, ResponseMessage, Role, StreamChoice, + ToolCall as OpenAIToolCall, Usage, +}; +use crate::clients::TransformError; + +// ============================================================================ +// Gemini GenerateContentResponse -> OpenAI ChatCompletionsResponse +// ============================================================================ + +fn map_finish_reason(gemini_reason: Option<&str>) -> Option { + gemini_reason.map(|r| match r { + "STOP" => FinishReason::Stop, + "MAX_TOKENS" => FinishReason::Length, + "SAFETY" | "RECITATION" => FinishReason::ContentFilter, + _ => FinishReason::Stop, + }) +} + +impl TryFrom for ChatCompletionsResponse { + type Error = TransformError; + + fn try_from(resp: GenerateContentResponse) -> Result { + let candidates = resp.candidates.unwrap_or_default(); + let candidate = candidates.first(); + + let mut content_text = String::new(); + let mut tool_calls: Vec = Vec::new(); + + if let Some(candidate) = candidate { + if let Some(ref content) = candidate.content { + for (i, part) in content.parts.iter().enumerate() { + if let Some(ref text) = part.text { + content_text.push_str(text); + } + if let Some(ref fc) = part.function_call { + tool_calls.push(OpenAIToolCall { + id: format!("call_{}", i), + call_type: "function".to_string(), + function: OpenAIFunctionCall { + name: fc.name.clone(), + arguments: serde_json::to_string(&fc.args).unwrap_or_default(), + }, + }); + } + } + } + } + + let finish_reason = candidate + .and_then(|c| map_finish_reason(c.finish_reason.as_deref())) + .unwrap_or(FinishReason::Stop); + + let message_content = if content_text.is_empty() { + None + } else { + Some(content_text) + }; + + let tool_calls_opt = if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }; + + let choice = Choice { + index: 0, + message: ResponseMessage { + role: Role::Assistant, + content: message_content, + tool_calls: tool_calls_opt, + refusal: None, + annotations: None, + audio: None, + function_call: None, + }, + finish_reason: Some(finish_reason), + logprobs: None, + }; + + let usage = resp + .usage_metadata + .map(|um| Usage { + prompt_tokens: um.prompt_token_count.unwrap_or(0), + completion_tokens: um.candidates_token_count.unwrap_or(0), + total_tokens: um.total_token_count.unwrap_or(0), + prompt_tokens_details: None, + completion_tokens_details: None, + }) + .unwrap_or_default(); + + Ok(ChatCompletionsResponse { + id: format!( + "chatcmpl-gemini-{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis() + ), + object: Some("chat.completion".to_string()), + created: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + model: resp.model_version.unwrap_or_else(|| "gemini".to_string()), + choices: vec![choice], + usage, + system_fingerprint: None, + service_tier: None, + metadata: None, + }) + } +} + +// ============================================================================ +// Gemini GenerateContentResponse -> Anthropic MessagesResponse (via OpenAI) +// ============================================================================ + +impl TryFrom for MessagesResponse { + type Error = TransformError; + + fn try_from(resp: GenerateContentResponse) -> Result { + // Chain: Gemini -> OpenAI -> Anthropic + let chat_resp = ChatCompletionsResponse::try_from(resp)?; + MessagesResponse::try_from(chat_resp) + } +} + +// ============================================================================ +// Gemini GenerateContentResponse -> OpenAI ChatCompletionsStreamResponse +// ============================================================================ + +impl TryFrom for ChatCompletionsStreamResponse { + type Error = TransformError; + + fn try_from(resp: GenerateContentResponse) -> Result { + let candidates = resp.candidates.unwrap_or_default(); + let candidate = candidates.first(); + + let mut delta_content: Option = None; + + if let Some(candidate) = candidate { + if let Some(ref content) = candidate.content { + let mut text_parts = Vec::new(); + + for part in content.parts.iter() { + if let Some(ref text) = part.text { + text_parts.push(text.clone()); + } + } + + if !text_parts.is_empty() { + delta_content = Some(text_parts.join("")); + } + } + } + + let finish_reason = candidate.and_then(|c| map_finish_reason(c.finish_reason.as_deref())); + + let role = candidate + .and_then(|c| c.content.as_ref()) + .and_then(|c| c.role.as_deref()) + .map(|r| match r { + "model" => Role::Assistant, + _ => Role::User, + }); + + Ok(ChatCompletionsStreamResponse { + id: format!( + "chatcmpl-gemini-{}", + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis() + ), + object: Some("chat.completion.chunk".to_string()), + created: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + model: resp.model_version.unwrap_or_else(|| "gemini".to_string()), + choices: vec![StreamChoice { + index: 0, + delta: MessageDelta { + role, + content: delta_content, + tool_calls: None, + refusal: None, + function_call: None, + }, + finish_reason, + logprobs: None, + }], + usage: None, + system_fingerprint: None, + service_tier: None, + }) + } +} + +// ============================================================================ +// REVERSE: OpenAI ChatCompletionsResponse -> Gemini GenerateContentResponse +// ============================================================================ + +impl TryFrom for GenerateContentResponse { + type Error = TransformError; + + fn try_from(resp: ChatCompletionsResponse) -> Result { + use crate::apis::gemini::{Candidate, Content, FunctionCall, Part, UsageMetadata}; + + let candidates = if let Some(choice) = resp.choices.first() { + let mut parts = Vec::new(); + + // Text content + if let Some(ref content) = choice.message.content { + if !content.is_empty() { + parts.push(Part { + text: Some(content.clone()), + inline_data: None, + function_call: None, + function_response: None, + }); + } + } + + // Tool calls + if let Some(ref tool_calls) = choice.message.tool_calls { + for tc in tool_calls { + let args: serde_json::Value = + serde_json::from_str(&tc.function.arguments).unwrap_or_default(); + parts.push(Part { + text: None, + inline_data: None, + function_call: Some(FunctionCall { + name: tc.function.name.clone(), + args, + }), + function_response: None, + }); + } + } + + if parts.is_empty() { + parts.push(Part { + text: Some(String::new()), + inline_data: None, + function_call: None, + function_response: None, + }); + } + + let finish_reason = choice.finish_reason.as_ref().map(|fr| match fr { + FinishReason::Stop => "STOP".to_string(), + FinishReason::Length => "MAX_TOKENS".to_string(), + FinishReason::ContentFilter => "SAFETY".to_string(), + FinishReason::ToolCalls => "STOP".to_string(), + FinishReason::FunctionCall => "STOP".to_string(), + }); + + vec![Candidate { + content: Some(Content { + role: Some("model".to_string()), + parts, + }), + finish_reason, + safety_ratings: None, + }] + } else { + vec![] + }; + + let usage_metadata = Some(UsageMetadata { + prompt_token_count: Some(resp.usage.prompt_tokens), + candidates_token_count: Some(resp.usage.completion_tokens), + total_token_count: Some(resp.usage.total_tokens), + }); + + Ok(GenerateContentResponse { + candidates: Some(candidates), + usage_metadata, + prompt_feedback: None, + model_version: Some(resp.model), + }) + } +} + +// ============================================================================ +// TESTS +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_gemini_to_openai_response() { + let resp: GenerateContentResponse = serde_json::from_value(json!({ + "candidates": [{ + "content": { + "role": "model", + "parts": [{"text": "Hello! How can I help?"}] + }, + "finishReason": "STOP" + }], + "usageMetadata": { + "promptTokenCount": 5, + "candidatesTokenCount": 7, + "totalTokenCount": 12 + }, + "modelVersion": "gemini-2.0-flash" + })) + .unwrap(); + + let openai_resp = ChatCompletionsResponse::try_from(resp).unwrap(); + assert_eq!(openai_resp.choices.len(), 1); + let msg = &openai_resp.choices[0].message; + assert_eq!(msg.content.as_deref(), Some("Hello! How can I help?")); + assert_eq!( + openai_resp.choices[0].finish_reason, + Some(FinishReason::Stop) + ); + assert_eq!(openai_resp.usage.prompt_tokens, 5); + assert_eq!(openai_resp.usage.completion_tokens, 7); + } + + #[test] + fn test_gemini_to_openai_stream_response() { + let resp: GenerateContentResponse = serde_json::from_value(json!({ + "candidates": [{ + "content": { + "role": "model", + "parts": [{"text": "Hello"}] + } + }] + })) + .unwrap(); + + let stream_resp = ChatCompletionsStreamResponse::try_from(resp).unwrap(); + assert_eq!(stream_resp.choices.len(), 1); + assert_eq!( + stream_resp.choices[0].delta.content, + Some("Hello".to_string()) + ); + assert_eq!(stream_resp.choices[0].delta.role, Some(Role::Assistant)); + } + + #[test] + fn test_gemini_to_openai_with_function_call() { + let resp: GenerateContentResponse = serde_json::from_value(json!({ + "candidates": [{ + "content": { + "role": "model", + "parts": [{ + "functionCall": { + "name": "get_weather", + "args": {"location": "NYC"} + } + }] + }, + "finishReason": "STOP" + }] + })) + .unwrap(); + + let openai_resp = ChatCompletionsResponse::try_from(resp).unwrap(); + let msg = &openai_resp.choices[0].message; + assert!(msg.tool_calls.is_some()); + let tc = msg.tool_calls.as_ref().unwrap(); + assert_eq!(tc[0].function.name, "get_weather"); + } + + #[test] + fn test_openai_to_gemini_response() { + let resp: ChatCompletionsResponse = serde_json::from_value(json!({ + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1234567890, + "model": "gpt-4", + "choices": [{ + "index": 0, + "message": {"role": "assistant", "content": "Hello!"}, + "finish_reason": "stop" + }], + "usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12} + })) + .unwrap(); + + let gemini_resp = GenerateContentResponse::try_from(resp).unwrap(); + let candidates = gemini_resp.candidates.as_ref().unwrap(); + assert_eq!(candidates.len(), 1); + let parts = &candidates[0].content.as_ref().unwrap().parts; + assert_eq!(parts[0].text.as_deref(), Some("Hello!")); + assert_eq!(candidates[0].finish_reason.as_deref(), Some("STOP")); + } + + #[test] + fn test_finish_reason_mapping() { + assert_eq!(map_finish_reason(Some("STOP")), Some(FinishReason::Stop)); + assert_eq!( + map_finish_reason(Some("MAX_TOKENS")), + Some(FinishReason::Length) + ); + assert_eq!( + map_finish_reason(Some("SAFETY")), + Some(FinishReason::ContentFilter) + ); + assert_eq!( + map_finish_reason(Some("RECITATION")), + Some(FinishReason::ContentFilter) + ); + assert_eq!(map_finish_reason(None), None); + } +} diff --git a/crates/hermesllm/src/transforms/response/mod.rs b/crates/hermesllm/src/transforms/response/mod.rs index 1dd0d4ea4..d547c70f2 100644 --- a/crates/hermesllm/src/transforms/response/mod.rs +++ b/crates/hermesllm/src/transforms/response/mod.rs @@ -1,4 +1,5 @@ //! Response transformation modules +pub mod from_gemini; pub mod output_to_input; pub mod to_anthropic; pub mod to_openai; diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 7a353bcb7..ea06ee84b 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -217,7 +217,9 @@ impl StreamContext { SupportedUpstreamAPIs::OpenAIChatCompletions(_) | SupportedUpstreamAPIs::AmazonBedrockConverse(_) | SupportedUpstreamAPIs::AmazonBedrockConverseStream(_) - | SupportedUpstreamAPIs::OpenAIResponsesAPI(_), + | SupportedUpstreamAPIs::OpenAIResponsesAPI(_) + | SupportedUpstreamAPIs::GeminiGenerateContent(_) + | SupportedUpstreamAPIs::GeminiStreamGenerateContent(_), ) | None => { // OpenAI and default: use Authorization Bearer token