Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical at 6 decimals on 200 cross-window slot pairs). 2.5x faster than pyannote-MPS on Apple Silicon native. Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
22 lines
820 B
Python
22 lines
820 B
Python
import numpy as np
|
|
from pyannote_diarization_3_1_mlx.clustering import cluster_embeddings
|
|
|
|
|
|
def test_two_well_separated_clusters():
|
|
rng = np.random.default_rng(42)
|
|
a = rng.normal(loc=[1.0, 0.0, 0.0] + [0.0]*253, scale=0.01, size=(10, 256))
|
|
b = rng.normal(loc=[0.0, 1.0, 0.0] + [0.0]*253, scale=0.01, size=(10, 256))
|
|
emb = np.vstack([a, b]).astype(np.float32)
|
|
labels = cluster_embeddings(emb, num_speakers=2)
|
|
assert len(set(labels[:10])) == 1
|
|
assert len(set(labels[10:])) == 1
|
|
assert labels[0] != labels[10]
|
|
|
|
|
|
def test_threshold_based():
|
|
rng = np.random.default_rng(0)
|
|
emb = rng.normal(size=(30, 256)).astype(np.float32)
|
|
labels = cluster_embeddings(emb, num_speakers=None,
|
|
min_speakers=1, max_speakers=10)
|
|
assert 1 <= len(set(labels)) <= 10
|