feat(config): add enable_thinking flag (default False) + fix HMMT bench gold answers
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -21,7 +21,7 @@ _HMMT_2025_SUBSET = [
|
|||||||
{
|
{
|
||||||
"id": "hmmt-1",
|
"id": "hmmt-1",
|
||||||
"question": "Find the number of positive integers n <= 100 such that n^2 + n is divisible by 6.",
|
"question": "Find the number of positive integers n <= 100 such that n^2 + n is divisible by 6.",
|
||||||
"answer": "100",
|
"answer": "66",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "hmmt-2",
|
"id": "hmmt-2",
|
||||||
@@ -41,7 +41,7 @@ _HMMT_2025_SUBSET = [
|
|||||||
{
|
{
|
||||||
"id": "hmmt-5",
|
"id": "hmmt-5",
|
||||||
"question": "What is the remainder when 2^100 is divided by 125?",
|
"question": "What is the remainder when 2^100 is divided by 125?",
|
||||||
"answer": "76",
|
"answer": "1",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class RSAConfig:
|
|||||||
memory_fraction: float = 0.80
|
memory_fraction: float = 0.80
|
||||||
max_context_tokens: int | None = None
|
max_context_tokens: int | None = None
|
||||||
aggregation_template: Literal["zaya_v1"] = "zaya_v1"
|
aggregation_template: Literal["zaya_v1"] = "zaya_v1"
|
||||||
|
enable_thinking: bool = False
|
||||||
answer_selection: Literal["first_final_candidate"] = "first_final_candidate"
|
answer_selection: Literal["first_final_candidate"] = "first_final_candidate"
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ class MarkovianRSAOrchestrator:
|
|||||||
parent_ids_per_trace: list[list[str]] = []
|
parent_ids_per_trace: list[list[str]] = []
|
||||||
if is_round_0:
|
if is_round_0:
|
||||||
messages = build_round_0_messages(original_prompt)
|
messages = build_round_0_messages(original_prompt)
|
||||||
prompt_ids = self._render_chat(messages)
|
prompt_ids = self._render_chat(messages, enable_thinking=cfg.enable_thinking)
|
||||||
prompts_token_ids = [prompt_ids for _ in range(cfg.parallel)]
|
prompts_token_ids = [prompt_ids for _ in range(cfg.parallel)]
|
||||||
parent_ids_per_trace = [[] for _ in range(cfg.parallel)]
|
parent_ids_per_trace = [[] for _ in range(cfg.parallel)]
|
||||||
else:
|
else:
|
||||||
@@ -171,7 +171,7 @@ class MarkovianRSAOrchestrator:
|
|||||||
original_prompt=original_prompt, tails=tails,
|
original_prompt=original_prompt, tails=tails,
|
||||||
template=cfg.aggregation_template,
|
template=cfg.aggregation_template,
|
||||||
)
|
)
|
||||||
prompt_ids = self._render_chat(messages)
|
prompt_ids = self._render_chat(messages, enable_thinking=cfg.enable_thinking)
|
||||||
child_trace_id = f"r{round_idx}-t{trace_idx}-{run_id[:6]}"
|
child_trace_id = f"r{round_idx}-t{trace_idx}-{run_id[:6]}"
|
||||||
audit.write(AggregationPromptEvent(
|
audit.write(AggregationPromptEvent(
|
||||||
run_id=run_id, round=round_idx, trace_id=child_trace_id,
|
run_id=run_id, round=round_idx, trace_id=child_trace_id,
|
||||||
@@ -217,10 +217,10 @@ class MarkovianRSAOrchestrator:
|
|||||||
round_elapsed = time.time() - round_t0
|
round_elapsed = time.time() - round_t0
|
||||||
return records, round_elapsed
|
return records, round_elapsed
|
||||||
|
|
||||||
def _render_chat(self, messages: list[dict[str, str]]) -> list[int]:
|
def _render_chat(self, messages: list[dict[str, str]], *, enable_thinking: bool) -> list[int]:
|
||||||
"""Apply ZAYA chat template and return token ids."""
|
"""Apply ZAYA chat template and return token ids."""
|
||||||
rendered = self.tokenizer.apply_chat_template(
|
rendered = self.tokenizer.apply_chat_template(
|
||||||
messages, add_generation_prompt=True, enable_thinking=True,
|
messages, add_generation_prompt=True, enable_thinking=enable_thinking,
|
||||||
)
|
)
|
||||||
if isinstance(rendered, str):
|
if isinstance(rendered, str):
|
||||||
return self.tokenizer.encode(rendered)
|
return self.tokenizer.encode(rendered)
|
||||||
|
|||||||
@@ -74,3 +74,13 @@ def test_final_tokens_zero_rejected():
|
|||||||
def test_final_tokens_negative_rejected():
|
def test_final_tokens_negative_rejected():
|
||||||
with pytest.raises(ValueError, match="final_tokens must be positive"):
|
with pytest.raises(ValueError, match="final_tokens must be positive"):
|
||||||
RSAConfig(final_tokens=-1)
|
RSAConfig(final_tokens=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enable_thinking_default_false():
|
||||||
|
cfg = RSAConfig()
|
||||||
|
assert cfg.enable_thinking is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_enable_thinking_can_be_enabled():
|
||||||
|
cfg = RSAConfig(enable_thinking=True)
|
||||||
|
assert cfg.enable_thinking is True
|
||||||
|
|||||||
Reference in New Issue
Block a user