fix(config): validate final_tokens > 0 + add replace() revalidates regression test
This commit is contained in:
@@ -34,6 +34,8 @@ class RSAConfig:
|
|||||||
raise ValueError("chunk_tokens must be positive")
|
raise ValueError("chunk_tokens must be positive")
|
||||||
if self.tail_tokens <= 0:
|
if self.tail_tokens <= 0:
|
||||||
raise ValueError("tail_tokens must be positive")
|
raise ValueError("tail_tokens must be positive")
|
||||||
|
if self.final_tokens is not None and self.final_tokens <= 0:
|
||||||
|
raise ValueError("final_tokens must be positive when set")
|
||||||
if self.parallel < 1:
|
if self.parallel < 1:
|
||||||
raise ValueError("parallel must be >= 1")
|
raise ValueError("parallel must be >= 1")
|
||||||
if self.rounds < 1:
|
if self.rounds < 1:
|
||||||
|
|||||||
@@ -58,3 +58,19 @@ def test_effective_final_tokens_falls_back_to_chunk():
|
|||||||
assert cfg.effective_final_tokens() == 8192
|
assert cfg.effective_final_tokens() == 8192
|
||||||
cfg2 = RSAConfig(chunk_tokens=8192, final_tokens=20000)
|
cfg2 = RSAConfig(chunk_tokens=8192, final_tokens=20000)
|
||||||
assert cfg2.effective_final_tokens() == 20000
|
assert cfg2.effective_final_tokens() == 20000
|
||||||
|
|
||||||
|
|
||||||
|
def test_replace_revalidates():
|
||||||
|
cfg = RSAConfig()
|
||||||
|
with pytest.raises(ValueError, match="aggregation_subsample.*<= parallel"):
|
||||||
|
cfg.replace(parallel=2, aggregation_subsample=10)
|
||||||
|
|
||||||
|
|
||||||
|
def test_final_tokens_zero_rejected():
|
||||||
|
with pytest.raises(ValueError, match="final_tokens must be positive"):
|
||||||
|
RSAConfig(final_tokens=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_final_tokens_negative_rejected():
|
||||||
|
with pytest.raises(ValueError, match="final_tokens must be positive"):
|
||||||
|
RSAConfig(final_tokens=-1)
|
||||||
|
|||||||
Reference in New Issue
Block a user