Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/diarize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .clustering import cluster_speakers, estimate_speakers # noqa: F401
from .embeddings import extract_embeddings
from .utils import (
DiarizeArtifacts,
DiarizeResult,
Segment,
SpeakerEstimationDetails, # noqa: F401
Expand All @@ -45,6 +46,7 @@ class _RawSegment(NamedTuple):
__all__ = [
"diarize",
"DiarizeResult",
"DiarizeArtifacts",
"Segment",
"SpeakerEstimationDetails",
"estimate_speakers",
Expand Down Expand Up @@ -129,6 +131,7 @@ def diarize(
min_speakers: int = 1,
max_speakers: int = 20,
num_speakers: int | None = None,
return_artifacts: bool = False,
) -> DiarizeResult:
"""Run the full speaker diarization pipeline on an audio file.

Expand All @@ -146,6 +149,9 @@ def diarize(
max_speakers: Maximum number of speakers for auto-detection.
num_speakers: If set, skip auto-detection and use this exact
number of speakers.
return_artifacts: If ``True``, populate
:attr:`DiarizeResult.artifacts` with speaker embeddings and
subsegments. Defaults to ``False``.

Returns:
:class:`DiarizeResult` containing segments, speaker info, and
Expand Down Expand Up @@ -202,6 +208,10 @@ def diarize(
audio_path=audio_path_str,
audio_duration=duration,
estimation_details=estimation_details,
artifacts=DiarizeArtifacts(
embeddings=embeddings.tolist(),
subsegments=subsegments,
) if return_artifacts else None,
)

logger.info(
Expand Down
23 changes: 23 additions & 0 deletions src/diarize/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Iterator

import soundfile as sf

from pydantic import BaseModel, Field, computed_field, model_validator

logger = logging.getLogger(__name__)
Expand All @@ -21,6 +22,7 @@
"SubSegment",
"SpeakerEstimationDetails",
"DiarizeResult",
"DiarizeArtifacts",
"get_audio_duration",
"format_timestamp",
]
Expand Down Expand Up @@ -121,6 +123,26 @@ class SpeakerEstimationDetails(BaseModel):
cosine_sim_p10: float | None = None


class DiarizeArtifacts(BaseModel):
"""Internal pipeline artifacts, available when ``return_artifacts=True``.

All fields use plain Python types for clean serialization.

Attributes:
embeddings: Speaker embedding vectors (one per subsegment).
subsegments: Embedding windows used during extraction.
"""

embeddings: list[list[float]] = Field(
default_factory=list,
description="Speaker embedding vectors as nested lists",
)
subsegments: list[SubSegment] = Field(
default_factory=list,
description="Embedding windows used during extraction",
)


class DiarizeResult(BaseModel):
"""Result of speaker diarization.

Expand All @@ -145,6 +167,7 @@ class DiarizeResult(BaseModel):
audio_path: str = ""
audio_duration: float = Field(default=0.0, ge=0)
estimation_details: SpeakerEstimationDetails | None = None
artifacts: DiarizeArtifacts | None = None

@computed_field # type: ignore[prop-decorator]
@property
Expand Down