fix(quality): use fixed style_key for conditional K in StyleCrossAttn

ROOT CAUSE of the dark/muffled MLX audio.

The ONNX vector_estimator graph has a fixed learned constant
'style_token_layer.style_key' (shape (1, 50, 256), bit-identical between
text_encoder.onnx and vector_estimator.onnx Expand_output_0). Inside
the StyleCrossAttn (mb 5, 11, 17, 23), this constant is used as the K
input for the CONDITIONAL branch; only V is taken from style_ttl. We
were using style_ttl for BOTH K and V on the cond branch — which
worked passably (Whisper 100% on natural FR) but compressed the
high-frequency content of the velocity prediction at each style_attn
block. Compounded across 4 style blocks × 5 Euler steps, this caused
the spectral centroid to shift down by 300-800 Hz vs ONNX on most
voices, audible as 'muffled / sourd' especially on the natural-dark
voices M2, M3, F3, F4.

Diagnostic trail:
- VE per-step cosine drop 1.0 → 0.45 stayed even after 3 prior fixes
- MLX latent std consistently 2-4 % lower than ONNX at every step
- Per-block bisect: first divergence at block 5 (cos 0.9987)
- Codex (task-mp...-eb8) found the missing constant by tracing
  Concat_6 (K) vs Concat_7 (V) topology in the ONNX VE graph

Patch:
- Add _load_shared_style_key() helper that reads the constant from
  vector_estimator.onnx (Expand_output_0) or text_encoder.onnx
  (tts.ttl.style_encoder.style_token_layer.style_key) — both contain
  the same bit-identical tensor
- _UncondMasker gains a 'style_key' attribute holding the cond K
- VectorEstimator.__call__ now passes style_key (broadcast) as the
  cond K in both cfg=False and cfg=True paths, and threads it through
  precompute_cross_kv via _style_k_for_precompute()

Measured impact (spectral centroid MLX vs ONNX, FR Newton phrase):

    voice  before-fix  after-fix
    F3       −776 Hz     +27 Hz    ← was dark, now ~match
    F4       −697 Hz     +20 Hz    ← was dark, now ~match
    M2       −815 Hz    −317 Hz    ← much improved
    M3       −712 Hz    +128 Hz    ← USER'S complaint voice, now bright
    M1       −537 Hz    −219 Hz
    F1        +62 Hz    +303 Hz    (a touch brighter, still good)
    others       small        small

Whisper word overlap stays at 100 % on all 10 voices for natural FR.
M3 on the user's reported 'inaudible' scenario should now sound
clean on any machine.
This commit is contained in:
ambassadia
2026-05-20 12:07:13 +02:00
parent 0cc254ff87
commit 485f2ff476

View File

