155 lines
4.1 KiB
Python
155 lines
4.1 KiB
Python
from __future__ import annotations
|
|
|
|
import glob
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from huggingface_hub import snapshot_download
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
from .granite_speech import Model, ModelConfig
|
|
|
|
DEFAULT_ALLOW_PATTERNS = [
|
|
"*.json",
|
|
"*.safetensors",
|
|
"*.py",
|
|
"*.model",
|
|
"*.tiktoken",
|
|
"*.txt",
|
|
"*.jsonl",
|
|
"*.yaml",
|
|
"*.npz",
|
|
]
|
|
|
|
|
|
def _is_local_path(path: str) -> bool:
|
|
return (
|
|
path.startswith(".")
|
|
or path.startswith("/")
|
|
or path.startswith("~")
|
|
or (len(path) > 1 and path[1] == ":")
|
|
)
|
|
|
|
|
|
def get_model_path(
|
|
path_or_hf_repo: str | Path,
|
|
*,
|
|
revision: str | None = None,
|
|
force_download: bool = False,
|
|
allow_patterns: list[str] | None = None,
|
|
) -> Path:
|
|
if isinstance(path_or_hf_repo, Path):
|
|
path = path_or_hf_repo.expanduser()
|
|
if path.exists():
|
|
return path
|
|
raise FileNotFoundError(f"Local path not found: {path_or_hf_repo}")
|
|
|
|
path = Path(path_or_hf_repo).expanduser()
|
|
if path.exists():
|
|
return path
|
|
if _is_local_path(path_or_hf_repo):
|
|
raise FileNotFoundError(f"Local path not found: {path_or_hf_repo}")
|
|
|
|
return Path(
|
|
snapshot_download(
|
|
path_or_hf_repo,
|
|
revision=revision,
|
|
allow_patterns=allow_patterns or DEFAULT_ALLOW_PATTERNS,
|
|
force_download=force_download,
|
|
)
|
|
)
|
|
|
|
|
|
def load_config(model_path: str | Path) -> dict[str, Any]:
|
|
model_path = Path(model_path)
|
|
config_file = model_path / "config.json"
|
|
if not config_file.exists():
|
|
raise FileNotFoundError(f"Config not found at {model_path}")
|
|
return json.loads(config_file.read_text(encoding="utf-8"))
|
|
|
|
|
|
def load_weights(model_path: Path) -> dict[str, mx.array]:
|
|
weight_files = sorted(glob.glob(str(model_path / "*.safetensors")))
|
|
if not weight_files:
|
|
weight_files = sorted(glob.glob(str(model_path / "*.npz")))
|
|
if not weight_files:
|
|
raise FileNotFoundError(
|
|
f"No weight files (safetensors or npz) found in {model_path}"
|
|
)
|
|
|
|
weights = {}
|
|
for weight_file in weight_files:
|
|
weights.update(mx.load(weight_file))
|
|
return weights
|
|
|
|
|
|
def apply_quantization(
|
|
model: nn.Module,
|
|
config: dict[str, Any],
|
|
weights: dict[str, mx.array],
|
|
model_quant_predicate=None,
|
|
) -> None:
|
|
quantization = config.get("quantization") or config.get("quantization_config")
|
|
if quantization is None:
|
|
return
|
|
|
|
group_size = quantization.get("group_size", 64)
|
|
|
|
def class_predicate(path, module):
|
|
if not hasattr(module, "to_quantized"):
|
|
return False
|
|
if hasattr(module, "weight") and module.weight.shape[-1] % group_size != 0:
|
|
return False
|
|
if model_quant_predicate is not None:
|
|
pred = model_quant_predicate(path, module)
|
|
if isinstance(pred, dict):
|
|
return pred
|
|
if not pred:
|
|
return False
|
|
if path in quantization:
|
|
return quantization[path]
|
|
return f"{path}.scales" in weights
|
|
|
|
nn.quantize(
|
|
model,
|
|
group_size=group_size,
|
|
bits=quantization["bits"],
|
|
mode=quantization.get("mode", "affine"),
|
|
class_predicate=class_predicate,
|
|
)
|
|
|
|
|
|
def load_model(
|
|
model_path: str | Path,
|
|
*,
|
|
lazy: bool = False,
|
|
strict: bool = False,
|
|
**kwargs: Any,
|
|
) -> nn.Module:
|
|
path = get_model_path(
|
|
model_path,
|
|
revision=kwargs.pop("revision", None),
|
|
force_download=kwargs.pop("force_download", False),
|
|
allow_patterns=kwargs.pop("allow_patterns", None),
|
|
)
|
|
config = load_config(path)
|
|
model = Model(ModelConfig.from_dict(config))
|
|
weights = load_weights(path)
|
|
|
|
if hasattr(model, "sanitize"):
|
|
weights = model.sanitize(weights)
|
|
|
|
apply_quantization(model, config, weights, model.model_quant_predicate)
|
|
model.load_weights(list(weights.items()), strict=strict)
|
|
|
|
if not lazy:
|
|
mx.eval(model.parameters())
|
|
model.eval()
|
|
|
|
if hasattr(Model, "post_load_hook"):
|
|
model = Model.post_load_hook(model, path)
|
|
return model
|
|
|