feat(orchestrator): add T=1 path with audit JSONL + tail extraction
This commit is contained in:
@@ -3,5 +3,6 @@ __version__ = "0.1.0"
|
|||||||
|
|
||||||
from markovian_rsa_mlx.config import RSAConfig
|
from markovian_rsa_mlx.config import RSAConfig
|
||||||
from markovian_rsa_mlx.loader import load_zaya_model
|
from markovian_rsa_mlx.loader import load_zaya_model
|
||||||
|
from markovian_rsa_mlx.orchestrator import MarkovianRSAOrchestrator
|
||||||
|
|
||||||
__all__ = ["__version__", "RSAConfig", "load_zaya_model"]
|
__all__ = ["__version__", "RSAConfig", "load_zaya_model", "MarkovianRSAOrchestrator"]
|
||||||
|
|||||||
235
src/markovian_rsa_mlx/orchestrator.py
Normal file
235
src/markovian_rsa_mlx/orchestrator.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
"""MarkovianRSAOrchestrator — drives N parallel traces + aggregation rounds."""
|
||||||
|
from __future__ import annotations
|
||||||
|
import datetime as _dt
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from markovian_rsa_mlx.audit import (
|
||||||
|
AuditWriter, RunStartEvent, GenerationStartEvent,
|
||||||
|
TraceCompleteEvent, TailExtractedEvent, AggregationPromptEvent,
|
||||||
|
RoundCompleteEvent, FinalEvent, RunEndEvent,
|
||||||
|
)
|
||||||
|
from markovian_rsa_mlx.batching import GenerationRequest, run_batch, GenerationResult
|
||||||
|
from markovian_rsa_mlx.config import RSAConfig
|
||||||
|
from markovian_rsa_mlx.prompts import (
|
||||||
|
build_round_0_messages,
|
||||||
|
build_aggregation_messages,
|
||||||
|
)
|
||||||
|
from markovian_rsa_mlx.results import RSAResult, RSARound, RSAStats, TraceRecord
|
||||||
|
|
||||||
|
|
||||||
|
def _trace_seed(base_seed: int | None, run_id: str, round_index: int, trace_index: int) -> int:
|
||||||
|
"""Deterministic seed if base_seed set, else stable from run_id."""
|
||||||
|
key = f"{base_seed}|{run_id}|{round_index}|{trace_index}"
|
||||||
|
h = hashlib.sha256(key.encode()).hexdigest()
|
||||||
|
return int(h[:8], 16)
|
||||||
|
|
||||||
|
|
||||||
|
def _now_iso() -> str:
|
||||||
|
return _dt.datetime.now(tz=_dt.timezone.utc).isoformat().replace("+00:00", "Z")
|
||||||
|
|
||||||
|
|
||||||
|
class MarkovianRSAOrchestrator:
|
||||||
|
"""Drives Markovian RSA rounds over a loaded mlx-lm model + tokenizer."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Any,
|
||||||
|
tokenizer: Any,
|
||||||
|
*,
|
||||||
|
model_id: str = "kyr0/zaya1-base-8b-MLX",
|
||||||
|
quantization: str = "q4_g64",
|
||||||
|
default_config: RSAConfig | None = None,
|
||||||
|
single_generate=None,
|
||||||
|
batch_generate=None,
|
||||||
|
) -> None:
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.model_id = model_id
|
||||||
|
self.quantization = quantization
|
||||||
|
self.default_config = default_config or RSAConfig()
|
||||||
|
self._single_generate = single_generate
|
||||||
|
self._batch_generate = batch_generate
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
model_id: str = "kyr0/zaya1-base-8b-MLX",
|
||||||
|
*,
|
||||||
|
quantization: str = "q4_g64",
|
||||||
|
default_config: RSAConfig | None = None,
|
||||||
|
) -> "MarkovianRSAOrchestrator":
|
||||||
|
from markovian_rsa_mlx.loader import load_zaya_model
|
||||||
|
model, tokenizer = load_zaya_model(model_id)
|
||||||
|
return cls(
|
||||||
|
model=model, tokenizer=tokenizer,
|
||||||
|
model_id=model_id, quantization=quantization,
|
||||||
|
default_config=default_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def solve(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
*,
|
||||||
|
config: RSAConfig | None = None,
|
||||||
|
return_audit: bool = False,
|
||||||
|
audit_path: str | Path | None = None,
|
||||||
|
):
|
||||||
|
cfg = config or self.default_config
|
||||||
|
run_id = uuid.uuid4().hex[:12]
|
||||||
|
t0 = time.time()
|
||||||
|
with AuditWriter(audit_path) as aud:
|
||||||
|
aud.write(RunStartEvent(
|
||||||
|
run_id=run_id, model_id=self.model_id, quantization=self.quantization,
|
||||||
|
config=cfg, prompt=prompt, created_at=_now_iso(),
|
||||||
|
))
|
||||||
|
rounds_records: list[RSARound] = []
|
||||||
|
previous_traces: list[TraceRecord] = []
|
||||||
|
for round_idx in range(cfg.rounds):
|
||||||
|
round_traces, round_elapsed = self._run_round(
|
||||||
|
run_id=run_id, round_idx=round_idx, original_prompt=prompt,
|
||||||
|
previous_traces=previous_traces, cfg=cfg, audit=aud,
|
||||||
|
)
|
||||||
|
rounds_records.append(RSARound(
|
||||||
|
round=round_idx, traces=round_traces, elapsed_s=round_elapsed,
|
||||||
|
memory_estimate_bytes=0,
|
||||||
|
))
|
||||||
|
aud.write(RoundCompleteEvent(
|
||||||
|
run_id=run_id, round=round_idx,
|
||||||
|
trace_ids=[t.trace_id for t in round_traces],
|
||||||
|
memory_estimate_bytes=0, elapsed_s=round_elapsed,
|
||||||
|
))
|
||||||
|
previous_traces = round_traces
|
||||||
|
|
||||||
|
final_trace = previous_traces[0]
|
||||||
|
aud.write(FinalEvent(
|
||||||
|
run_id=run_id, final_trace_id=final_trace.trace_id,
|
||||||
|
final_text=final_trace.text,
|
||||||
|
all_final_trace_ids=[t.trace_id for t in previous_traces],
|
||||||
|
answer_selection=cfg.answer_selection,
|
||||||
|
))
|
||||||
|
elapsed = time.time() - t0
|
||||||
|
total_tokens = sum(
|
||||||
|
t.generated_tokens for r in rounds_records for t in r.traces
|
||||||
|
)
|
||||||
|
aud.write(RunEndEvent(
|
||||||
|
run_id=run_id, elapsed_s=elapsed,
|
||||||
|
total_generated_tokens=total_tokens, peak_memory_bytes=0,
|
||||||
|
))
|
||||||
|
|
||||||
|
result = RSAResult(
|
||||||
|
run_id=run_id, prompt=prompt, final_text=final_trace.text,
|
||||||
|
final_trace_id=final_trace.trace_id, model_id=self.model_id,
|
||||||
|
quantization=self.quantization, config=cfg, rounds=rounds_records,
|
||||||
|
stats=RSAStats(
|
||||||
|
total_generated_tokens=total_tokens, elapsed_s=elapsed,
|
||||||
|
peak_memory_bytes=0,
|
||||||
|
),
|
||||||
|
audit_path=Path(audit_path) if audit_path is not None else None,
|
||||||
|
)
|
||||||
|
if return_audit:
|
||||||
|
return result.final_text, result
|
||||||
|
return result.final_text
|
||||||
|
|
||||||
|
def _run_round(
|
||||||
|
self, *, run_id: str, round_idx: int, original_prompt: str,
|
||||||
|
previous_traces: list[TraceRecord], cfg: RSAConfig, audit: AuditWriter,
|
||||||
|
) -> tuple[list[TraceRecord], float]:
|
||||||
|
round_t0 = time.time()
|
||||||
|
is_round_0 = round_idx == 0
|
||||||
|
max_tokens = cfg.chunk_tokens if round_idx < cfg.rounds - 1 else cfg.effective_final_tokens()
|
||||||
|
|
||||||
|
prompts_token_ids: list[list[int]] = []
|
||||||
|
parent_ids_per_trace: list[list[str]] = []
|
||||||
|
if is_round_0:
|
||||||
|
messages = build_round_0_messages(original_prompt)
|
||||||
|
prompt_ids = self._render_chat(messages)
|
||||||
|
prompts_token_ids = [prompt_ids for _ in range(cfg.parallel)]
|
||||||
|
parent_ids_per_trace = [[] for _ in range(cfg.parallel)]
|
||||||
|
else:
|
||||||
|
import random as _random
|
||||||
|
rng = _random.Random(_trace_seed(cfg.seed, run_id, round_idx, -1))
|
||||||
|
for trace_idx in range(cfg.parallel):
|
||||||
|
K = min(cfg.aggregation_subsample, len(previous_traces))
|
||||||
|
selected = rng.sample(previous_traces, K)
|
||||||
|
tails = [self._extract_tail_text(t.token_ids, cfg.tail_tokens) for t in selected]
|
||||||
|
tail_token_ids_list = [self._extract_tail_token_ids(t.token_ids, cfg.tail_tokens) for t in selected]
|
||||||
|
for sel, tail_ids, tail_text in zip(selected, tail_token_ids_list, tails):
|
||||||
|
audit.write(TailExtractedEvent(
|
||||||
|
run_id=run_id, round=round_idx, trace_id=sel.trace_id,
|
||||||
|
tail_token_ids=tail_ids, tail_text=tail_text,
|
||||||
|
tail_tokens=len(tail_ids),
|
||||||
|
))
|
||||||
|
messages = build_aggregation_messages(
|
||||||
|
original_prompt=original_prompt, tails=tails,
|
||||||
|
template=cfg.aggregation_template,
|
||||||
|
)
|
||||||
|
prompt_ids = self._render_chat(messages)
|
||||||
|
child_trace_id = f"r{round_idx}-t{trace_idx}-{run_id[:6]}"
|
||||||
|
audit.write(AggregationPromptEvent(
|
||||||
|
run_id=run_id, round=round_idx, trace_id=child_trace_id,
|
||||||
|
selected_tail_trace_ids=[s.trace_id for s in selected],
|
||||||
|
prompt_text=messages[0]["content"], prompt_token_ids=prompt_ids,
|
||||||
|
))
|
||||||
|
prompts_token_ids.append(prompt_ids)
|
||||||
|
parent_ids_per_trace.append([s.trace_id for s in selected])
|
||||||
|
|
||||||
|
seeds = [_trace_seed(cfg.seed, run_id, round_idx, i) for i in range(cfg.parallel)]
|
||||||
|
trace_ids = [f"r{round_idx}-t{i}-{run_id[:6]}" for i in range(cfg.parallel)]
|
||||||
|
for i, tid in enumerate(trace_ids):
|
||||||
|
audit.write(GenerationStartEvent(
|
||||||
|
run_id=run_id, round=round_idx, trace_id=tid,
|
||||||
|
seed=seeds[i], prompt_token_count=len(prompts_token_ids[i]),
|
||||||
|
max_tokens=max_tokens, parent_trace_ids=parent_ids_per_trace[i],
|
||||||
|
))
|
||||||
|
|
||||||
|
requests = [
|
||||||
|
GenerationRequest(prompt_token_ids=prompts_token_ids[i], seed=seeds[i], max_tokens=max_tokens)
|
||||||
|
for i in range(cfg.parallel)
|
||||||
|
]
|
||||||
|
results: list[GenerationResult] = run_batch(
|
||||||
|
model=self.model, tokenizer=self.tokenizer,
|
||||||
|
requests=requests, temperature=cfg.temperature, top_p=cfg.top_p, top_k=cfg.top_k,
|
||||||
|
serial=cfg.serial, single_generate=self._single_generate, batch_generate=self._batch_generate,
|
||||||
|
)
|
||||||
|
|
||||||
|
records: list[TraceRecord] = []
|
||||||
|
for i, (tid, gen) in enumerate(zip(trace_ids, results)):
|
||||||
|
audit.write(TraceCompleteEvent(
|
||||||
|
run_id=run_id, round=round_idx, trace_id=tid,
|
||||||
|
text=gen.text, token_ids=gen.token_ids,
|
||||||
|
generated_tokens=gen.generated_tokens, finish_reason=gen.finish_reason,
|
||||||
|
elapsed_s=gen.elapsed_s,
|
||||||
|
))
|
||||||
|
records.append(TraceRecord(
|
||||||
|
trace_id=tid, text=gen.text, token_ids=gen.token_ids,
|
||||||
|
generated_tokens=gen.generated_tokens, finish_reason=gen.finish_reason,
|
||||||
|
elapsed_s=gen.elapsed_s, seed=seeds[i],
|
||||||
|
parent_trace_ids=parent_ids_per_trace[i],
|
||||||
|
))
|
||||||
|
round_elapsed = time.time() - round_t0
|
||||||
|
return records, round_elapsed
|
||||||
|
|
||||||
|
def _render_chat(self, messages: list[dict[str, str]]) -> list[int]:
|
||||||
|
"""Apply ZAYA chat template and return token ids."""
|
||||||
|
rendered = self.tokenizer.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=True, enable_thinking=True,
|
||||||
|
)
|
||||||
|
if isinstance(rendered, str):
|
||||||
|
return self.tokenizer.encode(rendered)
|
||||||
|
return list(rendered)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_tail_token_ids(ids: list[int], tail_tokens: int) -> list[int]:
|
||||||
|
if tail_tokens <= 0 or not ids:
|
||||||
|
return []
|
||||||
|
return ids[-tail_tokens:]
|
||||||
|
|
||||||
|
def _extract_tail_text(self, ids: list[int], tail_tokens: int) -> str:
|
||||||
|
tail_ids = self._extract_tail_token_ids(ids, tail_tokens)
|
||||||
|
if not tail_ids:
|
||||||
|
return ""
|
||||||
|
return self.tokenizer.decode(tail_ids)
|
||||||
47
src/markovian_rsa_mlx/results.py
Normal file
47
src/markovian_rsa_mlx/results.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""Public result types returned by MarkovianRSAOrchestrator."""
|
||||||
|
from __future__ import annotations
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from markovian_rsa_mlx.config import RSAConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TraceRecord:
|
||||||
|
trace_id: str
|
||||||
|
text: str
|
||||||
|
token_ids: list[int]
|
||||||
|
generated_tokens: int
|
||||||
|
finish_reason: str
|
||||||
|
elapsed_s: float
|
||||||
|
seed: int
|
||||||
|
parent_trace_ids: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RSARound:
|
||||||
|
round: int
|
||||||
|
traces: list[TraceRecord]
|
||||||
|
elapsed_s: float
|
||||||
|
memory_estimate_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RSAStats:
|
||||||
|
total_generated_tokens: int
|
||||||
|
elapsed_s: float
|
||||||
|
peak_memory_bytes: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RSAResult:
|
||||||
|
run_id: str
|
||||||
|
prompt: str
|
||||||
|
final_text: str
|
||||||
|
final_trace_id: str
|
||||||
|
model_id: str
|
||||||
|
quantization: str
|
||||||
|
config: RSAConfig
|
||||||
|
rounds: list[RSARound]
|
||||||
|
stats: RSAStats
|
||||||
|
audit_path: Path | None
|
||||||
60
tests/test_orchestrator_t1.py
Normal file
60
tests/test_orchestrator_t1.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from markovian_rsa_mlx.batching import GenerationRequest, GenerationResult
|
||||||
|
from markovian_rsa_mlx.config import RSAConfig
|
||||||
|
from markovian_rsa_mlx.orchestrator import MarkovianRSAOrchestrator
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_tokenizer(eos_id: int = 999):
|
||||||
|
tok = MagicMock()
|
||||||
|
tok.encode.side_effect = lambda s: [ord(c) for c in s][:32] or [1]
|
||||||
|
tok.decode.side_effect = lambda ids: "".join(chr(min(i, 122)) for i in ids if 32 <= i <= 122)
|
||||||
|
tok.eos_token_id = eos_id
|
||||||
|
tok.all_special_ids = [eos_id]
|
||||||
|
tok.apply_chat_template.side_effect = lambda messages, **kw: \
|
||||||
|
" ".join(m["content"] for m in messages).encode().hex()
|
||||||
|
return tok
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_single_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k):
|
||||||
|
text = f"trace-{seed}-final-answer"
|
||||||
|
ids = [ord(c) for c in text]
|
||||||
|
return GenerationResult(
|
||||||
|
token_ids=ids, text=text, generated_tokens=len(ids),
|
||||||
|
finish_reason="eos", elapsed_s=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_t1_single_round_produces_final_text(tmp_path):
|
||||||
|
cfg = RSAConfig(rounds=1, parallel=2, aggregation_subsample=2,
|
||||||
|
chunk_tokens=64, tail_tokens=8, serial=True, seed=123)
|
||||||
|
orch = MarkovianRSAOrchestrator(
|
||||||
|
model=MagicMock(),
|
||||||
|
tokenizer=_fake_tokenizer(),
|
||||||
|
model_id="test-model",
|
||||||
|
quantization="bf16",
|
||||||
|
single_generate=_fake_single_gen,
|
||||||
|
batch_generate=None,
|
||||||
|
)
|
||||||
|
audit_path = tmp_path / "audit.jsonl"
|
||||||
|
text, result = orch.solve("What is 2+2?", config=cfg, return_audit=True, audit_path=audit_path)
|
||||||
|
assert isinstance(text, str)
|
||||||
|
assert text == result.final_text
|
||||||
|
assert result.config.rounds == 1
|
||||||
|
assert len(result.rounds) == 1
|
||||||
|
assert len(result.rounds[0].traces) == 2
|
||||||
|
assert audit_path.exists()
|
||||||
|
lines = audit_path.read_text().strip().split("\n")
|
||||||
|
# at minimum: run_start, 2 trace_complete, final, run_end
|
||||||
|
assert len(lines) >= 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_t1_returns_string_when_return_audit_false(tmp_path):
|
||||||
|
cfg = RSAConfig(rounds=1, parallel=2, aggregation_subsample=2, serial=True)
|
||||||
|
orch = MarkovianRSAOrchestrator(
|
||||||
|
model=MagicMock(), tokenizer=_fake_tokenizer(),
|
||||||
|
model_id="m", quantization="bf16",
|
||||||
|
single_generate=_fake_single_gen,
|
||||||
|
)
|
||||||
|
out = orch.solve("X", config=cfg)
|
||||||
|
assert isinstance(out, str)
|
||||||
Reference in New Issue
Block a user