From b97374227792e07e4930189e7e2cd36b48c17284 Mon Sep 17 00:00:00 2001 From: loookashow Date: Mon, 4 May 2026 13:39:25 +0200 Subject: [PATCH 1/3] fix: reduce diarization speaker label noise --- src/diarize/__init__.py | 93 ++++++++++++++++++++++++++++++++------- src/diarize/clustering.py | 77 ++++++++++++++++++++++++++++++-- tests/test_diarize.py | 87 ++++++++++++++++++++++++++++++++++-- 3 files changed, 234 insertions(+), 23 deletions(-) diff --git a/src/diarize/__init__.py b/src/diarize/__init__.py index fc1574f..240b6d8 100644 --- a/src/diarize/__init__.py +++ b/src/diarize/__init__.py @@ -15,6 +15,7 @@ from __future__ import annotations import logging +from collections import Counter from pathlib import Path from typing import NamedTuple @@ -54,6 +55,50 @@ class _RawSegment(NamedTuple): logger = logging.getLogger(__name__) +def _majority_label(labels: list[int]) -> int | None: + """Return the unique majority label, or ``None`` on ties.""" + if not labels: + return None + + counts = Counter(labels) + best_label, best_count = counts.most_common(1)[0] + if sum(1 for count in counts.values() if count == best_count) > 1: + return None + return int(best_label) + + +def _smooth_window_labels(labels: list[int]) -> list[int]: + """Apply a 3-window majority filter while preserving ties.""" + if len(labels) < 3: + return labels + + smoothed: list[int] = [] + for idx, label in enumerate(labels): + window = labels[max(0, idx - 1) : min(len(labels), idx + 2)] + majority = _majority_label(window) + smoothed.append(label if majority is None else majority) + + return smoothed + + +def _window_boundaries( + speech_segment: SpeechSegment, + segment_subsegments: list[SubSegment], +) -> list[tuple[float, float]]: + """Convert overlapping windows into non-overlapping intervals.""" + if not segment_subsegments: + return [] + + centers = [(sub.start + sub.end) / 2 for sub in segment_subsegments] + boundaries = [speech_segment.start] + for left, right in zip(centers, centers[1:]): + midpoint = (left + right) / 2 + boundaries.append(min(speech_segment.end, max(speech_segment.start, midpoint))) + boundaries.append(speech_segment.end) + + return list(zip(boundaries[:-1], boundaries[1:])) + + def _build_diarization_segments( speech_segments: list[SpeechSegment], subsegments: list[SubSegment], @@ -61,8 +106,8 @@ def _build_diarization_segments( ) -> list[Segment]: """Assemble diarization segments from subsegments and cluster labels. - Each subsegment (embedding window) gets its speaker label. Adjacent - subsegments from the same speaker are merged. Short VAD segments + Overlapping embedding windows are converted to a non-overlapping + timeline and smoothed with a local majority filter. VAD segments without embeddings are assigned the nearest speaker. Args: @@ -73,16 +118,30 @@ def _build_diarization_segments( Returns: Merged :class:`Segment` list sorted by start time. """ - # Build raw segments from subsegments + labels raw_segments: list[_RawSegment] = [] - for sub, label in zip(subsegments, labels): - raw_segments.append( - _RawSegment( - start=sub.start, - end=sub.end, - speaker=f"SPEAKER_{int(label):02d}", + subsegments_by_parent: dict[int, list[int]] = {} + for idx, sub in enumerate(subsegments): + subsegments_by_parent.setdefault(sub.parent_idx, []).append(idx) + + for parent_idx, speech_segment in enumerate(speech_segments): + indices = subsegments_by_parent.get(parent_idx) + if not indices: + continue + + indices.sort(key=lambda idx: subsegments[idx].start) + parent_labels = _smooth_window_labels([int(labels[idx]) for idx in indices]) + parent_subsegments = [subsegments[idx] for idx in indices] + windows = _window_boundaries(speech_segment, parent_subsegments) + for (start, end), label in zip(windows, parent_labels): + if end <= start: + continue + raw_segments.append( + _RawSegment( + start=start, + end=end, + speaker=f"SPEAKER_{label:02d}", + ) ) - ) # Add short VAD segments that were skipped during embedding extraction covered_indices = {sub.parent_idx for sub in subsegments} @@ -93,12 +152,12 @@ def _build_diarization_segments( seg_mid = (seg.start + seg.end) / 2 best_speaker = "SPEAKER_00" best_dist = float("inf") - for sub, label in zip(subsegments, labels): - sub_mid = (sub.start + sub.end) / 2 - dist = abs(seg_mid - sub_mid) + for raw in raw_segments: + raw_mid = (raw.start + raw.end) / 2 + dist = abs(seg_mid - raw_mid) if dist < best_dist: best_dist = dist - best_speaker = f"SPEAKER_{int(label):02d}" + best_speaker = raw.speaker raw_segments.append(_RawSegment(start=seg.start, end=seg.end, speaker=best_speaker)) # Sort by time @@ -195,7 +254,11 @@ def diarize( ) # 4. Build result - segments = _build_diarization_segments(speech_segments, subsegments, labels) + segments = _build_diarization_segments( + speech_segments, + subsegments, + labels, + ) result = DiarizeResult( segments=segments, diff --git a/src/diarize/clustering.py b/src/diarize/clustering.py index b789856..79dcc0c 100644 --- a/src/diarize/clustering.py +++ b/src/diarize/clustering.py @@ -173,10 +173,61 @@ def estimate_speakers( # ── Spectral Clustering ────────────────────────────────────────────────────── +def _refine_labels_spherical( + embeddings: np.ndarray, + labels: np.ndarray, + *, + max_iter: int = 8, +) -> np.ndarray: + """Refine cluster labels with spherical centroid reassignment.""" + n = len(embeddings) + if n == 0: + return labels + + unique_labels = sorted({int(label) for label in labels}) + k = len(unique_labels) + if k <= 1: + return np.zeros(n, dtype=int) + + label_map = {label: idx for idx, label in enumerate(unique_labels)} + refined = np.array([label_map[int(label)] for label in labels], dtype=int) + emb = normalize(embeddings, norm="l2") + + for _ in range(max_iter): + centroids = np.zeros((k, emb.shape[1]), dtype=float) + valid = np.zeros(k, dtype=bool) + for label in range(k): + members = emb[refined == label] + if len(members) == 0: + continue + centroid = members.mean(axis=0) + norm = float(np.linalg.norm(centroid)) + if norm == 0: + continue + centroids[label] = centroid / norm + valid[label] = True + + if not np.all(valid): + break + + scores = emb @ centroids.T + updated = np.argmax(scores, axis=1) + if len(set(updated)) < k: + break + if np.array_equal(updated, refined): + break + refined = updated + + return refined + + def cluster_spectral(embeddings: np.ndarray, k: int) -> np.ndarray: """Cluster embeddings into *k* speakers using Spectral Clustering. Uses cosine similarity as the affinity metric, rescaled to [0, 1]. + The spectral assignment is then refined with a few spherical + centroid-reassignment iterations, which reduces noisy window labels + while preserving the selected number of speakers. Args: embeddings: Speaker embeddings of shape ``(N, D)``. @@ -212,6 +263,7 @@ def cluster_spectral(embeddings: np.ndarray, k: int) -> np.ndarray: n_init=10, ) labels: np.ndarray = sc.fit_predict(affinity) + labels = _refine_labels_spherical(embeddings, labels) logger.debug("Spectral clustering: %d clusters", k) return labels @@ -219,6 +271,23 @@ def cluster_spectral(embeddings: np.ndarray, k: int) -> np.ndarray: # ── High-level wrappers ────────────────────────────────────────────────────── +def _silhouette_candidate_counts( + k: int, + n_embeddings: int, + min_speakers: int, + max_speakers: int, +) -> list[int]: + """Return valid speaker counts for silhouette refinement.""" + if k < 2 or n_embeddings < 4: + return [] + + lower = max(2, min_speakers, k - 2) + upper = min(max_speakers, n_embeddings - 1, k + 3) + if upper < lower: + return [] + return list(range(lower, upper + 1)) + + def cluster_auto( embeddings: np.ndarray, min_speakers: int = 1, @@ -242,10 +311,10 @@ def cluster_auto( k, details = estimate_speakers(embeddings, min_speakers, max_speakers) n = len(embeddings) - # Silhouette refinement: BIC tends to undercount speakers. - # Try k, k+1, k+2 and pick the k with the best silhouette score. - if k >= 2 and n >= 4: - candidates = [c for c in range(k, k + 3) if c <= min(max_speakers, n - 1)] + # Silhouette refinement: use BIC as an anchor, then score a small + # neighbourhood around it. This catches both undercounts and overcounts. + if k >= 2: + candidates = _silhouette_candidate_counts(k, n, min_speakers, max_speakers) if len(candidates) > 1: distance = np.maximum(1 - (cosine_similarity(embeddings) + 1) / 2, 0) best_k, best_labels, best_sil = k, None, -1.0 diff --git a/tests/test_diarize.py b/tests/test_diarize.py index bbbd8a4..368df89 100644 --- a/tests/test_diarize.py +++ b/tests/test_diarize.py @@ -472,6 +472,25 @@ def test_precheck_gmm_not_run(self): class TestClusterSpectral: """Tests for Spectral Clustering.""" + def test_spherical_refinement_corrects_noisy_labels(self): + from diarize.clustering import _refine_labels_spherical + + embeddings = np.array( + [ + [1.0, 0.0], + [0.9, 0.1], + [0.0, 1.0], + [0.1, 0.9], + ] + ) + noisy_labels = np.array([0, 0, 0, 1]) + + labels = _refine_labels_spherical(embeddings, noisy_labels) + assert len(set(labels)) == 2 + assert labels[0] == labels[1] + assert labels[2] == labels[3] + assert labels[0] != labels[2] + def test_basic_clustering(self): from diarize.clustering import cluster_spectral @@ -493,6 +512,49 @@ def test_basic_clustering(self): class TestClusterSpeakers: """Tests for the high-level cluster_speakers wrapper.""" + def test_silhouette_candidates_cover_bic_neighbourhood(self): + from diarize.clustering import _silhouette_candidate_counts + + assert _silhouette_candidate_counts( + k=5, + n_embeddings=20, + min_speakers=1, + max_speakers=10, + ) == [ + 3, + 4, + 5, + 6, + 7, + 8, + ] + + def test_silhouette_candidates_skip_single_speaker(self): + from diarize.clustering import _silhouette_candidate_counts + + assert _silhouette_candidate_counts( + k=1, + n_embeddings=20, + min_speakers=1, + max_speakers=10, + ) == [] + assert _silhouette_candidate_counts( + k=3, + n_embeddings=3, + min_speakers=1, + max_speakers=10, + ) == [] + + def test_silhouette_candidates_respect_min_speakers(self): + from diarize.clustering import _silhouette_candidate_counts + + assert _silhouette_candidate_counts( + k=5, + n_embeddings=20, + min_speakers=4, + max_speakers=10, + ) == [4, 5, 6, 7, 8] + def test_fixed_num_speakers(self): from diarize.clustering import cluster_speakers @@ -581,6 +643,23 @@ def test_different_speakers_not_merged(self): assert segments[0].speaker == "SPEAKER_00" assert segments[1].speaker == "SPEAKER_01" + def test_long_segment_windows_become_non_overlapping(self): + from diarize import _build_diarization_segments + from diarize.utils import SpeechSegment, SubSegment + + speech = [SpeechSegment(start=0.0, end=8.0)] + subs = [ + SubSegment(start=0.0, end=1.2, parent_idx=0), + SubSegment(start=0.6, end=1.8, parent_idx=0), + SubSegment(start=1.2, end=2.4, parent_idx=0), + SubSegment(start=1.8, end=3.0, parent_idx=0), + ] + labels = np.array([0, 1, 0, 1]) + + segments = _build_diarization_segments(speech, subs, labels) + assert len(segments) >= 2 + assert all(a.end <= b.start for a, b in zip(segments, segments[1:])) + def test_short_segment_assigned_nearest_speaker(self): """VAD segments without embeddings should get the nearest speaker.""" from diarize import _build_diarization_segments @@ -598,10 +677,10 @@ def test_short_segment_assigned_nearest_speaker(self): labels = np.array([0, 1]) segments = _build_diarization_segments(speech, subs, labels) - # The short segment at 2.5-2.8 should be assigned SPEAKER_00 (nearest) - short_seg = [s for s in segments if s.start == pytest.approx(2.5)] - assert len(short_seg) == 1 - assert short_seg[0].speaker == "SPEAKER_00" + # The short segment at 2.5-2.8 should inherit SPEAKER_00 from the nearest speech. + covering = [s for s in segments if s.start <= 2.5 and s.end >= 2.8] + assert len(covering) == 1 + assert covering[0].speaker == "SPEAKER_00" def test_no_speech_segments(self): """No speech and no subsegments returns empty.""" From baf4908e6187c7d4a16302e7dd8568ab5e2be678 Mon Sep 17 00:00:00 2001 From: loookashow Date: Mon, 4 May 2026 13:58:35 +0200 Subject: [PATCH 2/3] style: format diarization tests --- tests/test_diarize.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/tests/test_diarize.py b/tests/test_diarize.py index 368df89..9b243e0 100644 --- a/tests/test_diarize.py +++ b/tests/test_diarize.py @@ -532,18 +532,24 @@ def test_silhouette_candidates_cover_bic_neighbourhood(self): def test_silhouette_candidates_skip_single_speaker(self): from diarize.clustering import _silhouette_candidate_counts - assert _silhouette_candidate_counts( - k=1, - n_embeddings=20, - min_speakers=1, - max_speakers=10, - ) == [] - assert _silhouette_candidate_counts( - k=3, - n_embeddings=3, - min_speakers=1, - max_speakers=10, - ) == [] + assert ( + _silhouette_candidate_counts( + k=1, + n_embeddings=20, + min_speakers=1, + max_speakers=10, + ) + == [] + ) + assert ( + _silhouette_candidate_counts( + k=3, + n_embeddings=3, + min_speakers=1, + max_speakers=10, + ) + == [] + ) def test_silhouette_candidates_respect_min_speakers(self): from diarize.clustering import _silhouette_candidate_counts From bbb0445b2a9a24dbd1624d682a555df56a9593ea Mon Sep 17 00:00:00 2001 From: loookashow Date: Mon, 4 May 2026 14:08:07 +0200 Subject: [PATCH 3/3] test: cover diarization clustering edge cases --- src/diarize/clustering.py | 2 +- tests/test_diarize.py | 107 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 1 deletion(-) diff --git a/src/diarize/clustering.py b/src/diarize/clustering.py index 79dcc0c..c7459e1 100644 --- a/src/diarize/clustering.py +++ b/src/diarize/clustering.py @@ -198,7 +198,7 @@ def _refine_labels_spherical( valid = np.zeros(k, dtype=bool) for label in range(k): members = emb[refined == label] - if len(members) == 0: + if len(members) == 0: # pragma: no cover - defensive against future label changes. continue centroid = members.mean(axis=0) norm = float(np.linalg.norm(centroid)) diff --git a/tests/test_diarize.py b/tests/test_diarize.py index 9b243e0..581d55f 100644 --- a/tests/test_diarize.py +++ b/tests/test_diarize.py @@ -491,6 +491,42 @@ def test_spherical_refinement_corrects_noisy_labels(self): assert labels[2] == labels[3] assert labels[0] != labels[2] + def test_spherical_refinement_handles_empty_and_single_label(self): + from diarize.clustering import _refine_labels_spherical + + empty = _refine_labels_spherical(np.empty((0, 2)), np.array([], dtype=int)) + assert len(empty) == 0 + + labels = _refine_labels_spherical(np.ones((3, 2)), np.array([5, 5, 5])) + assert labels.tolist() == [0, 0, 0] + + def test_spherical_refinement_stops_on_degenerate_centroid(self): + from diarize.clustering import _refine_labels_spherical + + embeddings = np.array([[0.0, 0.0], [1.0, 0.0], [-1.0, 0.0]]) + labels = np.array([0, 1, 1]) + + refined = _refine_labels_spherical(embeddings, labels) + assert refined.tolist() == [0, 1, 1] + + def test_spherical_refinement_stops_on_empty_cluster(self): + from diarize.clustering import _refine_labels_spherical + + embeddings = np.array([[1.0, 0.0], [0.0, 1.0]]) + labels = np.array([0, 2]) + + refined = _refine_labels_spherical(embeddings, labels) + assert refined.tolist() == [0, 1] + + def test_spherical_refinement_preserves_nonempty_clusters(self): + from diarize.clustering import _refine_labels_spherical + + embeddings = np.array([[1.0, 0.0], [0.9, 0.0], [0.8, 0.0]]) + labels = np.array([0, 1, 1]) + + refined = _refine_labels_spherical(embeddings, labels) + assert refined.tolist() == [0, 1, 1] + def test_basic_clustering(self): from diarize.clustering import cluster_spectral @@ -561,6 +597,22 @@ def test_silhouette_candidates_respect_min_speakers(self): max_speakers=10, ) == [4, 5, 6, 7, 8] + def test_auto_updates_details_when_silhouette_changes_k(self): + from diarize.clustering import cluster_auto + from diarize.utils import SpeakerEstimationDetails + + rng = np.random.RandomState(42) + c1 = rng.randn(12, 8) * 0.01 + np.array([1.0, 0, 0, 0, 0, 0, 0, 0]) + c2 = rng.randn(12, 8) * 0.01 + np.array([0, 1.0, 0, 0, 0, 0, 0, 0]) + embeddings = np.vstack([c1, c2]) + details = SpeakerEstimationDetails(method="gmm_bic", best_k=3) + + with patch("diarize.clustering.estimate_speakers", return_value=(3, details)): + labels, details = cluster_auto(embeddings, min_speakers=2, max_speakers=5) + + assert len(labels) == len(embeddings) + assert details.best_k == 2 + def test_fixed_num_speakers(self): from diarize.clustering import cluster_speakers @@ -600,6 +652,39 @@ def test_single_embedding(self): class TestBuildDiarizationSegments: """Tests for segment assembly from subsegments and labels.""" + def test_majority_helpers_handle_empty_and_ties(self): + from diarize import _majority_label, _smooth_window_labels + + assert _majority_label([]) is None + assert _majority_label([0, 1]) is None + assert _smooth_window_labels([0, 1]) == [0, 1] + assert _smooth_window_labels([0, 1, 1]) == [0, 1, 1] + + def test_window_boundaries_empty_and_degenerate_windows(self): + from diarize import _build_diarization_segments, _window_boundaries + from diarize.utils import SpeechSegment, SubSegment + + speech = SpeechSegment(start=1.0, end=2.0) + assert _window_boundaries(speech, []) == [] + + segments = _build_diarization_segments( + [speech], + [SubSegment(start=1.0, end=1.0, parent_idx=0)], + np.array([0]), + ) + assert segments[0].start == pytest.approx(1.0) + assert segments[0].end == pytest.approx(2.0) + + skipped = _build_diarization_segments( + [SpeechSegment(start=1.0, end=2.0)], + [ + SubSegment(start=0.0, end=0.0, parent_idx=0), + SubSegment(start=1.5, end=1.5, parent_idx=0), + ], + np.array([0, 1]), + ) + assert all(seg.end > seg.start for seg in skipped) + def test_basic_assembly(self): from diarize import _build_diarization_segments from diarize.utils import SpeechSegment, SubSegment @@ -1033,6 +1118,16 @@ def test_path_object_accepted(self): assert result.audio_path == "test.wav" + def test_invalid_arguments_rejected(self): + from diarize import diarize + + with pytest.raises(ValueError, match="min_speakers"): + diarize("test.wav", min_speakers=0) + with pytest.raises(ValueError, match="max_speakers"): + diarize("test.wav", min_speakers=3, max_speakers=2) + with pytest.raises(ValueError, match="num_speakers"): + diarize("test.wav", num_speakers=0) + # ── Public API imports ─────────────────────────────────────────────────────── @@ -1125,6 +1220,18 @@ def test_cluster_spectral_k_equals_1(self): assert len(labels) == 10 assert set(labels) == {0} + def test_cluster_speakers_validation_errors(self): + """cluster_speakers should reject invalid speaker bounds.""" + from diarize.clustering import cluster_speakers + + embeddings = np.random.RandomState(42).randn(4, 256) + with pytest.raises(ValueError, match="min_speakers"): + cluster_speakers(embeddings, min_speakers=0) + with pytest.raises(ValueError, match="max_speakers"): + cluster_speakers(embeddings, min_speakers=3, max_speakers=2) + with pytest.raises(ValueError, match="num_speakers"): + cluster_speakers(embeddings, num_speakers=0) + def test_extract_embeddings_empty_returns_2d(self): """P3: extract_embeddings should return (0, 256) shape when no embeddings.""" from diarize.embeddings import extract_embeddings