fix(pipeline): wire TextEncoder style_key into VectorEstimator.uncond_masker
Sync from GitHub commit 42c7ca7 (user pushed directly).
VE's conditional-style-attention K is the shared style_key bank that
lives in TextEncoder ('tts.ttl.style_encoder.style_token_layer.style_key').
The MLX pipeline was building VE and TE independently and never wiring
the key over: '_load_shared_style_key()' (added in the previous fix)
falls back silently to mx.zeros((1, 50, 256)) when its disk path-scan
returns empty — which happens on any machine that doesn't have the
ONNX cache at /tmp/supertonic3/.
Effect: on the dev M3 Ultra (where the ONNX cache exists), the loader
found the file → audio was fine. On the user's other Mac (no cache) →
style_key fell back to zeros → conditional attention K = 0 → CFG combine
4*cond - 3*uncond collapsed M3 (the lowest-norm style_ttl) to near-DC
noise → Whisper hallucinated 'Merci.' / 'PO PO PO...'.
Fix: copy te.tts.ttl.style_encoder.style_token_layer.style_key into
ve.uncond_masker.style_key right after both submodules are built, in
both _from_safetensors and _from_onnx code paths.
Validated on M3 Ultra: VE.uncond_masker.style_key.sum(|x|) goes from
0.0 to ~3627.34; Whisper on all 13 voices (10 presets + 3 customs)
returns 81-94 % word overlap on the test phrase, with M3 at 94 %.
This commit is contained in:
@@ -441,6 +441,12 @@ class SupertonicMLXPipeline:
|
|||||||
dp = _build(DurationPredictor, "duration_predictor")
|
dp = _build(DurationPredictor, "duration_predictor")
|
||||||
voc = _build(Vocoder, "vocoder")
|
voc = _build(Vocoder, "vocoder")
|
||||||
|
|
||||||
|
# Conditional-style-attention K is the shared style_key bank that lives
|
||||||
|
# in TextEncoder. Wire it into VectorEstimator's uncond_masker so the
|
||||||
|
# CFG (4*cond - 3*uncond) combine has a valid K on the cond branch.
|
||||||
|
# Without this, low-norm voice styles (M3) collapse to near-DC noise.
|
||||||
|
ve.uncond_masker.style_key = te.tts.ttl.style_encoder.style_token_layer.style_key
|
||||||
|
|
||||||
if dtype is not None and dtype != mx.float32:
|
if dtype is not None and dtype != mx.float32:
|
||||||
cls._cast_all(dp, te, ve, voc, dtype=dtype)
|
cls._cast_all(dp, te, ve, voc, dtype=dtype)
|
||||||
|
|
||||||
@@ -463,6 +469,8 @@ class SupertonicMLXPipeline:
|
|||||||
voc = Vocoder()
|
voc = Vocoder()
|
||||||
_load_into(voc, _convert_onnx(onnx_dir / "vocoder.onnx"))
|
_load_into(voc, _convert_onnx(onnx_dir / "vocoder.onnx"))
|
||||||
|
|
||||||
|
ve.uncond_masker.style_key = te.tts.ttl.style_encoder.style_token_layer.style_key
|
||||||
|
|
||||||
if dtype is not None and dtype != mx.float32:
|
if dtype is not None and dtype != mx.float32:
|
||||||
cls._cast_all(dp, te, ve, voc, dtype=dtype)
|
cls._cast_all(dp, te, ve, voc, dtype=dtype)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user