feat(prompts): add round-0 + zaya_v1 aggregation templates with versioned registry
This commit is contained in:
50
tests/test_prompts.py
Normal file
50
tests/test_prompts.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
from markovian_rsa_mlx.prompts import (
|
||||
build_round_0_messages,
|
||||
build_aggregation_user_content,
|
||||
AGGREGATION_TEMPLATES,
|
||||
)
|
||||
|
||||
|
||||
def test_round_0_returns_chat_messages_with_user_role():
|
||||
msgs = build_round_0_messages("What is 2+2?")
|
||||
assert msgs == [{"role": "user", "content": "What is 2+2?"}]
|
||||
|
||||
|
||||
def test_zaya_v1_aggregation_includes_problem_and_K_tails():
|
||||
body = build_aggregation_user_content(
|
||||
original_prompt="What is 2+2?",
|
||||
tails=["...solution attempt 1...", "...solution attempt 2..."],
|
||||
template="zaya_v1",
|
||||
)
|
||||
assert "Original problem" in body
|
||||
assert "What is 2+2?" in body
|
||||
assert "[Approach 1]" in body
|
||||
assert "...solution attempt 1..." in body
|
||||
assert "[Approach 2]" in body
|
||||
assert "...solution attempt 2..." in body
|
||||
assert "synthesize" in body.lower()
|
||||
assert "{K}" not in body # placeholder must be filled
|
||||
|
||||
|
||||
def test_zaya_v1_uses_K_equal_len_tails():
|
||||
body = build_aggregation_user_content(
|
||||
original_prompt="P",
|
||||
tails=["t1", "t2", "t3"],
|
||||
template="zaya_v1",
|
||||
)
|
||||
assert "3 candidate" in body # K=3
|
||||
assert "[Approach 3]" in body
|
||||
|
||||
|
||||
def test_unknown_template_raises():
|
||||
with pytest.raises(ValueError, match="unknown template"):
|
||||
build_aggregation_user_content(
|
||||
original_prompt="P",
|
||||
tails=["t"],
|
||||
template="does_not_exist",
|
||||
)
|
||||
|
||||
|
||||
def test_aggregation_templates_registry():
|
||||
assert "zaya_v1" in AGGREGATION_TEMPLATES
|
||||
Reference in New Issue
Block a user