fix(critical): missing residual in DurationPredictor.sentence_encoder
The root-cause of the audio gibberish. The ONNX graph has a residual ADD
between attn_encoder output and convnext output before the slot-0
extraction that feeds proj_out:
/sentence_encoder/Add = attn_encoder/Mul_2_output + convnext/convnext.5/Mul_3_output
/sentence_encoder/Slice_1 = Add[:, :, 0:1]
/sentence_encoder/proj_out/Conv = Conv1d(Slice_1, ...)
The MLX port was skipping this residual:
x = self.convnext(x, mask_ntc)
x = self.attn_encoder(x, mask_ntc)
sentence_out = x[:, :1, :] # ← missing + convnext residual
Effect: the sentence vector fed into the predictor MLP was wrong → log
duration was systematically 0.95 nats lower than ONNX → predicted
duration was 35 % of correct length → T_lat 3 × too short → VE had to
compress speech into 1/3 of the proper frames → audio unintelligible.
Fix (one line): explicitly hold both x_conv and x_attn outputs and add
them before the slot-0 slice.
Measured impact on the FR test phrase
'Bonjour, je suis une voix générée par le modèle Supertonic trois en MLX
sur Apple Silicon.' (Whisper-large-v3 word overlap, MLX FP32):
voice before-fix after-fix
F1 25 % 88 %
F2 25 % 88 %
F3 19 % 88 %
F4 0 % 88 %
F5 12 % 81 %
M1 12 % 88 %
M2 56 % 88 %
M3 0 % 75 %
M4 6 % 81 %
M5 0 % 94 %
avg 16 % 86 %
The ONNX SDK reference ceiling on the same phrase is 81-88 %, so MLX is
now AT parity with the upstream ONNX SDK.
Bisection trail: DurationPredictor MLX output was 35 % of ONNX on a
side-by-side check; sentence_encoder per-stage compare showed cosine 1.0
through text_embedder + convnext + attn_encoder, then a drop to 0.149 at
proj_out — caught by tracing the ONNX Slice_1 producer to a missing Add
node. Both the timestep schedule fix (step+1 → step) and the
<lang>-token tokenization fix from the previous commit are still needed;
this third fix closes the gap to ONNX SDK quality.
Repos can be re-published after this commit.
This commit is contained in:
@@ -247,8 +247,19 @@ class _DPSentenceEncoder(nn.Module):
|
||||
else:
|
||||
mask_ntc = None
|
||||
|
||||
x = self.convnext(x, mask_ntc)
|
||||
x = self.attn_encoder(x, mask_ntc)
|
||||
x_conv = self.convnext(x, mask_ntc)
|
||||
x_attn = self.attn_encoder(x_conv, mask_ntc)
|
||||
|
||||
# Residual connection: ONNX graph adds the convnext output back to the
|
||||
# attn_encoder output before the slot-0 extraction
|
||||
# (``/sentence_encoder/Add = attn_encoder/Mul_2_output + convnext/convnext.5/Mul_3_output``).
|
||||
# Missing this residual is what caused MLX DurationPredictor to return
|
||||
# ~35 % of the correct duration (T_lat too short → audio gibberish);
|
||||
# see Whisper validation in tools/whisper_validate.py for the smoking
|
||||
# gun. Inputs were forwarded with cosine 1.0 through both convnext and
|
||||
# attn_encoder, but slot-0 of the missing-residual output diverged to
|
||||
# cosine 0.149 vs ONNX.
|
||||
x = x_attn + x_conv
|
||||
|
||||
# Take slot 0 (sentence token output) → (B, 1, 64)
|
||||
sentence_out = x[:, :1, :] # (B, 1, 64)
|
||||
|
||||
Reference in New Issue
Block a user