feat(audit): add JSONL event types + streaming AuditWriter
This commit is contained in:
155
src/markovian_rsa_mlx/audit.py
Normal file
155
src/markovian_rsa_mlx/audit.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
"""JSONL audit events + writer.
|
||||||
|
|
||||||
|
All RSA orchestrator state changes emit one event per JSON line. Schema is
|
||||||
|
documented in docs/superpowers/specs/2026-05-10-markovian-rsa-mlx-design.md.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
import dataclasses
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from markovian_rsa_mlx.config import RSAConfig
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BaseEvent:
|
||||||
|
"""Base for all audit events. Subclasses set EVENT_NAME."""
|
||||||
|
EVENT_NAME: str = field(init=False, repr=False, default="")
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
d: dict[str, Any] = {"event": self.EVENT_NAME}
|
||||||
|
for f in dataclasses.fields(self):
|
||||||
|
if f.name == "EVENT_NAME":
|
||||||
|
continue
|
||||||
|
v = getattr(self, f.name)
|
||||||
|
if isinstance(v, RSAConfig):
|
||||||
|
v = dataclasses.asdict(v)
|
||||||
|
d[f.name] = v
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RunStartEvent(_BaseEvent):
|
||||||
|
run_id: str
|
||||||
|
model_id: str
|
||||||
|
quantization: str
|
||||||
|
config: RSAConfig
|
||||||
|
prompt: str
|
||||||
|
created_at: str
|
||||||
|
EVENT_NAME: str = field(init=False, default="run_start")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GenerationStartEvent(_BaseEvent):
|
||||||
|
run_id: str
|
||||||
|
round: int
|
||||||
|
trace_id: str
|
||||||
|
seed: int
|
||||||
|
prompt_token_count: int
|
||||||
|
max_tokens: int
|
||||||
|
parent_trace_ids: list[str]
|
||||||
|
EVENT_NAME: str = field(init=False, default="generation_start")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TraceCompleteEvent(_BaseEvent):
|
||||||
|
run_id: str
|
||||||
|
round: int
|
||||||
|
trace_id: str
|
||||||
|
text: str
|
||||||
|
token_ids: list[int]
|
||||||
|
generated_tokens: int
|
||||||
|
finish_reason: Literal["eos", "max_tokens", "error"]
|
||||||
|
elapsed_s: float
|
||||||
|
EVENT_NAME: str = field(init=False, default="trace_complete")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TailExtractedEvent(_BaseEvent):
|
||||||
|
run_id: str
|
||||||
|
round: int
|
||||||
|
trace_id: str
|
||||||
|
tail_token_ids: list[int]
|
||||||
|
tail_text: str
|
||||||
|
tail_tokens: int
|
||||||
|
EVENT_NAME: str = field(init=False, default="tail_extracted")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AggregationPromptEvent(_BaseEvent):
|
||||||
|
run_id: str
|
||||||
|
round: int
|
||||||
|
trace_id: str
|
||||||
|
selected_tail_trace_ids: list[str]
|
||||||
|
prompt_text: str
|
||||||
|
prompt_token_ids: list[int]
|
||||||
|
EVENT_NAME: str = field(init=False, default="aggregation_prompt")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RoundCompleteEvent(_BaseEvent):
|
||||||
|
run_id: str
|
||||||
|
round: int
|
||||||
|
trace_ids: list[str]
|
||||||
|
memory_estimate_bytes: int
|
||||||
|
elapsed_s: float
|
||||||
|
EVENT_NAME: str = field(init=False, default="round_complete")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FinalEvent(_BaseEvent):
|
||||||
|
run_id: str
|
||||||
|
final_trace_id: str
|
||||||
|
final_text: str
|
||||||
|
all_final_trace_ids: list[str]
|
||||||
|
answer_selection: str
|
||||||
|
EVENT_NAME: str = field(init=False, default="final")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RunEndEvent(_BaseEvent):
|
||||||
|
run_id: str
|
||||||
|
elapsed_s: float
|
||||||
|
total_generated_tokens: int
|
||||||
|
peak_memory_bytes: int
|
||||||
|
EVENT_NAME: str = field(init=False, default="run_end")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ErrorEvent(_BaseEvent):
|
||||||
|
run_id: str
|
||||||
|
stage: str
|
||||||
|
message: str
|
||||||
|
recoverable: bool
|
||||||
|
EVENT_NAME: str = field(init=False, default="error")
|
||||||
|
|
||||||
|
|
||||||
|
class AuditWriter:
|
||||||
|
"""Streaming JSONL writer. No-op when path is None."""
|
||||||
|
def __init__(self, path: Path | str | None, mode: str = "w") -> None:
|
||||||
|
self._path = Path(path) if path is not None else None
|
||||||
|
self._fp: io.TextIOBase | None = None
|
||||||
|
if self._path is not None:
|
||||||
|
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._fp = self._path.open(mode, encoding="utf-8")
|
||||||
|
|
||||||
|
def write(self, event: _BaseEvent) -> None:
|
||||||
|
if self._fp is None:
|
||||||
|
return
|
||||||
|
self._fp.write(json.dumps(event.to_dict(), ensure_ascii=False))
|
||||||
|
self._fp.write("\n")
|
||||||
|
self._fp.flush()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
if self._fp is not None:
|
||||||
|
self._fp.close()
|
||||||
|
self._fp = None
|
||||||
|
|
||||||
|
def __enter__(self) -> "AuditWriter":
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args) -> None:
|
||||||
|
self.close()
|
||||||
80
tests/test_audit.py
Normal file
80
tests/test_audit.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from markovian_rsa_mlx.audit import (
|
||||||
|
AuditWriter,
|
||||||
|
RunStartEvent,
|
||||||
|
TraceCompleteEvent,
|
||||||
|
TailExtractedEvent,
|
||||||
|
AggregationPromptEvent,
|
||||||
|
RoundCompleteEvent,
|
||||||
|
FinalEvent,
|
||||||
|
RunEndEvent,
|
||||||
|
)
|
||||||
|
from markovian_rsa_mlx.config import RSAConfig
|
||||||
|
|
||||||
|
|
||||||
|
def test_writer_emits_jsonl(tmp_path: Path):
|
||||||
|
audit_path = tmp_path / "audit.jsonl"
|
||||||
|
writer = AuditWriter(audit_path)
|
||||||
|
writer.write(RunStartEvent(
|
||||||
|
run_id="run-1",
|
||||||
|
model_id="kyr0/zaya1-base-8b-MLX",
|
||||||
|
quantization="q4_g64",
|
||||||
|
config=RSAConfig(),
|
||||||
|
prompt="2+2",
|
||||||
|
created_at="2026-05-10T00:00:00Z",
|
||||||
|
))
|
||||||
|
writer.write(TraceCompleteEvent(
|
||||||
|
run_id="run-1",
|
||||||
|
round=0,
|
||||||
|
trace_id="t0",
|
||||||
|
text="The answer is 4.",
|
||||||
|
token_ids=[1, 2, 3],
|
||||||
|
generated_tokens=3,
|
||||||
|
finish_reason="eos",
|
||||||
|
elapsed_s=0.5,
|
||||||
|
))
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
lines = audit_path.read_text().strip().split("\n")
|
||||||
|
assert len(lines) == 2
|
||||||
|
e0 = json.loads(lines[0])
|
||||||
|
assert e0["event"] == "run_start"
|
||||||
|
assert e0["model_id"] == "kyr0/zaya1-base-8b-MLX"
|
||||||
|
assert e0["config"]["rounds"] == 2
|
||||||
|
e1 = json.loads(lines[1])
|
||||||
|
assert e1["event"] == "trace_complete"
|
||||||
|
assert e1["text"] == "The answer is 4."
|
||||||
|
assert e1["finish_reason"] == "eos"
|
||||||
|
|
||||||
|
|
||||||
|
def test_event_serialization_includes_event_field():
|
||||||
|
ev = TailExtractedEvent(
|
||||||
|
run_id="r", round=0, trace_id="t",
|
||||||
|
tail_token_ids=[1, 2, 3], tail_text="hi", tail_tokens=3,
|
||||||
|
)
|
||||||
|
payload = ev.to_dict()
|
||||||
|
assert payload["event"] == "tail_extracted"
|
||||||
|
|
||||||
|
|
||||||
|
def test_writer_appends_subsequent_events(tmp_path: Path):
|
||||||
|
audit_path = tmp_path / "audit.jsonl"
|
||||||
|
w = AuditWriter(audit_path)
|
||||||
|
w.write(RunEndEvent(run_id="r", elapsed_s=1.0,
|
||||||
|
total_generated_tokens=10, peak_memory_bytes=1000))
|
||||||
|
w.close()
|
||||||
|
w2 = AuditWriter(audit_path, mode="a")
|
||||||
|
w2.write(RunEndEvent(run_id="r2", elapsed_s=2.0,
|
||||||
|
total_generated_tokens=20, peak_memory_bytes=2000))
|
||||||
|
w2.close()
|
||||||
|
lines = audit_path.read_text().strip().split("\n")
|
||||||
|
assert len(lines) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_writer_no_op_when_path_none():
|
||||||
|
w = AuditWriter(None)
|
||||||
|
w.write(RunEndEvent(run_id="x", elapsed_s=1.0,
|
||||||
|
total_generated_tokens=0, peak_memory_bytes=0))
|
||||||
|
w.close() # no exception
|
||||||
Reference in New Issue
Block a user