Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions .github/copilot-instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ curl -LsSf https://astral.sh/uv/install.sh | sh
- **Check code style**: Run `make style` to check code formatting and linting
- **Auto-format code**: Run `make format` to automatically format code and fix linting issues
- **Build package**: Run `make build` to build the package
- **Run tests**: Run `make test` to run all tests
- **Help**: Run `make help` to see available make commands

All make commands use `uv` internally to run tools in an isolated environment.
Expand All @@ -74,11 +75,4 @@ The project is configured in `pyproject.toml` with:
- Line length: 119 characters
- Target Python version: 3.11
- Google Python style standards
- Import sorting with isort

### Testing

When testing code changes, use `uvx` to run commands:
- `uvx ruff check .` - Run linting
- `uvx ruff format .` - Format code
- `uv run python <script>.py` - Run Python scripts
- Import sorting with isort
27 changes: 27 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: Tests

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
test:
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v5
with:
python-version: '3.11'
github-token: ${{ github.token }}

- name: Run tests
run: make test
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@ build:
@echo "Building package..."
uv build

.PHONY: test
test:
@echo "Running tests..."
uv run pytest -vvv tests

.PHONY: help
help:
@echo "Available targets:"
@echo " style - Check code formatting and linting with Ruff"
@echo " format - Auto-format code and fix linting issues with Ruff"
@echo " build - Build the package"
@echo " test - Run tests"
@echo " help - Show this help message"
84 changes: 58 additions & 26 deletions mini_ema/bot/pretty_gemini_bot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Pretty Gemini bot with structured outputs and character personality."""

import os
import threading
from collections.abc import Iterable
from typing import Literal
from typing import Any, Literal

from google import genai
from google.genai import types
Expand All @@ -12,6 +13,55 @@
from .bare_gemini_bot import BareGeminiBot


class ConversationHistory:
"""Thread-safe conversation history manager.

This class manages conversation history with a maximum number of rounds,
ensuring thread-safe operations when multiple threads access the history.
Each round consists of 2 messages (user message + assistant response).
"""

def __init__(self):
"""Initialize the conversation history manager.

Reads max_rounds from PRETTY_GEMINI_BOT_HISTORY_LENGTH environment variable.
"""
self._lock = threading.Lock()
self._history: list[Any] = []
# Read max_rounds from environment variable
history_length_str = os.getenv("PRETTY_GEMINI_BOT_HISTORY_LENGTH", "10")
max_rounds = max(0, int(history_length_str))
# Calculate max capacity once in init (max_rounds * 2 messages per round)
self._max_capacity = max_rounds * 2

def add_messages(self, messages: list[Any]) -> None:
"""Add messages to the conversation history in a thread-safe manner.

Messages are appended to the history and automatically trimmed to max_capacity.

Args:
messages: List of messages to add to the history.
"""
with self._lock:
self._history.extend(messages)
# Always trim to max_capacity
self._history = self._history[-self._max_capacity :] if self._max_capacity > 0 else []

def get_recent_messages(self) -> list[Any]:
"""Get all messages in the conversation history in a thread-safe manner.

