Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,9 @@ async with client.connect() as session:
await session.talk("This will be spoken immediately")

# Stream text to TTS incrementally (for streaming scenarios)
await session.send_talk_stream(
content="Hello",
start_of_speech=True,
end_of_speech=False,
)
await session.send_talk_stream(
content=" world!",
start_of_speech=False,
end_of_speech=True,
)
talk_stream = session.create_talk_stream()
await talk_stream.send("Hello", end_of_speech=False)
await talk_stream.send(" world!", end_of_speech=True)

# Interrupt the avatar if speaking
await session.interrupt()
Expand Down
15 changes: 8 additions & 7 deletions examples/persona_interactive_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,15 @@ async def interactive_loop(session, display: VideoDisplay) -> None:
message_text = " ".join(parts[1:])
try:
if command == "t":
await session.talk(message_text)
await session.send_talk_stream(message_text)
elif command == "ts":
await session.send_talk_stream(
message_text,
start_of_speech=True,
end_of_speech=True,
correlation_id=None,
)
talk_stream = session.create_talk_stream()
# As demo, split the message into two chunks of equal word count
words = message_text.split()
chunk1 = " ".join(words[: len(words) // 2])
chunk2 = " ".join(words[len(words) // 2 :])
await talk_stream.send(chunk1, end_of_speech=False)
await talk_stream.send(chunk2, end_of_speech=True)
print(f"✅ Sent talk (stream) command: {message_text}")
except Exception as e:
print(f"❌ Error sending talk (stream) command: {e}")
Expand Down
3 changes: 3 additions & 0 deletions src/anam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ async def consume_audio():
from av.video.frame import VideoFrame

from ._agent_audio_input_stream import AgentAudioInputStream
from ._talk_message_stream import TalkMessageStream, TalkMessageStreamState
from ._version import __version__
from .client import AnamClient, Session
from .errors import (
Expand Down Expand Up @@ -70,6 +71,8 @@ async def consume_audio():
# Main client
"AnamClient",
"Session",
"TalkMessageStream",
"TalkMessageStreamState",
# Types
"AgentAudioInputConfig",
"AgentAudioInputStream",
Expand Down
11 changes: 3 additions & 8 deletions src/anam/_signalling.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,26 +333,21 @@ async def send_agent_audio_input_end(self) -> None:
async def send_talk_stream_input(
self,
content: str,
correlation_id: str,
start_of_speech: bool = True,
end_of_speech: bool = True,
correlation_id: str | None = None,
) -> None:
"""Send talk stream input to make the avatar speak text directly.

This bypasses the LLM and sends text directly to TTS.

Args:
content: The text for the avatar to speak.
correlation_id: ID to correlate this message with interruptions.
Callers should use TalkMessageStream which manages this.
start_of_speech: Whether this is the start of a speech sequence.
end_of_speech: Whether this is the end of a speech sequence.
correlation_id: Optional ID to correlate this message with interruptions.
"""
import uuid

# Generate correlation ID if not provided
if correlation_id is None:
correlation_id = str(uuid.uuid4())

message = {
"actionType": SignalAction.TALK_STREAM_INPUT.value,
"sessionId": self._session_id,
Expand Down
13 changes: 9 additions & 4 deletions src/anam/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
on_connection_established: Callable[[], Awaitable[None]] | None = None,
on_connection_closed: Callable[[str, str | None], Awaitable[None]] | None = None,
on_session_ready: Callable[[], Awaitable[None]] | None = None,
on_talk_stream_interrupted: Callable[[str], Awaitable[None]] | None = None,
custom_ice_servers: list[dict[str, Any]] | None = None,
):
"""Initialize the streaming client.
Expand All @@ -52,6 +53,7 @@ def __init__(
on_connection_established: Callback when connected.
on_connection_closed: Callback when disconnected.
on_session_ready: Callback when sessionready signal is received (ready to receive TTS).
on_talk_stream_interrupted: Callback(correlation_id) when talk stream is interrupted.
custom_ice_servers: Custom ICE servers (optional).
"""
self._session_info = session_info
Expand All @@ -62,6 +64,7 @@ def __init__(
self._on_connection_established = on_connection_established
self._on_connection_closed = on_connection_closed
self._on_session_ready = on_session_ready
self._on_talk_stream_interrupted = on_talk_stream_interrupted

# Configuration
self._ice_servers = custom_ice_servers or session_info.ice_servers
Expand Down Expand Up @@ -154,6 +157,8 @@ async def _handle_signal_message(self, message: dict[str, Any]) -> None:
elif action_type == SignalAction.TALK_STREAM_INTERRUPTED.value:
correlation_id = payload.get("correlationId") if isinstance(payload, dict) else None
logger.debug("Talk stream interrupted: %s", correlation_id)
if self._on_talk_stream_interrupted and correlation_id:
await self._on_talk_stream_interrupted(correlation_id)

async def _handle_answer(self, payload: dict[str, Any]) -> None:
"""Handle SDP answer from server."""
Expand Down Expand Up @@ -555,11 +560,11 @@ def send_interrupt(self) -> None:
self.send_data_message(json.dumps(message))

async def send_talk(self, content: str) -> None:
"""Send text for the avatar to speak directly (bypasses LLM).
"""Send a single text message directly to TTS via REST API.

This sends the talk command via REST API to the engine, which is the
correct method for simple talk commands. For streaming text, use the
signalling client's send_talk_stream_input method.
Convenience method for one-off messages. Sends text directly to TTS,
bypassing the LLM, but slightly higher latency than send_talk_stream().
For streaming multiple chunks, use create_talk_stream() to manage the stream.

Args:
content: The text for the avatar to speak.
Expand Down
145 changes: 145 additions & 0 deletions src/anam/_talk_message_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Talk message stream for sending streaming text to TTS via WebSocket signalling."""

import logging
from enum import Enum
from typing import TYPE_CHECKING

from ._signalling import SignallingClient
from .types import AnamEvent

if TYPE_CHECKING:
from .client import AnamClient

logger = logging.getLogger(__name__)


class TalkMessageStreamState(str, Enum):
"""State of a talk message stream."""

UNSTARTED = "unstarted"
STREAMING = "streaming"
INTERRUPTED = "interrupted"
ENDED = "ended"


class TalkMessageStream:
"""Stream for sending text chunks to TTS with a stable correlation ID.

Manages correlation_id internally so callers don't need to track it across
chunks. All chunks in the same speech sequence share the same correlation_id,
which is used for interruption correlation.

Example:
```python
# Streaming multiple chunks
stream = session.create_talk_stream()
for i, chunk in enumerate(llm_chunks):
await stream.send(chunk, end_of_speech=(i == len(llm_chunks) - 1))

# Single message (or use session.send_talk_stream for convenience)
stream = session.create_talk_stream()
await stream.send("Hello!", end_of_speech=True)
```
"""

def __init__(
self,
correlation_id: str,
signalling_client: SignallingClient,
client: "AnamClient",
):
"""Initialize the talk message stream.

Args:
correlation_id: ID to correlate this stream with interruptions.
signalling_client: Signalling client for sending messages.
client: AnamClient for registering interrupt listener.
"""
self._correlation_id = correlation_id
self._signalling_client = signalling_client
self._client = client
self._state = TalkMessageStreamState.UNSTARTED
self._interrupt_handler = self._on_talk_stream_interrupted
client.add_listener(AnamEvent.TALK_STREAM_INTERRUPTED, self._interrupt_handler)

def _on_talk_stream_interrupted(self, correlation_id: str) -> None:
"""Handle TALK_STREAM_INTERRUPTED event if it matches this stream."""
if correlation_id == self._correlation_id:
self._state = TalkMessageStreamState.INTERRUPTED
self._deactivate()

@property
def correlation_id(self) -> str:
"""The correlation ID for this stream (for interruption correlation)."""
return self._correlation_id

@property
def is_active(self) -> bool:
"""Whether the stream can accept more data."""
return self._state in (
TalkMessageStreamState.UNSTARTED,
TalkMessageStreamState.STREAMING,
)

@property
def state(self) -> TalkMessageStreamState:
"""Current state of the stream."""
return self._state

async def send(self, content: str, end_of_speech: bool = False) -> None:
"""Send a text chunk to TTS.

Args:
content: The text chunk to speak.
end_of_speech: Whether this is the final chunk of the speech.

Raises:
RuntimeError: If the stream is not in an active state (already
ended or interrupted).
"""
if self._state not in (
TalkMessageStreamState.UNSTARTED,
TalkMessageStreamState.STREAMING,
):
raise RuntimeError(f"Talk stream is not in an active state: {self._state}")

start_of_speech = self._state == TalkMessageStreamState.UNSTARTED

await self._signalling_client.send_talk_stream_input(
content=content,
correlation_id=self._correlation_id,
start_of_speech=start_of_speech,
end_of_speech=end_of_speech,
)

self._state = TalkMessageStreamState.STREAMING
if end_of_speech:
self._state = TalkMessageStreamState.ENDED
self._deactivate()

async def end(self) -> None:
"""Signal end of speech with an empty chunk.

Use when you've sent all content chunks but need to explicitly end
the stream. No-op if the stream is already ended.
"""
if self._state == TalkMessageStreamState.ENDED:
logger.debug("Talk stream is already ended via end of speech. No need to call end().")
return

if self._state != TalkMessageStreamState.STREAMING:
logger.warning("Talk stream is not in streaming state: %s", self._state)
return

await self._signalling_client.send_talk_stream_input(
content="",
correlation_id=self._correlation_id,
start_of_speech=False,
end_of_speech=True,
)
self._state = TalkMessageStreamState.ENDED
self._deactivate()

def _deactivate(self) -> None:
"""Clean up listeners when stream ends or is interrupted."""
self._client.remove_listener(AnamEvent.TALK_STREAM_INTERRUPTED, self._interrupt_handler)
Loading
Loading