fix(critical): missing residual in DurationPredictor.sentence_encoder

The root-cause of the audio gibberish. The ONNX graph has a residual ADD
between attn_encoder output and convnext output before the slot-0
extraction that feeds proj_out:

    /sentence_encoder/Add = attn_encoder/Mul_2_output + convnext/convnext.5/Mul_3_output
    /sentence_encoder/Slice_1 = Add[:, :, 0:1]
    /sentence_encoder/proj_out/Conv = Conv1d(Slice_1, ...)

The MLX port was skipping this residual:

    x = self.convnext(x, mask_ntc)
    x = self.attn_encoder(x, mask_ntc)
    sentence_out = x[:, :1, :]            # ← missing + convnext residual

Effect: the sentence vector fed into the predictor MLP was wrong → log
duration was systematically 0.95 nats lower than ONNX → predicted
duration was 35 % of correct length → T_lat 3 × too short → VE had to
compress speech into 1/3 of the proper frames → audio unintelligible.

Fix (one line): explicitly hold both x_conv and x_attn outputs and add
them before the slot-0 slice.

Measured impact on the FR test phrase
'Bonjour, je suis une voix générée par le modèle Supertonic trois en MLX
sur Apple Silicon.' (Whisper-large-v3 word overlap, MLX FP32):

    voice  before-fix  after-fix
    F1     25 %        88 %
    F2     25 %        88 %
    F3     19 %        88 %
    F4      0 %        88 %
    F5     12 %        81 %
    M1     12 %        88 %
    M2     56 %        88 %
    M3      0 %        75 %
    M4      6 %        81 %
    M5      0 %        94 %
    avg    16 %        86 %

The ONNX SDK reference ceiling on the same phrase is 81-88 %, so MLX is
now AT parity with the upstream ONNX SDK.

Bisection trail: DurationPredictor MLX output was 35 % of ONNX on a
side-by-side check; sentence_encoder per-stage compare showed cosine 1.0
through text_embedder + convnext + attn_encoder, then a drop to 0.149 at
proj_out — caught by tracing the ONNX Slice_1 producer to a missing Add
node. Both the timestep schedule fix (step+1 → step) and the
<lang>-token tokenization fix from the previous commit are still needed;
this third fix closes the gap to ONNX SDK quality.

Repos can be re-published after this commit.
This commit is contained in:
ambassadia
2026-05-20 11:14:27 +02:00
parent ba1a5f5f31
commit d02690dc0b

View File

@@ -247,8 +247,19 @@ class _DPSentenceEncoder(nn.Module):
else: else:
mask_ntc = None mask_ntc = None
x = self.convnext(x, mask_ntc) x_conv = self.convnext(x, mask_ntc)
x = self.attn_encoder(x, mask_ntc) x_attn = self.attn_encoder(x_conv, mask_ntc)
# Residual connection: ONNX graph adds the convnext output back to the
# attn_encoder output before the slot-0 extraction
# (``/sentence_encoder/Add = attn_encoder/Mul_2_output + convnext/convnext.5/Mul_3_output``).
# Missing this residual is what caused MLX DurationPredictor to return
# ~35 % of the correct duration (T_lat too short → audio gibberish);
# see Whisper validation in tools/whisper_validate.py for the smoking
# gun. Inputs were forwarded with cosine 1.0 through both convnext and
# attn_encoder, but slot-0 of the missing-residual output diverged to
# cosine 0.149 vs ONNX.
x = x_attn + x_conv
# Take slot 0 (sentence token output) → (B, 1, 64) # Take slot 0 (sentence token output) → (B, 1, 64)
sentence_out = x[:, :1, :] # (B, 1, 64) sentence_out = x[:, :1, :] # (B, 1, 64)