Files
pyannote-speaker-diarizatio…/tests/unit/test_diar_powerset.py
transcrilive 2b1a3c1312 feat: initial public release v0.1.0 — MLX port of pyannote-speaker-diarization-3.1
Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical
at 6 decimals on 200 cross-window slot pairs). 2.5x faster than
pyannote-MPS on Apple Silicon native.

Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
2026-05-09 16:05:39 +02:00

40 lines
1.1 KiB
Python

import numpy as np
import mlx.core as mx
from pyannote_diarization_3_1_mlx.powerset import Powerset, POWERSET_3_2_MAPPING
def test_static_mapping_matches_pyannote():
assert POWERSET_3_2_MAPPING.shape == (7, 3)
expected = np.array([
[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[1, 1, 0],
[1, 0, 1],
[0, 1, 1],
], dtype=np.float32)
np.testing.assert_array_equal(POWERSET_3_2_MAPPING, expected)
def test_to_multilabel_hard_argmax():
p = Powerset()
# frame 0 → class 1 (S1 only), frame 1 → class 4 (S1+S2), frame 2 → class 0
logits = mx.array([
[0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0],
[5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
])
out = p.to_multilabel(logits)
out_np = np.asarray(out)
np.testing.assert_array_equal(out_np[0], [1, 0, 0])
np.testing.assert_array_equal(out_np[1], [1, 1, 0])
np.testing.assert_array_equal(out_np[2], [0, 0, 0])
def test_to_multilabel_shape():
p = Powerset()
logits = mx.zeros((589, 7))
out = p.to_multilabel(logits)
assert out.shape == (589, 3)