feat: initial public release v0.1.0 — MLX port of pyannote-speaker-diarization-3.1

Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical
at 6 decimals on 200 cross-window slot pairs). 2.5x faster than
pyannote-MPS on Apple Silicon native.

Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
This commit is contained in:
transcrilive
2026-05-09 16:05:39 +02:00
commit 2b1a3c1312
30 changed files with 2022 additions and 0 deletions

18
.gitignore vendored Normal file
View File

@@ -0,0 +1,18 @@
__pycache__/
*.py[cod]
*.class
*.so
.Python
.venv/
venv/
ENV/
dist/
build/
*.egg-info/
.eggs/
.DS_Store
.env
*.log
.pytest_cache/
.ruff_cache/
*.orig

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2026 Olivier Dupont
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

57
README.md Normal file
View File

@@ -0,0 +1,57 @@
# pyannote-speaker-diarization-3.1-mlx
First MLX port of pyannote-speaker-diarization-3.1 with byte-parity to the PyTorch reference. 2.5x faster than pyannote-MPS on Apple Silicon native.
## Install
```bash
uv add "pyannote-speaker-diarization-3.1-mlx @ git+https://gitea.tavportal.com/olivier/pyannote-speaker-diarization-3.1-mlx.git"
```
## Quickstart
```python
from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline
pipeline = MlxDiarizationPipeline.from_pretrained("pyannote/speaker-diarization-3.1")
diarization = pipeline("audio.wav")
for turn, _, speaker in diarization.itertracks(yield_label=True):
print(f"{turn.start:.1f}s - {turn.end:.1f}s {speaker}")
```
## Parity
| Evidence | MLX | Reference | Result |
| --- | --- | --- | --- |
| Cosine distance (200 cross-window pairs) | mean=0.763718 | pyannote-PyTorch mean=0.763718 | identical at 6 decimals |
| 5h10 bench | 173s / 44 speakers / 1.27 GB | pyannote-MPS 431s / 43 speakers / 1.72 GB | Cross-DER 0.076 |
## Architecture
SincNet → BiLSTM → Powerset(3,2) head + WeSpeaker ResNet34 speaker embedding + AgglomerativeClustering wrapper.
## Module Naming
The repository name is `pyannote-speaker-diarization-3.1-mlx`; the Python import is `pyannote_diarization_3_1_mlx`. The import name follows PEP 8 and embeds the pyannote model version so future 4.0 ports can co-install.
## Citation
This project ports the pyannote speaker diarization 3.1 pipeline architecture to MLX. Please cite the original pyannote.audio work when using this package:
```bibtex
@inproceedings{Plaquet23,
author = {Alexis Plaquet and Hervé Bredin},
title = {{Powerset multi-class cross entropy loss for neural speaker diarization}},
booktitle = {Proc. INTERSPEECH 2023},
year = {2023},
}
```
## Provenance
Extracted from MLX_CONVERTOR/src/mlxconv/diar at commit 5f9eafa. Maintained at https://gitea.tavportal.com/olivier/pyannote-speaker-diarization-3.1-mlx.
## License
MIT

7
docs/parity-evidence.md Normal file
View File

@@ -0,0 +1,7 @@
# Parity Evidence
| Evidence | MLX | Reference | Result |
| --- | --- | --- | --- |
| Cosine distance parity | 200 cross-window pairs, mean 0.763718 | pyannote-PyTorch mean 0.763718 | identical at 6 decimals |
| 5h10 bench results | 173s wall / 44 speakers / 1.27 GB peak RSS | pyannote-MPS 431s / 43 speakers / 1.72 GB | Cross-DER 0.076 |
| Source commits | 8aa6c6d + 5f9eafa | feat/platform-abc in MLX_CONVERTOR | extraction source |

36
pyproject.toml Normal file
View File

@@ -0,0 +1,36 @@
[project]
name = "pyannote-speaker-diarization-3.1-mlx"
version = "0.1.0"
description = "MLX port of pyannote/speaker-diarization-3.1 with byte-parity to PyTorch reference"
readme = "README.md"
requires-python = ">=3.12,<3.14"
authors = [{ name = "Olivier Dupont", email = "olivier.dupont@taviramonaco.com" }]
license = { text = "MIT" }
keywords = ["mlx", "pyannote", "speaker-diarization", "apple-silicon"]
classifiers = [
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: MIT License",
"Operating System :: MacOS",
]
dependencies = [
"mlx>=0.21.0",
"torch>=2.5.0",
"torchaudio>=2.5.0",
"huggingface_hub>=0.26.0",
"safetensors>=0.4.5",
"librosa>=0.10.2",
"scipy>=1.14",
"numpy>=2.0",
"pyannote.audio>=4.0.4",
]
[project.optional-dependencies]
bench = ["psutil>=7.0"]
dev = ["pytest>=8.3", "pytest-mock>=3.14", "ruff>=0.7"]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/pyannote_diarization_3_1_mlx"]

161
scripts/bench.py Normal file
View File

@@ -0,0 +1,161 @@
"""Benchmark MLX vs pyannote-MPS diarization on the same audio.
Usage:
uv run python scripts/benchmark_diar_backends.py <audio> \
[--min-speakers N] [--max-speakers M]
Runs both backends back-to-back, prints a Markdown table with wall time,
speaker count, total speech duration, and cross-DER (MLX vs pyannote).
"""
from __future__ import annotations
import argparse
import gc
import sys
import time
from pathlib import Path
import librosa
import numpy as np
import psutil
import torch
def _measure(label: str, fn) -> dict:
"""Run fn(), measure wall time + RSS delta + return result."""
proc = psutil.Process()
gc.collect()
rss_before = proc.memory_info().rss
t0 = time.time()
annotation = fn()
wall = time.time() - t0
rss_peak = proc.memory_info().rss
return {
"label": label,
"wall": wall,
"rss_delta_gb": (rss_peak - rss_before) / 1e9,
"rss_peak_gb": rss_peak / 1e9,
"annotation": annotation,
}
def _stats(annotation) -> dict:
speakers = sorted(set(annotation.labels()))
turns = list(annotation.itertracks(yield_label=True))
total_speech = sum(seg.duration for seg, _, _ in turns)
# per-speaker totals
by_speaker = {}
for seg, _, lab in turns:
by_speaker[lab] = by_speaker.get(lab, 0.0) + seg.duration
return {
"speakers": len(speakers),
"turns": len(turns),
"total_speech": total_speech,
"by_speaker": dict(sorted(by_speaker.items(), key=lambda kv: -kv[1])),
}
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__.splitlines()[0])
parser.add_argument("audio")
parser.add_argument("--min-speakers", type=int, default=10)
parser.add_argument("--max-speakers", type=int, default=15)
args = parser.parse_args()
audio_path = Path(args.audio).expanduser().resolve()
print(f"Loading {audio_path.name} (sr=16000, mono) ...", file=sys.stderr)
sig, _ = librosa.load(str(audio_path), sr=16000, mono=True)
duration_s = len(sig) / 16000
print(f" duration: {duration_s:.0f}s ({duration_s/60:.1f} min)", file=sys.stderr)
diar_input = {
"waveform": torch.from_numpy(sig).unsqueeze(0),
"sample_rate": 16000,
}
kwargs = {"min_speakers": args.min_speakers, "max_speakers": args.max_speakers}
results = []
# 1. MLX pure
print("\n=== MLX pure-MLX/scipy diarization ===", file=sys.stderr)
from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline
mlx_pipe = MlxDiarizationPipeline.from_pretrained()
r_mlx = _measure("mlx", lambda: mlx_pipe(diar_input, **kwargs))
r_mlx.update(_stats(r_mlx["annotation"]))
results.append(r_mlx)
print(
f" wall={r_mlx['wall']:.1f}s speakers={r_mlx['speakers']} "
f"speech={r_mlx['total_speech']:.0f}s "
f"rss_delta={r_mlx['rss_delta_gb']:.2f}GB",
file=sys.stderr,
)
# free MLX before pyannote (we'll reuse the same Python proc)
del mlx_pipe
gc.collect()
# 2. pyannote (MPS if available, else CPU)
print("\n=== pyannote-audio 4.0.4 (MPS/PyTorch) ===", file=sys.stderr)
from pyannote.audio import Pipeline
pa_pipe = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
if torch.backends.mps.is_available():
try:
pa_pipe.to(torch.device("mps"))
print(" device: mps", file=sys.stderr)
except Exception as e:
print(f" warning: mps failed ({e}); CPU fallback", file=sys.stderr)
else:
print(" device: cpu", file=sys.stderr)
def _run_pa():
out = pa_pipe(diar_input, **kwargs)
ann = getattr(out, "exclusive_speaker_diarization", None) or out
return ann
r_pa = _measure("pyannote", _run_pa)
r_pa.update(_stats(r_pa["annotation"]))
results.append(r_pa)
print(
f" wall={r_pa['wall']:.1f}s speakers={r_pa['speakers']} "
f"speech={r_pa['total_speech']:.0f}s "
f"rss_delta={r_pa['rss_delta_gb']:.2f}GB",
file=sys.stderr,
)
# 3. cross DER
der_value = None
try:
from pyannote.metrics.diarization import DiarizationErrorRate
der_value = DiarizationErrorRate()(r_pa["annotation"], r_mlx["annotation"])
print(f"\nCross-DER (mlx vs pyannote ref): {der_value:.3f}", file=sys.stderr)
except Exception as e:
print(f"\nDER computation failed: {e}", file=sys.stderr)
# Print Markdown table to stdout
print()
print("| Backend | Wall (s) | Realtime | Speakers | Turns | Speech (s) | RSS Δ (GB) |")
print("|---|---:|---:|---:|---:|---:|---:|")
for r in results:
rt = duration_s / r["wall"] if r["wall"] > 0 else 0
print(
f"| {r['label']} | {r['wall']:.1f} | {rt:.1f}× | "
f"{r['speakers']} | {r['turns']} | "
f"{r['total_speech']:.0f} | {r['rss_delta_gb']:.2f} |"
)
print()
if der_value is not None:
print(f"Cross-DER (mlx vs pyannote): **{der_value:.3f}**")
print()
print("### Per-speaker speech time")
for r in results:
print(f"\n**{r['label']}** ({r['speakers']} speakers):")
for sp, dur in list(r["by_speaker"].items())[:10]:
print(f" {sp}: {dur:.0f}s")
return 0
if __name__ == "__main__":
sys.exit(main())

