feat: initial public release v0.1.0 — MLX port of pyannote-speaker-diarization-3.1
Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical at 6 decimals on 200 cross-window slot pairs). 2.5x faster than pyannote-MPS on Apple Silicon native. Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
This commit is contained in:
39
tests/unit/test_diar_powerset.py
Normal file
39
tests/unit/test_diar_powerset.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import numpy as np
|
||||
import mlx.core as mx
|
||||
from pyannote_diarization_3_1_mlx.powerset import Powerset, POWERSET_3_2_MAPPING
|
||||
|
||||
|
||||
def test_static_mapping_matches_pyannote():
|
||||
assert POWERSET_3_2_MAPPING.shape == (7, 3)
|
||||
expected = np.array([
|
||||
[0, 0, 0],
|
||||
[1, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1],
|
||||
[1, 1, 0],
|
||||
[1, 0, 1],
|
||||
[0, 1, 1],
|
||||
], dtype=np.float32)
|
||||
np.testing.assert_array_equal(POWERSET_3_2_MAPPING, expected)
|
||||
|
||||
|
||||
def test_to_multilabel_hard_argmax():
|
||||
p = Powerset()
|
||||
# frame 0 → class 1 (S1 only), frame 1 → class 4 (S1+S2), frame 2 → class 0
|
||||
logits = mx.array([
|
||||
[0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0],
|
||||
[5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
])
|
||||
out = p.to_multilabel(logits)
|
||||
out_np = np.asarray(out)
|
||||
np.testing.assert_array_equal(out_np[0], [1, 0, 0])
|
||||
np.testing.assert_array_equal(out_np[1], [1, 1, 0])
|
||||
np.testing.assert_array_equal(out_np[2], [0, 0, 0])
|
||||
|
||||
|
||||
def test_to_multilabel_shape():
|
||||
p = Powerset()
|
||||
logits = mx.zeros((589, 7))
|
||||
out = p.to_multilabel(logits)
|
||||
assert out.shape == (589, 3)
|
||||
Reference in New Issue
Block a user