Files
pyannote-speaker-diarizatio…/tests/unit/test_diar_audio_fbank.py
transcrilive 2b1a3c1312 feat: initial public release v0.1.0 — MLX port of pyannote-speaker-diarization-3.1
Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical
at 6 decimals on 200 cross-window slot pairs). 2.5x faster than
pyannote-MPS on Apple Silicon native.

Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
2026-05-09 16:05:39 +02:00

60 lines
1.8 KiB
Python

import numpy as np
import mlx.core as mx
import torch
from torchaudio.compliance import kaldi as ta_kaldi
from pyannote_diarization_3_1_mlx.audio import kaldi_fbank, load_waveform
def _fixed_signal(seconds: float = 3.0, sr: int = 16000):
t = np.linspace(0, seconds, int(seconds * sr), endpoint=False)
sig = (
0.5 * np.sin(2 * np.pi * 220 * t)
+ 0.3 * np.sin(2 * np.pi * 880 * t)
).astype(np.float32)
return sig
def test_fbank_matches_torchaudio_within_1pct():
sig = _fixed_signal()
# torchaudio reference: same params as pyannote WeSpeaker
sig_torch = torch.from_numpy(sig).unsqueeze(0) * (1 << 15)
ref = ta_kaldi.fbank(
sig_torch,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
window_type="hamming",
use_energy=False,
sample_frequency=16000,
).numpy() # (T, 80)
# Our MLX implementation, with same scaling and CMN
sig_mx = mx.array(sig)
out = kaldi_fbank(
sig_mx,
num_mel_bins=80,
frame_length_ms=25,
frame_shift_ms=10,
dither=0.0,
window_type="hamming",
use_energy=False,
sample_rate=16000,
)
out_np = np.asarray(out)
assert out_np.shape == ref.shape
# max abs diff should be small (kaldi-compliant, no random init)
diff = np.abs(out_np - ref).max()
assert diff < 0.05, f"max abs diff {diff:.4f}"
def test_load_waveform_resamples_to_16k():
import soundfile as sf
import tempfile, os
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
sf.write(f.name, _fixed_signal(seconds=1.0, sr=22050), 22050)
wav_mx = load_waveform(f.name)
os.unlink(f.name)
assert wav_mx.shape[-1] == 16000 # 1 second @ 16k after resample
assert wav_mx.dtype == mx.float32