feat(batching): GenerationRequest/Result + run_batch dispatch (serial vs batched)
This commit is contained in:
165
src/markovian_rsa_mlx/batching.py
Normal file
165
src/markovian_rsa_mlx/batching.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Thin abstraction over mlx-lm generation primitives.
|
||||
|
||||
Exports :
|
||||
- GenerationRequest / GenerationResult dataclasses with everything an audit
|
||||
event needs.
|
||||
- run_batch(...) : dispatches between serial and batched paths.
|
||||
|
||||
The default `single_generate` and `batch_generate` callables resolve
|
||||
mlx-lm's primitives lazily so that `import markovian_rsa_mlx.batching`
|
||||
doesn't pull mlx-lm at module load time (useful for unit tests with mocks).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationRequest:
|
||||
prompt_token_ids: list[int]
|
||||
seed: int
|
||||
max_tokens: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationResult:
|
||||
token_ids: list[int] # full output (excluding prompt)
|
||||
text: str # decoded output
|
||||
generated_tokens: int
|
||||
finish_reason: Literal["eos", "max_tokens", "error"]
|
||||
elapsed_s: float
|
||||
|
||||
|
||||
SingleGenerateFn = Callable[..., GenerationResult]
|
||||
BatchGenerateFn = Callable[..., list[GenerationResult]]
|
||||
|
||||
|
||||
def _default_single_generate(
|
||||
model: Any, tokenizer: Any, prompt_token_ids: list[int], *,
|
||||
max_tokens: int, seed: int, temperature: float, top_p: float, top_k: int,
|
||||
) -> GenerationResult:
|
||||
"""Real mlx-lm single-prompt generation. Imported lazily."""
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import generate
|
||||
|
||||
mx.random.seed(seed)
|
||||
t0 = time.time()
|
||||
text = generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt_token_ids,
|
||||
max_tokens=max_tokens,
|
||||
sampler=_make_sampler(temperature, top_p, top_k),
|
||||
verbose=False,
|
||||
)
|
||||
elapsed = time.time() - t0
|
||||
out_ids = tokenizer.encode(text)
|
||||
finish = "eos" if (out_ids and out_ids[-1] in _eos_ids(tokenizer)) else "max_tokens"
|
||||
return GenerationResult(
|
||||
token_ids=out_ids,
|
||||
text=text,
|
||||
generated_tokens=len(out_ids),
|
||||
finish_reason=finish,
|
||||
elapsed_s=elapsed,
|
||||
)
|
||||
|
||||
|
||||
def _default_batch_generate(
|
||||
model: Any, tokenizer: Any, requests: list[GenerationRequest], *,
|
||||
temperature: float, top_p: float, top_k: int,
|
||||
) -> list[GenerationResult]:
|
||||
"""Real mlx-lm batched generation via BatchGenerator. Imported lazily."""
|
||||
import mlx.core as mx
|
||||
try:
|
||||
from mlx_lm.batch_generate import BatchGenerator
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"mlx_lm.batch_generate not available — install kyr0/mlx-lm fork "
|
||||
"(feat/zaya-support) or pass serial=True"
|
||||
) from e
|
||||
|
||||
sampler = _make_sampler(temperature, top_p, top_k)
|
||||
gen = BatchGenerator(model, tokenizer, sampler=sampler)
|
||||
t0 = time.time()
|
||||
raw = gen.generate(
|
||||
prompts=[r.prompt_token_ids for r in requests],
|
||||
max_tokens=[r.max_tokens for r in requests],
|
||||
seeds=[r.seed for r in requests],
|
||||
)
|
||||
elapsed = time.time() - t0
|
||||
per_request = elapsed / max(len(requests), 1)
|
||||
results: list[GenerationResult] = []
|
||||
eos = _eos_ids(tokenizer)
|
||||
for req_idx, item in enumerate(raw):
|
||||
token_ids = list(item.token_ids)
|
||||
text = tokenizer.decode(token_ids) if hasattr(tokenizer, "decode") else item.text
|
||||
finish = "eos" if (token_ids and token_ids[-1] in eos) else "max_tokens"
|
||||
results.append(GenerationResult(
|
||||
token_ids=token_ids,
|
||||
text=text,
|
||||
generated_tokens=len(token_ids),
|
||||
finish_reason=finish,
|
||||
elapsed_s=per_request,
|
||||
))
|
||||
return results
|
||||
|
||||
|
||||
def _make_sampler(temperature: float, top_p: float, top_k: int):
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
return make_sampler(temp=temperature, top_p=top_p, top_k=top_k if top_k > 0 else 0)
|
||||
|
||||
|
||||
def _eos_ids(tokenizer: Any) -> set[int]:
|
||||
ids: set[int] = set()
|
||||
eos_id = getattr(tokenizer, "eos_token_id", None)
|
||||
if isinstance(eos_id, int):
|
||||
ids.add(eos_id)
|
||||
extra = getattr(tokenizer, "all_special_ids", None) or []
|
||||
for x in extra:
|
||||
if isinstance(x, int):
|
||||
ids.add(x)
|
||||
return ids
|
||||
|
||||
|
||||
def run_batch(
|
||||
model: Any,
|
||||
tokenizer: Any,
|
||||
requests: list[GenerationRequest],
|
||||
*,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
serial: bool,
|
||||
single_generate: SingleGenerateFn | None = None,
|
||||
batch_generate: BatchGenerateFn | None = None,
|
||||
) -> list[GenerationResult]:
|
||||
"""Run N generation requests. Use batched path unless serial=True or N==1."""
|
||||
sg = single_generate or _default_single_generate
|
||||
bg = batch_generate or _default_batch_generate
|
||||
n = len(requests)
|
||||
|
||||
if n == 0:
|
||||
return []
|
||||
if serial or n == 1 or bg is None:
|
||||
return [
|
||||
sg(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt_token_ids=r.prompt_token_ids,
|
||||
max_tokens=r.max_tokens,
|
||||
seed=r.seed,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
)
|
||||
for r in requests
|
||||
]
|
||||
return bg(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requests=requests,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
)
|
||||
Reference in New Issue
Block a user