diff --git a/README.md b/README.md index 7a5f419..a727872 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,8 @@ asyncio.run(main()) - 🤖 **Audio-passthrough** - Send TTS generated audio input and receive rendered synchronized audio/video avatar - 🗣️ **Direct text-to-speech** - Send text directly to TTS for immediate speech output (bypasses LLM processing) - 🎤 **Real-time user audio input** - Send raw audio samples (e.g. from microphone) to Anam for processing (turnkey solution: STT → LLM → TTS → Avatar) +- 🔧 **Client tool events** - Receive function/tool invocations from the LLM for custom client-side logic (function calling) +- 🧠 **Reasoning stream events** - Receive LLM chain-of-thought/reasoning as it streams (REASONING_STREAM_EVENT_RECEIVED, REASONING_HISTORY_UPDATED) - 📡 **Async iterator API** - Clean, Pythonic async/await patterns for continuous stream of audio/video frames - 🎯 **Event-driven API** - Simple decorator-based event handlers for discrete events - 📝 **Fully typed** - Complete type hints for IDE support @@ -144,7 +146,7 @@ For best performance, we suggest using 24kHz mono audio. The provided audio is r Register callbacks for connection and message events using the `@client.on()` decorator: ```python -from anam import AnamEvent, Message, MessageRole, MessageStreamEvent +from anam import AnamEvent, ClientToolEvent, Message, MessageRole, MessageStreamEvent, ReasoningMessage, ReasoningStreamEvent @client.on(AnamEvent.CONNECTION_ESTABLISHED) async def on_connected(): @@ -197,6 +199,29 @@ async def on_message_history_updated(messages: list[Message]): print(f"📝 Conversation history: {len(messages)} messages") for msg in messages: print(f" {msg.role}: {msg.content[:50]}...") + +@client.on(AnamEvent.CLIENT_TOOL_EVENT_RECEIVED) +async def on_client_tool_event(event: ClientToolEvent): + """Called when the LLM invokes a client-side tool (function calling). + + Use this to implement custom tools: handle the event, execute your logic, + and optionally send a response back via the talk stream. + """ + if event.event_name == "redirect": + url = event.event_data.get("url", "") + print(f"🔧 Tool 'redirect' called with url={url}") + +@client.on(AnamEvent.REASONING_STREAM_EVENT_RECEIVED) +async def on_reasoning_stream_event(event: ReasoningStreamEvent): + """Called for each chunk of the LLM's reasoning/chain-of-thought as it streams.""" + print(event.content, end="", flush=True) + if event.end_of_thought: + print() # New line when thought completes + +@client.on(AnamEvent.REASONING_HISTORY_UPDATED) +async def on_reasoning_history_updated(messages: list[ReasoningMessage]): + """Called when the reasoning history is updated (after a thought completes).""" + print(f"🧠 Reasoning history: {len(messages)} thought(s)") ``` ### Session @@ -219,10 +244,12 @@ async with client.connect() as session: # Interrupt the avatar if speaking await session.interrupt() - # Get message history + # Get message history and reasoning history history = client.get_message_history() for msg in history: print(f"{msg.role}: {msg.content}") + for thought in client.get_reasoning_history(): + print(f" [reasoning] {thought.content[:50]}...") # Wait until the session ends await session.wait_until_closed() diff --git a/src/anam/__init__.py b/src/anam/__init__.py index 686804e..565f290 100644 --- a/src/anam/__init__.py +++ b/src/anam/__init__.py @@ -58,11 +58,14 @@ async def consume_audio(): AgentAudioInputConfig, AnamEvent, ClientOptions, + ClientToolEvent, ConnectionClosedCode, Message, MessageRole, MessageStreamEvent, PersonaConfig, + ReasoningMessage, + ReasoningStreamEvent, SessionOptions, SessionReplayOptions, ) @@ -79,11 +82,14 @@ async def consume_audio(): "AnamEvent", "AudioFrame", "ClientOptions", + "ClientToolEvent", "ConnectionClosedCode", "Message", "MessageRole", "MessageStreamEvent", "PersonaConfig", + "ReasoningMessage", + "ReasoningStreamEvent", "SessionOptions", "SessionReplayOptions", "VideoFrame", diff --git a/src/anam/client.py b/src/anam/client.py index 99936e0..345c496 100644 --- a/src/anam/client.py +++ b/src/anam/client.py @@ -20,11 +20,14 @@ AgentAudioInputConfig, AnamEvent, ClientOptions, + ClientToolEvent, ConnectionClosedCode, Message, MessageRole, MessageStreamEvent, PersonaConfig, + ReasoningMessage, + ReasoningStreamEvent, SessionInfo, SessionOptions, ) @@ -137,6 +140,7 @@ def __init__( self._streaming_client: StreamingClient | None = None self._is_streaming = False self._message_history: list[Message] = [] + self._reasoning_history: list[ReasoningMessage] = [] def on(self, event: AnamEvent) -> Callable[[T], T]: """Decorator to register an event handler. @@ -316,6 +320,68 @@ async def _handle_data_message(self, data: dict[str, Any]) -> None: AnamEvent.MESSAGE_HISTORY_UPDATED, self._message_history.copy() ) + elif message_type == "clientToolEvent": + # Convert WebRTC format (snake_case) to ClientToolEvent + tool_data = data.get("data", {}) + client_tool_event = ClientToolEvent( + event_uid=tool_data.get("event_uid", ""), + session_id=tool_data.get("session_id", ""), + event_name=tool_data.get("event_name", ""), + event_data=tool_data.get("event_data", {}), + timestamp=tool_data.get("timestamp", ""), + timestamp_user_action=tool_data.get("timestamp_user_action", ""), + user_action_correlation_id=tool_data.get("user_action_correlation_id", ""), + ) + await self._emit(AnamEvent.CLIENT_TOOL_EVENT_RECEIVED, client_tool_event) + + elif message_type == "reasoningText": + # Convert WebRTC format to ReasoningStreamEvent + reason_data = data.get("data", {}) + message_id = reason_data.get("message_id", "") + role = reason_data.get("role", "persona") + content = reason_data.get("content", "") + end_of_thought = reason_data.get("end_of_thought", False) + + stream_event_id = f"{role}::{message_id}" + stream_event = ReasoningStreamEvent( + id=stream_event_id, + content=content, + role=role, + end_of_thought=end_of_thought, + ) + await self._emit(AnamEvent.REASONING_STREAM_EVENT_RECEIVED, stream_event) + + self._process_reasoning_stream_event(stream_event) + + if end_of_thought: + await self._emit( + AnamEvent.REASONING_HISTORY_UPDATED, + self._reasoning_history.copy(), + ) + + def _process_reasoning_stream_event(self, event: ReasoningStreamEvent) -> None: + """Process a reasoning stream event and update reasoning history.""" + existing_index = next( + (i for i, msg in enumerate(self._reasoning_history) if msg.id == event.id), + None, + ) + + if existing_index is not None: + existing = self._reasoning_history[existing_index] + self._reasoning_history[existing_index] = ReasoningMessage( + id=existing.id, + content=existing.content + event.content, + role=existing.role, + ) + else: + self._reasoning_history.append( + ReasoningMessage( + id=event.id, + content=event.content, + role=event.role, + ) + ) + def _process_message_stream_event(self, event: MessageStreamEvent, timestamp: str) -> None: """Process a message stream event and update message history.""" # Find existing message with same ID (for both user and persona messages) @@ -391,6 +457,7 @@ async def close(self) -> None: self._streaming_client = None self._session_info = None self._message_history.clear() + self._reasoning_history.clear() logger.info("Client closed") @property @@ -411,6 +478,14 @@ def get_message_history(self) -> list[Message]: """ return self._message_history.copy() + def get_reasoning_history(self) -> list[ReasoningMessage]: + """Get the current reasoning/chain-of-thought history. + + Returns: + A list of reasoning messages from the LLM. + """ + return self._reasoning_history.copy() + def set_persona_config(self, persona_config: PersonaConfig) -> None: """Set the persona configuration. diff --git a/src/anam/types.py b/src/anam/types.py index 6ab936d..196806b 100644 --- a/src/anam/types.py +++ b/src/anam/types.py @@ -21,6 +21,13 @@ class AnamEvent(str, Enum): # Persona events TALK_STREAM_INTERRUPTED = "talk_stream_interrupted" + # Tool events (function calling) + CLIENT_TOOL_EVENT_RECEIVED = "client_tool_event_received" + + # Reasoning events (LLM chain-of-thought) + REASONING_STREAM_EVENT_RECEIVED = "reasoning_stream_event_received" + REASONING_HISTORY_UPDATED = "reasoning_history_updated" + # Error events ERROR = "error" SERVER_WARNING = "server_warning" @@ -197,6 +204,71 @@ class MessageStreamEvent: interrupted: bool = False +@dataclass +class ReasoningStreamEvent: + """A streaming reasoning/chain-of-thought event from the LLM. + + Emitted for each chunk of the LLM's reasoning as it streams. Use this + to display or log the model's internal reasoning process. + + Attributes: + id: Unique identifier for the thought (same for all chunks). + content: The text content of this chunk. + role: Role of the reasoning (e.g., "persona"). + end_of_thought: Whether this is the final chunk of the thought. + """ + + id: str + content: str + role: str + end_of_thought: bool + + +@dataclass +class ReasoningMessage: + """A complete reasoning/chain-of-thought message from the LLM. + + Accumulated from ReasoningStreamEvent chunks. Emitted in + REASONING_HISTORY_UPDATED when end_of_thought is True. + + Attributes: + id: Unique identifier for the thought. + content: The full text content of the thought. + role: Role of the reasoning (e.g., "persona"). + """ + + id: str + content: str + role: str + + +@dataclass +class ClientToolEvent: + """A client tool event from the LLM (function calling). + + Emitted when the LLM invokes a client-side tool. Use this to implement + function calling: handle the event, execute the tool logic, and optionally + send a response back via the talk stream. + + Attributes: + event_uid: Unique ID for this event. + session_id: Session ID. + event_name: The tool name (e.g., "redirect", "get_weather"). + event_data: LLM-generated parameters for the tool. + timestamp: ISO timestamp when event was created. + timestamp_user_action: ISO timestamp of user action that triggered this. + user_action_correlation_id: Correlation ID for tracking. + """ + + event_uid: str + session_id: str + event_name: str + event_data: dict[str, Any] + timestamp: str + timestamp_user_action: str + user_action_correlation_id: str + + @dataclass class AgentAudioInputConfig: """Configuration for agent audio input stream. diff --git a/tests/test_client.py b/tests/test_client.py index 09a6bbf..fc95dd8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -2,7 +2,15 @@ import pytest -from anam import AnamClient, AnamEvent, ClientOptions, PersonaConfig +from anam import ( + AnamClient, + AnamEvent, + ClientOptions, + ClientToolEvent, + PersonaConfig, + ReasoningMessage, + ReasoningStreamEvent, +) from anam.errors import ConfigurationError @@ -103,6 +111,128 @@ async def handler() -> None: assert handler not in client._event_callbacks[AnamEvent.CONNECTION_ESTABLISHED] +class TestClientToolEvent: + """Tests for client tool event handling.""" + + @pytest.mark.asyncio + async def test_handle_client_tool_event_emits_event(self) -> None: + """Test that clientToolEvent data channel messages emit CLIENT_TOOL_EVENT_RECEIVED.""" + client = AnamClient(api_key="test-key", persona_id="test-persona") + + received_events: list[ClientToolEvent] = [] + + @client.on(AnamEvent.CLIENT_TOOL_EVENT_RECEIVED) + async def on_tool_event(event: ClientToolEvent) -> None: + received_events.append(event) + + # Simulate data channel message (WebRTC format) + data = { + "messageType": "clientToolEvent", + "data": { + "event_uid": "evt-123", + "session_id": "sess-456", + "event_name": "redirect", + "event_data": {"url": "https://example.com"}, + "timestamp": "2024-01-15T10:00:00Z", + "timestamp_user_action": "2024-01-15T09:59:59Z", + "user_action_correlation_id": "corr-789", + }, + } + + await client._handle_data_message(data) + + assert len(received_events) == 1 + event = received_events[0] + assert event.event_uid == "evt-123" + assert event.session_id == "sess-456" + assert event.event_name == "redirect" + assert event.event_data == {"url": "https://example.com"} + assert event.timestamp == "2024-01-15T10:00:00Z" + assert event.timestamp_user_action == "2024-01-15T09:59:59Z" + assert event.user_action_correlation_id == "corr-789" + + +class TestReasoningStreamEvent: + """Tests for reasoning stream event handling.""" + + @pytest.mark.asyncio + async def test_handle_reasoning_text_emits_stream_event(self) -> None: + """Test that reasoningText data channel messages emit REASONING_STREAM_EVENT_RECEIVED.""" + client = AnamClient(api_key="test-key", persona_id="test-persona") + + received_events: list[ReasoningStreamEvent] = [] + + @client.on(AnamEvent.REASONING_STREAM_EVENT_RECEIVED) + async def on_reasoning_event(event: ReasoningStreamEvent) -> None: + received_events.append(event) + + data = { + "messageType": "reasoningText", + "data": { + "message_id": "msg-123", + "content_index": 0, + "content": "Let me think...", + "role": "persona", + "end_of_thought": False, + }, + } + + await client._handle_data_message(data) + + assert len(received_events) == 1 + event = received_events[0] + assert event.id == "persona::msg-123" + assert event.content == "Let me think..." + assert event.role == "persona" + assert event.end_of_thought is False + + @pytest.mark.asyncio + async def test_reasoning_end_of_thought_emits_history_updated(self) -> None: + """Test that end_of_thought=True emits REASONING_HISTORY_UPDATED.""" + client = AnamClient(api_key="test-key", persona_id="test-persona") + + stream_events: list[ReasoningStreamEvent] = [] + history_updates: list[list[ReasoningMessage]] = [] + + @client.on(AnamEvent.REASONING_STREAM_EVENT_RECEIVED) + async def on_stream(event: ReasoningStreamEvent) -> None: + stream_events.append(event) + + @client.on(AnamEvent.REASONING_HISTORY_UPDATED) + async def on_history(messages: list[ReasoningMessage]) -> None: + history_updates.append(messages) + + # First chunk + await client._handle_data_message({ + "messageType": "reasoningText", + "data": { + "message_id": "msg-1", + "content_index": 0, + "content": "First ", + "role": "persona", + "end_of_thought": False, + }, + }) + # Second chunk (end of thought) + await client._handle_data_message({ + "messageType": "reasoningText", + "data": { + "message_id": "msg-1", + "content_index": 1, + "content": "part.", + "role": "persona", + "end_of_thought": True, + }, + }) + + assert len(stream_events) == 2 + assert stream_events[1].end_of_thought is True + assert len(history_updates) == 1 + assert len(history_updates[0]) == 1 + assert history_updates[0][0].content == "First part." + assert history_updates[0][0].id == "persona::msg-1" + + class TestPersonaConfig: """Tests for PersonaConfig.""" diff --git a/uv.lock b/uv.lock index 89f9fe8..b98206d 100644 --- a/uv.lock +++ b/uv.lock @@ -188,7 +188,7 @@ wheels = [ [[package]] name = "anam" -version = "0.3.0" +version = "0.4.0a1" source = { editable = "." } dependencies = [ { name = "aiohttp" },