From d4c241f91a00634f383c8e967ebe0f74442008b2 Mon Sep 17 00:00:00 2001 From: transcrilive Date: Sun, 10 May 2026 03:12:43 +0200 Subject: [PATCH] feat(guards): memory/context budget checks + multi-round T>=2 orchestrator tests Co-Authored-By: Claude Opus 4.7 (1M context) --- src/markovian_rsa_mlx/guards.py | 72 +++++++++++++++++++++++++++ src/markovian_rsa_mlx/orchestrator.py | 8 ++- tests/test_guards.py | 65 ++++++++++++++++++++++++ tests/test_orchestrator_t2.py | 54 ++++++++++++++++++++ 4 files changed, 197 insertions(+), 2 deletions(-) create mode 100644 src/markovian_rsa_mlx/guards.py create mode 100644 tests/test_guards.py create mode 100644 tests/test_orchestrator_t2.py diff --git a/src/markovian_rsa_mlx/guards.py b/src/markovian_rsa_mlx/guards.py new file mode 100644 index 0000000..6c0763a --- /dev/null +++ b/src/markovian_rsa_mlx/guards.py @@ -0,0 +1,72 @@ +"""Memory and context budget guards. + +These run before kicking off batched generation so we can either bail out +loudly (auto_serial=False) or transparently downgrade to serial decoding. +""" +from __future__ import annotations +from dataclasses import dataclass + +from markovian_rsa_mlx.config import RSAConfig + + +@dataclass +class BudgetDecision: + ok: bool + fallback_to_serial: bool + estimated_bytes: int + limit_bytes: int + threshold_bytes: int + + +def estimate_kv_bytes(*, zaya_per_token_bytes: int, tokens: int, parallel: int) -> int: + return zaya_per_token_bytes * tokens * parallel + + +def estimate_total_bytes(*, weights_bytes: int, kv_bytes: int, workspace_bytes: int) -> int: + return weights_bytes + kv_bytes + workspace_bytes + + +def check_memory_budget( + *, + cfg: RSAConfig, + weights_bytes: int, + per_token_kv_bytes: int, + workspace_bytes: int, + metal_limit_bytes: int, +) -> BudgetDecision: + """Decide whether to run batched, serial, or refuse.""" + kv = estimate_kv_bytes( + zaya_per_token_bytes=per_token_kv_bytes, + tokens=cfg.chunk_tokens, parallel=cfg.parallel, + ) + estimated = estimate_total_bytes(weights_bytes=weights_bytes, kv_bytes=kv, workspace_bytes=workspace_bytes) + threshold = int(metal_limit_bytes * cfg.memory_fraction) + if estimated <= threshold: + return BudgetDecision(ok=True, fallback_to_serial=False, + estimated_bytes=estimated, limit_bytes=metal_limit_bytes, + threshold_bytes=threshold) + if cfg.auto_serial: + return BudgetDecision(ok=False, fallback_to_serial=True, + estimated_bytes=estimated, limit_bytes=metal_limit_bytes, + threshold_bytes=threshold) + raise MemoryError( + f"Estimated {estimated/1e9:.2f} GB exceeds {threshold/1e9:.2f} GB " + f"({cfg.memory_fraction*100:.0f}% of {metal_limit_bytes/1e9:.2f} GB Metal limit). " + "Reduce parallel/chunk_tokens, set serial=True, or enable auto_serial." + ) + + +def check_context_budget(*, cfg: RSAConfig, prompt_tokens: int, max_context_tokens: int) -> None: + formatting_overhead = 200 # chat template wrappers + needed = ( + prompt_tokens + + cfg.aggregation_subsample * cfg.tail_tokens + + cfg.chunk_tokens + + formatting_overhead + ) + if needed > max_context_tokens: + raise ValueError( + f"Required context {needed} tokens would exceed model max " + f"{max_context_tokens} tokens. Reduce tail_tokens, aggregation_subsample, " + "or chunk_tokens." + ) diff --git a/src/markovian_rsa_mlx/orchestrator.py b/src/markovian_rsa_mlx/orchestrator.py index 48824d6..c7e39e1 100644 --- a/src/markovian_rsa_mlx/orchestrator.py +++ b/src/markovian_rsa_mlx/orchestrator.py @@ -22,8 +22,12 @@ from markovian_rsa_mlx.results import RSAResult, RSARound, RSAStats, TraceRecord def _trace_seed(base_seed: int | None, run_id: str, round_index: int, trace_index: int) -> int: - """Deterministic seed if base_seed set, else stable from run_id.""" - key = f"{base_seed}|{run_id}|{round_index}|{trace_index}" + """Deterministic seed when base_seed is set (ignores run_id); + otherwise uses run_id to provide stable per-call uniqueness.""" + if base_seed is not None: + key = f"{base_seed}|{round_index}|{trace_index}" + else: + key = f"{run_id}|{round_index}|{trace_index}" h = hashlib.sha256(key.encode()).hexdigest() return int(h[:8], 16) diff --git a/tests/test_guards.py b/tests/test_guards.py new file mode 100644 index 0000000..1c30332 --- /dev/null +++ b/tests/test_guards.py @@ -0,0 +1,65 @@ +import pytest +from markovian_rsa_mlx.guards import ( + estimate_kv_bytes, estimate_total_bytes, check_memory_budget, + check_context_budget, +) +from markovian_rsa_mlx.config import RSAConfig + + +def test_estimate_kv_bytes_for_known_zaya(): + bytes_per_token = estimate_kv_bytes(zaya_per_token_bytes=40_000, tokens=16384, parallel=4) + assert bytes_per_token == 40_000 * 16384 * 4 + + +def test_total_bytes_includes_weights(): + total = estimate_total_bytes(weights_bytes=4_700_000_000, kv_bytes=2_600_000_000, workspace_bytes=500_000_000) + assert total == 4_700_000_000 + 2_600_000_000 + 500_000_000 + + +def test_memory_check_returns_ok_when_under_limit(): + cfg = RSAConfig(parallel=4, chunk_tokens=16384, memory_fraction=0.80) + decision = check_memory_budget( + cfg=cfg, + weights_bytes=4_700_000_000, + per_token_kv_bytes=40_000, + workspace_bytes=500_000_000, + metal_limit_bytes=24_000_000_000, + ) + assert decision.ok is True + assert decision.fallback_to_serial is False + + +def test_memory_check_recommends_serial_when_over(): + cfg = RSAConfig(parallel=8, chunk_tokens=16384, memory_fraction=0.80, auto_serial=True) + decision = check_memory_budget( + cfg=cfg, + weights_bytes=4_700_000_000, + per_token_kv_bytes=40_000, + workspace_bytes=4_000_000_000, + metal_limit_bytes=16_000_000_000, + ) + assert decision.ok is False + assert decision.fallback_to_serial is True + + +def test_memory_check_raises_when_auto_serial_disabled(): + cfg = RSAConfig(parallel=8, chunk_tokens=16384, memory_fraction=0.80, auto_serial=False) + with pytest.raises(MemoryError): + check_memory_budget( + cfg=cfg, + weights_bytes=4_700_000_000, + per_token_kv_bytes=40_000, + workspace_bytes=4_000_000_000, + metal_limit_bytes=16_000_000_000, + ) + + +def test_context_check_passes_within_limit(): + cfg = RSAConfig(chunk_tokens=8000, tail_tokens=2000, aggregation_subsample=4) + check_context_budget(cfg=cfg, prompt_tokens=1000, max_context_tokens=131072) + + +def test_context_check_raises_when_over(): + cfg = RSAConfig(chunk_tokens=80000, tail_tokens=20000, aggregation_subsample=4) + with pytest.raises(ValueError, match="exceed"): + check_context_budget(cfg=cfg, prompt_tokens=1000, max_context_tokens=131072) diff --git a/tests/test_orchestrator_t2.py b/tests/test_orchestrator_t2.py new file mode 100644 index 0000000..f455d71 --- /dev/null +++ b/tests/test_orchestrator_t2.py @@ -0,0 +1,54 @@ +from unittest.mock import MagicMock + +from markovian_rsa_mlx.batching import GenerationResult +from markovian_rsa_mlx.config import RSAConfig +from markovian_rsa_mlx.orchestrator import MarkovianRSAOrchestrator + + +def _tok(eos_id=999): + t = MagicMock() + t.encode.side_effect = lambda s: [ord(c) for c in s][:32] or [1] + t.decode.side_effect = lambda ids: "".join(chr(min(i, 122)) for i in ids if 32 <= i <= 122) + t.eos_token_id = eos_id + t.all_special_ids = [eos_id] + t.apply_chat_template.side_effect = lambda messages, **kw: \ + " ".join(m["content"] for m in messages).encode() + return t + + +def _gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k): + text = f"r-trace-{seed}" + ids = [ord(c) for c in text] + return GenerationResult(token_ids=ids, text=text, generated_tokens=len(ids), + finish_reason="eos", elapsed_s=0.01) + + +def test_t2_two_rounds_aggregation_sees_K_tails(tmp_path): + cfg = RSAConfig(rounds=2, parallel=2, aggregation_subsample=2, + chunk_tokens=64, tail_tokens=4, serial=True, seed=7) + orch = MarkovianRSAOrchestrator( + model=MagicMock(), tokenizer=_tok(), + model_id="m", quantization="bf16", + single_generate=_gen, batch_generate=None, + ) + audit = tmp_path / "audit.jsonl" + text, result = orch.solve("Solve this", config=cfg, return_audit=True, audit_path=audit) + assert len(result.rounds) == 2 + assert len(result.rounds[1].traces) == 2 + # round 1 traces have parent links + assert all(t.parent_trace_ids for t in result.rounds[1].traces) + + +def test_deterministic_seeds_are_stable(tmp_path): + cfg = RSAConfig(rounds=2, parallel=2, aggregation_subsample=2, + chunk_tokens=8, tail_tokens=2, serial=True, seed=42) + orch = MarkovianRSAOrchestrator( + model=MagicMock(), tokenizer=_tok(), + model_id="m", quantization="bf16", + single_generate=_gen, batch_generate=None, + ) + _, r1 = orch.solve("Q", config=cfg, return_audit=True, audit_path=tmp_path / "a.jsonl") + _, r2 = orch.solve("Q", config=cfg, return_audit=True, audit_path=tmp_path / "b.jsonl") + seeds_1 = [t.seed for r in r1.rounds for t in r.traces] + seeds_2 = [t.seed for r in r2.rounds for t in r.traces] + assert seeds_1 == seeds_2