Files
pyannote-speaker-diarizatio…/tests/integration/test_diar_60s_smoke.py
transcrilive 2b1a3c1312 feat: initial public release v0.1.0 — MLX port of pyannote-speaker-diarization-3.1
Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical
at 6 decimals on 200 cross-window slot pairs). 2.5x faster than
pyannote-MPS on Apple Silicon native.

Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
2026-05-09 16:05:39 +02:00

55 lines
2.0 KiB
Python

"""Day-1 sanity gate. If this fails, do NOT spend further time on Plan A."""
import time
import numpy as np
import librosa
import pytest
import soundfile as sf
import torch
from pyannote.audio import Pipeline
from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline
@pytest.mark.integration
def test_diar_60s_parity_vs_pyannote():
audio_path = "/tmp/_diar_smoke_60s.wav"
# use any 60s slice of the existing test audio
sig, _ = librosa.load("/tmp/audio_first_3min.wav", sr=16000, duration=60)
sf.write(audio_path, sig, 16000)
# MLX pipeline
mlx_pipe = MlxDiarizationPipeline.from_pretrained()
mlx_ann = mlx_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
"sample_rate": 16000},
min_speakers=1, max_speakers=3)
mlx_speakers = set(mlx_ann.labels())
# pyannote PyTorch reference
ref_pipe = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
ref_out = ref_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
"sample_rate": 16000},
min_speakers=1, max_speakers=3)
# pyannote 3.x returns the annotation directly
if hasattr(ref_out, "exclusive_speaker_diarization"):
ref_ann = ref_out.exclusive_speaker_diarization
else:
ref_ann = ref_out
ref_speakers = set(ref_ann.labels())
# gate: speaker count within ±1
assert abs(len(mlx_speakers) - len(ref_speakers)) <= 1, \
f"speaker count diff: mlx={len(mlx_speakers)} ref={len(ref_speakers)}"
# gate: DER < 0.30 (Hungarian-aligned)
from pyannote.metrics.diarization import DiarizationErrorRate
der = DiarizationErrorRate()
der_value = der(ref_ann, mlx_ann)
assert der_value <= 0.30, f"DER {der_value:.3f} > 0.30 (gate ≤ 0.30)"
# gate: wall-clock under 30s (MLX should be fast on M2/M3)
t0 = time.time()
mlx_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
"sample_rate": 16000})
wall = time.time() - t0
assert wall < 30, f"wall {wall:.1f}s > 30s for 60s audio"