feat(guards): memory/context budget checks + multi-round T>=2 orchestrator tests
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
72
src/markovian_rsa_mlx/guards.py
Normal file
72
src/markovian_rsa_mlx/guards.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Memory and context budget guards.
|
||||
|
||||
These run before kicking off batched generation so we can either bail out
|
||||
loudly (auto_serial=False) or transparently downgrade to serial decoding.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
|
||||
from markovian_rsa_mlx.config import RSAConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BudgetDecision:
|
||||
ok: bool
|
||||
fallback_to_serial: bool
|
||||
estimated_bytes: int
|
||||
limit_bytes: int
|
||||
threshold_bytes: int
|
||||
|
||||
|
||||
def estimate_kv_bytes(*, zaya_per_token_bytes: int, tokens: int, parallel: int) -> int:
|
||||
return zaya_per_token_bytes * tokens * parallel
|
||||
|
||||
|
||||
def estimate_total_bytes(*, weights_bytes: int, kv_bytes: int, workspace_bytes: int) -> int:
|
||||
return weights_bytes + kv_bytes + workspace_bytes
|
||||
|
||||
|
||||
def check_memory_budget(
|
||||
*,
|
||||
cfg: RSAConfig,
|
||||
weights_bytes: int,
|
||||
per_token_kv_bytes: int,
|
||||
workspace_bytes: int,
|
||||
metal_limit_bytes: int,
|
||||
) -> BudgetDecision:
|
||||
"""Decide whether to run batched, serial, or refuse."""
|
||||
kv = estimate_kv_bytes(
|
||||
zaya_per_token_bytes=per_token_kv_bytes,
|
||||
tokens=cfg.chunk_tokens, parallel=cfg.parallel,
|
||||
)
|
||||
estimated = estimate_total_bytes(weights_bytes=weights_bytes, kv_bytes=kv, workspace_bytes=workspace_bytes)
|
||||
threshold = int(metal_limit_bytes * cfg.memory_fraction)
|
||||
if estimated <= threshold:
|
||||
return BudgetDecision(ok=True, fallback_to_serial=False,
|
||||
estimated_bytes=estimated, limit_bytes=metal_limit_bytes,
|
||||
threshold_bytes=threshold)
|
||||
if cfg.auto_serial:
|
||||
return BudgetDecision(ok=False, fallback_to_serial=True,
|
||||
estimated_bytes=estimated, limit_bytes=metal_limit_bytes,
|
||||
threshold_bytes=threshold)
|
||||
raise MemoryError(
|
||||
f"Estimated {estimated/1e9:.2f} GB exceeds {threshold/1e9:.2f} GB "
|
||||
f"({cfg.memory_fraction*100:.0f}% of {metal_limit_bytes/1e9:.2f} GB Metal limit). "
|
||||
"Reduce parallel/chunk_tokens, set serial=True, or enable auto_serial."
|
||||
)
|
||||
|
||||
|
||||
def check_context_budget(*, cfg: RSAConfig, prompt_tokens: int, max_context_tokens: int) -> None:
|
||||
formatting_overhead = 200 # chat template wrappers
|
||||
needed = (
|
||||
prompt_tokens
|
||||
+ cfg.aggregation_subsample * cfg.tail_tokens
|
||||
+ cfg.chunk_tokens
|
||||
+ formatting_overhead
|
||||
)
|
||||
if needed > max_context_tokens:
|
||||
raise ValueError(
|
||||
f"Required context {needed} tokens would exceed model max "
|
||||
f"{max_context_tokens} tokens. Reduce tail_tokens, aggregation_subsample, "
|
||||
"or chunk_tokens."
|
||||
)
|
||||
Reference in New Issue
Block a user