From db710cc157ab553ec6cfe19d60b6e9a1843aeb8b Mon Sep 17 00:00:00 2001 From: transcrilive Date: Sun, 10 May 2026 02:50:44 +0200 Subject: [PATCH] feat(prompts): add round-0 + zaya_v1 aggregation templates with versioned registry --- src/markovian_rsa_mlx/prompts.py | 77 ++++++++++++++++++++++++++++++++ tests/test_prompts.py | 50 +++++++++++++++++++++ 2 files changed, 127 insertions(+) create mode 100644 src/markovian_rsa_mlx/prompts.py create mode 100644 tests/test_prompts.py diff --git a/src/markovian_rsa_mlx/prompts.py b/src/markovian_rsa_mlx/prompts.py new file mode 100644 index 0000000..dc528f3 --- /dev/null +++ b/src/markovian_rsa_mlx/prompts.py @@ -0,0 +1,77 @@ +"""Prompt construction for Markovian RSA on ZAYA1. + +Round-0 templates feed the original problem in plain user-message form. +Aggregation templates inject K sampled tails into a user message that asks +the model to reconcile and continue. ZAYA1's chat template (chat_template.jinja) +takes care of <|im_start|>/<|im_end|>/ wrapping ; we only build the +user-message body. + +The exact co-trained format is not published in the ZAYA1 paper. zaya_v1 is +a reverse-engineered template. New templates can be A/B-tested via the +`aggregation_template` config field. +""" +from __future__ import annotations +from typing import Callable + +ChatMessage = dict[str, str] + + +def build_round_0_messages(original_prompt: str) -> list[ChatMessage]: + """Round 0 cold-start : plain user message with the problem.""" + return [{"role": "user", "content": original_prompt}] + + +_ZAYA_V1_HEADER = "Original problem:\n{ORIGINAL_PROMPT}\n\n" +_ZAYA_V1_INTRO = ( + "You have already produced {K} candidate reasoning approaches. " + "Each excerpt below is the most recent token tail from one attempt.\n\n" +) +_ZAYA_V1_TAIL_BLOCK = "[Approach {i}]\n{TAIL}\n\n" +_ZAYA_V1_FOOTER = ( + "Your task: synthesize these approaches, identify the most promising " + "direction, correct any visible mistakes, and continue the reasoning " + "to reach a final answer." +) + + +def _build_zaya_v1(original_prompt: str, tails: list[str]) -> str: + parts: list[str] = [] + parts.append(_ZAYA_V1_HEADER.format(ORIGINAL_PROMPT=original_prompt)) + parts.append(_ZAYA_V1_INTRO.format(K=len(tails))) + for i, tail in enumerate(tails, start=1): + parts.append(_ZAYA_V1_TAIL_BLOCK.format(i=i, TAIL=tail)) + parts.append(_ZAYA_V1_FOOTER) + return "".join(parts) + + +AGGREGATION_TEMPLATES: dict[str, Callable[[str, list[str]], str]] = { + "zaya_v1": _build_zaya_v1, +} + + +def build_aggregation_user_content( + original_prompt: str, + tails: list[str], + template: str = "zaya_v1", +) -> str: + """Build the user-message body for an aggregation prompt. + + The caller wraps this through `tokenizer.apply_chat_template(..., add_generation_prompt=True, enable_thinking=True)` + so ZAYA's chat template adds <|im_start|>/<|im_end|>/ wrappers. + """ + if template not in AGGREGATION_TEMPLATES: + raise ValueError( + f"unknown template '{template}' ; available: {sorted(AGGREGATION_TEMPLATES)}" + ) + return AGGREGATION_TEMPLATES[template](original_prompt, tails) + + +def build_aggregation_messages( + original_prompt: str, + tails: list[str], + template: str = "zaya_v1", +) -> list[ChatMessage]: + """Convenience : aggregation user content as chat messages.""" + return [ + {"role": "user", "content": build_aggregation_user_content(original_prompt, tails, template)} + ] diff --git a/tests/test_prompts.py b/tests/test_prompts.py new file mode 100644 index 0000000..4f993f7 --- /dev/null +++ b/tests/test_prompts.py @@ -0,0 +1,50 @@ +import pytest +from markovian_rsa_mlx.prompts import ( + build_round_0_messages, + build_aggregation_user_content, + AGGREGATION_TEMPLATES, +) + + +def test_round_0_returns_chat_messages_with_user_role(): + msgs = build_round_0_messages("What is 2+2?") + assert msgs == [{"role": "user", "content": "What is 2+2?"}] + + +def test_zaya_v1_aggregation_includes_problem_and_K_tails(): + body = build_aggregation_user_content( + original_prompt="What is 2+2?", + tails=["...solution attempt 1...", "...solution attempt 2..."], + template="zaya_v1", + ) + assert "Original problem" in body + assert "What is 2+2?" in body + assert "[Approach 1]" in body + assert "...solution attempt 1..." in body + assert "[Approach 2]" in body + assert "...solution attempt 2..." in body + assert "synthesize" in body.lower() + assert "{K}" not in body # placeholder must be filled + + +def test_zaya_v1_uses_K_equal_len_tails(): + body = build_aggregation_user_content( + original_prompt="P", + tails=["t1", "t2", "t3"], + template="zaya_v1", + ) + assert "3 candidate" in body # K=3 + assert "[Approach 3]" in body + + +def test_unknown_template_raises(): + with pytest.raises(ValueError, match="unknown template"): + build_aggregation_user_content( + original_prompt="P", + tails=["t"], + template="does_not_exist", + ) + + +def test_aggregation_templates_registry(): + assert "zaya_v1" in AGGREGATION_TEMPLATES