diff --git a/src/supertonic_3_mlx/pipeline.py b/src/supertonic_3_mlx/pipeline.py index 058a6d3..3867ec5 100644 --- a/src/supertonic_3_mlx/pipeline.py +++ b/src/supertonic_3_mlx/pipeline.py @@ -441,6 +441,12 @@ class SupertonicMLXPipeline: dp = _build(DurationPredictor, "duration_predictor") voc = _build(Vocoder, "vocoder") + # Conditional-style-attention K is the shared style_key bank that lives + # in TextEncoder. Wire it into VectorEstimator's uncond_masker so the + # CFG (4*cond - 3*uncond) combine has a valid K on the cond branch. + # Without this, low-norm voice styles (M3) collapse to near-DC noise. + ve.uncond_masker.style_key = te.tts.ttl.style_encoder.style_token_layer.style_key + if dtype is not None and dtype != mx.float32: cls._cast_all(dp, te, ve, voc, dtype=dtype) @@ -463,6 +469,8 @@ class SupertonicMLXPipeline: voc = Vocoder() _load_into(voc, _convert_onnx(onnx_dir / "vocoder.onnx")) + ve.uncond_masker.style_key = te.tts.ttl.style_encoder.style_token_layer.style_key + if dtype is not None and dtype != mx.float32: cls._cast_all(dp, te, ve, voc, dtype=dtype)