156 lines
5.7 KiB
Python
156 lines
5.7 KiB
Python
import json
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
from markovian_rsa_mlx.audit import (
|
|
AuditWriter,
|
|
RunStartEvent,
|
|
TraceCompleteEvent,
|
|
TailExtractedEvent,
|
|
AggregationPromptEvent,
|
|
RoundCompleteEvent,
|
|
FinalEvent,
|
|
RunEndEvent,
|
|
)
|
|
from markovian_rsa_mlx.config import RSAConfig
|
|
|
|
|
|
def test_writer_emits_jsonl(tmp_path: Path):
|
|
audit_path = tmp_path / "audit.jsonl"
|
|
writer = AuditWriter(audit_path)
|
|
writer.write(RunStartEvent(
|
|
run_id="run-1",
|
|
model_id="kyr0/zaya1-base-8b-MLX",
|
|
quantization="q4_g64",
|
|
config=RSAConfig(),
|
|
prompt="2+2",
|
|
created_at="2026-05-10T00:00:00Z",
|
|
))
|
|
writer.write(TraceCompleteEvent(
|
|
run_id="run-1",
|
|
round=0,
|
|
trace_id="t0",
|
|
text="The answer is 4.",
|
|
token_ids=[1, 2, 3],
|
|
generated_tokens=3,
|
|
finish_reason="eos",
|
|
elapsed_s=0.5,
|
|
))
|
|
writer.close()
|
|
|
|
lines = audit_path.read_text().strip().split("\n")
|
|
assert len(lines) == 2
|
|
e0 = json.loads(lines[0])
|
|
assert e0["event"] == "run_start"
|
|
assert e0["model_id"] == "kyr0/zaya1-base-8b-MLX"
|
|
assert e0["config"]["rounds"] == 2
|
|
e1 = json.loads(lines[1])
|
|
assert e1["event"] == "trace_complete"
|
|
assert e1["text"] == "The answer is 4."
|
|
assert e1["finish_reason"] == "eos"
|
|
|
|
|
|
def test_event_serialization_includes_event_field():
|
|
ev = TailExtractedEvent(
|
|
run_id="r", round=0, trace_id="t",
|
|
tail_token_ids=[1, 2, 3], tail_text="hi", tail_tokens=3,
|
|
)
|
|
payload = ev.to_dict()
|
|
assert payload["event"] == "tail_extracted"
|
|
|
|
|
|
def test_writer_appends_subsequent_events(tmp_path: Path):
|
|
audit_path = tmp_path / "audit.jsonl"
|
|
w = AuditWriter(audit_path)
|
|
w.write(RunEndEvent(run_id="r", elapsed_s=1.0,
|
|
total_generated_tokens=10, peak_memory_bytes=1000))
|
|
w.close()
|
|
w2 = AuditWriter(audit_path, mode="a")
|
|
w2.write(RunEndEvent(run_id="r2", elapsed_s=2.0,
|
|
total_generated_tokens=20, peak_memory_bytes=2000))
|
|
w2.close()
|
|
lines = audit_path.read_text().strip().split("\n")
|
|
assert len(lines) == 2
|
|
|
|
|
|
def test_writer_no_op_when_path_none():
|
|
w = AuditWriter(None)
|
|
w.write(RunEndEvent(run_id="x", elapsed_s=1.0,
|
|
total_generated_tokens=0, peak_memory_bytes=0))
|
|
w.close() # no exception
|
|
|
|
|
|
import pytest
|
|
from markovian_rsa_mlx.audit import (
|
|
GenerationStartEvent, ErrorEvent,
|
|
)
|
|
|
|
EVENT_CASES = [
|
|
(RunStartEvent, dict(run_id="r", model_id="m", quantization="q",
|
|
config=RSAConfig(), prompt="p", created_at="t"), "run_start"),
|
|
(GenerationStartEvent, dict(run_id="r", round=0, trace_id="t", seed=1,
|
|
prompt_token_count=2, max_tokens=3,
|
|
parent_trace_ids=[]), "generation_start"),
|
|
(TraceCompleteEvent, dict(run_id="r", round=0, trace_id="t",
|
|
text="x", token_ids=[1], generated_tokens=1,
|
|
finish_reason="eos", elapsed_s=0.1), "trace_complete"),
|
|
(TailExtractedEvent, dict(run_id="r", round=0, trace_id="t",
|
|
tail_token_ids=[1], tail_text="x", tail_tokens=1),
|
|
"tail_extracted"),
|
|
(AggregationPromptEvent, dict(run_id="r", round=0, trace_id="t",
|
|
selected_tail_trace_ids=[], prompt_text="p",
|
|
prompt_token_ids=[1]), "aggregation_prompt"),
|
|
(RoundCompleteEvent, dict(run_id="r", round=0, trace_ids=[],
|
|
memory_estimate_bytes=0, elapsed_s=0.1), "round_complete"),
|
|
(FinalEvent, dict(run_id="r", final_trace_id="t", final_text="x",
|
|
all_final_trace_ids=[], answer_selection="first_final_candidate"),
|
|
"final"),
|
|
(RunEndEvent, dict(run_id="r", elapsed_s=0.1, total_generated_tokens=1,
|
|
peak_memory_bytes=1), "run_end"),
|
|
(ErrorEvent, dict(run_id="r", stage="x", message="m", recoverable=False),
|
|
"error"),
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("event_cls,kwargs,expected_event_name", EVENT_CASES)
|
|
def test_event_round_trip_through_writer(tmp_path, event_cls, kwargs, expected_event_name):
|
|
"""All 9 event types serialize through AuditWriter with the right `event` field."""
|
|
audit_path = tmp_path / f"{expected_event_name}.jsonl"
|
|
with AuditWriter(audit_path) as w:
|
|
w.write(event_cls(**kwargs))
|
|
line = audit_path.read_text().strip()
|
|
payload = json.loads(line)
|
|
assert payload["event"] == expected_event_name
|
|
|
|
|
|
def test_with_context_manager_closes_on_exit(tmp_path):
|
|
"""Context manager closes the writer on block exit, even on exception."""
|
|
audit_path = tmp_path / "ctx.jsonl"
|
|
w = AuditWriter(audit_path)
|
|
with w:
|
|
w.write(RunEndEvent(run_id="r", elapsed_s=0.0,
|
|
total_generated_tokens=0, peak_memory_bytes=0))
|
|
# Writer must be closed after exiting the with block
|
|
with pytest.raises(ValueError, match="closed"):
|
|
w.write(RunEndEvent(run_id="r", elapsed_s=0.0,
|
|
total_generated_tokens=0, peak_memory_bytes=0))
|
|
assert audit_path.read_text().strip().count("\n") == 0 # exactly one event written
|
|
|
|
|
|
def test_write_after_close_raises(tmp_path):
|
|
"""Calling write() after close() raises a clear error (not silent no-op)."""
|
|
audit_path = tmp_path / "post_close.jsonl"
|
|
w = AuditWriter(audit_path)
|
|
w.close()
|
|
with pytest.raises(ValueError, match="closed"):
|
|
w.write(RunEndEvent(run_id="r", elapsed_s=0.0,
|
|
total_generated_tokens=0, peak_memory_bytes=0))
|
|
|
|
|
|
def test_close_is_idempotent(tmp_path):
|
|
"""Calling close() twice is fine."""
|
|
audit_path = tmp_path / "double_close.jsonl"
|
|
w = AuditWriter(audit_path)
|
|
w.close()
|
|
w.close() # no exception
|