77 lines
2.3 KiB
Python
77 lines
2.3 KiB
Python
import pytest
|
|
from markovian_rsa_mlx.config import RSAConfig
|
|
|
|
|
|
def test_default_config_values():
|
|
cfg = RSAConfig()
|
|
assert cfg.rounds == 2
|
|
assert cfg.parallel == 4
|
|
assert cfg.aggregation_subsample == 4
|
|
assert cfg.chunk_tokens == 16384
|
|
assert cfg.tail_tokens == 4096
|
|
assert cfg.final_tokens is None
|
|
assert cfg.temperature == 1.0
|
|
assert cfg.top_p == 0.95
|
|
assert cfg.top_k == -1
|
|
assert cfg.seed is None
|
|
assert cfg.serial is False
|
|
assert cfg.auto_serial is True
|
|
assert cfg.memory_fraction == 0.80
|
|
assert cfg.max_context_tokens is None
|
|
assert cfg.aggregation_template == "zaya_v1"
|
|
assert cfg.answer_selection == "first_final_candidate"
|
|
|
|
|
|
def test_validation_rejects_negative_tokens():
|
|
with pytest.raises(ValueError, match="chunk_tokens must be positive"):
|
|
RSAConfig(chunk_tokens=0)
|
|
|
|
|
|
def test_validation_rejects_K_greater_than_N():
|
|
with pytest.raises(ValueError, match="aggregation_subsample.*<= parallel"):
|
|
RSAConfig(parallel=2, aggregation_subsample=4)
|
|
|
|
|
|
def test_default_16gb_profile():
|
|
cfg = RSAConfig.default_16gb()
|
|
assert cfg.parallel == 2
|
|
assert cfg.chunk_tokens == 16384
|
|
|
|
|
|
def test_paper_16k_profile():
|
|
cfg = RSAConfig.paper_16k()
|
|
assert cfg.parallel == 4
|
|
assert cfg.chunk_tokens == 16384
|
|
assert cfg.top_p == 1.0
|
|
|
|
|
|
def test_paper_headline_40k_profile():
|
|
cfg = RSAConfig.paper_headline_40k()
|
|
assert cfg.parallel == 16
|
|
assert cfg.chunk_tokens == 40960
|
|
assert cfg.final_tokens == 40960
|
|
assert cfg.top_p == 1.0
|
|
|
|
|
|
def test_effective_final_tokens_falls_back_to_chunk():
|
|
cfg = RSAConfig(chunk_tokens=8192, final_tokens=None)
|
|
assert cfg.effective_final_tokens() == 8192
|
|
cfg2 = RSAConfig(chunk_tokens=8192, final_tokens=20000)
|
|
assert cfg2.effective_final_tokens() == 20000
|
|
|
|
|
|
def test_replace_revalidates():
|
|
cfg = RSAConfig()
|
|
with pytest.raises(ValueError, match="aggregation_subsample.*<= parallel"):
|
|
cfg.replace(parallel=2, aggregation_subsample=10)
|
|
|
|
|
|
def test_final_tokens_zero_rejected():
|
|
with pytest.raises(ValueError, match="final_tokens must be positive"):
|
|
RSAConfig(final_tokens=0)
|
|
|
|
|
|
def test_final_tokens_negative_rejected():
|
|
with pytest.raises(ValueError, match="final_tokens must be positive"):
|
|
RSAConfig(final_tokens=-1)
|