From 4b55163a5c3ff1b3c8281b3ac24468ca3a629393 Mon Sep 17 00:00:00 2001 From: transcrilive Date: Sun, 10 May 2026 02:54:58 +0200 Subject: [PATCH] feat(batching): GenerationRequest/Result + run_batch dispatch (serial vs batched) --- src/markovian_rsa_mlx/batching.py | 165 ++++++++++++++++++++++++++++++ tests/test_batching.py | 86 ++++++++++++++++ 2 files changed, 251 insertions(+) create mode 100644 src/markovian_rsa_mlx/batching.py create mode 100644 tests/test_batching.py diff --git a/src/markovian_rsa_mlx/batching.py b/src/markovian_rsa_mlx/batching.py new file mode 100644 index 0000000..851d87b --- /dev/null +++ b/src/markovian_rsa_mlx/batching.py @@ -0,0 +1,165 @@ +"""Thin abstraction over mlx-lm generation primitives. + +Exports : +- GenerationRequest / GenerationResult dataclasses with everything an audit + event needs. +- run_batch(...) : dispatches between serial and batched paths. + +The default `single_generate` and `batch_generate` callables resolve +mlx-lm's primitives lazily so that `import markovian_rsa_mlx.batching` +doesn't pull mlx-lm at module load time (useful for unit tests with mocks). +""" +from __future__ import annotations +import time +from dataclasses import dataclass +from typing import Any, Callable, Literal + + +@dataclass(frozen=True) +class GenerationRequest: + prompt_token_ids: list[int] + seed: int + max_tokens: int + + +@dataclass(frozen=True) +class GenerationResult: + token_ids: list[int] # full output (excluding prompt) + text: str # decoded output + generated_tokens: int + finish_reason: Literal["eos", "max_tokens", "error"] + elapsed_s: float + + +SingleGenerateFn = Callable[..., GenerationResult] +BatchGenerateFn = Callable[..., list[GenerationResult]] + + +def _default_single_generate( + model: Any, tokenizer: Any, prompt_token_ids: list[int], *, + max_tokens: int, seed: int, temperature: float, top_p: float, top_k: int, +) -> GenerationResult: + """Real mlx-lm single-prompt generation. Imported lazily.""" + import mlx.core as mx + from mlx_lm.generate import generate + + mx.random.seed(seed) + t0 = time.time() + text = generate( + model=model, + tokenizer=tokenizer, + prompt=prompt_token_ids, + max_tokens=max_tokens, + sampler=_make_sampler(temperature, top_p, top_k), + verbose=False, + ) + elapsed = time.time() - t0 + out_ids = tokenizer.encode(text) + finish = "eos" if (out_ids and out_ids[-1] in _eos_ids(tokenizer)) else "max_tokens" + return GenerationResult( + token_ids=out_ids, + text=text, + generated_tokens=len(out_ids), + finish_reason=finish, + elapsed_s=elapsed, + ) + + +def _default_batch_generate( + model: Any, tokenizer: Any, requests: list[GenerationRequest], *, + temperature: float, top_p: float, top_k: int, +) -> list[GenerationResult]: + """Real mlx-lm batched generation via BatchGenerator. Imported lazily.""" + import mlx.core as mx + try: + from mlx_lm.batch_generate import BatchGenerator + except ImportError as e: + raise RuntimeError( + "mlx_lm.batch_generate not available — install kyr0/mlx-lm fork " + "(feat/zaya-support) or pass serial=True" + ) from e + + sampler = _make_sampler(temperature, top_p, top_k) + gen = BatchGenerator(model, tokenizer, sampler=sampler) + t0 = time.time() + raw = gen.generate( + prompts=[r.prompt_token_ids for r in requests], + max_tokens=[r.max_tokens for r in requests], + seeds=[r.seed for r in requests], + ) + elapsed = time.time() - t0 + per_request = elapsed / max(len(requests), 1) + results: list[GenerationResult] = [] + eos = _eos_ids(tokenizer) + for req_idx, item in enumerate(raw): + token_ids = list(item.token_ids) + text = tokenizer.decode(token_ids) if hasattr(tokenizer, "decode") else item.text + finish = "eos" if (token_ids and token_ids[-1] in eos) else "max_tokens" + results.append(GenerationResult( + token_ids=token_ids, + text=text, + generated_tokens=len(token_ids), + finish_reason=finish, + elapsed_s=per_request, + )) + return results + + +def _make_sampler(temperature: float, top_p: float, top_k: int): + from mlx_lm.sample_utils import make_sampler + return make_sampler(temp=temperature, top_p=top_p, top_k=top_k if top_k > 0 else 0) + + +def _eos_ids(tokenizer: Any) -> set[int]: + ids: set[int] = set() + eos_id = getattr(tokenizer, "eos_token_id", None) + if isinstance(eos_id, int): + ids.add(eos_id) + extra = getattr(tokenizer, "all_special_ids", None) or [] + for x in extra: + if isinstance(x, int): + ids.add(x) + return ids + + +def run_batch( + model: Any, + tokenizer: Any, + requests: list[GenerationRequest], + *, + temperature: float, + top_p: float, + top_k: int, + serial: bool, + single_generate: SingleGenerateFn | None = None, + batch_generate: BatchGenerateFn | None = None, +) -> list[GenerationResult]: + """Run N generation requests. Use batched path unless serial=True or N==1.""" + sg = single_generate or _default_single_generate + bg = batch_generate or _default_batch_generate + n = len(requests) + + if n == 0: + return [] + if serial or n == 1 or bg is None: + return [ + sg( + model=model, + tokenizer=tokenizer, + prompt_token_ids=r.prompt_token_ids, + max_tokens=r.max_tokens, + seed=r.seed, + temperature=temperature, + top_p=top_p, + top_k=top_k, + ) + for r in requests + ] + return bg( + model=model, + tokenizer=tokenizer, + requests=requests, + temperature=temperature, + top_p=top_p, + top_k=top_k, + ) diff --git a/tests/test_batching.py b/tests/test_batching.py new file mode 100644 index 0000000..582689d --- /dev/null +++ b/tests/test_batching.py @@ -0,0 +1,86 @@ +from unittest.mock import MagicMock +from markovian_rsa_mlx.batching import GenerationRequest, GenerationResult, run_batch + + +def test_run_batch_serial_path_calls_per_request(): + # Mock per-request generator function : returns deterministic token IDs. + def fake_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k): + return GenerationResult( + token_ids=list(range(10, 10 + min(max_tokens, 5))), + text=f"text-seed-{seed}", + generated_tokens=min(max_tokens, 5), + finish_reason="eos", + elapsed_s=0.01, + ) + + requests = [ + GenerationRequest(prompt_token_ids=[1, 2, 3], seed=42, max_tokens=5), + GenerationRequest(prompt_token_ids=[1, 2, 3], seed=43, max_tokens=5), + ] + results = run_batch( + model=MagicMock(), + tokenizer=MagicMock(), + requests=requests, + temperature=1.0, + top_p=0.95, + top_k=-1, + serial=True, + single_generate=fake_gen, + ) + assert len(results) == 2 + assert results[0].text == "text-seed-42" + assert results[1].text == "text-seed-43" + assert all(r.generated_tokens == 5 for r in results) + assert all(r.finish_reason == "eos" for r in results) + + +def test_run_batch_batched_path_uses_batch_generate(monkeypatch): + # Patch BatchGenerator-like callable : returns N results in one shot. + def fake_batch_gen(model, tokenizer, requests, *, temperature, top_p, top_k): + return [ + GenerationResult( + token_ids=[10, 11, 12], + text=f"batched-{r.seed}", + generated_tokens=3, + finish_reason="max_tokens", + elapsed_s=0.02, + ) + for r in requests + ] + + requests = [ + GenerationRequest(prompt_token_ids=[1], seed=1, max_tokens=3), + GenerationRequest(prompt_token_ids=[1], seed=2, max_tokens=3), + GenerationRequest(prompt_token_ids=[1], seed=3, max_tokens=3), + ] + results = run_batch( + model=MagicMock(), + tokenizer=MagicMock(), + requests=requests, + temperature=1.0, + top_p=1.0, + top_k=-1, + serial=False, + batch_generate=fake_batch_gen, + ) + assert len(results) == 3 + assert results[0].text == "batched-1" + assert results[2].text == "batched-3" + + +def test_run_batch_single_request_uses_single_path(): + def fake_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k): + return GenerationResult(token_ids=[1], text="single", generated_tokens=1, finish_reason="eos", elapsed_s=0.0) + requests = [GenerationRequest(prompt_token_ids=[42], seed=0, max_tokens=1)] + results = run_batch( + model=MagicMock(), + tokenizer=MagicMock(), + requests=requests, + temperature=1.0, + top_p=0.95, + top_k=-1, + serial=False, # batched-by-default but N=1 → serial under hood + single_generate=fake_gen, + batch_generate=None, # not provided + ) + assert results[0].text == "single"