feat(cli): add solve subcommand with profile + per-flag overrides + JSONL audit
This commit is contained in:
@@ -1,4 +1,9 @@
|
||||
"""Typer CLI for markovian-rsa-mlx."""
|
||||
from __future__ import annotations
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
|
||||
app = typer.Typer(help="Markovian RSA orchestrator for ZAYA1-8B on MLX.")
|
||||
@@ -16,5 +21,96 @@ def version() -> None:
|
||||
typer.echo(__version__)
|
||||
|
||||
|
||||
@app.command()
|
||||
def solve(
|
||||
prompt: Optional[str] = typer.Argument(None, help="Problem text. Mutually exclusive with --prompt-file."),
|
||||
prompt_file: Optional[Path] = typer.Option(None, "--prompt-file", "-f"),
|
||||
model: str = typer.Option("kyr0/zaya1-base-8b-MLX", "--model", help="MLX model id or local path."),
|
||||
quantization: str = typer.Option("q4_g64", "--quantization", help="q4_g64 or bf16 (metadata)."),
|
||||
profile: str = typer.Option("default-16gb", "--profile",
|
||||
help="default-16gb | paper-16k | paper-headline-40k"),
|
||||
rounds: Optional[int] = typer.Option(None, "--rounds", "-T"),
|
||||
parallel: Optional[int] = typer.Option(None, "--parallel", "-N"),
|
||||
aggregation_subsample: Optional[int] = typer.Option(None, "--aggregation-subsample", "-K"),
|
||||
tail_tokens: Optional[int] = typer.Option(None, "--tail-tokens"),
|
||||
chunk_tokens: Optional[int] = typer.Option(None, "--chunk-tokens"),
|
||||
final_tokens: Optional[int] = typer.Option(None, "--final-tokens"),
|
||||
temperature: Optional[float] = typer.Option(None, "--temperature"),
|
||||
top_p: Optional[float] = typer.Option(None, "--top-p"),
|
||||
top_k: Optional[int] = typer.Option(None, "--top-k"),
|
||||
seed: Optional[int] = typer.Option(None, "--seed"),
|
||||
serial: bool = typer.Option(False, "--serial", help="Force sequential decodes."),
|
||||
no_auto_serial: bool = typer.Option(False, "--no-auto-serial",
|
||||
help="Disable automatic OOM fallback."),
|
||||
memory_fraction: Optional[float] = typer.Option(None, "--memory-fraction"),
|
||||
output: Optional[Path] = typer.Option(None, "--output", "-o", help="Final text path. Default stdout."),
|
||||
audit: Optional[Path] = typer.Option(None, "--audit", help="JSONL audit path."),
|
||||
json_summary: bool = typer.Option(False, "--json", help="Emit JSON RSAResult summary to stdout."),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v"),
|
||||
) -> None:
|
||||
"""Solve a single problem with Markovian RSA."""
|
||||
from markovian_rsa_mlx.config import RSAConfig
|
||||
from markovian_rsa_mlx.orchestrator import MarkovianRSAOrchestrator
|
||||
import dataclasses
|
||||
import json as _json
|
||||
|
||||
if prompt is None and prompt_file is None:
|
||||
typer.echo("error: provide PROMPT positional or --prompt-file", err=True)
|
||||
raise typer.Exit(2)
|
||||
if prompt and prompt_file:
|
||||
typer.echo("error: PROMPT and --prompt-file are mutually exclusive", err=True)
|
||||
raise typer.Exit(2)
|
||||
if prompt_file:
|
||||
prompt = prompt_file.read_text().strip()
|
||||
|
||||
profile_map = {
|
||||
"default-16gb": RSAConfig.default_16gb,
|
||||
"paper-16k": RSAConfig.paper_16k,
|
||||
"paper-headline-40k": RSAConfig.paper_headline_40k,
|
||||
}
|
||||
if profile not in profile_map:
|
||||
typer.echo(f"error: unknown profile '{profile}'", err=True)
|
||||
raise typer.Exit(2)
|
||||
cfg = profile_map[profile]()
|
||||
|
||||
overrides: dict = {}
|
||||
if rounds is not None: overrides["rounds"] = rounds
|
||||
if parallel is not None: overrides["parallel"] = parallel
|
||||
if aggregation_subsample is not None: overrides["aggregation_subsample"] = aggregation_subsample
|
||||
if tail_tokens is not None: overrides["tail_tokens"] = tail_tokens
|
||||
if chunk_tokens is not None: overrides["chunk_tokens"] = chunk_tokens
|
||||
if final_tokens is not None: overrides["final_tokens"] = final_tokens
|
||||
if temperature is not None: overrides["temperature"] = temperature
|
||||
if top_p is not None: overrides["top_p"] = top_p
|
||||
if top_k is not None: overrides["top_k"] = top_k
|
||||
if seed is not None: overrides["seed"] = seed
|
||||
if memory_fraction is not None: overrides["memory_fraction"] = memory_fraction
|
||||
if serial: overrides["serial"] = True
|
||||
if no_auto_serial: overrides["auto_serial"] = False
|
||||
if overrides:
|
||||
cfg = cfg.replace(**overrides)
|
||||
|
||||
if verbose:
|
||||
typer.echo(f"[mrm] config={cfg}", err=True)
|
||||
typer.echo(f"[mrm] loading model {model} ...", err=True)
|
||||
|
||||
orch = MarkovianRSAOrchestrator.from_pretrained(model, quantization=quantization)
|
||||
text, result = orch.solve(prompt, config=cfg, return_audit=True, audit_path=audit)
|
||||
|
||||
if output is not None:
|
||||
output.write_text(text)
|
||||
else:
|
||||
typer.echo(text)
|
||||
|
||||
if json_summary:
|
||||
summary = {
|
||||
"run_id": result.run_id, "model_id": result.model_id,
|
||||
"quantization": result.quantization, "config": dataclasses.asdict(result.config),
|
||||
"stats": dataclasses.asdict(result.stats), "rounds": len(result.rounds),
|
||||
"audit_path": str(result.audit_path) if result.audit_path else None,
|
||||
}
|
||||
sys.stderr.write(_json.dumps(summary, indent=2) + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
||||
19
tests/test_cli.py
Normal file
19
tests/test_cli.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from typer.testing import CliRunner
|
||||
from markovian_rsa_mlx.cli import app
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
def test_version_command_prints_version():
|
||||
result = runner.invoke(app, ["version"])
|
||||
assert result.exit_code == 0
|
||||
assert "0.1.0" in result.stdout
|
||||
|
||||
|
||||
def test_solve_help_shows_required_flags():
|
||||
result = runner.invoke(app, ["solve", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "--model" in result.stdout
|
||||
assert "--rounds" in result.stdout
|
||||
assert "--parallel" in result.stdout
|
||||
assert "--audit" in result.stdout
|
||||
Reference in New Issue
Block a user