Initial Granite Speech Plus MLX package
This commit is contained in:
154
src/granite_speech_plus_mlx/_vendored/loader.py
Normal file
154
src/granite_speech_plus_mlx/_vendored/loader.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .granite_speech import Model, ModelConfig
|
||||
|
||||
DEFAULT_ALLOW_PATTERNS = [
|
||||
"*.json",
|
||||
"*.safetensors",
|
||||
"*.py",
|
||||
"*.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.jsonl",
|
||||
"*.yaml",
|
||||
"*.npz",
|
||||
]
|
||||
|
||||
|
||||
def _is_local_path(path: str) -> bool:
|
||||
return (
|
||||
path.startswith(".")
|
||||
or path.startswith("/")
|
||||
or path.startswith("~")
|
||||
or (len(path) > 1 and path[1] == ":")
|
||||
)
|
||||
|
||||
|
||||
def get_model_path(
|
||||
path_or_hf_repo: str | Path,
|
||||
*,
|
||||
revision: str | None = None,
|
||||
force_download: bool = False,
|
||||
allow_patterns: list[str] | None = None,
|
||||
) -> Path:
|
||||
if isinstance(path_or_hf_repo, Path):
|
||||
path = path_or_hf_repo.expanduser()
|
||||
if path.exists():
|
||||
return path
|
||||
raise FileNotFoundError(f"Local path not found: {path_or_hf_repo}")
|
||||
|
||||
path = Path(path_or_hf_repo).expanduser()
|
||||
if path.exists():
|
||||
return path
|
||||
if _is_local_path(path_or_hf_repo):
|
||||
raise FileNotFoundError(f"Local path not found: {path_or_hf_repo}")
|
||||
|
||||
return Path(
|
||||
snapshot_download(
|
||||
path_or_hf_repo,
|
||||
revision=revision,
|
||||
allow_patterns=allow_patterns or DEFAULT_ALLOW_PATTERNS,
|
||||
force_download=force_download,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def load_config(model_path: str | Path) -> dict[str, Any]:
|
||||
model_path = Path(model_path)
|
||||
config_file = model_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise FileNotFoundError(f"Config not found at {model_path}")
|
||||
return json.loads(config_file.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def load_weights(model_path: Path) -> dict[str, mx.array]:
|
||||
weight_files = sorted(glob.glob(str(model_path / "*.safetensors")))
|
||||
if not weight_files:
|
||||
weight_files = sorted(glob.glob(str(model_path / "*.npz")))
|
||||
if not weight_files:
|
||||
raise FileNotFoundError(
|
||||
f"No weight files (safetensors or npz) found in {model_path}"
|
||||
)
|
||||
|
||||
weights = {}
|
||||
for weight_file in weight_files:
|
||||
weights.update(mx.load(weight_file))
|
||||
return weights
|
||||
|
||||
|
||||
def apply_quantization(
|
||||
model: nn.Module,
|
||||
config: dict[str, Any],
|
||||
weights: dict[str, mx.array],
|
||||
model_quant_predicate=None,
|
||||
) -> None:
|
||||
quantization = config.get("quantization") or config.get("quantization_config")
|
||||
if quantization is None:
|
||||
return
|
||||
|
||||
group_size = quantization.get("group_size", 64)
|
||||
|
||||
def class_predicate(path, module):
|
||||
if not hasattr(module, "to_quantized"):
|
||||
return False
|
||||
if hasattr(module, "weight") and module.weight.shape[-1] % group_size != 0:
|
||||
return False
|
||||
if model_quant_predicate is not None:
|
||||
pred = model_quant_predicate(path, module)
|
||||
if isinstance(pred, dict):
|
||||
return pred
|
||||
if not pred:
|
||||
return False
|
||||
if path in quantization:
|
||||
return quantization[path]
|
||||
return f"{path}.scales" in weights
|
||||
|
||||
nn.quantize(
|
||||
model,
|
||||
group_size=group_size,
|
||||
bits=quantization["bits"],
|
||||
mode=quantization.get("mode", "affine"),
|
||||
class_predicate=class_predicate,
|
||||
)
|
||||
|
||||
|
||||
def load_model(
|
||||
model_path: str | Path,
|
||||
*,
|
||||
lazy: bool = False,
|
||||
strict: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> nn.Module:
|
||||
path = get_model_path(
|
||||
model_path,
|
||||
revision=kwargs.pop("revision", None),
|
||||
force_download=kwargs.pop("force_download", False),
|
||||
allow_patterns=kwargs.pop("allow_patterns", None),
|
||||
)
|
||||
config = load_config(path)
|
||||
model = Model(ModelConfig.from_dict(config))
|
||||
weights = load_weights(path)
|
||||
|
||||
if hasattr(model, "sanitize"):
|
||||
weights = model.sanitize(weights)
|
||||
|
||||
apply_quantization(model, config, weights, model.model_quant_predicate)
|
||||
model.load_weights(list(weights.items()), strict=strict)
|
||||
|
||||
if not lazy:
|
||||
mx.eval(model.parameters())
|
||||
model.eval()
|
||||
|
||||
if hasattr(Model, "post_load_hook"):
|
||||
model = Model.post_load_hook(model, path)
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user