feat(batching): GenerationRequest/Result + run_batch dispatch (serial vs batched)
This commit is contained in:
165
src/markovian_rsa_mlx/batching.py
Normal file
165
src/markovian_rsa_mlx/batching.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Thin abstraction over mlx-lm generation primitives.
|
||||
|
||||
Exports :
|
||||
- GenerationRequest / GenerationResult dataclasses with everything an audit
|
||||
event needs.
|
||||
- run_batch(...) : dispatches between serial and batched paths.
|
||||
|
||||
The default `single_generate` and `batch_generate` callables resolve
|
||||
mlx-lm's primitives lazily so that `import markovian_rsa_mlx.batching`
|
||||
doesn't pull mlx-lm at module load time (useful for unit tests with mocks).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationRequest:
|
||||
prompt_token_ids: list[int]
|
||||
seed: int
|
||||
max_tokens: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationResult:
|
||||
token_ids: list[int] # full output (excluding prompt)
|
||||
text: str # decoded output
|
||||
generated_tokens: int
|
||||
finish_reason: Literal["eos", "max_tokens", "error"]
|
||||
elapsed_s: float
|
||||
|
||||
|
||||
SingleGenerateFn = Callable[..., GenerationResult]
|
||||
BatchGenerateFn = Callable[..., list[GenerationResult]]
|
||||
|
||||
|
||||
def _default_single_generate(
|
||||
model: Any, tokenizer: Any, prompt_token_ids: list[int], *,
|
||||
max_tokens: int, seed: int, temperature: float, top_p: float, top_k: int,
|
||||
) -> GenerationResult:
|
||||
"""Real mlx-lm single-prompt generation. Imported lazily."""
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import generate
|
||||
|
||||
mx.random.seed(seed)
|
||||
t0 = time.time()
|
||||
text = generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt_token_ids,
|
||||
max_tokens=max_tokens,
|
||||
sampler=_make_sampler(temperature, top_p, top_k),
|
||||
verbose=False,
|
||||
)
|
||||
elapsed = time.time() - t0
|
||||
out_ids = tokenizer.encode(text)
|
||||
finish = "eos" if (out_ids and out_ids[-1] in _eos_ids(tokenizer)) else "max_tokens"
|
||||
return GenerationResult(
|
||||
token_ids=out_ids,
|
||||
text=text,
|
||||
generated_tokens=len(out_ids),
|
||||
finish_reason=finish,
|
||||
elapsed_s=elapsed,
|
||||
)
|
||||
|
||||
|
||||
def _default_batch_generate(
|
||||
model: Any, tokenizer: Any, requests: list[GenerationRequest], *,
|
||||
temperature: float, top_p: float, top_k: int,
|
||||
) -> list[GenerationResult]:
|
||||
"""Real mlx-lm batched generation via BatchGenerator. Imported lazily."""
|
||||
import mlx.core as mx
|
||||
try:
|
||||
from mlx_lm.batch_generate import BatchGenerator
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
"mlx_lm.batch_generate not available — install kyr0/mlx-lm fork "
|
||||
"(feat/zaya-support) or pass serial=True"
|
||||
) from e
|
||||
|
||||
sampler = _make_sampler(temperature, top_p, top_k)
|
||||
gen = BatchGenerator(model, tokenizer, sampler=sampler)
|
||||
t0 = time.time()
|
||||
raw = gen.generate(
|
||||
prompts=[r.prompt_token_ids for r in requests],
|
||||
max_tokens=[r.max_tokens for r in requests],
|
||||
seeds=[r.seed for r in requests],
|
||||
)
|
||||
elapsed = time.time() - t0
|
||||
per_request = elapsed / max(len(requests), 1)
|
||||
results: list[GenerationResult] = []
|
||||
eos = _eos_ids(tokenizer)
|
||||
for req_idx, item in enumerate(raw):
|
||||
token_ids = list(item.token_ids)
|
||||
text = tokenizer.decode(token_ids) if hasattr(tokenizer, "decode") else item.text
|
||||
finish = "eos" if (token_ids and token_ids[-1] in eos) else "max_tokens"
|
||||
results.append(GenerationResult(
|
||||
token_ids=token_ids,
|
||||
text=text,
|
||||
generated_tokens=len(token_ids),
|
||||
finish_reason=finish,
|
||||
elapsed_s=per_request,
|
||||
))
|
||||
return results
|
||||
|
||||
|
||||
def _make_sampler(temperature: float, top_p: float, top_k: int):
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
return make_sampler(temp=temperature, top_p=top_p, top_k=top_k if top_k > 0 else 0)
|
||||
|
||||
|
||||
def _eos_ids(tokenizer: Any) -> set[int]:
|
||||
ids: set[int] = set()
|
||||
eos_id = getattr(tokenizer, "eos_token_id", None)
|
||||
if isinstance(eos_id, int):
|
||||
ids.add(eos_id)
|
||||
extra = getattr(tokenizer, "all_special_ids", None) or []
|
||||
for x in extra:
|
||||
if isinstance(x, int):
|
||||
ids.add(x)
|
||||
return ids
|
||||
|
||||
|
||||
def run_batch(
|
||||
model: Any,
|
||||
tokenizer: Any,
|
||||
requests: list[GenerationRequest],
|
||||
*,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
serial: bool,
|
||||
single_generate: SingleGenerateFn | None = None,
|
||||
batch_generate: BatchGenerateFn | None = None,
|
||||
) -> list[GenerationResult]:
|
||||
"""Run N generation requests. Use batched path unless serial=True or N==1."""
|
||||
sg = single_generate or _default_single_generate
|
||||
bg = batch_generate or _default_batch_generate
|
||||
n = len(requests)
|
||||
|
||||
if n == 0:
|
||||
return []
|
||||
if serial or n == 1 or bg is None:
|
||||
return [
|
||||
sg(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt_token_ids=r.prompt_token_ids,
|
||||
max_tokens=r.max_tokens,
|
||||
seed=r.seed,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
)
|
||||
for r in requests
|
||||
]
|
||||
return bg(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requests=requests,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
)
|
||||
86
tests/test_batching.py
Normal file
86
tests/test_batching.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from unittest.mock import MagicMock
|
||||
from markovian_rsa_mlx.batching import GenerationRequest, GenerationResult, run_batch
|
||||
|
||||
|
||||
def test_run_batch_serial_path_calls_per_request():
|
||||
# Mock per-request generator function : returns deterministic token IDs.
|
||||
def fake_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k):
|
||||
return GenerationResult(
|
||||
token_ids=list(range(10, 10 + min(max_tokens, 5))),
|
||||
text=f"text-seed-{seed}",
|
||||
generated_tokens=min(max_tokens, 5),
|
||||
finish_reason="eos",
|
||||
elapsed_s=0.01,
|
||||
)
|
||||
|
||||
requests = [
|
||||
GenerationRequest(prompt_token_ids=[1, 2, 3], seed=42, max_tokens=5),
|
||||
GenerationRequest(prompt_token_ids=[1, 2, 3], seed=43, max_tokens=5),
|
||||
]
|
||||
results = run_batch(
|
||||
model=MagicMock(),
|
||||
tokenizer=MagicMock(),
|
||||
requests=requests,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
top_k=-1,
|
||||
serial=True,
|
||||
single_generate=fake_gen,
|
||||
)
|
||||
assert len(results) == 2
|
||||
assert results[0].text == "text-seed-42"
|
||||
assert results[1].text == "text-seed-43"
|
||||
assert all(r.generated_tokens == 5 for r in results)
|
||||
assert all(r.finish_reason == "eos" for r in results)
|
||||
|
||||
|
||||
def test_run_batch_batched_path_uses_batch_generate(monkeypatch):
|
||||
# Patch BatchGenerator-like callable : returns N results in one shot.
|
||||
def fake_batch_gen(model, tokenizer, requests, *, temperature, top_p, top_k):
|
||||
return [
|
||||
GenerationResult(
|
||||
token_ids=[10, 11, 12],
|
||||
text=f"batched-{r.seed}",
|
||||
generated_tokens=3,
|
||||
finish_reason="max_tokens",
|
||||
elapsed_s=0.02,
|
||||
)
|
||||
for r in requests
|
||||
]
|
||||
|
||||
requests = [
|
||||
GenerationRequest(prompt_token_ids=[1], seed=1, max_tokens=3),
|
||||
GenerationRequest(prompt_token_ids=[1], seed=2, max_tokens=3),
|
||||
GenerationRequest(prompt_token_ids=[1], seed=3, max_tokens=3),
|
||||
]
|
||||
results = run_batch(
|
||||
model=MagicMock(),
|
||||
tokenizer=MagicMock(),
|
||||
requests=requests,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
top_k=-1,
|
||||
serial=False,
|
||||
batch_generate=fake_batch_gen,
|
||||
)
|
||||
assert len(results) == 3
|
||||
assert results[0].text == "batched-1"
|
||||
assert results[2].text == "batched-3"
|
||||
|
||||
|
||||
def test_run_batch_single_request_uses_single_path():
|
||||
def fake_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k):
|
||||
return GenerationResult(token_ids=[1], text="single", generated_tokens=1, finish_reason="eos", elapsed_s=0.0)
|
||||
requests = [GenerationRequest(prompt_token_ids=[42], seed=0, max_tokens=1)]
|
||||
results = run_batch(
|
||||
model=MagicMock(),
|
||||
tokenizer=MagicMock(),
|
||||
requests=requests,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
top_k=-1,
|
||||
serial=False, # batched-by-default but N=1 → serial under hood
|
||||
single_generate=fake_gen,
|
||||
batch_generate=None, # not provided
|
||||
)
|
||||
assert results[0].text == "single"
|
||||
Reference in New Issue
Block a user