51 lines
1.5 KiB
Python
51 lines
1.5 KiB
Python
import pytest
|
|
from markovian_rsa_mlx.prompts import (
|
|
build_round_0_messages,
|
|
build_aggregation_user_content,
|
|
AGGREGATION_TEMPLATES,
|
|
)
|
|
|
|
|
|
def test_round_0_returns_chat_messages_with_user_role():
|
|
msgs = build_round_0_messages("What is 2+2?")
|
|
assert msgs == [{"role": "user", "content": "What is 2+2?"}]
|
|
|
|
|
|
def test_zaya_v1_aggregation_includes_problem_and_K_tails():
|
|
body = build_aggregation_user_content(
|
|
original_prompt="What is 2+2?",
|
|
tails=["...solution attempt 1...", "...solution attempt 2..."],
|
|
template="zaya_v1",
|
|
)
|
|
assert "Original problem" in body
|
|
assert "What is 2+2?" in body
|
|
assert "[Approach 1]" in body
|
|
assert "...solution attempt 1..." in body
|
|
assert "[Approach 2]" in body
|
|
assert "...solution attempt 2..." in body
|
|
assert "synthesize" in body.lower()
|
|
assert "{K}" not in body # placeholder must be filled
|
|
|
|
|
|
def test_zaya_v1_uses_K_equal_len_tails():
|
|
body = build_aggregation_user_content(
|
|
original_prompt="P",
|
|
tails=["t1", "t2", "t3"],
|
|
template="zaya_v1",
|
|
)
|
|
assert "3 candidate" in body # K=3
|
|
assert "[Approach 3]" in body
|
|
|
|
|
|
def test_unknown_template_raises():
|
|
with pytest.raises(ValueError, match="unknown template"):
|
|
build_aggregation_user_content(
|
|
original_prompt="P",
|
|
tails=["t"],
|
|
template="does_not_exist",
|
|
)
|
|
|
|
|
|
def test_aggregation_templates_registry():
|
|
assert "zaya_v1" in AGGREGATION_TEMPLATES
|