Files
transcrilive 2b1a3c1312 feat: initial public release v0.1.0 — MLX port of pyannote-speaker-diarization-3.1
Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical
at 6 decimals on 200 cross-window slot pairs). 2.5x faster than
pyannote-MPS on Apple Silicon native.

Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
2026-05-09 16:05:39 +02:00

162 lines
5.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Benchmark MLX vs pyannote-MPS diarization on the same audio.
Usage:
uv run python scripts/benchmark_diar_backends.py <audio> \
[--min-speakers N] [--max-speakers M]
Runs both backends back-to-back, prints a Markdown table with wall time,
speaker count, total speech duration, and cross-DER (MLX vs pyannote).
"""
from __future__ import annotations
import argparse
import gc
import sys
import time
from pathlib import Path
import librosa
import numpy as np
import psutil
import torch
def _measure(label: str, fn) -> dict:
"""Run fn(), measure wall time + RSS delta + return result."""
proc = psutil.Process()
gc.collect()
rss_before = proc.memory_info().rss
t0 = time.time()
annotation = fn()
wall = time.time() - t0
rss_peak = proc.memory_info().rss
return {
"label": label,
"wall": wall,
"rss_delta_gb": (rss_peak - rss_before) / 1e9,
"rss_peak_gb": rss_peak / 1e9,
"annotation": annotation,
}
def _stats(annotation) -> dict:
speakers = sorted(set(annotation.labels()))
turns = list(annotation.itertracks(yield_label=True))
total_speech = sum(seg.duration for seg, _, _ in turns)
# per-speaker totals
by_speaker = {}
for seg, _, lab in turns:
by_speaker[lab] = by_speaker.get(lab, 0.0) + seg.duration
return {
"speakers": len(speakers),
"turns": len(turns),
"total_speech": total_speech,
"by_speaker": dict(sorted(by_speaker.items(), key=lambda kv: -kv[1])),
}
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__.splitlines()[0])
parser.add_argument("audio")
parser.add_argument("--min-speakers", type=int, default=10)
parser.add_argument("--max-speakers", type=int, default=15)
args = parser.parse_args()
audio_path = Path(args.audio).expanduser().resolve()
print(f"Loading {audio_path.name} (sr=16000, mono) ...", file=sys.stderr)
sig, _ = librosa.load(str(audio_path), sr=16000, mono=True)
duration_s = len(sig) / 16000
print(f" duration: {duration_s:.0f}s ({duration_s/60:.1f} min)", file=sys.stderr)
diar_input = {
"waveform": torch.from_numpy(sig).unsqueeze(0),
"sample_rate": 16000,
}
kwargs = {"min_speakers": args.min_speakers, "max_speakers": args.max_speakers}
results = []
# 1. MLX pure
print("\n=== MLX pure-MLX/scipy diarization ===", file=sys.stderr)
from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline
mlx_pipe = MlxDiarizationPipeline.from_pretrained()
r_mlx = _measure("mlx", lambda: mlx_pipe(diar_input, **kwargs))
r_mlx.update(_stats(r_mlx["annotation"]))
results.append(r_mlx)
print(
f" wall={r_mlx['wall']:.1f}s speakers={r_mlx['speakers']} "
f"speech={r_mlx['total_speech']:.0f}s "
f"rss_delta={r_mlx['rss_delta_gb']:.2f}GB",
file=sys.stderr,
)
# free MLX before pyannote (we'll reuse the same Python proc)
del mlx_pipe
gc.collect()
# 2. pyannote (MPS if available, else CPU)
print("\n=== pyannote-audio 4.0.4 (MPS/PyTorch) ===", file=sys.stderr)
from pyannote.audio import Pipeline
pa_pipe = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
if torch.backends.mps.is_available():
try:
pa_pipe.to(torch.device("mps"))
print(" device: mps", file=sys.stderr)
except Exception as e:
print(f" warning: mps failed ({e}); CPU fallback", file=sys.stderr)
else:
print(" device: cpu", file=sys.stderr)
def _run_pa():
out = pa_pipe(diar_input, **kwargs)
ann = getattr(out, "exclusive_speaker_diarization", None) or out
return ann
r_pa = _measure("pyannote", _run_pa)
r_pa.update(_stats(r_pa["annotation"]))
results.append(r_pa)
print(
f" wall={r_pa['wall']:.1f}s speakers={r_pa['speakers']} "
f"speech={r_pa['total_speech']:.0f}s "
f"rss_delta={r_pa['rss_delta_gb']:.2f}GB",
file=sys.stderr,
)
# 3. cross DER
der_value = None
try:
from pyannote.metrics.diarization import DiarizationErrorRate
der_value = DiarizationErrorRate()(r_pa["annotation"], r_mlx["annotation"])
print(f"\nCross-DER (mlx vs pyannote ref): {der_value:.3f}", file=sys.stderr)
except Exception as e:
print(f"\nDER computation failed: {e}", file=sys.stderr)
# Print Markdown table to stdout
print()
print("| Backend | Wall (s) | Realtime | Speakers | Turns | Speech (s) | RSS Δ (GB) |")
print("|---|---:|---:|---:|---:|---:|---:|")
for r in results:
rt = duration_s / r["wall"] if r["wall"] > 0 else 0
print(
f"| {r['label']} | {r['wall']:.1f} | {rt:.1f}× | "
f"{r['speakers']} | {r['turns']} | "
f"{r['total_speech']:.0f} | {r['rss_delta_gb']:.2f} |"
)
print()
if der_value is not None:
print(f"Cross-DER (mlx vs pyannote): **{der_value:.3f}**")
print()
print("### Per-speaker speech time")
for r in results:
print(f"\n**{r['label']}** ({r['speakers']} speakers):")
for sp, dur in list(r["by_speaker"].items())[:10]:
print(f" {sp}: {dur:.0f}s")
return 0
if __name__ == "__main__":
sys.exit(main())