Files
granite-speech-4.1-2b-plus-mlx/scripts/benchmark.py
2026-05-09 20:00:57 +02:00

112 lines
3.5 KiB
Python
Executable File

#!/usr/bin/env python
from __future__ import annotations
import argparse
import sys
import time
from collections import Counter
from pathlib import Path
from granite_speech_plus_mlx import GraniteSpeechPlusPipeline
from granite_speech_plus_mlx.pipeline import DEFAULT_MODEL
from granite_speech_plus_mlx.prompts import PROMPT_MODES
GRID = [
(60, 1.0),
(60, 1.2),
(180, 1.0),
(180, 1.2),
(300, 1.0),
(300, 1.2),
(300, 1.4),
]
HALLUCINATION_MARKERS = ("thank you very much", "merci d'avoir regarde")
def analyze(text: str) -> dict:
words = text.split()
lower_words = text.lower().split()
trigrams = Counter(
" ".join(lower_words[i : i + 3]) for i in range(len(lower_words) - 2)
)
top = trigrams.most_common(5)
lower = text.lower()
return {
"n_words": len(words),
"max_trigram_count": top[0][1] if top else 0,
"max_trigram_text": top[0][0] if top else "",
"halluc": {m: lower.count(m) for m in HALLUCINATION_MARKERS},
}
def main() -> int:
parser = argparse.ArgumentParser(description="Benchmark Granite Speech Plus MLX settings.")
parser.add_argument("audio")
parser.add_argument("--model", default=DEFAULT_MODEL)
parser.add_argument("--results", default="bench")
parser.add_argument("--prompt-mode", choices=sorted(PROMPT_MODES), default="asr")
parser.add_argument("--overlap-seconds", type=float, default=2.0)
parser.add_argument("--max-tokens", type=int, default=4096)
args = parser.parse_args()
results_dir = Path(args.results)
results_dir.mkdir(parents=True, exist_ok=True)
pipe = GraniteSpeechPlusPipeline.from_pretrained(
args.model,
overlap_seconds=args.overlap_seconds,
max_tokens=args.max_tokens,
verbose=True,
)
rows = []
for chunk_seconds, repetition_penalty in GRID:
out = results_dir / f"chunk{chunk_seconds}_rp{repetition_penalty:.1f}.txt"
pipe.chunk_seconds = float(chunk_seconds)
pipe.repetition_penalty = repetition_penalty
if out.exists():
print(f"# skipping {out.name} (already exists, delete to rerun)", file=sys.stderr)
elapsed = float("nan")
text = out.read_text(encoding="utf-8")
else:
print(
f"# running chunk={chunk_seconds}s rep_penalty={repetition_penalty}",
file=sys.stderr,
)
t0 = time.time()
text = pipe.transcribe(args.audio, prompt_mode=args.prompt_mode)
elapsed = time.time() - t0
out.write_text(text + "\n", encoding="utf-8")
rows.append(
{
"chunk": chunk_seconds,
"rp": repetition_penalty,
"elapsed": elapsed,
**analyze(text),
}
)
print()
print("| chunk(s) | rp | wall(s) | words | max_trigram(N) | hallucinations |")
print("|---:|---:|---:|---:|:---|:---|")
for row in rows:
halluc = ", ".join(
f"{key.split()[0]}x{value}" for key, value in row["halluc"].items() if value
) or "-"
trigram = f"{row['max_trigram_text']!r} ({row['max_trigram_count']}x)"
wall = "nan" if row["elapsed"] != row["elapsed"] else f"{row['elapsed']:.0f}"
print(
f"| {row['chunk']} | {row['rp']:.1f} | {wall} | {row['n_words']} "
f"| {trigram} | {halluc} |"
)
print()
print(f"Per-config transcripts in: {results_dir}")
return 0
if __name__ == "__main__":
sys.exit(main())