66 lines
2.3 KiB
Python
66 lines
2.3 KiB
Python
import pytest
|
|
from markovian_rsa_mlx.guards import (
|
|
estimate_kv_bytes, estimate_total_bytes, check_memory_budget,
|
|
check_context_budget,
|
|
)
|
|
from markovian_rsa_mlx.config import RSAConfig
|
|
|
|
|
|
def test_estimate_kv_bytes_for_known_zaya():
|
|
bytes_per_token = estimate_kv_bytes(zaya_per_token_bytes=40_000, tokens=16384, parallel=4)
|
|
assert bytes_per_token == 40_000 * 16384 * 4
|
|
|
|
|
|
def test_total_bytes_includes_weights():
|
|
total = estimate_total_bytes(weights_bytes=4_700_000_000, kv_bytes=2_600_000_000, workspace_bytes=500_000_000)
|
|
assert total == 4_700_000_000 + 2_600_000_000 + 500_000_000
|
|
|
|
|
|
def test_memory_check_returns_ok_when_under_limit():
|
|
cfg = RSAConfig(parallel=4, chunk_tokens=16384, memory_fraction=0.80)
|
|
decision = check_memory_budget(
|
|
cfg=cfg,
|
|
weights_bytes=4_700_000_000,
|
|
per_token_kv_bytes=40_000,
|
|
workspace_bytes=500_000_000,
|
|
metal_limit_bytes=24_000_000_000,
|
|
)
|
|
assert decision.ok is True
|
|
assert decision.fallback_to_serial is False
|
|
|
|
|
|
def test_memory_check_recommends_serial_when_over():
|
|
cfg = RSAConfig(parallel=8, chunk_tokens=16384, memory_fraction=0.80, auto_serial=True)
|
|
decision = check_memory_budget(
|
|
cfg=cfg,
|
|
weights_bytes=4_700_000_000,
|
|
per_token_kv_bytes=40_000,
|
|
workspace_bytes=4_000_000_000,
|
|
metal_limit_bytes=16_000_000_000,
|
|
)
|
|
assert decision.ok is False
|
|
assert decision.fallback_to_serial is True
|
|
|
|
|
|
def test_memory_check_raises_when_auto_serial_disabled():
|
|
cfg = RSAConfig(parallel=8, chunk_tokens=16384, memory_fraction=0.80, auto_serial=False)
|
|
with pytest.raises(MemoryError):
|
|
check_memory_budget(
|
|
cfg=cfg,
|
|
weights_bytes=4_700_000_000,
|
|
per_token_kv_bytes=40_000,
|
|
workspace_bytes=4_000_000_000,
|
|
metal_limit_bytes=16_000_000_000,
|
|
)
|
|
|
|
|
|
def test_context_check_passes_within_limit():
|
|
cfg = RSAConfig(chunk_tokens=8000, tail_tokens=2000, aggregation_subsample=4)
|
|
check_context_budget(cfg=cfg, prompt_tokens=1000, max_context_tokens=131072)
|
|
|
|
|
|
def test_context_check_raises_when_over():
|
|
cfg = RSAConfig(chunk_tokens=80000, tail_tokens=20000, aggregation_subsample=4)
|
|
with pytest.raises(ValueError, match="exceed"):
|
|
check_context_budget(cfg=cfg, prompt_tokens=1000, max_context_tokens=131072)
|