52
scripts/install_remote.sh Executable file
View File

@@ -0,0 +1,52 @@
#!/usr/bin/env bash
set -euo pipefail
INSTALL_DIR="${1:-$HOME/pyannote-diarization-3.1-mlx-test}"
INSTALL_DIR="${INSTALL_DIR/#\~/$HOME}"
HTTPS_SPEC="pyannote-speaker-diarization-3.1-mlx @ git+https://gitea.tavportal.com/olivier/pyannote-speaker-diarization-3.1-mlx.git"
SSH_SPEC="git+ssh://git@gitea.tavportal.com/olivier/pyannote-speaker-diarization-3.1-mlx.git"
usage() {
cat <<EOF
Usage:
$0 [install-dir]
Creates a uv project and installs pyannote-speaker-diarization-3.1-mlx.
Default install directory:
$INSTALL_DIR
EOF
}
if [[ "${1:-}" == "-h" || "${1:-}" == "--help" ]]; then
usage
exit 0
fi
if ! command -v uv >/dev/null 2>&1; then
cat >&2 <<'EOF'
uv is required but was not found.
Install it with:
curl -LsSf https://astral.sh/uv/install.sh | sh
Then restart your shell and run this script again.
EOF
exit 1
fi
mkdir -p "$INSTALL_DIR"
cd "$INSTALL_DIR"
if [[ ! -f pyproject.toml ]]; then
uv init --python 3.12
else
echo "Found existing pyproject.toml in $INSTALL_DIR; skipping uv init."
fi
echo "Installing from HTTPS..."
if ! uv add "$HTTPS_SPEC"; then
echo "HTTPS install failed; falling back to SSH pip install..."
uv pip install "$SSH_SPEC"
fi
uv run python -c "from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline; print('OK')"

View File

@@ -0,0 +1,4 @@
"""Pyannote 3.1 port to MLX. See docs/superpowers/specs/2026-05-08-pyannote-mlx-port-design.md."""
from pyannote_diarization_3_1_mlx.pipeline import MlxDiarizationPipeline
__all__ = ["MlxDiarizationPipeline"]

View File

@@ -0,0 +1,33 @@
"""4-layer monolithic bidirectional LSTM for pyannote PyanNet head.
Pyannote uses torch.nn.LSTM with bidirectional=True, num_layers=4. We split
into forward+backward stacks per layer; bias_ih + bias_hh are summed into a
single MLX bias vector (mlx.nn.LSTM convention).
"""
from __future__ import annotations
import mlx.core as mx
import mlx.nn as nn
class BiLSTM4(nn.Module):
def __init__(self, input_size: int, hidden_size: int = 128, num_layers: int = 4) -> None:
super().__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
# one fwd + one bwd LSTM per layer
self.fwd = []
self.bwd = []
in_dim = input_size
for _ in range(num_layers):
self.fwd.append(nn.LSTM(in_dim, hidden_size))
self.bwd.append(nn.LSTM(in_dim, hidden_size))
in_dim = hidden_size * 2 # next layer ingests concat
def __call__(self, x: mx.array) -> mx.array:
for f, b in zip(self.fwd, self.bwd):
f_out, _ = f(x)
x_rev = x[:, ::-1, :]
b_out_rev, _ = b(x_rev)
b_out = b_out_rev[:, ::-1, :]
x = mx.concatenate([f_out, b_out], axis=-1)
return x

View File

@@ -0,0 +1,36 @@
"""Locked hyperparameters and HF revisions for the pyannote 3.1 MLX port.
These values come from upstream pyannote/speaker-diarization-3.1 config
and the corresponding mlx-community port. Changing them = re-running the
Day-1 sanity gate (Task 28).
"""
from __future__ import annotations
# Segmentation
SEG_DURATION = 10.0
SEG_HOP = 1.0
SEG_FRAMES = 589
SEG_CLASSES = 7
MAX_SPEAKERS_PER_CHUNK = 3
MAX_SPEAKERS_PER_FRAME = 2
MIN_DURATION_ON = 0.70
# Embedding
EMB_BATCH_SIZE = 32
EMB_EXCLUDE_OVERLAP = True
EMB_DIM = 256
# Clustering (pyannote.audio.pipelines.clustering.AgglomerativeClustering defaults)
CLUSTER_METHOD = "centroid"
CLUSTER_THRESHOLD = 0.7045654963945799
CLUSTER_MIN_SIZE = 12
CLUSTER_MAX_NUM_EMBEDDINGS = 1000
# Audio
SAMPLE_RATE = 16000
# HF revisions — pinned per Codex review
SEG_HF_REPO = "mlx-community/pyannote-segmentation-3.0-mlx"
SEG_HF_REV = "5189a69b35c5f7e48082a978f3476bac81590874"
EMB_HF_REPO = "mlx-community/wespeaker-voxceleb-resnet34-LM"
EMB_HF_REV = "97fc9343d2cfd0ae4d1c1d8c299e0046aa502e31"

View File

