Initial Granite Speech Plus MLX package
This commit is contained in:
30
.gitignore
vendored
Normal file
30
.gitignore
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
.Python
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
env/
|
||||
|
||||
build/
|
||||
dist/
|
||||
*.egg-info/
|
||||
.eggs/
|
||||
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.ruff_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
||||
.DS_Store
|
||||
.env
|
||||
.env.*
|
||||
|
||||
*.log
|
||||
*.tmp
|
||||
transcripts/
|
||||
bench/
|
||||
|
||||
22
LICENSE
Normal file
22
LICENSE
Normal file
@@ -0,0 +1,22 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2026 Olivier Dupont
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
42
README.md
Normal file
42
README.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# granite-speech-4.1-2b-plus-mlx
|
||||
|
||||
Standalone Python package for the MLX port of IBM Granite Speech 4.1-2b-plus.
|
||||
The default model is
|
||||
[`mlx-community/granite-speech-4.1-2b-plus-mlx`](https://huggingface.co/mlx-community/granite-speech-4.1-2b-plus-mlx).
|
||||
|
||||
## Quickstart
|
||||
|
||||
```bash
|
||||
uv add "granite-speech-4.1-2b-plus-mlx @ git+https://gitea.tavportal.com/olivier/granite-speech-4.1-2b-plus-mlx.git"
|
||||
python -c "from granite_speech_plus_mlx import GraniteSpeechPlusPipeline as P; p=P.from_pretrained(); print(p.transcribe('audio.wav'))"
|
||||
python scripts/transcribe.py audio.wav --prompt-mode asr --output transcript.txt
|
||||
python scripts/transcribe.py meeting.wav --prompt-mode saa
|
||||
python scripts/benchmark.py audio.wav --results bench
|
||||
```
|
||||
|
||||
## Prompt Modes
|
||||
|
||||
- `asr`: standard transcription.
|
||||
- `saa`: speaker-attributed ASR with `[Speaker N]:` turn labels.
|
||||
- `ts`: word-level timestamp tags like `word [T:45]`.
|
||||
|
||||
See [docs/prompt-modes.md](docs/prompt-modes.md) for examples.
|
||||
|
||||
## Benchmark Hints
|
||||
|
||||
Granite Speech 4.1 allocates substantial encoder memory for long audio. Start
|
||||
with `--chunk-seconds 300 --repetition-penalty 1.2` for ASR and reduce chunks
|
||||
to 60 or 180 seconds if memory is tight. Timestamp mode (`ts`) often needs a
|
||||
larger `--max-tokens` budget because every word carries a timestamp tag.
|
||||
|
||||
## Provenance
|
||||
|
||||
This package was extracted from the local `MLX_CONVERTOR` project, including
|
||||
the Granite Speech patch bundle at
|
||||
`external/patches/granite-speech-idempotent-sanitize.patch`. The vendored
|
||||
Granite implementation is based on `mlx-audio` commit
|
||||
`f7c11556eda88731be5cc75ddbdf4a4cb9eeaafc` plus that local patch.
|
||||
|
||||
Package code is MIT licensed. Model weights remain under the IBM Granite model
|
||||
license; review the model card and license terms before redistribution or use.
|
||||
|
||||
37
docs/prompt-modes.md
Normal file
37
docs/prompt-modes.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Prompt Modes
|
||||
|
||||
Granite Speech Plus supports three prompt modes in this package.
|
||||
|
||||
## `asr`
|
||||
|
||||
Standard speech transcription.
|
||||
|
||||
```python
|
||||
from granite_speech_plus_mlx import GraniteSpeechPlusPipeline
|
||||
|
||||
pipe = GraniteSpeechPlusPipeline.from_pretrained()
|
||||
text = pipe.transcribe("audio.wav", prompt_mode="asr")
|
||||
```
|
||||
|
||||
## `saa`
|
||||
|
||||
Speaker-attributed ASR. The prompt asks the model to add speaker turn labels
|
||||
such as `[Speaker 1]:` and `[Speaker 2]:`.
|
||||
|
||||
```python
|
||||
text = pipe.transcribe("meeting.wav", prompt_mode="saa")
|
||||
```
|
||||
|
||||
## `ts`
|
||||
|
||||
Word-level timestamps. The prompt asks the model to append centisecond tags
|
||||
after words, for example `hello [T:45] world [T:82]`.
|
||||
|
||||
```python
|
||||
text = pipe.transcribe("clip.wav", prompt_mode="ts")
|
||||
```
|
||||
|
||||
For long audio, the pipeline chunks the waveform and feeds a short previous
|
||||
transcript prefix into later chunks for continuity. The prefix is context only;
|
||||
the model is instructed not to repeat it.
|
||||
|
||||
27
pyproject.toml
Normal file
27
pyproject.toml
Normal file
@@ -0,0 +1,27 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "granite-speech-4.1-2b-plus-mlx"
|
||||
version = "0.1.0"
|
||||
description = "Standalone MLX pipeline for the Granite Speech 4.1-2b-plus port."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
license = "MIT"
|
||||
authors = [
|
||||
{ name = "Olivier Dupont" }
|
||||
]
|
||||
dependencies = [
|
||||
"mlx>=0.22.0",
|
||||
"mlx-lm>=0.19.0",
|
||||
"numpy>=1.26",
|
||||
"transformers>=4.45",
|
||||
"huggingface-hub>=0.24",
|
||||
"soundfile>=0.12",
|
||||
"librosa>=0.10",
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/granite_speech_plus_mlx"]
|
||||
|
||||
111
scripts/benchmark.py
Executable file
111
scripts/benchmark.py
Executable file
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
from granite_speech_plus_mlx import GraniteSpeechPlusPipeline
|
||||
from granite_speech_plus_mlx.pipeline import DEFAULT_MODEL
|
||||
from granite_speech_plus_mlx.prompts import PROMPT_MODES
|
||||
|
||||
GRID = [
|
||||
(60, 1.0),
|
||||
(60, 1.2),
|
||||
(180, 1.0),
|
||||
(180, 1.2),
|
||||
(300, 1.0),
|
||||
(300, 1.2),
|
||||
(300, 1.4),
|
||||
]
|
||||
|
||||
HALLUCINATION_MARKERS = ("thank you very much", "merci d'avoir regarde")
|
||||
|
||||
|
||||
def analyze(text: str) -> dict:
|
||||
words = text.split()
|
||||
lower_words = text.lower().split()
|
||||
trigrams = Counter(
|
||||
" ".join(lower_words[i : i + 3]) for i in range(len(lower_words) - 2)
|
||||
)
|
||||
top = trigrams.most_common(5)
|
||||
lower = text.lower()
|
||||
return {
|
||||
"n_words": len(words),
|
||||
"max_trigram_count": top[0][1] if top else 0,
|
||||
"max_trigram_text": top[0][0] if top else "",
|
||||
"halluc": {m: lower.count(m) for m in HALLUCINATION_MARKERS},
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Benchmark Granite Speech Plus MLX settings.")
|
||||
parser.add_argument("audio")
|
||||
parser.add_argument("--model", default=DEFAULT_MODEL)
|
||||
parser.add_argument("--results", default="bench")
|
||||
parser.add_argument("--prompt-mode", choices=sorted(PROMPT_MODES), default="asr")
|
||||
parser.add_argument("--overlap-seconds", type=float, default=2.0)
|
||||
parser.add_argument("--max-tokens", type=int, default=4096)
|
||||
args = parser.parse_args()
|
||||
|
||||
results_dir = Path(args.results)
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
pipe = GraniteSpeechPlusPipeline.from_pretrained(
|
||||
args.model,
|
||||
overlap_seconds=args.overlap_seconds,
|
||||
max_tokens=args.max_tokens,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
rows = []
|
||||
for chunk_seconds, repetition_penalty in GRID:
|
||||
out = results_dir / f"chunk{chunk_seconds}_rp{repetition_penalty:.1f}.txt"
|
||||
pipe.chunk_seconds = float(chunk_seconds)
|
||||
pipe.repetition_penalty = repetition_penalty
|
||||
|
||||
if out.exists():
|
||||
print(f"# skipping {out.name} (already exists, delete to rerun)", file=sys.stderr)
|
||||
elapsed = float("nan")
|
||||
text = out.read_text(encoding="utf-8")
|
||||
else:
|
||||
print(
|
||||
f"# running chunk={chunk_seconds}s rep_penalty={repetition_penalty}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
t0 = time.time()
|
||||
text = pipe.transcribe(args.audio, prompt_mode=args.prompt_mode)
|
||||
elapsed = time.time() - t0
|
||||
out.write_text(text + "\n", encoding="utf-8")
|
||||
|
||||
rows.append(
|
||||
{
|
||||
"chunk": chunk_seconds,
|
||||
"rp": repetition_penalty,
|
||||
"elapsed": elapsed,
|
||||
**analyze(text),
|
||||
}
|
||||
)
|
||||
|
||||
print()
|
||||
print("| chunk(s) | rp | wall(s) | words | max_trigram(N) | hallucinations |")
|
||||
print("|---:|---:|---:|---:|:---|:---|")
|
||||
for row in rows:
|
||||
halluc = ", ".join(
|
||||
f"{key.split()[0]}x{value}" for key, value in row["halluc"].items() if value
|
||||
) or "-"
|
||||
trigram = f"{row['max_trigram_text']!r} ({row['max_trigram_count']}x)"
|
||||
wall = "nan" if row["elapsed"] != row["elapsed"] else f"{row['elapsed']:.0f}"
|
||||
print(
|
||||
f"| {row['chunk']} | {row['rp']:.1f} | {wall} | {row['n_words']} "
|
||||
f"| {trigram} | {halluc} |"
|
||||
)
|
||||
print()
|
||||
print(f"Per-config transcripts in: {results_dir}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
47
scripts/transcribe.py
Executable file
47
scripts/transcribe.py
Executable file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from granite_speech_plus_mlx import GraniteSpeechPlusPipeline
|
||||
from granite_speech_plus_mlx.pipeline import DEFAULT_MODEL
|
||||
from granite_speech_plus_mlx.prompts import GRANITE_SYSTEM_PROMPT, PROMPT_MODES
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Transcribe audio with Granite Speech Plus MLX.")
|
||||
parser.add_argument("audio")
|
||||
parser.add_argument("--model", default=DEFAULT_MODEL)
|
||||
parser.add_argument("--output", default=None)
|
||||
parser.add_argument("--chunk-seconds", type=float, default=300.0)
|
||||
parser.add_argument("--overlap-seconds", type=float, default=2.0)
|
||||
parser.add_argument("--prompt-mode", choices=sorted(PROMPT_MODES), default="asr")
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.2)
|
||||
parser.add_argument("--max-tokens", type=int, default=4096)
|
||||
parser.add_argument("--system-prompt", default=GRANITE_SYSTEM_PROMPT)
|
||||
parser.add_argument("--verbose", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
pipe = GraniteSpeechPlusPipeline.from_pretrained(
|
||||
args.model,
|
||||
chunk_seconds=args.chunk_seconds,
|
||||
overlap_seconds=args.overlap_seconds,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
max_tokens=args.max_tokens,
|
||||
system_prompt=args.system_prompt or None,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
text = pipe.transcribe(args.audio, prompt_mode=args.prompt_mode)
|
||||
|
||||
if args.output:
|
||||
Path(args.output).write_text(text + "\n", encoding="utf-8")
|
||||
else:
|
||||
print(text)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
66
scripts/upload_to_hf.py
Executable file
66
scripts/upload_to_hf.py
Executable file
@@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
SOURCE_CACHE = (
|
||||
Path.home()
|
||||
/ ".cache/huggingface/hub/models--ibm-granite--granite-speech-4.1-2b-plus"
|
||||
)
|
||||
DEST_REPO = "mlx-community/granite-speech-4.1-2b-plus-mlx"
|
||||
|
||||
|
||||
def find_weights_dir(root: Path) -> Path | None:
|
||||
if not root.exists():
|
||||
return None
|
||||
if list(root.glob("*.safetensors")) or (root / "config.json").exists():
|
||||
return root
|
||||
snapshots = root / "snapshots"
|
||||
if snapshots.exists():
|
||||
candidates = [
|
||||
path
|
||||
for path in snapshots.iterdir()
|
||||
if path.is_dir() and (list(path.glob("*.safetensors")) or (path / "config.json").exists())
|
||||
]
|
||||
if candidates:
|
||||
return sorted(candidates, key=lambda p: p.stat().st_mtime)[-1]
|
||||
return None
|
||||
|
||||
|
||||
def print_manual_commands() -> None:
|
||||
print(f"MLX weights not found at {SOURCE_CACHE}")
|
||||
print("Create them first with:")
|
||||
print("mlxconv ibm-granite/granite-speech-4.1-2b-plus")
|
||||
print("mlxconv ibm-granite/granite-speech-4.1-2b-plus --dtype q4_k_4")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
weights_dir = find_weights_dir(SOURCE_CACHE)
|
||||
if weights_dir is None:
|
||||
print_manual_commands()
|
||||
return 1
|
||||
|
||||
token = os.environ.get("HF_TOKEN")
|
||||
if not token:
|
||||
print("HF_TOKEN is required to upload.", file=sys.stderr)
|
||||
return 2
|
||||
|
||||
api = HfApi(token=token)
|
||||
api.create_repo(DEST_REPO, repo_type="model", exist_ok=True)
|
||||
api.upload_folder(
|
||||
repo_id=DEST_REPO,
|
||||
repo_type="model",
|
||||
folder_path=str(weights_dir),
|
||||
commit_message="Upload Granite Speech 4.1-2b-plus MLX weights",
|
||||
)
|
||||
print(f"Uploaded {weights_dir} to {DEST_REPO}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
4
src/granite_speech_plus_mlx/__init__.py
Normal file
4
src/granite_speech_plus_mlx/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .pipeline import GraniteSpeechPlusPipeline
|
||||
|
||||
__all__ = ["GraniteSpeechPlusPipeline"]
|
||||
|
||||
1
src/granite_speech_plus_mlx/_vendored/__init__.py
Normal file
1
src/granite_speech_plus_mlx/_vendored/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
15
src/granite_speech_plus_mlx/_vendored/audio.py
Normal file
15
src/granite_speech_plus_mlx/_vendored/audio.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import librosa
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
|
||||
def load_audio(file: str | Path, sr: int = SAMPLE_RATE):
|
||||
import mlx.core as mx
|
||||
|
||||
audio, _ = librosa.load(str(file), sr=sr, mono=True)
|
||||
return mx.array(audio, dtype=mx.float32)
|
||||
|
||||
18
src/granite_speech_plus_mlx/_vendored/base.py
Normal file
18
src/granite_speech_plus_mlx/_vendored/base.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class STTOutput:
|
||||
text: str
|
||||
segments: List[dict] | None = None
|
||||
language: str | None = None
|
||||
prompt_tokens: int = 0
|
||||
generation_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
prompt_tps: float = 0.0
|
||||
generation_tps: float = 0.0
|
||||
total_time: float = 0.0
|
||||
|
||||
161
src/granite_speech_plus_mlx/_vendored/dsp.py
Normal file
161
src/granite_speech_plus_mlx/_vendored/dsp.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def hanning(size: int, periodic: bool = False):
|
||||
denom = size if periodic else size - 1
|
||||
return mx.array(
|
||||
[0.5 * (1 - math.cos(2 * math.pi * n / denom)) for n in range(size)]
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def hamming(size: int, periodic: bool = False):
|
||||
denom = size if periodic else size - 1
|
||||
return mx.array(
|
||||
[0.54 - 0.46 * math.cos(2 * math.pi * n / denom) for n in range(size)]
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def blackman(size: int, periodic: bool = False):
|
||||
denom = size if periodic else size - 1
|
||||
return mx.array(
|
||||
[
|
||||
0.42
|
||||
- 0.5 * math.cos(2 * math.pi * n / denom)
|
||||
+ 0.08 * math.cos(4 * math.pi * n / denom)
|
||||
for n in range(size)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def bartlett(size: int, periodic: bool = False):
|
||||
denom = size if periodic else size - 1
|
||||
return mx.array([1 - 2 * abs(n - denom / 2) / denom for n in range(size)])
|
||||
|
||||
|
||||
STR_TO_WINDOW_FN = {
|
||||
"hann": hanning,
|
||||
"hanning": hanning,
|
||||
"hamming": hamming,
|
||||
"blackman": blackman,
|
||||
"bartlett": bartlett,
|
||||
}
|
||||
|
||||
|
||||
def stft(
|
||||
x,
|
||||
n_fft: int = 800,
|
||||
hop_length: int | None = None,
|
||||
win_length: int | None = None,
|
||||
window: mx.array | str = "hann",
|
||||
center: bool = True,
|
||||
pad_mode: str = "reflect",
|
||||
):
|
||||
if hop_length is None:
|
||||
hop_length = n_fft // 4
|
||||
if win_length is None:
|
||||
win_length = n_fft
|
||||
|
||||
if isinstance(window, str):
|
||||
window_fn = STR_TO_WINDOW_FN.get(window.lower())
|
||||
if window_fn is None:
|
||||
raise ValueError(f"Unknown window function: {window}")
|
||||
w = window_fn(win_length)
|
||||
else:
|
||||
w = window
|
||||
|
||||
if w.shape[0] < n_fft:
|
||||
pad_size = n_fft - w.shape[0]
|
||||
w = mx.concatenate([w, mx.zeros((pad_size,))], axis=0)
|
||||
|
||||
def _pad(signal, padding: int, mode: str = "reflect"):
|
||||
if mode == "constant":
|
||||
return mx.pad(signal, [(padding, padding)])
|
||||
if mode == "reflect":
|
||||
prefix = signal[1 : padding + 1][::-1]
|
||||
suffix = signal[-(padding + 1) : -1][::-1]
|
||||
return mx.concatenate([prefix, signal, suffix])
|
||||
raise ValueError(f"Invalid pad_mode {mode}")
|
||||
|
||||
if center:
|
||||
x = _pad(x, n_fft // 2, pad_mode)
|
||||
|
||||
num_frames = 1 + (x.shape[0] - n_fft) // hop_length
|
||||
if num_frames <= 0:
|
||||
raise ValueError(
|
||||
f"Input is too short for n_fft={n_fft}, hop_length={hop_length}, "
|
||||
f"center={center}."
|
||||
)
|
||||
|
||||
frames = mx.as_strided(x, shape=(num_frames, n_fft), strides=(hop_length, 1))
|
||||
return mx.fft.rfft(frames * w)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def mel_filters(
|
||||
sample_rate: int,
|
||||
n_fft: int,
|
||||
n_mels: int,
|
||||
f_min: float = 0,
|
||||
f_max: Optional[float] = None,
|
||||
norm: Optional[str] = None,
|
||||
mel_scale: str = "htk",
|
||||
) -> mx.array:
|
||||
def hz_to_mel(freq, scale="htk"):
|
||||
if scale == "htk":
|
||||
return 2595.0 * math.log10(1.0 + freq / 700.0)
|
||||
|
||||
f_sp = 200.0 / 3
|
||||
mels = freq / f_sp
|
||||
min_log_hz = 1000.0
|
||||
min_log_mel = min_log_hz / f_sp
|
||||
logstep = math.log(6.4) / 27.0
|
||||
if freq >= min_log_hz:
|
||||
mels = min_log_mel + math.log(freq / min_log_hz) / logstep
|
||||
return mels
|
||||
|
||||
def mel_to_hz(mels, scale="htk"):
|
||||
if scale == "htk":
|
||||
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
|
||||
|
||||
f_sp = 200.0 / 3
|
||||
freqs = f_sp * mels
|
||||
min_log_hz = 1000.0
|
||||
min_log_mel = min_log_hz / f_sp
|
||||
logstep = math.log(6.4) / 27.0
|
||||
return mx.where(
|
||||
mels >= min_log_mel,
|
||||
min_log_hz * mx.exp(logstep * (mels - min_log_mel)),
|
||||
freqs,
|
||||
)
|
||||
|
||||
f_max = f_max or sample_rate / 2
|
||||
n_freqs = n_fft // 2 + 1
|
||||
all_freqs = mx.linspace(0, sample_rate // 2, n_freqs)
|
||||
m_min = hz_to_mel(f_min, mel_scale)
|
||||
m_max = hz_to_mel(f_max, mel_scale)
|
||||
m_pts = mx.linspace(m_min, m_max, n_mels + 2)
|
||||
f_pts = mel_to_hz(m_pts, mel_scale)
|
||||
f_diff = f_pts[1:] - f_pts[:-1]
|
||||
slopes = mx.expand_dims(f_pts, 0) - mx.expand_dims(all_freqs, 1)
|
||||
down_slopes = (-slopes[:, :-2]) / f_diff[:-1]
|
||||
up_slopes = slopes[:, 2:] / f_diff[1:]
|
||||
filterbank = mx.maximum(
|
||||
mx.zeros_like(down_slopes), mx.minimum(down_slopes, up_slopes)
|
||||
)
|
||||
|
||||
if norm == "slaney":
|
||||
enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
|
||||
filterbank *= mx.expand_dims(enorm, 0)
|
||||
|
||||
return filterbank.moveaxis(0, 1)
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .config import EncoderConfig, ModelConfig, ProjectorConfig, TextConfig
|
||||
from .granite_speech import Model
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraniteSpeechPlusModelConfig(ModelConfig):
|
||||
model_type: str = "granite_speech_plus"
|
||||
|
||||
|
||||
DETECTION_HINTS = {
|
||||
"config_keys": {"encoder_config", "projector_config", "audio_token_index"},
|
||||
"architectures": {
|
||||
"GraniteSpeechForConditionalGeneration",
|
||||
"GraniteSpeechPlusForConditionalGeneration",
|
||||
},
|
||||
"path_patterns": {
|
||||
"granite_speech_plus",
|
||||
"granitespeechplus",
|
||||
"granite-speech-4.1-2b-plus",
|
||||
},
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
"EncoderConfig",
|
||||
"ProjectorConfig",
|
||||
"TextConfig",
|
||||
"ModelConfig",
|
||||
"GraniteSpeechPlusModelConfig",
|
||||
"Model",
|
||||
"DETECTION_HINTS",
|
||||
]
|
||||
|
||||
128
src/granite_speech_plus_mlx/_vendored/granite_speech/config.py
Normal file
128
src/granite_speech_plus_mlx/_vendored/granite_speech/config.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class EncoderConfig:
|
||||
input_dim: int = 160
|
||||
num_layers: int = 10
|
||||
hidden_dim: int = 1024
|
||||
feedforward_mult: int = 4
|
||||
num_heads: int = 8
|
||||
dim_head: int = 128
|
||||
output_dim: int = 42
|
||||
context_size: int = 200
|
||||
max_pos_emb: int = 512
|
||||
dropout: float = 0.1
|
||||
conv_kernel_size: int = 15
|
||||
conv_expansion_factor: int = 2
|
||||
# Plus variant: indices of intermediate encoder layers whose hidden state
|
||||
# gets concatenated with the final-layer hidden state along the channel
|
||||
# axis, before being fed to the projector. None / empty = base behavior.
|
||||
cat_hidden_layers: Optional[List[int]] = None
|
||||
model_type: str = "granite_speech_encoder"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProjectorConfig:
|
||||
hidden_size: int = 1024
|
||||
num_hidden_layers: int = 2
|
||||
num_attention_heads: int = 16
|
||||
intermediate_size: int = 4096
|
||||
hidden_act: str = "gelu"
|
||||
layer_norm_eps: float = 1e-12
|
||||
encoder_hidden_size: int = 1024
|
||||
cross_attention_frequency: int = 1
|
||||
model_type: str = "blip_2_qformer"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextConfig:
|
||||
model_type: str = "granite"
|
||||
vocab_size: int = 100353
|
||||
hidden_size: int = 2048
|
||||
intermediate_size: int = 4096
|
||||
num_hidden_layers: int = 40
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 4
|
||||
hidden_act: str = "silu"
|
||||
max_position_embeddings: int = 4096
|
||||
rms_norm_eps: float = 1e-5
|
||||
rope_theta: float = 10000.0
|
||||
rope_scaling: Optional[Dict] = None
|
||||
attention_bias: bool = False
|
||||
mlp_bias: bool = False
|
||||
attention_multiplier: float = 0.0078125
|
||||
embedding_multiplier: float = 12.0
|
||||
residual_multiplier: float = 0.22
|
||||
logits_scaling: float = 8.0
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
model_type: str = "granite_speech"
|
||||
encoder_config: EncoderConfig = None
|
||||
projector_config: ProjectorConfig = None
|
||||
text_config: TextConfig = None
|
||||
audio_token_index: int = 100352
|
||||
downsample_rate: int = 5
|
||||
window_size: int = 15
|
||||
has_lora_adapter: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.encoder_config, dict):
|
||||
self.encoder_config = EncoderConfig.from_dict(self.encoder_config)
|
||||
elif self.encoder_config is None:
|
||||
self.encoder_config = EncoderConfig()
|
||||
|
||||
if isinstance(self.projector_config, dict):
|
||||
self.projector_config = ProjectorConfig.from_dict(self.projector_config)
|
||||
elif self.projector_config is None:
|
||||
self.projector_config = ProjectorConfig()
|
||||
|
||||
if isinstance(self.text_config, dict):
|
||||
self.text_config = TextConfig.from_dict(self.text_config)
|
||||
elif self.text_config is None:
|
||||
self.text_config = TextConfig()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, params):
|
||||
return cls(
|
||||
**{
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k in inspect.signature(cls).parameters
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,850 @@
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx_lm.models.base import create_attention_mask
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.granite import Model as GraniteLM
|
||||
from mlx_lm.models.granite import ModelArgs as GraniteModelArgs
|
||||
|
||||
from ..base import STTOutput
|
||||
|
||||
from .config import EncoderConfig, ModelConfig, ProjectorConfig
|
||||
|
||||
LANGUAGE_CODES = {
|
||||
"en": "English",
|
||||
"fr": "French",
|
||||
"de": "German",
|
||||
"es": "Spanish",
|
||||
"pt": "Portuguese",
|
||||
"ja": "Japanese",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingResult:
|
||||
text: str
|
||||
is_final: bool
|
||||
start_time: float
|
||||
end_time: float
|
||||
language: str = "en"
|
||||
prompt_tokens: int = 0
|
||||
generation_tokens: int = 0
|
||||
|
||||
|
||||
class BatchNorm1d(nn.Module):
|
||||
|
||||
def __init__(self, num_features: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((num_features,))
|
||||
self.bias = mx.zeros((num_features,))
|
||||
self.running_mean = mx.zeros((num_features,))
|
||||
self.running_var = mx.ones((num_features,))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return (x - self.running_mean) / mx.sqrt(
|
||||
self.running_var + self.eps
|
||||
) * self.weight + self.bias
|
||||
|
||||
|
||||
class ConformerFeedForward(nn.Module):
|
||||
def __init__(self, config: EncoderConfig):
|
||||
super().__init__()
|
||||
self.pre_norm = nn.LayerNorm(config.hidden_dim)
|
||||
self.up_proj = nn.Linear(
|
||||
config.hidden_dim, config.hidden_dim * config.feedforward_mult
|
||||
)
|
||||
self.down_proj = nn.Linear(
|
||||
config.hidden_dim * config.feedforward_mult, config.hidden_dim
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = self.pre_norm(x)
|
||||
x = nn.silu(self.up_proj(x))
|
||||
x = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConformerAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: EncoderConfig):
|
||||
super().__init__()
|
||||
inner_dim = config.dim_head * config.num_heads
|
||||
self.max_pos_emb = config.max_pos_emb
|
||||
self.context_size = config.context_size
|
||||
self.num_heads = config.num_heads
|
||||
self.dim_head = config.dim_head
|
||||
self.scale = config.dim_head**-0.5
|
||||
self.pre_norm = nn.LayerNorm(config.hidden_dim)
|
||||
self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, config.hidden_dim)
|
||||
self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head)
|
||||
|
||||
def __call__(self, x: mx.array, attention_dists: mx.array) -> mx.array:
|
||||
x = self.pre_norm(x)
|
||||
B, N, _ = x.shape
|
||||
|
||||
num_blocks = math.ceil(N / self.context_size)
|
||||
remainder = N % self.context_size
|
||||
|
||||
if remainder > 0:
|
||||
pad_len = self.context_size - remainder
|
||||
x = mx.pad(x, [(0, 0), (0, pad_len), (0, 0)])
|
||||
|
||||
q = self.to_q(x)
|
||||
kv = self.to_kv(x)
|
||||
k, v = mx.split(kv, 2, axis=-1)
|
||||
|
||||
q = q.reshape(B, num_blocks, self.context_size, self.num_heads, -1)
|
||||
k = k.reshape(B, num_blocks, self.context_size, self.num_heads, -1)
|
||||
v = v.reshape(B, num_blocks, self.context_size, self.num_heads, -1)
|
||||
|
||||
q = q.transpose(0, 1, 3, 2, 4)
|
||||
k = k.transpose(0, 1, 3, 2, 4)
|
||||
v = v.transpose(0, 1, 3, 2, 4)
|
||||
|
||||
rel_pos_emb = self.rel_pos_emb(attention_dists)
|
||||
|
||||
C = self.context_size
|
||||
pos_attn = (
|
||||
mx.sum(
|
||||
q[:, :, :, :, None, :] * rel_pos_emb[None, None, None, :, :, :],
|
||||
axis=-1,
|
||||
)
|
||||
* self.scale
|
||||
)
|
||||
|
||||
if remainder > 0:
|
||||
row_valid = mx.arange(C)[:, None] < remainder
|
||||
col_valid = mx.arange(C)[None, :] < remainder
|
||||
mask = ~(row_valid & col_valid)
|
||||
mask_value = mx.array(mx.finfo(pos_attn.dtype).min)
|
||||
pos_attn_last = mx.where(
|
||||
mask[None, None, None], mask_value, pos_attn[:, -1:, :, :, :]
|
||||
)
|
||||
pos_attn = mx.concatenate(
|
||||
[pos_attn[:, :-1, :, :, :], pos_attn_last], axis=1
|
||||
)
|
||||
|
||||
attn_weights = (q @ k.transpose(0, 1, 2, 4, 3)) * self.scale + pos_attn
|
||||
attn_weights = mx.softmax(attn_weights, axis=-1)
|
||||
|
||||
out = attn_weights @ v
|
||||
out = out.transpose(0, 1, 3, 2, 4)
|
||||
out = out.reshape(B, -1, self.num_heads * self.dim_head)
|
||||
out = out[:, :N, :]
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
|
||||
def __init__(self, chan_in: int, chan_out: int, kernel_size: int):
|
||||
super().__init__()
|
||||
pad = kernel_size // 2
|
||||
pad_offset = (kernel_size + 1) % 2
|
||||
self.padding = (pad, pad - pad_offset)
|
||||
self.conv = nn.Conv1d(
|
||||
chan_in, chan_out, kernel_size, groups=chan_in, bias=False
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = mx.pad(x, [(0, 0), (self.padding[0], self.padding[1]), (0, 0)])
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class ConformerConvModule(nn.Module):
|
||||
|
||||
def __init__(self, config: EncoderConfig):
|
||||
super().__init__()
|
||||
inner_dim = config.hidden_dim * config.conv_expansion_factor
|
||||
|
||||
self.norm = nn.LayerNorm(config.hidden_dim)
|
||||
self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1)
|
||||
self.depth_conv = DepthWiseConv1d(inner_dim, inner_dim, config.conv_kernel_size)
|
||||
self.batch_norm = BatchNorm1d(inner_dim)
|
||||
self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = self.norm(x)
|
||||
x = self.up_conv(x)
|
||||
x1, x2 = mx.split(x, 2, axis=-1)
|
||||
x = x1 * mx.sigmoid(x2)
|
||||
x = self.depth_conv(x)
|
||||
x = nn.silu(self.batch_norm(x))
|
||||
x = self.down_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConformerBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: EncoderConfig):
|
||||
super().__init__()
|
||||
self.ff1 = ConformerFeedForward(config)
|
||||
self.attn = ConformerAttention(config)
|
||||
self.conv = ConformerConvModule(config)
|
||||
self.ff2 = ConformerFeedForward(config)
|
||||
self.post_norm = nn.LayerNorm(config.hidden_dim)
|
||||
|
||||
def __call__(self, x: mx.array, attention_dists: mx.array) -> mx.array:
|
||||
x = 0.5 * self.ff1(x) + x
|
||||
x = self.attn(x, attention_dists) + x
|
||||
x = self.conv(x) + x
|
||||
x = 0.5 * self.ff2(x) + x
|
||||
x = self.post_norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class CTCEncoder(nn.Module):
|
||||
|
||||
def __init__(self, config: EncoderConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.input_linear = nn.Linear(config.input_dim, config.hidden_dim)
|
||||
self.layers = [ConformerBlock(config) for _ in range(config.num_layers)]
|
||||
self.out = nn.Linear(config.hidden_dim, config.output_dim)
|
||||
self.out_mid = nn.Linear(config.output_dim, config.hidden_dim)
|
||||
self.num_layers = config.num_layers
|
||||
self._attention_dists = None
|
||||
|
||||
seq = mx.arange(config.context_size)
|
||||
relpos_dist = seq[:, None] - seq[None, :]
|
||||
self._attention_dists = (
|
||||
mx.clip(relpos_dist, -config.context_size, config.context_size)
|
||||
+ config.max_pos_emb
|
||||
)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
x = self.input_linear(x)
|
||||
cat_layers = set(self.config.cat_hidden_layers or [])
|
||||
exported_hidden_states = []
|
||||
if 0 in cat_layers:
|
||||
exported_hidden_states.append(x)
|
||||
|
||||
for idx, layer in enumerate(self.layers, start=1):
|
||||
x = layer(x, attention_dists=self._attention_dists)
|
||||
if idx in cat_layers:
|
||||
exported_hidden_states.append(x)
|
||||
if idx == self.num_layers // 2:
|
||||
x_mid = self.out(x)
|
||||
x = x + self.out_mid(mx.softmax(x_mid, axis=-1))
|
||||
|
||||
if exported_hidden_states:
|
||||
# Plus variant: prepend captured intermediate hidden states to the
|
||||
# final-layer output along the channel axis. Order matches the
|
||||
# upstream Transformers implementation: intermediates first, then
|
||||
# final.
|
||||
x = mx.concatenate([*exported_hidden_states, x], axis=-1)
|
||||
return x
|
||||
|
||||
|
||||
class QFormerMultiHeadAttention(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, kv_hidden_size: int = None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
kv_dim = kv_hidden_size or hidden_size
|
||||
|
||||
self.query = nn.Linear(hidden_size, hidden_size)
|
||||
self.key = nn.Linear(kv_dim, hidden_size)
|
||||
self.value = nn.Linear(kv_dim, hidden_size)
|
||||
|
||||
def __call__(
|
||||
self, hidden_states: mx.array, encoder_hidden_states: mx.array = None
|
||||
) -> mx.array:
|
||||
B, L, _ = hidden_states.shape
|
||||
|
||||
q = self.query(hidden_states)
|
||||
kv_input = (
|
||||
encoder_hidden_states
|
||||
if encoder_hidden_states is not None
|
||||
else hidden_states
|
||||
)
|
||||
k = self.key(kv_input)
|
||||
v = self.value(kv_input)
|
||||
|
||||
q = q.reshape(B, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||
k = k.reshape(B, -1, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||
v = v.reshape(B, -1, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||
|
||||
scale = self.head_dim**-0.5
|
||||
attn = (q * scale) @ k.transpose(0, 1, 3, 2)
|
||||
attn = mx.softmax(attn, axis=-1)
|
||||
out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return out
|
||||
|
||||
|
||||
class QFormerSelfOutput(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-12):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(hidden_size, hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(hidden_size, eps=eps)
|
||||
|
||||
def __call__(self, hidden_states: mx.array, input_tensor: mx.array) -> mx.array:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class QFormerAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
kv_hidden_size: int = None,
|
||||
eps: float = 1e-12,
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = QFormerMultiHeadAttention(
|
||||
hidden_size, num_heads, kv_hidden_size
|
||||
)
|
||||
self.output = QFormerSelfOutput(hidden_size, eps)
|
||||
|
||||
def __call__(
|
||||
self, hidden_states: mx.array, encoder_hidden_states: mx.array = None
|
||||
) -> mx.array:
|
||||
attn_out = self.attention(hidden_states, encoder_hidden_states)
|
||||
return self.output(attn_out, hidden_states)
|
||||
|
||||
|
||||
class QFormerIntermediate(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, intermediate_size: int):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(hidden_size, intermediate_size)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return nn.gelu(self.dense(x))
|
||||
|
||||
|
||||
class QFormerOutput(nn.Module):
|
||||
|
||||
def __init__(self, intermediate_size: int, hidden_size: int, eps: float = 1e-12):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(intermediate_size, hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(hidden_size, eps=eps)
|
||||
|
||||
def __call__(self, hidden_states: mx.array, input_tensor: mx.array) -> mx.array:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class QFormerLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: ProjectorConfig):
|
||||
super().__init__()
|
||||
self.attention = QFormerAttention(
|
||||
config.hidden_size, config.num_attention_heads, eps=config.layer_norm_eps
|
||||
)
|
||||
self.crossattention = QFormerAttention(
|
||||
config.hidden_size,
|
||||
config.num_attention_heads,
|
||||
kv_hidden_size=config.encoder_hidden_size,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
self.intermediate_query = QFormerIntermediate(
|
||||
config.hidden_size, config.intermediate_size
|
||||
)
|
||||
self.output_query = QFormerOutput(
|
||||
config.intermediate_size, config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, hidden_states: mx.array, encoder_hidden_states: mx.array
|
||||
) -> mx.array:
|
||||
hidden_states = self.attention(hidden_states)
|
||||
hidden_states = self.crossattention(hidden_states, encoder_hidden_states)
|
||||
intermediate = self.intermediate_query(hidden_states)
|
||||
hidden_states = self.output_query(intermediate, hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class QFormerEncoder(nn.Module):
|
||||
def __init__(self, config: ProjectorConfig):
|
||||
super().__init__()
|
||||
self.layer = [QFormerLayer(config) for _ in range(config.num_hidden_layers)]
|
||||
|
||||
def __call__(
|
||||
self, hidden_states: mx.array, encoder_hidden_states: mx.array
|
||||
) -> mx.array:
|
||||
for layer in self.layer:
|
||||
hidden_states = layer(hidden_states, encoder_hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class QFormerModel(nn.Module):
|
||||
def __init__(self, config: ProjectorConfig):
|
||||
super().__init__()
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.encoder = QFormerEncoder(config)
|
||||
|
||||
def __call__(
|
||||
self, query_embeds: mx.array, encoder_hidden_states: mx.array
|
||||
) -> mx.array:
|
||||
hidden_states = self.layernorm(query_embeds)
|
||||
return self.encoder(hidden_states, encoder_hidden_states)
|
||||
|
||||
|
||||
class EncoderProjector(nn.Module):
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.projector_config.hidden_size
|
||||
self.downsample_rate = config.downsample_rate
|
||||
self.window_size = config.window_size
|
||||
self.num_queries = config.window_size // config.downsample_rate
|
||||
|
||||
self.query = mx.zeros(
|
||||
(1, self.num_queries, config.projector_config.hidden_size)
|
||||
)
|
||||
self.qformer = QFormerModel(config.projector_config)
|
||||
self.linear = nn.Linear(
|
||||
config.projector_config.hidden_size, config.text_config.hidden_size
|
||||
)
|
||||
|
||||
def __call__(self, hidden_states: mx.array) -> mx.array:
|
||||
B, L, D = hidden_states.shape
|
||||
nblocks = math.ceil(L / self.window_size)
|
||||
pad = nblocks * self.window_size - L
|
||||
if pad > 0:
|
||||
hidden_states = mx.pad(hidden_states, [(0, 0), (0, pad), (0, 0)])
|
||||
|
||||
hidden_states = hidden_states.reshape(B * nblocks, self.window_size, D)
|
||||
|
||||
query = mx.broadcast_to(
|
||||
self.query, (B * nblocks, self.num_queries, self.hidden_size)
|
||||
)
|
||||
|
||||
query_output = self.qformer(query, hidden_states)
|
||||
query_proj = self.linear(
|
||||
query_output.reshape(B, nblocks * self.num_queries, -1)
|
||||
)
|
||||
return query_proj
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Plus variant invariant: encoder concats len(cat_hidden_layers)+1
|
||||
# hidden states channel-wise, so projector must accept that wider input.
|
||||
cat_layers = config.encoder_config.cat_hidden_layers or []
|
||||
expected_proj_in = config.encoder_config.hidden_dim * (len(cat_layers) + 1)
|
||||
if config.projector_config.encoder_hidden_size != expected_proj_in:
|
||||
raise ValueError(
|
||||
f"projector_config.encoder_hidden_size ({config.projector_config.encoder_hidden_size}) "
|
||||
f"must equal encoder_config.hidden_dim * (len(cat_hidden_layers) + 1) "
|
||||
f"({config.encoder_config.hidden_dim} * {len(cat_layers) + 1} = {expected_proj_in})"
|
||||
)
|
||||
|
||||
self.encoder = CTCEncoder(config.encoder_config)
|
||||
self.projector = EncoderProjector(config)
|
||||
text_args = GraniteModelArgs.from_dict(
|
||||
config.text_config.__dict__
|
||||
if hasattr(config.text_config, "__dict__")
|
||||
else config.text_config
|
||||
)
|
||||
self.language_model = GraniteLM(text_args)
|
||||
|
||||
self.audio_token_id = config.audio_token_index
|
||||
self._tokenizer = None
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.language_model.model.layers
|
||||
|
||||
def make_cache(self) -> List[KVCache]:
|
||||
return [KVCache() for _ in range(len(self.layers))]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: mx.array,
|
||||
cache: Optional[List[KVCache]] = None,
|
||||
input_embeddings: Optional[mx.array] = None,
|
||||
) -> mx.array:
|
||||
if input_embeddings is not None:
|
||||
h = input_embeddings
|
||||
else:
|
||||
h = self.language_model.model.embed_tokens(input_ids)
|
||||
|
||||
h = h * self.language_model.model.embedding_multiplier
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.language_model.model.layers)
|
||||
|
||||
mask = create_attention_mask(h, cache[0])
|
||||
|
||||
for layer, c in zip(self.language_model.model.layers, cache):
|
||||
h = layer(h, mask, cache=c)
|
||||
|
||||
h = self.language_model.model.norm(h)
|
||||
|
||||
if self.language_model.args.tie_word_embeddings:
|
||||
logits = self.language_model.model.embed_tokens.as_linear(h)
|
||||
else:
|
||||
logits = self.language_model.lm_head(h)
|
||||
|
||||
return logits / self.language_model.logits_scaling
|
||||
|
||||
def get_audio_features(self, input_features: mx.array) -> mx.array:
|
||||
encoder_output = self.encoder(input_features)
|
||||
projected = self.projector(encoder_output)
|
||||
return projected
|
||||
|
||||
def model_quant_predicate(self, p: str, m: nn.Module) -> bool:
|
||||
return not (p.startswith("encoder") or p.startswith("projector"))
|
||||
|
||||
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||
# Compare incoming weight shapes against the model's already-initialized
|
||||
# parameter shapes. This is idempotent across convert-time (PyTorch source
|
||||
# layout) and inference-time load (MLX-native layout) and correct even for
|
||||
# Conv1d kernel_size=1 layers where prior shape-ordering heuristics failed.
|
||||
# Pattern adapted from cohere_asr.Model.sanitize.
|
||||
model_weights = dict(tree_flatten(self.parameters()))
|
||||
sanitized = {}
|
||||
for k, v in weights.items():
|
||||
if "num_batches_tracked" in k:
|
||||
continue
|
||||
# granite-speech-4.1 ships a separate out_llm.safetensors with
|
||||
# top-level "weight"/"bias" keys (likely an audio CTC head). The
|
||||
# standard Model class does not define this layer, so dropping these
|
||||
# keys is required for the rest of the load to succeed. Inference
|
||||
# behaviour that depends on this head is not yet supported.
|
||||
if k in ("weight", "bias"):
|
||||
continue
|
||||
|
||||
expected = model_weights.get(k)
|
||||
if expected is not None and hasattr(expected, "shape"):
|
||||
if v.shape != expected.shape and v.ndim == 3:
|
||||
transposed = mx.transpose(v, (0, 2, 1))
|
||||
if transposed.shape == expected.shape:
|
||||
v = transposed
|
||||
|
||||
sanitized[k] = v
|
||||
return sanitized
|
||||
|
||||
@classmethod
|
||||
def post_load_hook(cls, model: "Model", model_path: Path) -> "Model":
|
||||
import transformers
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
prev = transformers.logging.get_verbosity()
|
||||
transformers.logging.set_verbosity_error()
|
||||
try:
|
||||
model._tokenizer = AutoTokenizer.from_pretrained(
|
||||
str(model_path), trust_remote_code=True
|
||||
)
|
||||
finally:
|
||||
transformers.logging.set_verbosity(prev)
|
||||
|
||||
return model
|
||||
|
||||
def _extract_features(
|
||||
self, audio: Union[mx.array, np.ndarray]
|
||||
) -> Tuple[mx.array, int]:
|
||||
from ..dsp import hanning, mel_filters, stft
|
||||
|
||||
n_fft = 512
|
||||
win_length = 400
|
||||
hop_length = 160
|
||||
n_mels = 80
|
||||
sample_rate = 16000
|
||||
|
||||
if isinstance(audio, mx.array):
|
||||
audio_1d = audio.reshape(-1)
|
||||
else:
|
||||
audio_1d = mx.array(audio.flatten(), dtype=mx.float32)
|
||||
|
||||
win = hanning(win_length, periodic=True)
|
||||
pad_left = (n_fft - win_length) // 2
|
||||
pad_right = n_fft - win_length - pad_left
|
||||
win_padded = mx.concatenate(
|
||||
[mx.zeros((pad_left,)), win, mx.zeros((pad_right,))]
|
||||
)
|
||||
|
||||
spec = stft(
|
||||
audio_1d,
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
window=win_padded,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
)
|
||||
|
||||
power = mx.abs(spec) ** 2
|
||||
mel_fb = mel_filters(sample_rate, n_fft, n_mels, mel_scale="htk")
|
||||
mel_spec = power @ mel_fb.T
|
||||
|
||||
logmel = mx.log10(mx.clip(mel_spec, 1e-10, None))
|
||||
mx_val = mx.max(logmel)
|
||||
logmel = mx.maximum(logmel, mx_val - 8.0) / 4.0 + 1.0
|
||||
|
||||
if logmel.shape[0] % 2 == 1:
|
||||
logmel = logmel[:-1]
|
||||
|
||||
encoder_input = logmel.reshape(-1, 2 * n_mels)
|
||||
|
||||
encoder_length = encoder_input.shape[0]
|
||||
nblocks = math.ceil(encoder_length / self.config.window_size)
|
||||
num_audio_tokens = nblocks * (
|
||||
self.config.window_size // self.config.downsample_rate
|
||||
)
|
||||
|
||||
input_features = encoder_input[None, :, :]
|
||||
return input_features, num_audio_tokens
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
num_audio_tokens: int,
|
||||
user_prompt: str = None,
|
||||
system_prompt: str = None,
|
||||
) -> mx.array:
|
||||
if user_prompt is None:
|
||||
user_prompt = "can you transcribe the speech into a written format?"
|
||||
|
||||
audio_placeholder = "<|audio|>" * num_audio_tokens
|
||||
content = f"{audio_placeholder}{user_prompt}"
|
||||
|
||||
if getattr(self._tokenizer, "chat_template", None):
|
||||
chat = []
|
||||
if system_prompt:
|
||||
chat.append({"role": "system", "content": system_prompt})
|
||||
chat.append({"role": "user", "content": content})
|
||||
prompt_str = self._tokenizer.apply_chat_template(
|
||||
chat, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
else:
|
||||
# Granite-3 chat format (granite-speech tokenizer ships without a
|
||||
# chat_template attribute, but its vocab includes the role tokens).
|
||||
sor, eor, eot = "<|start_of_role|>", "<|end_of_role|>", "<|end_of_text|>"
|
||||
parts = []
|
||||
if system_prompt:
|
||||
parts.append(f"{sor}system{eor}{system_prompt}{eot}\n")
|
||||
parts.append(f"{sor}user{eor}{content}{eot}\n")
|
||||
parts.append(f"{sor}assistant{eor}")
|
||||
prompt_str = "".join(parts)
|
||||
|
||||
prompt_ids = self._tokenizer.encode(prompt_str)
|
||||
|
||||
return mx.array(prompt_ids)
|
||||
|
||||
def _build_inputs_embeds(
|
||||
self, input_ids: mx.array, audio_features: mx.array
|
||||
) -> mx.array:
|
||||
is_audio = input_ids == self.audio_token_id
|
||||
llm_ids = mx.where(is_audio, 0, input_ids)
|
||||
|
||||
inputs_embeds = self.language_model.model.embed_tokens(llm_ids[None])
|
||||
|
||||
is_audio_np = np.array(is_audio)
|
||||
audio_positions = np.where(is_audio_np)[0]
|
||||
|
||||
orig_dtype = inputs_embeds.dtype
|
||||
embeds_np = np.array(inputs_embeds.astype(mx.float32))
|
||||
audio_np = np.array(audio_features.astype(mx.float32))
|
||||
|
||||
num_audio = min(len(audio_positions), audio_np.shape[1])
|
||||
embeds_np[0, audio_positions[:num_audio]] = audio_np[0, :num_audio]
|
||||
|
||||
return mx.array(embeds_np).astype(orig_dtype)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
audio: Union[str, mx.array, np.ndarray],
|
||||
*,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 0,
|
||||
min_p: float = 0.0,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: int = 100,
|
||||
prompt: str = None,
|
||||
system_prompt: str = None,
|
||||
language: str = None,
|
||||
prefill_step_size: int = 2048,
|
||||
verbose: bool = False,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[STTOutput, Generator[StreamingResult, None, None]]:
|
||||
if prompt is None and language is not None:
|
||||
lang_name = LANGUAGE_CODES.get(language.lower(), language)
|
||||
prompt = f"Translate the speech to {lang_name}."
|
||||
|
||||
if stream:
|
||||
return self._stream_generate(
|
||||
audio,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
repetition_context_size=repetition_context_size,
|
||||
prompt=prompt,
|
||||
prefill_step_size=prefill_step_size,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
from mlx_lm.generate import generate_step
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
|
||||
audio_data = self._load_audio(audio)
|
||||
input_features, num_audio_tokens = self._extract_features(audio_data)
|
||||
|
||||
if verbose:
|
||||
print("Encoding audio...")
|
||||
audio_features = self.get_audio_features(input_features)
|
||||
mx.eval(audio_features)
|
||||
|
||||
prompt_ids = self._build_prompt(num_audio_tokens, prompt, system_prompt=system_prompt)
|
||||
inputs_embeds = self._build_inputs_embeds(prompt_ids, audio_features)
|
||||
mx.eval(inputs_embeds)
|
||||
|
||||
prompt_tokens = len(prompt_ids)
|
||||
|
||||
sampler = make_sampler(temperature, top_p=top_p, min_p=min_p, top_k=top_k)
|
||||
logits_processors = make_logits_processors(
|
||||
repetition_penalty=repetition_penalty,
|
||||
repetition_context_size=repetition_context_size,
|
||||
)
|
||||
|
||||
eos_token_id = self._tokenizer.eos_token_id
|
||||
tokens = []
|
||||
|
||||
for token, logprobs in generate_step(
|
||||
prompt=prompt_ids,
|
||||
input_embeddings=inputs_embeds.squeeze(0),
|
||||
model=self,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prefill_step_size=prefill_step_size,
|
||||
):
|
||||
if token == eos_token_id:
|
||||
break
|
||||
tokens.append(token)
|
||||
|
||||
text = self._tokenizer.decode(tokens, skip_special_tokens=True)
|
||||
elapsed = time.time() - start_time
|
||||
gen_tokens = len(tokens)
|
||||
|
||||
if verbose:
|
||||
print(f"Prompt tokens: {prompt_tokens}")
|
||||
print(f"Generation tokens: {gen_tokens}")
|
||||
print(f"Total time: {elapsed:.2f}s")
|
||||
if gen_tokens > 0:
|
||||
print(f"Generation TPS: {gen_tokens / elapsed:.1f}")
|
||||
|
||||
return STTOutput(
|
||||
text=text,
|
||||
segments=[],
|
||||
prompt_tokens=prompt_tokens,
|
||||
generation_tokens=gen_tokens,
|
||||
total_tokens=prompt_tokens + gen_tokens,
|
||||
total_time=elapsed,
|
||||
prompt_tps=prompt_tokens / elapsed if elapsed > 0 else 0,
|
||||
generation_tps=gen_tokens / elapsed if elapsed > 0 else 0,
|
||||
)
|
||||
|
||||
def _stream_generate(
|
||||
self,
|
||||
audio: Union[str, mx.array, np.ndarray],
|
||||
*,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 0,
|
||||
min_p: float = 0.0,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_context_size: int = 100,
|
||||
prompt: str = None,
|
||||
prefill_step_size: int = 2048,
|
||||
verbose: bool = False,
|
||||
) -> Generator[StreamingResult, None, None]:
|
||||
from mlx_lm.generate import generate_step
|
||||
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||
|
||||
audio_data = self._load_audio(audio)
|
||||
input_features, num_audio_tokens = self._extract_features(audio_data)
|
||||
|
||||
audio_features = self.get_audio_features(input_features)
|
||||
mx.eval(audio_features)
|
||||
|
||||
prompt_ids = self._build_prompt(num_audio_tokens, prompt)
|
||||
inputs_embeds = self._build_inputs_embeds(prompt_ids, audio_features)
|
||||
mx.eval(inputs_embeds)
|
||||
|
||||
prompt_token_count = len(prompt_ids)
|
||||
|
||||
sampler = make_sampler(temperature, top_p=top_p, min_p=min_p, top_k=top_k)
|
||||
logits_processors = make_logits_processors(
|
||||
repetition_penalty=repetition_penalty,
|
||||
repetition_context_size=repetition_context_size,
|
||||
)
|
||||
|
||||
eos_token_id = self._tokenizer.eos_token_id
|
||||
gen_tokens = 0
|
||||
|
||||
for token, _ in generate_step(
|
||||
prompt=prompt_ids,
|
||||
input_embeddings=inputs_embeds.squeeze(0),
|
||||
model=self,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prefill_step_size=prefill_step_size,
|
||||
):
|
||||
if token == eos_token_id:
|
||||
break
|
||||
gen_tokens += 1
|
||||
text = self._tokenizer.decode([token], skip_special_tokens=True)
|
||||
yield StreamingResult(
|
||||
text=text,
|
||||
is_final=False,
|
||||
start_time=0.0,
|
||||
end_time=0.0,
|
||||
prompt_tokens=prompt_token_count,
|
||||
generation_tokens=gen_tokens,
|
||||
)
|
||||
|
||||
yield StreamingResult(
|
||||
text="",
|
||||
is_final=True,
|
||||
start_time=0.0,
|
||||
end_time=0.0,
|
||||
prompt_tokens=prompt_token_count,
|
||||
generation_tokens=gen_tokens,
|
||||
)
|
||||
|
||||
def _load_audio(self, audio: Union[str, mx.array, np.ndarray]) -> mx.array:
|
||||
if isinstance(audio, str):
|
||||
from ..audio import load_audio
|
||||
|
||||
return load_audio(audio)
|
||||
elif isinstance(audio, np.ndarray):
|
||||
return mx.array(audio, dtype=mx.float32)
|
||||
elif isinstance(audio, mx.array):
|
||||
return audio
|
||||
elif isinstance(audio, list):
|
||||
audio_item = audio[0]
|
||||
if isinstance(audio_item, str):
|
||||
from ..audio import load_audio
|
||||
|
||||
return load_audio(audio_item)
|
||||
return mx.array(np.array(audio_item), dtype=mx.float32)
|
||||
raise TypeError(f"Unsupported audio type: {type(audio)}")
|
||||
154
src/granite_speech_plus_mlx/_vendored/loader.py
Normal file
154
src/granite_speech_plus_mlx/_vendored/loader.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .granite_speech import Model, ModelConfig
|
||||
|
||||
DEFAULT_ALLOW_PATTERNS = [
|
||||
"*.json",
|
||||
"*.safetensors",
|
||||
"*.py",
|
||||
"*.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.jsonl",
|
||||
"*.yaml",
|
||||
"*.npz",
|
||||
]
|
||||
|
||||
|
||||
def _is_local_path(path: str) -> bool:
|
||||
return (
|
||||
path.startswith(".")
|
||||
or path.startswith("/")
|
||||
or path.startswith("~")
|
||||
or (len(path) > 1 and path[1] == ":")
|
||||
)
|
||||
|
||||
|
||||
def get_model_path(
|
||||
path_or_hf_repo: str | Path,
|
||||
*,
|
||||
revision: str | None = None,
|
||||
force_download: bool = False,
|
||||
allow_patterns: list[str] | None = None,
|
||||
) -> Path:
|
||||
if isinstance(path_or_hf_repo, Path):
|
||||
path = path_or_hf_repo.expanduser()
|
||||
if path.exists():
|
||||
return path
|
||||
raise FileNotFoundError(f"Local path not found: {path_or_hf_repo}")
|
||||
|
||||
path = Path(path_or_hf_repo).expanduser()
|
||||
if path.exists():
|
||||
return path
|
||||
if _is_local_path(path_or_hf_repo):
|
||||
raise FileNotFoundError(f"Local path not found: {path_or_hf_repo}")
|
||||
|
||||
return Path(
|
||||
snapshot_download(
|
||||
path_or_hf_repo,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns or DEFAULT_ALLOW_PATTERNS,
|
||||
force_download=force_download,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def load_config(model_path: str | Path) -> dict[str, Any]:
|
||||
model_path = Path(model_path)
|
||||
config_file = model_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise FileNotFoundError(f"Config not found at {model_path}")
|
||||
return json.loads(config_file.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def load_weights(model_path: Path) -> dict[str, mx.array]:
|
||||
weight_files = sorted(glob.glob(str(model_path / "*.safetensors")))
|
||||
if not weight_files:
|
||||
weight_files = sorted(glob.glob(str(model_path / "*.npz")))
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(
|
||||
f"No weight files (safetensors or npz) found in {model_path}"
|
||||
)
|
||||
|
||||
weights = {}
|
||||
for weight_file in weight_files:
|
||||
weights.update(mx.load(weight_file))
|
||||
return weights
|
||||
|
||||
|
||||
def apply_quantization(
|
||||
model: nn.Module,
|
||||
config: dict[str, Any],
|
||||
weights: dict[str, mx.array],
|
||||
model_quant_predicate=None,
|
||||
) -> None:
|
||||
quantization = config.get("quantization") or config.get("quantization_config")
|
||||
if quantization is None:
|
||||
return
|
||||
|
||||
group_size = quantization.get("group_size", 64)
|
||||
|
||||
def class_predicate(path, module):
|
||||
if not hasattr(module, "to_quantized"):
|
||||
return False
|
||||
if hasattr(module, "weight") and module.weight.shape[-1] % group_size != 0:
|
||||
return False
|
||||
if model_quant_predicate is not None:
|
||||
pred = model_quant_predicate(path, module)
|
||||
if isinstance(pred, dict):
|
||||
return pred
|
||||
if not pred:
|
||||
return False
|
||||
if path in quantization:
|
||||
return quantization[path]
|
||||
return f"{path}.scales" in weights
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=group_size,
|
||||
bits=quantization["bits"],
|
||||
mode=quantization.get("mode", "affine"),
|
||||
class_predicate=class_predicate,
|
||||
)
|
||||
|
||||
|
||||
def load_model(
|
||||
model_path: str | Path,
|
||||
*,
|
||||
lazy: bool = False,
|
||||
strict: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> nn.Module:
|
||||
path = get_model_path(
|
||||
model_path,
|
||||
revision=kwargs.pop("revision", None),
|
||||
force_download=kwargs.pop("force_download", False),
|
||||
allow_patterns=kwargs.pop("allow_patterns", None),
|
||||
)
|
||||
config = load_config(path)
|
||||
model = Model(ModelConfig.from_dict(config))
|
||||
weights = load_weights(path)
|
||||
|
||||
if hasattr(model, "sanitize"):
|
||||
weights = model.sanitize(weights)
|
||||
|
||||
apply_quantization(model, config, weights, model.model_quant_predicate)
|
||||
model.load_weights(list(weights.items()), strict=strict)
|
||||
|
||||
if not lazy:
|
||||
mx.eval(model.parameters())
|
||||
model.eval()
|
||||
|
||||
if hasattr(Model, "post_load_hook"):
|
||||
model = Model.post_load_hook(model, path)
|
||||
return model
|
||||
|
||||
56
src/granite_speech_plus_mlx/chunking.py
Normal file
56
src/granite_speech_plus_mlx/chunking.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
from typing import Any, Iterable, Iterator
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioChunk:
|
||||
index: int
|
||||
start: float
|
||||
end: float
|
||||
samples: Any
|
||||
|
||||
|
||||
def chunk_audio(
|
||||
audio: Any,
|
||||
sr: int,
|
||||
chunk_seconds: float,
|
||||
overlap_seconds: float = 2.0,
|
||||
) -> Iterator[AudioChunk]:
|
||||
if chunk_seconds <= 0:
|
||||
raise ValueError("chunk_seconds must be positive")
|
||||
if overlap_seconds < 0:
|
||||
raise ValueError("overlap_seconds cannot be negative")
|
||||
|
||||
chunk_samples = int(chunk_seconds * sr)
|
||||
overlap_samples = int(overlap_seconds * sr)
|
||||
if overlap_samples >= chunk_samples:
|
||||
raise ValueError("overlap_seconds must be smaller than chunk_seconds")
|
||||
|
||||
step = chunk_samples - overlap_samples
|
||||
n = len(audio)
|
||||
pos = 0
|
||||
index = 1
|
||||
while pos < n:
|
||||
end = min(pos + chunk_samples, n)
|
||||
yield AudioChunk(index, pos / sr, end / sr, audio[pos:end])
|
||||
if end == n:
|
||||
break
|
||||
pos += step
|
||||
index += 1
|
||||
|
||||
|
||||
def prefix_text(transcripts: Iterable[str], max_chars: int = 800) -> str:
|
||||
text = "\n".join(t.strip() for t in transcripts if t and t.strip())
|
||||
text = re.sub(r"^## \[[^\n]+\]\n", "", text, flags=re.MULTILINE)
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
|
||||
tail = text[-max_chars:]
|
||||
first_space = tail.find(" ")
|
||||
if first_space > 0:
|
||||
tail = tail[first_space + 1 :]
|
||||
return tail.strip()
|
||||
126
src/granite_speech_plus_mlx/pipeline.py
Normal file
126
src/granite_speech_plus_mlx/pipeline.py
Normal file
@@ -0,0 +1,126 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from .chunking import chunk_audio, prefix_text
|
||||
from .prompts import GRANITE_SYSTEM_PROMPT, PROMPT_MODES, build_prompt
|
||||
|
||||
DEFAULT_MODEL = "mlx-community/granite-speech-4.1-2b-plus-mlx"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraniteSpeechPlusPipeline:
|
||||
model: Any
|
||||
repo_id: str = DEFAULT_MODEL
|
||||
chunk_seconds: float = 300.0
|
||||
overlap_seconds: float = 2.0
|
||||
repetition_penalty: float = 1.2
|
||||
max_tokens: int = 4096
|
||||
system_prompt: str | None = GRANITE_SYSTEM_PROMPT
|
||||
verbose: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
repo_id: str = DEFAULT_MODEL,
|
||||
*,
|
||||
chunk_seconds: float = 300.0,
|
||||
overlap_seconds: float = 2.0,
|
||||
repetition_penalty: float = 1.2,
|
||||
max_tokens: int = 4096,
|
||||
system_prompt: str | None = GRANITE_SYSTEM_PROMPT,
|
||||
verbose: bool = False,
|
||||
**load_kwargs: Any,
|
||||
) -> "GraniteSpeechPlusPipeline":
|
||||
from ._vendored.loader import load_model
|
||||
|
||||
model = load_model(repo_id, **load_kwargs)
|
||||
return cls(
|
||||
model=model,
|
||||
repo_id=repo_id,
|
||||
chunk_seconds=chunk_seconds,
|
||||
overlap_seconds=overlap_seconds,
|
||||
repetition_penalty=repetition_penalty,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
def transcribe(self, audio_path: str | Path, prompt_mode: str = "asr") -> str:
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
if prompt_mode not in PROMPT_MODES:
|
||||
modes = ", ".join(sorted(PROMPT_MODES))
|
||||
raise ValueError(f"prompt_mode must be one of: {modes}")
|
||||
|
||||
audio_file = Path(audio_path).expanduser().resolve()
|
||||
audio, sr = librosa.load(str(audio_file), sr=16000, mono=True)
|
||||
audio = np.asarray(audio, dtype=np.float32)
|
||||
duration = len(audio) / sr if sr else 0.0
|
||||
chunks = list(
|
||||
chunk_audio(
|
||||
audio,
|
||||
sr,
|
||||
self.chunk_seconds,
|
||||
overlap_seconds=self.overlap_seconds,
|
||||
)
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print(
|
||||
f"Loaded {audio_file} ({duration:.1f}s, {len(chunks)} chunks)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
rendered: list[str] = []
|
||||
plain_texts: list[str] = []
|
||||
t_start = time.time()
|
||||
|
||||
for chunk in chunks:
|
||||
prompt = build_prompt(
|
||||
prompt_mode,
|
||||
prefix_text=prefix_text(plain_texts),
|
||||
)
|
||||
kwargs: dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
"max_tokens": self.max_tokens,
|
||||
}
|
||||
if self.system_prompt:
|
||||
kwargs["system_prompt"] = self.system_prompt
|
||||
if self.repetition_penalty and self.repetition_penalty > 1.0:
|
||||
kwargs["repetition_penalty"] = self.repetition_penalty
|
||||
|
||||
t0 = time.time()
|
||||
result = self.model.generate(chunk.samples, **kwargs)
|
||||
text = getattr(result, "text", result)
|
||||
if not isinstance(text, str):
|
||||
text = str(text)
|
||||
text = text.strip()
|
||||
plain_texts.append(text)
|
||||
|
||||
if len(chunks) > 1:
|
||||
rendered.append(f"## [{chunk.start:.1f}s - {chunk.end:.1f}s]\n{text}")
|
||||
else:
|
||||
rendered.append(text)
|
||||
|
||||
if self.verbose:
|
||||
elapsed = time.time() - t0
|
||||
rtf = (chunk.end - chunk.start) / elapsed if elapsed > 0 else 0.0
|
||||
print(
|
||||
f"[{chunk.index:>3}/{len(chunks)}] "
|
||||
f"{chunk.start:>6.1f}s-{chunk.end:<6.1f}s "
|
||||
f"gen={elapsed:>5.1f}s rtf={rtf:>4.1f}x {text[:80]}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
elapsed = time.time() - t_start
|
||||
rtf = duration / elapsed if elapsed > 0 else 0.0
|
||||
print(f"Total: {elapsed:.1f}s, rtf={rtf:.1f}x", file=sys.stderr)
|
||||
|
||||
return "\n\n".join(rendered).strip()
|
||||
65
src/granite_speech_plus_mlx/prompts.py
Normal file
65
src/granite_speech_plus_mlx/prompts.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
GRANITE_SYSTEM_PROMPT = (
|
||||
"Knowledge Cutoff Date: April 2024.\n"
|
||||
"Today's Date: December 19, 2024.\n"
|
||||
"You are Granite, developed by IBM. You are a helpful AI assistant"
|
||||
)
|
||||
|
||||
PROMPT_MODES = {
|
||||
"asr": "can you transcribe the speech into a written format?",
|
||||
"saa": (
|
||||
"Speaker attribution: Transcribe and denote who is speaking by adding "
|
||||
"[Speaker 1]: and [Speaker 2]: tags before speaker turns."
|
||||
),
|
||||
"ts": (
|
||||
"Timestamps: Transcribe the speech. After each word, add a timestamp tag "
|
||||
"showing the end time in centiseconds, e.g. hello [T:45] world [T:82]"
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def granite3_chat_template(
|
||||
content: str,
|
||||
*,
|
||||
system_prompt: str | None = GRANITE_SYSTEM_PROMPT,
|
||||
add_generation_prompt: bool = True,
|
||||
) -> str:
|
||||
"""Build the Granite-3 chat template used when tokenizers omit one."""
|
||||
start_role = "<|start_of_role|>"
|
||||
end_role = "<|end_of_role|>"
|
||||
end_text = "<|end_of_text|>"
|
||||
|
||||
parts: list[str] = []
|
||||
if system_prompt:
|
||||
parts.append(f"{start_role}system{end_role}{system_prompt}{end_text}\n")
|
||||
parts.append(f"{start_role}user{end_role}{content}{end_text}\n")
|
||||
if add_generation_prompt:
|
||||
parts.append(f"{start_role}assistant{end_role}")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def build_prompt(
|
||||
prompt_mode: str = "asr",
|
||||
*,
|
||||
prefix_text: str | None = None,
|
||||
custom_prompt: str | None = None,
|
||||
) -> str:
|
||||
if custom_prompt:
|
||||
base = custom_prompt
|
||||
else:
|
||||
try:
|
||||
base = PROMPT_MODES[prompt_mode]
|
||||
except KeyError as exc:
|
||||
modes = ", ".join(sorted(PROMPT_MODES))
|
||||
raise ValueError(f"prompt_mode must be one of: {modes}") from exc
|
||||
|
||||
if not prefix_text:
|
||||
return base
|
||||
|
||||
return (
|
||||
f"{base}\n\n"
|
||||
"Previous transcript context for continuity only. Do not repeat it:\n"
|
||||
f"{prefix_text.strip()}"
|
||||
)
|
||||
|
||||
6
tests/test_smoke.py
Normal file
6
tests/test_smoke.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from granite_speech_plus_mlx import GraniteSpeechPlusPipeline
|
||||
|
||||
|
||||
def test_pipeline_symbol_exists():
|
||||
assert GraniteSpeechPlusPipeline is not None
|
||||
|
||||
Reference in New Issue
Block a user