feat(loader): add load_zaya_model wrapper around mlx-lm.load
This commit is contained in:
@@ -1,2 +1,7 @@
|
|||||||
"""Markovian RSA test-time compute methodology on MLX."""
|
"""Markovian RSA test-time compute methodology on MLX."""
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
|
from markovian_rsa_mlx.config import RSAConfig
|
||||||
|
from markovian_rsa_mlx.loader import load_zaya_model
|
||||||
|
|
||||||
|
__all__ = ["__version__", "RSAConfig", "load_zaya_model"]
|
||||||
|
|||||||
22
src/markovian_rsa_mlx/loader.py
Normal file
22
src/markovian_rsa_mlx/loader.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""Model + tokenizer loading via mlx-lm. Local paths and HF repo ids supported."""
|
||||||
|
from __future__ import annotations
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_model_path(model_id: str) -> str:
|
||||||
|
"""Pass-through resolver. mlx-lm handles HF download for repo ids."""
|
||||||
|
p = Path(model_id)
|
||||||
|
if p.exists():
|
||||||
|
return str(p)
|
||||||
|
return model_id
|
||||||
|
|
||||||
|
|
||||||
|
def load_zaya_model(model_id: str, *, trust_remote_code: bool = True) -> tuple[Any, Any]:
|
||||||
|
"""Load a ZAYA-supporting MLX checkpoint via mlx-lm.
|
||||||
|
|
||||||
|
Requires the kyr0/mlx-lm fork (feat/zaya-support) which adds the `zaya`
|
||||||
|
architecture support. Returns (model, tokenizer).
|
||||||
|
"""
|
||||||
|
from mlx_lm import load
|
||||||
|
return load(resolve_model_path(model_id), tokenizer_config={"trust_remote_code": trust_remote_code})
|
||||||
25
tests/test_loader.py
Normal file
25
tests/test_loader.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import pytest
|
||||||
|
from markovian_rsa_mlx.loader import resolve_model_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_local_path_passthrough(tmp_path):
|
||||||
|
# local dir is returned as-is
|
||||||
|
d = tmp_path / "weights"
|
||||||
|
d.mkdir()
|
||||||
|
assert resolve_model_path(str(d)) == str(d)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_hf_id_format_unchanged():
|
||||||
|
# HF repo id is returned as-is (load_model handles download)
|
||||||
|
assert resolve_model_path("kyr0/zaya1-base-8b-MLX") == "kyr0/zaya1-base-8b-MLX"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
def test_load_real_zaya_q4_smoke():
|
||||||
|
"""Hits the actual Q4 weights — slow, requires HF cache."""
|
||||||
|
from markovian_rsa_mlx.loader import load_zaya_model
|
||||||
|
model, tokenizer = load_zaya_model("kyr0/zaya1-base-8b-MLX")
|
||||||
|
assert model is not None
|
||||||
|
assert tokenizer is not None
|
||||||
|
assert hasattr(tokenizer, "encode")
|
||||||
|
assert hasattr(tokenizer, "decode")
|
||||||
Reference in New Issue
Block a user