fix(config): validate final_tokens > 0 + add replace() revalidates regression test
This commit is contained in:
@@ -34,6 +34,8 @@ class RSAConfig:
|
||||
raise ValueError("chunk_tokens must be positive")
|
||||
if self.tail_tokens <= 0:
|
||||
raise ValueError("tail_tokens must be positive")
|
||||
if self.final_tokens is not None and self.final_tokens <= 0:
|
||||
raise ValueError("final_tokens must be positive when set")
|
||||
if self.parallel < 1:
|
||||
raise ValueError("parallel must be >= 1")
|
||||
if self.rounds < 1:
|
||||
|
||||
@@ -58,3 +58,19 @@ def test_effective_final_tokens_falls_back_to_chunk():
|
||||
assert cfg.effective_final_tokens() == 8192
|
||||
cfg2 = RSAConfig(chunk_tokens=8192, final_tokens=20000)
|
||||
assert cfg2.effective_final_tokens() == 20000
|
||||
|
||||
|
||||
def test_replace_revalidates():
|
||||
cfg = RSAConfig()
|
||||
with pytest.raises(ValueError, match="aggregation_subsample.*<= parallel"):
|
||||
cfg.replace(parallel=2, aggregation_subsample=10)
|
||||
|
||||
|
||||
def test_final_tokens_zero_rejected():
|
||||
with pytest.raises(ValueError, match="final_tokens must be positive"):
|
||||
RSAConfig(final_tokens=0)
|
||||
|
||||
|
||||
def test_final_tokens_negative_rejected():
|
||||
with pytest.raises(ValueError, match="final_tokens must be positive"):
|
||||
RSAConfig(final_tokens=-1)
|
||||
|
||||
Reference in New Issue
Block a user