Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical at 6 decimals on 200 cross-window slot pairs). 2.5x faster than pyannote-MPS on Apple Silicon native. Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
import numpy as np
|
|
import mlx.core as mx
|
|
import torch
|
|
from torchaudio.compliance import kaldi as ta_kaldi
|
|
from pyannote_diarization_3_1_mlx.audio import kaldi_fbank, load_waveform
|
|
|
|
|
|
def _fixed_signal(seconds: float = 3.0, sr: int = 16000):
|
|
t = np.linspace(0, seconds, int(seconds * sr), endpoint=False)
|
|
sig = (
|
|
0.5 * np.sin(2 * np.pi * 220 * t)
|
|
+ 0.3 * np.sin(2 * np.pi * 880 * t)
|
|
).astype(np.float32)
|
|
return sig
|
|
|
|
|
|
def test_fbank_matches_torchaudio_within_1pct():
|
|
sig = _fixed_signal()
|
|
# torchaudio reference: same params as pyannote WeSpeaker
|
|
sig_torch = torch.from_numpy(sig).unsqueeze(0) * (1 << 15)
|
|
ref = ta_kaldi.fbank(
|
|
sig_torch,
|
|
num_mel_bins=80,
|
|
frame_length=25,
|
|
frame_shift=10,
|
|
dither=0.0,
|
|
window_type="hamming",
|
|
use_energy=False,
|
|
sample_frequency=16000,
|
|
).numpy() # (T, 80)
|
|
|
|
# Our MLX implementation, with same scaling and CMN
|
|
sig_mx = mx.array(sig)
|
|
out = kaldi_fbank(
|
|
sig_mx,
|
|
num_mel_bins=80,
|
|
frame_length_ms=25,
|
|
frame_shift_ms=10,
|
|
dither=0.0,
|
|
window_type="hamming",
|
|
use_energy=False,
|
|
sample_rate=16000,
|
|
)
|
|
out_np = np.asarray(out)
|
|
assert out_np.shape == ref.shape
|
|
# max abs diff should be small (kaldi-compliant, no random init)
|
|
diff = np.abs(out_np - ref).max()
|
|
assert diff < 0.05, f"max abs diff {diff:.4f}"
|
|
|
|
|
|
def test_load_waveform_resamples_to_16k():
|
|
import soundfile as sf
|
|
import tempfile, os
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
|
sf.write(f.name, _fixed_signal(seconds=1.0, sr=22050), 22050)
|
|
wav_mx = load_waveform(f.name)
|
|
os.unlink(f.name)
|
|
assert wav_mx.shape[-1] == 16000 # 1 second @ 16k after resample
|
|
assert wav_mx.dtype == mx.float32
|