feat: initial public release v0.1.0 — MLX port of pyannote-speaker-diarization-3.1

Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical
at 6 decimals on 200 cross-window slot pairs). 2.5x faster than
pyannote-MPS on Apple Silicon native.

Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
This commit is contained in:
transcrilive
2026-05-09 16:05:39 +02:00
commit 2b1a3c1312
30 changed files with 2022 additions and 0 deletions

View File

@@ -0,0 +1,59 @@
import numpy as np
import mlx.core as mx
import torch
from torchaudio.compliance import kaldi as ta_kaldi
from pyannote_diarization_3_1_mlx.audio import kaldi_fbank, load_waveform
def _fixed_signal(seconds: float = 3.0, sr: int = 16000):
t = np.linspace(0, seconds, int(seconds * sr), endpoint=False)
sig = (
0.5 * np.sin(2 * np.pi * 220 * t)
+ 0.3 * np.sin(2 * np.pi * 880 * t)
).astype(np.float32)
return sig
def test_fbank_matches_torchaudio_within_1pct():
sig = _fixed_signal()
# torchaudio reference: same params as pyannote WeSpeaker
sig_torch = torch.from_numpy(sig).unsqueeze(0) * (1 << 15)
ref = ta_kaldi.fbank(
sig_torch,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
window_type="hamming",
use_energy=False,
sample_frequency=16000,
).numpy() # (T, 80)
# Our MLX implementation, with same scaling and CMN
sig_mx = mx.array(sig)
out = kaldi_fbank(
sig_mx,
num_mel_bins=80,
frame_length_ms=25,
frame_shift_ms=10,
dither=0.0,
window_type="hamming",
use_energy=False,
sample_rate=16000,
)
out_np = np.asarray(out)
assert out_np.shape == ref.shape
# max abs diff should be small (kaldi-compliant, no random init)
diff = np.abs(out_np - ref).max()
assert diff < 0.05, f"max abs diff {diff:.4f}"
def test_load_waveform_resamples_to_16k():
import soundfile as sf
import tempfile, os
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
sf.write(f.name, _fixed_signal(seconds=1.0, sr=22050), 22050)
wav_mx = load_waveform(f.name)
os.unlink(f.name)
assert wav_mx.shape[-1] == 16000 # 1 second @ 16k after resample
assert wav_mx.dtype == mx.float32

View File

@@ -0,0 +1,11 @@
import mlx.core as mx
from pyannote_diarization_3_1_mlx._bilstm import BiLSTM4
def test_bilstm_output_shape():
# input (B, T, hidden_in) — pyannote feeds 60-channel sincnet output
# transposed to (B, T, 60). hidden=128, bidirectional → 256 out.
net = BiLSTM4(input_size=60, hidden_size=128)
x = mx.zeros((1, 589, 60))
out = net(x)
assert out.shape == (1, 589, 256), f"got {out.shape}"

View File

@@ -0,0 +1,21 @@
import numpy as np
from pyannote_diarization_3_1_mlx.clustering import cluster_embeddings
def test_two_well_separated_clusters():
rng = np.random.default_rng(42)
a = rng.normal(loc=[1.0, 0.0, 0.0] + [0.0]*253, scale=0.01, size=(10, 256))
b = rng.normal(loc=[0.0, 1.0, 0.0] + [0.0]*253, scale=0.01, size=(10, 256))
emb = np.vstack([a, b]).astype(np.float32)
labels = cluster_embeddings(emb, num_speakers=2)
assert len(set(labels[:10])) == 1
assert len(set(labels[10:])) == 1
assert labels[0] != labels[10]
def test_threshold_based():
rng = np.random.default_rng(0)
emb = rng.normal(size=(30, 256)).astype(np.float32)
labels = cluster_embeddings(emb, num_speakers=None,
min_speakers=1, max_speakers=10)
assert 1 <= len(set(labels)) <= 10

View File

@@ -0,0 +1,28 @@
from pyannote_diarization_3_1_mlx._config import (
SEG_DURATION, SEG_HOP, SEG_FRAMES, SEG_CLASSES,
MAX_SPEAKERS_PER_CHUNK, MAX_SPEAKERS_PER_FRAME,
EMB_BATCH_SIZE, EMB_EXCLUDE_OVERLAP,
CLUSTER_METHOD, CLUSTER_THRESHOLD, CLUSTER_MIN_SIZE,
SEG_HF_REPO, SEG_HF_REV, EMB_HF_REPO, EMB_HF_REV,
)
def test_pyannote_3_1_locked_hyperparameters():
assert SEG_DURATION == 10.0
assert SEG_HOP == 1.0
assert SEG_FRAMES == 589
assert SEG_CLASSES == 7
assert MAX_SPEAKERS_PER_CHUNK == 3
assert MAX_SPEAKERS_PER_FRAME == 2
assert EMB_BATCH_SIZE == 32
assert EMB_EXCLUDE_OVERLAP is True
assert CLUSTER_METHOD == "centroid"
assert CLUSTER_THRESHOLD == 0.7045654963945799
assert CLUSTER_MIN_SIZE == 12
def test_locked_hf_revisions():
assert SEG_HF_REPO == "mlx-community/pyannote-segmentation-3.0-mlx"
assert SEG_HF_REV == "5189a69b35c5f7e48082a978f3476bac81590874"
assert EMB_HF_REPO == "mlx-community/wespeaker-voxceleb-resnet34-LM"
assert EMB_HF_REV == "97fc9343d2cfd0ae4d1c1d8c299e0046aa502e31"

