feat: initial public release v0.1.0 — MLX port of pyannote-speaker-diarization-3.1
Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical at 6 decimals on 200 cross-window slot pairs). 2.5x faster than pyannote-MPS on Apple Silicon native. Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
This commit is contained in:
54
tests/integration/test_diar_60s_smoke.py
Normal file
54
tests/integration/test_diar_60s_smoke.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Day-1 sanity gate. If this fails, do NOT spend further time on Plan A."""
|
||||
import time
|
||||
import numpy as np
|
||||
import librosa
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_diar_60s_parity_vs_pyannote():
|
||||
audio_path = "/tmp/_diar_smoke_60s.wav"
|
||||
# use any 60s slice of the existing test audio
|
||||
sig, _ = librosa.load("/tmp/audio_first_3min.wav", sr=16000, duration=60)
|
||||
sf.write(audio_path, sig, 16000)
|
||||
|
||||
# MLX pipeline
|
||||
mlx_pipe = MlxDiarizationPipeline.from_pretrained()
|
||||
mlx_ann = mlx_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
|
||||
"sample_rate": 16000},
|
||||
min_speakers=1, max_speakers=3)
|
||||
mlx_speakers = set(mlx_ann.labels())
|
||||
|
||||
# pyannote PyTorch reference
|
||||
ref_pipe = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
|
||||
ref_out = ref_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
|
||||
"sample_rate": 16000},
|
||||
min_speakers=1, max_speakers=3)
|
||||
# pyannote 3.x returns the annotation directly
|
||||
if hasattr(ref_out, "exclusive_speaker_diarization"):
|
||||
ref_ann = ref_out.exclusive_speaker_diarization
|
||||
else:
|
||||
ref_ann = ref_out
|
||||
ref_speakers = set(ref_ann.labels())
|
||||
|
||||
# gate: speaker count within ±1
|
||||
assert abs(len(mlx_speakers) - len(ref_speakers)) <= 1, \
|
||||
f"speaker count diff: mlx={len(mlx_speakers)} ref={len(ref_speakers)}"
|
||||
|
||||
# gate: DER < 0.30 (Hungarian-aligned)
|
||||
from pyannote.metrics.diarization import DiarizationErrorRate
|
||||
der = DiarizationErrorRate()
|
||||
der_value = der(ref_ann, mlx_ann)
|
||||
assert der_value <= 0.30, f"DER {der_value:.3f} > 0.30 (gate ≤ 0.30)"
|
||||
|
||||
# gate: wall-clock under 30s (MLX should be fast on M2/M3)
|
||||
t0 = time.time()
|
||||
mlx_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
|
||||
"sample_rate": 16000})
|
||||
wall = time.time() - t0
|
||||
assert wall < 30, f"wall {wall:.1f}s > 30s for 60s audio"
|
||||
Reference in New Issue
Block a user