73 lines
2.6 KiB
Python
73 lines
2.6 KiB
Python
"""Memory and context budget guards.
|
|
|
|
These run before kicking off batched generation so we can either bail out
|
|
loudly (auto_serial=False) or transparently downgrade to serial decoding.
|
|
"""
|
|
from __future__ import annotations
|
|
from dataclasses import dataclass
|
|
|
|
from markovian_rsa_mlx.config import RSAConfig
|
|
|
|
|
|
@dataclass
|
|
class BudgetDecision:
|
|
ok: bool
|
|
fallback_to_serial: bool
|
|
estimated_bytes: int
|
|
limit_bytes: int
|
|
threshold_bytes: int
|
|
|
|
|
|
def estimate_kv_bytes(*, zaya_per_token_bytes: int, tokens: int, parallel: int) -> int:
|
|
return zaya_per_token_bytes * tokens * parallel
|
|
|
|
|
|
def estimate_total_bytes(*, weights_bytes: int, kv_bytes: int, workspace_bytes: int) -> int:
|
|
return weights_bytes + kv_bytes + workspace_bytes
|
|
|
|
|
|
def check_memory_budget(
|
|
*,
|
|
cfg: RSAConfig,
|
|
weights_bytes: int,
|
|
per_token_kv_bytes: int,
|
|
workspace_bytes: int,
|
|
metal_limit_bytes: int,
|
|
) -> BudgetDecision:
|
|
"""Decide whether to run batched, serial, or refuse."""
|
|
kv = estimate_kv_bytes(
|
|
zaya_per_token_bytes=per_token_kv_bytes,
|
|
tokens=cfg.chunk_tokens, parallel=cfg.parallel,
|
|
)
|
|
estimated = estimate_total_bytes(weights_bytes=weights_bytes, kv_bytes=kv, workspace_bytes=workspace_bytes)
|
|
threshold = int(metal_limit_bytes * cfg.memory_fraction)
|
|
if estimated <= threshold:
|
|
return BudgetDecision(ok=True, fallback_to_serial=False,
|
|
estimated_bytes=estimated, limit_bytes=metal_limit_bytes,
|
|
threshold_bytes=threshold)
|
|
if cfg.auto_serial:
|
|
return BudgetDecision(ok=False, fallback_to_serial=True,
|
|
estimated_bytes=estimated, limit_bytes=metal_limit_bytes,
|
|
threshold_bytes=threshold)
|
|
raise MemoryError(
|
|
f"Estimated {estimated/1e9:.2f} GB exceeds {threshold/1e9:.2f} GB "
|
|
f"({cfg.memory_fraction*100:.0f}% of {metal_limit_bytes/1e9:.2f} GB Metal limit). "
|
|
"Reduce parallel/chunk_tokens, set serial=True, or enable auto_serial."
|
|
)
|
|
|
|
|
|
def check_context_budget(*, cfg: RSAConfig, prompt_tokens: int, max_context_tokens: int) -> None:
|
|
formatting_overhead = 200 # chat template wrappers
|
|
needed = (
|
|
prompt_tokens
|
|
+ cfg.aggregation_subsample * cfg.tail_tokens
|
|
+ cfg.chunk_tokens
|
|
+ formatting_overhead
|
|
)
|
|
if needed > max_context_tokens:
|
|
raise ValueError(
|
|
f"Required context {needed} tokens would exceed model max "
|
|
f"{max_context_tokens} tokens. Reduce tail_tokens, aggregation_subsample, "
|
|
"or chunk_tokens."
|
|
)
|