Files
granite-speech-4.1-2b-plus-mlx/src/granite_speech_plus_mlx/_vendored/loader.py
2026-05-09 20:00:57 +02:00

155 lines
4.1 KiB
Python

from __future__ import annotations
import glob
import json
from pathlib import Path
from typing import Any
from huggingface_hub import snapshot_download
import mlx.core as mx
import mlx.nn as nn
from .granite_speech import Model, ModelConfig
DEFAULT_ALLOW_PATTERNS = [
"*.json",
"*.safetensors",
"*.py",
"*.model",
"*.tiktoken",
"*.txt",
"*.jsonl",
"*.yaml",
"*.npz",
]
def _is_local_path(path: str) -> bool:
return (
path.startswith(".")
or path.startswith("/")
or path.startswith("~")
or (len(path) > 1 and path[1] == ":")
)
def get_model_path(
path_or_hf_repo: str | Path,
*,
revision: str | None = None,
force_download: bool = False,
allow_patterns: list[str] | None = None,
) -> Path:
if isinstance(path_or_hf_repo, Path):
path = path_or_hf_repo.expanduser()
if path.exists():
return path
raise FileNotFoundError(f"Local path not found: {path_or_hf_repo}")
path = Path(path_or_hf_repo).expanduser()
if path.exists():
return path
if _is_local_path(path_or_hf_repo):
raise FileNotFoundError(f"Local path not found: {path_or_hf_repo}")
return Path(
snapshot_download(
path_or_hf_repo,
revision=revision,
allow_patterns=allow_patterns or DEFAULT_ALLOW_PATTERNS,
force_download=force_download,
)
)
def load_config(model_path: str | Path) -> dict[str, Any]:
model_path = Path(model_path)
config_file = model_path / "config.json"
if not config_file.exists():
raise FileNotFoundError(f"Config not found at {model_path}")
return json.loads(config_file.read_text(encoding="utf-8"))
def load_weights(model_path: Path) -> dict[str, mx.array]:
weight_files = sorted(glob.glob(str(model_path / "*.safetensors")))
if not weight_files:
weight_files = sorted(glob.glob(str(model_path / "*.npz")))
if not weight_files:
raise FileNotFoundError(
f"No weight files (safetensors or npz) found in {model_path}"
)
weights = {}
for weight_file in weight_files:
weights.update(mx.load(weight_file))
return weights
def apply_quantization(
model: nn.Module,
config: dict[str, Any],
weights: dict[str, mx.array],
model_quant_predicate=None,
) -> None:
quantization = config.get("quantization") or config.get("quantization_config")
if quantization is None:
return
group_size = quantization.get("group_size", 64)
def class_predicate(path, module):
if not hasattr(module, "to_quantized"):
return False
if hasattr(module, "weight") and module.weight.shape[-1] % group_size != 0:
return False
if model_quant_predicate is not None:
pred = model_quant_predicate(path, module)
if isinstance(pred, dict):
return pred
if not pred:
return False
if path in quantization:
return quantization[path]
return f"{path}.scales" in weights
nn.quantize(
model,
group_size=group_size,
bits=quantization["bits"],
mode=quantization.get("mode", "affine"),
class_predicate=class_predicate,
)
def load_model(
model_path: str | Path,
*,
lazy: bool = False,
strict: bool = False,
**kwargs: Any,
) -> nn.Module:
path = get_model_path(
model_path,
revision=kwargs.pop("revision", None),
force_download=kwargs.pop("force_download", False),
allow_patterns=kwargs.pop("allow_patterns", None),
)
config = load_config(path)
model = Model(ModelConfig.from_dict(config))
weights = load_weights(path)
if hasattr(model, "sanitize"):
weights = model.sanitize(weights)
apply_quantization(model, config, weights, model.model_quant_predicate)
model.load_weights(list(weights.items()), strict=strict)
if not lazy:
mx.eval(model.parameters())
model.eval()
if hasattr(Model, "post_load_hook"):
model = Model.post_load_hook(model, path)
return model