Initial Granite Speech Plus MLX package
This commit is contained in:
111
scripts/benchmark.py
Executable file
111
scripts/benchmark.py
Executable file
@@ -0,0 +1,111 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
from granite_speech_plus_mlx import GraniteSpeechPlusPipeline
|
||||
from granite_speech_plus_mlx.pipeline import DEFAULT_MODEL
|
||||
from granite_speech_plus_mlx.prompts import PROMPT_MODES
|
||||
|
||||
GRID = [
|
||||
(60, 1.0),
|
||||
(60, 1.2),
|
||||
(180, 1.0),
|
||||
(180, 1.2),
|
||||
(300, 1.0),
|
||||
(300, 1.2),
|
||||
(300, 1.4),
|
||||
]
|
||||
|
||||
HALLUCINATION_MARKERS = ("thank you very much", "merci d'avoir regarde")
|
||||
|
||||
|
||||
def analyze(text: str) -> dict:
|
||||
words = text.split()
|
||||
lower_words = text.lower().split()
|
||||
trigrams = Counter(
|
||||
" ".join(lower_words[i : i + 3]) for i in range(len(lower_words) - 2)
|
||||
)
|
||||
top = trigrams.most_common(5)
|
||||
lower = text.lower()
|
||||
return {
|
||||
"n_words": len(words),
|
||||
"max_trigram_count": top[0][1] if top else 0,
|
||||
"max_trigram_text": top[0][0] if top else "",
|
||||
"halluc": {m: lower.count(m) for m in HALLUCINATION_MARKERS},
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Benchmark Granite Speech Plus MLX settings.")
|
||||
parser.add_argument("audio")
|
||||
parser.add_argument("--model", default=DEFAULT_MODEL)
|
||||
parser.add_argument("--results", default="bench")
|
||||
parser.add_argument("--prompt-mode", choices=sorted(PROMPT_MODES), default="asr")
|
||||
parser.add_argument("--overlap-seconds", type=float, default=2.0)
|
||||
parser.add_argument("--max-tokens", type=int, default=4096)
|
||||
args = parser.parse_args()
|
||||
|
||||
results_dir = Path(args.results)
|
||||
results_dir.mkdir(parents=True, exist_ok=True)
|
||||
pipe = GraniteSpeechPlusPipeline.from_pretrained(
|
||||
args.model,
|
||||
overlap_seconds=args.overlap_seconds,
|
||||
max_tokens=args.max_tokens,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
rows = []
|
||||
for chunk_seconds, repetition_penalty in GRID:
|
||||
out = results_dir / f"chunk{chunk_seconds}_rp{repetition_penalty:.1f}.txt"
|
||||
pipe.chunk_seconds = float(chunk_seconds)
|
||||
pipe.repetition_penalty = repetition_penalty
|
||||
|
||||
if out.exists():
|
||||
print(f"# skipping {out.name} (already exists, delete to rerun)", file=sys.stderr)
|
||||
elapsed = float("nan")
|
||||
text = out.read_text(encoding="utf-8")
|
||||
else:
|
||||
print(
|
||||
f"# running chunk={chunk_seconds}s rep_penalty={repetition_penalty}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
t0 = time.time()
|
||||
text = pipe.transcribe(args.audio, prompt_mode=args.prompt_mode)
|
||||
elapsed = time.time() - t0
|
||||
out.write_text(text + "\n", encoding="utf-8")
|
||||
|
||||
rows.append(
|
||||
{
|
||||
"chunk": chunk_seconds,
|
||||
"rp": repetition_penalty,
|
||||
"elapsed": elapsed,
|
||||
**analyze(text),
|
||||
}
|
||||
)
|
||||
|
||||
print()
|
||||
print("| chunk(s) | rp | wall(s) | words | max_trigram(N) | hallucinations |")
|
||||
print("|---:|---:|---:|---:|:---|:---|")
|
||||
for row in rows:
|
||||
halluc = ", ".join(
|
||||
f"{key.split()[0]}x{value}" for key, value in row["halluc"].items() if value
|
||||
) or "-"
|
||||
trigram = f"{row['max_trigram_text']!r} ({row['max_trigram_count']}x)"
|
||||
wall = "nan" if row["elapsed"] != row["elapsed"] else f"{row['elapsed']:.0f}"
|
||||
print(
|
||||
f"| {row['chunk']} | {row['rp']:.1f} | {wall} | {row['n_words']} "
|
||||
f"| {trigram} | {halluc} |"
|
||||
)
|
||||
print()
|
||||
print(f"Per-config transcripts in: {results_dir}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
47
scripts/transcribe.py
Executable file
47
scripts/transcribe.py
Executable file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from granite_speech_plus_mlx import GraniteSpeechPlusPipeline
|
||||
from granite_speech_plus_mlx.pipeline import DEFAULT_MODEL
|
||||
from granite_speech_plus_mlx.prompts import GRANITE_SYSTEM_PROMPT, PROMPT_MODES
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(description="Transcribe audio with Granite Speech Plus MLX.")
|
||||
parser.add_argument("audio")
|
||||
parser.add_argument("--model", default=DEFAULT_MODEL)
|
||||
parser.add_argument("--output", default=None)
|
||||
parser.add_argument("--chunk-seconds", type=float, default=300.0)
|
||||
parser.add_argument("--overlap-seconds", type=float, default=2.0)
|
||||
parser.add_argument("--prompt-mode", choices=sorted(PROMPT_MODES), default="asr")
|
||||
parser.add_argument("--repetition-penalty", type=float, default=1.2)
|
||||
parser.add_argument("--max-tokens", type=int, default=4096)
|
||||
parser.add_argument("--system-prompt", default=GRANITE_SYSTEM_PROMPT)
|
||||
parser.add_argument("--verbose", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
pipe = GraniteSpeechPlusPipeline.from_pretrained(
|
||||
args.model,
|
||||
chunk_seconds=args.chunk_seconds,
|
||||
overlap_seconds=args.overlap_seconds,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
max_tokens=args.max_tokens,
|
||||
system_prompt=args.system_prompt or None,
|
||||
verbose=args.verbose,
|
||||
)
|
||||
text = pipe.transcribe(args.audio, prompt_mode=args.prompt_mode)
|
||||
|
||||
if args.output:
|
||||
Path(args.output).write_text(text + "\n", encoding="utf-8")
|
||||
else:
|
||||
print(text)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
66
scripts/upload_to_hf.py
Executable file
66
scripts/upload_to_hf.py
Executable file
@@ -0,0 +1,66 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
SOURCE_CACHE = (
|
||||
Path.home()
|
||||
/ ".cache/huggingface/hub/models--ibm-granite--granite-speech-4.1-2b-plus"
|
||||
)
|
||||
DEST_REPO = "mlx-community/granite-speech-4.1-2b-plus-mlx"
|
||||
|
||||
|
||||
def find_weights_dir(root: Path) -> Path | None:
|
||||
if not root.exists():
|
||||
return None
|
||||
if list(root.glob("*.safetensors")) or (root / "config.json").exists():
|
||||
return root
|
||||
snapshots = root / "snapshots"
|
||||
if snapshots.exists():
|
||||
candidates = [
|
||||
path
|
||||
for path in snapshots.iterdir()
|
||||
if path.is_dir() and (list(path.glob("*.safetensors")) or (path / "config.json").exists())
|
||||
]
|
||||
if candidates:
|
||||
return sorted(candidates, key=lambda p: p.stat().st_mtime)[-1]
|
||||
return None
|
||||
|
||||
|
||||
def print_manual_commands() -> None:
|
||||
print(f"MLX weights not found at {SOURCE_CACHE}")
|
||||
print("Create them first with:")
|
||||
print("mlxconv ibm-granite/granite-speech-4.1-2b-plus")
|
||||
print("mlxconv ibm-granite/granite-speech-4.1-2b-plus --dtype q4_k_4")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
weights_dir = find_weights_dir(SOURCE_CACHE)
|
||||
if weights_dir is None:
|
||||
print_manual_commands()
|
||||
return 1
|
||||
|
||||
token = os.environ.get("HF_TOKEN")
|
||||
if not token:
|
||||
print("HF_TOKEN is required to upload.", file=sys.stderr)
|
||||
return 2
|
||||
|
||||
api = HfApi(token=token)
|
||||
api.create_repo(DEST_REPO, repo_type="model", exist_ok=True)
|
||||
api.upload_folder(
|
||||
repo_id=DEST_REPO,
|
||||
repo_type="model",
|
||||
folder_path=str(weights_dir),
|
||||
commit_message="Upload Granite Speech 4.1-2b-plus MLX weights",
|
||||
)
|
||||
print(f"Uploaded {weights_dir} to {DEST_REPO}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
Reference in New Issue
Block a user