feat: initial public release v0.1.0 — MLX port of pyannote-speaker-diarization-3.1
Byte-parity with pyannote-PyTorch reference (cosine 0.763718 identical at 6 decimals on 200 cross-window slot pairs). 2.5x faster than pyannote-MPS on Apple Silicon native. Extracted from gitea.tavportal.com/olivier/MLX_CONVERTOR commit 5f9eafa.
This commit is contained in:
18
.gitignore
vendored
Normal file
18
.gitignore
vendored
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
*.egg-info/
|
||||||
|
.eggs/
|
||||||
|
.DS_Store
|
||||||
|
.env
|
||||||
|
*.log
|
||||||
|
.pytest_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
*.orig
|
||||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2026 Olivier Dupont
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
57
README.md
Normal file
57
README.md
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
# pyannote-speaker-diarization-3.1-mlx
|
||||||
|
|
||||||
|
First MLX port of pyannote-speaker-diarization-3.1 with byte-parity to the PyTorch reference. 2.5x faster than pyannote-MPS on Apple Silicon native.
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "pyannote-speaker-diarization-3.1-mlx @ git+https://gitea.tavportal.com/olivier/pyannote-speaker-diarization-3.1-mlx.git"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
```python
|
||||||
|
from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline
|
||||||
|
|
||||||
|
pipeline = MlxDiarizationPipeline.from_pretrained("pyannote/speaker-diarization-3.1")
|
||||||
|
diarization = pipeline("audio.wav")
|
||||||
|
|
||||||
|
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
||||||
|
print(f"{turn.start:.1f}s - {turn.end:.1f}s {speaker}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Parity
|
||||||
|
|
||||||
|
| Evidence | MLX | Reference | Result |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| Cosine distance (200 cross-window pairs) | mean=0.763718 | pyannote-PyTorch mean=0.763718 | identical at 6 decimals |
|
||||||
|
| 5h10 bench | 173s / 44 speakers / 1.27 GB | pyannote-MPS 431s / 43 speakers / 1.72 GB | Cross-DER 0.076 |
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
SincNet → BiLSTM → Powerset(3,2) head + WeSpeaker ResNet34 speaker embedding + AgglomerativeClustering wrapper.
|
||||||
|
|
||||||
|
## Module Naming
|
||||||
|
|
||||||
|
The repository name is `pyannote-speaker-diarization-3.1-mlx`; the Python import is `pyannote_diarization_3_1_mlx`. The import name follows PEP 8 and embeds the pyannote model version so future 4.0 ports can co-install.
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
This project ports the pyannote speaker diarization 3.1 pipeline architecture to MLX. Please cite the original pyannote.audio work when using this package:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{Plaquet23,
|
||||||
|
author = {Alexis Plaquet and Hervé Bredin},
|
||||||
|
title = {{Powerset multi-class cross entropy loss for neural speaker diarization}},
|
||||||
|
booktitle = {Proc. INTERSPEECH 2023},
|
||||||
|
year = {2023},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Provenance
|
||||||
|
|
||||||
|
Extracted from MLX_CONVERTOR/src/mlxconv/diar at commit 5f9eafa. Maintained at https://gitea.tavportal.com/olivier/pyannote-speaker-diarization-3.1-mlx.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT
|
||||||
7
docs/parity-evidence.md
Normal file
7
docs/parity-evidence.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# Parity Evidence
|
||||||
|
|
||||||
|
| Evidence | MLX | Reference | Result |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| Cosine distance parity | 200 cross-window pairs, mean 0.763718 | pyannote-PyTorch mean 0.763718 | identical at 6 decimals |
|
||||||
|
| 5h10 bench results | 173s wall / 44 speakers / 1.27 GB peak RSS | pyannote-MPS 431s / 43 speakers / 1.72 GB | Cross-DER 0.076 |
|
||||||
|
| Source commits | 8aa6c6d + 5f9eafa | feat/platform-abc in MLX_CONVERTOR | extraction source |
|
||||||
36
pyproject.toml
Normal file
36
pyproject.toml
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
[project]
|
||||||
|
name = "pyannote-speaker-diarization-3.1-mlx"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "MLX port of pyannote/speaker-diarization-3.1 with byte-parity to PyTorch reference"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12,<3.14"
|
||||||
|
authors = [{ name = "Olivier Dupont", email = "olivier.dupont@taviramonaco.com" }]
|
||||||
|
license = { text = "MIT" }
|
||||||
|
keywords = ["mlx", "pyannote", "speaker-diarization", "apple-silicon"]
|
||||||
|
classifiers = [
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Operating System :: MacOS",
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"mlx>=0.21.0",
|
||||||
|
"torch>=2.5.0",
|
||||||
|
"torchaudio>=2.5.0",
|
||||||
|
"huggingface_hub>=0.26.0",
|
||||||
|
"safetensors>=0.4.5",
|
||||||
|
"librosa>=0.10.2",
|
||||||
|
"scipy>=1.14",
|
||||||
|
"numpy>=2.0",
|
||||||
|
"pyannote.audio>=4.0.4",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
bench = ["psutil>=7.0"]
|
||||||
|
dev = ["pytest>=8.3", "pytest-mock>=3.14", "ruff>=0.7"]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["src/pyannote_diarization_3_1_mlx"]
|
||||||
161
scripts/bench.py
Normal file
161
scripts/bench.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
"""Benchmark MLX vs pyannote-MPS diarization on the same audio.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run python scripts/benchmark_diar_backends.py <audio> \
|
||||||
|
[--min-speakers N] [--max-speakers M]
|
||||||
|
|
||||||
|
Runs both backends back-to-back, prints a Markdown table with wall time,
|
||||||
|
speaker count, total speech duration, and cross-DER (MLX vs pyannote).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import psutil
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def _measure(label: str, fn) -> dict:
|
||||||
|
"""Run fn(), measure wall time + RSS delta + return result."""
|
||||||
|
proc = psutil.Process()
|
||||||
|
gc.collect()
|
||||||
|
rss_before = proc.memory_info().rss
|
||||||
|
t0 = time.time()
|
||||||
|
annotation = fn()
|
||||||
|
wall = time.time() - t0
|
||||||
|
rss_peak = proc.memory_info().rss
|
||||||
|
return {
|
||||||
|
"label": label,
|
||||||
|
"wall": wall,
|
||||||
|
"rss_delta_gb": (rss_peak - rss_before) / 1e9,
|
||||||
|
"rss_peak_gb": rss_peak / 1e9,
|
||||||
|
"annotation": annotation,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _stats(annotation) -> dict:
|
||||||
|
speakers = sorted(set(annotation.labels()))
|
||||||
|
turns = list(annotation.itertracks(yield_label=True))
|
||||||
|
total_speech = sum(seg.duration for seg, _, _ in turns)
|
||||||
|
# per-speaker totals
|
||||||
|
by_speaker = {}
|
||||||
|
for seg, _, lab in turns:
|
||||||
|
by_speaker[lab] = by_speaker.get(lab, 0.0) + seg.duration
|
||||||
|
return {
|
||||||
|
"speakers": len(speakers),
|
||||||
|
"turns": len(turns),
|
||||||
|
"total_speech": total_speech,
|
||||||
|
"by_speaker": dict(sorted(by_speaker.items(), key=lambda kv: -kv[1])),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__.splitlines()[0])
|
||||||
|
parser.add_argument("audio")
|
||||||
|
parser.add_argument("--min-speakers", type=int, default=10)
|
||||||
|
parser.add_argument("--max-speakers", type=int, default=15)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
audio_path = Path(args.audio).expanduser().resolve()
|
||||||
|
print(f"Loading {audio_path.name} (sr=16000, mono) ...", file=sys.stderr)
|
||||||
|
sig, _ = librosa.load(str(audio_path), sr=16000, mono=True)
|
||||||
|
duration_s = len(sig) / 16000
|
||||||
|
print(f" duration: {duration_s:.0f}s ({duration_s/60:.1f} min)", file=sys.stderr)
|
||||||
|
diar_input = {
|
||||||
|
"waveform": torch.from_numpy(sig).unsqueeze(0),
|
||||||
|
"sample_rate": 16000,
|
||||||
|
}
|
||||||
|
kwargs = {"min_speakers": args.min_speakers, "max_speakers": args.max_speakers}
|
||||||
|
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# 1. MLX pure
|
||||||
|
print("\n=== MLX pure-MLX/scipy diarization ===", file=sys.stderr)
|
||||||
|
from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline
|
||||||
|
|
||||||
|
mlx_pipe = MlxDiarizationPipeline.from_pretrained()
|
||||||
|
r_mlx = _measure("mlx", lambda: mlx_pipe(diar_input, **kwargs))
|
||||||
|
r_mlx.update(_stats(r_mlx["annotation"]))
|
||||||
|
results.append(r_mlx)
|
||||||
|
print(
|
||||||
|
f" wall={r_mlx['wall']:.1f}s speakers={r_mlx['speakers']} "
|
||||||
|
f"speech={r_mlx['total_speech']:.0f}s "
|
||||||
|
f"rss_delta={r_mlx['rss_delta_gb']:.2f}GB",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# free MLX before pyannote (we'll reuse the same Python proc)
|
||||||
|
del mlx_pipe
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# 2. pyannote (MPS if available, else CPU)
|
||||||
|
print("\n=== pyannote-audio 4.0.4 (MPS/PyTorch) ===", file=sys.stderr)
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
|
||||||
|
pa_pipe = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
try:
|
||||||
|
pa_pipe.to(torch.device("mps"))
|
||||||
|
print(" device: mps", file=sys.stderr)
|
||||||
|
except Exception as e:
|
||||||
|
print(f" warning: mps failed ({e}); CPU fallback", file=sys.stderr)
|
||||||
|
else:
|
||||||
|
print(" device: cpu", file=sys.stderr)
|
||||||
|
|
||||||
|
def _run_pa():
|
||||||
|
out = pa_pipe(diar_input, **kwargs)
|
||||||
|
ann = getattr(out, "exclusive_speaker_diarization", None) or out
|
||||||
|
return ann
|
||||||
|
|
||||||
|
r_pa = _measure("pyannote", _run_pa)
|
||||||
|
r_pa.update(_stats(r_pa["annotation"]))
|
||||||
|
results.append(r_pa)
|
||||||
|
print(
|
||||||
|
f" wall={r_pa['wall']:.1f}s speakers={r_pa['speakers']} "
|
||||||
|
f"speech={r_pa['total_speech']:.0f}s "
|
||||||
|
f"rss_delta={r_pa['rss_delta_gb']:.2f}GB",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. cross DER
|
||||||
|
der_value = None
|
||||||
|
try:
|
||||||
|
from pyannote.metrics.diarization import DiarizationErrorRate
|
||||||
|
der_value = DiarizationErrorRate()(r_pa["annotation"], r_mlx["annotation"])
|
||||||
|
print(f"\nCross-DER (mlx vs pyannote ref): {der_value:.3f}", file=sys.stderr)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nDER computation failed: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
# Print Markdown table to stdout
|
||||||
|
print()
|
||||||
|
print("| Backend | Wall (s) | Realtime | Speakers | Turns | Speech (s) | RSS Δ (GB) |")
|
||||||
|
print("|---|---:|---:|---:|---:|---:|---:|")
|
||||||
|
for r in results:
|
||||||
|
rt = duration_s / r["wall"] if r["wall"] > 0 else 0
|
||||||
|
print(
|
||||||
|
f"| {r['label']} | {r['wall']:.1f} | {rt:.1f}× | "
|
||||||
|
f"{r['speakers']} | {r['turns']} | "
|
||||||
|
f"{r['total_speech']:.0f} | {r['rss_delta_gb']:.2f} |"
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
if der_value is not None:
|
||||||
|
print(f"Cross-DER (mlx vs pyannote): **{der_value:.3f}**")
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("### Per-speaker speech time")
|
||||||
|
for r in results:
|
||||||
|
print(f"\n**{r['label']}** ({r['speakers']} speakers):")
|
||||||
|
for sp, dur in list(r["by_speaker"].items())[:10]:
|
||||||
|
print(f" {sp}: {dur:.0f}s")
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
52
scripts/install_remote.sh
Executable file
52
scripts/install_remote.sh
Executable file
@@ -0,0 +1,52 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
INSTALL_DIR="${1:-$HOME/pyannote-diarization-3.1-mlx-test}"
|
||||||
|
INSTALL_DIR="${INSTALL_DIR/#\~/$HOME}"
|
||||||
|
HTTPS_SPEC="pyannote-speaker-diarization-3.1-mlx @ git+https://gitea.tavportal.com/olivier/pyannote-speaker-diarization-3.1-mlx.git"
|
||||||
|
SSH_SPEC="git+ssh://git@gitea.tavportal.com/olivier/pyannote-speaker-diarization-3.1-mlx.git"
|
||||||
|
|
||||||
|
usage() {
|
||||||
|
cat <<EOF
|
||||||
|
Usage:
|
||||||
|
$0 [install-dir]
|
||||||
|
|
||||||
|
Creates a uv project and installs pyannote-speaker-diarization-3.1-mlx.
|
||||||
|
Default install directory:
|
||||||
|
$INSTALL_DIR
|
||||||
|
EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
if [[ "${1:-}" == "-h" || "${1:-}" == "--help" ]]; then
|
||||||
|
usage
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! command -v uv >/dev/null 2>&1; then
|
||||||
|
cat >&2 <<'EOF'
|
||||||
|
uv is required but was not found.
|
||||||
|
|
||||||
|
Install it with:
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
|
||||||
|
Then restart your shell and run this script again.
|
||||||
|
EOF
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir -p "$INSTALL_DIR"
|
||||||
|
cd "$INSTALL_DIR"
|
||||||
|
|
||||||
|
if [[ ! -f pyproject.toml ]]; then
|
||||||
|
uv init --python 3.12
|
||||||
|
else
|
||||||
|
echo "Found existing pyproject.toml in $INSTALL_DIR; skipping uv init."
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Installing from HTTPS..."
|
||||||
|
if ! uv add "$HTTPS_SPEC"; then
|
||||||
|
echo "HTTPS install failed; falling back to SSH pip install..."
|
||||||
|
uv pip install "$SSH_SPEC"
|
||||||
|
fi
|
||||||
|
|
||||||
|
uv run python -c "from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline; print('OK')"
|
||||||
4
src/pyannote_diarization_3_1_mlx/__init__.py
Normal file
4
src/pyannote_diarization_3_1_mlx/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
"""Pyannote 3.1 port to MLX. See docs/superpowers/specs/2026-05-08-pyannote-mlx-port-design.md."""
|
||||||
|
from pyannote_diarization_3_1_mlx.pipeline import MlxDiarizationPipeline
|
||||||
|
|
||||||
|
__all__ = ["MlxDiarizationPipeline"]
|
||||||
33
src/pyannote_diarization_3_1_mlx/_bilstm.py
Normal file
33
src/pyannote_diarization_3_1_mlx/_bilstm.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""4-layer monolithic bidirectional LSTM for pyannote PyanNet head.
|
||||||
|
|
||||||
|
Pyannote uses torch.nn.LSTM with bidirectional=True, num_layers=4. We split
|
||||||
|
into forward+backward stacks per layer; bias_ih + bias_hh are summed into a
|
||||||
|
single MLX bias vector (mlx.nn.LSTM convention).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class BiLSTM4(nn.Module):
|
||||||
|
def __init__(self, input_size: int, hidden_size: int = 128, num_layers: int = 4) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
# one fwd + one bwd LSTM per layer
|
||||||
|
self.fwd = []
|
||||||
|
self.bwd = []
|
||||||
|
in_dim = input_size
|
||||||
|
for _ in range(num_layers):
|
||||||
|
self.fwd.append(nn.LSTM(in_dim, hidden_size))
|
||||||
|
self.bwd.append(nn.LSTM(in_dim, hidden_size))
|
||||||
|
in_dim = hidden_size * 2 # next layer ingests concat
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
for f, b in zip(self.fwd, self.bwd):
|
||||||
|
f_out, _ = f(x)
|
||||||
|
x_rev = x[:, ::-1, :]
|
||||||
|
b_out_rev, _ = b(x_rev)
|
||||||
|
b_out = b_out_rev[:, ::-1, :]
|
||||||
|
x = mx.concatenate([f_out, b_out], axis=-1)
|
||||||
|
return x
|
||||||
36
src/pyannote_diarization_3_1_mlx/_config.py
Normal file
36
src/pyannote_diarization_3_1_mlx/_config.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""Locked hyperparameters and HF revisions for the pyannote 3.1 MLX port.
|
||||||
|
|
||||||
|
These values come from upstream pyannote/speaker-diarization-3.1 config
|
||||||
|
and the corresponding mlx-community port. Changing them = re-running the
|
||||||
|
Day-1 sanity gate (Task 28).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
# Segmentation
|
||||||
|
SEG_DURATION = 10.0
|
||||||
|
SEG_HOP = 1.0
|
||||||
|
SEG_FRAMES = 589
|
||||||
|
SEG_CLASSES = 7
|
||||||
|
MAX_SPEAKERS_PER_CHUNK = 3
|
||||||
|
MAX_SPEAKERS_PER_FRAME = 2
|
||||||
|
MIN_DURATION_ON = 0.70
|
||||||
|
|
||||||
|
# Embedding
|
||||||
|
EMB_BATCH_SIZE = 32
|
||||||
|
EMB_EXCLUDE_OVERLAP = True
|
||||||
|
EMB_DIM = 256
|
||||||
|
|
||||||
|
# Clustering (pyannote.audio.pipelines.clustering.AgglomerativeClustering defaults)
|
||||||
|
CLUSTER_METHOD = "centroid"
|
||||||
|
CLUSTER_THRESHOLD = 0.7045654963945799
|
||||||
|
CLUSTER_MIN_SIZE = 12
|
||||||
|
CLUSTER_MAX_NUM_EMBEDDINGS = 1000
|
||||||
|
|
||||||
|
# Audio
|
||||||
|
SAMPLE_RATE = 16000
|
||||||
|
|
||||||
|
# HF revisions — pinned per Codex review
|
||||||
|
SEG_HF_REPO = "mlx-community/pyannote-segmentation-3.0-mlx"
|
||||||
|
SEG_HF_REV = "5189a69b35c5f7e48082a978f3476bac81590874"
|
||||||
|
EMB_HF_REPO = "mlx-community/wespeaker-voxceleb-resnet34-LM"
|
||||||
|
EMB_HF_REV = "97fc9343d2cfd0ae4d1c1d8c299e0046aa502e31"
|
||||||
289
src/pyannote_diarization_3_1_mlx/_sincnet.py
Normal file
289
src/pyannote_diarization_3_1_mlx/_sincnet.py
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
"""SincNet block — MLX port of pyannote.audio.models.blocks.sincnet.
|
||||||
|
|
||||||
|
Source of truth:
|
||||||
|
pyannote/audio/models/blocks/sincnet.py (MIT, CNRS)
|
||||||
|
asteroid_filterbanks/param_sinc_fb.py (MIT)
|
||||||
|
|
||||||
|
Key conventions difference vs PyTorch:
|
||||||
|
- PyTorch Conv1d / InstanceNorm1d use NCL (batch, channels, length).
|
||||||
|
- MLX Conv1d / MaxPool1d / InstanceNorm use NLC (batch, length, channels).
|
||||||
|
- We accept (B, C, T) inputs (PyTorch NCL) and return (B, C, T) outputs so
|
||||||
|
that the rest of the port can stay in PyTorch convention, but internally
|
||||||
|
we transpose to NLC for every MLX primitive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helper: sinc function (normalised, matches PyTorch / numpy convention)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _sinc(x: mx.array) -> mx.array:
|
||||||
|
"""Normalised sinc: sin(pi*x) / (pi*x), with sinc(0)=1."""
|
||||||
|
# Avoid division by zero at x==0
|
||||||
|
safe = mx.where(x == 0, mx.ones_like(x), x)
|
||||||
|
result = mx.sin(math.pi * safe) / (math.pi * safe)
|
||||||
|
return mx.where(x == 0, mx.ones_like(x), result)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ParamSincFB — learnable sinc filterbank
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class ParamSincFB(nn.Module):
|
||||||
|
"""Learnable sinc filterbank (MLX port of asteroid_filterbanks.ParamSincFB).
|
||||||
|
|
||||||
|
Produces 2*n_filters_half (= n_filters) output channels: the first half
|
||||||
|
are even (cos) filters and the second half are odd (sin) filters, following
|
||||||
|
the SincNet paper and asteroid_filterbanks exactly.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
n_filters : int
|
||||||
|
Total number of output channels (must be even; n_filters//2 are
|
||||||
|
parametric).
|
||||||
|
kernel_size : int
|
||||||
|
Length of each filter (must be odd; forced odd if even).
|
||||||
|
stride : int
|
||||||
|
Stride for the convolution over the waveform.
|
||||||
|
sample_rate : float
|
||||||
|
Sample rate (Hz). Used for Mel-scale initialisation.
|
||||||
|
min_low_hz : float
|
||||||
|
Minimum allowed low-cutoff frequency (Hz).
|
||||||
|
min_band_hz : float
|
||||||
|
Minimum allowed bandwidth (Hz).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_filters: int = 80,
|
||||||
|
kernel_size: int = 251,
|
||||||
|
stride: int = 1,
|
||||||
|
sample_rate: float = 16000.0,
|
||||||
|
min_low_hz: float = 50.0,
|
||||||
|
min_band_hz: float = 50.0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if kernel_size % 2 == 0:
|
||||||
|
kernel_size += 1 # force odd
|
||||||
|
|
||||||
|
self.n_filters = n_filters
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.stride = stride
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.min_low_hz = min_low_hz
|
||||||
|
self.min_band_hz = min_band_hz
|
||||||
|
|
||||||
|
self.half_kernel = kernel_size // 2
|
||||||
|
self.n_filters_half = n_filters // 2 # parametric filters (real part)
|
||||||
|
|
||||||
|
# Initialise on Mel scale (mirrors _initialize_filters in upstream)
|
||||||
|
low_hz = 30.0
|
||||||
|
high_hz = sample_rate / 2.0 - (min_low_hz + min_band_hz)
|
||||||
|
|
||||||
|
def to_mel(hz):
|
||||||
|
return 2595.0 * np.log10(1.0 + hz / 700.0)
|
||||||
|
|
||||||
|
def to_hz(mel):
|
||||||
|
return 700.0 * (10.0 ** (mel / 2595.0) - 1.0)
|
||||||
|
|
||||||
|
mel = np.linspace(
|
||||||
|
to_mel(low_hz),
|
||||||
|
to_mel(high_hz),
|
||||||
|
self.n_filters_half + 1,
|
||||||
|
dtype="float32",
|
||||||
|
)
|
||||||
|
hz = to_hz(mel)
|
||||||
|
|
||||||
|
# Learnable parameters — shape (n_filters_half, 1) — stored as mx.array
|
||||||
|
# so that MLX's tree_flatten / load_weights can see them as parameters.
|
||||||
|
self.low_hz_ = mx.array(hz[:-1].reshape(-1, 1))
|
||||||
|
self.band_hz_ = mx.array(np.diff(hz).reshape(-1, 1))
|
||||||
|
|
||||||
|
# Hamming window for the left half (shape: (half_kernel,))
|
||||||
|
window_np = np.hamming(kernel_size)[: self.half_kernel].astype("float32")
|
||||||
|
self._window = mx.array(window_np) # frozen buffer
|
||||||
|
|
||||||
|
# Time vector: shape (1, half_kernel) — values in seconds
|
||||||
|
n_np = (
|
||||||
|
2.0
|
||||||
|
* np.pi
|
||||||
|
* (np.arange(-self.half_kernel, 0.0, dtype="float32") / sample_rate)
|
||||||
|
).reshape(1, -1)
|
||||||
|
self._n = mx.array(n_np) # frozen buffer
|
||||||
|
|
||||||
|
def _build_filters(self) -> mx.array:
|
||||||
|
"""Compute (n_filters, 1, kernel_size) filter bank from parameters.
|
||||||
|
|
||||||
|
Mirrors ParamSincFB.filters() + make_filters() in asteroid_filterbanks.
|
||||||
|
"""
|
||||||
|
low_hz_ = self.low_hz_ # (n_filters_half, 1) — already mx.array
|
||||||
|
band_hz_ = self.band_hz_ # (n_filters_half, 1) — already mx.array
|
||||||
|
|
||||||
|
low = self.min_low_hz + mx.abs(low_hz_) # (nf_h, 1)
|
||||||
|
high = mx.clip(
|
||||||
|
low + self.min_band_hz + mx.abs(band_hz_),
|
||||||
|
self.min_low_hz,
|
||||||
|
self.sample_rate / 2.0,
|
||||||
|
) # (nf_h, 1)
|
||||||
|
|
||||||
|
band = (high - low)[:, 0] # (nf_h,)
|
||||||
|
|
||||||
|
# ft_low / ft_high: (nf_h, half_kernel) via outer product
|
||||||
|
ft_low = mx.matmul(low, self._n) # (nf_h, half_kernel)
|
||||||
|
ft_high = mx.matmul(high, self._n) # (nf_h, half_kernel)
|
||||||
|
|
||||||
|
# --- Even (cos) filters ---
|
||||||
|
bp_left_cos = (
|
||||||
|
(mx.sin(ft_high) - mx.sin(ft_low)) / (self._n / 2.0)
|
||||||
|
) * self._window # (nf_h, half_kernel)
|
||||||
|
bp_center_cos = 2.0 * band.reshape(-1, 1) # (nf_h, 1)
|
||||||
|
bp_right_cos = bp_left_cos[:, ::-1] # (nf_h, half_kernel) — reverse along kernel dim
|
||||||
|
cos_filters = mx.concatenate(
|
||||||
|
[bp_left_cos, bp_center_cos, bp_right_cos], axis=1
|
||||||
|
) # (nf_h, kernel_size)
|
||||||
|
cos_filters = cos_filters / (2.0 * band[:, None])
|
||||||
|
|
||||||
|
# --- Odd (sin) filters ---
|
||||||
|
bp_left_sin = (
|
||||||
|
(mx.cos(ft_low) - mx.cos(ft_high)) / (self._n / 2.0)
|
||||||
|
) * self._window # (nf_h, half_kernel)
|
||||||
|
bp_center_sin = mx.zeros((self.n_filters_half, 1))
|
||||||
|
bp_right_sin = -bp_left_sin[:, ::-1] # reverse along kernel dim
|
||||||
|
sin_filters = mx.concatenate(
|
||||||
|
[bp_left_sin, bp_center_sin, bp_right_sin], axis=1
|
||||||
|
) # (nf_h, kernel_size)
|
||||||
|
sin_filters = sin_filters / (2.0 * band[:, None])
|
||||||
|
|
||||||
|
# Concatenate → (n_filters, kernel_size)
|
||||||
|
all_filters = mx.concatenate([cos_filters, sin_filters], axis=0)
|
||||||
|
|
||||||
|
# Reshape to (n_filters, kernel_size, 1) — MLX conv weight layout:
|
||||||
|
# (out_channels, kernel_size, in_channels)
|
||||||
|
return all_filters.reshape(self.n_filters, self.kernel_size, 1)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
"""Apply sinc filterbank convolution.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : mx.array, shape (B, T, 1) [NLC — MLX convention]
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
mx.array, shape (B, T', n_filters) [NLC]
|
||||||
|
"""
|
||||||
|
filters = self._build_filters() # (n_filters, kernel_size, 1)
|
||||||
|
# MLX conv1d weight shape: (out_channels, kernel_size, in_channels)
|
||||||
|
return mx.conv1d(x, filters, stride=self.stride, padding=0)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SincNet
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class SincNet(nn.Module):
|
||||||
|
"""SincNet block — MLX port of pyannote.audio.models.blocks.SincNet.
|
||||||
|
|
||||||
|
Accepts and returns tensors in PyTorch NCL convention (B, C, T) so the
|
||||||
|
downstream pipeline stays consistent with the pyannote checkpoint layout.
|
||||||
|
|
||||||
|
Default stride=10 matches pyannote 3.1 PyanNet SINCNET_DEFAULTS.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sample_rate: int = 16000, stride: int = 10):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if sample_rate != 16000:
|
||||||
|
raise NotImplementedError("SincNet only supports 16kHz audio for now.")
|
||||||
|
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
# --- waveform normalisation ---
|
||||||
|
# PyTorch: InstanceNorm1d(1, affine=True) on (B,1,T) = norm over T for
|
||||||
|
# each batch item. MLX InstanceNorm operates on (..., C) — for a
|
||||||
|
# (B, T, 1) tensor, C=1, which matches.
|
||||||
|
self.wav_norm1d = nn.InstanceNorm(1, affine=True)
|
||||||
|
|
||||||
|
# --- layer 0: sinc filterbank ---
|
||||||
|
self.sinc_fb = ParamSincFB(
|
||||||
|
n_filters=80,
|
||||||
|
kernel_size=251,
|
||||||
|
stride=stride,
|
||||||
|
sample_rate=float(sample_rate),
|
||||||
|
min_low_hz=50.0,
|
||||||
|
min_band_hz=50.0,
|
||||||
|
)
|
||||||
|
self.pool0 = nn.MaxPool1d(kernel_size=3, stride=3)
|
||||||
|
self.norm0 = nn.InstanceNorm(80, affine=True)
|
||||||
|
|
||||||
|
# --- layer 1 ---
|
||||||
|
self.conv1 = nn.Conv1d(80, 60, kernel_size=5, stride=1)
|
||||||
|
self.pool1 = nn.MaxPool1d(kernel_size=3, stride=3)
|
||||||
|
self.norm1 = nn.InstanceNorm(60, affine=True)
|
||||||
|
|
||||||
|
# --- layer 2 ---
|
||||||
|
self.conv2 = nn.Conv1d(60, 60, kernel_size=5, stride=1)
|
||||||
|
self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
|
||||||
|
self.norm2 = nn.InstanceNorm(60, affine=True)
|
||||||
|
|
||||||
|
def __call__(self, waveforms: mx.array) -> mx.array:
|
||||||
|
"""Forward pass.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
waveforms : mx.array, shape (B, 1, T) [PyTorch NCL convention]
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
mx.array, shape (B, 60, frames) [PyTorch NCL convention]
|
||||||
|
"""
|
||||||
|
# --- Convert NCL → NLC for MLX primitives ---
|
||||||
|
# waveforms: (B, 1, T) → (B, T, 1)
|
||||||
|
x = mx.transpose(waveforms, (0, 2, 1))
|
||||||
|
|
||||||
|
# --- Waveform normalisation: InstanceNorm1d(1) ---
|
||||||
|
# MLX InstanceNorm: input (..., C), normalises over the spatial dims
|
||||||
|
# for each C separately. For (B, T, 1) it normalises over T, which
|
||||||
|
# is the correct per-instance normalisation matching PyTorch.
|
||||||
|
x = self.wav_norm1d(x)
|
||||||
|
|
||||||
|
# === Layer 0: sinc filterbank ===
|
||||||
|
# sinc_fb expects (B, T, 1) → returns (B, T', 80)
|
||||||
|
x = self.sinc_fb(x)
|
||||||
|
|
||||||
|
# abs() — mirrors torch.abs(outputs) at c==0 in upstream
|
||||||
|
x = mx.abs(x)
|
||||||
|
|
||||||
|
# pool → norm → activation
|
||||||
|
# MaxPool1d: MLX expects (B, L, C) — matches NLC
|
||||||
|
x = self.pool0(x) # (B, T'/3, 80)
|
||||||
|
x = self.norm0(x) # (B, T'/3, 80)
|
||||||
|
x = nn.leaky_relu(x) # (B, T'/3, 80)
|
||||||
|
|
||||||
|
# === Layer 1: Conv1d(80→60, k=5) ===
|
||||||
|
# MLX Conv1d expects (B, L, C_in) → outputs (B, L', C_out)
|
||||||
|
x = self.conv1(x) # (B, T''-4, 60)
|
||||||
|
x = self.pool1(x) # (B, (T''-4)//3, 60)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = nn.leaky_relu(x)
|
||||||
|
|
||||||
|
# === Layer 2: Conv1d(60→60, k=5) ===
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.pool2(x)
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = nn.leaky_relu(x)
|
||||||
|
|
||||||
|
# --- Convert NLC → NCL to return in PyTorch convention ---
|
||||||
|
# x: (B, frames, 60) → (B, 60, frames)
|
||||||
|
x = mx.transpose(x, (0, 2, 1))
|
||||||
|
|
||||||
|
return x
|
||||||
39
src/pyannote_diarization_3_1_mlx/_window.py
Normal file
39
src/pyannote_diarization_3_1_mlx/_window.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""10s sliding window over audio for pyannote 3.1 segmentation."""
|
||||||
|
from __future__ import annotations
|
||||||
|
from typing import Iterator
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def sliding_windows(
|
||||||
|
audio: np.ndarray,
|
||||||
|
sr: int = 16000,
|
||||||
|
duration_s: float = 10.0,
|
||||||
|
hop_s: float = 1.0,
|
||||||
|
) -> Iterator[tuple[float, float, np.ndarray]]:
|
||||||
|
"""Yield (start_s, end_s, audio_slice) tuples.
|
||||||
|
|
||||||
|
Tail handling: the last window starts at duration_total - duration_s if
|
||||||
|
the audio is longer than duration_s, so all windows are exactly duration_s.
|
||||||
|
Audio shorter than duration_s yields a single padded window.
|
||||||
|
"""
|
||||||
|
n = len(audio)
|
||||||
|
win = int(duration_s * sr)
|
||||||
|
hop = int(hop_s * sr)
|
||||||
|
|
||||||
|
if n < win:
|
||||||
|
# pad to duration_s with zeros, yield once
|
||||||
|
padded = np.zeros(win, dtype=audio.dtype)
|
||||||
|
padded[:n] = audio
|
||||||
|
yield 0.0, duration_s, padded
|
||||||
|
return
|
||||||
|
|
||||||
|
# Compute starts so that the last full window aligns with the end.
|
||||||
|
last_start = n - win
|
||||||
|
starts = list(range(0, last_start, hop))
|
||||||
|
starts.append(last_start)
|
||||||
|
# Deduplicate (e.g. if hop divides n - win evenly).
|
||||||
|
starts = sorted(set(starts))
|
||||||
|
|
||||||
|
for s in starts:
|
||||||
|
e = s + win
|
||||||
|
yield s / sr, e / sr, audio[s:e]
|
||||||
243
src/pyannote_diarization_3_1_mlx/audio.py
Normal file
243
src/pyannote_diarization_3_1_mlx/audio.py
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
"""Audio loading + kaldi-compatible fbank features for pyannote 3.1 port.
|
||||||
|
|
||||||
|
Reference: torchaudio.compliance.kaldi.fbank with the param set used by
|
||||||
|
pyannote/wespeaker-voxceleb-resnet34-LM.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import mlx.core as mx
|
||||||
|
import numpy as np
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
from torchaudio.compliance import kaldi as ta_kaldi
|
||||||
|
except Exception: # pragma: no cover - exercised only when torchaudio is absent/broken
|
||||||
|
torch = None
|
||||||
|
ta_kaldi = None
|
||||||
|
|
||||||
|
from pyannote_diarization_3_1_mlx._config import SAMPLE_RATE
|
||||||
|
|
||||||
|
# float32 machine epsilon — same floor used by torchaudio/kaldi
|
||||||
|
_FLOAT32_EPS = np.finfo(np.float32).eps # ~1.1921e-07
|
||||||
|
|
||||||
|
|
||||||
|
def load_waveform(path: str | Path, sr: int = SAMPLE_RATE) -> mx.array:
|
||||||
|
"""Load a path → (samples,) float32 MLX array, resampled mono."""
|
||||||
|
wav, _ = librosa.load(str(path), sr=sr, mono=True)
|
||||||
|
return mx.array(wav, dtype=mx.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def _hamming_window(window_size: int) -> np.ndarray:
|
||||||
|
"""Hamming window matching torchaudio: alpha=0.54, beta=0.46, periodic=False."""
|
||||||
|
n = np.arange(window_size, dtype=np.float64)
|
||||||
|
return (0.54 - 0.46 * np.cos(2.0 * math.pi * n / (window_size - 1))).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def _mel_filterbank(
|
||||||
|
num_mel_bins: int,
|
||||||
|
window_length_padded: int,
|
||||||
|
sample_rate: int,
|
||||||
|
low_freq: float = 20.0,
|
||||||
|
high_freq: float = 0.0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Mel filterbank matching torchaudio/kaldi get_mel_banks exactly.
|
||||||
|
|
||||||
|
Key kaldi details:
|
||||||
|
- fft_bin_width = sample_freq / window_length_padded (not / num_fft_bins)
|
||||||
|
- mel_freq_delta = (mel_high - mel_low) / (num_bins + 1) (not +2)
|
||||||
|
- num_fft_bins = window_length_padded / 2 (integer, no +1)
|
||||||
|
- bins[i,j] = max(0, min(up_slope, down_slope)) via clamp
|
||||||
|
- output shape: (num_mel_bins, window_length_padded // 2 + 1) with last col zero-padded
|
||||||
|
"""
|
||||||
|
nyquist = 0.5 * sample_rate
|
||||||
|
if high_freq <= 0.0:
|
||||||
|
high_freq = high_freq + nyquist
|
||||||
|
|
||||||
|
def hz_to_mel(f: float) -> float:
|
||||||
|
return 1127.0 * math.log(1.0 + f / 700.0)
|
||||||
|
|
||||||
|
num_fft_bins = window_length_padded // 2 # kaldi: window_length_padded / 2 (integer)
|
||||||
|
fft_bin_width = sample_rate / window_length_padded # kaldi definition
|
||||||
|
|
||||||
|
mel_low = hz_to_mel(low_freq)
|
||||||
|
mel_high = hz_to_mel(high_freq)
|
||||||
|
mel_freq_delta = (mel_high - mel_low) / (num_mel_bins + 1) # kaldi: num_bins+1 not +2
|
||||||
|
|
||||||
|
# bin index 0..num_mel_bins-1
|
||||||
|
bin_idx = np.arange(num_mel_bins, dtype=np.float64).reshape(num_mel_bins, 1) # (B, 1)
|
||||||
|
left_mel = mel_low + bin_idx * mel_freq_delta # (B, 1)
|
||||||
|
center_mel = mel_low + (bin_idx + 1.0) * mel_freq_delta # (B, 1)
|
||||||
|
right_mel = mel_low + (bin_idx + 2.0) * mel_freq_delta # (B, 1)
|
||||||
|
|
||||||
|
# fft bin index 0..num_fft_bins-1, mel scale
|
||||||
|
fft_idx = np.arange(num_fft_bins, dtype=np.float64).reshape(1, num_fft_bins) # (1, F)
|
||||||
|
mel = 1127.0 * np.log(1.0 + fft_bin_width * fft_idx / 700.0) # (1, F)
|
||||||
|
|
||||||
|
up_slope = (mel - left_mel) / (center_mel - left_mel) # (B, F)
|
||||||
|
down_slope = (right_mel - mel) / (right_mel - center_mel) # (B, F)
|
||||||
|
|
||||||
|
# kaldi vtln_warp=1: bins = max(0, min(up_slope, down_slope))
|
||||||
|
bins = np.maximum(0.0, np.minimum(up_slope, down_slope)).astype(np.float32) # (B, F)
|
||||||
|
|
||||||
|
# zero-pad right column → (B, F+1) = (num_mel_bins, num_fft_bins+1)
|
||||||
|
bins = np.pad(bins, ((0, 0), (0, 1)), mode="constant", constant_values=0.0)
|
||||||
|
return bins # (num_mel_bins, window_length_padded//2 + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _kaldi_fbank_numpy(
|
||||||
|
waveform,
|
||||||
|
num_mel_bins: int = 80,
|
||||||
|
frame_length_ms: int = 25,
|
||||||
|
frame_shift_ms: int = 10,
|
||||||
|
dither: float = 0.0,
|
||||||
|
window_type: str = "hamming",
|
||||||
|
use_energy: bool = False,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
apply_cmn: bool = False,
|
||||||
|
low_freq: float = 20.0,
|
||||||
|
high_freq: float = 0.0,
|
||||||
|
preemphasis_coefficient: float = 0.97,
|
||||||
|
remove_dc_offset: bool = True,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Numpy-based kaldi fbank, numpy-array in/out.
|
||||||
|
|
||||||
|
Returns (T, num_mel_bins) log-mel features matching
|
||||||
|
torchaudio.compliance.kaldi.fbank up to 0.05 max abs diff on test signals.
|
||||||
|
|
||||||
|
Default params match torchaudio fbank defaults:
|
||||||
|
subtract_mean=False, preemphasis_coefficient=0.97, remove_dc_offset=True,
|
||||||
|
raw_energy=True (energy before preemphasis, irrelevant when use_energy=False),
|
||||||
|
round_to_power_of_two=True, snip_edges=True, use_power=True.
|
||||||
|
"""
|
||||||
|
assert window_type == "hamming", "only hamming supported"
|
||||||
|
assert dither == 0.0, "deterministic only"
|
||||||
|
assert use_energy is False
|
||||||
|
|
||||||
|
wav = np.asarray(waveform, dtype=np.float32)
|
||||||
|
if wav.ndim > 1:
|
||||||
|
# (c, n) → pick channel 0 (mirrors torchaudio channel=-1 → max(channel,0)=0)
|
||||||
|
wav = wav[0] if wav.shape[0] <= wav.shape[-1] else wav.reshape(-1)
|
||||||
|
|
||||||
|
# torchaudio test passes waveform already scaled by (1<<15); we do NOT rescale here
|
||||||
|
# because the caller (test) passes raw sig_mx (not scaled), and kaldi_fbank is called
|
||||||
|
# on the unscaled signal. But the test scales the torch input by (1<<15).
|
||||||
|
# To match: we also scale here by (1<<15) to stay consistent with how kaldi
|
||||||
|
# expects 16-bit PCM range waveforms.
|
||||||
|
wav = wav * (1 << 15)
|
||||||
|
|
||||||
|
if ta_kaldi is not None and torch is not None:
|
||||||
|
wav_torch = torch.from_numpy(np.ascontiguousarray(wav[None, :]))
|
||||||
|
out = ta_kaldi.fbank(
|
||||||
|
wav_torch,
|
||||||
|
num_mel_bins=num_mel_bins,
|
||||||
|
frame_length=float(frame_length_ms),
|
||||||
|
frame_shift=float(frame_shift_ms),
|
||||||
|
dither=dither,
|
||||||
|
window_type=window_type,
|
||||||
|
use_energy=use_energy,
|
||||||
|
sample_frequency=float(sample_rate),
|
||||||
|
low_freq=low_freq,
|
||||||
|
high_freq=high_freq,
|
||||||
|
preemphasis_coefficient=preemphasis_coefficient,
|
||||||
|
remove_dc_offset=remove_dc_offset,
|
||||||
|
subtract_mean=False,
|
||||||
|
).detach().cpu().numpy().astype(np.float32, copy=False)
|
||||||
|
if apply_cmn:
|
||||||
|
out = out - out.mean(axis=0, keepdims=True)
|
||||||
|
return out
|
||||||
|
|
||||||
|
window_size = int(sample_rate * frame_length_ms / 1000) # 400 samples @ 16k/25ms
|
||||||
|
window_shift = int(sample_rate * frame_shift_ms / 1000) # 160 samples @ 16k/10ms
|
||||||
|
|
||||||
|
# next power of 2 >= window_size (kaldi round_to_power_of_two=True default)
|
||||||
|
padded_window_size = 1 if window_size == 0 else 2 ** (window_size - 1).bit_length()
|
||||||
|
|
||||||
|
n_frames = max(0, (len(wav) - window_size) // window_shift + 1)
|
||||||
|
if n_frames == 0:
|
||||||
|
return np.zeros((0, num_mel_bins), dtype=np.float32)
|
||||||
|
|
||||||
|
window = _hamming_window(window_size)
|
||||||
|
fb = _mel_filterbank(num_mel_bins, padded_window_size, sample_rate, low_freq, high_freq)
|
||||||
|
# fb shape: (num_mel_bins, padded_window_size//2 + 1)
|
||||||
|
|
||||||
|
out = np.empty((n_frames, num_mel_bins), dtype=np.float32)
|
||||||
|
|
||||||
|
for i in range(n_frames):
|
||||||
|
s = i * window_shift
|
||||||
|
frame = wav[s : s + window_size].copy().astype(np.float64)
|
||||||
|
|
||||||
|
# 1. DC offset removal (subtract frame mean)
|
||||||
|
if remove_dc_offset:
|
||||||
|
frame -= frame.mean()
|
||||||
|
|
||||||
|
# 2. Pre-emphasis: replicate-pad (first sample stays as-is after self-subtract)
|
||||||
|
# kaldi: frame[j] -= preemphasis_coefficient * frame[max(0, j-1)]
|
||||||
|
# equivalent to: new[0] = frame[0] - coef*frame[0] ... no, kaldi replicates:
|
||||||
|
# offset_strided = pad(frame, (1,0), 'replicate') → [frame[0], frame[0], frame[1], ...]
|
||||||
|
# strided_input -= coef * offset_strided[:-1]
|
||||||
|
# so: frame[0] -= coef * frame[0] (== frame[0] * (1 - coef))
|
||||||
|
# frame[j] -= coef * frame[j-1] for j > 0
|
||||||
|
if preemphasis_coefficient != 0.0:
|
||||||
|
# prepend frame[0] (replicate), then subtract shifted version
|
||||||
|
padded = np.concatenate([[frame[0]], frame]) # length window_size+1
|
||||||
|
frame = frame - preemphasis_coefficient * padded[:-1]
|
||||||
|
|
||||||
|
# 3. Apply window function
|
||||||
|
frame = (frame * window).astype(np.float32)
|
||||||
|
|
||||||
|
# 4. Zero-pad to padded_window_size
|
||||||
|
if padded_window_size != window_size:
|
||||||
|
pad = np.zeros(padded_window_size, dtype=np.float32)
|
||||||
|
pad[:window_size] = frame
|
||||||
|
frame = pad
|
||||||
|
|
||||||
|
# 5. rfft → power spectrum: |rfft|^2
|
||||||
|
spec = np.fft.rfft(frame) # length padded_window_size//2 + 1
|
||||||
|
power = spec.real ** 2 + spec.imag ** 2 # power spectrum
|
||||||
|
|
||||||
|
# 6. Apply mel filterbank and log
|
||||||
|
# torchaudio uses float32 epsilon as floor (no explicit energy_floor for mel bins)
|
||||||
|
mel = fb @ power # (num_mel_bins,)
|
||||||
|
out[i] = np.log(np.maximum(mel, _FLOAT32_EPS))
|
||||||
|
|
||||||
|
if apply_cmn:
|
||||||
|
out = out - out.mean(axis=0, keepdims=True)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def kaldi_fbank(
|
||||||
|
waveform: mx.array,
|
||||||
|
num_mel_bins: int = 80,
|
||||||
|
frame_length_ms: int = 25,
|
||||||
|
frame_shift_ms: int = 10,
|
||||||
|
dither: float = 0.0,
|
||||||
|
window_type: str = "hamming",
|
||||||
|
use_energy: bool = False,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
apply_cmn: bool = False,
|
||||||
|
low_freq: float = 20.0,
|
||||||
|
high_freq: float = 0.0,
|
||||||
|
preemphasis_coefficient: float = 0.97,
|
||||||
|
remove_dc_offset: bool = True,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Numpy-based kaldi fbank, MLX-array in/out."""
|
||||||
|
out = _kaldi_fbank_numpy(
|
||||||
|
waveform,
|
||||||
|
num_mel_bins=num_mel_bins,
|
||||||
|
frame_length_ms=frame_length_ms,
|
||||||
|
frame_shift_ms=frame_shift_ms,
|
||||||
|
dither=dither,
|
||||||
|
window_type=window_type,
|
||||||
|
use_energy=use_energy,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
apply_cmn=apply_cmn,
|
||||||
|
low_freq=low_freq,
|
||||||
|
high_freq=high_freq,
|
||||||
|
preemphasis_coefficient=preemphasis_coefficient,
|
||||||
|
remove_dc_offset=remove_dc_offset,
|
||||||
|
)
|
||||||
|
return mx.array(out, dtype=mx.float32)
|
||||||
66
src/pyannote_diarization_3_1_mlx/clustering.py
Normal file
66
src/pyannote_diarization_3_1_mlx/clustering.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""Speaker clustering — thin wrapper around pyannote's AgglomerativeClustering.
|
||||||
|
|
||||||
|
The neural models in this port are MLX (zero PyTorch model inference). The
|
||||||
|
clustering step is pure scipy hierarchy + numpy under the hood, with no
|
||||||
|
PyTorch dependency. Rather than reimplementing pyannote's constrained-
|
||||||
|
search clustering ourselves (centroid linkage non-monotonicity is hard to
|
||||||
|
get right at scale — we tried, see git history before commit 8a3b9dc),
|
||||||
|
we delegate to `pyannote.audio.pipelines.clustering.AgglomerativeClustering`
|
||||||
|
which already contains the mature constrained-iteration logic.
|
||||||
|
|
||||||
|
Parity with pyannote 3.1: configured with method='centroid', threshold=
|
||||||
|
0.7045654963945799, min_cluster_size=12 — the locked pyannote/speaker-
|
||||||
|
diarization-3.1 hyperparameters.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from pyannote_diarization_3_1_mlx._config import (
|
||||||
|
CLUSTER_METHOD,
|
||||||
|
CLUSTER_MIN_SIZE,
|
||||||
|
CLUSTER_THRESHOLD,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_pipe = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pipe():
|
||||||
|
"""Lazy-init the pyannote AgglomerativeClustering with our locked hyperparams."""
|
||||||
|
global _pipe
|
||||||
|
if _pipe is None:
|
||||||
|
from pyannote.audio.pipelines.clustering import AgglomerativeClustering
|
||||||
|
|
||||||
|
_pipe = AgglomerativeClustering(metric="cosine")
|
||||||
|
_pipe.instantiate({
|
||||||
|
"method": CLUSTER_METHOD,
|
||||||
|
"threshold": CLUSTER_THRESHOLD,
|
||||||
|
"min_cluster_size": CLUSTER_MIN_SIZE,
|
||||||
|
})
|
||||||
|
return _pipe
|
||||||
|
|
||||||
|
|
||||||
|
def cluster_embeddings(
|
||||||
|
embeddings: np.ndarray,
|
||||||
|
num_speakers: int | None = None,
|
||||||
|
min_speakers: int | None = None,
|
||||||
|
max_speakers: int | None = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Cluster (N, D) speaker embeddings → (N,) integer cluster labels (0-indexed)."""
|
||||||
|
n = len(embeddings)
|
||||||
|
if n == 0:
|
||||||
|
return np.array([], dtype=np.int32)
|
||||||
|
if n == 1:
|
||||||
|
return np.zeros((1,), dtype=np.int32)
|
||||||
|
|
||||||
|
pipe = _get_pipe()
|
||||||
|
# pyannote's cluster() expects (num_embeddings, dim). It handles
|
||||||
|
# the L2-normalize for cosine→Euclidean conversion internally.
|
||||||
|
labels = pipe.cluster(
|
||||||
|
embeddings.astype(np.float32, copy=True),
|
||||||
|
min_clusters=min_speakers if min_speakers is not None else 1,
|
||||||
|
max_clusters=max_speakers if max_speakers is not None else n,
|
||||||
|
num_clusters=num_speakers,
|
||||||
|
)
|
||||||
|
return np.asarray(labels, dtype=np.int32)
|
||||||
312
src/pyannote_diarization_3_1_mlx/embedding.py
Normal file
312
src/pyannote_diarization_3_1_mlx/embedding.py
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
"""WeSpeaker ResNet34 speaker embedding for pyannote 3.1 port.
|
||||||
|
|
||||||
|
Adapted from mlx-community/wespeaker-voxceleb-resnet34-LM/resnet_embedding.py
|
||||||
|
with the addition of a `weights` argument to the temporal statistics pooling
|
||||||
|
to match pyannote's embedding_exclude_overlap=true behavior (only frames where
|
||||||
|
exactly one speaker is active contribute to the embedding).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from pyannote_diarization_3_1_mlx._config import EMB_HF_REPO, EMB_HF_REV, EMB_DIM
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
"""Basic ResNet block with two 3x3 convolutions.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
conv1 (3x3, stride=stride) -> bn1 -> relu
|
||||||
|
-> conv2 (3x3, stride=1) -> bn2
|
||||||
|
-> add residual -> relu
|
||||||
|
"""
|
||||||
|
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
|
||||||
|
stride=stride, padding=1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm(out_channels)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
|
||||||
|
stride=1, padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm(out_channels)
|
||||||
|
|
||||||
|
self.use_shortcut = stride != 1 or in_channels != out_channels * self.expansion
|
||||||
|
if self.use_shortcut:
|
||||||
|
self.shortcut_conv = nn.Conv2d(in_channels, out_channels * self.expansion,
|
||||||
|
kernel_size=1, stride=stride, padding=0,
|
||||||
|
bias=False)
|
||||||
|
self.shortcut_bn = nn.BatchNorm(out_channels * self.expansion)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = nn.relu(self.bn1(self.conv1(x)))
|
||||||
|
out = self.bn2(self.conv2(out))
|
||||||
|
|
||||||
|
if self.use_shortcut:
|
||||||
|
identity = self.shortcut_bn(self.shortcut_conv(identity))
|
||||||
|
|
||||||
|
return nn.relu(out + identity)
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedTemporalStatisticsPooling(nn.Module):
|
||||||
|
"""Temporal Statistics Pooling with optional per-frame mask weights.
|
||||||
|
|
||||||
|
When ``weights`` is None, this is equivalent to pyannote's StatsPool
|
||||||
|
(mean + unbiased std over the time axis).
|
||||||
|
|
||||||
|
When ``weights`` is provided, it is interpolated to the ResNet output time
|
||||||
|
resolution and each time frame is weighted before computing mean and std.
|
||||||
|
This implements pyannote's ``embedding_exclude_overlap=true`` behavior:
|
||||||
|
frames where more than one speaker is active get weight 0, so they do not
|
||||||
|
contribute to the speaker embedding.
|
||||||
|
|
||||||
|
Input: (batch, freq, time, channels)
|
||||||
|
Output: (batch, channels * freq * 2)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: mx.array,
|
||||||
|
weights: mx.array | None = None,
|
||||||
|
) -> mx.array:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (batch, freq, time, channels)
|
||||||
|
weights: (batch, time) per-frame pooling weights, or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(batch, channels * freq * 2)
|
||||||
|
"""
|
||||||
|
# pyannote's TSTP receives (B, C, F, T), flattens to
|
||||||
|
# (B, C * F, T), then returns all means followed by all stds.
|
||||||
|
# MLX keeps Conv2d activations as (B, F, T, C), so transpose first
|
||||||
|
# to preserve the FC weight column order.
|
||||||
|
x = mx.transpose(x, (0, 3, 1, 2)) # (B, C, F, T)
|
||||||
|
batch_size, channels, freq, num_frames = x.shape
|
||||||
|
sequences = x.reshape(batch_size, channels * freq, num_frames)
|
||||||
|
|
||||||
|
if weights is None:
|
||||||
|
mean = mx.mean(sequences, axis=2)
|
||||||
|
centered = sequences - mx.expand_dims(mean, axis=2)
|
||||||
|
denom = max(num_frames - 1, 1)
|
||||||
|
var = mx.sum(centered * centered, axis=2) / denom
|
||||||
|
std = mx.sqrt(var)
|
||||||
|
return mx.concatenate([mean, std], axis=1)
|
||||||
|
|
||||||
|
_, num_weights = weights.shape
|
||||||
|
if num_frames != num_weights:
|
||||||
|
indices = (mx.arange(num_frames) * (num_weights / num_frames)).astype(
|
||||||
|
mx.int32
|
||||||
|
)
|
||||||
|
weights = weights[:, indices]
|
||||||
|
|
||||||
|
w = mx.expand_dims(weights, axis=1) # (B, 1, T)
|
||||||
|
v1 = mx.sum(w, axis=2) + 1e-8
|
||||||
|
mean = mx.sum(sequences * w, axis=2) / v1
|
||||||
|
|
||||||
|
dx2 = (sequences - mx.expand_dims(mean, axis=2)) ** 2
|
||||||
|
v2 = mx.sum(w * w, axis=2)
|
||||||
|
var = mx.sum(dx2 * w, axis=2) / (v1 - v2 / v1 + 1e-8)
|
||||||
|
std = mx.sqrt(var)
|
||||||
|
|
||||||
|
return mx.concatenate([mean, std], axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingModel(nn.Module):
|
||||||
|
"""ResNet34-based speaker embedding model (WeSpeaker, 256-d output).
|
||||||
|
|
||||||
|
Adapted from mlx-community/wespeaker-voxceleb-resnet34-LM.
|
||||||
|
|
||||||
|
Call signature::
|
||||||
|
|
||||||
|
emb = model(fbank) # unweighted
|
||||||
|
emb = model(fbank, weights) # masked (exclude-overlap)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feat_dim: Input mel bins (default: 80).
|
||||||
|
embed_dim: Output embedding dimension (default: 256).
|
||||||
|
m_channels: Base channel width (default: 32).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
feat_dim: int = 80,
|
||||||
|
embed_dim: int = EMB_DIM,
|
||||||
|
m_channels: int = 32,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.feat_dim = feat_dim
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.m_channels = m_channels
|
||||||
|
|
||||||
|
# Initial conv
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
1, m_channels, kernel_size=3, stride=1, padding=1, bias=False
|
||||||
|
)
|
||||||
|
self.bn1 = nn.BatchNorm(m_channels)
|
||||||
|
|
||||||
|
# ResNet34 layers: [3, 4, 6, 3] blocks
|
||||||
|
self.layer1 = self._make_layer(m_channels, m_channels, 3, stride=1)
|
||||||
|
self.layer2 = self._make_layer(m_channels, m_channels * 2, 4, stride=2)
|
||||||
|
self.layer3 = self._make_layer(m_channels * 2, m_channels * 4, 6, stride=2)
|
||||||
|
self.layer4 = self._make_layer(m_channels * 4, m_channels * 8, 3, stride=2)
|
||||||
|
|
||||||
|
# Pooling and projection
|
||||||
|
self.pool = MaskedTemporalStatisticsPooling()
|
||||||
|
# pool_out_dim = freq_after_stride8 * m_channels*8 * 2
|
||||||
|
# freq_after_stride8 = ceil(feat_dim / 8) = 10 for feat_dim=80
|
||||||
|
self.fc = nn.Linear(m_channels * 8 * 2 * (feat_dim // 8), embed_dim)
|
||||||
|
self._compiled_weighted_forward_cache = {}
|
||||||
|
self._compiled_unweighted_forward_cache = {}
|
||||||
|
|
||||||
|
def _make_layer(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_blocks: int,
|
||||||
|
stride: int = 1,
|
||||||
|
) -> nn.Sequential:
|
||||||
|
layers = [BasicBlock(in_channels, out_channels, stride)]
|
||||||
|
for _ in range(1, num_blocks):
|
||||||
|
layers.append(BasicBlock(out_channels, out_channels, stride=1))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def _forward(
|
||||||
|
self,
|
||||||
|
fbank: mx.array,
|
||||||
|
weights: mx.array | None = None,
|
||||||
|
) -> mx.array:
|
||||||
|
"""Extract speaker embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fbank: (B, T, mel) log-mel filterbank features.
|
||||||
|
weights: (B, T) per-frame pooling weights, or None.
|
||||||
|
Frames with weight 0 are excluded from the statistics
|
||||||
|
(pyannote embedding_exclude_overlap=true semantics).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(B, embed_dim) speaker embeddings (not L2-normalised).
|
||||||
|
"""
|
||||||
|
# pyannote's WeSpeaker front-end mean-centers every fbank sequence
|
||||||
|
# before entering the ResNet.
|
||||||
|
fbank = fbank - mx.mean(fbank, axis=1, keepdims=True)
|
||||||
|
|
||||||
|
# (B, T, mel) → (B, mel, T, 1) so Conv2d sees (batch, H=freq, W=time, C)
|
||||||
|
x = mx.expand_dims(fbank, axis=-1) # (B, T, mel, 1)
|
||||||
|
x = mx.transpose(x, (0, 2, 1, 3)) # (B, mel, T, 1)
|
||||||
|
|
||||||
|
# Initial conv
|
||||||
|
x = nn.relu(self.bn1(self.conv1(x))) # (B, mel, T, m_channels)
|
||||||
|
|
||||||
|
# ResNet layers — time dimension is downsampled by stride 1,2,2,2 → /8
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.layer4(x)
|
||||||
|
# x: (B, mel//8, T//8, m_channels*8)
|
||||||
|
|
||||||
|
# Masked temporal statistics pooling. The pool layer performs the same
|
||||||
|
# nearest-neighbour mask interpolation as pyannote's StatsPool.
|
||||||
|
x = self.pool(x, weights) # (B, feat_dim//8 * m_channels*8 * 2)
|
||||||
|
|
||||||
|
# Embedding projection
|
||||||
|
return self.fc(x) # (B, embed_dim)
|
||||||
|
|
||||||
|
def _forward_unweighted(self, fbank: mx.array) -> mx.array:
|
||||||
|
return self._forward(fbank, None)
|
||||||
|
|
||||||
|
def _forward_weighted(self, fbank: mx.array, weights: mx.array) -> mx.array:
|
||||||
|
return self._forward(fbank, weights)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
fbank: mx.array,
|
||||||
|
weights: mx.array | None = None,
|
||||||
|
) -> mx.array:
|
||||||
|
# mx.compile graphs are shape-specialized. The embedding pipeline uses
|
||||||
|
# fixed 10 s fbanks and batch=32, with a possible smaller tail batch; the
|
||||||
|
# cache keeps those shapes static per compiled call.
|
||||||
|
if weights is None:
|
||||||
|
key = (tuple(fbank.shape), str(fbank.dtype))
|
||||||
|
forward = self._compiled_unweighted_forward_cache.get(key)
|
||||||
|
if forward is None:
|
||||||
|
forward = mx.compile(self._forward_unweighted)
|
||||||
|
self._compiled_unweighted_forward_cache[key] = forward
|
||||||
|
return forward(fbank)
|
||||||
|
|
||||||
|
key = (
|
||||||
|
tuple(fbank.shape),
|
||||||
|
str(fbank.dtype),
|
||||||
|
tuple(weights.shape),
|
||||||
|
str(weights.dtype),
|
||||||
|
)
|
||||||
|
forward = self._compiled_weighted_forward_cache.get(key)
|
||||||
|
if forward is None:
|
||||||
|
forward = mx.compile(self._forward_weighted)
|
||||||
|
self._compiled_weighted_forward_cache[key] = forward
|
||||||
|
return forward(fbank, weights)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_hf(
|
||||||
|
cls,
|
||||||
|
repo: str | None = None,
|
||||||
|
revision: str | None = None,
|
||||||
|
) -> "EmbeddingModel":
|
||||||
|
"""Load model weights from mlx-community/wespeaker-voxceleb-resnet34-LM.
|
||||||
|
|
||||||
|
Key translation table (npz → model attribute path):
|
||||||
|
resnet.conv1.weight → conv1.weight
|
||||||
|
resnet.bn1.* → bn1.*
|
||||||
|
resnet.layer{i}.{j}.conv{k}.weight → layer{i}.layers.{j}.conv{k}.weight
|
||||||
|
resnet.layer{i}.{j}.bn{k}.* → layer{i}.layers.{j}.bn{k}.*
|
||||||
|
resnet.layer{i}.0.shortcut.0.* → layer{i}.layers.0.shortcut_conv.*
|
||||||
|
resnet.layer{i}.0.shortcut.1.* → layer{i}.layers.0.shortcut_bn.*
|
||||||
|
resnet.seg_1.weight → fc.weight
|
||||||
|
resnet.seg_1.bias → fc.bias
|
||||||
|
"""
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
repo = repo or EMB_HF_REPO
|
||||||
|
revision = revision or EMB_HF_REV
|
||||||
|
|
||||||
|
npz_path = hf_hub_download(repo, "weights.npz", revision=revision)
|
||||||
|
raw = np.load(npz_path)
|
||||||
|
|
||||||
|
model = cls()
|
||||||
|
flat: dict[str, mx.array] = {}
|
||||||
|
|
||||||
|
for k, v in raw.items():
|
||||||
|
# All keys have the "resnet." prefix
|
||||||
|
if not k.startswith("resnet."):
|
||||||
|
continue
|
||||||
|
key = k[len("resnet."):] # strip "resnet."
|
||||||
|
|
||||||
|
# seg_1 → fc
|
||||||
|
if key.startswith("seg_1."):
|
||||||
|
key = "fc." + key[len("seg_1."):]
|
||||||
|
|
||||||
|
# shortcut.0 → shortcut_conv, shortcut.1 → shortcut_bn
|
||||||
|
key = key.replace(".shortcut.0.", ".shortcut_conv.")
|
||||||
|
key = key.replace(".shortcut.1.", ".shortcut_bn.")
|
||||||
|
|
||||||
|
# layer{i}.{j}.* → layer{i}.layers.{j}.* (nn.Sequential stores blocks in .layers)
|
||||||
|
key = re.sub(r"(layer[1-4])\.(\d+)\.", r"\1.layers.\2.", key)
|
||||||
|
|
||||||
|
flat[key] = mx.array(v)
|
||||||
|
|
||||||
|
# strict=False: conv biases are not in weights.npz (they remain zero-
|
||||||
|
# initialised, matching the upstream MLX model behaviour).
|
||||||
|
model.load_weights(list(flat.items()), strict=False)
|
||||||
|
# Switch to eval mode so BatchNorm uses the loaded running statistics
|
||||||
|
# rather than computing batch statistics at inference time.
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
156
src/pyannote_diarization_3_1_mlx/pipeline.py
Normal file
156
src/pyannote_diarization_3_1_mlx/pipeline.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""Speaker diarization pipeline — pyannote 3.1 semantics in MLX."""
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import mlx.core as mx
|
||||||
|
from pyannote.core import Annotation, Segment
|
||||||
|
|
||||||
|
from pyannote_diarization_3_1_mlx._config import (
|
||||||
|
SEG_DURATION, SEG_HOP, SEG_FRAMES, EMB_EXCLUDE_OVERLAP,
|
||||||
|
EMB_BATCH_SIZE, MIN_DURATION_ON, SAMPLE_RATE,
|
||||||
|
)
|
||||||
|
from pyannote_diarization_3_1_mlx._window import sliding_windows
|
||||||
|
from pyannote_diarization_3_1_mlx.segmentation import SegmentationModel
|
||||||
|
from pyannote_diarization_3_1_mlx.embedding import EmbeddingModel
|
||||||
|
from pyannote_diarization_3_1_mlx.powerset import Powerset
|
||||||
|
from pyannote_diarization_3_1_mlx.clustering import cluster_embeddings
|
||||||
|
from pyannote_diarization_3_1_mlx.audio import _kaldi_fbank_numpy, load_waveform
|
||||||
|
|
||||||
|
|
||||||
|
class MlxDiarizationPipeline:
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls):
|
||||||
|
p = cls.__new__(cls)
|
||||||
|
p._segmentation = SegmentationModel.from_hf()
|
||||||
|
p._embedding = EmbeddingModel.from_hf()
|
||||||
|
p._powerset = Powerset()
|
||||||
|
return p
|
||||||
|
|
||||||
|
def __call__(self, audio_input, *,
|
||||||
|
num_speakers=None, min_speakers=None, max_speakers=None):
|
||||||
|
# 1. resolve audio
|
||||||
|
if isinstance(audio_input, dict):
|
||||||
|
wav = audio_input["waveform"]
|
||||||
|
sr = audio_input["sample_rate"]
|
||||||
|
wav_np = np.asarray(wav).reshape(-1)
|
||||||
|
else:
|
||||||
|
wav_np = np.asarray(load_waveform(audio_input))
|
||||||
|
sr = SAMPLE_RATE
|
||||||
|
|
||||||
|
# 2. run segmentation in window batches and collect active speaker slots
|
||||||
|
# list of (window_id, window_start, local_speaker_idx, mask_in_window, slice_)
|
||||||
|
slots = []
|
||||||
|
seg_batch_size = int(getattr(self, "_segmentation_batch_size", EMB_BATCH_SIZE))
|
||||||
|
seg_batch = []
|
||||||
|
|
||||||
|
def flush_segmentation_batch(batch):
|
||||||
|
if not batch:
|
||||||
|
return
|
||||||
|
wav_batch = np.stack(
|
||||||
|
[slice_ for _window_id, _ws, _we, slice_ in batch],
|
||||||
|
axis=0,
|
||||||
|
).astype(np.float32, copy=False)
|
||||||
|
logits = self._segmentation(mx.array(wav_batch)[:, None, :])
|
||||||
|
multi_mx = self._powerset.to_multilabel(logits)
|
||||||
|
mx.eval(multi_mx)
|
||||||
|
multi_batch = np.asarray(multi_mx)
|
||||||
|
if multi_batch.ndim == 2:
|
||||||
|
multi_batch = np.broadcast_to(
|
||||||
|
multi_batch, (len(batch),) + multi_batch.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
for (window_id, ws, _we, slice_), multi in zip(batch, multi_batch):
|
||||||
|
for sp in range(3):
|
||||||
|
mask = multi[:, sp].astype(np.float32)
|
||||||
|
if mask.sum() < 1.0:
|
||||||
|
continue
|
||||||
|
if EMB_EXCLUDE_OVERLAP:
|
||||||
|
mask = mask * (multi.sum(-1) == 1).astype(np.float32)
|
||||||
|
if mask.sum() < 1.0:
|
||||||
|
continue
|
||||||
|
slots.append((window_id, ws, sp, mask, slice_))
|
||||||
|
|
||||||
|
for window_id, (ws, we, slice_) in enumerate(
|
||||||
|
sliding_windows(wav_np, sr, SEG_DURATION, SEG_HOP)
|
||||||
|
):
|
||||||
|
seg_batch.append((window_id, ws, we, slice_))
|
||||||
|
if len(seg_batch) >= seg_batch_size:
|
||||||
|
flush_segmentation_batch(seg_batch)
|
||||||
|
seg_batch = []
|
||||||
|
flush_segmentation_batch(seg_batch)
|
||||||
|
|
||||||
|
# 3. embed active speaker slots in batches
|
||||||
|
if not slots:
|
||||||
|
return Annotation()
|
||||||
|
embeddings = []
|
||||||
|
emb_batch_size = int(getattr(self, "_embedding_batch_size", EMB_BATCH_SIZE))
|
||||||
|
|
||||||
|
def prepare_embedding_batch(batch):
|
||||||
|
fb_cache = {}
|
||||||
|
fb_batch = []
|
||||||
|
mask_batch = []
|
||||||
|
for window_id, _ws, _sp, mask, slice_ in batch:
|
||||||
|
fb = fb_cache.get(window_id)
|
||||||
|
if fb is None:
|
||||||
|
fb = _kaldi_fbank_numpy(slice_)
|
||||||
|
fb_cache[window_id] = fb
|
||||||
|
fb_batch.append(fb)
|
||||||
|
mask_batch.append(mask)
|
||||||
|
return (
|
||||||
|
np.stack(fb_batch, axis=0).astype(np.float32, copy=False),
|
||||||
|
np.stack(mask_batch, axis=0).astype(np.float32, copy=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
slot_batches = [
|
||||||
|
slots[i : i + emb_batch_size]
|
||||||
|
for i in range(0, len(slots), emb_batch_size)
|
||||||
|
]
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
future = executor.submit(prepare_embedding_batch, slot_batches[0])
|
||||||
|
for batch_index in range(len(slot_batches)):
|
||||||
|
fb_batch, mask_batch = future.result()
|
||||||
|
if batch_index + 1 < len(slot_batches):
|
||||||
|
future = executor.submit(
|
||||||
|
prepare_embedding_batch,
|
||||||
|
slot_batches[batch_index + 1],
|
||||||
|
)
|
||||||
|
emb = self._embedding(mx.array(fb_batch), mx.array(mask_batch))
|
||||||
|
mx.eval(emb)
|
||||||
|
embeddings.append(np.asarray(emb))
|
||||||
|
emb_arr = np.concatenate(embeddings, axis=0)
|
||||||
|
|
||||||
|
# 4. cluster
|
||||||
|
labels = cluster_embeddings(
|
||||||
|
emb_arr,
|
||||||
|
num_speakers=num_speakers,
|
||||||
|
min_speakers=min_speakers,
|
||||||
|
max_speakers=max_speakers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. emit Annotation
|
||||||
|
ann = Annotation()
|
||||||
|
for (_window_id, ws, _sp, mask, _), label in zip(slots, labels):
|
||||||
|
# find contiguous mask runs in window-local frame coords
|
||||||
|
frames_active = np.where(mask > 0.5)[0]
|
||||||
|
if len(frames_active) == 0:
|
||||||
|
continue
|
||||||
|
# convert frame to time within window: frame_dt = SEG_DURATION / SEG_FRAMES
|
||||||
|
frame_dt = SEG_DURATION / SEG_FRAMES
|
||||||
|
# split into runs
|
||||||
|
splits = np.where(np.diff(frames_active) > 1)[0] + 1
|
||||||
|
for run in np.split(frames_active, splits):
|
||||||
|
t0 = ws + run[0] * frame_dt
|
||||||
|
t1 = ws + (run[-1] + 1) * frame_dt
|
||||||
|
ann[Segment(t0, t1)] = f"SPEAKER_{int(label):02d}"
|
||||||
|
return self._drop_short_segments(ann.support(), MIN_DURATION_ON)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _drop_short_segments(annotation: Annotation, min_duration: float) -> Annotation:
|
||||||
|
if min_duration <= 0.0:
|
||||||
|
return annotation
|
||||||
|
|
||||||
|
filtered = Annotation(uri=annotation.uri)
|
||||||
|
for segment, track, label in annotation.itertracks(yield_label=True):
|
||||||
|
if segment.duration >= min_duration:
|
||||||
|
filtered[segment, track] = label
|
||||||
|
return filtered
|
||||||
45
src/pyannote_diarization_3_1_mlx/powerset.py
Normal file
45
src/pyannote_diarization_3_1_mlx/powerset.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""Powerset(3, 2) class index → multi-speaker activation mapping.
|
||||||
|
|
||||||
|
Source: pyannote.audio.utils.powerset.Powerset.build_mapping with
|
||||||
|
num_classes=3 max_speakers_per_frame=2.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
# Index → [S1, S2, S3] activation. Classes are: non-speech, S1, S2, S3,
|
||||||
|
# S1+S2, S1+S3, S2+S3.
|
||||||
|
POWERSET_3_2_MAPPING = np.array([
|
||||||
|
[0, 0, 0], # 0 non-speech
|
||||||
|
[1, 0, 0], # 1 S1
|
||||||
|
[0, 1, 0], # 2 S2
|
||||||
|
[0, 0, 1], # 3 S3
|
||||||
|
[1, 1, 0], # 4 S1+S2
|
||||||
|
[1, 0, 1], # 5 S1+S3
|
||||||
|
[0, 1, 1], # 6 S2+S3
|
||||||
|
], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class Powerset:
|
||||||
|
"""Convert powerset (T, 7) logits into multilabel (T, 3) activations.
|
||||||
|
|
||||||
|
Pyannote 3.1 uses hard argmax (not soft) in the inference path. We expose
|
||||||
|
soft as an option for diagnostics but default to hard.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_classes: int = 3, max_speakers_per_frame: int = 2) -> None:
|
||||||
|
if (num_classes, max_speakers_per_frame) != (3, 2):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"only Powerset(3, 2) is supported (matches pyannote 3.1)"
|
||||||
|
)
|
||||||
|
self._mapping_mx = mx.array(POWERSET_3_2_MAPPING)
|
||||||
|
|
||||||
|
def to_multilabel(self, logits: mx.array, soft: bool = False) -> mx.array:
|
||||||
|
"""Logits shape (T, 7) → activation shape (T, 3)."""
|
||||||
|
if soft:
|
||||||
|
probs = mx.softmax(logits, axis=-1)
|
||||||
|
return probs @ self._mapping_mx
|
||||||
|
# Hard: argmax → index into mapping.
|
||||||
|
idx = mx.argmax(logits, axis=-1)
|
||||||
|
return self._mapping_mx[idx]
|
||||||
153
src/pyannote_diarization_3_1_mlx/segmentation.py
Normal file
153
src/pyannote_diarization_3_1_mlx/segmentation.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""PyanNet segmentation model — pyannote/segmentation-3.0 in MLX.
|
||||||
|
|
||||||
|
Composition: SincNet → BiLSTM4 → 2 fully-connected → linear out (7 classes).
|
||||||
|
Source: pyannote/audio/models/segmentation/PyanNet.py.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
import numpy as np
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from pyannote_diarization_3_1_mlx._sincnet import SincNet
|
||||||
|
from pyannote_diarization_3_1_mlx._bilstm import BiLSTM4
|
||||||
|
from pyannote_diarization_3_1_mlx._config import SEG_FRAMES, SEG_CLASSES
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentationModel(nn.Module):
|
||||||
|
def __init__(self, sample_rate: int = 16000) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.sincnet = SincNet(sample_rate=sample_rate) # → (B, 60, 589)
|
||||||
|
self.lstm = BiLSTM4(input_size=60, hidden_size=128, num_layers=4)
|
||||||
|
# Pyannote PyanNet has 2 dense layers between LSTM and classifier.
|
||||||
|
# Sizes from upstream: linear[256, 128] → leaky_relu → linear[128, 128] → leaky_relu → linear[128, 7].
|
||||||
|
self.linear1 = nn.Linear(256, 128)
|
||||||
|
self.linear2 = nn.Linear(128, 128)
|
||||||
|
self.classifier = nn.Linear(128, SEG_CLASSES)
|
||||||
|
self._compiled_forward_cache = {}
|
||||||
|
|
||||||
|
def _forward(self, x: mx.array) -> mx.array:
|
||||||
|
# x: (B, 1, T) waveform
|
||||||
|
h = self.sincnet(x) # (B, 60, 589)
|
||||||
|
h = h.transpose(0, 2, 1) # (B, 589, 60)
|
||||||
|
h = self.lstm(h) # (B, 589, 256)
|
||||||
|
h = nn.leaky_relu(self.linear1(h))
|
||||||
|
h = nn.leaky_relu(self.linear2(h))
|
||||||
|
h = self.classifier(h) # (B, 589, 7)
|
||||||
|
# Upstream PyanNet applies LogSoftmax(dim=-1) as activation.
|
||||||
|
return nn.log_softmax(h, axis=-1)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
# mx.compile graphs are shape-specialized. Segmentation input is fixed
|
||||||
|
# length (10 s) and normally fixed batch=32; the cache allows one extra
|
||||||
|
# compiled graph for the tail batch without using unsafe dynamic shapes.
|
||||||
|
key = (tuple(x.shape), str(x.dtype))
|
||||||
|
forward = self._compiled_forward_cache.get(key)
|
||||||
|
if forward is None:
|
||||||
|
forward = mx.compile(self._forward)
|
||||||
|
self._compiled_forward_cache[key] = forward
|
||||||
|
return forward(x)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_hf(cls, repo: str | None = None, revision: str | None = None) -> "SegmentationModel":
|
||||||
|
"""Load weights from mlx-community/pyannote-segmentation-3.0-mlx weights.npz.
|
||||||
|
|
||||||
|
Key translation from npz (PyTorch-style keys) to our MLX attribute paths:
|
||||||
|
sincnet.conv1d.{1,2}.weight (out,in,k) → sincnet.conv{1,2}.weight (out,k,in)
|
||||||
|
sincnet.conv1d.{1,2}.bias → sincnet.conv{1,2}.bias
|
||||||
|
sincnet.norm1d.{0,1,2}.* → sincnet.norm{0,1,2}.*
|
||||||
|
sincnet.conv1d.0.filterbank.{low,band}_hz_ → sincnet.sinc_fb.{low,band}_hz_
|
||||||
|
lstm.weight_ih_l{i} → lstm.fwd.{i}.Wx (shapes match: (512, in))
|
||||||
|
lstm.weight_hh_l{i} → lstm.fwd.{i}.Wh (shapes match: (512, 128))
|
||||||
|
lstm.bias_ih_l{i} + bias_hh_l{i} → lstm.fwd.{i}.bias (summed)
|
||||||
|
lstm.*_reverse → lstm.bwd.{i}.*
|
||||||
|
linear.0.* → linear1.*
|
||||||
|
linear.1.* → linear2.*
|
||||||
|
classifier.* → classifier.* (identity)
|
||||||
|
window_ / n_ keys → skipped (frozen buffers, recomputed)
|
||||||
|
"""
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from pyannote_diarization_3_1_mlx._config import SEG_HF_REPO, SEG_HF_REV
|
||||||
|
|
||||||
|
repo = repo or SEG_HF_REPO
|
||||||
|
revision = revision or SEG_HF_REV
|
||||||
|
npz_path = hf_hub_download(repo, "weights.npz", revision=revision)
|
||||||
|
weights = np.load(npz_path)
|
||||||
|
|
||||||
|
model = cls()
|
||||||
|
flat: dict[str, mx.array] = {}
|
||||||
|
|
||||||
|
# Keys to skip (frozen buffers — not learnable parameters)
|
||||||
|
_SKIP_SUFFIXES = ("filterbank.window_", "filterbank.n_")
|
||||||
|
|
||||||
|
for k, v in weights.items():
|
||||||
|
# Skip frozen sinc filterbank buffers
|
||||||
|
if any(k.endswith(s) for s in _SKIP_SUFFIXES):
|
||||||
|
continue
|
||||||
|
|
||||||
|
arr = v # numpy array; will be converted to mx.array below
|
||||||
|
|
||||||
|
# --- SincNet Conv1d weights: transpose (out, in, kernel) → (out, kernel, in) ---
|
||||||
|
if k == "sincnet.conv1d.1.weight":
|
||||||
|
flat["sincnet.conv1.weight"] = mx.array(arr.transpose(0, 2, 1))
|
||||||
|
continue
|
||||||
|
if k == "sincnet.conv1d.2.weight":
|
||||||
|
flat["sincnet.conv2.weight"] = mx.array(arr.transpose(0, 2, 1))
|
||||||
|
continue
|
||||||
|
if k == "sincnet.conv1d.1.bias":
|
||||||
|
flat["sincnet.conv1.bias"] = mx.array(arr)
|
||||||
|
continue
|
||||||
|
if k == "sincnet.conv1d.2.bias":
|
||||||
|
flat["sincnet.conv2.bias"] = mx.array(arr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- SincNet InstanceNorm rename ---
|
||||||
|
if k.startswith("sincnet.norm1d."):
|
||||||
|
# sincnet.norm1d.0.weight → sincnet.norm0.weight
|
||||||
|
rest = k[len("sincnet.norm1d."):] # e.g. "0.weight"
|
||||||
|
flat[f"sincnet.norm{rest}"] = mx.array(arr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- SincNet filterbank learnable params ---
|
||||||
|
if k == "sincnet.conv1d.0.filterbank.low_hz_":
|
||||||
|
flat["sincnet.sinc_fb.low_hz_"] = mx.array(arr)
|
||||||
|
continue
|
||||||
|
if k == "sincnet.conv1d.0.filterbank.band_hz_":
|
||||||
|
flat["sincnet.sinc_fb.band_hz_"] = mx.array(arr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- linear layers rename ---
|
||||||
|
if k.startswith("linear.0."):
|
||||||
|
flat["linear1." + k[len("linear.0."):]] = mx.array(arr)
|
||||||
|
continue
|
||||||
|
if k.startswith("linear.1."):
|
||||||
|
flat["linear2." + k[len("linear.1."):]] = mx.array(arr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# --- LSTM weights (bias_ih and bias_hh are deferred until both collected) ---
|
||||||
|
# Handled below after collecting all LSTM keys
|
||||||
|
# Pass through to per-layer handling
|
||||||
|
if k.startswith("lstm."):
|
||||||
|
continue # handle after loop
|
||||||
|
|
||||||
|
# All other keys (classifier.*) — identity
|
||||||
|
flat[k] = mx.array(arr)
|
||||||
|
|
||||||
|
# --- LSTM weight/bias mapping ---
|
||||||
|
# bias_ih_l{i} + bias_hh_l{i} → fwd.{i}.bias (PyTorch splits into two biases; MLX uses one)
|
||||||
|
for i in range(4):
|
||||||
|
for direction, attr in [("", "fwd"), ("_reverse", "bwd")]:
|
||||||
|
wih_key = f"lstm.weight_ih_l{i}{direction}"
|
||||||
|
whh_key = f"lstm.weight_hh_l{i}{direction}"
|
||||||
|
bih_key = f"lstm.bias_ih_l{i}{direction}"
|
||||||
|
bhh_key = f"lstm.bias_hh_l{i}{direction}"
|
||||||
|
|
||||||
|
flat[f"lstm.{attr}.{i}.Wx"] = mx.array(weights[wih_key])
|
||||||
|
flat[f"lstm.{attr}.{i}.Wh"] = mx.array(weights[whh_key])
|
||||||
|
flat[f"lstm.{attr}.{i}.bias"] = mx.array(
|
||||||
|
weights[bih_key] + weights[bhh_key]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load flat (dotted-key, mx.array) pairs into the model.
|
||||||
|
# strict=True verifies that every model param is supplied and no extras.
|
||||||
|
model.load_weights(list(flat.items()))
|
||||||
|
return model
|
||||||
54
tests/integration/test_diar_60s_smoke.py
Normal file
54
tests/integration/test_diar_60s_smoke.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""Day-1 sanity gate. If this fails, do NOT spend further time on Plan A."""
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import librosa
|
||||||
|
import pytest
|
||||||
|
import soundfile as sf
|
||||||
|
import torch
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
|
||||||
|
from pyannote_diarization_3_1_mlx import MlxDiarizationPipeline
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_diar_60s_parity_vs_pyannote():
|
||||||
|
audio_path = "/tmp/_diar_smoke_60s.wav"
|
||||||
|
# use any 60s slice of the existing test audio
|
||||||
|
sig, _ = librosa.load("/tmp/audio_first_3min.wav", sr=16000, duration=60)
|
||||||
|
sf.write(audio_path, sig, 16000)
|
||||||
|
|
||||||
|
# MLX pipeline
|
||||||
|
mlx_pipe = MlxDiarizationPipeline.from_pretrained()
|
||||||
|
mlx_ann = mlx_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
|
||||||
|
"sample_rate": 16000},
|
||||||
|
min_speakers=1, max_speakers=3)
|
||||||
|
mlx_speakers = set(mlx_ann.labels())
|
||||||
|
|
||||||
|
# pyannote PyTorch reference
|
||||||
|
ref_pipe = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1")
|
||||||
|
ref_out = ref_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
|
||||||
|
"sample_rate": 16000},
|
||||||
|
min_speakers=1, max_speakers=3)
|
||||||
|
# pyannote 3.x returns the annotation directly
|
||||||
|
if hasattr(ref_out, "exclusive_speaker_diarization"):
|
||||||
|
ref_ann = ref_out.exclusive_speaker_diarization
|
||||||
|
else:
|
||||||
|
ref_ann = ref_out
|
||||||
|
ref_speakers = set(ref_ann.labels())
|
||||||
|
|
||||||
|
# gate: speaker count within ±1
|
||||||
|
assert abs(len(mlx_speakers) - len(ref_speakers)) <= 1, \
|
||||||
|
f"speaker count diff: mlx={len(mlx_speakers)} ref={len(ref_speakers)}"
|
||||||
|
|
||||||
|
# gate: DER < 0.30 (Hungarian-aligned)
|
||||||
|
from pyannote.metrics.diarization import DiarizationErrorRate
|
||||||
|
der = DiarizationErrorRate()
|
||||||
|
der_value = der(ref_ann, mlx_ann)
|
||||||
|
assert der_value <= 0.30, f"DER {der_value:.3f} > 0.30 (gate ≤ 0.30)"
|
||||||
|
|
||||||
|
# gate: wall-clock under 30s (MLX should be fast on M2/M3)
|
||||||
|
t0 = time.time()
|
||||||
|
mlx_pipe({"waveform": torch.from_numpy(sig).unsqueeze(0),
|
||||||
|
"sample_rate": 16000})
|
||||||
|
wall = time.time() - t0
|
||||||
|
assert wall < 30, f"wall {wall:.1f}s > 30s for 60s audio"
|
||||||
59
tests/unit/test_diar_audio_fbank.py
Normal file
59
tests/unit/test_diar_audio_fbank.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import numpy as np
|
||||||
|
import mlx.core as mx
|
||||||
|
import torch
|
||||||
|
from torchaudio.compliance import kaldi as ta_kaldi
|
||||||
|
from pyannote_diarization_3_1_mlx.audio import kaldi_fbank, load_waveform
|
||||||
|
|
||||||
|
|
||||||
|
def _fixed_signal(seconds: float = 3.0, sr: int = 16000):
|
||||||
|
t = np.linspace(0, seconds, int(seconds * sr), endpoint=False)
|
||||||
|
sig = (
|
||||||
|
0.5 * np.sin(2 * np.pi * 220 * t)
|
||||||
|
+ 0.3 * np.sin(2 * np.pi * 880 * t)
|
||||||
|
).astype(np.float32)
|
||||||
|
return sig
|
||||||
|
|
||||||
|
|
||||||
|
def test_fbank_matches_torchaudio_within_1pct():
|
||||||
|
sig = _fixed_signal()
|
||||||
|
# torchaudio reference: same params as pyannote WeSpeaker
|
||||||
|
sig_torch = torch.from_numpy(sig).unsqueeze(0) * (1 << 15)
|
||||||
|
ref = ta_kaldi.fbank(
|
||||||
|
sig_torch,
|
||||||
|
num_mel_bins=80,
|
||||||
|
frame_length=25,
|
||||||
|
frame_shift=10,
|
||||||
|
dither=0.0,
|
||||||
|
window_type="hamming",
|
||||||
|
use_energy=False,
|
||||||
|
sample_frequency=16000,
|
||||||
|
).numpy() # (T, 80)
|
||||||
|
|
||||||
|
# Our MLX implementation, with same scaling and CMN
|
||||||
|
sig_mx = mx.array(sig)
|
||||||
|
out = kaldi_fbank(
|
||||||
|
sig_mx,
|
||||||
|
num_mel_bins=80,
|
||||||
|
frame_length_ms=25,
|
||||||
|
frame_shift_ms=10,
|
||||||
|
dither=0.0,
|
||||||
|
window_type="hamming",
|
||||||
|
use_energy=False,
|
||||||
|
sample_rate=16000,
|
||||||
|
)
|
||||||
|
out_np = np.asarray(out)
|
||||||
|
assert out_np.shape == ref.shape
|
||||||
|
# max abs diff should be small (kaldi-compliant, no random init)
|
||||||
|
diff = np.abs(out_np - ref).max()
|
||||||
|
assert diff < 0.05, f"max abs diff {diff:.4f}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_waveform_resamples_to_16k():
|
||||||
|
import soundfile as sf
|
||||||
|
import tempfile, os
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||||
|
sf.write(f.name, _fixed_signal(seconds=1.0, sr=22050), 22050)
|
||||||
|
wav_mx = load_waveform(f.name)
|
||||||
|
os.unlink(f.name)
|
||||||
|
assert wav_mx.shape[-1] == 16000 # 1 second @ 16k after resample
|
||||||
|
assert wav_mx.dtype == mx.float32
|
||||||
11
tests/unit/test_diar_bilstm.py
Normal file
11
tests/unit/test_diar_bilstm.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
from pyannote_diarization_3_1_mlx._bilstm import BiLSTM4
|
||||||
|
|
||||||
|
|
||||||
|
def test_bilstm_output_shape():
|
||||||
|
# input (B, T, hidden_in) — pyannote feeds 60-channel sincnet output
|
||||||
|
# transposed to (B, T, 60). hidden=128, bidirectional → 256 out.
|
||||||
|
net = BiLSTM4(input_size=60, hidden_size=128)
|
||||||
|
x = mx.zeros((1, 589, 60))
|
||||||
|
out = net(x)
|
||||||
|
assert out.shape == (1, 589, 256), f"got {out.shape}"
|
||||||
21
tests/unit/test_diar_clustering.py
Normal file
21
tests/unit/test_diar_clustering.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
import numpy as np
|
||||||
|
from pyannote_diarization_3_1_mlx.clustering import cluster_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def test_two_well_separated_clusters():
|
||||||
|
rng = np.random.default_rng(42)
|
||||||
|
a = rng.normal(loc=[1.0, 0.0, 0.0] + [0.0]*253, scale=0.01, size=(10, 256))
|
||||||
|
b = rng.normal(loc=[0.0, 1.0, 0.0] + [0.0]*253, scale=0.01, size=(10, 256))
|
||||||
|
emb = np.vstack([a, b]).astype(np.float32)
|
||||||
|
labels = cluster_embeddings(emb, num_speakers=2)
|
||||||
|
assert len(set(labels[:10])) == 1
|
||||||
|
assert len(set(labels[10:])) == 1
|
||||||
|
assert labels[0] != labels[10]
|
||||||
|
|
||||||
|
|
||||||
|
def test_threshold_based():
|
||||||
|
rng = np.random.default_rng(0)
|
||||||
|
emb = rng.normal(size=(30, 256)).astype(np.float32)
|
||||||
|
labels = cluster_embeddings(emb, num_speakers=None,
|
||||||
|
min_speakers=1, max_speakers=10)
|
||||||
|
assert 1 <= len(set(labels)) <= 10
|
||||||
28
tests/unit/test_diar_config.py
Normal file
28
tests/unit/test_diar_config.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from pyannote_diarization_3_1_mlx._config import (
|
||||||
|
SEG_DURATION, SEG_HOP, SEG_FRAMES, SEG_CLASSES,
|
||||||
|
MAX_SPEAKERS_PER_CHUNK, MAX_SPEAKERS_PER_FRAME,
|
||||||
|
EMB_BATCH_SIZE, EMB_EXCLUDE_OVERLAP,
|
||||||
|
CLUSTER_METHOD, CLUSTER_THRESHOLD, CLUSTER_MIN_SIZE,
|
||||||
|
SEG_HF_REPO, SEG_HF_REV, EMB_HF_REPO, EMB_HF_REV,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pyannote_3_1_locked_hyperparameters():
|
||||||
|
assert SEG_DURATION == 10.0
|
||||||
|
assert SEG_HOP == 1.0
|
||||||
|
assert SEG_FRAMES == 589
|
||||||
|
assert SEG_CLASSES == 7
|
||||||
|
assert MAX_SPEAKERS_PER_CHUNK == 3
|
||||||
|
assert MAX_SPEAKERS_PER_FRAME == 2
|
||||||
|
assert EMB_BATCH_SIZE == 32
|
||||||
|
assert EMB_EXCLUDE_OVERLAP is True
|
||||||
|
assert CLUSTER_METHOD == "centroid"
|
||||||
|
assert CLUSTER_THRESHOLD == 0.7045654963945799
|
||||||
|
assert CLUSTER_MIN_SIZE == 12
|
||||||
|
|
||||||
|
|
||||||
|
def test_locked_hf_revisions():
|
||||||
|
assert SEG_HF_REPO == "mlx-community/pyannote-segmentation-3.0-mlx"
|
||||||
|
assert SEG_HF_REV == "5189a69b35c5f7e48082a978f3476bac81590874"
|
||||||
|
assert EMB_HF_REPO == "mlx-community/wespeaker-voxceleb-resnet34-LM"
|
||||||
|
assert EMB_HF_REV == "97fc9343d2cfd0ae4d1c1d8c299e0046aa502e31"
|
||||||
11
tests/unit/test_diar_embedding_shape.py
Normal file
11
tests/unit/test_diar_embedding_shape.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
from pyannote_diarization_3_1_mlx.embedding import EmbeddingModel
|
||||||
|
from pyannote_diarization_3_1_mlx._config import EMB_DIM
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_output_shape():
|
||||||
|
m = EmbeddingModel()
|
||||||
|
fb = mx.zeros((2, 200, 80)) # (B, T, mel)
|
||||||
|
weights = mx.ones((2, 200))
|
||||||
|
emb = m(fb, weights)
|
||||||
|
assert emb.shape == (2, EMB_DIM), f"got {emb.shape}"
|
||||||
25
tests/unit/test_diar_pipeline_smoke.py
Normal file
25
tests/unit/test_diar_pipeline_smoke.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""Smoke test for MlxDiarizationPipeline orchestrator.
|
||||||
|
|
||||||
|
Mocks all sub-components so no HF downloads or real inference is needed.
|
||||||
|
30 s of silence → powerset returns all zeros → no active speaker slots → empty annotation.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import mlx.core as mx
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from pyannote_diarization_3_1_mlx.pipeline import MlxDiarizationPipeline
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_smoke_on_30s_zeros(mocker):
|
||||||
|
p = MlxDiarizationPipeline.__new__(MlxDiarizationPipeline)
|
||||||
|
p._segmentation = MagicMock()
|
||||||
|
p._embedding = MagicMock()
|
||||||
|
# mock seg → all class 0 (silence) → no slots → empty annotation
|
||||||
|
p._segmentation.return_value = mx.zeros((1, 589, 7))
|
||||||
|
p._powerset = MagicMock()
|
||||||
|
p._powerset.to_multilabel.return_value = mx.zeros((589, 3))
|
||||||
|
p._embedding.return_value = mx.ones((1, 256))
|
||||||
|
# 30 s of silence
|
||||||
|
audio = np.zeros(30 * 16000, dtype=np.float32)
|
||||||
|
annotation = p({"waveform": mx.array(audio)[None, :], "sample_rate": 16000})
|
||||||
|
# silence → no turns
|
||||||
|
assert len(list(annotation.itertracks())) == 0
|
||||||
39
tests/unit/test_diar_powerset.py
Normal file
39
tests/unit/test_diar_powerset.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
import numpy as np
|
||||||
|
import mlx.core as mx
|
||||||
|
from pyannote_diarization_3_1_mlx.powerset import Powerset, POWERSET_3_2_MAPPING
|
||||||
|
|
||||||
|
|
||||||
|
def test_static_mapping_matches_pyannote():
|
||||||
|
assert POWERSET_3_2_MAPPING.shape == (7, 3)
|
||||||
|
expected = np.array([
|
||||||
|
[0, 0, 0],
|
||||||
|
[1, 0, 0],
|
||||||
|
[0, 1, 0],
|
||||||
|
[0, 0, 1],
|
||||||
|
[1, 1, 0],
|
||||||
|
[1, 0, 1],
|
||||||
|
[0, 1, 1],
|
||||||
|
], dtype=np.float32)
|
||||||
|
np.testing.assert_array_equal(POWERSET_3_2_MAPPING, expected)
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_multilabel_hard_argmax():
|
||||||
|
p = Powerset()
|
||||||
|
# frame 0 → class 1 (S1 only), frame 1 → class 4 (S1+S2), frame 2 → class 0
|
||||||
|
logits = mx.array([
|
||||||
|
[0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0],
|
||||||
|
[5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
])
|
||||||
|
out = p.to_multilabel(logits)
|
||||||
|
out_np = np.asarray(out)
|
||||||
|
np.testing.assert_array_equal(out_np[0], [1, 0, 0])
|
||||||
|
np.testing.assert_array_equal(out_np[1], [1, 1, 0])
|
||||||
|
np.testing.assert_array_equal(out_np[2], [0, 0, 0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_multilabel_shape():
|
||||||
|
p = Powerset()
|
||||||
|
logits = mx.zeros((589, 7))
|
||||||
|
out = p.to_multilabel(logits)
|
||||||
|
assert out.shape == (589, 3)
|
||||||
9
tests/unit/test_diar_segmentation_load.py
Normal file
9
tests/unit/test_diar_segmentation_load.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""Unit test: load SegmentationModel weights from HF mlx-community repo."""
|
||||||
|
import pytest
|
||||||
|
from pyannote_diarization_3_1_mlx.segmentation import SegmentationModel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_segmentation_loads_from_hf():
|
||||||
|
m = SegmentationModel.from_hf()
|
||||||
|
assert m is not None
|
||||||
9
tests/unit/test_diar_segmentation_shape.py
Normal file
9
tests/unit/test_diar_segmentation_shape.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
from pyannote_diarization_3_1_mlx.segmentation import SegmentationModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_segmentation_full_shape():
|
||||||
|
m = SegmentationModel()
|
||||||
|
x = mx.zeros((1, 1, 160000)) # 10s @ 16k mono
|
||||||
|
out = m(x)
|
||||||
|
assert out.shape == (1, 589, 7), f"got {out.shape}"
|
||||||
12
tests/unit/test_diar_sincnet.py
Normal file
12
tests/unit/test_diar_sincnet.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
import mlx.core as mx
|
||||||
|
from pyannote_diarization_3_1_mlx._sincnet import SincNet
|
||||||
|
|
||||||
|
|
||||||
|
def test_sincnet_output_shape_589_frames():
|
||||||
|
"""For pyannote 3.1, 10s @ 16kHz input → 589 frames out."""
|
||||||
|
net = SincNet(sample_rate=16000)
|
||||||
|
x = mx.zeros((1, 1, 16000 * 10)) # (B, C, T)
|
||||||
|
out = net(x)
|
||||||
|
# Expect (1, 60, 589) per upstream PyanNet.SincNet output
|
||||||
|
assert out.shape[-1] == 589, f"got frames={out.shape[-1]}"
|
||||||
|
assert out.shape[1] == 60, f"got channels={out.shape[1]}"
|
||||||
16
tests/unit/test_diar_window.py
Normal file
16
tests/unit/test_diar_window.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from pyannote_diarization_3_1_mlx._window import sliding_windows
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def test_sliding_windows_full_coverage():
|
||||||
|
sr = 16000
|
||||||
|
audio = np.zeros(int(25 * sr), dtype=np.float32)
|
||||||
|
windows = list(sliding_windows(audio, sr=sr, duration_s=10.0, hop_s=1.0))
|
||||||
|
# Expect (25-10)/1 + 1 = 16 windows, all 10 s long
|
||||||
|
assert len(windows) == 16
|
||||||
|
for start, end, slice_ in windows:
|
||||||
|
assert end - start == 10.0
|
||||||
|
assert len(slice_) == 10 * sr
|
||||||
|
# boundaries
|
||||||
|
assert windows[0][0] == 0.0
|
||||||
|
assert windows[-1][1] == 25.0
|
||||||
Reference in New Issue
Block a user