"""JSONL audit events + writer. All RSA orchestrator state changes emit one event per JSON line. Schema is documented in docs/superpowers/specs/2026-05-10-markovian-rsa-mlx-design.md. """ from __future__ import annotations import dataclasses import io import json import weakref from dataclasses import dataclass, field from pathlib import Path from typing import Any, Literal from markovian_rsa_mlx.config import RSAConfig @dataclass class _BaseEvent: """Base for all audit events. Subclasses set EVENT_NAME.""" EVENT_NAME: str = field(init=False, repr=False, default="") def to_dict(self) -> dict[str, Any]: d: dict[str, Any] = {"event": self.EVENT_NAME} for f in dataclasses.fields(self): if f.name == "EVENT_NAME": continue v = getattr(self, f.name) if isinstance(v, RSAConfig): v = dataclasses.asdict(v) d[f.name] = v return d @dataclass class RunStartEvent(_BaseEvent): run_id: str model_id: str quantization: str config: RSAConfig prompt: str created_at: str EVENT_NAME: str = field(init=False, default="run_start") @dataclass class GenerationStartEvent(_BaseEvent): run_id: str round: int trace_id: str seed: int prompt_token_count: int max_tokens: int parent_trace_ids: list[str] EVENT_NAME: str = field(init=False, default="generation_start") @dataclass class TraceCompleteEvent(_BaseEvent): run_id: str round: int trace_id: str text: str token_ids: list[int] generated_tokens: int finish_reason: Literal["eos", "max_tokens", "error"] elapsed_s: float EVENT_NAME: str = field(init=False, default="trace_complete") @dataclass class TailExtractedEvent(_BaseEvent): run_id: str round: int trace_id: str tail_token_ids: list[int] tail_text: str tail_tokens: int EVENT_NAME: str = field(init=False, default="tail_extracted") @dataclass class AggregationPromptEvent(_BaseEvent): run_id: str round: int trace_id: str selected_tail_trace_ids: list[str] prompt_text: str prompt_token_ids: list[int] EVENT_NAME: str = field(init=False, default="aggregation_prompt") @dataclass class RoundCompleteEvent(_BaseEvent): run_id: str round: int trace_ids: list[str] memory_estimate_bytes: int elapsed_s: float EVENT_NAME: str = field(init=False, default="round_complete") @dataclass class FinalEvent(_BaseEvent): run_id: str final_trace_id: str final_text: str all_final_trace_ids: list[str] answer_selection: str EVENT_NAME: str = field(init=False, default="final") @dataclass class RunEndEvent(_BaseEvent): run_id: str elapsed_s: float total_generated_tokens: int peak_memory_bytes: int EVENT_NAME: str = field(init=False, default="run_end") @dataclass class ErrorEvent(_BaseEvent): run_id: str stage: str message: str recoverable: bool EVENT_NAME: str = field(init=False, default="error") class AuditWriter: """Streaming JSONL writer. No-op when path is None. Once `close()` is called, subsequent `write()` calls raise ValueError. Use as a context manager (`with AuditWriter(p) as w:`) for guaranteed close. """ def __init__(self, path: Path | str | None, mode: Literal["w", "a"] = "w") -> None: self._path = Path(path) if path is not None else None self._enabled = self._path is not None self._closed = False self._fp: io.TextIOWrapper | None = None if self._path is not None: self._path.parent.mkdir(parents=True, exist_ok=True) self._fp = self._path.open(mode, encoding="utf-8") if self._fp is not None: weakref.finalize(self, self._fp.close) def write(self, event: _BaseEvent) -> None: if not self._enabled: return if self._closed: raise ValueError("AuditWriter is closed") assert self._fp is not None self._fp.write(json.dumps(event.to_dict(), ensure_ascii=False)) self._fp.write("\n") self._fp.flush() # flush per event so a crash mid-run preserves the audit trail def close(self) -> None: if self._closed: return self._closed = True if self._fp is not None: self._fp.close() self._fp = None def __enter__(self) -> "AuditWriter": return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close()