diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 767ee8c8..35f6ab2f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,7 +10,7 @@ jobs: - name: Set up python uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.9 - name: Bootstrap poetry run: | curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1 @@ -26,7 +26,7 @@ jobs: - name: Set up python uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.9 - name: Bootstrap poetry run: | curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1 @@ -48,7 +48,7 @@ jobs: - name: Set up python uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.9 - name: Bootstrap poetry run: | curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 897ae427..10579d92 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,7 +10,7 @@ jobs: - name: Set up python uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.9 - name: Bootstrap poetry run: | curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1 @@ -26,7 +26,7 @@ jobs: - name: Set up python uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.9 - name: Bootstrap poetry run: | curl -sSL https://install.python-poetry.org | python - -y --version 1.5.1 diff --git a/README.md b/README.md index f023d00c..a76894ad 100644 --- a/README.md +++ b/README.md @@ -253,6 +253,106 @@ client_tools.register("calculate_sum", calculate_sum, is_async=False) client_tools.register("fetch_data", fetch_data, is_async=True) ``` +### WebRTC Support + +ElevenLabs Python SDK supports WebRTC connections for real-time, low-latency conversations using LiveKit infrastructure. WebRTC provides better audio quality, lower latency, and improved connectivity compared to traditional WebSocket connections. + +#### Key Benefits +- **Lower Latency**: Direct peer-to-peer audio streaming +- **Better Audio Quality**: Optimized for real-time audio +- **Improved Connectivity**: NAT traversal and firewall handling +- **Adaptive Bitrate**: Automatic quality adjustment based on network conditions + +#### Basic WebRTC Usage + +```python +import asyncio +from elevenlabs import ElevenLabs +from elevenlabs.conversational_ai.conversation_factory import create_webrtc_conversation +from elevenlabs.conversational_ai.conversation import AsyncAudioInterface + +class SimpleAsyncAudioInterface(AsyncAudioInterface): + async def start(self, input_callback): + print("Audio interface started") + self.input_callback = input_callback + + async def stop(self): + print("Audio interface stopped") + + async def output(self, audio: bytes): + print(f"Received audio: {len(audio)} bytes") + + async def interrupt(self): + print("Audio interrupted") + +async def main(): + client = ElevenLabs(api_key="YOUR_API_KEY") + audio_interface = SimpleAsyncAudioInterface() + + # WebRTC conversation with automatic token fetching + conversation = create_webrtc_conversation( + client=client, + agent_id="your-agent-id", + audio_interface=audio_interface, + ) + + await conversation.start_session() + await conversation.send_user_message("Hello!") + await asyncio.sleep(10) + await conversation.end_session() + +asyncio.run(main()) +``` + +#### Connection Type Comparison + +```python +from elevenlabs.conversational_ai.conversation_factory import create_conversation +from elevenlabs.conversational_ai.base_connection import ConnectionType + +# WebSocket (existing) +ws_conversation = create_conversation( + client=client, + agent_id="your-agent-id", + connection_type=ConnectionType.WEBSOCKET, + # Uses sync AudioInterface +) + +# WebRTC (new) +webrtc_conversation = create_conversation( + client=client, + agent_id="your-agent-id", + connection_type=ConnectionType.WEBRTC, + audio_interface=AsyncAudioInterface(), # Async interface required +) +``` + +#### Authentication Methods + +WebRTC conversations support multiple authentication approaches: + +1. **Automatic Token Fetching**: Provide only `agent_id` and the SDK fetches the conversation token automatically +2. **Explicit Token**: Provide both `agent_id` and `conversation_token` for manual token management + +```python +# Automatic token fetching (recommended) +conversation = create_webrtc_conversation( + client=client, + agent_id="your-agent-id", + audio_interface=audio_interface +) + +# Explicit token +conversation = create_webrtc_conversation( + client=client, + agent_id="your-agent-id", + conversation_token="your-conversation-token", + audio_interface=audio_interface +) +``` + +**Requirements**: WebRTC conversations require the `livekit` dependency (`pip install livekit`), which is automatically installed with the SDK. All WebRTC conversations must use `AsyncAudioInterface` implementations. + ## Languages Supported Explore [all models & languages](https://elevenlabs.io/docs/models). diff --git a/poetry.lock b/poetry.lock index 10ee395f..0dc691e1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,16 @@ # This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. +[[package]] +name = "aiofiles" +version = "24.1.0" +description = "File support for asyncio." +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, + {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"}, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -244,6 +255,27 @@ files = [ {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"}, ] +[[package]] +name = "livekit" +version = "1.0.13" +description = "Python Real-time SDK for LiveKit" +optional = false +python-versions = ">=3.9.0" +files = [ + {file = "livekit-1.0.13-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:7174723d75544e6942e1c1a99fb297bfee538d0f7b9bd3f3cdebf06e42a72abc"}, + {file = "livekit-1.0.13-py3-none-macosx_11_0_arm64.whl", hash = "sha256:ef1f641bc622c0b15adf0e91dfc62740d20db51d09369d3a7f84e8314b0ce067"}, + {file = "livekit-1.0.13-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:d40a8b9d5cc931736e82bb723e1ae27436e0b2d20b0217627341030400784dc2"}, + {file = "livekit-1.0.13-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:d73bb327a1a711b09e0b39d574fb04af9b2f38381c6267330df8a713e44e1be3"}, + {file = "livekit-1.0.13-py3-none-win_amd64.whl", hash = "sha256:bbb2d17203d74991aac23a5d0519e33984f8b0c0d53b2182c837086742d1b813"}, + {file = "livekit-1.0.13.tar.gz", hash = "sha256:eb50b59b7320b1e960ea8f71b8e52fb832fb867e42806845659918dbe13e6a10"}, +] + +[package.dependencies] +aiofiles = ">=24" +numpy = ">=1.26" +protobuf = ">=4.25.0" +types-protobuf = ">=3" + [[package]] name = "mypy" version = "1.13.0" @@ -308,6 +340,60 @@ files = [ {file = "mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558"}, ] +[[package]] +name = "numpy" +version = "2.0.2" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "numpy-2.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:51129a29dbe56f9ca83438b706e2e69a39892b5eda6cedcb6b0c9fdc9b0d3ece"}, + {file = "numpy-2.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f15975dfec0cf2239224d80e32c3170b1d168335eaedee69da84fbe9f1f9cd04"}, + {file = "numpy-2.0.2-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8c5713284ce4e282544c68d1c3b2c7161d38c256d2eefc93c1d683cf47683e66"}, + {file = "numpy-2.0.2-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:becfae3ddd30736fe1889a37f1f580e245ba79a5855bff5f2a29cb3ccc22dd7b"}, + {file = "numpy-2.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2da5960c3cf0df7eafefd806d4e612c5e19358de82cb3c343631188991566ccd"}, + {file = "numpy-2.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:496f71341824ed9f3d2fd36cf3ac57ae2e0165c143b55c3a035ee219413f3318"}, + {file = "numpy-2.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a61ec659f68ae254e4d237816e33171497e978140353c0c2038d46e63282d0c8"}, + {file = "numpy-2.0.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d731a1c6116ba289c1e9ee714b08a8ff882944d4ad631fd411106a30f083c326"}, + {file = "numpy-2.0.2-cp310-cp310-win32.whl", hash = "sha256:984d96121c9f9616cd33fbd0618b7f08e0cfc9600a7ee1d6fd9b239186d19d97"}, + {file = "numpy-2.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:c7b0be4ef08607dd04da4092faee0b86607f111d5ae68036f16cc787e250a131"}, + {file = "numpy-2.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:49ca4decb342d66018b01932139c0961a8f9ddc7589611158cb3c27cbcf76448"}, + {file = "numpy-2.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:11a76c372d1d37437857280aa142086476136a8c0f373b2e648ab2c8f18fb195"}, + {file = "numpy-2.0.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:807ec44583fd708a21d4a11d94aedf2f4f3c3719035c76a2bbe1fe8e217bdc57"}, + {file = "numpy-2.0.2-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8cafab480740e22f8d833acefed5cc87ce276f4ece12fdaa2e8903db2f82897a"}, + {file = "numpy-2.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a15f476a45e6e5a3a79d8a14e62161d27ad897381fecfa4a09ed5322f2085669"}, + {file = "numpy-2.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13e689d772146140a252c3a28501da66dfecd77490b498b168b501835041f951"}, + {file = "numpy-2.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9ea91dfb7c3d1c56a0e55657c0afb38cf1eeae4544c208dc465c3c9f3a7c09f9"}, + {file = "numpy-2.0.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c1c9307701fec8f3f7a1e6711f9089c06e6284b3afbbcd259f7791282d660a15"}, + {file = "numpy-2.0.2-cp311-cp311-win32.whl", hash = "sha256:a392a68bd329eafac5817e5aefeb39038c48b671afd242710b451e76090e81f4"}, + {file = "numpy-2.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:286cd40ce2b7d652a6f22efdfc6d1edf879440e53e76a75955bc0c826c7e64dc"}, + {file = "numpy-2.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:df55d490dea7934f330006d0f81e8551ba6010a5bf035a249ef61a94f21c500b"}, + {file = "numpy-2.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8df823f570d9adf0978347d1f926b2a867d5608f434a7cff7f7908c6570dcf5e"}, + {file = "numpy-2.0.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:9a92ae5c14811e390f3767053ff54eaee3bf84576d99a2456391401323f4ec2c"}, + {file = "numpy-2.0.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:a842d573724391493a97a62ebbb8e731f8a5dcc5d285dfc99141ca15a3302d0c"}, + {file = "numpy-2.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c05e238064fc0610c840d1cf6a13bf63d7e391717d247f1bf0318172e759e692"}, + {file = "numpy-2.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0123ffdaa88fa4ab64835dcbde75dcdf89c453c922f18dced6e27c90d1d0ec5a"}, + {file = "numpy-2.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:96a55f64139912d61de9137f11bf39a55ec8faec288c75a54f93dfd39f7eb40c"}, + {file = "numpy-2.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ec9852fb39354b5a45a80bdab5ac02dd02b15f44b3804e9f00c556bf24b4bded"}, + {file = "numpy-2.0.2-cp312-cp312-win32.whl", hash = "sha256:671bec6496f83202ed2d3c8fdc486a8fc86942f2e69ff0e986140339a63bcbe5"}, + {file = "numpy-2.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:cfd41e13fdc257aa5778496b8caa5e856dc4896d4ccf01841daee1d96465467a"}, + {file = "numpy-2.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9059e10581ce4093f735ed23f3b9d283b9d517ff46009ddd485f1747eb22653c"}, + {file = "numpy-2.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:423e89b23490805d2a5a96fe40ec507407b8ee786d66f7328be214f9679df6dd"}, + {file = "numpy-2.0.2-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:2b2955fa6f11907cf7a70dab0d0755159bca87755e831e47932367fc8f2f2d0b"}, + {file = "numpy-2.0.2-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:97032a27bd9d8988b9a97a8c4d2c9f2c15a81f61e2f21404d7e8ef00cb5be729"}, + {file = "numpy-2.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1e795a8be3ddbac43274f18588329c72939870a16cae810c2b73461c40718ab1"}, + {file = "numpy-2.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f26b258c385842546006213344c50655ff1555a9338e2e5e02a0756dc3e803dd"}, + {file = "numpy-2.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5fec9451a7789926bcf7c2b8d187292c9f93ea30284802a0ab3f5be8ab36865d"}, + {file = "numpy-2.0.2-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:9189427407d88ff25ecf8f12469d4d39d35bee1db5d39fc5c168c6f088a6956d"}, + {file = "numpy-2.0.2-cp39-cp39-win32.whl", hash = "sha256:905d16e0c60200656500c95b6b8dca5d109e23cb24abc701d41c02d74c6b3afa"}, + {file = "numpy-2.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:a3f4ab0caa7f053f6797fcd4e1e25caee367db3112ef2b6ef82d749530768c73"}, + {file = "numpy-2.0.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:7f0a0c6f12e07fa94133c8a67404322845220c06a9e80e85999afe727f7438b8"}, + {file = "numpy-2.0.2-pp39-pypy39_pp73-macosx_14_0_x86_64.whl", hash = "sha256:312950fdd060354350ed123c0e25a71327d3711584beaef30cdaa93320c392d4"}, + {file = "numpy-2.0.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26df23238872200f63518dd2aa984cfca675d82469535dc7162dc2ee52d9dd5c"}, + {file = "numpy-2.0.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a46288ec55ebbd58947d31d72be2c63cbf839f0a63b49cb755022310792a3385"}, + {file = "numpy-2.0.2.tar.gz", hash = "sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78"}, +] + [[package]] name = "packaging" version = "25.0" @@ -334,6 +420,24 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "protobuf" +version = "6.32.1" +description = "" +optional = false +python-versions = ">=3.9" +files = [ + {file = "protobuf-6.32.1-cp310-abi3-win32.whl", hash = "sha256:a8a32a84bc9f2aad712041b8b366190f71dde248926da517bde9e832e4412085"}, + {file = "protobuf-6.32.1-cp310-abi3-win_amd64.whl", hash = "sha256:b00a7d8c25fa471f16bc8153d0e53d6c9e827f0953f3c09aaa4331c718cae5e1"}, + {file = "protobuf-6.32.1-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d8c7e6eb619ffdf105ee4ab76af5a68b60a9d0f66da3ea12d1640e6d8dab7281"}, + {file = "protobuf-6.32.1-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:2f5b80a49e1eb7b86d85fcd23fe92df154b9730a725c3b38c4e43b9d77018bf4"}, + {file = "protobuf-6.32.1-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:b1864818300c297265c83a4982fd3169f97122c299f56a56e2445c3698d34710"}, + {file = "protobuf-6.32.1-cp39-cp39-win32.whl", hash = "sha256:68ff170bac18c8178f130d1ccb94700cf72852298e016a2443bdb9502279e5f1"}, + {file = "protobuf-6.32.1-cp39-cp39-win_amd64.whl", hash = "sha256:d0975d0b2f3e6957111aa3935d08a0eb7e006b1505d825f862a1fffc8348e122"}, + {file = "protobuf-6.32.1-py3-none-any.whl", hash = "sha256:2601b779fc7d32a866c6b4404f9d42a3f67c5b9f3f15b4db3cccabe06b95c346"}, + {file = "protobuf-6.32.1.tar.gz", hash = "sha256:ee2469e4a021474ab9baafea6cd070e5bf27c7d29433504ddea1a4ee5850f68d"}, +] + [[package]] name = "pyaudio" version = "0.2.14" @@ -656,6 +760,17 @@ files = [ {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"}, ] +[[package]] +name = "types-protobuf" +version = "6.32.1.20250918" +description = "Typing stubs for protobuf" +optional = false +python-versions = ">=3.9" +files = [ + {file = "types_protobuf-6.32.1.20250918-py3-none-any.whl", hash = "sha256:22ba6133d142d11cc34d3788ad6dead2732368ebb0406eaa7790ea6ae46c8d0b"}, + {file = "types_protobuf-6.32.1.20250918.tar.gz", hash = "sha256:44ce0ae98475909ca72379946ab61a4435eec2a41090821e713c17e8faf5b88f"}, +] + [[package]] name = "types-pyaudio" version = "0.2.16.20240516" @@ -807,4 +922,4 @@ pyaudio = ["pyaudio"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "b7da141a4d5cf0383830ab57e4341326054c70f2d67cc19d189d4cd1cbaf21c0" +content-hash = "e7cb27516e124f02d2e959c2e4986786d7542bfe81362376069ca3ee3c7b0255" diff --git a/pyproject.toml b/pyproject.toml index 2696993b..92b62e57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ pydantic-core = ">=2.18.2" requests = ">=2.20" typing_extensions = ">= 4.0.0" websockets = ">=11.0" +livekit = { version = "^1.0.13", python = ">=3.9" } [tool.poetry.group.dev.dependencies] mypy = "==1.13.0" diff --git a/src/elevenlabs/conversational_ai/base_connection.py b/src/elevenlabs/conversational_ai/base_connection.py new file mode 100644 index 00000000..5a160c50 --- /dev/null +++ b/src/elevenlabs/conversational_ai/base_connection.py @@ -0,0 +1,75 @@ +from abc import ABC, abstractmethod +import asyncio +import json +from typing import Callable, Optional, Awaitable, Union, Any, Literal, Dict +from enum import Enum + + +class ConnectionType(str, Enum): + """Connection types available for conversations.""" + WEBSOCKET = "websocket" + WEBRTC = "webrtc" + + +class BaseConnection(ABC): + """Base class for conversation connections.""" + + def __init__(self) -> None: + self.conversation_id: Optional[str] = None + self._message_queue: list[dict] = [] + self._on_message_callback: Optional[Callable[[dict], Union[None, Awaitable[None]]]] = None + + @abstractmethod + async def connect(self) -> None: + """Establish the connection.""" + pass + + @abstractmethod + async def close(self) -> None: + """Close the connection.""" + pass + + @abstractmethod + async def send_message(self, message: dict) -> None: + """Send a message through the connection.""" + pass + + @abstractmethod + async def send_audio(self, audio_data: bytes) -> None: + """Send audio data through the connection.""" + pass + + def send_message_sync(self, message: dict) -> None: + """Send a message synchronously (for compatibility with sync code).""" + import asyncio + try: + # Try to get the current event loop + loop = asyncio.get_event_loop() + if loop.is_running(): + # If loop is running, create a task + asyncio.create_task(self.send_message(message)) + else: + # If loop is not running, run the coroutine + loop.run_until_complete(self.send_message(message)) + except RuntimeError: + # No event loop, create new one + asyncio.run(self.send_message(message)) + + def on_message(self, callback: Callable[[dict], Union[None, Awaitable[None]]]) -> None: + """Set the message callback.""" + self._on_message_callback = callback + # Process any queued messages + if self._message_queue: + for message in self._message_queue: + self._handle_message(message) + self._message_queue.clear() + + def _handle_message(self, message: dict) -> None: + """Handle incoming messages.""" + if self._on_message_callback: + if asyncio.iscoroutinefunction(self._on_message_callback): + asyncio.create_task(self._on_message_callback(message)) + else: + self._on_message_callback(message) + else: + self._message_queue.append(message) \ No newline at end of file diff --git a/src/elevenlabs/conversational_ai/connection_factory.py b/src/elevenlabs/conversational_ai/connection_factory.py new file mode 100644 index 00000000..9817c6a6 --- /dev/null +++ b/src/elevenlabs/conversational_ai/connection_factory.py @@ -0,0 +1,56 @@ +from typing import Optional, Dict, Any, Callable + +from .base_connection import BaseConnection, ConnectionType +from .websocket_connection import WebSocketConnection +from .webrtc_connection import WebRTCConnection, WebRTCConnectionConfig + + +def create_connection( + connection_type: ConnectionType, + *, + ws_url: Optional[str] = None, + conversation_token: Optional[str] = None, + agent_id: Optional[str] = None, + livekit_url: Optional[str] = None, + api_origin: Optional[str] = None, + overrides: Optional[Dict[str, Any]] = None, + on_debug: Optional[Callable[[Dict[str, Any]], None]] = None, +) -> BaseConnection: + """Factory function to create connections based on type.""" + + if connection_type == ConnectionType.WEBSOCKET: + if not ws_url: + raise ValueError("ws_url is required for WebSocket connections") + return WebSocketConnection(ws_url) + + elif connection_type == ConnectionType.WEBRTC: + return WebRTCConnection( + conversation_token=conversation_token, + agent_id=agent_id, + livekit_url=livekit_url, + api_origin=api_origin, + overrides=overrides, + on_debug=on_debug, + ) + + else: + raise ValueError(f"Unknown connection type: {connection_type}") + + +def determine_connection_type( + connection_type: Optional[ConnectionType] = None, + conversation_token: Optional[str] = None, + **kwargs +) -> ConnectionType: + """Determine the appropriate connection type based on parameters.""" + + # If connection_type is explicitly specified, use it + if connection_type: + return connection_type + + # If conversation_token is provided, use WebRTC + if conversation_token: + return ConnectionType.WEBRTC + + # Default to WebSocket for backward compatibility + return ConnectionType.WEBSOCKET \ No newline at end of file diff --git a/src/elevenlabs/conversational_ai/conversation.py b/src/elevenlabs/conversational_ai/conversation.py index dbc72c5b..a95991b9 100644 --- a/src/elevenlabs/conversational_ai/conversation.py +++ b/src/elevenlabs/conversational_ai/conversation.py @@ -13,6 +13,9 @@ from ..base_client import BaseElevenLabs from ..version import __version__ +from .base_connection import ConnectionType, BaseConnection +from .connection_factory import create_connection, determine_connection_type +from .location_utils import Location, get_origin_for_location class ClientToOrchestratorEvent(str, Enum): @@ -295,11 +298,25 @@ def __init__( conversation_config_override: Optional[dict] = None, dynamic_variables: Optional[dict] = None, user_id: Optional[str] = None, + connection_type: Optional[ConnectionType] = None, + conversation_token: Optional[str] = None, + location: Optional[Location] = None, + livekit_url: Optional[str] = None, + api_origin: Optional[str] = None, + webrtc_overrides: Optional[dict] = None, + on_debug: Optional[Callable[[dict], None]] = None, ): self.extra_body = extra_body or {} self.conversation_config_override = conversation_config_override or {} self.dynamic_variables = dynamic_variables or {} self.user_id = user_id + self.connection_type = connection_type + self.conversation_token = conversation_token + self.location = location + self.livekit_url = livekit_url + self.api_origin = api_origin + self.webrtc_overrides = webrtc_overrides or {} + self.on_debug = on_debug class BaseConversation: @@ -326,10 +343,15 @@ def __init__( self._conversation_id = None self._last_interrupt_id = 0 + self._connection: Optional[BaseConnection] = None def _get_wss_url(self): - base_http_url = self.client._client_wrapper.get_base_url() - base_ws_url = base_http_url.replace("https://", "wss://").replace("http://", "ws://") + # Use location-based URL if location is specified + if self.config.location is not None: + base_ws_url = get_origin_for_location(self.config.location) + else: + base_http_url = self.client._client_wrapper.get_base_url() + base_ws_url = base_http_url.replace("https://", "wss://").replace("http://", "ws://") return f"{base_ws_url}/v1/convai/conversation?agent_id={self.agent_id}&source=python_sdk&version={__version__}" def _get_signed_url(self): @@ -339,6 +361,55 @@ def _get_signed_url(self): separator = "&" if "?" in signed_url else "?" return f"{signed_url}{separator}source=python_sdk&version={__version__}" + def _determine_connection_type(self) -> ConnectionType: + """Determine the appropriate connection type for this conversation.""" + return determine_connection_type( + connection_type=self.config.connection_type, + conversation_token=self.config.conversation_token + ) + + def _create_connection(self): + """Create the appropriate connection based on configuration.""" + connection_type = self._determine_connection_type() + + if connection_type == ConnectionType.WEBSOCKET: + ws_url = self._get_signed_url() if self.requires_auth else self._get_wss_url() + return create_connection(connection_type, ws_url=ws_url) + elif connection_type == ConnectionType.WEBRTC: + # Convert base HTTP URL to appropriate origins + base_http_url = self.client._client_wrapper.get_base_url() + + # Use configured URLs or derive from base URL + api_origin = self.config.api_origin or base_http_url + livekit_url = self.config.livekit_url + if not livekit_url: + # Default LiveKit URL if not specified + livekit_url = "wss://livekit.rtc.elevenlabs.io" + + # Merge conversation overrides with WebRTC overrides + overrides = { + **self.config.webrtc_overrides, + "client": { + "version": __version__, + "source": "python_sdk", + }, + "custom_llm_extra_body": self.config.extra_body, + "conversation_config_override": self.config.conversation_config_override, + "dynamic_variables": self.config.dynamic_variables, + } + + return create_connection( + connection_type, + conversation_token=self.config.conversation_token, + agent_id=self.agent_id, + livekit_url=livekit_url, + api_origin=api_origin, + overrides=overrides, + on_debug=self.config.on_debug, + ) + else: + raise ValueError(f"Unsupported connection type: {connection_type}") + def _create_initiation_message(self): return json.dumps( { @@ -534,8 +605,16 @@ def start_session(self): Will run in background thread until `end_session` is called. """ - ws_url = self._get_signed_url() if self.requires_auth else self._get_wss_url() - self._thread = threading.Thread(target=self._run, args=(ws_url,)) + connection_type = self._determine_connection_type() + if connection_type == ConnectionType.WEBSOCKET: + ws_url = self._get_signed_url() if self.requires_auth else self._get_wss_url() + self._thread = threading.Thread(target=self._run_websocket, args=(ws_url,)) + elif connection_type == ConnectionType.WEBRTC: + self._connection = self._create_connection() + self._thread = threading.Thread(target=self._run_webrtc) + else: + raise ValueError(f"Unsupported connection type: {connection_type}") + self._thread.start() def end_session(self): @@ -545,6 +624,33 @@ def end_session(self): self._ws = None self._should_stop.set() + # Close connection if it exists + if self._connection: + connection_type = self._determine_connection_type() + if connection_type == ConnectionType.WEBRTC: + import asyncio + try: + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + task = asyncio.create_task(self._connection.close()) + else: + asyncio.wait_for( + loop.run_until_complete(self._connection.close()), + timeout=5.0 + ) + except RuntimeError: + async def cleanup(): + await asyncio.wait_for(self._connection.close(), timeout=5.0) + + asyncio.run(cleanup()) + + except asyncio.TimeoutError: + print("Warning: WebRTC connection cleanup timed out") + except Exception as e: + print(f"Warning: Error during WebRTC connection cleanup: {e}") + self._connection = None + if self.callback_end_session: self.callback_end_session() @@ -567,17 +673,31 @@ def send_user_message(self, text: str): text: The text message to send to the agent. Raises: - RuntimeError: If the session is not active or websocket is not connected. + RuntimeError: If the session is not active or connection is not established. """ - if not self._ws: - raise RuntimeError("Session not started or websocket not connected.") + connection_type = self._determine_connection_type() - event = UserMessageClientToOrchestratorEvent(text=text) - try: - self._ws.send(json.dumps(event.to_dict())) - except Exception as e: - print(f"Error sending user message: {e}") - raise + if connection_type == ConnectionType.WEBSOCKET: + if not self._ws: + raise RuntimeError("Session not started or websocket not connected.") + + event = UserMessageClientToOrchestratorEvent(text=text) + try: + self._ws.send(json.dumps(event.to_dict())) + except Exception as e: + print(f"Error sending user message: {e}") + raise + + elif connection_type == ConnectionType.WEBRTC: + if not self._connection: + raise RuntimeError("Session not started or WebRTC connection not established.") + + event = UserMessageClientToOrchestratorEvent(text=text) + try: + self._connection.send_message_sync(event.to_dict()) + except Exception as e: + print(f"Error sending user message: {e}") + raise def register_user_activity(self): """Register user activity to prevent session timeout. @@ -619,7 +739,7 @@ def send_contextual_update(self, text: str): print(f"Error sending contextual update: {e}") raise - def _run(self, ws_url: str): + def _run_websocket(self, ws_url: str): with connect(ws_url, max_size=16 * 1024 * 1024) as ws: self._ws = ws ws.send(self._create_initiation_message()) @@ -657,6 +777,104 @@ def input_callback(audio): self._ws = None + def _run_webrtc(self): + """Run WebRTC conversation session.""" + try: + # Connect to WebRTC + import asyncio + + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async def webrtc_session(): + await self._connection.connect() + self._conversation_id = self._connection.conversation_id + + # Set up message callback + def message_callback(message): + self._handle_webrtc_message(message) + + self._connection.on_message(message_callback) + + # Set up audio input callback + def input_callback(audio): + try: + # Send audio through WebRTC connection + loop.create_task(self._connection.send_audio(audio)) + except Exception as e: + print(f"Error sending user audio chunk: {e}") + self.end_session() + + self.audio_interface.start(input_callback) + + # Keep running until stopped + while not self._should_stop.is_set(): + await asyncio.sleep(0.1) + + await self._connection.close() + + loop.run_until_complete(webrtc_session()) + + except Exception as e: + print(f"WebRTC session error: {e}") + self.end_session() + finally: + loop.close() + + def _handle_webrtc_message(self, message): + """Handle messages from WebRTC connection.""" + class WebRTCMessageHandler: + def __init__(self, conversation): + self.conversation = conversation + self.callback_agent_response = conversation.callback_agent_response + self.callback_agent_response_correction = conversation.callback_agent_response_correction + self.callback_user_transcript = conversation.callback_user_transcript + self.callback_latency_measurement = conversation.callback_latency_measurement + + def handle_audio_output(self, audio): + self.conversation.audio_interface.output(audio) + + def handle_agent_response(self, response): + if self.conversation.callback_agent_response: + self.conversation.callback_agent_response(response) + + def handle_agent_response_correction(self, original, corrected): + if self.conversation.callback_agent_response_correction: + self.conversation.callback_agent_response_correction(original, corrected) + + def handle_user_transcript(self, transcript): + if self.conversation.callback_user_transcript: + self.conversation.callback_user_transcript(transcript) + + def handle_interruption(self): + self.conversation.audio_interface.interrupt() + + def handle_ping(self, event): + # For WebRTC, pings are handled by the connection itself + pass + + def handle_latency_measurement(self, latency): + if self.conversation.callback_latency_measurement: + self.conversation.callback_latency_measurement(latency) + + def handle_client_tool_call(self, tool_name, parameters): + def send_response(response): + if not self.conversation._should_stop.is_set(): + # Send response through WebRTC connection + import asyncio + asyncio.create_task(self.conversation._connection.send_message(response)) + + self.conversation.client_tools.execute_tool(tool_name, parameters, send_response) + + handler = WebRTCMessageHandler(self) + self._handle_message_core(message, handler) + def _handle_message(self, message, ws): class SyncMessageHandler: def __init__(self, conversation, ws): @@ -779,8 +997,15 @@ async def start_session(self): Will run in background task until `end_session` is called. """ - ws_url = self._get_signed_url() if self.requires_auth else self._get_wss_url() - self._task = asyncio.create_task(self._run(ws_url)) + connection_type = self._determine_connection_type() + if connection_type == ConnectionType.WEBSOCKET: + ws_url = self._get_signed_url() if self.requires_auth else self._get_wss_url() + self._task = asyncio.create_task(self._run_websocket(ws_url)) + elif connection_type == ConnectionType.WEBRTC: + self._connection = self._create_connection() + self._task = asyncio.create_task(self._run_webrtc()) + else: + raise ValueError(f"Unsupported connection type: {connection_type}") async def end_session(self): """Ends the conversation session and cleans up resources.""" @@ -789,6 +1014,11 @@ async def end_session(self): self._ws = None self._should_stop.set() + # Close connection if it exists + if self._connection: + await self._connection.close() + self._connection = None + if self.callback_end_session: await self.callback_end_session() @@ -811,17 +1041,31 @@ async def send_user_message(self, text: str): text: The text message to send to the agent. Raises: - RuntimeError: If the session is not active or websocket is not connected. + RuntimeError: If the session is not active or connection is not established. """ - if not self._ws: - raise RuntimeError("Session not started or websocket not connected.") + connection_type = self._determine_connection_type() - event = UserMessageClientToOrchestratorEvent(text=text) - try: - await self._ws.send(json.dumps(event.to_dict())) - except Exception as e: - print(f"Error sending user message: {e}") - raise + if connection_type == ConnectionType.WEBSOCKET: + if not self._ws: + raise RuntimeError("Session not started or websocket not connected.") + + event = UserMessageClientToOrchestratorEvent(text=text) + try: + await self._ws.send(json.dumps(event.to_dict())) + except Exception as e: + print(f"Error sending user message: {e}") + raise + + elif connection_type == ConnectionType.WEBRTC: + if not self._connection: + raise RuntimeError("Session not started or WebRTC connection not established.") + + event = UserMessageClientToOrchestratorEvent(text=text) + try: + await self._connection.send_message(event.to_dict()) + except Exception as e: + print(f"Error sending user message: {e}") + raise async def register_user_activity(self): """Register user activity to prevent session timeout. @@ -863,7 +1107,7 @@ async def send_contextual_update(self, text: str): print(f"Error sending contextual update: {e}") raise - async def _run(self, ws_url: str): + async def _run_websocket(self, ws_url: str): async with websockets.connect(ws_url, max_size=16 * 1024 * 1024) as ws: self._ws = ws await ws.send(self._create_initiation_message()) @@ -905,6 +1149,84 @@ async def input_callback(audio): finally: self._ws = None + async def _run_webrtc(self): + """Run async WebRTC conversation session.""" + try: + await self._connection.connect() + self._conversation_id = self._connection.conversation_id + + # Set up message callback + async def message_callback(message): + await self._handle_webrtc_message(message) + + self._connection.on_message(message_callback) + + # Set up audio input callback + async def input_callback(audio): + try: + await self._connection.send_audio(audio) + except Exception as e: + print(f"Error sending user audio chunk: {e}") + await self.end_session() + + await self.audio_interface.start(input_callback) + + # Keep running until stopped + while not self._should_stop.is_set(): + await asyncio.sleep(0.1) + + await self._connection.close() + + except Exception as e: + print(f"WebRTC session error: {e}") + await self.end_session() + + async def _handle_webrtc_message(self, message): + """Handle messages from WebRTC connection.""" + class AsyncWebRTCMessageHandler: + def __init__(self, conversation): + self.conversation = conversation + self.callback_agent_response = conversation.callback_agent_response + self.callback_agent_response_correction = conversation.callback_agent_response_correction + self.callback_user_transcript = conversation.callback_user_transcript + self.callback_latency_measurement = conversation.callback_latency_measurement + + async def handle_audio_output(self, audio): + await self.conversation.audio_interface.output(audio) + + async def handle_agent_response(self, response): + if self.conversation.callback_agent_response: + await self.conversation.callback_agent_response(response) + + async def handle_agent_response_correction(self, original, corrected): + if self.conversation.callback_agent_response_correction: + await self.conversation.callback_agent_response_correction(original, corrected) + + async def handle_user_transcript(self, transcript): + if self.conversation.callback_user_transcript: + await self.conversation.callback_user_transcript(transcript) + + async def handle_interruption(self): + await self.conversation.audio_interface.interrupt() + + async def handle_ping(self, event): + # For WebRTC, pings are handled by the connection itself + pass + + async def handle_latency_measurement(self, latency): + if self.conversation.callback_latency_measurement: + await self.conversation.callback_latency_measurement(latency) + + def handle_client_tool_call(self, tool_name, parameters): + def send_response(response): + if not self.conversation._should_stop.is_set(): + asyncio.create_task(self.conversation._connection.send_message(response)) + + self.conversation.client_tools.execute_tool(tool_name, parameters, send_response) + + handler = AsyncWebRTCMessageHandler(self) + await self._handle_message_core_async(message, handler) + async def _handle_message(self, message, ws): class AsyncMessageHandler: def __init__(self, conversation, ws): diff --git a/src/elevenlabs/conversational_ai/conversation_factory.py b/src/elevenlabs/conversational_ai/conversation_factory.py new file mode 100644 index 00000000..33260b90 --- /dev/null +++ b/src/elevenlabs/conversational_ai/conversation_factory.py @@ -0,0 +1,273 @@ +from typing import Optional, Callable, Awaitable, Union + +from ..base_client import BaseElevenLabs +from .conversation import ( + Conversation, + AsyncConversation, + AudioInterface, + AsyncAudioInterface, + ConversationInitiationData, + ClientTools +) +from .webrtc_conversation import WebRTCConversation +from .base_connection import ConnectionType +from .location_utils import Location, get_origin_for_location, get_livekit_url_for_location + + +def create_conversation( + client: BaseElevenLabs, + agent_id: str, + user_id: Optional[str] = None, + *, + connection_type: ConnectionType = ConnectionType.WEBSOCKET, + conversation_token: Optional[str] = None, + requires_auth: bool = True, + location: Optional[Location] = None, + audio_interface: Optional[Union[AudioInterface, AsyncAudioInterface]] = None, + config: Optional[ConversationInitiationData] = None, + client_tools: Optional[ClientTools] = None, + # Sync callbacks (for websocket conversations) + callback_agent_response: Optional[Callable[[str], None]] = None, + callback_agent_response_correction: Optional[Callable[[str, str], None]] = None, + callback_user_transcript: Optional[Callable[[str], None]] = None, + callback_latency_measurement: Optional[Callable[[int], None]] = None, + callback_end_session: Optional[Callable] = None, + # Async callbacks (for WebRTC and async websocket conversations) + async_callback_agent_response: Optional[Callable[[str], Awaitable[None]]] = None, + async_callback_agent_response_correction: Optional[Callable[[str, str], Awaitable[None]]] = None, + async_callback_user_transcript: Optional[Callable[[str], Awaitable[None]]] = None, + async_callback_latency_measurement: Optional[Callable[[int], Awaitable[None]]] = None, + async_callback_end_session: Optional[Callable[[], Awaitable[None]]] = None, +) -> Union[Conversation, AsyncConversation, WebRTCConversation]: + """Create a conversation with the specified connection type. + + Args: + client: ElevenLabs client instance + agent_id: ID of the agent to connect to + user_id: Optional user ID + connection_type: Type of connection (websocket or webrtc) + conversation_token: Token for WebRTC authentication + requires_auth: Whether authentication is required + location: Data residency location (us, eu-residency, in-residency, global) + audio_interface: Audio interface for the conversation + config: Conversation configuration + client_tools: Client tools for handling agent calls + callback_*: Synchronous callbacks for websocket conversations + async_callback_*: Asynchronous callbacks for WebRTC and async conversations + + Returns: + A conversation instance of the appropriate type + + Examples: + # WebSocket conversation (default) + conversation = create_conversation( + client=client, + agent_id="your-agent-id", + audio_interface=your_audio_interface + ) + + # WebRTC conversation with EU residency + conversation = create_conversation( + client=client, + agent_id="your-agent-id", + connection_type=ConnectionType.WEBRTC, + location=Location.EU_RESIDENCY, + audio_interface=your_async_audio_interface, + async_callback_agent_response=your_response_handler + ) + + # WebSocket conversation with specific location + conversation = create_conversation( + client=client, + agent_id="your-agent-id", + location=Location.IN_RESIDENCY, + audio_interface=your_audio_interface + ) + """ + + # Set up configuration + if config is None: + config = ConversationInitiationData() + + config.connection_type = connection_type + if conversation_token: + config.conversation_token = conversation_token + if location is not None: + config.location = location + + if connection_type == ConnectionType.WEBRTC: + # Create WebRTC conversation + if not isinstance(audio_interface, AsyncAudioInterface) and audio_interface is not None: + raise ValueError("WebRTC conversations require an AsyncAudioInterface") + + # Determine URLs based on location + livekit_url = None + api_origin = None + if location is not None: + livekit_url = get_livekit_url_for_location(location) + # Convert WSS to HTTPS for API origin + api_origin = get_origin_for_location(location).replace("wss://", "https://") + + return WebRTCConversation( + client=client, + agent_id=agent_id, + user_id=user_id, + conversation_token=conversation_token, + livekit_url=livekit_url, + api_origin=api_origin, + audio_interface=audio_interface, + config=config, + client_tools=client_tools, + callback_agent_response=async_callback_agent_response, + callback_agent_response_correction=async_callback_agent_response_correction, + callback_user_transcript=async_callback_user_transcript, + callback_latency_measurement=async_callback_latency_measurement, + callback_end_session=async_callback_end_session, + ) + + elif connection_type == ConnectionType.WEBSOCKET: + # Determine if we should use sync or async conversation + has_async_callbacks = any([ + async_callback_agent_response, + async_callback_agent_response_correction, + async_callback_user_transcript, + async_callback_latency_measurement, + async_callback_end_session, + ]) + + if has_async_callbacks or isinstance(audio_interface, AsyncAudioInterface): + # Use async conversation + return AsyncConversation( + client=client, + agent_id=agent_id, + user_id=user_id, + requires_auth=requires_auth, + audio_interface=audio_interface, # type: ignore + config=config, + client_tools=client_tools, + callback_agent_response=async_callback_agent_response, + callback_agent_response_correction=async_callback_agent_response_correction, + callback_user_transcript=async_callback_user_transcript, + callback_latency_measurement=async_callback_latency_measurement, + callback_end_session=async_callback_end_session, + ) + else: + # Use sync conversation + if not isinstance(audio_interface, AudioInterface) and audio_interface is not None: + raise ValueError("Synchronous WebSocket conversations require an AudioInterface") + + return Conversation( + client=client, + agent_id=agent_id, + user_id=user_id, + requires_auth=requires_auth, + audio_interface=audio_interface, # type: ignore + config=config, + client_tools=client_tools, + callback_agent_response=callback_agent_response, + callback_agent_response_correction=callback_agent_response_correction, + callback_user_transcript=callback_user_transcript, + callback_latency_measurement=callback_latency_measurement, + callback_end_session=callback_end_session, + ) + + else: + raise ValueError(f"Unsupported connection type: {connection_type}") + + +# Convenience functions for specific connection types + +def create_webrtc_conversation( + client: BaseElevenLabs, + agent_id: str, + user_id: Optional[str] = None, + *, + conversation_token: Optional[str] = None, + location: Optional[Location] = None, + livekit_url: Optional[str] = None, + api_origin: Optional[str] = None, + webrtc_overrides: Optional[dict] = None, + on_debug: Optional[Callable[[dict], None]] = None, + audio_interface: Optional[AsyncAudioInterface] = None, + config: Optional[ConversationInitiationData] = None, + client_tools: Optional[ClientTools] = None, + callback_agent_response: Optional[Callable[[str], Awaitable[None]]] = None, + callback_agent_response_correction: Optional[Callable[[str, str], Awaitable[None]]] = None, + callback_user_transcript: Optional[Callable[[str], Awaitable[None]]] = None, + callback_latency_measurement: Optional[Callable[[int], Awaitable[None]]] = None, + callback_end_session: Optional[Callable[[], Awaitable[None]]] = None, +) -> WebRTCConversation: + """Create a WebRTC conversation. + + Convenience function for creating WebRTC conversations with type safety. + + Args: + location: Data residency location. If provided, overrides livekit_url and api_origin. + livekit_url: Custom LiveKit URL (overridden by location if provided). + api_origin: Custom API origin (overridden by location if provided). + """ + # Determine URLs based on location if provided + if location is not None: + livekit_url = get_livekit_url_for_location(location) + api_origin = get_origin_for_location(location).replace("wss://", "https://") + + return WebRTCConversation( + client=client, + agent_id=agent_id, + user_id=user_id, + conversation_token=conversation_token, + livekit_url=livekit_url, + api_origin=api_origin, + webrtc_overrides=webrtc_overrides, + on_debug=on_debug, + audio_interface=audio_interface, + config=config, + client_tools=client_tools, + callback_agent_response=callback_agent_response, + callback_agent_response_correction=callback_agent_response_correction, + callback_user_transcript=callback_user_transcript, + callback_latency_measurement=callback_latency_measurement, + callback_end_session=callback_end_session, + ) + + +def create_websocket_conversation( + client: BaseElevenLabs, + agent_id: str, + user_id: Optional[str] = None, + *, + requires_auth: bool = True, + location: Optional[Location] = None, + audio_interface: Optional[AudioInterface] = None, + config: Optional[ConversationInitiationData] = None, + client_tools: Optional[ClientTools] = None, + callback_agent_response: Optional[Callable[[str], None]] = None, + callback_agent_response_correction: Optional[Callable[[str, str], None]] = None, + callback_user_transcript: Optional[Callable[[str], None]] = None, + callback_latency_measurement: Optional[Callable[[int], None]] = None, + callback_end_session: Optional[Callable] = None, +) -> Union[Conversation, AsyncConversation]: + """Create a WebSocket conversation. + + Convenience function for creating WebSocket conversations with type safety. + + Args: + location: Data residency location (us, eu-residency, in-residency, global) + """ + result = create_conversation( + client=client, + agent_id=agent_id, + user_id=user_id, + connection_type=ConnectionType.WEBSOCKET, + requires_auth=requires_auth, + location=location, + audio_interface=audio_interface, + config=config, + client_tools=client_tools, + callback_agent_response=callback_agent_response, + callback_agent_response_correction=callback_agent_response_correction, + callback_user_transcript=callback_user_transcript, + callback_latency_measurement=callback_latency_measurement, + callback_end_session=callback_end_session, + ) + return result # type: ignore[return-value] \ No newline at end of file diff --git a/src/elevenlabs/conversational_ai/location_utils.py b/src/elevenlabs/conversational_ai/location_utils.py new file mode 100644 index 00000000..0ff95111 --- /dev/null +++ b/src/elevenlabs/conversational_ai/location_utils.py @@ -0,0 +1,50 @@ +from enum import Enum +from typing import Dict + + +class Location(Enum): + """Location enum for data residency and region selection.""" + US = "us" + EU_RESIDENCY = "eu-residency" + IN_RESIDENCY = "in-residency" + GLOBAL = "global" + + +def get_origin_for_location(location: Location) -> str: + """ + Get the WebSocket API origin URL for a given location. + + Args: + location: The location enum value + + Returns: + The WebSocket URL for the specified location + """ + origin_map: Dict[Location, str] = { + Location.US: "wss://api.elevenlabs.io", + Location.EU_RESIDENCY: "wss://api.eu.residency.elevenlabs.io", + Location.IN_RESIDENCY: "wss://api.in.residency.elevenlabs.io", + Location.GLOBAL: "wss://api.elevenlabs.io", + } + + return origin_map[location] + + +def get_livekit_url_for_location(location: Location) -> str: + """ + Get the LiveKit WebRTC URL for a given location. + + Args: + location: The location enum value + + Returns: + The LiveKit URL for the specified location + """ + livekit_url_map: Dict[Location, str] = { + Location.US: "wss://livekit.rtc.elevenlabs.io", + Location.EU_RESIDENCY: "wss://livekit.rtc.eu.residency.elevenlabs.io", + Location.IN_RESIDENCY: "wss://livekit.rtc.in.residency.elevenlabs.io", + Location.GLOBAL: "wss://livekit.rtc.elevenlabs.io", + } + + return livekit_url_map[location] \ No newline at end of file diff --git a/src/elevenlabs/conversational_ai/webrtc_connection.py b/src/elevenlabs/conversational_ai/webrtc_connection.py new file mode 100644 index 00000000..09750a39 --- /dev/null +++ b/src/elevenlabs/conversational_ai/webrtc_connection.py @@ -0,0 +1,363 @@ +import json +import asyncio +from typing import Optional, Dict, Any, Callable, Union, Awaitable +import httpx + +try: + from livekit.rtc import Room, TrackKind +except ImportError: + raise ImportError( + "livekit package is required for WebRTC support. " + "Install with: pip install livekit" + ) + +from .base_connection import BaseConnection + + +class WebRTCConnectionConfig: + """Configuration for WebRTC connection.""" + def __init__( + self, + conversation_token: Optional[str] = None, + agent_id: Optional[str] = None, + livekit_url: Optional[str] = None, + api_origin: Optional[str] = None, + overrides: Optional[Dict[str, Any]] = None, + on_debug: Optional[Callable[[Dict[str, Any]], None]] = None, + ) -> None: + self.conversation_token = conversation_token + self.agent_id = agent_id + self.livekit_url = livekit_url + self.api_origin = api_origin + self.overrides = overrides or {} + self.on_debug = on_debug + + +class WebRTCConnection(BaseConnection): + """WebRTC-based connection for conversations using LiveKit.""" + + DEFAULT_LIVEKIT_WS_URL = "wss://livekit.rtc.elevenlabs.io" + DEFAULT_API_ORIGIN = "https://api.elevenlabs.io" + + def __init__( + self, + conversation_token: Optional[str] = None, + agent_id: Optional[str] = None, + livekit_url: Optional[str] = None, + api_origin: Optional[str] = None, + overrides: Optional[Dict[str, Any]] = None, + on_debug: Optional[Callable[[Dict[str, Any]], None]] = None, + ) -> None: + super().__init__() + self.conversation_token = conversation_token + self.agent_id = agent_id + self.livekit_url = livekit_url or self.DEFAULT_LIVEKIT_WS_URL + self.api_origin = api_origin or self.DEFAULT_API_ORIGIN + self.overrides = overrides or {} + self.on_debug = on_debug + self._room: Optional[Room] = None + self._is_connected: bool = False + + @classmethod + async def create(cls, config: WebRTCConnectionConfig) -> "WebRTCConnection": + """Create and connect a WebRTC connection.""" + connection = cls( + conversation_token=config.conversation_token, + agent_id=config.agent_id, + livekit_url=config.livekit_url, + api_origin=config.api_origin, + overrides=config.overrides, + on_debug=config.on_debug, + ) + + await connection.connect() + return connection + + async def connect(self) -> None: + """Establish the WebRTC connection using LiveKit.""" + try: + # Get conversation token if not provided + if not self.conversation_token: + if not self.agent_id: + raise ValueError("Either conversation_token or agent_id is required for WebRTC connection") + self.conversation_token = await self._fetch_conversation_token() + + # Create room and connect + self._room = Room() + self._setup_room_callbacks() + + # Connect to LiveKit room using configurable URL + try: + await self._room.connect(self.livekit_url, self.conversation_token) + self._is_connected = True + except Exception as e: + self._is_connected = False + raise ConnectionError(f"Failed to connect to LiveKit room: {e}") from e + + # Set conversation ID from room name if available + if self._room.name: + # Extract conversation ID from room name if it contains one + import re + match = re.search(r'(conv_[a-zA-Z0-9]+)', self._room.name) + self.conversation_id = match.group(0) if match else self._room.name + else: + self.conversation_id = f"webrtc-{id(self)}" + + # Enable microphone + try: + await self._enable_microphone(True) + except Exception as e: + self.debug({ + "type": "microphone_enable_error", + "error": f"Failed to enable microphone: {e}" + }) + except Exception as e: + self.debug({ + "type": "microphone_enable_error", + "error": str(e) + }) + + # Send overrides if any + if self.overrides: + try: + await self.send_message(self._construct_overrides()) + except Exception as e: + self.debug({ + "type": "overrides_send_error", + "error": str(e) + }) + + self.debug({ + "type": "conversation_initiation_client_data", + "message": self._construct_overrides() + }) + + except Exception as e: + # Ensure cleanup on connection failure + if self._room: + try: + await self._room.disconnect() + except: + pass + self._room = None + self._is_connected = False + raise + + async def close(self) -> None: + """Close the WebRTC connection.""" + if self._room: + try: + await self._room.disconnect() + except Exception as e: + self.debug({ + "type": "disconnect_error", + "error": str(e) + }) + finally: + self._room = None + self._is_connected = False + + async def send_message(self, message: dict) -> None: + """Send a message through WebRTC data channel.""" + if not self._is_connected or not self._room: + raise RuntimeError("WebRTC room not connected") + + # In WebRTC mode, audio is sent via published tracks, not data messages + if "user_audio_chunk" in message: + return # Audio is handled separately + + try: + data = json.dumps(message).encode('utf-8') + await self._room.local_participant.publish_data(data, reliable=True) + except Exception as e: + print(f"Failed to send message via WebRTC: {e}") + raise + + async def send_audio(self, audio_data: bytes) -> None: + """Send audio data through WebRTC (handled by published tracks).""" + # In WebRTC mode, audio is sent through the microphone track + # This method can be used for custom audio streaming if needed + pass + + async def receive_messages(self) -> None: + """Receive and handle messages - handled by LiveKit event callbacks.""" + # In WebRTC mode, messages are handled via LiveKit event callbacks + # This method exists for compatibility with the BaseConnection interface + if not self._is_connected: + return + + # Keep the connection alive while connected + while self._is_connected: + await asyncio.sleep(0.1) + + async def _fetch_conversation_token(self) -> str: + """Fetch conversation token from ElevenLabs API.""" + if not self.agent_id: + raise ValueError("Agent ID is required to fetch conversation token") + + try: + # Get version and source from overrides or use defaults + version = self.overrides.get("client", {}).get("version", "2.15.0") # From pyproject.toml + source = self.overrides.get("client", {}).get("source", "python_sdk") + + # Convert WSS origin to HTTPS for API calls + api_origin = self._convert_wss_to_https(self.api_origin) + + url = f"{api_origin}/v1/convai/conversation/token?agent_id={self.agent_id}&source={source}&version={version}" + + async with httpx.AsyncClient(timeout=30.0) as client: + try: + response = await client.get(url) + except httpx.TimeoutException: + raise ConnectionError(f"Timeout when fetching conversation token for agent {self.agent_id}") + except httpx.NetworkError as e: + raise ConnectionError(f"Network error when fetching conversation token: {e}") + + if not response.is_success: + error_msg = f"ElevenLabs API returned {response.status_code} {response.reason_phrase}" + if response.status_code == 401: + error_msg = "Your agent has authentication enabled, but no signed URL or conversation token was provided." + elif response.status_code == 404: + error_msg = f"Agent with ID {self.agent_id} not found" + elif response.status_code == 429: + error_msg = "Rate limit exceeded. Please try again later." + + raise Exception(f"Failed to fetch conversation token for agent {self.agent_id}: {error_msg}") + + try: + data = response.json() + except Exception as e: + raise Exception(f"Invalid JSON response from API: {e}") + + token = data.get("token") + + if not token: + raise Exception("No conversation token received from API") + + return token + + except Exception as e: + self.debug({ + "type": "token_fetch_error", + "agent_id": self.agent_id, + "error": str(e) + }) + raise + + def _convert_wss_to_https(self, origin: str) -> str: + """Convert WSS origin to HTTPS for API calls.""" + return origin.replace("wss://", "https://") + + def _construct_overrides(self) -> Dict[str, Any]: + """Construct overrides message for conversation initiation.""" + return { + "type": "conversation_initiation_client_data", + "overrides": self.overrides + } + + def debug(self, info: Dict[str, Any]) -> None: + """Log debug information.""" + if self.on_debug: + self.on_debug(info) + + def _setup_room_callbacks(self) -> None: + """Setup LiveKit room event callbacks.""" + if not self._room: + return + + @self._room.on("connected") + def on_connected() -> None: + self._is_connected = True + self.debug({"type": "webrtc_connected", "message": "WebRTC room connected"}) + + @self._room.on("disconnected") + def on_disconnected(reason: Optional[str] = None) -> None: + self._is_connected = False + self.debug({"type": "webrtc_disconnected", "message": f"WebRTC room disconnected: {reason}"}) + + @self._room.on("connection_state_changed") + def on_connection_state_changed(state) -> None: + self.debug({"type": "connection_state_changed", "state": str(state)}) + # Handle disconnected state + if hasattr(state, 'name') and state.name == 'DISCONNECTED': + self._is_connected = False + + @self._room.on("data_received") + def on_data_received(data: bytes, participant) -> None: + try: + message = json.loads(data.decode('utf-8')) + + # Filter out audio messages for WebRTC - they're handled via audio tracks + if message.get("type") == "audio": + return + + self._handle_message(message) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + self.debug({ + "type": "data_parse_error", + "error": str(e), + "raw_data": data.decode('utf-8', errors='replace') + }) + + @self._room.on("track_subscribed") + def on_track_subscribed(track, publication, participant) -> None: + if track.kind == TrackKind.KIND_AUDIO and "agent" in participant.identity: + self.debug({ + "type": "agent_audio_track_subscribed", + "participant": participant.identity + }) + + @self._room.on("active_speakers_changed") + def on_active_speakers_changed(speakers) -> None: + # Update mode based on active speakers + if speakers and len(speakers) > 0: + is_agent_speaking = any("agent" in speaker.identity for speaker in speakers) + mode = "speaking" if is_agent_speaking else "listening" + else: + mode = "listening" + + self.debug({"type": "mode_changed", "mode": mode}) + + async def set_microphone_enabled(self, enabled: bool) -> None: + """Enable or disable the microphone.""" + if not self._room or not self._room.local_participant: + raise RuntimeError("Room not connected") + + await self._enable_microphone(enabled) + + async def _enable_microphone(self, enabled: bool) -> None: + """Internal method to enable/disable microphone via track muting.""" + if not self._room or not self._room.local_participant: + raise RuntimeError("Room not connected") + + # Find the audio track publication + for track_pub in self._room.local_participant.track_publications.values(): + if track_pub.kind == TrackKind.KIND_AUDIO: + if track_pub.track: + if enabled: + await track_pub.track.unmute() + else: + await track_pub.track.mute() + return + + self.debug({ + "type": "microphone_control_error", + "enabled": enabled, + "error": "No audio track found" + }) + + async def set_microphone_device(self, device_id: str) -> None: + """Set the microphone input device.""" + if not self._room or not self._room.local_participant: + raise RuntimeError("Room not connected") + + # This would require additional LiveKit functionality for device switching + # For now, we log the request + self.debug({ + "type": "microphone_device_change_requested", + "device_id": device_id + }) + + def get_room(self) -> Optional[Room]: + """Get the LiveKit room instance for advanced usage.""" + return self._room \ No newline at end of file diff --git a/src/elevenlabs/conversational_ai/webrtc_conversation.py b/src/elevenlabs/conversational_ai/webrtc_conversation.py new file mode 100644 index 00000000..3d278501 --- /dev/null +++ b/src/elevenlabs/conversational_ai/webrtc_conversation.py @@ -0,0 +1,259 @@ +import asyncio +import json +import base64 +from typing import Optional, Callable, Awaitable + +from ..base_client import BaseElevenLabs +from .conversation import ( + BaseConversation, + ConversationInitiationData, + AsyncAudioInterface, + ClientTools +) +from .base_connection import ConnectionType +from .webrtc_connection import WebRTCConnection, WebRTCConnectionConfig + + +class WebRTCConversation(BaseConversation): + """WebRTC-based conversational AI session using LiveKit. + + This class provides WebRTC connectivity for real-time audio conversations + with ElevenLabs agents, offering lower latency compared to WebSocket connections. + """ + + def __init__( + self, + client: BaseElevenLabs, + agent_id: str, + user_id: Optional[str] = None, + *, + conversation_token: Optional[str] = None, + livekit_url: Optional[str] = None, + api_origin: Optional[str] = None, + webrtc_overrides: Optional[dict] = None, + on_debug: Optional[Callable[[dict], None]] = None, + audio_interface: Optional[AsyncAudioInterface] = None, + config: Optional[ConversationInitiationData] = None, + client_tools: Optional[ClientTools] = None, + callback_agent_response: Optional[Callable[[str], Awaitable[None]]] = None, + callback_agent_response_correction: Optional[Callable[[str, str], Awaitable[None]]] = None, + callback_user_transcript: Optional[Callable[[str], Awaitable[None]]] = None, + callback_latency_measurement: Optional[Callable[[int], Awaitable[None]]] = None, + callback_end_session: Optional[Callable[[], Awaitable[None]]] = None, + ): + """Initialize a WebRTC conversation. + + Args: + client: The ElevenLabs client to use for the conversation. + agent_id: The ID of the agent to converse with. + user_id: The ID of the user conversing with the agent. + conversation_token: Token for WebRTC authentication. If not provided, + will be fetched using the agent_id. + livekit_url: Custom LiveKit WebSocket URL. If not provided, uses default. + api_origin: Custom API origin for token fetching. If not provided, uses default. + webrtc_overrides: Additional overrides specific to WebRTC connection. + on_debug: Debug callback function for WebRTC connection events. + audio_interface: The async audio interface to use for input and output. + config: Configuration for the conversation. + client_tools: Client tools for handling agent tool calls. + callback_agent_response: Async callback for agent responses. + callback_agent_response_correction: Async callback for response corrections. + callback_user_transcript: Async callback for user transcripts. + callback_latency_measurement: Async callback for latency measurements. + callback_end_session: Async callback for when session ends. + """ + + # Set up configuration with WebRTC specifics + if config is None: + config = ConversationInitiationData() + config.connection_type = ConnectionType.WEBRTC + config.conversation_token = conversation_token + config.livekit_url = livekit_url + config.api_origin = api_origin + config.webrtc_overrides = webrtc_overrides or {} + config.on_debug = on_debug + + super().__init__( + client=client, + agent_id=agent_id, + user_id=user_id, + requires_auth=True, # WebRTC requires authentication + config=config, + client_tools=client_tools, + ) + + self.audio_interface = audio_interface + self.callback_agent_response = callback_agent_response + self.callback_agent_response_correction = callback_agent_response_correction + self.callback_user_transcript = callback_user_transcript + self.callback_latency_measurement = callback_latency_measurement + self.callback_end_session = callback_end_session + + self._connection: Optional[WebRTCConnection] = None + self._should_stop: Optional[asyncio.Event] = None + self._session_task: Optional[asyncio.Task] = None + + async def start_session(self): + """Start the WebRTC conversation session.""" + try: + # Initialize the stop event + if self._should_stop is None: + self._should_stop = asyncio.Event() + # Use the enhanced connection creation from BaseConversation + self._connection = self._create_connection() + + # Set up message handler + self._connection.on_message(self._handle_message) + + # Connect + await self._connection.connect() + + # Update conversation ID + self._conversation_id = self._connection.conversation_id + + # Start audio interface if provided + if self.audio_interface: + await self.audio_interface.start(self._audio_input_callback) + + if self.config.on_debug: + self.config.on_debug({ + "type": "webrtc_conversation_started", + "conversation_id": self._conversation_id + }) + + except Exception as e: + if self.config.on_debug: + self.config.on_debug({ + "type": "webrtc_session_start_error", + "error": str(e) + }) + raise + + async def end_session(self): + """End the WebRTC conversation session.""" + if self._should_stop: + self._should_stop.set() + + if self.audio_interface: + await self.audio_interface.stop() + + if self._connection: + await self._connection.close() + self._connection = None + + self.client_tools.stop() + + if self.callback_end_session: + await self.callback_end_session() + + async def send_user_message(self, text: str): + """Send a text message from the user to the agent.""" + if not self._connection: + raise RuntimeError("Session not started") + + message = { + "type": "user_message", + "text": text + } + await self._connection.send_message(message) + + async def send_contextual_update(self, text: str): + """Send a contextual update to the conversation.""" + if not self._connection: + raise RuntimeError("Session not started") + + message = { + "type": "contextual_update", + "text": text + } + await self._connection.send_message(message) + + async def register_user_activity(self): + """Register user activity to prevent session timeout.""" + if not self._connection: + raise RuntimeError("Session not started") + + message = { + "type": "user_activity" + } + await self._connection.send_message(message) + + async def _audio_input_callback(self, audio_data: bytes): + """Handle audio input from the audio interface.""" + if self._connection and self._should_stop and not self._should_stop.is_set(): + # For WebRTC, audio is sent through the room's microphone track + # This callback can be used for custom processing if needed + pass + + async def _handle_message(self, message: dict): + """Handle incoming messages from the WebRTC connection.""" + try: + message_type = message.get("type") + + if message_type == "conversation_initiation_metadata": + event = message["conversation_initiation_metadata_event"] + if not self._conversation_id: + self._conversation_id = event["conversation_id"] + + elif message_type == "audio": + # Audio is handled through WebRTC audio tracks, not data messages + pass + + elif message_type == "agent_response": + if self.callback_agent_response: + event = message["agent_response_event"] + await self.callback_agent_response(event["agent_response"].strip()) + + elif message_type == "agent_response_correction": + if self.callback_agent_response_correction: + event = message["agent_response_correction_event"] + await self.callback_agent_response_correction( + event["original_agent_response"].strip(), + event["corrected_agent_response"].strip() + ) + + elif message_type == "user_transcript": + if self.callback_user_transcript: + event = message["user_transcription_event"] + await self.callback_user_transcript(event["user_transcript"].strip()) + + elif message_type == "interruption": + if self.audio_interface: + await self.audio_interface.interrupt() + + elif message_type == "ping": + event = message["ping_event"] + # Send pong response + pong_message = { + "type": "pong", + "event_id": event["event_id"] + } + if self._connection: + await self._connection.send_message(pong_message) + + if self.callback_latency_measurement and event.get("ping_ms"): + await self.callback_latency_measurement(int(event["ping_ms"])) + + elif message_type == "client_tool_call": + tool_call = message.get("client_tool_call", {}) + tool_name = tool_call.get("tool_name") + parameters = { + "tool_call_id": tool_call["tool_call_id"], + **tool_call.get("parameters", {}) + } + + # Execute tool asynchronously + async def send_response(response): + if self._should_stop and not self._should_stop.is_set(): + await self._connection.send_message(response) + + self.client_tools.execute_tool(tool_name, parameters, send_response) + + except Exception as e: + print(f"Error handling message: {e}") + + def get_webrtc_room(self): + """Get the underlying LiveKit room for advanced WebRTC operations.""" + if self._connection: + return self._connection.get_room() + return None \ No newline at end of file diff --git a/src/elevenlabs/conversational_ai/websocket_connection.py b/src/elevenlabs/conversational_ai/websocket_connection.py new file mode 100644 index 00000000..5b20579f --- /dev/null +++ b/src/elevenlabs/conversational_ai/websocket_connection.py @@ -0,0 +1,57 @@ +import json +import base64 +from typing import Optional +import websockets +from websockets.exceptions import ConnectionClosedOK + +from .base_connection import BaseConnection + + +class WebSocketConnection(BaseConnection): + """WebSocket-based connection for conversations.""" + + def __init__(self, ws_url: str): + super().__init__() + self.ws_url = ws_url + self._ws: Optional[websockets.WebSocketClientProtocol] = None + + async def connect(self) -> None: + """Establish the WebSocket connection.""" + self._ws = await websockets.connect(self.ws_url, max_size=16 * 1024 * 1024) + + async def close(self) -> None: + """Close the WebSocket connection.""" + if self._ws: + await self._ws.close() + self._ws = None + + async def send_message(self, message: dict) -> None: + """Send a message through the WebSocket.""" + if not self._ws: + raise RuntimeError("WebSocket not connected") + await self._ws.send(json.dumps(message)) + + async def send_audio(self, audio_data: bytes) -> None: + """Send audio data through the WebSocket.""" + if not self._ws: + raise RuntimeError("WebSocket not connected") + + message = { + "user_audio_chunk": base64.b64encode(audio_data).decode() + } + await self._ws.send(json.dumps(message)) + + async def receive_messages(self) -> None: + """Receive and handle messages from the WebSocket.""" + if not self._ws: + return + + try: + async for message_str in self._ws: + try: + message = json.loads(message_str) + self._handle_message(message) + except json.JSONDecodeError: + continue + except ConnectionClosedOK: + pass \ No newline at end of file diff --git a/tests/test_webrtc_conversation.py b/tests/test_webrtc_conversation.py new file mode 100644 index 00000000..0f6ef497 --- /dev/null +++ b/tests/test_webrtc_conversation.py @@ -0,0 +1,245 @@ +import pytest +import asyncio +from unittest.mock import Mock, AsyncMock, patch + +from elevenlabs.conversational_ai.base_connection import ConnectionType +from elevenlabs.conversational_ai.conversation_factory import ( + create_conversation, + create_webrtc_conversation, + create_websocket_conversation +) +from elevenlabs.conversational_ai.webrtc_conversation import WebRTCConversation +from elevenlabs.conversational_ai.conversation import Conversation, AsyncConversation +from elevenlabs.conversational_ai.webrtc_connection import WebRTCConnection + + +class TestWebRTCConversation: + """Test WebRTC conversation functionality.""" + + @pytest.fixture + def mock_client(self): + """Create a mock ElevenLabs client.""" + return Mock() + + @pytest.fixture + def mock_audio_interface(self): + """Create a mock async audio interface.""" + from elevenlabs.conversational_ai.conversation import AsyncAudioInterface + interface = Mock(spec=AsyncAudioInterface) + interface.start = AsyncMock() + interface.stop = AsyncMock() + interface.output = AsyncMock() + interface.interrupt = AsyncMock() + return interface + + def test_connection_type_determination(self): + """Test that connection types are determined correctly.""" + from elevenlabs.conversational_ai.connection_factory import determine_connection_type + + # Default should be websocket + assert determine_connection_type() == ConnectionType.WEBSOCKET + + # Explicit connection type should be respected + assert determine_connection_type(ConnectionType.WEBRTC) == ConnectionType.WEBRTC + + # Conversation token should imply WebRTC + assert determine_connection_type(conversation_token="token") == ConnectionType.WEBRTC + + # Explicit type should override token inference + assert determine_connection_type( + ConnectionType.WEBSOCKET, + conversation_token="token" + ) == ConnectionType.WEBSOCKET + + def test_factory_creates_correct_conversation_types(self, mock_client): + """Test that the factory creates the correct conversation types.""" + # Create event loop for WebRTC conversation + import asyncio + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # WebRTC conversation + webrtc_conv = create_conversation( + client=mock_client, + agent_id="test-agent", + connection_type=ConnectionType.WEBRTC + ) + assert isinstance(webrtc_conv, WebRTCConversation) + finally: + loop.close() + + # WebSocket conversation (sync) + ws_conv = create_conversation( + client=mock_client, + agent_id="test-agent", + connection_type=ConnectionType.WEBSOCKET + ) + assert isinstance(ws_conv, (Conversation, AsyncConversation)) + + def test_convenience_functions(self, mock_client, mock_audio_interface): + """Test convenience functions for creating conversations.""" + # Create event loop for WebRTC conversation + import asyncio + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + # WebRTC convenience function with conversation token to avoid HTTP calls + with patch('elevenlabs.conversational_ai.webrtc_connection.Room') as mock_room_class: + mock_room = Mock() + mock_room.connect = AsyncMock() + mock_room.disconnect = AsyncMock() + mock_room.local_participant = Mock() + mock_room.local_participant.set_microphone_enabled = AsyncMock() + mock_room.name = "test-room" + mock_room_class.return_value = mock_room + + webrtc_conv = create_webrtc_conversation( + client=mock_client, + agent_id="test-agent", + conversation_token="test-token", + audio_interface=mock_audio_interface + ) + assert isinstance(webrtc_conv, WebRTCConversation) + finally: + loop.close() + + # WebSocket convenience function + ws_conv = create_websocket_conversation( + client=mock_client, + agent_id="test-agent" + ) + assert isinstance(ws_conv, Conversation) + + @pytest.mark.asyncio + async def test_webrtc_conversation_lifecycle(self, mock_client, mock_audio_interface): + """Test WebRTC conversation lifecycle.""" + with patch('elevenlabs.conversational_ai.webrtc_connection.Room') as mock_room_class: + # Mock room instance + mock_room = Mock() + mock_room.connect = AsyncMock() + mock_room.disconnect = AsyncMock() + mock_room.local_participant = Mock() + mock_room.local_participant.set_microphone_enabled = AsyncMock() + mock_room.local_participant.publish_data = AsyncMock() + mock_room.name = "test-room" + mock_room_class.return_value = mock_room + + # Create conversation with a conversation token to avoid HTTP calls + conversation = WebRTCConversation( + client=mock_client, + agent_id="test-agent", + conversation_token="test-token", # Provide token to avoid fetching + audio_interface=mock_audio_interface + ) + + # Test start session + await conversation.start_session() + mock_room.connect.assert_called_once() + mock_audio_interface.start.assert_called_once() + + # Test end session + await conversation.end_session() + mock_room.disconnect.assert_called_once() + mock_audio_interface.stop.assert_called_once() + + @pytest.mark.asyncio + async def test_webrtc_conversation_messaging(self, mock_client): + """Test WebRTC conversation messaging functionality.""" + with patch('elevenlabs.conversational_ai.webrtc_connection.Room') as mock_room_class: + # Mock room instance + mock_room = Mock() + mock_room.connect = AsyncMock() + mock_room.disconnect = AsyncMock() + mock_room.local_participant = Mock() + mock_room.local_participant.set_microphone_enabled = AsyncMock() + mock_room.local_participant.publish_data = AsyncMock() + mock_room.name = "test-room" + mock_room_class.return_value = mock_room + + # Create conversation with a conversation token to avoid HTTP calls + conversation = WebRTCConversation( + client=mock_client, + agent_id="test-agent", + conversation_token="test-token" # Provide token to avoid fetching + ) + + # Start session + await conversation.start_session() + + # Test sending user message + await conversation.send_user_message("Hello, agent!") + # WebRTC messages are sent via publish_data + assert mock_room.local_participant.publish_data.called + + # Test sending contextual update + await conversation.send_contextual_update("Context update") + assert mock_room.local_participant.publish_data.called + + # Test registering user activity + await conversation.register_user_activity() + assert mock_room.local_participant.publish_data.called + + def test_webrtc_connection_creation(self): + """Test WebRTC connection creation and configuration.""" + # Test with conversation token + connection = WebRTCConnection(conversation_token="test-token") + assert connection.conversation_token == "test-token" + + # Test with agent ID + connection = WebRTCConnection(agent_id="test-agent") + assert connection.agent_id == "test-agent" + + @pytest.mark.asyncio + async def test_webrtc_connection_token_fetch(self): + """Test fetching conversation token from API.""" + with patch('httpx.AsyncClient') as mock_client_class: + mock_client = AsyncMock() + mock_response = Mock() + mock_response.is_success = True + mock_response.json.return_value = {"token": "fetched-token"} + mock_client.get.return_value = mock_response + mock_client_class.return_value.__aenter__.return_value = mock_client + + connection = WebRTCConnection(agent_id="test-agent") + token = await connection._fetch_conversation_token() + + assert token == "fetched-token" + # Verify the call was made with required parameters + call_args = mock_client.get.call_args[0][0] + assert call_args.startswith("https://api.elevenlabs.io/v1/convai/conversation/token?") + assert "agent_id=test-agent" in call_args + assert "source=python_sdk" in call_args + assert "version=" in call_args + + @pytest.mark.asyncio + async def test_webrtc_connection_token_fetch_error(self): + """Test error handling when fetching conversation token.""" + with patch('httpx.AsyncClient') as mock_client_class: + mock_client = AsyncMock() + mock_response = Mock() + mock_response.is_success = False + mock_response.status_code = 404 + mock_response.text = "Not Found" + mock_client.get.return_value = mock_response + mock_client_class.return_value.__aenter__.return_value = mock_client + + connection = WebRTCConnection(agent_id="test-agent") + + with pytest.raises(Exception, match="Failed to fetch conversation token"): + await connection._fetch_conversation_token() + + def test_factory_validation(self, mock_client): + """Test validation in factory functions.""" + from elevenlabs.conversational_ai.conversation import AudioInterface + + # Should raise error for wrong audio interface type with WebRTC + sync_audio = Mock(spec=AudioInterface) + with pytest.raises(ValueError, match="WebRTC conversations require an AsyncAudioInterface"): + create_conversation( + client=mock_client, + agent_id="test-agent", + connection_type=ConnectionType.WEBRTC, + audio_interface=sync_audio + ) \ No newline at end of file