diff --git a/Cargo.toml b/Cargo.toml index fb1dd08..da53025 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,4 +17,5 @@ openai-rust2 = { version = "1.6.0" } async-trait = "0.1.88" log = "0.4.27" -env_logger = "0.11.8" \ No newline at end of file +env_logger = "0.11.8" +bumpalo = "3.16" \ No newline at end of file diff --git a/src/cloudllm/client_wrapper.rs b/src/cloudllm/client_wrapper.rs index c67aee5..b457b19 100644 --- a/src/cloudllm/client_wrapper.rs +++ b/src/cloudllm/client_wrapper.rs @@ -7,6 +7,7 @@ use openai_rust2 as openai_rust; /// and uses a ClientWrapper to interact with the LLM. // src/client_wrapper use std::error::Error; +use std::sync::Arc; use tokio::sync::Mutex; /// Represents the possible roles for a message. @@ -33,8 +34,8 @@ pub struct TokenUsage { pub struct Message { /// The role associated with the message. pub role: Role, - /// The actual content of the message. - pub content: String, + /// The actual content of the message stored as Arc to avoid clones. + pub content: Arc, } /// Trait defining the interface to interact with various LLM services. diff --git a/src/cloudllm/clients/claude.rs b/src/cloudllm/clients/claude.rs index 493132e..f84964a 100644 --- a/src/cloudllm/clients/claude.rs +++ b/src/cloudllm/clients/claude.rs @@ -81,4 +81,4 @@ impl ClientWrapper for ClaudeClient { fn usage_slot(&self) -> Option<&Mutex>> { self.delegate_client.usage_slot() } -} +} \ No newline at end of file diff --git a/src/cloudllm/clients/gemini.rs b/src/cloudllm/clients/gemini.rs index 3f56afe..308994b 100644 --- a/src/cloudllm/clients/gemini.rs +++ b/src/cloudllm/clients/gemini.rs @@ -5,8 +5,13 @@ use async_trait::async_trait; use log::error; use openai_rust::chat; use openai_rust2 as openai_rust; +use std::env; +use std::error::Error; +use std::sync::Arc; +use tokio::runtime::Runtime; use tokio::sync::Mutex; + pub struct GeminiClient { client: openai_rust::Client, pub model: String, @@ -194,7 +199,7 @@ impl ClientWrapper for GeminiClient { Role::User => "user".to_owned(), Role::Assistant => "assistant".to_owned(), }, - content: msg.content.clone(), + content: msg.content.to_string(), }); } @@ -213,7 +218,7 @@ impl ClientWrapper for GeminiClient { match result { Ok(content) => Ok(Message { role: Role::Assistant, - content, + content: Arc::from(content.as_str()), }), Err(err) => { if log::log_enabled!(log::Level::Error) { diff --git a/src/cloudllm/clients/openai.rs b/src/cloudllm/clients/openai.rs index 9873a2d..604a36f 100644 --- a/src/cloudllm/clients/openai.rs +++ b/src/cloudllm/clients/openai.rs @@ -41,6 +41,7 @@ //! //! Make sure `OPENAI_API_KEY` is set and pick a valid model name (e.g. `"gpt-4.1-nano"`). use std::error::Error; +use std::sync::Arc; use async_trait::async_trait; use openai_rust::chat; @@ -153,7 +154,7 @@ impl ClientWrapper for OpenAIClient { Role::User => "user".to_owned(), Role::Assistant => "assistant".to_owned(), }, - content: msg.content.clone(), + content: msg.content.to_string(), }); } @@ -172,7 +173,7 @@ impl ClientWrapper for OpenAIClient { match result { Ok(c) => Ok(Message { role: Role::Assistant, - content: c, + content: Arc::from(c.as_str()), }), Err(_) => { if log::log_enabled!(log::Level::Error) { diff --git a/src/cloudllm/llm_session.rs b/src/cloudllm/llm_session.rs index 1ae2614..ff53845 100644 --- a/src/cloudllm/llm_session.rs +++ b/src/cloudllm/llm_session.rs @@ -51,6 +51,7 @@ use crate::client_wrapper; use crate::cloudllm::client_wrapper::{ClientWrapper, Message, Role}; +use bumpalo::Bump; use openai_rust2 as openai_rust; use std::sync::Arc; @@ -65,6 +66,7 @@ use std::sync::Arc; /// - `total_output_tokens`: sum of all completion tokens received so far. /// - `total_context_tokens`: shortcut for input + output totals. /// - `total_token_count`: total tokens used in the current session. +/// - `arena`: bump allocator for efficient message body allocation. pub struct LLMSession { client: Arc, system_prompt: Message, @@ -74,18 +76,25 @@ pub struct LLMSession { total_input_tokens: usize, total_output_tokens: usize, total_token_count: usize, + arena: Bump, } impl LLMSession { /// Creates a new `LLMSession` with the given client and system prompt. /// Initializes the conversation history and sets a default maximum token limit. pub fn new(client: Arc, system_prompt: String, max_tokens: usize) -> Self { + let arena = Bump::new(); + + // Allocate system prompt in arena and create Arc from it + let system_prompt_str = arena.alloc_str(&system_prompt); + let system_prompt_arc: Arc = Arc::from(system_prompt_str); + // Create the system prompt message let system_prompt_message = Message { role: Role::System, - content: system_prompt, + content: system_prompt_arc, }; - // Count tokens in the system prompt message + LLMSession { client, system_prompt: system_prompt_message, @@ -95,6 +104,7 @@ impl LLMSession { total_input_tokens: 0, total_output_tokens: 0, total_token_count: 0, + arena, } } @@ -114,7 +124,14 @@ impl LLMSession { content: String, optional_search_parameters: Option, ) -> Result> { - let message = Message { role, content }; + // Allocate message content in arena and create Arc + let content_str = self.arena.alloc_str(&content); + let content_arc: Arc = Arc::from(content_str); + + let message = Message { + role, + content: content_arc, + }; // Cache the token count for the new message before adding it let message_token_count = estimate_message_token_count(&message); @@ -170,9 +187,13 @@ impl LLMSession { /// Sets a new system prompt for the session. /// Updates the token count accordingly. pub fn set_system_prompt(&mut self, prompt: String) { + // Allocate prompt in arena and create Arc + let prompt_str = self.arena.alloc_str(&prompt); + let prompt_arc: Arc = Arc::from(prompt_str); + self.system_prompt = Message { role: Role::System, - content: prompt, + content: prompt_arc, }; } @@ -196,6 +217,10 @@ impl LLMSession { pub fn get_cached_token_counts(&self) -> &Vec { &self.cached_token_counts } + + pub fn get_system_prompt(&self) -> &Message { + &self.system_prompt + } } /// Estimates the number of tokens in a string. diff --git a/tests/client_tests.rs b/tests/client_tests.rs index b2224ae..9bf0d83 100644 --- a/tests/client_tests.rs +++ b/tests/client_tests.rs @@ -43,7 +43,7 @@ fn test_claude_client() { log::error!("Error: {}", e); Message { role: System, - content: format!("An error occurred: {:?}", e), + content: format!("An error occurred: {:?}", e).into(), } }) }); @@ -129,7 +129,7 @@ pub fn test_grok_client() { log::error!("Error: {}", e); Message { role: crate::Role::System, - content: format!("An error occurred: {:?}", e), + content: format!("An error occurred: {:?}", e).into(), } }) }); @@ -168,7 +168,7 @@ fn test_openai_client() { log::error!("Error: {}", e); Message { role: Role::System, - content: format!("An error occurred: {:?}", e), + content: format!("An error occurred: {:?}", e).into(), } }) }); diff --git a/tests/llm_session_bump_allocations_test.rs b/tests/llm_session_bump_allocations_test.rs new file mode 100644 index 0000000..8488eca --- /dev/null +++ b/tests/llm_session_bump_allocations_test.rs @@ -0,0 +1,94 @@ +use cloudllm::client_wrapper::{ClientWrapper, Message, Role, TokenUsage}; +use cloudllm::LLMSession; +use async_trait::async_trait; +use openai_rust2 as openai_rust; +use std::sync::Arc; +use tokio::sync::Mutex; + +// Mock client for testing +struct MockClient { + usage: Mutex>, + response_content: String, +} + +impl MockClient { + fn new(response_content: String) -> Self { + Self { + usage: Mutex::new(None), + response_content, + } + } +} + +#[async_trait] +impl ClientWrapper for MockClient { + async fn send_message( + &self, + _messages: &[Message], + _optional_search_parameters: Option, + ) -> Result> { + Ok(Message { + role: Role::Assistant, + content: self.response_content.clone().into(), + }) + } + + fn usage_slot(&self) -> Option<&Mutex>> { + Some(&self.usage) + } +} + +#[tokio::test] +async fn test_arena_allocation() { + let mock_client = Arc::new(MockClient::new("Mock response".to_string())); + let mut session = LLMSession::new( + mock_client, + "Test system prompt".to_string(), + 1000, + ); + + // Send a message + let result = session.send_message( + Role::User, + "Test user message".to_string(), + None, + ).await; + + assert!(result.is_ok()); + let response = result.unwrap(); + assert_eq!(&*response.content, "Mock response"); + + // Verify system prompt is allocated correctly + assert_eq!(&*session.get_system_prompt().content, "Test system prompt"); + + // Verify conversation history + assert_eq!(session.get_conversation_history().len(), 2); // user message + assistant response +} + +#[test] +fn test_set_system_prompt() { + let mock_client = Arc::new(MockClient::new("Response".to_string())); + let mut session = LLMSession::new( + mock_client, + "Initial prompt".to_string(), + 1000, + ); + + // Change system prompt + session.set_system_prompt("Updated prompt".to_string()); + assert_eq!(&*session.get_system_prompt().content, "Updated prompt"); +} + +#[test] +fn test_message_content_is_arc_str() { + // Verify that Message.content is Arc and cloning is cheap + let msg = Message { + role: Role::User, + content: Arc::from("Test message"), + }; + + let cloned = msg.clone(); + + // Arc::ptr_eq checks if both Arcs point to the same allocation + assert!(Arc::ptr_eq(&msg.content, &cloned.content)); +} diff --git a/tests/llm_session_tests.rs b/tests/llm_session_tests.rs index 9e4b272..1e8c46b 100644 --- a/tests/llm_session_tests.rs +++ b/tests/llm_session_tests.rs @@ -41,7 +41,7 @@ impl ClientWrapper for MockClient { ) -> Result> { Ok(Message { role: Role::Assistant, - content: self.response_content.clone(), + content: self.response_content.clone().into(), }) } @@ -70,11 +70,11 @@ async fn test_token_caching() { // Verify token counts are cached correctly let expected_user_tokens = llm_session::estimate_message_token_count(&Message { role: Role::User, - content: user_message.to_string(), + content: user_message.to_string().into(), }); let expected_response_tokens = llm_session::estimate_message_token_count(&Message { role: Role::Assistant, - content: "Response".to_string(), + content: "Response".to_string().into(), }); assert_eq!(session.get_cached_token_counts()[0], expected_user_tokens); @@ -126,7 +126,7 @@ fn test_estimate_token_count() { fn test_estimate_message_token_count() { let message = Message { role: Role::User, - content: "test message".to_string(), + content: "test message".to_string().into(), }; // "test message" = 12 characters = 3 tokens + 1 role token = 4 tokens assert_eq!(llm_session::estimate_message_token_count(&message), 4);