feat(cli): add solve subcommand with profile + per-flag overrides + JSONL audit
This commit is contained in:
@@ -1,4 +1,9 @@
|
|||||||
"""Typer CLI for markovian-rsa-mlx."""
|
"""Typer CLI for markovian-rsa-mlx."""
|
||||||
|
from __future__ import annotations
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
app = typer.Typer(help="Markovian RSA orchestrator for ZAYA1-8B on MLX.")
|
app = typer.Typer(help="Markovian RSA orchestrator for ZAYA1-8B on MLX.")
|
||||||
@@ -16,5 +21,96 @@ def version() -> None:
|
|||||||
typer.echo(__version__)
|
typer.echo(__version__)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def solve(
|
||||||
|
prompt: Optional[str] = typer.Argument(None, help="Problem text. Mutually exclusive with --prompt-file."),
|
||||||
|
prompt_file: Optional[Path] = typer.Option(None, "--prompt-file", "-f"),
|
||||||
|
model: str = typer.Option("kyr0/zaya1-base-8b-MLX", "--model", help="MLX model id or local path."),
|
||||||
|
quantization: str = typer.Option("q4_g64", "--quantization", help="q4_g64 or bf16 (metadata)."),
|
||||||
|
profile: str = typer.Option("default-16gb", "--profile",
|
||||||
|
help="default-16gb | paper-16k | paper-headline-40k"),
|
||||||
|
rounds: Optional[int] = typer.Option(None, "--rounds", "-T"),
|
||||||
|
parallel: Optional[int] = typer.Option(None, "--parallel", "-N"),
|
||||||
|
aggregation_subsample: Optional[int] = typer.Option(None, "--aggregation-subsample", "-K"),
|
||||||
|
tail_tokens: Optional[int] = typer.Option(None, "--tail-tokens"),
|
||||||
|
chunk_tokens: Optional[int] = typer.Option(None, "--chunk-tokens"),
|
||||||
|
final_tokens: Optional[int] = typer.Option(None, "--final-tokens"),
|
||||||
|
temperature: Optional[float] = typer.Option(None, "--temperature"),
|
||||||
|
top_p: Optional[float] = typer.Option(None, "--top-p"),
|
||||||
|
top_k: Optional[int] = typer.Option(None, "--top-k"),
|
||||||
|
seed: Optional[int] = typer.Option(None, "--seed"),
|
||||||
|
serial: bool = typer.Option(False, "--serial", help="Force sequential decodes."),
|
||||||
|
no_auto_serial: bool = typer.Option(False, "--no-auto-serial",
|
||||||
|
help="Disable automatic OOM fallback."),
|
||||||
|
memory_fraction: Optional[float] = typer.Option(None, "--memory-fraction"),
|
||||||
|
output: Optional[Path] = typer.Option(None, "--output", "-o", help="Final text path. Default stdout."),
|
||||||
|
audit: Optional[Path] = typer.Option(None, "--audit", help="JSONL audit path."),
|
||||||
|
json_summary: bool = typer.Option(False, "--json", help="Emit JSON RSAResult summary to stdout."),
|
||||||
|
verbose: bool = typer.Option(False, "--verbose", "-v"),
|
||||||
|
) -> None:
|
||||||
|
"""Solve a single problem with Markovian RSA."""
|
||||||
|
from markovian_rsa_mlx.config import RSAConfig
|
||||||
|
from markovian_rsa_mlx.orchestrator import MarkovianRSAOrchestrator
|
||||||
|
import dataclasses
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
if prompt is None and prompt_file is None:
|
||||||
|
typer.echo("error: provide PROMPT positional or --prompt-file", err=True)
|
||||||
|
raise typer.Exit(2)
|
||||||
|
if prompt and prompt_file:
|
||||||
|
typer.echo("error: PROMPT and --prompt-file are mutually exclusive", err=True)
|
||||||
|
raise typer.Exit(2)
|
||||||
|
if prompt_file:
|
||||||
|
prompt = prompt_file.read_text().strip()
|
||||||
|
|
||||||
|
profile_map = {
|
||||||
|
"default-16gb": RSAConfig.default_16gb,
|
||||||
|
"paper-16k": RSAConfig.paper_16k,
|
||||||
|
"paper-headline-40k": RSAConfig.paper_headline_40k,
|
||||||
|
}
|
||||||
|
if profile not in profile_map:
|
||||||
|
typer.echo(f"error: unknown profile '{profile}'", err=True)
|
||||||
|
raise typer.Exit(2)
|
||||||
|
cfg = profile_map[profile]()
|
||||||
|
|
||||||
|
overrides: dict = {}
|
||||||
|
if rounds is not None: overrides["rounds"] = rounds
|
||||||
|
if parallel is not None: overrides["parallel"] = parallel
|
||||||
|
if aggregation_subsample is not None: overrides["aggregation_subsample"] = aggregation_subsample
|
||||||
|
if tail_tokens is not None: overrides["tail_tokens"] = tail_tokens
|
||||||
|
if chunk_tokens is not None: overrides["chunk_tokens"] = chunk_tokens
|
||||||
|
if final_tokens is not None: overrides["final_tokens"] = final_tokens
|
||||||
|
if temperature is not None: overrides["temperature"] = temperature
|
||||||
|
if top_p is not None: overrides["top_p"] = top_p
|
||||||
|
if top_k is not None: overrides["top_k"] = top_k
|
||||||
|
if seed is not None: overrides["seed"] = seed
|
||||||
|
if memory_fraction is not None: overrides["memory_fraction"] = memory_fraction
|
||||||
|
if serial: overrides["serial"] = True
|
||||||
|
if no_auto_serial: overrides["auto_serial"] = False
|
||||||
|
if overrides:
|
||||||
|
cfg = cfg.replace(**overrides)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
typer.echo(f"[mrm] config={cfg}", err=True)
|
||||||
|
typer.echo(f"[mrm] loading model {model} ...", err=True)
|
||||||
|
|
||||||
|
orch = MarkovianRSAOrchestrator.from_pretrained(model, quantization=quantization)
|
||||||
|
text, result = orch.solve(prompt, config=cfg, return_audit=True, audit_path=audit)
|
||||||
|
|
||||||
|
if output is not None:
|
||||||
|
output.write_text(text)
|
||||||
|
else:
|
||||||
|
typer.echo(text)
|
||||||
|
|
||||||
|
if json_summary:
|
||||||
|
summary = {
|
||||||
|
"run_id": result.run_id, "model_id": result.model_id,
|
||||||
|
"quantization": result.quantization, "config": dataclasses.asdict(result.config),
|
||||||
|
"stats": dataclasses.asdict(result.stats), "rounds": len(result.rounds),
|
||||||
|
"audit_path": str(result.audit_path) if result.audit_path else None,
|
||||||
|
}
|
||||||
|
sys.stderr.write(_json.dumps(summary, indent=2) + "\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app()
|
app()
|
||||||
|
|||||||
19
tests/test_cli.py
Normal file
19
tests/test_cli.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from typer.testing import CliRunner
|
||||||
|
from markovian_rsa_mlx.cli import app
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
def test_version_command_prints_version():
|
||||||
|
result = runner.invoke(app, ["version"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "0.1.0" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_solve_help_shows_required_flags():
|
||||||
|
result = runner.invoke(app, ["solve", "--help"])
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "--model" in result.stdout
|
||||||
|
assert "--rounds" in result.stdout
|
||||||
|
assert "--parallel" in result.stdout
|
||||||
|
assert "--audit" in result.stdout
|
||||||
Reference in New Issue
Block a user