feat: initial public release v0.1.0 — MLX port of pyannote-speaker-diarization-3.1

Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical
at 6 decimals on 200 cross-window slot pairs). 2.5x faster than
pyannote-MPS on Apple Silicon native.

Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
This commit is contained in:
transcrilive
2026-05-09 16:05:39 +02:00
commit 2b1a3c1312
30 changed files with 2022 additions and 0 deletions

View File

@@ -0,0 +1,54 @@
"""Day-1 sanity gate. If this fails, do NOT spend further time on Plan A."""
import time
import numpy as np
import librosa
import pytest
import soundfile as sf
import torch
from pyannote.audio import Pipeline
from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline
@pytest.mark.integration
def test_diar_60s_parity_vs_pyannote():
audio_path = "/tmp/_diar_smoke_60s.wav"
# use any 60s slice of the existing test audio
sig, _ = librosa.load("/tmp/audio_first_3min.wav", sr=16000, duration=60)
sf.write(audio_path, sig, 16000)
# MLX pipeline
mlx_pipe = MlxDiarizationPipeline.from_pretrained()
mlx_ann = mlx_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
"sample_rate": 16000},
min_speakers=1, max_speakers=3)
mlx_speakers = set(mlx_ann.labels())
# pyannote PyTorch reference
ref_pipe = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
ref_out = ref_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
"sample_rate": 16000},
min_speakers=1, max_speakers=3)
# pyannote 3.x returns the annotation directly
if hasattr(ref_out, "exclusive_speaker_diarization"):
ref_ann = ref_out.exclusive_speaker_diarization
else:
ref_ann = ref_out
ref_speakers = set(ref_ann.labels())
# gate: speaker count within ±1
assert abs(len(mlx_speakers) - len(ref_speakers)) <= 1, \
f"speaker count diff: mlx={len(mlx_speakers)} ref={len(ref_speakers)}"
# gate: DER < 0.30 (Hungarian-aligned)
from pyannote.metrics.diarization import DiarizationErrorRate
der = DiarizationErrorRate()
der_value = der(ref_ann, mlx_ann)
assert der_value <= 0.30, f"DER {der_value:.3f} > 0.30 (gate ≤ 0.30)"
# gate: wall-clock under 30s (MLX should be fast on M2/M3)
t0 = time.time()
mlx_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
"sample_rate": 16000})
wall = time.time() - t0
assert wall < 30, f"wall {wall:.1f}s > 30s for 60s audio"

View File

@@ -0,0 +1,59 @@
import numpy as np
import mlx.core as mx
import torch
from torchaudio.compliance import kaldi as ta_kaldi
from pyannote_diarization_3_1_mlx.audio import kaldi_fbank, load_waveform
def _fixed_signal(seconds: float = 3.0, sr: int = 16000):
t = np.linspace(0, seconds, int(seconds * sr), endpoint=False)
sig = (
0.5 * np.sin(2 * np.pi * 220 * t)
+ 0.3 * np.sin(2 * np.pi * 880 * t)
).astype(np.float32)
return sig
def test_fbank_matches_torchaudio_within_1pct():
sig = _fixed_signal()
# torchaudio reference: same params as pyannote WeSpeaker
sig_torch = torch.from_numpy(sig).unsqueeze(0) * (1 << 15)
ref = ta_kaldi.fbank(
sig_torch,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
window_type="hamming",
use_energy=False,
sample_frequency=16000,
).numpy() # (T, 80)
# Our MLX implementation, with same scaling and CMN
sig_mx = mx.array(sig)
out = kaldi_fbank(
sig_mx,
num_mel_bins=80,
frame_length_ms=25,
frame_shift_ms=10,
dither=0.0,
window_type="hamming",
use_energy=False,
sample_rate=16000,
)
out_np = np.asarray(out)
assert out_np.shape == ref.shape
# max abs diff should be small (kaldi-compliant, no random init)
diff = np.abs(out_np - ref).max()
assert diff < 0.05, f"max abs diff {diff:.4f}"
def test_load_waveform_resamples_to_16k():
import soundfile as sf
import tempfile, os
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
sf.write(f.name, _fixed_signal(seconds=1.0, sr=22050), 22050)
wav_mx = load_waveform(f.name)
os.unlink(f.name)
assert wav_mx.shape[-1] == 16000 # 1 second @ 16k after resample
assert wav_mx.dtype == mx.float32

