Skip to content
2 changes: 1 addition & 1 deletion src/anam/_signalling.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def connect(self) -> None:
"""
logger.debug("Connecting to signalling server: %s", self._ws_url)
try:
self._ws = await websockets.asyncio.client.connect(self._ws_url)
self._ws = await websockets.asyncio.client.connect(self._ws_url, close_timeout=2.0)
self._connection_attempts = 0
logger.info("WebSocket connection established")

Expand Down
26 changes: 17 additions & 9 deletions src/anam/_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ._agent_audio_input_stream import AgentAudioInputStream
from ._signalling import SignalAction, SignallingClient
from ._user_audio_input_track import UserAudioInputTrack
from .types import AgentAudioInputConfig, SessionInfo
from .types import AgentAudioInputConfig, ConnectionClosedCode, SessionInfo

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,6 +75,7 @@ def __init__(
self._agent_audio_input_stream: AgentAudioInputStream | None = None
self._user_audio_input_track: UserAudioInputTrack | None = None
self._audio_transceiver = None # Store transceiver for lazy track creation
self._closing = False

async def connect(self, timeout: float = 30.0) -> None:
"""Start the streaming connection.
Expand Down Expand Up @@ -136,7 +137,7 @@ async def _handle_signal_message(self, message: dict[str, Any]) -> None:
reason = payload if isinstance(payload, str) else "Session ended by server"
logger.info("Session ended by server: %s", reason)
if self._on_connection_closed:
await self._on_connection_closed("server_closed", reason)
await self._on_connection_closed(ConnectionClosedCode.SERVER_CLOSED.value, reason)
await self.close()

elif action_type == SignalAction.WARNING.value:
Expand Down Expand Up @@ -282,7 +283,7 @@ def on_ice_connection_state_change() -> None:
if not self._peer_connection:
return
state = self._peer_connection.iceConnectionState
logger.info("ICE connection state: %s", state)
logger.debug("ICE connection state: %s", state)
if state in ("connected", "completed"):
if not self._is_connected:
self._is_connected = True
Expand All @@ -296,16 +297,19 @@ def on_ice_connection_state_change() -> None:
)
if hasattr(self, "_connection_ready"):
self._connection_ready.set()
elif state == "closed":
if self._on_connection_closed:
asyncio.create_task(self._on_connection_closed("connection_closed", None))

@self._peer_connection.on("connectionstatechange")
def on_connection_state_change() -> None:
if not self._peer_connection:
return
state = self._peer_connection.connectionState
logger.debug("Connection state: %s", state)
if state == "closed":
# Only emit CONNECTION_CLOSED when the connection was lost (e.g. network)
if not self._closing and self._on_connection_closed:
asyncio.create_task(
self._on_connection_closed(ConnectionClosedCode.WEBRTC_FAILURE.value, None)
)

@self._peer_connection.on("track")
def on_track(track: MediaStreamTrack) -> None:
Expand Down Expand Up @@ -348,12 +352,12 @@ async def _setup_data_channel(self) -> None:

@self._data_channel.on("open")
def on_open() -> None:
logger.info("Data channel opened")
logger.debug("Data channel opened")
self._data_channel_open = True

@self._data_channel.on("close")
def on_close() -> None:
logger.info("Data channel closed")
logger.debug("Data channel closed")
self._data_channel_open = False

@self._data_channel.on("message")
Expand Down Expand Up @@ -629,6 +633,9 @@ def audio_track(self) -> MediaStreamTrack | None:

async def close(self) -> None:
"""Close the streaming connection and clean up resources."""
if self._closing:
return
self._closing = True
logger.debug("Closing streaming client")

# Close signalling
Expand Down Expand Up @@ -660,8 +667,9 @@ async def close(self) -> None:
finally:
self._peer_connection = None

self._closing = False
self._is_connected = False
logger.info("Streaming client closed")
logger.debug("Streaming client closed")

def send_user_audio(
self,
Expand Down
4 changes: 0 additions & 4 deletions src/anam/_user_audio_input_track.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ def __init__(self, sample_rate: int, num_channels: int):
# Flag to indicate if track is closed
self._is_closed = False

# Flag to flush buffer on first recv() - handles audio that accumulated
# between track connection and WebRTC starting to pull frames
self._first_recv = True

# Lock for thread-safe buffer access
self._lock = asyncio.Lock()

Expand Down
18 changes: 10 additions & 8 deletions src/anam/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AgentAudioInputConfig,
AnamEvent,
ClientOptions,
ConnectionClosedCode,
Message,
MessageRole,
MessageStreamEvent,
Expand Down Expand Up @@ -224,7 +225,7 @@ async def connect_async(self) -> "Session":
You must call session.close() when done.
Prefer using `async with client.connect()` instead.
"""
if self._is_streaming:
if self.is_streaming:
raise SessionError("Already connected. Call close() first.")

logger.info("Connecting to Anam...")
Expand Down Expand Up @@ -343,7 +344,7 @@ async def _handle_connection_established(self) -> None:

async def _handle_connection_closed(self, code: str, reason: str | None) -> None:
"""Handle connection closed."""
logger.info("Connection closed: %s %s", code, reason)
logger.debug("Connection closed")
self._is_streaming = False
await self._emit(AnamEvent.CONNECTION_CLOSED, code, reason)

Expand All @@ -367,13 +368,14 @@ def create_agent_audio_input_stream(

async def close(self) -> None:
"""Close the connection and clean up resources."""
if self._streaming_client:
if self._streaming_client and self.is_streaming:
self._is_streaming = False
await self._handle_connection_closed(ConnectionClosedCode.NORMAL.value, None)
await self._streaming_client.close()
self._streaming_client = None

self._session_info = None
self._is_streaming = False
logger.info("Client closed")
self._session_info = None
self._message_history.clear()
logger.info("Client closed")

@property
def is_streaming(self) -> bool:
Expand Down Expand Up @@ -677,9 +679,9 @@ async def wait_until_closed(self) -> None:

async def close(self) -> None:
"""Close the session."""
await self._client.close()
self._closed = True
self._close_event.set()
await self._client.close()

@property
def is_active(self) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading