Files
markovian-rsa-mlx/tests/test_batching.py

87 lines
3.0 KiB
Python

from unittest.mock import MagicMock
from markovian_rsa_mlx.batching import GenerationRequest, GenerationResult, run_batch
def test_run_batch_serial_path_calls_per_request():
# Mock per-request generator function : returns deterministic token IDs.
def fake_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k):
return GenerationResult(
token_ids=list(range(10, 10 + min(max_tokens, 5))),
text=f"text-seed-{seed}",
generated_tokens=min(max_tokens, 5),
finish_reason="eos",
elapsed_s=0.01,
)
requests = [
GenerationRequest(prompt_token_ids=[1, 2, 3], seed=42, max_tokens=5),
GenerationRequest(prompt_token_ids=[1, 2, 3], seed=43, max_tokens=5),
]
results = run_batch(
model=MagicMock(),
tokenizer=MagicMock(),
requests=requests,
temperature=1.0,
top_p=0.95,
top_k=-1,
serial=True,
single_generate=fake_gen,
)
assert len(results) == 2
assert results[0].text == "text-seed-42"
assert results[1].text == "text-seed-43"
assert all(r.generated_tokens == 5 for r in results)
assert all(r.finish_reason == "eos" for r in results)
def test_run_batch_batched_path_uses_batch_generate(monkeypatch):
# Patch BatchGenerator-like callable : returns N results in one shot.
def fake_batch_gen(model, tokenizer, requests, *, temperature, top_p, top_k):
return [
GenerationResult(
token_ids=[10, 11, 12],
text=f"batched-{r.seed}",
generated_tokens=3,
finish_reason="max_tokens",
elapsed_s=0.02,
)
for r in requests
]
requests = [
GenerationRequest(prompt_token_ids=[1], seed=1, max_tokens=3),
GenerationRequest(prompt_token_ids=[1], seed=2, max_tokens=3),
GenerationRequest(prompt_token_ids=[1], seed=3, max_tokens=3),
]
results = run_batch(
model=MagicMock(),
tokenizer=MagicMock(),
requests=requests,
temperature=1.0,
top_p=1.0,
top_k=-1,
serial=False,
batch_generate=fake_batch_gen,
)
assert len(results) == 3
assert results[0].text == "batched-1"
assert results[2].text == "batched-3"
def test_run_batch_single_request_uses_single_path():
def fake_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k):
return GenerationResult(token_ids=[1], text="single", generated_tokens=1, finish_reason="eos", elapsed_s=0.0)
requests = [GenerationRequest(prompt_token_ids=[42], seed=0, max_tokens=1)]
results = run_batch(
model=MagicMock(),
tokenizer=MagicMock(),
requests=requests,
temperature=1.0,
top_p=0.95,
top_k=-1,
serial=False, # batched-by-default but N=1 → serial under hood
single_generate=fake_gen,
batch_generate=None, # not provided
)
assert results[0].text == "single"