import mlx.core as mx from pyannote_diarization_3_1_mlx.segmentation import SegmentationModel def test_segmentation_full_shape(): m = SegmentationModel() x = mx.zeros((1, 1, 160000)) # 10s @ 16k mono out = m(x) assert out.shape == (1, 589, 7), f"got {out.shape}"