Files
markovian-rsa-mlx/src/markovian_rsa_mlx/guards.py
2026-05-10 03:12:43 +02:00

73 lines
2.6 KiB
Python

"""Memory and context budget guards.
These run before kicking off batched generation so we can either bail out
loudly (auto_serial=False) or transparently downgrade to serial decoding.
"""
from __future__ import annotations
from dataclasses import dataclass
from markovian_rsa_mlx.config import RSAConfig
@dataclass
class BudgetDecision:
ok: bool
fallback_to_serial: bool
estimated_bytes: int
limit_bytes: int
threshold_bytes: int
def estimate_kv_bytes(*, zaya_per_token_bytes: int, tokens: int, parallel: int) -> int:
return zaya_per_token_bytes * tokens * parallel
def estimate_total_bytes(*, weights_bytes: int, kv_bytes: int, workspace_bytes: int) -> int:
return weights_bytes + kv_bytes + workspace_bytes
def check_memory_budget(
*,
cfg: RSAConfig,
weights_bytes: int,
per_token_kv_bytes: int,
workspace_bytes: int,
metal_limit_bytes: int,
) -> BudgetDecision:
"""Decide whether to run batched, serial, or refuse."""
kv = estimate_kv_bytes(
zaya_per_token_bytes=per_token_kv_bytes,
tokens=cfg.chunk_tokens, parallel=cfg.parallel,
)
estimated = estimate_total_bytes(weights_bytes=weights_bytes, kv_bytes=kv, workspace_bytes=workspace_bytes)
threshold = int(metal_limit_bytes * cfg.memory_fraction)
if estimated <= threshold:
return BudgetDecision(ok=True, fallback_to_serial=False,
estimated_bytes=estimated, limit_bytes=metal_limit_bytes,
threshold_bytes=threshold)
if cfg.auto_serial:
return BudgetDecision(ok=False, fallback_to_serial=True,
estimated_bytes=estimated, limit_bytes=metal_limit_bytes,
threshold_bytes=threshold)
raise MemoryError(
f"Estimated {estimated/1e9:.2f} GB exceeds {threshold/1e9:.2f} GB "
f"({cfg.memory_fraction*100:.0f}% of {metal_limit_bytes/1e9:.2f} GB Metal limit). "
"Reduce parallel/chunk_tokens, set serial=True, or enable auto_serial."
)
def check_context_budget(*, cfg: RSAConfig, prompt_tokens: int, max_context_tokens: int) -> None:
formatting_overhead = 200 # chat template wrappers
needed = (
prompt_tokens
+ cfg.aggregation_subsample * cfg.tail_tokens
+ cfg.chunk_tokens
+ formatting_overhead
)
if needed > max_context_tokens:
raise ValueError(
f"Required context {needed} tokens would exceed model max "
f"{max_context_tokens} tokens. Reduce tail_tokens, aggregation_subsample, "
"or chunk_tokens."
)