Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 78 additions & 15 deletions src/diarize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import logging
from collections import Counter
from pathlib import Path
from typing import NamedTuple

Expand Down Expand Up @@ -54,15 +55,59 @@ class _RawSegment(NamedTuple):
logger = logging.getLogger(__name__)


def _majority_label(labels: list[int]) -> int | None:
"""Return the unique majority label, or ``None`` on ties."""
if not labels:
return None

counts = Counter(labels)
best_label, best_count = counts.most_common(1)[0]
if sum(1 for count in counts.values() if count == best_count) > 1:
return None
return int(best_label)


def _smooth_window_labels(labels: list[int]) -> list[int]:
"""Apply a 3-window majority filter while preserving ties."""
if len(labels) < 3:
return labels

smoothed: list[int] = []
for idx, label in enumerate(labels):
window = labels[max(0, idx - 1) : min(len(labels), idx + 2)]
majority = _majority_label(window)
smoothed.append(label if majority is None else majority)

return smoothed


def _window_boundaries(
speech_segment: SpeechSegment,
segment_subsegments: list[SubSegment],
) -> list[tuple[float, float]]:
"""Convert overlapping windows into non-overlapping intervals."""
if not segment_subsegments:
return []

centers = [(sub.start + sub.end) / 2 for sub in segment_subsegments]
boundaries = [speech_segment.start]
for left, right in zip(centers, centers[1:]):
midpoint = (left + right) / 2
boundaries.append(min(speech_segment.end, max(speech_segment.start, midpoint)))
boundaries.append(speech_segment.end)

return list(zip(boundaries[:-1], boundaries[1:]))


