feat(orchestrator): add T=1 path with audit JSONL + tail extraction
This commit is contained in:
60
tests/test_orchestrator_t1.py
Normal file
60
tests/test_orchestrator_t1.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from markovian_rsa_mlx.batching import GenerationRequest, GenerationResult
|
||||
from markovian_rsa_mlx.config import RSAConfig
|
||||
from markovian_rsa_mlx.orchestrator import MarkovianRSAOrchestrator
|
||||
|
||||
|
||||
def _fake_tokenizer(eos_id: int = 999):
|
||||
tok = MagicMock()
|
||||
tok.encode.side_effect = lambda s: [ord(c) for c in s][:32] or [1]
|
||||
tok.decode.side_effect = lambda ids: "".join(chr(min(i, 122)) for i in ids if 32 <= i <= 122)
|
||||
tok.eos_token_id = eos_id
|
||||
tok.all_special_ids = [eos_id]
|
||||
tok.apply_chat_template.side_effect = lambda messages, **kw: \
|
||||
" ".join(m["content"] for m in messages).encode().hex()
|
||||
return tok
|
||||
|
||||
|
||||
def _fake_single_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k):
|
||||
text = f"trace-{seed}-final-answer"
|
||||
ids = [ord(c) for c in text]
|
||||
return GenerationResult(
|
||||
token_ids=ids, text=text, generated_tokens=len(ids),
|
||||
finish_reason="eos", elapsed_s=0.01,
|
||||
)
|
||||
|
||||
|
||||
def test_t1_single_round_produces_final_text(tmp_path):
|
||||
cfg = RSAConfig(rounds=1, parallel=2, aggregation_subsample=2,
|
||||
chunk_tokens=64, tail_tokens=8, serial=True, seed=123)
|
||||
orch = MarkovianRSAOrchestrator(
|
||||
model=MagicMock(),
|
||||
tokenizer=_fake_tokenizer(),
|
||||
model_id="test-model",
|
||||
quantization="bf16",
|
||||
single_generate=_fake_single_gen,
|
||||
batch_generate=None,
|
||||
)
|
||||
audit_path = tmp_path / "audit.jsonl"
|
||||
text, result = orch.solve("What is 2+2?", config=cfg, return_audit=True, audit_path=audit_path)
|
||||
assert isinstance(text, str)
|
||||
assert text == result.final_text
|
||||
assert result.config.rounds == 1
|
||||
assert len(result.rounds) == 1
|
||||
assert len(result.rounds[0].traces) == 2
|
||||
assert audit_path.exists()
|
||||
lines = audit_path.read_text().strip().split("\n")
|
||||
# at minimum: run_start, 2 trace_complete, final, run_end
|
||||
assert len(lines) >= 5
|
||||
|
||||
|
||||
def test_t1_returns_string_when_return_audit_false(tmp_path):
|
||||
cfg = RSAConfig(rounds=1, parallel=2, aggregation_subsample=2, serial=True)
|
||||
orch = MarkovianRSAOrchestrator(
|
||||
model=MagicMock(), tokenizer=_fake_tokenizer(),
|
||||
model_id="m", quantization="bf16",
|
||||
single_generate=_fake_single_gen,
|
||||
)
|
||||
out = orch.solve("X", config=cfg)
|
||||
assert isinstance(out, str)
|
||||
Reference in New Issue
Block a user