Files
pyannote-speaker-diarizatio…/tests/unit/test_diar_clustering.py
transcrilive 2b1a3c1312 feat: initial public release v0.1.0 — MLX port of pyannote-speaker-diarization-3.1
Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical
at 6 decimals on 200 cross-window slot pairs). 2.5x faster than
pyannote-MPS on Apple Silicon native.

Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
2026-05-09 16:05:39 +02:00

22 lines
820 B
Python

import numpy as np
from pyannote_diarization_3_1_mlx.clustering import cluster_embeddings
def test_two_well_separated_clusters():
rng = np.random.default_rng(42)
a = rng.normal(loc=[1.0, 0.0, 0.0] + [0.0]*253, scale=0.01, size=(10, 256))
b = rng.normal(loc=[0.0, 1.0, 0.0] + [0.0]*253, scale=0.01, size=(10, 256))
emb = np.vstack([a, b]).astype(np.float32)
labels = cluster_embeddings(emb, num_speakers=2)
assert len(set(labels[:10])) == 1
assert len(set(labels[10:])) == 1
assert labels[0] != labels[10]
def test_threshold_based():
rng = np.random.default_rng(0)
emb = rng.normal(size=(30, 256)).astype(np.float32)
labels = cluster_embeddings(emb, num_speakers=None,
min_speakers=1, max_speakers=10)
assert 1 <= len(set(labels)) <= 10