Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/anam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ async def consume_audio():
MessageRole,
MessageStreamEvent,
PersonaConfig,
SessionOptions,
SessionReplayOptions,
)

__all__ = [
Expand All @@ -79,6 +81,8 @@ async def consume_audio():
"MessageRole",
"MessageStreamEvent",
"PersonaConfig",
"SessionOptions",
"SessionReplayOptions",
"VideoFrame",
# Errors
"AnamError",
Expand Down
16 changes: 8 additions & 8 deletions src/anam/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ._version import __version__
from .errors import AnamError, AuthenticationError, ErrorCode, SessionError
from .types import ClientOptions, PersonaConfig, SessionInfo
from .types import ClientOptions, PersonaConfig, SessionInfo, SessionOptions

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -39,12 +39,14 @@ def _api_url(self) -> str:
"""Get the full API URL."""
return f"{self._base_url}/{self._api_version}"

async def get_session_token(self, persona_config: PersonaConfig) -> str:
async def get_session_token(
self, persona_config: PersonaConfig, session_options: SessionOptions
) -> str:
"""Get a session token using the API key.

Args:
persona_config: The persona configuration to use.

session_options: Session options (optional).
Returns:
The session token string.

Expand All @@ -62,6 +64,7 @@ async def get_session_token(self, persona_config: PersonaConfig) -> str:
body = {
"clientLabel": client_label,
"personaConfig": persona_config.to_dict(),
"sessionOptions": session_options.to_dict(),
}

logger.debug("Requesting session token from %s", url)
Expand Down Expand Up @@ -98,7 +101,7 @@ async def get_session_token(self, persona_config: PersonaConfig) -> str:
async def start_session(
self,
persona_config: PersonaConfig,
session_options: dict[str, Any] | None = None,
session_options: SessionOptions,
) -> SessionInfo:
"""Start a new streaming session.

Expand All @@ -114,19 +117,16 @@ async def start_session(
"""
# Get session token if we don't have one
if not self._session_token:
await self.get_session_token(persona_config)
await self.get_session_token(persona_config, session_options)

url = f"{self._api_url}/engine/session"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self._session_token}",
}
body: dict[str, Any] = {
"personaConfig": persona_config.to_dict(),
"clientMetadata": CLIENT_METADATA,
}
if session_options:
body["sessionOptions"] = session_options

logger.debug("Starting session at %s", url)

Expand Down
8 changes: 6 additions & 2 deletions src/anam/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
MessageStreamEvent,
PersonaConfig,
SessionInfo,
SessionOptions,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -215,15 +216,17 @@ def connect(self) -> "_SessionContextManager":
"""
return _SessionContextManager(self)

async def connect_async(self) -> "Session":
async def connect_async(self, session_options: SessionOptions = SessionOptions()) -> "Session":
"""Connect to Anam and start streaming (without context manager).

Args:
session_options: Session options (default: SessionOptions(enable_session_replay=True)).

Returns:
A Session object for interacting with the avatar.

Note:
You must call session.close() when done.
Prefer using `async with client.connect()` instead.
"""
if self.is_streaming:
raise SessionError("Already connected. Call close() first.")
Expand All @@ -238,6 +241,7 @@ async def connect_async(self) -> "Session":

self._session_info = await self._api_client.start_session(
persona_config=self._persona_config,
session_options=session_options,
)

# Create streaming client with callbacks
Expand Down
35 changes: 35 additions & 0 deletions src/anam/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,41 @@ def to_dict(self) -> dict[str, Any]:
return result


@dataclass
class SessionReplayOptions:
"""Session replay options. Maps to anam-lab sessionReplay schema.

Args:
enable_session_replay: If True (default), session is recorded. Set False to disable.
"""

enable_session_replay: bool = True

def to_dict(self) -> dict[str, Any]:
return {"enableSessionReplay": self.enable_session_replay}


@dataclass
class SessionOptions:
"""Configuration for an Anam session.

Args:
enable_session_replay: If True (default), session is recorded. Set False to disable.
"""

enable_session_replay: bool = True

def __post_init__(self) -> None:
self._session_replay = SessionReplayOptions(
enable_session_replay=self.enable_session_replay
)

def to_dict(self) -> dict[str, Any]:
result: dict[str, Any] = {}
result["sessionReplay"] = self._session_replay.to_dict()
return result


@dataclass
class ClientOptions:
"""Optional configuration for AnamClient.
Expand Down
Loading