87 lines
3.0 KiB
Python
87 lines
3.0 KiB
Python
from unittest.mock import MagicMock
|
|
from markovian_rsa_mlx.batching import GenerationRequest, GenerationResult, run_batch
|
|
|
|
|
|
def test_run_batch_serial_path_calls_per_request():
|
|
# Mock per-request generator function : returns deterministic token IDs.
|
|
def fake_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k):
|
|
return GenerationResult(
|
|
token_ids=list(range(10, 10 + min(max_tokens, 5))),
|
|
text=f"text-seed-{seed}",
|
|
generated_tokens=min(max_tokens, 5),
|
|
finish_reason="eos",
|
|
elapsed_s=0.01,
|
|
)
|
|
|
|
requests = [
|
|
GenerationRequest(prompt_token_ids=[1, 2, 3], seed=42, max_tokens=5),
|
|
GenerationRequest(prompt_token_ids=[1, 2, 3], seed=43, max_tokens=5),
|
|
]
|
|
results = run_batch(
|
|
model=MagicMock(),
|
|
tokenizer=MagicMock(),
|
|
requests=requests,
|
|
temperature=1.0,
|
|
top_p=0.95,
|
|
top_k=-1,
|
|
serial=True,
|
|
single_generate=fake_gen,
|
|
)
|
|
assert len(results) == 2
|
|
assert results[0].text == "text-seed-42"
|
|
assert results[1].text == "text-seed-43"
|
|
assert all(r.generated_tokens == 5 for r in results)
|
|
assert all(r.finish_reason == "eos" for r in results)
|
|
|
|
|
|
def test_run_batch_batched_path_uses_batch_generate(monkeypatch):
|
|
# Patch BatchGenerator-like callable : returns N results in one shot.
|
|
def fake_batch_gen(model, tokenizer, requests, *, temperature, top_p, top_k):
|
|
return [
|
|
GenerationResult(
|
|
token_ids=[10, 11, 12],
|
|
text=f"batched-{r.seed}",
|
|
generated_tokens=3,
|
|
finish_reason="max_tokens",
|
|
elapsed_s=0.02,
|
|
)
|
|
for r in requests
|
|
]
|
|
|
|
requests = [
|
|
GenerationRequest(prompt_token_ids=[1], seed=1, max_tokens=3),
|
|
GenerationRequest(prompt_token_ids=[1], seed=2, max_tokens=3),
|
|
GenerationRequest(prompt_token_ids=[1], seed=3, max_tokens=3),
|
|
]
|
|
results = run_batch(
|
|
model=MagicMock(),
|
|
tokenizer=MagicMock(),
|
|
requests=requests,
|
|
temperature=1.0,
|
|
top_p=1.0,
|
|
top_k=-1,
|
|
serial=False,
|
|
batch_generate=fake_batch_gen,
|
|
)
|
|
assert len(results) == 3
|
|
assert results[0].text == "batched-1"
|
|
assert results[2].text == "batched-3"
|
|
|
|
|
|
def test_run_batch_single_request_uses_single_path():
|
|
def fake_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k):
|
|
return GenerationResult(token_ids=[1], text="single", generated_tokens=1, finish_reason="eos", elapsed_s=0.0)
|
|
requests = [GenerationRequest(prompt_token_ids=[42], seed=0, max_tokens=1)]
|
|
results = run_batch(
|
|
model=MagicMock(),
|
|
tokenizer=MagicMock(),
|
|
requests=requests,
|
|
temperature=1.0,
|
|
top_p=0.95,
|
|
top_k=-1,
|
|
serial=False, # batched-by-default but N=1 → serial under hood
|
|
single_generate=fake_gen,
|
|
batch_generate=None, # not provided
|
|
)
|
|
assert results[0].text == "single"
|