diff --git a/README.md b/README.md index 3dd6b26..403b297 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ for seg in result.segments: print(f" [{seg.start:.1f}s - {seg.end:.1f}s] {seg.speaker}") ``` -**~5.2% weighted DER** on VoxConverse dev. Processes audio **~8x faster than real-time** on CPU. Automatically detects the number of speakers. +**~5.0% weighted DER** on VoxConverse dev. Processes audio **~8x faster than real-time** on CPU. Automatically detects the number of speakers. > Benchmarked on a single dataset ([VoxConverse](https://github.com/joonson/voxconverse)). Cross-dataset validation is [in progress](#roadmap). @@ -35,7 +35,7 @@ for seg in result.segments: | GPU required | No | No (7x slower on CPU) | No | | HuggingFace account | No | Yes | Yes | | Auto speaker count | Yes | Yes | Yes | -| DER (VoxConverse dev) | **~5.2%** | ~11.2% | ~8.5% | +| DER (VoxConverse dev) | **~5.0%** | ~11.2% | ~8.5% | | CPU speed (RTF) | **0.12** | 0.86 | — | | Install | `pip install diarize` | `pip install pyannote.audio` | `pip install pyannote.audio` | @@ -89,7 +89,7 @@ Four-stage pipeline, all CPU, all open-source: 1. **Silero VAD** (MIT) — detects speech segments 2. **WeSpeaker ResNet34-LM** (Apache 2.0) — extracts 256-dim speaker embeddings via ONNX 3. **GMM BIC + silhouette refinement** — estimates the number of speakers -4. **Spectral Clustering** (scikit-learn, BSD) + smoothing — assigns speaker labels +4. **Spectral Clustering** (scikit-learn, BSD) + temporal smoothing — assigns speaker labels Details: [How It Works](https://foxnosetech.github.io/diarize/how-it-works/) @@ -102,7 +102,7 @@ Evaluated on [VoxConverse](https://github.com/joonson/voxconverse) dev set (216 | System | Weighted DER | Notes | |--------|----------|-------| | pyannote precision-2 | ~8.5% | Commercial license | -| **diarize** | **~5.2%** | **Apache 2.0, CPU-only, no API key** | +| **diarize** | **~5.0%** | **Apache 2.0, CPU-only, no API key** | | pyannote community-1 | ~11.2% | CC-BY-4.0, needs HF token | | pyannote 3.1 (legacy) | ~11.2% | MIT, needs HF token | @@ -121,7 +121,7 @@ Full benchmark results, speed comparison, and methodology: [benchmarks](https:// ## When to use something else - **You need commercial support or cross-dataset validation.** pyannote's commercial model has published production-oriented benchmarks beyond this single VoxConverse evaluation. If accuracy is the top priority and you have budget, compare on your own data. -- **You need very stable speaker labels in transcripts.** diarize can still show speaker fragmentation / label switching: one real speaker may be split across multiple `SPEAKER_XX` labels, or the label may briefly jump inside a continuous turn, especially on noisy real-world audio. +- **You need very stable speaker labels in transcripts.** Temporal smoothing reduces short label jumps, but diarize can still show speaker fragmentation / label switching: one real speaker may be split across multiple `SPEAKER_XX` labels, especially on noisy real-world audio. - **Your audio has 8+ speakers.** Automatic speaker count estimation degrades above 7 speakers. You can pass `num_speakers` explicitly, but test carefully. - **You need overlapping speech detection.** diarize assigns each segment to one speaker. Overlapping speech is not modeled. - **You need GPU-accelerated throughput.** diarize is CPU-only by design. For processing thousands of hours with GPU infrastructure, NeMo or pyannote on GPU will be faster. diff --git a/docs/benchmarks.md b/docs/benchmarks.md index af6037a..9dd8004 100644 --- a/docs/benchmarks.md +++ b/docs/benchmarks.md @@ -23,7 +23,7 @@ DER is the standard metric for speaker diarization, computed with | System | Weighted DER | Median DER | Notes | |--------|----------|------------|-------| | pyannote precision-2 | ~8.5% | -- | Commercial license | -| **diarize** | **~5.2%** | **~2.4%** | **Apache 2.0, CPU-only, no API key** | +| **diarize** | **~5.0%** | **~2.2%** | **Apache 2.0, CPU-only, no API key** | | pyannote community-1 | ~11.2% | -- | CC-BY-4.0, needs HF token | | pyannote 3.1 (legacy) | ~11.2% | -- | MIT, needs HF token | @@ -94,10 +94,10 @@ Measured on VoxConverse dev files on Apple M2 Pro / M2 Max - **Many speakers (8+):** Automatic speaker count estimation degrades. Use ``num_speakers`` when the speaker count is known. -- **Speaker label switching / fragmentation:** On noisy real-world audio, - one actual speaker can be split across multiple ``SPEAKER_XX`` labels, - or the label can briefly jump inside a continuous turn. This is mostly - a clustering and embedding-assignment limitation, and it is visible in +- **Speaker label switching / fragmentation:** Temporal smoothing reduces + short label jumps, but on noisy real-world audio one actual speaker can + still be split across multiple ``SPEAKER_XX`` labels. This is mostly a + clustering and embedding-assignment limitation, and it is visible in transcripts even when aggregate DER looks acceptable. - **Overlapping speech:** DER is computed with ``skip_overlap=True``. The pipeline does not model overlapping speech --- when two people diff --git a/docs/how-it-works.md b/docs/how-it-works.md index ed859a5..463b220 100644 --- a/docs/how-it-works.md +++ b/docs/how-it-works.md @@ -18,7 +18,7 @@ Audio File [3] GMM BIC + silhouette -> Estimated speaker count (k) | v -[4] Spectral + smoothing -> Speaker timeline +[4] Spectral + temporal smoothing -> Speaker timeline | v DiarizeResult @@ -97,11 +97,14 @@ scikit-learn's `SpectralClustering`. The initial spectral labels are refined with spherical centroid reassignment over L2-normalised embeddings. This preserves the selected -speaker count while reducing unstable one-window label flips. +speaker count while reducing unstable label assignments. For long VAD segments, overlapping embedding windows are decoded into -non-overlapping timeline intervals using window-center midpoints. A -3-window majority filter smooths local label noise. Adjacent intervals +non-overlapping timeline intervals using window-center midpoints. Within +each VAD segment, centroid-scored Viterbi decoding applies a speaker +switch penalty to remove short label flicker. Sustained original label +runs are restored before final island cleanup, so the smoother targets +brief A-B-A fragments rather than real speaker turns. Adjacent intervals assigned to the same speaker are merged, and short segments that were skipped during embedding extraction are assigned the label of the nearest speaker. diff --git a/docs/index.md b/docs/index.md index 4d30013..6dbac20 100644 --- a/docs/index.md +++ b/docs/index.md @@ -29,7 +29,7 @@ for seg in result.segments: | GPU required | No | No (7x slower on CPU) | No | | HuggingFace account | No | Yes | Yes | | Auto speaker count | Yes | Yes | Yes | -| DER (VoxConverse dev) | **~5.2%** | ~11.2% | ~8.5% | +| DER (VoxConverse dev) | **~5.0%** | ~11.2% | ~8.5% | | CPU speed (RTF) | **0.12** | 0.86 | --- | DER and speed numbers for pyannote are from their diff --git a/src/diarize/__init__.py b/src/diarize/__init__.py index 240b6d8..9c1a383 100644 --- a/src/diarize/__init__.py +++ b/src/diarize/__init__.py @@ -54,6 +54,9 @@ class _RawSegment(NamedTuple): logger = logging.getLogger(__name__) +_TEMPORAL_SWITCH_PENALTY = 0.18 +_MAX_FRAGMENT_DURATION = 1.2 + def _majority_label(labels: list[int]) -> int | None: """Return the unique majority label, or ``None`` on ties.""" @@ -81,6 +84,159 @@ def _smooth_window_labels(labels: list[int]) -> list[int]: return smoothed +def _normalize_rows(values: np.ndarray) -> np.ndarray: + """Return row-wise L2-normalised values, preserving zero rows.""" + norms = np.linalg.norm(values, axis=1, keepdims=True) + return np.divide( + values, + norms, + out=np.zeros_like(values, dtype=float), + where=norms > 0, + ) + + +def _speaker_centroids( + embeddings: np.ndarray, + labels: np.ndarray, +) -> tuple[list[int], np.ndarray]: + """Build L2-normalised speaker centroids from clustered embeddings.""" + if len(embeddings) == 0 or len(embeddings) != len(labels): + return [], np.empty((0, 0), dtype=float) + + norm_embeddings = _normalize_rows(np.asarray(embeddings, dtype=float)) + label_values = sorted({int(label) for label in labels}) + centroids: list[np.ndarray] = [] + valid_labels: list[int] = [] + + for label in label_values: + members = norm_embeddings[labels == label] + if len(members) == 0: # pragma: no cover - label_values are derived from labels. + continue + centroid = members.mean(axis=0, keepdims=True) + centroid = _normalize_rows(centroid)[0] + if not np.any(centroid): + continue + valid_labels.append(label) + centroids.append(centroid) + + if not centroids: + return [], np.empty((0, norm_embeddings.shape[1]), dtype=float) + + return valid_labels, np.vstack(centroids) + + +def _viterbi_smooth_scores( + scores: np.ndarray, + *, + switch_penalty: float = _TEMPORAL_SWITCH_PENALTY, +) -> list[int]: + """Find the best label path with a penalty for short speaker switches.""" + n_frames, n_labels = scores.shape + if n_frames == 0: + return [] + if n_labels <= 1: + return [0] * n_frames + + dp = np.empty((n_frames, n_labels), dtype=float) + back = np.zeros((n_frames, n_labels), dtype=int) + dp[0] = scores[0] + + transition = np.full((n_labels, n_labels), -switch_penalty, dtype=float) + np.fill_diagonal(transition, 0.0) + + for frame_idx in range(1, n_frames): + previous = dp[frame_idx - 1][:, None] + transition + back[frame_idx] = np.argmax(previous, axis=0) + dp[frame_idx] = scores[frame_idx] + np.max(previous, axis=0) + + path = [int(np.argmax(dp[-1]))] + for frame_idx in range(n_frames - 1, 0, -1): + path.append(int(back[frame_idx, path[-1]])) + path.reverse() + return path + + +def _smooth_window_labels_temporal( + labels: list[int], + embeddings: np.ndarray, + indices: list[int], + label_values: list[int], + centroids: np.ndarray, +) -> list[int]: + """Smooth a single VAD segment using centroid scores and Viterbi decoding.""" + if len(labels) < 3 or len(label_values) <= 1: + return labels + + label_to_idx = {label: idx for idx, label in enumerate(label_values)} + norm_embeddings = _normalize_rows(np.asarray(embeddings[indices], dtype=float)) + scores = norm_embeddings @ centroids.T + + # Small anchor to the original clustering label keeps confident labels + # stable while the transition penalty removes low-value label flicker. + for row_idx, label in enumerate(labels): + label_idx = label_to_idx.get(label) + if label_idx is not None: + scores[row_idx, label_idx] += 0.02 + + path = _viterbi_smooth_scores(scores) + return [label_values[state] for state in path] + + +def _collapse_short_label_islands( + labels: list[int], + windows: list[tuple[float, float]], + *, + max_duration: float = _MAX_FRAGMENT_DURATION, +) -> list[int]: + """Collapse short A-B-A label islands inside a continuous speech segment.""" + if len(labels) < 3 or len(labels) != len(windows): + return labels + + runs: list[tuple[int, int, int, float]] = [] + run_start = 0 + for idx in range(1, len(labels) + 1): + if idx == len(labels) or labels[idx] != labels[run_start]: + duration = windows[idx - 1][1] - windows[run_start][0] + runs.append((run_start, idx, labels[run_start], duration)) + run_start = idx + + if len(runs) < 3: + return labels + + smoothed = labels.copy() + for run_idx in range(1, len(runs) - 1): + start, end, label, duration = runs[run_idx] + _, _, previous_label, _ = runs[run_idx - 1] + _, _, next_label, _ = runs[run_idx + 1] + if previous_label == next_label and label != previous_label and duration <= max_duration: + smoothed[start:end] = [previous_label] * (end - start) + + return smoothed + + +def _restore_sustained_label_runs( + original_labels: list[int], + smoothed_labels: list[int], + windows: list[tuple[float, float]], + *, + min_duration: float = _MAX_FRAGMENT_DURATION, +) -> list[int]: + """Keep sustained original runs even when Viterbi prefers fewer switches.""" + if len(original_labels) != len(smoothed_labels) or len(original_labels) != len(windows): + return smoothed_labels + + restored = smoothed_labels.copy() + run_start = 0 + for idx in range(1, len(original_labels) + 1): + if idx == len(original_labels) or original_labels[idx] != original_labels[run_start]: + duration = windows[idx - 1][1] - windows[run_start][0] + if duration > min_duration: + restored[run_start:idx] = original_labels[run_start:idx] + run_start = idx + + return restored + + def _window_boundaries( speech_segment: SpeechSegment, segment_subsegments: list[SubSegment], @@ -103,17 +259,19 @@ def _build_diarization_segments( speech_segments: list[SpeechSegment], subsegments: list[SubSegment], labels: np.ndarray, + embeddings: np.ndarray | None = None, ) -> list[Segment]: """Assemble diarization segments from subsegments and cluster labels. Overlapping embedding windows are converted to a non-overlapping - timeline and smoothed with a local majority filter. VAD segments + timeline and smoothed with temporal label decoding. VAD segments without embeddings are assigned the nearest speaker. Args: speech_segments: Original speech segments from VAD. subsegments: Embedding windows with parent indices. labels: Cluster labels aligned with *subsegments*. + embeddings: Optional speaker embeddings aligned with *subsegments*. Returns: Merged :class:`Segment` list sorted by start time. @@ -123,15 +281,39 @@ def _build_diarization_segments( for idx, sub in enumerate(subsegments): subsegments_by_parent.setdefault(sub.parent_idx, []).append(idx) + label_values: list[int] = [] + centroids = np.empty((0, 0), dtype=float) + if embeddings is not None: + label_values, centroids = _speaker_centroids(embeddings, labels) + for parent_idx, speech_segment in enumerate(speech_segments): indices = subsegments_by_parent.get(parent_idx) if not indices: continue indices.sort(key=lambda idx: subsegments[idx].start) - parent_labels = _smooth_window_labels([int(labels[idx]) for idx in indices]) + parent_labels = [int(labels[idx]) for idx in indices] parent_subsegments = [subsegments[idx] for idx in indices] windows = _window_boundaries(speech_segment, parent_subsegments) + + if embeddings is not None and len(label_values) > 1: + original_labels = parent_labels + parent_labels = _smooth_window_labels_temporal( + parent_labels, + embeddings, + indices, + label_values, + centroids, + ) + parent_labels = _restore_sustained_label_runs( + original_labels, + parent_labels, + windows, + ) + parent_labels = _collapse_short_label_islands(parent_labels, windows) + else: + parent_labels = _smooth_window_labels(parent_labels) + for (start, end), label in zip(windows, parent_labels): if end <= start: continue @@ -258,6 +440,7 @@ def diarize( speech_segments, subsegments, labels, + embeddings, ) result = DiarizeResult( diff --git a/tests/test_diarize.py b/tests/test_diarize.py index 581d55f..dec799d 100644 --- a/tests/test_diarize.py +++ b/tests/test_diarize.py @@ -660,6 +660,29 @@ def test_majority_helpers_handle_empty_and_ties(self): assert _smooth_window_labels([0, 1]) == [0, 1] assert _smooth_window_labels([0, 1, 1]) == [0, 1, 1] + def test_temporal_helper_edge_cases(self): + from diarize import ( + _collapse_short_label_islands, + _restore_sustained_label_runs, + _speaker_centroids, + _viterbi_smooth_scores, + ) + + labels, centroids = _speaker_centroids(np.ones((2, 2)), np.array([0])) + assert labels == [] + assert centroids.size == 0 + + labels, centroids = _speaker_centroids(np.zeros((2, 2)), np.array([0, 0])) + assert labels == [] + assert centroids.shape == (0, 2) + + assert _viterbi_smooth_scores(np.empty((0, 2))) == [] + assert _viterbi_smooth_scores(np.ones((3, 1))) == [0, 0, 0] + + windows = [(0.0, 0.6), (0.6, 1.2), (1.2, 1.8)] + assert _collapse_short_label_islands([0, 1, 0], windows) == [0, 0, 0] + assert _restore_sustained_label_runs([0], [1], []) == [1] + def test_window_boundaries_empty_and_degenerate_windows(self): from diarize import _build_diarization_segments, _window_boundaries from diarize.utils import SpeechSegment, SubSegment @@ -751,6 +774,96 @@ def test_long_segment_windows_become_non_overlapping(self): assert len(segments) >= 2 assert all(a.end <= b.start for a, b in zip(segments, segments[1:])) + def test_temporal_smoothing_collapses_single_turn_label_switches(self): + from diarize import _build_diarization_segments + from diarize.utils import SpeechSegment, SubSegment + + speech = [SpeechSegment(start=0.0, end=4.2)] + subs = [ + SubSegment(start=0.0, end=1.2, parent_idx=0), + SubSegment(start=0.6, end=1.8, parent_idx=0), + SubSegment(start=1.2, end=2.4, parent_idx=0), + SubSegment(start=1.8, end=3.0, parent_idx=0), + SubSegment(start=2.4, end=3.6, parent_idx=0), + ] + labels = np.array([0, 1, 0, 1, 0]) + embeddings = np.array( + [ + [1.00, 0.00], + [0.99, 0.01], + [1.00, 0.02], + [0.98, 0.01], + [1.00, 0.00], + ] + ) + + segments = _build_diarization_segments(speech, subs, labels, embeddings) + + assert len(segments) == 1 + assert segments[0].speaker == "SPEAKER_00" + + def test_temporal_smoothing_preserves_sustained_speaker_change(self): + from diarize import _build_diarization_segments + from diarize.utils import SpeechSegment, SubSegment + + speech = [SpeechSegment(start=0.0, end=4.2)] + subs = [ + SubSegment(start=0.0, end=1.2, parent_idx=0), + SubSegment(start=0.6, end=1.8, parent_idx=0), + SubSegment(start=1.2, end=2.4, parent_idx=0), + SubSegment(start=1.8, end=3.0, parent_idx=0), + SubSegment(start=2.4, end=3.6, parent_idx=0), + ] + labels = np.array([0, 0, 1, 1, 1]) + embeddings = np.array( + [ + [1.00, 0.00], + [0.98, 0.02], + [0.00, 1.00], + [0.01, 0.99], + [0.00, 1.00], + ] + ) + + segments = _build_diarization_segments(speech, subs, labels, embeddings) + + assert [segment.speaker for segment in segments] == ["SPEAKER_00", "SPEAKER_01"] + + def test_temporal_smoothing_keeps_sustained_interior_run(self): + from diarize import _build_diarization_segments + from diarize.utils import SpeechSegment, SubSegment + + speech = [SpeechSegment(start=0.0, end=5.4)] + subs = [ + SubSegment(start=0.0, end=1.2, parent_idx=0), + SubSegment(start=0.6, end=1.8, parent_idx=0), + SubSegment(start=1.2, end=2.4, parent_idx=0), + SubSegment(start=1.8, end=3.0, parent_idx=0), + SubSegment(start=2.4, end=3.6, parent_idx=0), + SubSegment(start=3.0, end=4.2, parent_idx=0), + SubSegment(start=3.6, end=4.8, parent_idx=0), + ] + labels = np.array([1, 1, 0, 0, 0, 1, 1]) + embeddings = np.array( + [ + [1.00, 0.00], + [0.99, 0.01], + [0.98, 0.02], + [0.99, 0.01], + [1.00, 0.00], + [0.99, 0.01], + [1.00, 0.00], + ] + ) + + segments = _build_diarization_segments(speech, subs, labels, embeddings) + + assert [segment.speaker for segment in segments] == [ + "SPEAKER_01", + "SPEAKER_00", + "SPEAKER_01", + ] + def test_short_segment_assigned_nearest_speaker(self): """VAD segments without embeddings should get the nearest speaker.""" from diarize import _build_diarization_segments