"""Small-subset HMMT/AIME bench : vanilla mlx-lm vs Markovian RSA orchestrator. Usage : uv run python scripts/bench_hmmt.py \\ --subset hmmt_2025_subset \\ --n-problems 5 \\ --rounds 2 --parallel 4 \\ --output bench-out/hmmt_2025_subset.json """ from __future__ import annotations import argparse import json import re import sys import time from dataclasses import dataclass from pathlib import Path # Inline 5-problem HMMT'25-style subset (placeholder mini-set ; expand via --dataset later) _HMMT_2025_SUBSET = [ { "id": "hmmt-1", "question": "Find the number of positive integers n <= 100 such that n^2 + n is divisible by 6.", "answer": "100", }, { "id": "hmmt-2", "question": "Compute the smallest positive integer x such that 7^x ≡ 1 (mod 100).", "answer": "4", }, { "id": "hmmt-3", "question": "If f(x) = x^3 - 3x + 1 has roots a, b, c, compute a^2 + b^2 + c^2.", "answer": "6", }, { "id": "hmmt-4", "question": "How many ways can 4 distinct objects be split into 2 non-empty unordered groups?", "answer": "7", }, { "id": "hmmt-5", "question": "What is the remainder when 2^100 is divided by 125?", "answer": "76", }, ] _BOXED_RE = re.compile(r"\\boxed\{([^{}]+)\}") _NUMBER_RE = re.compile(r"-?\d+(?:\.\d+)?") @dataclass class SubsetScore: correct: int total: int accuracy: float def extract_final_answer(text: str) -> str: matches = _BOXED_RE.findall(text) if matches: return matches[-1].strip() nums = _NUMBER_RE.findall(text) if nums: return nums[-1].strip() return "" def score_subset(items: list[dict], predictions: list[str]) -> SubsetScore: correct = 0 for item, pred in zip(items, predictions): if extract_final_answer(pred) == item["answer"].strip(): correct += 1 total = len(items) return SubsetScore(correct=correct, total=total, accuracy=correct / max(total, 1)) def _vanilla_predict(orch, prompt: str, max_tokens: int) -> str: """One-shot decode with no aggregation : T=1, N=1.""" from markovian_rsa_mlx.config import RSAConfig cfg = RSAConfig(rounds=1, parallel=1, aggregation_subsample=1, chunk_tokens=max_tokens, tail_tokens=64, serial=True) return orch.solve(prompt, config=cfg) def _rsa_predict(orch, prompt: str, *, rounds: int, parallel: int, chunk: int) -> str: from markovian_rsa_mlx.config import RSAConfig cfg = RSAConfig(rounds=rounds, parallel=parallel, aggregation_subsample=min(parallel, 4), chunk_tokens=chunk, tail_tokens=4096, serial=parallel <= 2, seed=0) return orch.solve(prompt, config=cfg) def main() -> int: p = argparse.ArgumentParser(description=__doc__.splitlines()[0]) p.add_argument("--subset", default="hmmt_2025_subset", choices=["hmmt_2025_subset"]) p.add_argument("--n-problems", type=int, default=5) p.add_argument("--rounds", type=int, default=2) p.add_argument("--parallel", type=int, default=4) p.add_argument("--chunk-tokens", type=int, default=8192) p.add_argument("--model", default="kyr0/zaya1-base-8b-MLX") p.add_argument("--output", type=Path, default=None) args = p.parse_args() items = _HMMT_2025_SUBSET[: args.n_problems] from markovian_rsa_mlx import MarkovianRSAOrchestrator print(f"[bench] loading {args.model} ...", file=sys.stderr) orch = MarkovianRSAOrchestrator.from_pretrained(args.model) print(f"[bench] vanilla decode on {len(items)} problems ...", file=sys.stderr) t0 = time.time() vanilla = [_vanilla_predict(orch, it["question"], args.chunk_tokens) for it in items] vanilla_elapsed = time.time() - t0 vanilla_score = score_subset(items, vanilla) print(f"[bench] RSA T={args.rounds} N={args.parallel} ...", file=sys.stderr) t0 = time.time() rsa = [_rsa_predict(orch, it["question"], rounds=args.rounds, parallel=args.parallel, chunk=args.chunk_tokens) for it in items] rsa_elapsed = time.time() - t0 rsa_score = score_subset(items, rsa) summary = { "subset": args.subset, "n_problems": len(items), "model": args.model, "config": {"rounds": args.rounds, "parallel": args.parallel, "chunk_tokens": args.chunk_tokens}, "vanilla": {"correct": vanilla_score.correct, "total": vanilla_score.total, "accuracy": vanilla_score.accuracy, "elapsed_s": vanilla_elapsed}, "rsa": {"correct": rsa_score.correct, "total": rsa_score.total, "accuracy": rsa_score.accuracy, "elapsed_s": rsa_elapsed}, "lift_pp": (rsa_score.accuracy - vanilla_score.accuracy) * 100, "predictions": [ {"id": it["id"], "answer": it["answer"], "vanilla": v[:200] + "..." if len(v) > 200 else v, "rsa": r[:200] + "..." if len(r) > 200 else r} for it, v, r in zip(items, vanilla, rsa) ], } out = json.dumps(summary, indent=2, ensure_ascii=False) print(out) if args.output is not None: args.output.parent.mkdir(parents=True, exist_ok=True) args.output.write_text(out) return 0 if __name__ == "__main__": sys.exit(main())