feat: initial public release v0.1.0 — MLX port of pyannote-speaker-diarization-3.1
Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical at 6 decimals on 200 cross-window slot pairs). 2.5x faster than pyannote-MPS on Apple Silicon native. Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
This commit is contained in:
21
tests/unit/test_diar_clustering.py
Normal file
21
tests/unit/test_diar_clustering.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import numpy as np
|
||||
from pyannote_diarization_3_1_mlx.clustering import cluster_embeddings
|
||||
|
||||
|
||||
def test_two_well_separated_clusters():
|
||||
rng = np.random.default_rng(42)
|
||||
a = rng.normal(loc=[1.0, 0.0, 0.0] + [0.0]*253, scale=0.01, size=(10, 256))
|
||||
b = rng.normal(loc=[0.0, 1.0, 0.0] + [0.0]*253, scale=0.01, size=(10, 256))
|
||||
emb = np.vstack([a, b]).astype(np.float32)
|
||||
labels = cluster_embeddings(emb, num_speakers=2)
|
||||
assert len(set(labels[:10])) == 1
|
||||
assert len(set(labels[10:])) == 1
|
||||
assert labels[0] != labels[10]
|
||||
|
||||
|
||||
def test_threshold_based():
|
||||
rng = np.random.default_rng(0)
|
||||
emb = rng.normal(size=(30, 256)).astype(np.float32)
|
||||
labels = cluster_embeddings(emb, num_speakers=None,
|
||||
min_speakers=1, max_speakers=10)
|
||||
assert 1 <= len(set(labels)) <= 10
|
||||
Reference in New Issue
Block a user