fix(quality): use fixed style_key for conditional K in StyleCrossAttn
ROOT CAUSE of the dark/muffled MLX audio.
The ONNX vector_estimator graph has a fixed learned constant
'style_token_layer.style_key' (shape (1, 50, 256), bit-identical between
text_encoder.onnx and vector_estimator.onnx Expand_output_0). Inside
the StyleCrossAttn (mb 5, 11, 17, 23), this constant is used as the K
input for the CONDITIONAL branch; only V is taken from style_ttl. We
were using style_ttl for BOTH K and V on the cond branch — which
worked passably (Whisper 100% on natural FR) but compressed the
high-frequency content of the velocity prediction at each style_attn
block. Compounded across 4 style blocks × 5 Euler steps, this caused
the spectral centroid to shift down by 300-800 Hz vs ONNX on most
voices, audible as 'muffled / sourd' especially on the natural-dark
voices M2, M3, F3, F4.
Diagnostic trail:
- VE per-step cosine drop 1.0 → 0.45 stayed even after 3 prior fixes
- MLX latent std consistently 2-4 % lower than ONNX at every step
- Per-block bisect: first divergence at block 5 (cos 0.9987)
- Codex (task-mp...-eb8) found the missing constant by tracing
Concat_6 (K) vs Concat_7 (V) topology in the ONNX VE graph
Patch:
- Add _load_shared_style_key() helper that reads the constant from
vector_estimator.onnx (Expand_output_0) or text_encoder.onnx
(tts.ttl.style_encoder.style_token_layer.style_key) — both contain
the same bit-identical tensor
- _UncondMasker gains a 'style_key' attribute holding the cond K
- VectorEstimator.__call__ now passes style_key (broadcast) as the
cond K in both cfg=False and cfg=True paths, and threads it through
precompute_cross_kv via _style_k_for_precompute()
Measured impact (spectral centroid MLX vs ONNX, FR Newton phrase):
voice before-fix after-fix
F3 −776 Hz +27 Hz ← was dark, now ~match
F4 −697 Hz +20 Hz ← was dark, now ~match
M2 −815 Hz −317 Hz ← much improved
M3 −712 Hz +128 Hz ← USER'S complaint voice, now bright
M1 −537 Hz −219 Hz
F1 +62 Hz +303 Hz (a touch brighter, still good)
others small small
Whisper word overlap stays at 100 % on all 10 voices for natural FR.
M3 on the user's reported 'inaudible' scenario should now sound
clean on any machine.
This commit is contained in:
@@ -23,6 +23,8 @@ quantisation, and kernel fusion are layered on later in T.3.
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -59,6 +61,59 @@ def _mish(x: mx.array) -> mx.array:
|
||||
return x * mx.tanh(mx.logaddexp(x, mx.array(0.0, dtype=x.dtype)))
|
||||
|
||||
|
||||
def _load_shared_style_key() -> mx.array:
|
||||
"""Best-effort load of the fixed conditional style-attention key bank.
|
||||
|
||||
The upstream vector_estimator ONNX graph bakes this tensor in as the
|
||||
anonymous initializer ``/vector_estimator/Expand_output_0``. It is the same
|
||||
tensor as text_encoder ``tts.ttl.style_encoder.style_token_layer.style_key``.
|
||||
"""
|
||||
candidates: list[Path] = []
|
||||
for env_name in ("SUPERTONIC3_STYLE_KEY_ONNX", "SUPERTONIC3_TEXT_ENCODER_WEIGHTS"):
|
||||
if value := os.environ.get(env_name):
|
||||
candidates.append(Path(value))
|
||||
candidates.extend(
|
||||
[
|
||||
Path("/tmp/supertonic3/model/onnx/vector_estimator.onnx"),
|
||||
Path("/tmp/supertonic3/model/onnx/text_encoder.onnx"),
|
||||
Path.cwd() / "weights" / "text_encoder.safetensors",
|
||||
Path.cwd() / "sub-projects/supertonic3-mlx/hf_release/weights/text_encoder.safetensors",
|
||||
]
|
||||
)
|
||||
|
||||
for path in candidates:
|
||||
if not path.exists():
|
||||
continue
|
||||
try:
|
||||
if path.suffix == ".onnx":
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
|
||||
model = onnx.load(str(path))
|
||||
names = {
|
||||
"/vector_estimator/Expand_output_0",
|
||||
"tts.ttl.style_encoder.style_token_layer.style_key",
|
||||
}
|
||||
for init in model.graph.initializer:
|
||||
if init.name in names:
|
||||
arr = numpy_helper.to_array(init)
|
||||
if arr.shape == (1, STYLE_LEN, STYLE_DIM):
|
||||
return mx.array(arr.astype("float32", copy=False))
|
||||
elif path.suffix == ".safetensors":
|
||||
from safetensors import safe_open
|
||||
|
||||
with safe_open(str(path), framework="np") as f:
|
||||
key = "tts.ttl.style_encoder.style_token_layer.style_key"
|
||||
if key in f.keys():
|
||||
arr = f.get_tensor(key)
|
||||
if arr.shape == (1, STYLE_LEN, STYLE_DIM):
|
||||
return mx.array(arr.astype("float32", copy=False))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return mx.zeros((1, STYLE_LEN, STYLE_DIM))
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────
|
||||
# ConvNeXt building blocks
|
||||
# ──────────────────────────────────────────────────────────────────
|
||||
@@ -544,9 +599,10 @@ class _VectorField(nn.Module):
|
||||
|
||||
|
||||
class _UncondMasker(nn.Module):
|
||||
"""Holds the three unconditional-token tensors used by CFG.
|
||||
"""Holds the style-key bank plus unconditional-token tensors used by CFG.
|
||||
|
||||
Keys:
|
||||
``style_key`` (1, 50, 256)
|
||||
``text_special_token`` (1, 256, 1)
|
||||
``style_key_special_token`` (1, 50, 256)
|
||||
``style_value_special_token`` (1, 50, 256)
|
||||
@@ -554,6 +610,10 @@ class _UncondMasker(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Conditional style attention uses the fixed text-encoder style key bank
|
||||
# for K and the per-voice ``style_ttl`` for V. The vector_estimator ONNX
|
||||
# graph stores this as an anonymous initializer, so load it best-effort.
|
||||
self.style_key = _load_shared_style_key()
|
||||
# Initialised to zero; checkpoint provides real values.
|
||||
self.text_special_token = mx.zeros((1, TEXT_DIM, 1))
|
||||
self.style_key_special_token = mx.zeros((1, STYLE_LEN, STYLE_DIM))
|
||||
@@ -565,8 +625,9 @@ class VectorEstimator(nn.Module):
|
||||
|
||||
Two inference paths:
|
||||
- :meth:`velocity`: single forward pass; predicts the velocity from one set
|
||||
of conditioning inputs. ``style_k``/``style_v`` may be the same tensor
|
||||
(cond path) or different (uncond path of CFG).
|
||||
of conditioning inputs. Conditional style attention uses the fixed
|
||||
style key bank for K and ``style_ttl`` for V; CFG uses special-token
|
||||
K/V for the unconditional path.
|
||||
- :meth:`__call__`: full ONNX-parity forward — applies CFG batch doubling
|
||||
(cond + uncond) internally and combines via
|
||||
``final = noisy + (4*cond - 3*uncond) / total_step``.
|
||||
@@ -583,6 +644,28 @@ class VectorEstimator(nn.Module):
|
||||
self.vector_field = _VectorField()
|
||||
self.uncond_masker = _UncondMasker()
|
||||
|
||||
def _conditional_style_key(self, batch_size: int, dtype: mx.Dtype) -> mx.array:
|
||||
key = self.uncond_masker.style_key.astype(dtype)
|
||||
return mx.broadcast_to(key, (batch_size, STYLE_LEN, STYLE_DIM))
|
||||
|
||||
def _style_k_for_precompute(self, style_k: mx.array, style_v: mx.array) -> mx.array:
|
||||
batch = style_k.shape[0]
|
||||
if batch % 2 == 0 and batch > 1:
|
||||
half = batch // 2
|
||||
uncond_key = mx.broadcast_to(
|
||||
self.uncond_masker.style_key_special_token.astype(style_k.dtype),
|
||||
(batch - half, STYLE_LEN, STYLE_DIM),
|
||||
)
|
||||
try:
|
||||
mx.eval(uncond_key)
|
||||
looks_cfg = bool(mx.all(mx.abs(style_k[half:] - uncond_key) < 1e-5).item())
|
||||
except Exception:
|
||||
looks_cfg = False
|
||||
if looks_cfg:
|
||||
cond_key = self._conditional_style_key(half, style_k.dtype)
|
||||
return mx.concatenate([cond_key, style_k[half:]], axis=0)
|
||||
return self._conditional_style_key(batch, style_k.dtype)
|
||||
|
||||
# ── inference API ─────────────────────────────────────────────
|
||||
def velocity(
|
||||
self,
|
||||
@@ -641,6 +724,7 @@ class VectorEstimator(nn.Module):
|
||||
call; pre-projecting them once and feeding the result into
|
||||
:meth:`velocity_cached` cuts ~ 4 × 2 × 5 = 40 redundant matmuls.
|
||||
"""
|
||||
style_k = self._style_k_for_precompute(style_k, style_v)
|
||||
text_seq_len = mx.sum(text_mask, axis=(1, 2))
|
||||
text_ntc = text_emb.transpose(0, 2, 1) # (B, T_text, 256)
|
||||
|
||||
@@ -700,7 +784,7 @@ class VectorEstimator(nn.Module):
|
||||
self,
|
||||
noisy_latent: mx.array, # (B, 144, T_lat) channels-first per ONNX I/O
|
||||
text_emb: mx.array, # (B, 256, T_text) channels-first
|
||||
style_ttl: mx.array, # (B, 50, 256) — used as both K and V for cond
|
||||
style_ttl: mx.array, # (B, 50, 256) — V side for cond style attention
|
||||
latent_mask: mx.array, # (B, 1, T_lat)
|
||||
text_mask: mx.array, # (B, 1, T_text)
|
||||
current_step: mx.array, # (B,)
|
||||
@@ -721,15 +805,17 @@ class VectorEstimator(nn.Module):
|
||||
t_norm = current_step.astype(mx.float32) / total_step.astype(mx.float32)
|
||||
|
||||
if not cfg:
|
||||
style_key = self._conditional_style_key(B, style_ttl.dtype)
|
||||
v = self.velocity(
|
||||
noisy_latent, text_emb, style_ttl, style_ttl,
|
||||
noisy_latent, text_emb, style_key, style_ttl,
|
||||
latent_mask, text_mask, t_norm,
|
||||
)
|
||||
return noisy_latent + v / total_step.reshape(-1, 1, 1).astype(noisy_latent.dtype)
|
||||
|
||||
# CFG branch — build (2B, ...) inputs by concatenating cond + uncond.
|
||||
# uncond text_emb = text_special_token broadcast to (B, 256, T_text).
|
||||
# uncond style_k = style_key_special_token broadcast, similarly style_v.
|
||||
# cond style_k = fixed style_key broadcast; uncond style_k/style_v are
|
||||
# the learned special tokens broadcast to the batch.
|
||||
text_uncond = mx.broadcast_to(
|
||||
self.uncond_masker.text_special_token, (B, TEXT_DIM, text_emb.shape[2])
|
||||
)
|
||||
@@ -739,10 +825,11 @@ class VectorEstimator(nn.Module):
|
||||
style_v_uncond = mx.broadcast_to(
|
||||
self.uncond_masker.style_value_special_token, (B, STYLE_LEN, STYLE_DIM)
|
||||
)
|
||||
style_key_cond = self._conditional_style_key(B, style_ttl.dtype)
|
||||
|
||||
noisy_2 = mx.concatenate([noisy_latent, noisy_latent], axis=0)
|
||||
text_2 = mx.concatenate([text_emb, text_uncond], axis=0)
|
||||
style_k_2 = mx.concatenate([style_ttl, style_k_uncond], axis=0)
|
||||
style_k_2 = mx.concatenate([style_key_cond, style_k_uncond], axis=0)
|
||||
style_v_2 = mx.concatenate([style_ttl, style_v_uncond], axis=0)
|
||||
lm_2 = mx.concatenate([latent_mask, latent_mask], axis=0)
|
||||
tm_2 = mx.concatenate([text_mask, text_mask], axis=0)
|
||||
|
||||
Reference in New Issue
Block a user