feat(config): add RSAConfig + named profiles (default-16gb, paper-16k, paper-headline-40k)
This commit is contained in:
71
src/markovian_rsa_mlx/config.py
Normal file
71
src/markovian_rsa_mlx/config.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
"""RSAConfig dataclass + named profiles. Zero MLX dependencies — purely declarative."""
|
||||||
|
from __future__ import annotations
|
||||||
|
from dataclasses import dataclass, replace
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RSAConfig:
|
||||||
|
"""Markovian RSA hyperparameters.
|
||||||
|
|
||||||
|
Defaults target the 16 GB practical profile (N=4, T=2, chunk=16K),
|
||||||
|
not the paper headline (N=16, chunk=40K). See named constructors for
|
||||||
|
paper-aligned profiles.
|
||||||
|
"""
|
||||||
|
rounds: int = 2
|
||||||
|
parallel: int = 4
|
||||||
|
aggregation_subsample: int = 4
|
||||||
|
chunk_tokens: int = 16384
|
||||||
|
tail_tokens: int = 4096
|
||||||
|
final_tokens: int | None = None
|
||||||
|
temperature: float = 1.0
|
||||||
|
top_p: float = 0.95
|
||||||
|
top_k: int = -1
|
||||||
|
seed: int | None = None
|
||||||
|
serial: bool = False
|
||||||
|
auto_serial: bool = True
|
||||||
|
memory_fraction: float = 0.80
|
||||||
|
max_context_tokens: int | None = None
|
||||||
|
aggregation_template: Literal["zaya_v1"] = "zaya_v1"
|
||||||
|
answer_selection: Literal["first_final_candidate"] = "first_final_candidate"
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if self.chunk_tokens <= 0:
|
||||||
|
raise ValueError("chunk_tokens must be positive")
|
||||||
|
if self.tail_tokens <= 0:
|
||||||
|
raise ValueError("tail_tokens must be positive")
|
||||||
|
if self.parallel < 1:
|
||||||
|
raise ValueError("parallel must be >= 1")
|
||||||
|
if self.rounds < 1:
|
||||||
|
raise ValueError("rounds must be >= 1")
|
||||||
|
if self.aggregation_subsample > self.parallel:
|
||||||
|
raise ValueError(
|
||||||
|
f"aggregation_subsample ({self.aggregation_subsample}) "
|
||||||
|
f"must be <= parallel ({self.parallel})"
|
||||||
|
)
|
||||||
|
if not 0.0 < self.memory_fraction <= 1.0:
|
||||||
|
raise ValueError("memory_fraction must be in (0, 1]")
|
||||||
|
|
||||||
|
def effective_final_tokens(self) -> int:
|
||||||
|
return self.final_tokens if self.final_tokens is not None else self.chunk_tokens
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default_16gb(cls) -> "RSAConfig":
|
||||||
|
return cls(parallel=2, chunk_tokens=16384, aggregation_subsample=2)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def paper_16k(cls) -> "RSAConfig":
|
||||||
|
return cls(parallel=4, chunk_tokens=16384, top_p=1.0)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def paper_headline_40k(cls) -> "RSAConfig":
|
||||||
|
return cls(
|
||||||
|
parallel=16,
|
||||||
|
aggregation_subsample=4,
|
||||||
|
chunk_tokens=40960,
|
||||||
|
final_tokens=40960,
|
||||||
|
top_p=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def replace(self, **kwargs) -> "RSAConfig":
|
||||||
|
return replace(self, **kwargs)
|
||||||
60
tests/test_config.py
Normal file
60
tests/test_config.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import pytest
|
||||||
|
from markovian_rsa_mlx.config import RSAConfig
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_config_values():
|
||||||
|
cfg = RSAConfig()
|
||||||
|
assert cfg.rounds == 2
|
||||||
|
assert cfg.parallel == 4
|
||||||
|
assert cfg.aggregation_subsample == 4
|
||||||
|
assert cfg.chunk_tokens == 16384
|
||||||
|
assert cfg.tail_tokens == 4096
|
||||||
|
assert cfg.final_tokens is None
|
||||||
|
assert cfg.temperature == 1.0
|
||||||
|
assert cfg.top_p == 0.95
|
||||||
|
assert cfg.top_k == -1
|
||||||
|
assert cfg.seed is None
|
||||||
|
assert cfg.serial is False
|
||||||
|
assert cfg.auto_serial is True
|
||||||
|
assert cfg.memory_fraction == 0.80
|
||||||
|
assert cfg.max_context_tokens is None
|
||||||
|
assert cfg.aggregation_template == "zaya_v1"
|
||||||
|
assert cfg.answer_selection == "first_final_candidate"
|
||||||
|
|
||||||
|
|
||||||
|
def test_validation_rejects_negative_tokens():
|
||||||
|
with pytest.raises(ValueError, match="chunk_tokens must be positive"):
|
||||||
|
RSAConfig(chunk_tokens=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validation_rejects_K_greater_than_N():
|
||||||
|
with pytest.raises(ValueError, match="aggregation_subsample.*<= parallel"):
|
||||||
|
RSAConfig(parallel=2, aggregation_subsample=4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_16gb_profile():
|
||||||
|
cfg = RSAConfig.default_16gb()
|
||||||
|
assert cfg.parallel == 2
|
||||||
|
assert cfg.chunk_tokens == 16384
|
||||||
|
|
||||||
|
|
||||||
|
def test_paper_16k_profile():
|
||||||
|
cfg = RSAConfig.paper_16k()
|
||||||
|
assert cfg.parallel == 4
|
||||||
|
assert cfg.chunk_tokens == 16384
|
||||||
|
assert cfg.top_p == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_paper_headline_40k_profile():
|
||||||
|
cfg = RSAConfig.paper_headline_40k()
|
||||||
|
assert cfg.parallel == 16
|
||||||
|
assert cfg.chunk_tokens == 40960
|
||||||
|
assert cfg.final_tokens == 40960
|
||||||
|
assert cfg.top_p == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_effective_final_tokens_falls_back_to_chunk():
|
||||||
|
cfg = RSAConfig(chunk_tokens=8192, final_tokens=None)
|
||||||
|
assert cfg.effective_final_tokens() == 8192
|
||||||
|
cfg2 = RSAConfig(chunk_tokens=8192, final_tokens=20000)
|
||||||
|
assert cfg2.effective_final_tokens() == 20000
|
||||||
Reference in New Issue
Block a user