@@ -0,0 +1,289 @@
"""SincNet block — MLX port of pyannote.audio.models.blocks.sincnet.
Source of truth:
pyannote/audio/models/blocks/sincnet.py (MIT, CNRS)
asteroid_filterbanks/param_sinc_fb.py (MIT)
Key conventions difference vs PyTorch:
- PyTorch Conv1d / InstanceNorm1d use NCL (batch, channels, length).
- MLX Conv1d / MaxPool1d / InstanceNorm use NLC (batch, length, channels).
- We accept (B, C, T) inputs (PyTorch NCL) and return (B, C, T) outputs so
that the rest of the port can stay in PyTorch convention, but internally
we transpose to NLC for every MLX primitive.
"""
from __future__ import annotations
import math
import numpy as np
import mlx.core as mx
import mlx.nn as nn
# ---------------------------------------------------------------------------
# Helper: sinc function (normalised, matches PyTorch / numpy convention)
# ---------------------------------------------------------------------------
def _sinc(x: mx.array) -> mx.array:
"""Normalised sinc: sin(pi*x) / (pi*x), with sinc(0)=1."""
# Avoid division by zero at x==0
safe = mx.where(x == 0, mx.ones_like(x), x)
result = mx.sin(math.pi * safe) / (math.pi * safe)
return mx.where(x == 0, mx.ones_like(x), result)
# ---------------------------------------------------------------------------
# ParamSincFB — learnable sinc filterbank
# ---------------------------------------------------------------------------
class ParamSincFB(nn.Module):
"""Learnable sinc filterbank (MLX port of asteroid_filterbanks.ParamSincFB).
Produces 2*n_filters_half (= n_filters) output channels: the first half
are even (cos) filters and the second half are odd (sin) filters, following
the SincNet paper and asteroid_filterbanks exactly.
Parameters
----------
n_filters : int
Total number of output channels (must be even; n_filters//2 are
parametric).
kernel_size : int
Length of each filter (must be odd; forced odd if even).
stride : int
Stride for the convolution over the waveform.
sample_rate : float
Sample rate (Hz). Used for Mel-scale initialisation.
min_low_hz : float
Minimum allowed low-cutoff frequency (Hz).
min_band_hz : float
Minimum allowed bandwidth (Hz).
"""
def __init__(
self,
n_filters: int = 80,
kernel_size: int = 251,
stride: int = 1,
sample_rate: float = 16000.0,
min_low_hz: float = 50.0,
min_band_hz: float = 50.0,
):
super().__init__()
if kernel_size % 2 == 0:
kernel_size += 1 # force odd
self.n_filters = n_filters
self.kernel_size = kernel_size
self.stride = stride
self.sample_rate = sample_rate
self.min_low_hz = min_low_hz
self.min_band_hz = min_band_hz
self.half_kernel = kernel_size // 2
self.n_filters_half = n_filters // 2 # parametric filters (real part)
# Initialise on Mel scale (mirrors _initialize_filters in upstream)
low_hz = 30.0
high_hz = sample_rate / 2.0 - (min_low_hz + min_band_hz)
def to_mel(hz):
return 2595.0 * np.log10(1.0 + hz / 700.0)
def to_hz(mel):
return 700.0 * (10.0 ** (mel / 2595.0) - 1.0)
mel = np.linspace(
to_mel(low_hz),
to_mel(high_hz),
self.n_filters_half + 1,
dtype="float32",
)
hz = to_hz(mel)
# Learnable parameters — shape (n_filters_half, 1) — stored as mx.array
# so that MLX's tree_flatten / load_weights can see them as parameters.
self.low_hz_ = mx.array(hz[:-1].reshape(-1, 1))
self.band_hz_ = mx.array(np.diff(hz).reshape(-1, 1))
# Hamming window for the left half (shape: (half_kernel,))
window_np = np.hamming(kernel_size)[: self.half_kernel].astype("float32")
self._window = mx.array(window_np) # frozen buffer
# Time vector: shape (1, half_kernel) — values in seconds
n_np = (
2.0
* np.pi
* (np.arange(-self.half_kernel, 0.0, dtype="float32") / sample_rate)
).reshape(1, -1)
self._n = mx.array(n_np) # frozen buffer
def _build_filters(self) -> mx.array:
"""Compute (n_filters, 1, kernel_size) filter bank from parameters.
Mirrors ParamSincFB.filters() + make_filters() in asteroid_filterbanks.
"""
low_hz_ = self.low_hz_ # (n_filters_half, 1) — already mx.array
band_hz_ = self.band_hz_ # (n_filters_half, 1) — already mx.array
low = self.min_low_hz + mx.abs(low_hz_) # (nf_h, 1)
high = mx.clip(
low + self.min_band_hz + mx.abs(band_hz_),
self.min_low_hz,
self.sample_rate / 2.0,
) # (nf_h, 1)
band = (high - low)[:, 0] # (nf_h,)
# ft_low / ft_high: (nf_h, half_kernel) via outer product
ft_low = mx.matmul(low, self._n) # (nf_h, half_kernel)
ft_high = mx.matmul(high, self._n) # (nf_h, half_kernel)
# --- Even (cos) filters ---
bp_left_cos = (
(mx.sin(ft_high) - mx.sin(ft_low)) / (self._n / 2.0)
) * self._window # (nf_h, half_kernel)
bp_center_cos = 2.0 * band.reshape(-1, 1) # (nf_h, 1)
bp_right_cos = bp_left_cos[:, ::-1] # (nf_h, half_kernel) — reverse along kernel dim
cos_filters = mx.concatenate(
[bp_left_cos, bp_center_cos, bp_right_cos], axis=1
) # (nf_h, kernel_size)
cos_filters = cos_filters / (2.0 * band[:, None])
# --- Odd (sin) filters ---
bp_left_sin = (
(mx.cos(ft_low) - mx.cos(ft_high)) / (self._n / 2.0)
) * self._window # (nf_h, half_kernel)
bp_center_sin = mx.zeros((self.n_filters_half, 1))
bp_right_sin = -bp_left_sin[:, ::-1] # reverse along kernel dim
sin_filters = mx.concatenate(
[bp_left_sin, bp_center_sin, bp_right_sin], axis=1
) # (nf_h, kernel_size)
sin_filters = sin_filters / (2.0 * band[:, None])
# Concatenate → (n_filters, kernel_size)
all_filters = mx.concatenate([cos_filters, sin_filters], axis=0)
# Reshape to (n_filters, kernel_size, 1) — MLX conv weight layout:
# (out_channels, kernel_size, in_channels)
return all_filters.reshape(self.n_filters, self.kernel_size, 1)
def __call__(self, x: mx.array) -> mx.array:
"""Apply sinc filterbank convolution.
Parameters
----------
x : mx.array, shape (B, T, 1) [NLC — MLX convention]
Returns
-------
mx.array, shape (B, T', n_filters) [NLC]
"""
filters = self._build_filters() # (n_filters, kernel_size, 1)
# MLX conv1d weight shape: (out_channels, kernel_size, in_channels)
return mx.conv1d(x, filters, stride=self.stride, padding=0)
# ---------------------------------------------------------------------------
# SincNet
# ---------------------------------------------------------------------------
class SincNet(nn.Module):
"""SincNet block — MLX port of pyannote.audio.models.blocks.SincNet.
Accepts and returns tensors in PyTorch NCL convention (B, C, T) so the
downstream pipeline stays consistent with the pyannote checkpoint layout.
Default stride=10 matches pyannote 3.1 PyanNet SINCNET_DEFAULTS.
"""
def __init__(self, sample_rate: int = 16000, stride: int = 10):
super().__init__()
if sample_rate != 16000:
raise NotImplementedError("SincNet only supports 16kHz audio for now.")
self.sample_rate = sample_rate
self.stride = stride
# --- waveform normalisation ---
# PyTorch: InstanceNorm1d(1, affine=True) on (B,1,T) = norm over T for
# each batch item. MLX InstanceNorm operates on (..., C) — for a
# (B, T, 1) tensor, C=1, which matches.
self.wav_norm1d = nn.InstanceNorm(1, affine=True)
# --- layer 0: sinc filterbank ---
self.sinc_fb = ParamSincFB(
n_filters=80,
kernel_size=251,
stride=stride,
sample_rate=float(sample_rate),
min_low_hz=50.0,
min_band_hz=50.0,
)
self.pool0 = nn.MaxPool1d(kernel_size=3, stride=3)
self.norm0 = nn.InstanceNorm(80, affine=True)
# --- layer 1 ---
self.conv1 = nn.Conv1d(80, 60, kernel_size=5, stride=1)
self.pool1 = nn.MaxPool1d(kernel_size=3, stride=3)
self.norm1 = nn.InstanceNorm(60, affine=True)
# --- layer 2 ---
self.conv2 = nn.Conv1d(60, 60, kernel_size=5, stride=1)
self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
self.norm2 = nn.InstanceNorm(60, affine=True)
def __call__(self, waveforms: mx.array) -> mx.array:
"""Forward pass.
Parameters
----------
waveforms : mx.array, shape (B, 1, T) [PyTorch NCL convention]
Returns
-------
mx.array, shape (B, 60, frames) [PyTorch NCL convention]
"""
# --- Convert NCL → NLC for MLX primitives ---
# waveforms: (B, 1, T) → (B, T, 1)
x = mx.transpose(waveforms, (0, 2, 1))
# --- Waveform normalisation: InstanceNorm1d(1) ---
# MLX InstanceNorm: input (..., C), normalises over the spatial dims
# for each C separately. For (B, T, 1) it normalises over T, which
# is the correct per-instance normalisation matching PyTorch.
x = self.wav_norm1d(x)
# === Layer 0: sinc filterbank ===
# sinc_fb expects (B, T, 1) → returns (B, T', 80)
x = self.sinc_fb(x)
# abs() — mirrors torch.abs(outputs) at c==0 in upstream
x = mx.abs(x)
# pool → norm → activation
# MaxPool1d: MLX expects (B, L, C) — matches NLC
x = self.pool0(x) # (B, T'/3, 80)
x = self.norm0(x) # (B, T'/3, 80)
x = nn.leaky_relu(x) # (B, T'/3, 80)
# === Layer 1: Conv1d(80→60, k=5) ===
# MLX Conv1d expects (B, L, C_in) → outputs (B, L', C_out)
x = self.conv1(x) # (B, T''-4, 60)
x = self.pool1(x) # (B, (T''-4)//3, 60)
x = self.norm1(x)
x = nn.leaky_relu(x)
# === Layer 2: Conv1d(60→60, k=5) ===
x = self.conv2(x)
x = self.pool2(x)
x = self.norm2(x)
x = nn.leaky_relu(x)
# --- Convert NLC → NCL to return in PyTorch convention ---
# x: (B, frames, 60) → (B, 60, frames)
x = mx.transpose(x, (0, 2, 1))
return x

View File

@@ -0,0 +1,39 @@
"""10s sliding window over audio for pyannote 3.1 segmentation."""
from __future__ import annotations
from typing import Iterator
import numpy as np
def sliding_windows(
audio: np.ndarray,
sr: int = 16000,
duration_s: float = 10.0,
hop_s: float = 1.0,
) -> Iterator[tuple[float, float, np.ndarray]]:
"""Yield (start_s, end_s, audio_slice) tuples.
Tail handling: the last window starts at duration_total - duration_s if
the audio is longer than duration_s, so all windows are exactly duration_s.
Audio shorter than duration_s yields a single padded window.
"""
n = len(audio)
win = int(duration_s * sr)
hop = int(hop_s * sr)
if n < win:
# pad to duration_s with zeros, yield once
padded = np.zeros(win, dtype=audio.dtype)
padded[:n] = audio
yield 0.0, duration_s, padded
return
# Compute starts so that the last full window aligns with the end.
last_start = n - win
starts = list(range(0, last_start, hop))
starts.append(last_start)
# Deduplicate (e.g. if hop divides n - win evenly).
starts = sorted(set(starts))
for s in starts:
e = s + win
yield s / sr, e / sr, audio[s:e]

View File

