feat: initial public release v0.1.0 — MLX port of pyannote-speaker-diarization-3.1
Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical at 6 decimals on 200 cross-window slot pairs). 2.5x faster than pyannote-MPS on Apple Silicon native. Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
This commit is contained in:
59
tests/unit/test_diar_audio_fbank.py
Normal file
59
tests/unit/test_diar_audio_fbank.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from torchaudio.compliance import kaldi as ta_kaldi
|
||||
from pyannote_diarization_3_1_mlx.audio import kaldi_fbank, load_waveform
|
||||
|
||||
|
||||
def _fixed_signal(seconds: float = 3.0, sr: int = 16000):
|
||||
t = np.linspace(0, seconds, int(seconds * sr), endpoint=False)
|
||||
sig = (
|
||||
0.5 * np.sin(2 * np.pi * 220 * t)
|
||||
+ 0.3 * np.sin(2 * np.pi * 880 * t)
|
||||
).astype(np.float32)
|
||||
return sig
|
||||
|
||||
|
||||
def test_fbank_matches_torchaudio_within_1pct():
|
||||
sig = _fixed_signal()
|
||||
# torchaudio reference: same params as pyannote WeSpeaker
|
||||
sig_torch = torch.from_numpy(sig).unsqueeze(0) * (1 << 15)
|
||||
ref = ta_kaldi.fbank(
|
||||
sig_torch,
|
||||
num_mel_bins=80,
|
||||
frame_length=25,
|
||||
frame_shift=10,
|
||||
dither=0.0,
|
||||
window_type="hamming",
|
||||
use_energy=False,
|
||||
sample_frequency=16000,
|
||||
).numpy() # (T, 80)
|
||||
|
||||
# Our MLX implementation, with same scaling and CMN
|
||||
sig_mx = mx.array(sig)
|
||||
out = kaldi_fbank(
|
||||
sig_mx,
|
||||
num_mel_bins=80,
|
||||
frame_length_ms=25,
|
||||
frame_shift_ms=10,
|
||||
dither=0.0,
|
||||
window_type="hamming",
|
||||
use_energy=False,
|
||||
sample_rate=16000,
|
||||
)
|
||||
out_np = np.asarray(out)
|
||||
assert out_np.shape == ref.shape
|
||||
# max abs diff should be small (kaldi-compliant, no random init)
|
||||
diff = np.abs(out_np - ref).max()
|
||||
assert diff < 0.05, f"max abs diff {diff:.4f}"
|
||||
|
||||
|
||||
def test_load_waveform_resamples_to_16k():
|
||||
import soundfile as sf
|
||||
import tempfile, os
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
sf.write(f.name, _fixed_signal(seconds=1.0, sr=22050), 22050)
|
||||
wav_mx = load_waveform(f.name)
|
||||
os.unlink(f.name)
|
||||
assert wav_mx.shape[-1] == 16000 # 1 second @ 16k after resample
|
||||
assert wav_mx.dtype == mx.float32
|
||||
Reference in New Issue
Block a user