fix(pipeline): wire TextEncoder style_key into VectorEstimator.uncond_masker

Sync from GitHub commit 42c7ca7 (user pushed directly).

VE's conditional-style-attention K is the shared style_key bank that
lives in TextEncoder ('tts.ttl.style_encoder.style_token_layer.style_key').
The MLX pipeline was building VE and TE independently and never wiring
the key over: '_load_shared_style_key()' (added in the previous fix)
falls back silently to mx.zeros((1, 50, 256)) when its disk path-scan
returns empty — which happens on any machine that doesn't have the
ONNX cache at /tmp/supertonic3/.

Effect: on the dev M3 Ultra (where the ONNX cache exists), the loader
found the file → audio was fine. On the user's other Mac (no cache) →
style_key fell back to zeros → conditional attention K = 0 → CFG combine
4*cond - 3*uncond collapsed M3 (the lowest-norm style_ttl) to near-DC
noise → Whisper hallucinated 'Merci.' / 'PO PO PO...'.

Fix: copy te.tts.ttl.style_encoder.style_token_layer.style_key into
ve.uncond_masker.style_key right after both submodules are built, in
both _from_safetensors and _from_onnx code paths.

Validated on M3 Ultra: VE.uncond_masker.style_key.sum(|x|) goes from
0.0 to ~3627.34; Whisper on all 13 voices (10 presets + 3 customs)
returns 81-94 % word overlap on the test phrase, with M3 at 94 %.
This commit is contained in:
ambassadia
2026-05-20 15:28:20 +02:00
parent a3f44d0661
commit 052b24d0ac

View File

@@ -441,6 +441,12 @@ class SupertonicMLXPipeline:
dp = _build(DurationPredictor, "duration_predictor")
voc = _build(Vocoder, "vocoder")
# Conditional-style-attention K is the shared style_key bank that lives
# in TextEncoder. Wire it into VectorEstimator's uncond_masker so the
# CFG (4*cond - 3*uncond) combine has a valid K on the cond branch.
# Without this, low-norm voice styles (M3) collapse to near-DC noise.
ve.uncond_masker.style_key = te.tts.ttl.style_encoder.style_token_layer.style_key
if dtype is not None and dtype != mx.float32:
cls._cast_all(dp, te, ve, voc, dtype=dtype)
@@ -463,6 +469,8 @@ class SupertonicMLXPipeline:
voc = Vocoder()
_load_into(voc, _convert_onnx(onnx_dir / "vocoder.onnx"))
ve.uncond_masker.style_key = te.tts.ttl.style_encoder.style_token_layer.style_key
if dtype is not None and dtype != mx.float32:
cls._cast_all(dp, te, ve, voc, dtype=dtype)