feat(guards): memory/context budget checks + multi-round T>=2 orchestrator tests
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
72
src/markovian_rsa_mlx/guards.py
Normal file
72
src/markovian_rsa_mlx/guards.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Memory and context budget guards.
|
||||
|
||||
These run before kicking off batched generation so we can either bail out
|
||||
loudly (auto_serial=False) or transparently downgrade to serial decoding.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
|
||||
from markovian_rsa_mlx.config import RSAConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BudgetDecision:
|
||||
ok: bool
|
||||
fallback_to_serial: bool
|
||||
estimated_bytes: int
|
||||
limit_bytes: int
|
||||
threshold_bytes: int
|
||||
|
||||
|
||||
def estimate_kv_bytes(*, zaya_per_token_bytes: int, tokens: int, parallel: int) -> int:
|
||||
return zaya_per_token_bytes * tokens * parallel
|
||||
|
||||
|
||||
def estimate_total_bytes(*, weights_bytes: int, kv_bytes: int, workspace_bytes: int) -> int:
|
||||
return weights_bytes + kv_bytes + workspace_bytes
|
||||
|
||||
|
||||
def check_memory_budget(
|
||||
*,
|
||||
cfg: RSAConfig,
|
||||
weights_bytes: int,
|
||||
per_token_kv_bytes: int,
|
||||
workspace_bytes: int,
|
||||
metal_limit_bytes: int,
|
||||
) -> BudgetDecision:
|
||||
"""Decide whether to run batched, serial, or refuse."""
|
||||
kv = estimate_kv_bytes(
|
||||
zaya_per_token_bytes=per_token_kv_bytes,
|
||||
tokens=cfg.chunk_tokens, parallel=cfg.parallel,
|
||||
)
|
||||
estimated = estimate_total_bytes(weights_bytes=weights_bytes, kv_bytes=kv, workspace_bytes=workspace_bytes)
|
||||
threshold = int(metal_limit_bytes * cfg.memory_fraction)
|
||||
if estimated <= threshold:
|
||||
return BudgetDecision(ok=True, fallback_to_serial=False,
|
||||
estimated_bytes=estimated, limit_bytes=metal_limit_bytes,
|
||||
threshold_bytes=threshold)
|
||||
if cfg.auto_serial:
|
||||
return BudgetDecision(ok=False, fallback_to_serial=True,
|
||||
estimated_bytes=estimated, limit_bytes=metal_limit_bytes,
|
||||
threshold_bytes=threshold)
|
||||
raise MemoryError(
|
||||
f"Estimated {estimated/1e9:.2f} GB exceeds {threshold/1e9:.2f} GB "
|
||||
f"({cfg.memory_fraction*100:.0f}% of {metal_limit_bytes/1e9:.2f} GB Metal limit). "
|
||||
"Reduce parallel/chunk_tokens, set serial=True, or enable auto_serial."
|
||||
)
|
||||
|
||||
|
||||
def check_context_budget(*, cfg: RSAConfig, prompt_tokens: int, max_context_tokens: int) -> None:
|
||||
formatting_overhead = 200 # chat template wrappers
|
||||
needed = (
|
||||
prompt_tokens
|
||||
+ cfg.aggregation_subsample * cfg.tail_tokens
|
||||
+ cfg.chunk_tokens
|
||||
+ formatting_overhead
|
||||
)
|
||||
if needed > max_context_tokens:
|
||||
raise ValueError(
|
||||
f"Required context {needed} tokens would exceed model max "
|
||||
f"{max_context_tokens} tokens. Reduce tail_tokens, aggregation_subsample, "
|
||||
"or chunk_tokens."
|
||||
)
|
||||
@@ -22,8 +22,12 @@ from markovian_rsa_mlx.results import RSAResult, RSARound, RSAStats, TraceRecord
|
||||
|
||||
|
||||
def _trace_seed(base_seed: int | None, run_id: str, round_index: int, trace_index: int) -> int:
|
||||
"""Deterministic seed if base_seed set, else stable from run_id."""
|
||||
key = f"{base_seed}|{run_id}|{round_index}|{trace_index}"
|
||||
"""Deterministic seed when base_seed is set (ignores run_id);
|
||||
otherwise uses run_id to provide stable per-call uniqueness."""
|
||||
if base_seed is not None:
|
||||
key = f"{base_seed}|{round_index}|{trace_index}"
|
||||
else:
|
||||
key = f"{run_id}|{round_index}|{trace_index}"
|
||||
h = hashlib.sha256(key.encode()).hexdigest()
|
||||
return int(h[:8], 16)
|
||||
|
||||
|
||||
65
tests/test_guards.py
Normal file
65
tests/test_guards.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
from markovian_rsa_mlx.guards import (
|
||||
estimate_kv_bytes, estimate_total_bytes, check_memory_budget,
|
||||
check_context_budget,
|
||||
)
|
||||
from markovian_rsa_mlx.config import RSAConfig
|
||||
|
||||
|
||||
def test_estimate_kv_bytes_for_known_zaya():
|
||||
bytes_per_token = estimate_kv_bytes(zaya_per_token_bytes=40_000, tokens=16384, parallel=4)
|
||||
assert bytes_per_token == 40_000 * 16384 * 4
|
||||
|
||||
|
||||
def test_total_bytes_includes_weights():
|
||||
total = estimate_total_bytes(weights_bytes=4_700_000_000, kv_bytes=2_600_000_000, workspace_bytes=500_000_000)
|
||||
assert total == 4_700_000_000 + 2_600_000_000 + 500_000_000
|
||||
|
||||
|
||||
def test_memory_check_returns_ok_when_under_limit():
|
||||
cfg = RSAConfig(parallel=4, chunk_tokens=16384, memory_fraction=0.80)
|
||||
decision = check_memory_budget(
|
||||
cfg=cfg,
|
||||
weights_bytes=4_700_000_000,
|
||||
per_token_kv_bytes=40_000,
|
||||
workspace_bytes=500_000_000,
|
||||
metal_limit_bytes=24_000_000_000,
|
||||
)
|
||||
assert decision.ok is True
|
||||
assert decision.fallback_to_serial is False
|
||||
|
||||
|
||||
def test_memory_check_recommends_serial_when_over():
|
||||
cfg = RSAConfig(parallel=8, chunk_tokens=16384, memory_fraction=0.80, auto_serial=True)
|
||||
decision = check_memory_budget(
|
||||
cfg=cfg,
|
||||
weights_bytes=4_700_000_000,
|
||||
per_token_kv_bytes=40_000,
|
||||
workspace_bytes=4_000_000_000,
|
||||
metal_limit_bytes=16_000_000_000,
|
||||
)
|
||||
assert decision.ok is False
|
||||
assert decision.fallback_to_serial is True
|
||||
|
||||
|
||||
def test_memory_check_raises_when_auto_serial_disabled():
|
||||
cfg = RSAConfig(parallel=8, chunk_tokens=16384, memory_fraction=0.80, auto_serial=False)
|
||||
with pytest.raises(MemoryError):
|
||||
check_memory_budget(
|
||||
cfg=cfg,
|
||||
weights_bytes=4_700_000_000,
|
||||
per_token_kv_bytes=40_000,
|
||||
workspace_bytes=4_000_000_000,
|
||||
metal_limit_bytes=16_000_000_000,
|
||||
)
|
||||
|
||||
|
||||
def test_context_check_passes_within_limit():
|
||||
cfg = RSAConfig(chunk_tokens=8000, tail_tokens=2000, aggregation_subsample=4)
|
||||
check_context_budget(cfg=cfg, prompt_tokens=1000, max_context_tokens=131072)
|
||||
|
||||
|
||||
def test_context_check_raises_when_over():
|
||||
cfg = RSAConfig(chunk_tokens=80000, tail_tokens=20000, aggregation_subsample=4)
|
||||
with pytest.raises(ValueError, match="exceed"):
|
||||
check_context_budget(cfg=cfg, prompt_tokens=1000, max_context_tokens=131072)
|
||||
54
tests/test_orchestrator_t2.py
Normal file
54
tests/test_orchestrator_t2.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from markovian_rsa_mlx.batching import GenerationResult
|
||||
from markovian_rsa_mlx.config import RSAConfig
|
||||
from markovian_rsa_mlx.orchestrator import MarkovianRSAOrchestrator
|
||||
|
||||
|
||||
def _tok(eos_id=999):
|
||||
t = MagicMock()
|
||||
t.encode.side_effect = lambda s: [ord(c) for c in s][:32] or [1]
|
||||
t.decode.side_effect = lambda ids: "".join(chr(min(i, 122)) for i in ids if 32 <= i <= 122)
|
||||
t.eos_token_id = eos_id
|
||||
t.all_special_ids = [eos_id]
|
||||
t.apply_chat_template.side_effect = lambda messages, **kw: \
|
||||
" ".join(m["content"] for m in messages).encode()
|
||||
return t
|
||||
|
||||
|
||||
def _gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k):
|
||||
text = f"r-trace-{seed}"
|
||||
ids = [ord(c) for c in text]
|
||||
return GenerationResult(token_ids=ids, text=text, generated_tokens=len(ids),
|
||||
finish_reason="eos", elapsed_s=0.01)
|
||||
|
||||
|
||||
def test_t2_two_rounds_aggregation_sees_K_tails(tmp_path):
|
||||
cfg = RSAConfig(rounds=2, parallel=2, aggregation_subsample=2,
|
||||
chunk_tokens=64, tail_tokens=4, serial=True, seed=7)
|
||||
orch = MarkovianRSAOrchestrator(
|
||||
model=MagicMock(), tokenizer=_tok(),
|
||||
model_id="m", quantization="bf16",
|
||||
single_generate=_gen, batch_generate=None,
|
||||
)
|
||||
audit = tmp_path / "audit.jsonl"
|
||||
text, result = orch.solve("Solve this", config=cfg, return_audit=True, audit_path=audit)
|
||||
assert len(result.rounds) == 2
|
||||
assert len(result.rounds[1].traces) == 2
|
||||
# round 1 traces have parent links
|
||||
assert all(t.parent_trace_ids for t in result.rounds[1].traces)
|
||||
|
||||
|
||||
def test_deterministic_seeds_are_stable(tmp_path):
|
||||
cfg = RSAConfig(rounds=2, parallel=2, aggregation_subsample=2,
|
||||
chunk_tokens=8, tail_tokens=2, serial=True, seed=42)
|
||||
orch = MarkovianRSAOrchestrator(
|
||||
model=MagicMock(), tokenizer=_tok(),
|
||||
model_id="m", quantization="bf16",
|
||||
single_generate=_gen, batch_generate=None,
|
||||
)
|
||||
_, r1 = orch.solve("Q", config=cfg, return_audit=True, audit_path=tmp_path / "a.jsonl")
|
||||
_, r2 = orch.solve("Q", config=cfg, return_audit=True, audit_path=tmp_path / "b.jsonl")
|
||||
seeds_1 = [t.seed for r in r1.rounds for t in r.traces]
|
||||
seeds_2 = [t.seed for r in r2.rounds for t in r.traces]
|
||||
assert seeds_1 == seeds_2
|
||||
Reference in New Issue
Block a user