import json import tempfile from pathlib import Path from markovian_rsa_mlx.audit import ( AuditWriter, RunStartEvent, TraceCompleteEvent, TailExtractedEvent, AggregationPromptEvent, RoundCompleteEvent, FinalEvent, RunEndEvent, ) from markovian_rsa_mlx.config import RSAConfig def test_writer_emits_jsonl(tmp_path: Path): audit_path = tmp_path / "audit.jsonl" writer = AuditWriter(audit_path) writer.write(RunStartEvent( run_id="run-1", model_id="kyr0/zaya1-base-8b-MLX", quantization="q4_g64", config=RSAConfig(), prompt="2+2", created_at="2026-05-10T00:00:00Z", )) writer.write(TraceCompleteEvent( run_id="run-1", round=0, trace_id="t0", text="The answer is 4.", token_ids=[1, 2, 3], generated_tokens=3, finish_reason="eos", elapsed_s=0.5, )) writer.close() lines = audit_path.read_text().strip().split("\n") assert len(lines) == 2 e0 = json.loads(lines[0]) assert e0["event"] == "run_start" assert e0["model_id"] == "kyr0/zaya1-base-8b-MLX" assert e0["config"]["rounds"] == 2 e1 = json.loads(lines[1]) assert e1["event"] == "trace_complete" assert e1["text"] == "The answer is 4." assert e1["finish_reason"] == "eos" def test_event_serialization_includes_event_field(): ev = TailExtractedEvent( run_id="r", round=0, trace_id="t", tail_token_ids=[1, 2, 3], tail_text="hi", tail_tokens=3, ) payload = ev.to_dict() assert payload["event"] == "tail_extracted" def test_writer_appends_subsequent_events(tmp_path: Path): audit_path = tmp_path / "audit.jsonl" w = AuditWriter(audit_path) w.write(RunEndEvent(run_id="r", elapsed_s=1.0, total_generated_tokens=10, peak_memory_bytes=1000)) w.close() w2 = AuditWriter(audit_path, mode="a") w2.write(RunEndEvent(run_id="r2", elapsed_s=2.0, total_generated_tokens=20, peak_memory_bytes=2000)) w2.close() lines = audit_path.read_text().strip().split("\n") assert len(lines) == 2 def test_writer_no_op_when_path_none(): w = AuditWriter(None) w.write(RunEndEvent(run_id="x", elapsed_s=1.0, total_generated_tokens=0, peak_memory_bytes=0)) w.close() # no exception import pytest from markovian_rsa_mlx.audit import ( GenerationStartEvent, ErrorEvent, ) EVENT_CASES = [ (RunStartEvent, dict(run_id="r", model_id="m", quantization="q", config=RSAConfig(), prompt="p", created_at="t"), "run_start"), (GenerationStartEvent, dict(run_id="r", round=0, trace_id="t", seed=1, prompt_token_count=2, max_tokens=3, parent_trace_ids=[]), "generation_start"), (TraceCompleteEvent, dict(run_id="r", round=0, trace_id="t", text="x", token_ids=[1], generated_tokens=1, finish_reason="eos", elapsed_s=0.1), "trace_complete"), (TailExtractedEvent, dict(run_id="r", round=0, trace_id="t", tail_token_ids=[1], tail_text="x", tail_tokens=1), "tail_extracted"), (AggregationPromptEvent, dict(run_id="r", round=0, trace_id="t", selected_tail_trace_ids=[], prompt_text="p", prompt_token_ids=[1]), "aggregation_prompt"), (RoundCompleteEvent, dict(run_id="r", round=0, trace_ids=[], memory_estimate_bytes=0, elapsed_s=0.1), "round_complete"), (FinalEvent, dict(run_id="r", final_trace_id="t", final_text="x", all_final_trace_ids=[], answer_selection="first_final_candidate"), "final"), (RunEndEvent, dict(run_id="r", elapsed_s=0.1, total_generated_tokens=1, peak_memory_bytes=1), "run_end"), (ErrorEvent, dict(run_id="r", stage="x", message="m", recoverable=False), "error"), ] @pytest.mark.parametrize("event_cls,kwargs,expected_event_name", EVENT_CASES) def test_event_round_trip_through_writer(tmp_path, event_cls, kwargs, expected_event_name): """All 9 event types serialize through AuditWriter with the right `event` field.""" audit_path = tmp_path / f"{expected_event_name}.jsonl" with AuditWriter(audit_path) as w: w.write(event_cls(**kwargs)) line = audit_path.read_text().strip() payload = json.loads(line) assert payload["event"] == expected_event_name def test_with_context_manager_closes_on_exit(tmp_path): """Context manager closes the writer on block exit, even on exception.""" audit_path = tmp_path / "ctx.jsonl" w = AuditWriter(audit_path) with w: w.write(RunEndEvent(run_id="r", elapsed_s=0.0, total_generated_tokens=0, peak_memory_bytes=0)) # Writer must be closed after exiting the with block with pytest.raises(ValueError, match="closed"): w.write(RunEndEvent(run_id="r", elapsed_s=0.0, total_generated_tokens=0, peak_memory_bytes=0)) assert audit_path.read_text().strip().count("\n") == 0 # exactly one event written def test_write_after_close_raises(tmp_path): """Calling write() after close() raises a clear error (not silent no-op).""" audit_path = tmp_path / "post_close.jsonl" w = AuditWriter(audit_path) w.close() with pytest.raises(ValueError, match="closed"): w.write(RunEndEvent(run_id="r", elapsed_s=0.0, total_generated_tokens=0, peak_memory_bytes=0)) def test_close_is_idempotent(tmp_path): """Calling close() twice is fine.""" audit_path = tmp_path / "double_close.jsonl" w = AuditWriter(audit_path) w.close() w.close() # no exception