feat: initial public release v0.1.0 — MLX port of pyannote-speaker-diarization-3.1
Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical at 6 decimals on 200 cross-window slot pairs). 2.5x faster than pyannote-MPS on Apple Silicon native. Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
This commit is contained in:
59
tests/unit/test_diar_audio_fbank.py
Normal file
59
tests/unit/test_diar_audio_fbank.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
import torch
|
||||
from torchaudio.compliance import kaldi as ta_kaldi
|
||||
from pyannote_diarization_3_1_mlx.audio import kaldi_fbank, load_waveform
|
||||
|
||||
|
||||
def _fixed_signal(seconds: float = 3.0, sr: int = 16000):
|
||||
t = np.linspace(0, seconds, int(seconds * sr), endpoint=False)
|
||||
sig = (
|
||||
0.5 * np.sin(2 * np.pi * 220 * t)
|
||||
+ 0.3 * np.sin(2 * np.pi * 880 * t)
|
||||
).astype(np.float32)
|
||||
return sig
|
||||
|
||||
|
||||
def test_fbank_matches_torchaudio_within_1pct():
|
||||
sig = _fixed_signal()
|
||||
# torchaudio reference: same params as pyannote WeSpeaker
|
||||
sig_torch = torch.from_numpy(sig).unsqueeze(0) * (1 << 15)
|
||||
ref = ta_kaldi.fbank(
|
||||
sig_torch,
|
||||
num_mel_bins=80,
|
||||
frame_length=25,
|
||||
frame_shift=10,
|
||||
dither=0.0,
|
||||
window_type="hamming",
|
||||
use_energy=False,
|
||||
sample_frequency=16000,
|
||||
).numpy() # (T, 80)
|
||||
|
||||
# Our MLX implementation, with same scaling and CMN
|
||||
sig_mx = mx.array(sig)
|
||||
out = kaldi_fbank(
|
||||
sig_mx,
|
||||
num_mel_bins=80,
|
||||
frame_length_ms=25,
|
||||
frame_shift_ms=10,
|
||||
dither=0.0,
|
||||
window_type="hamming",
|
||||
use_energy=False,
|
||||
sample_rate=16000,
|
||||
)
|
||||
out_np = np.asarray(out)
|
||||
assert out_np.shape == ref.shape
|
||||
# max abs diff should be small (kaldi-compliant, no random init)
|
||||
diff = np.abs(out_np - ref).max()
|
||||
assert diff < 0.05, f"max abs diff {diff:.4f}"
|
||||
|
||||
|
||||
def test_load_waveform_resamples_to_16k():
|
||||
import soundfile as sf
|
||||
import tempfile, os
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
sf.write(f.name, _fixed_signal(seconds=1.0, sr=22050), 22050)
|
||||
wav_mx = load_waveform(f.name)
|
||||
os.unlink(f.name)
|
||||
assert wav_mx.shape[-1] == 16000 # 1 second @ 16k after resample
|
||||
assert wav_mx.dtype == mx.float32
|
||||
11
tests/unit/test_diar_bilstm.py
Normal file
11
tests/unit/test_diar_bilstm.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import mlx.core as mx
|
||||
from pyannote_diarization_3_1_mlx._bilstm import BiLSTM4
|
||||
|
||||
|
||||
def test_bilstm_output_shape():
|
||||
# input (B, T, hidden_in) — pyannote feeds 60-channel sincnet output
|
||||
# transposed to (B, T, 60). hidden=128, bidirectional → 256 out.
|
||||
net = BiLSTM4(input_size=60, hidden_size=128)
|
||||
x = mx.zeros((1, 589, 60))
|
||||
out = net(x)
|
||||
assert out.shape == (1, 589, 256), f"got {out.shape}"
|
||||
21
tests/unit/test_diar_clustering.py
Normal file
21
tests/unit/test_diar_clustering.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import numpy as np
|
||||
from pyannote_diarization_3_1_mlx.clustering import cluster_embeddings
|
||||
|
||||
|
||||
def test_two_well_separated_clusters():
|
||||
rng = np.random.default_rng(42)
|
||||
a = rng.normal(loc=[1.0, 0.0, 0.0] + [0.0]*253, scale=0.01, size=(10, 256))
|
||||
b = rng.normal(loc=[0.0, 1.0, 0.0] + [0.0]*253, scale=0.01, size=(10, 256))
|
||||
emb = np.vstack([a, b]).astype(np.float32)
|
||||
labels = cluster_embeddings(emb, num_speakers=2)
|
||||
assert len(set(labels[:10])) == 1
|
||||
assert len(set(labels[10:])) == 1
|
||||
assert labels[0] != labels[10]
|
||||
|
||||
|
||||
def test_threshold_based():
|
||||
rng = np.random.default_rng(0)
|
||||
emb = rng.normal(size=(30, 256)).astype(np.float32)
|
||||
labels = cluster_embeddings(emb, num_speakers=None,
|
||||
min_speakers=1, max_speakers=10)
|
||||
assert 1 <= len(set(labels)) <= 10
|
||||
28
tests/unit/test_diar_config.py
Normal file
28
tests/unit/test_diar_config.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from pyannote_diarization_3_1_mlx._config import (
|
||||
SEG_DURATION, SEG_HOP, SEG_FRAMES, SEG_CLASSES,
|
||||
MAX_SPEAKERS_PER_CHUNK, MAX_SPEAKERS_PER_FRAME,
|
||||
EMB_BATCH_SIZE, EMB_EXCLUDE_OVERLAP,
|
||||
CLUSTER_METHOD, CLUSTER_THRESHOLD, CLUSTER_MIN_SIZE,
|
||||
SEG_HF_REPO, SEG_HF_REV, EMB_HF_REPO, EMB_HF_REV,
|
||||
)
|
||||
|
||||
|
||||
def test_pyannote_3_1_locked_hyperparameters():
|
||||
assert SEG_DURATION == 10.0
|
||||
assert SEG_HOP == 1.0
|
||||
assert SEG_FRAMES == 589
|
||||
assert SEG_CLASSES == 7
|
||||
assert MAX_SPEAKERS_PER_CHUNK == 3
|
||||
assert MAX_SPEAKERS_PER_FRAME == 2
|
||||
assert EMB_BATCH_SIZE == 32
|
||||
assert EMB_EXCLUDE_OVERLAP is True
|
||||
assert CLUSTER_METHOD == "centroid"
|
||||
assert CLUSTER_THRESHOLD == 0.7045654963945799
|
||||
assert CLUSTER_MIN_SIZE == 12
|
||||
|
||||
|
||||
def test_locked_hf_revisions():
|
||||
assert SEG_HF_REPO == "mlx-community/pyannote-segmentation-3.0-mlx"
|
||||
assert SEG_HF_REV == "5189a69b35c5f7e48082a978f3476bac81590874"
|
||||
assert EMB_HF_REPO == "mlx-community/wespeaker-voxceleb-resnet34-LM"
|
||||
assert EMB_HF_REV == "97fc9343d2cfd0ae4d1c1d8c299e0046aa502e31"
|
||||
11
tests/unit/test_diar_embedding_shape.py
Normal file
11
tests/unit/test_diar_embedding_shape.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import mlx.core as mx
|
||||
from pyannote_diarization_3_1_mlx.embedding import EmbeddingModel
|
||||
from pyannote_diarization_3_1_mlx._config import EMB_DIM
|
||||
|
||||
|
||||
def test_embedding_output_shape():
|
||||
m = EmbeddingModel()
|
||||
fb = mx.zeros((2, 200, 80)) # (B, T, mel)
|
||||
weights = mx.ones((2, 200))
|
||||
emb = m(fb, weights)
|
||||
assert emb.shape == (2, EMB_DIM), f"got {emb.shape}"
|
||||
25
tests/unit/test_diar_pipeline_smoke.py
Normal file
25
tests/unit/test_diar_pipeline_smoke.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Smoke test for MlxDiarizationPipeline orchestrator.
|
||||
|
||||
Mocks all sub-components so no HF downloads or real inference is needed.
|
||||
30 s of silence → powerset returns all zeros → no active speaker slots → empty annotation.
|
||||
"""
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
from unittest.mock import MagicMock
|
||||
from pyannote_diarization_3_1_mlx.pipeline import MlxDiarizationPipeline
|
||||
|
||||
|
||||
def test_pipeline_smoke_on_30s_zeros(mocker):
|
||||
p = MlxDiarizationPipeline.__new__(MlxDiarizationPipeline)
|
||||
p._segmentation = MagicMock()
|
||||
p._embedding = MagicMock()
|
||||
# mock seg → all class 0 (silence) → no slots → empty annotation
|
||||
p._segmentation.return_value = mx.zeros((1, 589, 7))
|
||||
p._powerset = MagicMock()
|
||||
p._powerset.to_multilabel.return_value = mx.zeros((589, 3))
|
||||
p._embedding.return_value = mx.ones((1, 256))
|
||||
# 30 s of silence
|
||||
audio = np.zeros(30 * 16000, dtype=np.float32)
|
||||
annotation = p({"waveform": mx.array(audio)[None, :], "sample_rate": 16000})
|
||||
# silence → no turns
|
||||
assert len(list(annotation.itertracks())) == 0
|
||||
39
tests/unit/test_diar_powerset.py
Normal file
39
tests/unit/test_diar_powerset.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
from pyannote_diarization_3_1_mlx.powerset import Powerset, POWERSET_3_2_MAPPING
|
||||
|
||||
|
||||
def test_static_mapping_matches_pyannote():
|
||||
assert POWERSET_3_2_MAPPING.shape == (7, 3)
|
||||
expected = np.array([
|
||||
[0, 0, 0],
|
||||
[1, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1],
|
||||
[1, 1, 0],
|
||||
[1, 0, 1],
|
||||
[0, 1, 1],
|
||||
], dtype=np.float32)
|
||||
np.testing.assert_array_equal(POWERSET_3_2_MAPPING, expected)
|
||||
|
||||
|
||||
def test_to_multilabel_hard_argmax():
|
||||
p = Powerset()
|
||||
# frame 0 → class 1 (S1 only), frame 1 → class 4 (S1+S2), frame 2 → class 0
|
||||
logits = mx.array([
|
||||
[0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0],
|
||||
[5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
])
|
||||
out = p.to_multilabel(logits)
|
||||
out_np = np.asarray(out)
|
||||
np.testing.assert_array_equal(out_np[0], [1, 0, 0])
|
||||
np.testing.assert_array_equal(out_np[1], [1, 1, 0])
|
||||
np.testing.assert_array_equal(out_np[2], [0, 0, 0])
|
||||
|
||||
|
||||
def test_to_multilabel_shape():
|
||||
p = Powerset()
|
||||
logits = mx.zeros((589, 7))
|
||||
out = p.to_multilabel(logits)
|
||||
assert out.shape == (589, 3)
|
||||
9
tests/unit/test_diar_segmentation_load.py
Normal file
9
tests/unit/test_diar_segmentation_load.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Unit test: load SegmentationModel weights from HF mlx-community repo."""
|
||||
import pytest
|
||||
from pyannote_diarization_3_1_mlx.segmentation import SegmentationModel
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_segmentation_loads_from_hf():
|
||||
m = SegmentationModel.from_hf()
|
||||
assert m is not None
|
||||
9
tests/unit/test_diar_segmentation_shape.py
Normal file
9
tests/unit/test_diar_segmentation_shape.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import mlx.core as mx
|
||||
from pyannote_diarization_3_1_mlx.segmentation import SegmentationModel
|
||||
|
||||
|
||||
def test_segmentation_full_shape():
|
||||
m = SegmentationModel()
|
||||
x = mx.zeros((1, 1, 160000)) # 10s @ 16k mono
|
||||
out = m(x)
|
||||
assert out.shape == (1, 589, 7), f"got {out.shape}"
|
||||
12
tests/unit/test_diar_sincnet.py
Normal file
12
tests/unit/test_diar_sincnet.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import mlx.core as mx
|
||||
from pyannote_diarization_3_1_mlx._sincnet import SincNet
|
||||
|
||||
|
||||
def test_sincnet_output_shape_589_frames():
|
||||
"""For pyannote 3.1, 10s @ 16kHz input → 589 frames out."""
|
||||
net = SincNet(sample_rate=16000)
|
||||
x = mx.zeros((1, 1, 16000 * 10)) # (B, C, T)
|
||||
out = net(x)
|
||||
# Expect (1, 60, 589) per upstream PyanNet.SincNet output
|
||||
assert out.shape[-1] == 589, f"got frames={out.shape[-1]}"
|
||||
assert out.shape[1] == 60, f"got channels={out.shape[1]}"
|
||||
16
tests/unit/test_diar_window.py
Normal file
16
tests/unit/test_diar_window.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from pyannote_diarization_3_1_mlx._window import sliding_windows
|
||||
import numpy as np
|
||||
|
||||
|
||||
def test_sliding_windows_full_coverage():
|
||||
sr = 16000
|
||||
audio = np.zeros(int(25 * sr), dtype=np.float32)
|
||||
windows = list(sliding_windows(audio, sr=sr, duration_s=10.0, hop_s=1.0))
|
||||
# Expect (25-10)/1 + 1 = 16 windows, all 10 s long
|
||||
assert len(windows) == 16
|
||||
for start, end, slice_ in windows:
|
||||
assert end - start == 10.0
|
||||
assert len(slice_) == 10 * sr
|
||||
# boundaries
|
||||
assert windows[0][0] == 0.0
|
||||
assert windows[-1][1] == 25.0
|
||||
Reference in New Issue
Block a user