from __future__ import annotations import glob import json from pathlib import Path from typing import Any from huggingface_hub import snapshot_download import mlx.core as mx import mlx.nn as nn from .granite_speech import Model, ModelConfig DEFAULT_ALLOW_PATTERNS = [ "*.json", "*.safetensors", "*.py", "*.model", "*.tiktoken", "*.txt", "*.jsonl", "*.yaml", "*.npz", ] def _is_local_path(path: str) -> bool: return ( path.startswith(".") or path.startswith("/") or path.startswith("~") or (len(path) > 1 and path[1] == ":") ) def get_model_path( path_or_hf_repo: str | Path, *, revision: str | None = None, force_download: bool = False, allow_patterns: list[str] | None = None, ) -> Path: if isinstance(path_or_hf_repo, Path): path = path_or_hf_repo.expanduser() if path.exists(): return path raise FileNotFoundError(f"Local path not found: {path_or_hf_repo}") path = Path(path_or_hf_repo).expanduser() if path.exists(): return path if _is_local_path(path_or_hf_repo): raise FileNotFoundError(f"Local path not found: {path_or_hf_repo}") return Path( snapshot_download( path_or_hf_repo, revision=revision, allow_patterns=allow_patterns or DEFAULT_ALLOW_PATTERNS, force_download=force_download, ) ) def load_config(model_path: str | Path) -> dict[str, Any]: model_path = Path(model_path) config_file = model_path / "config.json" if not config_file.exists(): raise FileNotFoundError(f"Config not found at {model_path}") return json.loads(config_file.read_text(encoding="utf-8")) def load_weights(model_path: Path) -> dict[str, mx.array]: weight_files = sorted(glob.glob(str(model_path / "*.safetensors"))) if not weight_files: weight_files = sorted(glob.glob(str(model_path / "*.npz"))) if not weight_files: raise FileNotFoundError( f"No weight files (safetensors or npz) found in {model_path}" ) weights = {} for weight_file in weight_files: weights.update(mx.load(weight_file)) return weights def apply_quantization( model: nn.Module, config: dict[str, Any], weights: dict[str, mx.array], model_quant_predicate=None, ) -> None: quantization = config.get("quantization") or config.get("quantization_config") if quantization is None: return group_size = quantization.get("group_size", 64) def class_predicate(path, module): if not hasattr(module, "to_quantized"): return False if hasattr(module, "weight") and module.weight.shape[-1] % group_size != 0: return False if model_quant_predicate is not None: pred = model_quant_predicate(path, module) if isinstance(pred, dict): return pred if not pred: return False if path in quantization: return quantization[path] return f"{path}.scales" in weights nn.quantize( model, group_size=group_size, bits=quantization["bits"], mode=quantization.get("mode", "affine"), class_predicate=class_predicate, ) def load_model( model_path: str | Path, *, lazy: bool = False, strict: bool = False, **kwargs: Any, ) -> nn.Module: path = get_model_path( model_path, revision=kwargs.pop("revision", None), force_download=kwargs.pop("force_download", False), allow_patterns=kwargs.pop("allow_patterns", None), ) config = load_config(path) model = Model(ModelConfig.from_dict(config)) weights = load_weights(path) if hasattr(model, "sanitize"): weights = model.sanitize(weights) apply_quantization(model, config, weights, model.model_quant_predicate) model.load_weights(list(weights.items()), strict=strict) if not lazy: mx.eval(model.parameters()) model.eval() if hasattr(Model, "post_load_hook"): model = Model.post_load_hook(model, path) return model