2 Commits

Author SHA1 Message Date
transcrilive
b65bf91e37 release: v0.1.1 — enable_thinking=False default + corrected bench gold + CHANGELOG 2026-05-10 14:38:27 +02:00
transcrilive
81e8ac88cc feat(config): add enable_thinking flag (default False) + fix HMMT bench gold answers
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-10 13:08:41 +02:00
10 changed files with 51 additions and 11 deletions

18
CHANGELOG.md Normal file
View File

@@ -0,0 +1,18 @@
# Changelog
## v0.1.1 — 2026-05-10
### Added
- `RSAConfig.enable_thinking` field (default `False`). Toggling `<think>` mode in the chat template substantially affects output quality on math problems.
- Bench `scripts/bench_hmmt.py` now uses corrected gold answers for the placeholder HMMT-1 (66, was 100) and HMMT-5 (1, was 76).
### Changed
- Default `enable_thinking` flipped to `False`. Empirical testing shows `<think>` mode causes the model to narrate the aggregation prompt (`"We have a user message: ..."`) instead of solving. Direct mode produces math reasoning immediately.
- `_render_chat(messages, *, enable_thinking)` signature now takes an explicit kwarg (was hardcoded to `True`).
### Bench results
- 5/5 vanilla + 5/5 RSA on corrected HMMT subset. lift_pp +0.00pp (ceiling effect — vanilla already at 100%).
## v0.1.0 — 2026-05-10
Initial public release. T=2 N=4 RSA orchestrator with audit JSONL + CLI + HMMT bench harness.

View File

@@ -2,7 +2,7 @@
First MLX implementation of Zyphra's **Markovian RSA** test-time compute methodology, targeting **ZAYA1-8B** on Apple Silicon. Boosts reasoning accuracy by sampling N parallel reasoning traces, extracting their tails, and feeding aggregation prompts back to the model. First MLX implementation of Zyphra's **Markovian RSA** test-time compute methodology, targeting **ZAYA1-8B** on Apple Silicon. Boosts reasoning accuracy by sampling N parallel reasoning traces, extracting their tails, and feeding aggregation prompts back to the model.
> **Status :** v0.1.0. Aggregation prompt is `zaya_v1` (reverse-engineered ; paper does not publish the co-trained format). HMMT'25 5-problem smoke shows ≥ 0 pp lift on M2 Pro. > **Status :** v0.1.1. `enable_thinking=False` default ; aggregation `zaya_v1` template (reverse-engineered ; paper does not publish co-trained format). Both vanilla and RSA score 100% on the 5-problem corrected HMMT subset (ceiling effect — needs harder set for real lift measurement).
## Install ## Install
@@ -43,6 +43,17 @@ markovian-rsa-mlx solve "Compute the integral of x^2 from 0 to 5" \
| `paper-16k` | 2 | 4 | 16 K | ~ 16-24 GB | paper "deployment" profile | | `paper-16k` | 2 | 4 | 16 K | ~ 16-24 GB | paper "deployment" profile |
| `paper-headline-40k` | 2 | 16 | 40 K | 32+ GB | paper headline (HMMT'25 89.6) | | `paper-headline-40k` | 2 | 16 | 40 K | 32+ GB | paper headline (HMMT'25 89.6) |
## Bench results (HMMT'25 5-problem subset)
With the corrected placeholder dataset and `enable_thinking=False` default :
| Backend | Score | Wall time | Per-problem avg |
|---|---:|---:|---:|
| Vanilla (T=1 N=1) | 5/5 = 100% | 1085 s | 217 s |
| RSA T=2 N=2 (default-16gb) | 5/5 = 100% | 3974 s | 795 s |
`lift_pp = +0.00pp` on this subset due to ceiling effect (vanilla already hits 100%). Larger HMMT'25 / AIME'26 datasets needed to measure the real lift. The system is mechanically correct (RSA outputs reference "Approach 1, Approach 2" from aggregation prompts) ; just needs harder problems to differentiate.
## Audit JSONL ## Audit JSONL
Every event of the run is one line. Schema in Every event of the run is one line. Schema in

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "markovian-rsa-mlx" name = "markovian-rsa-mlx"
version = "0.1.0" version = "0.1.1"
description = "Markovian RSA test-time compute methodology on MLX for ZAYA1-8B and future co-trained models" description = "Markovian RSA test-time compute methodology on MLX for ZAYA1-8B and future co-trained models"
readme = "README.md" readme = "README.md"
requires-python = ">=3.12,<3.14" requires-python = ">=3.12,<3.14"

View File

