v0.1.0 — initial release

MLX-native port of Supertone's Supertonic 3 multilingual TTS. Runs the
full flow-matching + classifier-free-guidance pipeline at ~x100 realtime
on Apple Silicon, with audio cosine 1.0 vs the cached MLX path and
cosine 0.98 vs the upstream ONNX Runtime reference.

Weights are hosted at https://huggingface.co/ambassadia/supertonic-3-mlx
and auto-downloaded on first use; this repository ships the port code,
the model card, audio samples, and a zero-config setup_and_test.sh.

Install:
    pip install git+https://gitea.tavportal.com/olivier/supertonic-3-mlx.git

Quick test:
    git clone https://gitea.tavportal.com/olivier/supertonic-3-mlx.git
    cd supertonic-3-mlx && ./setup_and_test.sh

Licenses (dual): model weights = BigScience Open RAIL-M (Section 4
propagation), port code = Apache-2.0. See LICENSE, LICENSE-CODE, NOTICE.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
transcrilive
2026-05-20 09:17:05 +02:00
commit 12dbf4a821
36 changed files with 3812 additions and 0 deletions

View File

@@ -0,0 +1,545 @@
"""Supertonic 3 end-to-end MLX pipeline.
Stitches the four MLX sub-models (DurationPredictor → TextEncoder →
VectorEstimator → Vocoder) into a single ``generate(text, voice, lang)`` call
that returns a 44.1 kHz mono numpy waveform.
Flow:
text ──tokenize(unicode_indexer)──▶ text_ids (B, T_text)
voice_style (.json) ──▶ style_ttl (B, 50, 256), style_dp (B, 8, 16)
duration_predictor(text_ids, style_dp, text_mask) ──▶ duration_s (B,)
text_encoder(text_ids, style_ttl, text_mask) ──▶ text_emb (B, 256, T_text)
noise ~ N(0, I) of shape (B, 144, T_lat)
where T_lat = ceil(duration_s × 44100 / (512 × 6))
vector_estimator 5-step Euler with CFG (4×cond 3×uncond):
for step in [0..4]:
x ← VE(x, text_emb, style_ttl, masks, current_step=step+1, total_step=5)
vocoder(audio_latent) ──▶ wav (B, T_lat × 6 × 512)
Public API:
pipe = SupertonicMLXPipeline.from_pretrained("/tmp/supertonic3/model")
wav = pipe.generate("Hello world", voice="F1", lang="en")
import soundfile as sf
sf.write("out.wav", wav, pipe.sample_rate)
"""
from __future__ import annotations
import json
import math
from pathlib import Path
from typing import Optional
import mlx.core as mx
import numpy as np
from supertonic_3_mlx._config import SAMPLE_RATE
from supertonic_3_mlx.duration_predictor import DurationPredictor
from supertonic_3_mlx.text_encoder import TextEncoder
from supertonic_3_mlx.vector_estimator import VectorEstimator
from supertonic_3_mlx.vocoder import Vocoder
# Latent rate: at 44.1 kHz with hop=512 and chunk_compress=6, one latent step
# covers 512 × 6 = 3072 samples = 69.7 ms.
SAMPLES_PER_LATENT_STEP = 512 * 6 # 3072
# ── Shared ONNX → MLX weight extraction ─────────────────────────────
def _convert_onnx(onnx_path: str | Path) -> dict:
"""Return a dict of ``{clean_key: mx.array}`` for a Supertonic ONNX file.
Combines the three extraction stages discovered during the per-component
ports (T.3.1, T.3.2, T.3.3):
1. Named ``tts.*`` initialisers with shape transforms (dwconv, gamma,
pwconv, head.layer2).
2. Anonymous MatMul weights recovered via the MatMul output path.
3. Anonymous Conv weights and PReLU slopes recovered the same way.
"""
import onnx
import onnx.numpy_helper as nh
m = onnx.load(str(onnx_path))
def _matmul_clean(out_name: str) -> str:
p = out_name.lstrip("/")
if p.endswith("/MatMul_output_0"):
p = p[: -len("/MatMul_output_0")]
# Drop the leading model-name path (e.g. /text_encoder/, /duration_predictor/, /vector_estimator/)
for prefix in ("text_encoder/", "duration_predictor/", "vector_estimator/", "vocoder/"):
if p.startswith(prefix):
p = p[len(prefix):]
break
return p.replace("/", ".") + ".weight"
def _conv_clean(out_name: str) -> str:
p = out_name.lstrip("/")
if p.endswith("/Conv_output_0"):
p = p[: -len("/Conv_output_0")]
for prefix in ("vocoder/", "vector_estimator/", "text_encoder/", "duration_predictor/"):
if p.startswith(prefix):
p = p[len(prefix):]
break
return "tts.ae." + p.replace("/", ".")
def _prelu_clean(out_name: str) -> str:
p = out_name.lstrip("/")
if p.endswith("/PRelu_output_0"):
p = p[: -len("/PRelu_output_0")]
for prefix in ("vocoder/", "vector_estimator/"):
if p.startswith(prefix):
p = p[len(prefix):]
break
return "tts.ae." + p.replace("/", ".") + ".weight"
# Detect which model this file is — affects how we wrap named init keys
name_prefixes = {init.name.split(".")[0] for init in m.graph.initializer if "." in init.name}
is_text_encoder = "tts" in name_prefixes and any(
i.name.startswith("tts.ttl.text_encoder") for i in m.graph.initializer
)
weights: dict[str, mx.array] = {}
# Stage 1: named initialisers
for init in m.graph.initializer:
n = init.name
# Determine if this is a structured (named) weight or an anonymous graph const
if not (n.startswith("tts.") or "vector_estimator.tts.ttl." in n or "uncond_masker." in n):
continue
# Strip the vector_estimator-specific prefix so all 4 models share a name space.
if n.startswith("vector_estimator.tts.ttl."):
clean = n[len("vector_estimator.tts.ttl."):]
else:
clean = n
arr = nh.to_array(init)
# Shape transforms
if (clean.endswith(".dwconv.weight") and arr.ndim == 3
and arr.shape[1] == 1 and arr.shape[2] != 1):
arr = np.transpose(arr, (0, 2, 1))
if (clean.endswith(".dwconv.net.weight") and arr.ndim == 3
and arr.shape[1] == 1):
arr = np.transpose(arr, (0, 2, 1))
if (clean.endswith(".gamma") and arr.ndim == 3
and arr.shape[0] == 1 and arr.shape[2] == 1):
arr = arr.reshape(arr.shape[1])
if ((clean.endswith(".pwconv1.weight") or clean.endswith(".pwconv2.weight"))
and arr.ndim == 3 and arr.shape[-1] == 1):
arr = arr.squeeze(-1)
if clean.endswith(".net.weight") and arr.ndim == 3 and arr.shape[-1] == 1:
# Conv1d k=1 wrapped via .net (e.g. proj_in/proj_out)
arr = arr.squeeze(-1)
# vocoder head.layer2 (out, in, 1) → MLX Conv1d (out, K=1, in)
if clean == "tts.ae.decoder.head.layer2.weight" and arr.ndim == 3:
arr = np.transpose(arr, (0, 2, 1))
# vocoder head.layer1.net.weight (out, in, K) → MLX Conv1d (out, K, in)
if clean == "tts.ae.decoder.head.layer1.net.weight" and arr.ndim == 3:
arr = np.transpose(arr, (0, 2, 1))
weights[clean] = mx.array(arr)
# Stage 2: MatMul weight recovery
inits_map = {init.name: init for init in m.graph.initializer}
for node in m.graph.node:
if node.op_type != "MatMul" or len(node.input) < 2:
continue
winp = node.input[1]
if winp not in inits_map or winp.startswith("tts.") or "vector_estimator.tts" in winp:
continue
arr = nh.to_array(inits_map[winp])
if arr.ndim == 2:
arr = arr.T # ONNX (in, out) → MLX Linear (out, in)
clean = _matmul_clean(node.output[0])
# Build the leading namespace from the file context (already in tts.*)
if not clean.startswith(("tts.", "vector_field.", "uncond_masker.")):
clean = "tts.ttl." + clean if is_text_encoder else clean
weights[clean] = mx.array(arr)
# Stage 3: anonymous Conv + PReLU (vocoder embed / head)
for node in m.graph.node:
if node.op_type == "Conv":
for i, inp in enumerate(node.input[1:], 1):
if inp not in inits_map or inp.startswith("tts."):
continue
arr = nh.to_array(inits_map[inp])
base = _conv_clean(node.output[0])
if "dwconv" in base:
continue
if i == 1 and arr.ndim == 3:
arr = np.transpose(arr, (0, 2, 1)) # ONNX (out, in, K) → MLX (out, K, in)
key = base + (".weight" if i == 1 else ".bias")
weights[key] = mx.array(arr)
elif node.op_type == "PRelu":
for inp in node.input[1:]:
if inp in inits_map and not inp.startswith("tts."):
weights[_prelu_clean(node.output[0])] = mx.array(nh.to_array(inits_map[inp]))
return weights
def _load_into(model, weights: dict) -> int:
"""Match converted weights to model params (shape-tolerant via reshape).
Returns the number of successfully matched tensors.
"""
from mlx.utils import tree_flatten
expected = {k: tuple(v.shape) for k, v in tree_flatten(model.parameters())}
matched = {}
for k, exp_shape in expected.items():
if k not in weights:
continue
v = weights[k]
if tuple(v.shape) != exp_shape:
if v.size == np.prod(exp_shape):
v = v.reshape(exp_shape)
else:
continue
matched[k] = v
model.load_weights(list(matched.items()), strict=False)
return len(matched)
# ── Tokenization ────────────────────────────────────────────────────
def _encode_text(text: str, indexer: list[int], lang: str = "en") -> np.ndarray:
"""Encode a text string into character IDs.
The unicode_indexer is a flat list of size 65536; ``indexer[ord(c)]`` gives
the token ID for character ``c`` (-1 = unknown). For Phase T.4 we wrap the
text with no special language tokens — the ONNX SDK uses language tags but
our pipeline currently runs unconditioned on language for the first WAV
emission (parity validation happens after).
"""
ids = []
for c in text:
cp = ord(c)
if 0 <= cp < len(indexer):
tok = indexer[cp]
if tok >= 0:
ids.append(tok)
if not ids:
# fallback to a single space token to avoid empty input
ids = [indexer[ord(" ")]] if indexer[ord(" ")] >= 0 else [0]
return np.asarray(ids, dtype=np.int32)
# ── Pipeline ────────────────────────────────────────────────────────
class SupertonicMLXPipeline:
"""End-to-end Supertonic 3 TTS pipeline in pure MLX.
Loads four sub-models (duration_predictor, text_encoder, vector_estimator,
vocoder), the unicode tokenizer, and exposes ``generate(text, voice, lang)``.
"""
sample_rate: int = SAMPLE_RATE
# Locked by the model architecture: Supertonic 3 is a flow-matching + CFG
# model trained for exactly 5 Euler steps with t ∈ {0.2, 0.4, 0.6, 0.8, 1.0}
# and the combination 4×cond 3×uncond. Any other step count or skipping
# CFG produces an essentially uncorrelated waveform (verified by
# ``sub-projects/supertonic3-mlx/bench_n_steps.py``: cosine drops to
# ≤ 0.5 for n∈{3,4,6} and ≈ 0.05 for cfg=False). Reducing inference
# latency further would require distilling a shorter-schedule model.
n_euler_steps: int = 5
def __init__(
self,
duration_predictor: DurationPredictor,
text_encoder: TextEncoder,
vector_estimator: VectorEstimator,
vocoder: Vocoder,
unicode_indexer: list[int],
voice_dir: Path,
) -> None:
self.duration_predictor = duration_predictor
self.text_encoder = text_encoder
self.vector_estimator = vector_estimator
self.vocoder = vocoder
self.unicode_indexer = unicode_indexer
self.voice_dir = voice_dir
# T.5 — compile the hot loops. ``mx.compile`` caches a kernel graph keyed
# by input shapes; the 5× CFG Euler loop and the single vocoder pass
# both gain from fused kernel dispatch (~50100 layer ops collapse into
# one dispatch per cached graph).
# T.5.3 — also pre-project text and style K/V outside the step. They
# are invariant across the 5 Euler steps, so the 4 text_attn + 4
# style_attn blocks no longer re-run their W_key / W_value / RoPE_K
# matmuls on every step (saves 40 matmuls per generate).
cond_scale = self.vector_estimator.CFG_COND_SCALE
uncond_scale = self.vector_estimator.CFG_UNCOND_SCALE
def _cached_step(
noisy, lat_mask_2, text_mask_2, t_norm_2, total_step, kv_flat,
):
noisy_2 = mx.concatenate([noisy, noisy], axis=0)
text_kv = [(kv_flat[2 * i], kv_flat[2 * i + 1]) for i in range(4)]
style_kv = [(kv_flat[8 + 2 * i], kv_flat[8 + 2 * i + 1]) for i in range(4)]
v_2 = self.vector_estimator.velocity_cached(
noisy_2, lat_mask_2, text_mask_2, t_norm_2, text_kv, style_kv,
)
B = noisy.shape[0]
cond_v = v_2[:B]
uncond_v = v_2[B:]
combined = cond_scale * cond_v - uncond_scale * uncond_v
return noisy + combined / total_step.reshape(-1, 1, 1).astype(combined.dtype)
def _voc_step(latent):
return self.vocoder(latent)
self._cached_step_compiled = mx.compile(_cached_step)
self._voc_compiled = mx.compile(_voc_step)
# Pick the runtime dtype from any leaf weight of the vector estimator —
# ``from_pretrained(dtype=...)`` may have cast the model to ``bf16``,
# in which case all inputs to the compiled hot loops must be cast to
# match (mixed-dtype Conv/MatMul is not legal in MLX).
from mlx.utils import tree_flatten
leaves = [v for _, v in tree_flatten(vector_estimator.parameters())
if isinstance(v, mx.array)]
self.dtype = leaves[0].dtype if leaves else mx.float32
@classmethod
def from_pretrained(
cls,
model_id_or_path: str | Path,
dtype: mx.Dtype | None = None,
cache_dir: str | Path | None = None,
revision: str | None = None,
) -> "SupertonicMLXPipeline":
"""Construct the pipeline from a model snapshot.
Three sources are accepted, auto-detected:
1. **Hugging Face Hub repo id** (e.g. ``"ambassadia/supertonic-3-mlx"``):
weights are downloaded via :func:`huggingface_hub.snapshot_download`
into ``cache_dir`` (defaults to the standard HF cache) and loaded
directly from the bundled ``weights/*.safetensors`` files.
2. **Local path with a** ``weights/`` **subdir**: the MLX-native
layout (4 safetensors + ``unicode_indexer.json`` + ``voice_styles/``).
Fast path — no ONNX conversion at runtime.
3. **Local path with an** ``onnx/`` **subdir**: the upstream
``Supertone/supertonic-3`` snapshot layout. Weights are converted
from ONNX on the fly (~ 1 s per sub-model on M4). Useful for
development or when starting from the original upstream release.
Optional kwargs:
dtype — if non-None and not float32, cast all weights to the
given dtype after load (only ``mx.bfloat16`` is
currently meaningful; see README "BF16 note").
cache_dir — passed to ``huggingface_hub.snapshot_download``.
revision — branch / tag / commit sha on the Hub.
"""
# 1. Resolve the local snapshot directory
if isinstance(model_id_or_path, str) and "/" in model_id_or_path \
and not Path(model_id_or_path).exists():
try:
from huggingface_hub import snapshot_download
except ImportError as e:
raise ImportError(
"Loading from the Hugging Face Hub requires "
"``huggingface_hub`` — install with ``pip install "
"supertonic-3-mlx[hub]`` or ``pip install huggingface_hub``."
) from e
local_dir = Path(snapshot_download(
repo_id=model_id_or_path,
cache_dir=cache_dir,
revision=revision,
allow_patterns=[
"weights/*.safetensors",
"unicode_indexer.json",
"voice_styles/*.json",
],
))
else:
local_dir = Path(model_id_or_path)
# 2. Detect layout
weights_dir = local_dir / "weights"
onnx_dir = local_dir / "onnx"
if weights_dir.exists():
return cls._from_safetensors(local_dir, dtype=dtype)
if onnx_dir.exists():
return cls._from_onnx(local_dir, dtype=dtype)
raise FileNotFoundError(
f"{local_dir} contains neither ``weights/`` (safetensors layout) "
f"nor ``onnx/`` (upstream layout); cannot load."
)
@classmethod
def _from_safetensors(
cls, local_dir: Path, dtype: mx.Dtype | None = None,
) -> "SupertonicMLXPipeline":
from mlx.utils import tree_flatten
weights_dir = local_dir / "weights"
voice_dir = local_dir / "voice_styles"
unicode_indexer = json.loads((local_dir / "unicode_indexer.json").read_text())
def _build(cls_, name):
model = cls_()
w = mx.load(str(weights_dir / f"{name}.safetensors"))
# Reshape any mismatched leaves (defensive; the converter already
# produced shape-correct tensors but a future re-export may not).
expected = {k: tuple(v.shape) for k, v in tree_flatten(model.parameters())}
for k in list(w.keys()):
if k in expected and tuple(w[k].shape) != expected[k]:
if w[k].size == int(np.prod(expected[k])):
w[k] = w[k].reshape(expected[k])
model.load_weights(list(w.items()), strict=False)
return model
ve = _build(VectorEstimator, "vector_estimator")
te = _build(TextEncoder, "text_encoder")
dp = _build(DurationPredictor, "duration_predictor")
voc = _build(Vocoder, "vocoder")
if dtype is not None and dtype != mx.float32:
cls._cast_all(dp, te, ve, voc, dtype=dtype)
return cls(dp, te, ve, voc, unicode_indexer, voice_dir)
@classmethod
def _from_onnx(
cls, local_dir: Path, dtype: mx.Dtype | None = None,
) -> "SupertonicMLXPipeline":
onnx_dir = local_dir / "onnx"
voice_dir = local_dir / "voice_styles"
unicode_indexer = json.loads((onnx_dir / "unicode_indexer.json").read_text())
ve = VectorEstimator()
_load_into(ve, _convert_onnx(onnx_dir / "vector_estimator.onnx"))
te = TextEncoder()
_load_into(te, _convert_onnx(onnx_dir / "text_encoder.onnx"))
dp = DurationPredictor()
_load_into(dp, _convert_onnx(onnx_dir / "duration_predictor.onnx"))
voc = Vocoder()
_load_into(voc, _convert_onnx(onnx_dir / "vocoder.onnx"))
if dtype is not None and dtype != mx.float32:
cls._cast_all(dp, te, ve, voc, dtype=dtype)
return cls(dp, te, ve, voc, unicode_indexer, voice_dir)
@staticmethod
def _cast_all(*models, dtype: mx.Dtype) -> None:
"""Cast all fp32 leaves of each model to ``dtype`` (in-place)."""
from mlx.utils import tree_map
def _cast(p):
if not isinstance(p, mx.array) or p.dtype != mx.float32:
return p
return p.astype(dtype)
for m_ in models:
m_.update(tree_map(_cast, m_.parameters()))
def _load_voice(self, voice: str) -> tuple[mx.array, mx.array]:
"""Load ``voice_styles/<voice>.json`` and return (style_ttl, style_dp)."""
path = self.voice_dir / f"{voice}.json"
data = json.loads(path.read_text())
style_ttl = np.asarray(data["style_ttl"]["data"], dtype=np.float32) # (1, 50, 256)
style_dp = np.asarray(data["style_dp"]["data"], dtype=np.float32) # (1, 8, 16)
return mx.array(style_ttl), mx.array(style_dp)
def generate(
self,
text: str,
voice: str = "F1",
lang: str = "en",
seed: int = 42,
n_steps: Optional[int] = None,
) -> np.ndarray:
"""Synthesise a single utterance. Returns a 1D float32 numpy waveform."""
n_steps = n_steps if n_steps is not None else self.n_euler_steps
# Tokenize
text_ids_np = _encode_text(text, self.unicode_indexer, lang)
text_ids = mx.array(text_ids_np[None, :]) # (1, T_text)
T_text = text_ids.shape[1]
text_mask = mx.ones((1, 1, T_text), dtype=self.dtype)
# Style
style_ttl, style_dp = self._load_voice(voice)
if self.dtype != mx.float32:
style_ttl = style_ttl.astype(self.dtype)
style_dp = style_dp.astype(self.dtype)
# Duration → latent length
duration_s = self.duration_predictor(text_ids, style_dp, text_mask)
mx.eval(duration_s)
duration_val = max(float(duration_s[0].item()), 0.5) # clamp to ≥ 0.5 s
T_lat = max(int(math.ceil(duration_val * self.sample_rate / SAMPLES_PER_LATENT_STEP)), 1)
# Text embedding
text_emb = self.text_encoder(text_ids, style_ttl, text_mask) # (1, 256, T_text)
# Initial noise — fixed seed for reproducibility
key = mx.random.key(seed)
noise = mx.random.normal((1, 144, T_lat), key=key).astype(self.dtype)
latent_mask = mx.ones((1, 1, T_lat), dtype=self.dtype)
# T.5.3 — build the (2B) CFG conditioning tensors once and pre-project
# K/V for every text_attn / style_attn block. ``kv_flat`` is the 16
# ``(K, V)`` arrays flattened into a list for the compiled step.
B = noise.shape[0]
ve = self.vector_estimator
text_uncond = mx.broadcast_to(
ve.uncond_masker.text_special_token, (B, text_emb.shape[1], text_emb.shape[2])
).astype(self.dtype)
style_k_uncond = mx.broadcast_to(
ve.uncond_masker.style_key_special_token, (B, style_ttl.shape[1], style_ttl.shape[2])
).astype(self.dtype)
style_v_uncond = mx.broadcast_to(
ve.uncond_masker.style_value_special_token, (B, style_ttl.shape[1], style_ttl.shape[2])
).astype(self.dtype)
text_emb_2 = mx.concatenate([text_emb, text_uncond], axis=0)
style_k_2 = mx.concatenate([style_ttl, style_k_uncond], axis=0)
style_v_2 = mx.concatenate([style_ttl, style_v_uncond], axis=0)
text_mask_2 = mx.concatenate([text_mask, text_mask], axis=0)
latent_mask_2 = mx.concatenate([latent_mask, latent_mask], axis=0)
text_kv, style_kv = ve.precompute_cross_kv(
text_emb_2, style_k_2, style_v_2, text_mask_2,
)
kv_flat = []
for k, v in text_kv:
kv_flat.extend([k, v])
for k, v in style_kv:
kv_flat.extend([k, v])
# Euler with CFG — 5 steps by default
x = noise
total_step = mx.array([float(n_steps)], dtype=self.dtype)
for step in range(n_steps):
current_step = mx.array([float(step + 1)], dtype=self.dtype)
t_norm = current_step / total_step
t_norm_2 = mx.concatenate([t_norm, t_norm], axis=0)
x = self._cached_step_compiled(
x, latent_mask_2, text_mask_2, t_norm_2, total_step, kv_flat,
)
mx.eval(x)
# Decode latent → waveform
wav = self._voc_compiled(x)
mx.eval(wav)
if wav.dtype != mx.float32:
wav = wav.astype(mx.float32)
return np.array(wav)[0] # (T_lat × 6 × 512,)
__all__ = ["SupertonicMLXPipeline"]