Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical at 6 decimals on 200 cross-window slot pairs). 2.5x faster than pyannote-MPS on Apple Silicon native. Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
55 lines
2.0 KiB
Python
55 lines
2.0 KiB
Python
"""Day-1 sanity gate. If this fails, do NOT spend further time on Plan A."""
|
|
import time
|
|
import numpy as np
|
|
import librosa
|
|
import pytest
|
|
import soundfile as sf
|
|
import torch
|
|
from pyannote.audio import Pipeline
|
|
|
|
from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline
|
|
|
|
|
|
@pytest.mark.integration
|
|
def test_diar_60s_parity_vs_pyannote():
|
|
audio_path = "/tmp/_diar_smoke_60s.wav"
|
|
# use any 60s slice of the existing test audio
|
|
sig, _ = librosa.load("/tmp/audio_first_3min.wav", sr=16000, duration=60)
|
|
sf.write(audio_path, sig, 16000)
|
|
|
|
# MLX pipeline
|
|
mlx_pipe = MlxDiarizationPipeline.from_pretrained()
|
|
mlx_ann = mlx_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
|
|
"sample_rate": 16000},
|
|
min_speakers=1, max_speakers=3)
|
|
mlx_speakers = set(mlx_ann.labels())
|
|
|
|
# pyannote PyTorch reference
|
|
ref_pipe = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
|
|
ref_out = ref_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
|
|
"sample_rate": 16000},
|
|
min_speakers=1, max_speakers=3)
|
|
# pyannote 3.x returns the annotation directly
|
|
if hasattr(ref_out, "exclusive_speaker_diarization"):
|
|
ref_ann = ref_out.exclusive_speaker_diarization
|
|
else:
|
|
ref_ann = ref_out
|
|
ref_speakers = set(ref_ann.labels())
|
|
|
|
# gate: speaker count within ±1
|
|
assert abs(len(mlx_speakers) - len(ref_speakers)) <= 1, \
|
|
f"speaker count diff: mlx={len(mlx_speakers)} ref={len(ref_speakers)}"
|
|
|
|
# gate: DER < 0.30 (Hungarian-aligned)
|
|
from pyannote.metrics.diarization import DiarizationErrorRate
|
|
der = DiarizationErrorRate()
|
|
der_value = der(ref_ann, mlx_ann)
|
|
assert der_value <= 0.30, f"DER {der_value:.3f} > 0.30 (gate ≤ 0.30)"
|
|
|
|
# gate: wall-clock under 30s (MLX should be fast on M2/M3)
|
|
t0 = time.time()
|
|
mlx_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
|
|
"sample_rate": 16000})
|
|
wall = time.time() - t0
|
|
assert wall < 30, f"wall {wall:.1f}s > 30s for 60s audio"
|