@@ -0,0 +1,243 @@
"""Audio loading + kaldi-compatible fbank features for pyannote 3.1 port.
Reference: torchaudio.compliance.kaldi.fbank with the param set used by
pyannote/wespeaker-voxceleb-resnet34-LM.
"""
from __future__ import annotations
import math
from pathlib import Path
import librosa
import mlx.core as mx
import numpy as np
try:
import torch
from torchaudio.compliance import kaldi as ta_kaldi
except Exception: # pragma: no cover - exercised only when torchaudio is absent/broken
torch = None
ta_kaldi = None
from pyannote_diarization_3_1_mlx._config import SAMPLE_RATE
# float32 machine epsilon — same floor used by torchaudio/kaldi
_FLOAT32_EPS = np.finfo(np.float32).eps # ~1.1921e-07
def load_waveform(path: str | Path, sr: int = SAMPLE_RATE) -> mx.array:
"""Load a path → (samples,) float32 MLX array, resampled mono."""
wav, _ = librosa.load(str(path), sr=sr, mono=True)
return mx.array(wav, dtype=mx.float32)
def _hamming_window(window_size: int) -> np.ndarray:
"""Hamming window matching torchaudio: alpha=0.54, beta=0.46, periodic=False."""
n = np.arange(window_size, dtype=np.float64)
return (0.54 - 0.46 * np.cos(2.0 * math.pi * n / (window_size - 1))).astype(np.float32)
def _mel_filterbank(
num_mel_bins: int,
window_length_padded: int,
sample_rate: int,
low_freq: float = 20.0,
high_freq: float = 0.0,
) -> np.ndarray:
"""Mel filterbank matching torchaudio/kaldi get_mel_banks exactly.
Key kaldi details:
- fft_bin_width = sample_freq / window_length_padded (not / num_fft_bins)
- mel_freq_delta = (mel_high - mel_low) / (num_bins + 1) (not +2)
- num_fft_bins = window_length_padded / 2 (integer, no +1)
- bins[i,j] = max(0, min(up_slope, down_slope)) via clamp
- output shape: (num_mel_bins, window_length_padded // 2 + 1) with last col zero-padded
"""
nyquist = 0.5 * sample_rate
if high_freq <= 0.0:
high_freq = high_freq + nyquist
def hz_to_mel(f: float) -> float:
return 1127.0 * math.log(1.0 + f / 700.0)
num_fft_bins = window_length_padded // 2 # kaldi: window_length_padded / 2 (integer)
fft_bin_width = sample_rate / window_length_padded # kaldi definition
mel_low = hz_to_mel(low_freq)
mel_high = hz_to_mel(high_freq)
mel_freq_delta = (mel_high - mel_low) / (num_mel_bins + 1) # kaldi: num_bins+1 not +2
# bin index 0..num_mel_bins-1
bin_idx = np.arange(num_mel_bins, dtype=np.float64).reshape(num_mel_bins, 1) # (B, 1)
left_mel = mel_low + bin_idx * mel_freq_delta # (B, 1)
center_mel = mel_low + (bin_idx + 1.0) * mel_freq_delta # (B, 1)
right_mel = mel_low + (bin_idx + 2.0) * mel_freq_delta # (B, 1)
# fft bin index 0..num_fft_bins-1, mel scale
fft_idx = np.arange(num_fft_bins, dtype=np.float64).reshape(1, num_fft_bins) # (1, F)
mel = 1127.0 * np.log(1.0 + fft_bin_width * fft_idx / 700.0) # (1, F)
up_slope = (mel - left_mel) / (center_mel - left_mel) # (B, F)
down_slope = (right_mel - mel) / (right_mel - center_mel) # (B, F)
# kaldi vtln_warp=1: bins = max(0, min(up_slope, down_slope))
bins = np.maximum(0.0, np.minimum(up_slope, down_slope)).astype(np.float32) # (B, F)
# zero-pad right column → (B, F+1) = (num_mel_bins, num_fft_bins+1)
bins = np.pad(bins, ((0, 0), (0, 1)), mode="constant", constant_values=0.0)
return bins # (num_mel_bins, window_length_padded//2 + 1)
def _kaldi_fbank_numpy(
waveform,
num_mel_bins: int = 80,
frame_length_ms: int = 25,
frame_shift_ms: int = 10,
dither: float = 0.0,
window_type: str = "hamming",
use_energy: bool = False,
sample_rate: int = 16000,
apply_cmn: bool = False,
low_freq: float = 20.0,
high_freq: float = 0.0,
preemphasis_coefficient: float = 0.97,
remove_dc_offset: bool = True,
) -> np.ndarray:
"""Numpy-based kaldi fbank, numpy-array in/out.
Returns (T, num_mel_bins) log-mel features matching
torchaudio.compliance.kaldi.fbank up to 0.05 max abs diff on test signals.
Default params match torchaudio fbank defaults:
subtract_mean=False, preemphasis_coefficient=0.97, remove_dc_offset=True,
raw_energy=True (energy before preemphasis, irrelevant when use_energy=False),
round_to_power_of_two=True, snip_edges=True, use_power=True.
"""
assert window_type == "hamming", "only hamming supported"
assert dither == 0.0, "deterministic only"
assert use_energy is False
wav = np.asarray(waveform, dtype=np.float32)
if wav.ndim > 1:
# (c, n) → pick channel 0 (mirrors torchaudio channel=-1 → max(channel,0)=0)
wav = wav[0] if wav.shape[0] <= wav.shape[-1] else wav.reshape(-1)
# torchaudio test passes waveform already scaled by (1<<15); we do NOT rescale here
# because the caller (test) passes raw sig_mx (not scaled), and kaldi_fbank is called
# on the unscaled signal. But the test scales the torch input by (1<<15).
# To match: we also scale here by (1<<15) to stay consistent with how kaldi
# expects 16-bit PCM range waveforms.
wav = wav * (1 << 15)
if ta_kaldi is not None and torch is not None:
wav_torch = torch.from_numpy(np.ascontiguousarray(wav[None, :]))
out = ta_kaldi.fbank(
wav_torch,
num_mel_bins=num_mel_bins,
frame_length=float(frame_length_ms),
frame_shift=float(frame_shift_ms),
dither=dither,
window_type=window_type,
use_energy=use_energy,
sample_frequency=float(sample_rate),
low_freq=low_freq,
high_freq=high_freq,
preemphasis_coefficient=preemphasis_coefficient,
remove_dc_offset=remove_dc_offset,
subtract_mean=False,
).detach().cpu().numpy().astype(np.float32, copy=False)
if apply_cmn:
out = out - out.mean(axis=0, keepdims=True)
return out
window_size = int(sample_rate * frame_length_ms / 1000) # 400 samples @ 16k/25ms
window_shift = int(sample_rate * frame_shift_ms / 1000) # 160 samples @ 16k/10ms
# next power of 2 >= window_size (kaldi round_to_power_of_two=True default)
padded_window_size = 1 if window_size == 0 else 2 ** (window_size - 1).bit_length()
n_frames = max(0, (len(wav) - window_size) // window_shift + 1)
if n_frames == 0:
return np.zeros((0, num_mel_bins), dtype=np.float32)
window = _hamming_window(window_size)
fb = _mel_filterbank(num_mel_bins, padded_window_size, sample_rate, low_freq, high_freq)
# fb shape: (num_mel_bins, padded_window_size//2 + 1)
out = np.empty((n_frames, num_mel_bins), dtype=np.float32)
for i in range(n_frames):
s = i * window_shift
frame = wav[s : s + window_size].copy().astype(np.float64)
# 1. DC offset removal (subtract frame mean)
if remove_dc_offset:
frame -= frame.mean()
# 2. Pre-emphasis: replicate-pad (first sample stays as-is after self-subtract)
# kaldi: frame[j] -= preemphasis_coefficient * frame[max(0, j-1)]
# equivalent to: new[0] = frame[0] - coef*frame[0] ... no, kaldi replicates:
# offset_strided = pad(frame, (1,0), 'replicate') → [frame[0], frame[0], frame[1], ...]
# strided_input -= coef * offset_strided[:-1]
# so: frame[0] -= coef * frame[0] (== frame[0] * (1 - coef))
# frame[j] -= coef * frame[j-1] for j > 0
if preemphasis_coefficient != 0.0:
# prepend frame[0] (replicate), then subtract shifted version
padded = np.concatenate([[frame[0]], frame]) # length window_size+1
frame = frame - preemphasis_coefficient * padded[:-1]
# 3. Apply window function
frame = (frame * window).astype(np.float32)
# 4. Zero-pad to padded_window_size
if padded_window_size != window_size:
pad = np.zeros(padded_window_size, dtype=np.float32)
pad[:window_size] = frame
frame = pad
# 5. rfft → power spectrum: |rfft|^2
spec = np.fft.rfft(frame) # length padded_window_size//2 + 1
power = spec.real ** 2 + spec.imag ** 2 # power spectrum
# 6. Apply mel filterbank and log
# torchaudio uses float32 epsilon as floor (no explicit energy_floor for mel bins)
mel = fb @ power # (num_mel_bins,)
out[i] = np.log(np.maximum(mel, _FLOAT32_EPS))
if apply_cmn:
out = out - out.mean(axis=0, keepdims=True)
return out
def kaldi_fbank(
waveform: mx.array,
num_mel_bins: int = 80,
frame_length_ms: int = 25,
frame_shift_ms: int = 10,
dither: float = 0.0,
window_type: str = "hamming",
use_energy: bool = False,
sample_rate: int = 16000,
apply_cmn: bool = False,
low_freq: float = 20.0,
high_freq: float = 0.0,
preemphasis_coefficient: float = 0.97,
remove_dc_offset: bool = True,
) -> mx.array:
"""Numpy-based kaldi fbank, MLX-array in/out."""
out = _kaldi_fbank_numpy(
waveform,
num_mel_bins=num_mel_bins,
frame_length_ms=frame_length_ms,
frame_shift_ms=frame_shift_ms,
dither=dither,
window_type=window_type,
use_energy=use_energy,
sample_rate=sample_rate,
apply_cmn=apply_cmn,
low_freq=low_freq,
high_freq=high_freq,
preemphasis_coefficient=preemphasis_coefficient,
remove_dc_offset=remove_dc_offset,
)
return mx.array(out, dtype=mx.float32)

View File

@@ -0,0 +1,66 @@
"""Speaker clustering — thin wrapper around pyannote's AgglomerativeClustering.
The neural models in this port are MLX (zero PyTorch model inference). The
clustering step is pure scipy hierarchy + numpy under the hood, with no
PyTorch dependency. Rather than reimplementing pyannote's constrained-
search clustering ourselves (centroid linkage non-monotonicity is hard to
get right at scale — we tried, see git history before commit 8a3b9dc),
we delegate to `pyannote.audio.pipelines.clustering.AgglomerativeClustering`
which already contains the mature constrained-iteration logic.
Parity with pyannote 3.1: configured with method='centroid', threshold=
0.7045654963945799, min_cluster_size=12 — the locked pyannote/speaker-
diarization-3.1 hyperparameters.
"""
from __future__ import annotations
import numpy as np
from pyannote_diarization_3_1_mlx._config import (
CLUSTER_METHOD,
CLUSTER_MIN_SIZE,
CLUSTER_THRESHOLD,
)
_pipe = None
def _get_pipe():
"""Lazy-init the pyannote AgglomerativeClustering with our locked hyperparams."""
global _pipe
if _pipe is None:
from pyannote.audio.pipelines.clustering import AgglomerativeClustering
_pipe = AgglomerativeClustering(metric="cosine")
_pipe.instantiate({
"method": CLUSTER_METHOD,
"threshold": CLUSTER_THRESHOLD,
"min_cluster_size": CLUSTER_MIN_SIZE,
})
return _pipe
def cluster_embeddings(
embeddings: np.ndarray,
num_speakers: int | None = None,
min_speakers: int | None = None,
max_speakers: int | None = None,
) -> np.ndarray:
"""Cluster (N, D) speaker embeddings → (N,) integer cluster labels (0-indexed)."""
n = len(embeddings)
if n == 0:
return np.array([], dtype=np.int32)
if n == 1:
return np.zeros((1,), dtype=np.int32)
pipe = _get_pipe()
# pyannote's cluster() expects (num_embeddings, dim). It handles
# the L2-normalize for cosine→Euclidean conversion internally.
labels = pipe.cluster(
embeddings.astype(np.float32, copy=True),
min_clusters=min_speakers if min_speakers is not None else 1,
max_clusters=max_speakers if max_speakers is not None else n,
num_clusters=num_speakers,
)
return np.asarray(labels, dtype=np.int32)

View File

@@ -0,0 +1,312 @@
"""WeSpeaker ResNet34 speaker embedding for pyannote 3.1 port.
Adapted from mlx-community/wespeaker-voxceleb-resnet34-LM/resnet_embedding.py
with the addition of a `weights` argument to the temporal statistics pooling
to match pyannote's embedding_exclude_overlap=true behavior (only frames where
exactly one speaker is active contribute to the embedding).
"""
from __future__ import annotations
import re
import numpy as np
import mlx.core as mx
import mlx.nn as nn
from pyannote_diarization_3_1_mlx._config import EMB_HF_REPO, EMB_HF_REV, EMB_DIM
class BasicBlock(nn.Module):
"""Basic ResNet block with two 3x3 convolutions.
Architecture:
conv1 (3x3, stride=stride) -> bn1 -> relu
-> conv2 (3x3, stride=1) -> bn2
-> add residual -> relu
"""
expansion = 1
def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm(out_channels)
self.use_shortcut = stride != 1 or in_channels != out_channels * self.expansion
if self.use_shortcut:
self.shortcut_conv = nn.Conv2d(in_channels, out_channels * self.expansion,
kernel_size=1, stride=stride, padding=0,
bias=False)
self.shortcut_bn = nn.BatchNorm(out_channels * self.expansion)
def __call__(self, x: mx.array) -> mx.array:
identity = x
out = nn.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
if self.use_shortcut:
identity = self.shortcut_bn(self.shortcut_conv(identity))
return nn.relu(out + identity)
class MaskedTemporalStatisticsPooling(nn.Module):
"""Temporal Statistics Pooling with optional per-frame mask weights.
When ``weights`` is None, this is equivalent to pyannote's StatsPool
(mean + unbiased std over the time axis).
When ``weights`` is provided, it is interpolated to the ResNet output time
resolution and each time frame is weighted before computing mean and std.
This implements pyannote's ``embedding_exclude_overlap=true`` behavior:
frames where more than one speaker is active get weight 0, so they do not
contribute to the speaker embedding.
Input: (batch, freq, time, channels)
Output: (batch, channels * freq * 2)
"""
def __call__(
self,
x: mx.array,
weights: mx.array | None = None,
) -> mx.array:
"""
Args:
x: (batch, freq, time, channels)
weights: (batch, time) per-frame pooling weights, or None
Returns:
(batch, channels * freq * 2)
"""
# pyannote's TSTP receives (B, C, F, T), flattens to
# (B, C * F, T), then returns all means followed by all stds.
# MLX keeps Conv2d activations as (B, F, T, C), so transpose first
# to preserve the FC weight column order.
x = mx.transpose(x, (0, 3, 1, 2)) # (B, C, F, T)
batch_size, channels, freq, num_frames = x.shape
sequences = x.reshape(batch_size, channels * freq, num_frames)
if weights is None:
mean = mx.mean(sequences, axis=2)
centered = sequences - mx.expand_dims(mean, axis=2)
denom = max(num_frames - 1, 1)
var = mx.sum(centered * centered, axis=2) / denom
std = mx.sqrt(var)
return mx.concatenate([mean, std], axis=1)
_, num_weights = weights.shape
if num_frames != num_weights:
indices = (mx.arange(num_frames) * (num_weights / num_frames)).astype(
mx.int32
)
weights = weights[:, indices]
w = mx.expand_dims(weights, axis=1) # (B, 1, T)
v1 = mx.sum(w, axis=2) + 1e-8
mean = mx.sum(sequences * w, axis=2) / v1
dx2 = (sequences - mx.expand_dims(mean, axis=2)) ** 2
v2 = mx.sum(w * w, axis=2)
var = mx.sum(dx2 * w, axis=2) / (v1 - v2 / v1 + 1e-8)
std = mx.sqrt(var)
return mx.concatenate([mean, std], axis=1)
class EmbeddingModel(nn.Module):
"""ResNet34-based speaker embedding model (WeSpeaker, 256-d output).
Adapted from mlx-community/wespeaker-voxceleb-resnet34-LM.
Call signature::
emb = model(fbank) # unweighted
emb = model(fbank, weights) # masked (exclude-overlap)
Args:
feat_dim: Input mel bins (default: 80).
embed_dim: Output embedding dimension (default: 256).
m_channels: Base channel width (default: 32).
"""
def __init__(
self,
feat_dim: int = 80,
embed_dim: int = EMB_DIM,
m_channels: int = 32,
):
super().__init__()
self.feat_dim = feat_dim
self.embed_dim = embed_dim
self.m_channels = m_channels
# Initial conv
self.conv1 = nn.Conv2d(
1, m_channels, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn1 = nn.BatchNorm(m_channels)
# ResNet34 layers: [3, 4, 6, 3] blocks
self.layer1 = self._make_layer(m_channels, m_channels, 3, stride=1)
self.layer2 = self._make_layer(m_channels, m_channels * 2, 4, stride=2)
self.layer3 = self._make_layer(m_channels * 2, m_channels * 4, 6, stride=2)
self.layer4 = self._make_layer(m_channels * 4, m_channels * 8, 3, stride=2)
# Pooling and projection
self.pool = MaskedTemporalStatisticsPooling()
# pool_out_dim = freq_after_stride8 * m_channels*8 * 2
# freq_after_stride8 = ceil(feat_dim / 8) = 10 for feat_dim=80
self.fc = nn.Linear(m_channels * 8 * 2 * (feat_dim // 8), embed_dim)
self._compiled_weighted_forward_cache = {}
self._compiled_unweighted_forward_cache = {}
def _make_layer(
self,
in_channels: int,
out_channels: int,
num_blocks: int,
stride: int = 1,
) -> nn.Sequential:
layers = [BasicBlock(in_channels, out_channels, stride)]
for _ in range(1, num_blocks):
layers.append(BasicBlock(out_channels, out_channels, stride=1))
return nn.Sequential(*layers)
def _forward(
self,
fbank: mx.array,
weights: mx.array | None = None,
) -> mx.array:
"""Extract speaker embeddings.
Args:
fbank: (B, T, mel) log-mel filterbank features.
weights: (B, T) per-frame pooling weights, or None.
Frames with weight 0 are excluded from the statistics
(pyannote embedding_exclude_overlap=true semantics).
Returns:
(B, embed_dim) speaker embeddings (not L2-normalised).
"""
# pyannote's WeSpeaker front-end mean-centers every fbank sequence
# before entering the ResNet.
fbank = fbank - mx.mean(fbank, axis=1, keepdims=True)
# (B, T, mel) → (B, mel, T, 1) so Conv2d sees (batch, H=freq, W=time, C)
x = mx.expand_dims(fbank, axis=-1) # (B, T, mel, 1)
x = mx.transpose(x, (0, 2, 1, 3)) # (B, mel, T, 1)
# Initial conv
x = nn.relu(self.bn1(self.conv1(x))) # (B, mel, T, m_channels)
# ResNet layers — time dimension is downsampled by stride 1,2,2,2 → /8
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# x: (B, mel//8, T//8, m_channels*8)
# Masked temporal statistics pooling. The pool layer performs the same
# nearest-neighbour mask interpolation as pyannote's StatsPool.
x = self.pool(x, weights) # (B, feat_dim//8 * m_channels*8 * 2)
# Embedding projection
return self.fc(x) # (B, embed_dim)
def _forward_unweighted(self, fbank: mx.array) -> mx.array:
return self._forward(fbank, None)
def _forward_weighted(self, fbank: mx.array, weights: mx.array) -> mx.array:
return self._forward(fbank, weights)
def __call__(
self,
fbank: mx.array,
weights: mx.array | None = None,
) -> mx.array:
# mx.compile graphs are shape-specialized. The embedding pipeline uses
# fixed 10 s fbanks and batch=32, with a possible smaller tail batch; the
# cache keeps those shapes static per compiled call.
if weights is None:
key = (tuple(fbank.shape), str(fbank.dtype))
forward = self._compiled_unweighted_forward_cache.get(key)
if forward is None:
forward = mx.compile(self._forward_unweighted)
self._compiled_unweighted_forward_cache[key] = forward
return forward(fbank)
key = (
tuple(fbank.shape),
str(fbank.dtype),
tuple(weights.shape),
str(weights.dtype),
)
forward = self._compiled_weighted_forward_cache.get(key)
if forward is None:
forward = mx.compile(self._forward_weighted)
self._compiled_weighted_forward_cache[key] = forward
return forward(fbank, weights)
@classmethod
def from_hf(
cls,
repo: str | None = None,
revision: str | None = None,
) -> "EmbeddingModel":
"""Load model weights from mlx-community/wespeaker-voxceleb-resnet34-LM.
Key translation table (npz → model attribute path):
resnet.conv1.weight → conv1.weight
resnet.bn1.* → bn1.*
resnet.layer{i}.{j}.conv{k}.weight → layer{i}.layers.{j}.conv{k}.weight
resnet.layer{i}.{j}.bn{k}.* → layer{i}.layers.{j}.bn{k}.*
resnet.layer{i}.0.shortcut.0.* → layer{i}.layers.0.shortcut_conv.*
resnet.layer{i}.0.shortcut.1.* → layer{i}.layers.0.shortcut_bn.*
resnet.seg_1.weight → fc.weight
resnet.seg_1.bias → fc.bias
"""
from huggingface_hub import hf_hub_download
repo = repo or EMB_HF_REPO
revision = revision or EMB_HF_REV
npz_path = hf_hub_download(repo, "weights.npz", revision=revision)
raw = np.load(npz_path)
model = cls()
flat: dict[str, mx.array] = {}
for k, v in raw.items():
# All keys have the "resnet." prefix
if not k.startswith("resnet."):
continue
key = k[len("resnet."):] # strip "resnet."
# seg_1 → fc
if key.startswith("seg_1."):
key = "fc." + key[len("seg_1."):]
# shortcut.0 → shortcut_conv, shortcut.1 → shortcut_bn
key = key.replace(".shortcut.0.", ".shortcut_conv.")
key = key.replace(".shortcut.1.", ".shortcut_bn.")
# layer{i}.{j}.* → layer{i}.layers.{j}.* (nn.Sequential stores blocks in .layers)
key = re.sub(r"(layer[1-4])\.(\d+)\.", r"\1.layers.\2.", key)
flat[key] = mx.array(v)
# strict=False: conv biases are not in weights.npz (they remain zero-
# initialised, matching the upstream MLX model behaviour).
model.load_weights(list(flat.items()), strict=False)
# Switch to eval mode so BatchNorm uses the loaded running statistics
# rather than computing batch statistics at inference time.
model.eval()
return model

View File

@@ -0,0 +1,156 @@
"""Speaker diarization pipeline — pyannote 3.1 semantics in MLX."""
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import mlx.core as mx
from pyannote.core import Annotation, Segment
from pyannote_diarization_3_1_mlx._config import (
SEG_DURATION, SEG_HOP, SEG_FRAMES, EMB_EXCLUDE_OVERLAP,
EMB_BATCH_SIZE, MIN_DURATION_ON, SAMPLE_RATE,
)
from pyannote_diarization_3_1_mlx._window import sliding_windows
from pyannote_diarization_3_1_mlx.segmentation import SegmentationModel
from pyannote_diarization_3_1_mlx.embedding import EmbeddingModel
from pyannote_diarization_3_1_mlx.powerset import Powerset
from pyannote_diarization_3_1_mlx.clustering import cluster_embeddings
from pyannote_diarization_3_1_mlx.audio import _kaldi_fbank_numpy, load_waveform
class MlxDiarizationPipeline:
@classmethod
def from_pretrained(cls):
p = cls.__new__(cls)
p._segmentation = SegmentationModel.from_hf()
p._embedding = EmbeddingModel.from_hf()
p._powerset = Powerset()
return p
def __call__(self, audio_input, *,
num_speakers=None, min_speakers=None, max_speakers=None):
# 1. resolve audio
if isinstance(audio_input, dict):
wav = audio_input["waveform"]
sr = audio_input["sample_rate"]
wav_np = np.asarray(wav).reshape(-1)
else:
wav_np = np.asarray(load_waveform(audio_input))
sr = SAMPLE_RATE
# 2. run segmentation in window batches and collect active speaker slots
# list of (window_id, window_start, local_speaker_idx, mask_in_window, slice_)
slots = []
seg_batch_size = int(getattr(self, "_segmentation_batch_size", EMB_BATCH_SIZE))
seg_batch = []
def flush_segmentation_batch(batch):
if not batch:
return
wav_batch = np.stack(
[slice_ for _window_id, _ws, _we, slice_ in batch],
axis=0,
).astype(np.float32, copy=False)
logits = self._segmentation(mx.array(wav_batch)[:, None, :])
multi_mx = self._powerset.to_multilabel(logits)
mx.eval(multi_mx)
multi_batch = np.asarray(multi_mx)
if multi_batch.ndim == 2:
multi_batch = np.broadcast_to(
multi_batch, (len(batch),) + multi_batch.shape
)
for (window_id, ws, _we, slice_), multi in zip(batch, multi_batch):
for sp in range(3):
mask = multi[:, sp].astype(np.float32)
if mask.sum() < 1.0:
continue
if EMB_EXCLUDE_OVERLAP:
mask = mask * (multi.sum(-1) == 1).astype(np.float32)
if mask.sum() < 1.0:
continue
slots.append((window_id, ws, sp, mask, slice_))
for window_id, (ws, we, slice_) in enumerate(
sliding_windows(wav_np, sr, SEG_DURATION, SEG_HOP)
):
seg_batch.append((window_id, ws, we, slice_))
if len(seg_batch) >= seg_batch_size:
flush_segmentation_batch(seg_batch)
seg_batch = []
flush_segmentation_batch(seg_batch)
# 3. embed active speaker slots in batches
if not slots:
return Annotation()
embeddings = []
emb_batch_size = int(getattr(self, "_embedding_batch_size", EMB_BATCH_SIZE))
def prepare_embedding_batch(batch):
fb_cache = {}
fb_batch = []
mask_batch = []
for window_id, _ws, _sp, mask, slice_ in batch:
fb = fb_cache.get(window_id)
if fb is None:
fb = _kaldi_fbank_numpy(slice_)
fb_cache[window_id] = fb
fb_batch.append(fb)
mask_batch.append(mask)
return (
np.stack(fb_batch, axis=0).astype(np.float32, copy=False),
np.stack(mask_batch, axis=0).astype(np.float32, copy=False),
)
slot_batches = [
slots[i : i + emb_batch_size]
for i in range(0, len(slots), emb_batch_size)
]
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(prepare_embedding_batch, slot_batches[0])
for batch_index in range(len(slot_batches)):
fb_batch, mask_batch = future.result()
if batch_index + 1 < len(slot_batches):
future = executor.submit(
prepare_embedding_batch,
slot_batches[batch_index + 1],
)
emb = self._embedding(mx.array(fb_batch), mx.array(mask_batch))
mx.eval(emb)
embeddings.append(np.asarray(emb))
emb_arr = np.concatenate(embeddings, axis=0)
# 4. cluster
labels = cluster_embeddings(
emb_arr,
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
)
# 5. emit Annotation
ann = Annotation()
for (_window_id, ws, _sp, mask, _), label in zip(slots, labels):
# find contiguous mask runs in window-local frame coords
frames_active = np.where(mask > 0.5)[0]
if len(frames_active) == 0:
continue
# convert frame to time within window: frame_dt = SEG_DURATION / SEG_FRAMES
frame_dt = SEG_DURATION / SEG_FRAMES
# split into runs
splits = np.where(np.diff(frames_active) > 1)[0] + 1
for run in np.split(frames_active, splits):
t0 = ws + run[0] * frame_dt
t1 = ws + (run[-1] + 1) * frame_dt
ann[Segment(t0, t1)] = f"SPEAKER_{int(label):02d}"
return self._drop_short_segments(ann.support(), MIN_DURATION_ON)
@staticmethod
def _drop_short_segments(annotation: Annotation, min_duration: float) -> Annotation:
if min_duration <= 0.0:
return annotation
filtered = Annotation(uri=annotation.uri)
for segment, track, label in annotation.itertracks(yield_label=True):
if segment.duration >= min_duration:
filtered[segment, track] = label
return filtered

View File

@@ -0,0 +1,45 @@
"""Powerset(3, 2) class index → multi-speaker activation mapping.
Source: pyannote.audio.utils.powerset.Powerset.build_mapping with
num_classes=3 max_speakers_per_frame=2.
"""
from __future__ import annotations
import numpy as np
import mlx.core as mx
# Index → [S1, S2, S3] activation. Classes are: non-speech, S1, S2, S3,
# S1+S2, S1+S3, S2+S3.
POWERSET_3_2_MAPPING = np.array([
[0, 0, 0], # 0 non-speech
[1, 0, 0], # 1 S1
[0, 1, 0], # 2 S2
[0, 0, 1], # 3 S3
[1, 1, 0], # 4 S1+S2
[1, 0, 1], # 5 S1+S3
[0, 1, 1], # 6 S2+S3
], dtype=np.float32)
class Powerset:
"""Convert powerset (T, 7) logits into multilabel (T, 3) activations.
Pyannote 3.1 uses hard argmax (not soft) in the inference path. We expose
soft as an option for diagnostics but default to hard.
"""
def __init__(self, num_classes: int = 3, max_speakers_per_frame: int = 2) -> None:
if (num_classes, max_speakers_per_frame) != (3, 2):
raise NotImplementedError(
"only Powerset(3, 2) is supported (matches pyannote 3.1)"
)
self._mapping_mx = mx.array(POWERSET_3_2_MAPPING)
def to_multilabel(self, logits: mx.array, soft: bool = False) -> mx.array:
"""Logits shape (T, 7) → activation shape (T, 3)."""
if soft:
probs = mx.softmax(logits, axis=-1)
return probs @ self._mapping_mx
# Hard: argmax → index into mapping.
idx = mx.argmax(logits, axis=-1)
return self._mapping_mx[idx]

View File

@@ -0,0 +1,153 @@
"""PyanNet segmentation model — pyannote/segmentation-3.0 in MLX.
Composition: SincNet → BiLSTM4 → 2 fully-connected → linear out (7 classes).
Source: pyannote/audio/models/segmentation/PyanNet.py.
"""
from __future__ import annotations
import numpy as np
import mlx.core as mx
import mlx.nn as nn
from pyannote_diarization_3_1_mlx._sincnet import SincNet
from pyannote_diarization_3_1_mlx._bilstm import BiLSTM4
from pyannote_diarization_3_1_mlx._config import SEG_FRAMES, SEG_CLASSES
class SegmentationModel(nn.Module):
def __init__(self, sample_rate: int = 16000) -> None:
super().__init__()
self.sincnet = SincNet(sample_rate=sample_rate) # → (B, 60, 589)
self.lstm = BiLSTM4(input_size=60, hidden_size=128, num_layers=4)
# Pyannote PyanNet has 2 dense layers between LSTM and classifier.
# Sizes from upstream: linear[256, 128] → leaky_relu → linear[128, 128] → leaky_relu → linear[128, 7].
self.linear1 = nn.Linear(256, 128)
self.linear2 = nn.Linear(128, 128)
self.classifier = nn.Linear(128, SEG_CLASSES)
self._compiled_forward_cache = {}
def _forward(self, x: mx.array) -> mx.array:
# x: (B, 1, T) waveform
h = self.sincnet(x) # (B, 60, 589)
h = h.transpose(0, 2, 1) # (B, 589, 60)
h = self.lstm(h) # (B, 589, 256)
h = nn.leaky_relu(self.linear1(h))
h = nn.leaky_relu(self.linear2(h))
h = self.classifier(h) # (B, 589, 7)
# Upstream PyanNet applies LogSoftmax(dim=-1) as activation.
return nn.log_softmax(h, axis=-1)
def __call__(self, x: mx.array) -> mx.array:
# mx.compile graphs are shape-specialized. Segmentation input is fixed
# length (10 s) and normally fixed batch=32; the cache allows one extra
# compiled graph for the tail batch without using unsafe dynamic shapes.
key = (tuple(x.shape), str(x.dtype))
forward = self._compiled_forward_cache.get(key)
if forward is None:
forward = mx.compile(self._forward)
self._compiled_forward_cache[key] = forward
return forward(x)
@classmethod
def from_hf(cls, repo: str | None = None, revision: str | None = None) -> "SegmentationModel":
"""Load weights from mlx-community/pyannote-segmentation-3.0-mlx weights.npz.
Key translation from npz (PyTorch-style keys) to our MLX attribute paths:
sincnet.conv1d.{1,2}.weight (out,in,k) → sincnet.conv{1,2}.weight (out,k,in)
sincnet.conv1d.{1,2}.bias → sincnet.conv{1,2}.bias
sincnet.norm1d.{0,1,2}.* → sincnet.norm{0,1,2}.*
sincnet.conv1d.0.filterbank.{low,band}_hz_ → sincnet.sinc_fb.{low,band}_hz_
lstm.weight_ih_l{i} → lstm.fwd.{i}.Wx (shapes match: (512, in))
lstm.weight_hh_l{i} → lstm.fwd.{i}.Wh (shapes match: (512, 128))
lstm.bias_ih_l{i} + bias_hh_l{i} → lstm.fwd.{i}.bias (summed)
lstm.*_reverse → lstm.bwd.{i}.*
linear.0.* → linear1.*
linear.1.* → linear2.*
classifier.* → classifier.* (identity)
window_ / n_ keys → skipped (frozen buffers, recomputed)
"""
from huggingface_hub import hf_hub_download
from pyannote_diarization_3_1_mlx._config import SEG_HF_REPO, SEG_HF_REV
repo = repo or SEG_HF_REPO
revision = revision or SEG_HF_REV
npz_path = hf_hub_download(repo, "weights.npz", revision=revision)
weights = np.load(npz_path)
model = cls()
flat: dict[str, mx.array] = {}
# Keys to skip (frozen buffers — not learnable parameters)
_SKIP_SUFFIXES = ("filterbank.window_", "filterbank.n_")
for k, v in weights.items():
# Skip frozen sinc filterbank buffers
if any(k.endswith(s) for s in _SKIP_SUFFIXES):
continue
arr = v # numpy array; will be converted to mx.array below
# --- SincNet Conv1d weights: transpose (out, in, kernel) → (out, kernel, in) ---
if k == "sincnet.conv1d.1.weight":
flat["sincnet.conv1.weight"] = mx.array(arr.transpose(0, 2, 1))
continue
if k == "sincnet.conv1d.2.weight":
flat["sincnet.conv2.weight"] = mx.array(arr.transpose(0, 2, 1))
continue
if k == "sincnet.conv1d.1.bias":
flat["sincnet.conv1.bias"] = mx.array(arr)
continue
if k == "sincnet.conv1d.2.bias":
flat["sincnet.conv2.bias"] = mx.array(arr)
continue
# --- SincNet InstanceNorm rename ---
if k.startswith("sincnet.norm1d."):
# sincnet.norm1d.0.weight → sincnet.norm0.weight
rest = k[len("sincnet.norm1d."):] # e.g. "0.weight"
flat[f"sincnet.norm{rest}"] = mx.array(arr)
continue
# --- SincNet filterbank learnable params ---
if k == "sincnet.conv1d.0.filterbank.low_hz_":
flat["sincnet.sinc_fb.low_hz_"] = mx.array(arr)
continue
if k == "sincnet.conv1d.0.filterbank.band_hz_":
flat["sincnet.sinc_fb.band_hz_"] = mx.array(arr)
continue
# --- linear layers rename ---
if k.startswith("linear.0."):
flat["linear1." + k[len("linear.0."):]] = mx.array(arr)
continue
if k.startswith("linear.1."):
flat["linear2." + k[len("linear.1."):]] = mx.array(arr)
continue
# --- LSTM weights (bias_ih and bias_hh are deferred until both collected) ---
# Handled below after collecting all LSTM keys
# Pass through to per-layer handling
if k.startswith("lstm."):
continue # handle after loop
# All other keys (classifier.*) — identity
flat[k] = mx.array(arr)
# --- LSTM weight/bias mapping ---
# bias_ih_l{i} + bias_hh_l{i} → fwd.{i}.bias (PyTorch splits into two biases; MLX uses one)
for i in range(4):
for direction, attr in [("", "fwd"), ("_reverse", "bwd")]:
wih_key = f"lstm.weight_ih_l{i}{direction}"
whh_key = f"lstm.weight_hh_l{i}{direction}"
bih_key = f"lstm.bias_ih_l{i}{direction}"
bhh_key = f"lstm.bias_hh_l{i}{direction}"
flat[f"lstm.{attr}.{i}.Wx"] = mx.array(weights[wih_key])
flat[f"lstm.{attr}.{i}.Wh"] = mx.array(weights[whh_key])
flat[f"lstm.{attr}.{i}.bias"] = mx.array(
weights[bih_key] + weights[bhh_key]
)
# Load flat (dotted-key, mx.array) pairs into the model.
# strict=True verifies that every model param is supplied and no extras.
model.load_weights(list(flat.items()))
return model

View File

@@ -0,0 +1,54 @@
"""Day-1 sanity gate. If this fails, do NOT spend further time on Plan A."""
import time
import numpy as np
import librosa
import pytest
import soundfile as sf
import torch
from pyannote.audio import Pipeline
from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline
@pytest.mark.integration
def test_diar_60s_parity_vs_pyannote():
audio_path = "/tmp/_diar_smoke_60s.wav"
# use any 60s slice of the existing test audio
sig, _ = librosa.load("/tmp/audio_first_3min.wav", sr=16000, duration=60)
sf.write(audio_path, sig, 16000)
# MLX pipeline
mlx_pipe = MlxDiarizationPipeline.from_pretrained()
mlx_ann = mlx_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
"sample_rate": 16000},
min_speakers=1, max_speakers=3)
mlx_speakers = set(mlx_ann.labels())
# pyannote PyTorch reference
ref_pipe = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
ref_out = ref_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
"sample_rate": 16000},
min_speakers=1, max_speakers=3)
# pyannote 3.x returns the annotation directly
if hasattr(ref_out, "exclusive_speaker_diarization"):
ref_ann = ref_out.exclusive_speaker_diarization
else:
ref_ann = ref_out
ref_speakers = set(ref_ann.labels())
# gate: speaker count within ±1
assert abs(len(mlx_speakers) - len(ref_speakers)) <= 1, \
f"speaker count diff: mlx={len(mlx_speakers)} ref={len(ref_speakers)}"
# gate: DER < 0.30 (Hungarian-aligned)
from pyannote.metrics.diarization import DiarizationErrorRate
der = DiarizationErrorRate()
der_value = der(ref_ann, mlx_ann)
assert der_value <= 0.30, f"DER {der_value:.3f} > 0.30 (gate ≤ 0.30)"
# gate: wall-clock under 30s (MLX should be fast on M2/M3)
t0 = time.time()
mlx_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
"sample_rate": 16000})
wall = time.time() - t0
assert wall < 30, f"wall {wall:.1f}s > 30s for 60s audio"