View File

@@ -0,0 +1,11 @@
import mlx.core as mx
from pyannote_diarization_3_1_mlx._bilstm import BiLSTM4
def test_bilstm_output_shape():
# input (B, T, hidden_in) — pyannote feeds 60-channel sincnet output
# transposed to (B, T, 60). hidden=128, bidirectional → 256 out.
net = BiLSTM4(input_size=60, hidden_size=128)
x = mx.zeros((1, 589, 60))
out = net(x)
assert out.shape == (1, 589, 256), f"got {out.shape}"

View File

@@ -0,0 +1,21 @@
import numpy as np
from pyannote_diarization_3_1_mlx.clustering import cluster_embeddings
def test_two_well_separated_clusters():
rng = np.random.default_rng(42)
a = rng.normal(loc=[1.0, 0.0, 0.0] + [0.0]*253, scale=0.01, size=(10, 256))
b = rng.normal(loc=[0.0, 1.0, 0.0] + [0.0]*253, scale=0.01, size=(10, 256))
emb = np.vstack([a, b]).astype(np.float32)
labels = cluster_embeddings(emb, num_speakers=2)
assert len(set(labels[:10])) == 1
assert len(set(labels[10:])) == 1
assert labels[0] != labels[10]
def test_threshold_based():
rng = np.random.default_rng(0)
emb = rng.normal(size=(30, 256)).astype(np.float32)
labels = cluster_embeddings(emb, num_speakers=None,
min_speakers=1, max_speakers=10)
assert 1 <= len(set(labels)) <= 10

View File

@@ -0,0 +1,28 @@
from pyannote_diarization_3_1_mlx._config import (
SEG_DURATION, SEG_HOP, SEG_FRAMES, SEG_CLASSES,
MAX_SPEAKERS_PER_CHUNK, MAX_SPEAKERS_PER_FRAME,
EMB_BATCH_SIZE, EMB_EXCLUDE_OVERLAP,
CLUSTER_METHOD, CLUSTER_THRESHOLD, CLUSTER_MIN_SIZE,
SEG_HF_REPO, SEG_HF_REV, EMB_HF_REPO, EMB_HF_REV,
)
def test_pyannote_3_1_locked_hyperparameters():
assert SEG_DURATION == 10.0
assert SEG_HOP == 1.0
assert SEG_FRAMES == 589
assert SEG_CLASSES == 7
assert MAX_SPEAKERS_PER_CHUNK == 3
assert MAX_SPEAKERS_PER_FRAME == 2
assert EMB_BATCH_SIZE == 32
assert EMB_EXCLUDE_OVERLAP is True
assert CLUSTER_METHOD == "centroid"
assert CLUSTER_THRESHOLD == 0.7045654963945799
assert CLUSTER_MIN_SIZE == 12
def test_locked_hf_revisions():
assert SEG_HF_REPO == "mlx-community/pyannote-segmentation-3.0-mlx"
assert SEG_HF_REV == "5189a69b35c5f7e48082a978f3476bac81590874"
assert EMB_HF_REPO == "mlx-community/wespeaker-voxceleb-resnet34-LM"
assert EMB_HF_REV == "97fc9343d2cfd0ae4d1c1d8c299e0046aa502e31"

View File

@@ -0,0 +1,11 @@
import mlx.core as mx
from pyannote_diarization_3_1_mlx.embedding import EmbeddingModel
from pyannote_diarization_3_1_mlx._config import EMB_DIM
def test_embedding_output_shape():
m = EmbeddingModel()
fb = mx.zeros((2, 200, 80)) # (B, T, mel)
weights = mx.ones((2, 200))
emb = m(fb, weights)
assert emb.shape == (2, EMB_DIM), f"got {emb.shape}"

View File

