Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical at 6 decimals on 200 cross-window slot pairs). 2.5x faster than pyannote-MPS on Apple Silicon native. Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
26 lines
1.0 KiB
Python
26 lines
1.0 KiB
Python
"""Smoke test for MlxDiarizationPipeline orchestrator.
|
|
|
|
Mocks all sub-components so no HF downloads or real inference is needed.
|
|
30 s of silence → powerset returns all zeros → no active speaker slots → empty annotation.
|
|
"""
|
|
import numpy as np
|
|
import mlx.core as mx
|
|
from unittest.mock import MagicMock
|
|
from pyannote_diarization_3_1_mlx.pipeline import MlxDiarizationPipeline
|
|
|
|
|
|
def test_pipeline_smoke_on_30s_zeros(mocker):
|
|
p = MlxDiarizationPipeline.__new__(MlxDiarizationPipeline)
|
|
p._segmentation = MagicMock()
|
|
p._embedding = MagicMock()
|
|
# mock seg → all class 0 (silence) → no slots → empty annotation
|
|
p._segmentation.return_value = mx.zeros((1, 589, 7))
|
|
p._powerset = MagicMock()
|
|
p._powerset.to_multilabel.return_value = mx.zeros((589, 3))
|
|
p._embedding.return_value = mx.ones((1, 256))
|
|
# 30 s of silence
|
|
audio = np.zeros(30 * 16000, dtype=np.float32)
|
|
annotation = p({"waveform": mx.array(audio)[None, :], "sample_rate": 16000})
|
|
# silence → no turns
|
|
assert len(list(annotation.itertracks())) == 0
|