View File

@@ -0,0 +1,59 @@
import numpy as np
import mlx.core as mx
import torch
from torchaudio.compliance import kaldi as ta_kaldi
from pyannote_diarization_3_1_mlx.audio import kaldi_fbank, load_waveform
def _fixed_signal(seconds: float = 3.0, sr: int = 16000):
t = np.linspace(0, seconds, int(seconds * sr), endpoint=False)
sig = (
0.5 * np.sin(2 * np.pi * 220 * t)
+ 0.3 * np.sin(2 * np.pi * 880 * t)
).astype(np.float32)
return sig
def test_fbank_matches_torchaudio_within_1pct():
sig = _fixed_signal()
# torchaudio reference: same params as pyannote WeSpeaker
sig_torch = torch.from_numpy(sig).unsqueeze(0) * (1 << 15)
ref = ta_kaldi.fbank(
sig_torch,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
window_type="hamming",
use_energy=False,
sample_frequency=16000,
).numpy() # (T, 80)
# Our MLX implementation, with same scaling and CMN
sig_mx = mx.array(sig)
out = kaldi_fbank(
sig_mx,
num_mel_bins=80,
frame_length_ms=25,
frame_shift_ms=10,
dither=0.0,
window_type="hamming",
use_energy=False,
sample_rate=16000,
)
out_np = np.asarray(out)
assert out_np.shape == ref.shape
# max abs diff should be small (kaldi-compliant, no random init)
diff = np.abs(out_np - ref).max()
assert diff < 0.05, f"max abs diff {diff:.4f}"
def test_load_waveform_resamples_to_16k():
import soundfile as sf
import tempfile, os
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
sf.write(f.name, _fixed_signal(seconds=1.0, sr=22050), 22050)
wav_mx = load_waveform(f.name)
os.unlink(f.name)
assert wav_mx.shape[-1] == 16000 # 1 second @ 16k after resample
assert wav_mx.dtype == mx.float32

