from unittest.mock import MagicMock from markovian_rsa_mlx.batching import GenerationRequest, GenerationResult from markovian_rsa_mlx.config import RSAConfig from markovian_rsa_mlx.orchestrator import MarkovianRSAOrchestrator def _fake_tokenizer(eos_id: int = 999): tok = MagicMock() tok.encode.side_effect = lambda s: [ord(c) for c in s][:32] or [1] tok.decode.side_effect = lambda ids: "".join(chr(min(i, 122)) for i in ids if 32 <= i <= 122) tok.eos_token_id = eos_id tok.all_special_ids = [eos_id] tok.apply_chat_template.side_effect = lambda messages, **kw: \ " ".join(m["content"] for m in messages).encode().hex() return tok def _fake_single_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k): text = f"trace-{seed}-final-answer" ids = [ord(c) for c in text] return GenerationResult( token_ids=ids, text=text, generated_tokens=len(ids), finish_reason="eos", elapsed_s=0.01, ) def test_t1_single_round_produces_final_text(tmp_path): cfg = RSAConfig(rounds=1, parallel=2, aggregation_subsample=2, chunk_tokens=64, tail_tokens=8, serial=True, seed=123) orch = MarkovianRSAOrchestrator( model=MagicMock(), tokenizer=_fake_tokenizer(), model_id="test-model", quantization="bf16", single_generate=_fake_single_gen, batch_generate=None, ) audit_path = tmp_path / "audit.jsonl" text, result = orch.solve("What is 2+2?", config=cfg, return_audit=True, audit_path=audit_path) assert isinstance(text, str) assert text == result.final_text assert result.config.rounds == 1 assert len(result.rounds) == 1 assert len(result.rounds[0].traces) == 2 assert audit_path.exists() lines = audit_path.read_text().strip().split("\n") # at minimum: run_start, 2 trace_complete, final, run_end assert len(lines) >= 5 def test_t1_returns_string_when_return_audit_false(tmp_path): cfg = RSAConfig(rounds=1, parallel=2, aggregation_subsample=2, serial=True) orch = MarkovianRSAOrchestrator( model=MagicMock(), tokenizer=_fake_tokenizer(), model_id="m", quantization="bf16", single_generate=_fake_single_gen, ) out = orch.solve("X", config=cfg) assert isinstance(out, str)