from unittest.mock import MagicMock from markovian_rsa_mlx.batching import GenerationRequest, GenerationResult, run_batch def test_run_batch_serial_path_calls_per_request(): # Mock per-request generator function : returns deterministic token IDs. def fake_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k): return GenerationResult( token_ids=list(range(10, 10 + min(max_tokens, 5))), text=f"text-seed-{seed}", generated_tokens=min(max_tokens, 5), finish_reason="eos", elapsed_s=0.01, ) requests = [ GenerationRequest(prompt_token_ids=[1, 2, 3], seed=42, max_tokens=5), GenerationRequest(prompt_token_ids=[1, 2, 3], seed=43, max_tokens=5), ] results = run_batch( model=MagicMock(), tokenizer=MagicMock(), requests=requests, temperature=1.0, top_p=0.95, top_k=-1, serial=True, single_generate=fake_gen, ) assert len(results) == 2 assert results[0].text == "text-seed-42" assert results[1].text == "text-seed-43" assert all(r.generated_tokens == 5 for r in results) assert all(r.finish_reason == "eos" for r in results) def test_run_batch_batched_path_uses_batch_generate(monkeypatch): # Patch BatchGenerator-like callable : returns N results in one shot. def fake_batch_gen(model, tokenizer, requests, *, temperature, top_p, top_k): return [ GenerationResult( token_ids=[10, 11, 12], text=f"batched-{r.seed}", generated_tokens=3, finish_reason="max_tokens", elapsed_s=0.02, ) for r in requests ] requests = [ GenerationRequest(prompt_token_ids=[1], seed=1, max_tokens=3), GenerationRequest(prompt_token_ids=[1], seed=2, max_tokens=3), GenerationRequest(prompt_token_ids=[1], seed=3, max_tokens=3), ] results = run_batch( model=MagicMock(), tokenizer=MagicMock(), requests=requests, temperature=1.0, top_p=1.0, top_k=-1, serial=False, batch_generate=fake_batch_gen, ) assert len(results) == 3 assert results[0].text == "batched-1" assert results[2].text == "batched-3" def test_run_batch_single_request_uses_single_path(): def fake_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k): return GenerationResult(token_ids=[1], text="single", generated_tokens=1, finish_reason="eos", elapsed_s=0.0) requests = [GenerationRequest(prompt_token_ids=[42], seed=0, max_tokens=1)] results = run_batch( model=MagicMock(), tokenizer=MagicMock(), requests=requests, temperature=1.0, top_p=0.95, top_k=-1, serial=False, # batched-by-default but N=1 → serial under hood single_generate=fake_gen, batch_generate=None, # not provided ) assert results[0].text == "single"