diff --git a/src/diarize/__init__.py b/src/diarize/__init__.py index fc1574f..68a907c 100644 --- a/src/diarize/__init__.py +++ b/src/diarize/__init__.py @@ -23,6 +23,7 @@ from .clustering import cluster_speakers, estimate_speakers # noqa: F401 from .embeddings import extract_embeddings from .utils import ( + DiarizeArtifacts, DiarizeResult, Segment, SpeakerEstimationDetails, # noqa: F401 @@ -45,6 +46,7 @@ class _RawSegment(NamedTuple): __all__ = [ "diarize", "DiarizeResult", + "DiarizeArtifacts", "Segment", "SpeakerEstimationDetails", "estimate_speakers", @@ -129,6 +131,7 @@ def diarize( min_speakers: int = 1, max_speakers: int = 20, num_speakers: int | None = None, + return_artifacts: bool = False, ) -> DiarizeResult: """Run the full speaker diarization pipeline on an audio file. @@ -146,6 +149,9 @@ def diarize( max_speakers: Maximum number of speakers for auto-detection. num_speakers: If set, skip auto-detection and use this exact number of speakers. + return_artifacts: If ``True``, populate + :attr:`DiarizeResult.artifacts` with speaker embeddings and + subsegments. Defaults to ``False``. Returns: :class:`DiarizeResult` containing segments, speaker info, and @@ -202,6 +208,10 @@ def diarize( audio_path=audio_path_str, audio_duration=duration, estimation_details=estimation_details, + artifacts=DiarizeArtifacts( + embeddings=embeddings.tolist(), + subsegments=subsegments, + ) if return_artifacts else None, ) logger.info( diff --git a/src/diarize/utils.py b/src/diarize/utils.py index 4a74188..7b5b40c 100644 --- a/src/diarize/utils.py +++ b/src/diarize/utils.py @@ -11,6 +11,7 @@ from typing import Iterator import soundfile as sf + from pydantic import BaseModel, Field, computed_field, model_validator logger = logging.getLogger(__name__) @@ -21,6 +22,7 @@ "SubSegment", "SpeakerEstimationDetails", "DiarizeResult", + "DiarizeArtifacts", "get_audio_duration", "format_timestamp", ] @@ -121,6 +123,26 @@ class SpeakerEstimationDetails(BaseModel): cosine_sim_p10: float | None = None +class DiarizeArtifacts(BaseModel): + """Internal pipeline artifacts, available when ``return_artifacts=True``. + + All fields use plain Python types for clean serialization. + + Attributes: + embeddings: Speaker embedding vectors (one per subsegment). + subsegments: Embedding windows used during extraction. + """ + + embeddings: list[list[float]] = Field( + default_factory=list, + description="Speaker embedding vectors as nested lists", + ) + subsegments: list[SubSegment] = Field( + default_factory=list, + description="Embedding windows used during extraction", + ) + + class DiarizeResult(BaseModel): """Result of speaker diarization. @@ -145,6 +167,7 @@ class DiarizeResult(BaseModel): audio_path: str = "" audio_duration: float = Field(default=0.0, ge=0) estimation_details: SpeakerEstimationDetails | None = None + artifacts: DiarizeArtifacts | None = None @computed_field # type: ignore[prop-decorator] @property