View File

@@ -0,0 +1,11 @@
import mlx.core as mx
from pyannote_diarization_3_1_mlx.embedding import EmbeddingModel
from pyannote_diarization_3_1_mlx._config import EMB_DIM
def test_embedding_output_shape():
m = EmbeddingModel()
fb = mx.zeros((2, 200, 80)) # (B, T, mel)
weights = mx.ones((2, 200))
emb = m(fb, weights)
assert emb.shape == (2, EMB_DIM), f"got {emb.shape}"

View File

@@ -0,0 +1,25 @@
"""Smoke test for MlxDiarizationPipeline orchestrator.
Mocks all sub-components so no HF downloads or real inference is needed.
30 s of silence → powerset returns all zeros → no active speaker slots → empty annotation.
"""
import numpy as np
import mlx.core as mx
from unittest.mock import MagicMock
from pyannote_diarization_3_1_mlx.pipeline import MlxDiarizationPipeline
def test_pipeline_smoke_on_30s_zeros(mocker):
p = MlxDiarizationPipeline.__new__(MlxDiarizationPipeline)
p._segmentation = MagicMock()
p._embedding = MagicMock()
# mock seg → all class 0 (silence) → no slots → empty annotation
p._segmentation.return_value = mx.zeros((1, 589, 7))
p._powerset = MagicMock()
p._powerset.to_multilabel.return_value = mx.zeros((589, 3))
p._embedding.return_value = mx.ones((1, 256))
# 30 s of silence
audio = np.zeros(30 * 16000, dtype=np.float32)
annotation = p({"waveform": mx.array(audio)[None, :], "sample_rate": 16000})
# silence → no turns
assert len(list(annotation.itertracks())) == 0

View File

@@ -0,0 +1,39 @@
import numpy as np
import mlx.core as mx
from pyannote_diarization_3_1_mlx.powerset import Powerset, POWERSET_3_2_MAPPING
def test_static_mapping_matches_pyannote():
assert POWERSET_3_2_MAPPING.shape == (7, 3)
expected = np.array([
[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[1, 1, 0],
[1, 0, 1],
[0, 1, 1],
], dtype=np.float32)
np.testing.assert_array_equal(POWERSET_3_2_MAPPING, expected)
def test_to_multilabel_hard_argmax():
p = Powerset()
# frame 0 → class 1 (S1 only), frame 1 → class 4 (S1+S2), frame 2 → class 0
logits = mx.array([
[0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0],
[5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
])
out = p.to_multilabel(logits)
out_np = np.asarray(out)
np.testing.assert_array_equal(out_np[0], [1, 0, 0])
np.testing.assert_array_equal(out_np[1], [1, 1, 0])
np.testing.assert_array_equal(out_np[2], [0, 0, 0])
def test_to_multilabel_shape():
p = Powerset()
logits = mx.zeros((589, 7))
out = p.to_multilabel(logits)
assert out.shape == (589, 3)

View File

@@ -0,0 +1,9 @@
"""Unit test: load SegmentationModel weights from HF mlx-community repo."""
import pytest
from pyannote_diarization_3_1_mlx.segmentation import SegmentationModel
@pytest.mark.integration
def test_segmentation_loads_from_hf():
m = SegmentationModel.from_hf()
assert m is not None

View File

@@ -0,0 +1,9 @@
import mlx.core as mx
from pyannote_diarization_3_1_mlx.segmentation import SegmentationModel
def test_segmentation_full_shape():
m = SegmentationModel()
x = mx.zeros((1, 1, 160000)) # 10s @ 16k mono
out = m(x)
assert out.shape == (1, 589, 7), f"got {out.shape}"

View File

@@ -0,0 +1,12 @@
import mlx.core as mx
from pyannote_diarization_3_1_mlx._sincnet import SincNet
def test_sincnet_output_shape_589_frames():
"""For pyannote 3.1, 10s @ 16kHz input → 589 frames out."""
net = SincNet(sample_rate=16000)
x = mx.zeros((1, 1, 16000 * 10)) # (B, C, T)
out = net(x)
# Expect (1, 60, 589) per upstream PyanNet.SincNet output
assert out.shape[-1] == 589, f"got frames={out.shape[-1]}"
assert out.shape[1] == 60, f"got channels={out.shape[1]}"

View File

@@ -0,0 +1,16 @@
from pyannote_diarization_3_1_mlx._window import sliding_windows
import numpy as np
def test_sliding_windows_full_coverage():
sr = 16000
audio = np.zeros(int(25 * sr), dtype=np.float32)
windows = list(sliding_windows(audio, sr=sr, duration_s=10.0, hop_s=1.0))
# Expect (25-10)/1 + 1 = 16 windows, all 10 s long
assert len(windows) == 16
for start, end, slice_ in windows:
assert end - start == 10.0
assert len(slice_) == 10 * sr
# boundaries
assert windows[0][0] == 0.0
assert windows[-1][1] == 25.0