Initial Granite Speech Plus MLX package

This commit is contained in:
transcrilive
2026-05-09 20:00:57 +02:00
commit c6a20cb79f
21 changed files with 2002 additions and 0 deletions

30
.gitignore vendored Normal file
View File

@@ -0,0 +1,30 @@
__pycache__/
*.py[cod]
*$py.class
.Python
.venv/
venv/
ENV/
env/
build/
dist/
*.egg-info/
.eggs/
.pytest_cache/
.mypy_cache/
.ruff_cache/
.coverage
htmlcov/
.DS_Store
.env
.env.*
*.log
*.tmp
transcripts/
bench/

22
LICENSE Normal file
View File

@@ -0,0 +1,22 @@
MIT License
Copyright (c) 2026 Olivier Dupont
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

42
README.md Normal file
View File

@@ -0,0 +1,42 @@
# granite-speech-4.1-2b-plus-mlx
Standalone Python package for the MLX port of IBM Granite Speech 4.1-2b-plus.
The default model is
[`mlx-community/granite-speech-4.1-2b-plus-mlx`](https://huggingface.co/mlx-community/granite-speech-4.1-2b-plus-mlx).
## Quickstart
```bash
uv add "granite-speech-4.1-2b-plus-mlx @ git+https://gitea.tavportal.com/olivier/granite-speech-4.1-2b-plus-mlx.git"
python -c "from granite_speech_plus_mlx import GraniteSpeechPlusPipeline as P; p=P.from_pretrained(); print(p.transcribe('audio.wav'))"
python scripts/transcribe.py audio.wav --prompt-mode asr --output transcript.txt
python scripts/transcribe.py meeting.wav --prompt-mode saa
python scripts/benchmark.py audio.wav --results bench
```
## Prompt Modes
- `asr`: standard transcription.
- `saa`: speaker-attributed ASR with `[Speaker N]:` turn labels.
- `ts`: word-level timestamp tags like `word [T:45]`.
See [docs/prompt-modes.md](docs/prompt-modes.md) for examples.
## Benchmark Hints
Granite Speech 4.1 allocates substantial encoder memory for long audio. Start
with `--chunk-seconds 300 --repetition-penalty 1.2` for ASR and reduce chunks
to 60 or 180 seconds if memory is tight. Timestamp mode (`ts`) often needs a
larger `--max-tokens` budget because every word carries a timestamp tag.
## Provenance
This package was extracted from the local `MLX_CONVERTOR` project, including
the Granite Speech patch bundle at
`external/patches/granite-speech-idempotent-sanitize.patch`. The vendored
Granite implementation is based on `mlx-audio` commit
`f7c11556eda88731be5cc75ddbdf4a4cb9eeaafc` plus that local patch.
Package code is MIT licensed. Model weights remain under the IBM Granite model
license; review the model card and license terms before redistribution or use.

37
docs/prompt-modes.md Normal file
View File

@@ -0,0 +1,37 @@
# Prompt Modes
Granite Speech Plus supports three prompt modes in this package.
## `asr`
Standard speech transcription.
```python
from granite_speech_plus_mlx import GraniteSpeechPlusPipeline
pipe = GraniteSpeechPlusPipeline.from_pretrained()
text = pipe.transcribe("audio.wav", prompt_mode="asr")
```
## `saa`
Speaker-attributed ASR. The prompt asks the model to add speaker turn labels
such as `[Speaker 1]:` and `[Speaker 2]:`.
```python
text = pipe.transcribe("meeting.wav", prompt_mode="saa")
```
## `ts`
Word-level timestamps. The prompt asks the model to append centisecond tags
after words, for example `hello [T:45] world [T:82]`.
```python
text = pipe.transcribe("clip.wav", prompt_mode="ts")
```
For long audio, the pipeline chunks the waveform and feeds a short previous
transcript prefix into later chunks for continuity. The prefix is context only;
the model is instructed not to repeat it.

27
pyproject.toml Normal file
View File

@@ -0,0 +1,27 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "granite-speech-4.1-2b-plus-mlx"
version = "0.1.0"
description = "Standalone MLX pipeline for the Granite Speech 4.1-2b-plus port."
readme = "README.md"
requires-python = ">=3.10"
license = "MIT"
authors = [
{ name = "Olivier Dupont" }
]
dependencies = [
"mlx>=0.22.0",
"mlx-lm>=0.19.0",
"numpy>=1.26",
"transformers>=4.45",
"huggingface-hub>=0.24",
"soundfile>=0.12",
"librosa>=0.10",
]
[tool.hatch.build.targets.wheel]
packages = ["src/granite_speech_plus_mlx"]

111
scripts/benchmark.py Executable file
View File

@@ -0,0 +1,111 @@
#!/usr/bin/env python
from __future__ import annotations
import argparse
import sys
import time
from collections import Counter
from pathlib import Path
from granite_speech_plus_mlx import GraniteSpeechPlusPipeline
from granite_speech_plus_mlx.pipeline import DEFAULT_MODEL
from granite_speech_plus_mlx.prompts import PROMPT_MODES
GRID = [
(60, 1.0),
(60, 1.2),
(180, 1.0),
(180, 1.2),
(300, 1.0),
(300, 1.2),
(300, 1.4),
]
HALLUCINATION_MARKERS = ("thank you very much", "merci d'avoir regarde")
def analyze(text: str) -> dict:
words = text.split()
lower_words = text.lower().split()
trigrams = Counter(
" ".join(lower_words[i : i + 3]) for i in range(len(lower_words) - 2)
)
top = trigrams.most_common(5)
lower = text.lower()
return {
"n_words": len(words),
"max_trigram_count": top[0][1] if top else 0,
"max_trigram_text": top[0][0] if top else "",
"halluc": {m: lower.count(m) for m in HALLUCINATION_MARKERS},
}
def main() -> int:
parser = argparse.ArgumentParser(description="Benchmark Granite Speech Plus MLX settings.")
parser.add_argument("audio")
parser.add_argument("--model", default=DEFAULT_MODEL)
parser.add_argument("--results", default="bench")
parser.add_argument("--prompt-mode", choices=sorted(PROMPT_MODES), default="asr")
parser.add_argument("--overlap-seconds", type=float, default=2.0)
parser.add_argument("--max-tokens", type=int, default=4096)
args = parser.parse_args()
results_dir = Path(args.results)
results_dir.mkdir(parents=True, exist_ok=True)
pipe = GraniteSpeechPlusPipeline.from_pretrained(
args.model,
overlap_seconds=args.overlap_seconds,
max_tokens=args.max_tokens,
verbose=True,
)
rows = []
for chunk_seconds, repetition_penalty in GRID:
out = results_dir / f"chunk{chunk_seconds}_rp{repetition_penalty:.1f}.txt"
pipe.chunk_seconds = float(chunk_seconds)
pipe.repetition_penalty = repetition_penalty
if out.exists():
print(f"# skipping {out.name} (already exists, delete to rerun)", file=sys.stderr)
elapsed = float("nan")
text = out.read_text(encoding="utf-8")
else:
print(
f"# running chunk={chunk_seconds}s rep_penalty={repetition_penalty}",
file=sys.stderr,
)
t0 = time.time()
text = pipe.transcribe(args.audio, prompt_mode=args.prompt_mode)
elapsed = time.time() - t0
out.write_text(text + "\n", encoding="utf-8")
rows.append(
{
"chunk": chunk_seconds,
"rp": repetition_penalty,
"elapsed": elapsed,
**analyze(text),
}
)
print()
print("| chunk(s) | rp | wall(s) | words | max_trigram(N) | hallucinations |")
print("|---:|---:|---:|---:|:---|:---|")
for row in rows:
halluc = ", ".join(
f"{key.split()[0]}x{value}" for key, value in row["halluc"].items() if value
) or "-"
trigram = f"{row['max_trigram_text']!r} ({row['max_trigram_count']}x)"
wall = "nan" if row["elapsed"] != row["elapsed"] else f"{row['elapsed']:.0f}"
print(
f"| {row['chunk']} | {row['rp']:.1f} | {wall} | {row['n_words']} "
f"| {trigram} | {halluc} |"
)
print()
print(f"Per-config transcripts in: {results_dir}")
return 0
if __name__ == "__main__":
sys.exit(main())

47
scripts/transcribe.py Executable file
View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python
from __future__ import annotations
import argparse
import sys
from pathlib import Path
from granite_speech_plus_mlx import GraniteSpeechPlusPipeline
from granite_speech_plus_mlx.pipeline import DEFAULT_MODEL
from granite_speech_plus_mlx.prompts import GRANITE_SYSTEM_PROMPT, PROMPT_MODES
def main() -> int:
parser = argparse.ArgumentParser(description="Transcribe audio with Granite Speech Plus MLX.")
parser.add_argument("audio")
parser.add_argument("--model", default=DEFAULT_MODEL)
parser.add_argument("--output", default=None)
parser.add_argument("--chunk-seconds", type=float, default=300.0)
parser.add_argument("--overlap-seconds", type=float, default=2.0)
parser.add_argument("--prompt-mode", choices=sorted(PROMPT_MODES), default="asr")
parser.add_argument("--repetition-penalty", type=float, default=1.2)
parser.add_argument("--max-tokens", type=int, default=4096)
parser.add_argument("--system-prompt", default=GRANITE_SYSTEM_PROMPT)
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()
pipe = GraniteSpeechPlusPipeline.from_pretrained(
args.model,
chunk_seconds=args.chunk_seconds,
overlap_seconds=args.overlap_seconds,
repetition_penalty=args.repetition_penalty,
max_tokens=args.max_tokens,
system_prompt=args.system_prompt or None,
verbose=args.verbose,
)
text = pipe.transcribe(args.audio, prompt_mode=args.prompt_mode)
if args.output:
Path(args.output).write_text(text + "\n", encoding="utf-8")
else:
print(text)
return 0
if __name__ == "__main__":
sys.exit(main())

66
scripts/upload_to_hf.py Executable file
View File

@@ -0,0 +1,66 @@
#!/usr/bin/env python
from __future__ import annotations
import os
import sys
from pathlib import Path
from huggingface_hub import HfApi
SOURCE_CACHE = (
Path.home()
/ ".cache/huggingface/hub/models--ibm-granite--granite-speech-4.1-2b-plus"
)
DEST_REPO = "mlx-community/granite-speech-4.1-2b-plus-mlx"
def find_weights_dir(root: Path) -> Path | None:
if not root.exists():
return None
if list(root.glob("*.safetensors")) or (root / "config.json").exists():
return root
snapshots = root / "snapshots"
if snapshots.exists():
candidates = [
path
for path in snapshots.iterdir()
if path.is_dir() and (list(path.glob("*.safetensors")) or (path / "config.json").exists())
]
if candidates:
return sorted(candidates, key=lambda p: p.stat().st_mtime)[-1]
return None
def print_manual_commands() -> None:
print(f"MLX weights not found at {SOURCE_CACHE}")
print("Create them first with:")
print("mlxconv ibm-granite/granite-speech-4.1-2b-plus")
print("mlxconv ibm-granite/granite-speech-4.1-2b-plus --dtype q4_k_4")
def main() -> int:
weights_dir = find_weights_dir(SOURCE_CACHE)
if weights_dir is None:
print_manual_commands()
return 1
token = os.environ.get("HF_TOKEN")
if not token:
print("HF_TOKEN is required to upload.", file=sys.stderr)
return 2
api = HfApi(token=token)
api.create_repo(DEST_REPO, repo_type="model", exist_ok=True)
api.upload_folder(
repo_id=DEST_REPO,
repo_type="model",
folder_path=str(weights_dir),
commit_message="Upload Granite Speech 4.1-2b-plus MLX weights",
)
print(f"Uploaded {weights_dir} to {DEST_REPO}")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,4 @@
from .pipeline import GraniteSpeechPlusPipeline
__all__ = ["GraniteSpeechPlusPipeline"]

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,15 @@
from __future__ import annotations
from pathlib import Path
import librosa
SAMPLE_RATE = 16000
def load_audio(file: str | Path, sr: int = SAMPLE_RATE):
import mlx.core as mx
audio, _ = librosa.load(str(file), sr=sr, mono=True)
return mx.array(audio, dtype=mx.float32)

View File

@@ -0,0 +1,18 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List
@dataclass
class STTOutput:
text: str
segments: List[dict] | None = None
language: str | None = None
prompt_tokens: int = 0
generation_tokens: int = 0
total_tokens: int = 0
prompt_tps: float = 0.0
generation_tps: float = 0.0
total_time: float = 0.0

View File

@@ -0,0 +1,161 @@
from __future__ import annotations
import math
from functools import lru_cache
from typing import Optional
import mlx.core as mx
@lru_cache(maxsize=None)
def hanning(size: int, periodic: bool = False):
denom = size if periodic else size - 1
return mx.array(
[0.5 * (1 - math.cos(2 * math.pi * n / denom)) for n in range(size)]
)
@lru_cache(maxsize=None)
def hamming(size: int, periodic: bool = False):
denom = size if periodic else size - 1
return mx.array(
[0.54 - 0.46 * math.cos(2 * math.pi * n / denom) for n in range(size)]
)
@lru_cache(maxsize=None)
def blackman(size: int, periodic: bool = False):
denom = size if periodic else size - 1
return mx.array(
[
0.42
- 0.5 * math.cos(2 * math.pi * n / denom)
+ 0.08 * math.cos(4 * math.pi * n / denom)
for n in range(size)
]
)
@lru_cache(maxsize=None)
def bartlett(size: int, periodic: bool = False):
denom = size if periodic else size - 1
return mx.array([1 - 2 * abs(n - denom / 2) / denom for n in range(size)])
STR_TO_WINDOW_FN = {
"hann": hanning,
"hanning": hanning,
"hamming": hamming,
"blackman": blackman,
"bartlett": bartlett,
}
def stft(
x,
n_fft: int = 800,
hop_length: int | None = None,
win_length: int | None = None,
window: mx.array | str = "hann",
center: bool = True,
pad_mode: str = "reflect",
):
if hop_length is None:
hop_length = n_fft // 4
if win_length is None:
win_length = n_fft
if isinstance(window, str):
window_fn = STR_TO_WINDOW_FN.get(window.lower())
if window_fn is None:
raise ValueError(f"Unknown window function: {window}")
w = window_fn(win_length)
else:
w = window
if w.shape[0] < n_fft:
pad_size = n_fft - w.shape[0]
w = mx.concatenate([w, mx.zeros((pad_size,))], axis=0)
def _pad(signal, padding: int, mode: str = "reflect"):
if mode == "constant":
return mx.pad(signal, [(padding, padding)])
if mode == "reflect":
prefix = signal[1 : padding + 1][::-1]
suffix = signal[-(padding + 1) : -1][::-1]
return mx.concatenate([prefix, signal, suffix])
raise ValueError(f"Invalid pad_mode {mode}")
if center:
x = _pad(x, n_fft // 2, pad_mode)
num_frames = 1 + (x.shape[0] - n_fft) // hop_length
if num_frames <= 0:
raise ValueError(
f"Input is too short for n_fft={n_fft}, hop_length={hop_length}, "
f"center={center}."
)
frames = mx.as_strided(x, shape=(num_frames, n_fft), strides=(hop_length, 1))
return mx.fft.rfft(frames * w)
@lru_cache(maxsize=None)
def mel_filters(
sample_rate: int,
n_fft: int,
n_mels: int,
f_min: float = 0,
f_max: Optional[float] = None,
norm: Optional[str] = None,
mel_scale: str = "htk",
) -> mx.array:
def hz_to_mel(freq, scale="htk"):
if scale == "htk":
return 2595.0 * math.log10(1.0 + freq / 700.0)
f_sp = 200.0 / 3
mels = freq / f_sp
min_log_hz = 1000.0
min_log_mel = min_log_hz / f_sp
logstep = math.log(6.4) / 27.0
if freq >= min_log_hz:
mels = min_log_mel + math.log(freq / min_log_hz) / logstep
return mels
def mel_to_hz(mels, scale="htk"):
if scale == "htk":
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
f_sp = 200.0 / 3
freqs = f_sp * mels
min_log_hz = 1000.0
min_log_mel = min_log_hz / f_sp
logstep = math.log(6.4) / 27.0
return mx.where(
mels >= min_log_mel,
min_log_hz * mx.exp(logstep * (mels - min_log_mel)),
freqs,
)
f_max = f_max or sample_rate / 2
n_freqs = n_fft // 2 + 1
all_freqs = mx.linspace(0, sample_rate // 2, n_freqs)
m_min = hz_to_mel(f_min, mel_scale)
m_max = hz_to_mel(f_max, mel_scale)
m_pts = mx.linspace(m_min, m_max, n_mels + 2)
f_pts = mel_to_hz(m_pts, mel_scale)
f_diff = f_pts[1:] - f_pts[:-1]
slopes = mx.expand_dims(f_pts, 0) - mx.expand_dims(all_freqs, 1)
down_slopes = (-slopes[:, :-2]) / f_diff[:-1]
up_slopes = slopes[:, 2:] / f_diff[1:]
filterbank = mx.maximum(
mx.zeros_like(down_slopes), mx.minimum(down_slopes, up_slopes)
)
if norm == "slaney":
enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
filterbank *= mx.expand_dims(enorm, 0)
return filterbank.moveaxis(0, 1)

View File

@@ -0,0 +1,36 @@
from __future__ import annotations
from dataclasses import dataclass
from .config import EncoderConfig, ModelConfig, ProjectorConfig, TextConfig
from .granite_speech import Model
@dataclass
class GraniteSpeechPlusModelConfig(ModelConfig):
model_type: str = "granite_speech_plus"
DETECTION_HINTS = {
"config_keys": {"encoder_config", "projector_config", "audio_token_index"},
"architectures": {
"GraniteSpeechForConditionalGeneration",
"GraniteSpeechPlusForConditionalGeneration",
},
"path_patterns": {
"granite_speech_plus",
"granitespeechplus",
"granite-speech-4.1-2b-plus",
},
}
__all__ = [
"EncoderConfig",
"ProjectorConfig",
"TextConfig",
"ModelConfig",
"GraniteSpeechPlusModelConfig",
"Model",
"DETECTION_HINTS",
]

View File

@@ -0,0 +1,128 @@
import inspect
from dataclasses import dataclass, field
from typing import Dict, List, Optional
@dataclass
class EncoderConfig:
input_dim: int = 160
num_layers: int = 10
hidden_dim: int = 1024
feedforward_mult: int = 4
num_heads: int = 8
dim_head: int = 128
output_dim: int = 42
context_size: int = 200
max_pos_emb: int = 512
dropout: float = 0.1
conv_kernel_size: int = 15
conv_expansion_factor: int = 2
# Plus variant: indices of intermediate encoder layers whose hidden state
# gets concatenated with the final-layer hidden state along the channel
# axis, before being fed to the projector. None / empty = base behavior.
cat_hidden_layers: Optional[List[int]] = None
model_type: str = "granite_speech_encoder"
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
@dataclass
class ProjectorConfig:
hidden_size: int = 1024
num_hidden_layers: int = 2
num_attention_heads: int = 16
intermediate_size: int = 4096
hidden_act: str = "gelu"
layer_norm_eps: float = 1e-12
encoder_hidden_size: int = 1024
cross_attention_frequency: int = 1
model_type: str = "blip_2_qformer"
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
@dataclass
class TextConfig:
model_type: str = "granite"
vocab_size: int = 100353
hidden_size: int = 2048
intermediate_size: int = 4096
num_hidden_layers: int = 40
num_attention_heads: int = 16
num_key_value_heads: int = 4
hidden_act: str = "silu"
max_position_embeddings: int = 4096
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
rope_scaling: Optional[Dict] = None
attention_bias: bool = False
mlp_bias: bool = False
attention_multiplier: float = 0.0078125
embedding_multiplier: float = 12.0
residual_multiplier: float = 0.22
logits_scaling: float = 8.0
tie_word_embeddings: bool = False
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)
@dataclass
class ModelConfig:
model_type: str = "granite_speech"
encoder_config: EncoderConfig = None
projector_config: ProjectorConfig = None
text_config: TextConfig = None
audio_token_index: int = 100352
downsample_rate: int = 5
window_size: int = 15
has_lora_adapter: bool = False
def __post_init__(self):
if isinstance(self.encoder_config, dict):
self.encoder_config = EncoderConfig.from_dict(self.encoder_config)
elif self.encoder_config is None:
self.encoder_config = EncoderConfig()
if isinstance(self.projector_config, dict):
self.projector_config = ProjectorConfig.from_dict(self.projector_config)
elif self.projector_config is None:
self.projector_config = ProjectorConfig()
if isinstance(self.text_config, dict):
self.text_config = TextConfig.from_dict(self.text_config)
elif self.text_config is None:
self.text_config = TextConfig()
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)

View File

@@ -0,0 +1,850 @@
import math
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Generator, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_flatten
from mlx_lm.models.base import create_attention_mask
from mlx_lm.models.cache import KVCache
from mlx_lm.models.granite import Model as GraniteLM
from mlx_lm.models.granite import ModelArgs as GraniteModelArgs
from ..base import STTOutput
from .config import EncoderConfig, ModelConfig, ProjectorConfig
LANGUAGE_CODES = {
"en": "English",
"fr": "French",
"de": "German",
"es": "Spanish",
"pt": "Portuguese",
"ja": "Japanese",
}
@dataclass
class StreamingResult:
text: str
is_final: bool
start_time: float
end_time: float
language: str = "en"
prompt_tokens: int = 0
generation_tokens: int = 0
class BatchNorm1d(nn.Module):
def __init__(self, num_features: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((num_features,))
self.bias = mx.zeros((num_features,))
self.running_mean = mx.zeros((num_features,))
self.running_var = mx.ones((num_features,))
self.eps = eps
def __call__(self, x: mx.array) -> mx.array:
return (x - self.running_mean) / mx.sqrt(
self.running_var + self.eps
) * self.weight + self.bias
class ConformerFeedForward(nn.Module):
def __init__(self, config: EncoderConfig):
super().__init__()
self.pre_norm = nn.LayerNorm(config.hidden_dim)
self.up_proj = nn.Linear(
config.hidden_dim, config.hidden_dim * config.feedforward_mult
)
self.down_proj = nn.Linear(
config.hidden_dim * config.feedforward_mult, config.hidden_dim
)
def __call__(self, x: mx.array) -> mx.array:
x = self.pre_norm(x)
x = nn.silu(self.up_proj(x))
x = self.down_proj(x)
return x
class ConformerAttention(nn.Module):
def __init__(self, config: EncoderConfig):
super().__init__()
inner_dim = config.dim_head * config.num_heads
self.max_pos_emb = config.max_pos_emb
self.context_size = config.context_size
self.num_heads = config.num_heads
self.dim_head = config.dim_head
self.scale = config.dim_head**-0.5
self.pre_norm = nn.LayerNorm(config.hidden_dim)
self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False)
self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, config.hidden_dim)
self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head)
def __call__(self, x: mx.array, attention_dists: mx.array) -> mx.array:
x = self.pre_norm(x)
B, N, _ = x.shape
num_blocks = math.ceil(N / self.context_size)
remainder = N % self.context_size
if remainder > 0:
pad_len = self.context_size - remainder
x = mx.pad(x, [(0, 0), (0, pad_len), (0, 0)])
q = self.to_q(x)
kv = self.to_kv(x)
k, v = mx.split(kv, 2, axis=-1)
q = q.reshape(B, num_blocks, self.context_size, self.num_heads, -1)
k = k.reshape(B, num_blocks, self.context_size, self.num_heads, -1)
v = v.reshape(B, num_blocks, self.context_size, self.num_heads, -1)
q = q.transpose(0, 1, 3, 2, 4)
k = k.transpose(0, 1, 3, 2, 4)
v = v.transpose(0, 1, 3, 2, 4)
rel_pos_emb = self.rel_pos_emb(attention_dists)
C = self.context_size
pos_attn = (
mx.sum(
q[:, :, :, :, None, :] * rel_pos_emb[None, None, None, :, :, :],
axis=-1,
)
* self.scale
)
if remainder > 0:
row_valid = mx.arange(C)[:, None] < remainder
col_valid = mx.arange(C)[None, :] < remainder
mask = ~(row_valid & col_valid)
mask_value = mx.array(mx.finfo(pos_attn.dtype).min)
pos_attn_last = mx.where(
mask[None, None, None], mask_value, pos_attn[:, -1:, :, :, :]
)
pos_attn = mx.concatenate(
[pos_attn[:, :-1, :, :, :], pos_attn_last], axis=1
)
attn_weights = (q @ k.transpose(0, 1, 2, 4, 3)) * self.scale + pos_attn
attn_weights = mx.softmax(attn_weights, axis=-1)
out = attn_weights @ v
out = out.transpose(0, 1, 3, 2, 4)
out = out.reshape(B, -1, self.num_heads * self.dim_head)
out = out[:, :N, :]
out = self.to_out(out)
return out
class DepthWiseConv1d(nn.Module):
def __init__(self, chan_in: int, chan_out: int, kernel_size: int):
super().__init__()
pad = kernel_size // 2
pad_offset = (kernel_size + 1) % 2
self.padding = (pad, pad - pad_offset)
self.conv = nn.Conv1d(
chan_in, chan_out, kernel_size, groups=chan_in, bias=False
)
def __call__(self, x: mx.array) -> mx.array:
x = mx.pad(x, [(0, 0), (self.padding[0], self.padding[1]), (0, 0)])
return self.conv(x)
class ConformerConvModule(nn.Module):
def __init__(self, config: EncoderConfig):
super().__init__()
inner_dim = config.hidden_dim * config.conv_expansion_factor
self.norm = nn.LayerNorm(config.hidden_dim)
self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1)
self.depth_conv = DepthWiseConv1d(inner_dim, inner_dim, config.conv_kernel_size)
self.batch_norm = BatchNorm1d(inner_dim)
self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1)
def __call__(self, x: mx.array) -> mx.array:
x = self.norm(x)
x = self.up_conv(x)
x1, x2 = mx.split(x, 2, axis=-1)
x = x1 * mx.sigmoid(x2)
x = self.depth_conv(x)
x = nn.silu(self.batch_norm(x))
x = self.down_conv(x)
return x
class ConformerBlock(nn.Module):
def __init__(self, config: EncoderConfig):
super().__init__()
self.ff1 = ConformerFeedForward(config)
self.attn = ConformerAttention(config)
self.conv = ConformerConvModule(config)
self.ff2 = ConformerFeedForward(config)
self.post_norm = nn.LayerNorm(config.hidden_dim)
def __call__(self, x: mx.array, attention_dists: mx.array) -> mx.array:
x = 0.5 * self.ff1(x) + x
x = self.attn(x, attention_dists) + x
x = self.conv(x) + x
x = 0.5 * self.ff2(x) + x
x = self.post_norm(x)
return x
class CTCEncoder(nn.Module):
def __init__(self, config: EncoderConfig):
super().__init__()
self.config = config
self.input_linear = nn.Linear(config.input_dim, config.hidden_dim)
self.layers = [ConformerBlock(config) for _ in range(config.num_layers)]
self.out = nn.Linear(config.hidden_dim, config.output_dim)
self.out_mid = nn.Linear(config.output_dim, config.hidden_dim)
self.num_layers = config.num_layers
self._attention_dists = None
seq = mx.arange(config.context_size)
relpos_dist = seq[:, None] - seq[None, :]
self._attention_dists = (
mx.clip(relpos_dist, -config.context_size, config.context_size)
+ config.max_pos_emb
)
def __call__(self, x: mx.array) -> mx.array:
x = self.input_linear(x)
cat_layers = set(self.config.cat_hidden_layers or [])
exported_hidden_states = []
if 0 in cat_layers:
exported_hidden_states.append(x)
for idx, layer in enumerate(self.layers, start=1):
x = layer(x, attention_dists=self._attention_dists)
if idx in cat_layers:
exported_hidden_states.append(x)
if idx == self.num_layers // 2:
x_mid = self.out(x)
x = x + self.out_mid(mx.softmax(x_mid, axis=-1))
if exported_hidden_states:
# Plus variant: prepend captured intermediate hidden states to the
# final-layer output along the channel axis. Order matches the
# upstream Transformers implementation: intermediates first, then
# final.
x = mx.concatenate([*exported_hidden_states, x], axis=-1)
return x
class QFormerMultiHeadAttention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, kv_hidden_size: int = None):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
kv_dim = kv_hidden_size or hidden_size
self.query = nn.Linear(hidden_size, hidden_size)
self.key = nn.Linear(kv_dim, hidden_size)
self.value = nn.Linear(kv_dim, hidden_size)
def __call__(
self, hidden_states: mx.array, encoder_hidden_states: mx.array = None
) -> mx.array:
B, L, _ = hidden_states.shape
q = self.query(hidden_states)
kv_input = (
encoder_hidden_states
if encoder_hidden_states is not None
else hidden_states
)
k = self.key(kv_input)
v = self.value(kv_input)
q = q.reshape(B, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
k = k.reshape(B, -1, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
v = v.reshape(B, -1, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
scale = self.head_dim**-0.5
attn = (q * scale) @ k.transpose(0, 1, 3, 2)
attn = mx.softmax(attn, axis=-1)
out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, L, -1)
return out
class QFormerSelfOutput(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-12):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.LayerNorm = nn.LayerNorm(hidden_size, eps=eps)
def __call__(self, hidden_states: mx.array, input_tensor: mx.array) -> mx.array:
hidden_states = self.dense(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class QFormerAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
kv_hidden_size: int = None,
eps: float = 1e-12,
):
super().__init__()
self.attention = QFormerMultiHeadAttention(
hidden_size, num_heads, kv_hidden_size
)
self.output = QFormerSelfOutput(hidden_size, eps)
def __call__(
self, hidden_states: mx.array, encoder_hidden_states: mx.array = None
) -> mx.array:
attn_out = self.attention(hidden_states, encoder_hidden_states)
return self.output(attn_out, hidden_states)
class QFormerIntermediate(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.dense = nn.Linear(hidden_size, intermediate_size)
def __call__(self, x: mx.array) -> mx.array:
return nn.gelu(self.dense(x))
class QFormerOutput(nn.Module):
def __init__(self, intermediate_size: int, hidden_size: int, eps: float = 1e-12):
super().__init__()
self.dense = nn.Linear(intermediate_size, hidden_size)
self.LayerNorm = nn.LayerNorm(hidden_size, eps=eps)
def __call__(self, hidden_states: mx.array, input_tensor: mx.array) -> mx.array:
hidden_states = self.dense(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class QFormerLayer(nn.Module):
def __init__(self, config: ProjectorConfig):
super().__init__()
self.attention = QFormerAttention(
config.hidden_size, config.num_attention_heads, eps=config.layer_norm_eps
)
self.crossattention = QFormerAttention(
config.hidden_size,
config.num_attention_heads,
kv_hidden_size=config.encoder_hidden_size,
eps=config.layer_norm_eps,
)
self.intermediate_query = QFormerIntermediate(
config.hidden_size, config.intermediate_size
)
self.output_query = QFormerOutput(
config.intermediate_size, config.hidden_size, eps=config.layer_norm_eps
)
def __call__(
self, hidden_states: mx.array, encoder_hidden_states: mx.array
) -> mx.array:
hidden_states = self.attention(hidden_states)
hidden_states = self.crossattention(hidden_states, encoder_hidden_states)
intermediate = self.intermediate_query(hidden_states)
hidden_states = self.output_query(intermediate, hidden_states)
return hidden_states
class QFormerEncoder(nn.Module):
def __init__(self, config: ProjectorConfig):
super().__init__()
self.layer = [QFormerLayer(config) for _ in range(config.num_hidden_layers)]
def __call__(
self, hidden_states: mx.array, encoder_hidden_states: mx.array
) -> mx.array:
for layer in self.layer:
hidden_states = layer(hidden_states, encoder_hidden_states)
return hidden_states
class QFormerModel(nn.Module):
def __init__(self, config: ProjectorConfig):
super().__init__()
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.encoder = QFormerEncoder(config)
def __call__(
self, query_embeds: mx.array, encoder_hidden_states: mx.array
) -> mx.array:
hidden_states = self.layernorm(query_embeds)
return self.encoder(hidden_states, encoder_hidden_states)
class EncoderProjector(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.hidden_size = config.projector_config.hidden_size
self.downsample_rate = config.downsample_rate
self.window_size = config.window_size
self.num_queries = config.window_size // config.downsample_rate
self.query = mx.zeros(
(1, self.num_queries, config.projector_config.hidden_size)
)
self.qformer = QFormerModel(config.projector_config)
self.linear = nn.Linear(
config.projector_config.hidden_size, config.text_config.hidden_size
)
def __call__(self, hidden_states: mx.array) -> mx.array:
B, L, D = hidden_states.shape
nblocks = math.ceil(L / self.window_size)
pad = nblocks * self.window_size - L
if pad > 0:
hidden_states = mx.pad(hidden_states, [(0, 0), (0, pad), (0, 0)])
hidden_states = hidden_states.reshape(B * nblocks, self.window_size, D)
query = mx.broadcast_to(
self.query, (B * nblocks, self.num_queries, self.hidden_size)
)
query_output = self.qformer(query, hidden_states)
query_proj = self.linear(
query_output.reshape(B, nblocks * self.num_queries, -1)
)
return query_proj
class Model(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
# Plus variant invariant: encoder concats len(cat_hidden_layers)+1
# hidden states channel-wise, so projector must accept that wider input.
cat_layers = config.encoder_config.cat_hidden_layers or []
expected_proj_in = config.encoder_config.hidden_dim * (len(cat_layers) + 1)
if config.projector_config.encoder_hidden_size != expected_proj_in:
raise ValueError(
f"projector_config.encoder_hidden_size ({config.projector_config.encoder_hidden_size}) "
f"must equal encoder_config.hidden_dim * (len(cat_hidden_layers) + 1) "
f"({config.encoder_config.hidden_dim} * {len(cat_layers) + 1} = {expected_proj_in})"
)
self.encoder = CTCEncoder(config.encoder_config)
self.projector = EncoderProjector(config)
text_args = GraniteModelArgs.from_dict(
config.text_config.__dict__
if hasattr(config.text_config, "__dict__")
else config.text_config
)
self.language_model = GraniteLM(text_args)
self.audio_token_id = config.audio_token_index
self._tokenizer = None
@property
def layers(self):
return self.language_model.model.layers
def make_cache(self) -> List[KVCache]:
return [KVCache() for _ in range(len(self.layers))]
def __call__(
self,
input_ids: mx.array,
cache: Optional[List[KVCache]] = None,
input_embeddings: Optional[mx.array] = None,
) -> mx.array:
if input_embeddings is not None:
h = input_embeddings
else:
h = self.language_model.model.embed_tokens(input_ids)
h = h * self.language_model.model.embedding_multiplier
if cache is None:
cache = [None] * len(self.language_model.model.layers)
mask = create_attention_mask(h, cache[0])
for layer, c in zip(self.language_model.model.layers, cache):
h = layer(h, mask, cache=c)
h = self.language_model.model.norm(h)
if self.language_model.args.tie_word_embeddings:
logits = self.language_model.model.embed_tokens.as_linear(h)
else:
logits = self.language_model.lm_head(h)
return logits / self.language_model.logits_scaling
def get_audio_features(self, input_features: mx.array) -> mx.array:
encoder_output = self.encoder(input_features)
projected = self.projector(encoder_output)
return projected
def model_quant_predicate(self, p: str, m: nn.Module) -> bool:
return not (p.startswith("encoder") or p.startswith("projector"))
def sanitize(self, weights: Dict[str, mx.array]) -> Dict[str, mx.array]:
# Compare incoming weight shapes against the model's already-initialized
# parameter shapes. This is idempotent across convert-time (PyTorch source
# layout) and inference-time load (MLX-native layout) and correct even for
# Conv1d kernel_size=1 layers where prior shape-ordering heuristics failed.
# Pattern adapted from cohere_asr.Model.sanitize.
model_weights = dict(tree_flatten(self.parameters()))
sanitized = {}
for k, v in weights.items():
if "num_batches_tracked" in k:
continue
# granite-speech-4.1 ships a separate out_llm.safetensors with
# top-level "weight"/"bias" keys (likely an audio CTC head). The
# standard Model class does not define this layer, so dropping these
# keys is required for the rest of the load to succeed. Inference
# behaviour that depends on this head is not yet supported.
if k in ("weight", "bias"):
continue
expected = model_weights.get(k)
if expected is not None and hasattr(expected, "shape"):
if v.shape != expected.shape and v.ndim == 3:
transposed = mx.transpose(v, (0, 2, 1))
if transposed.shape == expected.shape:
v = transposed
sanitized[k] = v
return sanitized
@classmethod
def post_load_hook(cls, model: "Model", model_path: Path) -> "Model":
import transformers
from transformers import AutoTokenizer
prev = transformers.logging.get_verbosity()
transformers.logging.set_verbosity_error()
try:
model._tokenizer = AutoTokenizer.from_pretrained(
str(model_path), trust_remote_code=True
)
finally:
transformers.logging.set_verbosity(prev)
return model
def _extract_features(
self, audio: Union[mx.array, np.ndarray]
) -> Tuple[mx.array, int]:
from ..dsp import hanning, mel_filters, stft
n_fft = 512
win_length = 400
hop_length = 160
n_mels = 80
sample_rate = 16000
if isinstance(audio, mx.array):
audio_1d = audio.reshape(-1)
else:
audio_1d = mx.array(audio.flatten(), dtype=mx.float32)
win = hanning(win_length, periodic=True)
pad_left = (n_fft - win_length) // 2
pad_right = n_fft - win_length - pad_left
win_padded = mx.concatenate(
[mx.zeros((pad_left,)), win, mx.zeros((pad_right,))]
)
spec = stft(
audio_1d,
n_fft=n_fft,
hop_length=hop_length,
window=win_padded,
center=True,
pad_mode="reflect",
)
power = mx.abs(spec) ** 2
mel_fb = mel_filters(sample_rate, n_fft, n_mels, mel_scale="htk")
mel_spec = power @ mel_fb.T
logmel = mx.log10(mx.clip(mel_spec, 1e-10, None))
mx_val = mx.max(logmel)
logmel = mx.maximum(logmel, mx_val - 8.0) / 4.0 + 1.0
if logmel.shape[0] % 2 == 1:
logmel = logmel[:-1]
encoder_input = logmel.reshape(-1, 2 * n_mels)
encoder_length = encoder_input.shape[0]
nblocks = math.ceil(encoder_length / self.config.window_size)
num_audio_tokens = nblocks * (
self.config.window_size // self.config.downsample_rate
)
input_features = encoder_input[None, :, :]
return input_features, num_audio_tokens
def _build_prompt(
self,
num_audio_tokens: int,
user_prompt: str = None,
system_prompt: str = None,
) -> mx.array:
if user_prompt is None:
user_prompt = "can you transcribe the speech into a written format?"
audio_placeholder = "<|audio|>" * num_audio_tokens
content = f"{audio_placeholder}{user_prompt}"
if getattr(self._tokenizer, "chat_template", None):
chat = []
if system_prompt:
chat.append({"role": "system", "content": system_prompt})
chat.append({"role": "user", "content": content})
prompt_str = self._tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True
)
else:
# Granite-3 chat format (granite-speech tokenizer ships without a
# chat_template attribute, but its vocab includes the role tokens).
sor, eor, eot = "<|start_of_role|>", "<|end_of_role|>", "<|end_of_text|>"
parts = []
if system_prompt:
parts.append(f"{sor}system{eor}{system_prompt}{eot}\n")
parts.append(f"{sor}user{eor}{content}{eot}\n")
parts.append(f"{sor}assistant{eor}")
prompt_str = "".join(parts)
prompt_ids = self._tokenizer.encode(prompt_str)
return mx.array(prompt_ids)
def _build_inputs_embeds(
self, input_ids: mx.array, audio_features: mx.array
) -> mx.array:
is_audio = input_ids == self.audio_token_id
llm_ids = mx.where(is_audio, 0, input_ids)
inputs_embeds = self.language_model.model.embed_tokens(llm_ids[None])
is_audio_np = np.array(is_audio)
audio_positions = np.where(is_audio_np)[0]
orig_dtype = inputs_embeds.dtype
embeds_np = np.array(inputs_embeds.astype(mx.float32))
audio_np = np.array(audio_features.astype(mx.float32))
num_audio = min(len(audio_positions), audio_np.shape[1])
embeds_np[0, audio_positions[:num_audio]] = audio_np[0, :num_audio]
return mx.array(embeds_np).astype(orig_dtype)
def generate(
self,
audio: Union[str, mx.array, np.ndarray],
*,
max_tokens: int = 4096,
temperature: float = 0.0,
top_p: float = 1.0,
top_k: int = 0,
min_p: float = 0.0,
repetition_penalty: Optional[float] = None,
repetition_context_size: int = 100,
prompt: str = None,
system_prompt: str = None,
language: str = None,
prefill_step_size: int = 2048,
verbose: bool = False,
stream: bool = False,
**kwargs,
) -> Union[STTOutput, Generator[StreamingResult, None, None]]:
if prompt is None and language is not None:
lang_name = LANGUAGE_CODES.get(language.lower(), language)
prompt = f"Translate the speech to {lang_name}."
if stream:
return self._stream_generate(
audio,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
prompt=prompt,
prefill_step_size=prefill_step_size,
verbose=verbose,
)
start_time = time.time()
from mlx_lm.generate import generate_step
from mlx_lm.sample_utils import make_logits_processors, make_sampler
audio_data = self._load_audio(audio)
input_features, num_audio_tokens = self._extract_features(audio_data)
if verbose:
print("Encoding audio...")
audio_features = self.get_audio_features(input_features)
mx.eval(audio_features)
prompt_ids = self._build_prompt(num_audio_tokens, prompt, system_prompt=system_prompt)
inputs_embeds = self._build_inputs_embeds(prompt_ids, audio_features)
mx.eval(inputs_embeds)
prompt_tokens = len(prompt_ids)
sampler = make_sampler(temperature, top_p=top_p, min_p=min_p, top_k=top_k)
logits_processors = make_logits_processors(
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
)
eos_token_id = self._tokenizer.eos_token_id
tokens = []
for token, logprobs in generate_step(
prompt=prompt_ids,
input_embeddings=inputs_embeds.squeeze(0),
model=self,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prefill_step_size=prefill_step_size,
):
if token == eos_token_id:
break
tokens.append(token)
text = self._tokenizer.decode(tokens, skip_special_tokens=True)
elapsed = time.time() - start_time
gen_tokens = len(tokens)
if verbose:
print(f"Prompt tokens: {prompt_tokens}")
print(f"Generation tokens: {gen_tokens}")
print(f"Total time: {elapsed:.2f}s")
if gen_tokens > 0:
print(f"Generation TPS: {gen_tokens / elapsed:.1f}")
return STTOutput(
text=text,
segments=[],
prompt_tokens=prompt_tokens,
generation_tokens=gen_tokens,
total_tokens=prompt_tokens + gen_tokens,
total_time=elapsed,
prompt_tps=prompt_tokens / elapsed if elapsed > 0 else 0,
generation_tps=gen_tokens / elapsed if elapsed > 0 else 0,
)
def _stream_generate(
self,
audio: Union[str, mx.array, np.ndarray],
*,
max_tokens: int = 4096,
temperature: float = 0.0,
top_p: float = 1.0,
top_k: int = 0,
min_p: float = 0.0,
repetition_penalty: Optional[float] = None,
repetition_context_size: int = 100,
prompt: str = None,
prefill_step_size: int = 2048,
verbose: bool = False,
) -> Generator[StreamingResult, None, None]:
from mlx_lm.generate import generate_step
from mlx_lm.sample_utils import make_logits_processors, make_sampler
audio_data = self._load_audio(audio)
input_features, num_audio_tokens = self._extract_features(audio_data)
audio_features = self.get_audio_features(input_features)
mx.eval(audio_features)
prompt_ids = self._build_prompt(num_audio_tokens, prompt)
inputs_embeds = self._build_inputs_embeds(prompt_ids, audio_features)
mx.eval(inputs_embeds)
prompt_token_count = len(prompt_ids)
sampler = make_sampler(temperature, top_p=top_p, min_p=min_p, top_k=top_k)
logits_processors = make_logits_processors(
repetition_penalty=repetition_penalty,
repetition_context_size=repetition_context_size,
)
eos_token_id = self._tokenizer.eos_token_id
gen_tokens = 0
for token, _ in generate_step(
prompt=prompt_ids,
input_embeddings=inputs_embeds.squeeze(0),
model=self,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prefill_step_size=prefill_step_size,
):
if token == eos_token_id:
break
gen_tokens += 1
text = self._tokenizer.decode([token], skip_special_tokens=True)
yield StreamingResult(
text=text,
is_final=False,
start_time=0.0,
end_time=0.0,
prompt_tokens=prompt_token_count,
generation_tokens=gen_tokens,
)
yield StreamingResult(
text="",
is_final=True,
start_time=0.0,
end_time=0.0,
prompt_tokens=prompt_token_count,
generation_tokens=gen_tokens,
)
def _load_audio(self, audio: Union[str, mx.array, np.ndarray]) -> mx.array:
if isinstance(audio, str):
from ..audio import load_audio
return load_audio(audio)
elif isinstance(audio, np.ndarray):
return mx.array(audio, dtype=mx.float32)
elif isinstance(audio, mx.array):
return audio
elif isinstance(audio, list):
audio_item = audio[0]
if isinstance(audio_item, str):
from ..audio import load_audio
return load_audio(audio_item)
return mx.array(np.array(audio_item), dtype=mx.float32)
raise TypeError(f"Unsupported audio type: {type(audio)}")

View File

@@ -0,0 +1,154 @@
from __future__ import annotations
import glob
import json
from pathlib import Path
from typing import Any
from huggingface_hub import snapshot_download
import mlx.core as mx
import mlx.nn as nn
from .granite_speech import Model, ModelConfig
DEFAULT_ALLOW_PATTERNS = [
"*.json",
"*.safetensors",
"*.py",
"*.model",
"*.tiktoken",
"*.txt",
"*.jsonl",
"*.yaml",
"*.npz",
]
def _is_local_path(path: str) -> bool:
return (
path.startswith(".")
or path.startswith("/")
or path.startswith("~")
or (len(path) > 1 and path[1] == ":")
)
def get_model_path(
path_or_hf_repo: str | Path,
*,
revision: str | None = None,
force_download: bool = False,
allow_patterns: list[str] | None = None,
) -> Path:
if isinstance(path_or_hf_repo, Path):
path = path_or_hf_repo.expanduser()
if path.exists():
return path
raise FileNotFoundError(f"Local path not found: {path_or_hf_repo}")
path = Path(path_or_hf_repo).expanduser()
if path.exists():
return path
if _is_local_path(path_or_hf_repo):
raise FileNotFoundError(f"Local path not found: {path_or_hf_repo}")
return Path(
snapshot_download(
path_or_hf_repo,
revision=revision,
allow_patterns=allow_patterns or DEFAULT_ALLOW_PATTERNS,
force_download=force_download,
)
)
def load_config(model_path: str | Path) -> dict[str, Any]:
model_path = Path(model_path)
config_file = model_path / "config.json"
if not config_file.exists():
raise FileNotFoundError(f"Config not found at {model_path}")
return json.loads(config_file.read_text(encoding="utf-8"))
def load_weights(model_path: Path) -> dict[str, mx.array]:
weight_files = sorted(glob.glob(str(model_path / "*.safetensors")))
if not weight_files:
weight_files = sorted(glob.glob(str(model_path / "*.npz")))
if not weight_files:
raise FileNotFoundError(
f"No weight files (safetensors or npz) found in {model_path}"
)
weights = {}
for weight_file in weight_files:
weights.update(mx.load(weight_file))
return weights
def apply_quantization(
model: nn.Module,
config: dict[str, Any],
weights: dict[str, mx.array],
model_quant_predicate=None,
) -> None:
quantization = config.get("quantization") or config.get("quantization_config")
if quantization is None:
return
group_size = quantization.get("group_size", 64)
def class_predicate(path, module):
if not hasattr(module, "to_quantized"):
return False
if hasattr(module, "weight") and module.weight.shape[-1] % group_size != 0:
return False
if model_quant_predicate is not None:
pred = model_quant_predicate(path, module)
if isinstance(pred, dict):
return pred
if not pred:
return False
if path in quantization:
return quantization[path]
return f"{path}.scales" in weights
nn.quantize(
model,
group_size=group_size,
bits=quantization["bits"],
mode=quantization.get("mode", "affine"),
class_predicate=class_predicate,
)
def load_model(
model_path: str | Path,
*,
lazy: bool = False,
strict: bool = False,
**kwargs: Any,
) -> nn.Module:
path = get_model_path(
model_path,
revision=kwargs.pop("revision", None),
force_download=kwargs.pop("force_download", False),
allow_patterns=kwargs.pop("allow_patterns", None),
)
config = load_config(path)
model = Model(ModelConfig.from_dict(config))
weights = load_weights(path)
if hasattr(model, "sanitize"):
weights = model.sanitize(weights)
apply_quantization(model, config, weights, model.model_quant_predicate)
model.load_weights(list(weights.items()), strict=strict)
if not lazy:
mx.eval(model.parameters())
model.eval()
if hasattr(Model, "post_load_hook"):
model = Model.post_load_hook(model, path)
return model

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
from dataclasses import dataclass
import re
from typing import Any, Iterable, Iterator
@dataclass(frozen=True)
class AudioChunk:
index: int
start: float
end: float
samples: Any
def chunk_audio(
audio: Any,
sr: int,
chunk_seconds: float,
overlap_seconds: float = 2.0,
) -> Iterator[AudioChunk]:
if chunk_seconds <= 0:
raise ValueError("chunk_seconds must be positive")
if overlap_seconds < 0:
raise ValueError("overlap_seconds cannot be negative")
chunk_samples = int(chunk_seconds * sr)
overlap_samples = int(overlap_seconds * sr)
if overlap_samples >= chunk_samples:
raise ValueError("overlap_seconds must be smaller than chunk_seconds")
step = chunk_samples - overlap_samples
n = len(audio)
pos = 0
index = 1
while pos < n:
end = min(pos + chunk_samples, n)
yield AudioChunk(index, pos / sr, end / sr, audio[pos:end])
if end == n:
break
pos += step
index += 1
def prefix_text(transcripts: Iterable[str], max_chars: int = 800) -> str:
text = "\n".join(t.strip() for t in transcripts if t and t.strip())
text = re.sub(r"^## \[[^\n]+\]\n", "", text, flags=re.MULTILINE)
text = re.sub(r"\s+", " ", text).strip()
if len(text) <= max_chars:
return text
tail = text[-max_chars:]
first_space = tail.find(" ")
if first_space > 0:
tail = tail[first_space + 1 :]
return tail.strip()

View File

@@ -0,0 +1,126 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import sys
import time
from typing import Any
from .chunking import chunk_audio, prefix_text
from .prompts import GRANITE_SYSTEM_PROMPT, PROMPT_MODES, build_prompt
DEFAULT_MODEL = "mlx-community/granite-speech-4.1-2b-plus-mlx"
@dataclass
class GraniteSpeechPlusPipeline:
model: Any
repo_id: str = DEFAULT_MODEL
chunk_seconds: float = 300.0
overlap_seconds: float = 2.0
repetition_penalty: float = 1.2
max_tokens: int = 4096
system_prompt: str | None = GRANITE_SYSTEM_PROMPT
verbose: bool = False
@classmethod
def from_pretrained(
cls,
repo_id: str = DEFAULT_MODEL,
*,
chunk_seconds: float = 300.0,
overlap_seconds: float = 2.0,
repetition_penalty: float = 1.2,
max_tokens: int = 4096,
system_prompt: str | None = GRANITE_SYSTEM_PROMPT,
verbose: bool = False,
**load_kwargs: Any,
) -> "GraniteSpeechPlusPipeline":
from ._vendored.loader import load_model
model = load_model(repo_id, **load_kwargs)
return cls(
model=model,
repo_id=repo_id,
chunk_seconds=chunk_seconds,
overlap_seconds=overlap_seconds,
repetition_penalty=repetition_penalty,
max_tokens=max_tokens,
system_prompt=system_prompt,
verbose=verbose,
)
def transcribe(self, audio_path: str | Path, prompt_mode: str = "asr") -> str:
import librosa
import numpy as np
if prompt_mode not in PROMPT_MODES:
modes = ", ".join(sorted(PROMPT_MODES))
raise ValueError(f"prompt_mode must be one of: {modes}")
audio_file = Path(audio_path).expanduser().resolve()
audio, sr = librosa.load(str(audio_file), sr=16000, mono=True)
audio = np.asarray(audio, dtype=np.float32)
duration = len(audio) / sr if sr else 0.0
chunks = list(
chunk_audio(
audio,
sr,
self.chunk_seconds,
overlap_seconds=self.overlap_seconds,
)
)
if self.verbose:
print(
f"Loaded {audio_file} ({duration:.1f}s, {len(chunks)} chunks)",
file=sys.stderr,
)
rendered: list[str] = []
plain_texts: list[str] = []
t_start = time.time()
for chunk in chunks:
prompt = build_prompt(
prompt_mode,
prefix_text=prefix_text(plain_texts),
)
kwargs: dict[str, Any] = {
"prompt": prompt,
"max_tokens": self.max_tokens,
}
if self.system_prompt:
kwargs["system_prompt"] = self.system_prompt
if self.repetition_penalty and self.repetition_penalty > 1.0:
kwargs["repetition_penalty"] = self.repetition_penalty
t0 = time.time()
result = self.model.generate(chunk.samples, **kwargs)
text = getattr(result, "text", result)
if not isinstance(text, str):
text = str(text)
text = text.strip()
plain_texts.append(text)
if len(chunks) > 1:
rendered.append(f"## [{chunk.start:.1f}s - {chunk.end:.1f}s]\n{text}")
else:
rendered.append(text)
if self.verbose:
elapsed = time.time() - t0
rtf = (chunk.end - chunk.start) / elapsed if elapsed > 0 else 0.0
print(
f"[{chunk.index:>3}/{len(chunks)}] "
f"{chunk.start:>6.1f}s-{chunk.end:<6.1f}s "
f"gen={elapsed:>5.1f}s rtf={rtf:>4.1f}x {text[:80]}",
file=sys.stderr,
)
if self.verbose:
elapsed = time.time() - t_start
rtf = duration / elapsed if elapsed > 0 else 0.0
print(f"Total: {elapsed:.1f}s, rtf={rtf:.1f}x", file=sys.stderr)
return "\n\n".join(rendered).strip()

View File

@@ -0,0 +1,65 @@
from __future__ import annotations
GRANITE_SYSTEM_PROMPT = (
"Knowledge Cutoff Date: April 2024.\n"
"Today's Date: December 19, 2024.\n"
"You are Granite, developed by IBM. You are a helpful AI assistant"
)
PROMPT_MODES = {
"asr": "can you transcribe the speech into a written format?",
"saa": (
"Speaker attribution: Transcribe and denote who is speaking by adding "
"[Speaker 1]: and [Speaker 2]: tags before speaker turns."
),
"ts": (
"Timestamps: Transcribe the speech. After each word, add a timestamp tag "
"showing the end time in centiseconds, e.g. hello [T:45] world [T:82]"
),
}
def granite3_chat_template(
content: str,
*,
system_prompt: str | None = GRANITE_SYSTEM_PROMPT,
add_generation_prompt: bool = True,
) -> str:
"""Build the Granite-3 chat template used when tokenizers omit one."""
start_role = "<|start_of_role|>"
end_role = "<|end_of_role|>"
end_text = "<|end_of_text|>"
parts: list[str] = []
if system_prompt:
parts.append(f"{start_role}system{end_role}{system_prompt}{end_text}\n")
parts.append(f"{start_role}user{end_role}{content}{end_text}\n")
if add_generation_prompt:
parts.append(f"{start_role}assistant{end_role}")
return "".join(parts)
def build_prompt(
prompt_mode: str = "asr",
*,
prefix_text: str | None = None,
custom_prompt: str | None = None,
) -> str:
if custom_prompt:
base = custom_prompt
else:
try:
base = PROMPT_MODES[prompt_mode]
except KeyError as exc:
modes = ", ".join(sorted(PROMPT_MODES))
raise ValueError(f"prompt_mode must be one of: {modes}") from exc
if not prefix_text:
return base
return (
f"{base}\n\n"
"Previous transcript context for continuity only. Do not repeat it:\n"
f"{prefix_text.strip()}"
)

6
tests/test_smoke.py Normal file
View File

@@ -0,0 +1,6 @@
from granite_speech_plus_mlx import GraniteSpeechPlusPipeline
def test_pipeline_symbol_exists():
assert GraniteSpeechPlusPipeline is not None