feat(prompts): add round-0 + zaya_v1 aggregation templates with versioned registry

This commit is contained in:
transcrilive
2026-05-10 02:50:44 +02:00
parent 3d595e021f
commit db710cc157
2 changed files with 127 additions and 0 deletions

View File

@@ -0,0 +1,77 @@
"""Prompt construction for Markovian RSA on ZAYA1.
Round-0 templates feed the original problem in plain user-message form.
Aggregation templates inject K sampled tails into a user message that asks
the model to reconcile and continue. ZAYA1's chat template (chat_template.jinja)
takes care of <|im_start|>/<|im_end|>/<think> wrapping ; we only build the
user-message body.
The exact co-trained format is not published in the ZAYA1 paper. zaya_v1 is
a reverse-engineered template. New templates can be A/B-tested via the
`aggregation_template` config field.
"""
from __future__ import annotations
from typing import Callable
ChatMessage = dict[str, str]
def build_round_0_messages(original_prompt: str) -> list[ChatMessage]:
"""Round 0 cold-start : plain user message with the problem."""
return [{"role": "user", "content": original_prompt}]
_ZAYA_V1_HEADER = "Original problem:\n{ORIGINAL_PROMPT}\n\n"
_ZAYA_V1_INTRO = (
"You have already produced {K} candidate reasoning approaches. "
"Each excerpt below is the most recent token tail from one attempt.\n\n"
)
_ZAYA_V1_TAIL_BLOCK = "[Approach {i}]\n{TAIL}\n\n"
_ZAYA_V1_FOOTER = (
"Your task: synthesize these approaches, identify the most promising "
"direction, correct any visible mistakes, and continue the reasoning "
"to reach a final answer."
)
def _build_zaya_v1(original_prompt: str, tails: list[str]) -> str:
parts: list[str] = []
parts.append(_ZAYA_V1_HEADER.format(ORIGINAL_PROMPT=original_prompt))
parts.append(_ZAYA_V1_INTRO.format(K=len(tails)))
for i, tail in enumerate(tails, start=1):
parts.append(_ZAYA_V1_TAIL_BLOCK.format(i=i, TAIL=tail))
parts.append(_ZAYA_V1_FOOTER)
return "".join(parts)
AGGREGATION_TEMPLATES: dict[str, Callable[[str, list[str]], str]] = {
"zaya_v1": _build_zaya_v1,
}
def build_aggregation_user_content(
original_prompt: str,
tails: list[str],
template: str = "zaya_v1",
) -> str:
"""Build the user-message body for an aggregation prompt.
The caller wraps this through `tokenizer.apply_chat_template(..., add_generation_prompt=True, enable_thinking=True)`
so ZAYA's chat template adds <|im_start|>/<|im_end|>/<think> wrappers.
"""
if template not in AGGREGATION_TEMPLATES:
raise ValueError(
f"unknown template '{template}' ; available: {sorted(AGGREGATION_TEMPLATES)}"
)
return AGGREGATION_TEMPLATES[template](original_prompt, tails)
def build_aggregation_messages(
original_prompt: str,
tails: list[str],
template: str = "zaya_v1",
) -> list[ChatMessage]:
"""Convenience : aggregation user content as chat messages."""
return [
{"role": "user", "content": build_aggregation_user_content(original_prompt, tails, template)}
]