Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ openai-rust2 = { version = "1.6.0" }

async-trait = "0.1.88"
log = "0.4.27"
env_logger = "0.11.8"
env_logger = "0.11.8"
bumpalo = "3.16"
5 changes: 3 additions & 2 deletions src/cloudllm/client_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use openai_rust2 as openai_rust;
/// and uses a ClientWrapper to interact with the LLM.
// src/client_wrapper
use std::error::Error;
use std::sync::Arc;
use tokio::sync::Mutex;

/// Represents the possible roles for a message.
Expand All @@ -33,8 +34,8 @@ pub struct TokenUsage {
pub struct Message {
/// The role associated with the message.
pub role: Role,
/// The actual content of the message.
pub content: String,
/// The actual content of the message stored as Arc<str> to avoid clones.
pub content: Arc<str>,
}

/// Trait defining the interface to interact with various LLM services.
Expand Down
2 changes: 1 addition & 1 deletion src/cloudllm/clients/claude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ impl ClientWrapper for ClaudeClient {
fn usage_slot(&self) -> Option<&Mutex<Option<TokenUsage>>> {
self.delegate_client.usage_slot()
}
}
}
9 changes: 7 additions & 2 deletions src/cloudllm/clients/gemini.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@ use async_trait::async_trait;
use log::error;
use openai_rust::chat;
use openai_rust2 as openai_rust;
use std::env;
use std::error::Error;
use std::sync::Arc;
use tokio::runtime::Runtime;
use tokio::sync::Mutex;


pub struct GeminiClient {
client: openai_rust::Client,
pub model: String,
Expand Down Expand Up @@ -194,7 +199,7 @@ impl ClientWrapper for GeminiClient {
Role::User => "user".to_owned(),
Role::Assistant => "assistant".to_owned(),
},
content: msg.content.clone(),
content: msg.content.to_string(),
});
}

