diff --git a/src/markovian_rsa_mlx/config.py b/src/markovian_rsa_mlx/config.py index e6f0954..3d10bf2 100644 --- a/src/markovian_rsa_mlx/config.py +++ b/src/markovian_rsa_mlx/config.py @@ -34,6 +34,8 @@ class RSAConfig: raise ValueError("chunk_tokens must be positive") if self.tail_tokens <= 0: raise ValueError("tail_tokens must be positive") + if self.final_tokens is not None and self.final_tokens <= 0: + raise ValueError("final_tokens must be positive when set") if self.parallel < 1: raise ValueError("parallel must be >= 1") if self.rounds < 1: diff --git a/tests/test_config.py b/tests/test_config.py index 2ad27c7..6db91e3 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -58,3 +58,19 @@ def test_effective_final_tokens_falls_back_to_chunk(): assert cfg.effective_final_tokens() == 8192 cfg2 = RSAConfig(chunk_tokens=8192, final_tokens=20000) assert cfg2.effective_final_tokens() == 20000 + + +def test_replace_revalidates(): + cfg = RSAConfig() + with pytest.raises(ValueError, match="aggregation_subsample.*<= parallel"): + cfg.replace(parallel=2, aggregation_subsample=10) + + +def test_final_tokens_zero_rejected(): + with pytest.raises(ValueError, match="final_tokens must be positive"): + RSAConfig(final_tokens=0) + + +def test_final_tokens_negative_rejected(): + with pytest.raises(ValueError, match="final_tokens must be positive"): + RSAConfig(final_tokens=-1)