@@ -21,7 +21,7 @@ _HMMT_2025_SUBSET = [
{ {
"id": "hmmt-1", "id": "hmmt-1",
"question": "Find the number of positive integers n <= 100 such that n^2 + n is divisible by 6.", "question": "Find the number of positive integers n <= 100 such that n^2 + n is divisible by 6.",
"answer": "100", "answer": "66",
}, },
{ {
"id": "hmmt-2", "id": "hmmt-2",
@@ -41,7 +41,7 @@ _HMMT_2025_SUBSET = [
{ {
"id": "hmmt-5", "id": "hmmt-5",
"question": "What is the remainder when 2^100 is divided by 125?", "question": "What is the remainder when 2^100 is divided by 125?",
"answer": "76", "answer": "1",
}, },
] ]

View File

@@ -1,5 +1,5 @@
"""Markovian RSA test-time compute methodology on MLX.""" """Markovian RSA test-time compute methodology on MLX."""
__version__ = "0.1.0" __version__ = "0.1.1"
from markovian_rsa_mlx.config import RSAConfig from markovian_rsa_mlx.config import RSAConfig
from markovian_rsa_mlx.loader import load_zaya_model from markovian_rsa_mlx.loader import load_zaya_model

View File

@@ -27,6 +27,7 @@ class RSAConfig:
memory_fraction: float = 0.80 memory_fraction: float = 0.80
max_context_tokens: int | None = None max_context_tokens: int | None = None
aggregation_template: Literal["zaya_v1"] = "zaya_v1" aggregation_template: Literal["zaya_v1"] = "zaya_v1"
enable_thinking: bool = False
answer_selection: Literal["first_final_candidate"] = "first_final_candidate" answer_selection: Literal["first_final_candidate"] = "first_final_candidate"
def __post_init__(self) -> None: def __post_init__(self) -> None:

View File

@@ -150,7 +150,7 @@ class MarkovianRSAOrchestrator:
parent_ids_per_trace: list[list[str]] = [] parent_ids_per_trace: list[list[str]] = []
if is_round_0: if is_round_0:
messages = build_round_0_messages(original_prompt) messages = build_round_0_messages(original_prompt)
prompt_ids = self._render_chat(messages) prompt_ids = self._render_chat(messages, enable_thinking=cfg.enable_thinking)
prompts_token_ids = [prompt_ids for _ in range(cfg.parallel)] prompts_token_ids = [prompt_ids for _ in range(cfg.parallel)]
parent_ids_per_trace = [[] for _ in range(cfg.parallel)] parent_ids_per_trace = [[] for _ in range(cfg.parallel)]
else: else:
@@ -171,7 +171,7 @@ class MarkovianRSAOrchestrator:
original_prompt=original_prompt, tails=tails, original_prompt=original_prompt, tails=tails,
template=cfg.aggregation_template, template=cfg.aggregation_template,
) )
prompt_ids = self._render_chat(messages) prompt_ids = self._render_chat(messages, enable_thinking=cfg.enable_thinking)
child_trace_id = f"r{round_idx}-t{trace_idx}-{run_id[:6]}" child_trace_id = f"r{round_idx}-t{trace_idx}-{run_id[:6]}"
audit.write(AggregationPromptEvent( audit.write(AggregationPromptEvent(
run_id=run_id, round=round_idx, trace_id=child_trace_id, run_id=run_id, round=round_idx, trace_id=child_trace_id,
@@ -217,10 +217,10 @@ class MarkovianRSAOrchestrator:
round_elapsed = time.time() - round_t0 round_elapsed = time.time() - round_t0
return records, round_elapsed return records, round_elapsed
def _render_chat(self, messages: list[dict[str, str]]) -> list[int]: def _render_chat(self, messages: list[dict[str, str]], *, enable_thinking: bool) -> list[int]:
"""Apply ZAYA chat template and return token ids.""" """Apply ZAYA chat template and return token ids."""
rendered = self.tokenizer.apply_chat_template( rendered = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, enable_thinking=True, messages, add_generation_prompt=True, enable_thinking=enable_thinking,
) )
if isinstance(rendered, str): if isinstance(rendered, str):
return self.tokenizer.encode(rendered) return self.tokenizer.encode(rendered)

View File

@@ -7,7 +7,7 @@ runner = CliRunner()
def test_version_command_prints_version(): def test_version_command_prints_version():
result = runner.invoke(app, ["version"]) result = runner.invoke(app, ["version"])
assert result.exit_code == 0 assert result.exit_code == 0
assert "0.1.0" in result.stdout assert "0.1.1" in result.stdout
def test_solve_help_shows_required_flags(): def test_solve_help_shows_required_flags():

View File

@@ -74,3 +74,13 @@ def test_final_tokens_zero_rejected():
def test_final_tokens_negative_rejected(): def test_final_tokens_negative_rejected():
with pytest.raises(ValueError, match="final_tokens must be positive"): with pytest.raises(ValueError, match="final_tokens must be positive"):
RSAConfig(final_tokens=-1) RSAConfig(final_tokens=-1)
def test_enable_thinking_default_false():
cfg = RSAConfig()
assert cfg.enable_thinking is False
def test_enable_thinking_can_be_enabled():
cfg = RSAConfig(enable_thinking=True)
assert cfg.enable_thinking is True

2
uv.lock generated
View File

@@ -421,7 +421,7 @@ wheels = [
[[package]] [[package]]
name = "markovian-rsa-mlx" name = "markovian-rsa-mlx"
version = "0.1.0" version = "0.1.1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "huggingface-hub" }, { name = "huggingface-hub" },