feat(audit): add JSONL event types + streaming AuditWriter

This commit is contained in:
transcrilive
2026-05-10 02:43:10 +02:00
parent 40bd38c2c5
commit b688c4ef77
2 changed files with 235 additions and 0 deletions

View File

@@ -0,0 +1,155 @@
"""JSONL audit events + writer.
All RSA orchestrator state changes emit one event per JSON line. Schema is
documented in docs/superpowers/specs/2026-05-10-markovian-rsa-mlx-design.md.
"""
from __future__ import annotations
import dataclasses
import io
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal
from markovian_rsa_mlx.config import RSAConfig
@dataclass
class _BaseEvent:
"""Base for all audit events. Subclasses set EVENT_NAME."""
EVENT_NAME: str = field(init=False, repr=False, default="")
def to_dict(self) -> dict[str, Any]:
d: dict[str, Any] = {"event": self.EVENT_NAME}
for f in dataclasses.fields(self):
if f.name == "EVENT_NAME":
continue
v = getattr(self, f.name)
if isinstance(v, RSAConfig):
v = dataclasses.asdict(v)
d[f.name] = v
return d
@dataclass
class RunStartEvent(_BaseEvent):
run_id: str
model_id: str
quantization: str
config: RSAConfig
prompt: str
created_at: str
EVENT_NAME: str = field(init=False, default="run_start")
@dataclass
class GenerationStartEvent(_BaseEvent):
run_id: str
round: int
trace_id: str
seed: int
prompt_token_count: int
max_tokens: int
parent_trace_ids: list[str]
EVENT_NAME: str = field(init=False, default="generation_start")
@dataclass
class TraceCompleteEvent(_BaseEvent):
run_id: str
round: int
trace_id: str
text: str
token_ids: list[int]
generated_tokens: int
finish_reason: Literal["eos", "max_tokens", "error"]
elapsed_s: float
EVENT_NAME: str = field(init=False, default="trace_complete")
@dataclass
class TailExtractedEvent(_BaseEvent):
run_id: str
round: int
trace_id: str
tail_token_ids: list[int]
tail_text: str
tail_tokens: int
EVENT_NAME: str = field(init=False, default="tail_extracted")
@dataclass
class AggregationPromptEvent(_BaseEvent):
run_id: str
round: int
trace_id: str
selected_tail_trace_ids: list[str]
prompt_text: str
prompt_token_ids: list[int]
EVENT_NAME: str = field(init=False, default="aggregation_prompt")
@dataclass
class RoundCompleteEvent(_BaseEvent):
run_id: str
round: int
trace_ids: list[str]
memory_estimate_bytes: int
elapsed_s: float
EVENT_NAME: str = field(init=False, default="round_complete")
@dataclass
class FinalEvent(_BaseEvent):
run_id: str
final_trace_id: str
final_text: str
all_final_trace_ids: list[str]
answer_selection: str
EVENT_NAME: str = field(init=False, default="final")
@dataclass
class RunEndEvent(_BaseEvent):
run_id: str
elapsed_s: float
total_generated_tokens: int
peak_memory_bytes: int
EVENT_NAME: str = field(init=False, default="run_end")
@dataclass
class ErrorEvent(_BaseEvent):
run_id: str
stage: str
message: str
recoverable: bool
EVENT_NAME: str = field(init=False, default="error")
class AuditWriter:
"""Streaming JSONL writer. No-op when path is None."""
def __init__(self, path: Path | str | None, mode: str = "w") -> None:
self._path = Path(path) if path is not None else None
self._fp: io.TextIOBase | None = None
if self._path is not None:
self._path.parent.mkdir(parents=True, exist_ok=True)
self._fp = self._path.open(mode, encoding="utf-8")
def write(self, event: _BaseEvent) -> None:
if self._fp is None:
return
self._fp.write(json.dumps(event.to_dict(), ensure_ascii=False))
self._fp.write("\n")
self._fp.flush()
def close(self) -> None:
if self._fp is not None:
self._fp.close()
self._fp = None
def __enter__(self) -> "AuditWriter":
return self
def __exit__(self, *args) -> None:
self.close()

80
tests/test_audit.py Normal file
View File

@@ -0,0 +1,80 @@
import json
import tempfile
from pathlib import Path
from markovian_rsa_mlx.audit import (
AuditWriter,
RunStartEvent,
TraceCompleteEvent,
TailExtractedEvent,
AggregationPromptEvent,
RoundCompleteEvent,
FinalEvent,
RunEndEvent,
)
from markovian_rsa_mlx.config import RSAConfig
def test_writer_emits_jsonl(tmp_path: Path):
audit_path = tmp_path / "audit.jsonl"
writer = AuditWriter(audit_path)
writer.write(RunStartEvent(
run_id="run-1",
model_id="kyr0/zaya1-base-8b-MLX",
quantization="q4_g64",
config=RSAConfig(),
prompt="2+2",
created_at="2026-05-10T00:00:00Z",
))
writer.write(TraceCompleteEvent(
run_id="run-1",
round=0,
trace_id="t0",
text="The answer is 4.",
token_ids=[1, 2, 3],
generated_tokens=3,
finish_reason="eos",
elapsed_s=0.5,
))
writer.close()
lines = audit_path.read_text().strip().split("\n")
assert len(lines) == 2
e0 = json.loads(lines[0])
assert e0["event"] == "run_start"
assert e0["model_id"] == "kyr0/zaya1-base-8b-MLX"
assert e0["config"]["rounds"] == 2
e1 = json.loads(lines[1])
assert e1["event"] == "trace_complete"
assert e1["text"] == "The answer is 4."
assert e1["finish_reason"] == "eos"
def test_event_serialization_includes_event_field():
ev = TailExtractedEvent(
run_id="r", round=0, trace_id="t",
tail_token_ids=[1, 2, 3], tail_text="hi", tail_tokens=3,
)
payload = ev.to_dict()
assert payload["event"] == "tail_extracted"
def test_writer_appends_subsequent_events(tmp_path: Path):
audit_path = tmp_path / "audit.jsonl"
w = AuditWriter(audit_path)
w.write(RunEndEvent(run_id="r", elapsed_s=1.0,
total_generated_tokens=10, peak_memory_bytes=1000))
w.close()
w2 = AuditWriter(audit_path, mode="a")
w2.write(RunEndEvent(run_id="r2", elapsed_s=2.0,
total_generated_tokens=20, peak_memory_bytes=2000))
w2.close()
lines = audit_path.read_text().strip().split("\n")
assert len(lines) == 2
def test_writer_no_op_when_path_none():
w = AuditWriter(None)
w.write(RunEndEvent(run_id="x", elapsed_s=1.0,
total_generated_tokens=0, peak_memory_bytes=0))
w.close() # no exception