diff --git a/src/markovian_rsa_mlx/audit.py b/src/markovian_rsa_mlx/audit.py index 0476401..db8c673 100644 --- a/src/markovian_rsa_mlx/audit.py +++ b/src/markovian_rsa_mlx/audit.py @@ -7,6 +7,7 @@ from __future__ import annotations import dataclasses import io import json +import weakref from dataclasses import dataclass, field from pathlib import Path from typing import Any, Literal @@ -128,22 +129,36 @@ class ErrorEvent(_BaseEvent): class AuditWriter: - """Streaming JSONL writer. No-op when path is None.""" - def __init__(self, path: Path | str | None, mode: str = "w") -> None: + """Streaming JSONL writer. No-op when path is None. + + Once `close()` is called, subsequent `write()` calls raise ValueError. + Use as a context manager (`with AuditWriter(p) as w:`) for guaranteed close. + """ + def __init__(self, path: Path | str | None, mode: Literal["w", "a"] = "w") -> None: self._path = Path(path) if path is not None else None - self._fp: io.TextIOBase | None = None + self._enabled = self._path is not None + self._closed = False + self._fp: io.TextIOWrapper | None = None if self._path is not None: self._path.parent.mkdir(parents=True, exist_ok=True) self._fp = self._path.open(mode, encoding="utf-8") + if self._fp is not None: + weakref.finalize(self, self._fp.close) def write(self, event: _BaseEvent) -> None: - if self._fp is None: + if not self._enabled: return + if self._closed: + raise ValueError("AuditWriter is closed") + assert self._fp is not None self._fp.write(json.dumps(event.to_dict(), ensure_ascii=False)) self._fp.write("\n") - self._fp.flush() + self._fp.flush() # flush per event so a crash mid-run preserves the audit trail def close(self) -> None: + if self._closed: + return + self._closed = True if self._fp is not None: self._fp.close() self._fp = None @@ -151,5 +166,5 @@ class AuditWriter: def __enter__(self) -> "AuditWriter": return self - def __exit__(self, *args) -> None: + def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close() diff --git a/tests/test_audit.py b/tests/test_audit.py index e93372c..7255876 100644 --- a/tests/test_audit.py +++ b/tests/test_audit.py @@ -78,3 +78,78 @@ def test_writer_no_op_when_path_none(): w.write(RunEndEvent(run_id="x", elapsed_s=1.0, total_generated_tokens=0, peak_memory_bytes=0)) w.close() # no exception + + +import pytest +from markovian_rsa_mlx.audit import ( + GenerationStartEvent, ErrorEvent, +) + +EVENT_CASES = [ + (RunStartEvent, dict(run_id="r", model_id="m", quantization="q", + config=RSAConfig(), prompt="p", created_at="t"), "run_start"), + (GenerationStartEvent, dict(run_id="r", round=0, trace_id="t", seed=1, + prompt_token_count=2, max_tokens=3, + parent_trace_ids=[]), "generation_start"), + (TraceCompleteEvent, dict(run_id="r", round=0, trace_id="t", + text="x", token_ids=[1], generated_tokens=1, + finish_reason="eos", elapsed_s=0.1), "trace_complete"), + (TailExtractedEvent, dict(run_id="r", round=0, trace_id="t", + tail_token_ids=[1], tail_text="x", tail_tokens=1), + "tail_extracted"), + (AggregationPromptEvent, dict(run_id="r", round=0, trace_id="t", + selected_tail_trace_ids=[], prompt_text="p", + prompt_token_ids=[1]), "aggregation_prompt"), + (RoundCompleteEvent, dict(run_id="r", round=0, trace_ids=[], + memory_estimate_bytes=0, elapsed_s=0.1), "round_complete"), + (FinalEvent, dict(run_id="r", final_trace_id="t", final_text="x", + all_final_trace_ids=[], answer_selection="first_final_candidate"), + "final"), + (RunEndEvent, dict(run_id="r", elapsed_s=0.1, total_generated_tokens=1, + peak_memory_bytes=1), "run_end"), + (ErrorEvent, dict(run_id="r", stage="x", message="m", recoverable=False), + "error"), +] + + +@pytest.mark.parametrize("event_cls,kwargs,expected_event_name", EVENT_CASES) +def test_event_round_trip_through_writer(tmp_path, event_cls, kwargs, expected_event_name): + """All 9 event types serialize through AuditWriter with the right `event` field.""" + audit_path = tmp_path / f"{expected_event_name}.jsonl" + with AuditWriter(audit_path) as w: + w.write(event_cls(**kwargs)) + line = audit_path.read_text().strip() + payload = json.loads(line) + assert payload["event"] == expected_event_name + + +def test_with_context_manager_closes_on_exit(tmp_path): + """Context manager closes the writer on block exit, even on exception.""" + audit_path = tmp_path / "ctx.jsonl" + w = AuditWriter(audit_path) + with w: + w.write(RunEndEvent(run_id="r", elapsed_s=0.0, + total_generated_tokens=0, peak_memory_bytes=0)) + # Writer must be closed after exiting the with block + with pytest.raises(ValueError, match="closed"): + w.write(RunEndEvent(run_id="r", elapsed_s=0.0, + total_generated_tokens=0, peak_memory_bytes=0)) + assert audit_path.read_text().strip().count("\n") == 0 # exactly one event written + + +def test_write_after_close_raises(tmp_path): + """Calling write() after close() raises a clear error (not silent no-op).""" + audit_path = tmp_path / "post_close.jsonl" + w = AuditWriter(audit_path) + w.close() + with pytest.raises(ValueError, match="closed"): + w.write(RunEndEvent(run_id="r", elapsed_s=0.0, + total_generated_tokens=0, peak_memory_bytes=0)) + + +def test_close_is_idempotent(tmp_path): + """Calling close() twice is fine.""" + audit_path = tmp_path / "double_close.jsonl" + w = AuditWriter(audit_path) + w.close() + w.close() # no exception