From b58b8945676f800ed08334b6507c0befdcd6c5db Mon Sep 17 00:00:00 2001 From: transcrilive Date: Sun, 10 May 2026 03:15:21 +0200 Subject: [PATCH] feat(cli): add solve subcommand with profile + per-flag overrides + JSONL audit --- src/markovian_rsa_mlx/cli.py | 96 ++++++++++++++++++++++++++++++++++++ tests/test_cli.py | 19 +++++++ 2 files changed, 115 insertions(+) create mode 100644 tests/test_cli.py diff --git a/src/markovian_rsa_mlx/cli.py b/src/markovian_rsa_mlx/cli.py index 748b75c..f61f018 100644 --- a/src/markovian_rsa_mlx/cli.py +++ b/src/markovian_rsa_mlx/cli.py @@ -1,4 +1,9 @@ """Typer CLI for markovian-rsa-mlx.""" +from __future__ import annotations +import sys +from pathlib import Path +from typing import Optional + import typer app = typer.Typer(help="Markovian RSA orchestrator for ZAYA1-8B on MLX.") @@ -16,5 +21,96 @@ def version() -> None: typer.echo(__version__) +@app.command() +def solve( + prompt: Optional[str] = typer.Argument(None, help="Problem text. Mutually exclusive with --prompt-file."), + prompt_file: Optional[Path] = typer.Option(None, "--prompt-file", "-f"), + model: str = typer.Option("kyr0/zaya1-base-8b-MLX", "--model", help="MLX model id or local path."), + quantization: str = typer.Option("q4_g64", "--quantization", help="q4_g64 or bf16 (metadata)."), + profile: str = typer.Option("default-16gb", "--profile", + help="default-16gb | paper-16k | paper-headline-40k"), + rounds: Optional[int] = typer.Option(None, "--rounds", "-T"), + parallel: Optional[int] = typer.Option(None, "--parallel", "-N"), + aggregation_subsample: Optional[int] = typer.Option(None, "--aggregation-subsample", "-K"), + tail_tokens: Optional[int] = typer.Option(None, "--tail-tokens"), + chunk_tokens: Optional[int] = typer.Option(None, "--chunk-tokens"), + final_tokens: Optional[int] = typer.Option(None, "--final-tokens"), + temperature: Optional[float] = typer.Option(None, "--temperature"), + top_p: Optional[float] = typer.Option(None, "--top-p"), + top_k: Optional[int] = typer.Option(None, "--top-k"), + seed: Optional[int] = typer.Option(None, "--seed"), + serial: bool = typer.Option(False, "--serial", help="Force sequential decodes."), + no_auto_serial: bool = typer.Option(False, "--no-auto-serial", + help="Disable automatic OOM fallback."), + memory_fraction: Optional[float] = typer.Option(None, "--memory-fraction"), + output: Optional[Path] = typer.Option(None, "--output", "-o", help="Final text path. Default stdout."), + audit: Optional[Path] = typer.Option(None, "--audit", help="JSONL audit path."), + json_summary: bool = typer.Option(False, "--json", help="Emit JSON RSAResult summary to stdout."), + verbose: bool = typer.Option(False, "--verbose", "-v"), +) -> None: + """Solve a single problem with Markovian RSA.""" + from markovian_rsa_mlx.config import RSAConfig + from markovian_rsa_mlx.orchestrator import MarkovianRSAOrchestrator + import dataclasses + import json as _json + + if prompt is None and prompt_file is None: + typer.echo("error: provide PROMPT positional or --prompt-file", err=True) + raise typer.Exit(2) + if prompt and prompt_file: + typer.echo("error: PROMPT and --prompt-file are mutually exclusive", err=True) + raise typer.Exit(2) + if prompt_file: + prompt = prompt_file.read_text().strip() + + profile_map = { + "default-16gb": RSAConfig.default_16gb, + "paper-16k": RSAConfig.paper_16k, + "paper-headline-40k": RSAConfig.paper_headline_40k, + } + if profile not in profile_map: + typer.echo(f"error: unknown profile '{profile}'", err=True) + raise typer.Exit(2) + cfg = profile_map[profile]() + + overrides: dict = {} + if rounds is not None: overrides["rounds"] = rounds + if parallel is not None: overrides["parallel"] = parallel + if aggregation_subsample is not None: overrides["aggregation_subsample"] = aggregation_subsample + if tail_tokens is not None: overrides["tail_tokens"] = tail_tokens + if chunk_tokens is not None: overrides["chunk_tokens"] = chunk_tokens + if final_tokens is not None: overrides["final_tokens"] = final_tokens + if temperature is not None: overrides["temperature"] = temperature + if top_p is not None: overrides["top_p"] = top_p + if top_k is not None: overrides["top_k"] = top_k + if seed is not None: overrides["seed"] = seed + if memory_fraction is not None: overrides["memory_fraction"] = memory_fraction + if serial: overrides["serial"] = True + if no_auto_serial: overrides["auto_serial"] = False + if overrides: + cfg = cfg.replace(**overrides) + + if verbose: + typer.echo(f"[mrm] config={cfg}", err=True) + typer.echo(f"[mrm] loading model {model} ...", err=True) + + orch = MarkovianRSAOrchestrator.from_pretrained(model, quantization=quantization) + text, result = orch.solve(prompt, config=cfg, return_audit=True, audit_path=audit) + + if output is not None: + output.write_text(text) + else: + typer.echo(text) + + if json_summary: + summary = { + "run_id": result.run_id, "model_id": result.model_id, + "quantization": result.quantization, "config": dataclasses.asdict(result.config), + "stats": dataclasses.asdict(result.stats), "rounds": len(result.rounds), + "audit_path": str(result.audit_path) if result.audit_path else None, + } + sys.stderr.write(_json.dumps(summary, indent=2) + "\n") + + if __name__ == "__main__": app() diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..c356da4 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,19 @@ +from typer.testing import CliRunner +from markovian_rsa_mlx.cli import app + +runner = CliRunner() + + +def test_version_command_prints_version(): + result = runner.invoke(app, ["version"]) + assert result.exit_code == 0 + assert "0.1.0" in result.stdout + + +def test_solve_help_shows_required_flags(): + result = runner.invoke(app, ["solve", "--help"]) + assert result.exit_code == 0 + assert "--model" in result.stdout + assert "--rounds" in result.stdout + assert "--parallel" in result.stdout + assert "--audit" in result.stdout