feat(audit): add JSONL event types + streaming AuditWriter

This commit is contained in:
transcrilive
2026-05-10 02:43:10 +02:00
parent 40bd38c2c5
commit b688c4ef77
2 changed files with 235 additions and 0 deletions

View File

@@ -0,0 +1,155 @@
"""JSONL audit events + writer.
All RSA orchestrator state changes emit one event per JSON line. Schema is
documented in docs/superpowers/specs/2026-05-10-markovian-rsa-mlx-design.md.
"""
from __future__ import annotations
import dataclasses
import io
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal
from markovian_rsa_mlx.config import RSAConfig
@dataclass
class _BaseEvent:
"""Base for all audit events. Subclasses set EVENT_NAME."""
EVENT_NAME: str = field(init=False, repr=False, default="")
def to_dict(self) -> dict[str, Any]:
d: dict[str, Any] = {"event": self.EVENT_NAME}
for f in dataclasses.fields(self):
if f.name == "EVENT_NAME":
continue
v = getattr(self, f.name)
if isinstance(v, RSAConfig):
v = dataclasses.asdict(v)
d[f.name] = v
return d
@dataclass
class RunStartEvent(_BaseEvent):
run_id: str
model_id: str
quantization: str
config: RSAConfig
prompt: str
created_at: str
EVENT_NAME: str = field(init=False, default="run_start")
@dataclass
class GenerationStartEvent(_BaseEvent):
run_id: str
round: int
trace_id: str
seed: int
prompt_token_count: int
max_tokens: int
parent_trace_ids: list[str]
EVENT_NAME: str = field(init=False, default="generation_start")
@dataclass
class TraceCompleteEvent(_BaseEvent):
run_id: str
round: int
trace_id: str
text: str
token_ids: list[int]
generated_tokens: int
finish_reason: Literal["eos", "max_tokens", "error"]
elapsed_s: float
EVENT_NAME: str = field(init=False, default="trace_complete")
@dataclass
class TailExtractedEvent(_BaseEvent):
run_id: str
round: int
trace_id: str
tail_token_ids: list[int]
tail_text: str
tail_tokens: int
EVENT_NAME: str = field(init=False, default="tail_extracted")
@dataclass
class AggregationPromptEvent(_BaseEvent):
run_id: str
round: int
trace_id: str
selected_tail_trace_ids: list[str]
prompt_text: str
prompt_token_ids: list[int]
EVENT_NAME: str = field(init=False, default="aggregation_prompt")
@dataclass
class RoundCompleteEvent(_BaseEvent):
run_id: str
round: int
trace_ids: list[str]
memory_estimate_bytes: int
elapsed_s: float
EVENT_NAME: str = field(init=False, default="round_complete")
@dataclass
class FinalEvent(_BaseEvent):
run_id: str
final_trace_id: str
final_text: str
all_final_trace_ids: list[str]
answer_selection: str
EVENT_NAME: str = field(init=False, default="final")
@dataclass
class RunEndEvent(_BaseEvent):
run_id: str
elapsed_s: float
total_generated_tokens: int
peak_memory_bytes: int
EVENT_NAME: str = field(init=False, default="run_end")
@dataclass
class ErrorEvent(_BaseEvent):
run_id: str
stage: str
message: str
recoverable: bool
EVENT_NAME: str = field(init=False, default="error")
class AuditWriter:
"""Streaming JSONL writer. No-op when path is None."""
def __init__(self, path: Path | str | None, mode: str = "w") -> None:
self._path = Path(path) if path is not None else None
self._fp: io.TextIOBase | None = None
if self._path is not None:
self._path.parent.mkdir(parents=True, exist_ok=True)
self._fp = self._path.open(mode, encoding="utf-8")
def write(self, event: _BaseEvent) -> None:
if self._fp is None:
return
self._fp.write(json.dumps(event.to_dict(), ensure_ascii=False))
self._fp.write("\n")
self._fp.flush()
def close(self) -> None:
if self._fp is not None:
self._fp.close()
self._fp = None
def __enter__(self) -> "AuditWriter":
return self
def __exit__(self, *args) -> None:
self.close()