feat: initial public release v0.1.0 — MLX port of pyannote-speaker-diarization-3.1

Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical
at 6 decimals on 200 cross-window slot pairs). 2.5x faster than
pyannote-MPS on Apple Silicon native.

Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
This commit is contained in:
transcrilive
2026-05-09 16:05:39 +02:00
commit 2b1a3c1312
30 changed files with 2022 additions and 0 deletions

161
scripts/bench.py Normal file
View File

@@ -0,0 +1,161 @@
"""Benchmark MLX vs pyannote-MPS diarization on the same audio.
Usage:
uv run python scripts/benchmark_diar_backends.py <audio> \
[--min-speakers N] [--max-speakers M]
Runs both backends back-to-back, prints a Markdown table with wall time,
speaker count, total speech duration, and cross-DER (MLX vs pyannote).
"""
from __future__ import annotations
import argparse
import gc
import sys
import time
from pathlib import Path
import librosa
import numpy as np
import psutil
import torch
def _measure(label: str, fn) -> dict:
"""Run fn(), measure wall time + RSS delta + return result."""
proc = psutil.Process()
gc.collect()
rss_before = proc.memory_info().rss
t0 = time.time()
annotation = fn()
wall = time.time() - t0
rss_peak = proc.memory_info().rss
return {
"label": label,
"wall": wall,
"rss_delta_gb": (rss_peak - rss_before) / 1e9,
"rss_peak_gb": rss_peak / 1e9,
"annotation": annotation,
}
def _stats(annotation) -> dict:
speakers = sorted(set(annotation.labels()))
turns = list(annotation.itertracks(yield_label=True))
total_speech = sum(seg.duration for seg, _, _ in turns)
# per-speaker totals
by_speaker = {}
for seg, _, lab in turns:
by_speaker[lab] = by_speaker.get(lab, 0.0) + seg.duration
return {
"speakers": len(speakers),
"turns": len(turns),
"total_speech": total_speech,
"by_speaker": dict(sorted(by_speaker.items(), key=lambda kv: -kv[1])),
}
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__.splitlines()[0])
parser.add_argument("audio")
parser.add_argument("--min-speakers", type=int, default=10)
parser.add_argument("--max-speakers", type=int, default=15)
args = parser.parse_args()
audio_path = Path(args.audio).expanduser().resolve()
print(f"Loading {audio_path.name} (sr=16000, mono) ...", file=sys.stderr)
sig, _ = librosa.load(str(audio_path), sr=16000, mono=True)
duration_s = len(sig) / 16000
print(f" duration: {duration_s:.0f}s ({duration_s/60:.1f} min)", file=sys.stderr)
diar_input = {
"waveform": torch.from_numpy(sig).unsqueeze(0),
"sample_rate": 16000,
}
kwargs = {"min_speakers": args.min_speakers, "max_speakers": args.max_speakers}
results = []
# 1. MLX pure
print("\n=== MLX pure-MLX/scipy diarization ===", file=sys.stderr)
from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline
mlx_pipe = MlxDiarizationPipeline.from_pretrained()
r_mlx = _measure("mlx", lambda: mlx_pipe(diar_input, **kwargs))
r_mlx.update(_stats(r_mlx["annotation"]))
results.append(r_mlx)
print(
f" wall={r_mlx['wall']:.1f}s speakers={r_mlx['speakers']} "
f"speech={r_mlx['total_speech']:.0f}s "
f"rss_delta={r_mlx['rss_delta_gb']:.2f}GB",
file=sys.stderr,
)
# free MLX before pyannote (we'll reuse the same Python proc)
del mlx_pipe
gc.collect()
# 2. pyannote (MPS if available, else CPU)
print("\n=== pyannote-audio 4.0.4 (MPS/PyTorch) ===", file=sys.stderr)
from pyannote.audio import Pipeline
pa_pipe = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
if torch.backends.mps.is_available():
try:
pa_pipe.to(torch.device("mps"))
print(" device: mps", file=sys.stderr)
except Exception as e:
print(f" warning: mps failed ({e}); CPU fallback", file=sys.stderr)
else:
print(" device: cpu", file=sys.stderr)
def _run_pa():
out = pa_pipe(diar_input, **kwargs)
ann = getattr(out, "exclusive_speaker_diarization", None) or out
return ann
r_pa = _measure("pyannote", _run_pa)
r_pa.update(_stats(r_pa["annotation"]))
results.append(r_pa)
print(
f" wall={r_pa['wall']:.1f}s speakers={r_pa['speakers']} "
f"speech={r_pa['total_speech']:.0f}s "
f"rss_delta={r_pa['rss_delta_gb']:.2f}GB",
file=sys.stderr,
)
# 3. cross DER
der_value = None
try:
from pyannote.metrics.diarization import DiarizationErrorRate
der_value = DiarizationErrorRate()(r_pa["annotation"], r_mlx["annotation"])
print(f"\nCross-DER (mlx vs pyannote ref): {der_value:.3f}", file=sys.stderr)
except Exception as e:
print(f"\nDER computation failed: {e}", file=sys.stderr)
# Print Markdown table to stdout
print()
print("| Backend | Wall (s) | Realtime | Speakers | Turns | Speech (s) | RSS Δ (GB) |")
print("|---|---:|---:|---:|---:|---:|---:|")
for r in results:
rt = duration_s / r["wall"] if r["wall"] > 0 else 0
print(
f"| {r['label']} | {r['wall']:.1f} | {rt:.1f}× | "
f"{r['speakers']} | {r['turns']} | "
f"{r['total_speech']:.0f} | {r['rss_delta_gb']:.2f} |"
)
print()
if der_value is not None:
print(f"Cross-DER (mlx vs pyannote): **{der_value:.3f}**")
print()
print("### Per-speaker speech time")
for r in results:
print(f"\n**{r['label']}** ({r['speakers']} speakers):")
for sp, dur in list(r["by_speaker"].items())[:10]:
print(f" {sp}: {dur:.0f}s")
return 0
if __name__ == "__main__":
sys.exit(main())