import numpy as np import mlx.core as mx from pyannote_diarization_3_1_mlx.powerset import Powerset, POWERSET_3_2_MAPPING def test_static_mapping_matches_pyannote(): assert POWERSET_3_2_MAPPING.shape == (7, 3) expected = np.array([ [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1], [0, 1, 1], ], dtype=np.float32) np.testing.assert_array_equal(POWERSET_3_2_MAPPING, expected) def test_to_multilabel_hard_argmax(): p = Powerset() # frame 0 → class 1 (S1 only), frame 1 → class 4 (S1+S2), frame 2 → class 0 logits = mx.array([ [0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0], [5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ]) out = p.to_multilabel(logits) out_np = np.asarray(out) np.testing.assert_array_equal(out_np[0], [1, 0, 0]) np.testing.assert_array_equal(out_np[1], [1, 1, 0]) np.testing.assert_array_equal(out_np[2], [0, 0, 0]) def test_to_multilabel_shape(): p = Powerset() logits = mx.zeros((589, 7)) out = p.to_multilabel(logits) assert out.shape == (589, 3)