diff --git a/src/markovian_rsa_mlx/__init__.py b/src/markovian_rsa_mlx/__init__.py index e5ae407..a8c3d37 100644 --- a/src/markovian_rsa_mlx/__init__.py +++ b/src/markovian_rsa_mlx/__init__.py @@ -3,5 +3,6 @@ __version__ = "0.1.0" from markovian_rsa_mlx.config import RSAConfig from markovian_rsa_mlx.loader import load_zaya_model +from markovian_rsa_mlx.orchestrator import MarkovianRSAOrchestrator -__all__ = ["__version__", "RSAConfig", "load_zaya_model"] +__all__ = ["__version__", "RSAConfig", "load_zaya_model", "MarkovianRSAOrchestrator"] diff --git a/src/markovian_rsa_mlx/orchestrator.py b/src/markovian_rsa_mlx/orchestrator.py new file mode 100644 index 0000000..48824d6 --- /dev/null +++ b/src/markovian_rsa_mlx/orchestrator.py @@ -0,0 +1,235 @@ +"""MarkovianRSAOrchestrator — drives N parallel traces + aggregation rounds.""" +from __future__ import annotations +import datetime as _dt +import hashlib +import time +import uuid +from pathlib import Path +from typing import Any + +from markovian_rsa_mlx.audit import ( + AuditWriter, RunStartEvent, GenerationStartEvent, + TraceCompleteEvent, TailExtractedEvent, AggregationPromptEvent, + RoundCompleteEvent, FinalEvent, RunEndEvent, +) +from markovian_rsa_mlx.batching import GenerationRequest, run_batch, GenerationResult +from markovian_rsa_mlx.config import RSAConfig +from markovian_rsa_mlx.prompts import ( + build_round_0_messages, + build_aggregation_messages, +) +from markovian_rsa_mlx.results import RSAResult, RSARound, RSAStats, TraceRecord + + +def _trace_seed(base_seed: int | None, run_id: str, round_index: int, trace_index: int) -> int: + """Deterministic seed if base_seed set, else stable from run_id.""" + key = f"{base_seed}|{run_id}|{round_index}|{trace_index}" + h = hashlib.sha256(key.encode()).hexdigest() + return int(h[:8], 16) + + +def _now_iso() -> str: + return _dt.datetime.now(tz=_dt.timezone.utc).isoformat().replace("+00:00", "Z") + + +class MarkovianRSAOrchestrator: + """Drives Markovian RSA rounds over a loaded mlx-lm model + tokenizer.""" + + def __init__( + self, + model: Any, + tokenizer: Any, + *, + model_id: str = "kyr0/zaya1-base-8b-MLX", + quantization: str = "q4_g64", + default_config: RSAConfig | None = None, + single_generate=None, + batch_generate=None, + ) -> None: + self.model = model + self.tokenizer = tokenizer + self.model_id = model_id + self.quantization = quantization + self.default_config = default_config or RSAConfig() + self._single_generate = single_generate + self._batch_generate = batch_generate + + @classmethod + def from_pretrained( + cls, + model_id: str = "kyr0/zaya1-base-8b-MLX", + *, + quantization: str = "q4_g64", + default_config: RSAConfig | None = None, + ) -> "MarkovianRSAOrchestrator": + from markovian_rsa_mlx.loader import load_zaya_model + model, tokenizer = load_zaya_model(model_id) + return cls( + model=model, tokenizer=tokenizer, + model_id=model_id, quantization=quantization, + default_config=default_config, + ) + + def solve( + self, + prompt: str, + *, + config: RSAConfig | None = None, + return_audit: bool = False, + audit_path: str | Path | None = None, + ): + cfg = config or self.default_config + run_id = uuid.uuid4().hex[:12] + t0 = time.time() + with AuditWriter(audit_path) as aud: + aud.write(RunStartEvent( + run_id=run_id, model_id=self.model_id, quantization=self.quantization, + config=cfg, prompt=prompt, created_at=_now_iso(), + )) + rounds_records: list[RSARound] = [] + previous_traces: list[TraceRecord] = [] + for round_idx in range(cfg.rounds): + round_traces, round_elapsed = self._run_round( + run_id=run_id, round_idx=round_idx, original_prompt=prompt, + previous_traces=previous_traces, cfg=cfg, audit=aud, + ) + rounds_records.append(RSARound( + round=round_idx, traces=round_traces, elapsed_s=round_elapsed, + memory_estimate_bytes=0, + )) + aud.write(RoundCompleteEvent( + run_id=run_id, round=round_idx, + trace_ids=[t.trace_id for t in round_traces], + memory_estimate_bytes=0, elapsed_s=round_elapsed, + )) + previous_traces = round_traces + + final_trace = previous_traces[0] + aud.write(FinalEvent( + run_id=run_id, final_trace_id=final_trace.trace_id, + final_text=final_trace.text, + all_final_trace_ids=[t.trace_id for t in previous_traces], + answer_selection=cfg.answer_selection, + )) + elapsed = time.time() - t0 + total_tokens = sum( + t.generated_tokens for r in rounds_records for t in r.traces + ) + aud.write(RunEndEvent( + run_id=run_id, elapsed_s=elapsed, + total_generated_tokens=total_tokens, peak_memory_bytes=0, + )) + + result = RSAResult( + run_id=run_id, prompt=prompt, final_text=final_trace.text, + final_trace_id=final_trace.trace_id, model_id=self.model_id, + quantization=self.quantization, config=cfg, rounds=rounds_records, + stats=RSAStats( + total_generated_tokens=total_tokens, elapsed_s=elapsed, + peak_memory_bytes=0, + ), + audit_path=Path(audit_path) if audit_path is not None else None, + ) + if return_audit: + return result.final_text, result + return result.final_text + + def _run_round( + self, *, run_id: str, round_idx: int, original_prompt: str, + previous_traces: list[TraceRecord], cfg: RSAConfig, audit: AuditWriter, + ) -> tuple[list[TraceRecord], float]: + round_t0 = time.time() + is_round_0 = round_idx == 0 + max_tokens = cfg.chunk_tokens if round_idx < cfg.rounds - 1 else cfg.effective_final_tokens() + + prompts_token_ids: list[list[int]] = [] + parent_ids_per_trace: list[list[str]] = [] + if is_round_0: + messages = build_round_0_messages(original_prompt) + prompt_ids = self._render_chat(messages) + prompts_token_ids = [prompt_ids for _ in range(cfg.parallel)] + parent_ids_per_trace = [[] for _ in range(cfg.parallel)] + else: + import random as _random + rng = _random.Random(_trace_seed(cfg.seed, run_id, round_idx, -1)) + for trace_idx in range(cfg.parallel): + K = min(cfg.aggregation_subsample, len(previous_traces)) + selected = rng.sample(previous_traces, K) + tails = [self._extract_tail_text(t.token_ids, cfg.tail_tokens) for t in selected] + tail_token_ids_list = [self._extract_tail_token_ids(t.token_ids, cfg.tail_tokens) for t in selected] + for sel, tail_ids, tail_text in zip(selected, tail_token_ids_list, tails): + audit.write(TailExtractedEvent( + run_id=run_id, round=round_idx, trace_id=sel.trace_id, + tail_token_ids=tail_ids, tail_text=tail_text, + tail_tokens=len(tail_ids), + )) + messages = build_aggregation_messages( + original_prompt=original_prompt, tails=tails, + template=cfg.aggregation_template, + ) + prompt_ids = self._render_chat(messages) + child_trace_id = f"r{round_idx}-t{trace_idx}-{run_id[:6]}" + audit.write(AggregationPromptEvent( + run_id=run_id, round=round_idx, trace_id=child_trace_id, + selected_tail_trace_ids=[s.trace_id for s in selected], + prompt_text=messages[0]["content"], prompt_token_ids=prompt_ids, + )) + prompts_token_ids.append(prompt_ids) + parent_ids_per_trace.append([s.trace_id for s in selected]) + + seeds = [_trace_seed(cfg.seed, run_id, round_idx, i) for i in range(cfg.parallel)] + trace_ids = [f"r{round_idx}-t{i}-{run_id[:6]}" for i in range(cfg.parallel)] + for i, tid in enumerate(trace_ids): + audit.write(GenerationStartEvent( + run_id=run_id, round=round_idx, trace_id=tid, + seed=seeds[i], prompt_token_count=len(prompts_token_ids[i]), + max_tokens=max_tokens, parent_trace_ids=parent_ids_per_trace[i], + )) + + requests = [ + GenerationRequest(prompt_token_ids=prompts_token_ids[i], seed=seeds[i], max_tokens=max_tokens) + for i in range(cfg.parallel) + ] + results: list[GenerationResult] = run_batch( + model=self.model, tokenizer=self.tokenizer, + requests=requests, temperature=cfg.temperature, top_p=cfg.top_p, top_k=cfg.top_k, + serial=cfg.serial, single_generate=self._single_generate, batch_generate=self._batch_generate, + ) + + records: list[TraceRecord] = [] + for i, (tid, gen) in enumerate(zip(trace_ids, results)): + audit.write(TraceCompleteEvent( + run_id=run_id, round=round_idx, trace_id=tid, + text=gen.text, token_ids=gen.token_ids, + generated_tokens=gen.generated_tokens, finish_reason=gen.finish_reason, + elapsed_s=gen.elapsed_s, + )) + records.append(TraceRecord( + trace_id=tid, text=gen.text, token_ids=gen.token_ids, + generated_tokens=gen.generated_tokens, finish_reason=gen.finish_reason, + elapsed_s=gen.elapsed_s, seed=seeds[i], + parent_trace_ids=parent_ids_per_trace[i], + )) + round_elapsed = time.time() - round_t0 + return records, round_elapsed + + def _render_chat(self, messages: list[dict[str, str]]) -> list[int]: + """Apply ZAYA chat template and return token ids.""" + rendered = self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, enable_thinking=True, + ) + if isinstance(rendered, str): + return self.tokenizer.encode(rendered) + return list(rendered) + + @staticmethod + def _extract_tail_token_ids(ids: list[int], tail_tokens: int) -> list[int]: + if tail_tokens <= 0 or not ids: + return [] + return ids[-tail_tokens:] + + def _extract_tail_text(self, ids: list[int], tail_tokens: int) -> str: + tail_ids = self._extract_tail_token_ids(ids, tail_tokens) + if not tail_ids: + return "" + return self.tokenizer.decode(tail_ids) diff --git a/src/markovian_rsa_mlx/results.py b/src/markovian_rsa_mlx/results.py new file mode 100644 index 0000000..07f384e --- /dev/null +++ b/src/markovian_rsa_mlx/results.py @@ -0,0 +1,47 @@ +"""Public result types returned by MarkovianRSAOrchestrator.""" +from __future__ import annotations +from dataclasses import dataclass, field +from pathlib import Path + +from markovian_rsa_mlx.config import RSAConfig + + +@dataclass +class TraceRecord: + trace_id: str + text: str + token_ids: list[int] + generated_tokens: int + finish_reason: str + elapsed_s: float + seed: int + parent_trace_ids: list[str] = field(default_factory=list) + + +@dataclass +class RSARound: + round: int + traces: list[TraceRecord] + elapsed_s: float + memory_estimate_bytes: int + + +@dataclass +class RSAStats: + total_generated_tokens: int + elapsed_s: float + peak_memory_bytes: int + + +@dataclass +class RSAResult: + run_id: str + prompt: str + final_text: str + final_trace_id: str + model_id: str + quantization: str + config: RSAConfig + rounds: list[RSARound] + stats: RSAStats + audit_path: Path | None diff --git a/tests/test_orchestrator_t1.py b/tests/test_orchestrator_t1.py new file mode 100644 index 0000000..1ad71c4 --- /dev/null +++ b/tests/test_orchestrator_t1.py @@ -0,0 +1,60 @@ +from unittest.mock import MagicMock + +from markovian_rsa_mlx.batching import GenerationRequest, GenerationResult +from markovian_rsa_mlx.config import RSAConfig +from markovian_rsa_mlx.orchestrator import MarkovianRSAOrchestrator + + +def _fake_tokenizer(eos_id: int = 999): + tok = MagicMock() + tok.encode.side_effect = lambda s: [ord(c) for c in s][:32] or [1] + tok.decode.side_effect = lambda ids: "".join(chr(min(i, 122)) for i in ids if 32 <= i <= 122) + tok.eos_token_id = eos_id + tok.all_special_ids = [eos_id] + tok.apply_chat_template.side_effect = lambda messages, **kw: \ + " ".join(m["content"] for m in messages).encode().hex() + return tok + + +def _fake_single_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k): + text = f"trace-{seed}-final-answer" + ids = [ord(c) for c in text] + return GenerationResult( + token_ids=ids, text=text, generated_tokens=len(ids), + finish_reason="eos", elapsed_s=0.01, + ) + + +def test_t1_single_round_produces_final_text(tmp_path): + cfg = RSAConfig(rounds=1, parallel=2, aggregation_subsample=2, + chunk_tokens=64, tail_tokens=8, serial=True, seed=123) + orch = MarkovianRSAOrchestrator( + model=MagicMock(), + tokenizer=_fake_tokenizer(), + model_id="test-model", + quantization="bf16", + single_generate=_fake_single_gen, + batch_generate=None, + ) + audit_path = tmp_path / "audit.jsonl" + text, result = orch.solve("What is 2+2?", config=cfg, return_audit=True, audit_path=audit_path) + assert isinstance(text, str) + assert text == result.final_text + assert result.config.rounds == 1 + assert len(result.rounds) == 1 + assert len(result.rounds[0].traces) == 2 + assert audit_path.exists() + lines = audit_path.read_text().strip().split("\n") + # at minimum: run_start, 2 trace_complete, final, run_end + assert len(lines) >= 5 + + +def test_t1_returns_string_when_return_audit_false(tmp_path): + cfg = RSAConfig(rounds=1, parallel=2, aggregation_subsample=2, serial=True) + orch = MarkovianRSAOrchestrator( + model=MagicMock(), tokenizer=_fake_tokenizer(), + model_id="m", quantization="bf16", + single_generate=_fake_single_gen, + ) + out = orch.solve("X", config=cfg) + assert isinstance(out, str)