171 lines
4.5 KiB
Python
171 lines
4.5 KiB
Python
"""JSONL audit events + writer.
|
|
|
|
All RSA orchestrator state changes emit one event per JSON line. Schema is
|
|
documented in docs/superpowers/specs/2026-05-10-markovian-rsa-mlx-design.md.
|
|
"""
|
|
from __future__ import annotations
|
|
import dataclasses
|
|
import io
|
|
import json
|
|
import weakref
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Literal
|
|
|
|
from markovian_rsa_mlx.config import RSAConfig
|
|
|
|
|
|
@dataclass
|
|
class _BaseEvent:
|
|
"""Base for all audit events. Subclasses set EVENT_NAME."""
|
|
EVENT_NAME: str = field(init=False, repr=False, default="")
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
d: dict[str, Any] = {"event": self.EVENT_NAME}
|
|
for f in dataclasses.fields(self):
|
|
if f.name == "EVENT_NAME":
|
|
continue
|
|
v = getattr(self, f.name)
|
|
if isinstance(v, RSAConfig):
|
|
v = dataclasses.asdict(v)
|
|
d[f.name] = v
|
|
return d
|
|
|
|
|
|
@dataclass
|
|
class RunStartEvent(_BaseEvent):
|
|
run_id: str
|
|
model_id: str
|
|
quantization: str
|
|
config: RSAConfig
|
|
prompt: str
|
|
created_at: str
|
|
EVENT_NAME: str = field(init=False, default="run_start")
|
|
|
|
|
|
@dataclass
|
|
class GenerationStartEvent(_BaseEvent):
|
|
run_id: str
|
|
round: int
|
|
trace_id: str
|
|
seed: int
|
|
prompt_token_count: int
|
|
max_tokens: int
|
|
parent_trace_ids: list[str]
|
|
EVENT_NAME: str = field(init=False, default="generation_start")
|
|
|
|
|
|
@dataclass
|
|
class TraceCompleteEvent(_BaseEvent):
|
|
run_id: str
|
|
round: int
|
|
trace_id: str
|
|
text: str
|
|
token_ids: list[int]
|
|
generated_tokens: int
|
|
finish_reason: Literal["eos", "max_tokens", "error"]
|
|
elapsed_s: float
|
|
EVENT_NAME: str = field(init=False, default="trace_complete")
|
|
|
|
|
|
@dataclass
|
|
class TailExtractedEvent(_BaseEvent):
|
|
run_id: str
|
|
round: int
|
|
trace_id: str
|
|
tail_token_ids: list[int]
|
|
tail_text: str
|
|
tail_tokens: int
|
|
EVENT_NAME: str = field(init=False, default="tail_extracted")
|
|
|
|
|
|
@dataclass
|
|
class AggregationPromptEvent(_BaseEvent):
|
|
run_id: str
|
|
round: int
|
|
trace_id: str
|
|
selected_tail_trace_ids: list[str]
|
|
prompt_text: str
|
|
prompt_token_ids: list[int]
|
|
EVENT_NAME: str = field(init=False, default="aggregation_prompt")
|
|
|
|
|
|
@dataclass
|
|
class RoundCompleteEvent(_BaseEvent):
|
|
run_id: str
|
|
round: int
|
|
trace_ids: list[str]
|
|
memory_estimate_bytes: int
|
|
elapsed_s: float
|
|
EVENT_NAME: str = field(init=False, default="round_complete")
|
|
|
|
|
|
@dataclass
|
|
class FinalEvent(_BaseEvent):
|
|
run_id: str
|
|
final_trace_id: str
|
|
final_text: str
|
|
all_final_trace_ids: list[str]
|
|
answer_selection: str
|
|
EVENT_NAME: str = field(init=False, default="final")
|
|
|
|
|
|
@dataclass
|
|
class RunEndEvent(_BaseEvent):
|
|
run_id: str
|
|
elapsed_s: float
|
|
total_generated_tokens: int
|
|
peak_memory_bytes: int
|
|
EVENT_NAME: str = field(init=False, default="run_end")
|
|
|
|
|
|
@dataclass
|
|
class ErrorEvent(_BaseEvent):
|
|
run_id: str
|
|
stage: str
|
|
message: str
|
|
recoverable: bool
|
|
EVENT_NAME: str = field(init=False, default="error")
|
|
|
|
|
|
class AuditWriter:
|
|
"""Streaming JSONL writer. No-op when path is None.
|
|
|
|
Once `close()` is called, subsequent `write()` calls raise ValueError.
|
|
Use as a context manager (`with AuditWriter(p) as w:`) for guaranteed close.
|
|
"""
|
|
def __init__(self, path: Path | str | None, mode: Literal["w", "a"] = "w") -> None:
|
|
self._path = Path(path) if path is not None else None
|
|
self._enabled = self._path is not None
|
|
self._closed = False
|
|
self._fp: io.TextIOWrapper | None = None
|
|
if self._path is not None:
|
|
self._path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._fp = self._path.open(mode, encoding="utf-8")
|
|
if self._fp is not None:
|
|
weakref.finalize(self, self._fp.close)
|
|
|
|
def write(self, event: _BaseEvent) -> None:
|
|
if not self._enabled:
|
|
return
|
|
if self._closed:
|
|
raise ValueError("AuditWriter is closed")
|
|
assert self._fp is not None
|
|
self._fp.write(json.dumps(event.to_dict(), ensure_ascii=False))
|
|
self._fp.write("\n")
|
|
self._fp.flush() # flush per event so a crash mid-run preserves the audit trail
|
|
|
|
def close(self) -> None:
|
|
if self._closed:
|
|
return
|
|
self._closed = True
|
|
if self._fp is not None:
|
|
self._fp.close()
|
|
self._fp = None
|
|
|
|
def __enter__(self) -> "AuditWriter":
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
self.close()
|