From d8c7df3dc90d96f337912cc9e06e78022275c49a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 1 Jan 2026 13:30:51 +0000 Subject: [PATCH 1/7] Initial plan From 9eb7ea8782be96d47fdafb3e3e351600843d7a04 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 1 Jan 2026 13:35:32 +0000 Subject: [PATCH 2/7] Refactor conversation history management into thread-safe class Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com> --- mini_ema/bot/conversation_history.py | 89 +++++++++++ mini_ema/bot/pretty_gemini_bot.py | 18 +-- tests/test_conversation_history.py | 220 +++++++++++++++++++++++++++ 3 files changed, 318 insertions(+), 9 deletions(-) create mode 100644 mini_ema/bot/conversation_history.py create mode 100644 tests/test_conversation_history.py diff --git a/mini_ema/bot/conversation_history.py b/mini_ema/bot/conversation_history.py new file mode 100644 index 0000000..32878de --- /dev/null +++ b/mini_ema/bot/conversation_history.py @@ -0,0 +1,89 @@ +"""Thread-safe conversation history management for chat bots.""" + +import threading +from typing import Any + + +class ConversationHistory: + """Thread-safe conversation history manager. + + This class manages conversation history with a maximum number of rounds, + ensuring thread-safe operations when multiple threads access the history. + Each round consists of a user message and an assistant response. + """ + + def __init__(self, max_rounds: int = 10, messages_per_round: int = 2): + """Initialize the conversation history manager. + + Args: + max_rounds: Maximum number of conversation rounds to keep in history. + A round consists of messages_per_round messages (typically user + assistant). + Must be >= 0. If 0, no history is kept. + messages_per_round: Number of messages per conversation round (default: 2). + """ + self._lock = threading.Lock() + self._history: list[Any] = [] + self._max_rounds = max(0, max_rounds) + self._messages_per_round = max(1, messages_per_round) + + def add_messages(self, messages: list[Any]) -> None: + """Add messages to the conversation history in a thread-safe manner. + + Args: + messages: List of messages to add to the history. + """ + with self._lock: + self._history.extend(messages) + + def get_recent_messages(self, max_rounds: int | None = None) -> list[Any]: + """Get the most recent N rounds of messages in a thread-safe manner. + + Args: + max_rounds: Maximum number of rounds to retrieve. If None, uses the + instance's max_rounds setting. Must be >= 0. + + Returns: + List of recent messages. If max_rounds is 0 or history is empty, + returns an empty list. + """ + with self._lock: + # Use instance max_rounds if not specified + rounds = self._max_rounds if max_rounds is None else max(0, max_rounds) + + # If rounds is 0, return empty list + if rounds == 0: + return [] + + # Calculate the maximum number of messages to return + max_messages = rounds * self._messages_per_round + + # Handle the edge case where max_messages is 0 + if max_messages == 0: + return [] + + # Return the last max_messages from history + # Handle the case where max_messages >= len(history) + return self._history[-max_messages:] if self._history else [] + + def clear(self) -> None: + """Clear all conversation history in a thread-safe manner.""" + with self._lock: + self._history = [] + + def get_all_messages(self) -> list[Any]: + """Get all messages in the conversation history in a thread-safe manner. + + Returns: + List of all messages in the history. + """ + with self._lock: + return self._history.copy() + + def __len__(self) -> int: + """Get the number of messages in the conversation history. + + Returns: + Number of messages in the history. + """ + with self._lock: + return len(self._history) diff --git a/mini_ema/bot/pretty_gemini_bot.py b/mini_ema/bot/pretty_gemini_bot.py index bd80f74..937b6b2 100644 --- a/mini_ema/bot/pretty_gemini_bot.py +++ b/mini_ema/bot/pretty_gemini_bot.py @@ -10,6 +10,7 @@ from pydantic import BaseModel, Field from .bare_gemini_bot import BareGeminiBot +from .conversation_history import ConversationHistory class EmaMessage(BaseModel): @@ -83,12 +84,14 @@ def __init__(self, api_key: str | None = None, model: str | None = None, thinkin # Initialize the Gemini client self.client = genai.Client(api_key=self.api_key) - # Initialize conversation history array - self.conversation_history = [] + # Initialize thread-safe conversation history manager + self.conversation_history = ConversationHistory( + max_rounds=self.history_length, messages_per_round=MESSAGES_PER_ROUND + ) def clear(self): """Clear conversation history.""" - self.conversation_history = [] + self.conversation_history.clear() def get_response(self, message: str, username: str = "Phoenix") -> Iterable[dict]: """Generate a structured response using Gemini API with character personality. @@ -108,10 +111,8 @@ def get_response(self, message: str, username: str = "Phoenix") -> Iterable[dict # Format the message with XML tags to separate username and message formatted_message = f"{username}\n{message}" - # Get the recent N rounds of history based on history_length - # Each round consists of a user message and an assistant response - max_history_messages = self.history_length * MESSAGES_PER_ROUND - recent_history = self.conversation_history[-max_history_messages:] + # Get the recent N rounds of history from the thread-safe history manager + recent_history = self.conversation_history.get_recent_messages() # Create a new chat session with the recent history chat = self.client.chats.create( @@ -158,8 +159,7 @@ def get_response(self, message: str, username: str = "Phoenix") -> Iterable[dict new_assistant_message = updated_history[history_before_length + 1] # Verify both messages exist and are valid (not None) if new_user_message is not None and new_assistant_message is not None: - self.conversation_history.append(new_user_message) - self.conversation_history.append(new_assistant_message) + self.conversation_history.add_messages([new_user_message, new_assistant_message]) # Yield the response with metadata yield { diff --git a/tests/test_conversation_history.py b/tests/test_conversation_history.py new file mode 100644 index 0000000..9c984e0 --- /dev/null +++ b/tests/test_conversation_history.py @@ -0,0 +1,220 @@ +"""Unit tests for ConversationHistory class.""" + +import threading +import time + +from mini_ema.bot.conversation_history import ConversationHistory + + +def test_initialization(): + """Test basic initialization of ConversationHistory.""" + history = ConversationHistory(max_rounds=5, messages_per_round=2) + assert len(history) == 0 + assert history.get_all_messages() == [] + + +def test_add_messages(): + """Test adding messages to history.""" + history = ConversationHistory(max_rounds=5, messages_per_round=2) + messages = ["user message", "assistant response"] + history.add_messages(messages) + assert len(history) == 2 + assert history.get_all_messages() == messages + + +def test_get_recent_messages_basic(): + """Test getting recent messages with basic scenarios.""" + history = ConversationHistory(max_rounds=3, messages_per_round=2) + + # Add 2 rounds (4 messages) + history.add_messages(["user1", "assistant1"]) + history.add_messages(["user2", "assistant2"]) + + # Get all messages (default to max_rounds=3) + recent = history.get_recent_messages() + assert recent == ["user1", "assistant1", "user2", "assistant2"] + + # Get last 1 round + recent = history.get_recent_messages(max_rounds=1) + assert recent == ["user2", "assistant2"] + + +def test_get_recent_messages_exceeds_history(): + """Test getting recent messages when requested rounds exceed actual history.""" + history = ConversationHistory(max_rounds=5, messages_per_round=2) + + # Add only 1 round (2 messages) + history.add_messages(["user1", "assistant1"]) + + # Request 5 rounds, but only 1 exists + recent = history.get_recent_messages() + assert recent == ["user1", "assistant1"] + + +def test_get_recent_messages_zero_rounds(): + """Test edge case when max_rounds is 0.""" + history = ConversationHistory(max_rounds=0, messages_per_round=2) + + # Add messages + history.add_messages(["user1", "assistant1"]) + + # Should return empty list when max_rounds is 0 + recent = history.get_recent_messages() + assert recent == [] + + # Explicitly request 0 rounds + history2 = ConversationHistory(max_rounds=5, messages_per_round=2) + history2.add_messages(["user1", "assistant1"]) + recent2 = history2.get_recent_messages(max_rounds=0) + assert recent2 == [] + + +def test_automatic_trimming(): + """Test that history automatically keeps only recent N rounds.""" + history = ConversationHistory(max_rounds=2, messages_per_round=2) + + # Add 4 rounds (8 messages) + history.add_messages(["user1", "assistant1"]) + history.add_messages(["user2", "assistant2"]) + history.add_messages(["user3", "assistant3"]) + history.add_messages(["user4", "assistant4"]) + + # get_recent_messages should return only last 2 rounds (4 messages) + recent = history.get_recent_messages() + assert recent == ["user3", "assistant3", "user4", "assistant4"] + assert len(recent) == 4 + + +def test_clear(): + """Test clearing conversation history.""" + history = ConversationHistory(max_rounds=5, messages_per_round=2) + + # Add messages + history.add_messages(["user1", "assistant1"]) + assert len(history) == 2 + + # Clear history + history.clear() + assert len(history) == 0 + assert history.get_all_messages() == [] + + +def test_empty_history(): + """Test operations on empty history.""" + history = ConversationHistory(max_rounds=5, messages_per_round=2) + + assert len(history) == 0 + assert history.get_all_messages() == [] + assert history.get_recent_messages() == [] + + # Clear empty history should not raise error + history.clear() + assert len(history) == 0 + + +def test_thread_safety(): + """Test thread safety of ConversationHistory operations.""" + history = ConversationHistory(max_rounds=100, messages_per_round=2) + errors = [] + + def add_messages_thread(thread_id, count): + """Add messages from a thread.""" + try: + for i in range(count): + history.add_messages([f"user_{thread_id}_{i}", f"assistant_{thread_id}_{i}"]) + except Exception as e: + errors.append(e) + + def read_messages_thread(count): + """Read messages from a thread.""" + try: + for _ in range(count): + history.get_recent_messages() + time.sleep(0.001) # Small delay to interleave operations + except Exception as e: + errors.append(e) + + # Create multiple threads that add and read messages concurrently + threads = [] + for i in range(5): + t = threading.Thread(target=add_messages_thread, args=(i, 10)) + threads.append(t) + t.start() + + for _ in range(3): + t = threading.Thread(target=read_messages_thread, args=(20,)) + threads.append(t) + t.start() + + # Wait for all threads to complete + for t in threads: + t.join() + + # Check no errors occurred + assert len(errors) == 0 + + # Verify total messages added (5 threads * 10 rounds * 2 messages per round) + assert len(history) == 100 + + +def test_different_messages_per_round(): + """Test with different messages_per_round values.""" + # Test with 3 messages per round + history = ConversationHistory(max_rounds=2, messages_per_round=3) + + history.add_messages(["msg1", "msg2", "msg3"]) + history.add_messages(["msg4", "msg5", "msg6"]) + history.add_messages(["msg7", "msg8", "msg9"]) + + # Should return last 2 rounds (6 messages) + recent = history.get_recent_messages() + assert recent == ["msg4", "msg5", "msg6", "msg7", "msg8", "msg9"] + assert len(recent) == 6 + + +def test_negative_max_rounds(): + """Test that negative max_rounds is handled correctly.""" + # Should be converted to 0 + history = ConversationHistory(max_rounds=-5, messages_per_round=2) + history.add_messages(["user1", "assistant1"]) + + # Should return empty list + recent = history.get_recent_messages() + assert recent == [] + + +def test_array_slicing_edge_cases(): + """Test array slicing edge cases, especially when index is 0.""" + history = ConversationHistory(max_rounds=5, messages_per_round=2) + + # Add some messages + history.add_messages(["user1", "assistant1"]) + history.add_messages(["user2", "assistant2"]) + + # Request 0 rounds (edge case for array slicing with -0) + recent = history.get_recent_messages(max_rounds=0) + assert recent == [] + + # Verify regular slicing still works + recent = history.get_recent_messages(max_rounds=1) + assert recent == ["user2", "assistant2"] + + recent = history.get_recent_messages(max_rounds=2) + assert recent == ["user1", "assistant1", "user2", "assistant2"] + + +if __name__ == "__main__": + # Run all tests + test_initialization() + test_add_messages() + test_get_recent_messages_basic() + test_get_recent_messages_exceeds_history() + test_get_recent_messages_zero_rounds() + test_automatic_trimming() + test_clear() + test_empty_history() + test_thread_safety() + test_different_messages_per_round() + test_negative_max_rounds() + test_array_slicing_edge_cases() + print("All tests passed!") From 99e896d8e9e9329755c0ee71eca63e0048a51c9c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 1 Jan 2026 13:38:03 +0000 Subject: [PATCH 3/7] Remove redundant check in ConversationHistory.get_recent_messages Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com> --- mini_ema/bot/conversation_history.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mini_ema/bot/conversation_history.py b/mini_ema/bot/conversation_history.py index 32878de..c27d881 100644 --- a/mini_ema/bot/conversation_history.py +++ b/mini_ema/bot/conversation_history.py @@ -57,10 +57,6 @@ def get_recent_messages(self, max_rounds: int | None = None) -> list[Any]: # Calculate the maximum number of messages to return max_messages = rounds * self._messages_per_round - # Handle the edge case where max_messages is 0 - if max_messages == 0: - return [] - # Return the last max_messages from history # Handle the case where max_messages >= len(history) return self._history[-max_messages:] if self._history else [] From a19b93eefe8abbabdc161da9279aab043a462263 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 1 Jan 2026 13:40:06 +0000 Subject: [PATCH 4/7] Add automatic history trimming to prevent memory leaks Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com> --- mini_ema/bot/conversation_history.py | 9 +++++++++ tests/test_conversation_history.py | 5 +++++ 2 files changed, 14 insertions(+) diff --git a/mini_ema/bot/conversation_history.py b/mini_ema/bot/conversation_history.py index c27d881..a44d585 100644 --- a/mini_ema/bot/conversation_history.py +++ b/mini_ema/bot/conversation_history.py @@ -24,16 +24,25 @@ def __init__(self, max_rounds: int = 10, messages_per_round: int = 2): self._lock = threading.Lock() self._history: list[Any] = [] self._max_rounds = max(0, max_rounds) + # Ensure at least 1 message per round to prevent degenerate cases self._messages_per_round = max(1, messages_per_round) def add_messages(self, messages: list[Any]) -> None: """Add messages to the conversation history in a thread-safe manner. + Messages are appended to the history. If the total number of messages exceeds + the maximum capacity (max_rounds * messages_per_round), older messages are + automatically removed to maintain the limit. + Args: messages: List of messages to add to the history. """ with self._lock: self._history.extend(messages) + # Trim history to keep only the most recent messages + max_capacity = self._max_rounds * self._messages_per_round + if max_capacity > 0 and len(self._history) > max_capacity: + self._history = self._history[-max_capacity:] def get_recent_messages(self, max_rounds: int | None = None) -> list[Any]: """Get the most recent N rounds of messages in a thread-safe manner. diff --git a/tests/test_conversation_history.py b/tests/test_conversation_history.py index 9c984e0..eda2efd 100644 --- a/tests/test_conversation_history.py +++ b/tests/test_conversation_history.py @@ -84,6 +84,11 @@ def test_automatic_trimming(): assert recent == ["user3", "assistant3", "user4", "assistant4"] assert len(recent) == 4 + # The internal history should also be trimmed to 4 messages (automatic trimming) + all_messages = history.get_all_messages() + assert len(all_messages) == 4 + assert all_messages == ["user3", "assistant3", "user4", "assistant4"] + def test_clear(): """Test clearing conversation history.""" From ced71f56209675ce65707cc18ffdabb2c61ad40b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 1 Jan 2026 14:34:34 +0000 Subject: [PATCH 5/7] Simplify ConversationHistory: merge into PrettyGeminiBot, remove params, always trim Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com> --- mini_ema/bot/conversation_history.py | 94 -------------- mini_ema/bot/pretty_gemini_bot.py | 76 +++++++---- tests/test_conversation_history.py | 186 +++++++++++---------------- 3 files changed, 128 insertions(+), 228 deletions(-) delete mode 100644 mini_ema/bot/conversation_history.py diff --git a/mini_ema/bot/conversation_history.py b/mini_ema/bot/conversation_history.py deleted file mode 100644 index a44d585..0000000 --- a/mini_ema/bot/conversation_history.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Thread-safe conversation history management for chat bots.""" - -import threading -from typing import Any - - -class ConversationHistory: - """Thread-safe conversation history manager. - - This class manages conversation history with a maximum number of rounds, - ensuring thread-safe operations when multiple threads access the history. - Each round consists of a user message and an assistant response. - """ - - def __init__(self, max_rounds: int = 10, messages_per_round: int = 2): - """Initialize the conversation history manager. - - Args: - max_rounds: Maximum number of conversation rounds to keep in history. - A round consists of messages_per_round messages (typically user + assistant). - Must be >= 0. If 0, no history is kept. - messages_per_round: Number of messages per conversation round (default: 2). - """ - self._lock = threading.Lock() - self._history: list[Any] = [] - self._max_rounds = max(0, max_rounds) - # Ensure at least 1 message per round to prevent degenerate cases - self._messages_per_round = max(1, messages_per_round) - - def add_messages(self, messages: list[Any]) -> None: - """Add messages to the conversation history in a thread-safe manner. - - Messages are appended to the history. If the total number of messages exceeds - the maximum capacity (max_rounds * messages_per_round), older messages are - automatically removed to maintain the limit. - - Args: - messages: List of messages to add to the history. - """ - with self._lock: - self._history.extend(messages) - # Trim history to keep only the most recent messages - max_capacity = self._max_rounds * self._messages_per_round - if max_capacity > 0 and len(self._history) > max_capacity: - self._history = self._history[-max_capacity:] - - def get_recent_messages(self, max_rounds: int | None = None) -> list[Any]: - """Get the most recent N rounds of messages in a thread-safe manner. - - Args: - max_rounds: Maximum number of rounds to retrieve. If None, uses the - instance's max_rounds setting. Must be >= 0. - - Returns: - List of recent messages. If max_rounds is 0 or history is empty, - returns an empty list. - """ - with self._lock: - # Use instance max_rounds if not specified - rounds = self._max_rounds if max_rounds is None else max(0, max_rounds) - - # If rounds is 0, return empty list - if rounds == 0: - return [] - - # Calculate the maximum number of messages to return - max_messages = rounds * self._messages_per_round - - # Return the last max_messages from history - # Handle the case where max_messages >= len(history) - return self._history[-max_messages:] if self._history else [] - - def clear(self) -> None: - """Clear all conversation history in a thread-safe manner.""" - with self._lock: - self._history = [] - - def get_all_messages(self) -> list[Any]: - """Get all messages in the conversation history in a thread-safe manner. - - Returns: - List of all messages in the history. - """ - with self._lock: - return self._history.copy() - - def __len__(self) -> int: - """Get the number of messages in the conversation history. - - Returns: - Number of messages in the history. - """ - with self._lock: - return len(self._history) diff --git a/mini_ema/bot/pretty_gemini_bot.py b/mini_ema/bot/pretty_gemini_bot.py index 937b6b2..989c60b 100644 --- a/mini_ema/bot/pretty_gemini_bot.py +++ b/mini_ema/bot/pretty_gemini_bot.py @@ -1,8 +1,9 @@ """Pretty Gemini bot with structured outputs and character personality.""" import os +import threading from collections.abc import Iterable -from typing import Literal +from typing import Any, Literal from google import genai from google.genai import types @@ -10,7 +11,55 @@ from pydantic import BaseModel, Field from .bare_gemini_bot import BareGeminiBot -from .conversation_history import ConversationHistory + + +class ConversationHistory: + """Thread-safe conversation history manager. + + This class manages conversation history with a maximum number of rounds, + ensuring thread-safe operations when multiple threads access the history. + Each round consists of 2 messages (user message + assistant response). + """ + + def __init__(self): + """Initialize the conversation history manager. + + Reads max_rounds from PRETTY_GEMINI_BOT_HISTORY_LENGTH environment variable. + """ + self._lock = threading.Lock() + self._history: list[Any] = [] + # Read max_rounds from environment variable + history_length_str = os.getenv("PRETTY_GEMINI_BOT_HISTORY_LENGTH", "10") + max_rounds = max(0, int(history_length_str)) + # Calculate max capacity once in init (max_rounds * 2 messages per round) + self._max_capacity = max_rounds * 2 + + def add_messages(self, messages: list[Any]) -> None: + """Add messages to the conversation history in a thread-safe manner. + + Messages are appended to the history and automatically trimmed to max_capacity. + + Args: + messages: List of messages to add to the history. + """ + with self._lock: + self._history.extend(messages) + # Always trim to max_capacity + self._history = self._history[-self._max_capacity :] if self._max_capacity > 0 else [] + + def get_recent_messages(self) -> list[Any]: + """Get all messages in the conversation history in a thread-safe manner. + + Returns: + List of all messages in the history. + """ + with self._lock: + return self._history.copy() + + def clear(self) -> None: + """Clear all conversation history in a thread-safe manner.""" + with self._lock: + self._history = [] class EmaMessage(BaseModel): @@ -77,17 +126,11 @@ def __init__(self, api_key: str | None = None, model: str | None = None, thinkin thinking_level_str = thinking_level or os.getenv("PRETTY_GEMINI_BOT_THINKING_LEVEL", "MINIMAL") self.thinking_level = getattr(types.ThinkingLevel, thinking_level_str.upper(), types.ThinkingLevel.MINIMAL) - # Get conversation history length from environment variable - history_length_str = os.getenv("PRETTY_GEMINI_BOT_HISTORY_LENGTH", "10") - self.history_length = max(0, int(history_length_str)) # Ensure non-negative - # Initialize the Gemini client self.client = genai.Client(api_key=self.api_key) # Initialize thread-safe conversation history manager - self.conversation_history = ConversationHistory( - max_rounds=self.history_length, messages_per_round=MESSAGES_PER_ROUND - ) + self.conversation_history = ConversationHistory() def clear(self): """Clear conversation history.""" @@ -146,20 +189,9 @@ def get_response(self, message: str, username: str = "Phoenix") -> Iterable[dict # Format the content with character information content = self._format_message(ema_message) - # Add user message and assistant response to conversation history - # Get the full history from the chat session to capture all message parts + # Add the last 2 messages from chat history (user message and assistant response) updated_history = chat.get_history() - # Since we created the chat with existing history and then sent one new message, - # the new messages are at the end. We need to get only the new user message and response. - history_before_length = len(recent_history) - # Validate that we have both user and assistant messages before appending - if len(updated_history) >= history_before_length + MESSAGES_PER_ROUND: - # Extract the new user message and assistant response - new_user_message = updated_history[history_before_length] - new_assistant_message = updated_history[history_before_length + 1] - # Verify both messages exist and are valid (not None) - if new_user_message is not None and new_assistant_message is not None: - self.conversation_history.add_messages([new_user_message, new_assistant_message]) + self.conversation_history.add_messages(updated_history[-2:]) # Yield the response with metadata yield { diff --git a/tests/test_conversation_history.py b/tests/test_conversation_history.py index eda2efd..6229e97 100644 --- a/tests/test_conversation_history.py +++ b/tests/test_conversation_history.py @@ -1,77 +1,50 @@ """Unit tests for ConversationHistory class.""" +import os import threading import time -from mini_ema.bot.conversation_history import ConversationHistory +from mini_ema.bot.pretty_gemini_bot import ConversationHistory def test_initialization(): """Test basic initialization of ConversationHistory.""" - history = ConversationHistory(max_rounds=5, messages_per_round=2) - assert len(history) == 0 - assert history.get_all_messages() == [] + # Set env var for testing + os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "5" + history = ConversationHistory() + assert history._max_capacity == 10 # 5 rounds * 2 messages per round + assert len(history._history) == 0 + assert history.get_recent_messages() == [] def test_add_messages(): """Test adding messages to history.""" - history = ConversationHistory(max_rounds=5, messages_per_round=2) + os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "5" + history = ConversationHistory() messages = ["user message", "assistant response"] history.add_messages(messages) - assert len(history) == 2 - assert history.get_all_messages() == messages + assert len(history._history) == 2 + assert history.get_recent_messages() == messages def test_get_recent_messages_basic(): """Test getting recent messages with basic scenarios.""" - history = ConversationHistory(max_rounds=3, messages_per_round=2) + os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "3" + history = ConversationHistory() # Add 2 rounds (4 messages) history.add_messages(["user1", "assistant1"]) history.add_messages(["user2", "assistant2"]) - # Get all messages (default to max_rounds=3) + # Get all messages recent = history.get_recent_messages() assert recent == ["user1", "assistant1", "user2", "assistant2"] - # Get last 1 round - recent = history.get_recent_messages(max_rounds=1) - assert recent == ["user2", "assistant2"] - - -def test_get_recent_messages_exceeds_history(): - """Test getting recent messages when requested rounds exceed actual history.""" - history = ConversationHistory(max_rounds=5, messages_per_round=2) - - # Add only 1 round (2 messages) - history.add_messages(["user1", "assistant1"]) - - # Request 5 rounds, but only 1 exists - recent = history.get_recent_messages() - assert recent == ["user1", "assistant1"] - - -def test_get_recent_messages_zero_rounds(): - """Test edge case when max_rounds is 0.""" - history = ConversationHistory(max_rounds=0, messages_per_round=2) - - # Add messages - history.add_messages(["user1", "assistant1"]) - - # Should return empty list when max_rounds is 0 - recent = history.get_recent_messages() - assert recent == [] - - # Explicitly request 0 rounds - history2 = ConversationHistory(max_rounds=5, messages_per_round=2) - history2.add_messages(["user1", "assistant1"]) - recent2 = history2.get_recent_messages(max_rounds=0) - assert recent2 == [] - def test_automatic_trimming(): - """Test that history automatically keeps only recent N rounds.""" - history = ConversationHistory(max_rounds=2, messages_per_round=2) + """Test that history automatically trims to max_capacity.""" + os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "2" + history = ConversationHistory() # max_capacity = 4 # Add 4 rounds (8 messages) history.add_messages(["user1", "assistant1"]) @@ -79,47 +52,85 @@ def test_automatic_trimming(): history.add_messages(["user3", "assistant3"]) history.add_messages(["user4", "assistant4"]) - # get_recent_messages should return only last 2 rounds (4 messages) + # Should only keep last 2 rounds (4 messages) recent = history.get_recent_messages() assert recent == ["user3", "assistant3", "user4", "assistant4"] assert len(recent) == 4 - # The internal history should also be trimmed to 4 messages (automatic trimming) - all_messages = history.get_all_messages() - assert len(all_messages) == 4 - assert all_messages == ["user3", "assistant3", "user4", "assistant4"] + +def test_automatic_trimming_on_add(): + """Test that history is automatically trimmed when adding messages.""" + os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "2" + history = ConversationHistory() # max_capacity = 4 + + # Add 1 round + history.add_messages(["user1", "assistant1"]) + assert len(history._history) == 2 + + # Add 1 more round + history.add_messages(["user2", "assistant2"]) + assert len(history._history) == 4 + + # Add 1 more round - should trim the oldest round + history.add_messages(["user3", "assistant3"]) + assert len(history._history) == 4 # Should still be 4 (not 6) + assert history.get_recent_messages() == ["user2", "assistant2", "user3", "assistant3"] + + # Add another round - should trim again + history.add_messages(["user4", "assistant4"]) + assert len(history._history) == 4 + assert history.get_recent_messages() == ["user3", "assistant3", "user4", "assistant4"] def test_clear(): """Test clearing conversation history.""" - history = ConversationHistory(max_rounds=5, messages_per_round=2) + os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "5" + history = ConversationHistory() # Add messages history.add_messages(["user1", "assistant1"]) - assert len(history) == 2 + assert len(history._history) == 2 # Clear history history.clear() - assert len(history) == 0 - assert history.get_all_messages() == [] + assert len(history._history) == 0 + assert history.get_recent_messages() == [] def test_empty_history(): """Test operations on empty history.""" - history = ConversationHistory(max_rounds=5, messages_per_round=2) + os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "5" + history = ConversationHistory() - assert len(history) == 0 - assert history.get_all_messages() == [] + assert len(history._history) == 0 assert history.get_recent_messages() == [] # Clear empty history should not raise error history.clear() - assert len(history) == 0 + assert len(history._history) == 0 + + +def test_zero_max_rounds(): + """Test when max_rounds is 0.""" + os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "0" + history = ConversationHistory() + + # max_capacity should be 0 + assert history._max_capacity == 0 + + # Add messages + history.add_messages(["user1", "assistant1"]) + + # Should return empty list + recent = history.get_recent_messages() + assert recent == [] + assert len(history._history) == 0 def test_thread_safety(): """Test thread safety of ConversationHistory operations.""" - history = ConversationHistory(max_rounds=100, messages_per_round=2) + os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "100" + history = ConversationHistory() errors = [] def add_messages_thread(thread_id, count): @@ -158,54 +169,8 @@ def read_messages_thread(count): # Check no errors occurred assert len(errors) == 0 - # Verify total messages added (5 threads * 10 rounds * 2 messages per round) - assert len(history) == 100 - - -def test_different_messages_per_round(): - """Test with different messages_per_round values.""" - # Test with 3 messages per round - history = ConversationHistory(max_rounds=2, messages_per_round=3) - - history.add_messages(["msg1", "msg2", "msg3"]) - history.add_messages(["msg4", "msg5", "msg6"]) - history.add_messages(["msg7", "msg8", "msg9"]) - - # Should return last 2 rounds (6 messages) - recent = history.get_recent_messages() - assert recent == ["msg4", "msg5", "msg6", "msg7", "msg8", "msg9"] - assert len(recent) == 6 - - -def test_negative_max_rounds(): - """Test that negative max_rounds is handled correctly.""" - # Should be converted to 0 - history = ConversationHistory(max_rounds=-5, messages_per_round=2) - history.add_messages(["user1", "assistant1"]) - - # Should return empty list - recent = history.get_recent_messages() - assert recent == [] - - -def test_array_slicing_edge_cases(): - """Test array slicing edge cases, especially when index is 0.""" - history = ConversationHistory(max_rounds=5, messages_per_round=2) - - # Add some messages - history.add_messages(["user1", "assistant1"]) - history.add_messages(["user2", "assistant2"]) - - # Request 0 rounds (edge case for array slicing with -0) - recent = history.get_recent_messages(max_rounds=0) - assert recent == [] - - # Verify regular slicing still works - recent = history.get_recent_messages(max_rounds=1) - assert recent == ["user2", "assistant2"] - - recent = history.get_recent_messages(max_rounds=2) - assert recent == ["user1", "assistant1", "user2", "assistant2"] + # Verify we don't have more than max_capacity messages + assert len(history._history) <= 200 # 100 rounds * 2 messages if __name__ == "__main__": @@ -213,13 +178,10 @@ def test_array_slicing_edge_cases(): test_initialization() test_add_messages() test_get_recent_messages_basic() - test_get_recent_messages_exceeds_history() - test_get_recent_messages_zero_rounds() test_automatic_trimming() + test_automatic_trimming_on_add() test_clear() test_empty_history() + test_zero_max_rounds() test_thread_safety() - test_different_messages_per_round() - test_negative_max_rounds() - test_array_slicing_edge_cases() print("All tests passed!") From 892d33764a7d656399c1ceaac4d2d56ace07825d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 1 Jan 2026 14:47:42 +0000 Subject: [PATCH 6/7] Add make test command and update documentation Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com> --- .github/copilot-instructions.md | 6 +++++- Makefile | 10 ++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 753fc8b..976ca76 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -64,6 +64,7 @@ curl -LsSf https://astral.sh/uv/install.sh | sh - **Check code style**: Run `make style` to check code formatting and linting - **Auto-format code**: Run `make format` to automatically format code and fix linting issues - **Build package**: Run `make build` to build the package +- **Run tests**: Run `make test` to run all tests - **Help**: Run `make help` to see available make commands All make commands use `uv` internally to run tools in an isolated environment. @@ -81,4 +82,7 @@ The project is configured in `pyproject.toml` with: When testing code changes, use `uvx` to run commands: - `uvx ruff check .` - Run linting - `uvx ruff format .` - Format code -- `uv run python