v0.1.0 — initial release
MLX-native port of Supertone's Supertonic 3 multilingual TTS. Runs the full flow-matching + classifier-free-guidance pipeline at ~x100 realtime on Apple Silicon, with audio cosine 1.0 vs the cached MLX path and cosine 0.98 vs the upstream ONNX Runtime reference. Weights are hosted at https://huggingface.co/ambassadia/supertonic-3-mlx and auto-downloaded on first use; this repository ships the port code, the model card, audio samples, and a zero-config setup_and_test.sh. Install: pip install git+https://gitea.tavportal.com/olivier/supertonic-3-mlx.git Quick test: git clone https://gitea.tavportal.com/olivier/supertonic-3-mlx.git cd supertonic-3-mlx && ./setup_and_test.sh Licenses (dual): model weights = BigScience Open RAIL-M (Section 4 propagation), port code = Apache-2.0. See LICENSE, LICENSE-CODE, NOTICE. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
.venv/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
hello.wav
|
||||
*.egg-info/
|
||||
build/
|
||||
dist/
|
||||
209
LICENSE
Normal file
209
LICENSE
Normal file
@@ -0,0 +1,209 @@
|
||||
BigScience Open RAIL-M License
|
||||
dated August 18, 2022
|
||||
|
||||
Section I: PREAMBLE
|
||||
|
||||
This Open RAIL-M License was created by BigScience, a collaborative open innovation project aimed at
|
||||
the responsible development and use of large multilingual datasets and Large Language Models
|
||||
(“LLMs”). While a similar license was originally designed for the BLOOM model, we decided to adapt it
|
||||
and create this license in order to propose a general open and responsible license applicable to other
|
||||
machine learning based AI models (e.g. multimodal generative models).
|
||||
In short, this license strives for both the open and responsible downstream use of the accompanying
|
||||
model. When it comes to the open character, we took inspiration from open source permissive licenses
|
||||
regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based
|
||||
restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be
|
||||
able to enforce the license in case potential misuses of the Model may occur. Even though downstream
|
||||
derivative versions of the model could be released under different licensing terms, the latter will always
|
||||
have to include - at minimum - the same use-based restrictions as the ones in the original license (this
|
||||
license).
|
||||
The development and use of artificial intelligence (“AI”), does not come without concerns. The world has
|
||||
witnessed how AI techniques may, in some instances, become risky for the public in general. These risks
|
||||
come in many forms, from racial discrimination to the misuse of sensitive information.
|
||||
BigScience believes in the intersection between open and responsible AI development; thus, this License
|
||||
aims to strike a balance between both in order to enable responsible open-science in the field of AI.
|
||||
This License governs the use of the model (and its derivatives) and is informed by the model card
|
||||
associated with the model.
|
||||
|
||||
NOW THEREFORE, You and Licensor agree as follows:
|
||||
|
||||
1. Definitions
|
||||
(a) "License" means the terms and conditions for use, reproduction, and Distribution as defined in
|
||||
this document.
|
||||
(b) “Data” means a collection of information and/or content extracted from the dataset used with the
|
||||
Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under
|
||||
this License.
|
||||
(c)“Output” means the results of operating a Model as embodied in informational content resulting
|
||||
therefrom.
|
||||
(d)“Model” means any accompanying machine-learning based assemblies (including checkpoints),
|
||||
consisting of learnt weights, parameters (including optimizer states), corresponding to the model
|
||||
architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or
|
||||
in part on the Data, using the Complementary Material.
|
||||
(e) “Derivatives of the Model” means all modifications to the Model, works based on the Model, or any
|
||||
other model which is created or initialized by transfer of patterns of the weights, parameters,
|
||||
activations or output of the Model, to the other model, in order to cause the other model to perform
|
||||
similarly to the Model, including - but not limited to - distillation methods entailing the use of
|
||||
intermediate data representations or methods based on the generation of synthetic data by the Model
|
||||
for training the other model.
|
||||
(f)“Complementary Material” means the accompanying source code and scripts used to define,
|
||||
run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if
|
||||
any. This includes any accompanying documentation, tutorials, examples, etc, if any.
|
||||
(g) “Distribution” means any transmission, reproduction, publication or other sharing of the Model or
|
||||
Derivatives of the Model to a third party, including providing the Model as a hosted service made
|
||||
available by electronic or other remote means - e.g. API-based or web access.
|
||||
(h) “Licensor” means the copyright owner or entity authorized by the copyright owner that is
|
||||
granting the License, including the persons or entities that may have rights in the Model and/or
|
||||
distributing the Model.
|
||||
(i) "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this
|
||||
License and/or making use of the Model for whichever purpose and in any field of use, including
|
||||
usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
|
||||
(j) “Third Parties” means individuals or legal entities that are not under common control with
|
||||
Licensor or You.
|
||||
(k) "Contribution" means any work of authorship, including the original version of the Model and
|
||||
any modifications or additions to that Model or Derivatives of the Model thereof, that is
|
||||
intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an
|
||||
individual or Legal Entity authorized to submit on behalf of the copyright owner. For the
|
||||
purposes of this definition,
|
||||
“submitted” means any form of electronic, verbal, or written
|
||||
communication sent to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems, and issue tracking
|
||||
systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and
|
||||
improving the Model, but excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
(l) "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a
|
||||
Contribution has been received by Licensor and subsequently incorporated within the Model.
|
||||
|
||||
|
||||
Section II: INTELLECTUAL PROPERTY RIGHTS
|
||||
|
||||
Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary
|
||||
Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor
|
||||
hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the
|
||||
Complementary Material, the Model, and Derivatives of the Model.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as
|
||||
applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge,
|
||||
royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer
|
||||
to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such
|
||||
license applies only to those patent claims licensable by such Contributor that are necessarily infringed by
|
||||
their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such
|
||||
Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim
|
||||
or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution
|
||||
incorporated within the Model and/or Complementary Material constitutes direct or contributory patent
|
||||
infringement, then any patent licenses granted to You under this License for the Model and/or Work shall
|
||||
terminate as of the date such litigation is asserted or filed.
|
||||
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
|
||||
|
||||
4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g.
|
||||
software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof
|
||||
in any medium, with or without modifications, provided that You meet the following conditions:
|
||||
|
||||
a. Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision
|
||||
by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the
|
||||
Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to,
|
||||
that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply
|
||||
to the use of Complementary Material.
|
||||
|
||||
b. You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this
|
||||
License;
|
||||
|
||||
c. You must cause any modified files to carry prominent notices stating that You changed the files;
|
||||
|
||||
d. You must retain all copyright, patent, trademark, and attribution notices excluding those notices
|
||||
that do not pertain to any part of the Model, Derivatives of the Model.
|
||||
You may add Your own copyright statement to Your modifications and may provide additional or
|
||||
different license terms and conditions - respecting paragraph 4.a.
|
||||
- for use, reproduction, or Distribution
|
||||
of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use,
|
||||
reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
|
||||
|
||||
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions.
|
||||
Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You
|
||||
may use the Model subject to this License, including only for lawful purposes and in accordance with the
|
||||
License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or
|
||||
reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model
|
||||
to comply with the terms of this paragraph (paragraph 5).
|
||||
|
||||
6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You
|
||||
generate using the Model. You are accountable for the Output you generate and its subsequent uses. No
|
||||
use of the output can contravene any provision as stated in the License.
|
||||
|
||||
Section IV: OTHER PROVISIONS
|
||||
|
||||
7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the
|
||||
right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model
|
||||
through electronic means, or modify the Output of the Model based on updates. You shall undertake
|
||||
reasonable efforts to use the latest version of the Model.
|
||||
|
||||
8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks,
|
||||
trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the
|
||||
parties; and any rights not expressly granted herein are reserved by the Licensors.
|
||||
|
||||
9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides
|
||||
the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS
|
||||
IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied,
|
||||
including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT,
|
||||
MERCHANTABILITY , or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for
|
||||
determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the
|
||||
Complementary Material and assume any risks associated with Your exercise of permissions under this
|
||||
License.
|
||||
|
||||
10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence),
|
||||
contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or
|
||||
agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect,
|
||||
special, incidental, or consequential damages of any character arising as a result of this License or out of
|
||||
the use or inability to use the Model and the Complementary Material (including but not limited to
|
||||
damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other
|
||||
commercial damages or losses), even if such Contributor has been advised of the possibility of such
|
||||
damages.
|
||||
|
||||
11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the
|
||||
Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance
|
||||
of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License.
|
||||
However, in accepting such obligations, You may act only on Your own behalf and on Your sole
|
||||
responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and
|
||||
hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor
|
||||
by reason of your accepting any such warranty or additional liability.
|
||||
|
||||
12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining
|
||||
provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
Attachment A
|
||||
|
||||
Use Restrictions
|
||||
|
||||
You agree not to use the Model or Derivatives of the Model:
|
||||
(a) In any way that violates any applicable national, federal, state, local or international law
|
||||
or regulation;
|
||||
(b) For the purpose of exploiting, harming or attempting to exploit or harm minors in any
|
||||
way;
|
||||
(c) To generate or disseminate verifiably false information and/or content with the purpose of
|
||||
harming others;
|
||||
(d) To generate or disseminate personal identifiable information that can be used to harm an
|
||||
individual;
|
||||
(e) To generate or disseminate information and/or content (e.g. images, code, posts, articles),
|
||||
and place the information and/or content in any context (e.g. bot generating tweets)
|
||||
without expressly and intelligibly disclaiming that the information and/or content is
|
||||
machine generated;
|
||||
(f) To defame, disparage or otherwise harass others;
|
||||
(g) To impersonate or attempt to impersonate (e.g. deepfakes) others without their consent;
|
||||
(h) For fully automated decision making that adversely impacts an individual’s legal rights or
|
||||
otherwise creates or modifies a binding, enforceable obligation;
|
||||
(i) For any use intended to or which has the effect of discriminating against or harming
|
||||
individuals or groups based on online or offline social behavior or known or predicted
|
||||
personal or personality characteristics;
|
||||
(j) To exploit any of the vulnerabilities of a specific group of persons based on their age,
|
||||
social, physical or mental characteristics, in order to materially distort the behavior of a
|
||||
person pertaining to that group in a manner that causes or is likely to cause that person or
|
||||
another person physical or psychological harm;
|
||||
(k) For any use intended to or which has the effect of discriminating against individuals or
|
||||
groups based on legally protected characteristics or categories;
|
||||
(l) To provide medical advice and medical results interpretation;
|
||||
(m) To generate or disseminate information for the purpose to be used for administration of
|
||||
justice, law enforcement, immigration or asylum processes, such as predicting an
|
||||
individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal
|
||||
relationships between assertions made in documents, indiscriminate and
|
||||
arbitrarily-targeted use).
|
||||
202
LICENSE-CODE
Normal file
202
LICENSE-CODE
Normal file
@@ -0,0 +1,202 @@
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
39
NOTICE
Normal file
39
NOTICE
Normal file
@@ -0,0 +1,39 @@
|
||||
supertonic-3-mlx
|
||||
================
|
||||
|
||||
This release is a derivative of the upstream Supertone Supertonic 3
|
||||
text-to-speech model and consists of two artefact classes governed by
|
||||
two different licenses:
|
||||
|
||||
1. The model weights (under ./weights/*.safetensors) are released under
|
||||
the BigScience Open RAIL-M License. The full text is in ./LICENSE and
|
||||
was copied verbatim from
|
||||
https://huggingface.co/Supertone/supertonic-3/blob/main/LICENSE
|
||||
The Attachment A use restrictions (Section 5 + Attachment A clauses
|
||||
(a)–(m)) apply to all downstream use of the model and of any output
|
||||
generated by the model.
|
||||
|
||||
2. The MLX port code (under ./src/supertonic_3_mlx/) is released under
|
||||
the Apache License, Version 2.0. The full text is in ./LICENSE-CODE.
|
||||
|
||||
Attribution and modifications statement (BigScience Open RAIL-M Section 4.c):
|
||||
|
||||
Copyright (c) 2026 Supertone Inc. — original model weights and reference
|
||||
Python/ONNX implementation. Distributed at
|
||||
https://huggingface.co/Supertone/supertonic-3
|
||||
Copyright (c) 2026 Olivier Dupont — MLX-native port code, weight format
|
||||
conversion (ONNX → safetensors via the 3-stage extractor in
|
||||
``src/supertonic_3_mlx/pipeline.py:_convert_onnx``), and pipeline
|
||||
optimisations (``mx.compile`` of the CFG Euler loop, cross-attention
|
||||
K/V cache shared across the 5 Euler steps). Distributed at
|
||||
https://huggingface.co/ambassadia/supertonic-3-mlx
|
||||
|
||||
The MLX port does not modify the model's learned parameters in any
|
||||
semantic sense — the only weight-level transformation is a tensor-shape
|
||||
re-layout to match the MLX memory model (e.g. depthwise Conv1d
|
||||
``(C, 1, K)`` → ``(C, K, 1)``). Bit-identical audio output to the
|
||||
upstream ONNX Runtime reference is preserved up to FP32 accumulation
|
||||
noise (cosine ≥ 0.98 on the full pipeline, cosine = 1.00 on the vocoder).
|
||||
|
||||
No use of the Supertone trademarks, logos, or trade dress is asserted or
|
||||
permitted by this release (BigScience Open RAIL-M Section 8).
|
||||
260
README.md
Normal file
260
README.md
Normal file
@@ -0,0 +1,260 @@
|
||||
---
|
||||
license: openrail
|
||||
license_link: LICENSE
|
||||
language:
|
||||
- en
|
||||
- fr
|
||||
- de
|
||||
- es
|
||||
- it
|
||||
- pt
|
||||
- ja
|
||||
- ko
|
||||
- zh
|
||||
- ru
|
||||
- pl
|
||||
- nl
|
||||
- tr
|
||||
- ar
|
||||
- hi
|
||||
- vi
|
||||
- th
|
||||
- id
|
||||
- cs
|
||||
- ro
|
||||
- hu
|
||||
- el
|
||||
- da
|
||||
- sv
|
||||
- fi
|
||||
- no
|
||||
- he
|
||||
- uk
|
||||
- bg
|
||||
- hr
|
||||
- sk
|
||||
pipeline_tag: text-to-speech
|
||||
tags:
|
||||
- mlx
|
||||
- apple-silicon
|
||||
- tts
|
||||
- text-to-speech
|
||||
- speech-synthesis
|
||||
- supertonic
|
||||
- multilingual
|
||||
- flow-matching
|
||||
library_name: supertonic-3-mlx
|
||||
base_model: Supertone/supertonic-3
|
||||
inference: false
|
||||
---
|
||||
|
||||
# Supertonic 3 — MLX-native
|
||||
|
||||
**31-language text-to-speech, ~x100 realtime on Apple Silicon.**
|
||||
Native MLX port of [Supertone/supertonic-3](https://huggingface.co/Supertone/supertonic-3),
|
||||
runs the full flow-matching + classifier-free-guidance pipeline (DurationPredictor →
|
||||
TextEncoder → 24-block VectorEstimator (5 Euler steps) → 10-block Vocos vocoder)
|
||||
without ONNX, CoreML or any C++ runtime — only MLX + NumPy.
|
||||
|
||||
## Install
|
||||
|
||||
The package isn't on PyPI yet — install directly from this gitea source
|
||||
repository (or from the local checkout):
|
||||
|
||||
```bash
|
||||
pip install git+https://gitea.tavportal.com/olivier/supertonic-3-mlx.git
|
||||
```
|
||||
|
||||
Runtime dependencies are just `mlx`, `numpy`, and `huggingface_hub` (the
|
||||
last for the one-line weight download). On first use the ~ 400 MB weight
|
||||
bundle is downloaded from
|
||||
[`ambassadia/supertonic-3-mlx`](https://huggingface.co/ambassadia/supertonic-3-mlx)
|
||||
into your Hugging Face cache.
|
||||
|
||||
### One-shot quickstart + sanity test
|
||||
|
||||
A zero-config end-to-end test script ships with the repo. Clone the repo,
|
||||
run the script, and it will create a fresh venv, install everything,
|
||||
version-check MLX (with an optional auto-upgrade), download the weights
|
||||
and synthesise an utterance into `hello.wav`:
|
||||
|
||||
```bash
|
||||
git clone https://gitea.tavportal.com/olivier/supertonic-3-mlx.git
|
||||
cd supertonic-3-mlx
|
||||
./setup_and_test.sh # en F1, default text
|
||||
./setup_and_test.sh fr F2 "Bonjour." # custom lang / voice / text
|
||||
```
|
||||
|
||||
Re-runs reuse the venv and the cached weights — second invocation is
|
||||
~ 20 ms warm load + ~ 30 ms per generate.
|
||||
|
||||
## Quickstart (after install)
|
||||
|
||||
```python
|
||||
from supertonic_3_mlx import Pipeline
|
||||
|
||||
pipe = Pipeline.from_pretrained("ambassadia/supertonic-3-mlx")
|
||||
wav = pipe.generate("Hello world from Apple Silicon.", voice="F1", lang="en")
|
||||
|
||||
# wav is a 1-D numpy.float32 array at 44.1 kHz
|
||||
import soundfile as sf
|
||||
sf.write("hello.wav", wav, pipe.sample_rate)
|
||||
```
|
||||
|
||||
## Audio samples
|
||||
|
||||
Six languages, mix of male / female voices, mix of short and long utterances —
|
||||
all generated by the MLX pipeline at the wall times reported below.
|
||||
|
||||
<audio controls src="samples/en_F1_short.wav"></audio> **EN · F1 · 2.79 s** —
|
||||
"Hello world from Apple Silicon. Supertonic 3 runs at one hundred times real time."
|
||||
|
||||
<audio controls src="samples/en_M1_long.wav"></audio> **EN · M1 · 3.90 s** —
|
||||
"A gentle breeze moved through the open window while the children, still half-asleep, listened to the distant sound of the harbour bells."
|
||||
|
||||
<audio controls src="samples/fr_F2.wav"></audio> **FR · F2 · 3.41 s** —
|
||||
"Bonjour, ceci est un test de synthèse vocale en français. Le modèle gère trente-et-une langues sur une puce M4."
|
||||
|
||||
<audio controls src="samples/de_M2.wav"></audio> **DE · M2 · 3.69 s** —
|
||||
"Guten Morgen. Dieses Modell läuft komplett auf Apple Silicon, ohne ONNX und ohne CoreML, in reinem MLX."
|
||||
|
||||
<audio controls src="samples/ja_F3.wav"></audio> **JA · F3 · 1.46 s** —
|
||||
"こんにちは。これはアップルシリコン上でMLXを使ったテストです。"
|
||||
|
||||
<audio controls src="samples/es_M3.wav"></audio> **ES · M3 · 2.86 s** —
|
||||
"Hola, esto es una prueba de síntesis de voz en español ejecutada en tiempo real sobre Apple Silicon."
|
||||
|
||||
## Benchmarks (Apple M4, FP32, median of 3)
|
||||
|
||||
| Sample | Duration | MLX wall | RTF | ONNX SDK | Speedup |
|
||||
|-----------------|---------:|----------:|----------:|---------:|--------:|
|
||||
| EN · F1 · short | 2.79 s | 36.6 ms | **x76** | 1005 ms | **28 ×**|
|
||||
| EN · M1 · long | 3.90 s | 38.4 ms | **x102** | 1356 ms | **35 ×**|
|
||||
| FR · F2 | 3.41 s | 37.9 ms | **x90** | 1196 ms | **32 ×**|
|
||||
| DE · M2 | 3.69 s | 38.1 ms | **x97** | 1314 ms | **35 ×**|
|
||||
| JA · F3 | 1.46 s | 32.1 ms | **x46** | 848 ms | **26 ×**|
|
||||
| ES · M3 | 2.86 s | 37.0 ms | **x77** | 1002 ms | **27 ×**|
|
||||
|
||||
Raw numbers are in [`bench_results.csv`](bench_results.csv) (regenerable via
|
||||
the development monorepo at
|
||||
[`gitea.tavportal.com/olivier/MLX_CONVERTOR`](https://gitea.tavportal.com/olivier/MLX_CONVERTOR);
|
||||
this repository ships the consolidated release artefacts only).
|
||||
|
||||
Reference comparison: the CoreML build of the same model on the same hardware
|
||||
runs at ~x27 realtime. The MLX port is **~2-4× faster** end-to-end while
|
||||
remaining bit-identical to the ONNX Runtime reference on the vocoder
|
||||
(cosine 1.00) and at cosine ≥ 0.98 on the full estimator output.
|
||||
|
||||
## Voices
|
||||
|
||||
10 preset voices — five female (`F1`–`F5`) and five male (`M1`–`M5`). The
|
||||
`voice_styles/` directory contains both `style_ttl` (50×256 latent style for
|
||||
the audio path) and `style_dp` (8×16 style for the duration head) for each
|
||||
voice. Pass the voice name as the `voice=` kwarg to `Pipeline.generate`.
|
||||
|
||||
## Languages
|
||||
|
||||
31 languages supported. Pass the ISO 639-1 code as the `lang=` kwarg:
|
||||
`en` `fr` `de` `es` `it` `pt` `ja` `ko` `zh` `ru` `pl` `nl` `tr` `ar` `hi`
|
||||
`vi` `th` `id` `cs` `ro` `hu` `el` `da` `sv` `fi` `no` `he` `uk` `bg` `hr` `sk`.
|
||||
|
||||
## Architecture (short)
|
||||
|
||||
Four sub-models, all in `weights/*.safetensors`:
|
||||
|
||||
| Sub-model | Role | Params | Size |
|
||||
|----------------------|-------------------------------------|--------|---------|
|
||||
| `vector_estimator` | 24-block CFG flow-matching velocity | ~64 M | 256 MB |
|
||||
| `text_encoder` | Character → 256-D text embedding | ~9 M | 36 MB |
|
||||
| `duration_predictor` | Text → seconds | ~1 M | 3.5 MB |
|
||||
| `vocoder` | Latent (B,144,T) → 44.1 kHz wav | ~25 M | 101 MB |
|
||||
|
||||
The pipeline runs **exactly 5 Euler steps** with classifier-free guidance
|
||||
(`4×cond − 3×uncond`). This schedule is trained-in: reducing the step count
|
||||
or disabling CFG produces an essentially uncorrelated waveform (verified
|
||||
empirically — see the `bench_n_steps.py` script in the source repo).
|
||||
|
||||
## Loading from a local snapshot
|
||||
|
||||
Three layouts are auto-detected by `Pipeline.from_pretrained`:
|
||||
|
||||
1. **Hugging Face repo id** (e.g. `"ambassadia/supertonic-3-mlx"`) — auto-download
|
||||
2. **Local path containing `weights/`** (this layout) — fastest cold-load
|
||||
3. **Local path containing `onnx/`** (upstream snapshot) — converts at load time
|
||||
|
||||
## License
|
||||
|
||||
This release combines two artefact classes under two distinct licenses:
|
||||
|
||||
- **Model weights** (`weights/*.safetensors`) — **BigScience Open RAIL-M**.
|
||||
See [`LICENSE`](LICENSE) for the full text. The Attachment A use
|
||||
restrictions are reproduced below and apply to all downstream use of the
|
||||
model and of generated audio.
|
||||
- **Port code** (`src/supertonic_3_mlx/`) — **Apache License 2.0**. See
|
||||
[`LICENSE-CODE`](LICENSE-CODE).
|
||||
|
||||
See [`NOTICE`](NOTICE) for the modifications statement and the upstream
|
||||
attribution.
|
||||
|
||||
### OpenRAIL-M Attachment A — use restrictions
|
||||
|
||||
You agree not to use the model or derivatives:
|
||||
|
||||
(a) In any way that violates any applicable national, federal, state, local or
|
||||
international law or regulation.
|
||||
|
||||
(b) For the purpose of exploiting, harming or attempting to exploit or harm
|
||||
minors in any way.
|
||||
|
||||
(c) To generate or disseminate verifiably false information and/or content
|
||||
with the purpose of harming others.
|
||||
|
||||
(d) To generate or disseminate personal identifiable information that can be
|
||||
used to harm an individual.
|
||||
|
||||
(e) To generate or disseminate information and/or content (e.g. images, code,
|
||||
posts, articles), and place the information and/or content in any context
|
||||
(e.g. bot generating tweets) **without expressly and intelligibly disclaiming
|
||||
that the information and/or content is machine generated**.
|
||||
|
||||
(f) To defame, disparage or otherwise harass others.
|
||||
|
||||
(g) To impersonate or attempt to impersonate (e.g. **deepfakes**) others
|
||||
without their consent.
|
||||
|
||||
(h) For fully automated decision making that adversely impacts an individual's
|
||||
legal rights or otherwise creates or modifies a binding, enforceable obligation.
|
||||
|
||||
(i) For any use intended to or which has the effect of discriminating against
|
||||
or harming individuals or groups based on online or offline social behavior or
|
||||
known or predicted personal or personality characteristics.
|
||||
|
||||
(j) To exploit any of the vulnerabilities of a specific group of persons based
|
||||
on their age, social, physical or mental characteristics, in order to materially
|
||||
distort the behavior of a person pertaining to that group in a manner that
|
||||
causes or is likely to cause that person or another person physical or
|
||||
psychological harm.
|
||||
|
||||
(k) For any use intended to or which has the effect of discriminating against
|
||||
individuals or groups based on legally protected characteristics or categories.
|
||||
|
||||
(l) **To provide medical advice and medical results interpretation.**
|
||||
|
||||
(m) To generate or disseminate information for the purpose to be used for
|
||||
administration of justice, law enforcement, immigration or asylum processes,
|
||||
such as predicting an individual will commit fraud/crime commitment.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{supertonic3-mlx,
|
||||
title = {Supertonic 3 MLX: native Apple Silicon port of Supertone's multilingual TTS},
|
||||
author = {Dupont, Olivier},
|
||||
year = {2026},
|
||||
url = {https://huggingface.co/ambassadia/supertonic-3-mlx},
|
||||
note = {Derivative of Supertone/supertonic-3 (https://huggingface.co/Supertone/supertonic-3)}
|
||||
}
|
||||
```
|
||||
|
||||
Please also cite the upstream Supertone Supertonic 3 model when using this
|
||||
port.
|
||||
7
bench_results.csv
Normal file
7
bench_results.csv
Normal file
@@ -0,0 +1,7 @@
|
||||
filename,language,voice,text,duration_s,mlx_ms_median,rtf_mlx,onnx_ms_median,rtf_onnx,speedup_mlx_over_onnx
|
||||
samples/en_F1_short.wav,en,F1,Hello world from Apple Silicon. Supertonic 3 runs at one hundred times real time.,2.786,36.6,76.2,1004.7,2.8,27.5
|
||||
samples/en_M1_long.wav,en,M1,"A gentle breeze moved through the open window while the children, still half-asleep, listened to the distant sound of the harbour bells.",3.901,38.4,101.7,1356.0,2.9,35.3
|
||||
samples/fr_F2.wav,fr,F2,"Bonjour, ceci est un test de synthèse vocale en français. Le modèle gère trente-et-une langues sur une puce M4.",3.413,37.9,90.1,1195.6,2.9,31.6
|
||||
samples/de_M2.wav,de,M2,"Guten Morgen. Dieses Modell läuft komplett auf Apple Silicon, ohne ONNX und ohne CoreML, in reinem MLX.",3.692,38.1,96.9,1313.9,2.8,34.5
|
||||
samples/ja_F3.wav,ja,F3,こんにちは。これはアップルシリコン上でMLXを使ったテストです。,1.463,32.1,45.6,848.4,1.7,26.4
|
||||
samples/es_M3.wav,es,M3,"Hola, esto es una prueba de síntesis de voz en español ejecutada en tiempo real sobre Apple Silicon.",2.856,37.0,77.2,1002.1,2.9,27.1
|
||||
|
226
conversion_report.json
Normal file
226
conversion_report.json
Normal file
@@ -0,0 +1,226 @@
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"model": "VectorEstimator",
|
||||
"onnx": "/tmp/supertonic3/model/onnx/vector_estimator.onnx",
|
||||
"safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/vector_estimator.safetensors",
|
||||
"bytes": 256053073,
|
||||
"sha256": "2359240f2dcaee03b4800102aa0bea00223d2867ab752ef01af2b1cfaf92f3a6",
|
||||
"weights_kept": 351,
|
||||
"weights_dropped": 120,
|
||||
"dropped_detail": {
|
||||
"tts.ae.vector_field.proj_in.net.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.0.convnext.0.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.0.convnext.0.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.0.convnext.0.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.0.convnext.0.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.0.convnext.1.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.0.convnext.1.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.0.convnext.1.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.0.convnext.1.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.0.convnext.2.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.0.convnext.2.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.0.convnext.2.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.0.convnext.2.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.0.convnext.3.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.0.convnext.3.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.0.convnext.3.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.0.convnext.3.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.2.convnext.0.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.2.convnext.0.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.2.convnext.0.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.2.convnext.0.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.4.convnext.0.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.4.convnext.0.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.4.convnext.0.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.4.convnext.0.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.6.convnext.0.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.6.convnext.0.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.6.convnext.0.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.6.convnext.0.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.6.convnext.1.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.6.convnext.1.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.6.convnext.1.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.6.convnext.1.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.6.convnext.2.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.6.convnext.2.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.6.convnext.2.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.6.convnext.2.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.6.convnext.3.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.6.convnext.3.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.6.convnext.3.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.6.convnext.3.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.8.convnext.0.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.8.convnext.0.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.8.convnext.0.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.8.convnext.0.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.10.convnext.0.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.10.convnext.0.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.10.convnext.0.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.10.convnext.0.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.12.convnext.0.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.12.convnext.0.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.12.convnext.0.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.12.convnext.0.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.12.convnext.1.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.12.convnext.1.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.12.convnext.1.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.12.convnext.1.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.12.convnext.2.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.12.convnext.2.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.12.convnext.2.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.12.convnext.2.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.12.convnext.3.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.12.convnext.3.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.12.convnext.3.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.12.convnext.3.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.14.convnext.0.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.14.convnext.0.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.14.convnext.0.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.14.convnext.0.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.16.convnext.0.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.16.convnext.0.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.16.convnext.0.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.16.convnext.0.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.18.convnext.0.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.18.convnext.0.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.18.convnext.0.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.18.convnext.0.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.18.convnext.1.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.18.convnext.1.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.18.convnext.1.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.18.convnext.1.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.18.convnext.2.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.18.convnext.2.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.18.convnext.2.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.18.convnext.2.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.18.convnext.3.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.18.convnext.3.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.18.convnext.3.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.18.convnext.3.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.20.convnext.0.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.20.convnext.0.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.20.convnext.0.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.20.convnext.0.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.22.convnext.0.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.22.convnext.0.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.22.convnext.0.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.main_blocks.22.convnext.0.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.last_convnext.convnext.0.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.last_convnext.convnext.0.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.last_convnext.convnext.0.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.last_convnext.convnext.0.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.last_convnext.convnext.1.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.last_convnext.convnext.1.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.last_convnext.convnext.1.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.last_convnext.convnext.1.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.last_convnext.convnext.2.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.last_convnext.convnext.2.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.last_convnext.convnext.2.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.last_convnext.convnext.2.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.last_convnext.convnext.3.pwconv1.weight": "not-in-model",
|
||||
"tts.ae.vector_field.last_convnext.convnext.3.pwconv1.bias": "not-in-model",
|
||||
"tts.ae.vector_field.last_convnext.convnext.3.pwconv2.weight": "not-in-model",
|
||||
"tts.ae.vector_field.last_convnext.convnext.3.pwconv2.bias": "not-in-model",
|
||||
"tts.ae.vector_field.proj_out.net.weight": "not-in-model",
|
||||
"<missing>.vector_field.main_blocks.9.attn.theta": "expected-but-not-extracted",
|
||||
"<missing>.vector_field.main_blocks.9.attn.increments": "expected-but-not-extracted",
|
||||
"<missing>.vector_field.main_blocks.15.attn.theta": "expected-but-not-extracted",
|
||||
"<missing>.vector_field.main_blocks.15.attn.increments": "expected-but-not-extracted",
|
||||
"<missing>.vector_field.main_blocks.21.attn.theta": "expected-but-not-extracted",
|
||||
"<missing>.vector_field.main_blocks.21.attn.increments": "expected-but-not-extracted"
|
||||
},
|
||||
"elapsed_s": 0.289
|
||||
},
|
||||
{
|
||||
"model": "TextEncoder",
|
||||
"onnx": "/tmp/supertonic3/model/onnx/text_encoder.onnx",
|
||||
"safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/text_encoder.safetensors",
|
||||
"bytes": 36022466,
|
||||
"sha256": "9df20bb79496718b36d2c0fc37636d3f78d6ef751b2899ff6dfeb975ae737ada",
|
||||
"weights_kept": 146,
|
||||
"weights_dropped": 0,
|
||||
"dropped_detail": {},
|
||||
"elapsed_s": 0.035
|
||||
},
|
||||
{
|
||||
"model": "DurationPredictor",
|
||||
"onnx": "/tmp/supertonic3/model/onnx/duration_predictor.onnx",
|
||||
"safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/duration_predictor.safetensors",
|
||||
"bytes": 3470807,
|
||||
"sha256": "cd473acb6e0ac27426084488ccb3b3cc184e70d05db90897e2b892846db5dcb3",
|
||||
"weights_kept": 98,
|
||||
"weights_dropped": 0,
|
||||
"dropped_detail": {},
|
||||
"elapsed_s": 0.007
|
||||
},
|
||||
{
|
||||
"model": "Vocoder",
|
||||
"onnx": "/tmp/supertonic3/model/onnx/vocoder.onnx",
|
||||
"safetensors": "/Users/transcrilive/MLX_CONVERTOR/sub-projects/supertonic3-mlx/hf_release/weights/vocoder.safetensors",
|
||||
"bytes": 101364763,
|
||||
"sha256": "b2ec31ab7c554f6e15b9a6780554b5d3502345de7848b310966bfb4e1ea4e526",
|
||||
"weights_kept": 103,
|
||||
"weights_dropped": 0,
|
||||
"dropped_detail": {},
|
||||
"elapsed_s": 0.079
|
||||
}
|
||||
],
|
||||
"ancillary": [
|
||||
{
|
||||
"name": "unicode_indexer.json",
|
||||
"bytes": 277676,
|
||||
"sha256": "9bf7346e43883a81f8645c81224f786d43c5b57f3641f6e7671a7d6c493cb24f"
|
||||
},
|
||||
{
|
||||
"name": "voice_styles/F1.json",
|
||||
"bytes": 292046,
|
||||
"sha256": "bbdec6ee00231c2c742ad05483df5334cab3b52fda3ba38e6a07059c4563dbc2"
|
||||
},
|
||||
{
|
||||
"name": "voice_styles/F2.json",
|
||||
"bytes": 292423,
|
||||
"sha256": "7c722c6a72707b1a77f035d67f0d1351ba187738e06f7683e8c72b1df3477fc6"
|
||||
},
|
||||
{
|
||||
"name": "voice_styles/F3.json",
|
||||
"bytes": 290794,
|
||||
"sha256": "12f6ef2573baa2defa1128069cb59f203e3ab67c92af77b42df8a0e3a2f7c6ab"
|
||||
},
|
||||
{
|
||||
"name": "voice_styles/F4.json",
|
||||
"bytes": 291808,
|
||||
"sha256": "c2fa764c1225a76dfc3e2c73e8aa4f70d9ee48793860eb34c295fff01c2e032b"
|
||||
},
|
||||
{
|
||||
"name": "voice_styles/F5.json",
|
||||
"bytes": 291479,
|
||||
"sha256": "45966e73316415626cf41a7d1c6f3b4c70dbc1ba2bee5c1978ef0ce33244fc8d"
|
||||
},
|
||||
{
|
||||
"name": "voice_styles/M1.json",
|
||||
"bytes": 291748,
|
||||
"sha256": "e35604687f5d23694b8e91593a93eec0e4eca6c0b02bb8ed69139ab2ea6b0a5b"
|
||||
},
|
||||
{
|
||||
"name": "voice_styles/M2.json",
|
||||
"bytes": 292055,
|
||||
"sha256": "b76cbf62bac707c710cf0ae5aba5e31eea1a6339a9734bfae33ab98499534a50"
|
||||
},
|
||||
{
|
||||
"name": "voice_styles/M3.json",
|
||||
"bytes": 290198,
|
||||
"sha256": "ea1ac35ccb91b0d7ecad533a2fbd0eec10c91513d8951e3b25fbba99954e159b"
|
||||
},
|
||||
{
|
||||
"name": "voice_styles/M4.json",
|
||||
"bytes": 291522,
|
||||
"sha256": "ca8eefad4fcd989c9379032ff3e50738adc547eeb5e221b82593a6d7b3bac303"
|
||||
},
|
||||
{
|
||||
"name": "voice_styles/M5.json",
|
||||
"bytes": 291469,
|
||||
"sha256": "dd22b92740314321f8ae11c5e87f8dd60d060f15dd3a632b5adf77f471f77af2"
|
||||
}
|
||||
]
|
||||
}
|
||||
23
examples/quickstart.py
Normal file
23
examples/quickstart.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Minimal Supertonic 3 MLX usage — 5 lines, no fluff.
|
||||
|
||||
Run from anywhere AFTER ``pip install supertonic-3-mlx`` (or from inside
|
||||
this directory after ``pip install ./``):
|
||||
|
||||
python examples/quickstart.py
|
||||
"""
|
||||
from supertonic_3_mlx import Pipeline
|
||||
import soundfile as sf
|
||||
|
||||
# When the package has been pip-installed, this auto-downloads from the Hub
|
||||
# (~ 400 MB) into the standard Hugging Face cache. After the first run, the
|
||||
# weights are reused from cache and cold start is ~ 11 ms on M4.
|
||||
pipe = Pipeline.from_pretrained("ambassadia/supertonic-3-mlx")
|
||||
|
||||
wav = pipe.generate(
|
||||
"Hello world from Apple Silicon. Supertonic 3 runs at one hundred times realtime.",
|
||||
voice="F1", # one of F1..F5, M1..M5
|
||||
lang="en", # ISO 639-1
|
||||
)
|
||||
|
||||
sf.write("hello.wav", wav, pipe.sample_rate)
|
||||
print(f"wrote hello.wav — {len(wav) / pipe.sample_rate:.2f}s of audio")
|
||||
43
pyproject.toml
Normal file
43
pyproject.toml
Normal file
@@ -0,0 +1,43 @@
|
||||
[project]
|
||||
name = "supertonic-3-mlx"
|
||||
version = "0.1.0"
|
||||
description = "MLX-native port of Supertone's Supertonic 3 multilingual TTS (31 languages, ~x100 realtime on Apple Silicon)"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
authors = [{ name = "Olivier Dupont", email = "olivier.dupont@taviramonaco.com" }]
|
||||
license = { text = "Apache-2.0 AND OpenRAIL-M" }
|
||||
keywords = ["mlx", "tts", "speech-synthesis", "apple-silicon", "supertonic", "multilingual"]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Environment :: MacOS X",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: MacOS",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Multimedia :: Sound/Audio :: Speech",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
]
|
||||
dependencies = [
|
||||
"mlx>=0.21.0",
|
||||
"numpy>=1.24.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
hub = ["huggingface_hub>=0.26.0"]
|
||||
dev = ["pytest>=8.3.0", "ruff>=0.7.0"]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://huggingface.co/ambassadia/supertonic-3-mlx"
|
||||
Upstream = "https://huggingface.co/Supertone/supertonic-3"
|
||||
Source = "https://gitea.tavportal.com/olivier/supertonic-3-mlx"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/supertonic_3_mlx"]
|
||||
BIN
samples/de_M2.wav
Normal file
BIN
samples/de_M2.wav
Normal file
Binary file not shown.
BIN
samples/en_F1_short.wav
Normal file
BIN
samples/en_F1_short.wav
Normal file
Binary file not shown.
BIN
samples/en_M1_long.wav
Normal file
BIN
samples/en_M1_long.wav
Normal file
Binary file not shown.
BIN
samples/es_M3.wav
Normal file
BIN
samples/es_M3.wav
Normal file
Binary file not shown.
BIN
samples/fr_F2.wav
Normal file
BIN
samples/fr_F2.wav
Normal file
Binary file not shown.
BIN
samples/ja_F3.wav
Normal file
BIN
samples/ja_F3.wav
Normal file
Binary file not shown.
131
setup_and_test.sh
Executable file
131
setup_and_test.sh
Executable file
@@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env bash
|
||||
# Quick install + sanity-test for the supertonic-3-mlx standalone package.
|
||||
#
|
||||
# Creates a local ``.venv`` next to this script, installs the package and its
|
||||
# runtime deps, version-checks MLX, downloads the model weights from the
|
||||
# Hugging Face Hub on first run, and synthesises one short utterance to
|
||||
# ``hello.wav``. Idempotent: re-running reuses the existing venv and cached
|
||||
# weights.
|
||||
#
|
||||
# Usage:
|
||||
# ./setup_and_test.sh # default: en F1, "Hello world…"
|
||||
# ./setup_and_test.sh fr F2 "Bonjour."
|
||||
#
|
||||
set -euo pipefail
|
||||
|
||||
# ── 0. Inputs ────────────────────────────────────────────────────────
|
||||
LANG_CODE="${1:-en}"
|
||||
VOICE="${2:-F1}"
|
||||
TEXT="${3:-Hello world from Apple Silicon. Supertonic 3 runs at one hundred times realtime.}"
|
||||
|
||||
HERE="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
VENV="$HERE/.venv"
|
||||
|
||||
# ── 1. Python version gate ──────────────────────────────────────────
|
||||
if ! command -v python3 >/dev/null; then
|
||||
echo "ERROR: python3 not found. Install Python 3.10+ first." >&2
|
||||
exit 1
|
||||
fi
|
||||
PYVER="$(python3 -c 'import sys; print("%d.%d"%sys.version_info[:2])')"
|
||||
PYMAJ="${PYVER%.*}"; PYMIN="${PYVER#*.}"
|
||||
if [ "$PYMAJ" -lt 3 ] || { [ "$PYMAJ" -eq 3 ] && [ "$PYMIN" -lt 10 ]; }; then
|
||||
echo "ERROR: Python 3.10+ required, found $PYVER." >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "→ python3: $PYVER"
|
||||
|
||||
# ── 2. venv ─────────────────────────────────────────────────────────
|
||||
if [ ! -x "$VENV/bin/python" ]; then
|
||||
echo "→ creating venv at $VENV …"
|
||||
python3 -m venv "$VENV"
|
||||
fi
|
||||
PIP="$VENV/bin/pip"
|
||||
PY="$VENV/bin/python"
|
||||
|
||||
# ── 3. dependencies ─────────────────────────────────────────────────
|
||||
echo "→ installing dependencies …"
|
||||
"$PIP" install --quiet --upgrade pip
|
||||
# Install the package + the optional runtime deps. The package itself pulls in
|
||||
# mlx + numpy via its pyproject.toml; we add huggingface_hub for the Hub
|
||||
# download path, hf_transfer for large-blob throughput, and soundfile so the
|
||||
# test script can write a WAV.
|
||||
"$PIP" install --quiet "$HERE" huggingface_hub soundfile
|
||||
|
||||
# ── 4. MLX version gate + optional patch hook ───────────────────────
|
||||
"$PY" - <<'PYEOF'
|
||||
import sys
|
||||
|
||||
try:
|
||||
import mlx.core as mx
|
||||
except ImportError:
|
||||
print("ERROR: mlx not importable. Are you on Apple Silicon? "
|
||||
"MLX is macOS-on-Apple-Silicon only.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
ver_str = getattr(mx, "__version__", "0.0.0")
|
||||
ver = tuple(int(p) for p in ver_str.split(".")[:3] if p.isdigit())
|
||||
print(f"→ mlx version: {ver_str}")
|
||||
|
||||
# Minimum tested combination — bumped as the upstream API changes.
|
||||
MIN_OK = (0, 21, 0)
|
||||
if ver < MIN_OK:
|
||||
print(f" WARNING: mlx < {'.'.join(map(str, MIN_OK))}. Upgrading …",
|
||||
file=sys.stderr)
|
||||
import subprocess
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install",
|
||||
"--quiet", "--upgrade", "mlx"])
|
||||
print(" → upgraded. Re-run the script to pick up the new version.",
|
||||
file=sys.stderr)
|
||||
sys.exit(2)
|
||||
|
||||
# Patches we currently know about — none. This is the slot where future
|
||||
# MLX-specific shims would land (e.g. a workaround for an upstream Conv1d
|
||||
# regression). Keep the dispatch table here so the script stays a single
|
||||
# source of truth.
|
||||
PATCHES: dict[tuple[int, int, int], str] = {
|
||||
# (broken_version): "patch description"
|
||||
}
|
||||
applied = [desc for v, desc in PATCHES.items() if v == ver]
|
||||
if applied:
|
||||
for desc in applied:
|
||||
print(f" applied patch: {desc}")
|
||||
else:
|
||||
print(f" no patches needed for mlx {ver_str}")
|
||||
PYEOF
|
||||
|
||||
# ── 5. quickstart generate ──────────────────────────────────────────
|
||||
echo "→ generating audio …"
|
||||
LANG_CODE="$LANG_CODE" VOICE="$VOICE" TEXT="$TEXT" \
|
||||
HF_HUB_DISABLE_XET=1 \
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||
"$PY" - <<'PYEOF'
|
||||
import os, time
|
||||
from supertonic_3_mlx import Pipeline
|
||||
import soundfile as sf
|
||||
|
||||
lang = os.environ["LANG_CODE"]
|
||||
voice = os.environ["VOICE"]
|
||||
text = os.environ["TEXT"]
|
||||
|
||||
# First call downloads ~ 400 MB of weights into the HF cache. Subsequent
|
||||
# runs reuse the cache and load in ~ 20 ms.
|
||||
t0 = time.perf_counter()
|
||||
pipe = Pipeline.from_pretrained("ambassadia/supertonic-3-mlx")
|
||||
load_t = time.perf_counter() - t0
|
||||
print(f" load : {load_t*1000:.0f} ms")
|
||||
|
||||
# Warmup (compiles the kernel graph for this shape).
|
||||
pipe.generate("Warm.", voice=voice, lang=lang)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
wav = pipe.generate(text, voice=voice, lang=lang, seed=42)
|
||||
gen_t = time.perf_counter() - t0
|
||||
dur = len(wav) / pipe.sample_rate
|
||||
print(f" generate : {gen_t*1000:.0f} ms")
|
||||
print(f" audio : {dur:.2f} s ({len(wav)} samples @ {pipe.sample_rate} Hz)")
|
||||
print(f" RTF : x{dur/gen_t:.0f}")
|
||||
print(f" max amp : {abs(wav).max():.4f}")
|
||||
|
||||
sf.write("hello.wav", wav, pipe.sample_rate)
|
||||
print("\n✓ wrote hello.wav — open it to verify the synthesis sounds correct.")
|
||||
PYEOF
|
||||
51
src/supertonic_3_mlx/__init__.py
Normal file
51
src/supertonic_3_mlx/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Supertonic 3 — MLX-native TTS for Apple Silicon.
|
||||
|
||||
31-language text-to-speech, 5 Euler steps with classifier-free guidance, in
|
||||
pure MLX. On M4 the full pipeline runs at ~x100 realtime.
|
||||
|
||||
Quickstart
|
||||
----------
|
||||
|
||||
from supertonic_3_mlx import Pipeline
|
||||
pipe = Pipeline.from_pretrained("ambassadia/supertonic-3-mlx")
|
||||
wav = pipe.generate("Hello world from Apple Silicon.", voice="F1", lang="en")
|
||||
# wav is a 1-D ``numpy.float32`` array at 44.1 kHz.
|
||||
|
||||
The model weights are released under the BigScience OpenRAIL-M license
|
||||
(see LICENSE in the Hugging Face repository). This MLX port code is
|
||||
Apache-2.0. Together they form a dual-license package; Attachment A use
|
||||
restrictions of OpenRAIL-M govern downstream use of the generated audio.
|
||||
|
||||
Public API:
|
||||
Pipeline — end-to-end TTS, ``from_pretrained`` + ``generate``
|
||||
VectorEstimator — the 24-block CFG flow-matching net (sub-model 1/4)
|
||||
TextEncoder — character → text embedding (sub-model 2/4)
|
||||
DurationPredictor — text → duration in seconds (sub-model 3/4)
|
||||
Vocoder — latent → 44.1 kHz waveform (sub-model 4/4)
|
||||
"""
|
||||
from supertonic_3_mlx._config import (
|
||||
DIM, LATENT_CH, CONVNEXT_HIDDEN, CONVNEXT_K,
|
||||
NUM_MAIN_BLOCKS, NUM_CYCLES, BLOCKS_PER_CYCLE, BLOCK_CYCLE, STACK4_DILATIONS,
|
||||
TEXT_HEADS, TEXT_HEAD_DIM, TEXT_DIM, ROTARY_BASE, ROTARY_SCALE,
|
||||
STYLE_HEADS, STYLE_HEAD_DIM, STYLE_LEN, STYLE_DIM,
|
||||
TIME_EMB_DIM, TIME_MLP_HIDDEN,
|
||||
EPS_LN, CHUNK_COMPRESS, LATENT_DIM, SAMPLE_RATE,
|
||||
SUPERTONIC3_HF_REPO,
|
||||
)
|
||||
from supertonic_3_mlx.duration_predictor import DurationPredictor
|
||||
from supertonic_3_mlx.text_encoder import TextEncoder
|
||||
from supertonic_3_mlx.vector_estimator import VectorEstimator
|
||||
from supertonic_3_mlx.vocoder import Vocoder
|
||||
from supertonic_3_mlx.pipeline import SupertonicMLXPipeline as Pipeline
|
||||
|
||||
__all__ = [
|
||||
"Pipeline",
|
||||
"DurationPredictor", "TextEncoder", "VectorEstimator", "Vocoder",
|
||||
"DIM", "LATENT_CH", "CONVNEXT_HIDDEN", "CONVNEXT_K",
|
||||
"NUM_MAIN_BLOCKS", "NUM_CYCLES", "BLOCKS_PER_CYCLE", "BLOCK_CYCLE", "STACK4_DILATIONS",
|
||||
"TEXT_HEADS", "TEXT_HEAD_DIM", "TEXT_DIM", "ROTARY_BASE", "ROTARY_SCALE",
|
||||
"STYLE_HEADS", "STYLE_HEAD_DIM", "STYLE_LEN", "STYLE_DIM",
|
||||
"TIME_EMB_DIM", "TIME_MLP_HIDDEN",
|
||||
"EPS_LN", "CHUNK_COMPRESS", "LATENT_DIM", "SAMPLE_RATE",
|
||||
"SUPERTONIC3_HF_REPO",
|
||||
]
|
||||
58
src/supertonic_3_mlx/_config.py
Normal file
58
src/supertonic_3_mlx/_config.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Locked hyperparameters for Supertonic 3 MLX port.
|
||||
|
||||
Derived from the official ``Supertone/supertonic-3/onnx/tts.json``.
|
||||
Changing these = re-running parity tests.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
# Vector estimator (the flow-matching denoiser)
|
||||
DIM: int = 512 # backbone width
|
||||
LATENT_CH: int = 144 # 24 * chunk_compress_factor (6)
|
||||
CONVNEXT_HIDDEN: int = 2048 # main_blocks ConvNeXt intermediate dim (2× vs s2)
|
||||
CONVNEXT_K: int = 5
|
||||
LAST_CONVNEXT_NUM: int = 4 # last_convnext is a 4-layer stack (dilations [1,1,1,1])
|
||||
|
||||
# 24 main_blocks = 4 cycles × 6 sub-blocks (cycle: stack4, time, cn1, text_attn, cn1, style_attn)
|
||||
NUM_CYCLES: int = 4
|
||||
BLOCKS_PER_CYCLE: int = 6
|
||||
NUM_MAIN_BLOCKS: int = NUM_CYCLES * BLOCKS_PER_CYCLE
|
||||
BLOCK_CYCLE = ("stack4", "time", "cn1", "text_attn", "cn1", "style_attn")
|
||||
|
||||
# ConvNeXt stack 4 (in stack4 blocks) — dilation schedule
|
||||
STACK4_DILATIONS = (1, 2, 4, 8)
|
||||
|
||||
# Text cross-attention (RoPE) — block type "text_attn"
|
||||
TEXT_DIM: int = 256
|
||||
TEXT_HEADS: int = 8 # 2× vs s2 (4)
|
||||
TEXT_HEAD_DIM: int = DIM // TEXT_HEADS # 512/8 = 64
|
||||
ROTARY_BASE: int = 10_000
|
||||
ROTARY_SCALE: int = 10
|
||||
|
||||
# Style cross-attention — block type "style_attn"
|
||||
STYLE_DIM: int = 256
|
||||
STYLE_LEN: int = 50 # 50 style tokens (n_style)
|
||||
STYLE_HEADS: int = 2
|
||||
STYLE_HEAD_DIM: int = 128
|
||||
|
||||
# Time encoding (sinusoidal + MLP)
|
||||
TIME_EMB_DIM: int = 64
|
||||
TIME_MLP_HIDDEN: int = 256
|
||||
|
||||
# LayerNorm epsilon
|
||||
EPS_LN: float = 1e-6
|
||||
|
||||
# Chunk compress factor (used by AE)
|
||||
CHUNK_COMPRESS: int = 6
|
||||
LATENT_DIM: int = 24 # ldim before chunk compression
|
||||
|
||||
# Sample rate
|
||||
SAMPLE_RATE: int = 44_100
|
||||
|
||||
# HF references (will be pinned to SHA after first download)
|
||||
SUPERTONIC3_HF_REPO: str = "Supertone/supertonic-3"
|
||||
ONNX_VECTOR_ESTIMATOR: str = "onnx/vector_estimator.onnx"
|
||||
ONNX_TEXT_ENCODER: str = "onnx/text_encoder.onnx"
|
||||
ONNX_DURATION_PREDICTOR: str = "onnx/duration_predictor.onnx"
|
||||
ONNX_VOCODER: str = "onnx/vocoder.onnx"
|
||||
ONNX_TTS_JSON: str = "onnx/tts.json"
|
||||
ONNX_UNICODE_INDEXER: str = "onnx/unicode_indexer.json"
|
||||
50
src/supertonic_3_mlx/_nn_wrappers.py
Normal file
50
src/supertonic_3_mlx/_nn_wrappers.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Small wrapper modules to match Supertonic 3 ONNX submodule nesting.
|
||||
|
||||
The s3 checkpoint nests primitives one level deeper than typical MLX modules:
|
||||
- ``norm.norm.weight`` — LayerNorm wrapped in a Norm container
|
||||
- ``linear.linear.weight`` — Linear wrapped in a Linear container
|
||||
- ``W_query.linear.weight`` — attention projection wrapped
|
||||
|
||||
Mirroring this nesting lets us load the safetensors with ``model.load_weights(...)``
|
||||
without any key remapping at load time.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class WrappedNorm(nn.Module):
|
||||
"""Container with a single nested LayerNorm — produces key ``X.norm.weight``."""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim, eps=eps)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class WrappedLinear(nn.Module):
|
||||
"""Container with a single nested Linear — produces keys ``X.linear.weight/bias``."""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(in_dim, out_dim, bias=bias)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class ProjConv1x1(nn.Module):
|
||||
"""Conv1d k=1 expressed as ``self.net = Linear`` (matches ``proj_in.net.weight``)."""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.net = nn.Linear(in_dim, out_dim, bias=bias)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
__all__ = ["WrappedNorm", "WrappedLinear", "ProjConv1x1"]
|
||||
347
src/supertonic_3_mlx/duration_predictor.py
Normal file
347
src/supertonic_3_mlx/duration_predictor.py
Normal file
@@ -0,0 +1,347 @@
|
||||
"""Supertonic 3 duration predictor — predicts total audio duration in seconds.
|
||||
|
||||
Pipeline (channels-last NTC throughout):
|
||||
|
||||
text_ids [B, T] int64 character IDs
|
||||
→ char_embed (Embedding 8322→64) [B, T, 64]
|
||||
→ prepend sentence_token (1, 64, 1) [B, T+1, 64]
|
||||
→ 6× ConvNeXt (dim=64, hidden=256, k=5, all dilations=1)
|
||||
→ 2× RelPosSelfAttn (heads=2, head_dim=32, window=4) + norm + FFN + norm
|
||||
→ proj_out (Conv1d k=1: 64→64) applied to slot 0 (sentence token)
|
||||
→ concat with style_dp flattened (B, 8×16=128) [B, 192]
|
||||
→ Linear(192 → 128) → PReLU → Linear(128 → 1) → exp → duration [B]
|
||||
|
||||
Inputs:
|
||||
text_ids: (B, T) int — character indices
|
||||
style_dp: (B, 8, 16) — style summary tokens
|
||||
text_mask: (B, 1, T) — 1.0 valid, 0.0 padded
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from supertonic_3_mlx._config import EPS_LN
|
||||
from supertonic_3_mlx._nn_wrappers import WrappedNorm
|
||||
from supertonic_3_mlx.vector_estimator import _pad_sym_edge, _gelu_exact
|
||||
|
||||
|
||||
DP_VOCAB = 8322
|
||||
DP_DIM = 64
|
||||
DP_CONVNEXT_HIDDEN = 256
|
||||
DP_CONVNEXT_K = 5
|
||||
DP_CONVNEXT_NUM_LAYERS = 6
|
||||
DP_ATTN_NUM_LAYERS = 2
|
||||
DP_ATTN_HEADS = 2
|
||||
DP_ATTN_HEAD_DIM = DP_DIM // DP_ATTN_HEADS # 32
|
||||
DP_FFN_HIDDEN = 256
|
||||
DP_REL_POS_WINDOW = 4
|
||||
DP_N_STYLE = 8
|
||||
DP_STYLE_DIM = 16
|
||||
DP_MLP_IN = DP_DIM + DP_N_STYLE * DP_STYLE_DIM # 64 + 128 = 192
|
||||
DP_MLP_HIDDEN = 128
|
||||
|
||||
|
||||
class _DPConvNeXtBlock(nn.Module):
|
||||
"""ConvNeXt block (dim=64, hidden=256, dilation=1)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.dwconv = nn.Conv1d(
|
||||
DP_DIM, DP_DIM, kernel_size=DP_CONVNEXT_K, padding=0,
|
||||
dilation=1, groups=DP_DIM, bias=True,
|
||||
)
|
||||
self.norm = WrappedNorm(DP_DIM, eps=EPS_LN)
|
||||
self.pwconv1 = nn.Linear(DP_DIM, DP_CONVNEXT_HIDDEN, bias=True)
|
||||
self.pwconv2 = nn.Linear(DP_CONVNEXT_HIDDEN, DP_DIM, bias=True)
|
||||
self.gamma = mx.zeros((DP_DIM,))
|
||||
self.pad = (DP_CONVNEXT_K - 1) // 2
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
residual = x
|
||||
y = _pad_sym_edge(x, self.pad)
|
||||
y = self.dwconv(y)
|
||||
y = self.norm(y)
|
||||
y = self.pwconv1(y)
|
||||
y = _gelu_exact(y)
|
||||
y = self.pwconv2(y)
|
||||
y = y * self.gamma
|
||||
out = residual + y
|
||||
if mask is not None:
|
||||
out = out * mask
|
||||
return out
|
||||
|
||||
|
||||
class _DPConvNeXtStack(nn.Module):
|
||||
"""``convnext.[0..5]`` — 6 ConvNeXt blocks."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.convnext = [_DPConvNeXtBlock() for _ in range(DP_CONVNEXT_NUM_LAYERS)]
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
for b in self.convnext:
|
||||
x = b(x, mask)
|
||||
return x
|
||||
|
||||
|
||||
class _DPConvLayer(nn.Module):
|
||||
"""Conv1d k=1 with weight (out, 1, in) — matches ONNX storage."""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.weight = mx.zeros((out_dim, 1, in_dim))
|
||||
self.bias = mx.zeros((out_dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return mx.conv1d(x, self.weight, stride=1, padding=0) + self.bias
|
||||
|
||||
|
||||
def _dp_rel_to_abs(x: mx.array) -> mx.array:
|
||||
"""(B, h, L, 2L-1) → (B, h, L, L) via VITS shifted-skew reshape."""
|
||||
B, h, L, _ = x.shape
|
||||
x = mx.concatenate([x, mx.zeros((B, h, L, 1), dtype=x.dtype)], axis=-1)
|
||||
x_flat = x.reshape(B, h, L * 2 * L)
|
||||
x_flat = mx.concatenate([x_flat, mx.zeros((B, h, L - 1), dtype=x.dtype)], axis=-1)
|
||||
x_final = x_flat.reshape(B, h, L + 1, 2 * L - 1)
|
||||
return x_final[:, :, :L, L - 1:]
|
||||
|
||||
|
||||
def _dp_abs_to_rel(x: mx.array) -> mx.array:
|
||||
"""(B, h, L, L) → (B, h, L, 2L-1)."""
|
||||
B, h, L, _ = x.shape
|
||||
x = mx.concatenate([x, mx.zeros((B, h, L, L - 1), dtype=x.dtype)], axis=-1)
|
||||
x_flat = x.reshape(B, h, L * (2 * L - 1))
|
||||
x_flat = mx.concatenate([mx.zeros((B, h, L), dtype=x.dtype), x_flat], axis=-1)
|
||||
x_final = x_flat.reshape(B, h, L, 2 * L)
|
||||
return x_final[:, :, :, 1:]
|
||||
|
||||
|
||||
def _dp_slice_rel(rel: mx.array, length: int, window: int) -> mx.array:
|
||||
"""(1, 2W+1, d) → (1, 2L-1, d) by zero-padding/slicing."""
|
||||
pad_l = max(length - (window + 1), 0)
|
||||
if pad_l > 0:
|
||||
zero = mx.zeros((1, pad_l, rel.shape[-1]), dtype=rel.dtype)
|
||||
padded = mx.concatenate([zero, rel, zero], axis=1)
|
||||
else:
|
||||
padded = rel
|
||||
start = max(window + 1 - length, 0)
|
||||
return padded[:, start: start + 2 * length - 1]
|
||||
|
||||
|
||||
class _DPRelPosSelfAttn(nn.Module):
|
||||
"""VITS-style rel-pos self-attention (2 heads × 32 head_dim, window=4).
|
||||
|
||||
Includes both rel-pos contributions (q × rel_k → logits, abs_to_rel(attn) × rel_v → out).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv_q = _DPConvLayer(DP_DIM, DP_DIM)
|
||||
self.conv_k = _DPConvLayer(DP_DIM, DP_DIM)
|
||||
self.conv_v = _DPConvLayer(DP_DIM, DP_DIM)
|
||||
self.conv_o = _DPConvLayer(DP_DIM, DP_DIM)
|
||||
self.emb_rel_k = mx.zeros((1, 2 * DP_REL_POS_WINDOW + 1, DP_ATTN_HEAD_DIM))
|
||||
self.emb_rel_v = mx.zeros((1, 2 * DP_REL_POS_WINDOW + 1, DP_ATTN_HEAD_DIM))
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
B, T, _ = x.shape
|
||||
H, D = DP_ATTN_HEADS, DP_ATTN_HEAD_DIM
|
||||
q = self.conv_q(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
|
||||
k = self.conv_k(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
|
||||
v = self.conv_v(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
|
||||
scale = D ** -0.5
|
||||
|
||||
logits = (q @ k.transpose(0, 1, 3, 2)) * scale
|
||||
|
||||
rel_k = _dp_slice_rel(self.emb_rel_k, T, DP_REL_POS_WINDOW)
|
||||
rel_logits = q @ rel_k.transpose(0, 2, 1)[:, None, :, :]
|
||||
rel_logits = _dp_rel_to_abs(rel_logits * scale)
|
||||
logits = logits + rel_logits
|
||||
|
||||
if mask is not None:
|
||||
key_mask = mask[:, :, 0][:, None, None, :]
|
||||
neg_inf = mx.array(-1e4, dtype=logits.dtype)
|
||||
logits = mx.where(key_mask.astype(mx.bool_), logits, neg_inf)
|
||||
|
||||
attn = mx.softmax(logits, axis=-1)
|
||||
out = attn @ v
|
||||
|
||||
rel_v = _dp_slice_rel(self.emb_rel_v, T, DP_REL_POS_WINDOW)
|
||||
rel_weights = _dp_abs_to_rel(attn)
|
||||
out = out + rel_weights @ rel_v[:, None, :, :]
|
||||
|
||||
out = out.transpose(0, 2, 1, 3).reshape(B, T, H * D)
|
||||
return self.conv_o(out)
|
||||
|
||||
|
||||
class _DPFFN(nn.Module):
|
||||
"""FFN with two Conv1d k=1 — 64 → 256 → 64, ReLU + mask."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv_1 = _DPConvLayer(DP_DIM, DP_FFN_HIDDEN)
|
||||
self.conv_2 = _DPConvLayer(DP_FFN_HIDDEN, DP_DIM)
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
if mask is not None:
|
||||
x = x * mask
|
||||
y = self.conv_1(x)
|
||||
y = mx.maximum(y, mx.array(0.0, dtype=y.dtype))
|
||||
if mask is not None:
|
||||
y = y * mask
|
||||
y = self.conv_2(y)
|
||||
if mask is not None:
|
||||
y = y * mask
|
||||
return y
|
||||
|
||||
|
||||
class _DPAttnEncoder(nn.Module):
|
||||
"""2× (attn + norm) + (ffn + norm)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attn_layers = [_DPRelPosSelfAttn() for _ in range(DP_ATTN_NUM_LAYERS)]
|
||||
self.norm_layers_1 = [WrappedNorm(DP_DIM, eps=EPS_LN) for _ in range(DP_ATTN_NUM_LAYERS)]
|
||||
self.ffn_layers = [_DPFFN() for _ in range(DP_ATTN_NUM_LAYERS)]
|
||||
self.norm_layers_2 = [WrappedNorm(DP_DIM, eps=EPS_LN) for _ in range(DP_ATTN_NUM_LAYERS)]
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
for i in range(DP_ATTN_NUM_LAYERS):
|
||||
x = self.norm_layers_1[i](x + self.attn_layers[i](x, mask=mask))
|
||||
x = self.norm_layers_2[i](x + self.ffn_layers[i](x, mask))
|
||||
return x
|
||||
|
||||
|
||||
class _DPSentenceEncoder(nn.Module):
|
||||
"""Text → 64-d sentence vector via prepended ``sentence_token`` slot."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
class _TextEmb(nn.Module):
|
||||
def __init__(_):
|
||||
super().__init__()
|
||||
_.char_embedder = nn.Embedding(DP_VOCAB, DP_DIM)
|
||||
def __call__(_, ids):
|
||||
return _.char_embedder(ids)
|
||||
self.text_embedder = _TextEmb()
|
||||
self.convnext = _DPConvNeXtStack()
|
||||
self.attn_encoder = _DPAttnEncoder()
|
||||
# proj_out keeps the .net.weight (out, 1, in) Conv1d-k1 layout
|
||||
self.proj_out = _DPProjOut()
|
||||
# sentence_token (1, DIM, 1) — prepended as the first time slot
|
||||
self.sentence_token = mx.zeros((1, DP_DIM, 1))
|
||||
|
||||
def __call__(self, text_ids: mx.array, text_mask: mx.array) -> mx.array:
|
||||
x = self.text_embedder(text_ids) # (B, T, 64)
|
||||
# Prepend sentence_token: shape (1, 64, 1) → (B, 1, 64)
|
||||
B = x.shape[0]
|
||||
sentence = self.sentence_token.transpose(0, 2, 1)
|
||||
sentence = mx.broadcast_to(sentence, (B, 1, DP_DIM))
|
||||
x = mx.concatenate([sentence, x], axis=1) # (B, T+1, 64)
|
||||
|
||||
# Extend mask with a leading 1 (sentence token always valid)
|
||||
if text_mask is not None:
|
||||
extra = mx.ones((B, 1, 1), dtype=text_mask.dtype)
|
||||
mask_ntc = mx.concatenate([extra, text_mask.transpose(0, 2, 1)], axis=1)
|
||||
else:
|
||||
mask_ntc = None
|
||||
|
||||
x = self.convnext(x, mask_ntc)
|
||||
x = self.attn_encoder(x, mask_ntc)
|
||||
|
||||
# Take slot 0 (sentence token output) → (B, 1, 64)
|
||||
sentence_out = x[:, :1, :] # (B, 1, 64)
|
||||
# proj_out (Conv1d k=1) — applied along time, output (B, 1, 64)
|
||||
sentence_out = self.proj_out(sentence_out)
|
||||
return sentence_out.reshape(B, DP_DIM) # (B, 64)
|
||||
|
||||
|
||||
class _DPProjOut(nn.Module):
|
||||
"""Conv1d k=1 64→64. No bias in ONNX (confirmed via graph inspection)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
class _Net(nn.Module):
|
||||
def __init__(_):
|
||||
super().__init__()
|
||||
_.weight = mx.zeros((DP_DIM, 1, DP_DIM))
|
||||
def __call__(_, x):
|
||||
return mx.conv1d(x, _.weight, stride=1, padding=0)
|
||||
self.net = _Net()
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class _DPPredictor(nn.Module):
|
||||
"""Linear(192 → 128) + PReLU + Linear(128 → 1).
|
||||
|
||||
PReLU is stored under ``activation.weight (1,)`` — a single learnable
|
||||
negative-slope coefficient.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layers = [
|
||||
nn.Linear(DP_MLP_IN, DP_MLP_HIDDEN, bias=True),
|
||||
nn.Linear(DP_MLP_HIDDEN, 1, bias=True),
|
||||
]
|
||||
# PReLU: activation.weight shape (1,) — single scalar slope
|
||||
class _Activation(nn.Module):
|
||||
def __init__(_):
|
||||
super().__init__()
|
||||
_.weight = mx.zeros((1,))
|
||||
def __call__(_, x):
|
||||
# PReLU(x) = max(0, x) + slope * min(0, x)
|
||||
neg = mx.minimum(x, mx.array(0.0, dtype=x.dtype))
|
||||
pos = mx.maximum(x, mx.array(0.0, dtype=x.dtype))
|
||||
return pos + _.weight * neg
|
||||
self.activation = _Activation()
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
h = self.layers[0](x) # (B, 128)
|
||||
h = self.activation(h)
|
||||
h = self.layers[1](h) # (B, 1)
|
||||
return h
|
||||
|
||||
|
||||
class _DPRoot(nn.Module):
|
||||
"""``tts.dp.X`` namespace container."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.sentence_encoder = _DPSentenceEncoder()
|
||||
self.predictor = _DPPredictor()
|
||||
|
||||
|
||||
class _DPContainer(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.dp = _DPRoot()
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
"""Predicts total audio duration (seconds) for an utterance.
|
||||
|
||||
Submodule namespace matches ONNX keys ``tts.dp.X.Y`` exactly.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.tts = _DPContainer()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text_ids: mx.array, # (B, T) int
|
||||
style_dp: mx.array, # (B, 8, 16)
|
||||
text_mask: mx.array, # (B, 1, T)
|
||||
) -> mx.array:
|
||||
sentence = self.tts.dp.sentence_encoder(text_ids, text_mask) # (B, 64)
|
||||
style_flat = style_dp.reshape(style_dp.shape[0], -1) # (B, 128)
|
||||
joined = mx.concatenate([sentence, style_flat], axis=-1) # (B, 192)
|
||||
log_dur = self.tts.dp.predictor(joined).reshape(-1) # (B,)
|
||||
return mx.exp(log_dur) # duration in seconds
|
||||
|
||||
|
||||
__all__ = ["DurationPredictor"]
|
||||
545
src/supertonic_3_mlx/pipeline.py
Normal file
545
src/supertonic_3_mlx/pipeline.py
Normal file
@@ -0,0 +1,545 @@
|
||||
"""Supertonic 3 end-to-end MLX pipeline.
|
||||
|
||||
Stitches the four MLX sub-models (DurationPredictor → TextEncoder →
|
||||
VectorEstimator → Vocoder) into a single ``generate(text, voice, lang)`` call
|
||||
that returns a 44.1 kHz mono numpy waveform.
|
||||
|
||||
Flow:
|
||||
|
||||
text ──tokenize(unicode_indexer)──▶ text_ids (B, T_text)
|
||||
│
|
||||
voice_style (.json) ──▶ style_ttl (B, 50, 256), style_dp (B, 8, 16)
|
||||
│
|
||||
duration_predictor(text_ids, style_dp, text_mask) ──▶ duration_s (B,)
|
||||
│
|
||||
text_encoder(text_ids, style_ttl, text_mask) ──▶ text_emb (B, 256, T_text)
|
||||
│
|
||||
noise ~ N(0, I) of shape (B, 144, T_lat)
|
||||
where T_lat = ceil(duration_s × 44100 / (512 × 6))
|
||||
│
|
||||
vector_estimator 5-step Euler with CFG (4×cond − 3×uncond):
|
||||
for step in [0..4]:
|
||||
x ← VE(x, text_emb, style_ttl, masks, current_step=step+1, total_step=5)
|
||||
│
|
||||
vocoder(audio_latent) ──▶ wav (B, T_lat × 6 × 512)
|
||||
|
||||
Public API:
|
||||
|
||||
pipe = SupertonicMLXPipeline.from_pretrained("/tmp/supertonic3/model")
|
||||
wav = pipe.generate("Hello world", voice="F1", lang="en")
|
||||
import soundfile as sf
|
||||
sf.write("out.wav", wav, pipe.sample_rate)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
from supertonic_3_mlx._config import SAMPLE_RATE
|
||||
from supertonic_3_mlx.duration_predictor import DurationPredictor
|
||||
from supertonic_3_mlx.text_encoder import TextEncoder
|
||||
from supertonic_3_mlx.vector_estimator import VectorEstimator
|
||||
from supertonic_3_mlx.vocoder import Vocoder
|
||||
|
||||
|
||||
# Latent rate: at 44.1 kHz with hop=512 and chunk_compress=6, one latent step
|
||||
# covers 512 × 6 = 3072 samples = 69.7 ms.
|
||||
SAMPLES_PER_LATENT_STEP = 512 * 6 # 3072
|
||||
|
||||
|
||||
# ── Shared ONNX → MLX weight extraction ─────────────────────────────
|
||||
|
||||
|
||||
def _convert_onnx(onnx_path: str | Path) -> dict:
|
||||
"""Return a dict of ``{clean_key: mx.array}`` for a Supertonic ONNX file.
|
||||
|
||||
Combines the three extraction stages discovered during the per-component
|
||||
ports (T.3.1, T.3.2, T.3.3):
|
||||
|
||||
1. Named ``tts.*`` initialisers with shape transforms (dwconv, gamma,
|
||||
pwconv, head.layer2).
|
||||
2. Anonymous MatMul weights recovered via the MatMul output path.
|
||||
3. Anonymous Conv weights and PReLU slopes recovered the same way.
|
||||
"""
|
||||
import onnx
|
||||
import onnx.numpy_helper as nh
|
||||
|
||||
m = onnx.load(str(onnx_path))
|
||||
|
||||
def _matmul_clean(out_name: str) -> str:
|
||||
p = out_name.lstrip("/")
|
||||
if p.endswith("/MatMul_output_0"):
|
||||
p = p[: -len("/MatMul_output_0")]
|
||||
# Drop the leading model-name path (e.g. /text_encoder/, /duration_predictor/, /vector_estimator/)
|
||||
for prefix in ("text_encoder/", "duration_predictor/", "vector_estimator/", "vocoder/"):
|
||||
if p.startswith(prefix):
|
||||
p = p[len(prefix):]
|
||||
break
|
||||
return p.replace("/", ".") + ".weight"
|
||||
|
||||
def _conv_clean(out_name: str) -> str:
|
||||
p = out_name.lstrip("/")
|
||||
if p.endswith("/Conv_output_0"):
|
||||
p = p[: -len("/Conv_output_0")]
|
||||
for prefix in ("vocoder/", "vector_estimator/", "text_encoder/", "duration_predictor/"):
|
||||
if p.startswith(prefix):
|
||||
p = p[len(prefix):]
|
||||
break
|
||||
return "tts.ae." + p.replace("/", ".")
|
||||
|
||||
def _prelu_clean(out_name: str) -> str:
|
||||
p = out_name.lstrip("/")
|
||||
if p.endswith("/PRelu_output_0"):
|
||||
p = p[: -len("/PRelu_output_0")]
|
||||
for prefix in ("vocoder/", "vector_estimator/"):
|
||||
if p.startswith(prefix):
|
||||
p = p[len(prefix):]
|
||||
break
|
||||
return "tts.ae." + p.replace("/", ".") + ".weight"
|
||||
|
||||
# Detect which model this file is — affects how we wrap named init keys
|
||||
name_prefixes = {init.name.split(".")[0] for init in m.graph.initializer if "." in init.name}
|
||||
is_text_encoder = "tts" in name_prefixes and any(
|
||||
i.name.startswith("tts.ttl.text_encoder") for i in m.graph.initializer
|
||||
)
|
||||
|
||||
weights: dict[str, mx.array] = {}
|
||||
|
||||
# Stage 1: named initialisers
|
||||
for init in m.graph.initializer:
|
||||
n = init.name
|
||||
# Determine if this is a structured (named) weight or an anonymous graph const
|
||||
if not (n.startswith("tts.") or "vector_estimator.tts.ttl." in n or "uncond_masker." in n):
|
||||
continue
|
||||
|
||||
# Strip the vector_estimator-specific prefix so all 4 models share a name space.
|
||||
if n.startswith("vector_estimator.tts.ttl."):
|
||||
clean = n[len("vector_estimator.tts.ttl."):]
|
||||
else:
|
||||
clean = n
|
||||
|
||||
arr = nh.to_array(init)
|
||||
|
||||
# Shape transforms
|
||||
if (clean.endswith(".dwconv.weight") and arr.ndim == 3
|
||||
and arr.shape[1] == 1 and arr.shape[2] != 1):
|
||||
arr = np.transpose(arr, (0, 2, 1))
|
||||
if (clean.endswith(".dwconv.net.weight") and arr.ndim == 3
|
||||
and arr.shape[1] == 1):
|
||||
arr = np.transpose(arr, (0, 2, 1))
|
||||
if (clean.endswith(".gamma") and arr.ndim == 3
|
||||
and arr.shape[0] == 1 and arr.shape[2] == 1):
|
||||
arr = arr.reshape(arr.shape[1])
|
||||
if ((clean.endswith(".pwconv1.weight") or clean.endswith(".pwconv2.weight"))
|
||||
and arr.ndim == 3 and arr.shape[-1] == 1):
|
||||
arr = arr.squeeze(-1)
|
||||
if clean.endswith(".net.weight") and arr.ndim == 3 and arr.shape[-1] == 1:
|
||||
# Conv1d k=1 wrapped via .net (e.g. proj_in/proj_out)
|
||||
arr = arr.squeeze(-1)
|
||||
# vocoder head.layer2 (out, in, 1) → MLX Conv1d (out, K=1, in)
|
||||
if clean == "tts.ae.decoder.head.layer2.weight" and arr.ndim == 3:
|
||||
arr = np.transpose(arr, (0, 2, 1))
|
||||
# vocoder head.layer1.net.weight (out, in, K) → MLX Conv1d (out, K, in)
|
||||
if clean == "tts.ae.decoder.head.layer1.net.weight" and arr.ndim == 3:
|
||||
arr = np.transpose(arr, (0, 2, 1))
|
||||
|
||||
weights[clean] = mx.array(arr)
|
||||
|
||||
# Stage 2: MatMul weight recovery
|
||||
inits_map = {init.name: init for init in m.graph.initializer}
|
||||
for node in m.graph.node:
|
||||
if node.op_type != "MatMul" or len(node.input) < 2:
|
||||
continue
|
||||
winp = node.input[1]
|
||||
if winp not in inits_map or winp.startswith("tts.") or "vector_estimator.tts" in winp:
|
||||
continue
|
||||
arr = nh.to_array(inits_map[winp])
|
||||
if arr.ndim == 2:
|
||||
arr = arr.T # ONNX (in, out) → MLX Linear (out, in)
|
||||
clean = _matmul_clean(node.output[0])
|
||||
# Build the leading namespace from the file context (already in tts.*)
|
||||
if not clean.startswith(("tts.", "vector_field.", "uncond_masker.")):
|
||||
clean = "tts.ttl." + clean if is_text_encoder else clean
|
||||
weights[clean] = mx.array(arr)
|
||||
|
||||
# Stage 3: anonymous Conv + PReLU (vocoder embed / head)
|
||||
for node in m.graph.node:
|
||||
if node.op_type == "Conv":
|
||||
for i, inp in enumerate(node.input[1:], 1):
|
||||
if inp not in inits_map or inp.startswith("tts."):
|
||||
continue
|
||||
arr = nh.to_array(inits_map[inp])
|
||||
base = _conv_clean(node.output[0])
|
||||
if "dwconv" in base:
|
||||
continue
|
||||
if i == 1 and arr.ndim == 3:
|
||||
arr = np.transpose(arr, (0, 2, 1)) # ONNX (out, in, K) → MLX (out, K, in)
|
||||
key = base + (".weight" if i == 1 else ".bias")
|
||||
weights[key] = mx.array(arr)
|
||||
elif node.op_type == "PRelu":
|
||||
for inp in node.input[1:]:
|
||||
if inp in inits_map and not inp.startswith("tts."):
|
||||
weights[_prelu_clean(node.output[0])] = mx.array(nh.to_array(inits_map[inp]))
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def _load_into(model, weights: dict) -> int:
|
||||
"""Match converted weights to model params (shape-tolerant via reshape).
|
||||
|
||||
Returns the number of successfully matched tensors.
|
||||
"""
|
||||
from mlx.utils import tree_flatten
|
||||
expected = {k: tuple(v.shape) for k, v in tree_flatten(model.parameters())}
|
||||
matched = {}
|
||||
for k, exp_shape in expected.items():
|
||||
if k not in weights:
|
||||
continue
|
||||
v = weights[k]
|
||||
if tuple(v.shape) != exp_shape:
|
||||
if v.size == np.prod(exp_shape):
|
||||
v = v.reshape(exp_shape)
|
||||
else:
|
||||
continue
|
||||
matched[k] = v
|
||||
model.load_weights(list(matched.items()), strict=False)
|
||||
return len(matched)
|
||||
|
||||
|
||||
# ── Tokenization ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _encode_text(text: str, indexer: list[int], lang: str = "en") -> np.ndarray:
|
||||
"""Encode a text string into character IDs.
|
||||
|
||||
The unicode_indexer is a flat list of size 65536; ``indexer[ord(c)]`` gives
|
||||
the token ID for character ``c`` (-1 = unknown). For Phase T.4 we wrap the
|
||||
text with no special language tokens — the ONNX SDK uses language tags but
|
||||
our pipeline currently runs unconditioned on language for the first WAV
|
||||
emission (parity validation happens after).
|
||||
"""
|
||||
ids = []
|
||||
for c in text:
|
||||
cp = ord(c)
|
||||
if 0 <= cp < len(indexer):
|
||||
tok = indexer[cp]
|
||||
if tok >= 0:
|
||||
ids.append(tok)
|
||||
if not ids:
|
||||
# fallback to a single space token to avoid empty input
|
||||
ids = [indexer[ord(" ")]] if indexer[ord(" ")] >= 0 else [0]
|
||||
return np.asarray(ids, dtype=np.int32)
|
||||
|
||||
|
||||
# ── Pipeline ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class SupertonicMLXPipeline:
|
||||
"""End-to-end Supertonic 3 TTS pipeline in pure MLX.
|
||||
|
||||
Loads four sub-models (duration_predictor, text_encoder, vector_estimator,
|
||||
vocoder), the unicode tokenizer, and exposes ``generate(text, voice, lang)``.
|
||||
"""
|
||||
|
||||
sample_rate: int = SAMPLE_RATE
|
||||
# Locked by the model architecture: Supertonic 3 is a flow-matching + CFG
|
||||
# model trained for exactly 5 Euler steps with t ∈ {0.2, 0.4, 0.6, 0.8, 1.0}
|
||||
# and the combination 4×cond − 3×uncond. Any other step count or skipping
|
||||
# CFG produces an essentially uncorrelated waveform (verified by
|
||||
# ``sub-projects/supertonic3-mlx/bench_n_steps.py``: cosine drops to
|
||||
# ≤ 0.5 for n∈{3,4,6} and ≈ 0.05 for cfg=False). Reducing inference
|
||||
# latency further would require distilling a shorter-schedule model.
|
||||
n_euler_steps: int = 5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
duration_predictor: DurationPredictor,
|
||||
text_encoder: TextEncoder,
|
||||
vector_estimator: VectorEstimator,
|
||||
vocoder: Vocoder,
|
||||
unicode_indexer: list[int],
|
||||
voice_dir: Path,
|
||||
) -> None:
|
||||
self.duration_predictor = duration_predictor
|
||||
self.text_encoder = text_encoder
|
||||
self.vector_estimator = vector_estimator
|
||||
self.vocoder = vocoder
|
||||
self.unicode_indexer = unicode_indexer
|
||||
self.voice_dir = voice_dir
|
||||
|
||||
# T.5 — compile the hot loops. ``mx.compile`` caches a kernel graph keyed
|
||||
# by input shapes; the 5× CFG Euler loop and the single vocoder pass
|
||||
# both gain from fused kernel dispatch (~50–100 layer ops collapse into
|
||||
# one dispatch per cached graph).
|
||||
|
||||
# T.5.3 — also pre-project text and style K/V outside the step. They
|
||||
# are invariant across the 5 Euler steps, so the 4 text_attn + 4
|
||||
# style_attn blocks no longer re-run their W_key / W_value / RoPE_K
|
||||
# matmuls on every step (saves 40 matmuls per generate).
|
||||
cond_scale = self.vector_estimator.CFG_COND_SCALE
|
||||
uncond_scale = self.vector_estimator.CFG_UNCOND_SCALE
|
||||
|
||||
def _cached_step(
|
||||
noisy, lat_mask_2, text_mask_2, t_norm_2, total_step, kv_flat,
|
||||
):
|
||||
noisy_2 = mx.concatenate([noisy, noisy], axis=0)
|
||||
text_kv = [(kv_flat[2 * i], kv_flat[2 * i + 1]) for i in range(4)]
|
||||
style_kv = [(kv_flat[8 + 2 * i], kv_flat[8 + 2 * i + 1]) for i in range(4)]
|
||||
v_2 = self.vector_estimator.velocity_cached(
|
||||
noisy_2, lat_mask_2, text_mask_2, t_norm_2, text_kv, style_kv,
|
||||
)
|
||||
B = noisy.shape[0]
|
||||
cond_v = v_2[:B]
|
||||
uncond_v = v_2[B:]
|
||||
combined = cond_scale * cond_v - uncond_scale * uncond_v
|
||||
return noisy + combined / total_step.reshape(-1, 1, 1).astype(combined.dtype)
|
||||
|
||||
def _voc_step(latent):
|
||||
return self.vocoder(latent)
|
||||
|
||||
self._cached_step_compiled = mx.compile(_cached_step)
|
||||
self._voc_compiled = mx.compile(_voc_step)
|
||||
|
||||
# Pick the runtime dtype from any leaf weight of the vector estimator —
|
||||
# ``from_pretrained(dtype=...)`` may have cast the model to ``bf16``,
|
||||
# in which case all inputs to the compiled hot loops must be cast to
|
||||
# match (mixed-dtype Conv/MatMul is not legal in MLX).
|
||||
from mlx.utils import tree_flatten
|
||||
leaves = [v for _, v in tree_flatten(vector_estimator.parameters())
|
||||
if isinstance(v, mx.array)]
|
||||
self.dtype = leaves[0].dtype if leaves else mx.float32
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_id_or_path: str | Path,
|
||||
dtype: mx.Dtype | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
revision: str | None = None,
|
||||
) -> "SupertonicMLXPipeline":
|
||||
"""Construct the pipeline from a model snapshot.
|
||||
|
||||
Three sources are accepted, auto-detected:
|
||||
|
||||
1. **Hugging Face Hub repo id** (e.g. ``"ambassadia/supertonic-3-mlx"``):
|
||||
weights are downloaded via :func:`huggingface_hub.snapshot_download`
|
||||
into ``cache_dir`` (defaults to the standard HF cache) and loaded
|
||||
directly from the bundled ``weights/*.safetensors`` files.
|
||||
2. **Local path with a** ``weights/`` **subdir**: the MLX-native
|
||||
layout (4 safetensors + ``unicode_indexer.json`` + ``voice_styles/``).
|
||||
Fast path — no ONNX conversion at runtime.
|
||||
3. **Local path with an** ``onnx/`` **subdir**: the upstream
|
||||
``Supertone/supertonic-3`` snapshot layout. Weights are converted
|
||||
from ONNX on the fly (~ 1 s per sub-model on M4). Useful for
|
||||
development or when starting from the original upstream release.
|
||||
|
||||
Optional kwargs:
|
||||
dtype — if non-None and not float32, cast all weights to the
|
||||
given dtype after load (only ``mx.bfloat16`` is
|
||||
currently meaningful; see README "BF16 note").
|
||||
cache_dir — passed to ``huggingface_hub.snapshot_download``.
|
||||
revision — branch / tag / commit sha on the Hub.
|
||||
"""
|
||||
# 1. Resolve the local snapshot directory
|
||||
if isinstance(model_id_or_path, str) and "/" in model_id_or_path \
|
||||
and not Path(model_id_or_path).exists():
|
||||
try:
|
||||
from huggingface_hub import snapshot_download
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Loading from the Hugging Face Hub requires "
|
||||
"``huggingface_hub`` — install with ``pip install "
|
||||
"supertonic-3-mlx[hub]`` or ``pip install huggingface_hub``."
|
||||
) from e
|
||||
local_dir = Path(snapshot_download(
|
||||
repo_id=model_id_or_path,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
allow_patterns=[
|
||||
"weights/*.safetensors",
|
||||
"unicode_indexer.json",
|
||||
"voice_styles/*.json",
|
||||
],
|
||||
))
|
||||
else:
|
||||
local_dir = Path(model_id_or_path)
|
||||
|
||||
# 2. Detect layout
|
||||
weights_dir = local_dir / "weights"
|
||||
onnx_dir = local_dir / "onnx"
|
||||
if weights_dir.exists():
|
||||
return cls._from_safetensors(local_dir, dtype=dtype)
|
||||
if onnx_dir.exists():
|
||||
return cls._from_onnx(local_dir, dtype=dtype)
|
||||
raise FileNotFoundError(
|
||||
f"{local_dir} contains neither ``weights/`` (safetensors layout) "
|
||||
f"nor ``onnx/`` (upstream layout); cannot load."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_safetensors(
|
||||
cls, local_dir: Path, dtype: mx.Dtype | None = None,
|
||||
) -> "SupertonicMLXPipeline":
|
||||
from mlx.utils import tree_flatten
|
||||
weights_dir = local_dir / "weights"
|
||||
voice_dir = local_dir / "voice_styles"
|
||||
unicode_indexer = json.loads((local_dir / "unicode_indexer.json").read_text())
|
||||
|
||||
def _build(cls_, name):
|
||||
model = cls_()
|
||||
w = mx.load(str(weights_dir / f"{name}.safetensors"))
|
||||
# Reshape any mismatched leaves (defensive; the converter already
|
||||
# produced shape-correct tensors but a future re-export may not).
|
||||
expected = {k: tuple(v.shape) for k, v in tree_flatten(model.parameters())}
|
||||
for k in list(w.keys()):
|
||||
if k in expected and tuple(w[k].shape) != expected[k]:
|
||||
if w[k].size == int(np.prod(expected[k])):
|
||||
w[k] = w[k].reshape(expected[k])
|
||||
model.load_weights(list(w.items()), strict=False)
|
||||
return model
|
||||
|
||||
ve = _build(VectorEstimator, "vector_estimator")
|
||||
te = _build(TextEncoder, "text_encoder")
|
||||
dp = _build(DurationPredictor, "duration_predictor")
|
||||
voc = _build(Vocoder, "vocoder")
|
||||
|
||||
if dtype is not None and dtype != mx.float32:
|
||||
cls._cast_all(dp, te, ve, voc, dtype=dtype)
|
||||
|
||||
return cls(dp, te, ve, voc, unicode_indexer, voice_dir)
|
||||
|
||||
@classmethod
|
||||
def _from_onnx(
|
||||
cls, local_dir: Path, dtype: mx.Dtype | None = None,
|
||||
) -> "SupertonicMLXPipeline":
|
||||
onnx_dir = local_dir / "onnx"
|
||||
voice_dir = local_dir / "voice_styles"
|
||||
unicode_indexer = json.loads((onnx_dir / "unicode_indexer.json").read_text())
|
||||
|
||||
ve = VectorEstimator()
|
||||
_load_into(ve, _convert_onnx(onnx_dir / "vector_estimator.onnx"))
|
||||
te = TextEncoder()
|
||||
_load_into(te, _convert_onnx(onnx_dir / "text_encoder.onnx"))
|
||||
dp = DurationPredictor()
|
||||
_load_into(dp, _convert_onnx(onnx_dir / "duration_predictor.onnx"))
|
||||
voc = Vocoder()
|
||||
_load_into(voc, _convert_onnx(onnx_dir / "vocoder.onnx"))
|
||||
|
||||
if dtype is not None and dtype != mx.float32:
|
||||
cls._cast_all(dp, te, ve, voc, dtype=dtype)
|
||||
|
||||
return cls(dp, te, ve, voc, unicode_indexer, voice_dir)
|
||||
|
||||
@staticmethod
|
||||
def _cast_all(*models, dtype: mx.Dtype) -> None:
|
||||
"""Cast all fp32 leaves of each model to ``dtype`` (in-place)."""
|
||||
from mlx.utils import tree_map
|
||||
|
||||
def _cast(p):
|
||||
if not isinstance(p, mx.array) or p.dtype != mx.float32:
|
||||
return p
|
||||
return p.astype(dtype)
|
||||
|
||||
for m_ in models:
|
||||
m_.update(tree_map(_cast, m_.parameters()))
|
||||
|
||||
def _load_voice(self, voice: str) -> tuple[mx.array, mx.array]:
|
||||
"""Load ``voice_styles/<voice>.json`` and return (style_ttl, style_dp)."""
|
||||
path = self.voice_dir / f"{voice}.json"
|
||||
data = json.loads(path.read_text())
|
||||
style_ttl = np.asarray(data["style_ttl"]["data"], dtype=np.float32) # (1, 50, 256)
|
||||
style_dp = np.asarray(data["style_dp"]["data"], dtype=np.float32) # (1, 8, 16)
|
||||
return mx.array(style_ttl), mx.array(style_dp)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
text: str,
|
||||
voice: str = "F1",
|
||||
lang: str = "en",
|
||||
seed: int = 42,
|
||||
n_steps: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
"""Synthesise a single utterance. Returns a 1D float32 numpy waveform."""
|
||||
n_steps = n_steps if n_steps is not None else self.n_euler_steps
|
||||
|
||||
# Tokenize
|
||||
text_ids_np = _encode_text(text, self.unicode_indexer, lang)
|
||||
text_ids = mx.array(text_ids_np[None, :]) # (1, T_text)
|
||||
T_text = text_ids.shape[1]
|
||||
text_mask = mx.ones((1, 1, T_text), dtype=self.dtype)
|
||||
|
||||
# Style
|
||||
style_ttl, style_dp = self._load_voice(voice)
|
||||
if self.dtype != mx.float32:
|
||||
style_ttl = style_ttl.astype(self.dtype)
|
||||
style_dp = style_dp.astype(self.dtype)
|
||||
|
||||
# Duration → latent length
|
||||
duration_s = self.duration_predictor(text_ids, style_dp, text_mask)
|
||||
mx.eval(duration_s)
|
||||
duration_val = max(float(duration_s[0].item()), 0.5) # clamp to ≥ 0.5 s
|
||||
T_lat = max(int(math.ceil(duration_val * self.sample_rate / SAMPLES_PER_LATENT_STEP)), 1)
|
||||
|
||||
# Text embedding
|
||||
text_emb = self.text_encoder(text_ids, style_ttl, text_mask) # (1, 256, T_text)
|
||||
|
||||
# Initial noise — fixed seed for reproducibility
|
||||
key = mx.random.key(seed)
|
||||
noise = mx.random.normal((1, 144, T_lat), key=key).astype(self.dtype)
|
||||
latent_mask = mx.ones((1, 1, T_lat), dtype=self.dtype)
|
||||
|
||||
# T.5.3 — build the (2B) CFG conditioning tensors once and pre-project
|
||||
# K/V for every text_attn / style_attn block. ``kv_flat`` is the 16
|
||||
# ``(K, V)`` arrays flattened into a list for the compiled step.
|
||||
B = noise.shape[0]
|
||||
ve = self.vector_estimator
|
||||
text_uncond = mx.broadcast_to(
|
||||
ve.uncond_masker.text_special_token, (B, text_emb.shape[1], text_emb.shape[2])
|
||||
).astype(self.dtype)
|
||||
style_k_uncond = mx.broadcast_to(
|
||||
ve.uncond_masker.style_key_special_token, (B, style_ttl.shape[1], style_ttl.shape[2])
|
||||
).astype(self.dtype)
|
||||
style_v_uncond = mx.broadcast_to(
|
||||
ve.uncond_masker.style_value_special_token, (B, style_ttl.shape[1], style_ttl.shape[2])
|
||||
).astype(self.dtype)
|
||||
text_emb_2 = mx.concatenate([text_emb, text_uncond], axis=0)
|
||||
style_k_2 = mx.concatenate([style_ttl, style_k_uncond], axis=0)
|
||||
style_v_2 = mx.concatenate([style_ttl, style_v_uncond], axis=0)
|
||||
text_mask_2 = mx.concatenate([text_mask, text_mask], axis=0)
|
||||
latent_mask_2 = mx.concatenate([latent_mask, latent_mask], axis=0)
|
||||
|
||||
text_kv, style_kv = ve.precompute_cross_kv(
|
||||
text_emb_2, style_k_2, style_v_2, text_mask_2,
|
||||
)
|
||||
kv_flat = []
|
||||
for k, v in text_kv:
|
||||
kv_flat.extend([k, v])
|
||||
for k, v in style_kv:
|
||||
kv_flat.extend([k, v])
|
||||
|
||||
# Euler with CFG — 5 steps by default
|
||||
x = noise
|
||||
total_step = mx.array([float(n_steps)], dtype=self.dtype)
|
||||
for step in range(n_steps):
|
||||
current_step = mx.array([float(step + 1)], dtype=self.dtype)
|
||||
t_norm = current_step / total_step
|
||||
t_norm_2 = mx.concatenate([t_norm, t_norm], axis=0)
|
||||
x = self._cached_step_compiled(
|
||||
x, latent_mask_2, text_mask_2, t_norm_2, total_step, kv_flat,
|
||||
)
|
||||
mx.eval(x)
|
||||
|
||||
# Decode latent → waveform
|
||||
wav = self._voc_compiled(x)
|
||||
mx.eval(wav)
|
||||
if wav.dtype != mx.float32:
|
||||
wav = wav.astype(mx.float32)
|
||||
return np.array(wav)[0] # (T_lat × 6 × 512,)
|
||||
|
||||
|
||||
__all__ = ["SupertonicMLXPipeline"]
|
||||
382
src/supertonic_3_mlx/text_encoder.py
Normal file
382
src/supertonic_3_mlx/text_encoder.py
Normal file
@@ -0,0 +1,382 @@
|
||||
"""Supertonic 3 text encoder MLX port.
|
||||
|
||||
Pipeline (operating in channels-last NTC after the initial conv):
|
||||
|
||||
text_ids [B, T_text] int64 character IDs
|
||||
→ char_embedder (Embedding 8322→256) [B, T_text, 256]
|
||||
→ 6× ConvNeXt(dim=256, hidden=1024, k=5, dilations [1,1,2,2,4,4])
|
||||
→ 4× attn_encoder block:
|
||||
RelPosSelfAttn (conv_q/k/v/o, 4 heads × 64) + norm_layers_1
|
||||
FFN (conv_1: 256→1024, conv_2: 1024→256) + norm_layers_2
|
||||
→ speech_prompted_text_encoder:
|
||||
cross-attn1: text (Q) × style_ttl (K, V) → text features
|
||||
cross-attn2: text (Q) × style_ttl (K, V) → text features
|
||||
norm
|
||||
→ output text_emb [B, 256, T_text] (channels-first to match vector_estimator)
|
||||
|
||||
Inputs:
|
||||
text_ids: (B, T_text) int — character indices
|
||||
style_ttl: (B, 50, 256) float — style token bank
|
||||
text_mask: (B, 1, T_text) float — 1.0 where valid, 0.0 where padded
|
||||
|
||||
Submodule naming matches the ONNX initializer keys exactly so that
|
||||
``model.load_weights(...)`` succeeds with no remapping.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from supertonic_3_mlx._config import EPS_LN
|
||||
from supertonic_3_mlx._nn_wrappers import WrappedNorm, WrappedLinear
|
||||
from supertonic_3_mlx.vector_estimator import (
|
||||
ConvNeXtBlock, _pad_sym_edge, _gelu_exact,
|
||||
)
|
||||
|
||||
|
||||
# Vocab + dims (frozen by checkpoint)
|
||||
VOCAB_SIZE = 8322
|
||||
TE_DIM = 256
|
||||
TE_CONVNEXT_HIDDEN = 1024
|
||||
TE_CONVNEXT_K = 5
|
||||
TE_CONVNEXT_NUM_LAYERS = 6
|
||||
TE_CONVNEXT_DILATIONS = (1, 1, 2, 2, 4, 4)
|
||||
|
||||
TE_ATTN_NUM_LAYERS = 4
|
||||
TE_ATTN_HEADS = 4
|
||||
TE_ATTN_HEAD_DIM = TE_DIM // TE_ATTN_HEADS # 64
|
||||
TE_FFN_HIDDEN = 1024
|
||||
|
||||
|
||||
class TextConvNeXtBlock(nn.Module):
|
||||
"""ConvNeXt for the text encoder (dim=256, hidden=1024).
|
||||
|
||||
Shares the same architecture as ``vector_estimator.ConvNeXtBlock`` but is
|
||||
redefined here with text-encoder-specific defaults to keep the modules
|
||||
self-contained.
|
||||
"""
|
||||
|
||||
def __init__(self, dilation: int = 1) -> None:
|
||||
super().__init__()
|
||||
self.dim = TE_DIM
|
||||
self.dilation = dilation
|
||||
self.pad = dilation * (TE_CONVNEXT_K - 1) // 2
|
||||
self.dwconv = nn.Conv1d(
|
||||
TE_DIM, TE_DIM, kernel_size=TE_CONVNEXT_K, padding=0,
|
||||
dilation=dilation, groups=TE_DIM, bias=True,
|
||||
)
|
||||
self.norm = WrappedNorm(TE_DIM, eps=EPS_LN)
|
||||
self.pwconv1 = nn.Linear(TE_DIM, TE_CONVNEXT_HIDDEN, bias=True)
|
||||
self.pwconv2 = nn.Linear(TE_CONVNEXT_HIDDEN, TE_DIM, bias=True)
|
||||
self.gamma = mx.zeros((TE_DIM,))
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
# x: (B, T_text, 256)
|
||||
residual = x
|
||||
y = _pad_sym_edge(x, self.pad)
|
||||
y = self.dwconv(y)
|
||||
y = self.norm(y)
|
||||
y = self.pwconv1(y)
|
||||
y = _gelu_exact(y)
|
||||
y = self.pwconv2(y)
|
||||
y = y * self.gamma
|
||||
out = residual + y
|
||||
if mask is not None:
|
||||
out = out * mask
|
||||
return out
|
||||
|
||||
|
||||
class TextConvNeXtStack(nn.Module):
|
||||
"""6 stacked ConvNeXt blocks. Loaded as ``convnext.convnext.[0..5].X``."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.convnext = [TextConvNeXtBlock(d) for d in TE_CONVNEXT_DILATIONS]
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
for b in self.convnext:
|
||||
x = b(x, mask)
|
||||
return x
|
||||
|
||||
|
||||
class _ConvLayer(nn.Module):
|
||||
"""Conv1d k=1 expressed via the ONNX-style ``X.weight (out, in, 1) + X.bias``.
|
||||
|
||||
The attn_encoder uses Conv1d k=1 instead of nn.Linear for its Q/K/V/O.
|
||||
This wrapper keeps the weight shape (out, in, 1) intact and runs as a
|
||||
Conv1d (the equivalent of a Linear when k=1).
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.weight = mx.zeros((out_dim, 1, in_dim)) # (C_out, K=1, C_in)
|
||||
self.bias = mx.zeros((out_dim,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (B, T, in_dim) — channels-last
|
||||
# equivalent to nn.Conv1d(in_dim, out_dim, k=1) in NTC layout
|
||||
return mx.conv1d(x, self.weight, stride=1, padding=0) + self.bias
|
||||
|
||||
|
||||
REL_POS_WINDOW = 4 # rel_pos table size = 2*4 + 1 = 9
|
||||
|
||||
|
||||
def _rel_to_abs(x: mx.array) -> mx.array:
|
||||
"""[B, h, L, 2L-1] → [B, h, L, L] via the VITS shifted-skew reshape."""
|
||||
B, h, L, _ = x.shape
|
||||
x = mx.concatenate([x, mx.zeros((B, h, L, 1), dtype=x.dtype)], axis=-1)
|
||||
x_flat = x.reshape(B, h, L * 2 * L)
|
||||
x_flat = mx.concatenate([x_flat, mx.zeros((B, h, L - 1), dtype=x.dtype)], axis=-1)
|
||||
x_final = x_flat.reshape(B, h, L + 1, 2 * L - 1)
|
||||
return x_final[:, :, :L, L - 1:]
|
||||
|
||||
|
||||
def _abs_to_rel(x: mx.array) -> mx.array:
|
||||
"""[B, h, L, L] → [B, h, L, 2L-1] (inverse of _rel_to_abs)."""
|
||||
B, h, L, _ = x.shape
|
||||
x = mx.concatenate([x, mx.zeros((B, h, L, L - 1), dtype=x.dtype)], axis=-1)
|
||||
x_flat = x.reshape(B, h, L * (2 * L - 1))
|
||||
x_flat = mx.concatenate([mx.zeros((B, h, L), dtype=x.dtype), x_flat], axis=-1)
|
||||
x_final = x_flat.reshape(B, h, L, 2 * L)
|
||||
return x_final[:, :, :, 1:]
|
||||
|
||||
|
||||
def _slice_rel_emb(rel: mx.array, length: int, window: int) -> mx.array:
|
||||
"""``rel`` (1, 2W+1, d) → (1, 2L-1, d) by zero-padding/slicing."""
|
||||
pad_l = max(length - (window + 1), 0)
|
||||
if pad_l > 0:
|
||||
zero = mx.zeros((1, pad_l, rel.shape[-1]), dtype=rel.dtype)
|
||||
padded = mx.concatenate([zero, rel, zero], axis=1)
|
||||
else:
|
||||
padded = rel
|
||||
start = max(window + 1 - length, 0)
|
||||
return padded[:, start: start + 2 * length - 1]
|
||||
|
||||
|
||||
class RelPosSelfAttention(nn.Module):
|
||||
"""VITS-style relative-position self-attention with window=4.
|
||||
|
||||
Adds two contributions to vanilla MHA:
|
||||
- ``rel_logits = q @ rel_k.T`` then ``_rel_to_abs`` and added to attention logits
|
||||
- ``rel_attn = _abs_to_rel(softmax(logits))`` then ``@ rel_v`` and added to output
|
||||
|
||||
Loaded keys (per layer):
|
||||
``conv_q/k/v/o.weight`` (256, 256, 1) and ``.bias`` (256)
|
||||
``emb_rel_k`` (1, 9, 64), ``emb_rel_v`` (1, 9, 64)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv_q = _ConvLayer(TE_DIM, TE_DIM)
|
||||
self.conv_k = _ConvLayer(TE_DIM, TE_DIM)
|
||||
self.conv_v = _ConvLayer(TE_DIM, TE_DIM)
|
||||
self.conv_o = _ConvLayer(TE_DIM, TE_DIM)
|
||||
self.window = REL_POS_WINDOW
|
||||
self.emb_rel_k = mx.zeros((1, 2 * REL_POS_WINDOW + 1, TE_ATTN_HEAD_DIM))
|
||||
self.emb_rel_v = mx.zeros((1, 2 * REL_POS_WINDOW + 1, TE_ATTN_HEAD_DIM))
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
B, T, _ = x.shape
|
||||
H, D = TE_ATTN_HEADS, TE_ATTN_HEAD_DIM
|
||||
q = self.conv_q(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
|
||||
k = self.conv_k(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
|
||||
v = self.conv_v(x).reshape(B, T, H, D).transpose(0, 2, 1, 3)
|
||||
scale = D ** -0.5
|
||||
|
||||
# Standard attention logits
|
||||
logits = (q @ k.transpose(0, 1, 3, 2)) * scale # (B, H, T, T)
|
||||
|
||||
# VITS relative-position contribution to logits
|
||||
rel_k = _slice_rel_emb(self.emb_rel_k, T, self.window) # (1, 2T-1, D)
|
||||
rel_logits = q @ rel_k.transpose(0, 2, 1)[:, None, :, :] # (B, H, T, 2T-1)
|
||||
rel_logits = _rel_to_abs(rel_logits * scale) # (B, H, T, T)
|
||||
logits = logits + rel_logits
|
||||
|
||||
if mask is not None:
|
||||
key_mask = mask[:, :, 0][:, None, None, :]
|
||||
neg_inf = mx.array(-1e4, dtype=logits.dtype)
|
||||
logits = mx.where(key_mask.astype(mx.bool_), logits, neg_inf)
|
||||
|
||||
attn = mx.softmax(logits, axis=-1) # (B, H, T, T)
|
||||
out = attn @ v # (B, H, T, D)
|
||||
|
||||
# VITS rel-pos value contribution
|
||||
rel_v = _slice_rel_emb(self.emb_rel_v, T, self.window) # (1, 2T-1, D)
|
||||
rel_weights = _abs_to_rel(attn) # (B, H, T, 2T-1)
|
||||
out = out + rel_weights @ rel_v[:, None, :, :] # (B, H, T, D)
|
||||
|
||||
out = out.transpose(0, 2, 1, 3).reshape(B, T, H * D)
|
||||
return self.conv_o(out)
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
"""FFN with Conv1d k=1 wrappers: conv_1 (256→1024) + ReLU + conv_2 (1024→256).
|
||||
|
||||
Activation is ReLU (confirmed by ONNX graph node ``Relu`` in ``ffn_layers.N``),
|
||||
not GELU. The mask is applied before each Conv to match the ONNX semantics.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.conv_1 = _ConvLayer(TE_DIM, TE_FFN_HIDDEN)
|
||||
self.conv_2 = _ConvLayer(TE_FFN_HIDDEN, TE_DIM)
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
if mask is not None:
|
||||
x = x * mask
|
||||
y = self.conv_1(x)
|
||||
y = mx.maximum(y, mx.array(0.0, dtype=y.dtype))
|
||||
if mask is not None:
|
||||
y = y * mask
|
||||
y = self.conv_2(y)
|
||||
if mask is not None:
|
||||
y = y * mask
|
||||
return y
|
||||
|
||||
|
||||
class AttnEncoder(nn.Module):
|
||||
"""Stack of (RelPosSelfAttn + norm1) + (FFN + norm2) × 4."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attn_layers = [RelPosSelfAttention() for _ in range(TE_ATTN_NUM_LAYERS)]
|
||||
self.norm_layers_1 = [WrappedNorm(TE_DIM, eps=EPS_LN) for _ in range(TE_ATTN_NUM_LAYERS)]
|
||||
self.ffn_layers = [FFN() for _ in range(TE_ATTN_NUM_LAYERS)]
|
||||
self.norm_layers_2 = [WrappedNorm(TE_DIM, eps=EPS_LN) for _ in range(TE_ATTN_NUM_LAYERS)]
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
for i in range(TE_ATTN_NUM_LAYERS):
|
||||
y = self.attn_layers[i](x, mask=mask)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
y = self.ffn_layers[i](x, mask)
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
return x
|
||||
|
||||
|
||||
class _TextEmbedder(nn.Module):
|
||||
"""char_embedder: VOCAB → TE_DIM. Loaded as ``char_embedder.weight (8322, 256)``."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.char_embedder = nn.Embedding(VOCAB_SIZE, TE_DIM)
|
||||
|
||||
def __call__(self, text_ids: mx.array) -> mx.array:
|
||||
return self.char_embedder(text_ids)
|
||||
|
||||
|
||||
class _InnerTextEncoder(nn.Module):
|
||||
"""Pure text encoder before speech prompting. Loaded as ``text_encoder.X.Y``."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.text_embedder = _TextEmbedder()
|
||||
self.convnext = TextConvNeXtStack()
|
||||
self.attn_encoder = AttnEncoder()
|
||||
|
||||
def __call__(self, text_ids: mx.array, mask: mx.array) -> mx.array:
|
||||
x = self.text_embedder(text_ids) # (B, T, 256)
|
||||
if mask is not None:
|
||||
x = x * mask
|
||||
x = self.convnext(x, mask)
|
||||
x = self.attn_encoder(x, mask)
|
||||
return x
|
||||
|
||||
|
||||
class _StyleEncoder(nn.Module):
|
||||
"""Holds ``style_token_layer.style_key`` (1, 50, 256)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Use a child module so the parameter path matches ``style_token_layer.style_key``
|
||||
class _StyleTokenLayer(nn.Module):
|
||||
def __init__(_):
|
||||
super().__init__()
|
||||
_.style_key = mx.zeros((1, 50, 256))
|
||||
self.style_token_layer = _StyleTokenLayer()
|
||||
|
||||
|
||||
class _SpeechPromptedAttn(nn.Module):
|
||||
"""Cross-attention from text (Q) to style_ttl (K, V). Single head, 256-d."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.W_query = WrappedLinear(TE_DIM, TE_DIM, bias=True)
|
||||
self.W_key = WrappedLinear(TE_DIM, TE_DIM, bias=True)
|
||||
self.W_value = WrappedLinear(TE_DIM, TE_DIM, bias=True)
|
||||
self.out_fc = WrappedLinear(TE_DIM, TE_DIM, bias=True)
|
||||
|
||||
def __call__(self, x: mx.array, style: mx.array) -> mx.array:
|
||||
# x: (B, T_text, 256); style: (B, 50, 256)
|
||||
# Single-head cross attention.
|
||||
B, T, D = x.shape
|
||||
q = self.W_query(x)
|
||||
k = self.W_key(style)
|
||||
v = self.W_value(style)
|
||||
scale = D ** -0.5
|
||||
logits = (q @ k.transpose(0, 2, 1)) * scale
|
||||
attn = mx.softmax(logits, axis=-1)
|
||||
out = attn @ v
|
||||
return self.out_fc(out)
|
||||
|
||||
|
||||
class _SpeechPromptedTextEncoder(nn.Module):
|
||||
"""Two cross-attention layers modulating text features with style_ttl."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attention1 = _SpeechPromptedAttn()
|
||||
self.attention2 = _SpeechPromptedAttn()
|
||||
self.norm = WrappedNorm(TE_DIM, eps=EPS_LN)
|
||||
|
||||
def __call__(self, x: mx.array, style: mx.array) -> mx.array:
|
||||
x = x + self.attention1(x, style)
|
||||
x = x + self.attention2(x, style)
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class _RootTextEncoder(nn.Module):
|
||||
"""Top-level container matching ONNX ``tts.ttl.*`` namespace."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.text_encoder = _InnerTextEncoder()
|
||||
self.style_encoder = _StyleEncoder()
|
||||
self.speech_prompted_text_encoder = _SpeechPromptedTextEncoder()
|
||||
|
||||
|
||||
class _TtsContainer(nn.Module):
|
||||
"""Outer container so weight keys ``tts.ttl.X.Y`` resolve."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.ttl = _RootTextEncoder()
|
||||
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
"""Top-level text encoder: ``text_ids + style_ttl + text_mask → text_emb (B, 256, T)``.
|
||||
|
||||
Submodule naming matches the ONNX initializer keys after a single
|
||||
``tts.ttl.`` prefix wrap (so weight keys look like
|
||||
``tts.ttl.text_encoder.convnext.convnext.0.dwconv.weight``).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.tts = _TtsContainer()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text_ids: mx.array, # (B, T_text) int
|
||||
style_ttl: mx.array, # (B, 50, 256)
|
||||
text_mask: mx.array, # (B, 1, T_text)
|
||||
) -> mx.array:
|
||||
mask_ntc = text_mask.transpose(0, 2, 1) # (B, T_text, 1)
|
||||
x = self.tts.ttl.text_encoder(text_ids, mask_ntc)
|
||||
x = self.tts.ttl.speech_prompted_text_encoder(x, style_ttl)
|
||||
if mask_ntc is not None:
|
||||
x = x * mask_ntc
|
||||
# Return channels-first (B, 256, T_text) to match the vector_estimator input.
|
||||
return x.transpose(0, 2, 1)
|
||||
|
||||
|
||||
__all__ = ["TextEncoder", "VOCAB_SIZE", "TE_DIM"]
|
||||
765
src/supertonic_3_mlx/vector_estimator.py
Normal file
765
src/supertonic_3_mlx/vector_estimator.py
Normal file
@@ -0,0 +1,765 @@
|
||||
"""Supertonic 3 vector estimator (64 M params) — flow-matching denoiser, MLX port.
|
||||
|
||||
Pipeline (operating in channels-last NTC layout):
|
||||
|
||||
noisy_latent [B, 144, T_lat] (channels first from ONNX I/O)
|
||||
→ transpose [B, T_lat, 144]
|
||||
→ proj_in (Linear 144→512) [B, T_lat, 512]
|
||||
→ 24 main_blocks (4 cycles × 6 sub-types):
|
||||
cycle = [stack4, time_film, cn1, text_attn, cn1, style_attn]
|
||||
→ last_convnext (4 ConvNeXt) [B, T_lat, 512]
|
||||
→ proj_out (Linear 512→144) [B, T_lat, 144]
|
||||
→ transpose [B, 144, T_lat]
|
||||
→ Euler step: denoised = noisy + velocity * (1 / total_step)
|
||||
→ output [B, 144, T_lat]
|
||||
|
||||
Submodule naming matches the s3 ONNX initializer keys exactly, so loading
|
||||
the safetensors produced by ``weights.convert_onnx_to_mlx`` requires no
|
||||
remapping.
|
||||
|
||||
The forward path is faithful to ONNX semantics in fp32; ``mx.compile``,
|
||||
quantisation, and kernel fusion are layered on later in T.3.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from supertonic_3_mlx._config import (
|
||||
DIM, LATENT_CH, CONVNEXT_HIDDEN, CONVNEXT_K, STACK4_DILATIONS,
|
||||
NUM_MAIN_BLOCKS, BLOCKS_PER_CYCLE, BLOCK_CYCLE,
|
||||
TEXT_DIM, TEXT_HEADS, TEXT_HEAD_DIM, ROTARY_BASE, ROTARY_SCALE,
|
||||
STYLE_DIM, STYLE_LEN, STYLE_HEADS, STYLE_HEAD_DIM,
|
||||
TIME_EMB_DIM, TIME_MLP_HIDDEN,
|
||||
EPS_LN,
|
||||
)
|
||||
from supertonic_3_mlx._nn_wrappers import (
|
||||
WrappedNorm, WrappedLinear, ProjConv1x1,
|
||||
)
|
||||
|
||||
|
||||
def _pad_sym_edge(x: mx.array, pad: int) -> mx.array:
|
||||
"""Symmetric replicate-edge pad on the time axis (axis=1 for [B, T, C])."""
|
||||
if pad == 0:
|
||||
return x
|
||||
left = mx.broadcast_to(x[:, :1, :], (x.shape[0], pad, x.shape[2]))
|
||||
right = mx.broadcast_to(x[:, -1:, :], (x.shape[0], pad, x.shape[2]))
|
||||
return mx.concatenate([left, x, right], axis=1)
|
||||
|
||||
|
||||
def _gelu_exact(x: mx.array) -> mx.array:
|
||||
"""Exact (non-tanh) GELU: x * 0.5 * (1 + erf(x / sqrt(2)))."""
|
||||
return x * 0.5 * (1.0 + mx.erf(x * (2 ** -0.5)))
|
||||
|
||||
|
||||
def _mish(x: mx.array) -> mx.array:
|
||||
"""Mish: x * tanh(softplus(x)) = x * tanh(log(1 + exp(x)))."""
|
||||
return x * mx.tanh(mx.logaddexp(x, mx.array(0.0, dtype=x.dtype)))
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────
|
||||
# ConvNeXt building blocks
|
||||
# ──────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ConvNeXtBlock(nn.Module):
|
||||
"""Single ConvNeXt block matching s3 keys: ``dwconv``, ``norm.norm``, ``pwconv1/2``, ``gamma``."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = DIM,
|
||||
hidden: int = CONVNEXT_HIDDEN,
|
||||
kernel: int = CONVNEXT_K,
|
||||
dilation: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dilation = dilation
|
||||
self.pad = dilation * (kernel - 1) // 2
|
||||
self.dwconv = nn.Conv1d(
|
||||
dim, dim, kernel_size=kernel, padding=0, dilation=dilation,
|
||||
groups=dim, bias=True,
|
||||
)
|
||||
self.norm = WrappedNorm(dim, eps=EPS_LN)
|
||||
self.pwconv1 = nn.Linear(dim, hidden, bias=True)
|
||||
self.pwconv2 = nn.Linear(hidden, dim, bias=True)
|
||||
# Stored as shape (1, dim, 1) in the ONNX checkpoint — see weights.py for
|
||||
# the load-time reshape that flattens it to (dim,) for broadcasting in NTC.
|
||||
self.gamma = mx.zeros((dim,))
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
# x: (B, T, C)
|
||||
residual = x
|
||||
y = _pad_sym_edge(x, self.pad)
|
||||
y = self.dwconv(y) # (B, T, C)
|
||||
y = self.norm(y) # LayerNorm last-dim
|
||||
y = self.pwconv1(y) # (B, T, hidden)
|
||||
y = _gelu_exact(y)
|
||||
y = self.pwconv2(y) # (B, T, C)
|
||||
y = y * self.gamma # broadcast over (B, T, .)
|
||||
out = residual + y
|
||||
if mask is not None:
|
||||
out = out * mask
|
||||
return out
|
||||
|
||||
|
||||
class ConvNeXtStack(nn.Module):
|
||||
"""List of ConvNeXt blocks. Loaded as ``convnext.[0..N-1].X``."""
|
||||
|
||||
def __init__(self, dilations: tuple, dim: int = DIM, hidden: int = CONVNEXT_HIDDEN) -> None:
|
||||
super().__init__()
|
||||
self.convnext = [ConvNeXtBlock(dim, hidden, CONVNEXT_K, d) for d in dilations]
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None = None) -> mx.array:
|
||||
for b in self.convnext:
|
||||
x = b(x, mask)
|
||||
return x
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────
|
||||
# 6 block types per cycle
|
||||
# ──────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class Stack4Block(nn.Module):
|
||||
"""Cycle position 0 — 4 ConvNeXt with dilations [1, 2, 4, 8].
|
||||
|
||||
Loaded keys: ``convnext.[0..3].{dwconv,norm.norm,pwconv1,pwconv2,gamma}``.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.convnext = [ConvNeXtBlock(DIM, CONVNEXT_HIDDEN, CONVNEXT_K, d) for d in STACK4_DILATIONS]
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None, **_) -> mx.array:
|
||||
for b in self.convnext:
|
||||
x = b(x, mask)
|
||||
return x
|
||||
|
||||
|
||||
class TimeFiLMBlock(nn.Module):
|
||||
"""Cycle position 1 — additive time conditioning: ``x + linear(t_emb)``.
|
||||
|
||||
Loaded keys: ``linear.linear.{weight,bias}``.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear = WrappedLinear(TIME_EMB_DIM, DIM, bias=True)
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None, t_emb: mx.array, **_) -> mx.array:
|
||||
# t_emb: (B, TIME_EMB_DIM) → broadcast across T
|
||||
bias = self.linear(t_emb)[:, None, :] # (B, 1, DIM)
|
||||
y = x + bias
|
||||
if mask is not None:
|
||||
y = y * mask
|
||||
return y
|
||||
|
||||
|
||||
class ConvNeXt1Block(nn.Module):
|
||||
"""Cycle positions 2 and 4 — a single ConvNeXt block.
|
||||
|
||||
Loaded keys: ``convnext.0.{dwconv,norm.norm,pwconv1,pwconv2,gamma}``.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.convnext = [ConvNeXtBlock(DIM, CONVNEXT_HIDDEN, CONVNEXT_K, 1)]
|
||||
|
||||
def __call__(self, x: mx.array, mask: mx.array | None, **_) -> mx.array:
|
||||
return self.convnext[0](x, mask)
|
||||
|
||||
|
||||
def _build_rope_freqs(head_dim: int, base: int, scale: int, max_len: int = 1024) -> mx.array:
|
||||
"""Pre-compute RoPE cos/sin table — (max_len, head_dim/2, 2)."""
|
||||
half = head_dim // 2
|
||||
inv_freq = 1.0 / (base ** (mx.arange(half, dtype=mx.float32) / half))
|
||||
pos = mx.arange(max_len, dtype=mx.float32) * scale
|
||||
angles = pos[:, None] * inv_freq[None, :] # (max_len, half)
|
||||
return mx.stack([mx.cos(angles), mx.sin(angles)], axis=-1) # (max_len, half, 2)
|
||||
|
||||
|
||||
def _apply_rope(x: mx.array, freqs: mx.array) -> mx.array:
|
||||
"""Apply RoPE rotation. ``x`` shape (B, H, T, head_dim); ``freqs`` (T, half, 2)."""
|
||||
half = x.shape[-1] // 2
|
||||
x_even, x_odd = x[..., :half], x[..., half:]
|
||||
cos = freqs[..., 0] # (T, half)
|
||||
sin = freqs[..., 1]
|
||||
rot_even = x_even * cos[None, None, :, :] - x_odd * sin[None, None, :, :]
|
||||
rot_odd = x_even * sin[None, None, :, :] + x_odd * cos[None, None, :, :]
|
||||
return mx.concatenate([rot_even, rot_odd], axis=-1)
|
||||
|
||||
|
||||
class TextCrossAttnBlock(nn.Module):
|
||||
"""Cycle position 3 — text cross-attention with RoPE on Q and K.
|
||||
|
||||
Loaded keys:
|
||||
``attn.W_query.linear.{weight,bias}``
|
||||
``attn.W_key.linear.{weight,bias}``
|
||||
``attn.W_value.linear.{weight,bias}``
|
||||
``attn.out_fc.linear.{weight,bias}``
|
||||
``attn.theta`` — frozen RoPE inv-freq table (1, 1, half)
|
||||
``attn.increments`` — frozen position table (1, 1000, 1) — 0..999
|
||||
``norm.norm.{weight,bias}``
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attn = _AttnInner(DIM, TEXT_DIM, TEXT_HEADS, TEXT_HEAD_DIM)
|
||||
self.norm = WrappedNorm(DIM, eps=EPS_LN)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | None,
|
||||
*,
|
||||
text_emb: mx.array | None = None,
|
||||
text_mask: mx.array | None = None,
|
||||
latent_seq_len: mx.array | None = None,
|
||||
text_seq_len: mx.array | None = None,
|
||||
kv_cache: tuple[mx.array, mx.array] | None = None,
|
||||
**_,
|
||||
) -> mx.array:
|
||||
# x: (B, T_lat, DIM); text_emb: (B, T_text, TEXT_DIM) — unused when kv_cache supplied.
|
||||
residual = x * mask if mask is not None else x
|
||||
h = self.attn(
|
||||
residual, text_emb, text_mask=text_mask,
|
||||
latent_seq_len=latent_seq_len, text_seq_len=text_seq_len,
|
||||
kv_cache=kv_cache,
|
||||
)
|
||||
if mask is not None:
|
||||
h = h * mask
|
||||
out = self.norm(residual + h)
|
||||
if mask is not None:
|
||||
out = out * mask
|
||||
return out
|
||||
|
||||
|
||||
class _AttnInner(nn.Module):
|
||||
"""Multi-head cross-attention with RoPE applied to query and key.
|
||||
|
||||
Holds parameters under ``W_query``, ``W_key``, ``W_value``, ``out_fc`` —
|
||||
each is a :class:`WrappedLinear` so its weight is keyed
|
||||
``…W_query.linear.weight`` to match the ONNX checkpoint.
|
||||
|
||||
``theta`` and ``increments`` come from the ONNX graph as frozen tensors
|
||||
(precomputed RoPE table). We rebuild the equivalent table from the
|
||||
Supertonic-3 config so the module is self-contained.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
ctx_dim: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
# ONNX divides attention logits by 16.0 (= sqrt(TEXT_DIM)), not sqrt(head_dim).
|
||||
self.scale = ctx_dim ** -0.5
|
||||
|
||||
kv_dim = num_heads * head_dim # = DIM = 512
|
||||
self.W_query = WrappedLinear(in_dim, kv_dim, bias=True)
|
||||
self.W_key = WrappedLinear(ctx_dim, kv_dim, bias=True)
|
||||
self.W_value = WrappedLinear(ctx_dim, kv_dim, bias=True)
|
||||
self.out_fc = WrappedLinear(kv_dim, in_dim, bias=True)
|
||||
|
||||
# Frozen RoPE tables — overwritten by checkpoint at load time.
|
||||
# ONNX layout:
|
||||
# ``increments`` (1, 1000, 1) holds positions 0..999 (no scale baked in)
|
||||
# ``theta`` (1, 1, half) holds rotary_scale × base^(-i/half)
|
||||
# Angle formula: ``angle = (pos / actual_seq_len) × theta``.
|
||||
# The division by the actual seq length is critical — it normalises
|
||||
# absolute positions into [0, 1] so audio and text are RoPE-aligned
|
||||
# regardless of their respective lengths.
|
||||
max_len = 1000
|
||||
half = head_dim // 2
|
||||
idx = mx.arange(half, dtype=mx.float32)
|
||||
self.theta = (ROTARY_SCALE * mx.exp(-math.log(ROTARY_BASE) * idx / half))[None, None, :]
|
||||
positions = mx.arange(max_len, dtype=mx.int64)
|
||||
self.increments = positions[None, :, None] # (1, max_len, 1)
|
||||
|
||||
def _rope(self, x: mx.array, seq_len: mx.array | int | None = None) -> mx.array:
|
||||
"""Apply RoPE rotation. ``seq_len`` is the effective (unmasked) length.
|
||||
|
||||
Args:
|
||||
x: (B, H, T, head_dim)
|
||||
seq_len: scalar or (B,) — actual sequence length for position normalisation.
|
||||
If None, defaults to T (no normalisation).
|
||||
"""
|
||||
T = x.shape[-2]
|
||||
positions = self.increments[:, :T, :] # (1, T, 1)
|
||||
if seq_len is None:
|
||||
seq_len = float(T)
|
||||
if isinstance(seq_len, (int, float)):
|
||||
divisor = float(seq_len)
|
||||
else:
|
||||
divisor = seq_len.astype(mx.float32).reshape(-1, 1, 1)
|
||||
norm_pos = positions / divisor # broadcasts to (B, T, 1) if divisor is (B,1,1)
|
||||
angles = norm_pos * self.theta # (B, T, half) or (1, T, half)
|
||||
cos = mx.cos(angles)
|
||||
sin = mx.sin(angles)
|
||||
half = self.head_dim // 2
|
||||
# Broadcast (?, T, half) → (?, 1, T, half) for head dim
|
||||
cos_b = cos[..., None, :, :] if cos.ndim == 3 else cos[None, None, :, :]
|
||||
sin_b = sin[..., None, :, :] if sin.ndim == 3 else sin[None, None, :, :]
|
||||
# Make sure broadcasts properly
|
||||
if cos_b.shape[0] == 1 and x.shape[0] > 1:
|
||||
cos_b = mx.broadcast_to(cos_b, (x.shape[0], 1, T, half))
|
||||
sin_b = mx.broadcast_to(sin_b, (x.shape[0], 1, T, half))
|
||||
# Reshape if needed
|
||||
cos_b = cos_b.reshape(-1, 1, T, half)
|
||||
sin_b = sin_b.reshape(-1, 1, T, half)
|
||||
x_first, x_second = x[..., :half], x[..., half:]
|
||||
rot_first = x_first * cos_b - x_second * sin_b
|
||||
rot_second = x_first * sin_b + x_second * cos_b
|
||||
return mx.concatenate([rot_first, rot_second], axis=-1)
|
||||
|
||||
def project_kv(
|
||||
self,
|
||||
text_emb: mx.array,
|
||||
text_seq_len: mx.array | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Project text_emb → (K_rope, V) once. Both are constant across the
|
||||
Euler steps in a TTS inference call (T.5.3 cache target)."""
|
||||
B, T_text, _ = text_emb.shape
|
||||
H, D = self.num_heads, self.head_dim
|
||||
k = self.W_key(text_emb).reshape(B, T_text, H, D).transpose(0, 2, 1, 3)
|
||||
v = self.W_value(text_emb).reshape(B, T_text, H, D).transpose(0, 2, 1, 3)
|
||||
k = self._rope(k, seq_len=text_seq_len if text_seq_len is not None else T_text)
|
||||
return k, v
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
text_emb: mx.array | None = None,
|
||||
text_mask: mx.array | None = None,
|
||||
latent_seq_len: mx.array | None = None,
|
||||
text_seq_len: mx.array | None = None,
|
||||
kv_cache: tuple[mx.array, mx.array] | None = None,
|
||||
) -> mx.array:
|
||||
B, T_lat, _ = x.shape
|
||||
H, D = self.num_heads, self.head_dim
|
||||
|
||||
q = self.W_query(x).reshape(B, T_lat, H, D).transpose(0, 2, 1, 3) # (B, H, T_lat, D)
|
||||
if kv_cache is not None:
|
||||
k, v = kv_cache
|
||||
else:
|
||||
k, v = self.project_kv(text_emb, text_seq_len=text_seq_len)
|
||||
|
||||
# RoPE normalises positions by the effective (unmasked) sequence length.
|
||||
q = self._rope(q, seq_len=latent_seq_len if latent_seq_len is not None else T_lat)
|
||||
|
||||
# Attention
|
||||
logits = (q @ k.transpose(0, 1, 3, 2)) * self.scale # (B, H, T_lat, T_text)
|
||||
if text_mask is not None:
|
||||
neg_inf = mx.array(-1e4, dtype=logits.dtype)
|
||||
logits = mx.where(text_mask[:, :, None, :].astype(mx.bool_), logits, neg_inf)
|
||||
attn = mx.softmax(logits, axis=-1)
|
||||
out = attn @ v # (B, H, T_lat, D)
|
||||
out = out.transpose(0, 2, 1, 3).reshape(B, T_lat, H * D)
|
||||
return self.out_fc(out)
|
||||
|
||||
|
||||
class StyleCrossAttnBlock(nn.Module):
|
||||
"""Cycle position 5 — style cross-attention to 50 learned style tokens.
|
||||
|
||||
Loaded keys:
|
||||
``attention.W_query.linear.{weight,bias}``
|
||||
``attention.W_key.linear.{weight,bias}``
|
||||
``attention.W_value.linear.{weight,bias}``
|
||||
``attention.out_fc.linear.{weight,bias}``
|
||||
``norm.norm.{weight,bias}``
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.attention = _StyleAttnInner(DIM, STYLE_DIM, STYLE_HEADS, STYLE_HEAD_DIM)
|
||||
self.norm = WrappedNorm(DIM, eps=EPS_LN)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | None,
|
||||
*,
|
||||
style_k: mx.array | None = None,
|
||||
style_v: mx.array | None = None,
|
||||
kv_cache: tuple[mx.array, mx.array] | None = None,
|
||||
**_,
|
||||
) -> mx.array:
|
||||
# style_v defaults to style_k (same tensor for cond path); CFG path supplies
|
||||
# different style_v to model the uncond branch.
|
||||
if style_v is None and style_k is not None:
|
||||
style_v = style_k
|
||||
residual = x * mask if mask is not None else x
|
||||
h = self.attention(residual, style_k, style_v, kv_cache=kv_cache)
|
||||
if mask is not None:
|
||||
h = h * mask
|
||||
out = self.norm(residual + h)
|
||||
if mask is not None:
|
||||
out = out * mask
|
||||
return out
|
||||
|
||||
|
||||
class _StyleAttnInner(nn.Module):
|
||||
def __init__(self, in_dim: int, ctx_dim: int, num_heads: int, head_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
# ONNX divides attention logits by 16.0 (= sqrt(STYLE_DIM)), not sqrt(head_dim).
|
||||
self.scale = ctx_dim ** -0.5
|
||||
kv_dim = num_heads * head_dim # 2 * 128 = 256
|
||||
# Q is on DIM (audio), K/V on ctx_dim (style 256)
|
||||
self.W_query = WrappedLinear(in_dim, kv_dim, bias=True)
|
||||
self.W_key = WrappedLinear(ctx_dim, kv_dim, bias=True)
|
||||
self.W_value = WrappedLinear(ctx_dim, kv_dim, bias=True)
|
||||
self.out_fc = WrappedLinear(kv_dim, in_dim, bias=True)
|
||||
|
||||
def project_kv(
|
||||
self, style_k: mx.array, style_v: mx.array
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Project (style_k, style_v) → (K, V) once. T.5.3 cache target."""
|
||||
B, T_style = style_k.shape[0], style_k.shape[1]
|
||||
H, D = self.num_heads, self.head_dim
|
||||
# Note: ONNX graph applies tanh to the K projection (``attention/tanh/Tanh``
|
||||
# node) — the style key bank is bounded into [-1, 1] before softmax dot
|
||||
# product, which acts as a soft attention temperature regulariser.
|
||||
k = mx.tanh(self.W_key(style_k)).reshape(B, T_style, H, D).transpose(0, 2, 1, 3)
|
||||
v = self.W_value(style_v).reshape(B, style_v.shape[1], H, D).transpose(0, 2, 1, 3)
|
||||
return k, v
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
style_k: mx.array | None = None,
|
||||
style_v: mx.array | None = None,
|
||||
kv_cache: tuple[mx.array, mx.array] | None = None,
|
||||
) -> mx.array:
|
||||
# style_k and style_v can be the same tensor (cond) or distinct (uncond
|
||||
# branch in CFG, where K comes from style_key_special_token and V from
|
||||
# style_value_special_token).
|
||||
B, T_lat, _ = x.shape
|
||||
H, D = self.num_heads, self.head_dim
|
||||
q = self.W_query(x).reshape(B, T_lat, H, D).transpose(0, 2, 1, 3)
|
||||
if kv_cache is not None:
|
||||
k, v = kv_cache
|
||||
else:
|
||||
k, v = self.project_kv(style_k, style_v)
|
||||
logits = (q @ k.transpose(0, 1, 3, 2)) * self.scale
|
||||
attn = mx.softmax(logits, axis=-1)
|
||||
out = attn @ v
|
||||
out = out.transpose(0, 2, 1, 3).reshape(B, T_lat, H * D)
|
||||
return self.out_fc(out)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────
|
||||
# Time encoder
|
||||
# ──────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _MlpItem(nn.Module):
|
||||
"""A single MLP layer wrapped to produce keys ``mlp.N.linear.{weight,bias}``."""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(in_dim, out_dim, bias=True)
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class TimeEncoder(nn.Module):
|
||||
"""Sinusoidal time embedding + 2-layer MLP. Keys: ``mlp.0.linear``, ``mlp.2.linear``."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# ONNX: mlp.0.linear (64→256), mlp.2.linear (256→64). Index 1 is activation.
|
||||
self.mlp = [
|
||||
_MlpItem(TIME_EMB_DIM, TIME_MLP_HIDDEN), # mlp.0
|
||||
nn.Identity(), # mlp.1 (activation; no weights)
|
||||
_MlpItem(TIME_MLP_HIDDEN, TIME_EMB_DIM), # mlp.2
|
||||
]
|
||||
|
||||
def __call__(self, t: mx.array) -> mx.array:
|
||||
# t: (B,) — produce sinusoidal embedding then run through MLP.
|
||||
# Activation is Mish (not SiLU) to match the ONNX graph
|
||||
# (Softplus → Tanh → Mul pattern == x * tanh(softplus(x))).
|
||||
emb = self._sinusoidal(t, TIME_EMB_DIM)
|
||||
h = self.mlp[0](emb)
|
||||
h = _mish(h)
|
||||
h = self.mlp[2](h)
|
||||
return h
|
||||
|
||||
@staticmethod
|
||||
def _sinusoidal(t: mx.array, dim: int) -> mx.array:
|
||||
"""Time embedding matching ``Supertonic-3`` ONNX exactly.
|
||||
|
||||
ONNX path: pos = t * 1000; freqs[i] = 10000^(-i/(half-1));
|
||||
concat[sin(pos*freqs), cos(pos*freqs)].
|
||||
"""
|
||||
half = dim // 2
|
||||
denom = max(half - 1, 1)
|
||||
freqs = mx.exp(-math.log(10_000) * mx.arange(half, dtype=mx.float32) / denom)
|
||||
pos = t.astype(mx.float32)[:, None] * 1000.0
|
||||
angles = pos * freqs[None, :]
|
||||
return mx.concatenate([mx.sin(angles), mx.cos(angles)], axis=-1).astype(mx.float32)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────
|
||||
# Top-level VectorEstimator
|
||||
# ──────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _build_main_block(idx: int) -> nn.Module:
|
||||
"""Instantiate the appropriate block class for cycle position ``idx % 6``."""
|
||||
pos = idx % BLOCKS_PER_CYCLE
|
||||
name = BLOCK_CYCLE[pos]
|
||||
if name == "stack4":
|
||||
return Stack4Block()
|
||||
if name == "time":
|
||||
return TimeFiLMBlock()
|
||||
if name == "cn1":
|
||||
return ConvNeXt1Block()
|
||||
if name == "text_attn":
|
||||
return TextCrossAttnBlock()
|
||||
if name == "style_attn":
|
||||
return StyleCrossAttnBlock()
|
||||
raise RuntimeError(f"unknown block type for index {idx}: {name}")
|
||||
|
||||
|
||||
class _VectorField(nn.Module):
|
||||
"""Inner module mirroring ONNX ``vector_estimator.tts.ttl.vector_field.*``."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.proj_in = ProjConv1x1(LATENT_CH, DIM, bias=False)
|
||||
self.main_blocks = [_build_main_block(i) for i in range(NUM_MAIN_BLOCKS)]
|
||||
self.last_convnext = ConvNeXtStack(dilations=(1, 1, 1, 1), dim=DIM, hidden=CONVNEXT_HIDDEN)
|
||||
self.proj_out = ProjConv1x1(DIM, LATENT_CH, bias=False)
|
||||
self.time_encoder = TimeEncoder()
|
||||
|
||||
|
||||
class _UncondMasker(nn.Module):
|
||||
"""Holds the three unconditional-token tensors used by CFG.
|
||||
|
||||
Keys:
|
||||
``text_special_token`` (1, 256, 1)
|
||||
``style_key_special_token`` (1, 50, 256)
|
||||
``style_value_special_token`` (1, 50, 256)
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Initialised to zero; checkpoint provides real values.
|
||||
self.text_special_token = mx.zeros((1, TEXT_DIM, 1))
|
||||
self.style_key_special_token = mx.zeros((1, STYLE_LEN, STYLE_DIM))
|
||||
self.style_value_special_token = mx.zeros((1, STYLE_LEN, STYLE_DIM))
|
||||
|
||||
|
||||
class VectorEstimator(nn.Module):
|
||||
"""Top-level module — matches ONNX root names ``vector_field.*`` and ``uncond_masker.*``.
|
||||
|
||||
Two inference paths:
|
||||
- :meth:`velocity`: single forward pass; predicts the velocity from one set
|
||||
of conditioning inputs. ``style_k``/``style_v`` may be the same tensor
|
||||
(cond path) or different (uncond path of CFG).
|
||||
- :meth:`__call__`: full ONNX-parity forward — applies CFG batch doubling
|
||||
(cond + uncond) internally and combines via
|
||||
``final = noisy + (4*cond - 3*uncond) / total_step``.
|
||||
"""
|
||||
|
||||
# CFG guidance constants — baked into the ONNX graph as ``/Constant_3`` (=4.0)
|
||||
# and ``/Constant_4`` (=3.0). Equivalent to guidance_scale = 4 with the
|
||||
# standard formula ``v = uncond + g*(cond - uncond) = 4*cond - 3*uncond``.
|
||||
CFG_COND_SCALE: float = 4.0
|
||||
CFG_UNCOND_SCALE: float = 3.0
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.vector_field = _VectorField()
|
||||
self.uncond_masker = _UncondMasker()
|
||||
|
||||
# ── inference API ─────────────────────────────────────────────
|
||||
def velocity(
|
||||
self,
|
||||
noisy_latent: mx.array, # (B, 144, T_lat)
|
||||
text_emb: mx.array, # (B, 256, T_text)
|
||||
style_k: mx.array, # (B, 50, 256) — K side of style attention
|
||||
style_v: mx.array, # (B, 50, 256) — V side of style attention
|
||||
latent_mask: mx.array, # (B, 1, T_lat)
|
||||
text_mask: mx.array, # (B, 1, T_text)
|
||||
t_norm: mx.array, # (B,) timestep in [0, 1]
|
||||
) -> mx.array:
|
||||
"""Predict velocity (B, 144, T_lat) without applying CFG or Euler step."""
|
||||
x = noisy_latent.transpose(0, 2, 1) # (B, T_lat, 144)
|
||||
text = text_emb.transpose(0, 2, 1) # (B, T_text, 256)
|
||||
lat_mask_ntc = latent_mask.transpose(0, 2, 1) # (B, T_lat, 1)
|
||||
|
||||
x = self.vector_field.proj_in(x) # (B, T_lat, 512)
|
||||
t_emb = self.vector_field.time_encoder(t_norm) # (B, TIME_EMB_DIM)
|
||||
|
||||
# Effective (unmasked) sequence lengths for RoPE normalisation —
|
||||
# ONNX uses ``ReduceSum(mask)`` for this so that audio and text are
|
||||
# rope-aligned regardless of padding.
|
||||
latent_seq_len = mx.sum(latent_mask, axis=(1, 2)) # (B,)
|
||||
text_seq_len = mx.sum(text_mask, axis=(1, 2)) # (B,)
|
||||
|
||||
for blk in self.vector_field.main_blocks:
|
||||
x = blk(
|
||||
x,
|
||||
lat_mask_ntc,
|
||||
t_emb=t_emb,
|
||||
text_emb=text,
|
||||
text_mask=text_mask,
|
||||
style_k=style_k,
|
||||
style_v=style_v,
|
||||
latent_seq_len=latent_seq_len,
|
||||
text_seq_len=text_seq_len,
|
||||
)
|
||||
|
||||
x = self.vector_field.last_convnext(x, lat_mask_ntc)
|
||||
v_ntc = self.vector_field.proj_out(x) # (B, T_lat, 144)
|
||||
return v_ntc.transpose(0, 2, 1) # (B, 144, T_lat)
|
||||
|
||||
# ── T.5.3 — pre-projected K/V path ────────────────────────────
|
||||
def precompute_cross_kv(
|
||||
self,
|
||||
text_emb: mx.array, # (B, 256, T_text) channels-first
|
||||
style_k: mx.array, # (B, 50, 256)
|
||||
style_v: mx.array, # (B, 50, 256)
|
||||
text_mask: mx.array, # (B, 1, T_text)
|
||||
) -> tuple[list[tuple[mx.array, mx.array]], list[tuple[mx.array, mx.array]]]:
|
||||
"""Project K/V for every text_attn and style_attn block exactly once.
|
||||
|
||||
Returns ``(text_kv_list, style_kv_list)`` — both ordered to align with
|
||||
the corresponding blocks encountered when iterating ``main_blocks``.
|
||||
These tensors are invariant across the 5 Euler steps of one TTS
|
||||
call; pre-projecting them once and feeding the result into
|
||||
:meth:`velocity_cached` cuts ~ 4 × 2 × 5 = 40 redundant matmuls.
|
||||
"""
|
||||
text_seq_len = mx.sum(text_mask, axis=(1, 2))
|
||||
text_ntc = text_emb.transpose(0, 2, 1) # (B, T_text, 256)
|
||||
|
||||
text_kv: list[tuple[mx.array, mx.array]] = []
|
||||
style_kv: list[tuple[mx.array, mx.array]] = []
|
||||
for blk in self.vector_field.main_blocks:
|
||||
if isinstance(blk, TextCrossAttnBlock):
|
||||
text_kv.append(blk.attn.project_kv(text_ntc, text_seq_len=text_seq_len))
|
||||
elif isinstance(blk, StyleCrossAttnBlock):
|
||||
style_kv.append(blk.attention.project_kv(style_k, style_v))
|
||||
return text_kv, style_kv
|
||||
|
||||
def velocity_cached(
|
||||
self,
|
||||
noisy_latent: mx.array,
|
||||
latent_mask: mx.array,
|
||||
text_mask: mx.array,
|
||||
t_norm: mx.array,
|
||||
text_kv: list[tuple[mx.array, mx.array]],
|
||||
style_kv: list[tuple[mx.array, mx.array]],
|
||||
) -> mx.array:
|
||||
"""Same as :meth:`velocity` but reads K/V from pre-projected caches.
|
||||
|
||||
``text_kv`` and ``style_kv`` must come from :meth:`precompute_cross_kv`
|
||||
applied to the same (batched) conditioning tensors that will be
|
||||
active for this call.
|
||||
"""
|
||||
x = noisy_latent.transpose(0, 2, 1)
|
||||
lat_mask_ntc = latent_mask.transpose(0, 2, 1)
|
||||
|
||||
x = self.vector_field.proj_in(x)
|
||||
t_emb = self.vector_field.time_encoder(t_norm)
|
||||
latent_seq_len = mx.sum(latent_mask, axis=(1, 2))
|
||||
|
||||
ti = 0
|
||||
si = 0
|
||||
for blk in self.vector_field.main_blocks:
|
||||
if isinstance(blk, TextCrossAttnBlock):
|
||||
x = blk(
|
||||
x, lat_mask_ntc,
|
||||
text_mask=text_mask,
|
||||
latent_seq_len=latent_seq_len,
|
||||
kv_cache=text_kv[ti],
|
||||
)
|
||||
ti += 1
|
||||
elif isinstance(blk, StyleCrossAttnBlock):
|
||||
x = blk(x, lat_mask_ntc, kv_cache=style_kv[si])
|
||||
si += 1
|
||||
else:
|
||||
x = blk(x, lat_mask_ntc, t_emb=t_emb)
|
||||
|
||||
x = self.vector_field.last_convnext(x, lat_mask_ntc)
|
||||
v_ntc = self.vector_field.proj_out(x)
|
||||
return v_ntc.transpose(0, 2, 1)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
noisy_latent: mx.array, # (B, 144, T_lat) channels-first per ONNX I/O
|
||||
text_emb: mx.array, # (B, 256, T_text) channels-first
|
||||
style_ttl: mx.array, # (B, 50, 256) — used as both K and V for cond
|
||||
latent_mask: mx.array, # (B, 1, T_lat)
|
||||
text_mask: mx.array, # (B, 1, T_text)
|
||||
current_step: mx.array, # (B,)
|
||||
total_step: mx.array, # (B,)
|
||||
cfg: bool = True,
|
||||
) -> mx.array:
|
||||
"""Run one Euler step with CFG (matches ONNX semantics).
|
||||
|
||||
With ``cfg=True`` (default) the model runs both conditional and
|
||||
unconditional paths in a single batched forward and combines via
|
||||
``final = noisy + (4*cond_v - 3*uncond_v) / total_step``.
|
||||
|
||||
With ``cfg=False`` only the conditional path runs — half the work, but
|
||||
produces a different (lower-quality) output. Useful for speed bench /
|
||||
sanity tests.
|
||||
"""
|
||||
B = noisy_latent.shape[0]
|
||||
t_norm = current_step.astype(mx.float32) / total_step.astype(mx.float32)
|
||||
|
||||
if not cfg:
|
||||
v = self.velocity(
|
||||
noisy_latent, text_emb, style_ttl, style_ttl,
|
||||
latent_mask, text_mask, t_norm,
|
||||
)
|
||||
return noisy_latent + v / total_step.reshape(-1, 1, 1).astype(noisy_latent.dtype)
|
||||
|
||||
# CFG branch — build (2B, ...) inputs by concatenating cond + uncond.
|
||||
# uncond text_emb = text_special_token broadcast to (B, 256, T_text).
|
||||
# uncond style_k = style_key_special_token broadcast, similarly style_v.
|
||||
text_uncond = mx.broadcast_to(
|
||||
self.uncond_masker.text_special_token, (B, TEXT_DIM, text_emb.shape[2])
|
||||
)
|
||||
style_k_uncond = mx.broadcast_to(
|
||||
self.uncond_masker.style_key_special_token, (B, STYLE_LEN, STYLE_DIM)
|
||||
)
|
||||
style_v_uncond = mx.broadcast_to(
|
||||
self.uncond_masker.style_value_special_token, (B, STYLE_LEN, STYLE_DIM)
|
||||
)
|
||||
|
||||
noisy_2 = mx.concatenate([noisy_latent, noisy_latent], axis=0)
|
||||
text_2 = mx.concatenate([text_emb, text_uncond], axis=0)
|
||||
style_k_2 = mx.concatenate([style_ttl, style_k_uncond], axis=0)
|
||||
style_v_2 = mx.concatenate([style_ttl, style_v_uncond], axis=0)
|
||||
lm_2 = mx.concatenate([latent_mask, latent_mask], axis=0)
|
||||
tm_2 = mx.concatenate([text_mask, text_mask], axis=0)
|
||||
t_norm_2 = mx.concatenate([t_norm, t_norm], axis=0)
|
||||
|
||||
v_2 = self.velocity(
|
||||
noisy_2, text_2, style_k_2, style_v_2, lm_2, tm_2, t_norm_2,
|
||||
) # (2B, 144, T_lat)
|
||||
cond_v = v_2[:B]
|
||||
uncond_v = v_2[B:2 * B]
|
||||
combined_v = self.CFG_COND_SCALE * cond_v - self.CFG_UNCOND_SCALE * uncond_v
|
||||
return noisy_latent + combined_v / total_step.reshape(-1, 1, 1).astype(noisy_latent.dtype)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ConvNeXtBlock", "ConvNeXtStack",
|
||||
"Stack4Block", "TimeFiLMBlock", "ConvNeXt1Block",
|
||||
"TextCrossAttnBlock", "StyleCrossAttnBlock",
|
||||
"TimeEncoder", "VectorEstimator",
|
||||
]
|
||||
304
src/supertonic_3_mlx/vocoder.py
Normal file
304
src/supertonic_3_mlx/vocoder.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""Supertonic 3 vocoder — latent → 44.1 kHz waveform, MLX port.
|
||||
|
||||
Pipeline (operating in channels-last NTC layout, then converted to channels-first
|
||||
for output reshape):
|
||||
|
||||
latent [B, 144, T_lat] (output of vector_estimator)
|
||||
→ /= normalizer.scale (scalar)
|
||||
→ reshape [B, 24, T_lat*6] # de-compress
|
||||
→ (* latent_std + latent_mean) # de-normalise
|
||||
→ transpose to NTC [B, T_lat*6, 24]
|
||||
→ embed Conv1d(24→512, k=7, sym-edge pad) [B, T_lat*6, 512]
|
||||
→ 10× ConvNeXt(dim=512, hidden=2048, k=7,
|
||||
dilations [1,2,4,1,2,4,1,1,1,1])
|
||||
→ final_norm: BatchNorm1d (eval-time: running stats only)
|
||||
→ head.layer1: Conv1d(512→2048, k=3, sym-edge pad)
|
||||
→ PReLU (with per-channel learnable slope)
|
||||
→ head.layer2: Conv1d(2048→512, k=1, no bias)
|
||||
→ transpose to (B, 512, T_lat*6) → flatten → wav (B, T_lat*6*512)
|
||||
|
||||
The 512 samples/step × 6 chunk × 44.1 kHz → T_lat steps of about 0.0697 s each.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from supertonic_3_mlx._config import EPS_LN
|
||||
from supertonic_3_mlx._nn_wrappers import WrappedNorm
|
||||
from supertonic_3_mlx.vector_estimator import _gelu_exact
|
||||
|
||||
|
||||
def _pad_left_edge(x: mx.array, pad: int) -> mx.array:
|
||||
"""Causal replicate-edge pad on the time axis (axis=1 for [B, T, C]).
|
||||
|
||||
Pads ``pad`` time-steps on the LEFT only by replicating the first frame.
|
||||
Matches the ONNX vocoder pads spec ``[0, 0, pad, 0, 0, 0]``.
|
||||
"""
|
||||
if pad == 0:
|
||||
return x
|
||||
left = mx.broadcast_to(x[:, :1, :], (x.shape[0], pad, x.shape[2]))
|
||||
return mx.concatenate([left, x], axis=1)
|
||||
|
||||
|
||||
VOC_DIM = 512
|
||||
VOC_HIDDEN = 2048
|
||||
VOC_K = 7
|
||||
VOC_HEAD_K = 3
|
||||
VOC_LDIM = 24 # de-compressed channels (24 × 6 = 144 input)
|
||||
VOC_CHUNK_COMPRESS = 6
|
||||
VOC_NUM_CONVNEXT_LAYERS = 10
|
||||
VOC_DILATIONS = (1, 2, 4, 1, 2, 4, 1, 1, 1, 1)
|
||||
EPS_BN = 1e-5
|
||||
|
||||
|
||||
class _Conv1dNet(nn.Module):
|
||||
"""Conv1d wrapped under ``.net`` to match ONNX storage ``.net.weight/bias``."""
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int, kernel: int, dilation: int = 1,
|
||||
groups: int = 1, bias: bool = True) -> None:
|
||||
super().__init__()
|
||||
class _Net(nn.Module):
|
||||
def __init__(_):
|
||||
super().__init__()
|
||||
# MLX Conv1d weight: (out, K, in/groups)
|
||||
_.weight = mx.zeros((out_dim, kernel, in_dim // groups))
|
||||
if bias:
|
||||
_.bias = mx.zeros((out_dim,))
|
||||
else:
|
||||
_.bias = None
|
||||
def __call__(_, x, dilation=1):
|
||||
y = mx.conv1d(x, _.weight, stride=1, padding=0, dilation=dilation,
|
||||
groups=groups)
|
||||
if _.bias is not None:
|
||||
y = y + _.bias
|
||||
return y
|
||||
self.net = _Net()
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
self.kernel = kernel
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.net(x, dilation=self.dilation)
|
||||
|
||||
|
||||
class _VocConvNeXtBlock(nn.Module):
|
||||
"""ConvNeXt block matching keys ``convnext.N.{dwconv.net,norm.norm,pwconv1,pwconv2,gamma}``."""
|
||||
|
||||
def __init__(self, dilation: int) -> None:
|
||||
super().__init__()
|
||||
self.dilation = dilation
|
||||
self.pad = dilation * (VOC_K - 1)
|
||||
self.dwconv = _Conv1dNet(VOC_DIM, VOC_DIM, kernel=VOC_K, dilation=dilation,
|
||||
groups=VOC_DIM, bias=True)
|
||||
self.norm = WrappedNorm(VOC_DIM, eps=EPS_LN)
|
||||
# pwconv1 / pwconv2 stored as Conv1d k=1 → loaded after squeeze to Linear.
|
||||
self.pwconv1 = nn.Linear(VOC_DIM, VOC_HIDDEN, bias=True)
|
||||
self.pwconv2 = nn.Linear(VOC_HIDDEN, VOC_DIM, bias=True)
|
||||
self.gamma = mx.zeros((VOC_DIM,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
residual = x
|
||||
y = _pad_left_edge(x, self.pad)
|
||||
y = self.dwconv(y)
|
||||
y = self.norm(y)
|
||||
y = self.pwconv1(y)
|
||||
y = _gelu_exact(y)
|
||||
y = self.pwconv2(y)
|
||||
y = y * self.gamma
|
||||
return residual + y
|
||||
|
||||
|
||||
class _BatchNorm1dEval(nn.Module):
|
||||
"""Eval-mode BatchNorm1d: applies stored running_mean/running_var only.
|
||||
|
||||
Loaded keys: ``norm.{weight,bias,running_mean,running_var}``.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
class _Norm(nn.Module):
|
||||
def __init__(_):
|
||||
super().__init__()
|
||||
_.weight = mx.ones((VOC_DIM,))
|
||||
_.bias = mx.zeros((VOC_DIM,))
|
||||
_.running_mean = mx.zeros((VOC_DIM,))
|
||||
_.running_var = mx.ones((VOC_DIM,))
|
||||
def __call__(_, x):
|
||||
# x: (B, T, C). BN1d normalises across batch+time per channel.
|
||||
# Eval mode: use stored running stats.
|
||||
norm = (x - _.running_mean) * mx.rsqrt(_.running_var + EPS_BN)
|
||||
return norm * _.weight + _.bias
|
||||
self.norm = _Norm()
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return self.norm(x)
|
||||
|
||||
|
||||
class _VocHeadActivation(nn.Module):
|
||||
"""PReLU with per-channel learnable slope (weight shape (C,))."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# ONNX anonymous PReLU stores slope of shape (1,) sometimes or (C,).
|
||||
# We default to (1,) and reshape on load if needed.
|
||||
self.weight = mx.zeros((1,))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# PReLU: max(0, x) + slope × min(0, x).
|
||||
# slope broadcasts over (B, T, C) or (B, C, T) depending on layout.
|
||||
zero = mx.array(0.0, dtype=x.dtype)
|
||||
return mx.maximum(x, zero) + self.weight * mx.minimum(x, zero)
|
||||
|
||||
|
||||
class _VocHead(nn.Module):
|
||||
"""``head.layer1`` (Conv1d 512→2048 k=3) + ``head.act`` (PReLU) + ``head.layer2`` (Conv1d k=1, no bias)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.layer1 = _Conv1dNet(VOC_DIM, VOC_HIDDEN, kernel=VOC_HEAD_K, bias=True)
|
||||
self.act = _VocHeadActivation()
|
||||
# layer2 has no .net wrapper in ONNX (different from layer1)
|
||||
# ONNX: head.layer2.weight (512, 2048, 1) — Conv1d k=1, no bias.
|
||||
# We represent it directly without .net wrap.
|
||||
self.layer2 = _VocLayer2()
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
# x: (B, T, 512)
|
||||
pad = VOC_HEAD_K - 1
|
||||
y = _pad_left_edge(x, pad)
|
||||
y = self.layer1(y) # (B, T, 2048)
|
||||
y = self.act(y)
|
||||
y = self.layer2(y) # (B, T, 512)
|
||||
return y
|
||||
|
||||
|
||||
class _VocLayer2(nn.Module):
|
||||
"""Conv1d k=1 (2048 → 512), no bias. Keys: ``layer2.weight (512, 2048, 1)``."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# MLX Conv1d weight shape: (out, K, in/groups) = (512, 1, 2048)
|
||||
# ONNX storage: (out, in, 1) = (512, 2048, 1). Same size; reshape on load.
|
||||
self.weight = mx.zeros((VOC_DIM, 1, VOC_HIDDEN))
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return mx.conv1d(x, self.weight, stride=1, padding=0)
|
||||
|
||||
|
||||
class _VocEmbed(nn.Module):
|
||||
"""Initial Conv1d(24→512, k=7) with sym-edge pad.
|
||||
|
||||
The weight + bias are anonymous in the ONNX graph (``onnx::Conv_1441`` and
|
||||
``onnx::Conv_1442``); the conversion recovers them via the Conv node path
|
||||
``/decoder/embed/net/Conv`` → structured name ``tts.ae.decoder.embed.net.{weight,bias}``.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
class _Net(nn.Module):
|
||||
def __init__(_):
|
||||
super().__init__()
|
||||
_.weight = mx.zeros((VOC_DIM, VOC_K, VOC_LDIM))
|
||||
_.bias = mx.zeros((VOC_DIM,))
|
||||
def __call__(_, x):
|
||||
return mx.conv1d(x, _.weight, stride=1, padding=0) + _.bias
|
||||
self.net = _Net()
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
pad = VOC_K - 1
|
||||
y = _pad_left_edge(x, pad)
|
||||
return self.net(y)
|
||||
|
||||
|
||||
class _VocDecoder(nn.Module):
|
||||
"""``tts.ae.decoder.X`` namespace."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.embed = _VocEmbed()
|
||||
self.convnext = [_VocConvNeXtBlock(d) for d in VOC_DILATIONS]
|
||||
self.final_norm = _BatchNorm1dEval()
|
||||
self.head = _VocHead()
|
||||
|
||||
|
||||
class _AEContainer(nn.Module):
|
||||
"""``tts.ae.X`` — holds latent_mean, latent_std, decoder."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.latent_mean = mx.zeros((1, VOC_LDIM, 1))
|
||||
self.latent_std = mx.ones((1, VOC_LDIM, 1))
|
||||
self.decoder = _VocDecoder()
|
||||
|
||||
|
||||
class _TtlContainer(nn.Module):
|
||||
"""``tts.ttl.normalizer.scale`` (scalar) — divides the latent before de-norm."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
class _Normalizer(nn.Module):
|
||||
def __init__(_):
|
||||
super().__init__()
|
||||
_.scale = mx.array(1.0)
|
||||
self.normalizer = _Normalizer()
|
||||
|
||||
|
||||
class _TtsContainer(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.ttl = _TtlContainer()
|
||||
self.ae = _AEContainer()
|
||||
|
||||
|
||||
class Vocoder(nn.Module):
|
||||
"""Latent → waveform decoder (44.1 kHz mono).
|
||||
|
||||
Submodule namespace matches ONNX keys ``tts.X.Y`` exactly.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.tts = _TtsContainer()
|
||||
|
||||
def __call__(self, latent: mx.array) -> mx.array:
|
||||
# latent: (B, 144, T_lat)
|
||||
B = latent.shape[0]
|
||||
T_lat = latent.shape[2]
|
||||
|
||||
# /= scale (scalar)
|
||||
x = latent / self.tts.ttl.normalizer.scale
|
||||
|
||||
# reshape (B, 144, T_lat) → (B, 24, T_lat*6)
|
||||
x = x.reshape(B, VOC_LDIM, VOC_CHUNK_COMPRESS, T_lat) # (B, 24, 6, T_lat)
|
||||
x = x.transpose(0, 1, 3, 2) # (B, 24, T_lat, 6)
|
||||
x = x.reshape(B, VOC_LDIM, T_lat * VOC_CHUNK_COMPRESS) # (B, 24, T_lat*6)
|
||||
|
||||
# De-normalise: (* std + mean)
|
||||
x = x * self.tts.ae.latent_std + self.tts.ae.latent_mean
|
||||
|
||||
# Transpose to NTC for Conv1d layers
|
||||
x = x.transpose(0, 2, 1) # (B, T_lat*6, 24)
|
||||
|
||||
# embed
|
||||
x = self.tts.ae.decoder.embed(x) # (B, T_lat*6, 512)
|
||||
|
||||
# 10× ConvNeXt
|
||||
for blk in self.tts.ae.decoder.convnext:
|
||||
x = blk(x)
|
||||
|
||||
# final_norm (BatchNorm1d eval)
|
||||
x = self.tts.ae.decoder.final_norm(x)
|
||||
|
||||
# head
|
||||
x = self.tts.ae.decoder.head(x) # (B, T_lat*6, 512)
|
||||
|
||||
# Flatten time × channels row-major → waveform (matches ONNX:
|
||||
# head.layer2 Conv (B, 512, T_lat*6) → Transpose to (B, T_lat*6, 512) →
|
||||
# Reshape to (B, T_lat*6*512). Since the head already runs in NTC, we
|
||||
# are already in the post-Transpose layout and only the Reshape remains).
|
||||
wav = x.reshape(B, -1) # (B, T_lat*6*512)
|
||||
return wav
|
||||
|
||||
|
||||
__all__ = ["Vocoder", "VOC_DIM", "VOC_HIDDEN", "VOC_LDIM", "VOC_CHUNK_COMPRESS"]
|
||||
152
src/supertonic_3_mlx/weights.py
Normal file
152
src/supertonic_3_mlx/weights.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""ONNX → MLX safetensors conversion for Supertonic 3.
|
||||
|
||||
Two-stage extraction:
|
||||
1. **Named initializers** (e.g. ``vector_estimator.tts.ttl.vector_field.main_blocks.0.convnext.0.dwconv.weight``)
|
||||
— straight name strip + optional shape transformation.
|
||||
2. **Anonymous MatMul weights** (e.g. ``onnx::MatMul_3391``) — looked up via the
|
||||
MatMul node graph: each MatMul output path is the human-readable name of the
|
||||
weight (e.g. ``…/W_query/linear/MatMul_output_0``); we trace the second
|
||||
operand initializer and rebind it to the structured name + transpose to
|
||||
the MLX Linear layout ``(out, in)``.
|
||||
|
||||
Shape transformations:
|
||||
- depthwise dwconv: ONNX ``(C, 1, K)`` → MLX ``(C, K, 1)``
|
||||
- pwconv1/2 k=1: ONNX ``(out, in, 1)`` → MLX ``(out, in)``
|
||||
- proj_in/out k=1: ONNX ``(out, in, 1)`` → MLX ``(out, in)``
|
||||
- MatMul Linear: ONNX ``(in, out)`` → MLX ``(out, in)``
|
||||
- gamma: ONNX ``(1, dim, 1)`` → MLX ``(dim,)``
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
|
||||
|
||||
_ONNX_PREFIX = "vector_estimator.tts.ttl."
|
||||
|
||||
_DWCONV_SUFFIX = ".dwconv.weight"
|
||||
_PWCONV_SUFFIXES = (".pwconv1.weight", ".pwconv2.weight")
|
||||
_GAMMA_SUFFIX = ".gamma"
|
||||
|
||||
|
||||
def _strip_prefix(name: str) -> str:
|
||||
if name.startswith(_ONNX_PREFIX):
|
||||
return name[len(_ONNX_PREFIX):]
|
||||
return name
|
||||
|
||||
|
||||
def _is_named_weight(name: str) -> bool:
|
||||
"""True if this is a structured weight (vs anonymous graph constant)."""
|
||||
if name.startswith(_ONNX_PREFIX):
|
||||
return True
|
||||
if name.startswith("uncond_masker."):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _convert_named(clean_name: str, arr: np.ndarray) -> np.ndarray:
|
||||
"""Apply shape transforms to a named initializer based on its key."""
|
||||
# Depthwise Conv1d weight: (C, 1, K) → (C, K, 1)
|
||||
if clean_name.endswith(_DWCONV_SUFFIX) and arr.ndim == 3 and arr.shape[1] == 1 and arr.shape[2] != 1:
|
||||
arr = np.transpose(arr, (0, 2, 1))
|
||||
|
||||
# Pointwise k=1 / proj net weight: (out, in, 1) → (out, in)
|
||||
if (any(clean_name.endswith(s) for s in _PWCONV_SUFFIXES) or clean_name.endswith(".net.weight")) \
|
||||
and arr.ndim == 3 and arr.shape[-1] == 1:
|
||||
arr = arr.squeeze(-1)
|
||||
|
||||
# gamma: (1, C, 1) → (C,)
|
||||
if clean_name.endswith(_GAMMA_SUFFIX) and arr.ndim == 3 and arr.shape[0] == 1 and arr.shape[2] == 1:
|
||||
arr = arr.reshape(arr.shape[1])
|
||||
|
||||
return arr
|
||||
|
||||
|
||||
def _matmul_output_to_clean_name(matmul_output: str) -> str:
|
||||
"""Map a MatMul node output path to the structured ``.weight`` key.
|
||||
|
||||
Example::
|
||||
|
||||
/vector_estimator/vector_field/main_blocks.3/attn/W_query/linear/MatMul_output_0
|
||||
→ vector_field.main_blocks.3.attn.W_query.linear.weight
|
||||
"""
|
||||
# Strip prefix slash and the trailing /MatMul_output_0
|
||||
path = matmul_output.lstrip("/")
|
||||
if path.endswith("/MatMul_output_0"):
|
||||
path = path[: -len("/MatMul_output_0")]
|
||||
# Drop leading "vector_estimator/" if present
|
||||
if path.startswith("vector_estimator/"):
|
||||
path = path[len("vector_estimator/"):]
|
||||
return path.replace("/", ".") + ".weight"
|
||||
|
||||
|
||||
def convert_onnx_to_mlx(onnx_path: str | Path) -> Dict[str, mx.array]:
|
||||
"""Load an ONNX model and return all weights as ``{clean_name: mx.array}``.
|
||||
|
||||
Combines named initializers and MatMul-only weights into a single dict ready
|
||||
for ``model.load_weights(...)``.
|
||||
"""
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
|
||||
model = onnx.load(str(onnx_path))
|
||||
|
||||
# Build initializer name → numpy array map (in-memory once)
|
||||
inits: Dict[str, np.ndarray] = {
|
||||
init.name: numpy_helper.to_array(init) for init in model.graph.initializer
|
||||
}
|
||||
|
||||
out: Dict[str, mx.array] = {}
|
||||
|
||||
# Stage 1: named initializers
|
||||
for name, arr in inits.items():
|
||||
if not _is_named_weight(name):
|
||||
continue
|
||||
clean = _strip_prefix(name)
|
||||
arr = _convert_named(clean, arr)
|
||||
out[clean] = mx.array(arr)
|
||||
|
||||
# Stage 2: anonymous MatMul weights, recovered via the graph
|
||||
for node in model.graph.node:
|
||||
if node.op_type != "MatMul":
|
||||
continue
|
||||
if len(node.input) < 2:
|
||||
continue
|
||||
# The weight is conventionally the second operand
|
||||
weight_name = node.input[1]
|
||||
if weight_name not in inits:
|
||||
continue
|
||||
# Skip if it's already named structurally (shouldn't happen here)
|
||||
if _is_named_weight(weight_name):
|
||||
continue
|
||||
# Look up the structured name from the MatMul output path
|
||||
if len(node.output) < 1:
|
||||
continue
|
||||
clean = _matmul_output_to_clean_name(node.output[0])
|
||||
# ONNX MatMul stores W as (in, out); MLX Linear expects (out, in)
|
||||
arr = inits[weight_name]
|
||||
if arr.ndim == 2:
|
||||
arr = arr.T
|
||||
out[clean] = mx.array(arr)
|
||||
|
||||
if not out:
|
||||
raise RuntimeError(f"no weights extracted from {onnx_path}")
|
||||
return out
|
||||
|
||||
|
||||
def save_safetensors(
|
||||
onnx_path: str | Path,
|
||||
output_path: str | Path,
|
||||
) -> Dict[str, Tuple[int, ...]]:
|
||||
"""Convert an ONNX file to MLX safetensors. Returns a {name: shape} map."""
|
||||
weights = convert_onnx_to_mlx(onnx_path)
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
mx.save_safetensors(str(output_path), weights)
|
||||
return {k: tuple(v.shape) for k, v in weights.items()}
|
||||
|
||||
|
||||
__all__ = ["convert_onnx_to_mlx", "save_safetensors"]
|
||||
1
unicode_indexer.json
Normal file
1
unicode_indexer.json
Normal file
File diff suppressed because one or more lines are too long
1
voice_styles/F1.json
Normal file
1
voice_styles/F1.json
Normal file
File diff suppressed because one or more lines are too long
1
voice_styles/F2.json
Normal file
1
voice_styles/F2.json
Normal file
File diff suppressed because one or more lines are too long
1
voice_styles/F3.json
Normal file
1
voice_styles/F3.json
Normal file
File diff suppressed because one or more lines are too long
1
voice_styles/F4.json
Normal file
1
voice_styles/F4.json
Normal file
File diff suppressed because one or more lines are too long
1
voice_styles/F5.json
Normal file
1
voice_styles/F5.json
Normal file
File diff suppressed because one or more lines are too long
1
voice_styles/M1.json
Normal file
1
voice_styles/M1.json
Normal file
File diff suppressed because one or more lines are too long
1
voice_styles/M2.json
Normal file
1
voice_styles/M2.json
Normal file
File diff suppressed because one or more lines are too long
1
voice_styles/M3.json
Normal file
1
voice_styles/M3.json
Normal file
File diff suppressed because one or more lines are too long
1
voice_styles/M4.json
Normal file
1
voice_styles/M4.json
Normal file
File diff suppressed because one or more lines are too long
1
voice_styles/M5.json
Normal file
1
voice_styles/M5.json
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user