feat(orchestrator): add T=1 path with audit JSONL + tail extraction

This commit is contained in:
transcrilive
2026-05-10 03:02:31 +02:00
parent c5584b6396
commit 86ccbe53e1
4 changed files with 344 additions and 1 deletions

View File

@@ -3,5 +3,6 @@ __version__ = "0.1.0"
from markovian_rsa_mlx.config import RSAConfig
from markovian_rsa_mlx.loader import load_zaya_model
from markovian_rsa_mlx.orchestrator import MarkovianRSAOrchestrator
__all__ = ["__version__", "RSAConfig", "load_zaya_model"]
__all__ = ["__version__", "RSAConfig", "load_zaya_model", "MarkovianRSAOrchestrator"]

View File

@@ -0,0 +1,235 @@
"""MarkovianRSAOrchestrator — drives N parallel traces + aggregation rounds."""
from __future__ import annotations
import datetime as _dt
import hashlib
import time
import uuid
from pathlib import Path
from typing import Any
from markovian_rsa_mlx.audit import (
AuditWriter, RunStartEvent, GenerationStartEvent,
TraceCompleteEvent, TailExtractedEvent, AggregationPromptEvent,
RoundCompleteEvent, FinalEvent, RunEndEvent,
)
from markovian_rsa_mlx.batching import GenerationRequest, run_batch, GenerationResult
from markovian_rsa_mlx.config import RSAConfig
from markovian_rsa_mlx.prompts import (
build_round_0_messages,
build_aggregation_messages,
)
from markovian_rsa_mlx.results import RSAResult, RSARound, RSAStats, TraceRecord
def _trace_seed(base_seed: int | None, run_id: str, round_index: int, trace_index: int) -> int:
"""Deterministic seed if base_seed set, else stable from run_id."""
key = f"{base_seed}|{run_id}|{round_index}|{trace_index}"
h = hashlib.sha256(key.encode()).hexdigest()
return int(h[:8], 16)
def _now_iso() -> str:
return _dt.datetime.now(tz=_dt.timezone.utc).isoformat().replace("+00:00", "Z")
class MarkovianRSAOrchestrator:
"""Drives Markovian RSA rounds over a loaded mlx-lm model + tokenizer."""
def __init__(
self,
model: Any,
tokenizer: Any,
*,
model_id: str = "kyr0/zaya1-base-8b-MLX",
quantization: str = "q4_g64",
default_config: RSAConfig | None = None,
single_generate=None,
batch_generate=None,
) -> None:
self.model = model
self.tokenizer = tokenizer
self.model_id = model_id
self.quantization = quantization
self.default_config = default_config or RSAConfig()
self._single_generate = single_generate
self._batch_generate = batch_generate
@classmethod
def from_pretrained(
cls,
model_id: str = "kyr0/zaya1-base-8b-MLX",
*,
quantization: str = "q4_g64",
default_config: RSAConfig | None = None,
) -> "MarkovianRSAOrchestrator":
from markovian_rsa_mlx.loader import load_zaya_model
model, tokenizer = load_zaya_model(model_id)
return cls(
model=model, tokenizer=tokenizer,
model_id=model_id, quantization=quantization,
default_config=default_config,
)
def solve(
self,
prompt: str,
*,
config: RSAConfig | None = None,
return_audit: bool = False,
audit_path: str | Path | None = None,
):
cfg = config or self.default_config
run_id = uuid.uuid4().hex[:12]
t0 = time.time()
with AuditWriter(audit_path) as aud:
aud.write(RunStartEvent(
run_id=run_id, model_id=self.model_id, quantization=self.quantization,
config=cfg, prompt=prompt, created_at=_now_iso(),
))
rounds_records: list[RSARound] = []
previous_traces: list[TraceRecord] = []
for round_idx in range(cfg.rounds):
round_traces, round_elapsed = self._run_round(
run_id=run_id, round_idx=round_idx, original_prompt=prompt,
previous_traces=previous_traces, cfg=cfg, audit=aud,
)
rounds_records.append(RSARound(
round=round_idx, traces=round_traces, elapsed_s=round_elapsed,
memory_estimate_bytes=0,
))
aud.write(RoundCompleteEvent(
run_id=run_id, round=round_idx,
trace_ids=[t.trace_id for t in round_traces],
memory_estimate_bytes=0, elapsed_s=round_elapsed,
))
previous_traces = round_traces
final_trace = previous_traces[0]
aud.write(FinalEvent(
run_id=run_id, final_trace_id=final_trace.trace_id,
final_text=final_trace.text,
all_final_trace_ids=[t.trace_id for t in previous_traces],
answer_selection=cfg.answer_selection,
))
elapsed = time.time() - t0
total_tokens = sum(
t.generated_tokens for r in rounds_records for t in r.traces
)
aud.write(RunEndEvent(
run_id=run_id, elapsed_s=elapsed,
total_generated_tokens=total_tokens, peak_memory_bytes=0,
))
result = RSAResult(
run_id=run_id, prompt=prompt, final_text=final_trace.text,
final_trace_id=final_trace.trace_id, model_id=self.model_id,
quantization=self.quantization, config=cfg, rounds=rounds_records,
stats=RSAStats(
total_generated_tokens=total_tokens, elapsed_s=elapsed,
peak_memory_bytes=0,
),
audit_path=Path(audit_path) if audit_path is not None else None,
)
if return_audit:
return result.final_text, result
return result.final_text
def _run_round(
self, *, run_id: str, round_idx: int, original_prompt: str,
previous_traces: list[TraceRecord], cfg: RSAConfig, audit: AuditWriter,
) -> tuple[list[TraceRecord], float]:
round_t0 = time.time()
is_round_0 = round_idx == 0
max_tokens = cfg.chunk_tokens if round_idx < cfg.rounds - 1 else cfg.effective_final_tokens()
prompts_token_ids: list[list[int]] = []
parent_ids_per_trace: list[list[str]] = []
if is_round_0:
messages = build_round_0_messages(original_prompt)
prompt_ids = self._render_chat(messages)
prompts_token_ids = [prompt_ids for _ in range(cfg.parallel)]
parent_ids_per_trace = [[] for _ in range(cfg.parallel)]
else:
import random as _random
rng = _random.Random(_trace_seed(cfg.seed, run_id, round_idx, -1))
for trace_idx in range(cfg.parallel):
K = min(cfg.aggregation_subsample, len(previous_traces))
selected = rng.sample(previous_traces, K)
tails = [self._extract_tail_text(t.token_ids, cfg.tail_tokens) for t in selected]
tail_token_ids_list = [self._extract_tail_token_ids(t.token_ids, cfg.tail_tokens) for t in selected]
for sel, tail_ids, tail_text in zip(selected, tail_token_ids_list, tails):
audit.write(TailExtractedEvent(
run_id=run_id, round=round_idx, trace_id=sel.trace_id,
tail_token_ids=tail_ids, tail_text=tail_text,
tail_tokens=len(tail_ids),
))
messages = build_aggregation_messages(
original_prompt=original_prompt, tails=tails,
template=cfg.aggregation_template,
)
prompt_ids = self._render_chat(messages)
child_trace_id = f"r{round_idx}-t{trace_idx}-{run_id[:6]}"
audit.write(AggregationPromptEvent(
run_id=run_id, round=round_idx, trace_id=child_trace_id,
selected_tail_trace_ids=[s.trace_id for s in selected],
prompt_text=messages[0]["content"], prompt_token_ids=prompt_ids,
))
prompts_token_ids.append(prompt_ids)
parent_ids_per_trace.append([s.trace_id for s in selected])
seeds = [_trace_seed(cfg.seed, run_id, round_idx, i) for i in range(cfg.parallel)]
trace_ids = [f"r{round_idx}-t{i}-{run_id[:6]}" for i in range(cfg.parallel)]
for i, tid in enumerate(trace_ids):
audit.write(GenerationStartEvent(
run_id=run_id, round=round_idx, trace_id=tid,
seed=seeds[i], prompt_token_count=len(prompts_token_ids[i]),
max_tokens=max_tokens, parent_trace_ids=parent_ids_per_trace[i],
))
requests = [
GenerationRequest(prompt_token_ids=prompts_token_ids[i], seed=seeds[i], max_tokens=max_tokens)
for i in range(cfg.parallel)
]
results: list[GenerationResult] = run_batch(
model=self.model, tokenizer=self.tokenizer,
requests=requests, temperature=cfg.temperature, top_p=cfg.top_p, top_k=cfg.top_k,
serial=cfg.serial, single_generate=self._single_generate, batch_generate=self._batch_generate,
)
records: list[TraceRecord] = []
for i, (tid, gen) in enumerate(zip(trace_ids, results)):
audit.write(TraceCompleteEvent(
run_id=run_id, round=round_idx, trace_id=tid,
text=gen.text, token_ids=gen.token_ids,
generated_tokens=gen.generated_tokens, finish_reason=gen.finish_reason,
elapsed_s=gen.elapsed_s,
))
records.append(TraceRecord(
trace_id=tid, text=gen.text, token_ids=gen.token_ids,
generated_tokens=gen.generated_tokens, finish_reason=gen.finish_reason,
elapsed_s=gen.elapsed_s, seed=seeds[i],
parent_trace_ids=parent_ids_per_trace[i],
))
round_elapsed = time.time() - round_t0
return records, round_elapsed
def _render_chat(self, messages: list[dict[str, str]]) -> list[int]:
"""Apply ZAYA chat template and return token ids."""
rendered = self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, enable_thinking=True,
)
if isinstance(rendered, str):
return self.tokenizer.encode(rendered)
return list(rendered)
@staticmethod
def _extract_tail_token_ids(ids: list[int], tail_tokens: int) -> list[int]:
if tail_tokens <= 0 or not ids:
return []
return ids[-tail_tokens:]
def _extract_tail_text(self, ids: list[int], tail_tokens: int) -> str:
tail_ids = self._extract_tail_token_ids(ids, tail_tokens)
if not tail_ids:
return ""
return self.tokenizer.decode(tail_ids)

