v0.1.0 — initial release
MLX-native port of Supertone's Supertonic 3 multilingual TTS. Runs the full flow-matching + classifier-free-guidance pipeline at ~x100 realtime on Apple Silicon, with audio cosine 1.0 vs the cached MLX path and cosine 0.98 vs the upstream ONNX Runtime reference. Weights are hosted at https://huggingface.co/ambassadia/supertonic-3-mlx and auto-downloaded on first use; this repository ships the port code, the model card, audio samples, and a zero-config setup_and_test.sh. Install: pip install git+https://gitea.tavportal.com/olivier/supertonic-3-mlx.git Quick test: git clone https://gitea.tavportal.com/olivier/supertonic-3-mlx.git cd supertonic-3-mlx && ./setup_and_test.sh Licenses (dual): model weights = BigScience Open RAIL-M (Section 4 propagation), port code = Apache-2.0. See LICENSE, LICENSE-CODE, NOTICE. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
545
src/supertonic_3_mlx/pipeline.py
Normal file
545
src/supertonic_3_mlx/pipeline.py
Normal file
@@ -0,0 +1,545 @@
|
||||
"""Supertonic 3 end-to-end MLX pipeline.
|
||||
|
||||
Stitches the four MLX sub-models (DurationPredictor → TextEncoder →
|
||||
VectorEstimator → Vocoder) into a single ``generate(text, voice, lang)`` call
|
||||
that returns a 44.1 kHz mono numpy waveform.
|
||||
|
||||
Flow:
|
||||
|
||||
text ──tokenize(unicode_indexer)──▶ text_ids (B, T_text)
|
||||
│
|
||||
voice_style (.json) ──▶ style_ttl (B, 50, 256), style_dp (B, 8, 16)
|
||||
│
|
||||
duration_predictor(text_ids, style_dp, text_mask) ──▶ duration_s (B,)
|
||||
│
|
||||
text_encoder(text_ids, style_ttl, text_mask) ──▶ text_emb (B, 256, T_text)
|
||||
│
|
||||
noise ~ N(0, I) of shape (B, 144, T_lat)
|
||||
where T_lat = ceil(duration_s × 44100 / (512 × 6))
|
||||
│
|
||||
vector_estimator 5-step Euler with CFG (4×cond − 3×uncond):
|
||||
for step in [0..4]:
|
||||
x ← VE(x, text_emb, style_ttl, masks, current_step=step+1, total_step=5)
|
||||
│
|
||||
vocoder(audio_latent) ──▶ wav (B, T_lat × 6 × 512)
|
||||
|
||||
Public API:
|
||||
|
||||
pipe = SupertonicMLXPipeline.from_pretrained("/tmp/supertonic3/model")
|
||||
wav = pipe.generate("Hello world", voice="F1", lang="en")
|
||||
import soundfile as sf
|
||||
sf.write("out.wav", wav, pipe.sample_rate)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from supertonic_3_mlx._config import SAMPLE_RATE
|
||||
from supertonic_3_mlx.duration_predictor import DurationPredictor
|
||||
from supertonic_3_mlx.text_encoder import TextEncoder
|
||||
from supertonic_3_mlx.vector_estimator import VectorEstimator
|
||||
from supertonic_3_mlx.vocoder import Vocoder
|
||||
|
||||
|
||||
# Latent rate: at 44.1 kHz with hop=512 and chunk_compress=6, one latent step
|
||||
# covers 512 × 6 = 3072 samples = 69.7 ms.
|
||||
SAMPLES_PER_LATENT_STEP = 512 * 6 # 3072
|
||||
|
||||
|
||||
# ── Shared ONNX → MLX weight extraction ─────────────────────────────
|
||||
|
||||
|
||||
def _convert_onnx(onnx_path: str | Path) -> dict:
|
||||
"""Return a dict of ``{clean_key: mx.array}`` for a Supertonic ONNX file.
|
||||
|
||||
Combines the three extraction stages discovered during the per-component
|
||||
ports (T.3.1, T.3.2, T.3.3):
|
||||
|
||||
1. Named ``tts.*`` initialisers with shape transforms (dwconv, gamma,
|
||||
pwconv, head.layer2).
|
||||
2. Anonymous MatMul weights recovered via the MatMul output path.
|
||||
3. Anonymous Conv weights and PReLU slopes recovered the same way.
|
||||
"""
|
||||
import onnx
|
||||
import onnx.numpy_helper as nh
|
||||
|
||||
m = onnx.load(str(onnx_path))
|
||||
|
||||
def _matmul_clean(out_name: str) -> str:
|
||||
p = out_name.lstrip("/")
|
||||
if p.endswith("/MatMul_output_0"):
|
||||
p = p[: -len("/MatMul_output_0")]
|
||||
# Drop the leading model-name path (e.g. /text_encoder/, /duration_predictor/, /vector_estimator/)
|
||||
for prefix in ("text_encoder/", "duration_predictor/", "vector_estimator/", "vocoder/"):
|
||||
if p.startswith(prefix):
|
||||
p = p[len(prefix):]
|
||||
break
|
||||
return p.replace("/", ".") + ".weight"
|
||||
|
||||
def _conv_clean(out_name: str) -> str:
|
||||
p = out_name.lstrip("/")
|
||||
if p.endswith("/Conv_output_0"):
|
||||
p = p[: -len("/Conv_output_0")]
|
||||
for prefix in ("vocoder/", "vector_estimator/", "text_encoder/", "duration_predictor/"):
|
||||
if p.startswith(prefix):
|
||||
p = p[len(prefix):]
|
||||
break
|
||||
return "tts.ae." + p.replace("/", ".")
|
||||
|
||||
def _prelu_clean(out_name: str) -> str:
|
||||
p = out_name.lstrip("/")
|
||||
if p.endswith("/PRelu_output_0"):
|
||||
p = p[: -len("/PRelu_output_0")]
|
||||
for prefix in ("vocoder/", "vector_estimator/"):
|
||||
if p.startswith(prefix):
|
||||
p = p[len(prefix):]
|
||||
break
|
||||
return "tts.ae." + p.replace("/", ".") + ".weight"
|
||||
|
||||
# Detect which model this file is — affects how we wrap named init keys
|
||||
name_prefixes = {init.name.split(".")[0] for init in m.graph.initializer if "." in init.name}
|
||||
is_text_encoder = "tts" in name_prefixes and any(
|
||||
i.name.startswith("tts.ttl.text_encoder") for i in m.graph.initializer
|
||||
)
|
||||
|
||||
weights: dict[str, mx.array] = {}
|
||||
|
||||
# Stage 1: named initialisers
|
||||
for init in m.graph.initializer:
|
||||
n = init.name
|
||||
# Determine if this is a structured (named) weight or an anonymous graph const
|
||||
if not (n.startswith("tts.") or "vector_estimator.tts.ttl." in n or "uncond_masker." in n):
|
||||
continue
|
||||
|
||||
# Strip the vector_estimator-specific prefix so all 4 models share a name space.
|
||||
if n.startswith("vector_estimator.tts.ttl."):
|
||||
clean = n[len("vector_estimator.tts.ttl."):]
|
||||
else:
|
||||
clean = n
|
||||
|
||||
arr = nh.to_array(init)
|
||||
|
||||
# Shape transforms
|
||||
if (clean.endswith(".dwconv.weight") and arr.ndim == 3
|
||||
and arr.shape[1] == 1 and arr.shape[2] != 1):
|
||||
arr = np.transpose(arr, (0, 2, 1))
|
||||
if (clean.endswith(".dwconv.net.weight") and arr.ndim == 3
|
||||
and arr.shape[1] == 1):
|
||||
arr = np.transpose(arr, (0, 2, 1))
|
||||
if (clean.endswith(".gamma") and arr.ndim == 3
|
||||
and arr.shape[0] == 1 and arr.shape[2] == 1):
|
||||
arr = arr.reshape(arr.shape[1])
|
||||
if ((clean.endswith(".pwconv1.weight") or clean.endswith(".pwconv2.weight"))
|
||||
and arr.ndim == 3 and arr.shape[-1] == 1):
|
||||
arr = arr.squeeze(-1)
|
||||
if clean.endswith(".net.weight") and arr.ndim == 3 and arr.shape[-1] == 1:
|
||||
# Conv1d k=1 wrapped via .net (e.g. proj_in/proj_out)
|
||||
arr = arr.squeeze(-1)
|
||||
# vocoder head.layer2 (out, in, 1) → MLX Conv1d (out, K=1, in)
|
||||
if clean == "tts.ae.decoder.head.layer2.weight" and arr.ndim == 3:
|
||||
arr = np.transpose(arr, (0, 2, 1))
|
||||
# vocoder head.layer1.net.weight (out, in, K) → MLX Conv1d (out, K, in)
|
||||
if clean == "tts.ae.decoder.head.layer1.net.weight" and arr.ndim == 3:
|
||||
arr = np.transpose(arr, (0, 2, 1))
|
||||
|
||||
weights[clean] = mx.array(arr)
|
||||
|
||||
# Stage 2: MatMul weight recovery
|
||||
inits_map = {init.name: init for init in m.graph.initializer}
|
||||
for node in m.graph.node:
|
||||
if node.op_type != "MatMul" or len(node.input) < 2:
|
||||
continue
|
||||
winp = node.input[1]
|
||||
if winp not in inits_map or winp.startswith("tts.") or "vector_estimator.tts" in winp:
|
||||
continue
|
||||
arr = nh.to_array(inits_map[winp])
|
||||
if arr.ndim == 2:
|
||||
arr = arr.T # ONNX (in, out) → MLX Linear (out, in)
|
||||
clean = _matmul_clean(node.output[0])
|
||||
# Build the leading namespace from the file context (already in tts.*)
|
||||
if not clean.startswith(("tts.", "vector_field.", "uncond_masker.")):
|
||||
clean = "tts.ttl." + clean if is_text_encoder else clean
|
||||
weights[clean] = mx.array(arr)
|
||||
|
||||
# Stage 3: anonymous Conv + PReLU (vocoder embed / head)
|
||||
for node in m.graph.node:
|
||||
if node.op_type == "Conv":
|
||||
for i, inp in enumerate(node.input[1:], 1):
|
||||
if inp not in inits_map or inp.startswith("tts."):
|
||||
continue
|
||||
arr = nh.to_array(inits_map[inp])
|
||||
base = _conv_clean(node.output[0])
|
||||
if "dwconv" in base:
|
||||
continue
|
||||
if i == 1 and arr.ndim == 3:
|
||||
arr = np.transpose(arr, (0, 2, 1)) # ONNX (out, in, K) → MLX (out, K, in)
|
||||
key = base + (".weight" if i == 1 else ".bias")
|
||||
weights[key] = mx.array(arr)
|
||||
elif node.op_type == "PRelu":
|
||||
for inp in node.input[1:]:
|
||||
if inp in inits_map and not inp.startswith("tts."):
|
||||
weights[_prelu_clean(node.output[0])] = mx.array(nh.to_array(inits_map[inp]))
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def _load_into(model, weights: dict) -> int:
|
||||
"""Match converted weights to model params (shape-tolerant via reshape).
|
||||
|
||||
Returns the number of successfully matched tensors.
|
||||
"""
|
||||
from mlx.utils import tree_flatten
|
||||
expected = {k: tuple(v.shape) for k, v in tree_flatten(model.parameters())}
|
||||
matched = {}
|
||||
for k, exp_shape in expected.items():
|
||||
if k not in weights:
|
||||
continue
|
||||
v = weights[k]
|
||||
if tuple(v.shape) != exp_shape:
|
||||
if v.size == np.prod(exp_shape):
|
||||
v = v.reshape(exp_shape)
|
||||
else:
|
||||
continue
|
||||
matched[k] = v
|
||||
model.load_weights(list(matched.items()), strict=False)
|
||||
return len(matched)
|
||||
|
||||
|
||||
# ── Tokenization ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _encode_text(text: str, indexer: list[int], lang: str = "en") -> np.ndarray:
|
||||
"""Encode a text string into character IDs.
|
||||
|
||||
The unicode_indexer is a flat list of size 65536; ``indexer[ord(c)]`` gives
|
||||
the token ID for character ``c`` (-1 = unknown). For Phase T.4 we wrap the
|
||||
text with no special language tokens — the ONNX SDK uses language tags but
|
||||
our pipeline currently runs unconditioned on language for the first WAV
|
||||
emission (parity validation happens after).
|
||||
"""
|
||||
ids = []
|
||||
for c in text:
|
||||
cp = ord(c)
|
||||
if 0 <= cp < len(indexer):
|
||||
tok = indexer[cp]
|
||||
if tok >= 0:
|
||||
ids.append(tok)
|
||||
if not ids:
|
||||
# fallback to a single space token to avoid empty input
|
||||
ids = [indexer[ord(" ")]] if indexer[ord(" ")] >= 0 else [0]
|
||||
return np.asarray(ids, dtype=np.int32)
|
||||
|
||||
|
||||
# ── Pipeline ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class SupertonicMLXPipeline:
|
||||
"""End-to-end Supertonic 3 TTS pipeline in pure MLX.
|
||||
|
||||
Loads four sub-models (duration_predictor, text_encoder, vector_estimator,
|
||||
vocoder), the unicode tokenizer, and exposes ``generate(text, voice, lang)``.
|
||||
"""
|
||||
|
||||
sample_rate: int = SAMPLE_RATE
|
||||
# Locked by the model architecture: Supertonic 3 is a flow-matching + CFG
|
||||
# model trained for exactly 5 Euler steps with t ∈ {0.2, 0.4, 0.6, 0.8, 1.0}
|
||||
# and the combination 4×cond − 3×uncond. Any other step count or skipping
|
||||
# CFG produces an essentially uncorrelated waveform (verified by
|
||||
# ``sub-projects/supertonic3-mlx/bench_n_steps.py``: cosine drops to
|
||||
# ≤ 0.5 for n∈{3,4,6} and ≈ 0.05 for cfg=False). Reducing inference
|
||||
# latency further would require distilling a shorter-schedule model.
|
||||
n_euler_steps: int = 5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
duration_predictor: DurationPredictor,
|
||||
text_encoder: TextEncoder,
|
||||
vector_estimator: VectorEstimator,
|
||||
vocoder: Vocoder,
|
||||
unicode_indexer: list[int],
|
||||
voice_dir: Path,
|
||||
) -> None:
|
||||
self.duration_predictor = duration_predictor
|
||||
self.text_encoder = text_encoder
|
||||
self.vector_estimator = vector_estimator
|
||||
self.vocoder = vocoder
|
||||
self.unicode_indexer = unicode_indexer
|
||||
self.voice_dir = voice_dir
|
||||
|
||||
# T.5 — compile the hot loops. ``mx.compile`` caches a kernel graph keyed
|
||||
# by input shapes; the 5× CFG Euler loop and the single vocoder pass
|
||||
# both gain from fused kernel dispatch (~50–100 layer ops collapse into
|
||||
# one dispatch per cached graph).
|
||||
|
||||
# T.5.3 — also pre-project text and style K/V outside the step. They
|
||||
# are invariant across the 5 Euler steps, so the 4 text_attn + 4
|
||||
# style_attn blocks no longer re-run their W_key / W_value / RoPE_K
|
||||
# matmuls on every step (saves 40 matmuls per generate).
|
||||
cond_scale = self.vector_estimator.CFG_COND_SCALE
|
||||
uncond_scale = self.vector_estimator.CFG_UNCOND_SCALE
|
||||
|
||||
def _cached_step(
|
||||
noisy, lat_mask_2, text_mask_2, t_norm_2, total_step, kv_flat,
|
||||
):
|
||||
noisy_2 = mx.concatenate([noisy, noisy], axis=0)
|
||||
text_kv = [(kv_flat[2 * i], kv_flat[2 * i + 1]) for i in range(4)]
|
||||
style_kv = [(kv_flat[8 + 2 * i], kv_flat[8 + 2 * i + 1]) for i in range(4)]
|
||||
v_2 = self.vector_estimator.velocity_cached(
|
||||
noisy_2, lat_mask_2, text_mask_2, t_norm_2, text_kv, style_kv,
|
||||
)
|
||||
B = noisy.shape[0]
|
||||
cond_v = v_2[:B]
|
||||
uncond_v = v_2[B:]
|
||||
combined = cond_scale * cond_v - uncond_scale * uncond_v
|
||||
return noisy + combined / total_step.reshape(-1, 1, 1).astype(combined.dtype)
|
||||
|
||||
def _voc_step(latent):
|
||||
return self.vocoder(latent)
|
||||
|
||||
self._cached_step_compiled = mx.compile(_cached_step)
|
||||
self._voc_compiled = mx.compile(_voc_step)
|
||||
|
||||
# Pick the runtime dtype from any leaf weight of the vector estimator —
|
||||
# ``from_pretrained(dtype=...)`` may have cast the model to ``bf16``,
|
||||
# in which case all inputs to the compiled hot loops must be cast to
|
||||
# match (mixed-dtype Conv/MatMul is not legal in MLX).
|
||||
from mlx.utils import tree_flatten
|
||||
leaves = [v for _, v in tree_flatten(vector_estimator.parameters())
|
||||
if isinstance(v, mx.array)]
|
||||
self.dtype = leaves[0].dtype if leaves else mx.float32
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_id_or_path: str | Path,
|
||||
dtype: mx.Dtype | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
revision: str | None = None,
|
||||
) -> "SupertonicMLXPipeline":
|
||||
"""Construct the pipeline from a model snapshot.
|
||||
|
||||
Three sources are accepted, auto-detected:
|
||||
|
||||
1. **Hugging Face Hub repo id** (e.g. ``"ambassadia/supertonic-3-mlx"``):
|
||||
weights are downloaded via :func:`huggingface_hub.snapshot_download`
|
||||
into ``cache_dir`` (defaults to the standard HF cache) and loaded
|
||||
directly from the bundled ``weights/*.safetensors`` files.
|
||||
2. **Local path with a** ``weights/`` **subdir**: the MLX-native
|
||||
layout (4 safetensors + ``unicode_indexer.json`` + ``voice_styles/``).
|
||||
Fast path — no ONNX conversion at runtime.
|
||||
3. **Local path with an** ``onnx/`` **subdir**: the upstream
|
||||
``Supertone/supertonic-3`` snapshot layout. Weights are converted
|
||||
from ONNX on the fly (~ 1 s per sub-model on M4). Useful for
|
||||
development or when starting from the original upstream release.
|
||||
|
||||
Optional kwargs:
|
||||
dtype — if non-None and not float32, cast all weights to the
|
||||
given dtype after load (only ``mx.bfloat16`` is
|
||||
currently meaningful; see README "BF16 note").
|
||||
cache_dir — passed to ``huggingface_hub.snapshot_download``.
|
||||
revision — branch / tag / commit sha on the Hub.
|
||||
"""
|
||||
# 1. Resolve the local snapshot directory
|
||||
if isinstance(model_id_or_path, str) and "/" in model_id_or_path \
|
||||
and not Path(model_id_or_path).exists():
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Loading from the Hugging Face Hub requires "
|
||||
"``huggingface_hub`` — install with ``pip install "
|
||||
"supertonic-3-mlx[hub]`` or ``pip install huggingface_hub``."
|
||||
) from e
|
||||
local_dir = Path(snapshot_download(
|
||||
repo_id=model_id_or_path,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
allow_patterns=[
|
||||
"weights/*.safetensors",
|
||||
"unicode_indexer.json",
|
||||
"voice_styles/*.json",
|
||||
],
|
||||
))
|
||||
else:
|
||||
local_dir = Path(model_id_or_path)
|
||||
|
||||
# 2. Detect layout
|
||||
weights_dir = local_dir / "weights"
|
||||
onnx_dir = local_dir / "onnx"
|
||||
if weights_dir.exists():
|
||||
return cls._from_safetensors(local_dir, dtype=dtype)
|
||||
if onnx_dir.exists():
|
||||
return cls._from_onnx(local_dir, dtype=dtype)
|
||||
raise FileNotFoundError(
|
||||
f"{local_dir} contains neither ``weights/`` (safetensors layout) "
|
||||
f"nor ``onnx/`` (upstream layout); cannot load."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_safetensors(
|
||||
cls, local_dir: Path, dtype: mx.Dtype | None = None,
|
||||
) -> "SupertonicMLXPipeline":
|
||||
from mlx.utils import tree_flatten
|
||||
weights_dir = local_dir / "weights"
|
||||
voice_dir = local_dir / "voice_styles"
|
||||
unicode_indexer = json.loads((local_dir / "unicode_indexer.json").read_text())
|
||||
|
||||
def _build(cls_, name):
|
||||
model = cls_()
|
||||
w = mx.load(str(weights_dir / f"{name}.safetensors"))
|
||||
# Reshape any mismatched leaves (defensive; the converter already
|
||||
# produced shape-correct tensors but a future re-export may not).
|
||||
expected = {k: tuple(v.shape) for k, v in tree_flatten(model.parameters())}
|
||||
for k in list(w.keys()):
|
||||
if k in expected and tuple(w[k].shape) != expected[k]:
|
||||
if w[k].size == int(np.prod(expected[k])):
|
||||
w[k] = w[k].reshape(expected[k])
|
||||
model.load_weights(list(w.items()), strict=False)
|
||||
return model
|
||||
|
||||
ve = _build(VectorEstimator, "vector_estimator")
|
||||
te = _build(TextEncoder, "text_encoder")
|
||||
dp = _build(DurationPredictor, "duration_predictor")
|
||||
voc = _build(Vocoder, "vocoder")
|
||||
|
||||
if dtype is not None and dtype != mx.float32:
|
||||
cls._cast_all(dp, te, ve, voc, dtype=dtype)
|
||||
|
||||
return cls(dp, te, ve, voc, unicode_indexer, voice_dir)
|
||||
|
||||
@classmethod
|
||||
def _from_onnx(
|
||||
cls, local_dir: Path, dtype: mx.Dtype | None = None,
|
||||
) -> "SupertonicMLXPipeline":
|
||||
onnx_dir = local_dir / "onnx"
|
||||
voice_dir = local_dir / "voice_styles"
|
||||
unicode_indexer = json.loads((onnx_dir / "unicode_indexer.json").read_text())
|
||||
|
||||
ve = VectorEstimator()
|
||||
_load_into(ve, _convert_onnx(onnx_dir / "vector_estimator.onnx"))
|
||||
te = TextEncoder()
|
||||
_load_into(te, _convert_onnx(onnx_dir / "text_encoder.onnx"))
|
||||
dp = DurationPredictor()
|
||||
_load_into(dp, _convert_onnx(onnx_dir / "duration_predictor.onnx"))
|
||||
voc = Vocoder()
|
||||
_load_into(voc, _convert_onnx(onnx_dir / "vocoder.onnx"))
|
||||
|
||||
if dtype is not None and dtype != mx.float32:
|
||||
cls._cast_all(dp, te, ve, voc, dtype=dtype)
|
||||
|
||||
return cls(dp, te, ve, voc, unicode_indexer, voice_dir)
|
||||
|
||||
@staticmethod
|
||||
def _cast_all(*models, dtype: mx.Dtype) -> None:
|
||||
"""Cast all fp32 leaves of each model to ``dtype`` (in-place)."""
|
||||
from mlx.utils import tree_map
|
||||
|
||||
def _cast(p):
|
||||
if not isinstance(p, mx.array) or p.dtype != mx.float32:
|
||||
return p
|
||||
return p.astype(dtype)
|
||||
|
||||
for m_ in models:
|
||||
m_.update(tree_map(_cast, m_.parameters()))
|
||||
|
||||
def _load_voice(self, voice: str) -> tuple[mx.array, mx.array]:
|
||||
"""Load ``voice_styles/<voice>.json`` and return (style_ttl, style_dp)."""
|
||||
path = self.voice_dir / f"{voice}.json"
|
||||
data = json.loads(path.read_text())
|
||||
style_ttl = np.asarray(data["style_ttl"]["data"], dtype=np.float32) # (1, 50, 256)
|
||||
style_dp = np.asarray(data["style_dp"]["data"], dtype=np.float32) # (1, 8, 16)
|
||||
return mx.array(style_ttl), mx.array(style_dp)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice: str = "F1",
|
||||
lang: str = "en",
|
||||
seed: int = 42,
|
||||
n_steps: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
"""Synthesise a single utterance. Returns a 1D float32 numpy waveform."""
|
||||
n_steps = n_steps if n_steps is not None else self.n_euler_steps
|
||||
|
||||
# Tokenize
|
||||
text_ids_np = _encode_text(text, self.unicode_indexer, lang)
|
||||
text_ids = mx.array(text_ids_np[None, :]) # (1, T_text)
|
||||
T_text = text_ids.shape[1]
|
||||
text_mask = mx.ones((1, 1, T_text), dtype=self.dtype)
|
||||
|
||||
# Style
|
||||
style_ttl, style_dp = self._load_voice(voice)
|
||||
if self.dtype != mx.float32:
|
||||
style_ttl = style_ttl.astype(self.dtype)
|
||||
style_dp = style_dp.astype(self.dtype)
|
||||
|
||||
# Duration → latent length
|
||||
duration_s = self.duration_predictor(text_ids, style_dp, text_mask)
|
||||
mx.eval(duration_s)
|
||||
duration_val = max(float(duration_s[0].item()), 0.5) # clamp to ≥ 0.5 s
|
||||
T_lat = max(int(math.ceil(duration_val * self.sample_rate / SAMPLES_PER_LATENT_STEP)), 1)
|
||||
|
||||
# Text embedding
|
||||
text_emb = self.text_encoder(text_ids, style_ttl, text_mask) # (1, 256, T_text)
|
||||
|
||||
# Initial noise — fixed seed for reproducibility
|
||||
key = mx.random.key(seed)
|
||||
noise = mx.random.normal((1, 144, T_lat), key=key).astype(self.dtype)
|
||||
latent_mask = mx.ones((1, 1, T_lat), dtype=self.dtype)
|
||||
|
||||
# T.5.3 — build the (2B) CFG conditioning tensors once and pre-project
|
||||
# K/V for every text_attn / style_attn block. ``kv_flat`` is the 16
|
||||
# ``(K, V)`` arrays flattened into a list for the compiled step.
|
||||
B = noise.shape[0]
|
||||
ve = self.vector_estimator
|
||||
text_uncond = mx.broadcast_to(
|
||||
ve.uncond_masker.text_special_token, (B, text_emb.shape[1], text_emb.shape[2])
|
||||
).astype(self.dtype)
|
||||
style_k_uncond = mx.broadcast_to(
|
||||
ve.uncond_masker.style_key_special_token, (B, style_ttl.shape[1], style_ttl.shape[2])
|
||||
).astype(self.dtype)
|
||||
style_v_uncond = mx.broadcast_to(
|
||||
ve.uncond_masker.style_value_special_token, (B, style_ttl.shape[1], style_ttl.shape[2])
|
||||
).astype(self.dtype)
|
||||
text_emb_2 = mx.concatenate([text_emb, text_uncond], axis=0)
|
||||
style_k_2 = mx.concatenate([style_ttl, style_k_uncond], axis=0)
|
||||
style_v_2 = mx.concatenate([style_ttl, style_v_uncond], axis=0)
|
||||
text_mask_2 = mx.concatenate([text_mask, text_mask], axis=0)
|
||||
latent_mask_2 = mx.concatenate([latent_mask, latent_mask], axis=0)
|
||||
|
||||
text_kv, style_kv = ve.precompute_cross_kv(
|
||||
text_emb_2, style_k_2, style_v_2, text_mask_2,
|
||||
)
|
||||
kv_flat = []
|
||||
for k, v in text_kv:
|
||||
kv_flat.extend([k, v])
|
||||
for k, v in style_kv:
|
||||
kv_flat.extend([k, v])
|
||||
|
||||
# Euler with CFG — 5 steps by default
|
||||
x = noise
|
||||
total_step = mx.array([float(n_steps)], dtype=self.dtype)
|
||||
for step in range(n_steps):
|
||||
current_step = mx.array([float(step + 1)], dtype=self.dtype)
|
||||
t_norm = current_step / total_step
|
||||
t_norm_2 = mx.concatenate([t_norm, t_norm], axis=0)
|
||||
x = self._cached_step_compiled(
|
||||
x, latent_mask_2, text_mask_2, t_norm_2, total_step, kv_flat,
|
||||
)
|
||||
mx.eval(x)
|
||||
|
||||
# Decode latent → waveform
|
||||
wav = self._voc_compiled(x)
|
||||
mx.eval(wav)
|
||||
if wav.dtype != mx.float32:
|
||||
wav = wav.astype(mx.float32)
|
||||
return np.array(wav)[0] # (T_lat × 6 × 512,)
|
||||
|
||||
|
||||
__all__ = ["SupertonicMLXPipeline"]
|
||||
Reference in New Issue
Block a user