diff --git a/src/supertonic_3_mlx/duration_predictor.py b/src/supertonic_3_mlx/duration_predictor.py index 946f8af..aac2eb1 100644 --- a/src/supertonic_3_mlx/duration_predictor.py +++ b/src/supertonic_3_mlx/duration_predictor.py @@ -247,8 +247,19 @@ class _DPSentenceEncoder(nn.Module): else: mask_ntc = None - x = self.convnext(x, mask_ntc) - x = self.attn_encoder(x, mask_ntc) + x_conv = self.convnext(x, mask_ntc) + x_attn = self.attn_encoder(x_conv, mask_ntc) + + # Residual connection: ONNX graph adds the convnext output back to the + # attn_encoder output before the slot-0 extraction + # (``/sentence_encoder/Add = attn_encoder/Mul_2_output + convnext/convnext.5/Mul_3_output``). + # Missing this residual is what caused MLX DurationPredictor to return + # ~35 % of the correct duration (T_lat too short → audio gibberish); + # see Whisper validation in tools/whisper_validate.py for the smoking + # gun. Inputs were forwarded with cosine 1.0 through both convnext and + # attn_encoder, but slot-0 of the missing-residual output diverged to + # cosine 0.149 vs ONNX. + x = x_attn + x_conv # Take slot 0 (sentence token output) → (B, 1, 64) sentence_out = x[:, :1, :] # (B, 1, 64)