Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical at 6 decimals on 200 cross-window slot pairs). 2.5x faster than pyannote-MPS on Apple Silicon native. Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
40 lines
1.1 KiB
Python
40 lines
1.1 KiB
Python
import numpy as np
|
|
import mlx.core as mx
|
|
from pyannote_diarization_3_1_mlx.powerset import Powerset, POWERSET_3_2_MAPPING
|
|
|
|
|
|
def test_static_mapping_matches_pyannote():
|
|
assert POWERSET_3_2_MAPPING.shape == (7, 3)
|
|
expected = np.array([
|
|
[0, 0, 0],
|
|
[1, 0, 0],
|
|
[0, 1, 0],
|
|
[0, 0, 1],
|
|
[1, 1, 0],
|
|
[1, 0, 1],
|
|
[0, 1, 1],
|
|
], dtype=np.float32)
|
|
np.testing.assert_array_equal(POWERSET_3_2_MAPPING, expected)
|
|
|
|
|
|
def test_to_multilabel_hard_argmax():
|
|
p = Powerset()
|
|
# frame 0 → class 1 (S1 only), frame 1 → class 4 (S1+S2), frame 2 → class 0
|
|
logits = mx.array([
|
|
[0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
|
[0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0],
|
|
[5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
|
])
|
|
out = p.to_multilabel(logits)
|
|
out_np = np.asarray(out)
|
|
np.testing.assert_array_equal(out_np[0], [1, 0, 0])
|
|
np.testing.assert_array_equal(out_np[1], [1, 1, 0])
|
|
np.testing.assert_array_equal(out_np[2], [0, 0, 0])
|
|
|
|
|
|
def test_to_multilabel_shape():
|
|
p = Powerset()
|
|
logits = mx.zeros((589, 7))
|
|
out = p.to_multilabel(logits)
|
|
assert out.shape == (589, 3)
|