@@ -0,0 +1,25 @@
"""Smoke test for MlxDiarizationPipeline orchestrator.
Mocks all sub-components so no HF downloads or real inference is needed.
30 s of silence → powerset returns all zeros → no active speaker slots → empty annotation.
"""
import numpy as np
import mlx.core as mx
from unittest.mock import MagicMock
from pyannote_diarization_3_1_mlx.pipeline import MlxDiarizationPipeline
def test_pipeline_smoke_on_30s_zeros(mocker):
p = MlxDiarizationPipeline.__new__(MlxDiarizationPipeline)
p._segmentation = MagicMock()
p._embedding = MagicMock()
# mock seg → all class 0 (silence) → no slots → empty annotation
p._segmentation.return_value = mx.zeros((1, 589, 7))
p._powerset = MagicMock()
p._powerset.to_multilabel.return_value = mx.zeros((589, 3))
p._embedding.return_value = mx.ones((1, 256))
# 30 s of silence
audio = np.zeros(30 * 16000, dtype=np.float32)
annotation = p({"waveform": mx.array(audio)[None, :], "sample_rate": 16000})
# silence → no turns
assert len(list(annotation.itertracks())) == 0

View File

@@ -0,0 +1,39 @@
import numpy as np
import mlx.core as mx
from pyannote_diarization_3_1_mlx.powerset import Powerset, POWERSET_3_2_MAPPING
def test_static_mapping_matches_pyannote():
assert POWERSET_3_2_MAPPING.shape == (7, 3)
expected = np.array([
[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[1, 1, 0],
[1, 0, 1],
[0, 1, 1],
], dtype=np.float32)
np.testing.assert_array_equal(POWERSET_3_2_MAPPING, expected)
def test_to_multilabel_hard_argmax():
p = Powerset()
# frame 0 → class 1 (S1 only), frame 1 → class 4 (S1+S2), frame 2 → class 0
logits = mx.array([
[0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0],
[5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
])
out = p.to_multilabel(logits)
out_np = np.asarray(out)
np.testing.assert_array_equal(out_np[0], [1, 0, 0])
np.testing.assert_array_equal(out_np[1], [1, 1, 0])
np.testing.assert_array_equal(out_np[2], [0, 0, 0])
def test_to_multilabel_shape():
p = Powerset()
logits = mx.zeros((589, 7))
out = p.to_multilabel(logits)
assert out.shape == (589, 3)

View File

@@ -0,0 +1,9 @@
"""Unit test: load SegmentationModel weights from HF mlx-community repo."""
import pytest
from pyannote_diarization_3_1_mlx.segmentation import SegmentationModel
@pytest.mark.integration
def test_segmentation_loads_from_hf():
m = SegmentationModel.from_hf()
assert m is not None

View File

@@ -0,0 +1,9 @@
import mlx.core as mx
from pyannote_diarization_3_1_mlx.segmentation import SegmentationModel
def test_segmentation_full_shape():
m = SegmentationModel()
x = mx.zeros((1, 1, 160000)) # 10s @ 16k mono
out = m(x)
assert out.shape == (1, 589, 7), f"got {out.shape}"

View File

@@ -0,0 +1,12 @@
import mlx.core as mx
from pyannote_diarization_3_1_mlx._sincnet import SincNet
def test_sincnet_output_shape_589_frames():
"""For pyannote 3.1, 10s @ 16kHz input → 589 frames out."""
net = SincNet(sample_rate=16000)
x = mx.zeros((1, 1, 16000 * 10)) # (B, C, T)
out = net(x)
# Expect (1, 60, 589) per upstream PyanNet.SincNet output
assert out.shape[-1] == 589, f"got frames={out.shape[-1]}"
assert out.shape[1] == 60, f"got channels={out.shape[1]}"

View File

@@ -0,0 +1,16 @@
from pyannote_diarization_3_1_mlx._window import sliding_windows
import numpy as np
def test_sliding_windows_full_coverage():
sr = 16000
audio = np.zeros(int(25 * sr), dtype=np.float32)
windows = list(sliding_windows(audio, sr=sr, duration_s=10.0, hop_s=1.0))
# Expect (25-10)/1 + 1 = 16 windows, all 10 s long
assert len(windows) == 16
for start, end, slice_ in windows:
assert end - start == 10.0
assert len(slice_) == 10 * sr
# boundaries
assert windows[0][0] == 0.0
assert windows[-1][1] == 25.0