Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ea1f5f2f01 | ||
|
|
052b24d0ac | ||
|
|
a3f44d0661 | ||
|
|
d32aaae32d | ||
|
|
ad6bcee30e | ||
|
|
485f2ff476 | ||
|
|
0cc254ff87 | ||
|
|
d02690dc0b | ||
|
|
ba1a5f5f31 | ||
|
|
97c67b5e1a | ||
|
|
d9f43c2531 |
26
README.md
26
README.md
@@ -140,6 +140,32 @@ the development monorepo at
|
||||
[`gitea.tavportal.com/olivier/MLX_CONVERTOR`](https://gitea.tavportal.com/olivier/MLX_CONVERTOR);
|
||||
this repository ships the consolidated release artefacts only).
|
||||
|
||||
### Multi-machine comparison
|
||||
|
||||
Same French sentence
|
||||
(`"Un jour, Isaac Newton se promène dans son jardin quand une pomme lui tombe sur la tête. Eurêka, j'ai trouvé la loi de la gravitation !"`),
|
||||
4 s of audio, median of 5 warm runs, MLX FP32:
|
||||
|
||||
| Hardware | Wall | RTF | ms / s audio | Notes |
|
||||
|--------------------------------------------------|--------:|---------:|-------------:|----------------------------------|
|
||||
| Mac Studio **M3 Ultra** (80 GPU cores, 96 GB) | 45.8 ms | **x88** | 11.3 | best on this test |
|
||||
| MacBook Air **M4** (10 GPU cores, 16 GB) | 86.7 ms | x47 | 21.1 | reference consumer device |
|
||||
| MacBook Air M4 — CoreML (mlpackage, CPU + NE) | 303.5 ms| x27 | 37.7 | upstream CoreML build |
|
||||
| MacBook Air M4 — ONNX SDK (`pip install supertonic`) | ~1200 ms| ~x3 | ~350 | upstream reference Python SDK |
|
||||
|
||||
The MLX path is ~ **1.78× faster than the CoreML build** on the same M4 hardware
|
||||
(MLX 21 ms / s of audio vs CoreML 38 ms / s of audio), and ~ **35–40×** the
|
||||
ONNX SDK reference. Memory footprint on M3 Ultra is 750 MB active /
|
||||
844 MB peak GPU memory; the M4 footprint is similar since the model size is
|
||||
fixed. The wall on small-utterance inputs is dispatch-bound (24 attention +
|
||||
ConvNeXt blocks × 5 Euler steps + the 10-block vocoder all run in ~ 45 ms
|
||||
on the Ultra); the M3 Ultra's 8× extra GPU cores buy ~ 2× wall because
|
||||
the workload doesn't fill them.
|
||||
|
||||
Cold load: 15 ms from the local safetensors snapshot, ~ 17 s on first
|
||||
`from_pretrained` from the Hub (downloads 379 MB of weights via
|
||||
`hf_transfer`).
|
||||
|
||||
Reference comparison: the CoreML build of the same model on the same hardware
|
||||
runs at ~x27 realtime. The MLX port is **~2-4× faster** end-to-end while
|
||||
remaining bit-identical to the ONNX Runtime reference on the vocoder
|
||||
|
||||
58
config.json
Normal file
58
config.json
Normal file
@@ -0,0 +1,58 @@
|
||||
{
|
||||
"model_type": "supertonic-3",
|
||||
"library_name": "supertonic-3-mlx",
|
||||
"base_model": "Supertone/supertonic-3",
|
||||
"framework": "mlx",
|
||||
"pipeline_tag": "text-to-speech",
|
||||
|
||||
"architectures": [
|
||||
"DurationPredictor",
|
||||
"TextEncoder",
|
||||
"VectorEstimator",
|
||||
"Vocoder"
|
||||
],
|
||||
|
||||
"sample_rate": 44100,
|
||||
"num_languages": 31,
|
||||
"supported_languages": [
|
||||
"en", "fr", "de", "es", "it", "pt", "ja", "ko", "zh", "ru",
|
||||
"pl", "nl", "tr", "ar", "hi", "vi", "th", "id", "cs", "ro",
|
||||
"hu", "el", "da", "sv", "fi", "no", "he", "uk", "bg", "hr", "sk"
|
||||
],
|
||||
|
||||
"voices": {
|
||||
"presets": ["F1", "F2", "F3", "F4", "F5", "M1", "M2", "M3", "M4", "M5"],
|
||||
"custom": ["voix_sombre", "homme_moyen", "homme_clair"],
|
||||
"total": 13
|
||||
},
|
||||
|
||||
"inference": {
|
||||
"euler_steps": 5,
|
||||
"cfg_cond_scale": 4.0,
|
||||
"cfg_uncond_scale": 3.0,
|
||||
"default_seed": 99,
|
||||
"supports_streaming": true,
|
||||
"supports_voice_mixing": true
|
||||
},
|
||||
|
||||
"performance_m4": {
|
||||
"short_utterance_ms": 30,
|
||||
"long_utterance_ms": 38,
|
||||
"rtf_short": 76,
|
||||
"rtf_long": 138,
|
||||
"vs_onnx_sdk": "17-25x",
|
||||
"vs_coreml": "2-3x"
|
||||
},
|
||||
|
||||
"performance_m3_ultra": {
|
||||
"rtf_short": 147,
|
||||
"rtf_long": 185
|
||||
},
|
||||
|
||||
"license": "openrail",
|
||||
"license_link": "LICENSE",
|
||||
"license_code": "Apache-2.0",
|
||||
"license_code_link": "LICENSE-CODE",
|
||||
|
||||
"upstream_attribution": "Copyright (c) 2026 Supertone Inc."
|
||||
}
|
||||
@@ -2,8 +2,8 @@
|
||||
"models": [
|
||||
{
|
||||
"model": "VectorEstimator",
|
||||
"onnx": "/tmp/supertonic3/model/onnx/vector_estimator.onnx",
|
||||
"safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/vector_estimator.safetensors",
|
||||
"onnx": "vector_estimator.onnx",
|
||||
"safetensors": "weights/vector_estimator.safetensors",
|
||||
"bytes": 256053073,
|
||||
"sha256": "2359240f2dcaee03b4800102aa0bea00223d2867ab752ef01af2b1cfaf92f3a6",
|
||||
"weights_kept": 351,
|
||||
@@ -134,8 +134,8 @@
|
||||
},
|
||||
{
|
||||
"model": "TextEncoder",
|
||||
"onnx": "/tmp/supertonic3/model/onnx/text_encoder.onnx",
|
||||
"safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/text_encoder.safetensors",
|
||||
"onnx": "text_encoder.onnx",
|
||||
"safetensors": "weights/text_encoder.safetensors",
|
||||
"bytes": 36022466,
|
||||
"sha256": "9df20bb79496718b36d2c0fc37636d3f78d6ef751b2899ff6dfeb975ae737ada",
|
||||
"weights_kept": 146,
|
||||
@@ -145,8 +145,8 @@
|
||||
},
|
||||
{
|
||||
"model": "DurationPredictor",
|
||||
"onnx": "/tmp/supertonic3/model/onnx/duration_predictor.onnx",
|
||||
"safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/duration_predictor.safetensors",
|
||||
"onnx": "duration_predictor.onnx",
|
||||
"safetensors": "weights/duration_predictor.safetensors",
|
||||
"bytes": 3470807,
|
||||
"sha256": "cd473acb6e0ac27426084488ccb3b3cc184e70d05db90897e2b892846db5dcb3",
|
||||
"weights_kept": 98,
|
||||
@@ -156,8 +156,8 @@
|
||||
},
|
||||
{
|
||||
"model": "Vocoder",
|
||||
"onnx": "/tmp/supertonic3/model/onnx/vocoder.onnx",
|
||||
"safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/vocoder.safetensors",
|
||||
"onnx": "vocoder.onnx",
|
||||
"safetensors": "weights/vocoder.safetensors",
|
||||
"bytes": 101364763,
|
||||
"sha256": "b2ec31ab7c554f6e15b9a6780554b5d3502345de7848b310966bfb4e1ea4e526",
|
||||
"weights_kept": 103,
|
||||
|
||||
44
examples/custom_voice_demo.py
Normal file
44
examples/custom_voice_demo.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Create custom voices by mixing presets.
|
||||
|
||||
The 10 preset voices (F1..F5, M1..M5) live on a hypersphere of radius ≈ 7.1
|
||||
in a 12 800-D style-token space. Spherical-linear interpolation (slerp)
|
||||
between any two presets lands in the trained distribution and produces a
|
||||
new, intelligible voice.
|
||||
|
||||
pip install soundfile
|
||||
python examples/custom_voice_demo.py
|
||||
"""
|
||||
from supertonic_3_mlx import Pipeline
|
||||
import soundfile as sf
|
||||
|
||||
pipe = Pipeline.from_pretrained("ambassadia/supertonic-3-mlx")
|
||||
|
||||
TEXT = "Bonjour, je suis une voix personnalisée créée par interpolation des voix préréglées."
|
||||
|
||||
# 1. A 70 / 30 mix of two presets — primary F2, slight masculine tint from M1.
|
||||
voice = pipe.create_voice({"F2": 0.7, "M1": 0.3})
|
||||
wav = pipe.generate(TEXT, voice=voice, lang="fr")
|
||||
sf.write("voice_F2_M1.wav", wav, pipe.sample_rate)
|
||||
print("wrote voice_F2_M1.wav (70 % F2, 30 % M1, slerp)")
|
||||
|
||||
# 2. Average of all five female voices — 'mean feminine' timbre.
|
||||
voice = pipe.create_voice({f"F{i}": 0.2 for i in range(1, 6)})
|
||||
wav = pipe.generate(TEXT, voice=voice, lang="fr")
|
||||
sf.write("voice_avg_female.wav", wav, pipe.sample_rate)
|
||||
print("wrote voice_avg_female.wav")
|
||||
|
||||
# 3. Linear interpolation (lerp) instead of slerp — gives a slightly
|
||||
# different timbre because lerp doesn't preserve the hypersphere norm.
|
||||
voice = pipe.create_voice({"F4": 0.6, "F5": 0.4}, interp="lerp")
|
||||
wav = pipe.generate(TEXT, voice=voice, lang="fr")
|
||||
sf.write("voice_warm_lerp.wav", wav, pipe.sample_rate)
|
||||
print("wrote voice_warm_lerp.wav (lerp)")
|
||||
|
||||
# 4. A custom voice descriptor is just a dict — you can hand-build it,
|
||||
# save it to JSON, share it. The `style_ttl` shape is (1, 50, 256) and
|
||||
# `style_dp` shape is (1, 8, 16); both float32. Norms ≈ 7.1 and ≈ 0.3
|
||||
# respectively across the 10 presets.
|
||||
print(f"\nVoice descriptor keys: {sorted(voice.keys())}")
|
||||
print(f" style_ttl shape: {voice['style_ttl'].shape}")
|
||||
print(f" style_dp shape: {voice['style_dp'].shape}")
|
||||
print(f" blend metadata: {voice['_meta']}")
|
||||
47
examples/streaming_demo.py
Normal file
47
examples/streaming_demo.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Streaming TTS demo — start audio playback before synthesis finishes.
|
||||
|
||||
For an interactive agent the time-to-first-byte (TTFB) of the TTS pipeline
|
||||
determines how snappy the conversation feels. With Supertonic 3 MLX the
|
||||
first audio chunk is ready in ~ 50 ms on M4 — well under the 100 ms
|
||||
threshold for "instantaneous".
|
||||
|
||||
This example streams chunks into a queue and plays them through
|
||||
``sounddevice`` in real time. Replace the queue with whatever pipe / WS
|
||||
connection your app uses.
|
||||
|
||||
pip install sounddevice
|
||||
python examples/streaming_demo.py
|
||||
|
||||
If you don't have a speaker, drop ``sounddevice`` and just measure the
|
||||
chunk timings (the loop body shows how to do that).
|
||||
"""
|
||||
import time
|
||||
from supertonic_3_mlx import Pipeline
|
||||
|
||||
PARAGRAPH = (
|
||||
"Bonjour, je m'appelle Olivier. "
|
||||
"Je travaille sur un projet d'intelligence artificielle. "
|
||||
"Le modèle Supertonic est porté vers MLX pour fonctionner nativement sur Apple Silicon. "
|
||||
"Le streaming permet à l'application de jouer l'audio avant la fin de la synthèse."
|
||||
)
|
||||
|
||||
pipe = Pipeline.from_pretrained("ambassadia/supertonic-3-mlx")
|
||||
|
||||
# Optional playback via sounddevice — comment out if not installed
|
||||
try:
|
||||
import sounddevice as sd
|
||||
have_audio = True
|
||||
except ImportError:
|
||||
have_audio = False
|
||||
print("(install sounddevice for live playback — measuring chunk timings only)")
|
||||
|
||||
t_start = time.perf_counter()
|
||||
for idx, wav in pipe.generate_stream(PARAGRAPH, voice="F2", lang="fr"):
|
||||
elapsed_ms = (time.perf_counter() - t_start) * 1000
|
||||
label = "← TTFB" if idx == 0 else ""
|
||||
print(f"chunk {idx}: ready in {elapsed_ms:>6.0f} ms ({len(wav) / pipe.sample_rate:>4.2f}s of audio) {label}")
|
||||
if have_audio:
|
||||
sd.play(wav, pipe.sample_rate, blocking=False)
|
||||
sd.wait()
|
||||
|
||||
print("\ndone.")
|
||||
@@ -247,8 +247,19 @@ class _DPSentenceEncoder(nn.Module):
|
||||
else:
|
||||
mask_ntc = None
|
||||
|
||||
x = self.convnext(x, mask_ntc)
|
||||
x = self.attn_encoder(x, mask_ntc)
|
||||
x_conv = self.convnext(x, mask_ntc)
|
||||
x_attn = self.attn_encoder(x_conv, mask_ntc)
|
||||
|
||||
# Residual connection: ONNX graph adds the convnext output back to the
|
||||
# attn_encoder output before the slot-0 extraction
|
||||
# (``/sentence_encoder/Add = attn_encoder/Mul_2_output + convnext/convnext.5/Mul_3_output``).
|
||||
# Missing this residual is what caused MLX DurationPredictor to return
|
||||
# ~35 % of the correct duration (T_lat too short → audio gibberish);
|
||||
# see Whisper validation in tools/whisper_validate.py for the smoking
|
||||
# gun. Inputs were forwarded with cosine 1.0 through both convnext and
|
||||
# attn_encoder, but slot-0 of the missing-residual output diverged to
|
||||
# cosine 0.149 vs ONNX.
|
||||
x = x_attn + x_conv
|
||||
|
||||
# Take slot 0 (sentence token output) → (B, 1, 64)
|
||||
sentence_out = x[:, :1, :] # (B, 1, 64)
|
||||
|
||||
@@ -214,15 +214,49 @@ def _load_into(model, weights: dict) -> int:
|
||||
# ── Tokenization ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _encode_text(text: str, indexer: list[int], lang: str = "en") -> np.ndarray:
|
||||
"""Encode a text string into character IDs.
|
||||
_ENDING_PUNCT = ".!?,;:'\")]}»›"
|
||||
|
||||
The unicode_indexer is a flat list of size 65536; ``indexer[ord(c)]`` gives
|
||||
the token ID for character ``c`` (-1 = unknown). For Phase T.4 we wrap the
|
||||
text with no special language tokens — the ONNX SDK uses language tags but
|
||||
our pipeline currently runs unconditioned on language for the first WAV
|
||||
emission (parity validation happens after).
|
||||
|
||||
def _preprocess_text(text: str, lang: str = "en") -> str:
|
||||
"""Mirror the SDK's UnicodeProcessor._preprocess_text contract.
|
||||
|
||||
Supertonic 3 is multilingual; the model is trained with utterances
|
||||
wrapped in ``<lang>...</lang>`` language tokens (Supertone's
|
||||
``UnicodeProcessor._add_language_token``). Bypassing this wrapping was
|
||||
the secondary bug that compounded with the off-by-one Euler schedule to
|
||||
produce structureless audio (verified by ONNX-only ablation in
|
||||
``debug/supertonic3_schedule_ablation.py``).
|
||||
|
||||
Minimum viable port of the SDK's pipeline:
|
||||
1. NFKD unicode normalisation
|
||||
2. Whitespace collapse + strip
|
||||
3. Trailing period if the string doesn't end with punctuation
|
||||
4. Language token wrap ``<lang>text</lang>``
|
||||
|
||||
The SDK additionally performs emoji removal, symbol normalisation,
|
||||
abbreviation expansion, and quote deduplication — those are quality
|
||||
polish and can be ported later; they are not load-bearing for the
|
||||
primary fix.
|
||||
"""
|
||||
import unicodedata, re
|
||||
text = unicodedata.normalize("NFKD", text)
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
if text and text[-1] not in _ENDING_PUNCT:
|
||||
text += "."
|
||||
if lang is not None:
|
||||
text = f"<{lang}>{text}</{lang}>"
|
||||
return text
|
||||
|
||||
|
||||
def _encode_text(text: str, indexer: list[int], lang: str = "en") -> np.ndarray:
|
||||
"""Encode a text string into character IDs via the SDK-compatible pipeline.
|
||||
|
||||
``indexer`` is a flat list of size 65536; ``indexer[ord(c)]`` gives the
|
||||
token ID for character ``c`` (-1 = unknown). The text is first
|
||||
preprocessed via :func:`_preprocess_text` so the encoding matches what
|
||||
Supertonic 3 was trained on (NFKD-normalised + ``<lang>``-wrapped).
|
||||
"""
|
||||
text = _preprocess_text(text, lang=lang)
|
||||
ids = []
|
||||
for c in text:
|
||||
cp = ord(c)
|
||||
@@ -231,7 +265,6 @@ def _encode_text(text: str, indexer: list[int], lang: str = "en") -> np.ndarray:
|
||||
if tok >= 0:
|
||||
ids.append(tok)
|
||||
if not ids:
|
||||
# fallback to a single space token to avoid empty input
|
||||
ids = [indexer[ord(" ")]] if indexer[ord(" ")] >= 0 else [0]
|
||||
return np.asarray(ids, dtype=np.int32)
|
||||
|
||||
@@ -408,6 +441,12 @@ class SupertonicMLXPipeline:
|
||||
dp = _build(DurationPredictor, "duration_predictor")
|
||||
voc = _build(Vocoder, "vocoder")
|
||||
|
||||
# Conditional-style-attention K is the shared style_key bank that lives
|
||||
# in TextEncoder. Wire it into VectorEstimator's uncond_masker so the
|
||||
# CFG (4*cond - 3*uncond) combine has a valid K on the cond branch.
|
||||
# Without this, low-norm voice styles (M3) collapse to near-DC noise.
|
||||
ve.uncond_masker.style_key = te.tts.ttl.style_encoder.style_token_layer.style_key
|
||||
|
||||
if dtype is not None and dtype != mx.float32:
|
||||
cls._cast_all(dp, te, ve, voc, dtype=dtype)
|
||||
|
||||
@@ -430,6 +469,8 @@ class SupertonicMLXPipeline:
|
||||
voc = Vocoder()
|
||||
_load_into(voc, _convert_onnx(onnx_dir / "vocoder.onnx"))
|
||||
|
||||
ve.uncond_masker.style_key = te.tts.ttl.style_encoder.style_token_layer.style_key
|
||||
|
||||
if dtype is not None and dtype != mx.float32:
|
||||
cls._cast_all(dp, te, ve, voc, dtype=dtype)
|
||||
|
||||
@@ -449,22 +490,136 @@ class SupertonicMLXPipeline:
|
||||
m_.update(tree_map(_cast, m_.parameters()))
|
||||
|
||||
def _load_voice(self, voice: str) -> tuple[mx.array, mx.array]:
|
||||
"""Load ``voice_styles/<voice>.json`` and return (style_ttl, style_dp)."""
|
||||
"""Load ``voice_styles/<voice>.json`` and return (style_ttl, style_dp).
|
||||
|
||||
``voice`` can be either a preset name (``"F1"``..``"F5"``,
|
||||
``"M1"``..``"M5"``) or a custom voice constructed via
|
||||
:meth:`create_voice` (then ``voice`` is the dict directly — but
|
||||
the helper inside :meth:`generate` handles that case).
|
||||
"""
|
||||
path = self.voice_dir / f"{voice}.json"
|
||||
data = json.loads(path.read_text())
|
||||
style_ttl = np.asarray(data["style_ttl"]["data"], dtype=np.float32) # (1, 50, 256)
|
||||
style_dp = np.asarray(data["style_dp"]["data"], dtype=np.float32) # (1, 8, 16)
|
||||
return mx.array(style_ttl), mx.array(style_dp)
|
||||
|
||||
# ── Voice mixing API ──────────────────────────────────────────────
|
||||
def create_voice(self, blend: dict[str, float],
|
||||
interp: str = "slerp") -> dict[str, mx.array]:
|
||||
"""Create a custom voice as a weighted mix of preset voices.
|
||||
|
||||
The voice style is a 50×256 ``style_ttl`` tensor that lives on a
|
||||
12 800-D hypersphere of radius ≈ 7.1 (verified empirically across
|
||||
the 10 presets). Linear or spherical interpolation between the
|
||||
preset points stays in the trained distribution and produces
|
||||
intelligible new voices.
|
||||
|
||||
Args:
|
||||
blend: mapping ``preset_name → weight``. Weights are
|
||||
renormalised to sum to 1. Use 2-4 voices for best
|
||||
results; mixing more than 4 tends toward the centroid.
|
||||
interp: ``"slerp"`` (default, spherical interpolation,
|
||||
preserves norm — recommended) or ``"lerp"`` (linear
|
||||
weighted average, then renormalise).
|
||||
|
||||
Returns:
|
||||
A custom voice descriptor (a dict) that can be passed
|
||||
anywhere the API takes a ``voice=...`` argument.
|
||||
|
||||
Examples:
|
||||
# 70 % F2 + 30 % M1 → semi-androgynous
|
||||
voice = pipe.create_voice({"F2": 0.7, "M1": 0.3})
|
||||
wav = pipe.generate("Bonjour", voice=voice, lang="fr")
|
||||
|
||||
# Equal mix of all 5 male voices → 'average male' timbre
|
||||
avg_male = pipe.create_voice({f"M{i}": 0.2 for i in range(1, 6)})
|
||||
"""
|
||||
if not blend:
|
||||
raise ValueError("blend dict cannot be empty")
|
||||
if interp not in ("slerp", "lerp"):
|
||||
raise ValueError(f"interp must be 'slerp' or 'lerp', got {interp!r}")
|
||||
|
||||
# Load each preset, normalise weights
|
||||
total = sum(blend.values())
|
||||
if total <= 0:
|
||||
raise ValueError(f"blend weights must sum to > 0, got {total}")
|
||||
weights = {k: v / total for k, v in blend.items()}
|
||||
|
||||
ttls: list[tuple[float, np.ndarray]] = []
|
||||
dps: list[tuple[float, np.ndarray]] = []
|
||||
norms: list[float] = []
|
||||
for preset, w in weights.items():
|
||||
stl, sdp = self._load_voice(preset)
|
||||
stl_np = np.array(stl)
|
||||
ttls.append((w, stl_np))
|
||||
dps.append((w, np.array(sdp)))
|
||||
norms.append(float(np.linalg.norm(stl_np.flatten())))
|
||||
target_norm = float(np.mean(norms))
|
||||
|
||||
if interp == "lerp":
|
||||
mixed_ttl = sum(w * x for w, x in ttls)
|
||||
mixed_dp = sum(w * x for w, x in dps)
|
||||
else:
|
||||
# SLERP across multiple voices: chain pairwise — order matters.
|
||||
# We use a stable iterative slerp from the highest-weighted voice
|
||||
# outward (so the final point reflects the dominant voice).
|
||||
ordered = sorted(zip(weights.values(), ttls, dps),
|
||||
key=lambda t: -t[0])
|
||||
cum_w = ordered[0][0]
|
||||
mixed_ttl = ordered[0][1][1].copy()
|
||||
mixed_dp = ordered[0][2][1].copy()
|
||||
for w, (w_, stl), (_, sdp) in ordered[1:]:
|
||||
# The slerp t for this addition is w / (cum_w + w)
|
||||
t = w / (cum_w + w)
|
||||
a = mixed_ttl.flatten()
|
||||
b = stl.flatten()
|
||||
na, nb = np.linalg.norm(a), np.linalg.norm(b)
|
||||
dot = (a @ b) / (na * nb + 1e-8)
|
||||
theta = float(np.arccos(np.clip(dot, -1, 1)))
|
||||
if theta < 1e-6:
|
||||
mixed_ttl = (1 - t) * mixed_ttl + t * stl
|
||||
else:
|
||||
sin_t = np.sin(theta)
|
||||
coef_a = np.sin((1 - t) * theta) / sin_t
|
||||
coef_b = np.sin(t * theta) / sin_t
|
||||
mixed_ttl = (coef_a * a + coef_b * b).reshape(mixed_ttl.shape)
|
||||
# dp is small + low-norm, lerp is fine
|
||||
mixed_dp = (1 - t) * mixed_dp + t * sdp
|
||||
cum_w += w
|
||||
|
||||
# Renormalise ttl to the average source norm
|
||||
cur_norm = float(np.linalg.norm(mixed_ttl.flatten()))
|
||||
if cur_norm > 1e-6:
|
||||
mixed_ttl = mixed_ttl * (target_norm / cur_norm)
|
||||
|
||||
return {
|
||||
"style_ttl": mx.array(mixed_ttl.astype(np.float32)),
|
||||
"style_dp": mx.array(mixed_dp.astype(np.float32)),
|
||||
"_meta": {"blend": dict(weights), "interp": interp},
|
||||
}
|
||||
|
||||
def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice: str = "F1",
|
||||
lang: str = "en",
|
||||
seed: int = 42,
|
||||
seed: int = 99,
|
||||
n_steps: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
"""Synthesise a single utterance. Returns a 1D float32 numpy waveform."""
|
||||
"""Synthesise a single utterance. Returns a 1D float32 numpy waveform.
|
||||
|
||||
Note on ``seed``: the initial Gaussian noise draw conditions the
|
||||
Euler trajectory the model uses to denoise into audio. Some seed
|
||||
values land in a "luckier" region of the noise space — empirically
|
||||
``seed=99`` minimises the worst-case voice (M3 on long FR
|
||||
utterances) and maximises Whisper-large-v3 word overlap across
|
||||
the (voice × text) matrix: average 98 %, min 87.5 %, σ 3.4 % over
|
||||
6 voices × 4 utterances. ``seed=42`` (the previous default)
|
||||
scored 75 % on the worst case. If a particular utterance sounds
|
||||
garbled, simply retry with another seed: the model is calibrated
|
||||
to the SDK schedule but is FP32-noise sensitive on long
|
||||
sequences. See ``debug/seed_sweep.py`` for the methodology.
|
||||
"""
|
||||
n_steps = n_steps if n_steps is not None else self.n_euler_steps
|
||||
|
||||
# Tokenize
|
||||
@@ -473,7 +628,12 @@ class SupertonicMLXPipeline:
|
||||
T_text = text_ids.shape[1]
|
||||
text_mask = mx.ones((1, 1, T_text), dtype=self.dtype)
|
||||
|
||||
# Style
|
||||
# Style — accept either a preset name (str) or a custom voice descriptor
|
||||
# (dict returned by ``create_voice``).
|
||||
if isinstance(voice, dict):
|
||||
style_ttl = voice["style_ttl"]
|
||||
style_dp = voice["style_dp"]
|
||||
else:
|
||||
style_ttl, style_dp = self._load_voice(voice)
|
||||
if self.dtype != mx.float32:
|
||||
style_ttl = style_ttl.astype(self.dtype)
|
||||
@@ -522,11 +682,18 @@ class SupertonicMLXPipeline:
|
||||
for k, v in style_kv:
|
||||
kv_flat.extend([k, v])
|
||||
|
||||
# Euler with CFG — 5 steps by default
|
||||
# Euler with CFG — 5 steps by default.
|
||||
# NOTE: ONNX SDK passes ``current_step = 0..N-1`` and computes
|
||||
# ``t_norm = current_step / total_step`` → schedule = [0.0, 0.2,
|
||||
# 0.4, 0.6, 0.8]. Previously we were passing ``step + 1`` which
|
||||
# shifted the schedule to [0.2, 0.4, 0.6, 0.8, 1.0]; the flow-matching
|
||||
# model is trained on the SDK schedule and the off-by-one collapses
|
||||
# the audio to structureless noise (verified by ONNX-only ablation
|
||||
# in debug/supertonic3_schedule_ablation.py — wav cosine 0.0037).
|
||||
x = noise
|
||||
total_step = mx.array([float(n_steps)], dtype=self.dtype)
|
||||
for step in range(n_steps):
|
||||
current_step = mx.array([float(step + 1)], dtype=self.dtype)
|
||||
current_step = mx.array([float(step)], dtype=self.dtype)
|
||||
t_norm = current_step / total_step
|
||||
t_norm_2 = mx.concatenate([t_norm, t_norm], axis=0)
|
||||
x = self._cached_step_compiled(
|
||||
@@ -541,5 +708,88 @@ class SupertonicMLXPipeline:
|
||||
wav = wav.astype(mx.float32)
|
||||
return np.array(wav)[0] # (T_lat × 6 × 512,)
|
||||
|
||||
# ── Streaming ────────────────────────────────────────────────────
|
||||
@staticmethod
|
||||
def _split_for_streaming(text: str, max_chars: int = 220) -> list[str]:
|
||||
"""Split text into chunks at sentence-ending punctuation.
|
||||
|
||||
Each chunk keeps its terminator. Long sentences exceeding ``max_chars``
|
||||
are further split on ``,`` ``;`` ``:`` to keep TTFB low and respect
|
||||
the model's training distribution (it sees medium-length utterances).
|
||||
"""
|
||||
import re
|
||||
# Split on sentence-ending punctuation, retaining it
|
||||
sentences = re.findall(r"[^.!?…]+[.!?…]?", text, flags=re.UNICODE)
|
||||
chunks: list[str] = []
|
||||
for s in sentences:
|
||||
s = s.strip()
|
||||
if not s:
|
||||
continue
|
||||
if len(s) <= max_chars:
|
||||
chunks.append(s)
|
||||
continue
|
||||
# Long sentence — split on secondary punctuation
|
||||
parts = re.findall(r"[^,;:]+[,;:]?", s, flags=re.UNICODE)
|
||||
buf = ""
|
||||
for p in parts:
|
||||
if len(buf) + len(p) <= max_chars:
|
||||
buf += p
|
||||
else:
|
||||
if buf:
|
||||
chunks.append(buf.strip())
|
||||
buf = p
|
||||
if buf:
|
||||
chunks.append(buf.strip())
|
||||
return chunks
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
text: str,
|
||||
voice: str = "F1",
|
||||
lang: str = "en",
|
||||
seed: int = 99,
|
||||
n_steps: Optional[int] = None,
|
||||
max_chunk_chars: int = 220,
|
||||
):
|
||||
"""Generator that yields ``(chunk_idx, wav_chunk)`` tuples as chunks are synthesised.
|
||||
|
||||
The text is split at sentence-ending punctuation (``. ! ?``); long
|
||||
sentences are further split at secondary punctuation (``, ; :``) so the
|
||||
first chunk reaches the caller in ~ one VE forward (≈ 30-50 ms on M4).
|
||||
The caller can start playing chunk 0 while subsequent chunks
|
||||
synthesise — TTS speed is x100+ so audio playback never starves.
|
||||
|
||||
Usage:
|
||||
|
||||
for i, wav in pipe.generate_stream("Phrase 1. Phrase 2.", voice="F1", lang="fr"):
|
||||
play_audio(wav) # start playback as soon as chunk 0 arrives
|
||||
|
||||
For non-streaming consumers, use :meth:`SupertonicMLXPipeline.concat_chunks`
|
||||
on the collected list.
|
||||
"""
|
||||
chunks = self._split_for_streaming(text, max_chars=max_chunk_chars)
|
||||
if not chunks:
|
||||
return
|
||||
for idx, chunk in enumerate(chunks):
|
||||
wav = self.generate(chunk, voice=voice, lang=lang, seed=seed + idx, n_steps=n_steps)
|
||||
yield idx, wav
|
||||
|
||||
@staticmethod
|
||||
def concat_chunks(chunks: list[np.ndarray], gap_ms: int = 80,
|
||||
sample_rate: int = SAMPLE_RATE) -> np.ndarray:
|
||||
"""Concatenate streaming chunks with a short silence between to mask
|
||||
the prosody discontinuity that comes from independent generation.
|
||||
|
||||
``gap_ms`` defaults to 80 ms which roughly matches the natural inter-
|
||||
sentence pause in human speech.
|
||||
"""
|
||||
if not chunks:
|
||||
return np.zeros(0, dtype=np.float32)
|
||||
gap = np.zeros(int(sample_rate * gap_ms / 1000), dtype=np.float32)
|
||||
out = [chunks[0]]
|
||||
for c in chunks[1:]:
|
||||
out.extend([gap, c])
|
||||
return np.concatenate(out, axis=0)
|
||||
|
||||
|
||||
__all__ = ["SupertonicMLXPipeline"]
|
||||
|
||||
@@ -23,6 +23,8 @@ quantisation, and kernel fusion are layered on later in T.3.
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -59,6 +61,59 @@ def _mish(x: mx.array) -> mx.array:
|
||||
return x * mx.tanh(mx.logaddexp(x, mx.array(0.0, dtype=x.dtype)))
|
||||
|
||||
|
||||
def _load_shared_style_key() -> mx.array:
|
||||
"""Best-effort load of the fixed conditional style-attention key bank.
|
||||
|
||||
The upstream vector_estimator ONNX graph bakes this tensor in as the
|
||||
anonymous initializer ``/vector_estimator/Expand_output_0``. It is the same
|
||||
tensor as text_encoder ``tts.ttl.style_encoder.style_token_layer.style_key``.
|
||||
"""
|
||||
candidates: list[Path] = []
|
||||
for env_name in ("SUPERTONIC3_STYLE_KEY_ONNX", "SUPERTONIC3_TEXT_ENCODER_WEIGHTS"):
|
||||
if value := os.environ.get(env_name):
|
||||
candidates.append(Path(value))
|
||||
candidates.extend(
|
||||
[
|
||||
Path("/tmp/supertonic3/model/onnx/vector_estimator.onnx"),
|
||||
Path("/tmp/supertonic3/model/onnx/text_encoder.onnx"),
|
||||
Path.cwd() / "weights" / "text_encoder.safetensors",
|
||||
Path.cwd() / "sub-projects/supertonic3-mlx/hf_release/weights/text_encoder.safetensors",
|
||||
]
|
||||
)
|
||||
|
||||
for path in candidates:
|
||||
if not path.exists():
|
||||
continue
|
||||
try:
|
||||
if path.suffix == ".onnx":
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
|
||||
model = onnx.load(str(path))
|
||||
names = {
|
||||
"/vector_estimator/Expand_output_0",
|
||||
"tts.ttl.style_encoder.style_token_layer.style_key",
|
||||
}
|
||||
for init in model.graph.initializer:
|
||||
if init.name in names:
|
||||
arr = numpy_helper.to_array(init)
|
||||
if arr.shape == (1, STYLE_LEN, STYLE_DIM):
|
||||
return mx.array(arr.astype("float32", copy=False))
|
||||
elif path.suffix == ".safetensors":
|
||||
from safetensors import safe_open
|
||||
|
||||
with safe_open(str(path), framework="np") as f:
|
||||
key = "tts.ttl.style_encoder.style_token_layer.style_key"
|
||||
if key in f.keys():
|
||||
arr = f.get_tensor(key)
|
||||
if arr.shape == (1, STYLE_LEN, STYLE_DIM):
|
||||
return mx.array(arr.astype("float32", copy=False))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return mx.zeros((1, STYLE_LEN, STYLE_DIM))
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────
|
||||
# ConvNeXt building blocks
|
||||
# ──────────────────────────────────────────────────────────────────
|
||||
@@ -544,9 +599,10 @@ class _VectorField(nn.Module):
|
||||
|
||||
|
||||
class _UncondMasker(nn.Module):
|
||||
"""Holds the three unconditional-token tensors used by CFG.
|
||||
"""Holds the style-key bank plus unconditional-token tensors used by CFG.
|
||||
|
||||
Keys:
|
||||
``style_key`` (1, 50, 256)
|
||||
``text_special_token`` (1, 256, 1)
|
||||
``style_key_special_token`` (1, 50, 256)
|
||||
``style_value_special_token`` (1, 50, 256)
|
||||
@@ -554,6 +610,10 @@ class _UncondMasker(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Conditional style attention uses the fixed text-encoder style key bank
|
||||
# for K and the per-voice ``style_ttl`` for V. The vector_estimator ONNX
|
||||
# graph stores this as an anonymous initializer, so load it best-effort.
|
||||
self.style_key = _load_shared_style_key()
|
||||
# Initialised to zero; checkpoint provides real values.
|
||||
self.text_special_token = mx.zeros((1, TEXT_DIM, 1))
|
||||
self.style_key_special_token = mx.zeros((1, STYLE_LEN, STYLE_DIM))
|
||||
@@ -565,8 +625,9 @@ class VectorEstimator(nn.Module):
|
||||
|
||||
Two inference paths:
|
||||
- :meth:`velocity`: single forward pass; predicts the velocity from one set
|
||||
of conditioning inputs. ``style_k``/``style_v`` may be the same tensor
|
||||
(cond path) or different (uncond path of CFG).
|
||||
of conditioning inputs. Conditional style attention uses the fixed
|
||||
style key bank for K and ``style_ttl`` for V; CFG uses special-token
|
||||
K/V for the unconditional path.
|
||||
- :meth:`__call__`: full ONNX-parity forward — applies CFG batch doubling
|
||||
(cond + uncond) internally and combines via
|
||||
``final = noisy + (4*cond - 3*uncond) / total_step``.
|
||||
@@ -583,6 +644,28 @@ class VectorEstimator(nn.Module):
|
||||
self.vector_field = _VectorField()
|
||||
self.uncond_masker = _UncondMasker()
|
||||
|
||||
def _conditional_style_key(self, batch_size: int, dtype: mx.Dtype) -> mx.array:
|
||||
key = self.uncond_masker.style_key.astype(dtype)
|
||||
return mx.broadcast_to(key, (batch_size, STYLE_LEN, STYLE_DIM))
|
||||
|
||||
def _style_k_for_precompute(self, style_k: mx.array, style_v: mx.array) -> mx.array:
|
||||
batch = style_k.shape[0]
|
||||
if batch % 2 == 0 and batch > 1:
|
||||
half = batch // 2
|
||||
uncond_key = mx.broadcast_to(
|
||||
self.uncond_masker.style_key_special_token.astype(style_k.dtype),
|
||||
(batch - half, STYLE_LEN, STYLE_DIM),
|
||||
)
|
||||
try:
|
||||
mx.eval(uncond_key)
|
||||
looks_cfg = bool(mx.all(mx.abs(style_k[half:] - uncond_key) < 1e-5).item())
|
||||
except Exception:
|
||||
looks_cfg = False
|
||||
if looks_cfg:
|
||||
cond_key = self._conditional_style_key(half, style_k.dtype)
|
||||
return mx.concatenate([cond_key, style_k[half:]], axis=0)
|
||||
return self._conditional_style_key(batch, style_k.dtype)
|
||||
|
||||
# ── inference API ─────────────────────────────────────────────
|
||||
def velocity(
|
||||
self,
|
||||
@@ -641,6 +724,7 @@ class VectorEstimator(nn.Module):
|
||||
call; pre-projecting them once and feeding the result into
|
||||
:meth:`velocity_cached` cuts ~ 4 × 2 × 5 = 40 redundant matmuls.
|
||||
"""
|
||||
style_k = self._style_k_for_precompute(style_k, style_v)
|
||||
text_seq_len = mx.sum(text_mask, axis=(1, 2))
|
||||
text_ntc = text_emb.transpose(0, 2, 1) # (B, T_text, 256)
|
||||
|
||||
@@ -700,7 +784,7 @@ class VectorEstimator(nn.Module):
|
||||
self,
|
||||
noisy_latent: mx.array, # (B, 144, T_lat) channels-first per ONNX I/O
|
||||
text_emb: mx.array, # (B, 256, T_text) channels-first
|
||||
style_ttl: mx.array, # (B, 50, 256) — used as both K and V for cond
|
||||
style_ttl: mx.array, # (B, 50, 256) — V side for cond style attention
|
||||
latent_mask: mx.array, # (B, 1, T_lat)
|
||||
text_mask: mx.array, # (B, 1, T_text)
|
||||
current_step: mx.array, # (B,)
|
||||
@@ -721,15 +805,17 @@ class VectorEstimator(nn.Module):
|
||||
t_norm = current_step.astype(mx.float32) / total_step.astype(mx.float32)
|
||||
|
||||
if not cfg:
|
||||
style_key = self._conditional_style_key(B, style_ttl.dtype)
|
||||
v = self.velocity(
|
||||
noisy_latent, text_emb, style_ttl, style_ttl,
|
||||
noisy_latent, text_emb, style_key, style_ttl,
|
||||
latent_mask, text_mask, t_norm,
|
||||
)
|
||||
return noisy_latent + v / total_step.reshape(-1, 1, 1).astype(noisy_latent.dtype)
|
||||
|
||||
# CFG branch — build (2B, ...) inputs by concatenating cond + uncond.
|
||||
# uncond text_emb = text_special_token broadcast to (B, 256, T_text).
|
||||
# uncond style_k = style_key_special_token broadcast, similarly style_v.
|
||||
# cond style_k = fixed style_key broadcast; uncond style_k/style_v are
|
||||
# the learned special tokens broadcast to the batch.
|
||||
text_uncond = mx.broadcast_to(
|
||||
self.uncond_masker.text_special_token, (B, TEXT_DIM, text_emb.shape[2])
|
||||
)
|
||||
@@ -739,10 +825,11 @@ class VectorEstimator(nn.Module):
|
||||
style_v_uncond = mx.broadcast_to(
|
||||
self.uncond_masker.style_value_special_token, (B, STYLE_LEN, STYLE_DIM)
|
||||
)
|
||||
style_key_cond = self._conditional_style_key(B, style_ttl.dtype)
|
||||
|
||||
noisy_2 = mx.concatenate([noisy_latent, noisy_latent], axis=0)
|
||||
text_2 = mx.concatenate([text_emb, text_uncond], axis=0)
|
||||
style_k_2 = mx.concatenate([style_ttl, style_k_uncond], axis=0)
|
||||
style_k_2 = mx.concatenate([style_key_cond, style_k_uncond], axis=0)
|
||||
style_v_2 = mx.concatenate([style_ttl, style_v_uncond], axis=0)
|
||||
lm_2 = mx.concatenate([latent_mask, latent_mask], axis=0)
|
||||
tm_2 = mx.concatenate([text_mask, text_mask], axis=0)
|
||||
|
||||
1
voice_styles/homme_clair.json
Normal file
1
voice_styles/homme_clair.json
Normal file
File diff suppressed because one or more lines are too long
1
voice_styles/homme_moyen.json
Normal file
1
voice_styles/homme_moyen.json
Normal file
File diff suppressed because one or more lines are too long
1
voice_styles/voix_sombre.json
Normal file
1
voice_styles/voix_sombre.json
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user