Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ for seg in result.segments:
print(f" [{seg.start:.1f}s - {seg.end:.1f}s] {seg.speaker}")
```

**~5.2% weighted DER** on VoxConverse dev. Processes audio **~8x faster than real-time** on CPU. Automatically detects the number of speakers.
**~5.0% weighted DER** on VoxConverse dev. Processes audio **~8x faster than real-time** on CPU. Automatically detects the number of speakers.

> Benchmarked on a single dataset ([VoxConverse](https://github.com/joonson/voxconverse)). Cross-dataset validation is [in progress](#roadmap).

Expand All @@ -35,7 +35,7 @@ for seg in result.segments:
| GPU required | No | No (7x slower on CPU) | No |
| HuggingFace account | No | Yes | Yes |
| Auto speaker count | Yes | Yes | Yes |
| DER (VoxConverse dev) | **~5.2%** | ~11.2% | ~8.5% |
| DER (VoxConverse dev) | **~5.0%** | ~11.2% | ~8.5% |
| CPU speed (RTF) | **0.12** | 0.86 | — |
| Install | `pip install diarize` | `pip install pyannote.audio` | `pip install pyannote.audio` |

Expand Down Expand Up @@ -89,7 +89,7 @@ Four-stage pipeline, all CPU, all open-source:
1. **Silero VAD** (MIT) — detects speech segments
2. **WeSpeaker ResNet34-LM** (Apache 2.0) — extracts 256-dim speaker embeddings via ONNX
3. **GMM BIC + silhouette refinement** — estimates the number of speakers
4. **Spectral Clustering** (scikit-learn, BSD) + smoothing — assigns speaker labels
4. **Spectral Clustering** (scikit-learn, BSD) + temporal smoothing — assigns speaker labels

Details: [How It Works](https://foxnosetech.github.io/diarize/how-it-works/)

Expand All @@ -102,7 +102,7 @@ Evaluated on [VoxConverse](https://github.com/joonson/voxconverse) dev set (216
| System | Weighted DER | Notes |
|--------|----------|-------|
| pyannote precision-2 | ~8.5% | Commercial license |
| **diarize** | **~5.2%** | **Apache 2.0, CPU-only, no API key** |
| **diarize** | **~5.0%** | **Apache 2.0, CPU-only, no API key** |
| pyannote community-1 | ~11.2% | CC-BY-4.0, needs HF token |
| pyannote 3.1 (legacy) | ~11.2% | MIT, needs HF token |

Expand All @@ -121,7 +121,7 @@ Full benchmark results, speed comparison, and methodology: [benchmarks](https://
## When to use something else

- **You need commercial support or cross-dataset validation.** pyannote's commercial model has published production-oriented benchmarks beyond this single VoxConverse evaluation. If accuracy is the top priority and you have budget, compare on your own data.
- **You need very stable speaker labels in transcripts.** diarize can still show speaker fragmentation / label switching: one real speaker may be split across multiple `SPEAKER_XX` labels, or the label may briefly jump inside a continuous turn, especially on noisy real-world audio.
- **You need very stable speaker labels in transcripts.** Temporal smoothing reduces short label jumps, but diarize can still show speaker fragmentation / label switching: one real speaker may be split across multiple `SPEAKER_XX` labels, especially on noisy real-world audio.
- **Your audio has 8+ speakers.** Automatic speaker count estimation degrades above 7 speakers. You can pass `num_speakers` explicitly, but test carefully.
- **You need overlapping speech detection.** diarize assigns each segment to one speaker. Overlapping speech is not modeled.
- **You need GPU-accelerated throughput.** diarize is CPU-only by design. For processing thousands of hours with GPU infrastructure, NeMo or pyannote on GPU will be faster.
Expand Down
10 changes: 5 additions & 5 deletions docs/benchmarks.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ DER is the standard metric for speaker diarization, computed with
| System | Weighted DER | Median DER | Notes |
|--------|----------|------------|-------|
| pyannote precision-2 | ~8.5% | -- | Commercial license |
| **diarize** | **~5.2%** | **~2.4%** | **Apache 2.0, CPU-only, no API key** |
| **diarize** | **~5.0%** | **~2.2%** | **Apache 2.0, CPU-only, no API key** |
| pyannote community-1 | ~11.2% | -- | CC-BY-4.0, needs HF token |
| pyannote 3.1 (legacy) | ~11.2% | -- | MIT, needs HF token |

Expand Down Expand Up @@ -94,10 +94,10 @@ Measured on VoxConverse dev files on Apple M2 Pro / M2 Max

- **Many speakers (8+):** Automatic speaker count estimation degrades.
Use ``num_speakers`` when the speaker count is known.
- **Speaker label switching / fragmentation:** On noisy real-world audio,
one actual speaker can be split across multiple ``SPEAKER_XX`` labels,
or the label can briefly jump inside a continuous turn. This is mostly
a clustering and embedding-assignment limitation, and it is visible in
- **Speaker label switching / fragmentation:** Temporal smoothing reduces
short label jumps, but on noisy real-world audio one actual speaker can
still be split across multiple ``SPEAKER_XX`` labels. This is mostly a
clustering and embedding-assignment limitation, and it is visible in
transcripts even when aggregate DER looks acceptable.
- **Overlapping speech:** DER is computed with ``skip_overlap=True``.
The pipeline does not model overlapping speech --- when two people
Expand Down
11 changes: 7 additions & 4 deletions docs/how-it-works.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Audio File
[3] GMM BIC + silhouette -> Estimated speaker count (k)
|
v
[4] Spectral + smoothing -> Speaker timeline
[4] Spectral + temporal smoothing -> Speaker timeline
|
v
DiarizeResult
Expand Down Expand Up @@ -97,11 +97,14 @@ scikit-learn's `SpectralClustering`.

The initial spectral labels are refined with spherical centroid
reassignment over L2-normalised embeddings. This preserves the selected
speaker count while reducing unstable one-window label flips.
speaker count while reducing unstable label assignments.

For long VAD segments, overlapping embedding windows are decoded into
non-overlapping timeline intervals using window-center midpoints. A
3-window majority filter smooths local label noise. Adjacent intervals
non-overlapping timeline intervals using window-center midpoints. Within
each VAD segment, centroid-scored Viterbi decoding applies a speaker
switch penalty to remove short label flicker. Sustained original label
runs are restored before final island cleanup, so the smoother targets
brief A-B-A fragments rather than real speaker turns. Adjacent intervals
assigned to the same speaker are merged, and short segments that were
skipped during embedding extraction are assigned the label of the nearest
speaker.
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ for seg in result.segments:
| GPU required | No | No (7x slower on CPU) | No |
| HuggingFace account | No | Yes | Yes |
| Auto speaker count | Yes | Yes | Yes |
| DER (VoxConverse dev) | **~5.2%** | ~11.2% | ~8.5% |
| DER (VoxConverse dev) | **~5.0%** | ~11.2% | ~8.5% |
| CPU speed (RTF) | **0.12** | 0.86 | --- |

DER and speed numbers for pyannote are from their
Expand Down
187 changes: 185 additions & 2 deletions src/diarize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class _RawSegment(NamedTuple):

logger = logging.getLogger(__name__)

_TEMPORAL_SWITCH_PENALTY = 0.18
_MAX_FRAGMENT_DURATION = 1.2


def _majority_label(labels: list[int]) -> int | None:
"""Return the unique majority label, or ``None`` on ties."""
Expand Down Expand Up @@ -81,6 +84,159 @@ def _smooth_window_labels(labels: list[int]) -> list[int]:
return smoothed


def _normalize_rows(values: np.ndarray) -> np.ndarray:
"""Return row-wise L2-normalised values, preserving zero rows."""
norms = np.linalg.norm(values, axis=1, keepdims=True)
return np.divide(
values,
norms,
out=np.zeros_like(values, dtype=float),
where=norms > 0,
)


def _speaker_centroids(
embeddings: np.ndarray,
labels: np.ndarray,
) -> tuple[list[int], np.ndarray]:
"""Build L2-normalised speaker centroids from clustered embeddings."""
if len(embeddings) == 0 or len(embeddings) != len(labels):
return [], np.empty((0, 0), dtype=float)

norm_embeddings = _normalize_rows(np.asarray(embeddings, dtype=float))
label_values = sorted({int(label) for label in labels})
centroids: list[np.ndarray] = []
valid_labels: list[int] = []

for label in label_values:
members = norm_embeddings[labels == label]
if len(members) == 0: # pragma: no cover - label_values are derived from labels.
continue
centroid = members.mean(axis=0, keepdims=True)
centroid = _normalize_rows(centroid)[0]
if not np.any(centroid):
continue
valid_labels.append(label)
centroids.append(centroid)

if not centroids:
return [], np.empty((0, norm_embeddings.shape[1]), dtype=float)

return valid_labels, np.vstack(centroids)


def _viterbi_smooth_scores(
scores: np.ndarray,
*,
switch_penalty: float = _TEMPORAL_SWITCH_PENALTY,
) -> list[int]:
"""Find the best label path with a penalty for short speaker switches."""
n_frames, n_labels = scores.shape
if n_frames == 0:
return []
if n_labels <= 1:
return [0] * n_frames

dp = np.empty((n_frames, n_labels), dtype=float)
back = np.zeros((n_frames, n_labels), dtype=int)
dp[0] = scores[0]

transition = np.full((n_labels, n_labels), -switch_penalty, dtype=float)
np.fill_diagonal(transition, 0.0)

for frame_idx in range(1, n_frames):
previous = dp[frame_idx - 1][:, None] + transition
back[frame_idx] = np.argmax(previous, axis=0)
dp[frame_idx] = scores[frame_idx] + np.max(previous, axis=0)

path = [int(np.argmax(dp[-1]))]
for frame_idx in range(n_frames - 1, 0, -1):
path.append(int(back[frame_idx, path[-1]]))
path.reverse()
return path


def _smooth_window_labels_temporal(
labels: list[int],
embeddings: np.ndarray,
indices: list[int],
label_values: list[int],
centroids: np.ndarray,
) -> list[int]:
"""Smooth a single VAD segment using centroid scores and Viterbi decoding."""
if len(labels) < 3 or len(label_values) <= 1:
return labels

label_to_idx = {label: idx for idx, label in enumerate(label_values)}
norm_embeddings = _normalize_rows(np.asarray(embeddings[indices], dtype=float))
scores = norm_embeddings @ centroids.T

# Small anchor to the original clustering label keeps confident labels
# stable while the transition penalty removes low-value label flicker.
for row_idx, label in enumerate(labels):
label_idx = label_to_idx.get(label)
if label_idx is not None:
scores[row_idx, label_idx] += 0.02

path = _viterbi_smooth_scores(scores)
return [label_values[state] for state in path]


def _collapse_short_label_islands(
labels: list[int],
windows: list[tuple[float, float]],
*,
max_duration: float = _MAX_FRAGMENT_DURATION,
) -> list[int]:
"""Collapse short A-B-A label islands inside a continuous speech segment."""
if len(labels) < 3 or len(labels) != len(windows):
return labels

runs: list[tuple[int, int, int, float]] = []
run_start = 0
for idx in range(1, len(labels) + 1):
if idx == len(labels) or labels[idx] != labels[run_start]:
duration = windows[idx - 1][1] - windows[run_start][0]
runs.append((run_start, idx, labels[run_start], duration))
run_start = idx

if len(runs) < 3:
return labels

smoothed = labels.copy()
for run_idx in range(1, len(runs) - 1):
start, end, label, duration = runs[run_idx]
_, _, previous_label, _ = runs[run_idx - 1]
_, _, next_label, _ = runs[run_idx + 1]
if previous_label == next_label and label != previous_label and duration <= max_duration:
smoothed[start:end] = [previous_label] * (end - start)

return smoothed


def _restore_sustained_label_runs(
original_labels: list[int],
smoothed_labels: list[int],
windows: list[tuple[float, float]],
*,
min_duration: float = _MAX_FRAGMENT_DURATION,
) -> list[int]:
"""Keep sustained original runs even when Viterbi prefers fewer switches."""
if len(original_labels) != len(smoothed_labels) or len(original_labels) != len(windows):
return smoothed_labels

restored = smoothed_labels.copy()
run_start = 0
for idx in range(1, len(original_labels) + 1):
if idx == len(original_labels) or original_labels[idx] != original_labels[run_start]:
duration = windows[idx - 1][1] - windows[run_start][0]
if duration > min_duration:
restored[run_start:idx] = original_labels[run_start:idx]
run_start = idx

return restored


def _window_boundaries(
speech_segment: SpeechSegment,
segment_subsegments: list[SubSegment],
Expand All @@ -103,17 +259,19 @@ def _build_diarization_segments(
speech_segments: list[SpeechSegment],
subsegments: list[SubSegment],
labels: np.ndarray,
embeddings: np.ndarray | None = None,
) -> list[Segment]:
"""Assemble diarization segments from subsegments and cluster labels.

Overlapping embedding windows are converted to a non-overlapping
timeline and smoothed with a local majority filter. VAD segments
timeline and smoothed with temporal label decoding. VAD segments
without embeddings are assigned the nearest speaker.

Args:
speech_segments: Original speech segments from VAD.
subsegments: Embedding windows with parent indices.
labels: Cluster labels aligned with *subsegments*.
embeddings: Optional speaker embeddings aligned with *subsegments*.

Returns:
Merged :class:`Segment` list sorted by start time.
Expand All @@ -123,15 +281,39 @@ def _build_diarization_segments(
for idx, sub in enumerate(subsegments):
subsegments_by_parent.setdefault(sub.parent_idx, []).append(idx)

label_values: list[int] = []
centroids = np.empty((0, 0), dtype=float)
if embeddings is not None:
label_values, centroids = _speaker_centroids(embeddings, labels)

for parent_idx, speech_segment in enumerate(speech_segments):
indices = subsegments_by_parent.get(parent_idx)
if not indices:
continue

indices.sort(key=lambda idx: subsegments[idx].start)
parent_labels = _smooth_window_labels([int(labels[idx]) for idx in indices])
parent_labels = [int(labels[idx]) for idx in indices]
parent_subsegments = [subsegments[idx] for idx in indices]
windows = _window_boundaries(speech_segment, parent_subsegments)

if embeddings is not None and len(label_values) > 1:
original_labels = parent_labels
parent_labels = _smooth_window_labels_temporal(
parent_labels,
embeddings,
indices,
label_values,
centroids,
)
parent_labels = _restore_sustained_label_runs(
original_labels,
parent_labels,
windows,
)
parent_labels = _collapse_short_label_islands(parent_labels, windows)
else:
parent_labels = _smooth_window_labels(parent_labels)

for (start, end), label in zip(windows, parent_labels):
if end <= start:
continue
Expand Down Expand Up @@ -258,6 +440,7 @@ def diarize(
speech_segments,
subsegments,
labels,
embeddings,
)

result = DiarizeResult(
Expand Down
Loading
Loading