fix(audit): raise on write-after-close + weakref.finalize + parametrized event coverage
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from __future__ import annotations
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
import weakref
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
@@ -128,22 +129,36 @@ class ErrorEvent(_BaseEvent):
|
|||||||
|
|
||||||
|
|
||||||
class AuditWriter:
|
class AuditWriter:
|
||||||
"""Streaming JSONL writer. No-op when path is None."""
|
"""Streaming JSONL writer. No-op when path is None.
|
||||||
def __init__(self, path: Path | str | None, mode: str = "w") -> None:
|
|
||||||
|
Once `close()` is called, subsequent `write()` calls raise ValueError.
|
||||||
|
Use as a context manager (`with AuditWriter(p) as w:`) for guaranteed close.
|
||||||
|
"""
|
||||||
|
def __init__(self, path: Path | str | None, mode: Literal["w", "a"] = "w") -> None:
|
||||||
self._path = Path(path) if path is not None else None
|
self._path = Path(path) if path is not None else None
|
||||||
self._fp: io.TextIOBase | None = None
|
self._enabled = self._path is not None
|
||||||
|
self._closed = False
|
||||||
|
self._fp: io.TextIOWrapper | None = None
|
||||||
if self._path is not None:
|
if self._path is not None:
|
||||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
self._fp = self._path.open(mode, encoding="utf-8")
|
self._fp = self._path.open(mode, encoding="utf-8")
|
||||||
|
if self._fp is not None:
|
||||||
|
weakref.finalize(self, self._fp.close)
|
||||||
|
|
||||||
def write(self, event: _BaseEvent) -> None:
|
def write(self, event: _BaseEvent) -> None:
|
||||||
if self._fp is None:
|
if not self._enabled:
|
||||||
return
|
return
|
||||||
|
if self._closed:
|
||||||
|
raise ValueError("AuditWriter is closed")
|
||||||
|
assert self._fp is not None
|
||||||
self._fp.write(json.dumps(event.to_dict(), ensure_ascii=False))
|
self._fp.write(json.dumps(event.to_dict(), ensure_ascii=False))
|
||||||
self._fp.write("\n")
|
self._fp.write("\n")
|
||||||
self._fp.flush()
|
self._fp.flush() # flush per event so a crash mid-run preserves the audit trail
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
|
if self._closed:
|
||||||
|
return
|
||||||
|
self._closed = True
|
||||||
if self._fp is not None:
|
if self._fp is not None:
|
||||||
self._fp.close()
|
self._fp.close()
|
||||||
self._fp = None
|
self._fp = None
|
||||||
@@ -151,5 +166,5 @@ class AuditWriter:
|
|||||||
def __enter__(self) -> "AuditWriter":
|
def __enter__(self) -> "AuditWriter":
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, *args) -> None:
|
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||||
self.close()
|
self.close()
|
||||||
|
|||||||
@@ -78,3 +78,78 @@ def test_writer_no_op_when_path_none():
|
|||||||
w.write(RunEndEvent(run_id="x", elapsed_s=1.0,
|
w.write(RunEndEvent(run_id="x", elapsed_s=1.0,
|
||||||
total_generated_tokens=0, peak_memory_bytes=0))
|
total_generated_tokens=0, peak_memory_bytes=0))
|
||||||
w.close() # no exception
|
w.close() # no exception
|
||||||
|
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from markovian_rsa_mlx.audit import (
|
||||||
|
GenerationStartEvent, ErrorEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
EVENT_CASES = [
|
||||||
|
(RunStartEvent, dict(run_id="r", model_id="m", quantization="q",
|
||||||
|
config=RSAConfig(), prompt="p", created_at="t"), "run_start"),
|
||||||
|
(GenerationStartEvent, dict(run_id="r", round=0, trace_id="t", seed=1,
|
||||||
|
prompt_token_count=2, max_tokens=3,
|
||||||
|
parent_trace_ids=[]), "generation_start"),
|
||||||
|
(TraceCompleteEvent, dict(run_id="r", round=0, trace_id="t",
|
||||||
|
text="x", token_ids=[1], generated_tokens=1,
|
||||||
|
finish_reason="eos", elapsed_s=0.1), "trace_complete"),
|
||||||
|
(TailExtractedEvent, dict(run_id="r", round=0, trace_id="t",
|
||||||
|
tail_token_ids=[1], tail_text="x", tail_tokens=1),
|
||||||
|
"tail_extracted"),
|
||||||
|
(AggregationPromptEvent, dict(run_id="r", round=0, trace_id="t",
|
||||||
|
selected_tail_trace_ids=[], prompt_text="p",
|
||||||
|
prompt_token_ids=[1]), "aggregation_prompt"),
|
||||||
|
(RoundCompleteEvent, dict(run_id="r", round=0, trace_ids=[],
|
||||||
|
memory_estimate_bytes=0, elapsed_s=0.1), "round_complete"),
|
||||||
|
(FinalEvent, dict(run_id="r", final_trace_id="t", final_text="x",
|
||||||
|
all_final_trace_ids=[], answer_selection="first_final_candidate"),
|
||||||
|
"final"),
|
||||||
|
(RunEndEvent, dict(run_id="r", elapsed_s=0.1, total_generated_tokens=1,
|
||||||
|
peak_memory_bytes=1), "run_end"),
|
||||||
|
(ErrorEvent, dict(run_id="r", stage="x", message="m", recoverable=False),
|
||||||
|
"error"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("event_cls,kwargs,expected_event_name", EVENT_CASES)
|
||||||
|
def test_event_round_trip_through_writer(tmp_path, event_cls, kwargs, expected_event_name):
|
||||||
|
"""All 9 event types serialize through AuditWriter with the right `event` field."""
|
||||||
|
audit_path = tmp_path / f"{expected_event_name}.jsonl"
|
||||||
|
with AuditWriter(audit_path) as w:
|
||||||
|
w.write(event_cls(**kwargs))
|
||||||
|
line = audit_path.read_text().strip()
|
||||||
|
payload = json.loads(line)
|
||||||
|
assert payload["event"] == expected_event_name
|
||||||
|
|
||||||
|
|
||||||
|
def test_with_context_manager_closes_on_exit(tmp_path):
|
||||||
|
"""Context manager closes the writer on block exit, even on exception."""
|
||||||
|
audit_path = tmp_path / "ctx.jsonl"
|
||||||
|
w = AuditWriter(audit_path)
|
||||||
|
with w:
|
||||||
|
w.write(RunEndEvent(run_id="r", elapsed_s=0.0,
|
||||||
|
total_generated_tokens=0, peak_memory_bytes=0))
|
||||||
|
# Writer must be closed after exiting the with block
|
||||||
|
with pytest.raises(ValueError, match="closed"):
|
||||||
|
w.write(RunEndEvent(run_id="r", elapsed_s=0.0,
|
||||||
|
total_generated_tokens=0, peak_memory_bytes=0))
|
||||||
|
assert audit_path.read_text().strip().count("\n") == 0 # exactly one event written
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_after_close_raises(tmp_path):
|
||||||
|
"""Calling write() after close() raises a clear error (not silent no-op)."""
|
||||||
|
audit_path = tmp_path / "post_close.jsonl"
|
||||||
|
w = AuditWriter(audit_path)
|
||||||
|
w.close()
|
||||||
|
with pytest.raises(ValueError, match="closed"):
|
||||||
|
w.write(RunEndEvent(run_id="r", elapsed_s=0.0,
|
||||||
|
total_generated_tokens=0, peak_memory_bytes=0))
|
||||||
|
|
||||||
|
|
||||||
|
def test_close_is_idempotent(tmp_path):
|
||||||
|
"""Calling close() twice is fine."""
|
||||||
|
audit_path = tmp_path / "double_close.jsonl"
|
||||||
|
w = AuditWriter(audit_path)
|
||||||
|
w.close()
|
||||||
|
w.close() # no exception
|
||||||
|
|||||||
Reference in New Issue
Block a user