import mlx.core as mx from pyannote_diarization_3_1_mlx.embedding import EmbeddingModel from pyannote_diarization_3_1_mlx._config import EMB_DIM def test_embedding_output_shape(): m = EmbeddingModel() fb = mx.zeros((2, 200, 80)) # (B, T, mel) weights = mx.ones((2, 200)) emb = m(fb, weights) assert emb.shape == (2, EMB_DIM), f"got {emb.shape}"