feat(batching): GenerationRequest/Result + run_batch dispatch (serial vs batched)
This commit is contained in:
86
tests/test_batching.py
Normal file
86
tests/test_batching.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from unittest.mock import MagicMock
|
||||
from markovian_rsa_mlx.batching import GenerationRequest, GenerationResult, run_batch
|
||||
|
||||
|
||||
def test_run_batch_serial_path_calls_per_request():
|
||||
# Mock per-request generator function : returns deterministic token IDs.
|
||||
def fake_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k):
|
||||
return GenerationResult(
|
||||
token_ids=list(range(10, 10 + min(max_tokens, 5))),
|
||||
text=f"text-seed-{seed}",
|
||||
generated_tokens=min(max_tokens, 5),
|
||||
finish_reason="eos",
|
||||
elapsed_s=0.01,
|
||||
)
|
||||
|
||||
requests = [
|
||||
GenerationRequest(prompt_token_ids=[1, 2, 3], seed=42, max_tokens=5),
|
||||
GenerationRequest(prompt_token_ids=[1, 2, 3], seed=43, max_tokens=5),
|
||||
]
|
||||
results = run_batch(
|
||||
model=MagicMock(),
|
||||
tokenizer=MagicMock(),
|
||||
requests=requests,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
top_k=-1,
|
||||
serial=True,
|
||||
single_generate=fake_gen,
|
||||
)
|
||||
assert len(results) == 2
|
||||
assert results[0].text == "text-seed-42"
|
||||
assert results[1].text == "text-seed-43"
|
||||
assert all(r.generated_tokens == 5 for r in results)
|
||||
assert all(r.finish_reason == "eos" for r in results)
|
||||
|
||||
|
||||
def test_run_batch_batched_path_uses_batch_generate(monkeypatch):
|
||||
# Patch BatchGenerator-like callable : returns N results in one shot.
|
||||
def fake_batch_gen(model, tokenizer, requests, *, temperature, top_p, top_k):
|
||||
return [
|
||||
GenerationResult(
|
||||
token_ids=[10, 11, 12],
|
||||
text=f"batched-{r.seed}",
|
||||
generated_tokens=3,
|
||||
finish_reason="max_tokens",
|
||||
elapsed_s=0.02,
|
||||
)
|
||||
for r in requests
|
||||
]
|
||||
|
||||
requests = [
|
||||
GenerationRequest(prompt_token_ids=[1], seed=1, max_tokens=3),
|
||||
GenerationRequest(prompt_token_ids=[1], seed=2, max_tokens=3),
|
||||
GenerationRequest(prompt_token_ids=[1], seed=3, max_tokens=3),
|
||||
]
|
||||
results = run_batch(
|
||||
model=MagicMock(),
|
||||
tokenizer=MagicMock(),
|
||||
requests=requests,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
top_k=-1,
|
||||
serial=False,
|
||||
batch_generate=fake_batch_gen,
|
||||
)
|
||||
assert len(results) == 3
|
||||
assert results[0].text == "batched-1"
|
||||
assert results[2].text == "batched-3"
|
||||
|
||||
|
||||
def test_run_batch_single_request_uses_single_path():
|
||||
def fake_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k):
|
||||
return GenerationResult(token_ids=[1], text="single", generated_tokens=1, finish_reason="eos", elapsed_s=0.0)
|
||||
requests = [GenerationRequest(prompt_token_ids=[42], seed=0, max_tokens=1)]
|
||||
results = run_batch(
|
||||
model=MagicMock(),
|
||||
tokenizer=MagicMock(),
|
||||
requests=requests,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
top_k=-1,
|
||||
serial=False, # batched-by-default but N=1 → serial under hood
|
||||
single_generate=fake_gen,
|
||||
batch_generate=None, # not provided
|
||||
)
|
||||
assert results[0].text == "single"
|
||||
Reference in New Issue
Block a user