feat(guards): memory/context budget checks + multi-round T>=2 orchestrator tests

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
transcrilive
2026-05-10 03:12:43 +02:00
parent 86ccbe53e1
commit d4c241f91a
4 changed files with 197 additions and 2 deletions

65
tests/test_guards.py Normal file
View File

@@ -0,0 +1,65 @@
import pytest
from markovian_rsa_mlx.guards import (
estimate_kv_bytes, estimate_total_bytes, check_memory_budget,
check_context_budget,
)
from markovian_rsa_mlx.config import RSAConfig
def test_estimate_kv_bytes_for_known_zaya():
bytes_per_token = estimate_kv_bytes(zaya_per_token_bytes=40_000, tokens=16384, parallel=4)
assert bytes_per_token == 40_000 * 16384 * 4
def test_total_bytes_includes_weights():
total = estimate_total_bytes(weights_bytes=4_700_000_000, kv_bytes=2_600_000_000, workspace_bytes=500_000_000)
assert total == 4_700_000_000 + 2_600_000_000 + 500_000_000
def test_memory_check_returns_ok_when_under_limit():
cfg = RSAConfig(parallel=4, chunk_tokens=16384, memory_fraction=0.80)
decision = check_memory_budget(
cfg=cfg,
weights_bytes=4_700_000_000,
per_token_kv_bytes=40_000,
workspace_bytes=500_000_000,
metal_limit_bytes=24_000_000_000,
)
assert decision.ok is True
assert decision.fallback_to_serial is False
def test_memory_check_recommends_serial_when_over():
cfg = RSAConfig(parallel=8, chunk_tokens=16384, memory_fraction=0.80, auto_serial=True)
decision = check_memory_budget(
cfg=cfg,
weights_bytes=4_700_000_000,
per_token_kv_bytes=40_000,
workspace_bytes=4_000_000_000,
metal_limit_bytes=16_000_000_000,
)
assert decision.ok is False
assert decision.fallback_to_serial is True
def test_memory_check_raises_when_auto_serial_disabled():
cfg = RSAConfig(parallel=8, chunk_tokens=16384, memory_fraction=0.80, auto_serial=False)
with pytest.raises(MemoryError):
check_memory_budget(
cfg=cfg,
weights_bytes=4_700_000_000,
per_token_kv_bytes=40_000,
workspace_bytes=4_000_000_000,
metal_limit_bytes=16_000_000_000,
)
def test_context_check_passes_within_limit():
cfg = RSAConfig(chunk_tokens=8000, tail_tokens=2000, aggregation_subsample=4)
check_context_budget(cfg=cfg, prompt_tokens=1000, max_context_tokens=131072)
def test_context_check_raises_when_over():
cfg = RSAConfig(chunk_tokens=80000, tail_tokens=20000, aggregation_subsample=4)
with pytest.raises(ValueError, match="exceed"):
check_context_budget(cfg=cfg, prompt_tokens=1000, max_context_tokens=131072)

View File

@@ -0,0 +1,54 @@
from unittest.mock import MagicMock
from markovian_rsa_mlx.batching import GenerationResult
from markovian_rsa_mlx.config import RSAConfig
from markovian_rsa_mlx.orchestrator import MarkovianRSAOrchestrator
def _tok(eos_id=999):
t = MagicMock()
t.encode.side_effect = lambda s: [ord(c) for c in s][:32] or [1]
t.decode.side_effect = lambda ids: "".join(chr(min(i, 122)) for i in ids if 32 <= i <= 122)
t.eos_token_id = eos_id
t.all_special_ids = [eos_id]
t.apply_chat_template.side_effect = lambda messages, **kw: \
" ".join(m["content"] for m in messages).encode()
return t
def _gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k):
text = f"r-trace-{seed}"
ids = [ord(c) for c in text]
return GenerationResult(token_ids=ids, text=text, generated_tokens=len(ids),
finish_reason="eos", elapsed_s=0.01)
def test_t2_two_rounds_aggregation_sees_K_tails(tmp_path):
cfg = RSAConfig(rounds=2, parallel=2, aggregation_subsample=2,
chunk_tokens=64, tail_tokens=4, serial=True, seed=7)
orch = MarkovianRSAOrchestrator(
model=MagicMock(), tokenizer=_tok(),
model_id="m", quantization="bf16",
single_generate=_gen, batch_generate=None,
)
audit = tmp_path / "audit.jsonl"
text, result = orch.solve("Solve this", config=cfg, return_audit=True, audit_path=audit)
assert len(result.rounds) == 2
assert len(result.rounds[1].traces) == 2
# round 1 traces have parent links
assert all(t.parent_trace_ids for t in result.rounds[1].traces)
def test_deterministic_seeds_are_stable(tmp_path):
cfg = RSAConfig(rounds=2, parallel=2, aggregation_subsample=2,
chunk_tokens=8, tail_tokens=2, serial=True, seed=42)
orch = MarkovianRSAOrchestrator(
model=MagicMock(), tokenizer=_tok(),
model_id="m", quantization="bf16",
single_generate=_gen, batch_generate=None,
)
_, r1 = orch.solve("Q", config=cfg, return_audit=True, audit_path=tmp_path / "a.jsonl")
_, r2 = orch.solve("Q", config=cfg, return_audit=True, audit_path=tmp_path / "b.jsonl")
seeds_1 = [t.seed for r in r1.rounds for t in r.traces]
seeds_2 = [t.seed for r in r2.rounds for t in r.traces]
assert seeds_1 == seeds_2