import pytest from markovian_rsa_mlx.prompts import ( build_round_0_messages, build_aggregation_user_content, AGGREGATION_TEMPLATES, ) def test_round_0_returns_chat_messages_with_user_role(): msgs = build_round_0_messages("What is 2+2?") assert msgs == [{"role": "user", "content": "What is 2+2?"}] def test_zaya_v1_aggregation_includes_problem_and_K_tails(): body = build_aggregation_user_content( original_prompt="What is 2+2?", tails=["...solution attempt 1...", "...solution attempt 2..."], template="zaya_v1", ) assert "Original problem" in body assert "What is 2+2?" in body assert "[Approach 1]" in body assert "...solution attempt 1..." in body assert "[Approach 2]" in body assert "...solution attempt 2..." in body assert "synthesize" in body.lower() assert "{K}" not in body # placeholder must be filled def test_zaya_v1_uses_K_equal_len_tails(): body = build_aggregation_user_content( original_prompt="P", tails=["t1", "t2", "t3"], template="zaya_v1", ) assert "3 candidate" in body # K=3 assert "[Approach 3]" in body def test_unknown_template_raises(): with pytest.raises(ValueError, match="unknown template"): build_aggregation_user_content( original_prompt="P", tails=["t"], template="does_not_exist", ) def test_aggregation_templates_registry(): assert "zaya_v1" in AGGREGATION_TEMPLATES