feat(bench): HMMT/AIME small-subset harness + answer extraction tests
This commit is contained in:
7
conftest.py
Normal file
7
conftest.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""Root conftest : add repo root to sys.path so `from scripts.* import ...` works."""
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
_ROOT = Path(__file__).resolve().parent
|
||||||
|
if str(_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_ROOT))
|
||||||
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
151
scripts/bench_hmmt.py
Normal file
151
scripts/bench_hmmt.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
"""Small-subset HMMT/AIME bench : vanilla mlx-lm vs Markovian RSA orchestrator.
|
||||||
|
|
||||||
|
Usage :
|
||||||
|
uv run python scripts/bench_hmmt.py \\
|
||||||
|
--subset hmmt_2025_subset \\
|
||||||
|
--n-problems 5 \\
|
||||||
|
--rounds 2 --parallel 4 \\
|
||||||
|
--output bench-out/hmmt_2025_subset.json
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Inline 5-problem HMMT'25-style subset (placeholder mini-set ; expand via --dataset later)
|
||||||
|
_HMMT_2025_SUBSET = [
|
||||||
|
{
|
||||||
|
"id": "hmmt-1",
|
||||||
|
"question": "Find the number of positive integers n <= 100 such that n^2 + n is divisible by 6.",
|
||||||
|
"answer": "100",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "hmmt-2",
|
||||||
|
"question": "Compute the smallest positive integer x such that 7^x ≡ 1 (mod 100).",
|
||||||
|
"answer": "4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "hmmt-3",
|
||||||
|
"question": "If f(x) = x^3 - 3x + 1 has roots a, b, c, compute a^2 + b^2 + c^2.",
|
||||||
|
"answer": "6",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "hmmt-4",
|
||||||
|
"question": "How many ways can 4 distinct objects be split into 2 non-empty unordered groups?",
|
||||||
|
"answer": "7",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "hmmt-5",
|
||||||
|
"question": "What is the remainder when 2^100 is divided by 125?",
|
||||||
|
"answer": "76",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
_BOXED_RE = re.compile(r"\\boxed\{([^{}]+)\}")
|
||||||
|
_NUMBER_RE = re.compile(r"-?\d+(?:\.\d+)?")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SubsetScore:
|
||||||
|
correct: int
|
||||||
|
total: int
|
||||||
|
accuracy: float
|
||||||
|
|
||||||
|
|
||||||
|
def extract_final_answer(text: str) -> str:
|
||||||
|
matches = _BOXED_RE.findall(text)
|
||||||
|
if matches:
|
||||||
|
return matches[-1].strip()
|
||||||
|
nums = _NUMBER_RE.findall(text)
|
||||||
|
if nums:
|
||||||
|
return nums[-1].strip()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def score_subset(items: list[dict], predictions: list[str]) -> SubsetScore:
|
||||||
|
correct = 0
|
||||||
|
for item, pred in zip(items, predictions):
|
||||||
|
if extract_final_answer(pred) == item["answer"].strip():
|
||||||
|
correct += 1
|
||||||
|
total = len(items)
|
||||||
|
return SubsetScore(correct=correct, total=total, accuracy=correct / max(total, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def _vanilla_predict(orch, prompt: str, max_tokens: int) -> str:
|
||||||
|
"""One-shot decode with no aggregation : T=1, N=1."""
|
||||||
|
from markovian_rsa_mlx.config import RSAConfig
|
||||||
|
cfg = RSAConfig(rounds=1, parallel=1, aggregation_subsample=1,
|
||||||
|
chunk_tokens=max_tokens, tail_tokens=64, serial=True)
|
||||||
|
return orch.solve(prompt, config=cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def _rsa_predict(orch, prompt: str, *, rounds: int, parallel: int, chunk: int) -> str:
|
||||||
|
from markovian_rsa_mlx.config import RSAConfig
|
||||||
|
cfg = RSAConfig(rounds=rounds, parallel=parallel,
|
||||||
|
aggregation_subsample=min(parallel, 4),
|
||||||
|
chunk_tokens=chunk, tail_tokens=4096,
|
||||||
|
serial=parallel <= 2, seed=0)
|
||||||
|
return orch.solve(prompt, config=cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
p = argparse.ArgumentParser(description=__doc__.splitlines()[0])
|
||||||
|
p.add_argument("--subset", default="hmmt_2025_subset",
|
||||||
|
choices=["hmmt_2025_subset"])
|
||||||
|
p.add_argument("--n-problems", type=int, default=5)
|
||||||
|
p.add_argument("--rounds", type=int, default=2)
|
||||||
|
p.add_argument("--parallel", type=int, default=4)
|
||||||
|
p.add_argument("--chunk-tokens", type=int, default=8192)
|
||||||
|
p.add_argument("--model", default="kyr0/zaya1-base-8b-MLX")
|
||||||
|
p.add_argument("--output", type=Path, default=None)
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
items = _HMMT_2025_SUBSET[: args.n_problems]
|
||||||
|
from markovian_rsa_mlx import MarkovianRSAOrchestrator
|
||||||
|
print(f"[bench] loading {args.model} ...", file=sys.stderr)
|
||||||
|
orch = MarkovianRSAOrchestrator.from_pretrained(args.model)
|
||||||
|
|
||||||
|
print(f"[bench] vanilla decode on {len(items)} problems ...", file=sys.stderr)
|
||||||
|
t0 = time.time()
|
||||||
|
vanilla = [_vanilla_predict(orch, it["question"], args.chunk_tokens) for it in items]
|
||||||
|
vanilla_elapsed = time.time() - t0
|
||||||
|
vanilla_score = score_subset(items, vanilla)
|
||||||
|
|
||||||
|
print(f"[bench] RSA T={args.rounds} N={args.parallel} ...", file=sys.stderr)
|
||||||
|
t0 = time.time()
|
||||||
|
rsa = [_rsa_predict(orch, it["question"], rounds=args.rounds,
|
||||||
|
parallel=args.parallel, chunk=args.chunk_tokens) for it in items]
|
||||||
|
rsa_elapsed = time.time() - t0
|
||||||
|
rsa_score = score_subset(items, rsa)
|
||||||
|
|
||||||
|
summary = {
|
||||||
|
"subset": args.subset, "n_problems": len(items),
|
||||||
|
"model": args.model,
|
||||||
|
"config": {"rounds": args.rounds, "parallel": args.parallel,
|
||||||
|
"chunk_tokens": args.chunk_tokens},
|
||||||
|
"vanilla": {"correct": vanilla_score.correct, "total": vanilla_score.total,
|
||||||
|
"accuracy": vanilla_score.accuracy, "elapsed_s": vanilla_elapsed},
|
||||||
|
"rsa": {"correct": rsa_score.correct, "total": rsa_score.total,
|
||||||
|
"accuracy": rsa_score.accuracy, "elapsed_s": rsa_elapsed},
|
||||||
|
"lift_pp": (rsa_score.accuracy - vanilla_score.accuracy) * 100,
|
||||||
|
"predictions": [
|
||||||
|
{"id": it["id"], "answer": it["answer"],
|
||||||
|
"vanilla": v[:200] + "..." if len(v) > 200 else v,
|
||||||
|
"rsa": r[:200] + "..." if len(r) > 200 else r}
|
||||||
|
for it, v, r in zip(items, vanilla, rsa)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
out = json.dumps(summary, indent=2, ensure_ascii=False)
|
||||||
|
print(out)
|
||||||
|
if args.output is not None:
|
||||||
|
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
args.output.write_text(out)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
27
tests/test_bench_harness.py
Normal file
27
tests/test_bench_harness.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from scripts.bench_hmmt import extract_final_answer, score_subset
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_final_answer_picks_last_boxed():
|
||||||
|
text = "Long reasoning... \\boxed{42} done."
|
||||||
|
assert extract_final_answer(text) == "42"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_final_answer_falls_back_to_last_number():
|
||||||
|
text = "...therefore the answer is 17."
|
||||||
|
assert extract_final_answer(text) == "17"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_final_answer_returns_empty_when_no_number():
|
||||||
|
assert extract_final_answer("no answer here") == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_score_subset_counts_correct():
|
||||||
|
items = [
|
||||||
|
{"question": "q1", "answer": "42"},
|
||||||
|
{"question": "q2", "answer": "100"},
|
||||||
|
]
|
||||||
|
predictions = ["The answer is 42.", "Final: 99"]
|
||||||
|
score = score_subset(items, predictions)
|
||||||
|
assert score.correct == 1
|
||||||
|
assert score.total == 2
|
||||||
|
assert abs(score.accuracy - 0.5) < 1e-6
|
||||||
Reference in New Issue
Block a user