diff --git a/src/supertonic_3_mlx/pipeline.py b/src/supertonic_3_mlx/pipeline.py
index 2af0662..5da6a62 100644
--- a/src/supertonic_3_mlx/pipeline.py
+++ b/src/supertonic_3_mlx/pipeline.py
@@ -25,7 +25,7 @@ Flow:
Public API:
- pipe = SupertonicMLXPipeline.from_pretrained("ambassadia/supertonic-3-mlx")
+ pipe = SupertonicMLXPipeline.from_pretrained("/tmp/supertonic3/model")
wav = pipe.generate("Hello world", voice="F1", lang="en")
import soundfile as sf
sf.write("out.wav", wav, pipe.sample_rate)
@@ -214,15 +214,49 @@ def _load_into(model, weights: dict) -> int:
# ── Tokenization ────────────────────────────────────────────────────
-def _encode_text(text: str, indexer: list[int], lang: str = "en") -> np.ndarray:
- """Encode a text string into character IDs.
+_ENDING_PUNCT = ".!?,;:'\")]}»›"
- The unicode_indexer is a flat list of size 65536; ``indexer[ord(c)]`` gives
- the token ID for character ``c`` (-1 = unknown). For Phase T.4 we wrap the
- text with no special language tokens — the ONNX SDK uses language tags but
- our pipeline currently runs unconditioned on language for the first WAV
- emission (parity validation happens after).
+
+def _preprocess_text(text: str, lang: str = "en") -> str:
+ """Mirror the SDK's UnicodeProcessor._preprocess_text contract.
+
+ Supertonic 3 is multilingual; the model is trained with utterances
+ wrapped in ``...`` language tokens (Supertone's
+ ``UnicodeProcessor._add_language_token``). Bypassing this wrapping was
+ the secondary bug that compounded with the off-by-one Euler schedule to
+ produce structureless audio (verified by ONNX-only ablation in
+ ``debug/supertonic3_schedule_ablation.py``).
+
+ Minimum viable port of the SDK's pipeline:
+ 1. NFKD unicode normalisation
+ 2. Whitespace collapse + strip
+ 3. Trailing period if the string doesn't end with punctuation
+ 4. Language token wrap ``text``
+
+ The SDK additionally performs emoji removal, symbol normalisation,
+ abbreviation expansion, and quote deduplication — those are quality
+ polish and can be ported later; they are not load-bearing for the
+ primary fix.
"""
+ import unicodedata, re
+ text = unicodedata.normalize("NFKD", text)
+ text = re.sub(r"\s+", " ", text).strip()
+ if text and text[-1] not in _ENDING_PUNCT:
+ text += "."
+ if lang is not None:
+ text = f"<{lang}>{text}{lang}>"
+ return text
+
+
+def _encode_text(text: str, indexer: list[int], lang: str = "en") -> np.ndarray:
+ """Encode a text string into character IDs via the SDK-compatible pipeline.
+
+ ``indexer`` is a flat list of size 65536; ``indexer[ord(c)]`` gives the
+ token ID for character ``c`` (-1 = unknown). The text is first
+ preprocessed via :func:`_preprocess_text` so the encoding matches what
+ Supertonic 3 was trained on (NFKD-normalised + ````-wrapped).
+ """
+ text = _preprocess_text(text, lang=lang)
ids = []
for c in text:
cp = ord(c)
@@ -231,7 +265,6 @@ def _encode_text(text: str, indexer: list[int], lang: str = "en") -> np.ndarray:
if tok >= 0:
ids.append(tok)
if not ids:
- # fallback to a single space token to avoid empty input
ids = [indexer[ord(" ")]] if indexer[ord(" ")] >= 0 else [0]
return np.asarray(ids, dtype=np.int32)
@@ -522,11 +555,18 @@ class SupertonicMLXPipeline:
for k, v in style_kv:
kv_flat.extend([k, v])
- # Euler with CFG — 5 steps by default
+ # Euler with CFG — 5 steps by default.
+ # NOTE: ONNX SDK passes ``current_step = 0..N-1`` and computes
+ # ``t_norm = current_step / total_step`` → schedule = [0.0, 0.2,
+ # 0.4, 0.6, 0.8]. Previously we were passing ``step + 1`` which
+ # shifted the schedule to [0.2, 0.4, 0.6, 0.8, 1.0]; the flow-matching
+ # model is trained on the SDK schedule and the off-by-one collapses
+ # the audio to structureless noise (verified by ONNX-only ablation
+ # in debug/supertonic3_schedule_ablation.py — wav cosine 0.0037).
x = noise
total_step = mx.array([float(n_steps)], dtype=self.dtype)
for step in range(n_steps):
- current_step = mx.array([float(step + 1)], dtype=self.dtype)
+ current_step = mx.array([float(step)], dtype=self.dtype)
t_norm = current_step / total_step
t_norm_2 = mx.concatenate([t_norm, t_norm], axis=0)
x = self._cached_step_compiled(