import numpy as np import mlx.core as mx import torch from torchaudio.compliance import kaldi as ta_kaldi from pyannote_diarization_3_1_mlx.audio import kaldi_fbank, load_waveform def _fixed_signal(seconds: float = 3.0, sr: int = 16000): t = np.linspace(0, seconds, int(seconds * sr), endpoint=False) sig = ( 0.5 * np.sin(2 * np.pi * 220 * t) + 0.3 * np.sin(2 * np.pi * 880 * t) ).astype(np.float32) return sig def test_fbank_matches_torchaudio_within_1pct(): sig = _fixed_signal() # torchaudio reference: same params as pyannote WeSpeaker sig_torch = torch.from_numpy(sig).unsqueeze(0) * (1 << 15) ref = ta_kaldi.fbank( sig_torch, num_mel_bins=80, frame_length=25, frame_shift=10, dither=0.0, window_type="hamming", use_energy=False, sample_frequency=16000, ).numpy() # (T, 80) # Our MLX implementation, with same scaling and CMN sig_mx = mx.array(sig) out = kaldi_fbank( sig_mx, num_mel_bins=80, frame_length_ms=25, frame_shift_ms=10, dither=0.0, window_type="hamming", use_energy=False, sample_rate=16000, ) out_np = np.asarray(out) assert out_np.shape == ref.shape # max abs diff should be small (kaldi-compliant, no random init) diff = np.abs(out_np - ref).max() assert diff < 0.05, f"max abs diff {diff:.4f}" def test_load_waveform_resamples_to_16k(): import soundfile as sf import tempfile, os with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: sf.write(f.name, _fixed_signal(seconds=1.0, sr=22050), 22050) wav_mx = load_waveform(f.name) os.unlink(f.name) assert wav_mx.shape[-1] == 16000 # 1 second @ 16k after resample assert wav_mx.dtype == mx.float32