Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical at 6 decimals on 200 cross-window slot pairs). 2.5x faster than pyannote-MPS on Apple Silicon native. Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
13 lines
483 B
Python
13 lines
483 B
Python
import mlx.core as mx
|
|
from pyannote_diarization_3_1_mlx._sincnet import SincNet
|
|
|
|
|
|
def test_sincnet_output_shape_589_frames():
|
|
"""For pyannote 3.1, 10s @ 16kHz input → 589 frames out."""
|
|
net = SincNet(sample_rate=16000)
|
|
x = mx.zeros((1, 1, 16000 * 10)) # (B, C, T)
|
|
out = net(x)
|
|
# Expect (1, 60, 589) per upstream PyanNet.SincNet output
|
|
assert out.shape[-1] == 589, f"got frames={out.shape[-1]}"
|
|
assert out.shape[1] == 60, f"got channels={out.shape[1]}"
|