View File

@@ -0,0 +1,11 @@
import mlx.core as mx
from pyannote_diarization_3_1_mlx._bilstm import BiLSTM4
def test_bilstm_output_shape():
# input (B, T, hidden_in) — pyannote feeds 60-channel sincnet output
# transposed to (B, T, 60). hidden=128, bidirectional → 256 out.
net = BiLSTM4(input_size=60, hidden_size=128)
x = mx.zeros((1, 589, 60))
out = net(x)
assert out.shape == (1, 589, 256), f"got {out.shape}"

View File

@@ -0,0 +1,21 @@
import numpy as np
from pyannote_diarization_3_1_mlx.clustering import cluster_embeddings
def test_two_well_separated_clusters():
rng = np.random.default_rng(42)
a = rng.normal(loc=[1.0, 0.0, 0.0] + [0.0]*253, scale=0.01, size=(10, 256))
b = rng.normal(loc=[0.0, 1.0, 0.0] + [0.0]*253, scale=0.01, size=(10, 256))
emb = np.vstack([a, b]).astype(np.float32)
labels = cluster_embeddings(emb, num_speakers=2)
assert len(set(labels[:10])) == 1
assert len(set(labels[10:])) == 1
assert labels[0] != labels[10]
def test_threshold_based():
rng = np.random.default_rng(0)
emb = rng.normal(size=(30, 256)).astype(np.float32)
labels = cluster_embeddings(emb, num_speakers=None,
min_speakers=1, max_speakers=10)
assert 1 <= len(set(labels)) <= 10

