Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical at 6 decimals on 200 cross-window slot pairs). 2.5x faster than pyannote-MPS on Apple Silicon native. Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
162 lines
5.2 KiB
Python
162 lines
5.2 KiB
Python
"""Benchmark MLX vs pyannote-MPS diarization on the same audio.
|
||
|
||
Usage:
|
||
uv run python scripts/benchmark_diar_backends.py <audio> \
|
||
[--min-speakers N] [--max-speakers M]
|
||
|
||
Runs both backends back-to-back, prints a Markdown table with wall time,
|
||
speaker count, total speech duration, and cross-DER (MLX vs pyannote).
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import gc
|
||
import sys
|
||
import time
|
||
from pathlib import Path
|
||
|
||
import librosa
|
||
import numpy as np
|
||
import psutil
|
||
import torch
|
||
|
||
|
||
def _measure(label: str, fn) -> dict:
|
||
"""Run fn(), measure wall time + RSS delta + return result."""
|
||
proc = psutil.Process()
|
||
gc.collect()
|
||
rss_before = proc.memory_info().rss
|
||
t0 = time.time()
|
||
annotation = fn()
|
||
wall = time.time() - t0
|
||
rss_peak = proc.memory_info().rss
|
||
return {
|
||
"label": label,
|
||
"wall": wall,
|
||
"rss_delta_gb": (rss_peak - rss_before) / 1e9,
|
||
"rss_peak_gb": rss_peak / 1e9,
|
||
"annotation": annotation,
|
||
}
|
||
|
||
|
||
def _stats(annotation) -> dict:
|
||
speakers = sorted(set(annotation.labels()))
|
||
turns = list(annotation.itertracks(yield_label=True))
|
||
total_speech = sum(seg.duration for seg, _, _ in turns)
|
||
# per-speaker totals
|
||
by_speaker = {}
|
||
for seg, _, lab in turns:
|
||
by_speaker[lab] = by_speaker.get(lab, 0.0) + seg.duration
|
||
return {
|
||
"speakers": len(speakers),
|
||
"turns": len(turns),
|
||
"total_speech": total_speech,
|
||
"by_speaker": dict(sorted(by_speaker.items(), key=lambda kv: -kv[1])),
|
||
}
|
||
|
||
|
||
def main() -> int:
|
||
parser = argparse.ArgumentParser(description=__doc__.splitlines()[0])
|
||
parser.add_argument("audio")
|
||
parser.add_argument("--min-speakers", type=int, default=10)
|
||
parser.add_argument("--max-speakers", type=int, default=15)
|
||
args = parser.parse_args()
|
||
|
||
audio_path = Path(args.audio).expanduser().resolve()
|
||
print(f"Loading {audio_path.name} (sr=16000, mono) ...", file=sys.stderr)
|
||
sig, _ = librosa.load(str(audio_path), sr=16000, mono=True)
|
||
duration_s = len(sig) / 16000
|
||
print(f" duration: {duration_s:.0f}s ({duration_s/60:.1f} min)", file=sys.stderr)
|
||
diar_input = {
|
||
"waveform": torch.from_numpy(sig).unsqueeze(0),
|
||
"sample_rate": 16000,
|
||
}
|
||
kwargs = {"min_speakers": args.min_speakers, "max_speakers": args.max_speakers}
|
||
|
||
results = []
|
||
|
||
# 1. MLX pure
|
||
print("\n=== MLX pure-MLX/scipy diarization ===", file=sys.stderr)
|
||
from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline
|
||
|
||
mlx_pipe = MlxDiarizationPipeline.from_pretrained()
|
||
r_mlx = _measure("mlx", lambda: mlx_pipe(diar_input, **kwargs))
|
||
r_mlx.update(_stats(r_mlx["annotation"]))
|
||
results.append(r_mlx)
|
||
print(
|
||
f" wall={r_mlx['wall']:.1f}s speakers={r_mlx['speakers']} "
|
||
f"speech={r_mlx['total_speech']:.0f}s "
|
||
f"rss_delta={r_mlx['rss_delta_gb']:.2f}GB",
|
||
file=sys.stderr,
|
||
)
|
||
|
||
# free MLX before pyannote (we'll reuse the same Python proc)
|
||
del mlx_pipe
|
||
gc.collect()
|
||
|
||
# 2. pyannote (MPS if available, else CPU)
|
||
print("\n=== pyannote-audio 4.0.4 (MPS/PyTorch) ===", file=sys.stderr)
|
||
from pyannote.audio import Pipeline
|
||
|
||
pa_pipe = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
|
||
if torch.backends.mps.is_available():
|
||
try:
|
||
pa_pipe.to(torch.device("mps"))
|
||
print(" device: mps", file=sys.stderr)
|
||
except Exception as e:
|
||
print(f" warning: mps failed ({e}); CPU fallback", file=sys.stderr)
|
||
else:
|
||
print(" device: cpu", file=sys.stderr)
|
||
|
||
def _run_pa():
|
||
out = pa_pipe(diar_input, **kwargs)
|
||
ann = getattr(out, "exclusive_speaker_diarization", None) or out
|
||
return ann
|
||
|
||
r_pa = _measure("pyannote", _run_pa)
|
||
r_pa.update(_stats(r_pa["annotation"]))
|
||
results.append(r_pa)
|
||
print(
|
||
f" wall={r_pa['wall']:.1f}s speakers={r_pa['speakers']} "
|
||
f"speech={r_pa['total_speech']:.0f}s "
|
||
f"rss_delta={r_pa['rss_delta_gb']:.2f}GB",
|
||
file=sys.stderr,
|
||
)
|
||
|
||
# 3. cross DER
|
||
der_value = None
|
||
try:
|
||
from pyannote.metrics.diarization import DiarizationErrorRate
|
||
der_value = DiarizationErrorRate()(r_pa["annotation"], r_mlx["annotation"])
|
||
print(f"\nCross-DER (mlx vs pyannote ref): {der_value:.3f}", file=sys.stderr)
|
||
except Exception as e:
|
||
print(f"\nDER computation failed: {e}", file=sys.stderr)
|
||
|
||
# Print Markdown table to stdout
|
||
print()
|
||
print("| Backend | Wall (s) | Realtime | Speakers | Turns | Speech (s) | RSS Δ (GB) |")
|
||
print("|---|---:|---:|---:|---:|---:|---:|")
|
||
for r in results:
|
||
rt = duration_s / r["wall"] if r["wall"] > 0 else 0
|
||
print(
|
||
f"| {r['label']} | {r['wall']:.1f} | {rt:.1f}× | "
|
||
f"{r['speakers']} | {r['turns']} | "
|
||
f"{r['total_speech']:.0f} | {r['rss_delta_gb']:.2f} |"
|
||
)
|
||
print()
|
||
if der_value is not None:
|
||
print(f"Cross-DER (mlx vs pyannote): **{der_value:.3f}**")
|
||
|
||
print()
|
||
print("### Per-speaker speech time")
|
||
for r in results:
|
||
print(f"\n**{r['label']}** ({r['speakers']} speakers):")
|
||
for sp, dur in list(r["by_speaker"].items())[:10]:
|
||
print(f" {sp}: {dur:.0f}s")
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|