Expand All @@ -213,7 +218,7 @@ impl ClientWrapper for GeminiClient {
match result {
Ok(content) => Ok(Message {
role: Role::Assistant,
content,
content: Arc::from(content.as_str()),
}),
Err(err) => {
if log::log_enabled!(log::Level::Error) {
Expand Down
5 changes: 3 additions & 2 deletions src/cloudllm/clients/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
//!
//! Make sure `OPENAI_API_KEY` is set and pick a valid model name (e.g. `"gpt-4.1-nano"`).
use std::error::Error;
use std::sync::Arc;

use async_trait::async_trait;
use openai_rust::chat;
Expand Down Expand Up @@ -153,7 +154,7 @@ impl ClientWrapper for OpenAIClient {
Role::User => "user".to_owned(),
Role::Assistant => "assistant".to_owned(),
},
content: msg.content.clone(),
content: msg.content.to_string(),
});
}

Expand All @@ -172,7 +173,7 @@ impl ClientWrapper for OpenAIClient {
match result {
Ok(c) => Ok(Message {
role: Role::Assistant,
content: c,
content: Arc::from(c.as_str()),
}),
Err(_) => {
if log::log_enabled!(log::Level::Error) {
Expand Down
33 changes: 29 additions & 4 deletions src/cloudllm/llm_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

use crate::client_wrapper;
use crate::cloudllm::client_wrapper::{ClientWrapper, Message, Role};
use bumpalo::Bump;
use openai_rust2 as openai_rust;
use std::sync::Arc;

Expand All @@ -65,6 +66,7 @@ use std::sync::Arc;
/// - `total_output_tokens`: sum of all completion tokens received so far.
/// - `total_context_tokens`: shortcut for input + output totals.
/// - `total_token_count`: total tokens used in the current session.
/// - `arena`: bump allocator for efficient message body allocation.
pub struct LLMSession {
client: Arc<dyn ClientWrapper>,
system_prompt: Message,
Expand All @@ -74,18 +76,25 @@ pub struct LLMSession {
total_input_tokens: usize,
total_output_tokens: usize,
total_token_count: usize,
arena: Bump,
}

impl LLMSession {
/// Creates a new `LLMSession` with the given client and system prompt.
/// Initializes the conversation history and sets a default maximum token limit.
pub fn new(client: Arc<dyn ClientWrapper>, system_prompt: String, max_tokens: usize) -> Self {
let arena = Bump::new();

// Allocate system prompt in arena and create Arc<str> from it
let system_prompt_str = arena.alloc_str(&system_prompt);
let system_prompt_arc: Arc<str> = Arc::from(system_prompt_str);

// Create the system prompt message
let system_prompt_message = Message {
role: Role::System,
content: system_prompt,
content: system_prompt_arc,
};
// Count tokens in the system prompt message

LLMSession {
client,
system_prompt: system_prompt_message,
Expand All @@ -95,6 +104,7 @@ impl LLMSession {
total_input_tokens: 0,
total_output_tokens: 0,
total_token_count: 0,
arena,
}
}

Expand All @@ -114,7 +124,14 @@ impl LLMSession {
content: String,
optional_search_parameters: Option<openai_rust::chat::SearchParameters>,
) -> Result<Message, Box<dyn std::error::Error>> {
let message = Message { role, content };
// Allocate message content in arena and create Arc<str>
let content_str = self.arena.alloc_str(&content);
let content_arc: Arc<str> = Arc::from(content_str);

let message = Message {
role,
content: content_arc,
};

// Cache the token count for the new message before adding it
let message_token_count = estimate_message_token_count(&message);
Expand Down Expand Up @@ -170,9 +187,13 @@ impl LLMSession {
/// Sets a new system prompt for the session.
/// Updates the token count accordingly.
pub fn set_system_prompt(&mut self, prompt: String) {
// Allocate prompt in arena and create Arc<str>
let prompt_str = self.arena.alloc_str(&prompt);
let prompt_arc: Arc<str> = Arc::from(prompt_str);

self.system_prompt = Message {
role: Role::System,
content: prompt,
content: prompt_arc,
};
}

Expand All @@ -196,6 +217,10 @@ impl LLMSession {
pub fn get_cached_token_counts(&self) -> &Vec<usize> {
&self.cached_token_counts
}

pub fn get_system_prompt(&self) -> &Message {
&self.system_prompt
}
}

/// Estimates the number of tokens in a string.
Expand Down
6 changes: 3 additions & 3 deletions tests/client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ fn test_claude_client() {
log::error!("Error: {}", e);
Message {
role: System,
content: format!("An error occurred: {:?}", e),
content: format!("An error occurred: {:?}", e).into(),
}
})
});
Expand Down Expand Up @@ -129,7 +129,7 @@ pub fn test_grok_client() {
log::error!("Error: {}", e);
Message {
role: crate::Role::System,
content: format!("An error occurred: {:?}", e),
content: format!("An error occurred: {:?}", e).into(),
}
})
});
Expand Down Expand Up @@ -168,7 +168,7 @@ fn test_openai_client() {
log::error!("Error: {}", e);
Message {
role: Role::System,
content: format!("An error occurred: {:?}", e),
content: format!("An error occurred: {:?}", e).into(),
}
})
});
Expand Down
94 changes: 94 additions & 0 deletions tests/llm_session_bump_allocations_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use cloudllm::client_wrapper::{ClientWrapper, Message, Role, TokenUsage};
use cloudllm::LLMSession;
use async_trait::async_trait;
use openai_rust2 as openai_rust;
use std::sync::Arc;
use tokio::sync::Mutex;

// Mock client for testing
struct MockClient {
usage: Mutex<Option<TokenUsage>>,
response_content: String,
}

impl MockClient {
fn new(response_content: String) -> Self {
Self {
usage: Mutex::new(None),
response_content,
}
}
}

#[async_trait]
impl ClientWrapper for MockClient {
async fn send_message(
&self,
_messages: &[Message],
_optional_search_parameters: Option<openai_rust::chat::SearchParameters>,
) -> Result<Message, Box<dyn std::error::Error>> {
Ok(Message {
role: Role::Assistant,
content: self.response_content.clone().into(),
})
}

fn usage_slot(&self) -> Option<&Mutex<Option<TokenUsage>>> {
Some(&self.usage)
}
}

#[tokio::test]
async fn test_arena_allocation() {
let mock_client = Arc::new(MockClient::new("Mock response".to_string()));
let mut session = LLMSession::new(
mock_client,
"Test system prompt".to_string(),
1000,
);

// Send a message
let result = session.send_message(
Role::User,
"Test user message".to_string(),
None,
).await;

assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(&*response.content, "Mock response");

// Verify system prompt is allocated correctly
assert_eq!(&*session.get_system_prompt().content, "Test system prompt");

// Verify conversation history
assert_eq!(session.get_conversation_history().len(), 2); // user message + assistant response
}

#[test]
fn test_set_system_prompt() {
let mock_client = Arc::new(MockClient::new("Response".to_string()));
let mut session = LLMSession::new(
mock_client,
"Initial prompt".to_string(),
1000,
);

// Change system prompt
session.set_system_prompt("Updated prompt".to_string());
assert_eq!(&*session.get_system_prompt().content, "Updated prompt");
}

#[test]
fn test_message_content_is_arc_str() {
// Verify that Message.content is Arc<str> and cloning is cheap
let msg = Message {
role: Role::User,
content: Arc::from("Test message"),
};

let cloned = msg.clone();

// Arc::ptr_eq checks if both Arcs point to the same allocation
assert!(Arc::ptr_eq(&msg.content, &cloned.content));
}
8 changes: 4 additions & 4 deletions tests/llm_session_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl ClientWrapper for MockClient {
) -> Result<Message, Box<dyn std::error::Error>> {
Ok(Message {
role: Role::Assistant,
content: self.response_content.clone(),
content: self.response_content.clone().into(),
})
}

Expand Down Expand Up @@ -70,11 +70,11 @@ async fn test_token_caching() {
// Verify token counts are cached correctly
let expected_user_tokens = llm_session::estimate_message_token_count(&Message {
role: Role::User,
content: user_message.to_string(),
content: user_message.to_string().into(),
});
let expected_response_tokens = llm_session::estimate_message_token_count(&Message {
role: Role::Assistant,
content: "Response".to_string(),
content: "Response".to_string().into(),
});

assert_eq!(session.get_cached_token_counts()[0], expected_user_tokens);
Expand Down Expand Up @@ -126,7 +126,7 @@ fn test_estimate_token_count() {
fn test_estimate_message_token_count() {
let message = Message {
role: Role::User,
content: "test message".to_string(),
content: "test message".to_string().into(),
};
// "test message" = 12 characters = 3 tokens + 1 role token = 4 tokens
assert_eq!(llm_session::estimate_message_token_count(&message), 4);
Expand Down