From 40bd38c2c5bda82c11a12aaba9d1d913aeaab1e9 Mon Sep 17 00:00:00 2001 From: transcrilive Date: Sun, 10 May 2026 02:40:59 +0200 Subject: [PATCH] fix(config): validate final_tokens > 0 + add replace() revalidates regression test --- src/markovian_rsa_mlx/config.py | 2 ++ tests/test_config.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/src/markovian_rsa_mlx/config.py b/src/markovian_rsa_mlx/config.py index e6f0954..3d10bf2 100644 --- a/src/markovian_rsa_mlx/config.py +++ b/src/markovian_rsa_mlx/config.py @@ -34,6 +34,8 @@ class RSAConfig: raise ValueError("chunk_tokens must be positive") if self.tail_tokens <= 0: raise ValueError("tail_tokens must be positive") + if self.final_tokens is not None and self.final_tokens <= 0: + raise ValueError("final_tokens must be positive when set") if self.parallel < 1: raise ValueError("parallel must be >= 1") if self.rounds < 1: diff --git a/tests/test_config.py b/tests/test_config.py index 2ad27c7..6db91e3 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -58,3 +58,19 @@ def test_effective_final_tokens_falls_back_to_chunk(): assert cfg.effective_final_tokens() == 8192 cfg2 = RSAConfig(chunk_tokens=8192, final_tokens=20000) assert cfg2.effective_final_tokens() == 20000 + + +def test_replace_revalidates(): + cfg = RSAConfig() + with pytest.raises(ValueError, match="aggregation_subsample.*<= parallel"): + cfg.replace(parallel=2, aggregation_subsample=10) + + +def test_final_tokens_zero_rejected(): + with pytest.raises(ValueError, match="final_tokens must be positive"): + RSAConfig(final_tokens=0) + + +def test_final_tokens_negative_rejected(): + with pytest.raises(ValueError, match="final_tokens must be positive"): + RSAConfig(final_tokens=-1)