diff --git a/scripts/bench_hmmt.py b/scripts/bench_hmmt.py index 9be1b68..dd7ce92 100644 --- a/scripts/bench_hmmt.py +++ b/scripts/bench_hmmt.py @@ -21,7 +21,7 @@ _HMMT_2025_SUBSET = [ { "id": "hmmt-1", "question": "Find the number of positive integers n <= 100 such that n^2 + n is divisible by 6.", - "answer": "100", + "answer": "66", }, { "id": "hmmt-2", @@ -41,7 +41,7 @@ _HMMT_2025_SUBSET = [ { "id": "hmmt-5", "question": "What is the remainder when 2^100 is divided by 125?", - "answer": "76", + "answer": "1", }, ] diff --git a/src/markovian_rsa_mlx/config.py b/src/markovian_rsa_mlx/config.py index 3d10bf2..212d071 100644 --- a/src/markovian_rsa_mlx/config.py +++ b/src/markovian_rsa_mlx/config.py @@ -27,6 +27,7 @@ class RSAConfig: memory_fraction: float = 0.80 max_context_tokens: int | None = None aggregation_template: Literal["zaya_v1"] = "zaya_v1" + enable_thinking: bool = False answer_selection: Literal["first_final_candidate"] = "first_final_candidate" def __post_init__(self) -> None: diff --git a/src/markovian_rsa_mlx/orchestrator.py b/src/markovian_rsa_mlx/orchestrator.py index c7e39e1..61f6590 100644 --- a/src/markovian_rsa_mlx/orchestrator.py +++ b/src/markovian_rsa_mlx/orchestrator.py @@ -150,7 +150,7 @@ class MarkovianRSAOrchestrator: parent_ids_per_trace: list[list[str]] = [] if is_round_0: messages = build_round_0_messages(original_prompt) - prompt_ids = self._render_chat(messages) + prompt_ids = self._render_chat(messages, enable_thinking=cfg.enable_thinking) prompts_token_ids = [prompt_ids for _ in range(cfg.parallel)] parent_ids_per_trace = [[] for _ in range(cfg.parallel)] else: @@ -171,7 +171,7 @@ class MarkovianRSAOrchestrator: original_prompt=original_prompt, tails=tails, template=cfg.aggregation_template, ) - prompt_ids = self._render_chat(messages) + prompt_ids = self._render_chat(messages, enable_thinking=cfg.enable_thinking) child_trace_id = f"r{round_idx}-t{trace_idx}-{run_id[:6]}" audit.write(AggregationPromptEvent( run_id=run_id, round=round_idx, trace_id=child_trace_id, @@ -217,10 +217,10 @@ class MarkovianRSAOrchestrator: round_elapsed = time.time() - round_t0 return records, round_elapsed - def _render_chat(self, messages: list[dict[str, str]]) -> list[int]: + def _render_chat(self, messages: list[dict[str, str]], *, enable_thinking: bool) -> list[int]: """Apply ZAYA chat template and return token ids.""" rendered = self.tokenizer.apply_chat_template( - messages, add_generation_prompt=True, enable_thinking=True, + messages, add_generation_prompt=True, enable_thinking=enable_thinking, ) if isinstance(rendered, str): return self.tokenizer.encode(rendered) diff --git a/tests/test_config.py b/tests/test_config.py index 6db91e3..f46bc82 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -74,3 +74,13 @@ def test_final_tokens_zero_rejected(): def test_final_tokens_negative_rejected(): with pytest.raises(ValueError, match="final_tokens must be positive"): RSAConfig(final_tokens=-1) + + +def test_enable_thinking_default_false(): + cfg = RSAConfig() + assert cfg.enable_thinking is False + + +def test_enable_thinking_can_be_enabled(): + cfg = RSAConfig(enable_thinking=True) + assert cfg.enable_thinking is True