86 lines
3.2 KiB
Markdown
86 lines
3.2 KiB
Markdown
# markovian-rsa-mlx
|
|
|
|
First MLX implementation of Zyphra's **Markovian RSA** test-time compute methodology, targeting **ZAYA1-8B** on Apple Silicon. Boosts reasoning accuracy by sampling N parallel reasoning traces, extracting their tails, and feeding aggregation prompts back to the model.
|
|
|
|
> **Status :** v0.1.1. `enable_thinking=False` default ; aggregation `zaya_v1` template (reverse-engineered ; paper does not publish co-trained format). Both vanilla and RSA score 100% on the 5-problem corrected HMMT subset (ceiling effect — needs harder set for real lift measurement).
|
|
|
|
## Install
|
|
|
|
```bash
|
|
uv add "markovian-rsa-mlx @ git+https://gitea.tavportal.com/olivier/markovian-rsa-mlx.git"
|
|
```
|
|
|
|
This pulls in `mlx-lm` from kyr0's `feat/zaya-support` branch automatically (until upstream PR #1261 merges).
|
|
|
|
## Quickstart
|
|
|
|
Python API:
|
|
|
|
```python
|
|
from markovian_rsa_mlx import MarkovianRSAOrchestrator, RSAConfig
|
|
|
|
orch = MarkovianRSAOrchestrator.from_pretrained("kyr0/zaya1-base-8b-MLX")
|
|
cfg = RSAConfig.default_16gb() # parallel=2, chunk=16K — fits 16 GB Mac
|
|
text, audit = orch.solve(
|
|
"Compute the integral of x^2 from 0 to 5",
|
|
config=cfg, return_audit=True, audit_path="run.jsonl",
|
|
)
|
|
print(text)
|
|
```
|
|
|
|
CLI:
|
|
|
|
```bash
|
|
markovian-rsa-mlx solve "Compute the integral of x^2 from 0 to 5" \
|
|
--profile default-16gb --audit run.jsonl
|
|
```
|
|
|
|
## Profiles
|
|
|
|
| Profile | rounds | parallel | chunk | Mem | Notes |
|
|
|---|---:|---:|---:|---:|---|
|
|
| `default-16gb` | 2 | 2 | 16 K | ~ 8 GB | safest on M2 16 GB |
|
|
| `paper-16k` | 2 | 4 | 16 K | ~ 16-24 GB | paper "deployment" profile |
|
|
| `paper-headline-40k` | 2 | 16 | 40 K | 32+ GB | paper headline (HMMT'25 89.6) |
|
|
|
|
## Bench results (HMMT'25 5-problem subset)
|
|
|
|
With the corrected placeholder dataset and `enable_thinking=False` default :
|
|
|
|
| Backend | Score | Wall time | Per-problem avg |
|
|
|---|---:|---:|---:|
|
|
| Vanilla (T=1 N=1) | 5/5 = 100% | 1085 s | 217 s |
|
|
| RSA T=2 N=2 (default-16gb) | 5/5 = 100% | 3974 s | 795 s |
|
|
|
|
`lift_pp = +0.00pp` on this subset due to ceiling effect (vanilla already hits 100%). Larger HMMT'25 / AIME'26 datasets needed to measure the real lift. The system is mechanically correct (RSA outputs reference "Approach 1, Approach 2" from aggregation prompts) ; just needs harder problems to differentiate.
|
|
|
|
## Audit JSONL
|
|
|
|
Every event of the run is one line. Schema in
|
|
[`docs/superpowers/specs/2026-05-10-markovian-rsa-mlx-design.md`](docs/superpowers/specs/2026-05-10-markovian-rsa-mlx-design.md) Section 2.
|
|
|
|
## Bench
|
|
|
|
```bash
|
|
uv run python scripts/bench_hmmt.py --n-problems 5 --rounds 2 --parallel 4 \
|
|
--output bench-out/hmmt_smoke.json
|
|
```
|
|
|
|
## Architecture
|
|
|
|
- `orchestrator.py` : drives N parallel traces + T rounds.
|
|
- `prompts.py` : round-0 + `zaya_v1` aggregation template.
|
|
- `batching.py` : dispatches between serial and `BatchGenerator` paths.
|
|
- `audit.py` : streaming JSONL writer + event types.
|
|
- `guards.py` : memory + context budget checks.
|
|
|
|
## License
|
|
|
|
MIT. See [LICENSE](LICENSE).
|
|
|
|
Model weights are governed by the upstream Zyphra licence ; see [`Zyphra/ZAYA1-8B`](https://huggingface.co/Zyphra/ZAYA1-8B).
|
|
|
|
## Provenance
|
|
|
|
Spec produced via 2-round Codex (gpt-5.5 xhigh) brainstorming. Implementation by Olivier Dupont with code-review assistance.
|