v0.1.0 — initial release
MLX-native port of Supertone's Supertonic 3 multilingual TTS. Runs the full flow-matching + classifier-free-guidance pipeline at ~x100 realtime on Apple Silicon, with audio cosine 1.0 vs the cached MLX path and cosine 0.98 vs the upstream ONNX Runtime reference. Weights are hosted at https://huggingface.co/ambassadia/supertonic-3-mlx and auto-downloaded on first use; this repository ships the port code, the model card, audio samples, and a zero-config setup_and_test.sh. Install: pip install git+https://gitea.tavportal.com/olivier/supertonic-3-mlx.git Quick test: git clone https://gitea.tavportal.com/olivier/supertonic-3-mlx.git cd supertonic-3-mlx && ./setup_and_test.sh Licenses (dual): model weights = BigScience Open RAIL-M (Section 4 propagation), port code = Apache-2.0. See LICENSE, LICENSE-CODE, NOTICE. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
382
src/supertonic_3_mlx/text_encoder.py
Normal file
382
src/supertonic_3_mlx/text_encoder.py
Normal file
@@ -0,0 +1,382 @@
|
||||
"""Supertonic 3 text encoder MLX port.
|
||||
|
||||
Pipeline (operating in channels-last NTC after the initial conv):
|
||||
|
||||
text_ids [B, T_text] int64 character IDs
|
||||
→ char_embedder (Embedding 8322→256) [B, T_text, 256]
|
||||
→ 6× ConvNeXt(dim=256, hidden=1024, k=5, dilations [1,1,2,2,4,4])
|
||||
→ 4× attn_encoder block:
|
||||
RelPosSelfAttn (conv_q/k/v/o, 4 heads × 64) + norm_layers_1
|
||||
FFN (conv_1: 256→1024, conv_2: 1024→256) + norm_layers_2
|
||||
→ speech_prompted_text_encoder:
|
||||
cross-attn1: text (Q) × style_ttl (K, V) → text features
|
||||
cross-attn2: text (Q) × style_ttl (K, V) → text features
|
||||
norm
|
||||
→ output text_emb [B, 256, T_text] (channels-first to match vector_estimator)
|
||||
|
||||
Inputs:
|
||||
text_ids: (B, T_text) int — character indices
|
||||
style_ttl: (B, 50, 256) float — style token bank
|
||||
text_mask: (B, 1, T_text) float — 1.0 where valid, 0.0 where padded
|
||||
|
||||
Submodule naming matches the ONNX initializer keys exactly so that
|
||||
``model.load_weights(...)`` succeeds with no remapping.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from supertonic_3_mlx._config import EPS_LN
|
||||
from supertonic_3_mlx._nn_wrappers import WrappedNorm, WrappedLinear
|
||||
from supertonic_3_mlx.vector_estimator import (
|
||||
ConvNeXtBlock, _pad_sym_edge, _gelu_exact,
|
||||
)
|
||||
|
||||
|
||||
# Vocab + dims (frozen by checkpoint)
|
||||
VOCAB_SIZE = 8322
|
||||
TE_DIM = 256
|
||||
TE_CONVNEXT_HIDDEN = 1024
|
||||
TE_CONVNEXT_K = 5
|
||||
TE_CONVNEXT_NUM_LAYERS = 6
|
||||
TE_CONVNEXT_DILATIONS = (1, 1, 2, 2, 4, 4)
|
||||
|
||||
TE_ATTN_NUM_LAYERS = 4
|
||||
TE_ATTN_HEADS = 4
|
||||
TE_ATTN_HEAD_DIM = TE_DIM // TE_ATTN_HEADS # 64
|
||||
TE_FFN_HIDDEN = 1024
|
||||
|
||||
|
||||
class TextConvNeXtBlock(nn.Module):
|
||||
"""ConvNeXt for the text encoder (dim=256, hidden=1024).
|
||||
|
||||
Shares the same architecture as ``vector_estimator.ConvNeXtBlock`` but is
|
||||
redefined here with text-encoder-specific defaults to keep the modules
|
||||
self-contained.
|
||||
"""
|
||||
|
||||
def __init__(self, dilation: int = 1) -> None:
|
||||
super().__init__()
|
||||
self.dim = TE_DIM
|
||||
self.dilation = dilation
|
||||
self.pad = dilation * (TE_CONVNEXT_K - 1) // 2
|
||||
self.dwconv = nn.Conv1d(
|
||||
TE_DIM, TE_DIM, kernel_size=TE_CONVNEXT_K, padding=0,
|
||||
dilation=dilation, groups=TE_DIM, bias=True,
|
||||
)
|
||||
self.norm = WrappedNorm(TE_DIM, eps=EPS_LN)
|
||||
self.pwconv1 = nn.Linear(TE_DIM, TE_CONVNEXT_HIDDEN, bias=True)
|
||||
self.pwconv2 = nn.Linear(TE_CONVNEXT_HIDDEN, TE_DIM, bias=True)
|
||||
self.gamma = mx.zeros((TE_DIM,))
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
# x: (B, T_text, 256)
|
||||
residual = x
|
||||
y = _pad_sym_edge(x, self.pad)
|
||||
y = self.dwconv(y)
|
||||
y = self.norm(y)
|
||||
y = self.pwconv1(y)
|
||||
y = _gelu_exact(y)
|
||||
y = self.pwconv2(y)
|
||||
y = y * self.gamma
|
||||
out = residual + y
|
||||
if mask is not None:
|
||||
out = out * mask
|
||||
return out
|
||||
|
||||
|
||||
class TextConvNeXtStack(nn.Module):
|
||||
"""6 stacked ConvNeXt blocks. Loaded as ``convnext.convnext.[0..5].X``."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.convnext = [TextConvNeXtBlock(d) for d in TE_CONVNEXT_DILATIONS]
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
for b in self.convnext:
|
||||
x = b(x, mask)
|
||||
return x
|
||||
|
||||
|
||||
class _ConvLayer(nn.Module):
|
||||
"""Conv1d k=1 expressed via the ONNX-style ``X.weight (out, in, 1) + X.bias``.
|
||||
|
||||
The attn_encoder uses Conv1d k=1 instead of nn.Linear for its Q/K/V/O.
|
||||
This wrapper keeps the weight shape (out, in, 1) intact and runs as a
|
||||
Conv1d (the equivalent of a Linear when k=1).
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.weight = mx.zeros((out_dim, 1, in_dim)) # (C_out, K=1, C_in)
|
||||
self.bias = mx.zeros((out_dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (B, T, in_dim) — channels-last
|
||||
# equivalent to nn.Conv1d(in_dim, out_dim, k=1) in NTC layout
|
||||
return mx.conv1d(x, self.weight, stride=1, padding=0) + self.bias
|
||||
|
||||
|
||||
REL_POS_WINDOW = 4 # rel_pos table size = 2*4 + 1 = 9
|
||||
|
||||
|
||||
def _rel_to_abs(x: mx.array) -> mx.array:
|
||||
"""[B, h, L, 2L-1] → [B, h, L, L] via the VITS shifted-skew reshape."""
|
||||
B, h, L, _ = x.shape
|
||||
x = mx.concatenate([x, mx.zeros((B, h, L, 1), dtype=x.dtype)], axis=-1)
|
||||
x_flat = x.reshape(B, h, L * 2 * L)
|
||||
x_flat = mx.concatenate([x_flat, mx.zeros((B, h, L - 1), dtype=x.dtype)], axis=-1)
|
||||
x_final = x_flat.reshape(B, h, L + 1, 2 * L - 1)
|
||||
return x_final[:, :, :L, L - 1:]
|
||||
|
||||
|
||||
def _abs_to_rel(x: mx.array) -> mx.array:
|
||||
"""[B, h, L, L] → [B, h, L, 2L-1] (inverse of _rel_to_abs)."""
|
||||
B, h, L, _ = x.shape
|
||||
x = mx.concatenate([x, mx.zeros((B, h, L, L - 1), dtype=x.dtype)], axis=-1)
|
||||
x_flat = x.reshape(B, h, L * (2 * L - 1))
|
||||
x_flat = mx.concatenate([mx.zeros((B, h, L), dtype=x.dtype), x_flat], axis=-1)
|
||||
x_final = x_flat.reshape(B, h, L, 2 * L)
|
||||
return x_final[:, :, :, 1:]
|
||||
|
||||
|
||||
def _slice_rel_emb(rel: mx.array, length: int, window: int) -> mx.array:
|
||||
"""``rel`` (1, 2W+1, d) → (1, 2L-1, d) by zero-padding/slicing."""
|
||||
pad_l = max(length - (window + 1), 0)
|
||||
if pad_l > 0:
|
||||
zero = mx.zeros((1, pad_l, rel.shape[-1]), dtype=rel.dtype)
|
||||
padded = mx.concatenate([zero, rel, zero], axis=1)
|
||||
else:
|
||||
padded = rel
|
||||
start = max(window + 1 - length, 0)
|
||||
return padded[:, start: start + 2 * length - 1]
|
||||
|
||||
|
||||
class RelPosSelfAttention(nn.Module):
|
||||
"""VITS-style relative-position self-attention with window=4.
|
||||
|
||||
Adds two contributions to vanilla MHA:
|
||||
- ``rel_logits = q @ rel_k.T`` then ``_rel_to_abs`` and added to attention logits
|
||||
- ``rel_attn = _abs_to_rel(softmax(logits))`` then ``@ rel_v`` and added to output
|
||||
|
||||
Loaded keys (per layer):
|
||||
``conv_q/k/v/o.weight`` (256, 256, 1) and ``.bias`` (256)
|
||||
``emb_rel_k`` (1, 9, 64), ``emb_rel_v`` (1, 9, 64)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv_q = _ConvLayer(TE_DIM, TE_DIM)
|
||||
self.conv_k = _ConvLayer(TE_DIM, TE_DIM)
|
||||
self.conv_v = _ConvLayer(TE_DIM, TE_DIM)
|
||||
self.conv_o = _ConvLayer(TE_DIM, TE_DIM)
|
||||
self.window = REL_POS_WINDOW
|
||||
self.emb_rel_k = mx.zeros((1, 2 * REL_POS_WINDOW + 1, TE_ATTN_HEAD_DIM))
|
||||
self.emb_rel_v = mx.zeros((1, 2 * REL_POS_WINDOW + 1, TE_ATTN_HEAD_DIM))
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
B, T, _ = x.shape
|
||||
H, D = TE_ATTN_HEADS, TE_ATTN_HEAD_DIM
|
||||
q = self.conv_q(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
|
||||
k = self.conv_k(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
|
||||
v = self.conv_v(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
|
||||
scale = D ** -0.5
|
||||
|
||||
# Standard attention logits
|
||||
logits = (q @ k.transpose(0, 1, 3, 2)) * scale # (B, H, T, T)
|
||||
|
||||
# VITS relative-position contribution to logits
|
||||
rel_k = _slice_rel_emb(self.emb_rel_k, T, self.window) # (1, 2T-1, D)
|
||||
rel_logits = q @ rel_k.transpose(0, 2, 1)[:, None, :, :] # (B, H, T, 2T-1)
|
||||
rel_logits = _rel_to_abs(rel_logits * scale) # (B, H, T, T)
|
||||
logits = logits + rel_logits
|
||||
|
||||
if mask is not None:
|
||||
key_mask = mask[:, :, 0][:, None, None, :]
|
||||
neg_inf = mx.array(-1e4, dtype=logits.dtype)
|
||||
logits = mx.where(key_mask.astype(mx.bool_), logits, neg_inf)
|
||||
|
||||
attn = mx.softmax(logits, axis=-1) # (B, H, T, T)
|
||||
out = attn @ v # (B, H, T, D)
|
||||
|
||||
# VITS rel-pos value contribution
|
||||
rel_v = _slice_rel_emb(self.emb_rel_v, T, self.window) # (1, 2T-1, D)
|
||||
rel_weights = _abs_to_rel(attn) # (B, H, T, 2T-1)
|
||||
out = out + rel_weights @ rel_v[:, None, :, :] # (B, H, T, D)
|
||||
|
||||
out = out.transpose(0, 2, 1, 3).reshape(B, T, H * D)
|
||||
return self.conv_o(out)
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
"""FFN with Conv1d k=1 wrappers: conv_1 (256→1024) + ReLU + conv_2 (1024→256).
|
||||
|
||||
Activation is ReLU (confirmed by ONNX graph node ``Relu`` in ``ffn_layers.N``),
|
||||
not GELU. The mask is applied before each Conv to match the ONNX semantics.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv_1 = _ConvLayer(TE_DIM, TE_FFN_HIDDEN)
|
||||
self.conv_2 = _ConvLayer(TE_FFN_HIDDEN, TE_DIM)
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
if mask is not None:
|
||||
x = x * mask
|
||||
y = self.conv_1(x)
|
||||
y = mx.maximum(y, mx.array(0.0, dtype=y.dtype))
|
||||
if mask is not None:
|
||||
y = y * mask
|
||||
y = self.conv_2(y)
|
||||
if mask is not None:
|
||||
y = y * mask
|
||||
return y
|
||||
|
||||
|
||||
class AttnEncoder(nn.Module):
|
||||
"""Stack of (RelPosSelfAttn + norm1) + (FFN + norm2) × 4."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attn_layers = [RelPosSelfAttention() for _ in range(TE_ATTN_NUM_LAYERS)]
|
||||
self.norm_layers_1 = [WrappedNorm(TE_DIM, eps=EPS_LN) for _ in range(TE_ATTN_NUM_LAYERS)]
|
||||
self.ffn_layers = [FFN() for _ in range(TE_ATTN_NUM_LAYERS)]
|
||||
self.norm_layers_2 = [WrappedNorm(TE_DIM, eps=EPS_LN) for _ in range(TE_ATTN_NUM_LAYERS)]
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
for i in range(TE_ATTN_NUM_LAYERS):
|
||||
y = self.attn_layers[i](x, mask=mask)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
y = self.ffn_layers[i](x, mask)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
return x
|
||||
|
||||
|
||||
class _TextEmbedder(nn.Module):
|
||||
"""char_embedder: VOCAB → TE_DIM. Loaded as ``char_embedder.weight (8322, 256)``."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.char_embedder = nn.Embedding(VOCAB_SIZE, TE_DIM)
|
||||
|
||||
def __call__(self, text_ids: mx.array) -> mx.array:
|
||||
return self.char_embedder(text_ids)
|
||||
|
||||
|
||||
class _InnerTextEncoder(nn.Module):
|
||||
"""Pure text encoder before speech prompting. Loaded as ``text_encoder.X.Y``."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.text_embedder = _TextEmbedder()
|
||||
self.convnext = TextConvNeXtStack()
|
||||
self.attn_encoder = AttnEncoder()
|
||||
|
||||
def __call__(self, text_ids: mx.array, mask: mx.array) -> mx.array:
|
||||
x = self.text_embedder(text_ids) # (B, T, 256)
|
||||
if mask is not None:
|
||||
x = x * mask
|
||||
x = self.convnext(x, mask)
|
||||
x = self.attn_encoder(x, mask)
|
||||
return x
|
||||
|
||||
|
||||
class _StyleEncoder(nn.Module):
|
||||
"""Holds ``style_token_layer.style_key`` (1, 50, 256)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Use a child module so the parameter path matches ``style_token_layer.style_key``
|
||||
class _StyleTokenLayer(nn.Module):
|
||||
def __init__(_):
|
||||
super().__init__()
|
||||
_.style_key = mx.zeros((1, 50, 256))
|
||||
self.style_token_layer = _StyleTokenLayer()
|
||||
|
||||
|
||||
class _SpeechPromptedAttn(nn.Module):
|
||||
"""Cross-attention from text (Q) to style_ttl (K, V). Single head, 256-d."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.W_query = WrappedLinear(TE_DIM, TE_DIM, bias=True)
|
||||
self.W_key = WrappedLinear(TE_DIM, TE_DIM, bias=True)
|
||||
self.W_value = WrappedLinear(TE_DIM, TE_DIM, bias=True)
|
||||
self.out_fc = WrappedLinear(TE_DIM, TE_DIM, bias=True)
|
||||
|
||||
def __call__(self, x: mx.array, style: mx.array) -> mx.array:
|
||||
# x: (B, T_text, 256); style: (B, 50, 256)
|
||||
# Single-head cross attention.
|
||||
B, T, D = x.shape
|
||||
q = self.W_query(x)
|
||||
k = self.W_key(style)
|
||||
v = self.W_value(style)
|
||||
scale = D ** -0.5
|
||||
logits = (q @ k.transpose(0, 2, 1)) * scale
|
||||
attn = mx.softmax(logits, axis=-1)
|
||||
out = attn @ v
|
||||
return self.out_fc(out)
|
||||
|
||||
|
||||
class _SpeechPromptedTextEncoder(nn.Module):
|
||||
"""Two cross-attention layers modulating text features with style_ttl."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attention1 = _SpeechPromptedAttn()
|
||||
self.attention2 = _SpeechPromptedAttn()
|
||||
self.norm = WrappedNorm(TE_DIM, eps=EPS_LN)
|
||||
|
||||
def __call__(self, x: mx.array, style: mx.array) -> mx.array:
|
||||
x = x + self.attention1(x, style)
|
||||
x = x + self.attention2(x, style)
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class _RootTextEncoder(nn.Module):
|
||||
"""Top-level container matching ONNX ``tts.ttl.*`` namespace."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.text_encoder = _InnerTextEncoder()
|
||||
self.style_encoder = _StyleEncoder()
|
||||
self.speech_prompted_text_encoder = _SpeechPromptedTextEncoder()
|
||||
|
||||
|
||||
class _TtsContainer(nn.Module):
|
||||
"""Outer container so weight keys ``tts.ttl.X.Y`` resolve."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.ttl = _RootTextEncoder()
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
"""Top-level text encoder: ``text_ids + style_ttl + text_mask → text_emb (B, 256, T)``.
|
||||
|
||||
Submodule naming matches the ONNX initializer keys after a single
|
||||
``tts.ttl.`` prefix wrap (so weight keys look like
|
||||
``tts.ttl.text_encoder.convnext.convnext.0.dwconv.weight``).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.tts = _TtsContainer()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text_ids: mx.array, # (B, T_text) int
|
||||
style_ttl: mx.array, # (B, 50, 256)
|
||||
text_mask: mx.array, # (B, 1, T_text)
|
||||
) -> mx.array:
|
||||
mask_ntc = text_mask.transpose(0, 2, 1) # (B, T_text, 1)
|
||||
x = self.tts.ttl.text_encoder(text_ids, mask_ntc)
|
||||
x = self.tts.ttl.speech_prompted_text_encoder(x, style_ttl)
|
||||
if mask_ntc is not None:
|
||||
x = x * mask_ntc
|
||||
# Return channels-first (B, 256, T_text) to match the vector_estimator input.
|
||||
return x.transpose(0, 2, 1)
|
||||
|
||||
|
||||
__all__ = ["TextEncoder", "VOCAB_SIZE", "TE_DIM"]
|
||||
Reference in New Issue
Block a user