fix(audit): raise on write-after-close + weakref.finalize + parametrized event coverage

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
transcrilive
2026-05-10 02:48:38 +02:00
parent b688c4ef77
commit 3d595e021f
2 changed files with 96 additions and 6 deletions

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
import dataclasses import dataclasses
import io import io
import json import json
import weakref
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Literal from typing import Any, Literal
@@ -128,22 +129,36 @@ class ErrorEvent(_BaseEvent):
class AuditWriter: class AuditWriter:
"""Streaming JSONL writer. No-op when path is None.""" """Streaming JSONL writer. No-op when path is None.
def __init__(self, path: Path | str | None, mode: str = "w") -> None:
Once `close()` is called, subsequent `write()` calls raise ValueError.
Use as a context manager (`with AuditWriter(p) as w:`) for guaranteed close.
"""
def __init__(self, path: Path | str | None, mode: Literal["w", "a"] = "w") -> None:
self._path = Path(path) if path is not None else None self._path = Path(path) if path is not None else None
self._fp: io.TextIOBase | None = None self._enabled = self._path is not None
self._closed = False
self._fp: io.TextIOWrapper | None = None
if self._path is not None: if self._path is not None:
self._path.parent.mkdir(parents=True, exist_ok=True) self._path.parent.mkdir(parents=True, exist_ok=True)
self._fp = self._path.open(mode, encoding="utf-8") self._fp = self._path.open(mode, encoding="utf-8")
if self._fp is not None:
weakref.finalize(self, self._fp.close)
def write(self, event: _BaseEvent) -> None: def write(self, event: _BaseEvent) -> None:
if self._fp is None: if not self._enabled:
return return
if self._closed:
raise ValueError("AuditWriter is closed")
assert self._fp is not None
self._fp.write(json.dumps(event.to_dict(), ensure_ascii=False)) self._fp.write(json.dumps(event.to_dict(), ensure_ascii=False))
self._fp.write("\n") self._fp.write("\n")
self._fp.flush() self._fp.flush() # flush per event so a crash mid-run preserves the audit trail
def close(self) -> None: def close(self) -> None:
if self._closed:
return
self._closed = True
if self._fp is not None: if self._fp is not None:
self._fp.close() self._fp.close()
self._fp = None self._fp = None
@@ -151,5 +166,5 @@ class AuditWriter:
def __enter__(self) -> "AuditWriter": def __enter__(self) -> "AuditWriter":
return self return self
def __exit__(self, *args) -> None: def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close() self.close()

View File

@@ -78,3 +78,78 @@ def test_writer_no_op_when_path_none():
w.write(RunEndEvent(run_id="x", elapsed_s=1.0, w.write(RunEndEvent(run_id="x", elapsed_s=1.0,
total_generated_tokens=0, peak_memory_bytes=0)) total_generated_tokens=0, peak_memory_bytes=0))
w.close() # no exception w.close() # no exception
import pytest
from markovian_rsa_mlx.audit import (
GenerationStartEvent, ErrorEvent,
)
EVENT_CASES = [
(RunStartEvent, dict(run_id="r", model_id="m", quantization="q",
config=RSAConfig(), prompt="p", created_at="t"), "run_start"),
(GenerationStartEvent, dict(run_id="r", round=0, trace_id="t", seed=1,
prompt_token_count=2, max_tokens=3,
parent_trace_ids=[]), "generation_start"),
(TraceCompleteEvent, dict(run_id="r", round=0, trace_id="t",
text="x", token_ids=[1], generated_tokens=1,
finish_reason="eos", elapsed_s=0.1), "trace_complete"),
(TailExtractedEvent, dict(run_id="r", round=0, trace_id="t",
tail_token_ids=[1], tail_text="x", tail_tokens=1),
"tail_extracted"),
(AggregationPromptEvent, dict(run_id="r", round=0, trace_id="t",
selected_tail_trace_ids=[], prompt_text="p",
prompt_token_ids=[1]), "aggregation_prompt"),
(RoundCompleteEvent, dict(run_id="r", round=0, trace_ids=[],
memory_estimate_bytes=0, elapsed_s=0.1), "round_complete"),
(FinalEvent, dict(run_id="r", final_trace_id="t", final_text="x",
all_final_trace_ids=[], answer_selection="first_final_candidate"),
"final"),
(RunEndEvent, dict(run_id="r", elapsed_s=0.1, total_generated_tokens=1,
peak_memory_bytes=1), "run_end"),
(ErrorEvent, dict(run_id="r", stage="x", message="m", recoverable=False),
"error"),
]
@pytest.mark.parametrize("event_cls,kwargs,expected_event_name", EVENT_CASES)
def test_event_round_trip_through_writer(tmp_path, event_cls, kwargs, expected_event_name):
"""All 9 event types serialize through AuditWriter with the right `event` field."""
audit_path = tmp_path / f"{expected_event_name}.jsonl"
with AuditWriter(audit_path) as w:
w.write(event_cls(**kwargs))
line = audit_path.read_text().strip()
payload = json.loads(line)
assert payload["event"] == expected_event_name
def test_with_context_manager_closes_on_exit(tmp_path):
"""Context manager closes the writer on block exit, even on exception."""
audit_path = tmp_path / "ctx.jsonl"
w = AuditWriter(audit_path)
with w:
w.write(RunEndEvent(run_id="r", elapsed_s=0.0,
total_generated_tokens=0, peak_memory_bytes=0))
# Writer must be closed after exiting the with block
with pytest.raises(ValueError, match="closed"):
w.write(RunEndEvent(run_id="r", elapsed_s=0.0,
total_generated_tokens=0, peak_memory_bytes=0))
assert audit_path.read_text().strip().count("\n") == 0 # exactly one event written
def test_write_after_close_raises(tmp_path):
"""Calling write() after close() raises a clear error (not silent no-op)."""
audit_path = tmp_path / "post_close.jsonl"
w = AuditWriter(audit_path)
w.close()
with pytest.raises(ValueError, match="closed"):
w.write(RunEndEvent(run_id="r", elapsed_s=0.0,
total_generated_tokens=0, peak_memory_bytes=0))
def test_close_is_idempotent(tmp_path):
"""Calling close() twice is fine."""
audit_path = tmp_path / "double_close.jsonl"
w = AuditWriter(audit_path)
w.close()
w.close() # no exception