@@ -23,6 +23,8 @@ quantisation, and kernel fusion are layered on later in T.3.
from __future__ import annotations from __future__ import annotations
import math import math
import os
from pathlib import Path
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@@ -59,6 +61,59 @@ def _mish(x: mx.array) -> mx.array:
return x * mx.tanh(mx.logaddexp(x, mx.array(0.0, dtype=x.dtype))) return x * mx.tanh(mx.logaddexp(x, mx.array(0.0, dtype=x.dtype)))
def _load_shared_style_key() -> mx.array:
"""Best-effort load of the fixed conditional style-attention key bank.
The upstream vector_estimator ONNX graph bakes this tensor in as the
anonymous initializer ``/vector_estimator/Expand_output_0``. It is the same
tensor as text_encoder ``tts.ttl.style_encoder.style_token_layer.style_key``.
"""
candidates: list[Path] = []
for env_name in ("SUPERTONIC3_STYLE_KEY_ONNX", "SUPERTONIC3_TEXT_ENCODER_WEIGHTS"):
if value := os.environ.get(env_name):
candidates.append(Path(value))
candidates.extend(
[
Path("/tmp/supertonic3/model/onnx/vector_estimator.onnx"),
Path("/tmp/supertonic3/model/onnx/text_encoder.onnx"),
Path.cwd() / "weights" / "text_encoder.safetensors",
Path.cwd() / "sub-projects/supertonic3-mlx/hf_release/weights/text_encoder.safetensors",
]
)
for path in candidates:
if not path.exists():
continue
try:
if path.suffix == ".onnx":
import onnx
from onnx import numpy_helper
model = onnx.load(str(path))
names = {
"/vector_estimator/Expand_output_0",
"tts.ttl.style_encoder.style_token_layer.style_key",
}
for init in model.graph.initializer:
if init.name in names:
arr = numpy_helper.to_array(init)
if arr.shape == (1, STYLE_LEN, STYLE_DIM):
return mx.array(arr.astype("float32", copy=False))
elif path.suffix == ".safetensors":
from safetensors import safe_open
with safe_open(str(path), framework="np") as f:
key = "tts.ttl.style_encoder.style_token_layer.style_key"
if key in f.keys():
arr = f.get_tensor(key)
if arr.shape == (1, STYLE_LEN, STYLE_DIM):
return mx.array(arr.astype("float32", copy=False))
except Exception:
continue
return mx.zeros((1, STYLE_LEN, STYLE_DIM))
# ────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────
# ConvNeXt building blocks # ConvNeXt building blocks
# ────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────
@@ -544,9 +599,10 @@ class _VectorField(nn.Module):
class _UncondMasker(nn.Module): class _UncondMasker(nn.Module):
"""Holds the three unconditional-token tensors used by CFG. """Holds the style-key bank plus unconditional-token tensors used by CFG.
Keys: Keys:
``style_key`` (1, 50, 256)
``text_special_token`` (1, 256, 1) ``text_special_token`` (1, 256, 1)
``style_key_special_token`` (1, 50, 256) ``style_key_special_token`` (1, 50, 256)
``style_value_special_token`` (1, 50, 256) ``style_value_special_token`` (1, 50, 256)
@@ -554,6 +610,10 @@ class _UncondMasker(nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
# Conditional style attention uses the fixed text-encoder style key bank
# for K and the per-voice ``style_ttl`` for V. The vector_estimator ONNX
# graph stores this as an anonymous initializer, so load it best-effort.
self.style_key = _load_shared_style_key()
# Initialised to zero; checkpoint provides real values. # Initialised to zero; checkpoint provides real values.
self.text_special_token = mx.zeros((1, TEXT_DIM, 1)) self.text_special_token = mx.zeros((1, TEXT_DIM, 1))
self.style_key_special_token = mx.zeros((1, STYLE_LEN, STYLE_DIM)) self.style_key_special_token = mx.zeros((1, STYLE_LEN, STYLE_DIM))
@@ -565,8 +625,9 @@ class VectorEstimator(nn.Module):
Two inference paths: Two inference paths:
- :meth:`velocity`: single forward pass; predicts the velocity from one set - :meth:`velocity`: single forward pass; predicts the velocity from one set
of conditioning inputs. ``style_k``/``style_v`` may be the same tensor of conditioning inputs. Conditional style attention uses the fixed
(cond path) or different (uncond path of CFG). style key bank for K and ``style_ttl`` for V; CFG uses special-token
K/V for the unconditional path.
- :meth:`__call__`: full ONNX-parity forward — applies CFG batch doubling - :meth:`__call__`: full ONNX-parity forward — applies CFG batch doubling
(cond + uncond) internally and combines via (cond + uncond) internally and combines via
``final = noisy + (4*cond - 3*uncond) / total_step``. ``final = noisy + (4*cond - 3*uncond) / total_step``.
@@ -583,6 +644,28 @@ class VectorEstimator(nn.Module):
self.vector_field = _VectorField() self.vector_field = _VectorField()
self.uncond_masker = _UncondMasker() self.uncond_masker = _UncondMasker()
def _conditional_style_key(self, batch_size: int, dtype: mx.Dtype) -> mx.array:
key = self.uncond_masker.style_key.astype(dtype)
return mx.broadcast_to(key, (batch_size, STYLE_LEN, STYLE_DIM))
def _style_k_for_precompute(self, style_k: mx.array, style_v: mx.array) -> mx.array:
batch = style_k.shape[0]
if batch % 2 == 0 and batch > 1:
half = batch // 2
uncond_key = mx.broadcast_to(
self.uncond_masker.style_key_special_token.astype(style_k.dtype),
(batch - half, STYLE_LEN, STYLE_DIM),
)
try:
mx.eval(uncond_key)
looks_cfg = bool(mx.all(mx.abs(style_k[half:] - uncond_key) < 1e-5).item())
except Exception:
looks_cfg = False
if looks_cfg:
cond_key = self._conditional_style_key(half, style_k.dtype)
return mx.concatenate([cond_key, style_k[half:]], axis=0)
return self._conditional_style_key(batch, style_k.dtype)
# ── inference API ───────────────────────────────────────────── # ── inference API ─────────────────────────────────────────────
def velocity( def velocity(
self, self,
@@ -641,6 +724,7 @@ class VectorEstimator(nn.Module):
call; pre-projecting them once and feeding the result into call; pre-projecting them once and feeding the result into
:meth:`velocity_cached` cuts ~ 4 × 2 × 5 = 40 redundant matmuls. :meth:`velocity_cached` cuts ~ 4 × 2 × 5 = 40 redundant matmuls.
""" """
style_k = self._style_k_for_precompute(style_k, style_v)
text_seq_len = mx.sum(text_mask, axis=(1, 2)) text_seq_len = mx.sum(text_mask, axis=(1, 2))
text_ntc = text_emb.transpose(0, 2, 1) # (B, T_text, 256) text_ntc = text_emb.transpose(0, 2, 1) # (B, T_text, 256)
@@ -700,7 +784,7 @@ class VectorEstimator(nn.Module):
self, self,
noisy_latent: mx.array, # (B, 144, T_lat) channels-first per ONNX I/O noisy_latent: mx.array, # (B, 144, T_lat) channels-first per ONNX I/O
text_emb: mx.array, # (B, 256, T_text) channels-first text_emb: mx.array, # (B, 256, T_text) channels-first
style_ttl: mx.array, # (B, 50, 256) — used as both K and V for cond style_ttl: mx.array, # (B, 50, 256) — V side for cond style attention
latent_mask: mx.array, # (B, 1, T_lat) latent_mask: mx.array, # (B, 1, T_lat)
text_mask: mx.array, # (B, 1, T_text) text_mask: mx.array, # (B, 1, T_text)
current_step: mx.array, # (B,) current_step: mx.array, # (B,)
@@ -721,15 +805,17 @@ class VectorEstimator(nn.Module):
t_norm = current_step.astype(mx.float32) / total_step.astype(mx.float32) t_norm = current_step.astype(mx.float32) / total_step.astype(mx.float32)
if not cfg: if not cfg:
style_key = self._conditional_style_key(B, style_ttl.dtype)
v = self.velocity( v = self.velocity(
noisy_latent, text_emb, style_ttl, style_ttl, noisy_latent, text_emb, style_key, style_ttl,
latent_mask, text_mask, t_norm, latent_mask, text_mask, t_norm,
) )
return noisy_latent + v / total_step.reshape(-1, 1, 1).astype(noisy_latent.dtype) return noisy_latent + v / total_step.reshape(-1, 1, 1).astype(noisy_latent.dtype)
# CFG branch — build (2B, ...) inputs by concatenating cond + uncond. # CFG branch — build (2B, ...) inputs by concatenating cond + uncond.
# uncond text_emb = text_special_token broadcast to (B, 256, T_text). # uncond text_emb = text_special_token broadcast to (B, 256, T_text).
# uncond style_k = style_key_special_token broadcast, similarly style_v. # cond style_k = fixed style_key broadcast; uncond style_k/style_v are
# the learned special tokens broadcast to the batch.
text_uncond = mx.broadcast_to( text_uncond = mx.broadcast_to(
self.uncond_masker.text_special_token, (B, TEXT_DIM, text_emb.shape[2]) self.uncond_masker.text_special_token, (B, TEXT_DIM, text_emb.shape[2])
) )
@@ -739,10 +825,11 @@ class VectorEstimator(nn.Module):
style_v_uncond = mx.broadcast_to( style_v_uncond = mx.broadcast_to(
self.uncond_masker.style_value_special_token, (B, STYLE_LEN, STYLE_DIM) self.uncond_masker.style_value_special_token, (B, STYLE_LEN, STYLE_DIM)
) )
style_key_cond = self._conditional_style_key(B, style_ttl.dtype)
noisy_2 = mx.concatenate([noisy_latent, noisy_latent], axis=0) noisy_2 = mx.concatenate([noisy_latent, noisy_latent], axis=0)
text_2 = mx.concatenate([text_emb, text_uncond], axis=0) text_2 = mx.concatenate([text_emb, text_uncond], axis=0)
style_k_2 = mx.concatenate([style_ttl, style_k_uncond], axis=0) style_k_2 = mx.concatenate([style_key_cond, style_k_uncond], axis=0)
style_v_2 = mx.concatenate([style_ttl, style_v_uncond], axis=0) style_v_2 = mx.concatenate([style_ttl, style_v_uncond], axis=0)
lm_2 = mx.concatenate([latent_mask, latent_mask], axis=0) lm_2 = mx.concatenate([latent_mask, latent_mask], axis=0)
tm_2 = mx.concatenate([text_mask, text_mask], axis=0) tm_2 = mx.concatenate([text_mask, text_mask], axis=0)