From d02690dc0bbdd706520e7f81b9e14ae1510e7b8a Mon Sep 17 00:00:00 2001 From: ambassadia Date: Wed, 20 May 2026 11:14:27 +0200 Subject: [PATCH] fix(critical): missing residual in DurationPredictor.sentence_encoder MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The root-cause of the audio gibberish. The ONNX graph has a residual ADD between attn_encoder output and convnext output before the slot-0 extraction that feeds proj_out: /sentence_encoder/Add = attn_encoder/Mul_2_output + convnext/convnext.5/Mul_3_output /sentence_encoder/Slice_1 = Add[:, :, 0:1] /sentence_encoder/proj_out/Conv = Conv1d(Slice_1, ...) The MLX port was skipping this residual: x = self.convnext(x, mask_ntc) x = self.attn_encoder(x, mask_ntc) sentence_out = x[:, :1, :] # ← missing + convnext residual Effect: the sentence vector fed into the predictor MLP was wrong → log duration was systematically 0.95 nats lower than ONNX → predicted duration was 35 % of correct length → T_lat 3 × too short → VE had to compress speech into 1/3 of the proper frames → audio unintelligible. Fix (one line): explicitly hold both x_conv and x_attn outputs and add them before the slot-0 slice. Measured impact on the FR test phrase 'Bonjour, je suis une voix générée par le modèle Supertonic trois en MLX sur Apple Silicon.' (Whisper-large-v3 word overlap, MLX FP32): voice before-fix after-fix F1 25 % 88 % F2 25 % 88 % F3 19 % 88 % F4 0 % 88 % F5 12 % 81 % M1 12 % 88 % M2 56 % 88 % M3 0 % 75 % M4 6 % 81 % M5 0 % 94 % avg 16 % 86 % The ONNX SDK reference ceiling on the same phrase is 81-88 %, so MLX is now AT parity with the upstream ONNX SDK. Bisection trail: DurationPredictor MLX output was 35 % of ONNX on a side-by-side check; sentence_encoder per-stage compare showed cosine 1.0 through text_embedder + convnext + attn_encoder, then a drop to 0.149 at proj_out — caught by tracing the ONNX Slice_1 producer to a missing Add node. Both the timestep schedule fix (step+1 → step) and the -token tokenization fix from the previous commit are still needed; this third fix closes the gap to ONNX SDK quality. Repos can be re-published after this commit. --- src/supertonic_3_mlx/duration_predictor.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/supertonic_3_mlx/duration_predictor.py b/src/supertonic_3_mlx/duration_predictor.py index 946f8af..aac2eb1 100644 --- a/src/supertonic_3_mlx/duration_predictor.py +++ b/src/supertonic_3_mlx/duration_predictor.py @@ -247,8 +247,19 @@ class _DPSentenceEncoder(nn.Module): else: mask_ntc = None - x = self.convnext(x, mask_ntc) - x = self.attn_encoder(x, mask_ntc) + x_conv = self.convnext(x, mask_ntc) + x_attn = self.attn_encoder(x_conv, mask_ntc) + + # Residual connection: ONNX graph adds the convnext output back to the + # attn_encoder output before the slot-0 extraction + # (``/sentence_encoder/Add = attn_encoder/Mul_2_output + convnext/convnext.5/Mul_3_output``). + # Missing this residual is what caused MLX DurationPredictor to return + # ~35 % of the correct duration (T_lat too short → audio gibberish); + # see Whisper validation in tools/whisper_validate.py for the smoking + # gun. Inputs were forwarded with cosine 1.0 through both convnext and + # attn_encoder, but slot-0 of the missing-residual output diverged to + # cosine 0.149 vs ONNX. + x = x_attn + x_conv # Take slot 0 (sentence token output) → (B, 1, 64) sentence_out = x[:, :1, :] # (B, 1, 64)