Files
markovian-rsa-mlx/tests/test_prompts.py

51 lines
1.5 KiB
Python

import pytest
from markovian_rsa_mlx.prompts import (
build_round_0_messages,
build_aggregation_user_content,
AGGREGATION_TEMPLATES,
)
def test_round_0_returns_chat_messages_with_user_role():
msgs = build_round_0_messages("What is 2+2?")
assert msgs == [{"role": "user", "content": "What is 2+2?"}]
def test_zaya_v1_aggregation_includes_problem_and_K_tails():
body = build_aggregation_user_content(
original_prompt="What is 2+2?",
tails=["...solution attempt 1...", "...solution attempt 2..."],
template="zaya_v1",
)
assert "Original problem" in body
assert "What is 2+2?" in body
assert "[Approach 1]" in body
assert "...solution attempt 1..." in body
assert "[Approach 2]" in body
assert "...solution attempt 2..." in body
assert "synthesize" in body.lower()
assert "{K}" not in body # placeholder must be filled
def test_zaya_v1_uses_K_equal_len_tails():
body = build_aggregation_user_content(
original_prompt="P",
tails=["t1", "t2", "t3"],
template="zaya_v1",
)
assert "3 candidate" in body # K=3
assert "[Approach 3]" in body
def test_unknown_template_raises():
with pytest.raises(ValueError, match="unknown template"):
build_aggregation_user_content(
original_prompt="P",
tails=["t"],
template="does_not_exist",
)
def test_aggregation_templates_registry():
assert "zaya_v1" in AGGREGATION_TEMPLATES