diff --git a/livekit-rtc/livekit/rtc/__init__.py b/livekit-rtc/livekit/rtc/__init__.py index 565b8882..f641bfd9 100644 --- a/livekit-rtc/livekit/rtc/__init__.py +++ b/livekit-rtc/livekit/rtc/__init__.py @@ -108,6 +108,7 @@ ByteStreamWriter, ByteStreamReader, ) +from .frame_processor import FrameProcessor __all__ = [ "ConnectionQuality", @@ -184,6 +185,7 @@ "ByteStreamReader", "ByteStreamWriter", "AudioProcessingModule", + "FrameProcessor", "__version__", ] diff --git a/livekit-rtc/livekit/rtc/audio_stream.py b/livekit-rtc/livekit/rtc/audio_stream.py index b33e668f..b2e8bbaa 100644 --- a/livekit-rtc/livekit/rtc/audio_stream.py +++ b/livekit-rtc/livekit/rtc/audio_stream.py @@ -27,6 +27,7 @@ from .audio_frame import AudioFrame from .participant import Participant from .track import Track +from .frame_processor import FrameProcessor @dataclass @@ -62,7 +63,7 @@ def __init__( sample_rate: int = 48000, num_channels: int = 1, frame_size_ms: int | None = None, - noise_cancellation: Optional[NoiseCancellationOptions] = None, + noise_cancellation: Optional[NoiseCancellationOptions | FrameProcessor[AudioFrame]] = None, **kwargs, ) -> None: """Initialize an `AudioStream` instance. @@ -76,8 +77,8 @@ def __init__( sample_rate (int, optional): The sample rate for the audio stream in Hz. Defaults to 48000. num_channels (int, optional): The number of audio channels. Defaults to 1. - noise_cancellation (Optional[NoiseCancellationOptions], optional): - If noise cancellation is used, pass a `NoiseCancellationOptions` instance + noise_cancellation (Optional[NoiseCancellationOptions | FrameProcessor[AudioFrame]], optional): + If noise cancellation is used, pass a `NoiseCancellationOptions` or `FrameProcessor[AudioFrame]` instance created by the noise cancellation module. Example: @@ -105,9 +106,12 @@ def __init__( self._audio_filter_module = None self._audio_filter_options = None - if noise_cancellation is not None: + if isinstance(noise_cancellation, NoiseCancellationOptions): self._audio_filter_module = noise_cancellation.module_id self._audio_filter_options = noise_cancellation.options + elif isinstance(noise_cancellation, FrameProcessor): + self._processor = noise_cancellation + self._task = self._loop.create_task(self._run()) self._task.add_done_callback(task_done_logger) @@ -132,7 +136,7 @@ def from_participant( sample_rate: int = 48000, num_channels: int = 1, frame_size_ms: int | None = None, - noise_cancellation: Optional[NoiseCancellationOptions] = None, + noise_cancellation: Optional[NoiseCancellationOptions | FrameProcessor[AudioFrame]] = None, ) -> AudioStream: """Create an `AudioStream` from a participant's audio track. @@ -182,7 +186,7 @@ def from_track( sample_rate: int = 48000, num_channels: int = 1, frame_size_ms: int | None = None, - noise_cancellation: Optional[NoiseCancellationOptions] = None, + noise_cancellation: Optional[NoiseCancellationOptions | FrameProcessor[AudioFrame]] = None, ) -> AudioStream: """Create an `AudioStream` from an existing audio track. @@ -268,6 +272,8 @@ async def _run(self): if audio_event.HasField("frame_received"): owned_buffer_info = audio_event.frame_received.frame frame = AudioFrame._from_owned_info(owned_buffer_info) + if self._processor is not None: + frame = self._processor._process(frame) event = AudioFrameEvent(frame) self._queue.put(event) elif audio_event.HasField("eos"): diff --git a/livekit-rtc/livekit/rtc/frame_processor.py b/livekit-rtc/livekit/rtc/frame_processor.py new file mode 100644 index 00000000..8704e82f --- /dev/null +++ b/livekit-rtc/livekit/rtc/frame_processor.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from typing import Generic, TypeVar, Union +from .audio_frame import AudioFrame +from .video_frame import VideoFrame + + +T = TypeVar("T", bound=Union[AudioFrame, VideoFrame]) + + +class FrameProcessor(Generic[T], ABC): + @property + @abstractmethod + def is_enabled(self) -> bool: ... + + @abstractmethod + def set_enabled(self, enable: bool): ... + + @abstractmethod + def _update_stream_info( + self, + *, + room_name: str, + participant_identity: str, + publication_sid: str, + ): ... + + @abstractmethod + def _update_credentials(self, *, token: str, url: str): ... + + @abstractmethod + def _process(self, frame: T) -> T: ... + + @abstractmethod + def _close(self): ...