diff --git a/README.md b/README.md index 219e4b7..7a5f419 100644 --- a/README.md +++ b/README.md @@ -212,16 +212,9 @@ async with client.connect() as session: await session.talk("This will be spoken immediately") # Stream text to TTS incrementally (for streaming scenarios) - await session.send_talk_stream( - content="Hello", - start_of_speech=True, - end_of_speech=False, - ) - await session.send_talk_stream( - content=" world!", - start_of_speech=False, - end_of_speech=True, - ) + talk_stream = session.create_talk_stream() + await talk_stream.send("Hello", end_of_speech=False) + await talk_stream.send(" world!", end_of_speech=True) # Interrupt the avatar if speaking await session.interrupt() diff --git a/examples/persona_interactive_video.py b/examples/persona_interactive_video.py index 0a641c2..082429f 100644 --- a/examples/persona_interactive_video.py +++ b/examples/persona_interactive_video.py @@ -137,14 +137,15 @@ async def interactive_loop(session, display: VideoDisplay) -> None: message_text = " ".join(parts[1:]) try: if command == "t": - await session.talk(message_text) + await session.send_talk_stream(message_text) elif command == "ts": - await session.send_talk_stream( - message_text, - start_of_speech=True, - end_of_speech=True, - correlation_id=None, - ) + talk_stream = session.create_talk_stream() + # As demo, split the message into two chunks of equal word count + words = message_text.split() + chunk1 = " ".join(words[: len(words) // 2]) + chunk2 = " ".join(words[len(words) // 2 :]) + await talk_stream.send(chunk1, end_of_speech=False) + await talk_stream.send(chunk2, end_of_speech=True) print(f"✅ Sent talk (stream) command: {message_text}") except Exception as e: print(f"❌ Error sending talk (stream) command: {e}") diff --git a/src/anam/__init__.py b/src/anam/__init__.py index fe42d13..686804e 100644 --- a/src/anam/__init__.py +++ b/src/anam/__init__.py @@ -43,6 +43,7 @@ async def consume_audio(): from av.video.frame import VideoFrame from ._agent_audio_input_stream import AgentAudioInputStream +from ._talk_message_stream import TalkMessageStream, TalkMessageStreamState from ._version import __version__ from .client import AnamClient, Session from .errors import ( @@ -70,6 +71,8 @@ async def consume_audio(): # Main client "AnamClient", "Session", + "TalkMessageStream", + "TalkMessageStreamState", # Types "AgentAudioInputConfig", "AgentAudioInputStream", diff --git a/src/anam/_signalling.py b/src/anam/_signalling.py index dca9048..03d8eb8 100644 --- a/src/anam/_signalling.py +++ b/src/anam/_signalling.py @@ -333,9 +333,9 @@ async def send_agent_audio_input_end(self) -> None: async def send_talk_stream_input( self, content: str, + correlation_id: str, start_of_speech: bool = True, end_of_speech: bool = True, - correlation_id: str | None = None, ) -> None: """Send talk stream input to make the avatar speak text directly. @@ -343,16 +343,11 @@ async def send_talk_stream_input( Args: content: The text for the avatar to speak. + correlation_id: ID to correlate this message with interruptions. + Callers should use TalkMessageStream which manages this. start_of_speech: Whether this is the start of a speech sequence. end_of_speech: Whether this is the end of a speech sequence. - correlation_id: Optional ID to correlate this message with interruptions. """ - import uuid - - # Generate correlation ID if not provided - if correlation_id is None: - correlation_id = str(uuid.uuid4()) - message = { "actionType": SignalAction.TALK_STREAM_INPUT.value, "sessionId": self._session_id, diff --git a/src/anam/_streaming.py b/src/anam/_streaming.py index fbe3dd2..c4f1222 100644 --- a/src/anam/_streaming.py +++ b/src/anam/_streaming.py @@ -42,6 +42,7 @@ def __init__( on_connection_established: Callable[[], Awaitable[None]] | None = None, on_connection_closed: Callable[[str, str | None], Awaitable[None]] | None = None, on_session_ready: Callable[[], Awaitable[None]] | None = None, + on_talk_stream_interrupted: Callable[[str], Awaitable[None]] | None = None, custom_ice_servers: list[dict[str, Any]] | None = None, ): """Initialize the streaming client. @@ -52,6 +53,7 @@ def __init__( on_connection_established: Callback when connected. on_connection_closed: Callback when disconnected. on_session_ready: Callback when sessionready signal is received (ready to receive TTS). + on_talk_stream_interrupted: Callback(correlation_id) when talk stream is interrupted. custom_ice_servers: Custom ICE servers (optional). """ self._session_info = session_info @@ -62,6 +64,7 @@ def __init__( self._on_connection_established = on_connection_established self._on_connection_closed = on_connection_closed self._on_session_ready = on_session_ready + self._on_talk_stream_interrupted = on_talk_stream_interrupted # Configuration self._ice_servers = custom_ice_servers or session_info.ice_servers @@ -154,6 +157,8 @@ async def _handle_signal_message(self, message: dict[str, Any]) -> None: elif action_type == SignalAction.TALK_STREAM_INTERRUPTED.value: correlation_id = payload.get("correlationId") if isinstance(payload, dict) else None logger.debug("Talk stream interrupted: %s", correlation_id) + if self._on_talk_stream_interrupted and correlation_id: + await self._on_talk_stream_interrupted(correlation_id) async def _handle_answer(self, payload: dict[str, Any]) -> None: """Handle SDP answer from server.""" @@ -555,11 +560,11 @@ def send_interrupt(self) -> None: self.send_data_message(json.dumps(message)) async def send_talk(self, content: str) -> None: - """Send text for the avatar to speak directly (bypasses LLM). + """Send a single text message directly to TTS via REST API. - This sends the talk command via REST API to the engine, which is the - correct method for simple talk commands. For streaming text, use the - signalling client's send_talk_stream_input method. + Convenience method for one-off messages. Sends text directly to TTS, + bypassing the LLM, but slightly higher latency than send_talk_stream(). + For streaming multiple chunks, use create_talk_stream() to manage the stream. Args: content: The text for the avatar to speak. diff --git a/src/anam/_talk_message_stream.py b/src/anam/_talk_message_stream.py new file mode 100644 index 0000000..de88396 --- /dev/null +++ b/src/anam/_talk_message_stream.py @@ -0,0 +1,145 @@ +"""Talk message stream for sending streaming text to TTS via WebSocket signalling.""" + +import logging +from enum import Enum +from typing import TYPE_CHECKING + +from ._signalling import SignallingClient +from .types import AnamEvent + +if TYPE_CHECKING: + from .client import AnamClient + +logger = logging.getLogger(__name__) + + +class TalkMessageStreamState(str, Enum): + """State of a talk message stream.""" + + UNSTARTED = "unstarted" + STREAMING = "streaming" + INTERRUPTED = "interrupted" + ENDED = "ended" + + +class TalkMessageStream: + """Stream for sending text chunks to TTS with a stable correlation ID. + + Manages correlation_id internally so callers don't need to track it across + chunks. All chunks in the same speech sequence share the same correlation_id, + which is used for interruption correlation. + + Example: + ```python + # Streaming multiple chunks + stream = session.create_talk_stream() + for i, chunk in enumerate(llm_chunks): + await stream.send(chunk, end_of_speech=(i == len(llm_chunks) - 1)) + + # Single message (or use session.send_talk_stream for convenience) + stream = session.create_talk_stream() + await stream.send("Hello!", end_of_speech=True) + ``` + """ + + def __init__( + self, + correlation_id: str, + signalling_client: SignallingClient, + client: "AnamClient", + ): + """Initialize the talk message stream. + + Args: + correlation_id: ID to correlate this stream with interruptions. + signalling_client: Signalling client for sending messages. + client: AnamClient for registering interrupt listener. + """ + self._correlation_id = correlation_id + self._signalling_client = signalling_client + self._client = client + self._state = TalkMessageStreamState.UNSTARTED + self._interrupt_handler = self._on_talk_stream_interrupted + client.add_listener(AnamEvent.TALK_STREAM_INTERRUPTED, self._interrupt_handler) + + def _on_talk_stream_interrupted(self, correlation_id: str) -> None: + """Handle TALK_STREAM_INTERRUPTED event if it matches this stream.""" + if correlation_id == self._correlation_id: + self._state = TalkMessageStreamState.INTERRUPTED + self._deactivate() + + @property + def correlation_id(self) -> str: + """The correlation ID for this stream (for interruption correlation).""" + return self._correlation_id + + @property + def is_active(self) -> bool: + """Whether the stream can accept more data.""" + return self._state in ( + TalkMessageStreamState.UNSTARTED, + TalkMessageStreamState.STREAMING, + ) + + @property + def state(self) -> TalkMessageStreamState: + """Current state of the stream.""" + return self._state + + async def send(self, content: str, end_of_speech: bool = False) -> None: + """Send a text chunk to TTS. + + Args: + content: The text chunk to speak. + end_of_speech: Whether this is the final chunk of the speech. + + Raises: + RuntimeError: If the stream is not in an active state (already + ended or interrupted). + """ + if self._state not in ( + TalkMessageStreamState.UNSTARTED, + TalkMessageStreamState.STREAMING, + ): + raise RuntimeError(f"Talk stream is not in an active state: {self._state}") + + start_of_speech = self._state == TalkMessageStreamState.UNSTARTED + + await self._signalling_client.send_talk_stream_input( + content=content, + correlation_id=self._correlation_id, + start_of_speech=start_of_speech, + end_of_speech=end_of_speech, + ) + + self._state = TalkMessageStreamState.STREAMING + if end_of_speech: + self._state = TalkMessageStreamState.ENDED + self._deactivate() + + async def end(self) -> None: + """Signal end of speech with an empty chunk. + + Use when you've sent all content chunks but need to explicitly end + the stream. No-op if the stream is already ended. + """ + if self._state == TalkMessageStreamState.ENDED: + logger.debug("Talk stream is already ended via end of speech. No need to call end().") + return + + if self._state != TalkMessageStreamState.STREAMING: + logger.warning("Talk stream is not in streaming state: %s", self._state) + return + + await self._signalling_client.send_talk_stream_input( + content="", + correlation_id=self._correlation_id, + start_of_speech=False, + end_of_speech=True, + ) + self._state = TalkMessageStreamState.ENDED + self._deactivate() + + def _deactivate(self) -> None: + """Clean up listeners when stream ends or is interrupted.""" + self._client.remove_listener(AnamEvent.TALK_STREAM_INTERRUPTED, self._interrupt_handler) diff --git a/src/anam/client.py b/src/anam/client.py index 0a9ccab..99936e0 100644 --- a/src/anam/client.py +++ b/src/anam/client.py @@ -4,6 +4,7 @@ import asyncio import logging +import uuid from collections.abc import AsyncIterator from typing import Any, Awaitable, Callable, TypeVar @@ -13,6 +14,7 @@ from ._agent_audio_input_stream import AgentAudioInputStream from ._api import CoreApiClient from ._streaming import StreamingClient +from ._talk_message_stream import TalkMessageStream from .errors import ConfigurationError, SessionError from .types import ( AgentAudioInputConfig, @@ -251,6 +253,7 @@ async def connect_async(self, session_options: SessionOptions = SessionOptions() on_connection_established=self._handle_connection_established, on_connection_closed=self._handle_connection_closed, on_session_ready=self._handle_session_ready, + on_talk_stream_interrupted=self._handle_talk_stream_interrupted, custom_ice_servers=self._options.ice_servers, ) @@ -351,6 +354,10 @@ async def _handle_session_ready(self) -> None: """Handle session ready (signalling: ready to receive user audio or TTS).""" await self._emit(AnamEvent.SESSION_READY) + async def _handle_talk_stream_interrupted(self, correlation_id: str) -> None: + """Handle talk stream interrupted signal from server.""" + await self._emit(AnamEvent.TALK_STREAM_INTERRUPTED, correlation_id) + async def _handle_connection_closed(self, code: str, reason: str | None) -> None: """Handle connection closed.""" logger.debug("Connection closed") @@ -545,27 +552,28 @@ async def interrupt(self) -> None: streaming.send_interrupt() - async def send_talk_stream( - self, - content: str, - start_of_speech: bool = True, - end_of_speech: bool = True, - correlation_id: str | None = None, - ) -> None: - """Stream text directly to TTS via WebSocket signalling. + def create_talk_stream(self, correlation_id: str | None = None) -> TalkMessageStream: + """Create a talk message stream for sending text chunks to TTS. - Sends text directly to TTS, bypassing the LLM. - Ideal for streaming scenarios with continuous text. - Lower latency than talk(). + The stream manages correlation_id internally so you don't need to track + it across chunks. Use this for streaming LLM output. All chunks in the + same speech share one correlation_id for interruption handling. Args: - content: The text for the avatar to speak. - start_of_speech: Whether this is the start of a speech sequence. - end_of_speech: Whether this is the end of a speech sequence. - correlation_id: Optional ID to correlate with interruptions. + correlation_id: Optional ID. If not provided, a UUID is generated. + + Returns: + TalkMessageStream with send() and end() methods. Raises: SessionError: If not connected. + + Example: + ```python + stream = session.create_talk_stream() + for i, chunk in enumerate(llm_chunks): + await stream.send(chunk, end_of_speech=(i == len(llm_chunks) - 1)) + ``` """ if not self._client._streaming_client: raise SessionError("Not connected") @@ -574,14 +582,31 @@ async def send_talk_stream( if not signalling_client: raise SessionError("Signalling client not initialized") - # The send_talk_stream_input method will buffer the message if WebSocket isn't ready - await signalling_client.send_talk_stream_input( - content=content, - start_of_speech=start_of_speech, - end_of_speech=end_of_speech, + if correlation_id is None or correlation_id.strip() == "": + correlation_id = str(uuid.uuid4()) + + return TalkMessageStream( correlation_id=correlation_id, + signalling_client=signalling_client, + client=self._client, ) + async def send_talk_stream(self, content: str) -> None: + """Send a single text message directly to TTS via WebSocket signalling. + + Convenience method for one-off messages. Sends text directly to TTS, + bypassing the LLM. For streaming multiple chunks, use create_talk_stream() + instead to manage the stream. + + Args: + content: The text for the avatar to speak. + + Raises: + SessionError: If not connected. + """ + stream = self.create_talk_stream() + await stream.send(content, end_of_speech=True) + def create_agent_audio_input_stream( self, config: AgentAudioInputConfig ) -> AgentAudioInputStream: diff --git a/tests/test_talk_message_stream.py b/tests/test_talk_message_stream.py new file mode 100644 index 0000000..e87ea1a --- /dev/null +++ b/tests/test_talk_message_stream.py @@ -0,0 +1,225 @@ +"""Tests for TalkMessageStream.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from anam import AnamEvent +from anam._talk_message_stream import TalkMessageStream, TalkMessageStreamState + + +@pytest.fixture +def mock_signalling_client() -> MagicMock: + """Create a mock SignallingClient with async send_talk_stream_input.""" + client = MagicMock() + client.send_talk_stream_input = AsyncMock() + return client + + +@pytest.fixture +def mock_anam_client() -> MagicMock: + """Create a mock AnamClient with add_listener and remove_listener.""" + client = MagicMock() + client.add_listener = MagicMock() + client.remove_listener = MagicMock() + return client + + +@pytest.fixture +def stream( + mock_signalling_client: MagicMock, + mock_anam_client: MagicMock, +) -> TalkMessageStream: + """Create a TalkMessageStream with mocked dependencies.""" + return TalkMessageStream( + correlation_id="test-correlation-123", + signalling_client=mock_signalling_client, + client=mock_anam_client, + ) + + +class TestTalkMessageStreamInit: + """Tests for TalkMessageStream initialization.""" + + def test_registers_interrupt_listener( + self, + mock_anam_client: MagicMock, + mock_signalling_client: MagicMock, + ) -> None: + """Test that stream registers for TALK_STREAM_INTERRUPTED on init.""" + TalkMessageStream( + correlation_id="cid-1", + signalling_client=mock_signalling_client, + client=mock_anam_client, + ) + mock_anam_client.add_listener.assert_called_once() + call_args = mock_anam_client.add_listener.call_args + assert call_args[0][0] == AnamEvent.TALK_STREAM_INTERRUPTED + assert callable(call_args[0][1]) + + def test_correlation_id_property(self, stream: TalkMessageStream) -> None: + """Test correlation_id property returns the configured value.""" + assert stream.correlation_id == "test-correlation-123" + + def test_initial_state(self, stream: TalkMessageStream) -> None: + """Test initial state is UNSTARTED and is_active is True.""" + assert stream.state == TalkMessageStreamState.UNSTARTED + assert stream.is_active is True + + +class TestTalkMessageStreamSend: + """Tests for TalkMessageStream.send().""" + + @pytest.mark.asyncio + async def test_first_chunk_has_start_of_speech( + self, + stream: TalkMessageStream, + mock_signalling_client: MagicMock, + ) -> None: + """Test first send uses start_of_speech=True.""" + await stream.send("Hello", end_of_speech=False) + + mock_signalling_client.send_talk_stream_input.assert_called_once_with( + content="Hello", + correlation_id="test-correlation-123", + start_of_speech=True, + end_of_speech=False, + ) + + @pytest.mark.asyncio + async def test_subsequent_chunk_has_start_of_speech_false( + self, + stream: TalkMessageStream, + mock_signalling_client: MagicMock, + ) -> None: + """Test second send uses start_of_speech=False.""" + await stream.send("Hello", end_of_speech=False) + mock_signalling_client.reset_mock() + + await stream.send(" world", end_of_speech=False) + + mock_signalling_client.send_talk_stream_input.assert_called_once_with( + content=" world", + correlation_id="test-correlation-123", + start_of_speech=False, + end_of_speech=False, + ) + + @pytest.mark.asyncio + async def test_end_of_speech_transitions_to_ended( + self, + stream: TalkMessageStream, + mock_signalling_client: MagicMock, + mock_anam_client: MagicMock, + ) -> None: + """Test send with end_of_speech=True transitions to ENDED and deactivates.""" + await stream.send("Done", end_of_speech=True) + + assert stream.state == TalkMessageStreamState.ENDED + assert stream.is_active is False + mock_anam_client.remove_listener.assert_called_once() + + @pytest.mark.asyncio + async def test_send_after_ended_raises( + self, + stream: TalkMessageStream, + mock_signalling_client: MagicMock, + ) -> None: + """Test send raises RuntimeError when stream is already ended.""" + await stream.send("Done", end_of_speech=True) + + with pytest.raises(RuntimeError, match="not in an active state"): + await stream.send("More", end_of_speech=False) + + assert mock_signalling_client.send_talk_stream_input.call_count == 1 + + @pytest.mark.asyncio + async def test_send_after_interrupted_raises( + self, + stream: TalkMessageStream, + mock_signalling_client: MagicMock, + ) -> None: + """Test send raises RuntimeError when stream was interrupted.""" + await stream.send("Hello", end_of_speech=False) + stream._on_talk_stream_interrupted("test-correlation-123") + + with pytest.raises(RuntimeError, match="not in an active state"): + await stream.send("More", end_of_speech=False) + + +class TestTalkMessageStreamEnd: + """Tests for TalkMessageStream.end().""" + + @pytest.mark.asyncio + async def test_end_sends_empty_chunk( + self, + stream: TalkMessageStream, + mock_signalling_client: MagicMock, + ) -> None: + """Test end() sends empty content with end_of_speech=True.""" + await stream.send("Hello", end_of_speech=False) + mock_signalling_client.reset_mock() + + await stream.end() + + mock_signalling_client.send_talk_stream_input.assert_called_once_with( + content="", + correlation_id="test-correlation-123", + start_of_speech=False, + end_of_speech=True, + ) + assert stream.state == TalkMessageStreamState.ENDED + + @pytest.mark.asyncio + async def test_end_when_already_ended_is_noop( + self, + stream: TalkMessageStream, + mock_signalling_client: MagicMock, + ) -> None: + """Test end() when already ended does not send.""" + await stream.send("Done", end_of_speech=True) + mock_signalling_client.reset_mock() + + await stream.end() + + mock_signalling_client.send_talk_stream_input.assert_not_called() + + @pytest.mark.asyncio + async def test_end_when_unstarted_does_not_send( + self, + stream: TalkMessageStream, + mock_signalling_client: MagicMock, + ) -> None: + """Test end() when never sent (UNSTARTED) does not send.""" + await stream.end() + + mock_signalling_client.send_talk_stream_input.assert_not_called() + + +class TestTalkMessageStreamInterruption: + """Tests for TALK_STREAM_INTERRUPTED handling.""" + + def test_matching_correlation_id_sets_interrupted( + self, + stream: TalkMessageStream, + mock_anam_client: MagicMock, + ) -> None: + """Test interrupt event with matching correlation_id sets INTERRUPTED state.""" + assert stream.state == TalkMessageStreamState.UNSTARTED + + stream._on_talk_stream_interrupted("test-correlation-123") + + assert stream.state == TalkMessageStreamState.INTERRUPTED + assert stream.is_active is False + mock_anam_client.remove_listener.assert_called_once() + + def test_non_matching_correlation_id_ignored( + self, + stream: TalkMessageStream, + mock_anam_client: MagicMock, + ) -> None: + """Test interrupt event with different correlation_id is ignored.""" + stream._on_talk_stream_interrupted("other-correlation-456") + + assert stream.state == TalkMessageStreamState.UNSTARTED + mock_anam_client.remove_listener.assert_not_called() diff --git a/uv.lock b/uv.lock index 1200a81..89f9fe8 100644 --- a/uv.lock +++ b/uv.lock @@ -188,7 +188,7 @@ wheels = [ [[package]] name = "anam" -version = "0.2.0" +version = "0.3.0" source = { editable = "." } dependencies = [ { name = "aiohttp" },