v0.1.0 — initial release

MLX-native port of Supertone's Supertonic 3 multilingual TTS. Runs the
full flow-matching + classifier-free-guidance pipeline at ~x100 realtime
on Apple Silicon, with audio cosine 1.0 vs the cached MLX path and
cosine 0.98 vs the upstream ONNX Runtime reference.

Weights are hosted at https://huggingface.co/ambassadia/supertonic-3-mlx
and auto-downloaded on first use; this repository ships the port code,
the model card, audio samples, and a zero-config setup_and_test.sh.

Install:
    pip install git+https://gitea.tavportal.com/olivier/supertonic-3-mlx.git

Quick test:
    git clone https://gitea.tavportal.com/olivier/supertonic-3-mlx.git
    cd supertonic-3-mlx && ./setup_and_test.sh

Licenses (dual): model weights = BigScience Open RAIL-M (Section 4
propagation), port code = Apache-2.0. See LICENSE, LICENSE-CODE, NOTICE.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
transcrilive
2026-05-20 09:17:05 +02:00
commit 12dbf4a821
36 changed files with 3812 additions and 0 deletions

7
.gitignore vendored Normal file
View File

@@ -0,0 +1,7 @@
.venv/
__pycache__/
*.pyc
hello.wav
*.egg-info/
build/
dist/

209
LICENSE Normal file
View File

@@ -0,0 +1,209 @@
BigScience Open RAIL-M License
dated August 18, 2022
Section I: PREAMBLE
This Open RAIL-M License was created by BigScience, a collaborative open innovation project aimed at
the responsible development and use of large multilingual datasets and Large Language Models
(“LLMs”). While a similar license was originally designed for the BLOOM model, we decided to adapt it
and create this license in order to propose a general open and responsible license applicable to other
machine learning based AI models (e.g. multimodal generative models).
In short, this license strives for both the open and responsible downstream use of the accompanying
model. When it comes to the open character, we took inspiration from open source permissive licenses
regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based
restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be
able to enforce the license in case potential misuses of the Model may occur. Even though downstream
derivative versions of the model could be released under different licensing terms, the latter will always
have to include - at minimum - the same use-based restrictions as the ones in the original license (this
license).
The development and use of artificial intelligence (“AI”), does not come without concerns. The world has
witnessed how AI techniques may, in some instances, become risky for the public in general. These risks
come in many forms, from racial discrimination to the misuse of sensitive information.
BigScience believes in the intersection between open and responsible AI development; thus, this License
aims to strike a balance between both in order to enable responsible open-science in the field of AI.
This License governs the use of the model (and its derivatives) and is informed by the model card
associated with the model.
NOW THEREFORE, You and Licensor agree as follows:
1. Definitions
(a) "License" means the terms and conditions for use, reproduction, and Distribution as defined in
this document.
(b) “Data” means a collection of information and/or content extracted from the dataset used with the
Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under
this License.
(c)“Output” means the results of operating a Model as embodied in informational content resulting
therefrom.
(d)“Model” means any accompanying machine-learning based assemblies (including checkpoints),
consisting of learnt weights, parameters (including optimizer states), corresponding to the model
architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or
in part on the Data, using the Complementary Material.
(e) “Derivatives of the Model” means all modifications to the Model, works based on the Model, or any
other model which is created or initialized by transfer of patterns of the weights, parameters,
activations or output of the Model, to the other model, in order to cause the other model to perform
similarly to the Model, including - but not limited to - distillation methods entailing the use of
intermediate data representations or methods based on the generation of synthetic data by the Model
for training the other model.
(f)“Complementary Material” means the accompanying source code and scripts used to define,
run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if
any. This includes any accompanying documentation, tutorials, examples, etc, if any.
(g) “Distribution” means any transmission, reproduction, publication or other sharing of the Model or
Derivatives of the Model to a third party, including providing the Model as a hosted service made
available by electronic or other remote means - e.g. API-based or web access.
(h) “Licensor” means the copyright owner or entity authorized by the copyright owner that is
granting the License, including the persons or entities that may have rights in the Model and/or
distributing the Model.
(i) "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this
License and/or making use of the Model for whichever purpose and in any field of use, including
usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
(j) “Third Parties” means individuals or legal entities that are not under common control with
Licensor or You.
(k) "Contribution" means any work of authorship, including the original version of the Model and
any modifications or additions to that Model or Derivatives of the Model thereof, that is
intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an
individual or Legal Entity authorized to submit on behalf of the copyright owner. For the
purposes of this definition,
“submitted” means any form of electronic, verbal, or written
communication sent to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems, and issue tracking
systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and
improving the Model, but excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
(l) "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a
Contribution has been received by Licensor and subsequently incorporated within the Model.
Section II: INTELLECTUAL PROPERTY RIGHTS
Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary
Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor
hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the
Complementary Material, the Model, and Derivatives of the Model.
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as
applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge,
royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer
to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such
license applies only to those patent claims licensable by such Contributor that are necessarily infringed by
their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such
Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim
or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution
incorporated within the Model and/or Complementary Material constitutes direct or contributory patent
infringement, then any patent licenses granted to You under this License for the Model and/or Work shall
terminate as of the date such litigation is asserted or filed.
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g.
software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof
in any medium, with or without modifications, provided that You meet the following conditions:
a. Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision
by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the
Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to,
that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply
to the use of Complementary Material.
b. You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this
License;
c. You must cause any modified files to carry prominent notices stating that You changed the files;
d. You must retain all copyright, patent, trademark, and attribution notices excluding those notices
that do not pertain to any part of the Model, Derivatives of the Model.
You may add Your own copyright statement to Your modifications and may provide additional or
different license terms and conditions - respecting paragraph 4.a.
- for use, reproduction, or Distribution
of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use,
reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions.
Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You
may use the Model subject to this License, including only for lawful purposes and in accordance with the
License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or
reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model
to comply with the terms of this paragraph (paragraph 5).
6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You
generate using the Model. You are accountable for the Output you generate and its subsequent uses. No
use of the output can contravene any provision as stated in the License.
Section IV: OTHER PROVISIONS
7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the
right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model
through electronic means, or modify the Output of the Model based on updates. You shall undertake
reasonable efforts to use the latest version of the Model.
8. Trademarks and related. Nothing in this License permits You to make use of Licensors trademarks,
trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the
parties; and any rights not expressly granted herein are reserved by the Licensors.
9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides
the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS
IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied,
including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT,
MERCHANTABILITY , or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for
determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the
Complementary Material and assume any risks associated with Your exercise of permissions under this
License.
10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence),
contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or
agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect,
special, incidental, or consequential damages of any character arising as a result of this License or out of
the use or inability to use the Model and the Complementary Material (including but not limited to
damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other
commercial damages or losses), even if such Contributor has been advised of the possibility of such
damages.
11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the
Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance
of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License.
However, in accepting such obligations, You may act only on Your own behalf and on Your sole
responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and
hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor
by reason of your accepting any such warranty or additional liability.
12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining
provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
END OF TERMS AND CONDITIONS
Attachment A
Use Restrictions
You agree not to use the Model or Derivatives of the Model:
(a) In any way that violates any applicable national, federal, state, local or international law
or regulation;
(b) For the purpose of exploiting, harming or attempting to exploit or harm minors in any
way;
(c) To generate or disseminate verifiably false information and/or content with the purpose of
harming others;
(d) To generate or disseminate personal identifiable information that can be used to harm an
individual;
(e) To generate or disseminate information and/or content (e.g. images, code, posts, articles),
and place the information and/or content in any context (e.g. bot generating tweets)
without expressly and intelligibly disclaiming that the information and/or content is
machine generated;
(f) To defame, disparage or otherwise harass others;
(g) To impersonate or attempt to impersonate (e.g. deepfakes) others without their consent;
(h) For fully automated decision making that adversely impacts an individuals legal rights or
otherwise creates or modifies a binding, enforceable obligation;
(i) For any use intended to or which has the effect of discriminating against or harming
individuals or groups based on online or offline social behavior or known or predicted
personal or personality characteristics;
(j) To exploit any of the vulnerabilities of a specific group of persons based on their age,
social, physical or mental characteristics, in order to materially distort the behavior of a
person pertaining to that group in a manner that causes or is likely to cause that person or
another person physical or psychological harm;
(k) For any use intended to or which has the effect of discriminating against individuals or
groups based on legally protected characteristics or categories;
(l) To provide medical advice and medical results interpretation;
(m) To generate or disseminate information for the purpose to be used for administration of
justice, law enforcement, immigration or asylum processes, such as predicting an
individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal
relationships between assertions made in documents, indiscriminate and
arbitrarily-targeted use).

202
LICENSE-CODE Normal file
View File

@@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

39
NOTICE Normal file
View File

