feat(prompts): add round-0 + zaya_v1 aggregation templates with versioned registry

This commit is contained in:
transcrilive
2026-05-10 02:50:44 +02:00
parent 3d595e021f
commit db710cc157
2 changed files with 127 additions and 0 deletions

View File

@@ -0,0 +1,77 @@
"""Prompt construction for Markovian RSA on ZAYA1.
Round-0 templates feed the original problem in plain user-message form.
Aggregation templates inject K sampled tails into a user message that asks
the model to reconcile and continue. ZAYA1's chat template (chat_template.jinja)
takes care of <|im_start|>/<|im_end|>/<think> wrapping ; we only build the
user-message body.
The exact co-trained format is not published in the ZAYA1 paper. zaya_v1 is
a reverse-engineered template. New templates can be A/B-tested via the
`aggregation_template` config field.
"""
from __future__ import annotations
from typing import Callable
ChatMessage = dict[str, str]
def build_round_0_messages(original_prompt: str) -> list[ChatMessage]:
"""Round 0 cold-start : plain user message with the problem."""
return [{"role": "user", "content": original_prompt}]
_ZAYA_V1_HEADER = "Original problem:\n{ORIGINAL_PROMPT}\n\n"
_ZAYA_V1_INTRO = (
"You have already produced {K} candidate reasoning approaches. "
"Each excerpt below is the most recent token tail from one attempt.\n\n"
)
_ZAYA_V1_TAIL_BLOCK = "[Approach {i}]\n{TAIL}\n\n"
_ZAYA_V1_FOOTER = (
"Your task: synthesize these approaches, identify the most promising "
"direction, correct any visible mistakes, and continue the reasoning "
"to reach a final answer."
)
def _build_zaya_v1(original_prompt: str, tails: list[str]) -> str:
parts: list[str] = []
parts.append(_ZAYA_V1_HEADER.format(ORIGINAL_PROMPT=original_prompt))
parts.append(_ZAYA_V1_INTRO.format(K=len(tails)))
for i, tail in enumerate(tails, start=1):
parts.append(_ZAYA_V1_TAIL_BLOCK.format(i=i, TAIL=tail))
parts.append(_ZAYA_V1_FOOTER)
return "".join(parts)
AGGREGATION_TEMPLATES: dict[str, Callable[[str, list[str]], str]] = {
"zaya_v1": _build_zaya_v1,
}
def build_aggregation_user_content(
original_prompt: str,
tails: list[str],
template: str = "zaya_v1",
) -> str:
"""Build the user-message body for an aggregation prompt.
The caller wraps this through `tokenizer.apply_chat_template(..., add_generation_prompt=True, enable_thinking=True)`
so ZAYA's chat template adds <|im_start|>/<|im_end|>/<think> wrappers.
"""
if template not in AGGREGATION_TEMPLATES:
raise ValueError(
f"unknown template '{template}' ; available: {sorted(AGGREGATION_TEMPLATES)}"
)
return AGGREGATION_TEMPLATES[template](original_prompt, tails)
def build_aggregation_messages(
original_prompt: str,
tails: list[str],
template: str = "zaya_v1",
) -> list[ChatMessage]:
"""Convenience : aggregation user content as chat messages."""
return [
{"role": "user", "content": build_aggregation_user_content(original_prompt, tails, template)}
]

50
tests/test_prompts.py Normal file
View File

@@ -0,0 +1,50 @@
import pytest
from markovian_rsa_mlx.prompts import (
build_round_0_messages,
build_aggregation_user_content,
AGGREGATION_TEMPLATES,
)
def test_round_0_returns_chat_messages_with_user_role():
msgs = build_round_0_messages("What is 2+2?")
assert msgs == [{"role": "user", "content": "What is 2+2?"}]
def test_zaya_v1_aggregation_includes_problem_and_K_tails():
body = build_aggregation_user_content(
original_prompt="What is 2+2?",
tails=["...solution attempt 1...", "...solution attempt 2..."],
template="zaya_v1",
)
assert "Original problem" in body
assert "What is 2+2?" in body
assert "[Approach 1]" in body
assert "...solution attempt 1..." in body
assert "[Approach 2]" in body
assert "...solution attempt 2..." in body
assert "synthesize" in body.lower()
assert "{K}" not in body # placeholder must be filled
def test_zaya_v1_uses_K_equal_len_tails():
body = build_aggregation_user_content(
original_prompt="P",
tails=["t1", "t2", "t3"],
template="zaya_v1",
)
assert "3 candidate" in body # K=3
assert "[Approach 3]" in body
def test_unknown_template_raises():
with pytest.raises(ValueError, match="unknown template"):
build_aggregation_user_content(
original_prompt="P",
tails=["t"],
template="does_not_exist",
)
def test_aggregation_templates_registry():
assert "zaya_v1" in AGGREGATION_TEMPLATES