feat(batching): GenerationRequest/Result + run_batch dispatch (serial vs batched)

This commit is contained in:
transcrilive
2026-05-10 02:54:58 +02:00
parent db710cc157
commit 4b55163a5c
2 changed files with 251 additions and 0 deletions

View File

@@ -0,0 +1,165 @@
"""Thin abstraction over mlx-lm generation primitives.
Exports :
- GenerationRequest / GenerationResult dataclasses with everything an audit
event needs.
- run_batch(...) : dispatches between serial and batched paths.
The default `single_generate` and `batch_generate` callables resolve
mlx-lm's primitives lazily so that `import markovian_rsa_mlx.batching`
doesn't pull mlx-lm at module load time (useful for unit tests with mocks).
"""
from __future__ import annotations
import time
from dataclasses import dataclass
from typing import Any, Callable, Literal
@dataclass(frozen=True)
class GenerationRequest:
prompt_token_ids: list[int]
seed: int
max_tokens: int
@dataclass(frozen=True)
class GenerationResult:
token_ids: list[int] # full output (excluding prompt)
text: str # decoded output
generated_tokens: int
finish_reason: Literal["eos", "max_tokens", "error"]
elapsed_s: float
SingleGenerateFn = Callable[..., GenerationResult]
BatchGenerateFn = Callable[..., list[GenerationResult]]
def _default_single_generate(
model: Any, tokenizer: Any, prompt_token_ids: list[int], *,
max_tokens: int, seed: int, temperature: float, top_p: float, top_k: int,
) -> GenerationResult:
"""Real mlx-lm single-prompt generation. Imported lazily."""
import mlx.core as mx
from mlx_lm.generate import generate
mx.random.seed(seed)
t0 = time.time()
text = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_token_ids,
max_tokens=max_tokens,
sampler=_make_sampler(temperature, top_p, top_k),
verbose=False,
)
elapsed = time.time() - t0
out_ids = tokenizer.encode(text)
finish = "eos" if (out_ids and out_ids[-1] in _eos_ids(tokenizer)) else "max_tokens"
return GenerationResult(
token_ids=out_ids,
text=text,
generated_tokens=len(out_ids),
finish_reason=finish,
elapsed_s=elapsed,
)
def _default_batch_generate(
model: Any, tokenizer: Any, requests: list[GenerationRequest], *,
temperature: float, top_p: float, top_k: int,
) -> list[GenerationResult]:
"""Real mlx-lm batched generation via BatchGenerator. Imported lazily."""
import mlx.core as mx
try:
from mlx_lm.batch_generate import BatchGenerator
except ImportError as e:
raise RuntimeError(
"mlx_lm.batch_generate not available — install kyr0/mlx-lm fork "
"(feat/zaya-support) or pass serial=True"
) from e
sampler = _make_sampler(temperature, top_p, top_k)
gen = BatchGenerator(model, tokenizer, sampler=sampler)
t0 = time.time()
raw = gen.generate(
prompts=[r.prompt_token_ids for r in requests],
max_tokens=[r.max_tokens for r in requests],
seeds=[r.seed for r in requests],
)
elapsed = time.time() - t0
per_request = elapsed / max(len(requests), 1)
results: list[GenerationResult] = []
eos = _eos_ids(tokenizer)
for req_idx, item in enumerate(raw):
token_ids = list(item.token_ids)
text = tokenizer.decode(token_ids) if hasattr(tokenizer, "decode") else item.text
finish = "eos" if (token_ids and token_ids[-1] in eos) else "max_tokens"
results.append(GenerationResult(
token_ids=token_ids,
text=text,
generated_tokens=len(token_ids),
finish_reason=finish,
elapsed_s=per_request,
))
return results
def _make_sampler(temperature: float, top_p: float, top_k: int):
from mlx_lm.sample_utils import make_sampler
return make_sampler(temp=temperature, top_p=top_p, top_k=top_k if top_k > 0 else 0)
def _eos_ids(tokenizer: Any) -> set[int]:
ids: set[int] = set()
eos_id = getattr(tokenizer, "eos_token_id", None)
if isinstance(eos_id, int):
ids.add(eos_id)
extra = getattr(tokenizer, "all_special_ids", None) or []
for x in extra:
if isinstance(x, int):
ids.add(x)
return ids
def run_batch(
model: Any,
tokenizer: Any,
requests: list[GenerationRequest],
*,
temperature: float,
top_p: float,
top_k: int,
serial: bool,
single_generate: SingleGenerateFn | None = None,
batch_generate: BatchGenerateFn | None = None,
) -> list[GenerationResult]:
"""Run N generation requests. Use batched path unless serial=True or N==1."""
sg = single_generate or _default_single_generate
bg = batch_generate or _default_batch_generate
n = len(requests)
if n == 0:
return []
if serial or n == 1 or bg is None:
return [
sg(
model=model,
tokenizer=tokenizer,
prompt_token_ids=r.prompt_token_ids,
max_tokens=r.max_tokens,
seed=r.seed,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
for r in requests
]
return bg(
model=model,
tokenizer=tokenizer,
requests=requests,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)