"""Memory and context budget guards. These run before kicking off batched generation so we can either bail out loudly (auto_serial=False) or transparently downgrade to serial decoding. """ from __future__ import annotations from dataclasses import dataclass from markovian_rsa_mlx.config import RSAConfig @dataclass class BudgetDecision: ok: bool fallback_to_serial: bool estimated_bytes: int limit_bytes: int threshold_bytes: int def estimate_kv_bytes(*, zaya_per_token_bytes: int, tokens: int, parallel: int) -> int: return zaya_per_token_bytes * tokens * parallel def estimate_total_bytes(*, weights_bytes: int, kv_bytes: int, workspace_bytes: int) -> int: return weights_bytes + kv_bytes + workspace_bytes def check_memory_budget( *, cfg: RSAConfig, weights_bytes: int, per_token_kv_bytes: int, workspace_bytes: int, metal_limit_bytes: int, ) -> BudgetDecision: """Decide whether to run batched, serial, or refuse.""" kv = estimate_kv_bytes( zaya_per_token_bytes=per_token_kv_bytes, tokens=cfg.chunk_tokens, parallel=cfg.parallel, ) estimated = estimate_total_bytes(weights_bytes=weights_bytes, kv_bytes=kv, workspace_bytes=workspace_bytes) threshold = int(metal_limit_bytes * cfg.memory_fraction) if estimated <= threshold: return BudgetDecision(ok=True, fallback_to_serial=False, estimated_bytes=estimated, limit_bytes=metal_limit_bytes, threshold_bytes=threshold) if cfg.auto_serial: return BudgetDecision(ok=False, fallback_to_serial=True, estimated_bytes=estimated, limit_bytes=metal_limit_bytes, threshold_bytes=threshold) raise MemoryError( f"Estimated {estimated/1e9:.2f} GB exceeds {threshold/1e9:.2f} GB " f"({cfg.memory_fraction*100:.0f}% of {metal_limit_bytes/1e9:.2f} GB Metal limit). " "Reduce parallel/chunk_tokens, set serial=True, or enable auto_serial." ) def check_context_budget(*, cfg: RSAConfig, prompt_tokens: int, max_context_tokens: int) -> None: formatting_overhead = 200 # chat template wrappers needed = ( prompt_tokens + cfg.aggregation_subsample * cfg.tail_tokens + cfg.chunk_tokens + formatting_overhead ) if needed > max_context_tokens: raise ValueError( f"Required context {needed} tokens would exceed model max " f"{max_context_tokens} tokens. Reduce tail_tokens, aggregation_subsample, " "or chunk_tokens." )