diff --git a/src/anam/_signalling.py b/src/anam/_signalling.py index a8cabb7..dca9048 100644 --- a/src/anam/_signalling.py +++ b/src/anam/_signalling.py @@ -106,7 +106,7 @@ async def connect(self) -> None: """ logger.debug("Connecting to signalling server: %s", self._ws_url) try: - self._ws = await websockets.asyncio.client.connect(self._ws_url) + self._ws = await websockets.asyncio.client.connect(self._ws_url, close_timeout=2.0) self._connection_attempts = 0 logger.info("WebSocket connection established") diff --git a/src/anam/_streaming.py b/src/anam/_streaming.py index f582818..f153d19 100644 --- a/src/anam/_streaming.py +++ b/src/anam/_streaming.py @@ -22,7 +22,7 @@ from ._agent_audio_input_stream import AgentAudioInputStream from ._signalling import SignalAction, SignallingClient from ._user_audio_input_track import UserAudioInputTrack -from .types import AgentAudioInputConfig, SessionInfo +from .types import AgentAudioInputConfig, ConnectionClosedCode, SessionInfo logger = logging.getLogger(__name__) @@ -75,6 +75,7 @@ def __init__( self._agent_audio_input_stream: AgentAudioInputStream | None = None self._user_audio_input_track: UserAudioInputTrack | None = None self._audio_transceiver = None # Store transceiver for lazy track creation + self._closing = False async def connect(self, timeout: float = 30.0) -> None: """Start the streaming connection. @@ -136,7 +137,7 @@ async def _handle_signal_message(self, message: dict[str, Any]) -> None: reason = payload if isinstance(payload, str) else "Session ended by server" logger.info("Session ended by server: %s", reason) if self._on_connection_closed: - await self._on_connection_closed("server_closed", reason) + await self._on_connection_closed(ConnectionClosedCode.SERVER_CLOSED.value, reason) await self.close() elif action_type == SignalAction.WARNING.value: @@ -282,7 +283,7 @@ def on_ice_connection_state_change() -> None: if not self._peer_connection: return state = self._peer_connection.iceConnectionState - logger.info("ICE connection state: %s", state) + logger.debug("ICE connection state: %s", state) if state in ("connected", "completed"): if not self._is_connected: self._is_connected = True @@ -296,9 +297,6 @@ def on_ice_connection_state_change() -> None: ) if hasattr(self, "_connection_ready"): self._connection_ready.set() - elif state == "closed": - if self._on_connection_closed: - asyncio.create_task(self._on_connection_closed("connection_closed", None)) @self._peer_connection.on("connectionstatechange") def on_connection_state_change() -> None: @@ -306,6 +304,12 @@ def on_connection_state_change() -> None: return state = self._peer_connection.connectionState logger.debug("Connection state: %s", state) + if state == "closed": + # Only emit CONNECTION_CLOSED when the connection was lost (e.g. network) + if not self._closing and self._on_connection_closed: + asyncio.create_task( + self._on_connection_closed(ConnectionClosedCode.WEBRTC_FAILURE.value, None) + ) @self._peer_connection.on("track") def on_track(track: MediaStreamTrack) -> None: @@ -348,12 +352,12 @@ async def _setup_data_channel(self) -> None: @self._data_channel.on("open") def on_open() -> None: - logger.info("Data channel opened") + logger.debug("Data channel opened") self._data_channel_open = True @self._data_channel.on("close") def on_close() -> None: - logger.info("Data channel closed") + logger.debug("Data channel closed") self._data_channel_open = False @self._data_channel.on("message") @@ -629,6 +633,9 @@ def audio_track(self) -> MediaStreamTrack | None: async def close(self) -> None: """Close the streaming connection and clean up resources.""" + if self._closing: + return + self._closing = True logger.debug("Closing streaming client") # Close signalling @@ -660,8 +667,9 @@ async def close(self) -> None: finally: self._peer_connection = None + self._closing = False self._is_connected = False - logger.info("Streaming client closed") + logger.debug("Streaming client closed") def send_user_audio( self, diff --git a/src/anam/_user_audio_input_track.py b/src/anam/_user_audio_input_track.py index d8e52f0..9cd88df 100644 --- a/src/anam/_user_audio_input_track.py +++ b/src/anam/_user_audio_input_track.py @@ -53,10 +53,6 @@ def __init__(self, sample_rate: int, num_channels: int): # Flag to indicate if track is closed self._is_closed = False - # Flag to flush buffer on first recv() - handles audio that accumulated - # between track connection and WebRTC starting to pull frames - self._first_recv = True - # Lock for thread-safe buffer access self._lock = asyncio.Lock() diff --git a/src/anam/client.py b/src/anam/client.py index 235aba3..49bf421 100644 --- a/src/anam/client.py +++ b/src/anam/client.py @@ -18,6 +18,7 @@ AgentAudioInputConfig, AnamEvent, ClientOptions, + ConnectionClosedCode, Message, MessageRole, MessageStreamEvent, @@ -224,7 +225,7 @@ async def connect_async(self) -> "Session": You must call session.close() when done. Prefer using `async with client.connect()` instead. """ - if self._is_streaming: + if self.is_streaming: raise SessionError("Already connected. Call close() first.") logger.info("Connecting to Anam...") @@ -343,7 +344,7 @@ async def _handle_connection_established(self) -> None: async def _handle_connection_closed(self, code: str, reason: str | None) -> None: """Handle connection closed.""" - logger.info("Connection closed: %s %s", code, reason) + logger.debug("Connection closed") self._is_streaming = False await self._emit(AnamEvent.CONNECTION_CLOSED, code, reason) @@ -367,13 +368,14 @@ def create_agent_audio_input_stream( async def close(self) -> None: """Close the connection and clean up resources.""" - if self._streaming_client: + if self._streaming_client and self.is_streaming: + self._is_streaming = False + await self._handle_connection_closed(ConnectionClosedCode.NORMAL.value, None) await self._streaming_client.close() self._streaming_client = None - - self._session_info = None - self._is_streaming = False - logger.info("Client closed") + self._session_info = None + self._message_history.clear() + logger.info("Client closed") @property def is_streaming(self) -> bool: @@ -677,9 +679,9 @@ async def wait_until_closed(self) -> None: async def close(self) -> None: """Close the session.""" - await self._client.close() self._closed = True self._close_event.set() + await self._client.close() @property def is_active(self) -> bool: diff --git a/uv.lock b/uv.lock index 9d44fe9..4bc310a 100644 --- a/uv.lock +++ b/uv.lock @@ -188,7 +188,7 @@ wheels = [ [[package]] name = "anam" -version = "0.1.0" +version = "0.2.0a2" source = { editable = "." } dependencies = [ { name = "aiohttp" },