diff --git a/src/markovian_rsa_mlx/__init__.py b/src/markovian_rsa_mlx/__init__.py index d642c98..e5ae407 100644 --- a/src/markovian_rsa_mlx/__init__.py +++ b/src/markovian_rsa_mlx/__init__.py @@ -1,2 +1,7 @@ """Markovian RSA test-time compute methodology on MLX.""" __version__ = "0.1.0" + +from markovian_rsa_mlx.config import RSAConfig +from markovian_rsa_mlx.loader import load_zaya_model + +__all__ = ["__version__", "RSAConfig", "load_zaya_model"] diff --git a/src/markovian_rsa_mlx/loader.py b/src/markovian_rsa_mlx/loader.py new file mode 100644 index 0000000..2b3eebd --- /dev/null +++ b/src/markovian_rsa_mlx/loader.py @@ -0,0 +1,22 @@ +"""Model + tokenizer loading via mlx-lm. Local paths and HF repo ids supported.""" +from __future__ import annotations +from pathlib import Path +from typing import Any + + +def resolve_model_path(model_id: str) -> str: + """Pass-through resolver. mlx-lm handles HF download for repo ids.""" + p = Path(model_id) + if p.exists(): + return str(p) + return model_id + + +def load_zaya_model(model_id: str, *, trust_remote_code: bool = True) -> tuple[Any, Any]: + """Load a ZAYA-supporting MLX checkpoint via mlx-lm. + + Requires the kyr0/mlx-lm fork (feat/zaya-support) which adds the `zaya` + architecture support. Returns (model, tokenizer). + """ + from mlx_lm import load + return load(resolve_model_path(model_id), tokenizer_config={"trust_remote_code": trust_remote_code}) diff --git a/tests/test_loader.py b/tests/test_loader.py new file mode 100644 index 0000000..cb4f8df --- /dev/null +++ b/tests/test_loader.py @@ -0,0 +1,25 @@ +import pytest +from markovian_rsa_mlx.loader import resolve_model_path + + +def test_resolve_local_path_passthrough(tmp_path): + # local dir is returned as-is + d = tmp_path / "weights" + d.mkdir() + assert resolve_model_path(str(d)) == str(d) + + +def test_resolve_hf_id_format_unchanged(): + # HF repo id is returned as-is (load_model handles download) + assert resolve_model_path("kyr0/zaya1-base-8b-MLX") == "kyr0/zaya1-base-8b-MLX" + + +@pytest.mark.integration +def test_load_real_zaya_q4_smoke(): + """Hits the actual Q4 weights — slow, requires HF cache.""" + from markovian_rsa_mlx.loader import load_zaya_model + model, tokenizer = load_zaya_model("kyr0/zaya1-base-8b-MLX") + assert model is not None + assert tokenizer is not None + assert hasattr(tokenizer, "encode") + assert hasattr(tokenizer, "decode")