Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical at 6 decimals on 200 cross-window slot pairs). 2.5x faster than pyannote-MPS on Apple Silicon native. Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
12 lines
368 B
Python
12 lines
368 B
Python
import mlx.core as mx
|
|
from pyannote_diarization_3_1_mlx.embedding import EmbeddingModel
|
|
from pyannote_diarization_3_1_mlx._config import EMB_DIM
|
|
|
|
|
|
def test_embedding_output_shape():
|
|
m = EmbeddingModel()
|
|
fb = mx.zeros((2, 200, 80)) # (B, T, mel)
|
|
weights = mx.ones((2, 200))
|
|
emb = m(fb, weights)
|
|
assert emb.shape == (2, EMB_DIM), f"got {emb.shape}"
|