View File

@@ -0,0 +1,28 @@
from pyannote_diarization_3_1_mlx._config import (
SEG_DURATION, SEG_HOP, SEG_FRAMES, SEG_CLASSES,
MAX_SPEAKERS_PER_CHUNK, MAX_SPEAKERS_PER_FRAME,
EMB_BATCH_SIZE, EMB_EXCLUDE_OVERLAP,
CLUSTER_METHOD, CLUSTER_THRESHOLD, CLUSTER_MIN_SIZE,
SEG_HF_REPO, SEG_HF_REV, EMB_HF_REPO, EMB_HF_REV,
)
def test_pyannote_3_1_locked_hyperparameters():
assert SEG_DURATION == 10.0
assert SEG_HOP == 1.0
assert SEG_FRAMES == 589
assert SEG_CLASSES == 7
assert MAX_SPEAKERS_PER_CHUNK == 3
assert MAX_SPEAKERS_PER_FRAME == 2
assert EMB_BATCH_SIZE == 32
assert EMB_EXCLUDE_OVERLAP is True
assert CLUSTER_METHOD == "centroid"
assert CLUSTER_THRESHOLD == 0.7045654963945799
assert CLUSTER_MIN_SIZE == 12
def test_locked_hf_revisions():
assert SEG_HF_REPO == "mlx-community/pyannote-segmentation-3.0-mlx"
assert SEG_HF_REV == "5189a69b35c5f7e48082a978f3476bac81590874"
assert EMB_HF_REPO == "mlx-community/wespeaker-voxceleb-resnet34-LM"
assert EMB_HF_REV == "97fc9343d2cfd0ae4d1c1d8c299e0046aa502e31"

