feat(guards): memory/context budget checks + multi-round T>=2 orchestrator tests
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
54
tests/test_orchestrator_t2.py
Normal file
54
tests/test_orchestrator_t2.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from markovian_rsa_mlx.batching import GenerationResult
|
||||
from markovian_rsa_mlx.config import RSAConfig
|
||||
from markovian_rsa_mlx.orchestrator import MarkovianRSAOrchestrator
|
||||
|
||||
|
||||
def _tok(eos_id=999):
|
||||
t = MagicMock()
|
||||
t.encode.side_effect = lambda s: [ord(c) for c in s][:32] or [1]
|
||||
t.decode.side_effect = lambda ids: "".join(chr(min(i, 122)) for i in ids if 32 <= i <= 122)
|
||||
t.eos_token_id = eos_id
|
||||
t.all_special_ids = [eos_id]
|
||||
t.apply_chat_template.side_effect = lambda messages, **kw: \
|
||||
" ".join(m["content"] for m in messages).encode()
|
||||
return t
|
||||
|
||||
|
||||
def _gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k):
|
||||
text = f"r-trace-{seed}"
|
||||
ids = [ord(c) for c in text]
|
||||
return GenerationResult(token_ids=ids, text=text, generated_tokens=len(ids),
|
||||
finish_reason="eos", elapsed_s=0.01)
|
||||
|
||||
|
||||
def test_t2_two_rounds_aggregation_sees_K_tails(tmp_path):
|
||||
cfg = RSAConfig(rounds=2, parallel=2, aggregation_subsample=2,
|
||||
chunk_tokens=64, tail_tokens=4, serial=True, seed=7)
|
||||
orch = MarkovianRSAOrchestrator(
|
||||
model=MagicMock(), tokenizer=_tok(),
|
||||
model_id="m", quantization="bf16",
|
||||
single_generate=_gen, batch_generate=None,
|
||||
)
|
||||
audit = tmp_path / "audit.jsonl"
|
||||
text, result = orch.solve("Solve this", config=cfg, return_audit=True, audit_path=audit)
|
||||
assert len(result.rounds) == 2
|
||||
assert len(result.rounds[1].traces) == 2
|
||||
# round 1 traces have parent links
|
||||
assert all(t.parent_trace_ids for t in result.rounds[1].traces)
|
||||
|
||||
|
||||
def test_deterministic_seeds_are_stable(tmp_path):
|
||||
cfg = RSAConfig(rounds=2, parallel=2, aggregation_subsample=2,
|
||||
chunk_tokens=8, tail_tokens=2, serial=True, seed=42)
|
||||
orch = MarkovianRSAOrchestrator(
|
||||
model=MagicMock(), tokenizer=_tok(),
|
||||
model_id="m", quantization="bf16",
|
||||
single_generate=_gen, batch_generate=None,
|
||||
)
|
||||
_, r1 = orch.solve("Q", config=cfg, return_audit=True, audit_path=tmp_path / "a.jsonl")
|
||||
_, r2 = orch.solve("Q", config=cfg, return_audit=True, audit_path=tmp_path / "b.jsonl")
|
||||
seeds_1 = [t.seed for r in r1.rounds for t in r.traces]
|
||||
seeds_2 = [t.seed for r in r2.rounds for t in r.traces]
|
||||
assert seeds_1 == seeds_2
|
||||
Reference in New Issue
Block a user