feat(batching): GenerationRequest/Result + run_batch dispatch (serial vs batched)

This commit is contained in:
transcrilive
2026-05-10 02:54:58 +02:00
parent db710cc157
commit 4b55163a5c
2 changed files with 251 additions and 0 deletions

View File

@@ -0,0 +1,165 @@
"""Thin abstraction over mlx-lm generation primitives.
Exports :
- GenerationRequest / GenerationResult dataclasses with everything an audit
event needs.
- run_batch(...) : dispatches between serial and batched paths.
The default `single_generate` and `batch_generate` callables resolve
mlx-lm's primitives lazily so that `import markovian_rsa_mlx.batching`
doesn't pull mlx-lm at module load time (useful for unit tests with mocks).
"""
from __future__ import annotations
import time
from dataclasses import dataclass
from typing import Any, Callable, Literal
@dataclass(frozen=True)
class GenerationRequest:
prompt_token_ids: list[int]
seed: int
max_tokens: int
@dataclass(frozen=True)
class GenerationResult:
token_ids: list[int] # full output (excluding prompt)
text: str # decoded output
generated_tokens: int
finish_reason: Literal["eos", "max_tokens", "error"]
elapsed_s: float
SingleGenerateFn = Callable[..., GenerationResult]
BatchGenerateFn = Callable[..., list[GenerationResult]]
def _default_single_generate(
model: Any, tokenizer: Any, prompt_token_ids: list[int], *,
max_tokens: int, seed: int, temperature: float, top_p: float, top_k: int,
) -> GenerationResult:
"""Real mlx-lm single-prompt generation. Imported lazily."""
import mlx.core as mx
from mlx_lm.generate import generate
mx.random.seed(seed)
t0 = time.time()
text = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_token_ids,
max_tokens=max_tokens,
sampler=_make_sampler(temperature, top_p, top_k),
verbose=False,
)
elapsed = time.time() - t0
out_ids = tokenizer.encode(text)
finish = "eos" if (out_ids and out_ids[-1] in _eos_ids(tokenizer)) else "max_tokens"
return GenerationResult(
token_ids=out_ids,
text=text,
generated_tokens=len(out_ids),
finish_reason=finish,
elapsed_s=elapsed,
)
def _default_batch_generate(
model: Any, tokenizer: Any, requests: list[GenerationRequest], *,
temperature: float, top_p: float, top_k: int,
) -> list[GenerationResult]:
"""Real mlx-lm batched generation via BatchGenerator. Imported lazily."""
import mlx.core as mx
try:
from mlx_lm.batch_generate import BatchGenerator
except ImportError as e:
raise RuntimeError(
"mlx_lm.batch_generate not available — install kyr0/mlx-lm fork "
"(feat/zaya-support) or pass serial=True"
) from e
sampler = _make_sampler(temperature, top_p, top_k)
gen = BatchGenerator(model, tokenizer, sampler=sampler)
t0 = time.time()
raw = gen.generate(
prompts=[r.prompt_token_ids for r in requests],
max_tokens=[r.max_tokens for r in requests],
seeds=[r.seed for r in requests],
)
elapsed = time.time() - t0
per_request = elapsed / max(len(requests), 1)
results: list[GenerationResult] = []
eos = _eos_ids(tokenizer)
for req_idx, item in enumerate(raw):
token_ids = list(item.token_ids)
text = tokenizer.decode(token_ids) if hasattr(tokenizer, "decode") else item.text
finish = "eos" if (token_ids and token_ids[-1] in eos) else "max_tokens"
results.append(GenerationResult(
token_ids=token_ids,
text=text,
generated_tokens=len(token_ids),
finish_reason=finish,
elapsed_s=per_request,
))
return results
def _make_sampler(temperature: float, top_p: float, top_k: int):
from mlx_lm.sample_utils import make_sampler
return make_sampler(temp=temperature, top_p=top_p, top_k=top_k if top_k > 0 else 0)
def _eos_ids(tokenizer: Any) -> set[int]:
ids: set[int] = set()
eos_id = getattr(tokenizer, "eos_token_id", None)
if isinstance(eos_id, int):
ids.add(eos_id)
extra = getattr(tokenizer, "all_special_ids", None) or []
for x in extra:
if isinstance(x, int):
ids.add(x)
return ids
def run_batch(
model: Any,
tokenizer: Any,
requests: list[GenerationRequest],
*,
temperature: float,
top_p: float,
top_k: int,
serial: bool,
single_generate: SingleGenerateFn | None = None,
batch_generate: BatchGenerateFn | None = None,
) -> list[GenerationResult]:
"""Run N generation requests. Use batched path unless serial=True or N==1."""
sg = single_generate or _default_single_generate
bg = batch_generate or _default_batch_generate
n = len(requests)
if n == 0:
return []
if serial or n == 1 or bg is None:
return [
sg(
model=model,
tokenizer=tokenizer,
prompt_token_ids=r.prompt_token_ids,
max_tokens=r.max_tokens,
seed=r.seed,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
for r in requests
]
return bg(
model=model,
tokenizer=tokenizer,
requests=requests,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)

86
tests/test_batching.py Normal file
View File

@@ -0,0 +1,86 @@
from unittest.mock import MagicMock
from markovian_rsa_mlx.batching import GenerationRequest, GenerationResult, run_batch
def test_run_batch_serial_path_calls_per_request():
# Mock per-request generator function : returns deterministic token IDs.
def fake_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k):
return GenerationResult(
token_ids=list(range(10, 10 + min(max_tokens, 5))),
text=f"text-seed-{seed}",
generated_tokens=min(max_tokens, 5),
finish_reason="eos",
elapsed_s=0.01,
)
requests = [
GenerationRequest(prompt_token_ids=[1, 2, 3], seed=42, max_tokens=5),
GenerationRequest(prompt_token_ids=[1, 2, 3], seed=43, max_tokens=5),
]
results = run_batch(
model=MagicMock(),
tokenizer=MagicMock(),
requests=requests,
temperature=1.0,
top_p=0.95,
top_k=-1,
serial=True,
single_generate=fake_gen,
)
assert len(results) == 2
assert results[0].text == "text-seed-42"
assert results[1].text == "text-seed-43"
assert all(r.generated_tokens == 5 for r in results)
assert all(r.finish_reason == "eos" for r in results)
def test_run_batch_batched_path_uses_batch_generate(monkeypatch):
# Patch BatchGenerator-like callable : returns N results in one shot.
def fake_batch_gen(model, tokenizer, requests, *, temperature, top_p, top_k):
return [
GenerationResult(
token_ids=[10, 11, 12],
text=f"batched-{r.seed}",
generated_tokens=3,
finish_reason="max_tokens",
elapsed_s=0.02,
)
for r in requests
]
requests = [
GenerationRequest(prompt_token_ids=[1], seed=1, max_tokens=3),
GenerationRequest(prompt_token_ids=[1], seed=2, max_tokens=3),
GenerationRequest(prompt_token_ids=[1], seed=3, max_tokens=3),
]
results = run_batch(
model=MagicMock(),
tokenizer=MagicMock(),
requests=requests,
temperature=1.0,
top_p=1.0,
top_k=-1,
serial=False,
batch_generate=fake_batch_gen,
)
assert len(results) == 3
assert results[0].text == "batched-1"
assert results[2].text == "batched-3"
def test_run_batch_single_request_uses_single_path():
def fake_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k):
return GenerationResult(token_ids=[1], text="single", generated_tokens=1, finish_reason="eos", elapsed_s=0.0)
requests = [GenerationRequest(prompt_token_ids=[42], seed=0, max_tokens=1)]
results = run_batch(
model=MagicMock(),
tokenizer=MagicMock(),
requests=requests,
temperature=1.0,
top_p=0.95,
top_k=-1,
serial=False, # batched-by-default but N=1 → serial under hood
single_generate=fake_gen,
batch_generate=None, # not provided
)
assert results[0].text == "single"