From c6a20cb79f9cab476d638724193bf5bd5eb97e0f Mon Sep 17 00:00:00 2001 From: transcrilive Date: Sat, 9 May 2026 20:00:57 +0200 Subject: [PATCH] Initial Granite Speech Plus MLX package --- .gitignore | 30 + LICENSE | 22 + README.md | 42 + docs/prompt-modes.md | 37 + pyproject.toml | 27 + scripts/benchmark.py | 111 +++ scripts/transcribe.py | 47 + scripts/upload_to_hf.py | 66 ++ src/granite_speech_plus_mlx/__init__.py | 4 + .../_vendored/__init__.py | 1 + .../_vendored/audio.py | 15 + src/granite_speech_plus_mlx/_vendored/base.py | 18 + src/granite_speech_plus_mlx/_vendored/dsp.py | 161 ++++ .../_vendored/granite_speech/__init__.py | 36 + .../_vendored/granite_speech/config.py | 128 +++ .../granite_speech/granite_speech.py | 850 ++++++++++++++++++ .../_vendored/loader.py | 154 ++++ src/granite_speech_plus_mlx/chunking.py | 56 ++ src/granite_speech_plus_mlx/pipeline.py | 126 +++ src/granite_speech_plus_mlx/prompts.py | 65 ++ tests/test_smoke.py | 6 + 21 files changed, 2002 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 docs/prompt-modes.md create mode 100644 pyproject.toml create mode 100755 scripts/benchmark.py create mode 100755 scripts/transcribe.py create mode 100755 scripts/upload_to_hf.py create mode 100644 src/granite_speech_plus_mlx/__init__.py create mode 100644 src/granite_speech_plus_mlx/_vendored/__init__.py create mode 100644 src/granite_speech_plus_mlx/_vendored/audio.py create mode 100644 src/granite_speech_plus_mlx/_vendored/base.py create mode 100644 src/granite_speech_plus_mlx/_vendored/dsp.py create mode 100644 src/granite_speech_plus_mlx/_vendored/granite_speech/__init__.py create mode 100644 src/granite_speech_plus_mlx/_vendored/granite_speech/config.py create mode 100644 src/granite_speech_plus_mlx/_vendored/granite_speech/granite_speech.py create mode 100644 src/granite_speech_plus_mlx/_vendored/loader.py create mode 100644 src/granite_speech_plus_mlx/chunking.py create mode 100644 src/granite_speech_plus_mlx/pipeline.py create mode 100644 src/granite_speech_plus_mlx/prompts.py create mode 100644 tests/test_smoke.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..97d2bf1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,30 @@ +__pycache__/ +*.py[cod] +*$py.class + +.Python +.venv/ +venv/ +ENV/ +env/ + +build/ +dist/ +*.egg-info/ +.eggs/ + +.pytest_cache/ +.mypy_cache/ +.ruff_cache/ +.coverage +htmlcov/ + +.DS_Store +.env +.env.* + +*.log +*.tmp +transcripts/ +bench/ + diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..a3605c7 --- /dev/null +++ b/LICENSE @@ -0,0 +1,22 @@ +MIT License + +Copyright (c) 2026 Olivier Dupont + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/README.md b/README.md new file mode 100644 index 0000000..3833910 --- /dev/null +++ b/README.md @@ -0,0 +1,42 @@ +# granite-speech-4.1-2b-plus-mlx + +Standalone Python package for the MLX port of IBM Granite Speech 4.1-2b-plus. +The default model is +[`mlx-community/granite-speech-4.1-2b-plus-mlx`](https://huggingface.co/mlx-community/granite-speech-4.1-2b-plus-mlx). + +## Quickstart + +```bash +uv add "granite-speech-4.1-2b-plus-mlx @ git+https://gitea.tavportal.com/olivier/granite-speech-4.1-2b-plus-mlx.git" +python -c "from granite_speech_plus_mlx import GraniteSpeechPlusPipeline as P; p=P.from_pretrained(); print(p.transcribe('audio.wav'))" +python scripts/transcribe.py audio.wav --prompt-mode asr --output transcript.txt +python scripts/transcribe.py meeting.wav --prompt-mode saa +python scripts/benchmark.py audio.wav --results bench +``` + +## Prompt Modes + +- `asr`: standard transcription. +- `saa`: speaker-attributed ASR with `[Speaker N]:` turn labels. +- `ts`: word-level timestamp tags like `word [T:45]`. + +See [docs/prompt-modes.md](docs/prompt-modes.md) for examples. + +## Benchmark Hints + +Granite Speech 4.1 allocates substantial encoder memory for long audio. Start +with `--chunk-seconds 300 --repetition-penalty 1.2` for ASR and reduce chunks +to 60 or 180 seconds if memory is tight. Timestamp mode (`ts`) often needs a +larger `--max-tokens` budget because every word carries a timestamp tag. + +## Provenance + +This package was extracted from the local `MLX_CONVERTOR` project, including +the Granite Speech patch bundle at +`external/patches/granite-speech-idempotent-sanitize.patch`. The vendored +Granite implementation is based on `mlx-audio` commit +`f7c11556eda88731be5cc75ddbdf4a4cb9eeaafc` plus that local patch. + +Package code is MIT licensed. Model weights remain under the IBM Granite model +license; review the model card and license terms before redistribution or use. + diff --git a/docs/prompt-modes.md b/docs/prompt-modes.md new file mode 100644 index 0000000..39bec0f --- /dev/null +++ b/docs/prompt-modes.md @@ -0,0 +1,37 @@ +# Prompt Modes + +Granite Speech Plus supports three prompt modes in this package. + +## `asr` + +Standard speech transcription. + +```python +from granite_speech_plus_mlx import GraniteSpeechPlusPipeline + +pipe = GraniteSpeechPlusPipeline.from_pretrained() +text = pipe.transcribe("audio.wav", prompt_mode="asr") +``` + +## `saa` + +Speaker-attributed ASR. The prompt asks the model to add speaker turn labels +such as `[Speaker 1]:` and `[Speaker 2]:`. + +```python +text = pipe.transcribe("meeting.wav", prompt_mode="saa") +``` + +## `ts` + +Word-level timestamps. The prompt asks the model to append centisecond tags +after words, for example `hello [T:45] world [T:82]`. + +```python +text = pipe.transcribe("clip.wav", prompt_mode="ts") +``` + +For long audio, the pipeline chunks the waveform and feeds a short previous +transcript prefix into later chunks for continuity. The prefix is context only; +the model is instructed not to repeat it. + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6bf7132 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,27 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "granite-speech-4.1-2b-plus-mlx" +version = "0.1.0" +description = "Standalone MLX pipeline for the Granite Speech 4.1-2b-plus port." +readme = "README.md" +requires-python = ">=3.10" +license = "MIT" +authors = [ + { name = "Olivier Dupont" } +] +dependencies = [ + "mlx>=0.22.0", + "mlx-lm>=0.19.0", + "numpy>=1.26", + "transformers>=4.45", + "huggingface-hub>=0.24", + "soundfile>=0.12", + "librosa>=0.10", +] + +[tool.hatch.build.targets.wheel] +packages = ["src/granite_speech_plus_mlx"] + diff --git a/scripts/benchmark.py b/scripts/benchmark.py new file mode 100755 index 0000000..72e0f69 --- /dev/null +++ b/scripts/benchmark.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +from __future__ import annotations + +import argparse +import sys +import time +from collections import Counter +from pathlib import Path + +from granite_speech_plus_mlx import GraniteSpeechPlusPipeline +from granite_speech_plus_mlx.pipeline import DEFAULT_MODEL +from granite_speech_plus_mlx.prompts import PROMPT_MODES + +GRID = [ + (60, 1.0), + (60, 1.2), + (180, 1.0), + (180, 1.2), + (300, 1.0), + (300, 1.2), + (300, 1.4), +] + +HALLUCINATION_MARKERS = ("thank you very much", "merci d'avoir regarde") + + +def analyze(text: str) -> dict: + words = text.split() + lower_words = text.lower().split() + trigrams = Counter( + " ".join(lower_words[i : i + 3]) for i in range(len(lower_words) - 2) + ) + top = trigrams.most_common(5) + lower = text.lower() + return { + "n_words": len(words), + "max_trigram_count": top[0][1] if top else 0, + "max_trigram_text": top[0][0] if top else "", + "halluc": {m: lower.count(m) for m in HALLUCINATION_MARKERS}, + } + + +def main() -> int: + parser = argparse.ArgumentParser(description="Benchmark Granite Speech Plus MLX settings.") + parser.add_argument("audio") + parser.add_argument("--model", default=DEFAULT_MODEL) + parser.add_argument("--results", default="bench") + parser.add_argument("--prompt-mode", choices=sorted(PROMPT_MODES), default="asr") + parser.add_argument("--overlap-seconds", type=float, default=2.0) + parser.add_argument("--max-tokens", type=int, default=4096) + args = parser.parse_args() + + results_dir = Path(args.results) + results_dir.mkdir(parents=True, exist_ok=True) + pipe = GraniteSpeechPlusPipeline.from_pretrained( + args.model, + overlap_seconds=args.overlap_seconds, + max_tokens=args.max_tokens, + verbose=True, + ) + + rows = [] + for chunk_seconds, repetition_penalty in GRID: + out = results_dir / f"chunk{chunk_seconds}_rp{repetition_penalty:.1f}.txt" + pipe.chunk_seconds = float(chunk_seconds) + pipe.repetition_penalty = repetition_penalty + + if out.exists(): + print(f"# skipping {out.name} (already exists, delete to rerun)", file=sys.stderr) + elapsed = float("nan") + text = out.read_text(encoding="utf-8") + else: + print( + f"# running chunk={chunk_seconds}s rep_penalty={repetition_penalty}", + file=sys.stderr, + ) + t0 = time.time() + text = pipe.transcribe(args.audio, prompt_mode=args.prompt_mode) + elapsed = time.time() - t0 + out.write_text(text + "\n", encoding="utf-8") + + rows.append( + { + "chunk": chunk_seconds, + "rp": repetition_penalty, + "elapsed": elapsed, + **analyze(text), + } + ) + + print() + print("| chunk(s) | rp | wall(s) | words | max_trigram(N) | hallucinations |") + print("|---:|---:|---:|---:|:---|:---|") + for row in rows: + halluc = ", ".join( + f"{key.split()[0]}x{value}" for key, value in row["halluc"].items() if value + ) or "-" + trigram = f"{row['max_trigram_text']!r} ({row['max_trigram_count']}x)" + wall = "nan" if row["elapsed"] != row["elapsed"] else f"{row['elapsed']:.0f}" + print( + f"| {row['chunk']} | {row['rp']:.1f} | {wall} | {row['n_words']} " + f"| {trigram} | {halluc} |" + ) + print() + print(f"Per-config transcripts in: {results_dir}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) + diff --git a/scripts/transcribe.py b/scripts/transcribe.py new file mode 100755 index 0000000..1aee816 --- /dev/null +++ b/scripts/transcribe.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +from granite_speech_plus_mlx import GraniteSpeechPlusPipeline +from granite_speech_plus_mlx.pipeline import DEFAULT_MODEL +from granite_speech_plus_mlx.prompts import GRANITE_SYSTEM_PROMPT, PROMPT_MODES + + +def main() -> int: + parser = argparse.ArgumentParser(description="Transcribe audio with Granite Speech Plus MLX.") + parser.add_argument("audio") + parser.add_argument("--model", default=DEFAULT_MODEL) + parser.add_argument("--output", default=None) + parser.add_argument("--chunk-seconds", type=float, default=300.0) + parser.add_argument("--overlap-seconds", type=float, default=2.0) + parser.add_argument("--prompt-mode", choices=sorted(PROMPT_MODES), default="asr") + parser.add_argument("--repetition-penalty", type=float, default=1.2) + parser.add_argument("--max-tokens", type=int, default=4096) + parser.add_argument("--system-prompt", default=GRANITE_SYSTEM_PROMPT) + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + pipe = GraniteSpeechPlusPipeline.from_pretrained( + args.model, + chunk_seconds=args.chunk_seconds, + overlap_seconds=args.overlap_seconds, + repetition_penalty=args.repetition_penalty, + max_tokens=args.max_tokens, + system_prompt=args.system_prompt or None, + verbose=args.verbose, + ) + text = pipe.transcribe(args.audio, prompt_mode=args.prompt_mode) + + if args.output: + Path(args.output).write_text(text + "\n", encoding="utf-8") + else: + print(text) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) + diff --git a/scripts/upload_to_hf.py b/scripts/upload_to_hf.py new file mode 100755 index 0000000..21dcedc --- /dev/null +++ b/scripts/upload_to_hf.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python +from __future__ import annotations + +import os +import sys +from pathlib import Path + +from huggingface_hub import HfApi + +SOURCE_CACHE = ( + Path.home() + / ".cache/huggingface/hub/models--ibm-granite--granite-speech-4.1-2b-plus" +) +DEST_REPO = "mlx-community/granite-speech-4.1-2b-plus-mlx" + + +def find_weights_dir(root: Path) -> Path | None: + if not root.exists(): + return None + if list(root.glob("*.safetensors")) or (root / "config.json").exists(): + return root + snapshots = root / "snapshots" + if snapshots.exists(): + candidates = [ + path + for path in snapshots.iterdir() + if path.is_dir() and (list(path.glob("*.safetensors")) or (path / "config.json").exists()) + ] + if candidates: + return sorted(candidates, key=lambda p: p.stat().st_mtime)[-1] + return None + + +def print_manual_commands() -> None: + print(f"MLX weights not found at {SOURCE_CACHE}") + print("Create them first with:") + print("mlxconv ibm-granite/granite-speech-4.1-2b-plus") + print("mlxconv ibm-granite/granite-speech-4.1-2b-plus --dtype q4_k_4") + + +def main() -> int: + weights_dir = find_weights_dir(SOURCE_CACHE) + if weights_dir is None: + print_manual_commands() + return 1 + + token = os.environ.get("HF_TOKEN") + if not token: + print("HF_TOKEN is required to upload.", file=sys.stderr) + return 2 + + api = HfApi(token=token) + api.create_repo(DEST_REPO, repo_type="model", exist_ok=True) + api.upload_folder( + repo_id=DEST_REPO, + repo_type="model", + folder_path=str(weights_dir), + commit_message="Upload Granite Speech 4.1-2b-plus MLX weights", + ) + print(f"Uploaded {weights_dir} to {DEST_REPO}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) + diff --git a/src/granite_speech_plus_mlx/__init__.py b/src/granite_speech_plus_mlx/__init__.py new file mode 100644 index 0000000..f978c3d --- /dev/null +++ b/src/granite_speech_plus_mlx/__init__.py @@ -0,0 +1,4 @@ +from .pipeline import GraniteSpeechPlusPipeline + +__all__ = ["GraniteSpeechPlusPipeline"] + diff --git a/src/granite_speech_plus_mlx/_vendored/__init__.py b/src/granite_speech_plus_mlx/_vendored/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/granite_speech_plus_mlx/_vendored/__init__.py @@ -0,0 +1 @@ + diff --git a/src/granite_speech_plus_mlx/_vendored/audio.py b/src/granite_speech_plus_mlx/_vendored/audio.py new file mode 100644 index 0000000..344b690 --- /dev/null +++ b/src/granite_speech_plus_mlx/_vendored/audio.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from pathlib import Path + +import librosa + +SAMPLE_RATE = 16000 + + +def load_audio(file: str | Path, sr: int = SAMPLE_RATE): + import mlx.core as mx + + audio, _ = librosa.load(str(file), sr=sr, mono=True) + return mx.array(audio, dtype=mx.float32) + diff --git a/src/granite_speech_plus_mlx/_vendored/base.py b/src/granite_speech_plus_mlx/_vendored/base.py new file mode 100644 index 0000000..f17717d --- /dev/null +++ b/src/granite_speech_plus_mlx/_vendored/base.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List + + +@dataclass +class STTOutput: + text: str + segments: List[dict] | None = None + language: str | None = None + prompt_tokens: int = 0 + generation_tokens: int = 0 + total_tokens: int = 0 + prompt_tps: float = 0.0 + generation_tps: float = 0.0 + total_time: float = 0.0 + diff --git a/src/granite_speech_plus_mlx/_vendored/dsp.py b/src/granite_speech_plus_mlx/_vendored/dsp.py new file mode 100644 index 0000000..7b901ff --- /dev/null +++ b/src/granite_speech_plus_mlx/_vendored/dsp.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import math +from functools import lru_cache +from typing import Optional + +import mlx.core as mx + + +@lru_cache(maxsize=None) +def hanning(size: int, periodic: bool = False): + denom = size if periodic else size - 1 + return mx.array( + [0.5 * (1 - math.cos(2 * math.pi * n / denom)) for n in range(size)] + ) + + +@lru_cache(maxsize=None) +def hamming(size: int, periodic: bool = False): + denom = size if periodic else size - 1 + return mx.array( + [0.54 - 0.46 * math.cos(2 * math.pi * n / denom) for n in range(size)] + ) + + +@lru_cache(maxsize=None) +def blackman(size: int, periodic: bool = False): + denom = size if periodic else size - 1 + return mx.array( + [ + 0.42 + - 0.5 * math.cos(2 * math.pi * n / denom) + + 0.08 * math.cos(4 * math.pi * n / denom) + for n in range(size) + ] + ) + + +@lru_cache(maxsize=None) +def bartlett(size: int, periodic: bool = False): + denom = size if periodic else size - 1 + return mx.array([1 - 2 * abs(n - denom / 2) / denom for n in range(size)]) + + +STR_TO_WINDOW_FN = { + "hann": hanning, + "hanning": hanning, + "hamming": hamming, + "blackman": blackman, + "bartlett": bartlett, +} + + +def stft( + x, + n_fft: int = 800, + hop_length: int | None = None, + win_length: int | None = None, + window: mx.array | str = "hann", + center: bool = True, + pad_mode: str = "reflect", +): + if hop_length is None: + hop_length = n_fft // 4 + if win_length is None: + win_length = n_fft + + if isinstance(window, str): + window_fn = STR_TO_WINDOW_FN.get(window.lower()) + if window_fn is None: + raise ValueError(f"Unknown window function: {window}") + w = window_fn(win_length) + else: + w = window + + if w.shape[0] < n_fft: + pad_size = n_fft - w.shape[0] + w = mx.concatenate([w, mx.zeros((pad_size,))], axis=0) + + def _pad(signal, padding: int, mode: str = "reflect"): + if mode == "constant": + return mx.pad(signal, [(padding, padding)]) + if mode == "reflect": + prefix = signal[1 : padding + 1][::-1] + suffix = signal[-(padding + 1) : -1][::-1] + return mx.concatenate([prefix, signal, suffix]) + raise ValueError(f"Invalid pad_mode {mode}") + + if center: + x = _pad(x, n_fft // 2, pad_mode) + + num_frames = 1 + (x.shape[0] - n_fft) // hop_length + if num_frames <= 0: + raise ValueError( + f"Input is too short for n_fft={n_fft}, hop_length={hop_length}, " + f"center={center}." + ) + + frames = mx.as_strided(x, shape=(num_frames, n_fft), strides=(hop_length, 1)) + return mx.fft.rfft(frames * w) + + +@lru_cache(maxsize=None) +def mel_filters( + sample_rate: int, + n_fft: int, + n_mels: int, + f_min: float = 0, + f_max: Optional[float] = None, + norm: Optional[str] = None, + mel_scale: str = "htk", +) -> mx.array: + def hz_to_mel(freq, scale="htk"): + if scale == "htk": + return 2595.0 * math.log10(1.0 + freq / 700.0) + + f_sp = 200.0 / 3 + mels = freq / f_sp + min_log_hz = 1000.0 + min_log_mel = min_log_hz / f_sp + logstep = math.log(6.4) / 27.0 + if freq >= min_log_hz: + mels = min_log_mel + math.log(freq / min_log_hz) / logstep + return mels + + def mel_to_hz(mels, scale="htk"): + if scale == "htk": + return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) + + f_sp = 200.0 / 3 + freqs = f_sp * mels + min_log_hz = 1000.0 + min_log_mel = min_log_hz / f_sp + logstep = math.log(6.4) / 27.0 + return mx.where( + mels >= min_log_mel, + min_log_hz * mx.exp(logstep * (mels - min_log_mel)), + freqs, + ) + + f_max = f_max or sample_rate / 2 + n_freqs = n_fft // 2 + 1 + all_freqs = mx.linspace(0, sample_rate // 2, n_freqs) + m_min = hz_to_mel(f_min, mel_scale) + m_max = hz_to_mel(f_max, mel_scale) + m_pts = mx.linspace(m_min, m_max, n_mels + 2) + f_pts = mel_to_hz(m_pts, mel_scale) + f_diff = f_pts[1:] - f_pts[:-1] + slopes = mx.expand_dims(f_pts, 0) - mx.expand_dims(all_freqs, 1) + down_slopes = (-slopes[:, :-2]) / f_diff[:-1] + up_slopes = slopes[:, 2:] / f_diff[1:] + filterbank = mx.maximum( + mx.zeros_like(down_slopes), mx.minimum(down_slopes, up_slopes) + ) + + if norm == "slaney": + enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels]) + filterbank *= mx.expand_dims(enorm, 0) + + return filterbank.moveaxis(0, 1) + diff --git a/src/granite_speech_plus_mlx/_vendored/granite_speech/__init__.py b/src/granite_speech_plus_mlx/_vendored/granite_speech/__init__.py new file mode 100644 index 0000000..17a5875 --- /dev/null +++ b/src/granite_speech_plus_mlx/_vendored/granite_speech/__init__.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from .config import EncoderConfig, ModelConfig, ProjectorConfig, TextConfig +from .granite_speech import Model + + +@dataclass +class GraniteSpeechPlusModelConfig(ModelConfig): + model_type: str = "granite_speech_plus" + + +DETECTION_HINTS = { + "config_keys": {"encoder_config", "projector_config", "audio_token_index"}, + "architectures": { + "GraniteSpeechForConditionalGeneration", + "GraniteSpeechPlusForConditionalGeneration", + }, + "path_patterns": { + "granite_speech_plus", + "granitespeechplus", + "granite-speech-4.1-2b-plus", + }, +} + +__all__ = [ + "EncoderConfig", + "ProjectorConfig", + "TextConfig", + "ModelConfig", + "GraniteSpeechPlusModelConfig", + "Model", + "DETECTION_HINTS", +] + diff --git a/src/granite_speech_plus_mlx/_vendored/granite_speech/config.py b/src/granite_speech_plus_mlx/_vendored/granite_speech/config.py new file mode 100644 index 0000000..cf79034 --- /dev/null +++ b/src/granite_speech_plus_mlx/_vendored/granite_speech/config.py @@ -0,0 +1,128 @@ +import inspect +from dataclasses import dataclass, field +from typing import Dict, List, Optional + + +@dataclass +class EncoderConfig: + input_dim: int = 160 + num_layers: int = 10 + hidden_dim: int = 1024 + feedforward_mult: int = 4 + num_heads: int = 8 + dim_head: int = 128 + output_dim: int = 42 + context_size: int = 200 + max_pos_emb: int = 512 + dropout: float = 0.1 + conv_kernel_size: int = 15 + conv_expansion_factor: int = 2 + # Plus variant: indices of intermediate encoder layers whose hidden state + # gets concatenated with the final-layer hidden state along the channel + # axis, before being fed to the projector. None / empty = base behavior. + cat_hidden_layers: Optional[List[int]] = None + model_type: str = "granite_speech_encoder" + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +@dataclass +class ProjectorConfig: + hidden_size: int = 1024 + num_hidden_layers: int = 2 + num_attention_heads: int = 16 + intermediate_size: int = 4096 + hidden_act: str = "gelu" + layer_norm_eps: float = 1e-12 + encoder_hidden_size: int = 1024 + cross_attention_frequency: int = 1 + model_type: str = "blip_2_qformer" + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +@dataclass +class TextConfig: + model_type: str = "granite" + vocab_size: int = 100353 + hidden_size: int = 2048 + intermediate_size: int = 4096 + num_hidden_layers: int = 40 + num_attention_heads: int = 16 + num_key_value_heads: int = 4 + hidden_act: str = "silu" + max_position_embeddings: int = 4096 + rms_norm_eps: float = 1e-5 + rope_theta: float = 10000.0 + rope_scaling: Optional[Dict] = None + attention_bias: bool = False + mlp_bias: bool = False + attention_multiplier: float = 0.0078125 + embedding_multiplier: float = 12.0 + residual_multiplier: float = 0.22 + logits_scaling: float = 8.0 + tie_word_embeddings: bool = False + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +@dataclass +class ModelConfig: + model_type: str = "granite_speech" + encoder_config: EncoderConfig = None + projector_config: ProjectorConfig = None + text_config: TextConfig = None + audio_token_index: int = 100352 + downsample_rate: int = 5 + window_size: int = 15 + has_lora_adapter: bool = False + + def __post_init__(self): + if isinstance(self.encoder_config, dict): + self.encoder_config = EncoderConfig.from_dict(self.encoder_config) + elif self.encoder_config is None: + self.encoder_config = EncoderConfig() + + if isinstance(self.projector_config, dict): + self.projector_config = ProjectorConfig.from_dict(self.projector_config) + elif self.projector_config is None: + self.projector_config = ProjectorConfig() + + if isinstance(self.text_config, dict): + self.text_config = TextConfig.from_dict(self.text_config) + elif self.text_config is None: + self.text_config = TextConfig() + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) diff --git a/src/granite_speech_plus_mlx/_vendored/granite_speech/granite_speech.py b/src/granite_speech_plus_mlx/_vendored/granite_speech/granite_speech.py new file mode 100644 index 0000000..194eee6 --- /dev/null +++ b/src/granite_speech_plus_mlx/_vendored/granite_speech/granite_speech.py @@ -0,0 +1,850 @@ +import math +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Generator, List, Optional, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from mlx.utils import tree_flatten +from mlx_lm.models.base import create_attention_mask +from mlx_lm.models.cache import KVCache +from mlx_lm.models.granite import Model as GraniteLM +from mlx_lm.models.granite import ModelArgs as GraniteModelArgs + +from ..base import STTOutput + +from .config import EncoderConfig, ModelConfig, ProjectorConfig + +LANGUAGE_CODES = { + "en": "English", + "fr": "French", + "de": "German", + "es": "Spanish", + "pt": "Portuguese", + "ja": "Japanese", +} + + +@dataclass +class StreamingResult: + text: str + is_final: bool + start_time: float + end_time: float + language: str = "en" + prompt_tokens: int = 0 + generation_tokens: int = 0 + + +class BatchNorm1d(nn.Module): + + def __init__(self, num_features: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((num_features,)) + self.bias = mx.zeros((num_features,)) + self.running_mean = mx.zeros((num_features,)) + self.running_var = mx.ones((num_features,)) + self.eps = eps + + def __call__(self, x: mx.array) -> mx.array: + return (x - self.running_mean) / mx.sqrt( + self.running_var + self.eps + ) * self.weight + self.bias + + +class ConformerFeedForward(nn.Module): + def __init__(self, config: EncoderConfig): + super().__init__() + self.pre_norm = nn.LayerNorm(config.hidden_dim) + self.up_proj = nn.Linear( + config.hidden_dim, config.hidden_dim * config.feedforward_mult + ) + self.down_proj = nn.Linear( + config.hidden_dim * config.feedforward_mult, config.hidden_dim + ) + + def __call__(self, x: mx.array) -> mx.array: + x = self.pre_norm(x) + x = nn.silu(self.up_proj(x)) + x = self.down_proj(x) + return x + + +class ConformerAttention(nn.Module): + + def __init__(self, config: EncoderConfig): + super().__init__() + inner_dim = config.dim_head * config.num_heads + self.max_pos_emb = config.max_pos_emb + self.context_size = config.context_size + self.num_heads = config.num_heads + self.dim_head = config.dim_head + self.scale = config.dim_head**-0.5 + self.pre_norm = nn.LayerNorm(config.hidden_dim) + self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False) + self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, config.hidden_dim) + self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head) + + def __call__(self, x: mx.array, attention_dists: mx.array) -> mx.array: + x = self.pre_norm(x) + B, N, _ = x.shape + + num_blocks = math.ceil(N / self.context_size) + remainder = N % self.context_size + + if remainder > 0: + pad_len = self.context_size - remainder + x = mx.pad(x, [(0, 0), (0, pad_len), (0, 0)]) + + q = self.to_q(x) + kv = self.to_kv(x) + k, v = mx.split(kv, 2, axis=-1) + + q = q.reshape(B, num_blocks, self.context_size, self.num_heads, -1) + k = k.reshape(B, num_blocks, self.context_size, self.num_heads, -1) + v = v.reshape(B, num_blocks, self.context_size, self.num_heads, -1) + + q = q.transpose(0, 1, 3, 2, 4) + k = k.transpose(0, 1, 3, 2, 4) + v = v.transpose(0, 1, 3, 2, 4) + + rel_pos_emb = self.rel_pos_emb(attention_dists) + + C = self.context_size + pos_attn = ( + mx.sum( + q[:, :, :, :, None, :] * rel_pos_emb[None, None, None, :, :, :], + axis=-1, + ) + * self.scale + ) + + if remainder > 0: + row_valid = mx.arange(C)[:, None] < remainder + col_valid = mx.arange(C)[None, :] < remainder + mask = ~(row_valid & col_valid) + mask_value = mx.array(mx.finfo(pos_attn.dtype).min) + pos_attn_last = mx.where( + mask[None, None, None], mask_value, pos_attn[:, -1:, :, :, :] + ) + pos_attn = mx.concatenate( + [pos_attn[:, :-1, :, :, :], pos_attn_last], axis=1 + ) + + attn_weights = (q @ k.transpose(0, 1, 2, 4, 3)) * self.scale + pos_attn + attn_weights = mx.softmax(attn_weights, axis=-1) + + out = attn_weights @ v + out = out.transpose(0, 1, 3, 2, 4) + out = out.reshape(B, -1, self.num_heads * self.dim_head) + out = out[:, :N, :] + out = self.to_out(out) + return out + + +class DepthWiseConv1d(nn.Module): + + def __init__(self, chan_in: int, chan_out: int, kernel_size: int): + super().__init__() + pad = kernel_size // 2 + pad_offset = (kernel_size + 1) % 2 + self.padding = (pad, pad - pad_offset) + self.conv = nn.Conv1d( + chan_in, chan_out, kernel_size, groups=chan_in, bias=False + ) + + def __call__(self, x: mx.array) -> mx.array: + x = mx.pad(x, [(0, 0), (self.padding[0], self.padding[1]), (0, 0)]) + return self.conv(x) + + +class ConformerConvModule(nn.Module): + + def __init__(self, config: EncoderConfig): + super().__init__() + inner_dim = config.hidden_dim * config.conv_expansion_factor + + self.norm = nn.LayerNorm(config.hidden_dim) + self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1) + self.depth_conv = DepthWiseConv1d(inner_dim, inner_dim, config.conv_kernel_size) + self.batch_norm = BatchNorm1d(inner_dim) + self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1) + + def __call__(self, x: mx.array) -> mx.array: + x = self.norm(x) + x = self.up_conv(x) + x1, x2 = mx.split(x, 2, axis=-1) + x = x1 * mx.sigmoid(x2) + x = self.depth_conv(x) + x = nn.silu(self.batch_norm(x)) + x = self.down_conv(x) + return x + + +class ConformerBlock(nn.Module): + + def __init__(self, config: EncoderConfig): + super().__init__() + self.ff1 = ConformerFeedForward(config) + self.attn = ConformerAttention(config) + self.conv = ConformerConvModule(config) + self.ff2 = ConformerFeedForward(config) + self.post_norm = nn.LayerNorm(config.hidden_dim) + + def __call__(self, x: mx.array, attention_dists: mx.array) -> mx.array: + x = 0.5 * self.ff1(x) + x + x = self.attn(x, attention_dists) + x + x = self.conv(x) + x + x = 0.5 * self.ff2(x) + x + x = self.post_norm(x) + return x + + +class CTCEncoder(nn.Module): + + def __init__(self, config: EncoderConfig): + super().__init__() + self.config = config + self.input_linear = nn.Linear(config.input_dim, config.hidden_dim) + self.layers = [ConformerBlock(config) for _ in range(config.num_layers)] + self.out = nn.Linear(config.hidden_dim, config.output_dim) + self.out_mid = nn.Linear(config.output_dim, config.hidden_dim) + self.num_layers = config.num_layers + self._attention_dists = None + + seq = mx.arange(config.context_size) + relpos_dist = seq[:, None] - seq[None, :] + self._attention_dists = ( + mx.clip(relpos_dist, -config.context_size, config.context_size) + + config.max_pos_emb + ) + + def __call__(self, x: mx.array) -> mx.array: + x = self.input_linear(x) + cat_layers = set(self.config.cat_hidden_layers or []) + exported_hidden_states = [] + if 0 in cat_layers: + exported_hidden_states.append(x) + + for idx, layer in enumerate(self.layers, start=1): + x = layer(x, attention_dists=self._attention_dists) + if idx in cat_layers: + exported_hidden_states.append(x) + if idx == self.num_layers // 2: + x_mid = self.out(x) + x = x + self.out_mid(mx.softmax(x_mid, axis=-1)) + + if exported_hidden_states: + # Plus variant: prepend captured intermediate hidden states to the + # final-layer output along the channel axis. Order matches the + # upstream Transformers implementation: intermediates first, then + # final. + x = mx.concatenate([*exported_hidden_states, x], axis=-1) + return x + + +class QFormerMultiHeadAttention(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, kv_hidden_size: int = None): + super().__init__() + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + kv_dim = kv_hidden_size or hidden_size + + self.query = nn.Linear(hidden_size, hidden_size) + self.key = nn.Linear(kv_dim, hidden_size) + self.value = nn.Linear(kv_dim, hidden_size) + + def __call__( + self, hidden_states: mx.array, encoder_hidden_states: mx.array = None + ) -> mx.array: + B, L, _ = hidden_states.shape + + q = self.query(hidden_states) + kv_input = ( + encoder_hidden_states + if encoder_hidden_states is not None + else hidden_states + ) + k = self.key(kv_input) + v = self.value(kv_input) + + q = q.reshape(B, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + k = k.reshape(B, -1, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + v = v.reshape(B, -1, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + + scale = self.head_dim**-0.5 + attn = (q * scale) @ k.transpose(0, 1, 3, 2) + attn = mx.softmax(attn, axis=-1) + out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, L, -1) + return out + + +class QFormerSelfOutput(nn.Module): + + def __init__(self, hidden_size: int, eps: float = 1e-12): + super().__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=eps) + + def __call__(self, hidden_states: mx.array, input_tensor: mx.array) -> mx.array: + hidden_states = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class QFormerAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + kv_hidden_size: int = None, + eps: float = 1e-12, + ): + super().__init__() + self.attention = QFormerMultiHeadAttention( + hidden_size, num_heads, kv_hidden_size + ) + self.output = QFormerSelfOutput(hidden_size, eps) + + def __call__( + self, hidden_states: mx.array, encoder_hidden_states: mx.array = None + ) -> mx.array: + attn_out = self.attention(hidden_states, encoder_hidden_states) + return self.output(attn_out, hidden_states) + + +class QFormerIntermediate(nn.Module): + + def __init__(self, hidden_size: int, intermediate_size: int): + super().__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + + def __call__(self, x: mx.array) -> mx.array: + return nn.gelu(self.dense(x)) + + +class QFormerOutput(nn.Module): + + def __init__(self, intermediate_size: int, hidden_size: int, eps: float = 1e-12): + super().__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size, eps=eps) + + def __call__(self, hidden_states: mx.array, input_tensor: mx.array) -> mx.array: + hidden_states = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class QFormerLayer(nn.Module): + + def __init__(self, config: ProjectorConfig): + super().__init__() + self.attention = QFormerAttention( + config.hidden_size, config.num_attention_heads, eps=config.layer_norm_eps + ) + self.crossattention = QFormerAttention( + config.hidden_size, + config.num_attention_heads, + kv_hidden_size=config.encoder_hidden_size, + eps=config.layer_norm_eps, + ) + self.intermediate_query = QFormerIntermediate( + config.hidden_size, config.intermediate_size + ) + self.output_query = QFormerOutput( + config.intermediate_size, config.hidden_size, eps=config.layer_norm_eps + ) + + def __call__( + self, hidden_states: mx.array, encoder_hidden_states: mx.array + ) -> mx.array: + hidden_states = self.attention(hidden_states) + hidden_states = self.crossattention(hidden_states, encoder_hidden_states) + intermediate = self.intermediate_query(hidden_states) + hidden_states = self.output_query(intermediate, hidden_states) + return hidden_states + + +class QFormerEncoder(nn.Module): + def __init__(self, config: ProjectorConfig): + super().__init__() + self.layer = [QFormerLayer(config) for _ in range(config.num_hidden_layers)] + + def __call__( + self, hidden_states: mx.array, encoder_hidden_states: mx.array + ) -> mx.array: + for layer in self.layer: + hidden_states = layer(hidden_states, encoder_hidden_states) + return hidden_states + + +class QFormerModel(nn.Module): + def __init__(self, config: ProjectorConfig): + super().__init__() + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.encoder = QFormerEncoder(config) + + def __call__( + self, query_embeds: mx.array, encoder_hidden_states: mx.array + ) -> mx.array: + hidden_states = self.layernorm(query_embeds) + return self.encoder(hidden_states, encoder_hidden_states) + + +class EncoderProjector(nn.Module): + + def __init__(self, config: ModelConfig): + super().__init__() + self.hidden_size = config.projector_config.hidden_size + self.downsample_rate = config.downsample_rate + self.window_size = config.window_size + self.num_queries = config.window_size // config.downsample_rate + + self.query = mx.zeros( + (1, self.num_queries, config.projector_config.hidden_size) + ) + self.qformer = QFormerModel(config.projector_config) + self.linear = nn.Linear( + config.projector_config.hidden_size, config.text_config.hidden_size + ) + + def __call__(self, hidden_states: mx.array) -> mx.array: + B, L, D = hidden_states.shape + nblocks = math.ceil(L / self.window_size) + pad = nblocks * self.window_size - L + if pad > 0: + hidden_states = mx.pad(hidden_states, [(0, 0), (0, pad), (0, 0)]) + + hidden_states = hidden_states.reshape(B * nblocks, self.window_size, D) + + query = mx.broadcast_to( + self.query, (B * nblocks, self.num_queries, self.hidden_size) + ) + + query_output = self.qformer(query, hidden_states) + query_proj = self.linear( + query_output.reshape(B, nblocks * self.num_queries, -1) + ) + return query_proj + + +class Model(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + + # Plus variant invariant: encoder concats len(cat_hidden_layers)+1 + # hidden states channel-wise, so projector must accept that wider input. + cat_layers = config.encoder_config.cat_hidden_layers or [] + expected_proj_in = config.encoder_config.hidden_dim * (len(cat_layers) + 1) + if config.projector_config.encoder_hidden_size != expected_proj_in: + raise ValueError( + f"projector_config.encoder_hidden_size ({config.projector_config.encoder_hidden_size}) " + f"must equal encoder_config.hidden_dim * (len(cat_hidden_layers) + 1) " + f"({config.encoder_config.hidden_dim} * {len(cat_layers) + 1} = {expected_proj_in})" + ) + + self.encoder = CTCEncoder(config.encoder_config) + self.projector = EncoderProjector(config) + text_args = GraniteModelArgs.from_dict( + config.text_config.__dict__ + if hasattr(config.text_config, "__dict__") + else config.text_config + ) + self.language_model = GraniteLM(text_args) + + self.audio_token_id = config.audio_token_index + self._tokenizer = None + + @property + def layers(self): + return self.language_model.model.layers + + def make_cache(self) -> List[KVCache]: + return [KVCache() for _ in range(len(self.layers))] + + def __call__( + self, + input_ids: mx.array, + cache: Optional[List[KVCache]] = None, + input_embeddings: Optional[mx.array] = None, + ) -> mx.array: + if input_embeddings is not None: + h = input_embeddings + else: + h = self.language_model.model.embed_tokens(input_ids) + + h = h * self.language_model.model.embedding_multiplier + + if cache is None: + cache = [None] * len(self.language_model.model.layers) + + mask = create_attention_mask(h, cache[0]) + + for layer, c in zip(self.language_model.model.layers, cache): + h = layer(h, mask, cache=c) + + h = self.language_model.model.norm(h) + + if self.language_model.args.tie_word_embeddings: + logits = self.language_model.model.embed_tokens.as_linear(h) + else: + logits = self.language_model.lm_head(h) + + return logits / self.language_model.logits_scaling + + def get_audio_features(self, input_features: mx.array) -> mx.array: + encoder_output = self.encoder(input_features) + projected = self.projector(encoder_output) + return projected + + def model_quant_predicate(self, p: str, m: nn.Module) -> bool: + return not (p.startswith("encoder") or p.startswith("projector")) + + def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]: + # Compare incoming weight shapes against the model's already-initialized + # parameter shapes. This is idempotent across convert-time (PyTorch source + # layout) and inference-time load (MLX-native layout) and correct even for + # Conv1d kernel_size=1 layers where prior shape-ordering heuristics failed. + # Pattern adapted from cohere_asr.Model.sanitize. + model_weights = dict(tree_flatten(self.parameters())) + sanitized = {} + for k, v in weights.items(): + if "num_batches_tracked" in k: + continue + # granite-speech-4.1 ships a separate out_llm.safetensors with + # top-level "weight"/"bias" keys (likely an audio CTC head). The + # standard Model class does not define this layer, so dropping these + # keys is required for the rest of the load to succeed. Inference + # behaviour that depends on this head is not yet supported. + if k in ("weight", "bias"): + continue + + expected = model_weights.get(k) + if expected is not None and hasattr(expected, "shape"): + if v.shape != expected.shape and v.ndim == 3: + transposed = mx.transpose(v, (0, 2, 1)) + if transposed.shape == expected.shape: + v = transposed + + sanitized[k] = v + return sanitized + + @classmethod + def post_load_hook(cls, model: "Model", model_path: Path) -> "Model": + import transformers + from transformers import AutoTokenizer + + prev = transformers.logging.get_verbosity() + transformers.logging.set_verbosity_error() + try: + model._tokenizer = AutoTokenizer.from_pretrained( + str(model_path), trust_remote_code=True + ) + finally: + transformers.logging.set_verbosity(prev) + + return model + + def _extract_features( + self, audio: Union[mx.array, np.ndarray] + ) -> Tuple[mx.array, int]: + from ..dsp import hanning, mel_filters, stft + + n_fft = 512 + win_length = 400 + hop_length = 160 + n_mels = 80 + sample_rate = 16000 + + if isinstance(audio, mx.array): + audio_1d = audio.reshape(-1) + else: + audio_1d = mx.array(audio.flatten(), dtype=mx.float32) + + win = hanning(win_length, periodic=True) + pad_left = (n_fft - win_length) // 2 + pad_right = n_fft - win_length - pad_left + win_padded = mx.concatenate( + [mx.zeros((pad_left,)), win, mx.zeros((pad_right,))] + ) + + spec = stft( + audio_1d, + n_fft=n_fft, + hop_length=hop_length, + window=win_padded, + center=True, + pad_mode="reflect", + ) + + power = mx.abs(spec) ** 2 + mel_fb = mel_filters(sample_rate, n_fft, n_mels, mel_scale="htk") + mel_spec = power @ mel_fb.T + + logmel = mx.log10(mx.clip(mel_spec, 1e-10, None)) + mx_val = mx.max(logmel) + logmel = mx.maximum(logmel, mx_val - 8.0) / 4.0 + 1.0 + + if logmel.shape[0] % 2 == 1: + logmel = logmel[:-1] + + encoder_input = logmel.reshape(-1, 2 * n_mels) + + encoder_length = encoder_input.shape[0] + nblocks = math.ceil(encoder_length / self.config.window_size) + num_audio_tokens = nblocks * ( + self.config.window_size // self.config.downsample_rate + ) + + input_features = encoder_input[None, :, :] + return input_features, num_audio_tokens + + def _build_prompt( + self, + num_audio_tokens: int, + user_prompt: str = None, + system_prompt: str = None, + ) -> mx.array: + if user_prompt is None: + user_prompt = "can you transcribe the speech into a written format?" + + audio_placeholder = "<|audio|>" * num_audio_tokens + content = f"{audio_placeholder}{user_prompt}" + + if getattr(self._tokenizer, "chat_template", None): + chat = [] + if system_prompt: + chat.append({"role": "system", "content": system_prompt}) + chat.append({"role": "user", "content": content}) + prompt_str = self._tokenizer.apply_chat_template( + chat, tokenize=False, add_generation_prompt=True + ) + else: + # Granite-3 chat format (granite-speech tokenizer ships without a + # chat_template attribute, but its vocab includes the role tokens). + sor, eor, eot = "<|start_of_role|>", "<|end_of_role|>", "<|end_of_text|>" + parts = [] + if system_prompt: + parts.append(f"{sor}system{eor}{system_prompt}{eot}\n") + parts.append(f"{sor}user{eor}{content}{eot}\n") + parts.append(f"{sor}assistant{eor}") + prompt_str = "".join(parts) + + prompt_ids = self._tokenizer.encode(prompt_str) + + return mx.array(prompt_ids) + + def _build_inputs_embeds( + self, input_ids: mx.array, audio_features: mx.array + ) -> mx.array: + is_audio = input_ids == self.audio_token_id + llm_ids = mx.where(is_audio, 0, input_ids) + + inputs_embeds = self.language_model.model.embed_tokens(llm_ids[None]) + + is_audio_np = np.array(is_audio) + audio_positions = np.where(is_audio_np)[0] + + orig_dtype = inputs_embeds.dtype + embeds_np = np.array(inputs_embeds.astype(mx.float32)) + audio_np = np.array(audio_features.astype(mx.float32)) + + num_audio = min(len(audio_positions), audio_np.shape[1]) + embeds_np[0, audio_positions[:num_audio]] = audio_np[0, :num_audio] + + return mx.array(embeds_np).astype(orig_dtype) + + def generate( + self, + audio: Union[str, mx.array, np.ndarray], + *, + max_tokens: int = 4096, + temperature: float = 0.0, + top_p: float = 1.0, + top_k: int = 0, + min_p: float = 0.0, + repetition_penalty: Optional[float] = None, + repetition_context_size: int = 100, + prompt: str = None, + system_prompt: str = None, + language: str = None, + prefill_step_size: int = 2048, + verbose: bool = False, + stream: bool = False, + **kwargs, + ) -> Union[STTOutput, Generator[StreamingResult, None, None]]: + if prompt is None and language is not None: + lang_name = LANGUAGE_CODES.get(language.lower(), language) + prompt = f"Translate the speech to {lang_name}." + + if stream: + return self._stream_generate( + audio, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + repetition_penalty=repetition_penalty, + repetition_context_size=repetition_context_size, + prompt=prompt, + prefill_step_size=prefill_step_size, + verbose=verbose, + ) + + start_time = time.time() + + from mlx_lm.generate import generate_step + from mlx_lm.sample_utils import make_logits_processors, make_sampler + + audio_data = self._load_audio(audio) + input_features, num_audio_tokens = self._extract_features(audio_data) + + if verbose: + print("Encoding audio...") + audio_features = self.get_audio_features(input_features) + mx.eval(audio_features) + + prompt_ids = self._build_prompt(num_audio_tokens, prompt, system_prompt=system_prompt) + inputs_embeds = self._build_inputs_embeds(prompt_ids, audio_features) + mx.eval(inputs_embeds) + + prompt_tokens = len(prompt_ids) + + sampler = make_sampler(temperature, top_p=top_p, min_p=min_p, top_k=top_k) + logits_processors = make_logits_processors( + repetition_penalty=repetition_penalty, + repetition_context_size=repetition_context_size, + ) + + eos_token_id = self._tokenizer.eos_token_id + tokens = [] + + for token, logprobs in generate_step( + prompt=prompt_ids, + input_embeddings=inputs_embeds.squeeze(0), + model=self, + max_tokens=max_tokens, + sampler=sampler, + logits_processors=logits_processors, + prefill_step_size=prefill_step_size, + ): + if token == eos_token_id: + break + tokens.append(token) + + text = self._tokenizer.decode(tokens, skip_special_tokens=True) + elapsed = time.time() - start_time + gen_tokens = len(tokens) + + if verbose: + print(f"Prompt tokens: {prompt_tokens}") + print(f"Generation tokens: {gen_tokens}") + print(f"Total time: {elapsed:.2f}s") + if gen_tokens > 0: + print(f"Generation TPS: {gen_tokens / elapsed:.1f}") + + return STTOutput( + text=text, + segments=[], + prompt_tokens=prompt_tokens, + generation_tokens=gen_tokens, + total_tokens=prompt_tokens + gen_tokens, + total_time=elapsed, + prompt_tps=prompt_tokens / elapsed if elapsed > 0 else 0, + generation_tps=gen_tokens / elapsed if elapsed > 0 else 0, + ) + + def _stream_generate( + self, + audio: Union[str, mx.array, np.ndarray], + *, + max_tokens: int = 4096, + temperature: float = 0.0, + top_p: float = 1.0, + top_k: int = 0, + min_p: float = 0.0, + repetition_penalty: Optional[float] = None, + repetition_context_size: int = 100, + prompt: str = None, + prefill_step_size: int = 2048, + verbose: bool = False, + ) -> Generator[StreamingResult, None, None]: + from mlx_lm.generate import generate_step + from mlx_lm.sample_utils import make_logits_processors, make_sampler + + audio_data = self._load_audio(audio) + input_features, num_audio_tokens = self._extract_features(audio_data) + + audio_features = self.get_audio_features(input_features) + mx.eval(audio_features) + + prompt_ids = self._build_prompt(num_audio_tokens, prompt) + inputs_embeds = self._build_inputs_embeds(prompt_ids, audio_features) + mx.eval(inputs_embeds) + + prompt_token_count = len(prompt_ids) + + sampler = make_sampler(temperature, top_p=top_p, min_p=min_p, top_k=top_k) + logits_processors = make_logits_processors( + repetition_penalty=repetition_penalty, + repetition_context_size=repetition_context_size, + ) + + eos_token_id = self._tokenizer.eos_token_id + gen_tokens = 0 + + for token, _ in generate_step( + prompt=prompt_ids, + input_embeddings=inputs_embeds.squeeze(0), + model=self, + max_tokens=max_tokens, + sampler=sampler, + logits_processors=logits_processors, + prefill_step_size=prefill_step_size, + ): + if token == eos_token_id: + break + gen_tokens += 1 + text = self._tokenizer.decode([token], skip_special_tokens=True) + yield StreamingResult( + text=text, + is_final=False, + start_time=0.0, + end_time=0.0, + prompt_tokens=prompt_token_count, + generation_tokens=gen_tokens, + ) + + yield StreamingResult( + text="", + is_final=True, + start_time=0.0, + end_time=0.0, + prompt_tokens=prompt_token_count, + generation_tokens=gen_tokens, + ) + + def _load_audio(self, audio: Union[str, mx.array, np.ndarray]) -> mx.array: + if isinstance(audio, str): + from ..audio import load_audio + + return load_audio(audio) + elif isinstance(audio, np.ndarray): + return mx.array(audio, dtype=mx.float32) + elif isinstance(audio, mx.array): + return audio + elif isinstance(audio, list): + audio_item = audio[0] + if isinstance(audio_item, str): + from ..audio import load_audio + + return load_audio(audio_item) + return mx.array(np.array(audio_item), dtype=mx.float32) + raise TypeError(f"Unsupported audio type: {type(audio)}") diff --git a/src/granite_speech_plus_mlx/_vendored/loader.py b/src/granite_speech_plus_mlx/_vendored/loader.py new file mode 100644 index 0000000..23f1dec --- /dev/null +++ b/src/granite_speech_plus_mlx/_vendored/loader.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +import glob +import json +from pathlib import Path +from typing import Any + +from huggingface_hub import snapshot_download +import mlx.core as mx +import mlx.nn as nn + +from .granite_speech import Model, ModelConfig + +DEFAULT_ALLOW_PATTERNS = [ + "*.json", + "*.safetensors", + "*.py", + "*.model", + "*.tiktoken", + "*.txt", + "*.jsonl", + "*.yaml", + "*.npz", +] + + +def _is_local_path(path: str) -> bool: + return ( + path.startswith(".") + or path.startswith("/") + or path.startswith("~") + or (len(path) > 1 and path[1] == ":") + ) + + +def get_model_path( + path_or_hf_repo: str | Path, + *, + revision: str | None = None, + force_download: bool = False, + allow_patterns: list[str] | None = None, +) -> Path: + if isinstance(path_or_hf_repo, Path): + path = path_or_hf_repo.expanduser() + if path.exists(): + return path + raise FileNotFoundError(f"Local path not found: {path_or_hf_repo}") + + path = Path(path_or_hf_repo).expanduser() + if path.exists(): + return path + if _is_local_path(path_or_hf_repo): + raise FileNotFoundError(f"Local path not found: {path_or_hf_repo}") + + return Path( + snapshot_download( + path_or_hf_repo, + revision=revision, + allow_patterns=allow_patterns or DEFAULT_ALLOW_PATTERNS, + force_download=force_download, + ) + ) + + +def load_config(model_path: str | Path) -> dict[str, Any]: + model_path = Path(model_path) + config_file = model_path / "config.json" + if not config_file.exists(): + raise FileNotFoundError(f"Config not found at {model_path}") + return json.loads(config_file.read_text(encoding="utf-8")) + + +def load_weights(model_path: Path) -> dict[str, mx.array]: + weight_files = sorted(glob.glob(str(model_path / "*.safetensors"))) + if not weight_files: + weight_files = sorted(glob.glob(str(model_path / "*.npz"))) + if not weight_files: + raise FileNotFoundError( + f"No weight files (safetensors or npz) found in {model_path}" + ) + + weights = {} + for weight_file in weight_files: + weights.update(mx.load(weight_file)) + return weights + + +def apply_quantization( + model: nn.Module, + config: dict[str, Any], + weights: dict[str, mx.array], + model_quant_predicate=None, +) -> None: + quantization = config.get("quantization") or config.get("quantization_config") + if quantization is None: + return + + group_size = quantization.get("group_size", 64) + + def class_predicate(path, module): + if not hasattr(module, "to_quantized"): + return False + if hasattr(module, "weight") and module.weight.shape[-1] % group_size != 0: + return False + if model_quant_predicate is not None: + pred = model_quant_predicate(path, module) + if isinstance(pred, dict): + return pred + if not pred: + return False + if path in quantization: + return quantization[path] + return f"{path}.scales" in weights + + nn.quantize( + model, + group_size=group_size, + bits=quantization["bits"], + mode=quantization.get("mode", "affine"), + class_predicate=class_predicate, + ) + + +def load_model( + model_path: str | Path, + *, + lazy: bool = False, + strict: bool = False, + **kwargs: Any, +) -> nn.Module: + path = get_model_path( + model_path, + revision=kwargs.pop("revision", None), + force_download=kwargs.pop("force_download", False), + allow_patterns=kwargs.pop("allow_patterns", None), + ) + config = load_config(path) + model = Model(ModelConfig.from_dict(config)) + weights = load_weights(path) + + if hasattr(model, "sanitize"): + weights = model.sanitize(weights) + + apply_quantization(model, config, weights, model.model_quant_predicate) + model.load_weights(list(weights.items()), strict=strict) + + if not lazy: + mx.eval(model.parameters()) + model.eval() + + if hasattr(Model, "post_load_hook"): + model = Model.post_load_hook(model, path) + return model + diff --git a/src/granite_speech_plus_mlx/chunking.py b/src/granite_speech_plus_mlx/chunking.py new file mode 100644 index 0000000..ec7d6ca --- /dev/null +++ b/src/granite_speech_plus_mlx/chunking.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from dataclasses import dataclass +import re +from typing import Any, Iterable, Iterator + + +@dataclass(frozen=True) +class AudioChunk: + index: int + start: float + end: float + samples: Any + + +def chunk_audio( + audio: Any, + sr: int, + chunk_seconds: float, + overlap_seconds: float = 2.0, +) -> Iterator[AudioChunk]: + if chunk_seconds <= 0: + raise ValueError("chunk_seconds must be positive") + if overlap_seconds < 0: + raise ValueError("overlap_seconds cannot be negative") + + chunk_samples = int(chunk_seconds * sr) + overlap_samples = int(overlap_seconds * sr) + if overlap_samples >= chunk_samples: + raise ValueError("overlap_seconds must be smaller than chunk_seconds") + + step = chunk_samples - overlap_samples + n = len(audio) + pos = 0 + index = 1 + while pos < n: + end = min(pos + chunk_samples, n) + yield AudioChunk(index, pos / sr, end / sr, audio[pos:end]) + if end == n: + break + pos += step + index += 1 + + +def prefix_text(transcripts: Iterable[str], max_chars: int = 800) -> str: + text = "\n".join(t.strip() for t in transcripts if t and t.strip()) + text = re.sub(r"^## \[[^\n]+\]\n", "", text, flags=re.MULTILINE) + text = re.sub(r"\s+", " ", text).strip() + if len(text) <= max_chars: + return text + + tail = text[-max_chars:] + first_space = tail.find(" ") + if first_space > 0: + tail = tail[first_space + 1 :] + return tail.strip() diff --git a/src/granite_speech_plus_mlx/pipeline.py b/src/granite_speech_plus_mlx/pipeline.py new file mode 100644 index 0000000..819df41 --- /dev/null +++ b/src/granite_speech_plus_mlx/pipeline.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +import sys +import time +from typing import Any + +from .chunking import chunk_audio, prefix_text +from .prompts import GRANITE_SYSTEM_PROMPT, PROMPT_MODES, build_prompt + +DEFAULT_MODEL = "mlx-community/granite-speech-4.1-2b-plus-mlx" + + +@dataclass +class GraniteSpeechPlusPipeline: + model: Any + repo_id: str = DEFAULT_MODEL + chunk_seconds: float = 300.0 + overlap_seconds: float = 2.0 + repetition_penalty: float = 1.2 + max_tokens: int = 4096 + system_prompt: str | None = GRANITE_SYSTEM_PROMPT + verbose: bool = False + + @classmethod + def from_pretrained( + cls, + repo_id: str = DEFAULT_MODEL, + *, + chunk_seconds: float = 300.0, + overlap_seconds: float = 2.0, + repetition_penalty: float = 1.2, + max_tokens: int = 4096, + system_prompt: str | None = GRANITE_SYSTEM_PROMPT, + verbose: bool = False, + **load_kwargs: Any, + ) -> "GraniteSpeechPlusPipeline": + from ._vendored.loader import load_model + + model = load_model(repo_id, **load_kwargs) + return cls( + model=model, + repo_id=repo_id, + chunk_seconds=chunk_seconds, + overlap_seconds=overlap_seconds, + repetition_penalty=repetition_penalty, + max_tokens=max_tokens, + system_prompt=system_prompt, + verbose=verbose, + ) + + def transcribe(self, audio_path: str | Path, prompt_mode: str = "asr") -> str: + import librosa + import numpy as np + + if prompt_mode not in PROMPT_MODES: + modes = ", ".join(sorted(PROMPT_MODES)) + raise ValueError(f"prompt_mode must be one of: {modes}") + + audio_file = Path(audio_path).expanduser().resolve() + audio, sr = librosa.load(str(audio_file), sr=16000, mono=True) + audio = np.asarray(audio, dtype=np.float32) + duration = len(audio) / sr if sr else 0.0 + chunks = list( + chunk_audio( + audio, + sr, + self.chunk_seconds, + overlap_seconds=self.overlap_seconds, + ) + ) + + if self.verbose: + print( + f"Loaded {audio_file} ({duration:.1f}s, {len(chunks)} chunks)", + file=sys.stderr, + ) + + rendered: list[str] = [] + plain_texts: list[str] = [] + t_start = time.time() + + for chunk in chunks: + prompt = build_prompt( + prompt_mode, + prefix_text=prefix_text(plain_texts), + ) + kwargs: dict[str, Any] = { + "prompt": prompt, + "max_tokens": self.max_tokens, + } + if self.system_prompt: + kwargs["system_prompt"] = self.system_prompt + if self.repetition_penalty and self.repetition_penalty > 1.0: + kwargs["repetition_penalty"] = self.repetition_penalty + + t0 = time.time() + result = self.model.generate(chunk.samples, **kwargs) + text = getattr(result, "text", result) + if not isinstance(text, str): + text = str(text) + text = text.strip() + plain_texts.append(text) + + if len(chunks) > 1: + rendered.append(f"## [{chunk.start:.1f}s - {chunk.end:.1f}s]\n{text}") + else: + rendered.append(text) + + if self.verbose: + elapsed = time.time() - t0 + rtf = (chunk.end - chunk.start) / elapsed if elapsed > 0 else 0.0 + print( + f"[{chunk.index:>3}/{len(chunks)}] " + f"{chunk.start:>6.1f}s-{chunk.end:<6.1f}s " + f"gen={elapsed:>5.1f}s rtf={rtf:>4.1f}x {text[:80]}", + file=sys.stderr, + ) + + if self.verbose: + elapsed = time.time() - t_start + rtf = duration / elapsed if elapsed > 0 else 0.0 + print(f"Total: {elapsed:.1f}s, rtf={rtf:.1f}x", file=sys.stderr) + + return "\n\n".join(rendered).strip() diff --git a/src/granite_speech_plus_mlx/prompts.py b/src/granite_speech_plus_mlx/prompts.py new file mode 100644 index 0000000..a44aa34 --- /dev/null +++ b/src/granite_speech_plus_mlx/prompts.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +GRANITE_SYSTEM_PROMPT = ( + "Knowledge Cutoff Date: April 2024.\n" + "Today's Date: December 19, 2024.\n" + "You are Granite, developed by IBM. You are a helpful AI assistant" +) + +PROMPT_MODES = { + "asr": "can you transcribe the speech into a written format?", + "saa": ( + "Speaker attribution: Transcribe and denote who is speaking by adding " + "[Speaker 1]: and [Speaker 2]: tags before speaker turns." + ), + "ts": ( + "Timestamps: Transcribe the speech. After each word, add a timestamp tag " + "showing the end time in centiseconds, e.g. hello [T:45] world [T:82]" + ), +} + + +def granite3_chat_template( + content: str, + *, + system_prompt: str | None = GRANITE_SYSTEM_PROMPT, + add_generation_prompt: bool = True, +) -> str: + """Build the Granite-3 chat template used when tokenizers omit one.""" + start_role = "<|start_of_role|>" + end_role = "<|end_of_role|>" + end_text = "<|end_of_text|>" + + parts: list[str] = [] + if system_prompt: + parts.append(f"{start_role}system{end_role}{system_prompt}{end_text}\n") + parts.append(f"{start_role}user{end_role}{content}{end_text}\n") + if add_generation_prompt: + parts.append(f"{start_role}assistant{end_role}") + return "".join(parts) + + +def build_prompt( + prompt_mode: str = "asr", + *, + prefix_text: str | None = None, + custom_prompt: str | None = None, +) -> str: + if custom_prompt: + base = custom_prompt + else: + try: + base = PROMPT_MODES[prompt_mode] + except KeyError as exc: + modes = ", ".join(sorted(PROMPT_MODES)) + raise ValueError(f"prompt_mode must be one of: {modes}") from exc + + if not prefix_text: + return base + + return ( + f"{base}\n\n" + "Previous transcript context for continuity only. Do not repeat it:\n" + f"{prefix_text.strip()}" + ) + diff --git a/tests/test_smoke.py b/tests/test_smoke.py new file mode 100644 index 0000000..5f4c152 --- /dev/null +++ b/tests/test_smoke.py @@ -0,0 +1,6 @@ +from granite_speech_plus_mlx import GraniteSpeechPlusPipeline + + +def test_pipeline_symbol_exists(): + assert GraniteSpeechPlusPipeline is not None +