fix(config): validate final_tokens > 0 + add replace() revalidates regression test

This commit is contained in:
transcrilive
2026-05-10 02:40:59 +02:00
parent 08ae956986
commit 40bd38c2c5
2 changed files with 18 additions and 0 deletions

View File

@@ -34,6 +34,8 @@ class RSAConfig:
raise ValueError("chunk_tokens must be positive") raise ValueError("chunk_tokens must be positive")
if self.tail_tokens <= 0: if self.tail_tokens <= 0:
raise ValueError("tail_tokens must be positive") raise ValueError("tail_tokens must be positive")
if self.final_tokens is not None and self.final_tokens <= 0:
raise ValueError("final_tokens must be positive when set")
if self.parallel < 1: if self.parallel < 1:
raise ValueError("parallel must be >= 1") raise ValueError("parallel must be >= 1")
if self.rounds < 1: if self.rounds < 1:

View File

@@ -58,3 +58,19 @@ def test_effective_final_tokens_falls_back_to_chunk():
assert cfg.effective_final_tokens() == 8192 assert cfg.effective_final_tokens() == 8192
cfg2 = RSAConfig(chunk_tokens=8192, final_tokens=20000) cfg2 = RSAConfig(chunk_tokens=8192, final_tokens=20000)
assert cfg2.effective_final_tokens() == 20000 assert cfg2.effective_final_tokens() == 20000
def test_replace_revalidates():
cfg = RSAConfig()
with pytest.raises(ValueError, match="aggregation_subsample.*<= parallel"):
cfg.replace(parallel=2, aggregation_subsample=10)
def test_final_tokens_zero_rejected():
with pytest.raises(ValueError, match="final_tokens must be positive"):
RSAConfig(final_tokens=0)
def test_final_tokens_negative_rejected():
with pytest.raises(ValueError, match="final_tokens must be positive"):
RSAConfig(final_tokens=-1)