View File

@@ -0,0 +1,47 @@
"""Public result types returned by MarkovianRSAOrchestrator."""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from markovian_rsa_mlx.config import RSAConfig
@dataclass
class TraceRecord:
trace_id: str
text: str
token_ids: list[int]
generated_tokens: int
finish_reason: str
elapsed_s: float
seed: int
parent_trace_ids: list[str] = field(default_factory=list)
@dataclass
class RSARound:
round: int
traces: list[TraceRecord]
elapsed_s: float
memory_estimate_bytes: int
@dataclass
class RSAStats:
total_generated_tokens: int
elapsed_s: float
peak_memory_bytes: int
@dataclass
class RSAResult:
run_id: str
prompt: str
final_text: str
final_trace_id: str
model_id: str
quantization: str
config: RSAConfig
rounds: list[RSARound]
stats: RSAStats
audit_path: Path | None

View File

@@ -0,0 +1,60 @@
from unittest.mock import MagicMock
from markovian_rsa_mlx.batching import GenerationRequest, GenerationResult
from markovian_rsa_mlx.config import RSAConfig
from markovian_rsa_mlx.orchestrator import MarkovianRSAOrchestrator
def _fake_tokenizer(eos_id: int = 999):
tok = MagicMock()
tok.encode.side_effect = lambda s: [ord(c) for c in s][:32] or [1]
tok.decode.side_effect = lambda ids: "".join(chr(min(i, 122)) for i in ids if 32 <= i <= 122)
tok.eos_token_id = eos_id
tok.all_special_ids = [eos_id]
tok.apply_chat_template.side_effect = lambda messages, **kw: \
" ".join(m["content"] for m in messages).encode().hex()
return tok
def _fake_single_gen(model, tokenizer, prompt_token_ids, *, max_tokens, seed, temperature, top_p, top_k):
text = f"trace-{seed}-final-answer"
ids = [ord(c) for c in text]
return GenerationResult(
token_ids=ids, text=text, generated_tokens=len(ids),
finish_reason="eos", elapsed_s=0.01,
)
def test_t1_single_round_produces_final_text(tmp_path):
cfg = RSAConfig(rounds=1, parallel=2, aggregation_subsample=2,
chunk_tokens=64, tail_tokens=8, serial=True, seed=123)
orch = MarkovianRSAOrchestrator(
model=MagicMock(),
tokenizer=_fake_tokenizer(),
model_id="test-model",
quantization="bf16",
single_generate=_fake_single_gen,
batch_generate=None,
)
audit_path = tmp_path / "audit.jsonl"
text, result = orch.solve("What is 2+2?", config=cfg, return_audit=True, audit_path=audit_path)
assert isinstance(text, str)
assert text == result.final_text
assert result.config.rounds == 1
assert len(result.rounds) == 1
assert len(result.rounds[0].traces) == 2
assert audit_path.exists()
lines = audit_path.read_text().strip().split("\n")
# at minimum: run_start, 2 trace_complete, final, run_end
assert len(lines) >= 5
def test_t1_returns_string_when_return_audit_false(tmp_path):
cfg = RSAConfig(rounds=1, parallel=2, aggregation_subsample=2, serial=True)
orch = MarkovianRSAOrchestrator(
model=MagicMock(), tokenizer=_fake_tokenizer(),
model_id="m", quantization="bf16",
single_generate=_fake_single_gen,
)
out = orch.solve("X", config=cfg)
assert isinstance(out, str)