feat(audit): add JSONL event types + streaming AuditWriter
This commit is contained in:
155
src/markovian_rsa_mlx/audit.py
Normal file
155
src/markovian_rsa_mlx/audit.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""JSONL audit events + writer.
|
||||
|
||||
All RSA orchestrator state changes emit one event per JSON line. Schema is
|
||||
documented in docs/superpowers/specs/2026-05-10-markovian-rsa-mlx-design.md.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import io
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from markovian_rsa_mlx.config import RSAConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BaseEvent:
|
||||
"""Base for all audit events. Subclasses set EVENT_NAME."""
|
||||
EVENT_NAME: str = field(init=False, repr=False, default="")
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
d: dict[str, Any] = {"event": self.EVENT_NAME}
|
||||
for f in dataclasses.fields(self):
|
||||
if f.name == "EVENT_NAME":
|
||||
continue
|
||||
v = getattr(self, f.name)
|
||||
if isinstance(v, RSAConfig):
|
||||
v = dataclasses.asdict(v)
|
||||
d[f.name] = v
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunStartEvent(_BaseEvent):
|
||||
run_id: str
|
||||
model_id: str
|
||||
quantization: str
|
||||
config: RSAConfig
|
||||
prompt: str
|
||||
created_at: str
|
||||
EVENT_NAME: str = field(init=False, default="run_start")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationStartEvent(_BaseEvent):
|
||||
run_id: str
|
||||
round: int
|
||||
trace_id: str
|
||||
seed: int
|
||||
prompt_token_count: int
|
||||
max_tokens: int
|
||||
parent_trace_ids: list[str]
|
||||
EVENT_NAME: str = field(init=False, default="generation_start")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraceCompleteEvent(_BaseEvent):
|
||||
run_id: str
|
||||
round: int
|
||||
trace_id: str
|
||||
text: str
|
||||
token_ids: list[int]
|
||||
generated_tokens: int
|
||||
finish_reason: Literal["eos", "max_tokens", "error"]
|
||||
elapsed_s: float
|
||||
EVENT_NAME: str = field(init=False, default="trace_complete")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TailExtractedEvent(_BaseEvent):
|
||||
run_id: str
|
||||
round: int
|
||||
trace_id: str
|
||||
tail_token_ids: list[int]
|
||||
tail_text: str
|
||||
tail_tokens: int
|
||||
EVENT_NAME: str = field(init=False, default="tail_extracted")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregationPromptEvent(_BaseEvent):
|
||||
run_id: str
|
||||
round: int
|
||||
trace_id: str
|
||||
selected_tail_trace_ids: list[str]
|
||||
prompt_text: str
|
||||
prompt_token_ids: list[int]
|
||||
EVENT_NAME: str = field(init=False, default="aggregation_prompt")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoundCompleteEvent(_BaseEvent):
|
||||
run_id: str
|
||||
round: int
|
||||
trace_ids: list[str]
|
||||
memory_estimate_bytes: int
|
||||
elapsed_s: float
|
||||
EVENT_NAME: str = field(init=False, default="round_complete")
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinalEvent(_BaseEvent):
|
||||
run_id: str
|
||||
final_trace_id: str
|
||||
final_text: str
|
||||
all_final_trace_ids: list[str]
|
||||
answer_selection: str
|
||||
EVENT_NAME: str = field(init=False, default="final")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunEndEvent(_BaseEvent):
|
||||
run_id: str
|
||||
elapsed_s: float
|
||||
total_generated_tokens: int
|
||||
peak_memory_bytes: int
|
||||
EVENT_NAME: str = field(init=False, default="run_end")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorEvent(_BaseEvent):
|
||||
run_id: str
|
||||
stage: str
|
||||
message: str
|
||||
recoverable: bool
|
||||
EVENT_NAME: str = field(init=False, default="error")
|
||||
|
||||
|
||||
class AuditWriter:
|
||||
"""Streaming JSONL writer. No-op when path is None."""
|
||||
def __init__(self, path: Path | str | None, mode: str = "w") -> None:
|
||||
self._path = Path(path) if path is not None else None
|
||||
self._fp: io.TextIOBase | None = None
|
||||
if self._path is not None:
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._fp = self._path.open(mode, encoding="utf-8")
|
||||
|
||||
def write(self, event: _BaseEvent) -> None:
|
||||
if self._fp is None:
|
||||
return
|
||||
self._fp.write(json.dumps(event.to_dict(), ensure_ascii=False))
|
||||
self._fp.write("\n")
|
||||
self._fp.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
if self._fp is not None:
|
||||
self._fp.close()
|
||||
self._fp = None
|
||||
|
||||
def __enter__(self) -> "AuditWriter":
|
||||
return self
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
self.close()
|
||||
80
tests/test_audit.py
Normal file
80
tests/test_audit.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from markovian_rsa_mlx.audit import (
|
||||
AuditWriter,
|
||||
RunStartEvent,
|
||||
TraceCompleteEvent,
|
||||
TailExtractedEvent,
|
||||
AggregationPromptEvent,
|
||||
RoundCompleteEvent,
|
||||
FinalEvent,
|
||||
RunEndEvent,
|
||||
)
|
||||
from markovian_rsa_mlx.config import RSAConfig
|
||||
|
||||
|
||||
def test_writer_emits_jsonl(tmp_path: Path):
|
||||
audit_path = tmp_path / "audit.jsonl"
|
||||
writer = AuditWriter(audit_path)
|
||||
writer.write(RunStartEvent(
|
||||
run_id="run-1",
|
||||
model_id="kyr0/zaya1-base-8b-MLX",
|
||||
quantization="q4_g64",
|
||||
config=RSAConfig(),
|
||||
prompt="2+2",
|
||||
created_at="2026-05-10T00:00:00Z",
|
||||
))
|
||||
writer.write(TraceCompleteEvent(
|
||||
run_id="run-1",
|
||||
round=0,
|
||||
trace_id="t0",
|
||||
text="The answer is 4.",
|
||||
token_ids=[1, 2, 3],
|
||||
generated_tokens=3,
|
||||
finish_reason="eos",
|
||||
elapsed_s=0.5,
|
||||
))
|
||||
writer.close()
|
||||
|
||||
lines = audit_path.read_text().strip().split("\n")
|
||||
assert len(lines) == 2
|
||||
e0 = json.loads(lines[0])
|
||||
assert e0["event"] == "run_start"
|
||||
assert e0["model_id"] == "kyr0/zaya1-base-8b-MLX"
|
||||
assert e0["config"]["rounds"] == 2
|
||||
e1 = json.loads(lines[1])
|
||||
assert e1["event"] == "trace_complete"
|
||||
assert e1["text"] == "The answer is 4."
|
||||
assert e1["finish_reason"] == "eos"
|
||||
|
||||
|
||||
def test_event_serialization_includes_event_field():
|
||||
ev = TailExtractedEvent(
|
||||
run_id="r", round=0, trace_id="t",
|
||||
tail_token_ids=[1, 2, 3], tail_text="hi", tail_tokens=3,
|
||||
)
|
||||
payload = ev.to_dict()
|
||||
assert payload["event"] == "tail_extracted"
|
||||
|
||||
|
||||
def test_writer_appends_subsequent_events(tmp_path: Path):
|
||||
audit_path = tmp_path / "audit.jsonl"
|
||||
w = AuditWriter(audit_path)
|
||||
w.write(RunEndEvent(run_id="r", elapsed_s=1.0,
|
||||
total_generated_tokens=10, peak_memory_bytes=1000))
|
||||
w.close()
|
||||
w2 = AuditWriter(audit_path, mode="a")
|
||||
w2.write(RunEndEvent(run_id="r2", elapsed_s=2.0,
|
||||
total_generated_tokens=20, peak_memory_bytes=2000))
|
||||
w2.close()
|
||||
lines = audit_path.read_text().strip().split("\n")
|
||||
assert len(lines) == 2
|
||||
|
||||
|
||||
def test_writer_no_op_when_path_none():
|
||||
w = AuditWriter(None)
|
||||
w.write(RunEndEvent(run_id="x", elapsed_s=1.0,
|
||||
total_generated_tokens=0, peak_memory_bytes=0))
|
||||
w.close() # no exception
|
||||
Reference in New Issue
Block a user