def _build_diarization_segments(
speech_segments: list[SpeechSegment],
subsegments: list[SubSegment],
labels: np.ndarray,
) -> list[Segment]:
"""Assemble diarization segments from subsegments and cluster labels.

Each subsegment (embedding window) gets its speaker label. Adjacent
subsegments from the same speaker are merged. Short VAD segments
Overlapping embedding windows are converted to a non-overlapping
timeline and smoothed with a local majority filter. VAD segments
without embeddings are assigned the nearest speaker.

Args:
Expand All @@ -73,16 +118,30 @@ def _build_diarization_segments(
Returns:
Merged :class:`Segment` list sorted by start time.
"""
# Build raw segments from subsegments + labels
raw_segments: list[_RawSegment] = []
for sub, label in zip(subsegments, labels):
raw_segments.append(
_RawSegment(
start=sub.start,
end=sub.end,
speaker=f"SPEAKER_{int(label):02d}",
subsegments_by_parent: dict[int, list[int]] = {}
for idx, sub in enumerate(subsegments):
subsegments_by_parent.setdefault(sub.parent_idx, []).append(idx)

for parent_idx, speech_segment in enumerate(speech_segments):
indices = subsegments_by_parent.get(parent_idx)
if not indices:
continue

indices.sort(key=lambda idx: subsegments[idx].start)
parent_labels = _smooth_window_labels([int(labels[idx]) for idx in indices])
parent_subsegments = [subsegments[idx] for idx in indices]
windows = _window_boundaries(speech_segment, parent_subsegments)
for (start, end), label in zip(windows, parent_labels):
if end <= start:
continue
raw_segments.append(
_RawSegment(
start=start,
end=end,
speaker=f"SPEAKER_{label:02d}",
)
)
)

# Add short VAD segments that were skipped during embedding extraction
covered_indices = {sub.parent_idx for sub in subsegments}
Expand All @@ -93,12 +152,12 @@ def _build_diarization_segments(
seg_mid = (seg.start + seg.end) / 2
best_speaker = "SPEAKER_00"
best_dist = float("inf")
for sub, label in zip(subsegments, labels):
sub_mid = (sub.start + sub.end) / 2
dist = abs(seg_mid - sub_mid)
for raw in raw_segments:
raw_mid = (raw.start + raw.end) / 2
dist = abs(seg_mid - raw_mid)
if dist < best_dist:
best_dist = dist
best_speaker = f"SPEAKER_{int(label):02d}"
best_speaker = raw.speaker
raw_segments.append(_RawSegment(start=seg.start, end=seg.end, speaker=best_speaker))

# Sort by time
Expand Down Expand Up @@ -195,7 +254,11 @@ def diarize(
)

# 4. Build result
segments = _build_diarization_segments(speech_segments, subsegments, labels)
segments = _build_diarization_segments(
speech_segments,
subsegments,
labels,
)

result = DiarizeResult(
segments=segments,
Expand Down
77 changes: 73 additions & 4 deletions src/diarize/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,61 @@ def estimate_speakers(
# ── Spectral Clustering ──────────────────────────────────────────────────────


def _refine_labels_spherical(
embeddings: np.ndarray,
labels: np.ndarray,
*,
max_iter: int = 8,
) -> np.ndarray:
"""Refine cluster labels with spherical centroid reassignment."""
n = len(embeddings)
if n == 0:
return labels

unique_labels = sorted({int(label) for label in labels})
k = len(unique_labels)
if k <= 1:
return np.zeros(n, dtype=int)

label_map = {label: idx for idx, label in enumerate(unique_labels)}
refined = np.array([label_map[int(label)] for label in labels], dtype=int)
emb = normalize(embeddings, norm="l2")

for _ in range(max_iter):
centroids = np.zeros((k, emb.shape[1]), dtype=float)
valid = np.zeros(k, dtype=bool)
for label in range(k):
members = emb[refined == label]
if len(members) == 0: # pragma: no cover - defensive against future label changes.
continue
centroid = members.mean(axis=0)
norm = float(np.linalg.norm(centroid))
if norm == 0:
continue
centroids[label] = centroid / norm
valid[label] = True

if not np.all(valid):
break

scores = emb @ centroids.T
updated = np.argmax(scores, axis=1)
if len(set(updated)) < k:
break
if np.array_equal(updated, refined):
break
refined = updated

return refined


def cluster_spectral(embeddings: np.ndarray, k: int) -> np.ndarray:
"""Cluster embeddings into *k* speakers using Spectral Clustering.

Uses cosine similarity as the affinity metric, rescaled to [0, 1].
The spectral assignment is then refined with a few spherical
centroid-reassignment iterations, which reduces noisy window labels
while preserving the selected number of speakers.

Args:
embeddings: Speaker embeddings of shape ``(N, D)``.
Expand Down Expand Up @@ -212,13 +263,31 @@ def cluster_spectral(embeddings: np.ndarray, k: int) -> np.ndarray:
n_init=10,
)
labels: np.ndarray = sc.fit_predict(affinity)
labels = _refine_labels_spherical(embeddings, labels)
logger.debug("Spectral clustering: %d clusters", k)
return labels


# ── High-level wrappers ──────────────────────────────────────────────────────


def _silhouette_candidate_counts(
k: int,
n_embeddings: int,
min_speakers: int,
max_speakers: int,
) -> list[int]:
"""Return valid speaker counts for silhouette refinement."""
if k < 2 or n_embeddings < 4:
return []

lower = max(2, min_speakers, k - 2)
upper = min(max_speakers, n_embeddings - 1, k + 3)
if upper < lower:
return []
return list(range(lower, upper + 1))


def cluster_auto(
embeddings: np.ndarray,
min_speakers: int = 1,
Expand All @@ -242,10 +311,10 @@ def cluster_auto(
k, details = estimate_speakers(embeddings, min_speakers, max_speakers)
n = len(embeddings)

# Silhouette refinement: BIC tends to undercount speakers.
# Try k, k+1, k+2 and pick the k with the best silhouette score.
if k >= 2 and n >= 4:
candidates = [c for c in range(k, k + 3) if c <= min(max_speakers, n - 1)]
# Silhouette refinement: use BIC as an anchor, then score a small
# neighbourhood around it. This catches both undercounts and overcounts.
if k >= 2:
candidates = _silhouette_candidate_counts(k, n, min_speakers, max_speakers)
if len(candidates) > 1:
distance = np.maximum(1 - (cosine_similarity(embeddings) + 1) / 2, 0)
best_k, best_labels, best_sil = k, None, -1.0
Expand Down
Loading
Loading