feat(prompts): add round-0 + zaya_v1 aggregation templates with versioned registry
This commit is contained in:
77
src/markovian_rsa_mlx/prompts.py
Normal file
77
src/markovian_rsa_mlx/prompts.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Prompt construction for Markovian RSA on ZAYA1.
|
||||
|
||||
Round-0 templates feed the original problem in plain user-message form.
|
||||
Aggregation templates inject K sampled tails into a user message that asks
|
||||
the model to reconcile and continue. ZAYA1's chat template (chat_template.jinja)
|
||||
takes care of <|im_start|>/<|im_end|>/<think> wrapping ; we only build the
|
||||
user-message body.
|
||||
|
||||
The exact co-trained format is not published in the ZAYA1 paper. zaya_v1 is
|
||||
a reverse-engineered template. New templates can be A/B-tested via the
|
||||
`aggregation_template` config field.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
from typing import Callable
|
||||
|
||||
ChatMessage = dict[str, str]
|
||||
|
||||
|
||||
def build_round_0_messages(original_prompt: str) -> list[ChatMessage]:
|
||||
"""Round 0 cold-start : plain user message with the problem."""
|
||||
return [{"role": "user", "content": original_prompt}]
|
||||
|
||||
|
||||
_ZAYA_V1_HEADER = "Original problem:\n{ORIGINAL_PROMPT}\n\n"
|
||||
_ZAYA_V1_INTRO = (
|
||||
"You have already produced {K} candidate reasoning approaches. "
|
||||
"Each excerpt below is the most recent token tail from one attempt.\n\n"
|
||||
)
|
||||
_ZAYA_V1_TAIL_BLOCK = "[Approach {i}]\n{TAIL}\n\n"
|
||||
_ZAYA_V1_FOOTER = (
|
||||
"Your task: synthesize these approaches, identify the most promising "
|
||||
"direction, correct any visible mistakes, and continue the reasoning "
|
||||
"to reach a final answer."
|
||||
)
|
||||
|
||||
|
||||
def _build_zaya_v1(original_prompt: str, tails: list[str]) -> str:
|
||||
parts: list[str] = []
|
||||
parts.append(_ZAYA_V1_HEADER.format(ORIGINAL_PROMPT=original_prompt))
|
||||
parts.append(_ZAYA_V1_INTRO.format(K=len(tails)))
|
||||
for i, tail in enumerate(tails, start=1):
|
||||
parts.append(_ZAYA_V1_TAIL_BLOCK.format(i=i, TAIL=tail))
|
||||
parts.append(_ZAYA_V1_FOOTER)
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
AGGREGATION_TEMPLATES: dict[str, Callable[[str, list[str]], str]] = {
|
||||
"zaya_v1": _build_zaya_v1,
|
||||
}
|
||||
|
||||
|
||||
def build_aggregation_user_content(
|
||||
original_prompt: str,
|
||||
tails: list[str],
|
||||
template: str = "zaya_v1",
|
||||
) -> str:
|
||||
"""Build the user-message body for an aggregation prompt.
|
||||
|
||||
The caller wraps this through `tokenizer.apply_chat_template(..., add_generation_prompt=True, enable_thinking=True)`
|
||||
so ZAYA's chat template adds <|im_start|>/<|im_end|>/<think> wrappers.
|
||||
"""
|
||||
if template not in AGGREGATION_TEMPLATES:
|
||||
raise ValueError(
|
||||
f"unknown template '{template}' ; available: {sorted(AGGREGATION_TEMPLATES)}"
|
||||
)
|
||||
return AGGREGATION_TEMPLATES[template](original_prompt, tails)
|
||||
|
||||
|
||||
def build_aggregation_messages(
|
||||
original_prompt: str,
|
||||
tails: list[str],
|
||||
template: str = "zaya_v1",
|
||||
) -> list[ChatMessage]:
|
||||
"""Convenience : aggregation user content as chat messages."""
|
||||
return [
|
||||
{"role": "user", "content": build_aggregation_user_content(original_prompt, tails, template)}
|
||||
]
|
||||
Reference in New Issue
Block a user