feat: initial public release v0.1.0 — MLX port of pyannote-speaker-diarization-3.1
Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical at 6 decimals on 200 cross-window slot pairs). 2.5x faster than pyannote-MPS on Apple Silicon native. Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
This commit is contained in:
161
scripts/bench.py
Normal file
161
scripts/bench.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Benchmark MLX vs pyannote-MPS diarization on the same audio.
|
||||
|
||||
Usage:
|
||||
uv run python scripts/benchmark_diar_backends.py <audio> \
|
||||
[--min-speakers N] [--max-speakers M]
|
||||
|
||||
Runs both backends back-to-back, prints a Markdown table with wall time,
|
||||
speaker count, total speech duration, and cross-DER (MLX vs pyannote).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
|
||||
def _measure(label: str, fn) -> dict:
|
||||
"""Run fn(), measure wall time + RSS delta + return result."""
|
||||
proc = psutil.Process()
|
||||
gc.collect()
|
||||
rss_before = proc.memory_info().rss
|
||||
t0 = time.time()
|
||||
annotation = fn()
|
||||
wall = time.time() - t0
|
||||
rss_peak = proc.memory_info().rss
|
||||
return {
|
||||
"label": label,
|
||||
"wall": wall,
|
||||
"rss_delta_gb": (rss_peak - rss_before) / 1e9,
|
||||
"rss_peak_gb": rss_peak / 1e9,
|
||||
"annotation": annotation,
|
||||
}
|
||||
|
||||
|
||||
def _stats(annotation) -> dict:
|
||||
speakers = sorted(set(annotation.labels()))
|
||||
turns = list(annotation.itertracks(yield_label=True))
|
||||
total_speech = sum(seg.duration for seg, _, _ in turns)
|
||||
# per-speaker totals
|
||||
by_speaker = {}
|
||||
for seg, _, lab in turns:
|
||||
by_speaker[lab] = by_speaker.get(lab, 0.0) + seg.duration
|
||||
return {
|
||||
"speakers": len(speakers),
|
||||
"turns": len(turns),
|
||||
"total_speech": total_speech,
|
||||
"by_speaker": dict(sorted(by_speaker.items(), key=lambda kv: -kv[1])),
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description=__doc__.splitlines()[0])
|
||||
parser.add_argument("audio")
|
||||
parser.add_argument("--min-speakers", type=int, default=10)
|
||||
parser.add_argument("--max-speakers", type=int, default=15)
|
||||
args = parser.parse_args()
|
||||
|
||||
audio_path = Path(args.audio).expanduser().resolve()
|
||||
print(f"Loading {audio_path.name} (sr=16000, mono) ...", file=sys.stderr)
|
||||
sig, _ = librosa.load(str(audio_path), sr=16000, mono=True)
|
||||
duration_s = len(sig) / 16000
|
||||
print(f" duration: {duration_s:.0f}s ({duration_s/60:.1f} min)", file=sys.stderr)
|
||||
diar_input = {
|
||||
"waveform": torch.from_numpy(sig).unsqueeze(0),
|
||||
"sample_rate": 16000,
|
||||
}
|
||||
kwargs = {"min_speakers": args.min_speakers, "max_speakers": args.max_speakers}
|
||||
|
||||
results = []
|
||||
|
||||
# 1. MLX pure
|
||||
print("\n=== MLX pure-MLX/scipy diarization ===", file=sys.stderr)
|
||||
from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline
|
||||
|
||||
mlx_pipe = MlxDiarizationPipeline.from_pretrained()
|
||||
r_mlx = _measure("mlx", lambda: mlx_pipe(diar_input, **kwargs))
|
||||
r_mlx.update(_stats(r_mlx["annotation"]))
|
||||
results.append(r_mlx)
|
||||
print(
|
||||
f" wall={r_mlx['wall']:.1f}s speakers={r_mlx['speakers']} "
|
||||
f"speech={r_mlx['total_speech']:.0f}s "
|
||||
f"rss_delta={r_mlx['rss_delta_gb']:.2f}GB",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# free MLX before pyannote (we'll reuse the same Python proc)
|
||||
del mlx_pipe
|
||||
gc.collect()
|
||||
|
||||
# 2. pyannote (MPS if available, else CPU)
|
||||
print("\n=== pyannote-audio 4.0.4 (MPS/PyTorch) ===", file=sys.stderr)
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
pa_pipe = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
|
||||
if torch.backends.mps.is_available():
|
||||
try:
|
||||
pa_pipe.to(torch.device("mps"))
|
||||
print(" device: mps", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f" warning: mps failed ({e}); CPU fallback", file=sys.stderr)
|
||||
else:
|
||||
print(" device: cpu", file=sys.stderr)
|
||||
|
||||
def _run_pa():
|
||||
out = pa_pipe(diar_input, **kwargs)
|
||||
ann = getattr(out, "exclusive_speaker_diarization", None) or out
|
||||
return ann
|
||||
|
||||
r_pa = _measure("pyannote", _run_pa)
|
||||
r_pa.update(_stats(r_pa["annotation"]))
|
||||
results.append(r_pa)
|
||||
print(
|
||||
f" wall={r_pa['wall']:.1f}s speakers={r_pa['speakers']} "
|
||||
f"speech={r_pa['total_speech']:.0f}s "
|
||||
f"rss_delta={r_pa['rss_delta_gb']:.2f}GB",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# 3. cross DER
|
||||
der_value = None
|
||||
try:
|
||||
from pyannote.metrics.diarization import DiarizationErrorRate
|
||||
der_value = DiarizationErrorRate()(r_pa["annotation"], r_mlx["annotation"])
|
||||
print(f"\nCross-DER (mlx vs pyannote ref): {der_value:.3f}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"\nDER computation failed: {e}", file=sys.stderr)
|
||||
|
||||
# Print Markdown table to stdout
|
||||
print()
|
||||
print("| Backend | Wall (s) | Realtime | Speakers | Turns | Speech (s) | RSS Δ (GB) |")
|
||||
print("|---|---:|---:|---:|---:|---:|---:|")
|
||||
for r in results:
|
||||
rt = duration_s / r["wall"] if r["wall"] > 0 else 0
|
||||
print(
|
||||
f"| {r['label']} | {r['wall']:.1f} | {rt:.1f}× | "
|
||||
f"{r['speakers']} | {r['turns']} | "
|
||||
f"{r['total_speech']:.0f} | {r['rss_delta_gb']:.2f} |"
|
||||
)
|
||||
print()
|
||||
if der_value is not None:
|
||||
print(f"Cross-DER (mlx vs pyannote): **{der_value:.3f}**")
|
||||
|
||||
print()
|
||||
print("### Per-speaker speech time")
|
||||
for r in results:
|
||||
print(f"\n**{r['label']}** ({r['speakers']} speakers):")
|
||||
for sp, dur in list(r["by_speaker"].items())[:10]:
|
||||
print(f" {sp}: {dur:.0f}s")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
52
scripts/install_remote.sh
Executable file
52
scripts/install_remote.sh
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
INSTALL_DIR="${1:-$HOME/pyannote-diarization-3.1-mlx-test}"
|
||||
INSTALL_DIR="${INSTALL_DIR/#\~/$HOME}"
|
||||
HTTPS_SPEC="pyannote-speaker-diarization-3.1-mlx @ git+https://gitea.tavportal.com/olivier/pyannote-speaker-diarization-3.1-mlx.git"
|
||||
SSH_SPEC="git+ssh://git@gitea.tavportal.com/olivier/pyannote-speaker-diarization-3.1-mlx.git"
|
||||
|
||||
usage() {
|
||||
cat <<EOF
|
||||
Usage:
|
||||
$0 [install-dir]
|
||||
|
||||
Creates a uv project and installs pyannote-speaker-diarization-3.1-mlx.
|
||||
Default install directory:
|
||||
$INSTALL_DIR
|
||||
EOF
|
||||
}
|
||||
|
||||
if [[ "${1:-}" == "-h" || "${1:-}" == "--help" ]]; then
|
||||
usage
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if ! command -v uv >/dev/null 2>&1; then
|
||||
cat >&2 <<'EOF'
|
||||
uv is required but was not found.
|
||||
|
||||
Install it with:
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
Then restart your shell and run this script again.
|
||||
EOF
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p "$INSTALL_DIR"
|
||||
cd "$INSTALL_DIR"
|
||||
|
||||
if [[ ! -f pyproject.toml ]]; then
|
||||
uv init --python 3.12
|
||||
else
|
||||
echo "Found existing pyproject.toml in $INSTALL_DIR; skipping uv init."
|
||||
fi
|
||||
|
||||
echo "Installing from HTTPS..."
|
||||
if ! uv add "$HTTPS_SPEC"; then
|
||||
echo "HTTPS install failed; falling back to SSH pip install..."
|
||||
uv pip install "$SSH_SPEC"
|
||||
fi
|
||||
|
||||
uv run python -c "from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline; print('OK')"
|
||||
Reference in New Issue
Block a user