@@ -0,0 +1,39 @@
supertonic-3-mlx
================
This release is a derivative of the upstream Supertone Supertonic 3
text-to-speech model and consists of two artefact classes governed by
two different licenses:
1. The model weights (under ./weights/*.safetensors) are released under
the BigScience Open RAIL-M License. The full text is in ./LICENSE and
was copied verbatim from
https://huggingface.co/Supertone/supertonic-3/blob/main/LICENSE
The Attachment A use restrictions (Section 5 + Attachment A clauses
(a)(m)) apply to all downstream use of the model and of any output
generated by the model.
2. The MLX port code (under ./src/supertonic_3_mlx/) is released under
the Apache License, Version 2.0. The full text is in ./LICENSE-CODE.
Attribution and modifications statement (BigScience Open RAIL-M Section 4.c):
Copyright (c) 2026 Supertone Inc. — original model weights and reference
Python/ONNX implementation. Distributed at
https://huggingface.co/Supertone/supertonic-3
Copyright (c) 2026 Olivier Dupont — MLX-native port code, weight format
conversion (ONNX → safetensors via the 3-stage extractor in
``src/supertonic_3_mlx/pipeline.py:_convert_onnx``), and pipeline
optimisations (``mx.compile`` of the CFG Euler loop, cross-attention
K/V cache shared across the 5 Euler steps). Distributed at
https://huggingface.co/ambassadia/supertonic-3-mlx
The MLX port does not modify the model's learned parameters in any
semantic sense — the only weight-level transformation is a tensor-shape
re-layout to match the MLX memory model (e.g. depthwise Conv1d
``(C, 1, K)`` → ``(C, K, 1)``). Bit-identical audio output to the
upstream ONNX Runtime reference is preserved up to FP32 accumulation
noise (cosine ≥ 0.98 on the full pipeline, cosine = 1.00 on the vocoder).
No use of the Supertone trademarks, logos, or trade dress is asserted or
permitted by this release (BigScience Open RAIL-M Section 8).

260
README.md Normal file
View File

@@ -0,0 +1,260 @@
---
license: openrail
license_link: LICENSE
language:
- en
- fr
- de
- es
- it
- pt
- ja
- ko
- zh
- ru
- pl
- nl
- tr
- ar
- hi
- vi
- th
- id
- cs
- ro
- hu
- el
- da
- sv
- fi
- no
- he
- uk
- bg
- hr
- sk
pipeline_tag: text-to-speech
tags:
- mlx
- apple-silicon
- tts
- text-to-speech
- speech-synthesis
- supertonic
- multilingual
- flow-matching
library_name: supertonic-3-mlx
base_model: Supertone/supertonic-3
inference: false
---
# Supertonic 3 — MLX-native
**31-language text-to-speech, ~x100 realtime on Apple Silicon.**
Native MLX port of [Supertone/supertonic-3](https://huggingface.co/Supertone/supertonic-3),
runs the full flow-matching + classifier-free-guidance pipeline (DurationPredictor →
TextEncoder → 24-block VectorEstimator (5 Euler steps) → 10-block Vocos vocoder)
without ONNX, CoreML or any C++ runtime — only MLX + NumPy.
## Install
The package isn't on PyPI yet — install directly from this gitea source
repository (or from the local checkout):
```bash
pip install git+https://gitea.tavportal.com/olivier/supertonic-3-mlx.git
```
Runtime dependencies are just `mlx`, `numpy`, and `huggingface_hub` (the
last for the one-line weight download). On first use the ~ 400 MB weight
bundle is downloaded from
[`ambassadia/supertonic-3-mlx`](https://huggingface.co/ambassadia/supertonic-3-mlx)
into your Hugging Face cache.
### One-shot quickstart + sanity test
A zero-config end-to-end test script ships with the repo. Clone the repo,
run the script, and it will create a fresh venv, install everything,
version-check MLX (with an optional auto-upgrade), download the weights
and synthesise an utterance into `hello.wav`:
```bash
git clone https://gitea.tavportal.com/olivier/supertonic-3-mlx.git
cd supertonic-3-mlx
./setup_and_test.sh # en F1, default text
./setup_and_test.sh fr F2 "Bonjour." # custom lang / voice / text
```
Re-runs reuse the venv and the cached weights — second invocation is
~ 20 ms warm load + ~ 30 ms per generate.
## Quickstart (after install)
```python
from supertonic_3_mlx import Pipeline
pipe = Pipeline.from_pretrained("ambassadia/supertonic-3-mlx")
wav = pipe.generate("Hello world from Apple Silicon.", voice="F1", lang="en")
# wav is a 1-D numpy.float32 array at 44.1 kHz
import soundfile as sf
sf.write("hello.wav", wav, pipe.sample_rate)
```
## Audio samples
Six languages, mix of male / female voices, mix of short and long utterances —
all generated by the MLX pipeline at the wall times reported below.
<audio controls src="samples/en_F1_short.wav"></audio> &nbsp; **EN · F1 · 2.79 s**
"Hello world from Apple Silicon. Supertonic 3 runs at one hundred times real time."
<audio controls src="samples/en_M1_long.wav"></audio> &nbsp; **EN · M1 · 3.90 s**
"A gentle breeze moved through the open window while the children, still half-asleep, listened to the distant sound of the harbour bells."
<audio controls src="samples/fr_F2.wav"></audio> &nbsp; **FR · F2 · 3.41 s**
"Bonjour, ceci est un test de synthèse vocale en français. Le modèle gère trente-et-une langues sur une puce M4."
<audio controls src="samples/de_M2.wav"></audio> &nbsp; **DE · M2 · 3.69 s**
"Guten Morgen. Dieses Modell läuft komplett auf Apple Silicon, ohne ONNX und ohne CoreML, in reinem MLX."
<audio controls src="samples/ja_F3.wav"></audio> &nbsp; **JA · F3 · 1.46 s**
"こんにちは。これはアップルシリコン上でMLXを使ったテストです。"
<audio controls src="samples/es_M3.wav"></audio> &nbsp; **ES · M3 · 2.86 s**
"Hola, esto es una prueba de síntesis de voz en español ejecutada en tiempo real sobre Apple Silicon."
## Benchmarks (Apple M4, FP32, median of 3)
| Sample | Duration | MLX wall | RTF | ONNX SDK | Speedup |
|-----------------|---------:|----------:|----------:|---------:|--------:|
| EN · F1 · short | 2.79 s | 36.6 ms | **x76** | 1005 ms | **28 ×**|
| EN · M1 · long | 3.90 s | 38.4 ms | **x102** | 1356 ms | **35 ×**|
| FR · F2 | 3.41 s | 37.9 ms | **x90** | 1196 ms | **32 ×**|
| DE · M2 | 3.69 s | 38.1 ms | **x97** | 1314 ms | **35 ×**|
| JA · F3 | 1.46 s | 32.1 ms | **x46** | 848 ms | **26 ×**|
| ES · M3 | 2.86 s | 37.0 ms | **x77** | 1002 ms | **27 ×**|
Raw numbers are in [`bench_results.csv`](bench_results.csv) (regenerable via
the development monorepo at
[`gitea.tavportal.com/olivier/MLX_CONVERTOR`](https://gitea.tavportal.com/olivier/MLX_CONVERTOR);
this repository ships the consolidated release artefacts only).
Reference comparison: the CoreML build of the same model on the same hardware
runs at ~x27 realtime. The MLX port is **~2-4× faster** end-to-end while
remaining bit-identical to the ONNX Runtime reference on the vocoder
(cosine 1.00) and at cosine ≥ 0.98 on the full estimator output.
## Voices
10 preset voices — five female (`F1``F5`) and five male (`M1``M5`). The
`voice_styles/` directory contains both `style_ttl` (50×256 latent style for
the audio path) and `style_dp` (8×16 style for the duration head) for each
voice. Pass the voice name as the `voice=` kwarg to `Pipeline.generate`.
## Languages
31 languages supported. Pass the ISO 639-1 code as the `lang=` kwarg:
`en` `fr` `de` `es` `it` `pt` `ja` `ko` `zh` `ru` `pl` `nl` `tr` `ar` `hi`
`vi` `th` `id` `cs` `ro` `hu` `el` `da` `sv` `fi` `no` `he` `uk` `bg` `hr` `sk`.
## Architecture (short)
Four sub-models, all in `weights/*.safetensors`:
| Sub-model | Role | Params | Size |
|----------------------|-------------------------------------|--------|---------|
| `vector_estimator` | 24-block CFG flow-matching velocity | ~64 M | 256 MB |
| `text_encoder` | Character → 256-D text embedding | ~9 M | 36 MB |
| `duration_predictor` | Text → seconds | ~1 M | 3.5 MB |
| `vocoder` | Latent (B,144,T) → 44.1 kHz wav | ~25 M | 101 MB |
The pipeline runs **exactly 5 Euler steps** with classifier-free guidance
(`4×cond 3×uncond`). This schedule is trained-in: reducing the step count
or disabling CFG produces an essentially uncorrelated waveform (verified
empirically — see the `bench_n_steps.py` script in the source repo).
## Loading from a local snapshot
Three layouts are auto-detected by `Pipeline.from_pretrained`:
1. **Hugging Face repo id** (e.g. `"ambassadia/supertonic-3-mlx"`) — auto-download
2. **Local path containing `weights/`** (this layout) — fastest cold-load
3. **Local path containing `onnx/`** (upstream snapshot) — converts at load time
## License
This release combines two artefact classes under two distinct licenses:
- **Model weights** (`weights/*.safetensors`) — **BigScience Open RAIL-M**.
See [`LICENSE`](LICENSE) for the full text. The Attachment A use
restrictions are reproduced below and apply to all downstream use of the
model and of generated audio.
- **Port code** (`src/supertonic_3_mlx/`) — **Apache License 2.0**. See
[`LICENSE-CODE`](LICENSE-CODE).
See [`NOTICE`](NOTICE) for the modifications statement and the upstream
attribution.
### OpenRAIL-M Attachment A — use restrictions
You agree not to use the model or derivatives:
(a) In any way that violates any applicable national, federal, state, local or
international law or regulation.
(b) For the purpose of exploiting, harming or attempting to exploit or harm
minors in any way.
(c) To generate or disseminate verifiably false information and/or content
with the purpose of harming others.
(d) To generate or disseminate personal identifiable information that can be
used to harm an individual.
(e) To generate or disseminate information and/or content (e.g. images, code,
posts, articles), and place the information and/or content in any context
(e.g. bot generating tweets) **without expressly and intelligibly disclaiming
that the information and/or content is machine generated**.
(f) To defame, disparage or otherwise harass others.
(g) To impersonate or attempt to impersonate (e.g. **deepfakes**) others
without their consent.
(h) For fully automated decision making that adversely impacts an individual's
legal rights or otherwise creates or modifies a binding, enforceable obligation.
(i) For any use intended to or which has the effect of discriminating against
or harming individuals or groups based on online or offline social behavior or
known or predicted personal or personality characteristics.
(j) To exploit any of the vulnerabilities of a specific group of persons based
on their age, social, physical or mental characteristics, in order to materially
distort the behavior of a person pertaining to that group in a manner that
causes or is likely to cause that person or another person physical or
psychological harm.
(k) For any use intended to or which has the effect of discriminating against
individuals or groups based on legally protected characteristics or categories.
(l) **To provide medical advice and medical results interpretation.**
(m) To generate or disseminate information for the purpose to be used for
administration of justice, law enforcement, immigration or asylum processes,
such as predicting an individual will commit fraud/crime commitment.
## Citation
```bibtex
@misc{supertonic3-mlx,
title = {Supertonic 3 MLX: native Apple Silicon port of Supertone's multilingual TTS},
author = {Dupont, Olivier},
year = {2026},
url = {https://huggingface.co/ambassadia/supertonic-3-mlx},
note = {Derivative of Supertone/supertonic-3 (https://huggingface.co/Supertone/supertonic-3)}
}
```
Please also cite the upstream Supertone Supertonic 3 model when using this
port.

7
bench_results.csv Normal file
View File

@@ -0,0 +1,7 @@
filename,language,voice,text,duration_s,mlx_ms_median,rtf_mlx,onnx_ms_median,rtf_onnx,speedup_mlx_over_onnx
samples/en_F1_short.wav,en,F1,Hello world from Apple Silicon. Supertonic 3 runs at one hundred times real time.,2.786,36.6,76.2,1004.7,2.8,27.5
samples/en_M1_long.wav,en,M1,"A gentle breeze moved through the open window while the children, still half-asleep, listened to the distant sound of the harbour bells.",3.901,38.4,101.7,1356.0,2.9,35.3
samples/fr_F2.wav,fr,F2,"Bonjour, ceci est un test de synthèse vocale en français. Le modèle gère trente-et-une langues sur une puce M4.",3.413,37.9,90.1,1195.6,2.9,31.6
samples/de_M2.wav,de,M2,"Guten Morgen. Dieses Modell läuft komplett auf Apple Silicon, ohne ONNX und ohne CoreML, in reinem MLX.",3.692,38.1,96.9,1313.9,2.8,34.5
samples/ja_F3.wav,ja,F3,こんにちは。これはアップルシリコン上でMLXを使ったテストです。,1.463,32.1,45.6,848.4,1.7,26.4
samples/es_M3.wav,es,M3,"Hola, esto es una prueba de síntesis de voz en español ejecutada en tiempo real sobre Apple Silicon.",2.856,37.0,77.2,1002.1,2.9,27.1
1 filename language voice text duration_s mlx_ms_median rtf_mlx onnx_ms_median rtf_onnx speedup_mlx_over_onnx
2 samples/en_F1_short.wav en F1 Hello world from Apple Silicon. Supertonic 3 runs at one hundred times real time. 2.786 36.6 76.2 1004.7 2.8 27.5
3 samples/en_M1_long.wav en M1 A gentle breeze moved through the open window while the children, still half-asleep, listened to the distant sound of the harbour bells. 3.901 38.4 101.7 1356.0 2.9 35.3
4 samples/fr_F2.wav fr F2 Bonjour, ceci est un test de synthèse vocale en français. Le modèle gère trente-et-une langues sur une puce M4. 3.413 37.9 90.1 1195.6 2.9 31.6
5 samples/de_M2.wav de M2 Guten Morgen. Dieses Modell läuft komplett auf Apple Silicon, ohne ONNX und ohne CoreML, in reinem MLX. 3.692 38.1 96.9 1313.9 2.8 34.5
6 samples/ja_F3.wav ja F3 こんにちは。これはアップルシリコン上でMLXを使ったテストです。 1.463 32.1 45.6 848.4 1.7 26.4
7 samples/es_M3.wav es M3 Hola, esto es una prueba de síntesis de voz en español ejecutada en tiempo real sobre Apple Silicon. 2.856 37.0 77.2 1002.1 2.9 27.1

226
conversion_report.json Normal file
View File

@@ -0,0 +1,226 @@
{
"models": [
{
"model": "VectorEstimator",
"onnx": "/tmp/supertonic3/model/onnx/vector_estimator.onnx",
"safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/vector_estimator.safetensors",
"bytes": 256053073,
"sha256": "2359240f2dcaee03b4800102aa0bea00223d2867ab752ef01af2b1cfaf92f3a6",
"weights_kept": 351,
"weights_dropped": 120,
"dropped_detail": {
"tts.ae.vector_field.proj_in.net.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.0.convnext.0.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.0.convnext.0.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.0.convnext.0.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.0.convnext.0.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.0.convnext.1.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.0.convnext.1.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.0.convnext.1.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.0.convnext.1.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.0.convnext.2.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.0.convnext.2.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.0.convnext.2.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.0.convnext.2.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.0.convnext.3.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.0.convnext.3.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.0.convnext.3.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.0.convnext.3.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.2.convnext.0.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.2.convnext.0.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.2.convnext.0.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.2.convnext.0.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.4.convnext.0.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.4.convnext.0.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.4.convnext.0.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.4.convnext.0.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.6.convnext.0.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.6.convnext.0.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.6.convnext.0.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.6.convnext.0.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.6.convnext.1.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.6.convnext.1.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.6.convnext.1.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.6.convnext.1.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.6.convnext.2.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.6.convnext.2.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.6.convnext.2.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.6.convnext.2.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.6.convnext.3.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.6.convnext.3.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.6.convnext.3.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.6.convnext.3.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.8.convnext.0.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.8.convnext.0.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.8.convnext.0.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.8.convnext.0.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.10.convnext.0.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.10.convnext.0.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.10.convnext.0.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.10.convnext.0.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.12.convnext.0.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.12.convnext.0.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.12.convnext.0.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.12.convnext.0.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.12.convnext.1.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.12.convnext.1.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.12.convnext.1.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.12.convnext.1.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.12.convnext.2.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.12.convnext.2.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.12.convnext.2.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.12.convnext.2.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.12.convnext.3.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.12.convnext.3.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.12.convnext.3.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.12.convnext.3.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.14.convnext.0.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.14.convnext.0.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.14.convnext.0.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.14.convnext.0.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.16.convnext.0.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.16.convnext.0.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.16.convnext.0.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.16.convnext.0.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.18.convnext.0.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.18.convnext.0.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.18.convnext.0.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.18.convnext.0.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.18.convnext.1.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.18.convnext.1.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.18.convnext.1.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.18.convnext.1.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.18.convnext.2.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.18.convnext.2.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.18.convnext.2.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.18.convnext.2.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.18.convnext.3.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.18.convnext.3.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.18.convnext.3.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.18.convnext.3.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.20.convnext.0.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.20.convnext.0.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.20.convnext.0.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.20.convnext.0.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.22.convnext.0.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.22.convnext.0.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.main_blocks.22.convnext.0.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.main_blocks.22.convnext.0.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.last_convnext.convnext.0.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.last_convnext.convnext.0.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.last_convnext.convnext.0.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.last_convnext.convnext.0.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.last_convnext.convnext.1.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.last_convnext.convnext.1.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.last_convnext.convnext.1.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.last_convnext.convnext.1.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.last_convnext.convnext.2.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.last_convnext.convnext.2.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.last_convnext.convnext.2.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.last_convnext.convnext.2.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.last_convnext.convnext.3.pwconv1.weight": "not-in-model",
"tts.ae.vector_field.last_convnext.convnext.3.pwconv1.bias": "not-in-model",
"tts.ae.vector_field.last_convnext.convnext.3.pwconv2.weight": "not-in-model",
"tts.ae.vector_field.last_convnext.convnext.3.pwconv2.bias": "not-in-model",
"tts.ae.vector_field.proj_out.net.weight": "not-in-model",
"<missing>.vector_field.main_blocks.9.attn.theta": "expected-but-not-extracted",
"<missing>.vector_field.main_blocks.9.attn.increments": "expected-but-not-extracted",
"<missing>.vector_field.main_blocks.15.attn.theta": "expected-but-not-extracted",
"<missing>.vector_field.main_blocks.15.attn.increments": "expected-but-not-extracted",
"<missing>.vector_field.main_blocks.21.attn.theta": "expected-but-not-extracted",
"<missing>.vector_field.main_blocks.21.attn.increments": "expected-but-not-extracted"
},
"elapsed_s": 0.289
},
{
"model": "TextEncoder",
"onnx": "/tmp/supertonic3/model/onnx/text_encoder.onnx",
"safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/text_encoder.safetensors",
"bytes": 36022466,
"sha256": "9df20bb79496718b36d2c0fc37636d3f78d6ef751b2899ff6dfeb975ae737ada",
"weights_kept": 146,
"weights_dropped": 0,
"dropped_detail": {},
"elapsed_s": 0.035
},
{
"model": "DurationPredictor",
"onnx": "/tmp/supertonic3/model/onnx/duration_predictor.onnx",
"safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/duration_predictor.safetensors",
"bytes": 3470807,
"sha256": "cd473acb6e0ac27426084488ccb3b3cc184e70d05db90897e2b892846db5dcb3",
"weights_kept": 98,
"weights_dropped": 0,
"dropped_detail": {},
"elapsed_s": 0.007
},
{
"model": "Vocoder",
"onnx": "/tmp/supertonic3/model/onnx/vocoder.onnx",
"safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/vocoder.safetensors",
"bytes": 101364763,
"sha256": "b2ec31ab7c554f6e15b9a6780554b5d3502345de7848b310966bfb4e1ea4e526",
"weights_kept": 103,
"weights_dropped": 0,
"dropped_detail": {},
"elapsed_s": 0.079
}
],
"ancillary": [
{
"name": "unicode_indexer.json",
"bytes": 277676,
"sha256": "9bf7346e43883a81f8645c81224f786d43c5b57f3641f6e7671a7d6c493cb24f"
},
{
"name": "voice_styles/F1.json",
"bytes": 292046,
"sha256": "bbdec6ee00231c2c742ad05483df5334cab3b52fda3ba38e6a07059c4563dbc2"
},
{
"name": "voice_styles/F2.json",
"bytes": 292423,
"sha256": "7c722c6a72707b1a77f035d67f0d1351ba187738e06f7683e8c72b1df3477fc6"
},
{
"name": "voice_styles/F3.json",
"bytes": 290794,
"sha256": "12f6ef2573baa2defa1128069cb59f203e3ab67c92af77b42df8a0e3a2f7c6ab"
},
{
"name": "voice_styles/F4.json",
"bytes": 291808,
"sha256": "c2fa764c1225a76dfc3e2c73e8aa4f70d9ee48793860eb34c295fff01c2e032b"
},
{
"name": "voice_styles/F5.json",
"bytes": 291479,
"sha256": "45966e73316415626cf41a7d1c6f3b4c70dbc1ba2bee5c1978ef0ce33244fc8d"
},
{
"name": "voice_styles/M1.json",
"bytes": 291748,
"sha256": "e35604687f5d23694b8e91593a93eec0e4eca6c0b02bb8ed69139ab2ea6b0a5b"
},
{
"name": "voice_styles/M2.json",
"bytes": 292055,
"sha256": "b76cbf62bac707c710cf0ae5aba5e31eea1a6339a9734bfae33ab98499534a50"
},
{
"name": "voice_styles/M3.json",
"bytes": 290198,
"sha256": "ea1ac35ccb91b0d7ecad533a2fbd0eec10c91513d8951e3b25fbba99954e159b"
},
{
"name": "voice_styles/M4.json",
"bytes": 291522,
"sha256": "ca8eefad4fcd989c9379032ff3e50738adc547eeb5e221b82593a6d7b3bac303"
},
{
"name": "voice_styles/M5.json",
"bytes": 291469,
"sha256": "dd22b92740314321f8ae11c5e87f8dd60d060f15dd3a632b5adf77f471f77af2"
}
]
}

23
examples/quickstart.py Normal file
View File

@@ -0,0 +1,23 @@
"""Minimal Supertonic 3 MLX usage — 5 lines, no fluff.
Run from anywhere AFTER ``pip install supertonic-3-mlx`` (or from inside
this directory after ``pip install ./``):
python examples/quickstart.py
"""
from supertonic_3_mlx import Pipeline
import soundfile as sf
# When the package has been pip-installed, this auto-downloads from the Hub
# (~ 400 MB) into the standard Hugging Face cache. After the first run, the
# weights are reused from cache and cold start is ~ 11 ms on M4.
pipe = Pipeline.from_pretrained("ambassadia/supertonic-3-mlx")
wav = pipe.generate(
"Hello world from Apple Silicon. Supertonic 3 runs at one hundred times realtime.",
voice="F1", # one of F1..F5, M1..M5
lang="en", # ISO 639-1
)
sf.write("hello.wav", wav, pipe.sample_rate)
print(f"wrote hello.wav — {len(wav) / pipe.sample_rate:.2f}s of audio")

43
pyproject.toml Normal file
View File

@@ -0,0 +1,43 @@
[project]
name = "supertonic-3-mlx"
version = "0.1.0"
description = "MLX-native port of Supertone's Supertonic 3 multilingual TTS (31 languages, ~x100 realtime on Apple Silicon)"
readme = "README.md"
requires-python = ">=3.10"
authors = [{ name = "Olivier Dupont", email = "olivier.dupont@taviramonaco.com" }]
license = { text = "Apache-2.0 AND OpenRAIL-M" }
keywords = ["mlx", "tts", "speech-synthesis", "apple-silicon", "supertonic", "multilingual"]
classifiers = [
"Development Status :: 4 - Beta",
"Environment :: MacOS X",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: MacOS",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Multimedia :: Sound/Audio :: Speech",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
"mlx>=0.21.0",
"numpy>=1.24.0",
]
[project.optional-dependencies]
hub = ["huggingface_hub>=0.26.0"]
dev = ["pytest>=8.3.0", "ruff>=0.7.0"]
[project.urls]
Homepage = "https://huggingface.co/ambassadia/supertonic-3-mlx"
Upstream = "https://huggingface.co/Supertone/supertonic-3"
Source = "https://gitea.tavportal.com/olivier/supertonic-3-mlx"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/supertonic_3_mlx"]

BIN
samples/de_M2.wav Normal file

Binary file not shown.

BIN
samples/en_F1_short.wav Normal file

Binary file not shown.

BIN
samples/en_M1_long.wav Normal file

Binary file not shown.

BIN
samples/es_M3.wav Normal file

Binary file not shown.

BIN
samples/fr_F2.wav Normal file

Binary file not shown.

BIN
samples/ja_F3.wav Normal file

Binary file not shown.

131
setup_and_test.sh Executable file
View File

@@ -0,0 +1,131 @@
#!/usr/bin/env bash
# Quick install + sanity-test for the supertonic-3-mlx standalone package.
#
# Creates a local ``.venv`` next to this script, installs the package and its
# runtime deps, version-checks MLX, downloads the model weights from the
# Hugging Face Hub on first run, and synthesises one short utterance to
# ``hello.wav``. Idempotent: re-running reuses the existing venv and cached
# weights.
#
# Usage:
# ./setup_and_test.sh # default: en F1, "Hello world…"
# ./setup_and_test.sh fr F2 "Bonjour."
#
set -euo pipefail
# ── 0. Inputs ────────────────────────────────────────────────────────
LANG_CODE="${1:-en}"
VOICE="${2:-F1}"
TEXT="${3:-Hello world from Apple Silicon. Supertonic 3 runs at one hundred times realtime.}"
HERE="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
VENV="$HERE/.venv"
# ── 1. Python version gate ──────────────────────────────────────────
if ! command -v python3 >/dev/null; then
echo "ERROR: python3 not found. Install Python 3.10+ first." >&2
exit 1
fi
PYVER="$(python3 -c 'import sys; print("%d.%d"%sys.version_info[:2])')"
PYMAJ="${PYVER%.*}"; PYMIN="${PYVER#*.}"
if [ "$PYMAJ" -lt 3 ] || { [ "$PYMAJ" -eq 3 ] && [ "$PYMIN" -lt 10 ]; }; then
echo "ERROR: Python 3.10+ required, found $PYVER." >&2
exit 1
fi
echo "→ python3: $PYVER"
# ── 2. venv ─────────────────────────────────────────────────────────
if [ ! -x "$VENV/bin/python" ]; then
echo "→ creating venv at $VENV"
python3 -m venv "$VENV"
fi
PIP="$VENV/bin/pip"
PY="$VENV/bin/python"
# ── 3. dependencies ─────────────────────────────────────────────────
echo "→ installing dependencies …"
"$PIP" install --quiet --upgrade pip
# Install the package + the optional runtime deps. The package itself pulls in
# mlx + numpy via its pyproject.toml; we add huggingface_hub for the Hub
# download path, hf_transfer for large-blob throughput, and soundfile so the
# test script can write a WAV.
"$PIP" install --quiet "$HERE" huggingface_hub soundfile
# ── 4. MLX version gate + optional patch hook ───────────────────────
"$PY" - <<'PYEOF'
import sys
try:
import mlx.core as mx
except ImportError:
print("ERROR: mlx not importable. Are you on Apple Silicon? "
"MLX is macOS-on-Apple-Silicon only.", file=sys.stderr)
sys.exit(1)
ver_str = getattr(mx, "__version__", "0.0.0")
ver = tuple(int(p) for p in ver_str.split(".")[:3] if p.isdigit())
print(f"→ mlx version: {ver_str}")
# Minimum tested combination — bumped as the upstream API changes.
MIN_OK = (0, 21, 0)
if ver < MIN_OK:
print(f" WARNING: mlx < {'.'.join(map(str, MIN_OK))}. Upgrading …",
file=sys.stderr)
import subprocess
subprocess.check_call([sys.executable, "-m", "pip", "install",
"--quiet", "--upgrade", "mlx"])
print(" → upgraded. Re-run the script to pick up the new version.",
file=sys.stderr)
sys.exit(2)
# Patches we currently know about — none. This is the slot where future
# MLX-specific shims would land (e.g. a workaround for an upstream Conv1d
# regression). Keep the dispatch table here so the script stays a single
# source of truth.
PATCHES: dict[tuple[int, int, int], str] = {
# (broken_version): "patch description"
}
applied = [desc for v, desc in PATCHES.items() if v == ver]
if applied:
for desc in applied:
print(f" applied patch: {desc}")
else:
print(f" no patches needed for mlx {ver_str}")
PYEOF
# ── 5. quickstart generate ──────────────────────────────────────────
echo "→ generating audio …"
LANG_CODE="$LANG_CODE" VOICE="$VOICE" TEXT="$TEXT" \
HF_HUB_DISABLE_XET=1 \
HF_HUB_ENABLE_HF_TRANSFER=1 \
"$PY" - <<'PYEOF'
import os, time
from supertonic_3_mlx import Pipeline
import soundfile as sf
lang = os.environ["LANG_CODE"]
voice = os.environ["VOICE"]
text = os.environ["TEXT"]
# First call downloads ~ 400 MB of weights into the HF cache. Subsequent
# runs reuse the cache and load in ~ 20 ms.
t0 = time.perf_counter()
pipe = Pipeline.from_pretrained("ambassadia/supertonic-3-mlx")
load_t = time.perf_counter() - t0
print(f" load : {load_t*1000:.0f} ms")
# Warmup (compiles the kernel graph for this shape).
pipe.generate("Warm.", voice=voice, lang=lang)
t0 = time.perf_counter()
wav = pipe.generate(text, voice=voice, lang=lang, seed=42)
gen_t = time.perf_counter() - t0
dur = len(wav) / pipe.sample_rate
print(f" generate : {gen_t*1000:.0f} ms")
print(f" audio : {dur:.2f} s ({len(wav)} samples @ {pipe.sample_rate} Hz)")
print(f" RTF : x{dur/gen_t:.0f}")
print(f" max amp : {abs(wav).max():.4f}")
sf.write("hello.wav", wav, pipe.sample_rate)
print("\n✓ wrote hello.wav — open it to verify the synthesis sounds correct.")
PYEOF

View File

@@ -0,0 +1,51 @@
"""Supertonic 3 — MLX-native TTS for Apple Silicon.
31-language text-to-speech, 5 Euler steps with classifier-free guidance, in
pure MLX. On M4 the full pipeline runs at ~x100 realtime.
Quickstart
----------
from supertonic_3_mlx import Pipeline
pipe = Pipeline.from_pretrained("ambassadia/supertonic-3-mlx")
wav = pipe.generate("Hello world from Apple Silicon.", voice="F1", lang="en")
# wav is a 1-D ``numpy.float32`` array at 44.1 kHz.
The model weights are released under the BigScience OpenRAIL-M license
(see LICENSE in the Hugging Face repository). This MLX port code is
Apache-2.0. Together they form a dual-license package; Attachment A use
restrictions of OpenRAIL-M govern downstream use of the generated audio.
Public API:
Pipeline — end-to-end TTS, ``from_pretrained`` + ``generate``
VectorEstimator — the 24-block CFG flow-matching net (sub-model 1/4)
TextEncoder — character → text embedding (sub-model 2/4)
DurationPredictor — text → duration in seconds (sub-model 3/4)
Vocoder — latent → 44.1 kHz waveform (sub-model 4/4)
"""
from supertonic_3_mlx._config import (
DIM, LATENT_CH, CONVNEXT_HIDDEN, CONVNEXT_K,
NUM_MAIN_BLOCKS, NUM_CYCLES, BLOCKS_PER_CYCLE, BLOCK_CYCLE, STACK4_DILATIONS,
TEXT_HEADS, TEXT_HEAD_DIM, TEXT_DIM, ROTARY_BASE, ROTARY_SCALE,
STYLE_HEADS, STYLE_HEAD_DIM, STYLE_LEN, STYLE_DIM,
TIME_EMB_DIM, TIME_MLP_HIDDEN,
EPS_LN, CHUNK_COMPRESS, LATENT_DIM, SAMPLE_RATE,
SUPERTONIC3_HF_REPO,
)
from supertonic_3_mlx.duration_predictor import DurationPredictor
from supertonic_3_mlx.text_encoder import TextEncoder
from supertonic_3_mlx.vector_estimator import VectorEstimator
from supertonic_3_mlx.vocoder import Vocoder
from supertonic_3_mlx.pipeline import SupertonicMLXPipeline as Pipeline
__all__ = [
"Pipeline",
"DurationPredictor", "TextEncoder", "VectorEstimator", "Vocoder",
"DIM", "LATENT_CH", "CONVNEXT_HIDDEN", "CONVNEXT_K",
"NUM_MAIN_BLOCKS", "NUM_CYCLES", "BLOCKS_PER_CYCLE", "BLOCK_CYCLE", "STACK4_DILATIONS",
"TEXT_HEADS", "TEXT_HEAD_DIM", "TEXT_DIM", "ROTARY_BASE", "ROTARY_SCALE",
"STYLE_HEADS", "STYLE_HEAD_DIM", "STYLE_LEN", "STYLE_DIM",
"TIME_EMB_DIM", "TIME_MLP_HIDDEN",
"EPS_LN", "CHUNK_COMPRESS", "LATENT_DIM", "SAMPLE_RATE",
"SUPERTONIC3_HF_REPO",
]

View File

@@ -0,0 +1,58 @@
"""Locked hyperparameters for Supertonic 3 MLX port.
Derived from the official ``Supertone/supertonic-3/onnx/tts.json``.
Changing these = re-running parity tests.
"""
from __future__ import annotations
# Vector estimator (the flow-matching denoiser)
DIM: int = 512 # backbone width
LATENT_CH: int = 144 # 24 * chunk_compress_factor (6)
CONVNEXT_HIDDEN: int = 2048 # main_blocks ConvNeXt intermediate dim (2× vs s2)
CONVNEXT_K: int = 5
LAST_CONVNEXT_NUM: int = 4 # last_convnext is a 4-layer stack (dilations [1,1,1,1])
# 24 main_blocks = 4 cycles × 6 sub-blocks (cycle: stack4, time, cn1, text_attn, cn1, style_attn)
NUM_CYCLES: int = 4
BLOCKS_PER_CYCLE: int = 6
NUM_MAIN_BLOCKS: int = NUM_CYCLES * BLOCKS_PER_CYCLE
BLOCK_CYCLE = ("stack4", "time", "cn1", "text_attn", "cn1", "style_attn")
# ConvNeXt stack 4 (in stack4 blocks) — dilation schedule
STACK4_DILATIONS = (1, 2, 4, 8)
# Text cross-attention (RoPE) — block type "text_attn"
TEXT_DIM: int = 256
TEXT_HEADS: int = 8 # 2× vs s2 (4)
TEXT_HEAD_DIM: int = DIM // TEXT_HEADS # 512/8 = 64
ROTARY_BASE: int = 10_000
ROTARY_SCALE: int = 10
# Style cross-attention — block type "style_attn"
STYLE_DIM: int = 256
STYLE_LEN: int = 50 # 50 style tokens (n_style)
STYLE_HEADS: int = 2
STYLE_HEAD_DIM: int = 128
# Time encoding (sinusoidal + MLP)
TIME_EMB_DIM: int = 64
TIME_MLP_HIDDEN: int = 256
# LayerNorm epsilon
EPS_LN: float = 1e-6
# Chunk compress factor (used by AE)
CHUNK_COMPRESS: int = 6
LATENT_DIM: int = 24 # ldim before chunk compression
# Sample rate
SAMPLE_RATE: int = 44_100
# HF references (will be pinned to SHA after first download)
SUPERTONIC3_HF_REPO: str = "Supertone/supertonic-3"
ONNX_VECTOR_ESTIMATOR: str = "onnx/vector_estimator.onnx"
ONNX_TEXT_ENCODER: str = "onnx/text_encoder.onnx"
ONNX_DURATION_PREDICTOR: str = "onnx/duration_predictor.onnx"
ONNX_VOCODER: str = "onnx/vocoder.onnx"
ONNX_TTS_JSON: str = "onnx/tts.json"
ONNX_UNICODE_INDEXER: str = "onnx/unicode_indexer.json"

View File

@@ -0,0 +1,50 @@
"""Small wrapper modules to match Supertonic 3 ONNX submodule nesting.
The s3 checkpoint nests primitives one level deeper than typical MLX modules:
- ``norm.norm.weight`` — LayerNorm wrapped in a Norm container
- ``linear.linear.weight`` — Linear wrapped in a Linear container
- ``W_query.linear.weight`` — attention projection wrapped
Mirroring this nesting lets us load the safetensors with ``model.load_weights(...)``
without any key remapping at load time.
"""
from __future__ import annotations
import mlx.core as mx
import mlx.nn as nn
class WrappedNorm(nn.Module):
"""Container with a single nested LayerNorm — produces key ``X.norm.weight``."""
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.norm = nn.LayerNorm(dim, eps=eps)
def __call__(self, x: mx.array) -> mx.array:
return self.norm(x)
class WrappedLinear(nn.Module):
"""Container with a single nested Linear — produces keys ``X.linear.weight/bias``."""
def __init__(self, in_dim: int, out_dim: int, bias: bool = True) -> None:
super().__init__()
self.linear = nn.Linear(in_dim, out_dim, bias=bias)
def __call__(self, x: mx.array) -> mx.array:
return self.linear(x)
class ProjConv1x1(nn.Module):
"""Conv1d k=1 expressed as ``self.net = Linear`` (matches ``proj_in.net.weight``)."""
def __init__(self, in_dim: int, out_dim: int, bias: bool = True) -> None:
super().__init__()
self.net = nn.Linear(in_dim, out_dim, bias=bias)
def __call__(self, x: mx.array) -> mx.array:
return self.net(x)
__all__ = ["WrappedNorm", "WrappedLinear", "ProjConv1x1"]

View File

@@ -0,0 +1,347 @@
"""Supertonic 3 duration predictor — predicts total audio duration in seconds.
Pipeline (channels-last NTC throughout):
text_ids [B, T] int64 character IDs
→ char_embed (Embedding 8322→64) [B, T, 64]
→ prepend sentence_token (1, 64, 1) [B, T+1, 64]
→ 6× ConvNeXt (dim=64, hidden=256, k=5, all dilations=1)
→ 2× RelPosSelfAttn (heads=2, head_dim=32, window=4) + norm + FFN + norm
→ proj_out (Conv1d k=1: 64→64) applied to slot 0 (sentence token)
→ concat with style_dp flattened (B, 8×16=128) [B, 192]
→ Linear(192 → 128) → PReLU → Linear(128 → 1) → exp → duration [B]
Inputs:
text_ids: (B, T) int — character indices
style_dp: (B, 8, 16) — style summary tokens
text_mask: (B, 1, T) — 1.0 valid, 0.0 padded
"""
from __future__ import annotations
import mlx.core as mx
import mlx.nn as nn
from supertonic_3_mlx._config import EPS_LN
from supertonic_3_mlx._nn_wrappers import WrappedNorm
from supertonic_3_mlx.vector_estimator import _pad_sym_edge, _gelu_exact
DP_VOCAB = 8322
DP_DIM = 64
DP_CONVNEXT_HIDDEN = 256
DP_CONVNEXT_K = 5
DP_CONVNEXT_NUM_LAYERS = 6
DP_ATTN_NUM_LAYERS = 2
DP_ATTN_HEADS = 2
DP_ATTN_HEAD_DIM = DP_DIM // DP_ATTN_HEADS # 32
DP_FFN_HIDDEN = 256
DP_REL_POS_WINDOW = 4
DP_N_STYLE = 8
DP_STYLE_DIM = 16
DP_MLP_IN = DP_DIM + DP_N_STYLE * DP_STYLE_DIM # 64 + 128 = 192
DP_MLP_HIDDEN = 128
class _DPConvNeXtBlock(nn.Module):
"""ConvNeXt block (dim=64, hidden=256, dilation=1)."""
def __init__(self) -> None:
super().__init__()
self.dwconv = nn.Conv1d(
DP_DIM, DP_DIM, kernel_size=DP_CONVNEXT_K, padding=0,
dilation=1, groups=DP_DIM, bias=True,
)
self.norm = WrappedNorm(DP_DIM, eps=EPS_LN)
self.pwconv1 = nn.Linear(DP_DIM, DP_CONVNEXT_HIDDEN, bias=True)
self.pwconv2 = nn.Linear(DP_CONVNEXT_HIDDEN, DP_DIM, bias=True)
self.gamma = mx.zeros((DP_DIM,))
self.pad = (DP_CONVNEXT_K - 1) // 2
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
residual = x
y = _pad_sym_edge(x, self.pad)
y = self.dwconv(y)
y = self.norm(y)
y = self.pwconv1(y)
y = _gelu_exact(y)
y = self.pwconv2(y)
y = y * self.gamma
out = residual + y
if mask is not None:
out = out * mask
return out
class _DPConvNeXtStack(nn.Module):
"""``convnext.[0..5]`` — 6 ConvNeXt blocks."""
def __init__(self) -> None:
super().__init__()
self.convnext = [_DPConvNeXtBlock() for _ in range(DP_CONVNEXT_NUM_LAYERS)]
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
for b in self.convnext:
x = b(x, mask)
return x
class _DPConvLayer(nn.Module):
"""Conv1d k=1 with weight (out, 1, in) — matches ONNX storage."""
def __init__(self, in_dim: int, out_dim: int) -> None:
super().__init__()
self.weight = mx.zeros((out_dim, 1, in_dim))
self.bias = mx.zeros((out_dim,))
def __call__(self, x: mx.array) -> mx.array:
return mx.conv1d(x, self.weight, stride=1, padding=0) + self.bias
def _dp_rel_to_abs(x: mx.array) -> mx.array:
"""(B, h, L, 2L-1) → (B, h, L, L) via VITS shifted-skew reshape."""
B, h, L, _ = x.shape
x = mx.concatenate([x, mx.zeros((B, h, L, 1), dtype=x.dtype)], axis=-1)
x_flat = x.reshape(B, h, L * 2 * L)
x_flat = mx.concatenate([x_flat, mx.zeros((B, h, L - 1), dtype=x.dtype)], axis=-1)
x_final = x_flat.reshape(B, h, L + 1, 2 * L - 1)
return x_final[:, :, :L, L - 1:]
def _dp_abs_to_rel(x: mx.array) -> mx.array:
"""(B, h, L, L) → (B, h, L, 2L-1)."""
B, h, L, _ = x.shape
x = mx.concatenate([x, mx.zeros((B, h, L, L - 1), dtype=x.dtype)], axis=-1)
x_flat = x.reshape(B, h, L * (2 * L - 1))
x_flat = mx.concatenate([mx.zeros((B, h, L), dtype=x.dtype), x_flat], axis=-1)
x_final = x_flat.reshape(B, h, L, 2 * L)
return x_final[:, :, :, 1:]
def _dp_slice_rel(rel: mx.array, length: int, window: int) -> mx.array:
"""(1, 2W+1, d) → (1, 2L-1, d) by zero-padding/slicing."""
pad_l = max(length - (window + 1), 0)
if pad_l > 0:
zero = mx.zeros((1, pad_l, rel.shape[-1]), dtype=rel.dtype)
padded = mx.concatenate([zero, rel, zero], axis=1)
else:
padded = rel
start = max(window + 1 - length, 0)
return padded[:, start: start + 2 * length - 1]
class _DPRelPosSelfAttn(nn.Module):
"""VITS-style rel-pos self-attention (2 heads × 32 head_dim, window=4).
Includes both rel-pos contributions (q × rel_k → logits, abs_to_rel(attn) × rel_v → out).
"""
def __init__(self) -> None:
super().__init__()
self.conv_q = _DPConvLayer(DP_DIM, DP_DIM)
self.conv_k = _DPConvLayer(DP_DIM, DP_DIM)
self.conv_v = _DPConvLayer(DP_DIM, DP_DIM)
self.conv_o = _DPConvLayer(DP_DIM, DP_DIM)
self.emb_rel_k = mx.zeros((1, 2 * DP_REL_POS_WINDOW + 1, DP_ATTN_HEAD_DIM))
self.emb_rel_v = mx.zeros((1, 2 * DP_REL_POS_WINDOW + 1, DP_ATTN_HEAD_DIM))
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
B, T, _ = x.shape
H, D = DP_ATTN_HEADS, DP_ATTN_HEAD_DIM
q = self.conv_q(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
k = self.conv_k(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
v = self.conv_v(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
scale = D ** -0.5
logits = (q @ k.transpose(0, 1, 3, 2)) * scale
rel_k = _dp_slice_rel(self.emb_rel_k, T, DP_REL_POS_WINDOW)
rel_logits = q @ rel_k.transpose(0, 2, 1)[:, None, :, :]
rel_logits = _dp_rel_to_abs(rel_logits * scale)
logits = logits + rel_logits
if mask is not None:
key_mask = mask[:, :, 0][:, None, None, :]
neg_inf = mx.array(-1e4, dtype=logits.dtype)
logits = mx.where(key_mask.astype(mx.bool_), logits, neg_inf)
attn = mx.softmax(logits, axis=-1)
out = attn @ v
rel_v = _dp_slice_rel(self.emb_rel_v, T, DP_REL_POS_WINDOW)
rel_weights = _dp_abs_to_rel(attn)
out = out + rel_weights @ rel_v[:, None, :, :]
out = out.transpose(0, 2, 1, 3).reshape(B, T, H * D)
return self.conv_o(out)
class _DPFFN(nn.Module):
"""FFN with two Conv1d k=1 — 64 → 256 → 64, ReLU + mask."""
def __init__(self) -> None:
super().__init__()
self.conv_1 = _DPConvLayer(DP_DIM, DP_FFN_HIDDEN)
self.conv_2 = _DPConvLayer(DP_FFN_HIDDEN, DP_DIM)
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
if mask is not None:
x = x * mask
y = self.conv_1(x)
y = mx.maximum(y, mx.array(0.0, dtype=y.dtype))
if mask is not None:
y = y * mask
y = self.conv_2(y)
if mask is not None:
y = y * mask
return y
class _DPAttnEncoder(nn.Module):
"""2× (attn + norm) + (ffn + norm)."""
def __init__(self) -> None:
super().__init__()
self.attn_layers = [_DPRelPosSelfAttn() for _ in range(DP_ATTN_NUM_LAYERS)]
self.norm_layers_1 = [WrappedNorm(DP_DIM, eps=EPS_LN) for _ in range(DP_ATTN_NUM_LAYERS)]
self.ffn_layers = [_DPFFN() for _ in range(DP_ATTN_NUM_LAYERS)]
self.norm_layers_2 = [WrappedNorm(DP_DIM, eps=EPS_LN) for _ in range(DP_ATTN_NUM_LAYERS)]
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
for i in range(DP_ATTN_NUM_LAYERS):
x = self.norm_layers_1[i](x + self.attn_layers[i](x, mask=mask))
x = self.norm_layers_2[i](x + self.ffn_layers[i](x, mask))
return x
class _DPSentenceEncoder(nn.Module):
"""Text → 64-d sentence vector via prepended ``sentence_token`` slot."""
def __init__(self) -> None:
super().__init__()
class _TextEmb(nn.Module):
def __init__(_):
super().__init__()
_.char_embedder = nn.Embedding(DP_VOCAB, DP_DIM)
def __call__(_, ids):
return _.char_embedder(ids)
self.text_embedder = _TextEmb()
self.convnext = _DPConvNeXtStack()
self.attn_encoder = _DPAttnEncoder()
# proj_out keeps the .net.weight (out, 1, in) Conv1d-k1 layout
self.proj_out = _DPProjOut()
# sentence_token (1, DIM, 1) — prepended as the first time slot
self.sentence_token = mx.zeros((1, DP_DIM, 1))
def __call__(self, text_ids: mx.array, text_mask: mx.array) -> mx.array:
x = self.text_embedder(text_ids) # (B, T, 64)
# Prepend sentence_token: shape (1, 64, 1) → (B, 1, 64)
B = x.shape[0]
sentence = self.sentence_token.transpose(0, 2, 1)
sentence = mx.broadcast_to(sentence, (B, 1, DP_DIM))
x = mx.concatenate([sentence, x], axis=1) # (B, T+1, 64)
# Extend mask with a leading 1 (sentence token always valid)
if text_mask is not None:
extra = mx.ones((B, 1, 1), dtype=text_mask.dtype)
mask_ntc = mx.concatenate([extra, text_mask.transpose(0, 2, 1)], axis=1)
else:
mask_ntc = None
x = self.convnext(x, mask_ntc)
x = self.attn_encoder(x, mask_ntc)
# Take slot 0 (sentence token output) → (B, 1, 64)
sentence_out = x[:, :1, :] # (B, 1, 64)
# proj_out (Conv1d k=1) — applied along time, output (B, 1, 64)
sentence_out = self.proj_out(sentence_out)
return sentence_out.reshape(B, DP_DIM) # (B, 64)
class _DPProjOut(nn.Module):
"""Conv1d k=1 64→64. No bias in ONNX (confirmed via graph inspection)."""
def __init__(self) -> None:
super().__init__()
class _Net(nn.Module):
def __init__(_):
super().__init__()
_.weight = mx.zeros((DP_DIM, 1, DP_DIM))
def __call__(_, x):
return mx.conv1d(x, _.weight, stride=1, padding=0)
self.net = _Net()
def __call__(self, x: mx.array) -> mx.array:
return self.net(x)
class _DPPredictor(nn.Module):
"""Linear(192 → 128) + PReLU + Linear(128 → 1).
PReLU is stored under ``activation.weight (1,)`` — a single learnable
negative-slope coefficient.
"""
def __init__(self) -> None:
super().__init__()
self.layers = [
nn.Linear(DP_MLP_IN, DP_MLP_HIDDEN, bias=True),
nn.Linear(DP_MLP_HIDDEN, 1, bias=True),
]
# PReLU: activation.weight shape (1,) — single scalar slope
class _Activation(nn.Module):
def __init__(_):
super().__init__()
_.weight = mx.zeros((1,))
def __call__(_, x):
# PReLU(x) = max(0, x) + slope * min(0, x)
neg = mx.minimum(x, mx.array(0.0, dtype=x.dtype))
pos = mx.maximum(x, mx.array(0.0, dtype=x.dtype))
return pos + _.weight * neg
self.activation = _Activation()
def __call__(self, x: mx.array) -> mx.array:
h = self.layers[0](x) # (B, 128)
h = self.activation(h)
h = self.layers[1](h) # (B, 1)
return h
class _DPRoot(nn.Module):
"""``tts.dp.X`` namespace container."""
def __init__(self) -> None:
super().__init__()
self.sentence_encoder = _DPSentenceEncoder()
self.predictor = _DPPredictor()
class _DPContainer(nn.Module):
def __init__(self) -> None:
super().__init__()
self.dp = _DPRoot()
class DurationPredictor(nn.Module):
"""Predicts total audio duration (seconds) for an utterance.
Submodule namespace matches ONNX keys ``tts.dp.X.Y`` exactly.
"""
def __init__(self) -> None:
super().__init__()
self.tts = _DPContainer()
def __call__(
self,
text_ids: mx.array, # (B, T) int
style_dp: mx.array, # (B, 8, 16)
text_mask: mx.array, # (B, 1, T)
) -> mx.array:
sentence = self.tts.dp.sentence_encoder(text_ids, text_mask) # (B, 64)
style_flat = style_dp.reshape(style_dp.shape[0], -1) # (B, 128)
joined = mx.concatenate([sentence, style_flat], axis=-1) # (B, 192)
log_dur = self.tts.dp.predictor(joined).reshape(-1) # (B,)
return mx.exp(log_dur) # duration in seconds
__all__ = ["DurationPredictor"]

View File

@@ -0,0 +1,545 @@
"""Supertonic 3 end-to-end MLX pipeline.
Stitches the four MLX sub-models (DurationPredictor → TextEncoder →
VectorEstimator → Vocoder) into a single ``generate(text, voice, lang)`` call
that returns a 44.1 kHz mono numpy waveform.
Flow:
text ──tokenize(unicode_indexer)──▶ text_ids (B, T_text)
voice_style (.json) ──▶ style_ttl (B, 50, 256), style_dp (B, 8, 16)
duration_predictor(text_ids, style_dp, text_mask) ──▶ duration_s (B,)
text_encoder(text_ids, style_ttl, text_mask) ──▶ text_emb (B, 256, T_text)
noise ~ N(0, I) of shape (B, 144, T_lat)
where T_lat = ceil(duration_s × 44100 / (512 × 6))
vector_estimator 5-step Euler with CFG (4×cond 3×uncond):
for step in [0..4]:
x ← VE(x, text_emb, style_ttl, masks, current_step=step+1, total_step=5)
vocoder(audio_latent) ──▶ wav (B, T_lat × 6 × 512)
Public API:
pipe = SupertonicMLXPipeline.from_pretrained("/tmp/supertonic3/model")
wav = pipe.generate("Hello world", voice="F1", lang="en")
import soundfile as sf
sf.write("out.wav", wav, pipe.sample_rate)
"""
from __future__ import annotations
import json
import math
from pathlib import Path
from typing import Optional
import mlx.core as mx
import numpy as np
from supertonic_3_mlx._config import SAMPLE_RATE
from supertonic_3_mlx.duration_predictor import DurationPredictor
from supertonic_3_mlx.text_encoder import TextEncoder
from supertonic_3_mlx.vector_estimator import VectorEstimator
from supertonic_3_mlx.vocoder import Vocoder
# Latent rate: at 44.1 kHz with hop=512 and chunk_compress=6, one latent step
# covers 512 × 6 = 3072 samples = 69.7 ms.
SAMPLES_PER_LATENT_STEP = 512 * 6 # 3072
# ── Shared ONNX → MLX weight extraction ─────────────────────────────
def _convert_onnx(onnx_path: str | Path) -> dict:
"""Return a dict of ``{clean_key: mx.array}`` for a Supertonic ONNX file.
Combines the three extraction stages discovered during the per-component
ports (T.3.1, T.3.2, T.3.3):
1. Named ``tts.*`` initialisers with shape transforms (dwconv, gamma,
pwconv, head.layer2).
2. Anonymous MatMul weights recovered via the MatMul output path.
3. Anonymous Conv weights and PReLU slopes recovered the same way.
"""
import onnx
import onnx.numpy_helper as nh
m = onnx.load(str(onnx_path))
def _matmul_clean(out_name: str) -> str:
p = out_name.lstrip("/")
if p.endswith("/MatMul_output_0"):
p = p[: -len("/MatMul_output_0")]
# Drop the leading model-name path (e.g. /text_encoder/, /duration_predictor/, /vector_estimator/)
for prefix in ("text_encoder/", "duration_predictor/", "vector_estimator/", "vocoder/"):
if p.startswith(prefix):
p = p[len(prefix):]
break
return p.replace("/", ".") + ".weight"
def _conv_clean(out_name: str) -> str:
p = out_name.lstrip("/")
if p.endswith("/Conv_output_0"):
p = p[: -len("/Conv_output_0")]
for prefix in ("vocoder/", "vector_estimator/", "text_encoder/", "duration_predictor/"):
if p.startswith(prefix):
p = p[len(prefix):]
break
return "tts.ae." + p.replace("/", ".")
def _prelu_clean(out_name: str) -> str:
p = out_name.lstrip("/")
if p.endswith("/PRelu_output_0"):
p = p[: -len("/PRelu_output_0")]
for prefix in ("vocoder/", "vector_estimator/"):
if p.startswith(prefix):
p = p[len(prefix):]
break
return "tts.ae." + p.replace("/", ".") + ".weight"
# Detect which model this file is — affects how we wrap named init keys
name_prefixes = {init.name.split(".")[0] for init in m.graph.initializer if "." in init.name}
is_text_encoder = "tts" in name_prefixes and any(
i.name.startswith("tts.ttl.text_encoder") for i in m.graph.initializer
)
weights: dict[str, mx.array] = {}
# Stage 1: named initialisers
for init in m.graph.initializer:
n = init.name
# Determine if this is a structured (named) weight or an anonymous graph const
if not (n.startswith("tts.") or "vector_estimator.tts.ttl." in n or "uncond_masker." in n):
continue
# Strip the vector_estimator-specific prefix so all 4 models share a name space.
if n.startswith("vector_estimator.tts.ttl."):
clean = n[len("vector_estimator.tts.ttl."):]
else:
clean = n
arr = nh.to_array(init)
# Shape transforms
if (clean.endswith(".dwconv.weight") and arr.ndim == 3
and arr.shape[1] == 1 and arr.shape[2] != 1):
arr = np.transpose(arr, (0, 2, 1))
if (clean.endswith(".dwconv.net.weight") and arr.ndim == 3
and arr.shape[1] == 1):
arr = np.transpose(arr, (0, 2, 1))
if (clean.endswith(".gamma") and arr.ndim == 3
and arr.shape[0] == 1 and arr.shape[2] == 1):
arr = arr.reshape(arr.shape[1])
if ((clean.endswith(".pwconv1.weight") or clean.endswith(".pwconv2.weight"))
and arr.ndim == 3 and arr.shape[-1] == 1):
arr = arr.squeeze(-1)
if clean.endswith(".net.weight") and arr.ndim == 3 and arr.shape[-1] == 1:
# Conv1d k=1 wrapped via .net (e.g. proj_in/proj_out)
arr = arr.squeeze(-1)
# vocoder head.layer2 (out, in, 1) → MLX Conv1d (out, K=1, in)
if clean == "tts.ae.decoder.head.layer2.weight" and arr.ndim == 3:
arr = np.transpose(arr, (0, 2, 1))
# vocoder head.layer1.net.weight (out, in, K) → MLX Conv1d (out, K, in)
if clean == "tts.ae.decoder.head.layer1.net.weight" and arr.ndim == 3:
arr = np.transpose(arr, (0, 2, 1))
weights[clean] = mx.array(arr)
# Stage 2: MatMul weight recovery
inits_map = {init.name: init for init in m.graph.initializer}
for node in m.graph.node:
if node.op_type != "MatMul" or len(node.input) < 2:
continue
winp = node.input[1]
if winp not in inits_map or winp.startswith("tts.") or "vector_estimator.tts" in winp:
continue
arr = nh.to_array(inits_map[winp])
if arr.ndim == 2:
arr = arr.T # ONNX (in, out) → MLX Linear (out, in)
clean = _matmul_clean(node.output[0])
# Build the leading namespace from the file context (already in tts.*)
if not clean.startswith(("tts.", "vector_field.", "uncond_masker.")):
clean = "tts.ttl." + clean if is_text_encoder else clean
weights[clean] = mx.array(arr)
# Stage 3: anonymous Conv + PReLU (vocoder embed / head)
for node in m.graph.node:
if node.op_type == "Conv":
for i, inp in enumerate(node.input[1:], 1):
if inp not in inits_map or inp.startswith("tts."):
continue
arr = nh.to_array(inits_map[inp])
base = _conv_clean(node.output[0])
if "dwconv" in base:
continue
if i == 1 and arr.ndim == 3:
arr = np.transpose(arr, (0, 2, 1)) # ONNX (out, in, K) → MLX (out, K, in)
key = base + (".weight" if i == 1 else ".bias")
weights[key] = mx.array(arr)
elif node.op_type == "PRelu":
for inp in node.input[1:]:
if inp in inits_map and not inp.startswith("tts."):
weights[_prelu_clean(node.output[0])] = mx.array(nh.to_array(inits_map[inp]))
return weights
def _load_into(model, weights: dict) -> int:
"""Match converted weights to model params (shape-tolerant via reshape).
Returns the number of successfully matched tensors.
"""
from mlx.utils import tree_flatten
expected = {k: tuple(v.shape) for k, v in tree_flatten(model.parameters())}
matched = {}
for k, exp_shape in expected.items():
if k not in weights:
continue
v = weights[k]
if tuple(v.shape) != exp_shape:
if v.size == np.prod(exp_shape):
v = v.reshape(exp_shape)
else:
continue
matched[k] = v
model.load_weights(list(matched.items()), strict=False)
return len(matched)
# ── Tokenization ────────────────────────────────────────────────────
def _encode_text(text: str, indexer: list[int], lang: str = "en") -> np.ndarray:
"""Encode a text string into character IDs.
The unicode_indexer is a flat list of size 65536; ``indexer[ord(c)]`` gives
the token ID for character ``c`` (-1 = unknown). For Phase T.4 we wrap the
text with no special language tokens — the ONNX SDK uses language tags but
our pipeline currently runs unconditioned on language for the first WAV
emission (parity validation happens after).
"""
ids = []
for c in text:
cp = ord(c)
if 0 <= cp < len(indexer):
tok = indexer[cp]
if tok >= 0:
ids.append(tok)
if not ids:
# fallback to a single space token to avoid empty input
ids = [indexer[ord(" ")]] if indexer[ord(" ")] >= 0 else [0]
return np.asarray(ids, dtype=np.int32)
# ── Pipeline ────────────────────────────────────────────────────────
class SupertonicMLXPipeline:
"""End-to-end Supertonic 3 TTS pipeline in pure MLX.
Loads four sub-models (duration_predictor, text_encoder, vector_estimator,
vocoder), the unicode tokenizer, and exposes ``generate(text, voice, lang)``.
"""
sample_rate: int = SAMPLE_RATE
# Locked by the model architecture: Supertonic 3 is a flow-matching + CFG
# model trained for exactly 5 Euler steps with t ∈ {0.2, 0.4, 0.6, 0.8, 1.0}
# and the combination 4×cond 3×uncond. Any other step count or skipping
# CFG produces an essentially uncorrelated waveform (verified by
# ``sub-projects/supertonic3-mlx/bench_n_steps.py``: cosine drops to
# ≤ 0.5 for n∈{3,4,6} and ≈ 0.05 for cfg=False). Reducing inference
# latency further would require distilling a shorter-schedule model.
n_euler_steps: int = 5
def __init__(
self,
duration_predictor: DurationPredictor,
text_encoder: TextEncoder,
vector_estimator: VectorEstimator,
vocoder: Vocoder,
unicode_indexer: list[int],
voice_dir: Path,
) -> None:
self.duration_predictor = duration_predictor
self.text_encoder = text_encoder
self.vector_estimator = vector_estimator
self.vocoder = vocoder
self.unicode_indexer = unicode_indexer
self.voice_dir = voice_dir
# T.5 — compile the hot loops. ``mx.compile`` caches a kernel graph keyed
# by input shapes; the 5× CFG Euler loop and the single vocoder pass
# both gain from fused kernel dispatch (~50100 layer ops collapse into
# one dispatch per cached graph).
# T.5.3 — also pre-project text and style K/V outside the step. They
# are invariant across the 5 Euler steps, so the 4 text_attn + 4
# style_attn blocks no longer re-run their W_key / W_value / RoPE_K
# matmuls on every step (saves 40 matmuls per generate).
cond_scale = self.vector_estimator.CFG_COND_SCALE
uncond_scale = self.vector_estimator.CFG_UNCOND_SCALE
def _cached_step(
noisy, lat_mask_2, text_mask_2, t_norm_2, total_step, kv_flat,
):
noisy_2 = mx.concatenate([noisy, noisy], axis=0)
text_kv = [(kv_flat[2 * i], kv_flat[2 * i + 1]) for i in range(4)]
style_kv = [(kv_flat[8 + 2 * i], kv_flat[8 + 2 * i + 1]) for i in range(4)]
v_2 = self.vector_estimator.velocity_cached(
noisy_2, lat_mask_2, text_mask_2, t_norm_2, text_kv, style_kv,
)
B = noisy.shape[0]
cond_v = v_2[:B]
uncond_v = v_2[B:]
combined = cond_scale * cond_v - uncond_scale * uncond_v
return noisy + combined / total_step.reshape(-1, 1, 1).astype(combined.dtype)
def _voc_step(latent):
return self.vocoder(latent)
self._cached_step_compiled = mx.compile(_cached_step)
self._voc_compiled = mx.compile(_voc_step)
# Pick the runtime dtype from any leaf weight of the vector estimator —
# ``from_pretrained(dtype=...)`` may have cast the model to ``bf16``,
# in which case all inputs to the compiled hot loops must be cast to
# match (mixed-dtype Conv/MatMul is not legal in MLX).
from mlx.utils import tree_flatten
leaves = [v for _, v in tree_flatten(vector_estimator.parameters())
if isinstance(v, mx.array)]
self.dtype = leaves[0].dtype if leaves else mx.float32
@classmethod
def from_pretrained(
cls,
model_id_or_path: str | Path,
dtype: mx.Dtype | None = None,
cache_dir: str | Path | None = None,
revision: str | None = None,
) -> "SupertonicMLXPipeline":
"""Construct the pipeline from a model snapshot.
Three sources are accepted, auto-detected:
1. **Hugging Face Hub repo id** (e.g. ``"ambassadia/supertonic-3-mlx"``):
weights are downloaded via :func:`huggingface_hub.snapshot_download`
into ``cache_dir`` (defaults to the standard HF cache) and loaded
directly from the bundled ``weights/*.safetensors`` files.
2. **Local path with a** ``weights/`` **subdir**: the MLX-native
layout (4 safetensors + ``unicode_indexer.json`` + ``voice_styles/``).
Fast path — no ONNX conversion at runtime.
3. **Local path with an** ``onnx/`` **subdir**: the upstream
``Supertone/supertonic-3`` snapshot layout. Weights are converted
from ONNX on the fly (~ 1 s per sub-model on M4). Useful for
development or when starting from the original upstream release.
Optional kwargs:
dtype — if non-None and not float32, cast all weights to the
given dtype after load (only ``mx.bfloat16`` is
currently meaningful; see README "BF16 note").
cache_dir — passed to ``huggingface_hub.snapshot_download``.
revision — branch / tag / commit sha on the Hub.
"""
# 1. Resolve the local snapshot directory
if isinstance(model_id_or_path, str) and "/" in model_id_or_path \
and not Path(model_id_or_path).exists():
try:
from huggingface_hub import snapshot_download
except ImportError as e:
raise ImportError(
"Loading from the Hugging Face Hub requires "
"``huggingface_hub`` — install with ``pip install "
"supertonic-3-mlx[hub]`` or ``pip install huggingface_hub``."
) from e
local_dir = Path(snapshot_download(
repo_id=model_id_or_path,
cache_dir=cache_dir,
revision=revision,
allow_patterns=[
"weights/*.safetensors",
"unicode_indexer.json",
"voice_styles/*.json",
],
))
else:
local_dir = Path(model_id_or_path)
# 2. Detect layout
weights_dir = local_dir / "weights"
onnx_dir = local_dir / "onnx"
if weights_dir.exists():
return cls._from_safetensors(local_dir, dtype=dtype)
if onnx_dir.exists():
return cls._from_onnx(local_dir, dtype=dtype)
raise FileNotFoundError(
f"{local_dir} contains neither ``weights/`` (safetensors layout) "
f"nor ``onnx/`` (upstream layout); cannot load."
)
@classmethod
def _from_safetensors(
cls, local_dir: Path, dtype: mx.Dtype | None = None,
) -> "SupertonicMLXPipeline":
from mlx.utils import tree_flatten
weights_dir = local_dir / "weights"
voice_dir = local_dir / "voice_styles"
unicode_indexer = json.loads((local_dir / "unicode_indexer.json").read_text())
def _build(cls_, name):
model = cls_()
w = mx.load(str(weights_dir / f"{name}.safetensors"))
# Reshape any mismatched leaves (defensive; the converter already
# produced shape-correct tensors but a future re-export may not).
expected = {k: tuple(v.shape) for k, v in tree_flatten(model.parameters())}
for k in list(w.keys()):
if k in expected and tuple(w[k].shape) != expected[k]:
if w[k].size == int(np.prod(expected[k])):
w[k] = w[k].reshape(expected[k])
model.load_weights(list(w.items()), strict=False)
return model
ve = _build(VectorEstimator, "vector_estimator")
te = _build(TextEncoder, "text_encoder")
dp = _build(DurationPredictor, "duration_predictor")
voc = _build(Vocoder, "vocoder")
if dtype is not None and dtype != mx.float32:
cls._cast_all(dp, te, ve, voc, dtype=dtype)
return cls(dp, te, ve, voc, unicode_indexer, voice_dir)
@classmethod
def _from_onnx(
cls, local_dir: Path, dtype: mx.Dtype | None = None,
) -> "SupertonicMLXPipeline":
onnx_dir = local_dir / "onnx"
voice_dir = local_dir / "voice_styles"
unicode_indexer = json.loads((onnx_dir / "unicode_indexer.json").read_text())
ve = VectorEstimator()
_load_into(ve, _convert_onnx(onnx_dir / "vector_estimator.onnx"))
te = TextEncoder()
_load_into(te, _convert_onnx(onnx_dir / "text_encoder.onnx"))
dp = DurationPredictor()
_load_into(dp, _convert_onnx(onnx_dir / "duration_predictor.onnx"))
voc = Vocoder()
_load_into(voc, _convert_onnx(onnx_dir / "vocoder.onnx"))
if dtype is not None and dtype != mx.float32:
cls._cast_all(dp, te, ve, voc, dtype=dtype)
return cls(dp, te, ve, voc, unicode_indexer, voice_dir)
@staticmethod
def _cast_all(*models, dtype: mx.Dtype) -> None:
"""Cast all fp32 leaves of each model to ``dtype`` (in-place)."""
from mlx.utils import tree_map
def _cast(p):
if not isinstance(p, mx.array) or p.dtype != mx.float32:
return p
return p.astype(dtype)
for m_ in models:
m_.update(tree_map(_cast, m_.parameters()))
def _load_voice(self, voice: str) -> tuple[mx.array, mx.array]:
"""Load ``voice_styles/<voice>.json`` and return (style_ttl, style_dp)."""
path = self.voice_dir / f"{voice}.json"
data = json.loads(path.read_text())
style_ttl = np.asarray(data["style_ttl"]["data"], dtype=np.float32) # (1, 50, 256)
style_dp = np.asarray(data["style_dp"]["data"], dtype=np.float32) # (1, 8, 16)
return mx.array(style_ttl), mx.array(style_dp)
def generate(
self,
text: str,
voice: str = "F1",
lang: str = "en",
seed: int = 42,
n_steps: Optional[int] = None,
) -> np.ndarray:
"""Synthesise a single utterance. Returns a 1D float32 numpy waveform."""
n_steps = n_steps if n_steps is not None else self.n_euler_steps
# Tokenize
text_ids_np = _encode_text(text, self.unicode_indexer, lang)
text_ids = mx.array(text_ids_np[None, :]) # (1, T_text)
T_text = text_ids.shape[1]
text_mask = mx.ones((1, 1, T_text), dtype=self.dtype)
# Style
style_ttl, style_dp = self._load_voice(voice)
if self.dtype != mx.float32:
style_ttl = style_ttl.astype(self.dtype)
style_dp = style_dp.astype(self.dtype)
# Duration → latent length
duration_s = self.duration_predictor(text_ids, style_dp, text_mask)
mx.eval(duration_s)
duration_val = max(float(duration_s[0].item()), 0.5) # clamp to ≥ 0.5 s
T_lat = max(int(math.ceil(duration_val * self.sample_rate / SAMPLES_PER_LATENT_STEP)), 1)
# Text embedding
text_emb = self.text_encoder(text_ids, style_ttl, text_mask) # (1, 256, T_text)
# Initial noise — fixed seed for reproducibility
key = mx.random.key(seed)
noise = mx.random.normal((1, 144, T_lat), key=key).astype(self.dtype)
latent_mask = mx.ones((1, 1, T_lat), dtype=self.dtype)
# T.5.3 — build the (2B) CFG conditioning tensors once and pre-project
# K/V for every text_attn / style_attn block. ``kv_flat`` is the 16
# ``(K, V)`` arrays flattened into a list for the compiled step.
B = noise.shape[0]
ve = self.vector_estimator
text_uncond = mx.broadcast_to(
ve.uncond_masker.text_special_token, (B, text_emb.shape[1], text_emb.shape[2])
).astype(self.dtype)
style_k_uncond = mx.broadcast_to(
ve.uncond_masker.style_key_special_token, (B, style_ttl.shape[1], style_ttl.shape[2])
).astype(self.dtype)
style_v_uncond = mx.broadcast_to(
ve.uncond_masker.style_value_special_token, (B, style_ttl.shape[1], style_ttl.shape[2])
).astype(self.dtype)
text_emb_2 = mx.concatenate([text_emb, text_uncond], axis=0)
style_k_2 = mx.concatenate([style_ttl, style_k_uncond], axis=0)
style_v_2 = mx.concatenate([style_ttl, style_v_uncond], axis=0)
text_mask_2 = mx.concatenate([text_mask, text_mask], axis=0)
latent_mask_2 = mx.concatenate([latent_mask, latent_mask], axis=0)
text_kv, style_kv = ve.precompute_cross_kv(
text_emb_2, style_k_2, style_v_2, text_mask_2,
)
kv_flat = []
for k, v in text_kv:
kv_flat.extend([k, v])
for k, v in style_kv:
kv_flat.extend([k, v])
# Euler with CFG — 5 steps by default
x = noise
total_step = mx.array([float(n_steps)], dtype=self.dtype)
for step in range(n_steps):
current_step = mx.array([float(step + 1)], dtype=self.dtype)
t_norm = current_step / total_step
t_norm_2 = mx.concatenate([t_norm, t_norm], axis=0)
x = self._cached_step_compiled(
x, latent_mask_2, text_mask_2, t_norm_2, total_step, kv_flat,
)
mx.eval(x)
# Decode latent → waveform
wav = self._voc_compiled(x)
mx.eval(wav)
if wav.dtype != mx.float32:
wav = wav.astype(mx.float32)
return np.array(wav)[0] # (T_lat × 6 × 512,)
__all__ = ["SupertonicMLXPipeline"]

View File

@@ -0,0 +1,382 @@
"""Supertonic 3 text encoder MLX port.
Pipeline (operating in channels-last NTC after the initial conv):
text_ids [B, T_text] int64 character IDs
→ char_embedder (Embedding 8322→256) [B, T_text, 256]
→ 6× ConvNeXt(dim=256, hidden=1024, k=5, dilations [1,1,2,2,4,4])
→ 4× attn_encoder block:
RelPosSelfAttn (conv_q/k/v/o, 4 heads × 64) + norm_layers_1
FFN (conv_1: 256→1024, conv_2: 1024→256) + norm_layers_2
→ speech_prompted_text_encoder:
cross-attn1: text (Q) × style_ttl (K, V) → text features
cross-attn2: text (Q) × style_ttl (K, V) → text features
norm
→ output text_emb [B, 256, T_text] (channels-first to match vector_estimator)
Inputs:
text_ids: (B, T_text) int — character indices
style_ttl: (B, 50, 256) float — style token bank
text_mask: (B, 1, T_text) float — 1.0 where valid, 0.0 where padded
Submodule naming matches the ONNX initializer keys exactly so that
``model.load_weights(...)`` succeeds with no remapping.
"""
from __future__ import annotations
import mlx.core as mx
import mlx.nn as nn
from supertonic_3_mlx._config import EPS_LN
from supertonic_3_mlx._nn_wrappers import WrappedNorm, WrappedLinear
from supertonic_3_mlx.vector_estimator import (
ConvNeXtBlock, _pad_sym_edge, _gelu_exact,
)
# Vocab + dims (frozen by checkpoint)
VOCAB_SIZE = 8322
TE_DIM = 256
TE_CONVNEXT_HIDDEN = 1024
TE_CONVNEXT_K = 5
TE_CONVNEXT_NUM_LAYERS = 6
TE_CONVNEXT_DILATIONS = (1, 1, 2, 2, 4, 4)
TE_ATTN_NUM_LAYERS = 4
TE_ATTN_HEADS = 4
TE_ATTN_HEAD_DIM = TE_DIM // TE_ATTN_HEADS # 64
TE_FFN_HIDDEN = 1024
class TextConvNeXtBlock(nn.Module):
"""ConvNeXt for the text encoder (dim=256, hidden=1024).
Shares the same architecture as ``vector_estimator.ConvNeXtBlock`` but is
redefined here with text-encoder-specific defaults to keep the modules
self-contained.
"""
def __init__(self, dilation: int = 1) -> None:
super().__init__()
self.dim = TE_DIM
self.dilation = dilation
self.pad = dilation * (TE_CONVNEXT_K - 1) // 2
self.dwconv = nn.Conv1d(
TE_DIM, TE_DIM, kernel_size=TE_CONVNEXT_K, padding=0,
dilation=dilation, groups=TE_DIM, bias=True,
)
self.norm = WrappedNorm(TE_DIM, eps=EPS_LN)
self.pwconv1 = nn.Linear(TE_DIM, TE_CONVNEXT_HIDDEN, bias=True)
self.pwconv2 = nn.Linear(TE_CONVNEXT_HIDDEN, TE_DIM, bias=True)
self.gamma = mx.zeros((TE_DIM,))
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
# x: (B, T_text, 256)
residual = x
y = _pad_sym_edge(x, self.pad)
y = self.dwconv(y)
y = self.norm(y)
y = self.pwconv1(y)
y = _gelu_exact(y)
y = self.pwconv2(y)
y = y * self.gamma
out = residual + y
if mask is not None:
out = out * mask
return out
class TextConvNeXtStack(nn.Module):
"""6 stacked ConvNeXt blocks. Loaded as ``convnext.convnext.[0..5].X``."""
def __init__(self) -> None:
super().__init__()
self.convnext = [TextConvNeXtBlock(d) for d in TE_CONVNEXT_DILATIONS]
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
for b in self.convnext:
x = b(x, mask)
return x
class _ConvLayer(nn.Module):
"""Conv1d k=1 expressed via the ONNX-style ``X.weight (out, in, 1) + X.bias``.
The attn_encoder uses Conv1d k=1 instead of nn.Linear for its Q/K/V/O.
This wrapper keeps the weight shape (out, in, 1) intact and runs as a
Conv1d (the equivalent of a Linear when k=1).
"""
def __init__(self, in_dim: int, out_dim: int) -> None:
super().__init__()
self.weight = mx.zeros((out_dim, 1, in_dim)) # (C_out, K=1, C_in)
self.bias = mx.zeros((out_dim,))
def __call__(self, x: mx.array) -> mx.array:
# x: (B, T, in_dim) — channels-last
# equivalent to nn.Conv1d(in_dim, out_dim, k=1) in NTC layout
return mx.conv1d(x, self.weight, stride=1, padding=0) + self.bias
REL_POS_WINDOW = 4 # rel_pos table size = 2*4 + 1 = 9
def _rel_to_abs(x: mx.array) -> mx.array:
"""[B, h, L, 2L-1] → [B, h, L, L] via the VITS shifted-skew reshape."""
B, h, L, _ = x.shape
x = mx.concatenate([x, mx.zeros((B, h, L, 1), dtype=x.dtype)], axis=-1)
x_flat = x.reshape(B, h, L * 2 * L)
x_flat = mx.concatenate([x_flat, mx.zeros((B, h, L - 1), dtype=x.dtype)], axis=-1)
x_final = x_flat.reshape(B, h, L + 1, 2 * L - 1)
return x_final[:, :, :L, L - 1:]
def _abs_to_rel(x: mx.array) -> mx.array:
"""[B, h, L, L] → [B, h, L, 2L-1] (inverse of _rel_to_abs)."""
B, h, L, _ = x.shape
x = mx.concatenate([x, mx.zeros((B, h, L, L - 1), dtype=x.dtype)], axis=-1)
x_flat = x.reshape(B, h, L * (2 * L - 1))
x_flat = mx.concatenate([mx.zeros((B, h, L), dtype=x.dtype), x_flat], axis=-1)
x_final = x_flat.reshape(B, h, L, 2 * L)
return x_final[:, :, :, 1:]
def _slice_rel_emb(rel: mx.array, length: int, window: int) -> mx.array:
"""``rel`` (1, 2W+1, d) → (1, 2L-1, d) by zero-padding/slicing."""
pad_l = max(length - (window + 1), 0)
if pad_l > 0:
zero = mx.zeros((1, pad_l, rel.shape[-1]), dtype=rel.dtype)
padded = mx.concatenate([zero, rel, zero], axis=1)
else:
padded = rel
start = max(window + 1 - length, 0)
return padded[:, start: start + 2 * length - 1]
class RelPosSelfAttention(nn.Module):
"""VITS-style relative-position self-attention with window=4.
Adds two contributions to vanilla MHA:
- ``rel_logits = q @ rel_k.T`` then ``_rel_to_abs`` and added to attention logits
- ``rel_attn = _abs_to_rel(softmax(logits))`` then ``@ rel_v`` and added to output
Loaded keys (per layer):
``conv_q/k/v/o.weight`` (256, 256, 1) and ``.bias`` (256)
``emb_rel_k`` (1, 9, 64), ``emb_rel_v`` (1, 9, 64)
"""
def __init__(self) -> None:
super().__init__()
self.conv_q = _ConvLayer(TE_DIM, TE_DIM)
self.conv_k = _ConvLayer(TE_DIM, TE_DIM)
self.conv_v = _ConvLayer(TE_DIM, TE_DIM)
self.conv_o = _ConvLayer(TE_DIM, TE_DIM)
self.window = REL_POS_WINDOW
self.emb_rel_k = mx.zeros((1, 2 * REL_POS_WINDOW + 1, TE_ATTN_HEAD_DIM))
self.emb_rel_v = mx.zeros((1, 2 * REL_POS_WINDOW + 1, TE_ATTN_HEAD_DIM))
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
B, T, _ = x.shape
H, D = TE_ATTN_HEADS, TE_ATTN_HEAD_DIM
q = self.conv_q(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
k = self.conv_k(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
v = self.conv_v(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
scale = D ** -0.5
# Standard attention logits
logits = (q @ k.transpose(0, 1, 3, 2)) * scale # (B, H, T, T)
# VITS relative-position contribution to logits
rel_k = _slice_rel_emb(self.emb_rel_k, T, self.window) # (1, 2T-1, D)
rel_logits = q @ rel_k.transpose(0, 2, 1)[:, None, :, :] # (B, H, T, 2T-1)
rel_logits = _rel_to_abs(rel_logits * scale) # (B, H, T, T)
logits = logits + rel_logits
if mask is not None:
key_mask = mask[:, :, 0][:, None, None, :]
neg_inf = mx.array(-1e4, dtype=logits.dtype)
logits = mx.where(key_mask.astype(mx.bool_), logits, neg_inf)
attn = mx.softmax(logits, axis=-1) # (B, H, T, T)
out = attn @ v # (B, H, T, D)
# VITS rel-pos value contribution
rel_v = _slice_rel_emb(self.emb_rel_v, T, self.window) # (1, 2T-1, D)
rel_weights = _abs_to_rel(attn) # (B, H, T, 2T-1)
out = out + rel_weights @ rel_v[:, None, :, :] # (B, H, T, D)
out = out.transpose(0, 2, 1, 3).reshape(B, T, H * D)
return self.conv_o(out)
class FFN(nn.Module):
"""FFN with Conv1d k=1 wrappers: conv_1 (256→1024) + ReLU + conv_2 (1024→256).
Activation is ReLU (confirmed by ONNX graph node ``Relu`` in ``ffn_layers.N``),
not GELU. The mask is applied before each Conv to match the ONNX semantics.
"""
def __init__(self) -> None:
super().__init__()
self.conv_1 = _ConvLayer(TE_DIM, TE_FFN_HIDDEN)
self.conv_2 = _ConvLayer(TE_FFN_HIDDEN, TE_DIM)
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
if mask is not None:
x = x * mask
y = self.conv_1(x)
y = mx.maximum(y, mx.array(0.0, dtype=y.dtype))
if mask is not None:
y = y * mask
y = self.conv_2(y)
if mask is not None:
y = y * mask
return y
class AttnEncoder(nn.Module):
"""Stack of (RelPosSelfAttn + norm1) + (FFN + norm2) × 4."""
def __init__(self) -> None:
super().__init__()
self.attn_layers = [RelPosSelfAttention() for _ in range(TE_ATTN_NUM_LAYERS)]
self.norm_layers_1 = [WrappedNorm(TE_DIM, eps=EPS_LN) for _ in range(TE_ATTN_NUM_LAYERS)]
self.ffn_layers = [FFN() for _ in range(TE_ATTN_NUM_LAYERS)]
self.norm_layers_2 = [WrappedNorm(TE_DIM, eps=EPS_LN) for _ in range(TE_ATTN_NUM_LAYERS)]
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
for i in range(TE_ATTN_NUM_LAYERS):
y = self.attn_layers[i](x, mask=mask)
x = self.norm_layers_1[i](x + y)
y = self.ffn_layers[i](x, mask)
x = self.norm_layers_2[i](x + y)
return x
class _TextEmbedder(nn.Module):
"""char_embedder: VOCAB → TE_DIM. Loaded as ``char_embedder.weight (8322, 256)``."""
def __init__(self) -> None:
super().__init__()
self.char_embedder = nn.Embedding(VOCAB_SIZE, TE_DIM)
def __call__(self, text_ids: mx.array) -> mx.array:
return self.char_embedder(text_ids)
class _InnerTextEncoder(nn.Module):
"""Pure text encoder before speech prompting. Loaded as ``text_encoder.X.Y``."""
def __init__(self) -> None:
super().__init__()
self.text_embedder = _TextEmbedder()
self.convnext = TextConvNeXtStack()
self.attn_encoder = AttnEncoder()
def __call__(self, text_ids: mx.array, mask: mx.array) -> mx.array:
x = self.text_embedder(text_ids) # (B, T, 256)
if mask is not None:
x = x * mask
x = self.convnext(x, mask)
x = self.attn_encoder(x, mask)
return x
class _StyleEncoder(nn.Module):
"""Holds ``style_token_layer.style_key`` (1, 50, 256)."""
def __init__(self) -> None:
super().__init__()
# Use a child module so the parameter path matches ``style_token_layer.style_key``
class _StyleTokenLayer(nn.Module):
def __init__(_):
super().__init__()
_.style_key = mx.zeros((1, 50, 256))
self.style_token_layer = _StyleTokenLayer()
class _SpeechPromptedAttn(nn.Module):
"""Cross-attention from text (Q) to style_ttl (K, V). Single head, 256-d."""
def __init__(self) -> None:
super().__init__()
self.W_query = WrappedLinear(TE_DIM, TE_DIM, bias=True)
self.W_key = WrappedLinear(TE_DIM, TE_DIM, bias=True)
self.W_value = WrappedLinear(TE_DIM, TE_DIM, bias=True)
self.out_fc = WrappedLinear(TE_DIM, TE_DIM, bias=True)
def __call__(self, x: mx.array, style: mx.array) -> mx.array:
# x: (B, T_text, 256); style: (B, 50, 256)
# Single-head cross attention.
B, T, D = x.shape
q = self.W_query(x)
k = self.W_key(style)
v = self.W_value(style)
scale = D ** -0.5
logits = (q @ k.transpose(0, 2, 1)) * scale
attn = mx.softmax(logits, axis=-1)
out = attn @ v
return self.out_fc(out)
class _SpeechPromptedTextEncoder(nn.Module):
"""Two cross-attention layers modulating text features with style_ttl."""
def __init__(self) -> None:
super().__init__()
self.attention1 = _SpeechPromptedAttn()
self.attention2 = _SpeechPromptedAttn()
self.norm = WrappedNorm(TE_DIM, eps=EPS_LN)
def __call__(self, x: mx.array, style: mx.array) -> mx.array:
x = x + self.attention1(x, style)
x = x + self.attention2(x, style)
return self.norm(x)
class _RootTextEncoder(nn.Module):
"""Top-level container matching ONNX ``tts.ttl.*`` namespace."""
def __init__(self) -> None:
super().__init__()
self.text_encoder = _InnerTextEncoder()
self.style_encoder = _StyleEncoder()
self.speech_prompted_text_encoder = _SpeechPromptedTextEncoder()
class _TtsContainer(nn.Module):
"""Outer container so weight keys ``tts.ttl.X.Y`` resolve."""
def __init__(self) -> None:
super().__init__()
self.ttl = _RootTextEncoder()
class TextEncoder(nn.Module):
"""Top-level text encoder: ``text_ids + style_ttl + text_mask → text_emb (B, 256, T)``.
Submodule naming matches the ONNX initializer keys after a single
``tts.ttl.`` prefix wrap (so weight keys look like
``tts.ttl.text_encoder.convnext.convnext.0.dwconv.weight``).
"""
def __init__(self) -> None:
super().__init__()
self.tts = _TtsContainer()
def __call__(
self,
text_ids: mx.array, # (B, T_text) int
style_ttl: mx.array, # (B, 50, 256)
text_mask: mx.array, # (B, 1, T_text)
) -> mx.array:
mask_ntc = text_mask.transpose(0, 2, 1) # (B, T_text, 1)
x = self.tts.ttl.text_encoder(text_ids, mask_ntc)
x = self.tts.ttl.speech_prompted_text_encoder(x, style_ttl)
if mask_ntc is not None:
x = x * mask_ntc
# Return channels-first (B, 256, T_text) to match the vector_estimator input.
return x.transpose(0, 2, 1)
__all__ = ["TextEncoder", "VOCAB_SIZE", "TE_DIM"]

View File

@@ -0,0 +1,765 @@
"""Supertonic 3 vector estimator (64 M params) — flow-matching denoiser, MLX port.
Pipeline (operating in channels-last NTC layout):
noisy_latent [B, 144, T_lat] (channels first from ONNX I/O)
→ transpose [B, T_lat, 144]
→ proj_in (Linear 144→512) [B, T_lat, 512]
→ 24 main_blocks (4 cycles × 6 sub-types):
cycle = [stack4, time_film, cn1, text_attn, cn1, style_attn]
→ last_convnext (4 ConvNeXt) [B, T_lat, 512]
→ proj_out (Linear 512→144) [B, T_lat, 144]
→ transpose [B, 144, T_lat]
→ Euler step: denoised = noisy + velocity * (1 / total_step)
→ output [B, 144, T_lat]
Submodule naming matches the s3 ONNX initializer keys exactly, so loading
the safetensors produced by ``weights.convert_onnx_to_mlx`` requires no
remapping.
The forward path is faithful to ONNX semantics in fp32; ``mx.compile``,
quantisation, and kernel fusion are layered on later in T.3.
"""
from __future__ import annotations
import math
import mlx.core as mx
import mlx.nn as nn
from supertonic_3_mlx._config import (
DIM, LATENT_CH, CONVNEXT_HIDDEN, CONVNEXT_K, STACK4_DILATIONS,
NUM_MAIN_BLOCKS, BLOCKS_PER_CYCLE, BLOCK_CYCLE,
TEXT_DIM, TEXT_HEADS, TEXT_HEAD_DIM, ROTARY_BASE, ROTARY_SCALE,
STYLE_DIM, STYLE_LEN, STYLE_HEADS, STYLE_HEAD_DIM,
TIME_EMB_DIM, TIME_MLP_HIDDEN,
EPS_LN,
)
from supertonic_3_mlx._nn_wrappers import (
WrappedNorm, WrappedLinear, ProjConv1x1,
)
def _pad_sym_edge(x: mx.array, pad: int) -> mx.array:
"""Symmetric replicate-edge pad on the time axis (axis=1 for [B, T, C])."""
if pad == 0:
return x
left = mx.broadcast_to(x[:, :1, :], (x.shape[0], pad, x.shape[2]))
right = mx.broadcast_to(x[:, -1:, :], (x.shape[0], pad, x.shape[2]))
return mx.concatenate([left, x, right], axis=1)
def _gelu_exact(x: mx.array) -> mx.array:
"""Exact (non-tanh) GELU: x * 0.5 * (1 + erf(x / sqrt(2)))."""
return x * 0.5 * (1.0 + mx.erf(x * (2 ** -0.5)))
def _mish(x: mx.array) -> mx.array:
"""Mish: x * tanh(softplus(x)) = x * tanh(log(1 + exp(x)))."""
return x * mx.tanh(mx.logaddexp(x, mx.array(0.0, dtype=x.dtype)))
# ──────────────────────────────────────────────────────────────────
# ConvNeXt building blocks
# ──────────────────────────────────────────────────────────────────
class ConvNeXtBlock(nn.Module):
"""Single ConvNeXt block matching s3 keys: ``dwconv``, ``norm.norm``, ``pwconv1/2``, ``gamma``."""
def __init__(
self,
dim: int = DIM,
hidden: int = CONVNEXT_HIDDEN,
kernel: int = CONVNEXT_K,
dilation: int = 1,
) -> None:
super().__init__()
self.dim = dim
self.dilation = dilation
self.pad = dilation * (kernel - 1) // 2
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=kernel, padding=0, dilation=dilation,
groups=dim, bias=True,
)
self.norm = WrappedNorm(dim, eps=EPS_LN)
self.pwconv1 = nn.Linear(dim, hidden, bias=True)
self.pwconv2 = nn.Linear(hidden, dim, bias=True)
# Stored as shape (1, dim, 1) in the ONNX checkpoint — see weights.py for
# the load-time reshape that flattens it to (dim,) for broadcasting in NTC.
self.gamma = mx.zeros((dim,))
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
# x: (B, T, C)
residual = x
y = _pad_sym_edge(x, self.pad)
y = self.dwconv(y) # (B, T, C)
y = self.norm(y) # LayerNorm last-dim
y = self.pwconv1(y) # (B, T, hidden)
y = _gelu_exact(y)
y = self.pwconv2(y) # (B, T, C)
y = y * self.gamma # broadcast over (B, T, .)
out = residual + y
if mask is not None:
out = out * mask
return out
class ConvNeXtStack(nn.Module):
"""List of ConvNeXt blocks. Loaded as ``convnext.[0..N-1].X``."""
def __init__(self, dilations: tuple, dim: int = DIM, hidden: int = CONVNEXT_HIDDEN) -> None:
super().__init__()
self.convnext = [ConvNeXtBlock(dim, hidden, CONVNEXT_K, d) for d in dilations]
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
for b in self.convnext:
x = b(x, mask)
return x
# ──────────────────────────────────────────────────────────────────
# 6 block types per cycle
# ──────────────────────────────────────────────────────────────────
class Stack4Block(nn.Module):
"""Cycle position 0 — 4 ConvNeXt with dilations [1, 2, 4, 8].
Loaded keys: ``convnext.[0..3].{dwconv,norm.norm,pwconv1,pwconv2,gamma}``.
"""
def __init__(self) -> None:
super().__init__()
self.convnext = [ConvNeXtBlock(DIM, CONVNEXT_HIDDEN, CONVNEXT_K, d) for d in STACK4_DILATIONS]
def __call__(self, x: mx.array, mask: mx.array | None, **_) -> mx.array:
for b in self.convnext:
x = b(x, mask)
return x
class TimeFiLMBlock(nn.Module):
"""Cycle position 1 — additive time conditioning: ``x + linear(t_emb)``.
Loaded keys: ``linear.linear.{weight,bias}``.
"""
def __init__(self) -> None:
super().__init__()
self.linear = WrappedLinear(TIME_EMB_DIM, DIM, bias=True)
def __call__(self, x: mx.array, mask: mx.array | None, t_emb: mx.array, **_) -> mx.array:
# t_emb: (B, TIME_EMB_DIM) → broadcast across T
bias = self.linear(t_emb)[:, None, :] # (B, 1, DIM)
y = x + bias
if mask is not None:
y = y * mask
return y
class ConvNeXt1Block(nn.Module):
"""Cycle positions 2 and 4 — a single ConvNeXt block.
Loaded keys: ``convnext.0.{dwconv,norm.norm,pwconv1,pwconv2,gamma}``.
"""
def __init__(self) -> None:
super().__init__()
self.convnext = [ConvNeXtBlock(DIM, CONVNEXT_HIDDEN, CONVNEXT_K, 1)]
def __call__(self, x: mx.array, mask: mx.array | None, **_) -> mx.array:
return self.convnext[0](x, mask)
def _build_rope_freqs(head_dim: int, base: int, scale: int, max_len: int = 1024) -> mx.array:
"""Pre-compute RoPE cos/sin table — (max_len, head_dim/2, 2)."""
half = head_dim // 2
inv_freq = 1.0 / (base ** (mx.arange(half, dtype=mx.float32) / half))
pos = mx.arange(max_len, dtype=mx.float32) * scale
angles = pos[:, None] * inv_freq[None, :] # (max_len, half)
return mx.stack([mx.cos(angles), mx.sin(angles)], axis=-1) # (max_len, half, 2)
def _apply_rope(x: mx.array, freqs: mx.array) -> mx.array:
"""Apply RoPE rotation. ``x`` shape (B, H, T, head_dim); ``freqs`` (T, half, 2)."""
half = x.shape[-1] // 2
x_even, x_odd = x[..., :half], x[..., half:]
cos = freqs[..., 0] # (T, half)
sin = freqs[..., 1]
rot_even = x_even * cos[None, None, :, :] - x_odd * sin[None, None, :, :]
rot_odd = x_even * sin[None, None, :, :] + x_odd * cos[None, None, :, :]
return mx.concatenate([rot_even, rot_odd], axis=-1)
class TextCrossAttnBlock(nn.Module):
"""Cycle position 3 — text cross-attention with RoPE on Q and K.
Loaded keys:
``attn.W_query.linear.{weight,bias}``
``attn.W_key.linear.{weight,bias}``
``attn.W_value.linear.{weight,bias}``
``attn.out_fc.linear.{weight,bias}``
``attn.theta`` — frozen RoPE inv-freq table (1, 1, half)
``attn.increments`` — frozen position table (1, 1000, 1) — 0..999
``norm.norm.{weight,bias}``
"""
def __init__(self) -> None:
super().__init__()
self.attn = _AttnInner(DIM, TEXT_DIM, TEXT_HEADS, TEXT_HEAD_DIM)
self.norm = WrappedNorm(DIM, eps=EPS_LN)
def __call__(
self,
x: mx.array,
mask: mx.array | None,
*,
text_emb: mx.array | None = None,
text_mask: mx.array | None = None,
latent_seq_len: mx.array | None = None,
text_seq_len: mx.array | None = None,
kv_cache: tuple[mx.array, mx.array] | None = None,
**_,
) -> mx.array:
# x: (B, T_lat, DIM); text_emb: (B, T_text, TEXT_DIM) — unused when kv_cache supplied.
residual = x * mask if mask is not None else x
h = self.attn(
residual, text_emb, text_mask=text_mask,
latent_seq_len=latent_seq_len, text_seq_len=text_seq_len,
kv_cache=kv_cache,
)
if mask is not None:
h = h * mask
out = self.norm(residual + h)
if mask is not None:
out = out * mask
return out
class _AttnInner(nn.Module):
"""Multi-head cross-attention with RoPE applied to query and key.
Holds parameters under ``W_query``, ``W_key``, ``W_value``, ``out_fc`` —
each is a :class:`WrappedLinear` so its weight is keyed
``…W_query.linear.weight`` to match the ONNX checkpoint.
``theta`` and ``increments`` come from the ONNX graph as frozen tensors
(precomputed RoPE table). We rebuild the equivalent table from the
Supertonic-3 config so the module is self-contained.
"""
def __init__(
self,
in_dim: int,
ctx_dim: int,
num_heads: int,
head_dim: int,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
# ONNX divides attention logits by 16.0 (= sqrt(TEXT_DIM)), not sqrt(head_dim).
self.scale = ctx_dim ** -0.5
kv_dim = num_heads * head_dim # = DIM = 512
self.W_query = WrappedLinear(in_dim, kv_dim, bias=True)
self.W_key = WrappedLinear(ctx_dim, kv_dim, bias=True)
self.W_value = WrappedLinear(ctx_dim, kv_dim, bias=True)
self.out_fc = WrappedLinear(kv_dim, in_dim, bias=True)
# Frozen RoPE tables — overwritten by checkpoint at load time.
# ONNX layout:
# ``increments`` (1, 1000, 1) holds positions 0..999 (no scale baked in)
# ``theta`` (1, 1, half) holds rotary_scale × base^(-i/half)
# Angle formula: ``angle = (pos / actual_seq_len) × theta``.
# The division by the actual seq length is critical — it normalises
# absolute positions into [0, 1] so audio and text are RoPE-aligned
# regardless of their respective lengths.
max_len = 1000
half = head_dim // 2
idx = mx.arange(half, dtype=mx.float32)
self.theta = (ROTARY_SCALE * mx.exp(-math.log(ROTARY_BASE) * idx / half))[None, None, :]
positions = mx.arange(max_len, dtype=mx.int64)
self.increments = positions[None, :, None] # (1, max_len, 1)
def _rope(self, x: mx.array, seq_len: mx.array | int | None = None) -> mx.array:
"""Apply RoPE rotation. ``seq_len`` is the effective (unmasked) length.
Args:
x: (B, H, T, head_dim)
seq_len: scalar or (B,) — actual sequence length for position normalisation.
If None, defaults to T (no normalisation).
"""
T = x.shape[-2]
positions = self.increments[:, :T, :] # (1, T, 1)
if seq_len is None:
seq_len = float(T)
if isinstance(seq_len, (int, float)):
divisor = float(seq_len)
else:
divisor = seq_len.astype(mx.float32).reshape(-1, 1, 1)
norm_pos = positions / divisor # broadcasts to (B, T, 1) if divisor is (B,1,1)
angles = norm_pos * self.theta # (B, T, half) or (1, T, half)
cos = mx.cos(angles)
sin = mx.sin(angles)
half = self.head_dim // 2
# Broadcast (?, T, half) → (?, 1, T, half) for head dim
cos_b = cos[..., None, :, :] if cos.ndim == 3 else cos[None, None, :, :]
sin_b = sin[..., None, :, :] if sin.ndim == 3 else sin[None, None, :, :]
# Make sure broadcasts properly
if cos_b.shape[0] == 1 and x.shape[0] > 1:
cos_b = mx.broadcast_to(cos_b, (x.shape[0], 1, T, half))
sin_b = mx.broadcast_to(sin_b, (x.shape[0], 1, T, half))
# Reshape if needed
cos_b = cos_b.reshape(-1, 1, T, half)
sin_b = sin_b.reshape(-1, 1, T, half)
x_first, x_second = x[..., :half], x[..., half:]
rot_first = x_first * cos_b - x_second * sin_b
rot_second = x_first * sin_b + x_second * cos_b
return mx.concatenate([rot_first, rot_second], axis=-1)
def project_kv(
self,
text_emb: mx.array,
text_seq_len: mx.array | None = None,
) -> tuple[mx.array, mx.array]:
"""Project text_emb → (K_rope, V) once. Both are constant across the
Euler steps in a TTS inference call (T.5.3 cache target)."""
B, T_text, _ = text_emb.shape
H, D = self.num_heads, self.head_dim
k = self.W_key(text_emb).reshape(B, T_text, H, D).transpose(0, 2, 1, 3)
v = self.W_value(text_emb).reshape(B, T_text, H, D).transpose(0, 2, 1, 3)
k = self._rope(k, seq_len=text_seq_len if text_seq_len is not None else T_text)
return k, v
def __call__(
self,
x: mx.array,
text_emb: mx.array | None = None,
text_mask: mx.array | None = None,
latent_seq_len: mx.array | None = None,
text_seq_len: mx.array | None = None,
kv_cache: tuple[mx.array, mx.array] | None = None,
) -> mx.array:
B, T_lat, _ = x.shape
H, D = self.num_heads, self.head_dim
q = self.W_query(x).reshape(B, T_lat, H, D).transpose(0, 2, 1, 3) # (B, H, T_lat, D)
if kv_cache is not None:
k, v = kv_cache
else:
k, v = self.project_kv(text_emb, text_seq_len=text_seq_len)
# RoPE normalises positions by the effective (unmasked) sequence length.
q = self._rope(q, seq_len=latent_seq_len if latent_seq_len is not None else T_lat)
# Attention
logits = (q @ k.transpose(0, 1, 3, 2)) * self.scale # (B, H, T_lat, T_text)
if text_mask is not None:
neg_inf = mx.array(-1e4, dtype=logits.dtype)
logits = mx.where(text_mask[:, :, None, :].astype(mx.bool_), logits, neg_inf)
attn = mx.softmax(logits, axis=-1)
out = attn @ v # (B, H, T_lat, D)
out = out.transpose(0, 2, 1, 3).reshape(B, T_lat, H * D)
return self.out_fc(out)
class StyleCrossAttnBlock(nn.Module):
"""Cycle position 5 — style cross-attention to 50 learned style tokens.
Loaded keys:
``attention.W_query.linear.{weight,bias}``
``attention.W_key.linear.{weight,bias}``
``attention.W_value.linear.{weight,bias}``
``attention.out_fc.linear.{weight,bias}``
``norm.norm.{weight,bias}``
"""
def __init__(self) -> None:
super().__init__()
self.attention = _StyleAttnInner(DIM, STYLE_DIM, STYLE_HEADS, STYLE_HEAD_DIM)
self.norm = WrappedNorm(DIM, eps=EPS_LN)
def __call__(
self,
x: mx.array,
mask: mx.array | None,
*,
style_k: mx.array | None = None,
style_v: mx.array | None = None,
kv_cache: tuple[mx.array, mx.array] | None = None,
**_,
) -> mx.array:
# style_v defaults to style_k (same tensor for cond path); CFG path supplies
# different style_v to model the uncond branch.
if style_v is None and style_k is not None:
style_v = style_k
residual = x * mask if mask is not None else x
h = self.attention(residual, style_k, style_v, kv_cache=kv_cache)
if mask is not None:
h = h * mask
out = self.norm(residual + h)
if mask is not None:
out = out * mask
return out
class _StyleAttnInner(nn.Module):
def __init__(self, in_dim: int, ctx_dim: int, num_heads: int, head_dim: int) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
# ONNX divides attention logits by 16.0 (= sqrt(STYLE_DIM)), not sqrt(head_dim).
self.scale = ctx_dim ** -0.5
kv_dim = num_heads * head_dim # 2 * 128 = 256
# Q is on DIM (audio), K/V on ctx_dim (style 256)
self.W_query = WrappedLinear(in_dim, kv_dim, bias=True)
self.W_key = WrappedLinear(ctx_dim, kv_dim, bias=True)
self.W_value = WrappedLinear(ctx_dim, kv_dim, bias=True)
self.out_fc = WrappedLinear(kv_dim, in_dim, bias=True)
def project_kv(
self, style_k: mx.array, style_v: mx.array
) -> tuple[mx.array, mx.array]:
"""Project (style_k, style_v) → (K, V) once. T.5.3 cache target."""
B, T_style = style_k.shape[0], style_k.shape[1]
H, D = self.num_heads, self.head_dim
# Note: ONNX graph applies tanh to the K projection (``attention/tanh/Tanh``
# node) — the style key bank is bounded into [-1, 1] before softmax dot
# product, which acts as a soft attention temperature regulariser.
k = mx.tanh(self.W_key(style_k)).reshape(B, T_style, H, D).transpose(0, 2, 1, 3)
v = self.W_value(style_v).reshape(B, style_v.shape[1], H, D).transpose(0, 2, 1, 3)
return k, v
def __call__(
self,
x: mx.array,
style_k: mx.array | None = None,
style_v: mx.array | None = None,
kv_cache: tuple[mx.array, mx.array] | None = None,
) -> mx.array:
# style_k and style_v can be the same tensor (cond) or distinct (uncond
# branch in CFG, where K comes from style_key_special_token and V from
# style_value_special_token).
B, T_lat, _ = x.shape
H, D = self.num_heads, self.head_dim
q = self.W_query(x).reshape(B, T_lat, H, D).transpose(0, 2, 1, 3)
if kv_cache is not None:
k, v = kv_cache
else:
k, v = self.project_kv(style_k, style_v)
logits = (q @ k.transpose(0, 1, 3, 2)) * self.scale
attn = mx.softmax(logits, axis=-1)
out = attn @ v
out = out.transpose(0, 2, 1, 3).reshape(B, T_lat, H * D)
return self.out_fc(out)
# ──────────────────────────────────────────────────────────────────
# Time encoder
# ──────────────────────────────────────────────────────────────────
class _MlpItem(nn.Module):
"""A single MLP layer wrapped to produce keys ``mlp.N.linear.{weight,bias}``."""
def __init__(self, in_dim: int, out_dim: int) -> None:
super().__init__()
self.linear = nn.Linear(in_dim, out_dim, bias=True)
def __call__(self, x: mx.array) -> mx.array:
return self.linear(x)
class TimeEncoder(nn.Module):
"""Sinusoidal time embedding + 2-layer MLP. Keys: ``mlp.0.linear``, ``mlp.2.linear``."""
def __init__(self) -> None:
super().__init__()
# ONNX: mlp.0.linear (64→256), mlp.2.linear (256→64). Index 1 is activation.
self.mlp = [
_MlpItem(TIME_EMB_DIM, TIME_MLP_HIDDEN), # mlp.0
nn.Identity(), # mlp.1 (activation; no weights)
_MlpItem(TIME_MLP_HIDDEN, TIME_EMB_DIM), # mlp.2
]
def __call__(self, t: mx.array) -> mx.array:
# t: (B,) — produce sinusoidal embedding then run through MLP.
# Activation is Mish (not SiLU) to match the ONNX graph
# (Softplus → Tanh → Mul pattern == x * tanh(softplus(x))).
emb = self._sinusoidal(t, TIME_EMB_DIM)
h = self.mlp[0](emb)
h = _mish(h)
h = self.mlp[2](h)
return h
@staticmethod
def _sinusoidal(t: mx.array, dim: int) -> mx.array:
"""Time embedding matching ``Supertonic-3`` ONNX exactly.
ONNX path: pos = t * 1000; freqs[i] = 10000^(-i/(half-1));
concat[sin(pos*freqs), cos(pos*freqs)].
"""
half = dim // 2
denom = max(half - 1, 1)
freqs = mx.exp(-math.log(10_000) * mx.arange(half, dtype=mx.float32) / denom)
pos = t.astype(mx.float32)[:, None] * 1000.0
angles = pos * freqs[None, :]
return mx.concatenate([mx.sin(angles), mx.cos(angles)], axis=-1).astype(mx.float32)
# ──────────────────────────────────────────────────────────────────
# Top-level VectorEstimator
# ──────────────────────────────────────────────────────────────────
def _build_main_block(idx: int) -> nn.Module:
"""Instantiate the appropriate block class for cycle position ``idx % 6``."""
pos = idx % BLOCKS_PER_CYCLE
name = BLOCK_CYCLE[pos]
if name == "stack4":
return Stack4Block()
if name == "time":
return TimeFiLMBlock()
if name == "cn1":
return ConvNeXt1Block()
if name == "text_attn":
return TextCrossAttnBlock()
if name == "style_attn":
return StyleCrossAttnBlock()
raise RuntimeError(f"unknown block type for index {idx}: {name}")
class _VectorField(nn.Module):
"""Inner module mirroring ONNX ``vector_estimator.tts.ttl.vector_field.*``."""
def __init__(self) -> None:
super().__init__()
self.proj_in = ProjConv1x1(LATENT_CH, DIM, bias=False)
self.main_blocks = [_build_main_block(i) for i in range(NUM_MAIN_BLOCKS)]
self.last_convnext = ConvNeXtStack(dilations=(1, 1, 1, 1), dim=DIM, hidden=CONVNEXT_HIDDEN)
self.proj_out = ProjConv1x1(DIM, LATENT_CH, bias=False)
self.time_encoder = TimeEncoder()
class _UncondMasker(nn.Module):
"""Holds the three unconditional-token tensors used by CFG.
Keys:
``text_special_token`` (1, 256, 1)
``style_key_special_token`` (1, 50, 256)
``style_value_special_token`` (1, 50, 256)
"""
def __init__(self) -> None:
super().__init__()
# Initialised to zero; checkpoint provides real values.
self.text_special_token = mx.zeros((1, TEXT_DIM, 1))
self.style_key_special_token = mx.zeros((1, STYLE_LEN, STYLE_DIM))
self.style_value_special_token = mx.zeros((1, STYLE_LEN, STYLE_DIM))
class VectorEstimator(nn.Module):
"""Top-level module — matches ONNX root names ``vector_field.*`` and ``uncond_masker.*``.
Two inference paths:
- :meth:`velocity`: single forward pass; predicts the velocity from one set
of conditioning inputs. ``style_k``/``style_v`` may be the same tensor
(cond path) or different (uncond path of CFG).
- :meth:`__call__`: full ONNX-parity forward — applies CFG batch doubling
(cond + uncond) internally and combines via
``final = noisy + (4*cond - 3*uncond) / total_step``.
"""
# CFG guidance constants — baked into the ONNX graph as ``/Constant_3`` (=4.0)
# and ``/Constant_4`` (=3.0). Equivalent to guidance_scale = 4 with the
# standard formula ``v = uncond + g*(cond - uncond) = 4*cond - 3*uncond``.
CFG_COND_SCALE: float = 4.0
CFG_UNCOND_SCALE: float = 3.0
def __init__(self) -> None:
super().__init__()
self.vector_field = _VectorField()
self.uncond_masker = _UncondMasker()
# ── inference API ─────────────────────────────────────────────
def velocity(
self,
noisy_latent: mx.array, # (B, 144, T_lat)
text_emb: mx.array, # (B, 256, T_text)
style_k: mx.array, # (B, 50, 256) — K side of style attention
style_v: mx.array, # (B, 50, 256) — V side of style attention
latent_mask: mx.array, # (B, 1, T_lat)
text_mask: mx.array, # (B, 1, T_text)
t_norm: mx.array, # (B,) timestep in [0, 1]
) -> mx.array:
"""Predict velocity (B, 144, T_lat) without applying CFG or Euler step."""
x = noisy_latent.transpose(0, 2, 1) # (B, T_lat, 144)
text = text_emb.transpose(0, 2, 1) # (B, T_text, 256)
lat_mask_ntc = latent_mask.transpose(0, 2, 1) # (B, T_lat, 1)
x = self.vector_field.proj_in(x) # (B, T_lat, 512)
t_emb = self.vector_field.time_encoder(t_norm) # (B, TIME_EMB_DIM)
# Effective (unmasked) sequence lengths for RoPE normalisation —
# ONNX uses ``ReduceSum(mask)`` for this so that audio and text are
# rope-aligned regardless of padding.
latent_seq_len = mx.sum(latent_mask, axis=(1, 2)) # (B,)
text_seq_len = mx.sum(text_mask, axis=(1, 2)) # (B,)
for blk in self.vector_field.main_blocks:
x = blk(
x,
lat_mask_ntc,
t_emb=t_emb,
text_emb=text,
text_mask=text_mask,
style_k=style_k,
style_v=style_v,
latent_seq_len=latent_seq_len,
text_seq_len=text_seq_len,
)
x = self.vector_field.last_convnext(x, lat_mask_ntc)
v_ntc = self.vector_field.proj_out(x) # (B, T_lat, 144)
return v_ntc.transpose(0, 2, 1) # (B, 144, T_lat)
# ── T.5.3 — pre-projected K/V path ────────────────────────────
def precompute_cross_kv(
self,
text_emb: mx.array, # (B, 256, T_text) channels-first
style_k: mx.array, # (B, 50, 256)
style_v: mx.array, # (B, 50, 256)
text_mask: mx.array, # (B, 1, T_text)
) -> tuple[list[tuple[mx.array, mx.array]], list[tuple[mx.array, mx.array]]]:
"""Project K/V for every text_attn and style_attn block exactly once.
Returns ``(text_kv_list, style_kv_list)`` — both ordered to align with
the corresponding blocks encountered when iterating ``main_blocks``.
These tensors are invariant across the 5 Euler steps of one TTS
call; pre-projecting them once and feeding the result into
:meth:`velocity_cached` cuts ~ 4 × 2 × 5 = 40 redundant matmuls.
"""
text_seq_len = mx.sum(text_mask, axis=(1, 2))
text_ntc = text_emb.transpose(0, 2, 1) # (B, T_text, 256)
text_kv: list[tuple[mx.array, mx.array]] = []
style_kv: list[tuple[mx.array, mx.array]] = []
for blk in self.vector_field.main_blocks:
if isinstance(blk, TextCrossAttnBlock):
text_kv.append(blk.attn.project_kv(text_ntc, text_seq_len=text_seq_len))
elif isinstance(blk, StyleCrossAttnBlock):
style_kv.append(blk.attention.project_kv(style_k, style_v))
return text_kv, style_kv
def velocity_cached(
self,
noisy_latent: mx.array,
latent_mask: mx.array,
text_mask: mx.array,
t_norm: mx.array,
text_kv: list[tuple[mx.array, mx.array]],
style_kv: list[tuple[mx.array, mx.array]],
) -> mx.array:
"""Same as :meth:`velocity` but reads K/V from pre-projected caches.
``text_kv`` and ``style_kv`` must come from :meth:`precompute_cross_kv`
applied to the same (batched) conditioning tensors that will be
active for this call.
"""
x = noisy_latent.transpose(0, 2, 1)
lat_mask_ntc = latent_mask.transpose(0, 2, 1)
x = self.vector_field.proj_in(x)
t_emb = self.vector_field.time_encoder(t_norm)
latent_seq_len = mx.sum(latent_mask, axis=(1, 2))
ti = 0
si = 0
for blk in self.vector_field.main_blocks:
if isinstance(blk, TextCrossAttnBlock):
x = blk(
x, lat_mask_ntc,
text_mask=text_mask,
latent_seq_len=latent_seq_len,
kv_cache=text_kv[ti],
)
ti += 1
elif isinstance(blk, StyleCrossAttnBlock):
x = blk(x, lat_mask_ntc, kv_cache=style_kv[si])
si += 1
else:
x = blk(x, lat_mask_ntc, t_emb=t_emb)
x = self.vector_field.last_convnext(x, lat_mask_ntc)
v_ntc = self.vector_field.proj_out(x)
return v_ntc.transpose(0, 2, 1)
def __call__(
self,
noisy_latent: mx.array, # (B, 144, T_lat) channels-first per ONNX I/O
text_emb: mx.array, # (B, 256, T_text) channels-first
style_ttl: mx.array, # (B, 50, 256) — used as both K and V for cond
latent_mask: mx.array, # (B, 1, T_lat)
text_mask: mx.array, # (B, 1, T_text)
current_step: mx.array, # (B,)
total_step: mx.array, # (B,)
cfg: bool = True,
) -> mx.array:
"""Run one Euler step with CFG (matches ONNX semantics).
With ``cfg=True`` (default) the model runs both conditional and
unconditional paths in a single batched forward and combines via
``final = noisy + (4*cond_v - 3*uncond_v) / total_step``.
With ``cfg=False`` only the conditional path runs — half the work, but
produces a different (lower-quality) output. Useful for speed bench /
sanity tests.
"""
B = noisy_latent.shape[0]
t_norm = current_step.astype(mx.float32) / total_step.astype(mx.float32)
if not cfg:
v = self.velocity(
noisy_latent, text_emb, style_ttl, style_ttl,
latent_mask, text_mask, t_norm,
)
return noisy_latent + v / total_step.reshape(-1, 1, 1).astype(noisy_latent.dtype)
# CFG branch — build (2B, ...) inputs by concatenating cond + uncond.
# uncond text_emb = text_special_token broadcast to (B, 256, T_text).
# uncond style_k = style_key_special_token broadcast, similarly style_v.
text_uncond = mx.broadcast_to(
self.uncond_masker.text_special_token, (B, TEXT_DIM, text_emb.shape[2])
)
style_k_uncond = mx.broadcast_to(
self.uncond_masker.style_key_special_token, (B, STYLE_LEN, STYLE_DIM)
)
style_v_uncond = mx.broadcast_to(
self.uncond_masker.style_value_special_token, (B, STYLE_LEN, STYLE_DIM)
)
noisy_2 = mx.concatenate([noisy_latent, noisy_latent], axis=0)
text_2 = mx.concatenate([text_emb, text_uncond], axis=0)
style_k_2 = mx.concatenate([style_ttl, style_k_uncond], axis=0)
style_v_2 = mx.concatenate([style_ttl, style_v_uncond], axis=0)
lm_2 = mx.concatenate([latent_mask, latent_mask], axis=0)
tm_2 = mx.concatenate([text_mask, text_mask], axis=0)
t_norm_2 = mx.concatenate([t_norm, t_norm], axis=0)
v_2 = self.velocity(
noisy_2, text_2, style_k_2, style_v_2, lm_2, tm_2, t_norm_2,
) # (2B, 144, T_lat)
cond_v = v_2[:B]
uncond_v = v_2[B:2 * B]
combined_v = self.CFG_COND_SCALE * cond_v - self.CFG_UNCOND_SCALE * uncond_v
return noisy_latent + combined_v / total_step.reshape(-1, 1, 1).astype(noisy_latent.dtype)
__all__ = [
"ConvNeXtBlock", "ConvNeXtStack",
"Stack4Block", "TimeFiLMBlock", "ConvNeXt1Block",
"TextCrossAttnBlock", "StyleCrossAttnBlock",
"TimeEncoder", "VectorEstimator",
]

View File

@@ -0,0 +1,304 @@
"""Supertonic 3 vocoder — latent → 44.1 kHz waveform, MLX port.
Pipeline (operating in channels-last NTC layout, then converted to channels-first
for output reshape):
latent [B, 144, T_lat] (output of vector_estimator)
→ /= normalizer.scale (scalar)
→ reshape [B, 24, T_lat*6] # de-compress
→ (* latent_std + latent_mean) # de-normalise
→ transpose to NTC [B, T_lat*6, 24]
→ embed Conv1d(24→512, k=7, sym-edge pad) [B, T_lat*6, 512]
→ 10× ConvNeXt(dim=512, hidden=2048, k=7,
dilations [1,2,4,1,2,4,1,1,1,1])
→ final_norm: BatchNorm1d (eval-time: running stats only)
→ head.layer1: Conv1d(512→2048, k=3, sym-edge pad)
→ PReLU (with per-channel learnable slope)
→ head.layer2: Conv1d(2048→512, k=1, no bias)
→ transpose to (B, 512, T_lat*6) → flatten → wav (B, T_lat*6*512)
The 512 samples/step × 6 chunk × 44.1 kHz → T_lat steps of about 0.0697 s each.
"""
from __future__ import annotations
import mlx.core as mx
import mlx.nn as nn
from supertonic_3_mlx._config import EPS_LN
from supertonic_3_mlx._nn_wrappers import WrappedNorm
from supertonic_3_mlx.vector_estimator import _gelu_exact
def _pad_left_edge(x: mx.array, pad: int) -> mx.array:
"""Causal replicate-edge pad on the time axis (axis=1 for [B, T, C]).
Pads ``pad`` time-steps on the LEFT only by replicating the first frame.
Matches the ONNX vocoder pads spec ``[0, 0, pad, 0, 0, 0]``.
"""
if pad == 0:
return x
left = mx.broadcast_to(x[:, :1, :], (x.shape[0], pad, x.shape[2]))
return mx.concatenate([left, x], axis=1)
VOC_DIM = 512
VOC_HIDDEN = 2048
VOC_K = 7
VOC_HEAD_K = 3
VOC_LDIM = 24 # de-compressed channels (24 × 6 = 144 input)
VOC_CHUNK_COMPRESS = 6
VOC_NUM_CONVNEXT_LAYERS = 10
VOC_DILATIONS = (1, 2, 4, 1, 2, 4, 1, 1, 1, 1)
EPS_BN = 1e-5
class _Conv1dNet(nn.Module):
"""Conv1d wrapped under ``.net`` to match ONNX storage ``.net.weight/bias``."""
def __init__(self, in_dim: int, out_dim: int, kernel: int, dilation: int = 1,
groups: int = 1, bias: bool = True) -> None:
super().__init__()
class _Net(nn.Module):
def __init__(_):
super().__init__()
# MLX Conv1d weight: (out, K, in/groups)
_.weight = mx.zeros((out_dim, kernel, in_dim // groups))
if bias:
_.bias = mx.zeros((out_dim,))
else:
_.bias = None
def __call__(_, x, dilation=1):
y = mx.conv1d(x, _.weight, stride=1, padding=0, dilation=dilation,
groups=groups)
if _.bias is not None:
y = y + _.bias
return y
self.net = _Net()
self.dilation = dilation
self.groups = groups
self.kernel = kernel
def __call__(self, x: mx.array) -> mx.array:
return self.net(x, dilation=self.dilation)
class _VocConvNeXtBlock(nn.Module):
"""ConvNeXt block matching keys ``convnext.N.{dwconv.net,norm.norm,pwconv1,pwconv2,gamma}``."""
def __init__(self, dilation: int) -> None:
super().__init__()
self.dilation = dilation
self.pad = dilation * (VOC_K - 1)
self.dwconv = _Conv1dNet(VOC_DIM, VOC_DIM, kernel=VOC_K, dilation=dilation,
groups=VOC_DIM, bias=True)
self.norm = WrappedNorm(VOC_DIM, eps=EPS_LN)
# pwconv1 / pwconv2 stored as Conv1d k=1 → loaded after squeeze to Linear.
self.pwconv1 = nn.Linear(VOC_DIM, VOC_HIDDEN, bias=True)
self.pwconv2 = nn.Linear(VOC_HIDDEN, VOC_DIM, bias=True)
self.gamma = mx.zeros((VOC_DIM,))
def __call__(self, x: mx.array) -> mx.array:
residual = x
y = _pad_left_edge(x, self.pad)
y = self.dwconv(y)
y = self.norm(y)
y = self.pwconv1(y)
y = _gelu_exact(y)
y = self.pwconv2(y)
y = y * self.gamma
return residual + y
class _BatchNorm1dEval(nn.Module):
"""Eval-mode BatchNorm1d: applies stored running_mean/running_var only.
Loaded keys: ``norm.{weight,bias,running_mean,running_var}``.
"""
def __init__(self) -> None:
super().__init__()
class _Norm(nn.Module):
def __init__(_):
super().__init__()
_.weight = mx.ones((VOC_DIM,))
_.bias = mx.zeros((VOC_DIM,))
_.running_mean = mx.zeros((VOC_DIM,))
_.running_var = mx.ones((VOC_DIM,))
def __call__(_, x):
# x: (B, T, C). BN1d normalises across batch+time per channel.
# Eval mode: use stored running stats.
norm = (x - _.running_mean) * mx.rsqrt(_.running_var + EPS_BN)
return norm * _.weight + _.bias
self.norm = _Norm()
def __call__(self, x: mx.array) -> mx.array:
return self.norm(x)
class _VocHeadActivation(nn.Module):
"""PReLU with per-channel learnable slope (weight shape (C,))."""
def __init__(self) -> None:
super().__init__()
# ONNX anonymous PReLU stores slope of shape (1,) sometimes or (C,).
# We default to (1,) and reshape on load if needed.
self.weight = mx.zeros((1,))
def __call__(self, x: mx.array) -> mx.array:
# PReLU: max(0, x) + slope × min(0, x).
# slope broadcasts over (B, T, C) or (B, C, T) depending on layout.
zero = mx.array(0.0, dtype=x.dtype)
return mx.maximum(x, zero) + self.weight * mx.minimum(x, zero)
class _VocHead(nn.Module):
"""``head.layer1`` (Conv1d 512→2048 k=3) + ``head.act`` (PReLU) + ``head.layer2`` (Conv1d k=1, no bias)."""
def __init__(self) -> None:
super().__init__()
self.layer1 = _Conv1dNet(VOC_DIM, VOC_HIDDEN, kernel=VOC_HEAD_K, bias=True)
self.act = _VocHeadActivation()
# layer2 has no .net wrapper in ONNX (different from layer1)
# ONNX: head.layer2.weight (512, 2048, 1) — Conv1d k=1, no bias.
# We represent it directly without .net wrap.
self.layer2 = _VocLayer2()
def __call__(self, x: mx.array) -> mx.array:
# x: (B, T, 512)
pad = VOC_HEAD_K - 1
y = _pad_left_edge(x, pad)
y = self.layer1(y) # (B, T, 2048)
y = self.act(y)
y = self.layer2(y) # (B, T, 512)
return y
class _VocLayer2(nn.Module):
"""Conv1d k=1 (2048 → 512), no bias. Keys: ``layer2.weight (512, 2048, 1)``."""
def __init__(self) -> None:
super().__init__()
# MLX Conv1d weight shape: (out, K, in/groups) = (512, 1, 2048)
# ONNX storage: (out, in, 1) = (512, 2048, 1). Same size; reshape on load.
self.weight = mx.zeros((VOC_DIM, 1, VOC_HIDDEN))
def __call__(self, x: mx.array) -> mx.array:
return mx.conv1d(x, self.weight, stride=1, padding=0)
class _VocEmbed(nn.Module):
"""Initial Conv1d(24→512, k=7) with sym-edge pad.
The weight + bias are anonymous in the ONNX graph (``onnx::Conv_1441`` and
``onnx::Conv_1442``); the conversion recovers them via the Conv node path
``/decoder/embed/net/Conv`` → structured name ``tts.ae.decoder.embed.net.{weight,bias}``.
"""
def __init__(self) -> None:
super().__init__()
class _Net(nn.Module):
def __init__(_):
super().__init__()
_.weight = mx.zeros((VOC_DIM, VOC_K, VOC_LDIM))
_.bias = mx.zeros((VOC_DIM,))
def __call__(_, x):
return mx.conv1d(x, _.weight, stride=1, padding=0) + _.bias
self.net = _Net()
def __call__(self, x: mx.array) -> mx.array:
pad = VOC_K - 1
y = _pad_left_edge(x, pad)
return self.net(y)
class _VocDecoder(nn.Module):
"""``tts.ae.decoder.X`` namespace."""
def __init__(self) -> None:
super().__init__()
self.embed = _VocEmbed()
self.convnext = [_VocConvNeXtBlock(d) for d in VOC_DILATIONS]
self.final_norm = _BatchNorm1dEval()
self.head = _VocHead()
class _AEContainer(nn.Module):
"""``tts.ae.X`` — holds latent_mean, latent_std, decoder."""
def __init__(self) -> None:
super().__init__()
self.latent_mean = mx.zeros((1, VOC_LDIM, 1))
self.latent_std = mx.ones((1, VOC_LDIM, 1))
self.decoder = _VocDecoder()
class _TtlContainer(nn.Module):
"""``tts.ttl.normalizer.scale`` (scalar) — divides the latent before de-norm."""
def __init__(self) -> None:
super().__init__()
class _Normalizer(nn.Module):
def __init__(_):
super().__init__()
_.scale = mx.array(1.0)
self.normalizer = _Normalizer()
class _TtsContainer(nn.Module):
def __init__(self) -> None:
super().__init__()
self.ttl = _TtlContainer()
self.ae = _AEContainer()
class Vocoder(nn.Module):
"""Latent → waveform decoder (44.1 kHz mono).
Submodule namespace matches ONNX keys ``tts.X.Y`` exactly.
"""
def __init__(self) -> None:
super().__init__()
self.tts = _TtsContainer()
def __call__(self, latent: mx.array) -> mx.array:
# latent: (B, 144, T_lat)
B = latent.shape[0]
T_lat = latent.shape[2]
# /= scale (scalar)
x = latent / self.tts.ttl.normalizer.scale
# reshape (B, 144, T_lat) → (B, 24, T_lat*6)
x = x.reshape(B, VOC_LDIM, VOC_CHUNK_COMPRESS, T_lat) # (B, 24, 6, T_lat)
x = x.transpose(0, 1, 3, 2) # (B, 24, T_lat, 6)
x = x.reshape(B, VOC_LDIM, T_lat * VOC_CHUNK_COMPRESS) # (B, 24, T_lat*6)
# De-normalise: (* std + mean)
x = x * self.tts.ae.latent_std + self.tts.ae.latent_mean
# Transpose to NTC for Conv1d layers
x = x.transpose(0, 2, 1) # (B, T_lat*6, 24)
# embed
x = self.tts.ae.decoder.embed(x) # (B, T_lat*6, 512)
# 10× ConvNeXt
for blk in self.tts.ae.decoder.convnext:
x = blk(x)
# final_norm (BatchNorm1d eval)
x = self.tts.ae.decoder.final_norm(x)
# head
x = self.tts.ae.decoder.head(x) # (B, T_lat*6, 512)
# Flatten time × channels row-major → waveform (matches ONNX:
# head.layer2 Conv (B, 512, T_lat*6) → Transpose to (B, T_lat*6, 512) →
# Reshape to (B, T_lat*6*512). Since the head already runs in NTC, we
# are already in the post-Transpose layout and only the Reshape remains).
wav = x.reshape(B, -1) # (B, T_lat*6*512)
return wav
__all__ = ["Vocoder", "VOC_DIM", "VOC_HIDDEN", "VOC_LDIM", "VOC_CHUNK_COMPRESS"]

View File

@@ -0,0 +1,152 @@
"""ONNX → MLX safetensors conversion for Supertonic 3.
Two-stage extraction:
1. **Named initializers** (e.g. ``vector_estimator.tts.ttl.vector_field.main_blocks.0.convnext.0.dwconv.weight``)
— straight name strip + optional shape transformation.
2. **Anonymous MatMul weights** (e.g. ``onnx::MatMul_3391``) — looked up via the
MatMul node graph: each MatMul output path is the human-readable name of the
weight (e.g. ``…/W_query/linear/MatMul_output_0``); we trace the second
operand initializer and rebind it to the structured name + transpose to
the MLX Linear layout ``(out, in)``.
Shape transformations:
- depthwise dwconv: ONNX ``(C, 1, K)`` → MLX ``(C, K, 1)``
- pwconv1/2 k=1: ONNX ``(out, in, 1)`` → MLX ``(out, in)``
- proj_in/out k=1: ONNX ``(out, in, 1)`` → MLX ``(out, in)``
- MatMul Linear: ONNX ``(in, out)`` → MLX ``(out, in)``
- gamma: ONNX ``(1, dim, 1)`` → MLX ``(dim,)``
"""
from __future__ import annotations
from pathlib import Path
from typing import Dict, Tuple
import mlx.core as mx
import numpy as np
_ONNX_PREFIX = "vector_estimator.tts.ttl."
_DWCONV_SUFFIX = ".dwconv.weight"
_PWCONV_SUFFIXES = (".pwconv1.weight", ".pwconv2.weight")
_GAMMA_SUFFIX = ".gamma"
def _strip_prefix(name: str) -> str:
if name.startswith(_ONNX_PREFIX):
return name[len(_ONNX_PREFIX):]
return name
def _is_named_weight(name: str) -> bool:
"""True if this is a structured weight (vs anonymous graph constant)."""
if name.startswith(_ONNX_PREFIX):
return True
if name.startswith("uncond_masker."):
return True
return False
def _convert_named(clean_name: str, arr: np.ndarray) -> np.ndarray:
"""Apply shape transforms to a named initializer based on its key."""
# Depthwise Conv1d weight: (C, 1, K) → (C, K, 1)
if clean_name.endswith(_DWCONV_SUFFIX) and arr.ndim == 3 and arr.shape[1] == 1 and arr.shape[2] != 1:
arr = np.transpose(arr, (0, 2, 1))
# Pointwise k=1 / proj net weight: (out, in, 1) → (out, in)
if (any(clean_name.endswith(s) for s in _PWCONV_SUFFIXES) or clean_name.endswith(".net.weight")) \
and arr.ndim == 3 and arr.shape[-1] == 1:
arr = arr.squeeze(-1)
# gamma: (1, C, 1) → (C,)
if clean_name.endswith(_GAMMA_SUFFIX) and arr.ndim == 3 and arr.shape[0] == 1 and arr.shape[2] == 1:
arr = arr.reshape(arr.shape[1])
return arr
def _matmul_output_to_clean_name(matmul_output: str) -> str:
"""Map a MatMul node output path to the structured ``.weight`` key.
Example::
/vector_estimator/vector_field/main_blocks.3/attn/W_query/linear/MatMul_output_0
→ vector_field.main_blocks.3.attn.W_query.linear.weight
"""
# Strip prefix slash and the trailing /MatMul_output_0
path = matmul_output.lstrip("/")
if path.endswith("/MatMul_output_0"):
path = path[: -len("/MatMul_output_0")]
# Drop leading "vector_estimator/" if present
if path.startswith("vector_estimator/"):
path = path[len("vector_estimator/"):]
return path.replace("/", ".") + ".weight"
def convert_onnx_to_mlx(onnx_path: str | Path) -> Dict[str, mx.array]:
"""Load an ONNX model and return all weights as ``{clean_name: mx.array}``.
Combines named initializers and MatMul-only weights into a single dict ready
for ``model.load_weights(...)``.
"""
import onnx
from onnx import numpy_helper
model = onnx.load(str(onnx_path))
# Build initializer name → numpy array map (in-memory once)
inits: Dict[str, np.ndarray] = {
init.name: numpy_helper.to_array(init) for init in model.graph.initializer
}
out: Dict[str, mx.array] = {}
# Stage 1: named initializers
for name, arr in inits.items():
if not _is_named_weight(name):
continue
clean = _strip_prefix(name)
arr = _convert_named(clean, arr)
out[clean] = mx.array(arr)
# Stage 2: anonymous MatMul weights, recovered via the graph
for node in model.graph.node:
if node.op_type != "MatMul":
continue
if len(node.input) < 2:
continue
# The weight is conventionally the second operand
weight_name = node.input[1]
if weight_name not in inits:
continue
# Skip if it's already named structurally (shouldn't happen here)
if _is_named_weight(weight_name):
continue
# Look up the structured name from the MatMul output path
if len(node.output) < 1:
continue
clean = _matmul_output_to_clean_name(node.output[0])
# ONNX MatMul stores W as (in, out); MLX Linear expects (out, in)
arr = inits[weight_name]
if arr.ndim == 2:
arr = arr.T
out[clean] = mx.array(arr)
if not out:
raise RuntimeError(f"no weights extracted from {onnx_path}")
return out
def save_safetensors(
onnx_path: str | Path,
output_path: str | Path,
) -> Dict[str, Tuple[int, ...]]:
"""Convert an ONNX file to MLX safetensors. Returns a {name: shape} map."""
weights = convert_onnx_to_mlx(onnx_path)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
mx.save_safetensors(str(output_path), weights)
return {k: tuple(v.shape) for k, v in weights.items()}
__all__ = ["convert_onnx_to_mlx", "save_safetensors"]

1
unicode_indexer.json Normal file

File diff suppressed because one or more lines are too long

1
voice_styles/F1.json Normal file

File diff suppressed because one or more lines are too long

1
voice_styles/F2.json Normal file

File diff suppressed because one or more lines are too long

1
voice_styles/F3.json Normal file

File diff suppressed because one or more lines are too long

1
voice_styles/F4.json Normal file

File diff suppressed because one or more lines are too long

1
voice_styles/F5.json Normal file

File diff suppressed because one or more lines are too long

1
voice_styles/M1.json Normal file

File diff suppressed because one or more lines are too long

1
voice_styles/M2.json Normal file

File diff suppressed because one or more lines are too long

1
voice_styles/M3.json Normal file

File diff suppressed because one or more lines are too long

1
voice_styles/M4.json Normal file

File diff suppressed because one or more lines are too long

1
voice_styles/M5.json Normal file

File diff suppressed because one or more lines are too long