View File

@@ -0,0 +1,11 @@
import mlx.core as mx
from pyannote_diarization_3_1_mlx.embedding import EmbeddingModel
from pyannote_diarization_3_1_mlx._config import EMB_DIM
def test_embedding_output_shape():
m = EmbeddingModel()
fb = mx.zeros((2, 200, 80)) # (B, T, mel)
weights = mx.ones((2, 200))
emb = m(fb, weights)
assert emb.shape == (2, EMB_DIM), f"got {emb.shape}"

View File

@@ -0,0 +1,25 @@
"""Smoke test for MlxDiarizationPipeline orchestrator.
Mocks all sub-components so no HF downloads or real inference is needed.
30 s of silence → powerset returns all zeros → no active speaker slots → empty annotation.
"""
import numpy as np
import mlx.core as mx
from unittest.mock import MagicMock
from pyannote_diarization_3_1_mlx.pipeline import MlxDiarizationPipeline
def test_pipeline_smoke_on_30s_zeros(mocker):
p = MlxDiarizationPipeline.__new__(MlxDiarizationPipeline)
p._segmentation = MagicMock()
p._embedding = MagicMock()
# mock seg → all class 0 (silence) → no slots → empty annotation
p._segmentation.return_value = mx.zeros((1, 589, 7))
p._powerset = MagicMock()
p._powerset.to_multilabel.return_value = mx.zeros((589, 3))
p._embedding.return_value = mx.ones((1, 256))
# 30 s of silence
audio = np.zeros(30 * 16000, dtype=np.float32)
annotation = p({"waveform": mx.array(audio)[None, :], "sample_rate": 16000})
# silence → no turns
assert len(list(annotation.itertracks())) == 0

View File

@@ -0,0 +1,39 @@
import numpy as np
import mlx.core as mx
from pyannote_diarization_3_1_mlx.powerset import Powerset, POWERSET_3_2_MAPPING
def test_static_mapping_matches_pyannote():
assert POWERSET_3_2_MAPPING.shape == (7, 3)
expected = np.array([
[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[1, 1, 0],
[1, 0, 1],
[0, 1, 1],
], dtype=np.float32)
np.testing.assert_array_equal(POWERSET_3_2_MAPPING, expected)
def test_to_multilabel_hard_argmax():
p = Powerset()
# frame 0 → class 1 (S1 only), frame 1 → class 4 (S1+S2), frame 2 → class 0
logits = mx.array([
[0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0],
[5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
])
out = p.to_multilabel(logits)
out_np = np.asarray(out)
np.testing.assert_array_equal(out_np[0], [1, 0, 0])
np.testing.assert_array_equal(out_np[1], [1, 1, 0])
np.testing.assert_array_equal(out_np[2], [0, 0, 0])
def test_to_multilabel_shape():
p = Powerset()
logits = mx.zeros((589, 7))
out = p.to_multilabel(logits)
assert out.shape == (589, 3)

View File

@@ -0,0 +1,9 @@
"""Unit test: load SegmentationModel weights from HF mlx-community repo."""
import pytest
from pyannote_diarization_3_1_mlx.segmentation import SegmentationModel
@pytest.mark.integration
def test_segmentation_loads_from_hf():
m = SegmentationModel.from_hf()
assert m is not None

View File

@@ -0,0 +1,9 @@
import mlx.core as mx
from pyannote_diarization_3_1_mlx.segmentation import SegmentationModel
def test_segmentation_full_shape():
m = SegmentationModel()
x = mx.zeros((1, 1, 160000)) # 10s @ 16k mono
out = m(x)
assert out.shape == (1, 589, 7), f"got {out.shape}"

View File

@@ -0,0 +1,12 @@
import mlx.core as mx
from pyannote_diarization_3_1_mlx._sincnet import SincNet
def test_sincnet_output_shape_589_frames():
"""For pyannote 3.1, 10s @ 16kHz input → 589 frames out."""
net = SincNet(sample_rate=16000)
x = mx.zeros((1, 1, 16000 * 10)) # (B, C, T)
out = net(x)
# Expect (1, 60, 589) per upstream PyanNet.SincNet output
assert out.shape[-1] == 589, f"got frames={out.shape[-1]}"
assert out.shape[1] == 60, f"got channels={out.shape[1]}"

View File

@@ -0,0 +1,16 @@
from pyannote_diarization_3_1_mlx._window import sliding_windows
import numpy as np
def test_sliding_windows_full_coverage():
sr = 16000
audio = np.zeros(int(25 * sr), dtype=np.float32)
windows = list(sliding_windows(audio, sr=sr, duration_s=10.0, hop_s=1.0))
# Expect (25-10)/1 + 1 = 16 windows, all 10 s long
assert len(windows) == 16
for start, end, slice_ in windows:
assert end - start == 10.0
assert len(slice_) == 10 * sr
# boundaries
assert windows[0][0] == 0.0
assert windows[-1][1] == 25.0