From 67454162282267eed939f82a942553f315bdd0ba Mon Sep 17 00:00:00 2001 From: transcrilive Date: Sun, 10 May 2026 03:20:33 +0200 Subject: [PATCH] feat(bench): HMMT/AIME small-subset harness + answer extraction tests --- conftest.py | 7 ++ scripts/__init__.py | 0 scripts/bench_hmmt.py | 151 ++++++++++++++++++++++++++++++++++++ tests/test_bench_harness.py | 27 +++++++ 4 files changed, 185 insertions(+) create mode 100644 conftest.py create mode 100644 scripts/__init__.py create mode 100644 scripts/bench_hmmt.py create mode 100644 tests/test_bench_harness.py diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..2e64631 --- /dev/null +++ b/conftest.py @@ -0,0 +1,7 @@ +"""Root conftest : add repo root to sys.path so `from scripts.* import ...` works.""" +import sys +from pathlib import Path + +_ROOT = Path(__file__).resolve().parent +if str(_ROOT) not in sys.path: + sys.path.insert(0, str(_ROOT)) diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/bench_hmmt.py b/scripts/bench_hmmt.py new file mode 100644 index 0000000..9be1b68 --- /dev/null +++ b/scripts/bench_hmmt.py @@ -0,0 +1,151 @@ +"""Small-subset HMMT/AIME bench : vanilla mlx-lm vs Markovian RSA orchestrator. + +Usage : + uv run python scripts/bench_hmmt.py \\ + --subset hmmt_2025_subset \\ + --n-problems 5 \\ + --rounds 2 --parallel 4 \\ + --output bench-out/hmmt_2025_subset.json +""" +from __future__ import annotations +import argparse +import json +import re +import sys +import time +from dataclasses import dataclass +from pathlib import Path + +# Inline 5-problem HMMT'25-style subset (placeholder mini-set ; expand via --dataset later) +_HMMT_2025_SUBSET = [ + { + "id": "hmmt-1", + "question": "Find the number of positive integers n <= 100 such that n^2 + n is divisible by 6.", + "answer": "100", + }, + { + "id": "hmmt-2", + "question": "Compute the smallest positive integer x such that 7^x ≡ 1 (mod 100).", + "answer": "4", + }, + { + "id": "hmmt-3", + "question": "If f(x) = x^3 - 3x + 1 has roots a, b, c, compute a^2 + b^2 + c^2.", + "answer": "6", + }, + { + "id": "hmmt-4", + "question": "How many ways can 4 distinct objects be split into 2 non-empty unordered groups?", + "answer": "7", + }, + { + "id": "hmmt-5", + "question": "What is the remainder when 2^100 is divided by 125?", + "answer": "76", + }, +] + +_BOXED_RE = re.compile(r"\\boxed\{([^{}]+)\}") +_NUMBER_RE = re.compile(r"-?\d+(?:\.\d+)?") + + +@dataclass +class SubsetScore: + correct: int + total: int + accuracy: float + + +def extract_final_answer(text: str) -> str: + matches = _BOXED_RE.findall(text) + if matches: + return matches[-1].strip() + nums = _NUMBER_RE.findall(text) + if nums: + return nums[-1].strip() + return "" + + +def score_subset(items: list[dict], predictions: list[str]) -> SubsetScore: + correct = 0 + for item, pred in zip(items, predictions): + if extract_final_answer(pred) == item["answer"].strip(): + correct += 1 + total = len(items) + return SubsetScore(correct=correct, total=total, accuracy=correct / max(total, 1)) + + +def _vanilla_predict(orch, prompt: str, max_tokens: int) -> str: + """One-shot decode with no aggregation : T=1, N=1.""" + from markovian_rsa_mlx.config import RSAConfig + cfg = RSAConfig(rounds=1, parallel=1, aggregation_subsample=1, + chunk_tokens=max_tokens, tail_tokens=64, serial=True) + return orch.solve(prompt, config=cfg) + + +def _rsa_predict(orch, prompt: str, *, rounds: int, parallel: int, chunk: int) -> str: + from markovian_rsa_mlx.config import RSAConfig + cfg = RSAConfig(rounds=rounds, parallel=parallel, + aggregation_subsample=min(parallel, 4), + chunk_tokens=chunk, tail_tokens=4096, + serial=parallel <= 2, seed=0) + return orch.solve(prompt, config=cfg) + + +def main() -> int: + p = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + p.add_argument("--subset", default="hmmt_2025_subset", + choices=["hmmt_2025_subset"]) + p.add_argument("--n-problems", type=int, default=5) + p.add_argument("--rounds", type=int, default=2) + p.add_argument("--parallel", type=int, default=4) + p.add_argument("--chunk-tokens", type=int, default=8192) + p.add_argument("--model", default="kyr0/zaya1-base-8b-MLX") + p.add_argument("--output", type=Path, default=None) + args = p.parse_args() + + items = _HMMT_2025_SUBSET[: args.n_problems] + from markovian_rsa_mlx import MarkovianRSAOrchestrator + print(f"[bench] loading {args.model} ...", file=sys.stderr) + orch = MarkovianRSAOrchestrator.from_pretrained(args.model) + + print(f"[bench] vanilla decode on {len(items)} problems ...", file=sys.stderr) + t0 = time.time() + vanilla = [_vanilla_predict(orch, it["question"], args.chunk_tokens) for it in items] + vanilla_elapsed = time.time() - t0 + vanilla_score = score_subset(items, vanilla) + + print(f"[bench] RSA T={args.rounds} N={args.parallel} ...", file=sys.stderr) + t0 = time.time() + rsa = [_rsa_predict(orch, it["question"], rounds=args.rounds, + parallel=args.parallel, chunk=args.chunk_tokens) for it in items] + rsa_elapsed = time.time() - t0 + rsa_score = score_subset(items, rsa) + + summary = { + "subset": args.subset, "n_problems": len(items), + "model": args.model, + "config": {"rounds": args.rounds, "parallel": args.parallel, + "chunk_tokens": args.chunk_tokens}, + "vanilla": {"correct": vanilla_score.correct, "total": vanilla_score.total, + "accuracy": vanilla_score.accuracy, "elapsed_s": vanilla_elapsed}, + "rsa": {"correct": rsa_score.correct, "total": rsa_score.total, + "accuracy": rsa_score.accuracy, "elapsed_s": rsa_elapsed}, + "lift_pp": (rsa_score.accuracy - vanilla_score.accuracy) * 100, + "predictions": [ + {"id": it["id"], "answer": it["answer"], + "vanilla": v[:200] + "..." if len(v) > 200 else v, + "rsa": r[:200] + "..." if len(r) > 200 else r} + for it, v, r in zip(items, vanilla, rsa) + ], + } + out = json.dumps(summary, indent=2, ensure_ascii=False) + print(out) + if args.output is not None: + args.output.parent.mkdir(parents=True, exist_ok=True) + args.output.write_text(out) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/test_bench_harness.py b/tests/test_bench_harness.py new file mode 100644 index 0000000..fe6bd25 --- /dev/null +++ b/tests/test_bench_harness.py @@ -0,0 +1,27 @@ +from scripts.bench_hmmt import extract_final_answer, score_subset + + +def test_extract_final_answer_picks_last_boxed(): + text = "Long reasoning... \\boxed{42} done." + assert extract_final_answer(text) == "42" + + +def test_extract_final_answer_falls_back_to_last_number(): + text = "...therefore the answer is 17." + assert extract_final_answer(text) == "17" + + +def test_extract_final_answer_returns_empty_when_no_number(): + assert extract_final_answer("no answer here") == "" + + +def test_score_subset_counts_correct(): + items = [ + {"question": "q1", "answer": "42"}, + {"question": "q2", "answer": "100"}, + ] + predictions = ["The answer is 42.", "Final: 99"] + score = score_subset(items, predictions) + assert score.correct == 1 + assert score.total == 2 + assert abs(score.accuracy - 0.5) < 1e-6