feat(cli): add solve subcommand with profile + per-flag overrides + JSONL audit

This commit is contained in:
transcrilive
2026-05-10 03:15:21 +02:00
parent d4c241f91a
commit b58b894567
2 changed files with 115 additions and 0 deletions

View File

@@ -1,4 +1,9 @@
"""Typer CLI for markovian-rsa-mlx."""
from __future__ import annotations
import sys
from pathlib import Path
from typing import Optional
import typer
app = typer.Typer(help="Markovian RSA orchestrator for ZAYA1-8B on MLX.")
@@ -16,5 +21,96 @@ def version() -> None:
typer.echo(__version__)
@app.command()
def solve(
prompt: Optional[str] = typer.Argument(None, help="Problem text. Mutually exclusive with --prompt-file."),
prompt_file: Optional[Path] = typer.Option(None, "--prompt-file", "-f"),
model: str = typer.Option("kyr0/zaya1-base-8b-MLX", "--model", help="MLX model id or local path."),
quantization: str = typer.Option("q4_g64", "--quantization", help="q4_g64 or bf16 (metadata)."),
profile: str = typer.Option("default-16gb", "--profile",
help="default-16gb | paper-16k | paper-headline-40k"),
rounds: Optional[int] = typer.Option(None, "--rounds", "-T"),
parallel: Optional[int] = typer.Option(None, "--parallel", "-N"),
aggregation_subsample: Optional[int] = typer.Option(None, "--aggregation-subsample", "-K"),
tail_tokens: Optional[int] = typer.Option(None, "--tail-tokens"),
chunk_tokens: Optional[int] = typer.Option(None, "--chunk-tokens"),
final_tokens: Optional[int] = typer.Option(None, "--final-tokens"),
temperature: Optional[float] = typer.Option(None, "--temperature"),
top_p: Optional[float] = typer.Option(None, "--top-p"),
top_k: Optional[int] = typer.Option(None, "--top-k"),
seed: Optional[int] = typer.Option(None, "--seed"),
serial: bool = typer.Option(False, "--serial", help="Force sequential decodes."),
no_auto_serial: bool = typer.Option(False, "--no-auto-serial",
help="Disable automatic OOM fallback."),
memory_fraction: Optional[float] = typer.Option(None, "--memory-fraction"),
output: Optional[Path] = typer.Option(None, "--output", "-o", help="Final text path. Default stdout."),
audit: Optional[Path] = typer.Option(None, "--audit", help="JSONL audit path."),
json_summary: bool = typer.Option(False, "--json", help="Emit JSON RSAResult summary to stdout."),
verbose: bool = typer.Option(False, "--verbose", "-v"),
) -> None:
"""Solve a single problem with Markovian RSA."""
from markovian_rsa_mlx.config import RSAConfig
from markovian_rsa_mlx.orchestrator import MarkovianRSAOrchestrator
import dataclasses
import json as _json
if prompt is None and prompt_file is None:
typer.echo("error: provide PROMPT positional or --prompt-file", err=True)
raise typer.Exit(2)
if prompt and prompt_file:
typer.echo("error: PROMPT and --prompt-file are mutually exclusive", err=True)
raise typer.Exit(2)
if prompt_file:
prompt = prompt_file.read_text().strip()
profile_map = {
"default-16gb": RSAConfig.default_16gb,
"paper-16k": RSAConfig.paper_16k,
"paper-headline-40k": RSAConfig.paper_headline_40k,
}
if profile not in profile_map:
typer.echo(f"error: unknown profile '{profile}'", err=True)
raise typer.Exit(2)
cfg = profile_map[profile]()
overrides: dict = {}
if rounds is not None: overrides["rounds"] = rounds
if parallel is not None: overrides["parallel"] = parallel
if aggregation_subsample is not None: overrides["aggregation_subsample"] = aggregation_subsample
if tail_tokens is not None: overrides["tail_tokens"] = tail_tokens
if chunk_tokens is not None: overrides["chunk_tokens"] = chunk_tokens
if final_tokens is not None: overrides["final_tokens"] = final_tokens
if temperature is not None: overrides["temperature"] = temperature
if top_p is not None: overrides["top_p"] = top_p
if top_k is not None: overrides["top_k"] = top_k
if seed is not None: overrides["seed"] = seed
if memory_fraction is not None: overrides["memory_fraction"] = memory_fraction
if serial: overrides["serial"] = True
if no_auto_serial: overrides["auto_serial"] = False
if overrides:
cfg = cfg.replace(**overrides)
if verbose:
typer.echo(f"[mrm] config={cfg}", err=True)
typer.echo(f"[mrm] loading model {model} ...", err=True)
orch = MarkovianRSAOrchestrator.from_pretrained(model, quantization=quantization)
text, result = orch.solve(prompt, config=cfg, return_audit=True, audit_path=audit)
if output is not None:
output.write_text(text)
else:
typer.echo(text)
if json_summary:
summary = {
"run_id": result.run_id, "model_id": result.model_id,
"quantization": result.quantization, "config": dataclasses.asdict(result.config),
"stats": dataclasses.asdict(result.stats), "rounds": len(result.rounds),
"audit_path": str(result.audit_path) if result.audit_path else None,
}
sys.stderr.write(_json.dumps(summary, indent=2) + "\n")
if __name__ == "__main__":
app()

19
tests/test_cli.py Normal file
View File

@@ -0,0 +1,19 @@
from typer.testing import CliRunner
from markovian_rsa_mlx.cli import app
runner = CliRunner()
def test_version_command_prints_version():
result = runner.invoke(app, ["version"])
assert result.exit_code == 0
assert "0.1.0" in result.stdout
def test_solve_help_shows_required_flags():
result = runner.invoke(app, ["solve", "--help"])
assert result.exit_code == 0
assert "--model" in result.stdout
assert "--rounds" in result.stdout
assert "--parallel" in result.stdout
assert "--audit" in result.stdout