import mlx.core as mx from pyannote_diarization_3_1_mlx._bilstm import BiLSTM4 def test_bilstm_output_shape(): # input (B, T, hidden_in) — pyannote feeds 60-channel sincnet output # transposed to (B, T, 60). hidden=128, bidirectional → 256 out. net = BiLSTM4(input_size=60, hidden_size=128) x = mx.zeros((1, 589, 60)) out = net(x) assert out.shape == (1, 589, 256), f"got {out.shape}"