diff --git a/src/markovian_rsa_mlx/config.py b/src/markovian_rsa_mlx/config.py new file mode 100644 index 0000000..e6f0954 --- /dev/null +++ b/src/markovian_rsa_mlx/config.py @@ -0,0 +1,71 @@ +"""RSAConfig dataclass + named profiles. Zero MLX dependencies — purely declarative.""" +from __future__ import annotations +from dataclasses import dataclass, replace +from typing import Literal + + +@dataclass(frozen=True) +class RSAConfig: + """Markovian RSA hyperparameters. + + Defaults target the 16 GB practical profile (N=4, T=2, chunk=16K), + not the paper headline (N=16, chunk=40K). See named constructors for + paper-aligned profiles. + """ + rounds: int = 2 + parallel: int = 4 + aggregation_subsample: int = 4 + chunk_tokens: int = 16384 + tail_tokens: int = 4096 + final_tokens: int | None = None + temperature: float = 1.0 + top_p: float = 0.95 + top_k: int = -1 + seed: int | None = None + serial: bool = False + auto_serial: bool = True + memory_fraction: float = 0.80 + max_context_tokens: int | None = None + aggregation_template: Literal["zaya_v1"] = "zaya_v1" + answer_selection: Literal["first_final_candidate"] = "first_final_candidate" + + def __post_init__(self) -> None: + if self.chunk_tokens <= 0: + raise ValueError("chunk_tokens must be positive") + if self.tail_tokens <= 0: + raise ValueError("tail_tokens must be positive") + if self.parallel < 1: + raise ValueError("parallel must be >= 1") + if self.rounds < 1: + raise ValueError("rounds must be >= 1") + if self.aggregation_subsample > self.parallel: + raise ValueError( + f"aggregation_subsample ({self.aggregation_subsample}) " + f"must be <= parallel ({self.parallel})" + ) + if not 0.0 < self.memory_fraction <= 1.0: + raise ValueError("memory_fraction must be in (0, 1]") + + def effective_final_tokens(self) -> int: + return self.final_tokens if self.final_tokens is not None else self.chunk_tokens + + @classmethod + def default_16gb(cls) -> "RSAConfig": + return cls(parallel=2, chunk_tokens=16384, aggregation_subsample=2) + + @classmethod + def paper_16k(cls) -> "RSAConfig": + return cls(parallel=4, chunk_tokens=16384, top_p=1.0) + + @classmethod + def paper_headline_40k(cls) -> "RSAConfig": + return cls( + parallel=16, + aggregation_subsample=4, + chunk_tokens=40960, + final_tokens=40960, + top_p=1.0, + ) + + def replace(self, **kwargs) -> "RSAConfig": + return replace(self, **kwargs) diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..2ad27c7 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,60 @@ +import pytest +from markovian_rsa_mlx.config import RSAConfig + + +def test_default_config_values(): + cfg = RSAConfig() + assert cfg.rounds == 2 + assert cfg.parallel == 4 + assert cfg.aggregation_subsample == 4 + assert cfg.chunk_tokens == 16384 + assert cfg.tail_tokens == 4096 + assert cfg.final_tokens is None + assert cfg.temperature == 1.0 + assert cfg.top_p == 0.95 + assert cfg.top_k == -1 + assert cfg.seed is None + assert cfg.serial is False + assert cfg.auto_serial is True + assert cfg.memory_fraction == 0.80 + assert cfg.max_context_tokens is None + assert cfg.aggregation_template == "zaya_v1" + assert cfg.answer_selection == "first_final_candidate" + + +def test_validation_rejects_negative_tokens(): + with pytest.raises(ValueError, match="chunk_tokens must be positive"): + RSAConfig(chunk_tokens=0) + + +def test_validation_rejects_K_greater_than_N(): + with pytest.raises(ValueError, match="aggregation_subsample.*<= parallel"): + RSAConfig(parallel=2, aggregation_subsample=4) + + +def test_default_16gb_profile(): + cfg = RSAConfig.default_16gb() + assert cfg.parallel == 2 + assert cfg.chunk_tokens == 16384 + + +def test_paper_16k_profile(): + cfg = RSAConfig.paper_16k() + assert cfg.parallel == 4 + assert cfg.chunk_tokens == 16384 + assert cfg.top_p == 1.0 + + +def test_paper_headline_40k_profile(): + cfg = RSAConfig.paper_headline_40k() + assert cfg.parallel == 16 + assert cfg.chunk_tokens == 40960 + assert cfg.final_tokens == 40960 + assert cfg.top_p == 1.0 + + +def test_effective_final_tokens_falls_back_to_chunk(): + cfg = RSAConfig(chunk_tokens=8192, final_tokens=None) + assert cfg.effective_final_tokens() == 8192 + cfg2 = RSAConfig(chunk_tokens=8192, final_tokens=20000) + assert cfg2.effective_final_tokens() == 20000