#!/usr/bin/env python from __future__ import annotations import argparse import sys import time from collections import Counter from pathlib import Path from granite_speech_plus_mlx import GraniteSpeechPlusPipeline from granite_speech_plus_mlx.pipeline import DEFAULT_MODEL from granite_speech_plus_mlx.prompts import PROMPT_MODES GRID = [ (60, 1.0), (60, 1.2), (180, 1.0), (180, 1.2), (300, 1.0), (300, 1.2), (300, 1.4), ] HALLUCINATION_MARKERS = ("thank you very much", "merci d'avoir regarde") def analyze(text: str) -> dict: words = text.split() lower_words = text.lower().split() trigrams = Counter( " ".join(lower_words[i : i + 3]) for i in range(len(lower_words) - 2) ) top = trigrams.most_common(5) lower = text.lower() return { "n_words": len(words), "max_trigram_count": top[0][1] if top else 0, "max_trigram_text": top[0][0] if top else "", "halluc": {m: lower.count(m) for m in HALLUCINATION_MARKERS}, } def main() -> int: parser = argparse.ArgumentParser(description="Benchmark Granite Speech Plus MLX settings.") parser.add_argument("audio") parser.add_argument("--model", default=DEFAULT_MODEL) parser.add_argument("--results", default="bench") parser.add_argument("--prompt-mode", choices=sorted(PROMPT_MODES), default="asr") parser.add_argument("--overlap-seconds", type=float, default=2.0) parser.add_argument("--max-tokens", type=int, default=4096) args = parser.parse_args() results_dir = Path(args.results) results_dir.mkdir(parents=True, exist_ok=True) pipe = GraniteSpeechPlusPipeline.from_pretrained( args.model, overlap_seconds=args.overlap_seconds, max_tokens=args.max_tokens, verbose=True, ) rows = [] for chunk_seconds, repetition_penalty in GRID: out = results_dir / f"chunk{chunk_seconds}_rp{repetition_penalty:.1f}.txt" pipe.chunk_seconds = float(chunk_seconds) pipe.repetition_penalty = repetition_penalty if out.exists(): print(f"# skipping {out.name} (already exists, delete to rerun)", file=sys.stderr) elapsed = float("nan") text = out.read_text(encoding="utf-8") else: print( f"# running chunk={chunk_seconds}s rep_penalty={repetition_penalty}", file=sys.stderr, ) t0 = time.time() text = pipe.transcribe(args.audio, prompt_mode=args.prompt_mode) elapsed = time.time() - t0 out.write_text(text + "\n", encoding="utf-8") rows.append( { "chunk": chunk_seconds, "rp": repetition_penalty, "elapsed": elapsed, **analyze(text), } ) print() print("| chunk(s) | rp | wall(s) | words | max_trigram(N) | hallucinations |") print("|---:|---:|---:|---:|:---|:---|") for row in rows: halluc = ", ".join( f"{key.split()[0]}x{value}" for key, value in row["halluc"].items() if value ) or "-" trigram = f"{row['max_trigram_text']!r} ({row['max_trigram_count']}x)" wall = "nan" if row["elapsed"] != row["elapsed"] else f"{row['elapsed']:.0f}" print( f"| {row['chunk']} | {row['rp']:.1f} | {wall} | {row['n_words']} " f"| {trigram} | {halluc} |" ) print() print(f"Per-config transcripts in: {results_dir}") return 0 if __name__ == "__main__": sys.exit(main())