Initial Granite Speech Plus MLX package
This commit is contained in:
30
.gitignore
vendored
Normal file
30
.gitignore
vendored
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
.Python
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env/
|
||||||
|
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
*.egg-info/
|
||||||
|
.eggs/
|
||||||
|
|
||||||
|
.pytest_cache/
|
||||||
|
.mypy_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
|
||||||
|
.DS_Store
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
|
||||||
|
*.log
|
||||||
|
*.tmp
|
||||||
|
transcripts/
|
||||||
|
bench/
|
||||||
|
|
||||||
22
LICENSE
Normal file
22
LICENSE
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2026 Olivier Dupont
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|
||||||
42
README.md
Normal file
42
README.md
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# granite-speech-4.1-2b-plus-mlx
|
||||||
|
|
||||||
|
Standalone Python package for the MLX port of IBM Granite Speech 4.1-2b-plus.
|
||||||
|
The default model is
|
||||||
|
[`mlx-community/granite-speech-4.1-2b-plus-mlx`](https://huggingface.co/mlx-community/granite-speech-4.1-2b-plus-mlx).
|
||||||
|
|
||||||
|
## Quickstart
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv add "granite-speech-4.1-2b-plus-mlx @ git+https://gitea.tavportal.com/olivier/granite-speech-4.1-2b-plus-mlx.git"
|
||||||
|
python -c "from granite_speech_plus_mlx import GraniteSpeechPlusPipeline as P; p=P.from_pretrained(); print(p.transcribe('audio.wav'))"
|
||||||
|
python scripts/transcribe.py audio.wav --prompt-mode asr --output transcript.txt
|
||||||
|
python scripts/transcribe.py meeting.wav --prompt-mode saa
|
||||||
|
python scripts/benchmark.py audio.wav --results bench
|
||||||
|
```
|
||||||
|
|
||||||
|
## Prompt Modes
|
||||||
|
|
||||||
|
- `asr`: standard transcription.
|
||||||
|
- `saa`: speaker-attributed ASR with `[Speaker N]:` turn labels.
|
||||||
|
- `ts`: word-level timestamp tags like `word [T:45]`.
|
||||||
|
|
||||||
|
See [docs/prompt-modes.md](docs/prompt-modes.md) for examples.
|
||||||
|
|
||||||
|
## Benchmark Hints
|
||||||
|
|
||||||
|
Granite Speech 4.1 allocates substantial encoder memory for long audio. Start
|
||||||
|
with `--chunk-seconds 300 --repetition-penalty 1.2` for ASR and reduce chunks
|
||||||
|
to 60 or 180 seconds if memory is tight. Timestamp mode (`ts`) often needs a
|
||||||
|
larger `--max-tokens` budget because every word carries a timestamp tag.
|
||||||
|
|
||||||
|
## Provenance
|
||||||
|
|
||||||
|
This package was extracted from the local `MLX_CONVERTOR` project, including
|
||||||
|
the Granite Speech patch bundle at
|
||||||
|
`external/patches/granite-speech-idempotent-sanitize.patch`. The vendored
|
||||||
|
Granite implementation is based on `mlx-audio` commit
|
||||||
|
`f7c11556eda88731be5cc75ddbdf4a4cb9eeaafc` plus that local patch.
|
||||||
|
|
||||||
|
Package code is MIT licensed. Model weights remain under the IBM Granite model
|
||||||
|
license; review the model card and license terms before redistribution or use.
|
||||||
|
|
||||||
37
docs/prompt-modes.md
Normal file
37
docs/prompt-modes.md
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# Prompt Modes
|
||||||
|
|
||||||
|
Granite Speech Plus supports three prompt modes in this package.
|
||||||
|
|
||||||
|
## `asr`
|
||||||
|
|
||||||
|
Standard speech transcription.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from granite_speech_plus_mlx import GraniteSpeechPlusPipeline
|
||||||
|
|
||||||
|
pipe = GraniteSpeechPlusPipeline.from_pretrained()
|
||||||
|
text = pipe.transcribe("audio.wav", prompt_mode="asr")
|
||||||
|
```
|
||||||
|
|
||||||
|
## `saa`
|
||||||
|
|
||||||
|
Speaker-attributed ASR. The prompt asks the model to add speaker turn labels
|
||||||
|
such as `[Speaker 1]:` and `[Speaker 2]:`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
text = pipe.transcribe("meeting.wav", prompt_mode="saa")
|
||||||
|
```
|
||||||
|
|
||||||
|
## `ts`
|
||||||
|
|
||||||
|
Word-level timestamps. The prompt asks the model to append centisecond tags
|
||||||
|
after words, for example `hello [T:45] world [T:82]`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
text = pipe.transcribe("clip.wav", prompt_mode="ts")
|
||||||
|
```
|
||||||
|
|
||||||
|
For long audio, the pipeline chunks the waveform and feeds a short previous
|
||||||
|
transcript prefix into later chunks for continuity. The prefix is context only;
|
||||||
|
the model is instructed not to repeat it.
|
||||||
|
|
||||||
27
pyproject.toml
Normal file
27
pyproject.toml
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "granite-speech-4.1-2b-plus-mlx"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Standalone MLX pipeline for the Granite Speech 4.1-2b-plus port."
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
license = "MIT"
|
||||||
|
authors = [
|
||||||
|
{ name = "Olivier Dupont" }
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"mlx>=0.22.0",
|
||||||
|
"mlx-lm>=0.19.0",
|
||||||
|
"numpy>=1.26",
|
||||||
|
"transformers>=4.45",
|
||||||
|
"huggingface-hub>=0.24",
|
||||||
|
"soundfile>=0.12",
|
||||||
|
"librosa>=0.10",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["src/granite_speech_plus_mlx"]
|
||||||
|
|
||||||
111
scripts/benchmark.py
Executable file
111
scripts/benchmark.py
Executable file
@@ -0,0 +1,111 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from collections import Counter
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from granite_speech_plus_mlx import GraniteSpeechPlusPipeline
|
||||||
|
from granite_speech_plus_mlx.pipeline import DEFAULT_MODEL
|
||||||
|
from granite_speech_plus_mlx.prompts import PROMPT_MODES
|
||||||
|
|
||||||
|
GRID = [
|
||||||
|
(60, 1.0),
|
||||||
|
(60, 1.2),
|
||||||
|
(180, 1.0),
|
||||||
|
(180, 1.2),
|
||||||
|
(300, 1.0),
|
||||||
|
(300, 1.2),
|
||||||
|
(300, 1.4),
|
||||||
|
]
|
||||||
|
|
||||||
|
HALLUCINATION_MARKERS = ("thank you very much", "merci d'avoir regarde")
|
||||||
|
|
||||||
|
|
||||||
|
def analyze(text: str) -> dict:
|
||||||
|
words = text.split()
|
||||||
|
lower_words = text.lower().split()
|
||||||
|
trigrams = Counter(
|
||||||
|
" ".join(lower_words[i : i + 3]) for i in range(len(lower_words) - 2)
|
||||||
|
)
|
||||||
|
top = trigrams.most_common(5)
|
||||||
|
lower = text.lower()
|
||||||
|
return {
|
||||||
|
"n_words": len(words),
|
||||||
|
"max_trigram_count": top[0][1] if top else 0,
|
||||||
|
"max_trigram_text": top[0][0] if top else "",
|
||||||
|
"halluc": {m: lower.count(m) for m in HALLUCINATION_MARKERS},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(description="Benchmark Granite Speech Plus MLX settings.")
|
||||||
|
parser.add_argument("audio")
|
||||||
|
parser.add_argument("--model", default=DEFAULT_MODEL)
|
||||||
|
parser.add_argument("--results", default="bench")
|
||||||
|
parser.add_argument("--prompt-mode", choices=sorted(PROMPT_MODES), default="asr")
|
||||||
|
parser.add_argument("--overlap-seconds", type=float, default=2.0)
|
||||||
|
parser.add_argument("--max-tokens", type=int, default=4096)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
results_dir = Path(args.results)
|
||||||
|
results_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
pipe = GraniteSpeechPlusPipeline.from_pretrained(
|
||||||
|
args.model,
|
||||||
|
overlap_seconds=args.overlap_seconds,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for chunk_seconds, repetition_penalty in GRID:
|
||||||
|
out = results_dir / f"chunk{chunk_seconds}_rp{repetition_penalty:.1f}.txt"
|
||||||
|
pipe.chunk_seconds = float(chunk_seconds)
|
||||||
|
pipe.repetition_penalty = repetition_penalty
|
||||||
|
|
||||||
|
if out.exists():
|
||||||
|
print(f"# skipping {out.name} (already exists, delete to rerun)", file=sys.stderr)
|
||||||
|
elapsed = float("nan")
|
||||||
|
text = out.read_text(encoding="utf-8")
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"# running chunk={chunk_seconds}s rep_penalty={repetition_penalty}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
t0 = time.time()
|
||||||
|
text = pipe.transcribe(args.audio, prompt_mode=args.prompt_mode)
|
||||||
|
elapsed = time.time() - t0
|
||||||
|
out.write_text(text + "\n", encoding="utf-8")
|
||||||
|
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"chunk": chunk_seconds,
|
||||||
|
"rp": repetition_penalty,
|
||||||
|
"elapsed": elapsed,
|
||||||
|
**analyze(text),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("| chunk(s) | rp | wall(s) | words | max_trigram(N) | hallucinations |")
|
||||||
|
print("|---:|---:|---:|---:|:---|:---|")
|
||||||
|
for row in rows:
|
||||||
|
halluc = ", ".join(
|
||||||
|
f"{key.split()[0]}x{value}" for key, value in row["halluc"].items() if value
|
||||||
|
) or "-"
|
||||||
|
trigram = f"{row['max_trigram_text']!r} ({row['max_trigram_count']}x)"
|
||||||
|
wall = "nan" if row["elapsed"] != row["elapsed"] else f"{row['elapsed']:.0f}"
|
||||||
|
print(
|
||||||
|
f"| {row['chunk']} | {row['rp']:.1f} | {wall} | {row['n_words']} "
|
||||||
|
f"| {trigram} | {halluc} |"
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
print(f"Per-config transcripts in: {results_dir}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
|
|
||||||
47
scripts/transcribe.py
Executable file
47
scripts/transcribe.py
Executable file
@@ -0,0 +1,47 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from granite_speech_plus_mlx import GraniteSpeechPlusPipeline
|
||||||
|
from granite_speech_plus_mlx.pipeline import DEFAULT_MODEL
|
||||||
|
from granite_speech_plus_mlx.prompts import GRANITE_SYSTEM_PROMPT, PROMPT_MODES
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(description="Transcribe audio with Granite Speech Plus MLX.")
|
||||||
|
parser.add_argument("audio")
|
||||||
|
parser.add_argument("--model", default=DEFAULT_MODEL)
|
||||||
|
parser.add_argument("--output", default=None)
|
||||||
|
parser.add_argument("--chunk-seconds", type=float, default=300.0)
|
||||||
|
parser.add_argument("--overlap-seconds", type=float, default=2.0)
|
||||||
|
parser.add_argument("--prompt-mode", choices=sorted(PROMPT_MODES), default="asr")
|
||||||
|
parser.add_argument("--repetition-penalty", type=float, default=1.2)
|
||||||
|
parser.add_argument("--max-tokens", type=int, default=4096)
|
||||||
|
parser.add_argument("--system-prompt", default=GRANITE_SYSTEM_PROMPT)
|
||||||
|
parser.add_argument("--verbose", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
pipe = GraniteSpeechPlusPipeline.from_pretrained(
|
||||||
|
args.model,
|
||||||
|
chunk_seconds=args.chunk_seconds,
|
||||||
|
overlap_seconds=args.overlap_seconds,
|
||||||
|
repetition_penalty=args.repetition_penalty,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
system_prompt=args.system_prompt or None,
|
||||||
|
verbose=args.verbose,
|
||||||
|
)
|
||||||
|
text = pipe.transcribe(args.audio, prompt_mode=args.prompt_mode)
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
Path(args.output).write_text(text + "\n", encoding="utf-8")
|
||||||
|
else:
|
||||||
|
print(text)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
|
|
||||||
66
scripts/upload_to_hf.py
Executable file
66
scripts/upload_to_hf.py
Executable file
@@ -0,0 +1,66 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi
|
||||||
|
|
||||||
|
SOURCE_CACHE = (
|
||||||
|
Path.home()
|
||||||
|
/ ".cache/huggingface/hub/models--ibm-granite--granite-speech-4.1-2b-plus"
|
||||||
|
)
|
||||||
|
DEST_REPO = "mlx-community/granite-speech-4.1-2b-plus-mlx"
|
||||||
|
|
||||||
|
|
||||||
|
def find_weights_dir(root: Path) -> Path | None:
|
||||||
|
if not root.exists():
|
||||||
|
return None
|
||||||
|
if list(root.glob("*.safetensors")) or (root / "config.json").exists():
|
||||||
|
return root
|
||||||
|
snapshots = root / "snapshots"
|
||||||
|
if snapshots.exists():
|
||||||
|
candidates = [
|
||||||
|
path
|
||||||
|
for path in snapshots.iterdir()
|
||||||
|
if path.is_dir() and (list(path.glob("*.safetensors")) or (path / "config.json").exists())
|
||||||
|
]
|
||||||
|
if candidates:
|
||||||
|
return sorted(candidates, key=lambda p: p.stat().st_mtime)[-1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def print_manual_commands() -> None:
|
||||||
|
print(f"MLX weights not found at {SOURCE_CACHE}")
|
||||||
|
print("Create them first with:")
|
||||||
|
print("mlxconv ibm-granite/granite-speech-4.1-2b-plus")
|
||||||
|
print("mlxconv ibm-granite/granite-speech-4.1-2b-plus --dtype q4_k_4")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
weights_dir = find_weights_dir(SOURCE_CACHE)
|
||||||
|
if weights_dir is None:
|
||||||
|
print_manual_commands()
|
||||||
|
return 1
|
||||||
|
|
||||||
|
token = os.environ.get("HF_TOKEN")
|
||||||
|
if not token:
|
||||||
|
print("HF_TOKEN is required to upload.", file=sys.stderr)
|
||||||
|
return 2
|
||||||
|
|
||||||
|
api = HfApi(token=token)
|
||||||
|
api.create_repo(DEST_REPO, repo_type="model", exist_ok=True)
|
||||||
|
api.upload_folder(
|
||||||
|
repo_id=DEST_REPO,
|
||||||
|
repo_type="model",
|
||||||
|
folder_path=str(weights_dir),
|
||||||
|
commit_message="Upload Granite Speech 4.1-2b-plus MLX weights",
|
||||||
|
)
|
||||||
|
print(f"Uploaded {weights_dir} to {DEST_REPO}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
|
|
||||||
4
src/granite_speech_plus_mlx/__init__.py
Normal file
4
src/granite_speech_plus_mlx/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .pipeline import GraniteSpeechPlusPipeline
|
||||||
|
|
||||||
|
__all__ = ["GraniteSpeechPlusPipeline"]
|
||||||
|
|
||||||
1
src/granite_speech_plus_mlx/_vendored/__init__.py
Normal file
1
src/granite_speech_plus_mlx/_vendored/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
15
src/granite_speech_plus_mlx/_vendored/audio.py
Normal file
15
src/granite_speech_plus_mlx/_vendored/audio.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
|
||||||
|
SAMPLE_RATE = 16000
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(file: str | Path, sr: int = SAMPLE_RATE):
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
audio, _ = librosa.load(str(file), sr=sr, mono=True)
|
||||||
|
return mx.array(audio, dtype=mx.float32)
|
||||||
|
|
||||||
18
src/granite_speech_plus_mlx/_vendored/base.py
Normal file
18
src/granite_speech_plus_mlx/_vendored/base.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class STTOutput:
|
||||||
|
text: str
|
||||||
|
segments: List[dict] | None = None
|
||||||
|
language: str | None = None
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
generation_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
prompt_tps: float = 0.0
|
||||||
|
generation_tps: float = 0.0
|
||||||
|
total_time: float = 0.0
|
||||||
|
|
||||||
161
src/granite_speech_plus_mlx/_vendored/dsp.py
Normal file
161
src/granite_speech_plus_mlx/_vendored/dsp.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def hanning(size: int, periodic: bool = False):
|
||||||
|
denom = size if periodic else size - 1
|
||||||
|
return mx.array(
|
||||||
|
[0.5 * (1 - math.cos(2 * math.pi * n / denom)) for n in range(size)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def hamming(size: int, periodic: bool = False):
|
||||||
|
denom = size if periodic else size - 1
|
||||||
|
return mx.array(
|
||||||
|
[0.54 - 0.46 * math.cos(2 * math.pi * n / denom) for n in range(size)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def blackman(size: int, periodic: bool = False):
|
||||||
|
denom = size if periodic else size - 1
|
||||||
|
return mx.array(
|
||||||
|
[
|
||||||
|
0.42
|
||||||
|
- 0.5 * math.cos(2 * math.pi * n / denom)
|
||||||
|
+ 0.08 * math.cos(4 * math.pi * n / denom)
|
||||||
|
for n in range(size)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def bartlett(size: int, periodic: bool = False):
|
||||||
|
denom = size if periodic else size - 1
|
||||||
|
return mx.array([1 - 2 * abs(n - denom / 2) / denom for n in range(size)])
|
||||||
|
|
||||||
|
|
||||||
|
STR_TO_WINDOW_FN = {
|
||||||
|
"hann": hanning,
|
||||||
|
"hanning": hanning,
|
||||||
|
"hamming": hamming,
|
||||||
|
"blackman": blackman,
|
||||||
|
"bartlett": bartlett,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def stft(
|
||||||
|
x,
|
||||||
|
n_fft: int = 800,
|
||||||
|
hop_length: int | None = None,
|
||||||
|
win_length: int | None = None,
|
||||||
|
window: mx.array | str = "hann",
|
||||||
|
center: bool = True,
|
||||||
|
pad_mode: str = "reflect",
|
||||||
|
):
|
||||||
|
if hop_length is None:
|
||||||
|
hop_length = n_fft // 4
|
||||||
|
if win_length is None:
|
||||||
|
win_length = n_fft
|
||||||
|
|
||||||
|
if isinstance(window, str):
|
||||||
|
window_fn = STR_TO_WINDOW_FN.get(window.lower())
|
||||||
|
if window_fn is None:
|
||||||
|
raise ValueError(f"Unknown window function: {window}")
|
||||||
|
w = window_fn(win_length)
|
||||||
|
else:
|
||||||
|
w = window
|
||||||
|
|
||||||
|
if w.shape[0] < n_fft:
|
||||||
|
pad_size = n_fft - w.shape[0]
|
||||||
|
w = mx.concatenate([w, mx.zeros((pad_size,))], axis=0)
|
||||||
|
|
||||||
|
def _pad(signal, padding: int, mode: str = "reflect"):
|
||||||
|
if mode == "constant":
|
||||||
|
return mx.pad(signal, [(padding, padding)])
|
||||||
|
if mode == "reflect":
|
||||||
|
prefix = signal[1 : padding + 1][::-1]
|
||||||
|
suffix = signal[-(padding + 1) : -1][::-1]
|
||||||
|
return mx.concatenate([prefix, signal, suffix])
|
||||||
|
raise ValueError(f"Invalid pad_mode {mode}")
|
||||||
|
|
||||||
|
if center:
|
||||||
|
x = _pad(x, n_fft // 2, pad_mode)
|
||||||
|
|
||||||
|
num_frames = 1 + (x.shape[0] - n_fft) // hop_length
|
||||||
|
if num_frames <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Input is too short for n_fft={n_fft}, hop_length={hop_length}, "
|
||||||
|
f"center={center}."
|
||||||
|
)
|
||||||
|
|
||||||
|
frames = mx.as_strided(x, shape=(num_frames, n_fft), strides=(hop_length, 1))
|
||||||
|
return mx.fft.rfft(frames * w)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def mel_filters(
|
||||||
|
sample_rate: int,
|
||||||
|
n_fft: int,
|
||||||
|
n_mels: int,
|
||||||
|
f_min: float = 0,
|
||||||
|
f_max: Optional[float] = None,
|
||||||
|
norm: Optional[str] = None,
|
||||||
|
mel_scale: str = "htk",
|
||||||
|
) -> mx.array:
|
||||||
|
def hz_to_mel(freq, scale="htk"):
|
||||||
|
if scale == "htk":
|
||||||
|
return 2595.0 * math.log10(1.0 + freq / 700.0)
|
||||||
|
|
||||||
|
f_sp = 200.0 / 3
|
||||||
|
mels = freq / f_sp
|
||||||
|
min_log_hz = 1000.0
|
||||||
|
min_log_mel = min_log_hz / f_sp
|
||||||
|
logstep = math.log(6.4) / 27.0
|
||||||
|
if freq >= min_log_hz:
|
||||||
|
mels = min_log_mel + math.log(freq / min_log_hz) / logstep
|
||||||
|
return mels
|
||||||
|
|
||||||
|
def mel_to_hz(mels, scale="htk"):
|
||||||
|
if scale == "htk":
|
||||||
|
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
|
||||||
|
|
||||||
|
f_sp = 200.0 / 3
|
||||||
|
freqs = f_sp * mels
|
||||||
|
min_log_hz = 1000.0
|
||||||
|
min_log_mel = min_log_hz / f_sp
|
||||||
|
logstep = math.log(6.4) / 27.0
|
||||||
|
return mx.where(
|
||||||
|
mels >= min_log_mel,
|
||||||
|
min_log_hz * mx.exp(logstep * (mels - min_log_mel)),
|
||||||
|
freqs,
|
||||||
|
)
|
||||||
|
|
||||||
|
f_max = f_max or sample_rate / 2
|
||||||
|
n_freqs = n_fft // 2 + 1
|
||||||
|
all_freqs = mx.linspace(0, sample_rate // 2, n_freqs)
|
||||||
|
m_min = hz_to_mel(f_min, mel_scale)
|
||||||
|
m_max = hz_to_mel(f_max, mel_scale)
|
||||||
|
m_pts = mx.linspace(m_min, m_max, n_mels + 2)
|
||||||
|
f_pts = mel_to_hz(m_pts, mel_scale)
|
||||||
|
f_diff = f_pts[1:] - f_pts[:-1]
|
||||||
|
slopes = mx.expand_dims(f_pts, 0) - mx.expand_dims(all_freqs, 1)
|
||||||
|
down_slopes = (-slopes[:, :-2]) / f_diff[:-1]
|
||||||
|
up_slopes = slopes[:, 2:] / f_diff[1:]
|
||||||
|
filterbank = mx.maximum(
|
||||||
|
mx.zeros_like(down_slopes), mx.minimum(down_slopes, up_slopes)
|
||||||
|
)
|
||||||
|
|
||||||
|
if norm == "slaney":
|
||||||
|
enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
|
||||||
|
filterbank *= mx.expand_dims(enorm, 0)
|
||||||
|
|
||||||
|
return filterbank.moveaxis(0, 1)
|
||||||
|
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from .config import EncoderConfig, ModelConfig, ProjectorConfig, TextConfig
|
||||||
|
from .granite_speech import Model
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GraniteSpeechPlusModelConfig(ModelConfig):
|
||||||
|
model_type: str = "granite_speech_plus"
|
||||||
|
|
||||||
|
|
||||||
|
DETECTION_HINTS = {
|
||||||
|
"config_keys": {"encoder_config", "projector_config", "audio_token_index"},
|
||||||
|
"architectures": {
|
||||||
|
"GraniteSpeechForConditionalGeneration",
|
||||||
|
"GraniteSpeechPlusForConditionalGeneration",
|
||||||
|
},
|
||||||
|
"path_patterns": {
|
||||||
|
"granite_speech_plus",
|
||||||
|
"granitespeechplus",
|
||||||
|
"granite-speech-4.1-2b-plus",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EncoderConfig",
|
||||||
|
"ProjectorConfig",
|
||||||
|
"TextConfig",
|
||||||
|
"ModelConfig",
|
||||||
|
"GraniteSpeechPlusModelConfig",
|
||||||
|
"Model",
|
||||||
|
"DETECTION_HINTS",
|
||||||
|
]
|
||||||
|
|
||||||
128
src/granite_speech_plus_mlx/_vendored/granite_speech/config.py
Normal file
128
src/granite_speech_plus_mlx/_vendored/granite_speech/config.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
import inspect
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EncoderConfig:
|
||||||
|
input_dim: int = 160
|
||||||
|
num_layers: int = 10
|
||||||
|
hidden_dim: int = 1024
|
||||||
|
feedforward_mult: int = 4
|
||||||
|
num_heads: int = 8
|
||||||
|
dim_head: int = 128
|
||||||
|
output_dim: int = 42
|
||||||
|
context_size: int = 200
|
||||||
|
max_pos_emb: int = 512
|
||||||
|
dropout: float = 0.1
|
||||||
|
conv_kernel_size: int = 15
|
||||||
|
conv_expansion_factor: int = 2
|
||||||
|
# Plus variant: indices of intermediate encoder layers whose hidden state
|
||||||
|
# gets concatenated with the final-layer hidden state along the channel
|
||||||
|
# axis, before being fed to the projector. None / empty = base behavior.
|
||||||
|
cat_hidden_layers: Optional[List[int]] = None
|
||||||
|
model_type: str = "granite_speech_encoder"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, params):
|
||||||
|
return cls(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in params.items()
|
||||||
|
if k in inspect.signature(cls).parameters
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProjectorConfig:
|
||||||
|
hidden_size: int = 1024
|
||||||
|
num_hidden_layers: int = 2
|
||||||
|
num_attention_heads: int = 16
|
||||||
|
intermediate_size: int = 4096
|
||||||
|
hidden_act: str = "gelu"
|
||||||
|
layer_norm_eps: float = 1e-12
|
||||||
|
encoder_hidden_size: int = 1024
|
||||||
|
cross_attention_frequency: int = 1
|
||||||
|
model_type: str = "blip_2_qformer"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, params):
|
||||||
|
return cls(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in params.items()
|
||||||
|
if k in inspect.signature(cls).parameters
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TextConfig:
|
||||||
|
model_type: str = "granite"
|
||||||
|
vocab_size: int = 100353
|
||||||
|
hidden_size: int = 2048
|
||||||
|
intermediate_size: int = 4096
|
||||||
|
num_hidden_layers: int = 40
|
||||||
|
num_attention_heads: int = 16
|
||||||
|
num_key_value_heads: int = 4
|
||||||
|
hidden_act: str = "silu"
|
||||||
|
max_position_embeddings: int = 4096
|
||||||
|
rms_norm_eps: float = 1e-5
|
||||||
|
rope_theta: float = 10000.0
|
||||||
|
rope_scaling: Optional[Dict] = None
|
||||||
|
attention_bias: bool = False
|
||||||
|
mlp_bias: bool = False
|
||||||
|
attention_multiplier: float = 0.0078125
|
||||||
|
embedding_multiplier: float = 12.0
|
||||||
|
residual_multiplier: float = 0.22
|
||||||
|
logits_scaling: float = 8.0
|
||||||
|
tie_word_embeddings: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, params):
|
||||||
|
return cls(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in params.items()
|
||||||
|
if k in inspect.signature(cls).parameters
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelConfig:
|
||||||
|
model_type: str = "granite_speech"
|
||||||
|
encoder_config: EncoderConfig = None
|
||||||
|
projector_config: ProjectorConfig = None
|
||||||
|
text_config: TextConfig = None
|
||||||
|
audio_token_index: int = 100352
|
||||||
|
downsample_rate: int = 5
|
||||||
|
window_size: int = 15
|
||||||
|
has_lora_adapter: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if isinstance(self.encoder_config, dict):
|
||||||
|
self.encoder_config = EncoderConfig.from_dict(self.encoder_config)
|
||||||
|
elif self.encoder_config is None:
|
||||||
|
self.encoder_config = EncoderConfig()
|
||||||
|
|
||||||
|
if isinstance(self.projector_config, dict):
|
||||||
|
self.projector_config = ProjectorConfig.from_dict(self.projector_config)
|
||||||
|
elif self.projector_config is None:
|
||||||
|
self.projector_config = ProjectorConfig()
|
||||||
|
|
||||||
|
if isinstance(self.text_config, dict):
|
||||||
|
self.text_config = TextConfig.from_dict(self.text_config)
|
||||||
|
elif self.text_config is None:
|
||||||
|
self.text_config = TextConfig()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, params):
|
||||||
|
return cls(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in params.items()
|
||||||
|
if k in inspect.signature(cls).parameters
|
||||||
|
}
|
||||||
|
)
|
||||||
@@ -0,0 +1,850 @@
|
|||||||
|
import math
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
|
from mlx_lm.models.base import create_attention_mask
|
||||||
|
from mlx_lm.models.cache import KVCache
|
||||||
|
from mlx_lm.models.granite import Model as GraniteLM
|
||||||
|
from mlx_lm.models.granite import ModelArgs as GraniteModelArgs
|
||||||
|
|
||||||
|
from ..base import STTOutput
|
||||||
|
|
||||||
|
from .config import EncoderConfig, ModelConfig, ProjectorConfig
|
||||||
|
|
||||||
|
LANGUAGE_CODES = {
|
||||||
|
"en": "English",
|
||||||
|
"fr": "French",
|
||||||
|
"de": "German",
|
||||||
|
"es": "Spanish",
|
||||||
|
"pt": "Portuguese",
|
||||||
|
"ja": "Japanese",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamingResult:
|
||||||
|
text: str
|
||||||
|
is_final: bool
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
language: str = "en"
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
generation_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class BatchNorm1d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, num_features: int, eps: float = 1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = mx.ones((num_features,))
|
||||||
|
self.bias = mx.zeros((num_features,))
|
||||||
|
self.running_mean = mx.zeros((num_features,))
|
||||||
|
self.running_var = mx.ones((num_features,))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return (x - self.running_mean) / mx.sqrt(
|
||||||
|
self.running_var + self.eps
|
||||||
|
) * self.weight + self.bias
|
||||||
|
|
||||||
|
|
||||||
|
class ConformerFeedForward(nn.Module):
|
||||||
|
def __init__(self, config: EncoderConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.pre_norm = nn.LayerNorm(config.hidden_dim)
|
||||||
|
self.up_proj = nn.Linear(
|
||||||
|
config.hidden_dim, config.hidden_dim * config.feedforward_mult
|
||||||
|
)
|
||||||
|
self.down_proj = nn.Linear(
|
||||||
|
config.hidden_dim * config.feedforward_mult, config.hidden_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
x = self.pre_norm(x)
|
||||||
|
x = nn.silu(self.up_proj(x))
|
||||||
|
x = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConformerAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: EncoderConfig):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = config.dim_head * config.num_heads
|
||||||
|
self.max_pos_emb = config.max_pos_emb
|
||||||
|
self.context_size = config.context_size
|
||||||
|
self.num_heads = config.num_heads
|
||||||
|
self.dim_head = config.dim_head
|
||||||
|
self.scale = config.dim_head**-0.5
|
||||||
|
self.pre_norm = nn.LayerNorm(config.hidden_dim)
|
||||||
|
self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False)
|
||||||
|
self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False)
|
||||||
|
self.to_out = nn.Linear(inner_dim, config.hidden_dim)
|
||||||
|
self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array, attention_dists: mx.array) -> mx.array:
|
||||||
|
x = self.pre_norm(x)
|
||||||
|
B, N, _ = x.shape
|
||||||
|
|
||||||
|
num_blocks = math.ceil(N / self.context_size)
|
||||||
|
remainder = N % self.context_size
|
||||||
|
|
||||||
|
if remainder > 0:
|
||||||
|
pad_len = self.context_size - remainder
|
||||||
|
x = mx.pad(x, [(0, 0), (0, pad_len), (0, 0)])
|
||||||
|
|
||||||
|
q = self.to_q(x)
|
||||||
|
kv = self.to_kv(x)
|
||||||
|
k, v = mx.split(kv, 2, axis=-1)
|
||||||
|
|
||||||
|
q = q.reshape(B, num_blocks, self.context_size, self.num_heads, -1)
|
||||||
|
k = k.reshape(B, num_blocks, self.context_size, self.num_heads, -1)
|
||||||
|
v = v.reshape(B, num_blocks, self.context_size, self.num_heads, -1)
|
||||||
|
|
||||||
|
q = q.transpose(0, 1, 3, 2, 4)
|
||||||
|
k = k.transpose(0, 1, 3, 2, 4)
|
||||||
|
v = v.transpose(0, 1, 3, 2, 4)
|
||||||
|
|
||||||
|
rel_pos_emb = self.rel_pos_emb(attention_dists)
|
||||||
|
|
||||||
|
C = self.context_size
|
||||||
|
pos_attn = (
|
||||||
|
mx.sum(
|
||||||
|
q[:, :, :, :, None, :] * rel_pos_emb[None, None, None, :, :, :],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
* self.scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if remainder > 0:
|
||||||
|
row_valid = mx.arange(C)[:, None] < remainder
|
||||||
|
col_valid = mx.arange(C)[None, :] < remainder
|
||||||
|
mask = ~(row_valid & col_valid)
|
||||||
|
mask_value = mx.array(mx.finfo(pos_attn.dtype).min)
|
||||||
|
pos_attn_last = mx.where(
|
||||||
|
mask[None, None, None], mask_value, pos_attn[:, -1:, :, :, :]
|
||||||
|
)
|
||||||
|
pos_attn = mx.concatenate(
|
||||||
|
[pos_attn[:, :-1, :, :, :], pos_attn_last], axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_weights = (q @ k.transpose(0, 1, 2, 4, 3)) * self.scale + pos_attn
|
||||||
|
attn_weights = mx.softmax(attn_weights, axis=-1)
|
||||||
|
|
||||||
|
out = attn_weights @ v
|
||||||
|
out = out.transpose(0, 1, 3, 2, 4)
|
||||||
|
out = out.reshape(B, -1, self.num_heads * self.dim_head)
|
||||||
|
out = out[:, :N, :]
|
||||||
|
out = self.to_out(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class DepthWiseConv1d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, chan_in: int, chan_out: int, kernel_size: int):
|
||||||
|
super().__init__()
|
||||||
|
pad = kernel_size // 2
|
||||||
|
pad_offset = (kernel_size + 1) % 2
|
||||||
|
self.padding = (pad, pad - pad_offset)
|
||||||
|
self.conv = nn.Conv1d(
|
||||||
|
chan_in, chan_out, kernel_size, groups=chan_in, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
x = mx.pad(x, [(0, 0), (self.padding[0], self.padding[1]), (0, 0)])
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ConformerConvModule(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: EncoderConfig):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = config.hidden_dim * config.conv_expansion_factor
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm(config.hidden_dim)
|
||||||
|
self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1)
|
||||||
|
self.depth_conv = DepthWiseConv1d(inner_dim, inner_dim, config.conv_kernel_size)
|
||||||
|
self.batch_norm = BatchNorm1d(inner_dim)
|
||||||
|
self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.up_conv(x)
|
||||||
|
x1, x2 = mx.split(x, 2, axis=-1)
|
||||||
|
x = x1 * mx.sigmoid(x2)
|
||||||
|
x = self.depth_conv(x)
|
||||||
|
x = nn.silu(self.batch_norm(x))
|
||||||
|
x = self.down_conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConformerBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: EncoderConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.ff1 = ConformerFeedForward(config)
|
||||||
|
self.attn = ConformerAttention(config)
|
||||||
|
self.conv = ConformerConvModule(config)
|
||||||
|
self.ff2 = ConformerFeedForward(config)
|
||||||
|
self.post_norm = nn.LayerNorm(config.hidden_dim)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array, attention_dists: mx.array) -> mx.array:
|
||||||
|
x = 0.5 * self.ff1(x) + x
|
||||||
|
x = self.attn(x, attention_dists) + x
|
||||||
|
x = self.conv(x) + x
|
||||||
|
x = 0.5 * self.ff2(x) + x
|
||||||
|
x = self.post_norm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CTCEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: EncoderConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.input_linear = nn.Linear(config.input_dim, config.hidden_dim)
|
||||||
|
self.layers = [ConformerBlock(config) for _ in range(config.num_layers)]
|
||||||
|
self.out = nn.Linear(config.hidden_dim, config.output_dim)
|
||||||
|
self.out_mid = nn.Linear(config.output_dim, config.hidden_dim)
|
||||||
|
self.num_layers = config.num_layers
|
||||||
|
self._attention_dists = None
|
||||||
|
|
||||||
|
seq = mx.arange(config.context_size)
|
||||||
|
relpos_dist = seq[:, None] - seq[None, :]
|
||||||
|
self._attention_dists = (
|
||||||
|
mx.clip(relpos_dist, -config.context_size, config.context_size)
|
||||||
|
+ config.max_pos_emb
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
x = self.input_linear(x)
|
||||||
|
cat_layers = set(self.config.cat_hidden_layers or [])
|
||||||
|
exported_hidden_states = []
|
||||||
|
if 0 in cat_layers:
|
||||||
|
exported_hidden_states.append(x)
|
||||||
|
|
||||||
|
for idx, layer in enumerate(self.layers, start=1):
|
||||||
|
x = layer(x, attention_dists=self._attention_dists)
|
||||||
|
if idx in cat_layers:
|
||||||
|
exported_hidden_states.append(x)
|
||||||
|
if idx == self.num_layers // 2:
|
||||||
|
x_mid = self.out(x)
|
||||||
|
x = x + self.out_mid(mx.softmax(x_mid, axis=-1))
|
||||||
|
|
||||||
|
if exported_hidden_states:
|
||||||
|
# Plus variant: prepend captured intermediate hidden states to the
|
||||||
|
# final-layer output along the channel axis. Order matches the
|
||||||
|
# upstream Transformers implementation: intermediates first, then
|
||||||
|
# final.
|
||||||
|
x = mx.concatenate([*exported_hidden_states, x], axis=-1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class QFormerMultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, num_heads: int, kv_hidden_size: int = None):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = hidden_size // num_heads
|
||||||
|
kv_dim = kv_hidden_size or hidden_size
|
||||||
|
|
||||||
|
self.query = nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.key = nn.Linear(kv_dim, hidden_size)
|
||||||
|
self.value = nn.Linear(kv_dim, hidden_size)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, hidden_states: mx.array, encoder_hidden_states: mx.array = None
|
||||||
|
) -> mx.array:
|
||||||
|
B, L, _ = hidden_states.shape
|
||||||
|
|
||||||
|
q = self.query(hidden_states)
|
||||||
|
kv_input = (
|
||||||
|
encoder_hidden_states
|
||||||
|
if encoder_hidden_states is not None
|
||||||
|
else hidden_states
|
||||||
|
)
|
||||||
|
k = self.key(kv_input)
|
||||||
|
v = self.value(kv_input)
|
||||||
|
|
||||||
|
q = q.reshape(B, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||||
|
k = k.reshape(B, -1, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||||
|
v = v.reshape(B, -1, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
|
scale = self.head_dim**-0.5
|
||||||
|
attn = (q * scale) @ k.transpose(0, 1, 3, 2)
|
||||||
|
attn = mx.softmax(attn, axis=-1)
|
||||||
|
out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class QFormerSelfOutput(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int, eps: float = 1e-12):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.LayerNorm = nn.LayerNorm(hidden_size, eps=eps)
|
||||||
|
|
||||||
|
def __call__(self, hidden_states: mx.array, input_tensor: mx.array) -> mx.array:
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class QFormerAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
kv_hidden_size: int = None,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = QFormerMultiHeadAttention(
|
||||||
|
hidden_size, num_heads, kv_hidden_size
|
||||||
|
)
|
||||||
|
self.output = QFormerSelfOutput(hidden_size, eps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, hidden_states: mx.array, encoder_hidden_states: mx.array = None
|
||||||
|
) -> mx.array:
|
||||||
|
attn_out = self.attention(hidden_states, encoder_hidden_states)
|
||||||
|
return self.output(attn_out, hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class QFormerIntermediate(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int, intermediate_size: int):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(hidden_size, intermediate_size)
|
||||||
|
|
||||||
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
|
return nn.gelu(self.dense(x))
|
||||||
|
|
||||||
|
|
||||||
|
class QFormerOutput(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, intermediate_size: int, hidden_size: int, eps: float = 1e-12):
|
||||||
|
super().__init__()
|
||||||
|
self.dense = nn.Linear(intermediate_size, hidden_size)
|
||||||
|
self.LayerNorm = nn.LayerNorm(hidden_size, eps=eps)
|
||||||
|
|
||||||
|
def __call__(self, hidden_states: mx.array, input_tensor: mx.array) -> mx.array:
|
||||||
|
hidden_states = self.dense(hidden_states)
|
||||||
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class QFormerLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: ProjectorConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = QFormerAttention(
|
||||||
|
config.hidden_size, config.num_attention_heads, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
self.crossattention = QFormerAttention(
|
||||||
|
config.hidden_size,
|
||||||
|
config.num_attention_heads,
|
||||||
|
kv_hidden_size=config.encoder_hidden_size,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
self.intermediate_query = QFormerIntermediate(
|
||||||
|
config.hidden_size, config.intermediate_size
|
||||||
|
)
|
||||||
|
self.output_query = QFormerOutput(
|
||||||
|
config.intermediate_size, config.hidden_size, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, hidden_states: mx.array, encoder_hidden_states: mx.array
|
||||||
|
) -> mx.array:
|
||||||
|
hidden_states = self.attention(hidden_states)
|
||||||
|
hidden_states = self.crossattention(hidden_states, encoder_hidden_states)
|
||||||
|
intermediate = self.intermediate_query(hidden_states)
|
||||||
|
hidden_states = self.output_query(intermediate, hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class QFormerEncoder(nn.Module):
|
||||||
|
def __init__(self, config: ProjectorConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.layer = [QFormerLayer(config) for _ in range(config.num_hidden_layers)]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, hidden_states: mx.array, encoder_hidden_states: mx.array
|
||||||
|
) -> mx.array:
|
||||||
|
for layer in self.layer:
|
||||||
|
hidden_states = layer(hidden_states, encoder_hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class QFormerModel(nn.Module):
|
||||||
|
def __init__(self, config: ProjectorConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.encoder = QFormerEncoder(config)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, query_embeds: mx.array, encoder_hidden_states: mx.array
|
||||||
|
) -> mx.array:
|
||||||
|
hidden_states = self.layernorm(query_embeds)
|
||||||
|
return self.encoder(hidden_states, encoder_hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderProjector(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: ModelConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.projector_config.hidden_size
|
||||||
|
self.downsample_rate = config.downsample_rate
|
||||||
|
self.window_size = config.window_size
|
||||||
|
self.num_queries = config.window_size // config.downsample_rate
|
||||||
|
|
||||||
|
self.query = mx.zeros(
|
||||||
|
(1, self.num_queries, config.projector_config.hidden_size)
|
||||||
|
)
|
||||||
|
self.qformer = QFormerModel(config.projector_config)
|
||||||
|
self.linear = nn.Linear(
|
||||||
|
config.projector_config.hidden_size, config.text_config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, hidden_states: mx.array) -> mx.array:
|
||||||
|
B, L, D = hidden_states.shape
|
||||||
|
nblocks = math.ceil(L / self.window_size)
|
||||||
|
pad = nblocks * self.window_size - L
|
||||||
|
if pad > 0:
|
||||||
|
hidden_states = mx.pad(hidden_states, [(0, 0), (0, pad), (0, 0)])
|
||||||
|
|
||||||
|
hidden_states = hidden_states.reshape(B * nblocks, self.window_size, D)
|
||||||
|
|
||||||
|
query = mx.broadcast_to(
|
||||||
|
self.query, (B * nblocks, self.num_queries, self.hidden_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
query_output = self.qformer(query, hidden_states)
|
||||||
|
query_proj = self.linear(
|
||||||
|
query_output.reshape(B, nblocks * self.num_queries, -1)
|
||||||
|
)
|
||||||
|
return query_proj
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, config: ModelConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Plus variant invariant: encoder concats len(cat_hidden_layers)+1
|
||||||
|
# hidden states channel-wise, so projector must accept that wider input.
|
||||||
|
cat_layers = config.encoder_config.cat_hidden_layers or []
|
||||||
|
expected_proj_in = config.encoder_config.hidden_dim * (len(cat_layers) + 1)
|
||||||
|
if config.projector_config.encoder_hidden_size != expected_proj_in:
|
||||||
|
raise ValueError(
|
||||||
|
f"projector_config.encoder_hidden_size ({config.projector_config.encoder_hidden_size}) "
|
||||||
|
f"must equal encoder_config.hidden_dim * (len(cat_hidden_layers) + 1) "
|
||||||
|
f"({config.encoder_config.hidden_dim} * {len(cat_layers) + 1} = {expected_proj_in})"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder = CTCEncoder(config.encoder_config)
|
||||||
|
self.projector = EncoderProjector(config)
|
||||||
|
text_args = GraniteModelArgs.from_dict(
|
||||||
|
config.text_config.__dict__
|
||||||
|
if hasattr(config.text_config, "__dict__")
|
||||||
|
else config.text_config
|
||||||
|
)
|
||||||
|
self.language_model = GraniteLM(text_args)
|
||||||
|
|
||||||
|
self.audio_token_id = config.audio_token_index
|
||||||
|
self._tokenizer = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layers(self):
|
||||||
|
return self.language_model.model.layers
|
||||||
|
|
||||||
|
def make_cache(self) -> List[KVCache]:
|
||||||
|
return [KVCache() for _ in range(len(self.layers))]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
input_ids: mx.array,
|
||||||
|
cache: Optional[List[KVCache]] = None,
|
||||||
|
input_embeddings: Optional[mx.array] = None,
|
||||||
|
) -> mx.array:
|
||||||
|
if input_embeddings is not None:
|
||||||
|
h = input_embeddings
|
||||||
|
else:
|
||||||
|
h = self.language_model.model.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
h = h * self.language_model.model.embedding_multiplier
|
||||||
|
|
||||||
|
if cache is None:
|
||||||
|
cache = [None] * len(self.language_model.model.layers)
|
||||||
|
|
||||||
|
mask = create_attention_mask(h, cache[0])
|
||||||
|
|
||||||
|
for layer, c in zip(self.language_model.model.layers, cache):
|
||||||
|
h = layer(h, mask, cache=c)
|
||||||
|
|
||||||
|
h = self.language_model.model.norm(h)
|
||||||
|
|
||||||
|
if self.language_model.args.tie_word_embeddings:
|
||||||
|
logits = self.language_model.model.embed_tokens.as_linear(h)
|
||||||
|
else:
|
||||||
|
logits = self.language_model.lm_head(h)
|
||||||
|
|
||||||
|
return logits / self.language_model.logits_scaling
|
||||||
|
|
||||||
|
def get_audio_features(self, input_features: mx.array) -> mx.array:
|
||||||
|
encoder_output = self.encoder(input_features)
|
||||||
|
projected = self.projector(encoder_output)
|
||||||
|
return projected
|
||||||
|
|
||||||
|
def model_quant_predicate(self, p: str, m: nn.Module) -> bool:
|
||||||
|
return not (p.startswith("encoder") or p.startswith("projector"))
|
||||||
|
|
||||||
|
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
|
||||||
|
# Compare incoming weight shapes against the model's already-initialized
|
||||||
|
# parameter shapes. This is idempotent across convert-time (PyTorch source
|
||||||
|
# layout) and inference-time load (MLX-native layout) and correct even for
|
||||||
|
# Conv1d kernel_size=1 layers where prior shape-ordering heuristics failed.
|
||||||
|
# Pattern adapted from cohere_asr.Model.sanitize.
|
||||||
|
model_weights = dict(tree_flatten(self.parameters()))
|
||||||
|
sanitized = {}
|
||||||
|
for k, v in weights.items():
|
||||||
|
if "num_batches_tracked" in k:
|
||||||
|
continue
|
||||||
|
# granite-speech-4.1 ships a separate out_llm.safetensors with
|
||||||
|
# top-level "weight"/"bias" keys (likely an audio CTC head). The
|
||||||
|
# standard Model class does not define this layer, so dropping these
|
||||||
|
# keys is required for the rest of the load to succeed. Inference
|
||||||
|
# behaviour that depends on this head is not yet supported.
|
||||||
|
if k in ("weight", "bias"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
expected = model_weights.get(k)
|
||||||
|
if expected is not None and hasattr(expected, "shape"):
|
||||||
|
if v.shape != expected.shape and v.ndim == 3:
|
||||||
|
transposed = mx.transpose(v, (0, 2, 1))
|
||||||
|
if transposed.shape == expected.shape:
|
||||||
|
v = transposed
|
||||||
|
|
||||||
|
sanitized[k] = v
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def post_load_hook(cls, model: "Model", model_path: Path) -> "Model":
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
prev = transformers.logging.get_verbosity()
|
||||||
|
transformers.logging.set_verbosity_error()
|
||||||
|
try:
|
||||||
|
model._tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
str(model_path), trust_remote_code=True
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
transformers.logging.set_verbosity(prev)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _extract_features(
|
||||||
|
self, audio: Union[mx.array, np.ndarray]
|
||||||
|
) -> Tuple[mx.array, int]:
|
||||||
|
from ..dsp import hanning, mel_filters, stft
|
||||||
|
|
||||||
|
n_fft = 512
|
||||||
|
win_length = 400
|
||||||
|
hop_length = 160
|
||||||
|
n_mels = 80
|
||||||
|
sample_rate = 16000
|
||||||
|
|
||||||
|
if isinstance(audio, mx.array):
|
||||||
|
audio_1d = audio.reshape(-1)
|
||||||
|
else:
|
||||||
|
audio_1d = mx.array(audio.flatten(), dtype=mx.float32)
|
||||||
|
|
||||||
|
win = hanning(win_length, periodic=True)
|
||||||
|
pad_left = (n_fft - win_length) // 2
|
||||||
|
pad_right = n_fft - win_length - pad_left
|
||||||
|
win_padded = mx.concatenate(
|
||||||
|
[mx.zeros((pad_left,)), win, mx.zeros((pad_right,))]
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = stft(
|
||||||
|
audio_1d,
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
window=win_padded,
|
||||||
|
center=True,
|
||||||
|
pad_mode="reflect",
|
||||||
|
)
|
||||||
|
|
||||||
|
power = mx.abs(spec) ** 2
|
||||||
|
mel_fb = mel_filters(sample_rate, n_fft, n_mels, mel_scale="htk")
|
||||||
|
mel_spec = power @ mel_fb.T
|
||||||
|
|
||||||
|
logmel = mx.log10(mx.clip(mel_spec, 1e-10, None))
|
||||||
|
mx_val = mx.max(logmel)
|
||||||
|
logmel = mx.maximum(logmel, mx_val - 8.0) / 4.0 + 1.0
|
||||||
|
|
||||||
|
if logmel.shape[0] % 2 == 1:
|
||||||
|
logmel = logmel[:-1]
|
||||||
|
|
||||||
|
encoder_input = logmel.reshape(-1, 2 * n_mels)
|
||||||
|
|
||||||
|
encoder_length = encoder_input.shape[0]
|
||||||
|
nblocks = math.ceil(encoder_length / self.config.window_size)
|
||||||
|
num_audio_tokens = nblocks * (
|
||||||
|
self.config.window_size // self.config.downsample_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
input_features = encoder_input[None, :, :]
|
||||||
|
return input_features, num_audio_tokens
|
||||||
|
|
||||||
|
def _build_prompt(
|
||||||
|
self,
|
||||||
|
num_audio_tokens: int,
|
||||||
|
user_prompt: str = None,
|
||||||
|
system_prompt: str = None,
|
||||||
|
) -> mx.array:
|
||||||
|
if user_prompt is None:
|
||||||
|
user_prompt = "can you transcribe the speech into a written format?"
|
||||||
|
|
||||||
|
audio_placeholder = "<|audio|>" * num_audio_tokens
|
||||||
|
content = f"{audio_placeholder}{user_prompt}"
|
||||||
|
|
||||||
|
if getattr(self._tokenizer, "chat_template", None):
|
||||||
|
chat = []
|
||||||
|
if system_prompt:
|
||||||
|
chat.append({"role": "system", "content": system_prompt})
|
||||||
|
chat.append({"role": "user", "content": content})
|
||||||
|
prompt_str = self._tokenizer.apply_chat_template(
|
||||||
|
chat, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Granite-3 chat format (granite-speech tokenizer ships without a
|
||||||
|
# chat_template attribute, but its vocab includes the role tokens).
|
||||||
|
sor, eor, eot = "<|start_of_role|>", "<|end_of_role|>", "<|end_of_text|>"
|
||||||
|
parts = []
|
||||||
|
if system_prompt:
|
||||||
|
parts.append(f"{sor}system{eor}{system_prompt}{eot}\n")
|
||||||
|
parts.append(f"{sor}user{eor}{content}{eot}\n")
|
||||||
|
parts.append(f"{sor}assistant{eor}")
|
||||||
|
prompt_str = "".join(parts)
|
||||||
|
|
||||||
|
prompt_ids = self._tokenizer.encode(prompt_str)
|
||||||
|
|
||||||
|
return mx.array(prompt_ids)
|
||||||
|
|
||||||
|
def _build_inputs_embeds(
|
||||||
|
self, input_ids: mx.array, audio_features: mx.array
|
||||||
|
) -> mx.array:
|
||||||
|
is_audio = input_ids == self.audio_token_id
|
||||||
|
llm_ids = mx.where(is_audio, 0, input_ids)
|
||||||
|
|
||||||
|
inputs_embeds = self.language_model.model.embed_tokens(llm_ids[None])
|
||||||
|
|
||||||
|
is_audio_np = np.array(is_audio)
|
||||||
|
audio_positions = np.where(is_audio_np)[0]
|
||||||
|
|
||||||
|
orig_dtype = inputs_embeds.dtype
|
||||||
|
embeds_np = np.array(inputs_embeds.astype(mx.float32))
|
||||||
|
audio_np = np.array(audio_features.astype(mx.float32))
|
||||||
|
|
||||||
|
num_audio = min(len(audio_positions), audio_np.shape[1])
|
||||||
|
embeds_np[0, audio_positions[:num_audio]] = audio_np[0, :num_audio]
|
||||||
|
|
||||||
|
return mx.array(embeds_np).astype(orig_dtype)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
audio: Union[str, mx.array, np.ndarray],
|
||||||
|
*,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
top_k: int = 0,
|
||||||
|
min_p: float = 0.0,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
repetition_context_size: int = 100,
|
||||||
|
prompt: str = None,
|
||||||
|
system_prompt: str = None,
|
||||||
|
language: str = None,
|
||||||
|
prefill_step_size: int = 2048,
|
||||||
|
verbose: bool = False,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[STTOutput, Generator[StreamingResult, None, None]]:
|
||||||
|
if prompt is None and language is not None:
|
||||||
|
lang_name = LANGUAGE_CODES.get(language.lower(), language)
|
||||||
|
prompt = f"Translate the speech to {lang_name}."
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_generate(
|
||||||
|
audio,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
min_p=min_p,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
repetition_context_size=repetition_context_size,
|
||||||
|
prompt=prompt,
|
||||||
|
prefill_step_size=prefill_step_size,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
from mlx_lm.generate import generate_step
|
||||||
|
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||||
|
|
||||||
|
audio_data = self._load_audio(audio)
|
||||||
|
input_features, num_audio_tokens = self._extract_features(audio_data)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print("Encoding audio...")
|
||||||
|
audio_features = self.get_audio_features(input_features)
|
||||||
|
mx.eval(audio_features)
|
||||||
|
|
||||||
|
prompt_ids = self._build_prompt(num_audio_tokens, prompt, system_prompt=system_prompt)
|
||||||
|
inputs_embeds = self._build_inputs_embeds(prompt_ids, audio_features)
|
||||||
|
mx.eval(inputs_embeds)
|
||||||
|
|
||||||
|
prompt_tokens = len(prompt_ids)
|
||||||
|
|
||||||
|
sampler = make_sampler(temperature, top_p=top_p, min_p=min_p, top_k=top_k)
|
||||||
|
logits_processors = make_logits_processors(
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
repetition_context_size=repetition_context_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
eos_token_id = self._tokenizer.eos_token_id
|
||||||
|
tokens = []
|
||||||
|
|
||||||
|
for token, logprobs in generate_step(
|
||||||
|
prompt=prompt_ids,
|
||||||
|
input_embeddings=inputs_embeds.squeeze(0),
|
||||||
|
model=self,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
sampler=sampler,
|
||||||
|
logits_processors=logits_processors,
|
||||||
|
prefill_step_size=prefill_step_size,
|
||||||
|
):
|
||||||
|
if token == eos_token_id:
|
||||||
|
break
|
||||||
|
tokens.append(token)
|
||||||
|
|
||||||
|
text = self._tokenizer.decode(tokens, skip_special_tokens=True)
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
gen_tokens = len(tokens)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"Prompt tokens: {prompt_tokens}")
|
||||||
|
print(f"Generation tokens: {gen_tokens}")
|
||||||
|
print(f"Total time: {elapsed:.2f}s")
|
||||||
|
if gen_tokens > 0:
|
||||||
|
print(f"Generation TPS: {gen_tokens / elapsed:.1f}")
|
||||||
|
|
||||||
|
return STTOutput(
|
||||||
|
text=text,
|
||||||
|
segments=[],
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
generation_tokens=gen_tokens,
|
||||||
|
total_tokens=prompt_tokens + gen_tokens,
|
||||||
|
total_time=elapsed,
|
||||||
|
prompt_tps=prompt_tokens / elapsed if elapsed > 0 else 0,
|
||||||
|
generation_tps=gen_tokens / elapsed if elapsed > 0 else 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _stream_generate(
|
||||||
|
self,
|
||||||
|
audio: Union[str, mx.array, np.ndarray],
|
||||||
|
*,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
top_k: int = 0,
|
||||||
|
min_p: float = 0.0,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
repetition_context_size: int = 100,
|
||||||
|
prompt: str = None,
|
||||||
|
prefill_step_size: int = 2048,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> Generator[StreamingResult, None, None]:
|
||||||
|
from mlx_lm.generate import generate_step
|
||||||
|
from mlx_lm.sample_utils import make_logits_processors, make_sampler
|
||||||
|
|
||||||
|
audio_data = self._load_audio(audio)
|
||||||
|
input_features, num_audio_tokens = self._extract_features(audio_data)
|
||||||
|
|
||||||
|
audio_features = self.get_audio_features(input_features)
|
||||||
|
mx.eval(audio_features)
|
||||||
|
|
||||||
|
prompt_ids = self._build_prompt(num_audio_tokens, prompt)
|
||||||
|
inputs_embeds = self._build_inputs_embeds(prompt_ids, audio_features)
|
||||||
|
mx.eval(inputs_embeds)
|
||||||
|
|
||||||
|
prompt_token_count = len(prompt_ids)
|
||||||
|
|
||||||
|
sampler = make_sampler(temperature, top_p=top_p, min_p=min_p, top_k=top_k)
|
||||||
|
logits_processors = make_logits_processors(
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
repetition_context_size=repetition_context_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
eos_token_id = self._tokenizer.eos_token_id
|
||||||
|
gen_tokens = 0
|
||||||
|
|
||||||
|
for token, _ in generate_step(
|
||||||
|
prompt=prompt_ids,
|
||||||
|
input_embeddings=inputs_embeds.squeeze(0),
|
||||||
|
model=self,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
sampler=sampler,
|
||||||
|
logits_processors=logits_processors,
|
||||||
|
prefill_step_size=prefill_step_size,
|
||||||
|
):
|
||||||
|
if token == eos_token_id:
|
||||||
|
break
|
||||||
|
gen_tokens += 1
|
||||||
|
text = self._tokenizer.decode([token], skip_special_tokens=True)
|
||||||
|
yield StreamingResult(
|
||||||
|
text=text,
|
||||||
|
is_final=False,
|
||||||
|
start_time=0.0,
|
||||||
|
end_time=0.0,
|
||||||
|
prompt_tokens=prompt_token_count,
|
||||||
|
generation_tokens=gen_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield StreamingResult(
|
||||||
|
text="",
|
||||||
|
is_final=True,
|
||||||
|
start_time=0.0,
|
||||||
|
end_time=0.0,
|
||||||
|
prompt_tokens=prompt_token_count,
|
||||||
|
generation_tokens=gen_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_audio(self, audio: Union[str, mx.array, np.ndarray]) -> mx.array:
|
||||||
|
if isinstance(audio, str):
|
||||||
|
from ..audio import load_audio
|
||||||
|
|
||||||
|
return load_audio(audio)
|
||||||
|
elif isinstance(audio, np.ndarray):
|
||||||
|
return mx.array(audio, dtype=mx.float32)
|
||||||
|
elif isinstance(audio, mx.array):
|
||||||
|
return audio
|
||||||
|
elif isinstance(audio, list):
|
||||||
|
audio_item = audio[0]
|
||||||
|
if isinstance(audio_item, str):
|
||||||
|
from ..audio import load_audio
|
||||||
|
|
||||||
|
return load_audio(audio_item)
|
||||||
|
return mx.array(np.array(audio_item), dtype=mx.float32)
|
||||||
|
raise TypeError(f"Unsupported audio type: {type(audio)}")
|
||||||
154
src/granite_speech_plus_mlx/_vendored/loader.py
Normal file
154
src/granite_speech_plus_mlx/_vendored/loader.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
from .granite_speech import Model, ModelConfig
|
||||||
|
|
||||||
|
DEFAULT_ALLOW_PATTERNS = [
|
||||||
|
"*.json",
|
||||||
|
"*.safetensors",
|
||||||
|
"*.py",
|
||||||
|
"*.model",
|
||||||
|
"*.tiktoken",
|
||||||
|
"*.txt",
|
||||||
|
"*.jsonl",
|
||||||
|
"*.yaml",
|
||||||
|
"*.npz",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_local_path(path: str) -> bool:
|
||||||
|
return (
|
||||||
|
path.startswith(".")
|
||||||
|
or path.startswith("/")
|
||||||
|
or path.startswith("~")
|
||||||
|
or (len(path) > 1 and path[1] == ":")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_path(
|
||||||
|
path_or_hf_repo: str | Path,
|
||||||
|
*,
|
||||||
|
revision: str | None = None,
|
||||||
|
force_download: bool = False,
|
||||||
|
allow_patterns: list[str] | None = None,
|
||||||
|
) -> Path:
|
||||||
|
if isinstance(path_or_hf_repo, Path):
|
||||||
|
path = path_or_hf_repo.expanduser()
|
||||||
|
if path.exists():
|
||||||
|
return path
|
||||||
|
raise FileNotFoundError(f"Local path not found: {path_or_hf_repo}")
|
||||||
|
|
||||||
|
path = Path(path_or_hf_repo).expanduser()
|
||||||
|
if path.exists():
|
||||||
|
return path
|
||||||
|
if _is_local_path(path_or_hf_repo):
|
||||||
|
raise FileNotFoundError(f"Local path not found: {path_or_hf_repo}")
|
||||||
|
|
||||||
|
return Path(
|
||||||
|
snapshot_download(
|
||||||
|
path_or_hf_repo,
|
||||||
|
revision=revision,
|
||||||
|
allow_patterns=allow_patterns or DEFAULT_ALLOW_PATTERNS,
|
||||||
|
force_download=force_download,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(model_path: str | Path) -> dict[str, Any]:
|
||||||
|
model_path = Path(model_path)
|
||||||
|
config_file = model_path / "config.json"
|
||||||
|
if not config_file.exists():
|
||||||
|
raise FileNotFoundError(f"Config not found at {model_path}")
|
||||||
|
return json.loads(config_file.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def load_weights(model_path: Path) -> dict[str, mx.array]:
|
||||||
|
weight_files = sorted(glob.glob(str(model_path / "*.safetensors")))
|
||||||
|
if not weight_files:
|
||||||
|
weight_files = sorted(glob.glob(str(model_path / "*.npz")))
|
||||||
|
if not weight_files:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"No weight files (safetensors or npz) found in {model_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = {}
|
||||||
|
for weight_file in weight_files:
|
||||||
|
weights.update(mx.load(weight_file))
|
||||||
|
return weights
|
||||||
|
|
||||||
|
|
||||||
|
def apply_quantization(
|
||||||
|
model: nn.Module,
|
||||||
|
config: dict[str, Any],
|
||||||
|
weights: dict[str, mx.array],
|
||||||
|
model_quant_predicate=None,
|
||||||
|
) -> None:
|
||||||
|
quantization = config.get("quantization") or config.get("quantization_config")
|
||||||
|
if quantization is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
group_size = quantization.get("group_size", 64)
|
||||||
|
|
||||||
|
def class_predicate(path, module):
|
||||||
|
if not hasattr(module, "to_quantized"):
|
||||||
|
return False
|
||||||
|
if hasattr(module, "weight") and module.weight.shape[-1] % group_size != 0:
|
||||||
|
return False
|
||||||
|
if model_quant_predicate is not None:
|
||||||
|
pred = model_quant_predicate(path, module)
|
||||||
|
if isinstance(pred, dict):
|
||||||
|
return pred
|
||||||
|
if not pred:
|
||||||
|
return False
|
||||||
|
if path in quantization:
|
||||||
|
return quantization[path]
|
||||||
|
return f"{path}.scales" in weights
|
||||||
|
|
||||||
|
nn.quantize(
|
||||||
|
model,
|
||||||
|
group_size=group_size,
|
||||||
|
bits=quantization["bits"],
|
||||||
|
mode=quantization.get("mode", "affine"),
|
||||||
|
class_predicate=class_predicate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
model_path: str | Path,
|
||||||
|
*,
|
||||||
|
lazy: bool = False,
|
||||||
|
strict: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> nn.Module:
|
||||||
|
path = get_model_path(
|
||||||
|
model_path,
|
||||||
|
revision=kwargs.pop("revision", None),
|
||||||
|
force_download=kwargs.pop("force_download", False),
|
||||||
|
allow_patterns=kwargs.pop("allow_patterns", None),
|
||||||
|
)
|
||||||
|
config = load_config(path)
|
||||||
|
model = Model(ModelConfig.from_dict(config))
|
||||||
|
weights = load_weights(path)
|
||||||
|
|
||||||
|
if hasattr(model, "sanitize"):
|
||||||
|
weights = model.sanitize(weights)
|
||||||
|
|
||||||
|
apply_quantization(model, config, weights, model.model_quant_predicate)
|
||||||
|
model.load_weights(list(weights.items()), strict=strict)
|
||||||
|
|
||||||
|
if not lazy:
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if hasattr(Model, "post_load_hook"):
|
||||||
|
model = Model.post_load_hook(model, path)
|
||||||
|
return model
|
||||||
|
|
||||||
56
src/granite_speech_plus_mlx/chunking.py
Normal file
56
src/granite_speech_plus_mlx/chunking.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import re
|
||||||
|
from typing import Any, Iterable, Iterator
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AudioChunk:
|
||||||
|
index: int
|
||||||
|
start: float
|
||||||
|
end: float
|
||||||
|
samples: Any
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_audio(
|
||||||
|
audio: Any,
|
||||||
|
sr: int,
|
||||||
|
chunk_seconds: float,
|
||||||
|
overlap_seconds: float = 2.0,
|
||||||
|
) -> Iterator[AudioChunk]:
|
||||||
|
if chunk_seconds <= 0:
|
||||||
|
raise ValueError("chunk_seconds must be positive")
|
||||||
|
if overlap_seconds < 0:
|
||||||
|
raise ValueError("overlap_seconds cannot be negative")
|
||||||
|
|
||||||
|
chunk_samples = int(chunk_seconds * sr)
|
||||||
|
overlap_samples = int(overlap_seconds * sr)
|
||||||
|
if overlap_samples >= chunk_samples:
|
||||||
|
raise ValueError("overlap_seconds must be smaller than chunk_seconds")
|
||||||
|
|
||||||
|
step = chunk_samples - overlap_samples
|
||||||
|
n = len(audio)
|
||||||
|
pos = 0
|
||||||
|
index = 1
|
||||||
|
while pos < n:
|
||||||
|
end = min(pos + chunk_samples, n)
|
||||||
|
yield AudioChunk(index, pos / sr, end / sr, audio[pos:end])
|
||||||
|
if end == n:
|
||||||
|
break
|
||||||
|
pos += step
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
|
||||||
|
def prefix_text(transcripts: Iterable[str], max_chars: int = 800) -> str:
|
||||||
|
text = "\n".join(t.strip() for t in transcripts if t and t.strip())
|
||||||
|
text = re.sub(r"^## \[[^\n]+\]\n", "", text, flags=re.MULTILINE)
|
||||||
|
text = re.sub(r"\s+", " ", text).strip()
|
||||||
|
if len(text) <= max_chars:
|
||||||
|
return text
|
||||||
|
|
||||||
|
tail = text[-max_chars:]
|
||||||
|
first_space = tail.find(" ")
|
||||||
|
if first_space > 0:
|
||||||
|
tail = tail[first_space + 1 :]
|
||||||
|
return tail.strip()
|
||||||
126
src/granite_speech_plus_mlx/pipeline.py
Normal file
126
src/granite_speech_plus_mlx/pipeline.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .chunking import chunk_audio, prefix_text
|
||||||
|
from .prompts import GRANITE_SYSTEM_PROMPT, PROMPT_MODES, build_prompt
|
||||||
|
|
||||||
|
DEFAULT_MODEL = "mlx-community/granite-speech-4.1-2b-plus-mlx"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GraniteSpeechPlusPipeline:
|
||||||
|
model: Any
|
||||||
|
repo_id: str = DEFAULT_MODEL
|
||||||
|
chunk_seconds: float = 300.0
|
||||||
|
overlap_seconds: float = 2.0
|
||||||
|
repetition_penalty: float = 1.2
|
||||||
|
max_tokens: int = 4096
|
||||||
|
system_prompt: str | None = GRANITE_SYSTEM_PROMPT
|
||||||
|
verbose: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
repo_id: str = DEFAULT_MODEL,
|
||||||
|
*,
|
||||||
|
chunk_seconds: float = 300.0,
|
||||||
|
overlap_seconds: float = 2.0,
|
||||||
|
repetition_penalty: float = 1.2,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
system_prompt: str | None = GRANITE_SYSTEM_PROMPT,
|
||||||
|
verbose: bool = False,
|
||||||
|
**load_kwargs: Any,
|
||||||
|
) -> "GraniteSpeechPlusPipeline":
|
||||||
|
from ._vendored.loader import load_model
|
||||||
|
|
||||||
|
model = load_model(repo_id, **load_kwargs)
|
||||||
|
return cls(
|
||||||
|
model=model,
|
||||||
|
repo_id=repo_id,
|
||||||
|
chunk_seconds=chunk_seconds,
|
||||||
|
overlap_seconds=overlap_seconds,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
def transcribe(self, audio_path: str | Path, prompt_mode: str = "asr") -> str:
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
if prompt_mode not in PROMPT_MODES:
|
||||||
|
modes = ", ".join(sorted(PROMPT_MODES))
|
||||||
|
raise ValueError(f"prompt_mode must be one of: {modes}")
|
||||||
|
|
||||||
|
audio_file = Path(audio_path).expanduser().resolve()
|
||||||
|
audio, sr = librosa.load(str(audio_file), sr=16000, mono=True)
|
||||||
|
audio = np.asarray(audio, dtype=np.float32)
|
||||||
|
duration = len(audio) / sr if sr else 0.0
|
||||||
|
chunks = list(
|
||||||
|
chunk_audio(
|
||||||
|
audio,
|
||||||
|
sr,
|
||||||
|
self.chunk_seconds,
|
||||||
|
overlap_seconds=self.overlap_seconds,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(
|
||||||
|
f"Loaded {audio_file} ({duration:.1f}s, {len(chunks)} chunks)",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
rendered: list[str] = []
|
||||||
|
plain_texts: list[str] = []
|
||||||
|
t_start = time.time()
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
prompt = build_prompt(
|
||||||
|
prompt_mode,
|
||||||
|
prefix_text=prefix_text(plain_texts),
|
||||||
|
)
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
}
|
||||||
|
if self.system_prompt:
|
||||||
|
kwargs["system_prompt"] = self.system_prompt
|
||||||
|
if self.repetition_penalty and self.repetition_penalty > 1.0:
|
||||||
|
kwargs["repetition_penalty"] = self.repetition_penalty
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
result = self.model.generate(chunk.samples, **kwargs)
|
||||||
|
text = getattr(result, "text", result)
|
||||||
|
if not isinstance(text, str):
|
||||||
|
text = str(text)
|
||||||
|
text = text.strip()
|
||||||
|
plain_texts.append(text)
|
||||||
|
|
||||||
|
if len(chunks) > 1:
|
||||||
|
rendered.append(f"## [{chunk.start:.1f}s - {chunk.end:.1f}s]\n{text}")
|
||||||
|
else:
|
||||||
|
rendered.append(text)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
elapsed = time.time() - t0
|
||||||
|
rtf = (chunk.end - chunk.start) / elapsed if elapsed > 0 else 0.0
|
||||||
|
print(
|
||||||
|
f"[{chunk.index:>3}/{len(chunks)}] "
|
||||||
|
f"{chunk.start:>6.1f}s-{chunk.end:<6.1f}s "
|
||||||
|
f"gen={elapsed:>5.1f}s rtf={rtf:>4.1f}x {text[:80]}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
elapsed = time.time() - t_start
|
||||||
|
rtf = duration / elapsed if elapsed > 0 else 0.0
|
||||||
|
print(f"Total: {elapsed:.1f}s, rtf={rtf:.1f}x", file=sys.stderr)
|
||||||
|
|
||||||
|
return "\n\n".join(rendered).strip()
|
||||||
65
src/granite_speech_plus_mlx/prompts.py
Normal file
65
src/granite_speech_plus_mlx/prompts.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
GRANITE_SYSTEM_PROMPT = (
|
||||||
|
"Knowledge Cutoff Date: April 2024.\n"
|
||||||
|
"Today's Date: December 19, 2024.\n"
|
||||||
|
"You are Granite, developed by IBM. You are a helpful AI assistant"
|
||||||
|
)
|
||||||
|
|
||||||
|
PROMPT_MODES = {
|
||||||
|
"asr": "can you transcribe the speech into a written format?",
|
||||||
|
"saa": (
|
||||||
|
"Speaker attribution: Transcribe and denote who is speaking by adding "
|
||||||
|
"[Speaker 1]: and [Speaker 2]: tags before speaker turns."
|
||||||
|
),
|
||||||
|
"ts": (
|
||||||
|
"Timestamps: Transcribe the speech. After each word, add a timestamp tag "
|
||||||
|
"showing the end time in centiseconds, e.g. hello [T:45] world [T:82]"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def granite3_chat_template(
|
||||||
|
content: str,
|
||||||
|
*,
|
||||||
|
system_prompt: str | None = GRANITE_SYSTEM_PROMPT,
|
||||||
|
add_generation_prompt: bool = True,
|
||||||
|
) -> str:
|
||||||
|
"""Build the Granite-3 chat template used when tokenizers omit one."""
|
||||||
|
start_role = "<|start_of_role|>"
|
||||||
|
end_role = "<|end_of_role|>"
|
||||||
|
end_text = "<|end_of_text|>"
|
||||||
|
|
||||||
|
parts: list[str] = []
|
||||||
|
if system_prompt:
|
||||||
|
parts.append(f"{start_role}system{end_role}{system_prompt}{end_text}\n")
|
||||||
|
parts.append(f"{start_role}user{end_role}{content}{end_text}\n")
|
||||||
|
if add_generation_prompt:
|
||||||
|
parts.append(f"{start_role}assistant{end_role}")
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def build_prompt(
|
||||||
|
prompt_mode: str = "asr",
|
||||||
|
*,
|
||||||
|
prefix_text: str | None = None,
|
||||||
|
custom_prompt: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
if custom_prompt:
|
||||||
|
base = custom_prompt
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
base = PROMPT_MODES[prompt_mode]
|
||||||
|
except KeyError as exc:
|
||||||
|
modes = ", ".join(sorted(PROMPT_MODES))
|
||||||
|
raise ValueError(f"prompt_mode must be one of: {modes}") from exc
|
||||||
|
|
||||||
|
if not prefix_text:
|
||||||
|
return base
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"{base}\n\n"
|
||||||
|
"Previous transcript context for continuity only. Do not repeat it:\n"
|
||||||
|
f"{prefix_text.strip()}"
|
||||||
|
)
|
||||||
|
|
||||||
6
tests/test_smoke.py
Normal file
6
tests/test_smoke.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from granite_speech_plus_mlx import GraniteSpeechPlusPipeline
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_symbol_exists():
|
||||||
|
assert GraniteSpeechPlusPipeline is not None
|
||||||
|
|
||||||
Reference in New Issue
Block a user