Returns:
List of all messages in the history.
"""
with self._lock:
return self._history.copy()

def clear(self) -> None:
"""Clear all conversation history in a thread-safe manner."""
with self._lock:
self._history = []


class EmaMessage(BaseModel):
"""Structured message format with character personality."""

Expand Down Expand Up @@ -76,19 +126,15 @@ def __init__(self, api_key: str | None = None, model: str | None = None, thinkin
thinking_level_str = thinking_level or os.getenv("PRETTY_GEMINI_BOT_THINKING_LEVEL", "MINIMAL")
self.thinking_level = getattr(types.ThinkingLevel, thinking_level_str.upper(), types.ThinkingLevel.MINIMAL)

# Get conversation history length from environment variable
history_length_str = os.getenv("PRETTY_GEMINI_BOT_HISTORY_LENGTH", "10")
self.history_length = max(0, int(history_length_str)) # Ensure non-negative

# Initialize the Gemini client
self.client = genai.Client(api_key=self.api_key)

# Initialize conversation history array
self.conversation_history = []
# Initialize thread-safe conversation history manager
self.conversation_history = ConversationHistory()

def clear(self):
"""Clear conversation history."""
self.conversation_history = []
self.conversation_history.clear()

def get_response(self, message: str, username: str = "Phoenix") -> Iterable[dict]:
"""Generate a structured response using Gemini API with character personality.
Expand All @@ -108,10 +154,8 @@ def get_response(self, message: str, username: str = "Phoenix") -> Iterable[dict
# Format the message with XML tags to separate username and message
formatted_message = f"<username>{username}</username>\n<user_message>{message}</user_message>"

# Get the recent N rounds of history based on history_length
# Each round consists of a user message and an assistant response
max_history_messages = self.history_length * MESSAGES_PER_ROUND
recent_history = self.conversation_history[-max_history_messages:]
# Get the recent N rounds of history from the thread-safe history manager
recent_history = self.conversation_history.get_recent_messages()

# Create a new chat session with the recent history
chat = self.client.chats.create(
Expand Down Expand Up @@ -145,21 +189,9 @@ def get_response(self, message: str, username: str = "Phoenix") -> Iterable[dict
# Format the content with character information
content = self._format_message(ema_message)

# Add user message and assistant response to conversation history
# Get the full history from the chat session to capture all message parts
# Add the last 2 messages from chat history (user message and assistant response)
updated_history = chat.get_history()
# Since we created the chat with existing history and then sent one new message,
# the new messages are at the end. We need to get only the new user message and response.
history_before_length = len(recent_history)
# Validate that we have both user and assistant messages before appending
if len(updated_history) >= history_before_length + MESSAGES_PER_ROUND:
# Extract the new user message and assistant response
new_user_message = updated_history[history_before_length]
new_assistant_message = updated_history[history_before_length + 1]
# Verify both messages exist and are valid (not None)
if new_user_message is not None and new_assistant_message is not None:
self.conversation_history.append(new_user_message)
self.conversation_history.append(new_assistant_message)
self.conversation_history.add_messages(updated_history[-2:])

# Yield the response with metadata
yield {
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies = [
"google-genai>=1.0.0",
"python-dotenv>=1.0.0",
"httpx[socks]>=0.28.1",
"pytest>=8.0.0",
]

[build-system]
Expand Down
187 changes: 187 additions & 0 deletions tests/test_conversation_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
"""Unit tests for ConversationHistory class."""

import os
import threading
import time

from mini_ema.bot.pretty_gemini_bot import ConversationHistory


def test_initialization():
"""Test basic initialization of ConversationHistory."""
# Set env var for testing
os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "5"
history = ConversationHistory()
assert history._max_capacity == 10 # 5 rounds * 2 messages per round
assert len(history._history) == 0
assert history.get_recent_messages() == []


def test_add_messages():
"""Test adding messages to history."""
os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "5"
history = ConversationHistory()
messages = ["user message", "assistant response"]
history.add_messages(messages)
assert len(history._history) == 2
assert history.get_recent_messages() == messages


def test_get_recent_messages_basic():
"""Test getting recent messages with basic scenarios."""
os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "3"
history = ConversationHistory()

# Add 2 rounds (4 messages)
history.add_messages(["user1", "assistant1"])
history.add_messages(["user2", "assistant2"])

# Get all messages
recent = history.get_recent_messages()
assert recent == ["user1", "assistant1", "user2", "assistant2"]


def test_automatic_trimming():
"""Test that history automatically trims to max_capacity."""
os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "2"
history = ConversationHistory() # max_capacity = 4

# Add 4 rounds (8 messages)
history.add_messages(["user1", "assistant1"])
history.add_messages(["user2", "assistant2"])
history.add_messages(["user3", "assistant3"])
history.add_messages(["user4", "assistant4"])

# Should only keep last 2 rounds (4 messages)
recent = history.get_recent_messages()
assert recent == ["user3", "assistant3", "user4", "assistant4"]
assert len(recent) == 4


def test_automatic_trimming_on_add():
"""Test that history is automatically trimmed when adding messages."""
os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "2"
history = ConversationHistory() # max_capacity = 4

# Add 1 round
history.add_messages(["user1", "assistant1"])
assert len(history._history) == 2

# Add 1 more round
history.add_messages(["user2", "assistant2"])
assert len(history._history) == 4

# Add 1 more round - should trim the oldest round
history.add_messages(["user3", "assistant3"])
assert len(history._history) == 4 # Should still be 4 (not 6)
assert history.get_recent_messages() == ["user2", "assistant2", "user3", "assistant3"]

# Add another round - should trim again
history.add_messages(["user4", "assistant4"])
assert len(history._history) == 4
assert history.get_recent_messages() == ["user3", "assistant3", "user4", "assistant4"]


def test_clear():
"""Test clearing conversation history."""
os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "5"
history = ConversationHistory()

# Add messages
history.add_messages(["user1", "assistant1"])
assert len(history._history) == 2

# Clear history
history.clear()
assert len(history._history) == 0
assert history.get_recent_messages() == []


def test_empty_history():
"""Test operations on empty history."""
os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "5"
history = ConversationHistory()

assert len(history._history) == 0
assert history.get_recent_messages() == []

# Clear empty history should not raise error
history.clear()
assert len(history._history) == 0


def test_zero_max_rounds():
"""Test when max_rounds is 0."""
os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "0"
history = ConversationHistory()

# max_capacity should be 0
assert history._max_capacity == 0

# Add messages
history.add_messages(["user1", "assistant1"])

# Should return empty list
recent = history.get_recent_messages()
assert recent == []
assert len(history._history) == 0


def test_thread_safety():
"""Test thread safety of ConversationHistory operations."""
os.environ["PRETTY_GEMINI_BOT_HISTORY_LENGTH"] = "100"
history = ConversationHistory()
errors = []

def add_messages_thread(thread_id, count):
"""Add messages from a thread."""
try:
for i in range(count):
history.add_messages([f"user_{thread_id}_{i}", f"assistant_{thread_id}_{i}"])
except Exception as e:
errors.append(e)

def read_messages_thread(count):
"""Read messages from a thread."""
try:
for _ in range(count):
history.get_recent_messages()
time.sleep(0.001) # Small delay to interleave operations
except Exception as e:
errors.append(e)

# Create multiple threads that add and read messages concurrently
threads = []
for i in range(5):
t = threading.Thread(target=add_messages_thread, args=(i, 10))
threads.append(t)
t.start()

for _ in range(3):
t = threading.Thread(target=read_messages_thread, args=(20,))
threads.append(t)
t.start()

# Wait for all threads to complete
for t in threads:
t.join()

# Check no errors occurred
assert len(errors) == 0

# Verify we don't have more than max_capacity messages
assert len(history._history) <= 200 # 100 rounds * 2 messages


if __name__ == "__main__":
# Run all tests
test_initialization()
test_add_messages()
test_get_recent_messages_basic()
test_automatic_trimming()
test_automatic_trimming_on_add()
test_clear()
test_empty_history()
test_zero_max_rounds()
test_thread_safety()
print("All tests passed!")
Loading