From b688c4ef77a3997aac237e49c57d3c10a08e309c Mon Sep 17 00:00:00 2001 From: transcrilive Date: Sun, 10 May 2026 02:43:10 +0200 Subject: [PATCH] feat(audit): add JSONL event types + streaming AuditWriter --- src/markovian_rsa_mlx/audit.py | 155 +++++++++++++++++++++++++++++++++ tests/test_audit.py | 80 +++++++++++++++++ 2 files changed, 235 insertions(+) create mode 100644 src/markovian_rsa_mlx/audit.py create mode 100644 tests/test_audit.py diff --git a/src/markovian_rsa_mlx/audit.py b/src/markovian_rsa_mlx/audit.py new file mode 100644 index 0000000..0476401 --- /dev/null +++ b/src/markovian_rsa_mlx/audit.py @@ -0,0 +1,155 @@ +"""JSONL audit events + writer. + +All RSA orchestrator state changes emit one event per JSON line. Schema is +documented in docs/superpowers/specs/2026-05-10-markovian-rsa-mlx-design.md. +""" +from __future__ import annotations +import dataclasses +import io +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal + +from markovian_rsa_mlx.config import RSAConfig + + +@dataclass +class _BaseEvent: + """Base for all audit events. Subclasses set EVENT_NAME.""" + EVENT_NAME: str = field(init=False, repr=False, default="") + + def to_dict(self) -> dict[str, Any]: + d: dict[str, Any] = {"event": self.EVENT_NAME} + for f in dataclasses.fields(self): + if f.name == "EVENT_NAME": + continue + v = getattr(self, f.name) + if isinstance(v, RSAConfig): + v = dataclasses.asdict(v) + d[f.name] = v + return d + + +@dataclass +class RunStartEvent(_BaseEvent): + run_id: str + model_id: str + quantization: str + config: RSAConfig + prompt: str + created_at: str + EVENT_NAME: str = field(init=False, default="run_start") + + +@dataclass +class GenerationStartEvent(_BaseEvent): + run_id: str + round: int + trace_id: str + seed: int + prompt_token_count: int + max_tokens: int + parent_trace_ids: list[str] + EVENT_NAME: str = field(init=False, default="generation_start") + + +@dataclass +class TraceCompleteEvent(_BaseEvent): + run_id: str + round: int + trace_id: str + text: str + token_ids: list[int] + generated_tokens: int + finish_reason: Literal["eos", "max_tokens", "error"] + elapsed_s: float + EVENT_NAME: str = field(init=False, default="trace_complete") + + +@dataclass +class TailExtractedEvent(_BaseEvent): + run_id: str + round: int + trace_id: str + tail_token_ids: list[int] + tail_text: str + tail_tokens: int + EVENT_NAME: str = field(init=False, default="tail_extracted") + + +@dataclass +class AggregationPromptEvent(_BaseEvent): + run_id: str + round: int + trace_id: str + selected_tail_trace_ids: list[str] + prompt_text: str + prompt_token_ids: list[int] + EVENT_NAME: str = field(init=False, default="aggregation_prompt") + + +@dataclass +class RoundCompleteEvent(_BaseEvent): + run_id: str + round: int + trace_ids: list[str] + memory_estimate_bytes: int + elapsed_s: float + EVENT_NAME: str = field(init=False, default="round_complete") + + +@dataclass +class FinalEvent(_BaseEvent): + run_id: str + final_trace_id: str + final_text: str + all_final_trace_ids: list[str] + answer_selection: str + EVENT_NAME: str = field(init=False, default="final") + + +@dataclass +class RunEndEvent(_BaseEvent): + run_id: str + elapsed_s: float + total_generated_tokens: int + peak_memory_bytes: int + EVENT_NAME: str = field(init=False, default="run_end") + + +@dataclass +class ErrorEvent(_BaseEvent): + run_id: str + stage: str + message: str + recoverable: bool + EVENT_NAME: str = field(init=False, default="error") + + +class AuditWriter: + """Streaming JSONL writer. No-op when path is None.""" + def __init__(self, path: Path | str | None, mode: str = "w") -> None: + self._path = Path(path) if path is not None else None + self._fp: io.TextIOBase | None = None + if self._path is not None: + self._path.parent.mkdir(parents=True, exist_ok=True) + self._fp = self._path.open(mode, encoding="utf-8") + + def write(self, event: _BaseEvent) -> None: + if self._fp is None: + return + self._fp.write(json.dumps(event.to_dict(), ensure_ascii=False)) + self._fp.write("\n") + self._fp.flush() + + def close(self) -> None: + if self._fp is not None: + self._fp.close() + self._fp = None + + def __enter__(self) -> "AuditWriter": + return self + + def __exit__(self, *args) -> None: + self.close() diff --git a/tests/test_audit.py b/tests/test_audit.py new file mode 100644 index 0000000..e93372c --- /dev/null +++ b/tests/test_audit.py @@ -0,0 +1,80 @@ +import json +import tempfile +from pathlib import Path + +from markovian_rsa_mlx.audit import ( + AuditWriter, + RunStartEvent, + TraceCompleteEvent, + TailExtractedEvent, + AggregationPromptEvent, + RoundCompleteEvent, + FinalEvent, + RunEndEvent, +) +from markovian_rsa_mlx.config import RSAConfig + + +def test_writer_emits_jsonl(tmp_path: Path): + audit_path = tmp_path / "audit.jsonl" + writer = AuditWriter(audit_path) + writer.write(RunStartEvent( + run_id="run-1", + model_id="kyr0/zaya1-base-8b-MLX", + quantization="q4_g64", + config=RSAConfig(), + prompt="2+2", + created_at="2026-05-10T00:00:00Z", + )) + writer.write(TraceCompleteEvent( + run_id="run-1", + round=0, + trace_id="t0", + text="The answer is 4.", + token_ids=[1, 2, 3], + generated_tokens=3, + finish_reason="eos", + elapsed_s=0.5, + )) + writer.close() + + lines = audit_path.read_text().strip().split("\n") + assert len(lines) == 2 + e0 = json.loads(lines[0]) + assert e0["event"] == "run_start" + assert e0["model_id"] == "kyr0/zaya1-base-8b-MLX" + assert e0["config"]["rounds"] == 2 + e1 = json.loads(lines[1]) + assert e1["event"] == "trace_complete" + assert e1["text"] == "The answer is 4." + assert e1["finish_reason"] == "eos" + + +def test_event_serialization_includes_event_field(): + ev = TailExtractedEvent( + run_id="r", round=0, trace_id="t", + tail_token_ids=[1, 2, 3], tail_text="hi", tail_tokens=3, + ) + payload = ev.to_dict() + assert payload["event"] == "tail_extracted" + + +def test_writer_appends_subsequent_events(tmp_path: Path): + audit_path = tmp_path / "audit.jsonl" + w = AuditWriter(audit_path) + w.write(RunEndEvent(run_id="r", elapsed_s=1.0, + total_generated_tokens=10, peak_memory_bytes=1000)) + w.close() + w2 = AuditWriter(audit_path, mode="a") + w2.write(RunEndEvent(run_id="r2", elapsed_s=2.0, + total_generated_tokens=20, peak_memory_bytes=2000)) + w2.close() + lines = audit_path.read_text().strip().split("\n") + assert len(lines) == 2 + + +def test_writer_no_op_when_path_none(): + w = AuditWriter(None) + w.write(RunEndEvent(run_id="x", elapsed_s=1.0, + total_generated_tokens=0, peak_memory_bytes=0)) + w.close() # no exception