Files
supertonic-3-mlx/src/supertonic_3_mlx/text_encoder.py
transcrilive 12dbf4a821 v0.1.0 — initial release
MLX-native port of Supertone's Supertonic 3 multilingual TTS. Runs the
full flow-matching + classifier-free-guidance pipeline at ~x100 realtime
on Apple Silicon, with audio cosine 1.0 vs the cached MLX path and
cosine 0.98 vs the upstream ONNX Runtime reference.

Weights are hosted at https://huggingface.co/ambassadia/supertonic-3-mlx
and auto-downloaded on first use; this repository ships the port code,
the model card, audio samples, and a zero-config setup_and_test.sh.

Install:
    pip install git+https://gitea.tavportal.com/olivier/supertonic-3-mlx.git

Quick test:
    git clone https://gitea.tavportal.com/olivier/supertonic-3-mlx.git
    cd supertonic-3-mlx && ./setup_and_test.sh

Licenses (dual): model weights = BigScience Open RAIL-M (Section 4
propagation), port code = Apache-2.0. See LICENSE, LICENSE-CODE, NOTICE.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 09:17:05 +02:00

383 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Supertonic 3 text encoder MLX port.
Pipeline (operating in channels-last NTC after the initial conv):
text_ids [B, T_text] int64 character IDs
→ char_embedder (Embedding 8322→256) [B, T_text, 256]
→ 6× ConvNeXt(dim=256, hidden=1024, k=5, dilations [1,1,2,2,4,4])
→ 4× attn_encoder block:
RelPosSelfAttn (conv_q/k/v/o, 4 heads × 64) + norm_layers_1
FFN (conv_1: 256→1024, conv_2: 1024→256) + norm_layers_2
→ speech_prompted_text_encoder:
cross-attn1: text (Q) × style_ttl (K, V) → text features
cross-attn2: text (Q) × style_ttl (K, V) → text features
norm
→ output text_emb [B, 256, T_text] (channels-first to match vector_estimator)
Inputs:
text_ids: (B, T_text) int — character indices
style_ttl: (B, 50, 256) float — style token bank
text_mask: (B, 1, T_text) float — 1.0 where valid, 0.0 where padded
Submodule naming matches the ONNX initializer keys exactly so that
``model.load_weights(...)`` succeeds with no remapping.
"""
from __future__ import annotations
import mlx.core as mx
import mlx.nn as nn
from supertonic_3_mlx._config import EPS_LN
from supertonic_3_mlx._nn_wrappers import WrappedNorm, WrappedLinear
from supertonic_3_mlx.vector_estimator import (
ConvNeXtBlock, _pad_sym_edge, _gelu_exact,
)
# Vocab + dims (frozen by checkpoint)
VOCAB_SIZE = 8322
TE_DIM = 256
TE_CONVNEXT_HIDDEN = 1024
TE_CONVNEXT_K = 5
TE_CONVNEXT_NUM_LAYERS = 6
TE_CONVNEXT_DILATIONS = (1, 1, 2, 2, 4, 4)
TE_ATTN_NUM_LAYERS = 4
TE_ATTN_HEADS = 4
TE_ATTN_HEAD_DIM = TE_DIM // TE_ATTN_HEADS # 64
TE_FFN_HIDDEN = 1024
class TextConvNeXtBlock(nn.Module):
"""ConvNeXt for the text encoder (dim=256, hidden=1024).
Shares the same architecture as ``vector_estimator.ConvNeXtBlock`` but is
redefined here with text-encoder-specific defaults to keep the modules
self-contained.
"""
def __init__(self, dilation: int = 1) -> None:
super().__init__()
self.dim = TE_DIM
self.dilation = dilation
self.pad = dilation * (TE_CONVNEXT_K - 1) // 2
self.dwconv = nn.Conv1d(
TE_DIM, TE_DIM, kernel_size=TE_CONVNEXT_K, padding=0,
dilation=dilation, groups=TE_DIM, bias=True,
)
self.norm = WrappedNorm(TE_DIM, eps=EPS_LN)
self.pwconv1 = nn.Linear(TE_DIM, TE_CONVNEXT_HIDDEN, bias=True)
self.pwconv2 = nn.Linear(TE_CONVNEXT_HIDDEN, TE_DIM, bias=True)
self.gamma = mx.zeros((TE_DIM,))
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
# x: (B, T_text, 256)
residual = x
y = _pad_sym_edge(x, self.pad)
y = self.dwconv(y)
y = self.norm(y)
y = self.pwconv1(y)
y = _gelu_exact(y)
y = self.pwconv2(y)
y = y * self.gamma
out = residual + y
if mask is not None:
out = out * mask
return out
class TextConvNeXtStack(nn.Module):
"""6 stacked ConvNeXt blocks. Loaded as ``convnext.convnext.[0..5].X``."""
def __init__(self) -> None:
super().__init__()
self.convnext = [TextConvNeXtBlock(d) for d in TE_CONVNEXT_DILATIONS]
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
for b in self.convnext:
x = b(x, mask)
return x
class _ConvLayer(nn.Module):
"""Conv1d k=1 expressed via the ONNX-style ``X.weight (out, in, 1) + X.bias``.
The attn_encoder uses Conv1d k=1 instead of nn.Linear for its Q/K/V/O.
This wrapper keeps the weight shape (out, in, 1) intact and runs as a
Conv1d (the equivalent of a Linear when k=1).
"""
def __init__(self, in_dim: int, out_dim: int) -> None:
super().__init__()
self.weight = mx.zeros((out_dim, 1, in_dim)) # (C_out, K=1, C_in)
self.bias = mx.zeros((out_dim,))
def __call__(self, x: mx.array) -> mx.array:
# x: (B, T, in_dim) — channels-last
# equivalent to nn.Conv1d(in_dim, out_dim, k=1) in NTC layout
return mx.conv1d(x, self.weight, stride=1, padding=0) + self.bias
REL_POS_WINDOW = 4 # rel_pos table size = 2*4 + 1 = 9
def _rel_to_abs(x: mx.array) -> mx.array:
"""[B, h, L, 2L-1] → [B, h, L, L] via the VITS shifted-skew reshape."""
B, h, L, _ = x.shape
x = mx.concatenate([x, mx.zeros((B, h, L, 1), dtype=x.dtype)], axis=-1)
x_flat = x.reshape(B, h, L * 2 * L)
x_flat = mx.concatenate([x_flat, mx.zeros((B, h, L - 1), dtype=x.dtype)], axis=-1)
x_final = x_flat.reshape(B, h, L + 1, 2 * L - 1)
return x_final[:, :, :L, L - 1:]
def _abs_to_rel(x: mx.array) -> mx.array:
"""[B, h, L, L] → [B, h, L, 2L-1] (inverse of _rel_to_abs)."""
B, h, L, _ = x.shape
x = mx.concatenate([x, mx.zeros((B, h, L, L - 1), dtype=x.dtype)], axis=-1)
x_flat = x.reshape(B, h, L * (2 * L - 1))
x_flat = mx.concatenate([mx.zeros((B, h, L), dtype=x.dtype), x_flat], axis=-1)
x_final = x_flat.reshape(B, h, L, 2 * L)
return x_final[:, :, :, 1:]
def _slice_rel_emb(rel: mx.array, length: int, window: int) -> mx.array:
"""``rel`` (1, 2W+1, d) → (1, 2L-1, d) by zero-padding/slicing."""
pad_l = max(length - (window + 1), 0)
if pad_l > 0:
zero = mx.zeros((1, pad_l, rel.shape[-1]), dtype=rel.dtype)
padded = mx.concatenate([zero, rel, zero], axis=1)
else:
padded = rel
start = max(window + 1 - length, 0)
return padded[:, start: start + 2 * length - 1]
class RelPosSelfAttention(nn.Module):
"""VITS-style relative-position self-attention with window=4.
Adds two contributions to vanilla MHA:
- ``rel_logits = q @ rel_k.T`` then ``_rel_to_abs`` and added to attention logits
- ``rel_attn = _abs_to_rel(softmax(logits))`` then ``@ rel_v`` and added to output
Loaded keys (per layer):
``conv_q/k/v/o.weight`` (256, 256, 1) and ``.bias`` (256)
``emb_rel_k`` (1, 9, 64), ``emb_rel_v`` (1, 9, 64)
"""
def __init__(self) -> None:
super().__init__()
self.conv_q = _ConvLayer(TE_DIM, TE_DIM)
self.conv_k = _ConvLayer(TE_DIM, TE_DIM)
self.conv_v = _ConvLayer(TE_DIM, TE_DIM)
self.conv_o = _ConvLayer(TE_DIM, TE_DIM)
self.window = REL_POS_WINDOW
self.emb_rel_k = mx.zeros((1, 2 * REL_POS_WINDOW + 1, TE_ATTN_HEAD_DIM))
self.emb_rel_v = mx.zeros((1, 2 * REL_POS_WINDOW + 1, TE_ATTN_HEAD_DIM))
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
B, T, _ = x.shape
H, D = TE_ATTN_HEADS, TE_ATTN_HEAD_DIM
q = self.conv_q(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
k = self.conv_k(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
v = self.conv_v(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
scale = D ** -0.5
# Standard attention logits
logits = (q @ k.transpose(0, 1, 3, 2)) * scale # (B, H, T, T)
# VITS relative-position contribution to logits
rel_k = _slice_rel_emb(self.emb_rel_k, T, self.window) # (1, 2T-1, D)
rel_logits = q @ rel_k.transpose(0, 2, 1)[:, None, :, :] # (B, H, T, 2T-1)
rel_logits = _rel_to_abs(rel_logits * scale) # (B, H, T, T)
logits = logits + rel_logits
if mask is not None:
key_mask = mask[:, :, 0][:, None, None, :]
neg_inf = mx.array(-1e4, dtype=logits.dtype)
logits = mx.where(key_mask.astype(mx.bool_), logits, neg_inf)
attn = mx.softmax(logits, axis=-1) # (B, H, T, T)
out = attn @ v # (B, H, T, D)
# VITS rel-pos value contribution
rel_v = _slice_rel_emb(self.emb_rel_v, T, self.window) # (1, 2T-1, D)
rel_weights = _abs_to_rel(attn) # (B, H, T, 2T-1)
out = out + rel_weights @ rel_v[:, None, :, :] # (B, H, T, D)
out = out.transpose(0, 2, 1, 3).reshape(B, T, H * D)
return self.conv_o(out)
class FFN(nn.Module):
"""FFN with Conv1d k=1 wrappers: conv_1 (256→1024) + ReLU + conv_2 (1024→256).
Activation is ReLU (confirmed by ONNX graph node ``Relu`` in ``ffn_layers.N``),
not GELU. The mask is applied before each Conv to match the ONNX semantics.
"""
def __init__(self) -> None:
super().__init__()
self.conv_1 = _ConvLayer(TE_DIM, TE_FFN_HIDDEN)
self.conv_2 = _ConvLayer(TE_FFN_HIDDEN, TE_DIM)
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
if mask is not None:
x = x * mask
y = self.conv_1(x)
y = mx.maximum(y, mx.array(0.0, dtype=y.dtype))
if mask is not None:
y = y * mask
y = self.conv_2(y)
if mask is not None:
y = y * mask
return y
class AttnEncoder(nn.Module):
"""Stack of (RelPosSelfAttn + norm1) + (FFN + norm2) × 4."""
def __init__(self) -> None:
super().__init__()
self.attn_layers = [RelPosSelfAttention() for _ in range(TE_ATTN_NUM_LAYERS)]
self.norm_layers_1 = [WrappedNorm(TE_DIM, eps=EPS_LN) for _ in range(TE_ATTN_NUM_LAYERS)]
self.ffn_layers = [FFN() for _ in range(TE_ATTN_NUM_LAYERS)]
self.norm_layers_2 = [WrappedNorm(TE_DIM, eps=EPS_LN) for _ in range(TE_ATTN_NUM_LAYERS)]
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
for i in range(TE_ATTN_NUM_LAYERS):
y = self.attn_layers[i](x, mask=mask)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, mask)
x = self.norm_layers_2[i](x + y)
return x
class _TextEmbedder(nn.Module):
"""char_embedder: VOCAB → TE_DIM. Loaded as ``char_embedder.weight (8322, 256)``."""
def __init__(self) -> None:
super().__init__()
self.char_embedder = nn.Embedding(VOCAB_SIZE, TE_DIM)
def __call__(self, text_ids: mx.array) -> mx.array:
return self.char_embedder(text_ids)
class _InnerTextEncoder(nn.Module):
"""Pure text encoder before speech prompting. Loaded as ``text_encoder.X.Y``."""
def __init__(self) -> None:
super().__init__()
self.text_embedder = _TextEmbedder()
self.convnext = TextConvNeXtStack()
self.attn_encoder = AttnEncoder()
def __call__(self, text_ids: mx.array, mask: mx.array) -> mx.array:
x = self.text_embedder(text_ids) # (B, T, 256)
if mask is not None:
x = x * mask
x = self.convnext(x, mask)
x = self.attn_encoder(x, mask)
return x
class _StyleEncoder(nn.Module):
"""Holds ``style_token_layer.style_key`` (1, 50, 256)."""
def __init__(self) -> None:
super().__init__()
# Use a child module so the parameter path matches ``style_token_layer.style_key``
class _StyleTokenLayer(nn.Module):
def __init__(_):
super().__init__()
_.style_key = mx.zeros((1, 50, 256))
self.style_token_layer = _StyleTokenLayer()
class _SpeechPromptedAttn(nn.Module):
"""Cross-attention from text (Q) to style_ttl (K, V). Single head, 256-d."""
def __init__(self) -> None:
super().__init__()
self.W_query = WrappedLinear(TE_DIM, TE_DIM, bias=True)
self.W_key = WrappedLinear(TE_DIM, TE_DIM, bias=True)
self.W_value = WrappedLinear(TE_DIM, TE_DIM, bias=True)
self.out_fc = WrappedLinear(TE_DIM, TE_DIM, bias=True)
def __call__(self, x: mx.array, style: mx.array) -> mx.array:
# x: (B, T_text, 256); style: (B, 50, 256)
# Single-head cross attention.
B, T, D = x.shape
q = self.W_query(x)
k = self.W_key(style)
v = self.W_value(style)
scale = D ** -0.5
logits = (q @ k.transpose(0, 2, 1)) * scale
attn = mx.softmax(logits, axis=-1)
out = attn @ v
return self.out_fc(out)
class _SpeechPromptedTextEncoder(nn.Module):
"""Two cross-attention layers modulating text features with style_ttl."""
def __init__(self) -> None:
super().__init__()
self.attention1 = _SpeechPromptedAttn()
self.attention2 = _SpeechPromptedAttn()
self.norm = WrappedNorm(TE_DIM, eps=EPS_LN)
def __call__(self, x: mx.array, style: mx.array) -> mx.array:
x = x + self.attention1(x, style)
x = x + self.attention2(x, style)
return self.norm(x)
class _RootTextEncoder(nn.Module):
"""Top-level container matching ONNX ``tts.ttl.*`` namespace."""
def __init__(self) -> None:
super().__init__()
self.text_encoder = _InnerTextEncoder()
self.style_encoder = _StyleEncoder()
self.speech_prompted_text_encoder = _SpeechPromptedTextEncoder()
class _TtsContainer(nn.Module):
"""Outer container so weight keys ``tts.ttl.X.Y`` resolve."""
def __init__(self) -> None:
super().__init__()
self.ttl = _RootTextEncoder()
class TextEncoder(nn.Module):
"""Top-level text encoder: ``text_ids + style_ttl + text_mask → text_emb (B, 256, T)``.
Submodule naming matches the ONNX initializer keys after a single
``tts.ttl.`` prefix wrap (so weight keys look like
``tts.ttl.text_encoder.convnext.convnext.0.dwconv.weight``).
"""
def __init__(self) -> None:
super().__init__()
self.tts = _TtsContainer()
def __call__(
self,
text_ids: mx.array, # (B, T_text) int
style_ttl: mx.array, # (B, 50, 256)
text_mask: mx.array, # (B, 1, T_text)
) -> mx.array:
mask_ntc = text_mask.transpose(0, 2, 1) # (B, T_text, 1)
x = self.tts.ttl.text_encoder(text_ids, mask_ntc)
x = self.tts.ttl.speech_prompted_text_encoder(x, style_ttl)
if mask_ntc is not None:
x = x * mask_ntc
# Return channels-first (B, 256, T_text) to match the vector_estimator input.
return x.transpose(0, 2, 1)
__all__ = ["TextEncoder", "VOCAB_SIZE", "TE_DIM"]