import mlx.core as mx from pyannote_diarization_3_1_mlx._sincnet import SincNet def test_sincnet_output_shape_589_frames(): """For pyannote 3.1, 10s @ 16kHz input → 589 frames out.""" net = SincNet(sample_rate=16000) x = mx.zeros((1, 1, 16000 * 10)) # (B, C, T) out = net(x) # Expect (1, 60, 589) per upstream PyanNet.SincNet output assert out.shape[-1] == 589, f"got frames={out.shape[-1]}" assert out.shape[1] == 60, f"got channels={out.shape[1]}"