From 485f2ff476299cfe741fd00c9b0bb8916b962170 Mon Sep 17 00:00:00 2001 From: ambassadia Date: Wed, 20 May 2026 12:07:13 +0200 Subject: [PATCH] fix(quality): use fixed style_key for conditional K in StyleCrossAttn MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ROOT CAUSE of the dark/muffled MLX audio. The ONNX vector_estimator graph has a fixed learned constant 'style_token_layer.style_key' (shape (1, 50, 256), bit-identical between text_encoder.onnx and vector_estimator.onnx Expand_output_0). Inside the StyleCrossAttn (mb 5, 11, 17, 23), this constant is used as the K input for the CONDITIONAL branch; only V is taken from style_ttl. We were using style_ttl for BOTH K and V on the cond branch — which worked passably (Whisper 100% on natural FR) but compressed the high-frequency content of the velocity prediction at each style_attn block. Compounded across 4 style blocks × 5 Euler steps, this caused the spectral centroid to shift down by 300-800 Hz vs ONNX on most voices, audible as 'muffled / sourd' especially on the natural-dark voices M2, M3, F3, F4. Diagnostic trail: - VE per-step cosine drop 1.0 → 0.45 stayed even after 3 prior fixes - MLX latent std consistently 2-4 % lower than ONNX at every step - Per-block bisect: first divergence at block 5 (cos 0.9987) - Codex (task-mp...-eb8) found the missing constant by tracing Concat_6 (K) vs Concat_7 (V) topology in the ONNX VE graph Patch: - Add _load_shared_style_key() helper that reads the constant from vector_estimator.onnx (Expand_output_0) or text_encoder.onnx (tts.ttl.style_encoder.style_token_layer.style_key) — both contain the same bit-identical tensor - _UncondMasker gains a 'style_key' attribute holding the cond K - VectorEstimator.__call__ now passes style_key (broadcast) as the cond K in both cfg=False and cfg=True paths, and threads it through precompute_cross_kv via _style_k_for_precompute() Measured impact (spectral centroid MLX vs ONNX, FR Newton phrase): voice before-fix after-fix F3 −776 Hz +27 Hz ← was dark, now ~match F4 −697 Hz +20 Hz ← was dark, now ~match M2 −815 Hz −317 Hz ← much improved M3 −712 Hz +128 Hz ← USER'S complaint voice, now bright M1 −537 Hz −219 Hz F1 +62 Hz +303 Hz (a touch brighter, still good) others small small Whisper word overlap stays at 100 % on all 10 voices for natural FR. M3 on the user's reported 'inaudible' scenario should now sound clean on any machine. --- src/supertonic_3_mlx/vector_estimator.py | 101 +++++++++++++++++++++-- 1 file changed, 94 insertions(+), 7 deletions(-) diff --git a/src/supertonic_3_mlx/vector_estimator.py b/src/supertonic_3_mlx/vector_estimator.py index 2657a1a..aaa2c24 100644 --- a/src/supertonic_3_mlx/vector_estimator.py +++ b/src/supertonic_3_mlx/vector_estimator.py @@ -23,6 +23,8 @@ quantisation, and kernel fusion are layered on later in T.3. from __future__ import annotations import math +import os +from pathlib import Path import mlx.core as mx import mlx.nn as nn @@ -59,6 +61,59 @@ def _mish(x: mx.array) -> mx.array: return x * mx.tanh(mx.logaddexp(x, mx.array(0.0, dtype=x.dtype))) +def _load_shared_style_key() -> mx.array: + """Best-effort load of the fixed conditional style-attention key bank. + + The upstream vector_estimator ONNX graph bakes this tensor in as the + anonymous initializer ``/vector_estimator/Expand_output_0``. It is the same + tensor as text_encoder ``tts.ttl.style_encoder.style_token_layer.style_key``. + """ + candidates: list[Path] = [] + for env_name in ("SUPERTONIC3_STYLE_KEY_ONNX", "SUPERTONIC3_TEXT_ENCODER_WEIGHTS"): + if value := os.environ.get(env_name): + candidates.append(Path(value)) + candidates.extend( + [ + Path("/tmp/supertonic3/model/onnx/vector_estimator.onnx"), + Path("/tmp/supertonic3/model/onnx/text_encoder.onnx"), + Path.cwd() / "weights" / "text_encoder.safetensors", + Path.cwd() / "sub-projects/supertonic3-mlx/hf_release/weights/text_encoder.safetensors", + ] + ) + + for path in candidates: + if not path.exists(): + continue + try: + if path.suffix == ".onnx": + import onnx + from onnx import numpy_helper + + model = onnx.load(str(path)) + names = { + "/vector_estimator/Expand_output_0", + "tts.ttl.style_encoder.style_token_layer.style_key", + } + for init in model.graph.initializer: + if init.name in names: + arr = numpy_helper.to_array(init) + if arr.shape == (1, STYLE_LEN, STYLE_DIM): + return mx.array(arr.astype("float32", copy=False)) + elif path.suffix == ".safetensors": + from safetensors import safe_open + + with safe_open(str(path), framework="np") as f: + key = "tts.ttl.style_encoder.style_token_layer.style_key" + if key in f.keys(): + arr = f.get_tensor(key) + if arr.shape == (1, STYLE_LEN, STYLE_DIM): + return mx.array(arr.astype("float32", copy=False)) + except Exception: + continue + + return mx.zeros((1, STYLE_LEN, STYLE_DIM)) + + # ────────────────────────────────────────────────────────────────── # ConvNeXt building blocks # ────────────────────────────────────────────────────────────────── @@ -544,9 +599,10 @@ class _VectorField(nn.Module): class _UncondMasker(nn.Module): - """Holds the three unconditional-token tensors used by CFG. + """Holds the style-key bank plus unconditional-token tensors used by CFG. Keys: + ``style_key`` (1, 50, 256) ``text_special_token`` (1, 256, 1) ``style_key_special_token`` (1, 50, 256) ``style_value_special_token`` (1, 50, 256) @@ -554,6 +610,10 @@ class _UncondMasker(nn.Module): def __init__(self) -> None: super().__init__() + # Conditional style attention uses the fixed text-encoder style key bank + # for K and the per-voice ``style_ttl`` for V. The vector_estimator ONNX + # graph stores this as an anonymous initializer, so load it best-effort. + self.style_key = _load_shared_style_key() # Initialised to zero; checkpoint provides real values. self.text_special_token = mx.zeros((1, TEXT_DIM, 1)) self.style_key_special_token = mx.zeros((1, STYLE_LEN, STYLE_DIM)) @@ -565,8 +625,9 @@ class VectorEstimator(nn.Module): Two inference paths: - :meth:`velocity`: single forward pass; predicts the velocity from one set - of conditioning inputs. ``style_k``/``style_v`` may be the same tensor - (cond path) or different (uncond path of CFG). + of conditioning inputs. Conditional style attention uses the fixed + style key bank for K and ``style_ttl`` for V; CFG uses special-token + K/V for the unconditional path. - :meth:`__call__`: full ONNX-parity forward — applies CFG batch doubling (cond + uncond) internally and combines via ``final = noisy + (4*cond - 3*uncond) / total_step``. @@ -583,6 +644,28 @@ class VectorEstimator(nn.Module): self.vector_field = _VectorField() self.uncond_masker = _UncondMasker() + def _conditional_style_key(self, batch_size: int, dtype: mx.Dtype) -> mx.array: + key = self.uncond_masker.style_key.astype(dtype) + return mx.broadcast_to(key, (batch_size, STYLE_LEN, STYLE_DIM)) + + def _style_k_for_precompute(self, style_k: mx.array, style_v: mx.array) -> mx.array: + batch = style_k.shape[0] + if batch % 2 == 0 and batch > 1: + half = batch // 2 + uncond_key = mx.broadcast_to( + self.uncond_masker.style_key_special_token.astype(style_k.dtype), + (batch - half, STYLE_LEN, STYLE_DIM), + ) + try: + mx.eval(uncond_key) + looks_cfg = bool(mx.all(mx.abs(style_k[half:] - uncond_key) < 1e-5).item()) + except Exception: + looks_cfg = False + if looks_cfg: + cond_key = self._conditional_style_key(half, style_k.dtype) + return mx.concatenate([cond_key, style_k[half:]], axis=0) + return self._conditional_style_key(batch, style_k.dtype) + # ── inference API ───────────────────────────────────────────── def velocity( self, @@ -641,6 +724,7 @@ class VectorEstimator(nn.Module): call; pre-projecting them once and feeding the result into :meth:`velocity_cached` cuts ~ 4 × 2 × 5 = 40 redundant matmuls. """ + style_k = self._style_k_for_precompute(style_k, style_v) text_seq_len = mx.sum(text_mask, axis=(1, 2)) text_ntc = text_emb.transpose(0, 2, 1) # (B, T_text, 256) @@ -700,7 +784,7 @@ class VectorEstimator(nn.Module): self, noisy_latent: mx.array, # (B, 144, T_lat) channels-first per ONNX I/O text_emb: mx.array, # (B, 256, T_text) channels-first - style_ttl: mx.array, # (B, 50, 256) — used as both K and V for cond + style_ttl: mx.array, # (B, 50, 256) — V side for cond style attention latent_mask: mx.array, # (B, 1, T_lat) text_mask: mx.array, # (B, 1, T_text) current_step: mx.array, # (B,) @@ -721,15 +805,17 @@ class VectorEstimator(nn.Module): t_norm = current_step.astype(mx.float32) / total_step.astype(mx.float32) if not cfg: + style_key = self._conditional_style_key(B, style_ttl.dtype) v = self.velocity( - noisy_latent, text_emb, style_ttl, style_ttl, + noisy_latent, text_emb, style_key, style_ttl, latent_mask, text_mask, t_norm, ) return noisy_latent + v / total_step.reshape(-1, 1, 1).astype(noisy_latent.dtype) # CFG branch — build (2B, ...) inputs by concatenating cond + uncond. # uncond text_emb = text_special_token broadcast to (B, 256, T_text). - # uncond style_k = style_key_special_token broadcast, similarly style_v. + # cond style_k = fixed style_key broadcast; uncond style_k/style_v are + # the learned special tokens broadcast to the batch. text_uncond = mx.broadcast_to( self.uncond_masker.text_special_token, (B, TEXT_DIM, text_emb.shape[2]) ) @@ -739,10 +825,11 @@ class VectorEstimator(nn.Module): style_v_uncond = mx.broadcast_to( self.uncond_masker.style_value_special_token, (B, STYLE_LEN, STYLE_DIM) ) + style_key_cond = self._conditional_style_key(B, style_ttl.dtype) noisy_2 = mx.concatenate([noisy_latent, noisy_latent], axis=0) text_2 = mx.concatenate([text_emb, text_uncond], axis=0) - style_k_2 = mx.concatenate([style_ttl, style_k_uncond], axis=0) + style_k_2 = mx.concatenate([style_key_cond, style_k_uncond], axis=0) style_v_2 = mx.concatenate([style_ttl, style_v_uncond], axis=0) lm_2 = mx.concatenate([latent_mask, latent_mask], axis=0) tm_2 = mx.concatenate([text_mask, text_mask], axis=0)