diff --git a/src/supertonic_3_mlx/pipeline.py b/src/supertonic_3_mlx/pipeline.py index 2af0662..5da6a62 100644 --- a/src/supertonic_3_mlx/pipeline.py +++ b/src/supertonic_3_mlx/pipeline.py @@ -25,7 +25,7 @@ Flow: Public API: - pipe = SupertonicMLXPipeline.from_pretrained("ambassadia/supertonic-3-mlx") + pipe = SupertonicMLXPipeline.from_pretrained("/tmp/supertonic3/model") wav = pipe.generate("Hello world", voice="F1", lang="en") import soundfile as sf sf.write("out.wav", wav, pipe.sample_rate) @@ -214,15 +214,49 @@ def _load_into(model, weights: dict) -> int: # ── Tokenization ──────────────────────────────────────────────────── -def _encode_text(text: str, indexer: list[int], lang: str = "en") -> np.ndarray: - """Encode a text string into character IDs. +_ENDING_PUNCT = ".!?,;:'\")]}»›" - The unicode_indexer is a flat list of size 65536; ``indexer[ord(c)]`` gives - the token ID for character ``c`` (-1 = unknown). For Phase T.4 we wrap the - text with no special language tokens — the ONNX SDK uses language tags but - our pipeline currently runs unconditioned on language for the first WAV - emission (parity validation happens after). + +def _preprocess_text(text: str, lang: str = "en") -> str: + """Mirror the SDK's UnicodeProcessor._preprocess_text contract. + + Supertonic 3 is multilingual; the model is trained with utterances + wrapped in ``...`` language tokens (Supertone's + ``UnicodeProcessor._add_language_token``). Bypassing this wrapping was + the secondary bug that compounded with the off-by-one Euler schedule to + produce structureless audio (verified by ONNX-only ablation in + ``debug/supertonic3_schedule_ablation.py``). + + Minimum viable port of the SDK's pipeline: + 1. NFKD unicode normalisation + 2. Whitespace collapse + strip + 3. Trailing period if the string doesn't end with punctuation + 4. Language token wrap ``text`` + + The SDK additionally performs emoji removal, symbol normalisation, + abbreviation expansion, and quote deduplication — those are quality + polish and can be ported later; they are not load-bearing for the + primary fix. """ + import unicodedata, re + text = unicodedata.normalize("NFKD", text) + text = re.sub(r"\s+", " ", text).strip() + if text and text[-1] not in _ENDING_PUNCT: + text += "." + if lang is not None: + text = f"<{lang}>{text}" + return text + + +def _encode_text(text: str, indexer: list[int], lang: str = "en") -> np.ndarray: + """Encode a text string into character IDs via the SDK-compatible pipeline. + + ``indexer`` is a flat list of size 65536; ``indexer[ord(c)]`` gives the + token ID for character ``c`` (-1 = unknown). The text is first + preprocessed via :func:`_preprocess_text` so the encoding matches what + Supertonic 3 was trained on (NFKD-normalised + ````-wrapped). + """ + text = _preprocess_text(text, lang=lang) ids = [] for c in text: cp = ord(c) @@ -231,7 +265,6 @@ def _encode_text(text: str, indexer: list[int], lang: str = "en") -> np.ndarray: if tok >= 0: ids.append(tok) if not ids: - # fallback to a single space token to avoid empty input ids = [indexer[ord(" ")]] if indexer[ord(" ")] >= 0 else [0] return np.asarray(ids, dtype=np.int32) @@ -522,11 +555,18 @@ class SupertonicMLXPipeline: for k, v in style_kv: kv_flat.extend([k, v]) - # Euler with CFG — 5 steps by default + # Euler with CFG — 5 steps by default. + # NOTE: ONNX SDK passes ``current_step = 0..N-1`` and computes + # ``t_norm = current_step / total_step`` → schedule = [0.0, 0.2, + # 0.4, 0.6, 0.8]. Previously we were passing ``step + 1`` which + # shifted the schedule to [0.2, 0.4, 0.6, 0.8, 1.0]; the flow-matching + # model is trained on the SDK schedule and the off-by-one collapses + # the audio to structureless noise (verified by ONNX-only ablation + # in debug/supertonic3_schedule_ablation.py — wav cosine 0.0037). x = noise total_step = mx.array([float(n_steps)], dtype=self.dtype) for step in range(n_steps): - current_step = mx.array([float(step + 1)], dtype=self.dtype) + current_step = mx.array([float(step)], dtype=self.dtype) t_norm = current_step / total_step t_norm_2 = mx.concatenate([t_norm, t_norm], axis=0) x = self._cached_step_compiled(