更新
This commit is contained in:
@@ -65,7 +65,7 @@ def load_model():
|
||||
start = time.time()
|
||||
|
||||
from cosyvoice.cli.cosyvoice import AutoModel
|
||||
_model = AutoModel(model_dir=str(MODEL_DIR))
|
||||
_model = AutoModel(model_dir=str(MODEL_DIR), fp16=True)
|
||||
|
||||
_model_loaded = True
|
||||
print(f"✅ CosyVoice 3.0 model loaded in {time.time() - start:.1f}s")
|
||||
|
||||
159
models/MuseTalk/LICENSE
Normal file
159
models/MuseTalk/LICENSE
Normal file
@@ -0,0 +1,159 @@
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Tencent Music Entertainment Group
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
Other dependencies and licenses:
|
||||
|
||||
|
||||
Open Source Software Licensed under the MIT License:
|
||||
--------------------------------------------------------------------
|
||||
1. sd-vae-ft-mse
|
||||
Files:https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main
|
||||
License:MIT license
|
||||
For details:https://choosealicense.com/licenses/mit/
|
||||
|
||||
2. whisper
|
||||
Files:https://github.com/openai/whisper
|
||||
License:MIT license
|
||||
Copyright (c) 2022 OpenAI
|
||||
For details:https://github.com/openai/whisper/blob/main/LICENSE
|
||||
|
||||
3. face-parsing.PyTorch
|
||||
Files:https://github.com/zllrunning/face-parsing.PyTorch
|
||||
License:MIT License
|
||||
Copyright (c) 2019 zll
|
||||
For details:https://github.com/zllrunning/face-parsing.PyTorch/blob/master/LICENSE
|
||||
|
||||
|
||||
|
||||
Open Source Software Licensed under the Apache License Version 2.0:
|
||||
--------------------------------------------------------------------
|
||||
1. DWpose
|
||||
Files:https://huggingface.co/yzd-v/DWPose/tree/main
|
||||
License:Apache-2.0
|
||||
For details:https://choosealicense.com/licenses/apache-2.0/
|
||||
|
||||
|
||||
Terms of the Apache License Version 2.0:
|
||||
--------------------------------------------------------------------
|
||||
Apache License
|
||||
|
||||
Version 2.0, January 2004
|
||||
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
||||
|
||||
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
||||
|
||||
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
||||
|
||||
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
||||
|
||||
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
|
||||
|
||||
Open Source Software Licensed under the BSD 3-Clause License:
|
||||
--------------------------------------------------------------------
|
||||
1. face-alignment
|
||||
Files:https://github.com/1adrianb/face-alignment/tree/master
|
||||
License:BSD 3-Clause License
|
||||
Copyright (c) 2017, Adrian Bulat
|
||||
All rights reserved.
|
||||
For details:https://github.com/1adrianb/face-alignment/blob/master/LICENSE
|
||||
|
||||
|
||||
Terms of the BSD 3-Clause License:
|
||||
--------------------------------------------------------------------
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
Open Source Software:
|
||||
--------------------------------------------------------------------
|
||||
1.s3FD
|
||||
Files:https://github.com/yxlijun/S3FD.pytorch
|
||||
556
models/MuseTalk/README.md
Normal file
556
models/MuseTalk/README.md
Normal file
@@ -0,0 +1,556 @@
|
||||
# MuseTalk
|
||||
|
||||
> **ViGent2 集成说明**
|
||||
>
|
||||
> 本目录为 MuseTalk v1.5 的部署副本,作为混合唇形同步方案的长视频引擎。
|
||||
>
|
||||
> - **服务**: `scripts/server.py` — FastAPI 常驻推理服务 (端口 8011, GPU0)
|
||||
> - **PM2**: `vigent2-musetalk` (启动脚本 `run_musetalk.sh`)
|
||||
> - **路由**: 音频 >=120s 自动路由到 MuseTalk, <120s 走 LatentSync
|
||||
> - **部署文档**: [`Docs/MUSETALK_DEPLOY.md`](../../Docs/MUSETALK_DEPLOY.md)
|
||||
> - **修改记录**: `scripts/inference.py` 增强 FFmpeg 调用 + CLI 参数; `musetalk/utils/audio_processor.py` 音视频长度不匹配时零填充
|
||||
|
||||
---
|
||||
|
||||
<strong>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</strong>
|
||||
|
||||
Yue Zhang<sup>\*</sup>,
|
||||
Zhizhou Zhong<sup>\*</sup>,
|
||||
Minhao Liu<sup>\*</sup>,
|
||||
Zhaokang Chen,
|
||||
Bin Wu<sup>†</sup>,
|
||||
Yubin Zeng,
|
||||
Chao Zhan,
|
||||
Junxin Huang,
|
||||
Yingjie He,
|
||||
Wenjiang Zhou
|
||||
(<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)
|
||||
|
||||
Lyra Lab, Tencent Music Entertainment
|
||||
|
||||
**[github](https://github.com/TMElyralab/MuseTalk)** **[huggingface](https://huggingface.co/TMElyralab/MuseTalk)** **[space](https://huggingface.co/spaces/TMElyralab/MuseTalk)** **[Technical report](https://arxiv.org/abs/2410.10122)**
|
||||
|
||||
We introduce `MuseTalk`, a **real-time high quality** lip-syncing model (30fps+ on an NVIDIA Tesla V100). MuseTalk can be applied with input videos, e.g., generated by [MuseV](https://github.com/TMElyralab/MuseV), as a complete virtual human solution.
|
||||
|
||||
## 🔥 Updates
|
||||
We're excited to unveil MuseTalk 1.5.
|
||||
This version **(1)** integrates training with perceptual loss, GAN loss, and sync loss, significantly boosting its overall performance. **(2)** We've implemented a two-stage training strategy and a spatio-temporal data sampling approach to strike a balance between visual quality and lip-sync accuracy.
|
||||
Learn more details [here](https://arxiv.org/abs/2410.10122).
|
||||
**The inference codes, training codes and model weights of MuseTalk 1.5 are all available now!** 🚀
|
||||
|
||||
# Overview
|
||||
`MuseTalk` is a real-time high quality audio-driven lip-syncing model trained in the latent space of `ft-mse-vae`, which
|
||||
|
||||
1. modifies an unseen face according to the input audio, with a size of face region of `256 x 256`.
|
||||
1. supports audio in various languages, such as Chinese, English, and Japanese.
|
||||
1. supports real-time inference with 30fps+ on an NVIDIA Tesla V100.
|
||||
1. supports modification of the center point of the face region proposes, which **SIGNIFICANTLY** affects generation results.
|
||||
1. checkpoint available trained on the HDTF and private dataset.
|
||||
|
||||
# News
|
||||
- [04/05/2025] :mega: We are excited to announce that the training code is now open-sourced! You can now train your own MuseTalk model using our provided training scripts and configurations.
|
||||
- [03/28/2025] We are thrilled to announce the release of our 1.5 version. This version is a significant improvement over the 1.0 version, with enhanced clarity, identity consistency, and precise lip-speech synchronization. We update the [technical report](https://arxiv.org/abs/2410.10122) with more details.
|
||||
- [10/18/2024] We release the [technical report](https://arxiv.org/abs/2410.10122v2). Our report details a superior model to the open-source L1 loss version. It includes GAN and perceptual losses for improved clarity, and sync loss for enhanced performance.
|
||||
- [04/17/2024] We release a pipeline that utilizes MuseTalk for real-time inference.
|
||||
- [04/16/2024] Release Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk) on HuggingFace Spaces (thanks to HF team for their community grant)
|
||||
- [04/02/2024] Release MuseTalk project and pretrained models.
|
||||
|
||||
|
||||
## Model
|
||||

|
||||
MuseTalk was trained in latent spaces, where the images were encoded by a freezed VAE. The audio was encoded by a freezed `whisper-tiny` model. The architecture of the generation network was borrowed from the UNet of the `stable-diffusion-v1-4`, where the audio embeddings were fused to the image embeddings by cross-attention.
|
||||
|
||||
Note that although we use a very similar architecture as Stable Diffusion, MuseTalk is distinct in that it is **NOT** a diffusion model. Instead, MuseTalk operates by inpainting in the latent space with a single step.
|
||||
|
||||
## Cases
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="33%">
|
||||
|
||||
### Input Video
|
||||
---
|
||||
https://github.com/TMElyralab/MuseTalk/assets/163980830/37a3a666-7b90-4244-8d3a-058cb0e44107
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/1ce3e850-90ac-4a31-a45f-8dfa4f2960ac
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/fa3b13a1-ae26-4d1d-899e-87435f8d22b3
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/15800692-39d1-4f4c-99f2-aef044dc3251
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/a843f9c9-136d-4ed4-9303-4a7269787a60
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/6eb4e70e-9e19-48e9-85a9-bbfa589c5fcb
|
||||
|
||||
</td>
|
||||
<td width="33%">
|
||||
|
||||
### MuseTalk 1.0
|
||||
---
|
||||
https://github.com/user-attachments/assets/c04f3cd5-9f77-40e9-aafd-61978380d0ef
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/2051a388-1cef-4c1d-b2a2-3c1ceee5dc99
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/b5f56f71-5cdc-4e2e-a519-454242000d32
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/a5843835-04ab-4c31-989f-0995cfc22f34
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/3dc7f1d7-8747-4733-bbdd-97874af0c028
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/3c78064e-faad-4637-83ae-28452a22b09a
|
||||
|
||||
</td>
|
||||
<td width="33%">
|
||||
|
||||
### MuseTalk 1.5
|
||||
---
|
||||
https://github.com/user-attachments/assets/999a6f5b-61dd-48e1-b902-bb3f9cbc7247
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/d26a5c9a-003c-489d-a043-c9a331456e75
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/471290d7-b157-4cf6-8a6d-7e899afa302c
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/1ee77c4c-8c70-4add-b6db-583a12faa7dc
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/370510ea-624c-43b7-bbb0-ab5333e0fcc4
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
||||
# TODO:
|
||||
- [x] trained models and inference codes.
|
||||
- [x] Huggingface Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk).
|
||||
- [x] codes for real-time inference.
|
||||
- [x] [technical report](https://arxiv.org/abs/2410.10122v2).
|
||||
- [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
|
||||
- [x] realtime inference code for 1.5 version.
|
||||
- [x] training and data preprocessing codes.
|
||||
- [ ] **always** welcome to submit issues and PRs to improve this repository! 😊
|
||||
|
||||
|
||||
# Getting Started
|
||||
We provide a detailed tutorial about the installation and the basic usage of MuseTalk for new users:
|
||||
|
||||
## Third party integration
|
||||
Thanks for the third-party integration, which makes installation and use more convenient for everyone.
|
||||
We also hope you note that we have not verified, maintained, or updated third-party. Please refer to this project for specific results.
|
||||
|
||||
### [ComfyUI](https://github.com/chaojie/ComfyUI-MuseTalk)
|
||||
|
||||
## Installation
|
||||
To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
|
||||
|
||||
### Build environment
|
||||
We recommend Python 3.10 and CUDA 11.7. Set up your environment as follows:
|
||||
|
||||
```shell
|
||||
conda create -n MuseTalk python==3.10
|
||||
conda activate MuseTalk
|
||||
```
|
||||
|
||||
### Install PyTorch 2.0.1
|
||||
Choose one of the following installation methods:
|
||||
|
||||
```shell
|
||||
# Option 1: Using pip
|
||||
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
# Option 2: Using conda
|
||||
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
|
||||
```
|
||||
|
||||
### Install Dependencies
|
||||
Install the remaining required packages:
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Install MMLab Packages
|
||||
Install the MMLab ecosystem packages:
|
||||
|
||||
```bash
|
||||
pip install --no-cache-dir -U openmim
|
||||
mim install mmengine
|
||||
mim install "mmcv==2.0.1"
|
||||
mim install "mmdet==3.1.0"
|
||||
mim install "mmpose==1.1.0"
|
||||
```
|
||||
|
||||
### Setup FFmpeg
|
||||
1. [Download](https://github.com/BtbN/FFmpeg-Builds/releases) the ffmpeg-static package
|
||||
|
||||
2. Configure FFmpeg based on your operating system:
|
||||
|
||||
For Linux:
|
||||
```bash
|
||||
export FFMPEG_PATH=/path/to/ffmpeg
|
||||
# Example:
|
||||
export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static
|
||||
```
|
||||
|
||||
For Windows:
|
||||
Add the `ffmpeg-xxx\bin` directory to your system's PATH environment variable. Verify the installation by running `ffmpeg -version` in the command prompt - it should display the ffmpeg version information.
|
||||
|
||||
### Download weights
|
||||
You can download weights in two ways:
|
||||
|
||||
#### Option 1: Using Download Scripts
|
||||
We provide two scripts for automatic downloading:
|
||||
|
||||
For Linux:
|
||||
```bash
|
||||
sh ./download_weights.sh
|
||||
```
|
||||
|
||||
For Windows:
|
||||
```batch
|
||||
# Run the script
|
||||
download_weights.bat
|
||||
```
|
||||
|
||||
#### Option 2: Manual Download
|
||||
You can also download the weights manually from the following links:
|
||||
|
||||
1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk/tree/main)
|
||||
2. Download the weights of other components:
|
||||
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main)
|
||||
- [whisper](https://huggingface.co/openai/whisper-tiny/tree/main)
|
||||
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
|
||||
- [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main)
|
||||
- [face-parse-bisent](https://drive.google.com/file/d/154JgKpzCPW82qINcVieuPH3fZ2e0P812/view?pli=1)
|
||||
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
|
||||
|
||||
Finally, these weights should be organized in `models` as follows:
|
||||
```
|
||||
./models/
|
||||
├── musetalk
|
||||
│ └── musetalk.json
|
||||
│ └── pytorch_model.bin
|
||||
├── musetalkV15
|
||||
│ └── musetalk.json
|
||||
│ └── unet.pth
|
||||
├── syncnet
|
||||
│ └── latentsync_syncnet.pt
|
||||
├── dwpose
|
||||
│ └── dw-ll_ucoco_384.pth
|
||||
├── face-parse-bisent
|
||||
│ ├── 79999_iter.pth
|
||||
│ └── resnet18-5c106cde.pth
|
||||
├── sd-vae
|
||||
│ ├── config.json
|
||||
│ └── diffusion_pytorch_model.bin
|
||||
└── whisper
|
||||
├── config.json
|
||||
├── pytorch_model.bin
|
||||
└── preprocessor_config.json
|
||||
|
||||
```
|
||||
## Quickstart
|
||||
|
||||
### Inference
|
||||
We provide inference scripts for both versions of MuseTalk:
|
||||
|
||||
#### Prerequisites
|
||||
Before running inference, please ensure ffmpeg is installed and accessible:
|
||||
```bash
|
||||
# Check ffmpeg installation
|
||||
ffmpeg -version
|
||||
```
|
||||
If ffmpeg is not found, please install it first:
|
||||
- Windows: Download from [ffmpeg-static](https://github.com/BtbN/FFmpeg-Builds/releases) and add to PATH
|
||||
- Linux: `sudo apt-get install ffmpeg`
|
||||
|
||||
#### Normal Inference
|
||||
##### Linux Environment
|
||||
```bash
|
||||
# MuseTalk 1.5 (Recommended)
|
||||
sh inference.sh v1.5 normal
|
||||
|
||||
# MuseTalk 1.0
|
||||
sh inference.sh v1.0 normal
|
||||
```
|
||||
|
||||
##### Windows Environment
|
||||
|
||||
Please ensure that you set the `ffmpeg_path` to match the actual location of your FFmpeg installation.
|
||||
|
||||
```bash
|
||||
# MuseTalk 1.5 (Recommended)
|
||||
python -m scripts.inference --inference_config configs\inference\test.yaml --result_dir results\test --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
||||
|
||||
# For MuseTalk 1.0, change:
|
||||
# - models\musetalkV15 -> models\musetalk
|
||||
# - unet.pth -> pytorch_model.bin
|
||||
# - --version v15 -> --version v1
|
||||
```
|
||||
|
||||
#### Real-time Inference
|
||||
##### Linux Environment
|
||||
```bash
|
||||
# MuseTalk 1.5 (Recommended)
|
||||
sh inference.sh v1.5 realtime
|
||||
|
||||
# MuseTalk 1.0
|
||||
sh inference.sh v1.0 realtime
|
||||
```
|
||||
|
||||
##### Windows Environment
|
||||
```bash
|
||||
# MuseTalk 1.5 (Recommended)
|
||||
python -m scripts.realtime_inference --inference_config configs\inference\realtime.yaml --result_dir results\realtime --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --fps 25 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
||||
|
||||
# For MuseTalk 1.0, change:
|
||||
# - models\musetalkV15 -> models\musetalk
|
||||
# - unet.pth -> pytorch_model.bin
|
||||
# - --version v15 -> --version v1
|
||||
```
|
||||
|
||||
The configuration file `configs/inference/test.yaml` contains the inference settings, including:
|
||||
- `video_path`: Path to the input video, image file, or directory of images
|
||||
- `audio_path`: Path to the input audio file
|
||||
|
||||
Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg.
|
||||
|
||||
Important notes for real-time inference:
|
||||
1. Set `preparation` to `True` when processing a new avatar
|
||||
2. After preparation, the avatar will generate videos using audio clips from `audio_clips`
|
||||
3. The generation process can achieve 30fps+ on an NVIDIA Tesla V100
|
||||
4. Set `preparation` to `False` for generating more videos with the same avatar
|
||||
|
||||
For faster generation without saving images, you can use:
|
||||
```bash
|
||||
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
|
||||
```
|
||||
|
||||
## Gradio Demo
|
||||
We provide an intuitive web interface through Gradio for users to easily adjust input parameters. To optimize inference time, users can generate only the **first frame** to fine-tune the best lip-sync parameters, which helps reduce facial artifacts in the final output.
|
||||

|
||||
For minimum hardware requirements, we tested the system on a Windows environment using an NVIDIA GeForce RTX 3050 Ti Laptop GPU with 4GB VRAM. In fp16 mode, generating an 8-second video takes approximately 5 minutes. 
|
||||
|
||||
Both Linux and Windows users can launch the demo using the following command. Please ensure that the `ffmpeg_path` parameter matches your actual FFmpeg installation path:
|
||||
|
||||
```bash
|
||||
# You can remove --use_float16 for better quality, but it will increase VRAM usage and inference time
|
||||
python app.py --use_float16 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Data Preparation
|
||||
To train MuseTalk, you need to prepare your dataset following these steps:
|
||||
|
||||
1. **Place your source videos**
|
||||
|
||||
For example, if you're using the HDTF dataset, place all your video files in `./dataset/HDTF/source`.
|
||||
|
||||
2. **Run the preprocessing script**
|
||||
```bash
|
||||
python -m scripts.preprocess --config ./configs/training/preprocess.yaml
|
||||
```
|
||||
This script will:
|
||||
- Extract frames from videos
|
||||
- Detect and align faces
|
||||
- Generate audio features
|
||||
- Create the necessary data structure for training
|
||||
|
||||
### Training Process
|
||||
After data preprocessing, you can start the training process:
|
||||
|
||||
1. **First Stage**
|
||||
```bash
|
||||
sh train.sh stage1
|
||||
```
|
||||
|
||||
2. **Second Stage**
|
||||
```bash
|
||||
sh train.sh stage2
|
||||
```
|
||||
|
||||
### Configuration Adjustment
|
||||
Before starting the training, you should adjust the configuration files according to your hardware and requirements:
|
||||
|
||||
1. **GPU Configuration** (`configs/training/gpu.yaml`):
|
||||
- `gpu_ids`: Specify the GPU IDs you want to use (e.g., "0,1,2,3")
|
||||
- `num_processes`: Set this to match the number of GPUs you're using
|
||||
|
||||
2. **Stage 1 Configuration** (`configs/training/stage1.yaml`):
|
||||
- `data.train_bs`: Adjust batch size based on your GPU memory (default: 32)
|
||||
- `data.n_sample_frames`: Number of sampled frames per video (default: 1)
|
||||
|
||||
3. **Stage 2 Configuration** (`configs/training/stage2.yaml`):
|
||||
- `random_init_unet`: Must be set to `False` to use the model from stage 1
|
||||
- `data.train_bs`: Smaller batch size due to high GPU memory cost (default: 2)
|
||||
- `data.n_sample_frames`: Higher value for temporal consistency (default: 16)
|
||||
- `solver.gradient_accumulation_steps`: Increase to simulate larger batch sizes (default: 8)
|
||||
|
||||
|
||||
### GPU Memory Requirements
|
||||
Based on our testing on a machine with 8 NVIDIA H20 GPUs:
|
||||
|
||||
#### Stage 1 Memory Usage
|
||||
| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
|
||||
|:----------:|:----------------------:|:--------------:|:--------------:|
|
||||
| 8 | 1 | ~32GB | |
|
||||
| 16 | 1 | ~45GB | |
|
||||
| 32 | 1 | ~74GB | ✓ |
|
||||
|
||||
#### Stage 2 Memory Usage
|
||||
| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
|
||||
|:----------:|:----------------------:|:--------------:|:--------------:|
|
||||
| 1 | 8 | ~54GB | |
|
||||
| 2 | 2 | ~80GB | |
|
||||
| 2 | 8 | ~85GB | ✓ |
|
||||
|
||||
<details close>
|
||||
## TestCases For 1.0
|
||||
<table class="center">
|
||||
<tr style="font-weight: bolder;text-align:center;">
|
||||
<td width="33%">Image</td>
|
||||
<td width="33%">MuseV</td>
|
||||
<td width="33%">+MuseTalk</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/musk/musk.png width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/4a4bb2d1-9d14-4ca9-85c8-7f19c39f712e controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/b2a879c2-e23a-4d39-911d-51f0343218e4 controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/yongen/yongen.jpeg width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/57ef9dee-a9fd-4dc8-839b-3fbbbf0ff3f4 controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/94d8dcba-1bcd-4b54-9d1d-8b6fc53228f0 controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/sit/sit.jpeg width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/5fbab81b-d3f2-4c75-abb5-14c76e51769e controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/f8100f4a-3df8-4151-8de2-291b09269f66 controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/man/man.png width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/a6e7d431-5643-4745-9868-8b423a454153 controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/6ccf7bc7-cb48-42de-85bd-076d5ee8a623 controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/monalisa/monalisa.png width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/1568f604-a34f-4526-a13a-7d282aa2e773 controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/a40784fc-a885-4c1f-9b7e-8f87b7caf4e0 controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/sun1/sun.png width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/37a3a666-7b90-4244-8d3a-058cb0e44107 controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/172f4ff1-d432-45bd-a5a7-a07dec33a26b controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/sun2/sun.png width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/37a3a666-7b90-4244-8d3a-058cb0e44107 controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/85a6873d-a028-4cce-af2b-6c59a1f2971d controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
</table >
|
||||
|
||||
#### Use of bbox_shift to have adjustable results(For 1.0)
|
||||
:mag_right: We have found that upper-bound of the mask has an important impact on mouth openness. Thus, to control the mask region, we suggest using the `bbox_shift` parameter. Positive values (moving towards the lower half) increase mouth openness, while negative values (moving towards the upper half) decrease mouth openness.
|
||||
|
||||
You can start by running with the default configuration to obtain the adjustable value range, and then re-run the script within this range.
|
||||
|
||||
For example, in the case of `Xinying Sun`, after running the default configuration, it shows that the adjustable value rage is [-9, 9]. Then, to decrease the mouth openness, we set the value to be `-7`.
|
||||
```
|
||||
python -m scripts.inference --inference_config configs/inference/test.yaml --bbox_shift -7
|
||||
```
|
||||
:pushpin: More technical details can be found in [bbox_shift](assets/BBOX_SHIFT.md).
|
||||
|
||||
|
||||
#### Combining MuseV and MuseTalk
|
||||
|
||||
As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
|
||||
|
||||
# Acknowledgement
|
||||
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch) and [LatentSync](https://huggingface.co/ByteDance/LatentSync/tree/main).
|
||||
1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
|
||||
1. MuseTalk has been built on [HDTF](https://github.com/MRzzm/HDTF) datasets.
|
||||
|
||||
Thanks for open-sourcing!
|
||||
|
||||
# Limitations
|
||||
- Resolution: Though MuseTalk uses a face region size of 256 x 256, which make it better than other open-source methods, it has not yet reached the theoretical resolution bound. We will continue to deal with this problem.
|
||||
If you need higher resolution, you could apply super resolution models such as [GFPGAN](https://github.com/TencentARC/GFPGAN) in combination with MuseTalk.
|
||||
|
||||
- Identity preservation: Some details of the original face are not well preserved, such as mustache, lip shape and color.
|
||||
|
||||
- Jitter: There exists some jitter as the current pipeline adopts single-frame generation.
|
||||
|
||||
# Citation
|
||||
```bib
|
||||
@article{musetalk,
|
||||
title={MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling},
|
||||
author={Zhang, Yue and Zhong, Zhizhou and Liu, Minhao and Chen, Zhaokang and Wu, Bin and Zeng, Yubin and Zhan, Chao and He, Yingjie and Huang, Junxin and Zhou, Wenjiang},
|
||||
journal={arxiv},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
# Disclaimer/License
|
||||
1. `code`: The code of MuseTalk is released under the MIT License. There is no limitation for both academic and commercial usage.
|
||||
1. `model`: The trained model are available for any purpose, even commercially.
|
||||
1. `other opensource model`: Other open-source models used must comply with their license, such as `whisper`, `ft-mse-vae`, `dwpose`, `S3FD`, etc..
|
||||
1. The testdata are collected from internet, which are available for non-commercial research purposes only.
|
||||
1. `AIGC`: This project strives to impact the domain of AI-driven video generation positively. Users are granted the freedom to create videos using this tool, but they are expected to comply with local laws and utilize it responsibly. The developers do not assume any responsibility for potential misuse by users.
|
||||
570
models/MuseTalk/app.py
Normal file
570
models/MuseTalk/app.py
Normal file
@@ -0,0 +1,570 @@
|
||||
import os
|
||||
import time
|
||||
import pdb
|
||||
import re
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
import requests
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from omegaconf import OmegaConf
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import glob
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
import copy
|
||||
from argparse import Namespace
|
||||
import shutil
|
||||
import gdown
|
||||
import imageio
|
||||
import ffmpeg
|
||||
from moviepy.editor import *
|
||||
from transformers import WhisperModel
|
||||
|
||||
ProjectDir = os.path.abspath(os.path.dirname(__file__))
|
||||
CheckpointsDir = os.path.join(ProjectDir, "models")
|
||||
|
||||
@torch.no_grad()
|
||||
def debug_inpainting(video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
|
||||
left_cheek_width=90, right_cheek_width=90):
|
||||
"""Debug inpainting parameters, only process the first frame"""
|
||||
# Set default parameters
|
||||
args_dict = {
|
||||
"result_dir": './results/debug',
|
||||
"fps": 25,
|
||||
"batch_size": 1,
|
||||
"output_vid_name": '',
|
||||
"use_saved_coord": False,
|
||||
"audio_padding_length_left": 2,
|
||||
"audio_padding_length_right": 2,
|
||||
"version": "v15",
|
||||
"extra_margin": extra_margin,
|
||||
"parsing_mode": parsing_mode,
|
||||
"left_cheek_width": left_cheek_width,
|
||||
"right_cheek_width": right_cheek_width
|
||||
}
|
||||
args = Namespace(**args_dict)
|
||||
|
||||
# Create debug directory
|
||||
os.makedirs(args.result_dir, exist_ok=True)
|
||||
|
||||
# Read first frame
|
||||
if get_file_type(video_path) == "video":
|
||||
reader = imageio.get_reader(video_path)
|
||||
first_frame = reader.get_data(0)
|
||||
reader.close()
|
||||
else:
|
||||
first_frame = cv2.imread(video_path)
|
||||
first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Save first frame
|
||||
debug_frame_path = os.path.join(args.result_dir, "debug_frame.png")
|
||||
cv2.imwrite(debug_frame_path, cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR))
|
||||
|
||||
# Get face coordinates
|
||||
coord_list, frame_list = get_landmark_and_bbox([debug_frame_path], bbox_shift)
|
||||
bbox = coord_list[0]
|
||||
frame = frame_list[0]
|
||||
|
||||
if bbox == coord_placeholder:
|
||||
return None, "No face detected, please adjust bbox_shift parameter"
|
||||
|
||||
# Initialize face parser
|
||||
fp = FaceParsing(
|
||||
left_cheek_width=args.left_cheek_width,
|
||||
right_cheek_width=args.right_cheek_width
|
||||
)
|
||||
|
||||
# Process first frame
|
||||
x1, y1, x2, y2 = bbox
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
|
||||
# Generate random audio features
|
||||
random_audio = torch.randn(1, 50, 384, device=device, dtype=weight_dtype)
|
||||
audio_feature = pe(random_audio)
|
||||
|
||||
# Get latents
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
latents = latents.to(dtype=weight_dtype)
|
||||
|
||||
# Generate prediction results
|
||||
pred_latents = unet.model(latents, timesteps, encoder_hidden_states=audio_feature).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
|
||||
# Inpaint back to original image
|
||||
res_frame = recon[0]
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||
combine_frame = get_image(frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
||||
|
||||
# Save results (no need to convert color space again since get_image already returns RGB format)
|
||||
debug_result_path = os.path.join(args.result_dir, "debug_result.png")
|
||||
cv2.imwrite(debug_result_path, combine_frame)
|
||||
|
||||
# Create information text
|
||||
info_text = f"Parameter information:\n" + \
|
||||
f"bbox_shift: {bbox_shift}\n" + \
|
||||
f"extra_margin: {extra_margin}\n" + \
|
||||
f"parsing_mode: {parsing_mode}\n" + \
|
||||
f"left_cheek_width: {left_cheek_width}\n" + \
|
||||
f"right_cheek_width: {right_cheek_width}\n" + \
|
||||
f"Detected face coordinates: [{x1}, {y1}, {x2}, {y2}]"
|
||||
|
||||
return cv2.cvtColor(combine_frame, cv2.COLOR_RGB2BGR), info_text
|
||||
|
||||
def print_directory_contents(path):
|
||||
for child in os.listdir(path):
|
||||
child_path = os.path.join(path, child)
|
||||
if os.path.isdir(child_path):
|
||||
print(child_path)
|
||||
|
||||
def download_model():
|
||||
# 检查必需的模型文件是否存在
|
||||
required_models = {
|
||||
"MuseTalk": f"{CheckpointsDir}/musetalkV15/unet.pth",
|
||||
"MuseTalk": f"{CheckpointsDir}/musetalkV15/musetalk.json",
|
||||
"SD VAE": f"{CheckpointsDir}/sd-vae/config.json",
|
||||
"Whisper": f"{CheckpointsDir}/whisper/config.json",
|
||||
"DWPose": f"{CheckpointsDir}/dwpose/dw-ll_ucoco_384.pth",
|
||||
"SyncNet": f"{CheckpointsDir}/syncnet/latentsync_syncnet.pt",
|
||||
"Face Parse": f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth",
|
||||
"ResNet": f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
|
||||
}
|
||||
|
||||
missing_models = []
|
||||
for model_name, model_path in required_models.items():
|
||||
if not os.path.exists(model_path):
|
||||
missing_models.append(model_name)
|
||||
|
||||
if missing_models:
|
||||
# 全用英文
|
||||
print("The following required model files are missing:")
|
||||
for model in missing_models:
|
||||
print(f"- {model}")
|
||||
print("\nPlease run the download script to download the missing models:")
|
||||
if sys.platform == "win32":
|
||||
print("Windows: Run download_weights.bat")
|
||||
else:
|
||||
print("Linux/Mac: Run ./download_weights.sh")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("All required model files exist.")
|
||||
|
||||
|
||||
|
||||
|
||||
download_model() # for huggingface deployment.
|
||||
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.face_parsing import FaceParsing
|
||||
from musetalk.utils.audio_processor import AudioProcessor
|
||||
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder, get_bbox_range
|
||||
|
||||
|
||||
def fast_check_ffmpeg():
|
||||
try:
|
||||
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
|
||||
left_cheek_width=90, right_cheek_width=90, progress=gr.Progress(track_tqdm=True)):
|
||||
# Set default parameters, aligned with inference.py
|
||||
args_dict = {
|
||||
"result_dir": './results/output',
|
||||
"fps": 25,
|
||||
"batch_size": 8,
|
||||
"output_vid_name": '',
|
||||
"use_saved_coord": False,
|
||||
"audio_padding_length_left": 2,
|
||||
"audio_padding_length_right": 2,
|
||||
"version": "v15", # Fixed use v15 version
|
||||
"extra_margin": extra_margin,
|
||||
"parsing_mode": parsing_mode,
|
||||
"left_cheek_width": left_cheek_width,
|
||||
"right_cheek_width": right_cheek_width
|
||||
}
|
||||
args = Namespace(**args_dict)
|
||||
|
||||
# Check ffmpeg
|
||||
if not fast_check_ffmpeg():
|
||||
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||
|
||||
input_basename = os.path.basename(video_path).split('.')[0]
|
||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||
output_basename = f"{input_basename}_{audio_basename}"
|
||||
|
||||
# Create temporary directory
|
||||
temp_dir = os.path.join(args.result_dir, f"{args.version}")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# Set result save path
|
||||
result_img_save_path = os.path.join(temp_dir, output_basename)
|
||||
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
|
||||
os.makedirs(result_img_save_path, exist_ok=True)
|
||||
|
||||
if args.output_vid_name == "":
|
||||
output_vid_name = os.path.join(temp_dir, output_basename+".mp4")
|
||||
else:
|
||||
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
|
||||
|
||||
############################################## extract frames from source video ##############################################
|
||||
if get_file_type(video_path) == "video":
|
||||
save_dir_full = os.path.join(temp_dir, input_basename)
|
||||
os.makedirs(save_dir_full, exist_ok=True)
|
||||
# Read video
|
||||
reader = imageio.get_reader(video_path)
|
||||
|
||||
# Save images
|
||||
for i, im in enumerate(reader):
|
||||
imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im)
|
||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||
fps = get_video_fps(video_path)
|
||||
else: # input img folder
|
||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
fps = args.fps
|
||||
|
||||
############################################## extract audio feature ##############################################
|
||||
# Extract audio features
|
||||
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
||||
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||
whisper_input_features,
|
||||
device,
|
||||
weight_dtype,
|
||||
whisper,
|
||||
librosa_length,
|
||||
fps=fps,
|
||||
audio_padding_length_left=args.audio_padding_length_left,
|
||||
audio_padding_length_right=args.audio_padding_length_right,
|
||||
)
|
||||
|
||||
############################################## preprocess input image ##############################################
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("using extracted coordinates")
|
||||
with open(crop_coord_save_path,'rb') as f:
|
||||
coord_list = pickle.load(f)
|
||||
frame_list = read_imgs(input_img_list)
|
||||
else:
|
||||
print("extracting landmarks...time consuming")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
bbox_shift_text = get_bbox_range(input_img_list, bbox_shift)
|
||||
|
||||
# Initialize face parser
|
||||
fp = FaceParsing(
|
||||
left_cheek_width=args.left_cheek_width,
|
||||
right_cheek_width=args.right_cheek_width
|
||||
)
|
||||
|
||||
i = 0
|
||||
input_latent_list = []
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
|
||||
# to smooth the first and the last frame
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
|
||||
############################################## inference batch by batch ##############################################
|
||||
print("start inference")
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(
|
||||
whisper_chunks=whisper_chunks,
|
||||
vae_encode_latents=input_latent_list_cycle,
|
||||
batch_size=batch_size,
|
||||
delay_frame=0,
|
||||
device=device,
|
||||
)
|
||||
res_frame_list = []
|
||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
audio_feature_batch = pe(whisper_batch)
|
||||
# Ensure latent_batch is consistent with model weight type
|
||||
latent_batch = latent_batch.to(dtype=weight_dtype)
|
||||
|
||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
res_frame_list.append(res_frame)
|
||||
|
||||
############################################## pad to full image ##############################################
|
||||
print("pad talking image to original video")
|
||||
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
||||
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
||||
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
||||
x1, y1, x2, y2 = bbox
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||
except:
|
||||
continue
|
||||
|
||||
# Use v15 version blending
|
||||
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
||||
|
||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
||||
|
||||
# Frame rate
|
||||
fps = 25
|
||||
# Output video path
|
||||
output_video = 'temp.mp4'
|
||||
|
||||
# Read images
|
||||
def is_valid_image(file):
|
||||
pattern = re.compile(r'\d{8}\.png')
|
||||
return pattern.match(file)
|
||||
|
||||
images = []
|
||||
files = [file for file in os.listdir(result_img_save_path) if is_valid_image(file)]
|
||||
files.sort(key=lambda x: int(x.split('.')[0]))
|
||||
|
||||
for file in files:
|
||||
filename = os.path.join(result_img_save_path, file)
|
||||
images.append(imageio.imread(filename))
|
||||
|
||||
|
||||
# Save video
|
||||
imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p')
|
||||
|
||||
input_video = './temp.mp4'
|
||||
# Check if the input_video and audio_path exist
|
||||
if not os.path.exists(input_video):
|
||||
raise FileNotFoundError(f"Input video file not found: {input_video}")
|
||||
if not os.path.exists(audio_path):
|
||||
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
||||
|
||||
# Read video
|
||||
reader = imageio.get_reader(input_video)
|
||||
fps = reader.get_meta_data()['fps'] # Get original video frame rate
|
||||
reader.close() # Otherwise, error on win11: PermissionError: [WinError 32] Another program is using this file, process cannot access. : 'temp.mp4'
|
||||
# Store frames in list
|
||||
frames = images
|
||||
|
||||
print(len(frames))
|
||||
|
||||
# Load the video
|
||||
video_clip = VideoFileClip(input_video)
|
||||
|
||||
# Load the audio
|
||||
audio_clip = AudioFileClip(audio_path)
|
||||
|
||||
# Set the audio to the video
|
||||
video_clip = video_clip.set_audio(audio_clip)
|
||||
|
||||
# Write the output video
|
||||
video_clip.write_videofile(output_vid_name, codec='libx264', audio_codec='aac',fps=25)
|
||||
|
||||
os.remove("temp.mp4")
|
||||
#shutil.rmtree(result_img_save_path)
|
||||
print(f"result is save to {output_vid_name}")
|
||||
return output_vid_name,bbox_shift_text
|
||||
|
||||
|
||||
|
||||
# load model weights
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
vae, unet, pe = load_all_model(
|
||||
unet_model_path="./models/musetalkV15/unet.pth",
|
||||
vae_type="sd-vae",
|
||||
unet_config="./models/musetalkV15/musetalk.json",
|
||||
device=device
|
||||
)
|
||||
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ffmpeg_path", type=str, default=r"ffmpeg-master-latest-win64-gpl-shared\bin", help="Path to ffmpeg executable")
|
||||
parser.add_argument("--ip", type=str, default="127.0.0.1", help="IP address to bind to")
|
||||
parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
|
||||
parser.add_argument("--share", action="store_true", help="Create a public link")
|
||||
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set data type
|
||||
if args.use_float16:
|
||||
# Convert models to half precision for better performance
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
weight_dtype = torch.float16
|
||||
else:
|
||||
weight_dtype = torch.float32
|
||||
|
||||
# Move models to specified device
|
||||
pe = pe.to(device)
|
||||
vae.vae = vae.vae.to(device)
|
||||
unet.model = unet.model.to(device)
|
||||
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
# Initialize audio processor and Whisper model
|
||||
audio_processor = AudioProcessor(feature_extractor_path="./models/whisper")
|
||||
whisper = WhisperModel.from_pretrained("./models/whisper")
|
||||
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
||||
whisper.requires_grad_(False)
|
||||
|
||||
|
||||
def check_video(video):
|
||||
if not isinstance(video, str):
|
||||
return video # in case of none type
|
||||
# Define the output video file name
|
||||
dir_path, file_name = os.path.split(video)
|
||||
if file_name.startswith("outputxxx_"):
|
||||
return video
|
||||
# Add the output prefix to the file name
|
||||
output_file_name = "outputxxx_" + file_name
|
||||
|
||||
os.makedirs('./results',exist_ok=True)
|
||||
os.makedirs('./results/output',exist_ok=True)
|
||||
os.makedirs('./results/input',exist_ok=True)
|
||||
|
||||
# Combine the directory path and the new file name
|
||||
output_video = os.path.join('./results/input', output_file_name)
|
||||
|
||||
|
||||
# read video
|
||||
reader = imageio.get_reader(video)
|
||||
fps = reader.get_meta_data()['fps'] # get fps from original video
|
||||
|
||||
# conver fps to 25
|
||||
frames = [im for im in reader]
|
||||
target_fps = 25
|
||||
|
||||
L = len(frames)
|
||||
L_target = int(L / fps * target_fps)
|
||||
original_t = [x / fps for x in range(1, L+1)]
|
||||
t_idx = 0
|
||||
target_frames = []
|
||||
for target_t in range(1, L_target+1):
|
||||
while target_t / target_fps > original_t[t_idx]:
|
||||
t_idx += 1 # find the first t_idx so that target_t / target_fps <= original_t[t_idx]
|
||||
if t_idx >= L:
|
||||
break
|
||||
target_frames.append(frames[t_idx])
|
||||
|
||||
# save video
|
||||
imageio.mimwrite(output_video, target_frames, 'FFMPEG', fps=25, codec='libx264', quality=9, pixelformat='yuv420p')
|
||||
return output_video
|
||||
|
||||
|
||||
|
||||
|
||||
css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}"""
|
||||
|
||||
with gr.Blocks(css=css) as demo:
|
||||
gr.Markdown(
|
||||
"""<div align='center'> <h1>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</h1> \
|
||||
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
||||
</br>\
|
||||
Yue Zhang <sup>*</sup>,\
|
||||
Zhizhou Zhong <sup>*</sup>,\
|
||||
Minhao Liu<sup>*</sup>,\
|
||||
Zhaokang Chen,\
|
||||
Bin Wu<sup>†</sup>,\
|
||||
Yubin Zeng,\
|
||||
Chao Zhang,\
|
||||
Yingjie He,\
|
||||
Junxin Huang,\
|
||||
Wenjiang Zhou <br>\
|
||||
(<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\
|
||||
Lyra Lab, Tencent Music Entertainment\
|
||||
</h2> \
|
||||
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
|
||||
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
|
||||
<a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2410.10122'> [Technical report] </a>"""
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
audio = gr.Audio(label="Drving Audio",type="filepath")
|
||||
video = gr.Video(label="Reference Video",sources=['upload'])
|
||||
bbox_shift = gr.Number(label="BBox_shift value, px", value=0)
|
||||
extra_margin = gr.Slider(label="Extra Margin", minimum=0, maximum=40, value=10, step=1)
|
||||
parsing_mode = gr.Radio(label="Parsing Mode", choices=["jaw", "raw"], value="jaw")
|
||||
left_cheek_width = gr.Slider(label="Left Cheek Width", minimum=20, maximum=160, value=90, step=5)
|
||||
right_cheek_width = gr.Slider(label="Right Cheek Width", minimum=20, maximum=160, value=90, step=5)
|
||||
bbox_shift_scale = gr.Textbox(label="'left_cheek_width' and 'right_cheek_width' parameters determine the range of left and right cheeks editing when parsing model is 'jaw'. The 'extra_margin' parameter determines the movement range of the jaw. Users can freely adjust these three parameters to obtain better inpainting results.")
|
||||
|
||||
with gr.Row():
|
||||
debug_btn = gr.Button("1. Test Inpainting ")
|
||||
btn = gr.Button("2. Generate")
|
||||
with gr.Column():
|
||||
debug_image = gr.Image(label="Test Inpainting Result (First Frame)")
|
||||
debug_info = gr.Textbox(label="Parameter Information", lines=5)
|
||||
out1 = gr.Video()
|
||||
|
||||
video.change(
|
||||
fn=check_video, inputs=[video], outputs=[video]
|
||||
)
|
||||
btn.click(
|
||||
fn=inference,
|
||||
inputs=[
|
||||
audio,
|
||||
video,
|
||||
bbox_shift,
|
||||
extra_margin,
|
||||
parsing_mode,
|
||||
left_cheek_width,
|
||||
right_cheek_width
|
||||
],
|
||||
outputs=[out1,bbox_shift_scale]
|
||||
)
|
||||
debug_btn.click(
|
||||
fn=debug_inpainting,
|
||||
inputs=[
|
||||
video,
|
||||
bbox_shift,
|
||||
extra_margin,
|
||||
parsing_mode,
|
||||
left_cheek_width,
|
||||
right_cheek_width
|
||||
],
|
||||
outputs=[debug_image, debug_info]
|
||||
)
|
||||
|
||||
# Check ffmpeg and add to PATH
|
||||
if not fast_check_ffmpeg():
|
||||
print(f"Adding ffmpeg to PATH: {args.ffmpeg_path}")
|
||||
# According to operating system, choose path separator
|
||||
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||
if not fast_check_ffmpeg():
|
||||
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||
|
||||
# Solve asynchronous IO issues on Windows
|
||||
if sys.platform == 'win32':
|
||||
import asyncio
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
# Start Gradio application
|
||||
demo.queue().launch(
|
||||
share=args.share,
|
||||
debug=True,
|
||||
server_name=args.ip,
|
||||
server_port=args.port
|
||||
)
|
||||
10
models/MuseTalk/configs/inference/realtime.yaml
Normal file
10
models/MuseTalk/configs/inference/realtime.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
avator_1:
|
||||
preparation: True # your can set it to False if you want to use the existing avator, it will save time
|
||||
bbox_shift: 5
|
||||
video_path: "data/video/yongen.mp4"
|
||||
audio_clips:
|
||||
audio_0: "data/audio/yongen.wav"
|
||||
audio_1: "data/audio/eng.wav"
|
||||
|
||||
|
||||
|
||||
10
models/MuseTalk/configs/inference/test.yaml
Normal file
10
models/MuseTalk/configs/inference/test.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
task_0:
|
||||
video_path: "data/video/yongen.mp4"
|
||||
audio_path: "data/audio/yongen.wav"
|
||||
|
||||
task_1:
|
||||
video_path: "data/video/yongen.mp4"
|
||||
audio_path: "data/audio/eng.wav"
|
||||
bbox_shift: -7
|
||||
|
||||
|
||||
21
models/MuseTalk/configs/training/gpu.yaml
Normal file
21
models/MuseTalk/configs/training/gpu.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: True
|
||||
deepspeed_config:
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: False
|
||||
zero_stage: 2
|
||||
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: "5, 7" # modify this according to your GPU number
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
num_machines: 1
|
||||
num_processes: 2 # it should be the same as the number of GPUs
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
31
models/MuseTalk/configs/training/preprocess.yaml
Normal file
31
models/MuseTalk/configs/training/preprocess.yaml
Normal file
@@ -0,0 +1,31 @@
|
||||
clip_len_second: 30 # the length of the video clip
|
||||
video_root_raw: "./dataset/HDTF/source/" # the path of the original video
|
||||
val_list_hdtf:
|
||||
- RD_Radio7_000
|
||||
- RD_Radio8_000
|
||||
- RD_Radio9_000
|
||||
- WDA_TinaSmith_000
|
||||
- WDA_TomCarper_000
|
||||
- WDA_TomPerez_000
|
||||
- WDA_TomUdall_000
|
||||
- WDA_VeronicaEscobar0_000
|
||||
- WDA_VeronicaEscobar1_000
|
||||
- WDA_WhipJimClyburn_000
|
||||
- WDA_XavierBecerra_000
|
||||
- WDA_XavierBecerra_001
|
||||
- WDA_XavierBecerra_002
|
||||
- WDA_ZoeLofgren_000
|
||||
- WRA_SteveScalise1_000
|
||||
- WRA_TimScott_000
|
||||
- WRA_ToddYoung_000
|
||||
- WRA_TomCotton_000
|
||||
- WRA_TomPrice_000
|
||||
- WRA_VickyHartzler_000
|
||||
|
||||
# following dir will be automatically generated
|
||||
video_root_25fps: "./dataset/HDTF/video_root_25fps/"
|
||||
video_file_list: "./dataset/HDTF/video_file_list.txt"
|
||||
video_audio_clip_root: "./dataset/HDTF/video_audio_clip_root/"
|
||||
meta_root: "./dataset/HDTF/meta/"
|
||||
video_clip_file_list_train: "./dataset/HDTF/train.txt"
|
||||
video_clip_file_list_val: "./dataset/HDTF/val.txt"
|
||||
89
models/MuseTalk/configs/training/stage1.yaml
Normal file
89
models/MuseTalk/configs/training/stage1.yaml
Normal file
@@ -0,0 +1,89 @@
|
||||
exp_name: 'test' # Name of the experiment
|
||||
output_dir: './exp_out/stage1/' # Directory to save experiment outputs
|
||||
unet_sub_folder: musetalk # Subfolder name for UNet model
|
||||
random_init_unet: True # Whether to randomly initialize UNet (stage1) or use pretrained weights (stage2)
|
||||
whisper_path: "./models/whisper" # Path to the Whisper model
|
||||
pretrained_model_name_or_path: "./models" # Path to pretrained models
|
||||
resume_from_checkpoint: True # Whether to resume training from a checkpoint
|
||||
padding_pixel_mouth: 10 # Number of pixels to pad around the mouth region
|
||||
vae_type: "sd-vae" # Type of VAE model to use
|
||||
# Validation parameters
|
||||
num_images_to_keep: 8 # Number of validation images to keep
|
||||
ref_dropout_rate: 0 # Dropout rate for reference images
|
||||
syncnet_config_path: "./configs/training/syncnet.yaml" # Path to SyncNet configuration
|
||||
use_adapted_weight: False # Whether to use adapted weights for loss calculation
|
||||
cropping_jaw2edge_margin_mean: 10 # Mean margin for jaw-to-edge cropping
|
||||
cropping_jaw2edge_margin_std: 10 # Standard deviation for jaw-to-edge cropping
|
||||
crop_type: "crop_resize" # Type of cropping method
|
||||
random_margin_method: "normal" # Method for random margin generation
|
||||
num_backward_frames: 16 # Number of frames to use for backward pass in SyncNet
|
||||
|
||||
data:
|
||||
dataset_key: "HDTF" # Dataset to use for training
|
||||
train_bs: 32 # Training batch size (actual batch size is train_bs*n_sample_frames)
|
||||
image_size: 256 # Size of input images
|
||||
n_sample_frames: 1 # Number of frames to sample per batch
|
||||
num_workers: 8 # Number of data loading workers
|
||||
audio_padding_length_left: 2 # Left padding length for audio features
|
||||
audio_padding_length_right: 2 # Right padding length for audio features
|
||||
sample_method: pose_similarity_and_mouth_dissimilarity # Method for sampling frames
|
||||
top_k_ratio: 0.51 # Ratio for top-k sampling
|
||||
contorl_face_min_size: True # Whether to control minimum face size
|
||||
min_face_size: 150 # Minimum face size in pixels
|
||||
|
||||
loss_params:
|
||||
l1_loss: 1.0 # Weight for L1 loss
|
||||
vgg_loss: 0.01 # Weight for VGG perceptual loss
|
||||
vgg_layer_weight: [1, 1, 1, 1, 1] # Weights for different VGG layers
|
||||
pyramid_scale: [1, 0.5, 0.25, 0.125] # Scales for image pyramid
|
||||
gan_loss: 0 # Weight for GAN loss
|
||||
fm_loss: [1.0, 1.0, 1.0, 1.0] # Weights for feature matching loss
|
||||
sync_loss: 0 # Weight for sync loss
|
||||
mouth_gan_loss: 0 # Weight for mouth-specific GAN loss
|
||||
|
||||
model_params:
|
||||
discriminator_params:
|
||||
scales: [1] # Scales for discriminator
|
||||
block_expansion: 32 # Expansion factor for discriminator blocks
|
||||
max_features: 512 # Maximum number of features in discriminator
|
||||
num_blocks: 4 # Number of blocks in discriminator
|
||||
sn: True # Whether to use spectral normalization
|
||||
image_channel: 3 # Number of image channels
|
||||
estimate_jacobian: False # Whether to estimate Jacobian
|
||||
|
||||
discriminator_train_params:
|
||||
lr: 0.000005 # Learning rate for discriminator
|
||||
eps: 0.00000001 # Epsilon for optimizer
|
||||
weight_decay: 0.01 # Weight decay for optimizer
|
||||
patch_size: 1 # Size of patches for discriminator
|
||||
betas: [0.5, 0.999] # Beta parameters for Adam optimizer
|
||||
epochs: 10000 # Number of training epochs
|
||||
start_gan: 1000 # Step to start GAN training
|
||||
|
||||
solver:
|
||||
gradient_accumulation_steps: 1 # Number of steps for gradient accumulation
|
||||
uncond_steps: 10 # Number of unconditional steps
|
||||
mixed_precision: 'fp32' # Precision mode for training
|
||||
enable_xformers_memory_efficient_attention: True # Whether to use memory efficient attention
|
||||
gradient_checkpointing: True # Whether to use gradient checkpointing
|
||||
max_train_steps: 250000 # Maximum number of training steps
|
||||
max_grad_norm: 1.0 # Maximum gradient norm for clipping
|
||||
# Learning rate parameters
|
||||
learning_rate: 2.0e-5 # Base learning rate
|
||||
scale_lr: False # Whether to scale learning rate
|
||||
lr_warmup_steps: 1000 # Number of warmup steps for learning rate
|
||||
lr_scheduler: "linear" # Type of learning rate scheduler
|
||||
# Optimizer parameters
|
||||
use_8bit_adam: False # Whether to use 8-bit Adam optimizer
|
||||
adam_beta1: 0.5 # Beta1 parameter for Adam optimizer
|
||||
adam_beta2: 0.999 # Beta2 parameter for Adam optimizer
|
||||
adam_weight_decay: 1.0e-2 # Weight decay for Adam optimizer
|
||||
adam_epsilon: 1.0e-8 # Epsilon for Adam optimizer
|
||||
|
||||
total_limit: 10 # Maximum number of checkpoints to keep
|
||||
save_model_epoch_interval: 250000 # Interval between model saves
|
||||
checkpointing_steps: 10000 # Number of steps between checkpoints
|
||||
val_freq: 2000 # Frequency of validation
|
||||
|
||||
seed: 41 # Random seed for reproducibility
|
||||
|
||||
89
models/MuseTalk/configs/training/stage2.yaml
Normal file
89
models/MuseTalk/configs/training/stage2.yaml
Normal file
@@ -0,0 +1,89 @@
|
||||
exp_name: 'test' # Name of the experiment
|
||||
output_dir: './exp_out/stage2/' # Directory to save experiment outputs
|
||||
unet_sub_folder: musetalk # Subfolder name for UNet model
|
||||
random_init_unet: False # Whether to randomly initialize UNet (stage1) or use pretrained weights (stage2)
|
||||
whisper_path: "./models/whisper" # Path to the Whisper model
|
||||
pretrained_model_name_or_path: "./models" # Path to pretrained models
|
||||
resume_from_checkpoint: True # Whether to resume training from a checkpoint
|
||||
padding_pixel_mouth: 10 # Number of pixels to pad around the mouth region
|
||||
vae_type: "sd-vae" # Type of VAE model to use
|
||||
# Validation parameters
|
||||
num_images_to_keep: 8 # Number of validation images to keep
|
||||
ref_dropout_rate: 0 # Dropout rate for reference images
|
||||
syncnet_config_path: "./configs/training/syncnet.yaml" # Path to SyncNet configuration
|
||||
use_adapted_weight: False # Whether to use adapted weights for loss calculation
|
||||
cropping_jaw2edge_margin_mean: 10 # Mean margin for jaw-to-edge cropping
|
||||
cropping_jaw2edge_margin_std: 10 # Standard deviation for jaw-to-edge cropping
|
||||
crop_type: "dynamic_margin_crop_resize" # Type of cropping method
|
||||
random_margin_method: "normal" # Method for random margin generation
|
||||
num_backward_frames: 16 # Number of frames to use for backward pass in SyncNet
|
||||
|
||||
data:
|
||||
dataset_key: "HDTF" # Dataset to use for training
|
||||
train_bs: 2 # Training batch size (actual batch size is train_bs*n_sample_frames)
|
||||
image_size: 256 # Size of input images
|
||||
n_sample_frames: 16 # Number of frames to sample per batch
|
||||
num_workers: 8 # Number of data loading workers
|
||||
audio_padding_length_left: 2 # Left padding length for audio features
|
||||
audio_padding_length_right: 2 # Right padding length for audio features
|
||||
sample_method: pose_similarity_and_mouth_dissimilarity # Method for sampling frames
|
||||
top_k_ratio: 0.51 # Ratio for top-k sampling
|
||||
contorl_face_min_size: True # Whether to control minimum face size
|
||||
min_face_size: 200 # Minimum face size in pixels
|
||||
|
||||
loss_params:
|
||||
l1_loss: 1.0 # Weight for L1 loss
|
||||
vgg_loss: 0.01 # Weight for VGG perceptual loss
|
||||
vgg_layer_weight: [1, 1, 1, 1, 1] # Weights for different VGG layers
|
||||
pyramid_scale: [1, 0.5, 0.25, 0.125] # Scales for image pyramid
|
||||
gan_loss: 0.01 # Weight for GAN loss
|
||||
fm_loss: [1.0, 1.0, 1.0, 1.0] # Weights for feature matching loss
|
||||
sync_loss: 0.05 # Weight for sync loss
|
||||
mouth_gan_loss: 0.01 # Weight for mouth-specific GAN loss
|
||||
|
||||
model_params:
|
||||
discriminator_params:
|
||||
scales: [1] # Scales for discriminator
|
||||
block_expansion: 32 # Expansion factor for discriminator blocks
|
||||
max_features: 512 # Maximum number of features in discriminator
|
||||
num_blocks: 4 # Number of blocks in discriminator
|
||||
sn: True # Whether to use spectral normalization
|
||||
image_channel: 3 # Number of image channels
|
||||
estimate_jacobian: False # Whether to estimate Jacobian
|
||||
|
||||
discriminator_train_params:
|
||||
lr: 0.000005 # Learning rate for discriminator
|
||||
eps: 0.00000001 # Epsilon for optimizer
|
||||
weight_decay: 0.01 # Weight decay for optimizer
|
||||
patch_size: 1 # Size of patches for discriminator
|
||||
betas: [0.5, 0.999] # Beta parameters for Adam optimizer
|
||||
epochs: 10000 # Number of training epochs
|
||||
start_gan: 1000 # Step to start GAN training
|
||||
|
||||
solver:
|
||||
gradient_accumulation_steps: 8 # Number of steps for gradient accumulation
|
||||
uncond_steps: 10 # Number of unconditional steps
|
||||
mixed_precision: 'fp32' # Precision mode for training
|
||||
enable_xformers_memory_efficient_attention: True # Whether to use memory efficient attention
|
||||
gradient_checkpointing: True # Whether to use gradient checkpointing
|
||||
max_train_steps: 250000 # Maximum number of training steps
|
||||
max_grad_norm: 1.0 # Maximum gradient norm for clipping
|
||||
# Learning rate parameters
|
||||
learning_rate: 5.0e-6 # Base learning rate
|
||||
scale_lr: False # Whether to scale learning rate
|
||||
lr_warmup_steps: 1000 # Number of warmup steps for learning rate
|
||||
lr_scheduler: "linear" # Type of learning rate scheduler
|
||||
# Optimizer parameters
|
||||
use_8bit_adam: False # Whether to use 8-bit Adam optimizer
|
||||
adam_beta1: 0.5 # Beta1 parameter for Adam optimizer
|
||||
adam_beta2: 0.999 # Beta2 parameter for Adam optimizer
|
||||
adam_weight_decay: 1.0e-2 # Weight decay for Adam optimizer
|
||||
adam_epsilon: 1.0e-8 # Epsilon for Adam optimizer
|
||||
|
||||
total_limit: 10 # Maximum number of checkpoints to keep
|
||||
save_model_epoch_interval: 250000 # Interval between model saves
|
||||
checkpointing_steps: 2000 # Number of steps between checkpoints
|
||||
val_freq: 2000 # Frequency of validation
|
||||
|
||||
seed: 41 # Random seed for reproducibility
|
||||
|
||||
19
models/MuseTalk/configs/training/syncnet.yaml
Normal file
19
models/MuseTalk/configs/training/syncnet.yaml
Normal file
@@ -0,0 +1,19 @@
|
||||
# This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/configs/training/syncnet_16_pixel.yaml).
|
||||
model:
|
||||
audio_encoder: # input (1, 80, 52)
|
||||
in_channels: 1
|
||||
block_out_channels: [32, 64, 128, 256, 512, 1024, 2048]
|
||||
downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]]
|
||||
attn_blocks: [0, 0, 0, 0, 0, 0, 0]
|
||||
dropout: 0.0
|
||||
visual_encoder: # input (48, 128, 256)
|
||||
in_channels: 48
|
||||
block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048]
|
||||
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
|
||||
attn_blocks: [0, 0, 0, 0, 0, 0, 0, 0]
|
||||
dropout: 0.0
|
||||
|
||||
ckpt:
|
||||
resume_ckpt_path: ""
|
||||
inference_ckpt_path: ./models/syncnet/latentsync_syncnet.pt # this pretrained model is from LatentSync (https://huggingface.co/ByteDance/LatentSync/tree/main)
|
||||
save_ckpt_steps: 2500
|
||||
41
models/MuseTalk/download_weights.bat
Normal file
41
models/MuseTalk/download_weights.bat
Normal file
@@ -0,0 +1,41 @@
|
||||
@echo off
|
||||
setlocal
|
||||
|
||||
:: Set the checkpoints directory
|
||||
set CheckpointsDir=models
|
||||
|
||||
:: Create necessary directories
|
||||
mkdir %CheckpointsDir%\musetalk
|
||||
mkdir %CheckpointsDir%\musetalkV15
|
||||
mkdir %CheckpointsDir%\syncnet
|
||||
mkdir %CheckpointsDir%\dwpose
|
||||
mkdir %CheckpointsDir%\face-parse-bisent
|
||||
mkdir %CheckpointsDir%\sd-vae-ft-mse
|
||||
mkdir %CheckpointsDir%\whisper
|
||||
|
||||
:: Install required packages
|
||||
pip install -U "huggingface_hub[hf_xet]"
|
||||
|
||||
:: Set HuggingFace endpoint
|
||||
set HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
:: Download MuseTalk weights
|
||||
hf download TMElyralab/MuseTalk --local-dir %CheckpointsDir%
|
||||
|
||||
:: Download SD VAE weights
|
||||
hf download stabilityai/sd-vae-ft-mse --local-dir %CheckpointsDir%\sd-vae --include "config.json" "diffusion_pytorch_model.bin"
|
||||
|
||||
:: Download Whisper weights
|
||||
hf download openai/whisper-tiny --local-dir %CheckpointsDir%\whisper --include "config.json" "pytorch_model.bin" "preprocessor_config.json"
|
||||
|
||||
:: Download DWPose weights
|
||||
hf download yzd-v/DWPose --local-dir %CheckpointsDir%\dwpose --include "dw-ll_ucoco_384.pth"
|
||||
|
||||
:: Download SyncNet weights
|
||||
hf download ByteDance/LatentSync --local-dir %CheckpointsDir%\syncnet --include "latentsync_syncnet.pt"
|
||||
|
||||
:: Download face-parse-bisent weights
|
||||
hf download ManyOtherFunctions/face-parse-bisent --local-dir %CheckpointsDir%\face-parse-bisent --include "79999_iter.pth" "resnet18-5c106cde.pth"
|
||||
|
||||
echo All weights have been downloaded successfully!
|
||||
endlocal
|
||||
51
models/MuseTalk/download_weights.sh
Normal file
51
models/MuseTalk/download_weights.sh
Normal file
@@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Set the checkpoints directory
|
||||
CheckpointsDir="models"
|
||||
|
||||
# Create necessary directories
|
||||
mkdir -p models/musetalk models/musetalkV15 models/syncnet models/dwpose models/face-parse-bisent models/sd-vae models/whisper
|
||||
|
||||
# Install required packages
|
||||
pip install -U "huggingface_hub[cli]"
|
||||
pip install gdown
|
||||
|
||||
# Set HuggingFace mirror endpoint
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
# Download MuseTalk V1.0 weights
|
||||
huggingface-cli download TMElyralab/MuseTalk \
|
||||
--local-dir $CheckpointsDir \
|
||||
--include "musetalk/musetalk.json" "musetalk/pytorch_model.bin"
|
||||
|
||||
# Download MuseTalk V1.5 weights (unet.pth)
|
||||
huggingface-cli download TMElyralab/MuseTalk \
|
||||
--local-dir $CheckpointsDir \
|
||||
--include "musetalkV15/musetalk.json" "musetalkV15/unet.pth"
|
||||
|
||||
# Download SD VAE weights
|
||||
huggingface-cli download stabilityai/sd-vae-ft-mse \
|
||||
--local-dir $CheckpointsDir/sd-vae \
|
||||
--include "config.json" "diffusion_pytorch_model.bin"
|
||||
|
||||
# Download Whisper weights
|
||||
huggingface-cli download openai/whisper-tiny \
|
||||
--local-dir $CheckpointsDir/whisper \
|
||||
--include "config.json" "pytorch_model.bin" "preprocessor_config.json"
|
||||
|
||||
# Download DWPose weights
|
||||
huggingface-cli download yzd-v/DWPose \
|
||||
--local-dir $CheckpointsDir/dwpose \
|
||||
--include "dw-ll_ucoco_384.pth"
|
||||
|
||||
# Download SyncNet weights
|
||||
huggingface-cli download ByteDance/LatentSync \
|
||||
--local-dir $CheckpointsDir/syncnet \
|
||||
--include "latentsync_syncnet.pt"
|
||||
|
||||
# Download Face Parse Bisent weights
|
||||
gdown --id 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O $CheckpointsDir/face-parse-bisent/79999_iter.pth
|
||||
curl -L https://download.pytorch.org/models/resnet18-5c106cde.pth \
|
||||
-o $CheckpointsDir/face-parse-bisent/resnet18-5c106cde.pth
|
||||
|
||||
echo "✅ All weights have been downloaded successfully!"
|
||||
9
models/MuseTalk/entrypoint.sh
Normal file
9
models/MuseTalk/entrypoint.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo "entrypoint.sh"
|
||||
whoami
|
||||
which python
|
||||
source /opt/conda/etc/profile.d/conda.sh
|
||||
conda activate musev
|
||||
which python
|
||||
python app.py
|
||||
72
models/MuseTalk/inference.sh
Normal file
72
models/MuseTalk/inference.sh
Normal file
@@ -0,0 +1,72 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This script runs inference based on the version and mode specified by the user.
|
||||
# Usage:
|
||||
# To run v1.0 inference: sh inference.sh v1.0 [normal|realtime]
|
||||
# To run v1.5 inference: sh inference.sh v1.5 [normal|realtime]
|
||||
|
||||
# Check if the correct number of arguments is provided
|
||||
if [ "$#" -ne 2 ]; then
|
||||
echo "Usage: $0 <version> <mode>"
|
||||
echo "Example: $0 v1.0 normal or $0 v1.5 realtime"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get the version and mode from the user input
|
||||
version=$1
|
||||
mode=$2
|
||||
|
||||
# Validate mode
|
||||
if [ "$mode" != "normal" ] && [ "$mode" != "realtime" ]; then
|
||||
echo "Invalid mode specified. Please use 'normal' or 'realtime'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set config path based on mode
|
||||
if [ "$mode" = "normal" ]; then
|
||||
config_path="./configs/inference/test.yaml"
|
||||
result_dir="./results/test"
|
||||
else
|
||||
config_path="./configs/inference/realtime.yaml"
|
||||
result_dir="./results/realtime"
|
||||
fi
|
||||
|
||||
# Define the model paths based on the version
|
||||
if [ "$version" = "v1.0" ]; then
|
||||
model_dir="./models/musetalk"
|
||||
unet_model_path="$model_dir/pytorch_model.bin"
|
||||
unet_config="$model_dir/musetalk.json"
|
||||
version_arg="v1"
|
||||
elif [ "$version" = "v1.5" ]; then
|
||||
model_dir="./models/musetalkV15"
|
||||
unet_model_path="$model_dir/unet.pth"
|
||||
unet_config="$model_dir/musetalk.json"
|
||||
version_arg="v15"
|
||||
else
|
||||
echo "Invalid version specified. Please use v1.0 or v1.5."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set script name based on mode
|
||||
if [ "$mode" = "normal" ]; then
|
||||
script_name="scripts.inference"
|
||||
else
|
||||
script_name="scripts.realtime_inference"
|
||||
fi
|
||||
|
||||
# Base command arguments
|
||||
cmd_args="--inference_config $config_path \
|
||||
--result_dir $result_dir \
|
||||
--unet_model_path $unet_model_path \
|
||||
--unet_config $unet_config \
|
||||
--version $version_arg"
|
||||
|
||||
# Add realtime-specific arguments if in realtime mode
|
||||
if [ "$mode" = "realtime" ]; then
|
||||
cmd_args="$cmd_args \
|
||||
--fps 25 \
|
||||
--version $version_arg"
|
||||
fi
|
||||
|
||||
# Run inference
|
||||
python3 -m $script_name $cmd_args
|
||||
168
models/MuseTalk/musetalk/data/audio.py
Normal file
168
models/MuseTalk/musetalk/data/audio.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import librosa
|
||||
import librosa.filters
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from scipy.io import wavfile
|
||||
|
||||
class HParams:
|
||||
# copy from wav2lip
|
||||
def __init__(self):
|
||||
self.n_fft = 800
|
||||
self.hop_size = 200
|
||||
self.win_size = 800
|
||||
self.sample_rate = 16000
|
||||
self.frame_shift_ms = None
|
||||
self.signal_normalization = True
|
||||
|
||||
self.allow_clipping_in_normalization = True
|
||||
self.symmetric_mels = True
|
||||
self.max_abs_value = 4.0
|
||||
self.preemphasize = True
|
||||
self.preemphasis = 0.97
|
||||
self.min_level_db = -100
|
||||
self.ref_level_db = 20
|
||||
self.fmin = 55
|
||||
self.fmax=7600
|
||||
|
||||
self.use_lws=False
|
||||
self.num_mels=80 # Number of mel-spectrogram channels and local conditioning dimensionality
|
||||
self.rescale=True # Whether to rescale audio prior to preprocessing
|
||||
self.rescaling_max=0.9 # Rescaling value
|
||||
self.use_lws=False
|
||||
|
||||
|
||||
hp = HParams()
|
||||
|
||||
def load_wav(path, sr):
|
||||
return librosa.core.load(path, sr=sr)[0]
|
||||
#def load_wav(path, sr):
|
||||
# audio, sr_native = sf.read(path)
|
||||
# if sr != sr_native:
|
||||
# audio = librosa.resample(audio.T, sr_native, sr).T
|
||||
# return audio
|
||||
|
||||
def save_wav(wav, path, sr):
|
||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||
#proposed by @dsmiller
|
||||
wavfile.write(path, sr, wav.astype(np.int16))
|
||||
|
||||
def save_wavenet_wav(wav, path, sr):
|
||||
librosa.output.write_wav(path, wav, sr=sr)
|
||||
|
||||
def preemphasis(wav, k, preemphasize=True):
|
||||
if preemphasize:
|
||||
return signal.lfilter([1, -k], [1], wav)
|
||||
return wav
|
||||
|
||||
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
||||
if inv_preemphasize:
|
||||
return signal.lfilter([1], [1, -k], wav)
|
||||
return wav
|
||||
|
||||
def get_hop_size():
|
||||
hop_size = hp.hop_size
|
||||
if hop_size is None:
|
||||
assert hp.frame_shift_ms is not None
|
||||
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
|
||||
return hop_size
|
||||
|
||||
def linearspectrogram(wav):
|
||||
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
||||
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
|
||||
|
||||
if hp.signal_normalization:
|
||||
return _normalize(S)
|
||||
return S
|
||||
|
||||
def melspectrogram(wav):
|
||||
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
||||
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
|
||||
|
||||
if hp.signal_normalization:
|
||||
return _normalize(S)
|
||||
return S
|
||||
|
||||
def _lws_processor():
|
||||
import lws
|
||||
return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
|
||||
|
||||
def _stft(y):
|
||||
if hp.use_lws:
|
||||
return _lws_processor(hp).stft(y).T
|
||||
else:
|
||||
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
|
||||
|
||||
##########################################################
|
||||
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
||||
def num_frames(length, fsize, fshift):
|
||||
"""Compute number of time frames of spectrogram
|
||||
"""
|
||||
pad = (fsize - fshift)
|
||||
if length % fshift == 0:
|
||||
M = (length + pad * 2 - fsize) // fshift + 1
|
||||
else:
|
||||
M = (length + pad * 2 - fsize) // fshift + 2
|
||||
return M
|
||||
|
||||
|
||||
def pad_lr(x, fsize, fshift):
|
||||
"""Compute left and right padding
|
||||
"""
|
||||
M = num_frames(len(x), fsize, fshift)
|
||||
pad = (fsize - fshift)
|
||||
T = len(x) + 2 * pad
|
||||
r = (M - 1) * fshift + fsize - T
|
||||
return pad, pad + r
|
||||
##########################################################
|
||||
#Librosa correct padding
|
||||
def librosa_pad_lr(x, fsize, fshift):
|
||||
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
||||
|
||||
# Conversions
|
||||
_mel_basis = None
|
||||
|
||||
def _linear_to_mel(spectogram):
|
||||
global _mel_basis
|
||||
if _mel_basis is None:
|
||||
_mel_basis = _build_mel_basis()
|
||||
return np.dot(_mel_basis, spectogram)
|
||||
|
||||
def _build_mel_basis():
|
||||
assert hp.fmax <= hp.sample_rate // 2
|
||||
return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
|
||||
fmin=hp.fmin, fmax=hp.fmax)
|
||||
|
||||
def _amp_to_db(x):
|
||||
min_level = np.exp(hp.min_level_db / 20 * np.log(10))
|
||||
return 20 * np.log10(np.maximum(min_level, x))
|
||||
|
||||
def _db_to_amp(x):
|
||||
return np.power(10.0, (x) * 0.05)
|
||||
|
||||
def _normalize(S):
|
||||
if hp.allow_clipping_in_normalization:
|
||||
if hp.symmetric_mels:
|
||||
return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
|
||||
-hp.max_abs_value, hp.max_abs_value)
|
||||
else:
|
||||
return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
|
||||
|
||||
assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
|
||||
if hp.symmetric_mels:
|
||||
return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
|
||||
else:
|
||||
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
|
||||
|
||||
def _denormalize(D):
|
||||
if hp.allow_clipping_in_normalization:
|
||||
if hp.symmetric_mels:
|
||||
return (((np.clip(D, -hp.max_abs_value,
|
||||
hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
|
||||
+ hp.min_level_db)
|
||||
else:
|
||||
return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
||||
|
||||
if hp.symmetric_mels:
|
||||
return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
|
||||
else:
|
||||
return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
||||
610
models/MuseTalk/musetalk/data/dataset.py
Normal file
610
models/MuseTalk/musetalk/data/dataset.py
Normal file
@@ -0,0 +1,610 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import random
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
import torchvision.transforms as transforms
|
||||
from transformers import AutoFeatureExtractor
|
||||
import librosa
|
||||
import time
|
||||
import json
|
||||
import math
|
||||
from decord import AudioReader, VideoReader
|
||||
from decord.ndarray import cpu
|
||||
|
||||
from musetalk.data.sample_method import get_src_idx, shift_landmarks_to_face_coordinates, resize_landmark
|
||||
from musetalk.data import audio
|
||||
from musetalk.utils.audio_utils import ensure_wav
|
||||
|
||||
syncnet_mel_step_size = math.ceil(16 / 5 * 16) # latentsync
|
||||
|
||||
|
||||
class FaceDataset(Dataset):
|
||||
"""Dataset class for loading and processing video data
|
||||
|
||||
Each video can be represented as:
|
||||
- Concatenated frame images
|
||||
- '.mp4' or '.gif' files
|
||||
- Folder containing all frames
|
||||
"""
|
||||
def __init__(self,
|
||||
cfg,
|
||||
list_paths,
|
||||
root_path='./dataset/',
|
||||
repeats=None):
|
||||
# Initialize dataset paths
|
||||
meta_paths = []
|
||||
if repeats is None:
|
||||
repeats = [1] * len(list_paths)
|
||||
assert len(repeats) == len(list_paths)
|
||||
|
||||
# Load data list
|
||||
for list_path, repeat_time in zip(list_paths, repeats):
|
||||
with open(list_path, 'r') as f:
|
||||
num = 0
|
||||
f.readline() # Skip header line
|
||||
for line in f.readlines():
|
||||
line_info = line.strip()
|
||||
meta = line_info.split()
|
||||
meta = meta[0]
|
||||
meta_paths.extend([os.path.join(root_path, meta)] * repeat_time)
|
||||
num += 1
|
||||
print(f'{list_path}: {num} x {repeat_time} = {num * repeat_time} samples')
|
||||
|
||||
# Set basic attributes
|
||||
self.meta_paths = meta_paths
|
||||
self.root_path = root_path
|
||||
self.image_size = cfg['image_size']
|
||||
self.min_face_size = cfg['min_face_size']
|
||||
self.T = cfg['T']
|
||||
self.sample_method = cfg['sample_method']
|
||||
self.top_k_ratio = cfg['top_k_ratio']
|
||||
self.max_attempts = 200
|
||||
self.padding_pixel_mouth = cfg['padding_pixel_mouth']
|
||||
|
||||
# Cropping related parameters
|
||||
self.crop_type = cfg['crop_type']
|
||||
self.jaw2edge_margin_mean = cfg['cropping_jaw2edge_margin_mean']
|
||||
self.jaw2edge_margin_std = cfg['cropping_jaw2edge_margin_std']
|
||||
self.random_margin_method = cfg['random_margin_method']
|
||||
|
||||
# Image transformations
|
||||
self.to_tensor = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
self.pose_to_tensor = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
# Feature extractor
|
||||
self.feature_extractor = AutoFeatureExtractor.from_pretrained(cfg['whisper_path'])
|
||||
self.contorl_face_min_size = cfg["contorl_face_min_size"]
|
||||
|
||||
print("The sample method is: ", self.sample_method)
|
||||
print(f"only use face size > {self.min_face_size}", self.contorl_face_min_size)
|
||||
|
||||
def generate_random_value(self):
|
||||
"""Generate random value
|
||||
|
||||
Returns:
|
||||
float: Generated random value
|
||||
"""
|
||||
if self.random_margin_method == "uniform":
|
||||
random_value = np.random.uniform(
|
||||
self.jaw2edge_margin_mean - self.jaw2edge_margin_std,
|
||||
self.jaw2edge_margin_mean + self.jaw2edge_margin_std
|
||||
)
|
||||
elif self.random_margin_method == "normal":
|
||||
random_value = np.random.normal(
|
||||
loc=self.jaw2edge_margin_mean,
|
||||
scale=self.jaw2edge_margin_std
|
||||
)
|
||||
random_value = np.clip(
|
||||
random_value,
|
||||
self.jaw2edge_margin_mean - self.jaw2edge_margin_std,
|
||||
self.jaw2edge_margin_mean + self.jaw2edge_margin_std,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid random margin method: {self.random_margin_method}")
|
||||
return max(0, random_value)
|
||||
|
||||
def dynamic_margin_crop(self, img, original_bbox, extra_margin=None):
|
||||
"""Dynamically crop image with dynamic margin
|
||||
|
||||
Args:
|
||||
img: Input image
|
||||
original_bbox: Original bounding box
|
||||
extra_margin: Extra margin
|
||||
|
||||
Returns:
|
||||
tuple: (x1, y1, x2, y2, extra_margin)
|
||||
"""
|
||||
if extra_margin is None:
|
||||
extra_margin = self.generate_random_value()
|
||||
w, h = img.size
|
||||
x1, y1, x2, y2 = original_bbox
|
||||
y2 = min(y2 + int(extra_margin), h)
|
||||
return x1, y1, x2, y2, extra_margin
|
||||
|
||||
def crop_resize_img(self, img, bbox, crop_type='crop_resize', extra_margin=None):
|
||||
"""Crop and resize image
|
||||
|
||||
Args:
|
||||
img: Input image
|
||||
bbox: Bounding box
|
||||
crop_type: Type of cropping
|
||||
extra_margin: Extra margin
|
||||
|
||||
Returns:
|
||||
tuple: (Processed image, extra_margin, mask_scaled_factor)
|
||||
"""
|
||||
mask_scaled_factor = 1.
|
||||
if crop_type == 'crop_resize':
|
||||
x1, y1, x2, y2 = bbox
|
||||
img = img.crop((x1, y1, x2, y2))
|
||||
img = img.resize((self.image_size, self.image_size), Image.LANCZOS)
|
||||
elif crop_type == 'dynamic_margin_crop_resize':
|
||||
x1, y1, x2, y2, extra_margin = self.dynamic_margin_crop(img, bbox, extra_margin)
|
||||
w_original, _ = img.size
|
||||
img = img.crop((x1, y1, x2, y2))
|
||||
w_cropped, _ = img.size
|
||||
mask_scaled_factor = w_cropped / w_original
|
||||
img = img.resize((self.image_size, self.image_size), Image.LANCZOS)
|
||||
elif crop_type == 'resize':
|
||||
w, h = img.size
|
||||
scale = np.sqrt(self.image_size ** 2 / (h * w))
|
||||
new_w = int(w * scale) / 64 * 64
|
||||
new_h = int(h * scale) / 64 * 64
|
||||
img = img.resize((new_w, new_h), Image.LANCZOS)
|
||||
return img, extra_margin, mask_scaled_factor
|
||||
|
||||
def get_audio_file(self, wav_path, start_index):
|
||||
"""Get audio file features
|
||||
|
||||
Args:
|
||||
wav_path: Audio file path
|
||||
start_index: Starting index
|
||||
|
||||
Returns:
|
||||
tuple: (Audio features, start index)
|
||||
"""
|
||||
if not os.path.exists(wav_path):
|
||||
return None
|
||||
wav_path_converted = ensure_wav(wav_path)
|
||||
audio_input_librosa, sampling_rate = librosa.load(wav_path_converted, sr=16000)
|
||||
assert sampling_rate == 16000
|
||||
|
||||
while start_index >= 25 * 30:
|
||||
audio_input = audio_input_librosa[16000*30:]
|
||||
start_index -= 25 * 30
|
||||
if start_index + 2 * 25 >= 25 * 30:
|
||||
start_index -= 4 * 25
|
||||
audio_input = audio_input_librosa[16000*4:16000*34]
|
||||
else:
|
||||
audio_input = audio_input_librosa[:16000*30]
|
||||
|
||||
assert 2 * (start_index) >= 0
|
||||
assert 2 * (start_index + 2 * 25) <= 1500
|
||||
|
||||
audio_input = self.feature_extractor(
|
||||
audio_input,
|
||||
return_tensors="pt",
|
||||
sampling_rate=sampling_rate
|
||||
).input_features
|
||||
return audio_input, start_index
|
||||
|
||||
def get_audio_file_mel(self, wav_path, start_index):
|
||||
"""Get mel spectrogram of audio file
|
||||
|
||||
Args:
|
||||
wav_path: Audio file path
|
||||
start_index: Starting index
|
||||
|
||||
Returns:
|
||||
tuple: (Mel spectrogram, start index)
|
||||
"""
|
||||
if not os.path.exists(wav_path):
|
||||
return None
|
||||
|
||||
wav_path_converted = ensure_wav(wav_path)
|
||||
audio_input_librosa, sampling_rate = librosa.load(wav_path_converted, sr=16000)
|
||||
assert sampling_rate == 16000
|
||||
|
||||
audio_mel = self.mel_feature_extractor(audio_input_librosa)
|
||||
return audio_mel, start_index
|
||||
|
||||
def mel_feature_extractor(self, audio_input):
|
||||
"""Extract mel spectrogram features
|
||||
|
||||
Args:
|
||||
audio_input: Input audio
|
||||
|
||||
Returns:
|
||||
ndarray: Mel spectrogram features
|
||||
"""
|
||||
orig_mel = audio.melspectrogram(audio_input)
|
||||
return orig_mel.T
|
||||
|
||||
def crop_audio_window(self, spec, start_frame_num, fps=25):
|
||||
"""Crop audio window
|
||||
|
||||
Args:
|
||||
spec: Spectrogram
|
||||
start_frame_num: Starting frame number
|
||||
fps: Frames per second
|
||||
|
||||
Returns:
|
||||
ndarray: Cropped spectrogram
|
||||
"""
|
||||
start_idx = int(80. * (start_frame_num / float(fps)))
|
||||
end_idx = start_idx + syncnet_mel_step_size
|
||||
return spec[start_idx: end_idx, :]
|
||||
|
||||
def get_syncnet_input(self, video_path):
|
||||
"""Get SyncNet input features
|
||||
|
||||
Args:
|
||||
video_path: Video file path
|
||||
|
||||
Returns:
|
||||
ndarray: SyncNet input features
|
||||
"""
|
||||
ar = AudioReader(video_path, sample_rate=16000)
|
||||
original_mel = audio.melspectrogram(ar[:].asnumpy().squeeze(0))
|
||||
return original_mel.T
|
||||
|
||||
def get_resized_mouth_mask(
|
||||
self,
|
||||
img_resized,
|
||||
landmark_array,
|
||||
face_shape,
|
||||
padding_pixel_mouth=0,
|
||||
image_size=256,
|
||||
crop_margin=0
|
||||
):
|
||||
landmark_array = np.array(landmark_array)
|
||||
resized_landmark = resize_landmark(
|
||||
landmark_array, w=face_shape[0], h=face_shape[1], new_w=image_size, new_h=image_size)
|
||||
|
||||
landmark_array = np.array(resized_landmark[48 : 67]) # the lip landmarks in 68 landmarks format
|
||||
min_x, min_y = np.min(landmark_array, axis=0)
|
||||
max_x, max_y = np.max(landmark_array, axis=0)
|
||||
min_x = min_x - padding_pixel_mouth
|
||||
max_x = max_x + padding_pixel_mouth
|
||||
|
||||
# Calculate x-axis length and use it for y-axis
|
||||
width = max_x - min_x
|
||||
|
||||
# Calculate old center point
|
||||
center_y = (max_y + min_y) / 2
|
||||
|
||||
# Determine new min_y and max_y based on width
|
||||
min_y = center_y - width / 4
|
||||
max_y = center_y + width / 4
|
||||
|
||||
# Adjust mask position for dynamic crop, shift y-axis
|
||||
min_y = min_y - crop_margin
|
||||
max_y = max_y - crop_margin
|
||||
|
||||
# Prevent out of bounds
|
||||
min_x = max(min_x, 0)
|
||||
min_y = max(min_y, 0)
|
||||
max_x = min(max_x, face_shape[0])
|
||||
max_y = min(max_y, face_shape[1])
|
||||
|
||||
mask = np.zeros_like(np.array(img_resized))
|
||||
mask[round(min_y):round(max_y), round(min_x):round(max_x)] = 255
|
||||
return Image.fromarray(mask)
|
||||
|
||||
def __len__(self):
|
||||
return 100000
|
||||
|
||||
def __getitem__(self, idx):
|
||||
attempts = 0
|
||||
while attempts < self.max_attempts:
|
||||
try:
|
||||
meta_path = random.sample(self.meta_paths, k=1)[0]
|
||||
with open(meta_path, 'r') as f:
|
||||
meta_data = json.load(f)
|
||||
except Exception as e:
|
||||
print(f"meta file error:{meta_path}")
|
||||
print(e)
|
||||
attempts += 1
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
video_path = meta_data["mp4_path"]
|
||||
wav_path = meta_data["wav_path"]
|
||||
bbox_list = meta_data["face_list"]
|
||||
landmark_list = meta_data["landmark_list"]
|
||||
T = self.T
|
||||
|
||||
s = 0
|
||||
e = meta_data["frames"]
|
||||
len_valid_clip = e - s
|
||||
|
||||
if len_valid_clip < T * 10:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has less than {T * 10} frames")
|
||||
continue
|
||||
|
||||
try:
|
||||
cap = VideoReader(video_path, fault_tol=1, ctx=cpu(0))
|
||||
total_frames = len(cap)
|
||||
assert total_frames == len(landmark_list)
|
||||
assert total_frames == len(bbox_list)
|
||||
landmark_shape = np.array(landmark_list).shape
|
||||
if landmark_shape != (total_frames, 68, 2):
|
||||
attempts += 1
|
||||
print(f"video {video_path} has invalid landmark shape: {landmark_shape}, expected: {(total_frames, 68, 2)}") # we use 68 landmarks
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"video file error:{video_path}")
|
||||
print(e)
|
||||
attempts += 1
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
shift_landmarks, bbox_list_union, face_shapes = shift_landmarks_to_face_coordinates(
|
||||
landmark_list,
|
||||
bbox_list
|
||||
)
|
||||
if self.contorl_face_min_size and face_shapes[0][0] < self.min_face_size:
|
||||
print(f"video {video_path} has face size {face_shapes[0][0]} less than minimum required {self.min_face_size}")
|
||||
attempts += 1
|
||||
continue
|
||||
|
||||
step = 1
|
||||
drive_idx_start = random.randint(s, e - T * step)
|
||||
drive_idx_list = list(
|
||||
range(drive_idx_start, drive_idx_start + T * step, step))
|
||||
assert len(drive_idx_list) == T
|
||||
|
||||
src_idx_list = []
|
||||
list_index_out_of_range = False
|
||||
for drive_idx in drive_idx_list:
|
||||
src_idx = get_src_idx(
|
||||
drive_idx, T, self.sample_method, shift_landmarks, face_shapes, self.top_k_ratio)
|
||||
if src_idx is None:
|
||||
list_index_out_of_range = True
|
||||
break
|
||||
src_idx = min(src_idx, e - 1)
|
||||
src_idx = max(src_idx, s)
|
||||
src_idx_list.append(src_idx)
|
||||
|
||||
if list_index_out_of_range:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has invalid source index for drive frames")
|
||||
continue
|
||||
|
||||
ref_face_valid_flag = True
|
||||
extra_margin = self.generate_random_value()
|
||||
|
||||
# Get reference images
|
||||
ref_imgs = []
|
||||
for src_idx in src_idx_list:
|
||||
imSrc = Image.fromarray(cap[src_idx].asnumpy())
|
||||
bbox_s = bbox_list_union[src_idx]
|
||||
imSrc, _, _ = self.crop_resize_img(
|
||||
imSrc,
|
||||
bbox_s,
|
||||
self.crop_type,
|
||||
extra_margin=None
|
||||
)
|
||||
if self.contorl_face_min_size and min(imSrc.size[0], imSrc.size[1]) < self.min_face_size:
|
||||
ref_face_valid_flag = False
|
||||
break
|
||||
ref_imgs.append(imSrc)
|
||||
|
||||
if not ref_face_valid_flag:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has reference face size smaller than minimum required {self.min_face_size}")
|
||||
continue
|
||||
|
||||
# Get target images and masks
|
||||
imSameIDs = []
|
||||
bboxes = []
|
||||
face_masks = []
|
||||
face_mask_valid = True
|
||||
target_face_valid_flag = True
|
||||
|
||||
for drive_idx in drive_idx_list:
|
||||
imSameID = Image.fromarray(cap[drive_idx].asnumpy())
|
||||
bbox_s = bbox_list_union[drive_idx]
|
||||
imSameID, _ , mask_scaled_factor = self.crop_resize_img(
|
||||
imSameID,
|
||||
bbox_s,
|
||||
self.crop_type,
|
||||
extra_margin=extra_margin
|
||||
)
|
||||
if self.contorl_face_min_size and min(imSameID.size[0], imSameID.size[1]) < self.min_face_size:
|
||||
target_face_valid_flag = False
|
||||
break
|
||||
crop_margin = extra_margin * mask_scaled_factor
|
||||
face_mask = self.get_resized_mouth_mask(
|
||||
imSameID,
|
||||
shift_landmarks[drive_idx],
|
||||
face_shapes[drive_idx],
|
||||
self.padding_pixel_mouth,
|
||||
self.image_size,
|
||||
crop_margin=crop_margin
|
||||
)
|
||||
if np.count_nonzero(face_mask) == 0:
|
||||
face_mask_valid = False
|
||||
break
|
||||
|
||||
if face_mask.size[1] == 0 or face_mask.size[0] == 0:
|
||||
print(f"video {video_path} has invalid face mask size at frame {drive_idx}")
|
||||
face_mask_valid = False
|
||||
break
|
||||
|
||||
imSameIDs.append(imSameID)
|
||||
bboxes.append(bbox_s)
|
||||
face_masks.append(face_mask)
|
||||
|
||||
if not face_mask_valid:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has invalid face mask")
|
||||
continue
|
||||
|
||||
if not target_face_valid_flag:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has target face size smaller than minimum required {self.min_face_size}")
|
||||
continue
|
||||
|
||||
# Process audio features
|
||||
audio_offset = drive_idx_list[0]
|
||||
audio_step = step
|
||||
fps = 25.0 / step
|
||||
|
||||
try:
|
||||
audio_feature, audio_offset = self.get_audio_file(wav_path, audio_offset)
|
||||
_, audio_offset = self.get_audio_file_mel(wav_path, audio_offset)
|
||||
audio_feature_mel = self.get_syncnet_input(video_path)
|
||||
except Exception as e:
|
||||
print(f"audio file error:{wav_path}")
|
||||
print(e)
|
||||
attempts += 1
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
mel = self.crop_audio_window(audio_feature_mel, audio_offset)
|
||||
if mel.shape[0] != syncnet_mel_step_size:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has invalid mel spectrogram shape: {mel.shape}, expected: {syncnet_mel_step_size}")
|
||||
continue
|
||||
|
||||
mel = torch.FloatTensor(mel.T).unsqueeze(0)
|
||||
|
||||
# Build sample dictionary
|
||||
sample = dict(
|
||||
pixel_values_vid=torch.stack(
|
||||
[self.to_tensor(imSameID) for imSameID in imSameIDs], dim=0),
|
||||
pixel_values_ref_img=torch.stack(
|
||||
[self.to_tensor(ref_img) for ref_img in ref_imgs], dim=0),
|
||||
pixel_values_face_mask=torch.stack(
|
||||
[self.pose_to_tensor(face_mask) for face_mask in face_masks], dim=0),
|
||||
audio_feature=audio_feature[0],
|
||||
audio_offset=audio_offset,
|
||||
audio_step=audio_step,
|
||||
mel=mel,
|
||||
wav_path=wav_path,
|
||||
fps=fps,
|
||||
)
|
||||
|
||||
return sample
|
||||
|
||||
raise ValueError("Unable to find a valid sample after maximum attempts.")
|
||||
|
||||
class HDTFDataset(FaceDataset):
|
||||
"""HDTF dataset class"""
|
||||
def __init__(self, cfg):
|
||||
root_path = './dataset/HDTF/meta'
|
||||
list_paths = [
|
||||
'./dataset/HDTF/train.txt',
|
||||
]
|
||||
|
||||
|
||||
repeats = [10]
|
||||
super().__init__(cfg, list_paths, root_path, repeats)
|
||||
print('HDTFDataset: ', len(self))
|
||||
|
||||
class VFHQDataset(FaceDataset):
|
||||
"""VFHQ dataset class"""
|
||||
def __init__(self, cfg):
|
||||
root_path = './dataset/VFHQ/meta'
|
||||
list_paths = [
|
||||
'./dataset/VFHQ/train.txt',
|
||||
]
|
||||
repeats = [1]
|
||||
super().__init__(cfg, list_paths, root_path, repeats)
|
||||
print('VFHQDataset: ', len(self))
|
||||
|
||||
def PortraitDataset(cfg=None):
|
||||
"""Return dataset based on configuration
|
||||
|
||||
Args:
|
||||
cfg: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
Dataset: Combined dataset
|
||||
"""
|
||||
if cfg["dataset_key"] == "HDTF":
|
||||
return ConcatDataset([HDTFDataset(cfg)])
|
||||
elif cfg["dataset_key"] == "VFHQ":
|
||||
return ConcatDataset([VFHQDataset(cfg)])
|
||||
else:
|
||||
print("############ use all dataset ############ ")
|
||||
return ConcatDataset([HDTFDataset(cfg), VFHQDataset(cfg)])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Set random seeds for reproducibility
|
||||
seed = 42
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# Create dataset with configuration parameters
|
||||
dataset = PortraitDataset(cfg={
|
||||
'T': 1, # Number of frames to process at once
|
||||
'random_margin_method': "normal", # Method for generating random margins: "normal" or "uniform"
|
||||
'dataset_key': "HDTF", # Dataset to use: "HDTF", "VFHQ", or None for both
|
||||
'image_size': 256, # Size of processed images (height and width)
|
||||
'sample_method': 'pose_similarity_and_mouth_dissimilarity', # Method for selecting reference frames
|
||||
'top_k_ratio': 0.51, # Ratio for top-k selection in reference frame sampling
|
||||
'contorl_face_min_size': True, # Whether to enforce minimum face size
|
||||
'padding_pixel_mouth': 10, # Padding pixels around mouth region in mask
|
||||
'min_face_size': 200, # Minimum face size requirement for dataset
|
||||
'whisper_path': "./models/whisper", # Path to Whisper model
|
||||
'cropping_jaw2edge_margin_mean': 10, # Mean margin for jaw-to-edge cropping
|
||||
'cropping_jaw2edge_margin_std': 10, # Standard deviation for jaw-to-edge cropping
|
||||
'crop_type': "dynamic_margin_crop_resize", # Type of cropping: "crop_resize", "dynamic_margin_crop_resize", or "resize"
|
||||
})
|
||||
print(len(dataset))
|
||||
|
||||
import torchvision
|
||||
os.makedirs('debug', exist_ok=True)
|
||||
for i in range(10): # Check 10 samples
|
||||
sample = dataset[0]
|
||||
print(f"processing {i}")
|
||||
|
||||
# Get images and mask
|
||||
ref_img = (sample['pixel_values_ref_img'] + 1.0) / 2 # (b, c, h, w)
|
||||
target_img = (sample['pixel_values_vid'] + 1.0) / 2
|
||||
face_mask = sample['pixel_values_face_mask']
|
||||
|
||||
# Print dimension information
|
||||
print(f"ref_img shape: {ref_img.shape}")
|
||||
print(f"target_img shape: {target_img.shape}")
|
||||
print(f"face_mask shape: {face_mask.shape}")
|
||||
|
||||
# Create visualization images
|
||||
b, c, h, w = ref_img.shape
|
||||
|
||||
# Apply mask only to target image
|
||||
target_mask = face_mask
|
||||
|
||||
# Keep reference image unchanged
|
||||
ref_with_mask = ref_img.clone()
|
||||
|
||||
# Create mask overlay for target image
|
||||
target_with_mask = target_img.clone()
|
||||
target_with_mask = target_with_mask * (1 - target_mask) + target_mask # Apply mask only to target
|
||||
|
||||
# Save original images, mask, and overlay results
|
||||
# First row: original images
|
||||
# Second row: mask
|
||||
# Third row: overlay effect
|
||||
concatenated_img = torch.cat((
|
||||
ref_img, target_img, # Original images
|
||||
torch.zeros_like(ref_img), target_mask, # Mask (black for ref)
|
||||
ref_with_mask, target_with_mask # Overlay effect
|
||||
), dim=3)
|
||||
|
||||
torchvision.utils.save_image(
|
||||
concatenated_img, f'debug/mask_check_{i}.jpg', nrow=2)
|
||||
233
models/MuseTalk/musetalk/data/sample_method.py
Normal file
233
models/MuseTalk/musetalk/data/sample_method.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
def summarize_tensor(x):
|
||||
return f"\033[34m{str(tuple(x.shape)).ljust(24)}\033[0m (\033[31mmin {x.min().item():+.4f}\033[0m / \033[32mmean {x.mean().item():+.4f}\033[0m / \033[33mmax {x.max().item():+.4f}\033[0m)"
|
||||
|
||||
def calculate_mouth_open_similarity(landmarks_list, select_idx,top_k=50,ascending=True):
|
||||
num_landmarks = len(landmarks_list)
|
||||
mouth_open_ratios = np.zeros(num_landmarks) # Initialize as a numpy array
|
||||
print(np.shape(landmarks_list))
|
||||
## Calculate mouth opening ratios
|
||||
for i, landmarks in enumerate(landmarks_list):
|
||||
# Assuming landmarks are in the format [x, y] and accessible by index
|
||||
mouth_top = landmarks[165] # Adjust index according to your landmarks format
|
||||
mouth_bottom = landmarks[147] # Adjust index according to your landmarks format
|
||||
mouth_open_ratio = np.linalg.norm(mouth_top - mouth_bottom)
|
||||
mouth_open_ratios[i] = mouth_open_ratio
|
||||
|
||||
# Calculate differences matrix
|
||||
differences_matrix = np.abs(mouth_open_ratios[:, np.newaxis] - mouth_open_ratios[select_idx])
|
||||
differences_matrix_with_signs = mouth_open_ratios[:, np.newaxis] - mouth_open_ratios[select_idx]
|
||||
print(differences_matrix.shape)
|
||||
# Find top_k similar indices for each landmark set
|
||||
if ascending:
|
||||
top_indices = np.argsort(differences_matrix[i])[:top_k]
|
||||
else:
|
||||
top_indices = np.argsort(-differences_matrix[i])[:top_k]
|
||||
similar_landmarks_indices = top_indices.tolist()
|
||||
similar_landmarks_distances = differences_matrix_with_signs[i].tolist() #注意这里不要排序
|
||||
|
||||
return similar_landmarks_indices, similar_landmarks_distances
|
||||
#############################################################################################
|
||||
def get_closed_mouth(landmarks_list,ascending=True,top_k=50):
|
||||
num_landmarks = len(landmarks_list)
|
||||
|
||||
mouth_open_ratios = np.zeros(num_landmarks) # Initialize as a numpy array
|
||||
## Calculate mouth opening ratios
|
||||
#print("landmarks shape",np.shape(landmarks_list))
|
||||
for i, landmarks in enumerate(landmarks_list):
|
||||
# Assuming landmarks are in the format [x, y] and accessible by index
|
||||
#print(landmarks[165])
|
||||
mouth_top = np.array(landmarks[165])# Adjust index according to your landmarks format
|
||||
mouth_bottom = np.array(landmarks[147]) # Adjust index according to your landmarks format
|
||||
mouth_open_ratio = np.linalg.norm(mouth_top - mouth_bottom)
|
||||
mouth_open_ratios[i] = mouth_open_ratio
|
||||
|
||||
# Find top_k similar indices for each landmark set
|
||||
if ascending:
|
||||
top_indices = np.argsort(mouth_open_ratios)[:top_k]
|
||||
else:
|
||||
top_indices = np.argsort(-mouth_open_ratios)[:top_k]
|
||||
return top_indices
|
||||
|
||||
def calculate_landmarks_similarity(selected_idx, landmarks_list,image_shapes, start_index, end_index, top_k=50,ascending=True):
|
||||
"""
|
||||
Calculate the similarity between sets of facial landmarks and return the indices of the most similar faces.
|
||||
|
||||
Parameters:
|
||||
landmarks_list (list): A list containing sets of facial landmarks, each element is a set of landmarks.
|
||||
image_shapes (list): A list containing the shape of each image, each element is a (width, height) tuple.
|
||||
start_index (int): The starting index of the facial landmarks.
|
||||
end_index (int): The ending index of the facial landmarks.
|
||||
top_k (int): The number of most similar landmark sets to return. Default is 50.
|
||||
ascending (bool): Controls the sorting order. If True, sort in ascending order; If False, sort in descending order. Default is True.
|
||||
|
||||
Returns:
|
||||
similar_landmarks_indices (list): A list containing the indices of the most similar facial landmarks for each face.
|
||||
resized_landmarks (list): A list containing the resized facial landmarks.
|
||||
"""
|
||||
num_landmarks = len(landmarks_list)
|
||||
resized_landmarks = []
|
||||
|
||||
# Preprocess landmarks
|
||||
for i in range(num_landmarks):
|
||||
landmark_array = np.array(landmarks_list[i])
|
||||
selected_landmarks = landmark_array[start_index:end_index]
|
||||
resized_landmark = resize_landmark(selected_landmarks, w=image_shapes[i][0], h=image_shapes[i][1],new_w=256,new_h=256)
|
||||
resized_landmarks.append(resized_landmark)
|
||||
|
||||
resized_landmarks_array = np.array(resized_landmarks) # Convert list to array for easier manipulation
|
||||
|
||||
# Calculate similarity
|
||||
distances = np.linalg.norm(resized_landmarks_array - resized_landmarks_array[selected_idx][np.newaxis, :], axis=2)
|
||||
overall_distances = np.mean(distances, axis=1) # Calculate mean distance for each set of landmarks
|
||||
|
||||
if ascending:
|
||||
sorted_indices = np.argsort(overall_distances)
|
||||
similar_landmarks_indices = sorted_indices[1:top_k+1].tolist() # Exclude self and take top_k
|
||||
else:
|
||||
sorted_indices = np.argsort(-overall_distances)
|
||||
similar_landmarks_indices = sorted_indices[0:top_k].tolist()
|
||||
|
||||
return similar_landmarks_indices
|
||||
|
||||
def process_bbox_musetalk(face_array, landmark_array):
|
||||
x_min_face, y_min_face, x_max_face, y_max_face = map(int, face_array)
|
||||
x_min_lm = min([int(x) for x, y in landmark_array])
|
||||
y_min_lm = min([int(y) for x, y in landmark_array])
|
||||
x_max_lm = max([int(x) for x, y in landmark_array])
|
||||
y_max_lm = max([int(y) for x, y in landmark_array])
|
||||
x_min = min(x_min_face, x_min_lm)
|
||||
y_min = min(y_min_face, y_min_lm)
|
||||
x_max = max(x_max_face, x_max_lm)
|
||||
y_max = max(y_max_face, y_max_lm)
|
||||
|
||||
x_min = max(x_min, 0)
|
||||
y_min = max(y_min, 0)
|
||||
|
||||
return [x_min, y_min, x_max, y_max]
|
||||
|
||||
def shift_landmarks_to_face_coordinates(landmark_list, face_list):
|
||||
"""
|
||||
Translates the data in landmark_list to the coordinates of the cropped larger face.
|
||||
|
||||
Parameters:
|
||||
landmark_list (list): A list containing multiple sets of facial landmarks.
|
||||
face_list (list): A list containing multiple facial images.
|
||||
|
||||
Returns:
|
||||
landmark_list_shift (list): The list of translated landmarks.
|
||||
bbox_union (list): The list of union bounding boxes.
|
||||
face_shapes (list): The list of facial shapes.
|
||||
"""
|
||||
landmark_list_shift = []
|
||||
bbox_union = []
|
||||
face_shapes = []
|
||||
|
||||
for i in range(len(face_list)):
|
||||
landmark_array = np.array(landmark_list[i]) # 转换为numpy数组并创建副本
|
||||
face_array = face_list[i]
|
||||
f_landmark_bbox = process_bbox_musetalk(face_array, landmark_array)
|
||||
x_min, y_min, x_max, y_max = f_landmark_bbox
|
||||
landmark_array[:, 0] = landmark_array[:, 0] - f_landmark_bbox[0]
|
||||
landmark_array[:, 1] = landmark_array[:, 1] - f_landmark_bbox[1]
|
||||
landmark_list_shift.append(landmark_array)
|
||||
bbox_union.append(f_landmark_bbox)
|
||||
face_shapes.append((x_max - x_min, y_max - y_min))
|
||||
|
||||
return landmark_list_shift, bbox_union, face_shapes
|
||||
|
||||
def resize_landmark(landmark, w, h, new_w, new_h):
|
||||
landmark_norm = landmark / [w, h]
|
||||
landmark_resized = landmark_norm * [new_w, new_h]
|
||||
|
||||
return landmark_resized
|
||||
|
||||
def get_src_idx(drive_idx, T, sample_method,landmarks_list,image_shapes,top_k_ratio):
|
||||
"""
|
||||
Calculate the source index (src_idx) based on the given drive index, T, s, e, and sampling method.
|
||||
|
||||
Parameters:
|
||||
- drive_idx (int): The current drive index.
|
||||
- T (int): Total number of frames or a specific range limit.
|
||||
- sample_method (str): Sampling method, which can be "random" or other methods.
|
||||
- landmarks_list (list): List of facial landmarks.
|
||||
- image_shapes (list): List of image shapes.
|
||||
- top_k_ratio (float): Ratio for selecting top k similar frames.
|
||||
|
||||
Returns:
|
||||
- src_idx (int): The calculated source index.
|
||||
"""
|
||||
if sample_method == "random":
|
||||
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
||||
elif sample_method == "pose_similarity":
|
||||
top_k = int(top_k_ratio*len(landmarks_list))
|
||||
try:
|
||||
top_k = int(top_k_ratio*len(landmarks_list))
|
||||
# facial contour
|
||||
landmark_start_idx = 0
|
||||
landmark_end_idx = 16
|
||||
pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
|
||||
src_idx = random.choice(pose_similarity_list)
|
||||
while abs(src_idx-drive_idx)<5:
|
||||
src_idx = random.choice(pose_similarity_list)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
elif sample_method=="pose_similarity_and_closed_mouth":
|
||||
# facial contour
|
||||
landmark_start_idx = 0
|
||||
landmark_end_idx = 16
|
||||
try:
|
||||
top_k = int(top_k_ratio*len(landmarks_list))
|
||||
closed_mouth_list = get_closed_mouth(landmarks_list, ascending=True,top_k=top_k)
|
||||
#print("closed_mouth_list",closed_mouth_list)
|
||||
pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
|
||||
#print("pose_similarity_list",pose_similarity_list)
|
||||
common_list = list(set(closed_mouth_list).intersection(set(pose_similarity_list)))
|
||||
if len(common_list) == 0:
|
||||
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
||||
else:
|
||||
src_idx = random.choice(common_list)
|
||||
|
||||
while abs(src_idx-drive_idx) <5:
|
||||
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
elif sample_method=="pose_similarity_and_mouth_dissimilarity":
|
||||
top_k = int(top_k_ratio*len(landmarks_list))
|
||||
try:
|
||||
top_k = int(top_k_ratio*len(landmarks_list))
|
||||
|
||||
# facial contour for 68 landmarks format
|
||||
landmark_start_idx = 0
|
||||
landmark_end_idx = 16
|
||||
|
||||
pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
|
||||
|
||||
# Mouth inner coutour for 68 landmarks format
|
||||
landmark_start_idx = 60
|
||||
landmark_end_idx = 67
|
||||
|
||||
mouth_dissimilarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=False)
|
||||
|
||||
common_list = list(set(pose_similarity_list).intersection(set(mouth_dissimilarity_list)))
|
||||
if len(common_list) == 0:
|
||||
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
||||
else:
|
||||
src_idx = random.choice(common_list)
|
||||
|
||||
while abs(src_idx-drive_idx) <5:
|
||||
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown sample_method: {sample_method}")
|
||||
return src_idx
|
||||
81
models/MuseTalk/musetalk/loss/basic_loss.py
Normal file
81
models/MuseTalk/musetalk/loss/basic_loss.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, optim
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from musetalk.loss.discriminator import MultiScaleDiscriminator,DiscriminatorFullModel
|
||||
import musetalk.loss.vgg_face as vgg_face
|
||||
|
||||
class Interpolate(nn.Module):
|
||||
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
|
||||
super(Interpolate, self).__init__()
|
||||
self.size = size
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, input):
|
||||
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
|
||||
|
||||
def set_requires_grad(net, requires_grad=False):
|
||||
if net is not None:
|
||||
for param in net.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = OmegaConf.load("config/audio_adapter/E7.yaml")
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
pyramid_scale = [1, 0.5, 0.25, 0.125]
|
||||
vgg_IN = vgg_face.Vgg19().to(device)
|
||||
pyramid = vgg_face.ImagePyramide(cfg.loss_params.pyramid_scale, 3).to(device)
|
||||
vgg_IN.eval()
|
||||
downsampler = Interpolate(size=(224, 224), mode='bilinear', align_corners=False)
|
||||
|
||||
image = torch.rand(8, 3, 256, 256).to(device)
|
||||
image_pred = torch.rand(8, 3, 256, 256).to(device)
|
||||
pyramide_real = pyramid(downsampler(image))
|
||||
pyramide_generated = pyramid(downsampler(image_pred))
|
||||
|
||||
|
||||
loss_IN = 0
|
||||
for scale in cfg.loss_params.pyramid_scale:
|
||||
x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)])
|
||||
y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)])
|
||||
for i, weight in enumerate(cfg.loss_params.vgg_layer_weight):
|
||||
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
||||
loss_IN += weight * value
|
||||
loss_IN /= sum(cfg.loss_params.vgg_layer_weight) # 对vgg不同层取均值,金字塔loss是每层叠
|
||||
print(loss_IN)
|
||||
|
||||
#print(cfg.model_params.discriminator_params)
|
||||
|
||||
discriminator = MultiScaleDiscriminator(**cfg.model_params.discriminator_params).to(device)
|
||||
discriminator_full = DiscriminatorFullModel(discriminator)
|
||||
disc_scales = cfg.model_params.discriminator_params.scales
|
||||
# Prepare optimizer and loss function
|
||||
optimizer_D = optim.AdamW(discriminator.parameters(),
|
||||
lr=cfg.discriminator_train_params.lr,
|
||||
weight_decay=cfg.discriminator_train_params.weight_decay,
|
||||
betas=cfg.discriminator_train_params.betas,
|
||||
eps=cfg.discriminator_train_params.eps)
|
||||
scheduler_D = CosineAnnealingLR(optimizer_D,
|
||||
T_max=cfg.discriminator_train_params.epochs,
|
||||
eta_min=1e-6)
|
||||
|
||||
discriminator.train()
|
||||
|
||||
set_requires_grad(discriminator, False)
|
||||
|
||||
loss_G = 0.
|
||||
discriminator_maps_generated = discriminator(pyramide_generated)
|
||||
discriminator_maps_real = discriminator(pyramide_real)
|
||||
|
||||
for scale in disc_scales:
|
||||
key = 'prediction_map_%s' % scale
|
||||
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
|
||||
loss_G += value
|
||||
|
||||
print(loss_G)
|
||||
44
models/MuseTalk/musetalk/loss/conv.py
Normal file
44
models/MuseTalk/musetalk/loss/conv.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
class Conv2d(nn.Module):
|
||||
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
||||
nn.BatchNorm2d(cout)
|
||||
)
|
||||
self.act = nn.ReLU()
|
||||
self.residual = residual
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_block(x)
|
||||
if self.residual:
|
||||
out += x
|
||||
return self.act(out)
|
||||
|
||||
class nonorm_Conv2d(nn.Module):
|
||||
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
||||
)
|
||||
self.act = nn.LeakyReLU(0.01, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_block(x)
|
||||
return self.act(out)
|
||||
|
||||
class Conv2dTranspose(nn.Module):
|
||||
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
|
||||
nn.BatchNorm2d(cout)
|
||||
)
|
||||
self.act = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_block(x)
|
||||
return self.act(out)
|
||||
145
models/MuseTalk/musetalk/loss/discriminator.py
Normal file
145
models/MuseTalk/musetalk/loss/discriminator.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
from musetalk.loss.vgg_face import ImagePyramide
|
||||
|
||||
class DownBlock2d(nn.Module):
|
||||
"""
|
||||
Simple block for processing video (encoder).
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
|
||||
super(DownBlock2d, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
|
||||
|
||||
if sn:
|
||||
self.conv = nn.utils.spectral_norm(self.conv)
|
||||
|
||||
if norm:
|
||||
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
||||
else:
|
||||
self.norm = None
|
||||
self.pool = pool
|
||||
|
||||
def forward(self, x):
|
||||
out = x
|
||||
out = self.conv(out)
|
||||
if self.norm:
|
||||
out = self.norm(out)
|
||||
out = F.leaky_relu(out, 0.2)
|
||||
if self.pool:
|
||||
out = F.avg_pool2d(out, (2, 2))
|
||||
return out
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
"""
|
||||
Discriminator similar to Pix2Pix
|
||||
"""
|
||||
|
||||
def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
|
||||
sn=False, **kwargs):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
down_blocks = []
|
||||
for i in range(num_blocks):
|
||||
down_blocks.append(
|
||||
DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
||||
min(max_features, block_expansion * (2 ** (i + 1))),
|
||||
norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
|
||||
|
||||
self.down_blocks = nn.ModuleList(down_blocks)
|
||||
self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
|
||||
if sn:
|
||||
self.conv = nn.utils.spectral_norm(self.conv)
|
||||
|
||||
def forward(self, x):
|
||||
feature_maps = []
|
||||
out = x
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
feature_maps.append(down_block(out))
|
||||
out = feature_maps[-1]
|
||||
prediction_map = self.conv(out)
|
||||
|
||||
return feature_maps, prediction_map
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(nn.Module):
|
||||
"""
|
||||
Multi-scale (scale) discriminator
|
||||
"""
|
||||
|
||||
def __init__(self, scales=(), **kwargs):
|
||||
super(MultiScaleDiscriminator, self).__init__()
|
||||
self.scales = scales
|
||||
discs = {}
|
||||
for scale in scales:
|
||||
discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
|
||||
self.discs = nn.ModuleDict(discs)
|
||||
|
||||
def forward(self, x):
|
||||
out_dict = {}
|
||||
for scale, disc in self.discs.items():
|
||||
scale = str(scale).replace('-', '.')
|
||||
key = 'prediction_' + scale
|
||||
#print(key)
|
||||
#print(x)
|
||||
feature_maps, prediction_map = disc(x[key])
|
||||
out_dict['feature_maps_' + scale] = feature_maps
|
||||
out_dict['prediction_map_' + scale] = prediction_map
|
||||
return out_dict
|
||||
|
||||
|
||||
|
||||
class DiscriminatorFullModel(torch.nn.Module):
|
||||
"""
|
||||
Merge all discriminator related updates into single model for better multi-gpu usage
|
||||
"""
|
||||
|
||||
def __init__(self, discriminator):
|
||||
super(DiscriminatorFullModel, self).__init__()
|
||||
self.discriminator = discriminator
|
||||
self.scales = self.discriminator.scales
|
||||
print("scales",self.scales)
|
||||
self.pyramid = ImagePyramide(self.scales, 3)
|
||||
if torch.cuda.is_available():
|
||||
self.pyramid = self.pyramid.cuda()
|
||||
|
||||
self.zero_tensor = None
|
||||
|
||||
def get_zero_tensor(self, input):
|
||||
if self.zero_tensor is None:
|
||||
self.zero_tensor = torch.FloatTensor(1).fill_(0).cuda()
|
||||
self.zero_tensor.requires_grad_(False)
|
||||
return self.zero_tensor.expand_as(input)
|
||||
|
||||
def forward(self, x, generated, gan_mode='ls'):
|
||||
pyramide_real = self.pyramid(x)
|
||||
pyramide_generated = self.pyramid(generated.detach())
|
||||
|
||||
discriminator_maps_generated = self.discriminator(pyramide_generated)
|
||||
discriminator_maps_real = self.discriminator(pyramide_real)
|
||||
|
||||
value_total = 0
|
||||
for scale in self.scales:
|
||||
key = 'prediction_map_%s' % scale
|
||||
if gan_mode == 'hinge':
|
||||
value = -torch.mean(torch.min(discriminator_maps_real[key]-1, self.get_zero_tensor(discriminator_maps_real[key]))) - torch.mean(torch.min(-discriminator_maps_generated[key]-1, self.get_zero_tensor(discriminator_maps_generated[key])))
|
||||
elif gan_mode == 'ls':
|
||||
value = ((1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2).mean()
|
||||
else:
|
||||
raise ValueError('Unexpected gan_mode {}'.format(self.train_params['gan_mode']))
|
||||
|
||||
value_total += value
|
||||
|
||||
return value_total
|
||||
|
||||
def main():
|
||||
discriminator = MultiScaleDiscriminator(scales=[1],
|
||||
block_expansion=32,
|
||||
max_features=512,
|
||||
num_blocks=4,
|
||||
sn=True,
|
||||
image_channel=3,
|
||||
estimate_jacobian=False)
|
||||
152
models/MuseTalk/musetalk/loss/resnet.py
Normal file
152
models/MuseTalk/musetalk/loss/resnet.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
__all__ = ['ResNet', 'resnet50']
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_classes=1000, include_top=True):
|
||||
self.inplanes = 64
|
||||
super(ResNet, self).__init__()
|
||||
self.include_top = include_top
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True)
|
||||
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.avgpool = nn.AvgPool2d(7, stride=1)
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 255.
|
||||
x = x.flip(1)
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
|
||||
if not self.include_top:
|
||||
return x
|
||||
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
def resnet50(**kwargs):
|
||||
"""Constructs a ResNet-50 model.
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
return model
|
||||
95
models/MuseTalk/musetalk/loss/syncnet.py
Normal file
95
models/MuseTalk/musetalk/loss/syncnet.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .conv import Conv2d
|
||||
|
||||
logloss = nn.BCELoss(reduction="none")
|
||||
def cosine_loss(a, v, y):
|
||||
d = nn.functional.cosine_similarity(a, v)
|
||||
d = d.clamp(0,1) # cosine_similarity的取值范围是【-1,1】,BCE如果输入负数会报错RuntimeError: CUDA error: device-side assert triggered
|
||||
loss = logloss(d.unsqueeze(1), y).squeeze()
|
||||
loss = loss.mean()
|
||||
return loss, d
|
||||
|
||||
def get_sync_loss(
|
||||
audio_embed,
|
||||
gt_frames,
|
||||
pred_frames,
|
||||
syncnet,
|
||||
adapted_weight,
|
||||
frames_left_index=0,
|
||||
frames_right_index=16,
|
||||
):
|
||||
# 跟gt_frames做随机的插入交换,节省显存开销
|
||||
assert pred_frames.shape[1] == (frames_right_index - frames_left_index) * 3
|
||||
# 3通道图像
|
||||
frames_sync_loss = torch.cat(
|
||||
[gt_frames[:, :3 * frames_left_index, ...], pred_frames, gt_frames[:, 3 * frames_right_index:, ...]],
|
||||
axis=1
|
||||
)
|
||||
vision_embed = syncnet.get_image_embed(frames_sync_loss)
|
||||
y = torch.ones(frames_sync_loss.size(0), 1).float().to(audio_embed.device)
|
||||
loss, score = cosine_loss(audio_embed, vision_embed, y)
|
||||
return loss, score
|
||||
|
||||
class SyncNet_color(nn.Module):
|
||||
def __init__(self):
|
||||
super(SyncNet_color, self).__init__()
|
||||
|
||||
self.face_encoder = nn.Sequential(
|
||||
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
|
||||
|
||||
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
||||
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
||||
Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
|
||||
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
||||
|
||||
self.audio_encoder = nn.Sequential(
|
||||
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
||||
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
||||
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
||||
|
||||
def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
|
||||
face_embedding = self.face_encoder(face_sequences)
|
||||
audio_embedding = self.audio_encoder(audio_sequences)
|
||||
|
||||
audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
|
||||
face_embedding = face_embedding.view(face_embedding.size(0), -1)
|
||||
|
||||
audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
|
||||
face_embedding = F.normalize(face_embedding, p=2, dim=1)
|
||||
|
||||
|
||||
return audio_embedding, face_embedding
|
||||
237
models/MuseTalk/musetalk/loss/vgg_face.py
Normal file
237
models/MuseTalk/musetalk/loss/vgg_face.py
Normal file
@@ -0,0 +1,237 @@
|
||||
'''
|
||||
This part of code contains a pretrained vgg_face model.
|
||||
ref link: https://github.com/prlz77/vgg-face.pytorch
|
||||
'''
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo
|
||||
import pickle
|
||||
from musetalk.loss import resnet as ResNet
|
||||
|
||||
|
||||
MODEL_URL = "https://github.com/claudio-unipv/vggface-pytorch/releases/download/v0.1/vggface-9d491dd7c30312.pth"
|
||||
VGG_FACE_PATH = '/apdcephfs_cq8/share_1367250/zhentaoyu/Driving/00_VASA/00_data/models/pretrain_models/resnet50_ft_weight.pkl'
|
||||
|
||||
# It was 93.5940, 104.7624, 129.1863 before dividing by 255
|
||||
MEAN_RGB = [
|
||||
0.367035294117647,
|
||||
0.41083294117647057,
|
||||
0.5066129411764705
|
||||
]
|
||||
def load_state_dict(model, fname):
|
||||
"""
|
||||
Set parameters converted from Caffe models authors of VGGFace2 provide.
|
||||
See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/.
|
||||
|
||||
Arguments:
|
||||
model: model
|
||||
fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle.
|
||||
"""
|
||||
with open(fname, 'rb') as f:
|
||||
weights = pickle.load(f, encoding='latin1')
|
||||
|
||||
own_state = model.state_dict()
|
||||
for name, param in weights.items():
|
||||
if name in own_state:
|
||||
try:
|
||||
own_state[name].copy_(torch.from_numpy(param))
|
||||
except Exception:
|
||||
raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\
|
||||
'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size()))
|
||||
else:
|
||||
raise KeyError('unexpected key "{}" in state_dict'.format(name))
|
||||
|
||||
|
||||
def vggface2(pretrained=True):
|
||||
vggface = ResNet.resnet50(num_classes=8631, include_top=True)
|
||||
load_state_dict(vggface, VGG_FACE_PATH)
|
||||
return vggface
|
||||
|
||||
def vggface(pretrained=False, **kwargs):
|
||||
"""VGGFace model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns pre-trained model
|
||||
"""
|
||||
model = VggFace(**kwargs)
|
||||
if pretrained:
|
||||
state = torch.utils.model_zoo.load_url(MODEL_URL)
|
||||
model.load_state_dict(state)
|
||||
return model
|
||||
|
||||
|
||||
class VggFace(torch.nn.Module):
|
||||
def __init__(self, classes=2622):
|
||||
"""VGGFace model.
|
||||
|
||||
Face recognition network. It takes as input a Bx3x224x224
|
||||
batch of face images and gives as output a BxC score vector
|
||||
(C is the number of identities).
|
||||
Input images need to be scaled in the 0-1 range and then
|
||||
normalized with respect to the mean RGB used during training.
|
||||
|
||||
Args:
|
||||
classes (int): number of identities recognized by the
|
||||
network
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.conv1 = _ConvBlock(3, 64, 64)
|
||||
self.conv2 = _ConvBlock(64, 128, 128)
|
||||
self.conv3 = _ConvBlock(128, 256, 256, 256)
|
||||
self.conv4 = _ConvBlock(256, 512, 512, 512)
|
||||
self.conv5 = _ConvBlock(512, 512, 512, 512)
|
||||
self.dropout = torch.nn.Dropout(0.5)
|
||||
self.fc1 = torch.nn.Linear(7 * 7 * 512, 4096)
|
||||
self.fc2 = torch.nn.Linear(4096, 4096)
|
||||
self.fc3 = torch.nn.Linear(4096, classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
x = self.conv5(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.dropout(F.relu(self.fc1(x)))
|
||||
x = self.dropout(F.relu(self.fc2(x)))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
class _ConvBlock(torch.nn.Module):
|
||||
"""A Convolutional block."""
|
||||
|
||||
def __init__(self, *units):
|
||||
"""Create a block with len(units) - 1 convolutions.
|
||||
|
||||
convolution number i transforms the number of channels from
|
||||
units[i - 1] to units[i] channels.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.convs = torch.nn.ModuleList([
|
||||
torch.nn.Conv2d(in_, out, 3, 1, 1)
|
||||
for in_, out in zip(units[:-1], units[1:])
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
# Each convolution is followed by a ReLU, then the block is
|
||||
# concluded by a max pooling.
|
||||
for c in self.convs:
|
||||
x = F.relu(c(x))
|
||||
return F.max_pool2d(x, 2, 2, 0, ceil_mode=True)
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
from torchvision import models
|
||||
class Vgg19(torch.nn.Module):
|
||||
"""
|
||||
Vgg19 network for perceptual loss.
|
||||
"""
|
||||
def __init__(self, requires_grad=False):
|
||||
super(Vgg19, self).__init__()
|
||||
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
for x in range(2):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(2, 7):
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(7, 12):
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(12, 21):
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(21, 30):
|
||||
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||
|
||||
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
|
||||
requires_grad=False)
|
||||
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
|
||||
requires_grad=False)
|
||||
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
X = (X - self.mean) / self.std
|
||||
h_relu1 = self.slice1(X)
|
||||
h_relu2 = self.slice2(h_relu1)
|
||||
h_relu3 = self.slice3(h_relu2)
|
||||
h_relu4 = self.slice4(h_relu3)
|
||||
h_relu5 = self.slice5(h_relu4)
|
||||
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
||||
return out
|
||||
|
||||
|
||||
from torch import nn
|
||||
class AntiAliasInterpolation2d(nn.Module):
|
||||
"""
|
||||
Band-limited downsampling, for better preservation of the input signal.
|
||||
"""
|
||||
def __init__(self, channels, scale):
|
||||
super(AntiAliasInterpolation2d, self).__init__()
|
||||
sigma = (1 / scale - 1) / 2
|
||||
kernel_size = 2 * round(sigma * 4) + 1
|
||||
self.ka = kernel_size // 2
|
||||
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
|
||||
|
||||
kernel_size = [kernel_size, kernel_size]
|
||||
sigma = [sigma, sigma]
|
||||
# The gaussian kernel is the product of the
|
||||
# gaussian function of each dimension.
|
||||
kernel = 1
|
||||
meshgrids = torch.meshgrid(
|
||||
[
|
||||
torch.arange(size, dtype=torch.float32)
|
||||
for size in kernel_size
|
||||
]
|
||||
)
|
||||
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
||||
mean = (size - 1) / 2
|
||||
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
|
||||
|
||||
# Make sure sum of values in gaussian kernel equals 1.
|
||||
kernel = kernel / torch.sum(kernel)
|
||||
# Reshape to depthwise convolutional weight
|
||||
kernel = kernel.view(1, 1, *kernel.size())
|
||||
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
||||
|
||||
self.register_buffer('weight', kernel)
|
||||
self.groups = channels
|
||||
self.scale = scale
|
||||
inv_scale = 1 / scale
|
||||
self.int_inv_scale = int(inv_scale)
|
||||
|
||||
def forward(self, input):
|
||||
if self.scale == 1.0:
|
||||
return input
|
||||
|
||||
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
|
||||
out = F.conv2d(out, weight=self.weight, groups=self.groups)
|
||||
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ImagePyramide(torch.nn.Module):
|
||||
"""
|
||||
Create image pyramide for computing pyramide perceptual loss.
|
||||
"""
|
||||
def __init__(self, scales, num_channels):
|
||||
super(ImagePyramide, self).__init__()
|
||||
downs = {}
|
||||
for scale in scales:
|
||||
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
|
||||
self.downs = nn.ModuleDict(downs)
|
||||
|
||||
def forward(self, x):
|
||||
out_dict = {}
|
||||
for scale, down_module in self.downs.items():
|
||||
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
|
||||
return out_dict
|
||||
240
models/MuseTalk/musetalk/models/syncnet.py
Normal file
240
models/MuseTalk/musetalk/models/syncnet.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/models/stable_syncnet.py).
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from einops import rearrange
|
||||
from torch.nn import functional as F
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.models.attention import Attention as CrossAttention, FeedForward
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class SyncNet(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.audio_encoder = DownEncoder2D(
|
||||
in_channels=config["audio_encoder"]["in_channels"],
|
||||
block_out_channels=config["audio_encoder"]["block_out_channels"],
|
||||
downsample_factors=config["audio_encoder"]["downsample_factors"],
|
||||
dropout=config["audio_encoder"]["dropout"],
|
||||
attn_blocks=config["audio_encoder"]["attn_blocks"],
|
||||
)
|
||||
|
||||
self.visual_encoder = DownEncoder2D(
|
||||
in_channels=config["visual_encoder"]["in_channels"],
|
||||
block_out_channels=config["visual_encoder"]["block_out_channels"],
|
||||
downsample_factors=config["visual_encoder"]["downsample_factors"],
|
||||
dropout=config["visual_encoder"]["dropout"],
|
||||
attn_blocks=config["visual_encoder"]["attn_blocks"],
|
||||
)
|
||||
|
||||
self.eval()
|
||||
|
||||
def forward(self, image_sequences, audio_sequences):
|
||||
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
||||
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
||||
|
||||
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
||||
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
||||
|
||||
# Make them unit vectors
|
||||
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
||||
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
||||
|
||||
return vision_embeds, audio_embeds
|
||||
|
||||
def get_image_embed(self, image_sequences):
|
||||
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
||||
|
||||
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
||||
|
||||
# Make them unit vectors
|
||||
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
||||
|
||||
return vision_embeds
|
||||
|
||||
def get_audio_embed(self, audio_sequences):
|
||||
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
||||
|
||||
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
||||
|
||||
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
||||
|
||||
return audio_embeds
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
eps: float = 1e-6,
|
||||
act_fn: str = "silu",
|
||||
downsample_factor=2,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if act_fn == "relu":
|
||||
self.act_fn = nn.ReLU()
|
||||
elif act_fn == "silu":
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.conv_shortcut = None
|
||||
|
||||
if isinstance(downsample_factor, list):
|
||||
downsample_factor = tuple(downsample_factor)
|
||||
|
||||
if downsample_factor == 1:
|
||||
self.downsample_conv = None
|
||||
else:
|
||||
self.downsample_conv = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
|
||||
)
|
||||
self.pad = (0, 1, 0, 1)
|
||||
if isinstance(downsample_factor, tuple):
|
||||
if downsample_factor[0] == 1:
|
||||
self.pad = (0, 1, 1, 1) # The padding order is from back to front
|
||||
elif downsample_factor[1] == 1:
|
||||
self.pad = (1, 1, 0, 1)
|
||||
|
||||
def forward(self, input_tensor):
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
hidden_states += input_tensor
|
||||
|
||||
if self.downsample_conv is not None:
|
||||
hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
|
||||
hidden_states = self.downsample_conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttentionBlock2D(nn.Module):
|
||||
def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
|
||||
super().__init__()
|
||||
if not is_xformers_available():
|
||||
raise ModuleNotFoundError(
|
||||
"You have to install xformers to enable memory efficient attetion", name="xformers"
|
||||
)
|
||||
# inner_dim = dim_head * heads
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
self.norm3 = nn.LayerNorm(query_dim)
|
||||
|
||||
self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
|
||||
|
||||
self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
||||
self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.attn = CrossAttention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
|
||||
self.attn._use_memory_efficient_attention_xformers = True
|
||||
|
||||
def forward(self, hidden_states):
|
||||
assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
|
||||
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
|
||||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
||||
|
||||
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DownEncoder2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4 * 16,
|
||||
block_out_channels=[64, 128, 256, 256],
|
||||
downsample_factors=[2, 2, 2, 2],
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
attn_blocks=[1, 1, 1, 1],
|
||||
dropout: float = 0.0,
|
||||
act_fn="silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
# in
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# down
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
output_channels = block_out_channels[0]
|
||||
for i, block_out_channel in enumerate(block_out_channels):
|
||||
input_channels = output_channels
|
||||
output_channels = block_out_channel
|
||||
# is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = ResnetBlock2D(
|
||||
in_channels=input_channels,
|
||||
out_channels=output_channels,
|
||||
downsample_factor=downsample_factors[i],
|
||||
norm_num_groups=norm_num_groups,
|
||||
dropout=dropout,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
if attn_blocks[i] == 1:
|
||||
attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
|
||||
self.down_blocks.append(attention_block)
|
||||
|
||||
# out
|
||||
self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.act_fn_out = nn.ReLU()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
|
||||
# post-process
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.act_fn_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
51
models/MuseTalk/musetalk/models/unet.py
Normal file
51
models/MuseTalk/musetalk/models/unet.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
import json
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
import sys
|
||||
import time
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model=384, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
b, seq_len, d_model = x.size()
|
||||
pe = self.pe[:, :seq_len, :]
|
||||
x = x + pe.to(x.device)
|
||||
return x
|
||||
|
||||
class UNet():
|
||||
def __init__(self,
|
||||
unet_config,
|
||||
model_path,
|
||||
use_float16=False,
|
||||
device=None
|
||||
):
|
||||
with open(unet_config, 'r') as f:
|
||||
unet_config = json.load(f)
|
||||
self.model = UNet2DConditionModel(**unet_config)
|
||||
self.pe = PositionalEncoding(d_model=384)
|
||||
if device != None:
|
||||
self.device = device
|
||||
else:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
|
||||
self.model.load_state_dict(weights)
|
||||
if use_float16:
|
||||
self.model = self.model.half()
|
||||
self.model.to(self.device)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unet = UNet()
|
||||
148
models/MuseTalk/musetalk/models/vae.py
Normal file
148
models/MuseTalk/musetalk/models/vae.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from diffusers import AutoencoderKL
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
import torch.nn.functional as F
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
class VAE():
|
||||
"""
|
||||
VAE (Variational Autoencoder) class for image processing.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False):
|
||||
"""
|
||||
Initialize the VAE instance.
|
||||
|
||||
:param model_path: Path to the trained model.
|
||||
:param resized_img: The size to which images are resized.
|
||||
:param use_float16: Whether to use float16 precision.
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.vae = AutoencoderKL.from_pretrained(self.model_path)
|
||||
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.vae.to(self.device)
|
||||
|
||||
if use_float16:
|
||||
self.vae = self.vae.half()
|
||||
self._use_float16 = True
|
||||
else:
|
||||
self._use_float16 = False
|
||||
|
||||
self.scaling_factor = self.vae.config.scaling_factor
|
||||
self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
self._resized_img = resized_img
|
||||
self._mask_tensor = self.get_mask_tensor()
|
||||
|
||||
def get_mask_tensor(self):
|
||||
"""
|
||||
Creates a mask tensor for image processing.
|
||||
:return: A mask tensor.
|
||||
"""
|
||||
mask_tensor = torch.zeros((self._resized_img,self._resized_img))
|
||||
mask_tensor[:self._resized_img//2,:] = 1
|
||||
mask_tensor[mask_tensor< 0.5] = 0
|
||||
mask_tensor[mask_tensor>= 0.5] = 1
|
||||
return mask_tensor
|
||||
|
||||
def preprocess_img(self,img_name,half_mask=False):
|
||||
"""
|
||||
Preprocess an image for the VAE.
|
||||
|
||||
:param img_name: The image file path or a list of image file paths.
|
||||
:param half_mask: Whether to apply a half mask to the image.
|
||||
:return: A preprocessed image tensor.
|
||||
"""
|
||||
window = []
|
||||
if isinstance(img_name, str):
|
||||
window_fnames = [img_name]
|
||||
for fname in window_fnames:
|
||||
img = cv2.imread(fname)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(img, (self._resized_img, self._resized_img),
|
||||
interpolation=cv2.INTER_LANCZOS4)
|
||||
window.append(img)
|
||||
else:
|
||||
img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
|
||||
window.append(img)
|
||||
|
||||
x = np.asarray(window) / 255.
|
||||
x = np.transpose(x, (3, 0, 1, 2))
|
||||
x = torch.squeeze(torch.FloatTensor(x))
|
||||
if half_mask:
|
||||
x = x * (self._mask_tensor>0.5)
|
||||
x = self.transform(x)
|
||||
|
||||
x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
|
||||
x = x.to(self.vae.device)
|
||||
|
||||
return x
|
||||
|
||||
def encode_latents(self,image):
|
||||
"""
|
||||
Encode an image into latent variables.
|
||||
|
||||
:param image: The image tensor to encode.
|
||||
:return: The encoded latent variables.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist
|
||||
init_latents = self.scaling_factor * init_latent_dist.sample()
|
||||
return init_latents
|
||||
|
||||
def decode_latents(self, latents):
|
||||
"""
|
||||
Decode latent variables back into an image.
|
||||
:param latents: The latent variables to decode.
|
||||
:return: A NumPy array representing the decoded image.
|
||||
"""
|
||||
latents = (1/ self.scaling_factor) * latents
|
||||
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
image = (image * 255).round().astype("uint8")
|
||||
image = image[...,::-1] # RGB to BGR
|
||||
return image
|
||||
|
||||
def get_latents_for_unet(self,img):
|
||||
"""
|
||||
Prepare latent variables for a U-Net model.
|
||||
:param img: The image to process.
|
||||
:return: A concatenated tensor of latents for U-Net input.
|
||||
"""
|
||||
|
||||
ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
|
||||
masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
||||
ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
|
||||
ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
||||
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
return latent_model_input
|
||||
|
||||
if __name__ == "__main__":
|
||||
vae_mode_path = "./models/sd-vae-ft-mse/"
|
||||
vae = VAE(model_path = vae_mode_path,use_float16=False)
|
||||
img_path = "./results/sun001_crop/00000.png"
|
||||
|
||||
crop_imgs_path = "./results/sun001_crop/"
|
||||
latents_out_path = "./results/latents/"
|
||||
if not os.path.exists(latents_out_path):
|
||||
os.mkdir(latents_out_path)
|
||||
|
||||
files = os.listdir(crop_imgs_path)
|
||||
files.sort()
|
||||
files = [file for file in files if file.split(".")[-1] == "png"]
|
||||
|
||||
for file in files:
|
||||
index = file.split(".")[0]
|
||||
img_path = crop_imgs_path + file
|
||||
latents = vae.get_latents_for_unet(img_path)
|
||||
print(img_path,"latents",latents.size())
|
||||
#torch.save(latents,os.path.join(latents_out_path,index+".pt"))
|
||||
#reload_tensor = torch.load('tensor.pt')
|
||||
#print(reload_tensor.size())
|
||||
|
||||
|
||||
|
||||
5
models/MuseTalk/musetalk/utils/__init__.py
Normal file
5
models/MuseTalk/musetalk/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import sys
|
||||
from os.path import abspath, dirname
|
||||
current_dir = dirname(abspath(__file__))
|
||||
parent_dir = dirname(current_dir)
|
||||
sys.path.append(parent_dir+'/utils')
|
||||
113
models/MuseTalk/musetalk/utils/audio_processor.py
Normal file
113
models/MuseTalk/musetalk/utils/audio_processor.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import math
|
||||
import os
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
|
||||
class AudioProcessor:
|
||||
def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
|
||||
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
|
||||
|
||||
def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
|
||||
if not os.path.exists(wav_path):
|
||||
return None
|
||||
librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
|
||||
assert sampling_rate == 16000
|
||||
# Split audio into 30s segments
|
||||
segment_length = 30 * sampling_rate
|
||||
segments = [librosa_output[i:i + segment_length] for i in range(0, len(librosa_output), segment_length)]
|
||||
|
||||
features = []
|
||||
for segment in segments:
|
||||
audio_feature = self.feature_extractor(
|
||||
segment,
|
||||
return_tensors="pt",
|
||||
sampling_rate=sampling_rate
|
||||
).input_features
|
||||
if weight_dtype is not None:
|
||||
audio_feature = audio_feature.to(dtype=weight_dtype)
|
||||
features.append(audio_feature)
|
||||
|
||||
return features, len(librosa_output)
|
||||
|
||||
def get_whisper_chunk(
|
||||
self,
|
||||
whisper_input_features,
|
||||
device,
|
||||
weight_dtype,
|
||||
whisper,
|
||||
librosa_length,
|
||||
fps=25,
|
||||
audio_padding_length_left=2,
|
||||
audio_padding_length_right=2,
|
||||
):
|
||||
audio_feature_length_per_frame = 2 * (audio_padding_length_left + audio_padding_length_right + 1)
|
||||
whisper_feature = []
|
||||
# Process multiple 30s mel input features
|
||||
for input_feature in whisper_input_features:
|
||||
input_feature = input_feature.to(device).to(weight_dtype)
|
||||
audio_feats = whisper.encoder(input_feature, output_hidden_states=True).hidden_states
|
||||
audio_feats = torch.stack(audio_feats, dim=2)
|
||||
whisper_feature.append(audio_feats)
|
||||
|
||||
whisper_feature = torch.cat(whisper_feature, dim=1)
|
||||
# Trim the last segment to remove padding
|
||||
sr = 16000
|
||||
audio_fps = 50
|
||||
fps = int(fps)
|
||||
whisper_idx_multiplier = audio_fps / fps
|
||||
num_frames = math.floor((librosa_length / sr) * fps)
|
||||
actual_length = math.floor((librosa_length / sr) * audio_fps)
|
||||
whisper_feature = whisper_feature[:,:actual_length,...]
|
||||
|
||||
# Calculate padding amount
|
||||
padding_nums = math.ceil(whisper_idx_multiplier)
|
||||
# Add padding at start and end
|
||||
whisper_feature = torch.cat([
|
||||
torch.zeros_like(whisper_feature[:, :padding_nums * audio_padding_length_left]),
|
||||
whisper_feature,
|
||||
# Add extra padding to prevent out of bounds
|
||||
torch.zeros_like(whisper_feature[:, :padding_nums * 3 * audio_padding_length_right])
|
||||
], 1)
|
||||
|
||||
audio_prompts = []
|
||||
for frame_index in range(num_frames):
|
||||
audio_index = math.floor(frame_index * whisper_idx_multiplier)
|
||||
end_index = audio_index + audio_feature_length_per_frame
|
||||
|
||||
# Handle case where audio is shorter than video
|
||||
if end_index > whisper_feature.shape[1]:
|
||||
available = whisper_feature[:, audio_index:]
|
||||
padding_size = end_index - whisper_feature.shape[1]
|
||||
if padding_size > 0:
|
||||
padding = torch.zeros((whisper_feature.shape[0], padding_size, *whisper_feature.shape[2:]),
|
||||
device=whisper_feature.device, dtype=whisper_feature.dtype)
|
||||
audio_clip = torch.cat([available, padding], dim=1)
|
||||
else:
|
||||
audio_clip = available
|
||||
else:
|
||||
audio_clip = whisper_feature[:, audio_index: end_index]
|
||||
|
||||
# Final size check and padding
|
||||
if audio_clip.shape[1] < audio_feature_length_per_frame:
|
||||
padding_size = audio_feature_length_per_frame - audio_clip.shape[1]
|
||||
padding = torch.zeros((whisper_feature.shape[0], padding_size, *whisper_feature.shape[2:]),
|
||||
device=whisper_feature.device, dtype=whisper_feature.dtype)
|
||||
audio_clip = torch.cat([audio_clip, padding], dim=1)
|
||||
|
||||
audio_prompts.append(audio_clip)
|
||||
|
||||
audio_prompts = torch.cat(audio_prompts, dim=0) # T, 10, 5, 384
|
||||
audio_prompts = rearrange(audio_prompts, 'b c h w -> b (c h) w')
|
||||
return audio_prompts
|
||||
|
||||
if __name__ == "__main__":
|
||||
audio_processor = AudioProcessor()
|
||||
wav_path = "./2.wav"
|
||||
audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
|
||||
print("Audio Feature shape:", audio_feature.shape)
|
||||
print("librosa_feature_length:", librosa_feature_length)
|
||||
17
models/MuseTalk/musetalk/utils/audio_utils.py
Normal file
17
models/MuseTalk/musetalk/utils/audio_utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import os, subprocess
|
||||
|
||||
def ensure_wav(input_path: str, target_path: str | None = None) -> str:
|
||||
"""
|
||||
Convert any audio (mp3/ogg/m4a/wav/…) to 16kHz mono PCM WAV via ffmpeg.
|
||||
Returns path to the converted .wav (original if already correct).
|
||||
"""
|
||||
if not isinstance(input_path, str) or not os.path.exists(input_path):
|
||||
return input_path
|
||||
base, ext = os.path.splitext(input_path)
|
||||
ext = ext.lower()
|
||||
|
||||
if target_path is None:
|
||||
target_path = base + "_16k.wav"
|
||||
cmd = ["ffmpeg", "-y", "-i", input_path, "-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le", target_path]
|
||||
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
return target_path
|
||||
136
models/MuseTalk/musetalk/utils/blending.py
Normal file
136
models/MuseTalk/musetalk/utils/blending.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import cv2
|
||||
import copy
|
||||
|
||||
|
||||
def get_crop_box(box, expand):
|
||||
x, y, x1, y1 = box
|
||||
x_c, y_c = (x+x1)//2, (y+y1)//2
|
||||
w, h = x1-x, y1-y
|
||||
s = int(max(w, h)//2*expand)
|
||||
crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
|
||||
return crop_box, s
|
||||
|
||||
|
||||
def face_seg(image, mode="raw", fp=None):
|
||||
"""
|
||||
对图像进行面部解析,生成面部区域的掩码。
|
||||
|
||||
Args:
|
||||
image (PIL.Image): 输入图像。
|
||||
|
||||
Returns:
|
||||
PIL.Image: 面部区域的掩码图像。
|
||||
"""
|
||||
seg_image = fp(image, mode=mode) # 使用 FaceParsing 模型解析面部
|
||||
if seg_image is None:
|
||||
print("error, no person_segment") # 如果没有检测到面部,返回错误
|
||||
return None
|
||||
|
||||
seg_image = seg_image.resize(image.size) # 将掩码图像调整为输入图像的大小
|
||||
return seg_image
|
||||
|
||||
|
||||
def get_image(image, face, face_box, upper_boundary_ratio=0.5, expand=1.5, mode="raw", fp=None):
|
||||
"""
|
||||
将裁剪的面部图像粘贴回原始图像,并进行一些处理。
|
||||
|
||||
Args:
|
||||
image (numpy.ndarray): 原始图像(身体部分)。
|
||||
face (numpy.ndarray): 裁剪的面部图像。
|
||||
face_box (tuple): 面部边界框的坐标 (x, y, x1, y1)。
|
||||
upper_boundary_ratio (float): 用于控制面部区域的保留比例。
|
||||
expand (float): 扩展因子,用于放大裁剪框。
|
||||
mode: 融合mask构建方式
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: 处理后的图像。
|
||||
"""
|
||||
# 将 numpy 数组转换为 PIL 图像
|
||||
body = Image.fromarray(image[:, :, ::-1]) # 身体部分图像(整张图)
|
||||
face = Image.fromarray(face[:, :, ::-1]) # 面部图像
|
||||
|
||||
x, y, x1, y1 = face_box # 获取面部边界框的坐标
|
||||
crop_box, s = get_crop_box(face_box, expand) # 计算扩展后的裁剪框
|
||||
x_s, y_s, x_e, y_e = crop_box # 裁剪框的坐标
|
||||
face_position = (x, y) # 面部在原始图像中的位置
|
||||
|
||||
# 从身体图像中裁剪出扩展后的面部区域(下巴到边界有距离)
|
||||
face_large = body.crop(crop_box)
|
||||
|
||||
ori_shape = face_large.size # 裁剪后图像的原始尺寸
|
||||
|
||||
# 对裁剪后的面部区域进行面部解析,生成掩码
|
||||
mask_image = face_seg(face_large, mode=mode, fp=fp)
|
||||
|
||||
mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s)) # 裁剪出面部区域的掩码
|
||||
|
||||
mask_image = Image.new('L', ori_shape, 0) # 创建一个全黑的掩码图像
|
||||
mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s)) # 将面部掩码粘贴到全黑图像上
|
||||
|
||||
|
||||
# 保留面部区域的上半部分(用于控制说话区域)
|
||||
width, height = mask_image.size
|
||||
top_boundary = int(height * upper_boundary_ratio) # 计算上半部分的边界
|
||||
modified_mask_image = Image.new('L', ori_shape, 0) # 创建一个新的全黑掩码图像
|
||||
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) # 粘贴上半部分掩码
|
||||
|
||||
|
||||
# 对掩码进行高斯模糊,使边缘更平滑
|
||||
blur_kernel_size = int(0.05 * ori_shape[0] // 2 * 2) + 1 # 计算模糊核大小
|
||||
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) # 高斯模糊
|
||||
#mask_array = np.array(modified_mask_image)
|
||||
mask_image = Image.fromarray(mask_array) # 将模糊后的掩码转换回 PIL 图像
|
||||
|
||||
# 将裁剪的面部图像粘贴回扩展后的面部区域
|
||||
face_large.paste(face, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
|
||||
|
||||
body.paste(face_large, crop_box[:2], mask_image)
|
||||
|
||||
body = np.array(body) # 将 PIL 图像转换回 numpy 数组
|
||||
|
||||
return body[:, :, ::-1] # 返回处理后的图像(BGR 转 RGB)
|
||||
|
||||
|
||||
def get_image_blending(image, face, face_box, mask_array, crop_box):
|
||||
body = Image.fromarray(image[:,:,::-1])
|
||||
face = Image.fromarray(face[:,:,::-1])
|
||||
|
||||
x, y, x1, y1 = face_box
|
||||
x_s, y_s, x_e, y_e = crop_box
|
||||
face_large = body.crop(crop_box)
|
||||
|
||||
mask_image = Image.fromarray(mask_array)
|
||||
mask_image = mask_image.convert("L")
|
||||
face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||
body.paste(face_large, crop_box[:2], mask_image)
|
||||
body = np.array(body)
|
||||
return body[:,:,::-1]
|
||||
|
||||
|
||||
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.5, fp=None, mode="raw"):
|
||||
body = Image.fromarray(image[:,:,::-1])
|
||||
|
||||
x, y, x1, y1 = face_box
|
||||
#print(x1-x,y1-y)
|
||||
crop_box, s = get_crop_box(face_box, expand)
|
||||
x_s, y_s, x_e, y_e = crop_box
|
||||
|
||||
face_large = body.crop(crop_box)
|
||||
ori_shape = face_large.size
|
||||
|
||||
mask_image = face_seg(face_large, mode=mode, fp=fp)
|
||||
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||
mask_image = Image.new('L', ori_shape, 0)
|
||||
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||
|
||||
# keep upper_boundary_ratio of talking area
|
||||
width, height = mask_image.size
|
||||
top_boundary = int(height * upper_boundary_ratio)
|
||||
modified_mask_image = Image.new('L', ori_shape, 0)
|
||||
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
|
||||
|
||||
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
|
||||
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
|
||||
return mask_array, crop_box
|
||||
54
models/MuseTalk/musetalk/utils/dwpose/default_runtime.py
Normal file
54
models/MuseTalk/musetalk/utils/dwpose/default_runtime.py
Normal file
@@ -0,0 +1,54 @@
|
||||
default_scope = 'mmpose'
|
||||
|
||||
# hooks
|
||||
default_hooks = dict(
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook', interval=50),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', interval=10),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
visualization=dict(type='PoseVisualizationHook', enable=False),
|
||||
badcase=dict(
|
||||
type='BadCaseAnalysisHook',
|
||||
enable=False,
|
||||
out_dir='badcase',
|
||||
metric_type='loss',
|
||||
badcase_thr=5))
|
||||
|
||||
# custom hooks
|
||||
custom_hooks = [
|
||||
# Synchronize model buffers such as running_mean and running_var in BN
|
||||
# at the end of each epoch
|
||||
dict(type='SyncBuffersHook')
|
||||
]
|
||||
|
||||
# multi-processing backend
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=False,
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
)
|
||||
|
||||
# visualizer
|
||||
vis_backends = [
|
||||
dict(type='LocalVisBackend'),
|
||||
# dict(type='TensorboardVisBackend'),
|
||||
# dict(type='WandbVisBackend'),
|
||||
]
|
||||
visualizer = dict(
|
||||
type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
|
||||
|
||||
# logger
|
||||
log_processor = dict(
|
||||
type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume = False
|
||||
|
||||
# file I/O backend
|
||||
backend_args = dict(backend='local')
|
||||
|
||||
# training/validation/testing progress
|
||||
train_cfg = dict(by_epoch=True)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
@@ -0,0 +1,257 @@
|
||||
#_base_ = ['../../../_base_/default_runtime.py']
|
||||
_base_ = ['default_runtime.py']
|
||||
|
||||
# runtime
|
||||
max_epochs = 270
|
||||
stage2_num_epochs = 30
|
||||
base_lr = 4e-3
|
||||
train_batch_size = 32
|
||||
val_batch_size = 32
|
||||
|
||||
train_cfg = dict(max_epochs=max_epochs, val_interval=10)
|
||||
randomness = dict(seed=21)
|
||||
|
||||
# optimizer
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
|
||||
|
||||
# learning rate
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1.0e-5,
|
||||
by_epoch=False,
|
||||
begin=0,
|
||||
end=1000),
|
||||
dict(
|
||||
# use cosine lr from 150 to 300 epoch
|
||||
type='CosineAnnealingLR',
|
||||
eta_min=base_lr * 0.05,
|
||||
begin=max_epochs // 2,
|
||||
end=max_epochs,
|
||||
T_max=max_epochs // 2,
|
||||
by_epoch=True,
|
||||
convert_to_iter_based=True),
|
||||
]
|
||||
|
||||
# automatically scaling LR based on the actual training batch size
|
||||
auto_scale_lr = dict(base_batch_size=512)
|
||||
|
||||
# codec settings
|
||||
codec = dict(
|
||||
type='SimCCLabel',
|
||||
input_size=(288, 384),
|
||||
sigma=(6., 6.93),
|
||||
simcc_split_ratio=2.0,
|
||||
normalize=False,
|
||||
use_dark=False)
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='TopdownPoseEstimator',
|
||||
data_preprocessor=dict(
|
||||
type='PoseDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True),
|
||||
backbone=dict(
|
||||
_scope_='mmdet',
|
||||
type='CSPNeXt',
|
||||
arch='P5',
|
||||
expand_ratio=0.5,
|
||||
deepen_factor=1.,
|
||||
widen_factor=1.,
|
||||
out_indices=(4, ),
|
||||
channel_attention=True,
|
||||
norm_cfg=dict(type='SyncBN'),
|
||||
act_cfg=dict(type='SiLU'),
|
||||
init_cfg=dict(
|
||||
type='Pretrained',
|
||||
prefix='backbone.',
|
||||
checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
|
||||
'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa: E501
|
||||
)),
|
||||
head=dict(
|
||||
type='RTMCCHead',
|
||||
in_channels=1024,
|
||||
out_channels=133,
|
||||
input_size=codec['input_size'],
|
||||
in_featuremap_size=(9, 12),
|
||||
simcc_split_ratio=codec['simcc_split_ratio'],
|
||||
final_layer_kernel_size=7,
|
||||
gau_cfg=dict(
|
||||
hidden_dims=256,
|
||||
s=128,
|
||||
expansion_factor=2,
|
||||
dropout_rate=0.,
|
||||
drop_path=0.,
|
||||
act_fn='SiLU',
|
||||
use_rel_bias=False,
|
||||
pos_enc=False),
|
||||
loss=dict(
|
||||
type='KLDiscretLoss',
|
||||
use_target_weight=True,
|
||||
beta=10.,
|
||||
label_softmax=True),
|
||||
decoder=codec),
|
||||
test_cfg=dict(flip_test=True, ))
|
||||
|
||||
# base dataset settings
|
||||
dataset_type = 'UBody2dDataset'
|
||||
data_mode = 'topdown'
|
||||
data_root = 'data/UBody/'
|
||||
|
||||
backend_args = dict(backend='local')
|
||||
|
||||
scenes = [
|
||||
'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow',
|
||||
'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing',
|
||||
'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'
|
||||
]
|
||||
|
||||
train_datasets = [
|
||||
dict(
|
||||
type='CocoWholeBodyDataset',
|
||||
data_root='data/coco/',
|
||||
data_mode=data_mode,
|
||||
ann_file='annotations/coco_wholebody_train_v1.0.json',
|
||||
data_prefix=dict(img='train2017/'),
|
||||
pipeline=[])
|
||||
]
|
||||
|
||||
for scene in scenes:
|
||||
train_dataset = dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_mode=data_mode,
|
||||
ann_file=f'annotations/{scene}/train_annotations.json',
|
||||
data_prefix=dict(img='images/'),
|
||||
pipeline=[],
|
||||
sample_interval=10)
|
||||
train_datasets.append(train_dataset)
|
||||
|
||||
# pipelines
|
||||
train_pipeline = [
|
||||
dict(type='LoadImage', backend_args=backend_args),
|
||||
dict(type='GetBBoxCenterScale'),
|
||||
dict(type='RandomFlip', direction='horizontal'),
|
||||
dict(type='RandomHalfBody'),
|
||||
dict(
|
||||
type='RandomBBoxTransform', scale_factor=[0.5, 1.5], rotate_factor=90),
|
||||
dict(type='TopdownAffine', input_size=codec['input_size']),
|
||||
dict(type='mmdet.YOLOXHSVRandomAug'),
|
||||
dict(
|
||||
type='Albumentation',
|
||||
transforms=[
|
||||
dict(type='Blur', p=0.1),
|
||||
dict(type='MedianBlur', p=0.1),
|
||||
dict(
|
||||
type='CoarseDropout',
|
||||
max_holes=1,
|
||||
max_height=0.4,
|
||||
max_width=0.4,
|
||||
min_holes=1,
|
||||
min_height=0.2,
|
||||
min_width=0.2,
|
||||
p=1.0),
|
||||
]),
|
||||
dict(type='GenerateTarget', encoder=codec),
|
||||
dict(type='PackPoseInputs')
|
||||
]
|
||||
val_pipeline = [
|
||||
dict(type='LoadImage', backend_args=backend_args),
|
||||
dict(type='GetBBoxCenterScale'),
|
||||
dict(type='TopdownAffine', input_size=codec['input_size']),
|
||||
dict(type='PackPoseInputs')
|
||||
]
|
||||
|
||||
train_pipeline_stage2 = [
|
||||
dict(type='LoadImage', backend_args=backend_args),
|
||||
dict(type='GetBBoxCenterScale'),
|
||||
dict(type='RandomFlip', direction='horizontal'),
|
||||
dict(type='RandomHalfBody'),
|
||||
dict(
|
||||
type='RandomBBoxTransform',
|
||||
shift_factor=0.,
|
||||
scale_factor=[0.5, 1.5],
|
||||
rotate_factor=90),
|
||||
dict(type='TopdownAffine', input_size=codec['input_size']),
|
||||
dict(type='mmdet.YOLOXHSVRandomAug'),
|
||||
dict(
|
||||
type='Albumentation',
|
||||
transforms=[
|
||||
dict(type='Blur', p=0.1),
|
||||
dict(type='MedianBlur', p=0.1),
|
||||
dict(
|
||||
type='CoarseDropout',
|
||||
max_holes=1,
|
||||
max_height=0.4,
|
||||
max_width=0.4,
|
||||
min_holes=1,
|
||||
min_height=0.2,
|
||||
min_width=0.2,
|
||||
p=0.5),
|
||||
]),
|
||||
dict(type='GenerateTarget', encoder=codec),
|
||||
dict(type='PackPoseInputs')
|
||||
]
|
||||
|
||||
# data loaders
|
||||
train_dataloader = dict(
|
||||
batch_size=train_batch_size,
|
||||
num_workers=10,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type='CombinedDataset',
|
||||
metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
|
||||
datasets=train_datasets,
|
||||
pipeline=train_pipeline,
|
||||
test_mode=False,
|
||||
))
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=val_batch_size,
|
||||
num_workers=10,
|
||||
persistent_workers=True,
|
||||
drop_last=False,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
|
||||
dataset=dict(
|
||||
type='CocoWholeBodyDataset',
|
||||
data_root=data_root,
|
||||
data_mode=data_mode,
|
||||
ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json',
|
||||
bbox_file='data/coco/person_detection_results/'
|
||||
'COCO_val2017_detections_AP_H_56_person.json',
|
||||
data_prefix=dict(img='coco/val2017/'),
|
||||
test_mode=True,
|
||||
pipeline=val_pipeline,
|
||||
))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
# hooks
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
|
||||
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='EMAHook',
|
||||
ema_type='ExpMomentumEMA',
|
||||
momentum=0.0002,
|
||||
update_buffers=True,
|
||||
priority=49),
|
||||
dict(
|
||||
type='mmdet.PipelineSwitchHook',
|
||||
switch_epoch=max_epochs - stage2_num_epochs,
|
||||
switch_pipeline=train_pipeline_stage2)
|
||||
]
|
||||
|
||||
# evaluators
|
||||
val_evaluator = dict(
|
||||
type='CocoWholeBodyMetric',
|
||||
ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json')
|
||||
test_evaluator = val_evaluator
|
||||
1
models/MuseTalk/musetalk/utils/face_detection/README.md
Normal file
1
models/MuseTalk/musetalk/utils/face_detection/README.md
Normal file
@@ -0,0 +1 @@
|
||||
The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
|
||||
@@ -0,0 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
__author__ = """Adrian Bulat"""
|
||||
__email__ = 'adrian.bulat@nottingham.ac.uk'
|
||||
__version__ = '1.0.1'
|
||||
|
||||
from .api import FaceAlignment, LandmarksType, NetworkSize, YOLOv8_face
|
||||
240
models/MuseTalk/musetalk/utils/face_detection/api.py
Normal file
240
models/MuseTalk/musetalk/utils/face_detection/api.py
Normal file
@@ -0,0 +1,240 @@
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.model_zoo import load_url
|
||||
from enum import Enum
|
||||
import numpy as np
|
||||
import cv2
|
||||
try:
|
||||
import urllib.request as request_file
|
||||
except BaseException:
|
||||
import urllib as request_file
|
||||
|
||||
from .models import FAN, ResNetDepth
|
||||
from .utils import *
|
||||
|
||||
|
||||
class LandmarksType(Enum):
|
||||
"""Enum class defining the type of landmarks to detect.
|
||||
|
||||
``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
|
||||
``_2halfD`` - this points represent the projection of the 3D points into 3D
|
||||
``_3D`` - detect the points ``(x,y,z)``` in a 3D space
|
||||
|
||||
"""
|
||||
_2D = 1
|
||||
_2halfD = 2
|
||||
_3D = 3
|
||||
|
||||
|
||||
class NetworkSize(Enum):
|
||||
# TINY = 1
|
||||
# SMALL = 2
|
||||
# MEDIUM = 3
|
||||
LARGE = 4
|
||||
|
||||
def __new__(cls, value):
|
||||
member = object.__new__(cls)
|
||||
member._value_ = value
|
||||
return member
|
||||
|
||||
def __int__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
|
||||
class FaceAlignment:
|
||||
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
|
||||
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
|
||||
self.device = device
|
||||
self.flip_input = flip_input
|
||||
self.landmarks_type = landmarks_type
|
||||
self.verbose = verbose
|
||||
|
||||
network_size = int(network_size)
|
||||
|
||||
if 'cuda' in device:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cuda.matmul.allow_tf32 = False
|
||||
# torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = False
|
||||
# torch.backends.cudnn.allow_tf32 = True
|
||||
print('cuda start')
|
||||
|
||||
|
||||
# Get the face detector
|
||||
face_detector_module = __import__('face_detection.detection.' + face_detector,
|
||||
globals(), locals(), [face_detector], 0)
|
||||
|
||||
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
|
||||
|
||||
def get_detections_for_batch(self, images):
|
||||
images = images[..., ::-1]
|
||||
detected_faces = self.face_detector.detect_from_batch(images.copy())
|
||||
results = []
|
||||
|
||||
for i, d in enumerate(detected_faces):
|
||||
if len(d) == 0:
|
||||
results.append(None)
|
||||
continue
|
||||
d = d[0]
|
||||
d = np.clip(d, 0, None)
|
||||
|
||||
x1, y1, x2, y2 = map(int, d[:-1])
|
||||
results.append((x1, y1, x2, y2))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class YOLOv8_face:
|
||||
def __init__(self, path = 'face_detection/weights/yolov8n-face.onnx', conf_thres=0.2, iou_thres=0.5):
|
||||
self.conf_threshold = conf_thres
|
||||
self.iou_threshold = iou_thres
|
||||
self.class_names = ['face']
|
||||
self.num_classes = len(self.class_names)
|
||||
# Initialize model
|
||||
self.net = cv2.dnn.readNet(path)
|
||||
self.input_height = 640
|
||||
self.input_width = 640
|
||||
self.reg_max = 16
|
||||
|
||||
self.project = np.arange(self.reg_max)
|
||||
self.strides = (8, 16, 32)
|
||||
self.feats_hw = [(math.ceil(self.input_height / self.strides[i]), math.ceil(self.input_width / self.strides[i])) for i in range(len(self.strides))]
|
||||
self.anchors = self.make_anchors(self.feats_hw)
|
||||
|
||||
def make_anchors(self, feats_hw, grid_cell_offset=0.5):
|
||||
"""Generate anchors from features."""
|
||||
anchor_points = {}
|
||||
for i, stride in enumerate(self.strides):
|
||||
h,w = feats_hw[i]
|
||||
x = np.arange(0, w) + grid_cell_offset # shift x
|
||||
y = np.arange(0, h) + grid_cell_offset # shift y
|
||||
sx, sy = np.meshgrid(x, y)
|
||||
# sy, sx = np.meshgrid(y, x)
|
||||
anchor_points[stride] = np.stack((sx, sy), axis=-1).reshape(-1, 2)
|
||||
return anchor_points
|
||||
|
||||
def softmax(self, x, axis=1):
|
||||
x_exp = np.exp(x)
|
||||
# 如果是列向量,则axis=0
|
||||
x_sum = np.sum(x_exp, axis=axis, keepdims=True)
|
||||
s = x_exp / x_sum
|
||||
return s
|
||||
|
||||
def resize_image(self, srcimg, keep_ratio=True):
|
||||
top, left, newh, neww = 0, 0, self.input_width, self.input_height
|
||||
if keep_ratio and srcimg.shape[0] != srcimg.shape[1]:
|
||||
hw_scale = srcimg.shape[0] / srcimg.shape[1]
|
||||
if hw_scale > 1:
|
||||
newh, neww = self.input_height, int(self.input_width / hw_scale)
|
||||
img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
|
||||
left = int((self.input_width - neww) * 0.5)
|
||||
img = cv2.copyMakeBorder(img, 0, 0, left, self.input_width - neww - left, cv2.BORDER_CONSTANT,
|
||||
value=(0, 0, 0)) # add border
|
||||
else:
|
||||
newh, neww = int(self.input_height * hw_scale), self.input_width
|
||||
img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
|
||||
top = int((self.input_height - newh) * 0.5)
|
||||
img = cv2.copyMakeBorder(img, top, self.input_height - newh - top, 0, 0, cv2.BORDER_CONSTANT,
|
||||
value=(0, 0, 0))
|
||||
else:
|
||||
img = cv2.resize(srcimg, (self.input_width, self.input_height), interpolation=cv2.INTER_AREA)
|
||||
return img, newh, neww, top, left
|
||||
|
||||
def detect(self, srcimg):
|
||||
input_img, newh, neww, padh, padw = self.resize_image(cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB))
|
||||
scale_h, scale_w = srcimg.shape[0]/newh, srcimg.shape[1]/neww
|
||||
input_img = input_img.astype(np.float32) / 255.0
|
||||
|
||||
blob = cv2.dnn.blobFromImage(input_img)
|
||||
self.net.setInput(blob)
|
||||
outputs = self.net.forward(self.net.getUnconnectedOutLayersNames())
|
||||
# if isinstance(outputs, tuple):
|
||||
# outputs = list(outputs)
|
||||
# if float(cv2.__version__[:3])>=4.7:
|
||||
# outputs = [outputs[2], outputs[0], outputs[1]] ###opencv4.7需要这一步,opencv4.5不需要
|
||||
# Perform inference on the image
|
||||
det_bboxes, det_conf, det_classid, landmarks = self.post_process(outputs, scale_h, scale_w, padh, padw)
|
||||
return det_bboxes, det_conf, det_classid, landmarks
|
||||
|
||||
def post_process(self, preds, scale_h, scale_w, padh, padw):
|
||||
bboxes, scores, landmarks = [], [], []
|
||||
for i, pred in enumerate(preds):
|
||||
stride = int(self.input_height/pred.shape[2])
|
||||
pred = pred.transpose((0, 2, 3, 1))
|
||||
|
||||
box = pred[..., :self.reg_max * 4]
|
||||
cls = 1 / (1 + np.exp(-pred[..., self.reg_max * 4:-15])).reshape((-1,1))
|
||||
kpts = pred[..., -15:].reshape((-1,15)) ### x1,y1,score1, ..., x5,y5,score5
|
||||
|
||||
# tmp = box.reshape(self.feats_hw[i][0], self.feats_hw[i][1], 4, self.reg_max)
|
||||
tmp = box.reshape(-1, 4, self.reg_max)
|
||||
bbox_pred = self.softmax(tmp, axis=-1)
|
||||
bbox_pred = np.dot(bbox_pred, self.project).reshape((-1,4))
|
||||
|
||||
bbox = self.distance2bbox(self.anchors[stride], bbox_pred, max_shape=(self.input_height, self.input_width)) * stride
|
||||
kpts[:, 0::3] = (kpts[:, 0::3] * 2.0 + (self.anchors[stride][:, 0].reshape((-1,1)) - 0.5)) * stride
|
||||
kpts[:, 1::3] = (kpts[:, 1::3] * 2.0 + (self.anchors[stride][:, 1].reshape((-1,1)) - 0.5)) * stride
|
||||
kpts[:, 2::3] = 1 / (1+np.exp(-kpts[:, 2::3]))
|
||||
|
||||
bbox -= np.array([[padw, padh, padw, padh]]) ###合理使用广播法则
|
||||
bbox *= np.array([[scale_w, scale_h, scale_w, scale_h]])
|
||||
kpts -= np.tile(np.array([padw, padh, 0]), 5).reshape((1,15))
|
||||
kpts *= np.tile(np.array([scale_w, scale_h, 1]), 5).reshape((1,15))
|
||||
|
||||
bboxes.append(bbox)
|
||||
scores.append(cls)
|
||||
landmarks.append(kpts)
|
||||
|
||||
bboxes = np.concatenate(bboxes, axis=0)
|
||||
scores = np.concatenate(scores, axis=0)
|
||||
landmarks = np.concatenate(landmarks, axis=0)
|
||||
|
||||
bboxes_wh = bboxes.copy()
|
||||
bboxes_wh[:, 2:4] = bboxes[:, 2:4] - bboxes[:, 0:2] ####xywh
|
||||
classIds = np.argmax(scores, axis=1)
|
||||
confidences = np.max(scores, axis=1) ####max_class_confidence
|
||||
|
||||
mask = confidences>self.conf_threshold
|
||||
bboxes_wh = bboxes_wh[mask] ###合理使用广播法则
|
||||
confidences = confidences[mask]
|
||||
classIds = classIds[mask]
|
||||
landmarks = landmarks[mask]
|
||||
|
||||
indices = cv2.dnn.NMSBoxes(bboxes_wh.tolist(), confidences.tolist(), self.conf_threshold,
|
||||
self.iou_threshold).flatten()
|
||||
if len(indices) > 0:
|
||||
mlvl_bboxes = bboxes_wh[indices]
|
||||
confidences = confidences[indices]
|
||||
classIds = classIds[indices]
|
||||
landmarks = landmarks[indices]
|
||||
return mlvl_bboxes, confidences, classIds, landmarks
|
||||
else:
|
||||
print('nothing detect')
|
||||
return np.array([]), np.array([]), np.array([]), np.array([])
|
||||
|
||||
def distance2bbox(self, points, distance, max_shape=None):
|
||||
x1 = points[:, 0] - distance[:, 0]
|
||||
y1 = points[:, 1] - distance[:, 1]
|
||||
x2 = points[:, 0] + distance[:, 2]
|
||||
y2 = points[:, 1] + distance[:, 3]
|
||||
if max_shape is not None:
|
||||
x1 = np.clip(x1, 0, max_shape[1])
|
||||
y1 = np.clip(y1, 0, max_shape[0])
|
||||
x2 = np.clip(x2, 0, max_shape[1])
|
||||
y2 = np.clip(y2, 0, max_shape[0])
|
||||
return np.stack([x1, y1, x2, y2], axis=-1)
|
||||
|
||||
def draw_detections(self, image, boxes, scores, kpts):
|
||||
for box, score, kp in zip(boxes, scores, kpts):
|
||||
x, y, w, h = box.astype(int)
|
||||
# Draw rectangle
|
||||
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 255), thickness=3)
|
||||
cv2.putText(image, "face:"+str(round(score,2)), (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), thickness=2)
|
||||
for i in range(5):
|
||||
cv2.circle(image, (int(kp[i * 3]), int(kp[i * 3 + 1])), 4, (0, 255, 0), thickness=-1)
|
||||
# cv2.putText(image, str(i), (int(kp[i * 3]), int(kp[i * 3 + 1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=1)
|
||||
return image
|
||||
|
||||
ROOT = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -0,0 +1 @@
|
||||
from .core import FaceDetector
|
||||
130
models/MuseTalk/musetalk/utils/face_detection/detection/core.py
Normal file
130
models/MuseTalk/musetalk/utils/face_detection/detection/core.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import logging
|
||||
import glob
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
|
||||
|
||||
class FaceDetector(object):
|
||||
"""An abstract class representing a face detector.
|
||||
|
||||
Any other face detection implementation must subclass it. All subclasses
|
||||
must implement ``detect_from_image``, that return a list of detected
|
||||
bounding boxes. Optionally, for speed considerations detect from path is
|
||||
recommended.
|
||||
"""
|
||||
|
||||
def __init__(self, device, verbose):
|
||||
self.device = device
|
||||
self.verbose = verbose
|
||||
|
||||
if verbose:
|
||||
if 'cpu' in device:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning("Detection running on CPU, this may be potentially slow.")
|
||||
|
||||
if 'cpu' not in device and 'cuda' not in device:
|
||||
if verbose:
|
||||
logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
|
||||
raise ValueError
|
||||
|
||||
def detect_from_image(self, tensor_or_path):
|
||||
"""Detects faces in a given image.
|
||||
|
||||
This function detects the faces present in a provided BGR(usually)
|
||||
image. The input can be either the image itself or the path to it.
|
||||
|
||||
Arguments:
|
||||
tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
|
||||
to an image or the image itself.
|
||||
|
||||
Example::
|
||||
|
||||
>>> path_to_image = 'data/image_01.jpg'
|
||||
... detected_faces = detect_from_image(path_to_image)
|
||||
[A list of bounding boxes (x1, y1, x2, y2)]
|
||||
>>> image = cv2.imread(path_to_image)
|
||||
... detected_faces = detect_from_image(image)
|
||||
[A list of bounding boxes (x1, y1, x2, y2)]
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
|
||||
"""Detects faces from all the images present in a given directory.
|
||||
|
||||
Arguments:
|
||||
path {string} -- a string containing a path that points to the folder containing the images
|
||||
|
||||
Keyword Arguments:
|
||||
extensions {list} -- list of string containing the extensions to be
|
||||
consider in the following format: ``.extension_name`` (default:
|
||||
{['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
|
||||
folder recursively (default: {False}) show_progress_bar {bool} --
|
||||
display a progressbar (default: {True})
|
||||
|
||||
Example:
|
||||
>>> directory = 'data'
|
||||
... detected_faces = detect_from_directory(directory)
|
||||
{A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
|
||||
|
||||
"""
|
||||
if self.verbose:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if len(extensions) == 0:
|
||||
if self.verbose:
|
||||
logger.error("Expected at list one extension, but none was received.")
|
||||
raise ValueError
|
||||
|
||||
if self.verbose:
|
||||
logger.info("Constructing the list of images.")
|
||||
additional_pattern = '/**/*' if recursive else '/*'
|
||||
files = []
|
||||
for extension in extensions:
|
||||
files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
|
||||
|
||||
if self.verbose:
|
||||
logger.info("Finished searching for images. %s images found", len(files))
|
||||
logger.info("Preparing to run the detection.")
|
||||
|
||||
predictions = {}
|
||||
for image_path in tqdm(files, disable=not show_progress_bar):
|
||||
if self.verbose:
|
||||
logger.info("Running the face detector on image: %s", image_path)
|
||||
predictions[image_path] = self.detect_from_image(image_path)
|
||||
|
||||
if self.verbose:
|
||||
logger.info("The detector was successfully run on all %s images", len(files))
|
||||
|
||||
return predictions
|
||||
|
||||
@property
|
||||
def reference_scale(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def reference_x_shift(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def reference_y_shift(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
|
||||
"""Convert path (represented as a string) or torch.tensor to a numpy.ndarray
|
||||
|
||||
Arguments:
|
||||
tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
|
||||
"""
|
||||
if isinstance(tensor_or_path, str):
|
||||
return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
|
||||
elif torch.is_tensor(tensor_or_path):
|
||||
# Call cpu in case its coming from cuda
|
||||
return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
|
||||
elif isinstance(tensor_or_path, np.ndarray):
|
||||
return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
|
||||
else:
|
||||
raise TypeError
|
||||
@@ -0,0 +1 @@
|
||||
from .sfd_detector import SFDDetector as FaceDetector
|
||||
@@ -0,0 +1,129 @@
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
import cv2
|
||||
import random
|
||||
import datetime
|
||||
import time
|
||||
import math
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
from iou import IOU
|
||||
except BaseException:
|
||||
# IOU cython speedup 10x
|
||||
def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
|
||||
sa = abs((ax2 - ax1) * (ay2 - ay1))
|
||||
sb = abs((bx2 - bx1) * (by2 - by1))
|
||||
x1, y1 = max(ax1, bx1), max(ay1, by1)
|
||||
x2, y2 = min(ax2, bx2), min(ay2, by2)
|
||||
w = x2 - x1
|
||||
h = y2 - y1
|
||||
if w < 0 or h < 0:
|
||||
return 0.0
|
||||
else:
|
||||
return 1.0 * w * h / (sa + sb - w * h)
|
||||
|
||||
|
||||
def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
|
||||
xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
|
||||
dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
|
||||
dw, dh = math.log(ww / aww), math.log(hh / ahh)
|
||||
return dx, dy, dw, dh
|
||||
|
||||
|
||||
def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
|
||||
xc, yc = dx * aww + axc, dy * ahh + ayc
|
||||
ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
|
||||
x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
|
||||
return x1, y1, x2, y2
|
||||
|
||||
|
||||
def nms(dets, thresh):
|
||||
if 0 == len(dets):
|
||||
return []
|
||||
x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
|
||||
xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
|
||||
ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
|
||||
|
||||
inds = np.where(ovr <= thresh)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
def encode(matched, priors, variances):
|
||||
"""Encode the variances from the priorbox layers into the ground truth boxes
|
||||
we have matched (based on jaccard overlap) with the prior boxes.
|
||||
Args:
|
||||
matched: (tensor) Coords of ground truth for each prior in point-form
|
||||
Shape: [num_priors, 4].
|
||||
priors: (tensor) Prior boxes in center-offset form
|
||||
Shape: [num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
encoded boxes (tensor), Shape: [num_priors, 4]
|
||||
"""
|
||||
|
||||
# dist b/t match center and prior's center
|
||||
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
|
||||
# encode variance
|
||||
g_cxcy /= (variances[0] * priors[:, 2:])
|
||||
# match wh / prior wh
|
||||
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
|
||||
g_wh = torch.log(g_wh) / variances[1]
|
||||
# return target for smooth_l1_loss
|
||||
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
|
||||
|
||||
|
||||
def decode(loc, priors, variances):
|
||||
"""Decode locations from predictions using priors to undo
|
||||
the encoding we did for offset regression at train time.
|
||||
Args:
|
||||
loc (tensor): location predictions for loc layers,
|
||||
Shape: [num_priors,4]
|
||||
priors (tensor): Prior boxes in center-offset form.
|
||||
Shape: [num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
decoded bounding box predictions
|
||||
"""
|
||||
|
||||
boxes = torch.cat((
|
||||
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
||||
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
||||
boxes[:, :2] -= boxes[:, 2:] / 2
|
||||
boxes[:, 2:] += boxes[:, :2]
|
||||
return boxes
|
||||
|
||||
def batch_decode(loc, priors, variances):
|
||||
"""Decode locations from predictions using priors to undo
|
||||
the encoding we did for offset regression at train time.
|
||||
Args:
|
||||
loc (tensor): location predictions for loc layers,
|
||||
Shape: [num_priors,4]
|
||||
priors (tensor): Prior boxes in center-offset form.
|
||||
Shape: [num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
decoded bounding box predictions
|
||||
"""
|
||||
|
||||
boxes = torch.cat((
|
||||
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
|
||||
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
|
||||
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
|
||||
boxes[:, :, 2:] += boxes[:, :, :2]
|
||||
return boxes
|
||||
@@ -0,0 +1,114 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import os
|
||||
import sys
|
||||
import cv2
|
||||
import random
|
||||
import datetime
|
||||
import math
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import scipy.io as sio
|
||||
import zipfile
|
||||
from .net_s3fd import s3fd
|
||||
from .bbox import *
|
||||
|
||||
|
||||
def detect(net, img, device):
|
||||
img = img - np.array([104, 117, 123])
|
||||
img = img.transpose(2, 0, 1)
|
||||
img = img.reshape((1,) + img.shape)
|
||||
|
||||
if 'cuda' in device:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
img = torch.from_numpy(img).float().to(device)
|
||||
BB, CC, HH, WW = img.size()
|
||||
with torch.no_grad():
|
||||
olist = net(img)
|
||||
|
||||
bboxlist = []
|
||||
for i in range(len(olist) // 2):
|
||||
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
||||
olist = [oelem.data.cpu() for oelem in olist]
|
||||
for i in range(len(olist) // 2):
|
||||
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
||||
FB, FC, FH, FW = ocls.size() # feature map size
|
||||
stride = 2**(i + 2) # 4,8,16,32,64,128
|
||||
anchor = stride * 4
|
||||
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
||||
for Iindex, hindex, windex in poss:
|
||||
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
||||
score = ocls[0, 1, hindex, windex]
|
||||
loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
|
||||
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
|
||||
variances = [0.1, 0.2]
|
||||
box = decode(loc, priors, variances)
|
||||
x1, y1, x2, y2 = box[0] * 1.0
|
||||
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
||||
bboxlist.append([x1, y1, x2, y2, score])
|
||||
bboxlist = np.array(bboxlist)
|
||||
if 0 == len(bboxlist):
|
||||
bboxlist = np.zeros((1, 5))
|
||||
|
||||
return bboxlist
|
||||
|
||||
def batch_detect(net, imgs, device):
|
||||
imgs = imgs - np.array([104, 117, 123])
|
||||
imgs = imgs.transpose(0, 3, 1, 2)
|
||||
|
||||
if 'cuda' in device:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
imgs = torch.from_numpy(imgs).float().to(device)
|
||||
BB, CC, HH, WW = imgs.size()
|
||||
with torch.no_grad():
|
||||
olist = net(imgs)
|
||||
# print(olist)
|
||||
|
||||
bboxlist = []
|
||||
for i in range(len(olist) // 2):
|
||||
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
||||
|
||||
olist = [oelem.cpu() for oelem in olist]
|
||||
for i in range(len(olist) // 2):
|
||||
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
||||
FB, FC, FH, FW = ocls.size() # feature map size
|
||||
stride = 2**(i + 2) # 4,8,16,32,64,128
|
||||
anchor = stride * 4
|
||||
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
||||
for Iindex, hindex, windex in poss:
|
||||
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
||||
score = ocls[:, 1, hindex, windex]
|
||||
loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
|
||||
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
|
||||
variances = [0.1, 0.2]
|
||||
box = batch_decode(loc, priors, variances)
|
||||
box = box[:, 0] * 1.0
|
||||
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
||||
bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
|
||||
bboxlist = np.array(bboxlist)
|
||||
if 0 == len(bboxlist):
|
||||
bboxlist = np.zeros((1, BB, 5))
|
||||
|
||||
return bboxlist
|
||||
|
||||
def flip_detect(net, img, device):
|
||||
img = cv2.flip(img, 1)
|
||||
b = detect(net, img, device)
|
||||
|
||||
bboxlist = np.zeros(b.shape)
|
||||
bboxlist[:, 0] = img.shape[1] - b[:, 2]
|
||||
bboxlist[:, 1] = b[:, 1]
|
||||
bboxlist[:, 2] = img.shape[1] - b[:, 0]
|
||||
bboxlist[:, 3] = b[:, 3]
|
||||
bboxlist[:, 4] = b[:, 4]
|
||||
return bboxlist
|
||||
|
||||
|
||||
def pts_to_bb(pts):
|
||||
min_x, min_y = np.min(pts, axis=0)
|
||||
max_x, max_y = np.max(pts, axis=0)
|
||||
return np.array([min_x, min_y, max_x, max_y])
|
||||
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class L2Norm(nn.Module):
|
||||
def __init__(self, n_channels, scale=1.0):
|
||||
super(L2Norm, self).__init__()
|
||||
self.n_channels = n_channels
|
||||
self.scale = scale
|
||||
self.eps = 1e-10
|
||||
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
|
||||
self.weight.data *= 0.0
|
||||
self.weight.data += self.scale
|
||||
|
||||
def forward(self, x):
|
||||
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
|
||||
x = x / norm * self.weight.view(1, -1, 1, 1)
|
||||
return x
|
||||
|
||||
|
||||
class s3fd(nn.Module):
|
||||
def __init__(self):
|
||||
super(s3fd, self).__init__()
|
||||
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
||||
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
|
||||
self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
|
||||
self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
|
||||
self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.conv3_3_norm = L2Norm(256, scale=10)
|
||||
self.conv4_3_norm = L2Norm(512, scale=8)
|
||||
self.conv5_3_norm = L2Norm(512, scale=5)
|
||||
|
||||
self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
||||
self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
||||
self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
|
||||
self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
|
||||
self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
||||
self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
||||
self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
|
||||
self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
h = F.relu(self.conv1_1(x))
|
||||
h = F.relu(self.conv1_2(h))
|
||||
h = F.max_pool2d(h, 2, 2)
|
||||
|
||||
h = F.relu(self.conv2_1(h))
|
||||
h = F.relu(self.conv2_2(h))
|
||||
h = F.max_pool2d(h, 2, 2)
|
||||
|
||||
h = F.relu(self.conv3_1(h))
|
||||
h = F.relu(self.conv3_2(h))
|
||||
h = F.relu(self.conv3_3(h))
|
||||
f3_3 = h
|
||||
h = F.max_pool2d(h, 2, 2)
|
||||
|
||||
h = F.relu(self.conv4_1(h))
|
||||
h = F.relu(self.conv4_2(h))
|
||||
h = F.relu(self.conv4_3(h))
|
||||
f4_3 = h
|
||||
h = F.max_pool2d(h, 2, 2)
|
||||
|
||||
h = F.relu(self.conv5_1(h))
|
||||
h = F.relu(self.conv5_2(h))
|
||||
h = F.relu(self.conv5_3(h))
|
||||
f5_3 = h
|
||||
h = F.max_pool2d(h, 2, 2)
|
||||
|
||||
h = F.relu(self.fc6(h))
|
||||
h = F.relu(self.fc7(h))
|
||||
ffc7 = h
|
||||
h = F.relu(self.conv6_1(h))
|
||||
h = F.relu(self.conv6_2(h))
|
||||
f6_2 = h
|
||||
h = F.relu(self.conv7_1(h))
|
||||
h = F.relu(self.conv7_2(h))
|
||||
f7_2 = h
|
||||
|
||||
f3_3 = self.conv3_3_norm(f3_3)
|
||||
f4_3 = self.conv4_3_norm(f4_3)
|
||||
f5_3 = self.conv5_3_norm(f5_3)
|
||||
|
||||
cls1 = self.conv3_3_norm_mbox_conf(f3_3)
|
||||
reg1 = self.conv3_3_norm_mbox_loc(f3_3)
|
||||
cls2 = self.conv4_3_norm_mbox_conf(f4_3)
|
||||
reg2 = self.conv4_3_norm_mbox_loc(f4_3)
|
||||
cls3 = self.conv5_3_norm_mbox_conf(f5_3)
|
||||
reg3 = self.conv5_3_norm_mbox_loc(f5_3)
|
||||
cls4 = self.fc7_mbox_conf(ffc7)
|
||||
reg4 = self.fc7_mbox_loc(ffc7)
|
||||
cls5 = self.conv6_2_mbox_conf(f6_2)
|
||||
reg5 = self.conv6_2_mbox_loc(f6_2)
|
||||
cls6 = self.conv7_2_mbox_conf(f7_2)
|
||||
reg6 = self.conv7_2_mbox_loc(f7_2)
|
||||
|
||||
# max-out background label
|
||||
chunk = torch.chunk(cls1, 4, 1)
|
||||
bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
|
||||
cls1 = torch.cat([bmax, chunk[3]], dim=1)
|
||||
|
||||
return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
|
||||
@@ -0,0 +1,59 @@
|
||||
import os
|
||||
import cv2
|
||||
from torch.utils.model_zoo import load_url
|
||||
|
||||
from ..core import FaceDetector
|
||||
|
||||
from .net_s3fd import s3fd
|
||||
from .bbox import *
|
||||
from .detect import *
|
||||
|
||||
models_urls = {
|
||||
's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
|
||||
}
|
||||
|
||||
|
||||
class SFDDetector(FaceDetector):
|
||||
def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
|
||||
super(SFDDetector, self).__init__(device, verbose)
|
||||
|
||||
# Initialise the face detector
|
||||
if not os.path.isfile(path_to_detector):
|
||||
model_weights = load_url(models_urls['s3fd'])
|
||||
else:
|
||||
model_weights = torch.load(path_to_detector)
|
||||
|
||||
self.face_detector = s3fd()
|
||||
self.face_detector.load_state_dict(model_weights)
|
||||
self.face_detector.to(device)
|
||||
self.face_detector.eval()
|
||||
|
||||
def detect_from_image(self, tensor_or_path):
|
||||
image = self.tensor_or_path_to_ndarray(tensor_or_path)
|
||||
|
||||
bboxlist = detect(self.face_detector, image, device=self.device)
|
||||
keep = nms(bboxlist, 0.3)
|
||||
bboxlist = bboxlist[keep, :]
|
||||
bboxlist = [x for x in bboxlist if x[-1] > 0.5]
|
||||
|
||||
return bboxlist
|
||||
|
||||
def detect_from_batch(self, images):
|
||||
bboxlists = batch_detect(self.face_detector, images, device=self.device)
|
||||
keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
|
||||
bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
|
||||
bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
|
||||
|
||||
return bboxlists
|
||||
|
||||
@property
|
||||
def reference_scale(self):
|
||||
return 195
|
||||
|
||||
@property
|
||||
def reference_x_shift(self):
|
||||
return 0
|
||||
|
||||
@property
|
||||
def reference_y_shift(self):
|
||||
return 0
|
||||
261
models/MuseTalk/musetalk/utils/face_detection/models.py
Normal file
261
models/MuseTalk/musetalk/utils/face_detection/models.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
|
||||
"3x3 convolution with padding"
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
||||
stride=strd, padding=padding, bias=bias)
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_planes, out_planes):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
|
||||
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
|
||||
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
|
||||
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
|
||||
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
|
||||
|
||||
if in_planes != out_planes:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.BatchNorm2d(in_planes),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(in_planes, out_planes,
|
||||
kernel_size=1, stride=1, bias=False),
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out1 = self.bn1(x)
|
||||
out1 = F.relu(out1, True)
|
||||
out1 = self.conv1(out1)
|
||||
|
||||
out2 = self.bn2(out1)
|
||||
out2 = F.relu(out2, True)
|
||||
out2 = self.conv2(out2)
|
||||
|
||||
out3 = self.bn3(out2)
|
||||
out3 = F.relu(out3, True)
|
||||
out3 = self.conv3(out3)
|
||||
|
||||
out3 = torch.cat((out1, out2, out3), 1)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(residual)
|
||||
|
||||
out3 += residual
|
||||
|
||||
return out3
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class HourGlass(nn.Module):
|
||||
def __init__(self, num_modules, depth, num_features):
|
||||
super(HourGlass, self).__init__()
|
||||
self.num_modules = num_modules
|
||||
self.depth = depth
|
||||
self.features = num_features
|
||||
|
||||
self._generate_network(self.depth)
|
||||
|
||||
def _generate_network(self, level):
|
||||
self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
|
||||
|
||||
self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
|
||||
|
||||
if level > 1:
|
||||
self._generate_network(level - 1)
|
||||
else:
|
||||
self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
|
||||
|
||||
self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
|
||||
|
||||
def _forward(self, level, inp):
|
||||
# Upper branch
|
||||
up1 = inp
|
||||
up1 = self._modules['b1_' + str(level)](up1)
|
||||
|
||||
# Lower branch
|
||||
low1 = F.avg_pool2d(inp, 2, stride=2)
|
||||
low1 = self._modules['b2_' + str(level)](low1)
|
||||
|
||||
if level > 1:
|
||||
low2 = self._forward(level - 1, low1)
|
||||
else:
|
||||
low2 = low1
|
||||
low2 = self._modules['b2_plus_' + str(level)](low2)
|
||||
|
||||
low3 = low2
|
||||
low3 = self._modules['b3_' + str(level)](low3)
|
||||
|
||||
up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
|
||||
|
||||
return up1 + up2
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward(self.depth, x)
|
||||
|
||||
|
||||
class FAN(nn.Module):
|
||||
|
||||
def __init__(self, num_modules=1):
|
||||
super(FAN, self).__init__()
|
||||
self.num_modules = num_modules
|
||||
|
||||
# Base part
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.conv2 = ConvBlock(64, 128)
|
||||
self.conv3 = ConvBlock(128, 128)
|
||||
self.conv4 = ConvBlock(128, 256)
|
||||
|
||||
# Stacking part
|
||||
for hg_module in range(self.num_modules):
|
||||
self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
|
||||
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
|
||||
self.add_module('conv_last' + str(hg_module),
|
||||
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
||||
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
|
||||
self.add_module('l' + str(hg_module), nn.Conv2d(256,
|
||||
68, kernel_size=1, stride=1, padding=0))
|
||||
|
||||
if hg_module < self.num_modules - 1:
|
||||
self.add_module(
|
||||
'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
||||
self.add_module('al' + str(hg_module), nn.Conv2d(68,
|
||||
256, kernel_size=1, stride=1, padding=0))
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.bn1(self.conv1(x)), True)
|
||||
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
|
||||
previous = x
|
||||
|
||||
outputs = []
|
||||
for i in range(self.num_modules):
|
||||
hg = self._modules['m' + str(i)](previous)
|
||||
|
||||
ll = hg
|
||||
ll = self._modules['top_m_' + str(i)](ll)
|
||||
|
||||
ll = F.relu(self._modules['bn_end' + str(i)]
|
||||
(self._modules['conv_last' + str(i)](ll)), True)
|
||||
|
||||
# Predict heatmaps
|
||||
tmp_out = self._modules['l' + str(i)](ll)
|
||||
outputs.append(tmp_out)
|
||||
|
||||
if i < self.num_modules - 1:
|
||||
ll = self._modules['bl' + str(i)](ll)
|
||||
tmp_out_ = self._modules['al' + str(i)](tmp_out)
|
||||
previous = previous + ll + tmp_out_
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class ResNetDepth(nn.Module):
|
||||
|
||||
def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
|
||||
self.inplanes = 64
|
||||
super(ResNetDepth, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.avgpool = nn.AvgPool2d(7)
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
313
models/MuseTalk/musetalk/utils/face_detection/utils.py
Normal file
313
models/MuseTalk/musetalk/utils/face_detection/utils.py
Normal file
@@ -0,0 +1,313 @@
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def _gaussian(
|
||||
size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
|
||||
height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
|
||||
mean_vert=0.5):
|
||||
# handle some defaults
|
||||
if width is None:
|
||||
width = size
|
||||
if height is None:
|
||||
height = size
|
||||
if sigma_horz is None:
|
||||
sigma_horz = sigma
|
||||
if sigma_vert is None:
|
||||
sigma_vert = sigma
|
||||
center_x = mean_horz * width + 0.5
|
||||
center_y = mean_vert * height + 0.5
|
||||
gauss = np.empty((height, width), dtype=np.float32)
|
||||
# generate kernel
|
||||
for i in range(height):
|
||||
for j in range(width):
|
||||
gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
|
||||
sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
|
||||
if normalize:
|
||||
gauss = gauss / np.sum(gauss)
|
||||
return gauss
|
||||
|
||||
|
||||
def draw_gaussian(image, point, sigma):
|
||||
# Check if the gaussian is inside
|
||||
ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
|
||||
br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
|
||||
if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
|
||||
return image
|
||||
size = 6 * sigma + 1
|
||||
g = _gaussian(size)
|
||||
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
|
||||
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
|
||||
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
|
||||
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
|
||||
assert (g_x[0] > 0 and g_y[1] > 0)
|
||||
image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
|
||||
] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
|
||||
image[image > 1] = 1
|
||||
return image
|
||||
|
||||
|
||||
def transform(point, center, scale, resolution, invert=False):
|
||||
"""Generate and affine transformation matrix.
|
||||
|
||||
Given a set of points, a center, a scale and a targer resolution, the
|
||||
function generates and affine transformation matrix. If invert is ``True``
|
||||
it will produce the inverse transformation.
|
||||
|
||||
Arguments:
|
||||
point {torch.tensor} -- the input 2D point
|
||||
center {torch.tensor or numpy.array} -- the center around which to perform the transformations
|
||||
scale {float} -- the scale of the face/object
|
||||
resolution {float} -- the output resolution
|
||||
|
||||
Keyword Arguments:
|
||||
invert {bool} -- define wherever the function should produce the direct or the
|
||||
inverse transformation matrix (default: {False})
|
||||
"""
|
||||
_pt = torch.ones(3)
|
||||
_pt[0] = point[0]
|
||||
_pt[1] = point[1]
|
||||
|
||||
h = 200.0 * scale
|
||||
t = torch.eye(3)
|
||||
t[0, 0] = resolution / h
|
||||
t[1, 1] = resolution / h
|
||||
t[0, 2] = resolution * (-center[0] / h + 0.5)
|
||||
t[1, 2] = resolution * (-center[1] / h + 0.5)
|
||||
|
||||
if invert:
|
||||
t = torch.inverse(t)
|
||||
|
||||
new_point = (torch.matmul(t, _pt))[0:2]
|
||||
|
||||
return new_point.int()
|
||||
|
||||
|
||||
def crop(image, center, scale, resolution=256.0):
|
||||
"""Center crops an image or set of heatmaps
|
||||
|
||||
Arguments:
|
||||
image {numpy.array} -- an rgb image
|
||||
center {numpy.array} -- the center of the object, usually the same as of the bounding box
|
||||
scale {float} -- scale of the face
|
||||
|
||||
Keyword Arguments:
|
||||
resolution {float} -- the size of the output cropped image (default: {256.0})
|
||||
|
||||
Returns:
|
||||
[type] -- [description]
|
||||
""" # Crop around the center point
|
||||
""" Crops the image around the center. Input is expected to be an np.ndarray """
|
||||
ul = transform([1, 1], center, scale, resolution, True)
|
||||
br = transform([resolution, resolution], center, scale, resolution, True)
|
||||
# pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
|
||||
if image.ndim > 2:
|
||||
newDim = np.array([br[1] - ul[1], br[0] - ul[0],
|
||||
image.shape[2]], dtype=np.int32)
|
||||
newImg = np.zeros(newDim, dtype=np.uint8)
|
||||
else:
|
||||
newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
|
||||
newImg = np.zeros(newDim, dtype=np.uint8)
|
||||
ht = image.shape[0]
|
||||
wd = image.shape[1]
|
||||
newX = np.array(
|
||||
[max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
|
||||
newY = np.array(
|
||||
[max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
|
||||
oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
|
||||
oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
|
||||
newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
|
||||
] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
|
||||
newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
|
||||
interpolation=cv2.INTER_LINEAR)
|
||||
return newImg
|
||||
|
||||
|
||||
def get_preds_fromhm(hm, center=None, scale=None):
|
||||
"""Obtain (x,y) coordinates given a set of N heatmaps. If the center
|
||||
and the scale is provided the function will return the points also in
|
||||
the original coordinate frame.
|
||||
|
||||
Arguments:
|
||||
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
||||
|
||||
Keyword Arguments:
|
||||
center {torch.tensor} -- the center of the bounding box (default: {None})
|
||||
scale {float} -- face scale (default: {None})
|
||||
"""
|
||||
max, idx = torch.max(
|
||||
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
||||
idx += 1
|
||||
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
||||
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
||||
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
||||
|
||||
for i in range(preds.size(0)):
|
||||
for j in range(preds.size(1)):
|
||||
hm_ = hm[i, j, :]
|
||||
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
||||
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
||||
diff = torch.FloatTensor(
|
||||
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
||||
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
||||
preds[i, j].add_(diff.sign_().mul_(.25))
|
||||
|
||||
preds.add_(-.5)
|
||||
|
||||
preds_orig = torch.zeros(preds.size())
|
||||
if center is not None and scale is not None:
|
||||
for i in range(hm.size(0)):
|
||||
for j in range(hm.size(1)):
|
||||
preds_orig[i, j] = transform(
|
||||
preds[i, j], center, scale, hm.size(2), True)
|
||||
|
||||
return preds, preds_orig
|
||||
|
||||
def get_preds_fromhm_batch(hm, centers=None, scales=None):
|
||||
"""Obtain (x,y) coordinates given a set of N heatmaps. If the centers
|
||||
and the scales is provided the function will return the points also in
|
||||
the original coordinate frame.
|
||||
|
||||
Arguments:
|
||||
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
||||
|
||||
Keyword Arguments:
|
||||
centers {torch.tensor} -- the centers of the bounding box (default: {None})
|
||||
scales {float} -- face scales (default: {None})
|
||||
"""
|
||||
max, idx = torch.max(
|
||||
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
||||
idx += 1
|
||||
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
||||
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
||||
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
||||
|
||||
for i in range(preds.size(0)):
|
||||
for j in range(preds.size(1)):
|
||||
hm_ = hm[i, j, :]
|
||||
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
||||
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
||||
diff = torch.FloatTensor(
|
||||
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
||||
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
||||
preds[i, j].add_(diff.sign_().mul_(.25))
|
||||
|
||||
preds.add_(-.5)
|
||||
|
||||
preds_orig = torch.zeros(preds.size())
|
||||
if centers is not None and scales is not None:
|
||||
for i in range(hm.size(0)):
|
||||
for j in range(hm.size(1)):
|
||||
preds_orig[i, j] = transform(
|
||||
preds[i, j], centers[i], scales[i], hm.size(2), True)
|
||||
|
||||
return preds, preds_orig
|
||||
|
||||
def shuffle_lr(parts, pairs=None):
|
||||
"""Shuffle the points left-right according to the axis of symmetry
|
||||
of the object.
|
||||
|
||||
Arguments:
|
||||
parts {torch.tensor} -- a 3D or 4D object containing the
|
||||
heatmaps.
|
||||
|
||||
Keyword Arguments:
|
||||
pairs {list of integers} -- [order of the flipped points] (default: {None})
|
||||
"""
|
||||
if pairs is None:
|
||||
pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
|
||||
26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
|
||||
34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
|
||||
40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
|
||||
62, 61, 60, 67, 66, 65]
|
||||
if parts.ndimension() == 3:
|
||||
parts = parts[pairs, ...]
|
||||
else:
|
||||
parts = parts[:, pairs, ...]
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def flip(tensor, is_label=False):
|
||||
"""Flip an image or a set of heatmaps left-right
|
||||
|
||||
Arguments:
|
||||
tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
|
||||
|
||||
Keyword Arguments:
|
||||
is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
|
||||
"""
|
||||
if not torch.is_tensor(tensor):
|
||||
tensor = torch.from_numpy(tensor)
|
||||
|
||||
if is_label:
|
||||
tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
|
||||
else:
|
||||
tensor = tensor.flip(tensor.ndimension() - 1)
|
||||
|
||||
return tensor
|
||||
|
||||
# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
|
||||
|
||||
|
||||
def appdata_dir(appname=None, roaming=False):
|
||||
""" appdata_dir(appname=None, roaming=False)
|
||||
|
||||
Get the path to the application directory, where applications are allowed
|
||||
to write user specific files (e.g. configurations). For non-user specific
|
||||
data, consider using common_appdata_dir().
|
||||
If appname is given, a subdir is appended (and created if necessary).
|
||||
If roaming is True, will prefer a roaming directory (Windows Vista/7).
|
||||
"""
|
||||
|
||||
# Define default user directory
|
||||
userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
|
||||
if userDir is None:
|
||||
userDir = os.path.expanduser('~')
|
||||
if not os.path.isdir(userDir): # pragma: no cover
|
||||
userDir = '/var/tmp' # issue #54
|
||||
|
||||
# Get system app data dir
|
||||
path = None
|
||||
if sys.platform.startswith('win'):
|
||||
path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
|
||||
path = (path2 or path1) if roaming else (path1 or path2)
|
||||
elif sys.platform.startswith('darwin'):
|
||||
path = os.path.join(userDir, 'Library', 'Application Support')
|
||||
# On Linux and as fallback
|
||||
if not (path and os.path.isdir(path)):
|
||||
path = userDir
|
||||
|
||||
# Maybe we should store things local to the executable (in case of a
|
||||
# portable distro or a frozen application that wants to be portable)
|
||||
prefix = sys.prefix
|
||||
if getattr(sys, 'frozen', None):
|
||||
prefix = os.path.abspath(os.path.dirname(sys.executable))
|
||||
for reldir in ('settings', '../settings'):
|
||||
localpath = os.path.abspath(os.path.join(prefix, reldir))
|
||||
if os.path.isdir(localpath): # pragma: no cover
|
||||
try:
|
||||
open(os.path.join(localpath, 'test.write'), 'wb').close()
|
||||
os.remove(os.path.join(localpath, 'test.write'))
|
||||
except IOError:
|
||||
pass # We cannot write in this directory
|
||||
else:
|
||||
path = localpath
|
||||
break
|
||||
|
||||
# Get path specific for this app
|
||||
if appname:
|
||||
if path == userDir:
|
||||
appname = '.' + appname.lstrip('.') # Make it a hidden directory
|
||||
path = os.path.join(path, appname)
|
||||
if not os.path.isdir(path): # pragma: no cover
|
||||
os.mkdir(path)
|
||||
|
||||
# Done
|
||||
return path
|
||||
117
models/MuseTalk/musetalk/utils/face_parsing/__init__.py
Normal file
117
models/MuseTalk/musetalk/utils/face_parsing/__init__.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import torch
|
||||
import time
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from .model import BiSeNet
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
class FaceParsing():
|
||||
def __init__(self, left_cheek_width=80, right_cheek_width=80):
|
||||
self.net = self.model_init()
|
||||
self.preprocess = self.image_preprocess()
|
||||
# Ensure all size parameters are integers
|
||||
cone_height = 21
|
||||
tail_height = 12
|
||||
total_size = cone_height + tail_height
|
||||
|
||||
# Create kernel with explicit integer dimensions
|
||||
kernel = np.zeros((total_size, total_size), dtype=np.uint8)
|
||||
center_x = total_size // 2 # Ensure center coordinates are integers
|
||||
|
||||
# Cone part
|
||||
for row in range(cone_height):
|
||||
if row < cone_height//2:
|
||||
continue
|
||||
width = int(2 * (row - cone_height//2) + 1)
|
||||
start = int(center_x - (width // 2))
|
||||
end = int(center_x + (width // 2) + 1)
|
||||
kernel[row, start:end] = 1
|
||||
|
||||
# Vertical extension part
|
||||
if cone_height > 0:
|
||||
base_width = int(kernel[cone_height-1].sum())
|
||||
else:
|
||||
base_width = 1
|
||||
|
||||
for row in range(cone_height, total_size):
|
||||
start = max(0, int(center_x - (base_width//2)))
|
||||
end = min(total_size, int(center_x + (base_width//2) + 1))
|
||||
kernel[row, start:end] = 1
|
||||
self.kernel = kernel
|
||||
|
||||
# Modify cheek erosion kernel to be flatter ellipse
|
||||
self.cheek_kernel = cv2.getStructuringElement(
|
||||
cv2.MORPH_ELLIPSE, (35, 3))
|
||||
|
||||
# Add cheek area mask (protect chin area)
|
||||
self.cheek_mask = self._create_cheek_mask(left_cheek_width=left_cheek_width, right_cheek_width=right_cheek_width)
|
||||
|
||||
def _create_cheek_mask(self, left_cheek_width=80, right_cheek_width=80):
|
||||
"""Create cheek area mask (1/4 area on both sides)"""
|
||||
mask = np.zeros((512, 512), dtype=np.uint8)
|
||||
center = 512 // 2
|
||||
cv2.rectangle(mask, (0, 0), (center - left_cheek_width, 512), 255, -1) # Left cheek
|
||||
cv2.rectangle(mask, (center + right_cheek_width, 0), (512, 512), 255, -1) # Right cheek
|
||||
return mask
|
||||
|
||||
def model_init(self,
|
||||
resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
|
||||
model_pth='./models/face-parse-bisent/79999_iter.pth'):
|
||||
net = BiSeNet(resnet_path)
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
net.load_state_dict(torch.load(model_pth))
|
||||
else:
|
||||
net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
|
||||
net.eval()
|
||||
return net
|
||||
|
||||
def image_preprocess(self):
|
||||
return transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
])
|
||||
|
||||
def __call__(self, image, size=(512, 512), mode="raw"):
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
|
||||
width, height = image.size
|
||||
with torch.no_grad():
|
||||
image = image.resize(size, Image.BILINEAR)
|
||||
img = self.preprocess(image)
|
||||
if torch.cuda.is_available():
|
||||
img = torch.unsqueeze(img, 0).cuda()
|
||||
else:
|
||||
img = torch.unsqueeze(img, 0)
|
||||
out = self.net(img)[0]
|
||||
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
||||
|
||||
# Add 14:neck, remove 10:nose and 7:8:9
|
||||
if mode == "neck":
|
||||
parsing[np.isin(parsing, [1, 11, 12, 13, 14])] = 255
|
||||
parsing[np.where(parsing!=255)] = 0
|
||||
elif mode == "jaw":
|
||||
face_region = np.isin(parsing, [1])*255
|
||||
face_region = face_region.astype(np.uint8)
|
||||
original_dilated = cv2.dilate(face_region, self.kernel, iterations=1)
|
||||
eroded = cv2.erode(original_dilated, self.cheek_kernel, iterations=2)
|
||||
face_region = cv2.bitwise_and(eroded, self.cheek_mask)
|
||||
face_region = cv2.bitwise_or(face_region, cv2.bitwise_and(original_dilated, ~self.cheek_mask))
|
||||
parsing[(face_region==255) & (~np.isin(parsing, [10]))] = 255
|
||||
parsing[np.isin(parsing, [11, 12, 13])] = 255
|
||||
parsing[np.where(parsing!=255)] = 0
|
||||
else:
|
||||
parsing[np.isin(parsing, [1, 11, 12, 13])] = 255
|
||||
parsing[np.where(parsing!=255)] = 0
|
||||
|
||||
parsing = Image.fromarray(parsing.astype(np.uint8))
|
||||
return parsing
|
||||
|
||||
if __name__ == "__main__":
|
||||
fp = FaceParsing()
|
||||
segmap = fp('154_small.png')
|
||||
segmap.save('res.png')
|
||||
|
||||
283
models/MuseTalk/musetalk/utils/face_parsing/model.py
Normal file
283
models/MuseTalk/musetalk/utils/face_parsing/model.py
Normal file
@@ -0,0 +1,283 @@
|
||||
#!/usr/bin/python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from .resnet import Resnet18
|
||||
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.conv = nn.Conv2d(in_chan,
|
||||
out_chan,
|
||||
kernel_size = ks,
|
||||
stride = stride,
|
||||
padding = padding,
|
||||
bias = False)
|
||||
self.bn = nn.BatchNorm2d(out_chan)
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = F.relu(self.bn(x))
|
||||
return x
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
class BiSeNetOutput(nn.Module):
|
||||
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
||||
super(BiSeNetOutput, self).__init__()
|
||||
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
||||
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params = [], []
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
wd_params.append(module.weight)
|
||||
if not module.bias is None:
|
||||
nowd_params.append(module.bias)
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
nowd_params += list(module.parameters())
|
||||
return wd_params, nowd_params
|
||||
|
||||
|
||||
class AttentionRefinementModule(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
||||
super(AttentionRefinementModule, self).__init__()
|
||||
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
||||
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
||||
self.bn_atten = nn.BatchNorm2d(out_chan)
|
||||
self.sigmoid_atten = nn.Sigmoid()
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
feat = self.conv(x)
|
||||
atten = F.avg_pool2d(feat, feat.size()[2:])
|
||||
atten = self.conv_atten(atten)
|
||||
atten = self.bn_atten(atten)
|
||||
atten = self.sigmoid_atten(atten)
|
||||
out = torch.mul(feat, atten)
|
||||
return out
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
|
||||
class ContextPath(nn.Module):
|
||||
def __init__(self, resnet_path, *args, **kwargs):
|
||||
super(ContextPath, self).__init__()
|
||||
self.resnet = Resnet18(resnet_path)
|
||||
self.arm16 = AttentionRefinementModule(256, 128)
|
||||
self.arm32 = AttentionRefinementModule(512, 128)
|
||||
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||||
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||||
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
||||
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
H0, W0 = x.size()[2:]
|
||||
feat8, feat16, feat32 = self.resnet(x)
|
||||
H8, W8 = feat8.size()[2:]
|
||||
H16, W16 = feat16.size()[2:]
|
||||
H32, W32 = feat32.size()[2:]
|
||||
|
||||
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
||||
avg = self.conv_avg(avg)
|
||||
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
||||
|
||||
feat32_arm = self.arm32(feat32)
|
||||
feat32_sum = feat32_arm + avg_up
|
||||
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
||||
feat32_up = self.conv_head32(feat32_up)
|
||||
|
||||
feat16_arm = self.arm16(feat16)
|
||||
feat16_sum = feat16_arm + feat32_up
|
||||
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
||||
feat16_up = self.conv_head16(feat16_up)
|
||||
|
||||
return feat8, feat16_up, feat32_up # x8, x8, x16
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params = [], []
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
wd_params.append(module.weight)
|
||||
if not module.bias is None:
|
||||
nowd_params.append(module.bias)
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
nowd_params += list(module.parameters())
|
||||
return wd_params, nowd_params
|
||||
|
||||
|
||||
### This is not used, since I replace this with the resnet feature with the same size
|
||||
class SpatialPath(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SpatialPath, self).__init__()
|
||||
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
||||
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
||||
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
||||
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
feat = self.conv1(x)
|
||||
feat = self.conv2(feat)
|
||||
feat = self.conv3(feat)
|
||||
feat = self.conv_out(feat)
|
||||
return feat
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params = [], []
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
wd_params.append(module.weight)
|
||||
if not module.bias is None:
|
||||
nowd_params.append(module.bias)
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
nowd_params += list(module.parameters())
|
||||
return wd_params, nowd_params
|
||||
|
||||
|
||||
class FeatureFusionModule(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
||||
super(FeatureFusionModule, self).__init__()
|
||||
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
||||
self.conv1 = nn.Conv2d(out_chan,
|
||||
out_chan//4,
|
||||
kernel_size = 1,
|
||||
stride = 1,
|
||||
padding = 0,
|
||||
bias = False)
|
||||
self.conv2 = nn.Conv2d(out_chan//4,
|
||||
out_chan,
|
||||
kernel_size = 1,
|
||||
stride = 1,
|
||||
padding = 0,
|
||||
bias = False)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, fsp, fcp):
|
||||
fcat = torch.cat([fsp, fcp], dim=1)
|
||||
feat = self.convblk(fcat)
|
||||
atten = F.avg_pool2d(feat, feat.size()[2:])
|
||||
atten = self.conv1(atten)
|
||||
atten = self.relu(atten)
|
||||
atten = self.conv2(atten)
|
||||
atten = self.sigmoid(atten)
|
||||
feat_atten = torch.mul(feat, atten)
|
||||
feat_out = feat_atten + feat
|
||||
return feat_out
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params = [], []
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
wd_params.append(module.weight)
|
||||
if not module.bias is None:
|
||||
nowd_params.append(module.bias)
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
nowd_params += list(module.parameters())
|
||||
return wd_params, nowd_params
|
||||
|
||||
|
||||
class BiSeNet(nn.Module):
|
||||
def __init__(self, resnet_path='models/resnet18-5c106cde.pth', n_classes=19, *args, **kwargs):
|
||||
super(BiSeNet, self).__init__()
|
||||
self.cp = ContextPath(resnet_path)
|
||||
## here self.sp is deleted
|
||||
self.ffm = FeatureFusionModule(256, 256)
|
||||
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
||||
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
|
||||
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
H, W = x.size()[2:]
|
||||
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
|
||||
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
|
||||
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
||||
|
||||
feat_out = self.conv_out(feat_fuse)
|
||||
feat_out16 = self.conv_out16(feat_cp8)
|
||||
feat_out32 = self.conv_out32(feat_cp16)
|
||||
|
||||
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
|
||||
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
|
||||
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
|
||||
return feat_out, feat_out16, feat_out32
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
||||
for name, child in self.named_children():
|
||||
child_wd_params, child_nowd_params = child.get_params()
|
||||
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
|
||||
lr_mul_wd_params += child_wd_params
|
||||
lr_mul_nowd_params += child_nowd_params
|
||||
else:
|
||||
wd_params += child_wd_params
|
||||
nowd_params += child_nowd_params
|
||||
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
net = BiSeNet(19)
|
||||
net.cuda()
|
||||
net.eval()
|
||||
in_ten = torch.randn(16, 3, 640, 480).cuda()
|
||||
out, out16, out32 = net(in_ten)
|
||||
print(out.shape)
|
||||
|
||||
net.get_params()
|
||||
109
models/MuseTalk/musetalk/utils/face_parsing/resnet.py
Normal file
109
models/MuseTalk/musetalk/utils/face_parsing/resnet.py
Normal file
@@ -0,0 +1,109 @@
|
||||
#!/usr/bin/python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo as modelzoo
|
||||
|
||||
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
||||
|
||||
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, stride=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
||||
self.bn1 = nn.BatchNorm2d(out_chan)
|
||||
self.conv2 = conv3x3(out_chan, out_chan)
|
||||
self.bn2 = nn.BatchNorm2d(out_chan)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = None
|
||||
if in_chan != out_chan or stride != 1:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_chan, out_chan,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(out_chan),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
residual = self.conv1(x)
|
||||
residual = F.relu(self.bn1(residual))
|
||||
residual = self.conv2(residual)
|
||||
residual = self.bn2(residual)
|
||||
|
||||
shortcut = x
|
||||
if self.downsample is not None:
|
||||
shortcut = self.downsample(x)
|
||||
|
||||
out = shortcut + residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
||||
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
||||
for i in range(bnum-1):
|
||||
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class Resnet18(nn.Module):
|
||||
def __init__(self, model_path):
|
||||
super(Resnet18, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
||||
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
||||
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
||||
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
||||
self.init_weight(model_path)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = F.relu(self.bn1(x))
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
feat8 = self.layer2(x) # 1/8
|
||||
feat16 = self.layer3(feat8) # 1/16
|
||||
feat32 = self.layer4(feat16) # 1/32
|
||||
return feat8, feat16, feat32
|
||||
|
||||
def init_weight(self, model_path):
|
||||
state_dict = torch.load(model_path) #modelzoo.load_url(resnet18_url)
|
||||
self_state_dict = self.state_dict()
|
||||
for k, v in state_dict.items():
|
||||
if 'fc' in k: continue
|
||||
self_state_dict.update({k: v})
|
||||
self.load_state_dict(self_state_dict)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params = [], []
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
wd_params.append(module.weight)
|
||||
if not module.bias is None:
|
||||
nowd_params.append(module.bias)
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
nowd_params += list(module.parameters())
|
||||
return wd_params, nowd_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
net = Resnet18()
|
||||
x = torch.randn(16, 3, 224, 224)
|
||||
out = net(x)
|
||||
print(out[0].size())
|
||||
print(out[1].size())
|
||||
print(out[2].size())
|
||||
net.get_params()
|
||||
155
models/MuseTalk/musetalk/utils/preprocessing.py
Normal file
155
models/MuseTalk/musetalk/utils/preprocessing.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import sys
|
||||
from face_detection import FaceAlignment,LandmarksType
|
||||
from os import listdir, path
|
||||
import subprocess
|
||||
import numpy as np
|
||||
import cv2
|
||||
import pickle
|
||||
import os
|
||||
import json
|
||||
from mmpose.apis import inference_topdown, init_model
|
||||
from mmpose.structures import merge_data_samples
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
# initialize the mmpose model
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
|
||||
checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth'
|
||||
model = init_model(config_file, checkpoint_file, device=device)
|
||||
|
||||
# initialize the face detection model
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
fa = FaceAlignment(LandmarksType._2D, flip_input=False,device=device)
|
||||
|
||||
# maker if the bbox is not sufficient
|
||||
coord_placeholder = (0.0,0.0,0.0,0.0)
|
||||
|
||||
def resize_landmark(landmark, w, h, new_w, new_h):
|
||||
w_ratio = new_w / w
|
||||
h_ratio = new_h / h
|
||||
landmark_norm = landmark / [w, h]
|
||||
landmark_resized = landmark_norm * [new_w, new_h]
|
||||
return landmark_resized
|
||||
|
||||
def read_imgs(img_list):
|
||||
frames = []
|
||||
print('reading images...')
|
||||
for img_path in tqdm(img_list):
|
||||
frame = cv2.imread(img_path)
|
||||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
def get_bbox_range(img_list,upperbondrange =0):
|
||||
frames = read_imgs(img_list)
|
||||
batch_size_fa = 1
|
||||
batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
|
||||
coords_list = []
|
||||
landmarks = []
|
||||
if upperbondrange != 0:
|
||||
print('get key_landmark and face bounding boxes with the bbox_shift:',upperbondrange)
|
||||
else:
|
||||
print('get key_landmark and face bounding boxes with the default value')
|
||||
average_range_minus = []
|
||||
average_range_plus = []
|
||||
for fb in tqdm(batches):
|
||||
results = inference_topdown(model, np.asarray(fb)[0])
|
||||
results = merge_data_samples(results)
|
||||
keypoints = results.pred_instances.keypoints
|
||||
face_land_mark= keypoints[0][23:91]
|
||||
face_land_mark = face_land_mark.astype(np.int32)
|
||||
|
||||
# get bounding boxes by face detetion
|
||||
bbox = fa.get_detections_for_batch(np.asarray(fb))
|
||||
|
||||
# adjust the bounding box refer to landmark
|
||||
# Add the bounding box to a tuple and append it to the coordinates list
|
||||
for j, f in enumerate(bbox):
|
||||
if f is None: # no face in the image
|
||||
coords_list += [coord_placeholder]
|
||||
continue
|
||||
|
||||
half_face_coord = face_land_mark[29]#np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
|
||||
range_minus = (face_land_mark[30]- face_land_mark[29])[1]
|
||||
range_plus = (face_land_mark[29]- face_land_mark[28])[1]
|
||||
average_range_minus.append(range_minus)
|
||||
average_range_plus.append(range_plus)
|
||||
if upperbondrange != 0:
|
||||
half_face_coord[1] = upperbondrange+half_face_coord[1] #手动调整 + 向下(偏29) - 向上(偏28)
|
||||
|
||||
text_range=f"Total frame:「{len(frames)}」 Manually adjust range : [ -{int(sum(average_range_minus) / len(average_range_minus))}~{int(sum(average_range_plus) / len(average_range_plus))} ] , the current value: {upperbondrange}"
|
||||
return text_range
|
||||
|
||||
|
||||
def get_landmark_and_bbox(img_list,upperbondrange =0):
|
||||
frames = read_imgs(img_list)
|
||||
batch_size_fa = 1
|
||||
batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
|
||||
coords_list = []
|
||||
landmarks = []
|
||||
if upperbondrange != 0:
|
||||
print('get key_landmark and face bounding boxes with the bbox_shift:',upperbondrange)
|
||||
else:
|
||||
print('get key_landmark and face bounding boxes with the default value')
|
||||
average_range_minus = []
|
||||
average_range_plus = []
|
||||
for fb in tqdm(batches):
|
||||
results = inference_topdown(model, np.asarray(fb)[0])
|
||||
results = merge_data_samples(results)
|
||||
keypoints = results.pred_instances.keypoints
|
||||
face_land_mark= keypoints[0][23:91]
|
||||
face_land_mark = face_land_mark.astype(np.int32)
|
||||
|
||||
# get bounding boxes by face detetion
|
||||
bbox = fa.get_detections_for_batch(np.asarray(fb))
|
||||
|
||||
# adjust the bounding box refer to landmark
|
||||
# Add the bounding box to a tuple and append it to the coordinates list
|
||||
for j, f in enumerate(bbox):
|
||||
if f is None: # no face in the image
|
||||
coords_list += [coord_placeholder]
|
||||
continue
|
||||
|
||||
half_face_coord = face_land_mark[29]#np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
|
||||
range_minus = (face_land_mark[30]- face_land_mark[29])[1]
|
||||
range_plus = (face_land_mark[29]- face_land_mark[28])[1]
|
||||
average_range_minus.append(range_minus)
|
||||
average_range_plus.append(range_plus)
|
||||
if upperbondrange != 0:
|
||||
half_face_coord[1] = upperbondrange+half_face_coord[1] #手动调整 + 向下(偏29) - 向上(偏28)
|
||||
half_face_dist = np.max(face_land_mark[:,1]) - half_face_coord[1]
|
||||
min_upper_bond = 0
|
||||
upper_bond = max(min_upper_bond, half_face_coord[1] - half_face_dist)
|
||||
|
||||
f_landmark = (np.min(face_land_mark[:, 0]),int(upper_bond),np.max(face_land_mark[:, 0]),np.max(face_land_mark[:,1]))
|
||||
x1, y1, x2, y2 = f_landmark
|
||||
|
||||
if y2-y1<=0 or x2-x1<=0 or x1<0: # if the landmark bbox is not suitable, reuse the bbox
|
||||
coords_list += [f]
|
||||
w,h = f[2]-f[0], f[3]-f[1]
|
||||
print("error bbox:",f)
|
||||
else:
|
||||
coords_list += [f_landmark]
|
||||
|
||||
print("********************************************bbox_shift parameter adjustment**********************************************************")
|
||||
print(f"Total frame:「{len(frames)}」 Manually adjust range : [ -{int(sum(average_range_minus) / len(average_range_minus))}~{int(sum(average_range_plus) / len(average_range_plus))} ] , the current value: {upperbondrange}")
|
||||
print("*************************************************************************************************************************************")
|
||||
return coords_list,frames
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
img_list = ["./results/lyria/00000.png","./results/lyria/00001.png","./results/lyria/00002.png","./results/lyria/00003.png"]
|
||||
crop_coord_path = "./coord_face.pkl"
|
||||
coords_list,full_frames = get_landmark_and_bbox(img_list)
|
||||
with open(crop_coord_path, 'wb') as f:
|
||||
pickle.dump(coords_list, f)
|
||||
|
||||
for bbox, frame in zip(coords_list,full_frames):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
print('Cropped shape', crop_frame.shape)
|
||||
|
||||
#cv2.imwrite(path.join(save_dir, '{}.png'.format(i)),full_frames[i][0][y1:y2, x1:x2])
|
||||
print(coords_list)
|
||||
337
models/MuseTalk/musetalk/utils/training_utils.py
Normal file
337
models/MuseTalk/musetalk/utils/training_utils.py
Normal file
@@ -0,0 +1,337 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import WhisperModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from omegaconf import OmegaConf
|
||||
from einops import rearrange
|
||||
|
||||
from musetalk.models.syncnet import SyncNet
|
||||
from musetalk.loss.discriminator import MultiScaleDiscriminator, DiscriminatorFullModel
|
||||
from musetalk.loss.basic_loss import Interpolate
|
||||
import musetalk.loss.vgg_face as vgg_face
|
||||
from musetalk.data.dataset import PortraitDataset
|
||||
from musetalk.utils.utils import (
|
||||
get_image_pred,
|
||||
process_audio_features,
|
||||
process_and_save_images
|
||||
)
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_latents,
|
||||
timesteps,
|
||||
audio_prompts,
|
||||
):
|
||||
model_pred = self.unet(
|
||||
input_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states=audio_prompts
|
||||
).sample
|
||||
return model_pred
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def initialize_models_and_optimizers(cfg, accelerator, weight_dtype):
|
||||
"""Initialize models and optimizers"""
|
||||
model_dict = {
|
||||
'vae': None,
|
||||
'unet': None,
|
||||
'net': None,
|
||||
'wav2vec': None,
|
||||
'optimizer': None,
|
||||
'lr_scheduler': None,
|
||||
'scheduler_max_steps': None,
|
||||
'trainable_params': None
|
||||
}
|
||||
|
||||
model_dict['vae'] = AutoencoderKL.from_pretrained(
|
||||
cfg.pretrained_model_name_or_path,
|
||||
subfolder=cfg.vae_type,
|
||||
)
|
||||
|
||||
unet_config_file = os.path.join(
|
||||
cfg.pretrained_model_name_or_path,
|
||||
cfg.unet_sub_folder + "/musetalk.json"
|
||||
)
|
||||
|
||||
with open(unet_config_file, 'r') as f:
|
||||
unet_config = json.load(f)
|
||||
model_dict['unet'] = UNet2DConditionModel(**unet_config)
|
||||
|
||||
if not cfg.random_init_unet:
|
||||
pretrained_unet_path = os.path.join(cfg.pretrained_model_name_or_path, cfg.unet_sub_folder, "pytorch_model.bin")
|
||||
print(f"### Loading existing unet weights from {pretrained_unet_path}. ###")
|
||||
checkpoint = torch.load(pretrained_unet_path, map_location=accelerator.device)
|
||||
model_dict['unet'].load_state_dict(checkpoint)
|
||||
|
||||
unet_params = [p.numel() for n, p in model_dict['unet'].named_parameters()]
|
||||
logger.info(f"unet {sum(unet_params) / 1e6}M-parameter")
|
||||
|
||||
model_dict['vae'].requires_grad_(False)
|
||||
model_dict['unet'].requires_grad_(True)
|
||||
|
||||
model_dict['vae'].to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
model_dict['net'] = Net(model_dict['unet'])
|
||||
|
||||
model_dict['wav2vec'] = WhisperModel.from_pretrained(cfg.whisper_path).to(
|
||||
device="cuda", dtype=weight_dtype).eval()
|
||||
model_dict['wav2vec'].requires_grad_(False)
|
||||
|
||||
if cfg.solver.gradient_checkpointing:
|
||||
model_dict['unet'].enable_gradient_checkpointing()
|
||||
|
||||
if cfg.solver.scale_lr:
|
||||
learning_rate = (
|
||||
cfg.solver.learning_rate
|
||||
* cfg.solver.gradient_accumulation_steps
|
||||
* cfg.data.train_bs
|
||||
* accelerator.num_processes
|
||||
)
|
||||
else:
|
||||
learning_rate = cfg.solver.learning_rate
|
||||
|
||||
if cfg.solver.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
||||
)
|
||||
optimizer_cls = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
model_dict['trainable_params'] = list(filter(lambda p: p.requires_grad, model_dict['net'].parameters()))
|
||||
if accelerator.is_main_process:
|
||||
print('trainable params')
|
||||
for n, p in model_dict['net'].named_parameters():
|
||||
if p.requires_grad:
|
||||
print(n)
|
||||
|
||||
model_dict['optimizer'] = optimizer_cls(
|
||||
model_dict['trainable_params'],
|
||||
lr=learning_rate,
|
||||
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
|
||||
weight_decay=cfg.solver.adam_weight_decay,
|
||||
eps=cfg.solver.adam_epsilon,
|
||||
)
|
||||
|
||||
model_dict['scheduler_max_steps'] = cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps
|
||||
model_dict['lr_scheduler'] = get_scheduler(
|
||||
cfg.solver.lr_scheduler,
|
||||
optimizer=model_dict['optimizer'],
|
||||
num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps,
|
||||
num_training_steps=model_dict['scheduler_max_steps'],
|
||||
)
|
||||
|
||||
return model_dict
|
||||
|
||||
def initialize_dataloaders(cfg):
|
||||
"""Initialize training and validation dataloaders"""
|
||||
dataloader_dict = {
|
||||
'train_dataset': None,
|
||||
'val_dataset': None,
|
||||
'train_dataloader': None,
|
||||
'val_dataloader': None
|
||||
}
|
||||
|
||||
dataloader_dict['train_dataset'] = PortraitDataset(cfg={
|
||||
'image_size': cfg.data.image_size,
|
||||
'T': cfg.data.n_sample_frames,
|
||||
"sample_method": cfg.data.sample_method,
|
||||
'top_k_ratio': cfg.data.top_k_ratio,
|
||||
"contorl_face_min_size": cfg.data.contorl_face_min_size,
|
||||
"dataset_key": cfg.data.dataset_key,
|
||||
"padding_pixel_mouth": cfg.padding_pixel_mouth,
|
||||
"whisper_path": cfg.whisper_path,
|
||||
"min_face_size": cfg.data.min_face_size,
|
||||
"cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
|
||||
"cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
|
||||
"crop_type": cfg.crop_type,
|
||||
"random_margin_method": cfg.random_margin_method,
|
||||
})
|
||||
|
||||
dataloader_dict['train_dataloader'] = torch.utils.data.DataLoader(
|
||||
dataloader_dict['train_dataset'],
|
||||
batch_size=cfg.data.train_bs,
|
||||
shuffle=True,
|
||||
num_workers=cfg.data.num_workers,
|
||||
)
|
||||
|
||||
dataloader_dict['val_dataset'] = PortraitDataset(cfg={
|
||||
'image_size': cfg.data.image_size,
|
||||
'T': cfg.data.n_sample_frames,
|
||||
"sample_method": cfg.data.sample_method,
|
||||
'top_k_ratio': cfg.data.top_k_ratio,
|
||||
"contorl_face_min_size": cfg.data.contorl_face_min_size,
|
||||
"dataset_key": cfg.data.dataset_key,
|
||||
"padding_pixel_mouth": cfg.padding_pixel_mouth,
|
||||
"whisper_path": cfg.whisper_path,
|
||||
"min_face_size": cfg.data.min_face_size,
|
||||
"cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
|
||||
"cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
|
||||
"crop_type": cfg.crop_type,
|
||||
"random_margin_method": cfg.random_margin_method,
|
||||
})
|
||||
|
||||
dataloader_dict['val_dataloader'] = torch.utils.data.DataLoader(
|
||||
dataloader_dict['val_dataset'],
|
||||
batch_size=cfg.data.train_bs,
|
||||
shuffle=True,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
return dataloader_dict
|
||||
|
||||
def initialize_loss_functions(cfg, accelerator, scheduler_max_steps):
|
||||
"""Initialize loss functions and discriminators"""
|
||||
loss_dict = {
|
||||
'L1_loss': nn.L1Loss(reduction='mean'),
|
||||
'discriminator': None,
|
||||
'mouth_discriminator': None,
|
||||
'optimizer_D': None,
|
||||
'mouth_optimizer_D': None,
|
||||
'scheduler_D': None,
|
||||
'mouth_scheduler_D': None,
|
||||
'disc_scales': None,
|
||||
'discriminator_full': None,
|
||||
'mouth_discriminator_full': None
|
||||
}
|
||||
|
||||
if cfg.loss_params.gan_loss > 0:
|
||||
loss_dict['discriminator'] = MultiScaleDiscriminator(
|
||||
**cfg.model_params.discriminator_params).to(accelerator.device)
|
||||
loss_dict['discriminator_full'] = DiscriminatorFullModel(loss_dict['discriminator'])
|
||||
loss_dict['disc_scales'] = cfg.model_params.discriminator_params.scales
|
||||
loss_dict['optimizer_D'] = optim.AdamW(
|
||||
loss_dict['discriminator'].parameters(),
|
||||
lr=cfg.discriminator_train_params.lr,
|
||||
weight_decay=cfg.discriminator_train_params.weight_decay,
|
||||
betas=cfg.discriminator_train_params.betas,
|
||||
eps=cfg.discriminator_train_params.eps)
|
||||
loss_dict['scheduler_D'] = CosineAnnealingLR(
|
||||
loss_dict['optimizer_D'],
|
||||
T_max=scheduler_max_steps,
|
||||
eta_min=1e-6
|
||||
)
|
||||
|
||||
if cfg.loss_params.mouth_gan_loss > 0:
|
||||
loss_dict['mouth_discriminator'] = MultiScaleDiscriminator(
|
||||
**cfg.model_params.discriminator_params).to(accelerator.device)
|
||||
loss_dict['mouth_discriminator_full'] = DiscriminatorFullModel(loss_dict['mouth_discriminator'])
|
||||
loss_dict['mouth_optimizer_D'] = optim.AdamW(
|
||||
loss_dict['mouth_discriminator'].parameters(),
|
||||
lr=cfg.discriminator_train_params.lr,
|
||||
weight_decay=cfg.discriminator_train_params.weight_decay,
|
||||
betas=cfg.discriminator_train_params.betas,
|
||||
eps=cfg.discriminator_train_params.eps)
|
||||
loss_dict['mouth_scheduler_D'] = CosineAnnealingLR(
|
||||
loss_dict['mouth_optimizer_D'],
|
||||
T_max=scheduler_max_steps,
|
||||
eta_min=1e-6
|
||||
)
|
||||
|
||||
return loss_dict
|
||||
|
||||
def initialize_syncnet(cfg, accelerator, weight_dtype):
|
||||
"""Initialize SyncNet model"""
|
||||
if cfg.loss_params.sync_loss > 0 or cfg.use_adapted_weight:
|
||||
if cfg.data.n_sample_frames != 16:
|
||||
raise ValueError(
|
||||
f"Invalid n_sample_frames {cfg.data.n_sample_frames} for sync_loss, it should be 16."
|
||||
)
|
||||
syncnet_config = OmegaConf.load(cfg.syncnet_config_path)
|
||||
syncnet = SyncNet(OmegaConf.to_container(
|
||||
syncnet_config.model)).to(accelerator.device)
|
||||
print(
|
||||
f"Load SyncNet checkpoint from: {syncnet_config.ckpt.inference_ckpt_path}")
|
||||
checkpoint = torch.load(
|
||||
syncnet_config.ckpt.inference_ckpt_path, map_location=accelerator.device)
|
||||
syncnet.load_state_dict(checkpoint["state_dict"])
|
||||
syncnet.to(dtype=weight_dtype)
|
||||
syncnet.requires_grad_(False)
|
||||
syncnet.eval()
|
||||
return syncnet
|
||||
return None
|
||||
|
||||
def initialize_vgg(cfg, accelerator):
|
||||
"""Initialize VGG model"""
|
||||
if cfg.loss_params.vgg_loss > 0:
|
||||
vgg_IN = vgg_face.Vgg19().to(accelerator.device,)
|
||||
pyramid = vgg_face.ImagePyramide(
|
||||
cfg.loss_params.pyramid_scale, 3).to(accelerator.device)
|
||||
vgg_IN.eval()
|
||||
downsampler = Interpolate(
|
||||
size=(224, 224), mode='bilinear', align_corners=False).to(accelerator.device)
|
||||
return vgg_IN, pyramid, downsampler
|
||||
return None, None, None
|
||||
|
||||
def validation(
|
||||
cfg,
|
||||
val_dataloader,
|
||||
net,
|
||||
vae,
|
||||
wav2vec,
|
||||
accelerator,
|
||||
save_dir,
|
||||
global_step,
|
||||
weight_dtype,
|
||||
syncnet_score=1,
|
||||
):
|
||||
"""Validation function for model evaluation"""
|
||||
net.eval() # Set the model to evaluation mode
|
||||
for batch in val_dataloader:
|
||||
# The same ref_latents
|
||||
ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
|
||||
accelerator.device, non_blocking=True
|
||||
)
|
||||
pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
|
||||
accelerator.device, non_blocking=True
|
||||
)
|
||||
bsz, num_frames, c, h, w = ref_pixel_values.shape
|
||||
|
||||
audio_prompts = process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype)
|
||||
# audio feature for unet
|
||||
audio_prompts = rearrange(
|
||||
audio_prompts,
|
||||
'b f c h w-> (b f) c h w'
|
||||
)
|
||||
audio_prompts = rearrange(
|
||||
audio_prompts,
|
||||
'(b f) c h w -> (b f) (c h) w',
|
||||
b=bsz
|
||||
)
|
||||
# different masked_latents
|
||||
image_pred_train = get_image_pred(
|
||||
pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
|
||||
image_pred_infer = get_image_pred(
|
||||
ref_pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
|
||||
|
||||
process_and_save_images(
|
||||
batch,
|
||||
image_pred_train,
|
||||
image_pred_infer,
|
||||
save_dir,
|
||||
global_step,
|
||||
accelerator,
|
||||
cfg.num_images_to_keep,
|
||||
syncnet_score
|
||||
)
|
||||
# only infer 1 image in validation
|
||||
break
|
||||
net.train() # Set the model back to training mode
|
||||
319
models/MuseTalk/musetalk/utils/utils.py
Normal file
319
models/MuseTalk/musetalk/utils/utils.py
Normal file
@@ -0,0 +1,319 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Union, List
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import shutil
|
||||
import os.path as osp
|
||||
|
||||
from musetalk.models.vae import VAE
|
||||
from musetalk.models.unet import UNet,PositionalEncoding
|
||||
|
||||
|
||||
def load_all_model(
|
||||
unet_model_path=os.path.join("models", "musetalkV15", "unet.pth"),
|
||||
vae_type="sd-vae",
|
||||
unet_config=os.path.join("models", "musetalkV15", "musetalk.json"),
|
||||
device=None,
|
||||
):
|
||||
vae = VAE(
|
||||
model_path = os.path.join("models", vae_type),
|
||||
)
|
||||
print(f"load unet model from {unet_model_path}")
|
||||
unet = UNet(
|
||||
unet_config=unet_config,
|
||||
model_path=unet_model_path,
|
||||
device=device
|
||||
)
|
||||
pe = PositionalEncoding(d_model=384)
|
||||
return vae, unet, pe
|
||||
|
||||
def get_file_type(video_path):
|
||||
_, ext = os.path.splitext(video_path)
|
||||
|
||||
if ext.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
|
||||
return 'image'
|
||||
elif ext.lower() in ['.avi', '.mp4', '.mov', '.flv', '.mkv']:
|
||||
return 'video'
|
||||
else:
|
||||
return 'unsupported'
|
||||
|
||||
def get_video_fps(video_path):
|
||||
video = cv2.VideoCapture(video_path)
|
||||
fps = video.get(cv2.CAP_PROP_FPS)
|
||||
video.release()
|
||||
return fps
|
||||
|
||||
def datagen(
|
||||
whisper_chunks,
|
||||
vae_encode_latents,
|
||||
batch_size=8,
|
||||
delay_frame=0,
|
||||
device="cuda:0",
|
||||
):
|
||||
whisper_batch, latent_batch = [], []
|
||||
for i, w in enumerate(whisper_chunks):
|
||||
idx = (i+delay_frame)%len(vae_encode_latents)
|
||||
latent = vae_encode_latents[idx]
|
||||
whisper_batch.append(w)
|
||||
latent_batch.append(latent)
|
||||
|
||||
if len(latent_batch) >= batch_size:
|
||||
whisper_batch = torch.stack(whisper_batch)
|
||||
latent_batch = torch.cat(latent_batch, dim=0)
|
||||
yield whisper_batch, latent_batch
|
||||
whisper_batch, latent_batch = [], []
|
||||
|
||||
# the last batch may smaller than batch size
|
||||
if len(latent_batch) > 0:
|
||||
whisper_batch = torch.stack(whisper_batch)
|
||||
latent_batch = torch.cat(latent_batch, dim=0)
|
||||
|
||||
yield whisper_batch.to(device), latent_batch.to(device)
|
||||
|
||||
def cast_training_params(
|
||||
model: Union[torch.nn.Module, List[torch.nn.Module]],
|
||||
dtype=torch.float32,
|
||||
):
|
||||
if not isinstance(model, list):
|
||||
model = [model]
|
||||
for m in model:
|
||||
for param in m.parameters():
|
||||
# only upcast trainable parameters into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(dtype)
|
||||
|
||||
def rand_log_normal(
|
||||
shape,
|
||||
loc=0.,
|
||||
scale=1.,
|
||||
device='cpu',
|
||||
dtype=torch.float32,
|
||||
generator=None
|
||||
):
|
||||
"""Draws samples from an lognormal distribution."""
|
||||
rnd_normal = torch.randn(
|
||||
shape, device=device, dtype=dtype, generator=generator) # N(0, I)
|
||||
sigma = (rnd_normal * scale + loc).exp()
|
||||
return sigma
|
||||
|
||||
def get_mouth_region(frames, image_pred, pixel_values_face_mask):
|
||||
# Initialize lists to store the results for each image in the batch
|
||||
mouth_real_list = []
|
||||
mouth_generated_list = []
|
||||
|
||||
# Process each image in the batch
|
||||
for b in range(frames.shape[0]):
|
||||
# Find the non-zero area in the face mask
|
||||
non_zero_indices = torch.nonzero(pixel_values_face_mask[b])
|
||||
# If there are no non-zero indices, skip this image
|
||||
if non_zero_indices.numel() == 0:
|
||||
continue
|
||||
|
||||
min_y, max_y = torch.min(non_zero_indices[:, 1]), torch.max(
|
||||
non_zero_indices[:, 1])
|
||||
min_x, max_x = torch.min(non_zero_indices[:, 2]), torch.max(
|
||||
non_zero_indices[:, 2])
|
||||
|
||||
# Crop the frames and image_pred according to the non-zero area
|
||||
frames_cropped = frames[b, :, min_y:max_y, min_x:max_x]
|
||||
image_pred_cropped = image_pred[b, :, min_y:max_y, min_x:max_x]
|
||||
# Resize the cropped images to 256*256
|
||||
frames_resized = F.interpolate(frames_cropped.unsqueeze(
|
||||
0), size=(256, 256), mode='bilinear', align_corners=False)
|
||||
image_pred_resized = F.interpolate(image_pred_cropped.unsqueeze(
|
||||
0), size=(256, 256), mode='bilinear', align_corners=False)
|
||||
|
||||
# Append the resized images to the result lists
|
||||
mouth_real_list.append(frames_resized)
|
||||
mouth_generated_list.append(image_pred_resized)
|
||||
|
||||
# Convert the lists to tensors if they are not empty
|
||||
mouth_real = torch.cat(mouth_real_list, dim=0) if mouth_real_list else None
|
||||
mouth_generated = torch.cat(
|
||||
mouth_generated_list, dim=0) if mouth_generated_list else None
|
||||
|
||||
return mouth_real, mouth_generated
|
||||
|
||||
def get_image_pred(pixel_values,
|
||||
ref_pixel_values,
|
||||
audio_prompts,
|
||||
vae,
|
||||
net,
|
||||
weight_dtype):
|
||||
with torch.no_grad():
|
||||
bsz, num_frames, c, h, w = pixel_values.shape
|
||||
|
||||
masked_pixel_values = pixel_values.clone()
|
||||
masked_pixel_values[:, :, :, h//2:, :] = -1
|
||||
|
||||
masked_frames = rearrange(
|
||||
masked_pixel_values, 'b f c h w -> (b f) c h w')
|
||||
masked_latents = vae.encode(masked_frames).latent_dist.mode()
|
||||
masked_latents = masked_latents * vae.config.scaling_factor
|
||||
masked_latents = masked_latents.float()
|
||||
|
||||
ref_frames = rearrange(ref_pixel_values, 'b f c h w-> (b f) c h w')
|
||||
ref_latents = vae.encode(ref_frames).latent_dist.mode()
|
||||
ref_latents = ref_latents * vae.config.scaling_factor
|
||||
ref_latents = ref_latents.float()
|
||||
|
||||
input_latents = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
input_latents = input_latents.to(weight_dtype)
|
||||
timesteps = torch.tensor([0], device=input_latents.device)
|
||||
latents_pred = net(
|
||||
input_latents,
|
||||
timesteps,
|
||||
audio_prompts,
|
||||
)
|
||||
latents_pred = (1 / vae.config.scaling_factor) * latents_pred
|
||||
image_pred = vae.decode(latents_pred).sample
|
||||
image_pred = image_pred.float()
|
||||
|
||||
return image_pred
|
||||
|
||||
def process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype):
|
||||
with torch.no_grad():
|
||||
audio_feature_length_per_frame = 2 * \
|
||||
(cfg.data.audio_padding_length_left +
|
||||
cfg.data.audio_padding_length_right + 1)
|
||||
audio_feats = batch['audio_feature'].to(weight_dtype)
|
||||
audio_feats = wav2vec.encoder(
|
||||
audio_feats, output_hidden_states=True).hidden_states
|
||||
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype) # [B, T, 10, 5, 384]
|
||||
|
||||
start_ts = batch['audio_offset']
|
||||
step_ts = batch['audio_step']
|
||||
audio_feats = torch.cat([torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_left]),
|
||||
audio_feats,
|
||||
torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_right])], 1)
|
||||
audio_prompts = []
|
||||
for bb in range(bsz):
|
||||
audio_feats_list = []
|
||||
for f in range(num_frames):
|
||||
cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
|
||||
audio_clip = audio_feats[bb:bb+1,
|
||||
cur_t: cur_t+audio_feature_length_per_frame]
|
||||
|
||||
audio_feats_list.append(audio_clip)
|
||||
audio_feats_list = torch.stack(audio_feats_list, 1)
|
||||
audio_prompts.append(audio_feats_list)
|
||||
audio_prompts = torch.cat(audio_prompts) # B, T, 10, 5, 384
|
||||
return audio_prompts
|
||||
|
||||
def save_checkpoint(model, save_dir, ckpt_num, name="appearance_net", total_limit=None, logger=None):
|
||||
save_path = os.path.join(save_dir, f"{name}-{ckpt_num}.pth")
|
||||
|
||||
if total_limit is not None:
|
||||
checkpoints = os.listdir(save_dir)
|
||||
checkpoints = [d for d in checkpoints if d.endswith(".pth")]
|
||||
checkpoints = [d for d in checkpoints if name in d]
|
||||
checkpoints = sorted(
|
||||
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
|
||||
)
|
||||
|
||||
if len(checkpoints) >= total_limit:
|
||||
num_to_remove = len(checkpoints) - total_limit + 1
|
||||
removing_checkpoints = checkpoints[0:num_to_remove]
|
||||
logger.info(
|
||||
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
||||
)
|
||||
logger.info(
|
||||
f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
||||
|
||||
for removing_checkpoint in removing_checkpoints:
|
||||
removing_checkpoint = os.path.join(
|
||||
save_dir, removing_checkpoint)
|
||||
os.remove(removing_checkpoint)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, save_path)
|
||||
|
||||
def save_models(accelerator, net, save_dir, global_step, cfg, logger=None):
|
||||
unwarp_net = accelerator.unwrap_model(net)
|
||||
save_checkpoint(
|
||||
unwarp_net.unet,
|
||||
save_dir,
|
||||
global_step,
|
||||
name="unet",
|
||||
total_limit=cfg.total_limit,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
def delete_additional_ckpt(base_path, num_keep):
|
||||
dirs = []
|
||||
for d in os.listdir(base_path):
|
||||
if d.startswith("checkpoint-"):
|
||||
dirs.append(d)
|
||||
num_tot = len(dirs)
|
||||
if num_tot <= num_keep:
|
||||
return
|
||||
# ensure ckpt is sorted and delete the ealier!
|
||||
del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
|
||||
for d in del_dirs:
|
||||
path_to_dir = osp.join(base_path, d)
|
||||
if osp.exists(path_to_dir):
|
||||
shutil.rmtree(path_to_dir)
|
||||
|
||||
def seed_everything(seed):
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed % (2**32))
|
||||
random.seed(seed)
|
||||
|
||||
def process_and_save_images(
|
||||
batch,
|
||||
image_pred,
|
||||
image_pred_infer,
|
||||
save_dir,
|
||||
global_step,
|
||||
accelerator,
|
||||
num_images_to_keep=10,
|
||||
syncnet_score=1
|
||||
):
|
||||
# Rearrange the tensors
|
||||
print("image_pred.shape: ", image_pred.shape)
|
||||
pixel_values_ref_img = rearrange(batch['pixel_values_ref_img'], "b f c h w -> (b f) c h w")
|
||||
pixel_values = rearrange(batch["pixel_values_vid"], 'b f c h w -> (b f) c h w')
|
||||
|
||||
# Create masked pixel values
|
||||
masked_pixel_values = batch["pixel_values_vid"].clone()
|
||||
_, _, _, h, _ = batch["pixel_values_vid"].shape
|
||||
masked_pixel_values[:, :, :, h//2:, :] = -1
|
||||
masked_pixel_values = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
|
||||
|
||||
# Keep only the specified number of images
|
||||
pixel_values = pixel_values[:num_images_to_keep, :, :, :]
|
||||
masked_pixel_values = masked_pixel_values[:num_images_to_keep, :, :, :]
|
||||
pixel_values_ref_img = pixel_values_ref_img[:num_images_to_keep, :, :, :]
|
||||
image_pred = image_pred.detach()[:num_images_to_keep, :, :, :]
|
||||
image_pred_infer = image_pred_infer.detach()[:num_images_to_keep, :, :, :]
|
||||
|
||||
# Concatenate images
|
||||
concat = torch.cat([
|
||||
masked_pixel_values * 0.5 + 0.5,
|
||||
pixel_values_ref_img * 0.5 + 0.5,
|
||||
image_pred * 0.5 + 0.5,
|
||||
pixel_values * 0.5 + 0.5,
|
||||
image_pred_infer * 0.5 + 0.5,
|
||||
], dim=2)
|
||||
print("concat.shape: ", concat.shape)
|
||||
|
||||
# Create the save directory if it doesn't exist
|
||||
os.makedirs(f'{save_dir}/samples/', exist_ok=True)
|
||||
|
||||
# Try to save the concatenated image
|
||||
try:
|
||||
# Concatenate images horizontally and convert to numpy array
|
||||
final_image = torch.cat([concat[i] for i in range(concat.shape[0])], dim=-1).permute(1, 2, 0).cpu().numpy()[:, :, [2, 1, 0]] * 255
|
||||
# Save the image
|
||||
cv2.imwrite(f'{save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg', final_image)
|
||||
print(f"Image saved successfully: {save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg")
|
||||
except Exception as e:
|
||||
print(f"Failed to save image: {e}")
|
||||
128
models/MuseTalk/musetalk/whisper/audio2feature.py
Normal file
128
models/MuseTalk/musetalk/whisper/audio2feature.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import os
|
||||
from .whisper import load_model
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
import time
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
|
||||
class Audio2Feature():
|
||||
def __init__(self,
|
||||
whisper_model_type="tiny",
|
||||
model_path="./models/whisper/tiny.pt"):
|
||||
self.whisper_model_type = whisper_model_type
|
||||
self.model = load_model(model_path) #
|
||||
|
||||
def get_sliced_feature(self,
|
||||
feature_array,
|
||||
vid_idx,
|
||||
audio_feat_length=[2,2],
|
||||
fps=25):
|
||||
"""
|
||||
Get sliced features based on a given index
|
||||
:param feature_array:
|
||||
:param start_idx: the start index of the feature
|
||||
:param audio_feat_length:
|
||||
:return:
|
||||
"""
|
||||
length = len(feature_array)
|
||||
selected_feature = []
|
||||
selected_idx = []
|
||||
|
||||
center_idx = int(vid_idx*50/fps)
|
||||
left_idx = center_idx-audio_feat_length[0]*2
|
||||
right_idx = center_idx + (audio_feat_length[1]+1)*2
|
||||
|
||||
for idx in range(left_idx,right_idx):
|
||||
idx = max(0, idx)
|
||||
idx = min(length-1, idx)
|
||||
x = feature_array[idx]
|
||||
selected_feature.append(x)
|
||||
selected_idx.append(idx)
|
||||
|
||||
selected_feature = np.concatenate(selected_feature, axis=0)
|
||||
selected_feature = selected_feature.reshape(-1, 384)# 50*384
|
||||
return selected_feature,selected_idx
|
||||
|
||||
def get_sliced_feature_sparse(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25):
|
||||
"""
|
||||
Get sliced features based on a given index
|
||||
:param feature_array:
|
||||
:param start_idx: the start index of the feature
|
||||
:param audio_feat_length:
|
||||
:return:
|
||||
"""
|
||||
length = len(feature_array)
|
||||
selected_feature = []
|
||||
selected_idx = []
|
||||
|
||||
for dt in range(-audio_feat_length[0],audio_feat_length[1]+1):
|
||||
left_idx = int((vid_idx+dt)*50/fps)
|
||||
if left_idx<1 or left_idx>length-1:
|
||||
left_idx = max(0, left_idx)
|
||||
left_idx = min(length-1, left_idx)
|
||||
|
||||
x = feature_array[left_idx]
|
||||
x = x[np.newaxis,:,:]
|
||||
x = np.repeat(x, 2, axis=0)
|
||||
selected_feature.append(x)
|
||||
selected_idx.append(left_idx)
|
||||
selected_idx.append(left_idx)
|
||||
else:
|
||||
x = feature_array[left_idx-1:left_idx+1]
|
||||
selected_feature.append(x)
|
||||
selected_idx.append(left_idx-1)
|
||||
selected_idx.append(left_idx)
|
||||
selected_feature = np.concatenate(selected_feature, axis=0)
|
||||
selected_feature = selected_feature.reshape(-1, 384)# 50*384
|
||||
return selected_feature,selected_idx
|
||||
|
||||
|
||||
def feature2chunks(self,feature_array,fps,audio_feat_length = [2,2]):
|
||||
whisper_chunks = []
|
||||
whisper_idx_multiplier = 50./fps
|
||||
i = 0
|
||||
print(f"video in {fps} FPS, audio idx in 50FPS")
|
||||
while 1:
|
||||
start_idx = int(i * whisper_idx_multiplier)
|
||||
selected_feature,selected_idx = self.get_sliced_feature(feature_array= feature_array,vid_idx = i,audio_feat_length=audio_feat_length,fps=fps)
|
||||
#print(f"i:{i},selected_idx {selected_idx}")
|
||||
whisper_chunks.append(selected_feature)
|
||||
i += 1
|
||||
if start_idx>len(feature_array):
|
||||
break
|
||||
|
||||
return whisper_chunks
|
||||
|
||||
def audio2feat(self,audio_path):
|
||||
# get the sample rate of the audio
|
||||
result = self.model.transcribe(audio_path)
|
||||
embed_list = []
|
||||
for emb in result['segments']:
|
||||
encoder_embeddings = emb['encoder_embeddings']
|
||||
encoder_embeddings = encoder_embeddings.transpose(0,2,1,3)
|
||||
encoder_embeddings = encoder_embeddings.squeeze(0)
|
||||
start_idx = int(emb['start'])
|
||||
end_idx = int(emb['end'])
|
||||
emb_end_idx = int((end_idx - start_idx)/2)
|
||||
embed_list.append(encoder_embeddings[:emb_end_idx])
|
||||
concatenated_array = np.concatenate(embed_list, axis=0)
|
||||
return concatenated_array
|
||||
|
||||
if __name__ == "__main__":
|
||||
audio_processor = Audio2Feature(model_path="../../models/whisper/whisper_tiny.pt")
|
||||
audio_path = "./test.mp3"
|
||||
array = audio_processor.audio2feat(audio_path)
|
||||
print(array.shape)
|
||||
fps = 25
|
||||
whisper_idx_multiplier = 50./fps
|
||||
|
||||
i = 0
|
||||
print(f"video in {fps} FPS, audio idx in 50FPS")
|
||||
while 1:
|
||||
start_idx = int(i * whisper_idx_multiplier)
|
||||
selected_feature,selected_idx = audio_processor.get_sliced_feature(feature_array= array,vid_idx = i,audio_feat_length=[2,2],fps=fps)
|
||||
print(f"video idx {i},\t audio idx {selected_idx},\t shape {selected_feature.shape}")
|
||||
i += 1
|
||||
if start_idx>len(array):
|
||||
break
|
||||
116
models/MuseTalk/musetalk/whisper/whisper/__init__.py
Normal file
116
models/MuseTalk/musetalk/whisper/whisper/__init__.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from .model import Whisper, ModelDimensions
|
||||
from .transcribe import transcribe
|
||||
|
||||
|
||||
_MODELS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt",
|
||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
download_target = os.path.join(root, os.path.basename(url))
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_bytes if in_memory else download_target
|
||||
else:
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model.")
|
||||
|
||||
return model_bytes if in_memory else download_target
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load_model(name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False) -> Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
one of the official model names listed by `whisper.available_models()`, or
|
||||
path to a model checkpoint containing the model dimensions and the model state_dict.
|
||||
device : Union[str, torch.device]
|
||||
the PyTorch device to put the model into
|
||||
download_root: str
|
||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||
in_memory: bool
|
||||
whether to preload the model weights into host memory
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : Whisper
|
||||
The Whisper ASR model instance
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if download_root is None:
|
||||
download_root = os.getenv(
|
||||
"XDG_CACHE_HOME",
|
||||
os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
||||
)
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
elif os.path.isfile(name):
|
||||
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||
else:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
with (io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")) as fp:
|
||||
checkpoint = torch.load(fp, map_location=device)
|
||||
del checkpoint_file
|
||||
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
model = Whisper(dims)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
return model.to(device)
|
||||
4
models/MuseTalk/musetalk/whisper/whisper/__main__.py
Normal file
4
models/MuseTalk/musetalk/whisper/whisper/__main__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .transcribe import cli
|
||||
|
||||
|
||||
cli()
|
||||
50001
models/MuseTalk/musetalk/whisper/whisper/assets/gpt2/merges.txt
Normal file
50001
models/MuseTalk/musetalk/whisper/whisper/assets/gpt2/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1 @@
|
||||
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
||||
@@ -0,0 +1 @@
|
||||
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
|
||||
File diff suppressed because one or more lines are too long
BIN
models/MuseTalk/musetalk/whisper/whisper/assets/mel_filters.npz
Normal file
BIN
models/MuseTalk/musetalk/whisper/whisper/assets/mel_filters.npz
Normal file
Binary file not shown.
@@ -0,0 +1 @@
|
||||
{"<|endoftext|>": 50257}
|
||||
50000
models/MuseTalk/musetalk/whisper/whisper/assets/multilingual/merges.txt
Normal file
50000
models/MuseTalk/musetalk/whisper/whisper/assets/multilingual/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1 @@
|
||||
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
||||
@@ -0,0 +1 @@
|
||||
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
|
||||
File diff suppressed because one or more lines are too long
125
models/MuseTalk/musetalk/whisper/whisper/audio.py
Normal file
125
models/MuseTalk/musetalk/whisper/whisper/audio.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Union
|
||||
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils import exact_div
|
||||
|
||||
# hard-coded audio hyperparameters
|
||||
SAMPLE_RATE = 16000
|
||||
N_FFT = 400
|
||||
N_MELS = 80
|
||||
HOP_LENGTH = 160
|
||||
CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
||||
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
|
||||
|
||||
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
"""
|
||||
Open an audio file and read as mono waveform, resampling as necessary
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file: str
|
||||
The audio file to open
|
||||
|
||||
sr: int
|
||||
The sample rate to resample the audio if necessary
|
||||
|
||||
Returns
|
||||
-------
|
||||
A NumPy array containing the audio waveform, in float32 dtype.
|
||||
"""
|
||||
try:
|
||||
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
||||
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
||||
out, _ = (
|
||||
ffmpeg.input(file, threads=0)
|
||||
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
|
||||
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
||||
)
|
||||
except ffmpeg.Error as e:
|
||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||
|
||||
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||
|
||||
|
||||
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||
"""
|
||||
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
||||
"""
|
||||
if torch.is_tensor(array):
|
||||
if array.shape[axis] > length:
|
||||
array = array.index_select(dim=axis, index=torch.arange(length))
|
||||
|
||||
if array.shape[axis] < length:
|
||||
pad_widths = [(0, 0)] * array.ndim
|
||||
pad_widths[axis] = (0, length - array.shape[axis])
|
||||
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
||||
else:
|
||||
if array.shape[axis] > length:
|
||||
array = array.take(indices=range(length), axis=axis)
|
||||
|
||||
if array.shape[axis] < length:
|
||||
pad_widths = [(0, 0)] * array.ndim
|
||||
pad_widths[axis] = (0, length - array.shape[axis])
|
||||
array = np.pad(array, pad_widths)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
||||
"""
|
||||
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||
Allows decoupling librosa dependency; saved using:
|
||||
|
||||
np.savez_compressed(
|
||||
"mel_filters.npz",
|
||||
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
||||
)
|
||||
"""
|
||||
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
||||
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
|
||||
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
||||
|
||||
|
||||
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
|
||||
"""
|
||||
Compute the log-Mel spectrogram of
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
||||
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
||||
|
||||
n_mels: int
|
||||
The number of Mel-frequency filters, only 80 is supported
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor, shape = (80, n_frames)
|
||||
A Tensor that contains the Mel spectrogram
|
||||
"""
|
||||
if not torch.is_tensor(audio):
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
audio = torch.from_numpy(audio)
|
||||
|
||||
window = torch.hann_window(N_FFT).to(audio.device)
|
||||
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||
|
||||
magnitudes = stft[:, :-1].abs() ** 2
|
||||
|
||||
filters = mel_filters(audio.device, n_mels)
|
||||
mel_spec = filters @ magnitudes
|
||||
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec
|
||||
729
models/MuseTalk/musetalk/whisper/whisper/decoding.py
Normal file
729
models/MuseTalk/musetalk/whisper/whisper/decoding.py
Normal file
@@ -0,0 +1,729 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from .audio import CHUNK_LENGTH
|
||||
from .tokenizer import Tokenizer, get_tokenizer
|
||||
from .utils import compression_ratio
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
|
||||
"""
|
||||
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
||||
of the most probable language tokens and the probability distribution over all language tokens.
|
||||
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
||||
|
||||
Returns
|
||||
-------
|
||||
language_tokens : Tensor, shape = (n_audio,)
|
||||
ids of the most probable language tokens, which appears after the startoftranscript token.
|
||||
language_probs : List[Dict[str, float]], length = n_audio
|
||||
list of dictionaries containing the probability distribution over all languages.
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer(model.is_multilingual)
|
||||
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
|
||||
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
|
||||
|
||||
single = mel.ndim == 2
|
||||
if single:
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
# skip encoder forward pass if already-encoded audio features were given
|
||||
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
||||
mel = model.encoder(mel)
|
||||
|
||||
# forward pass using a single token, startoftranscript
|
||||
n_audio = mel.shape[0]
|
||||
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
||||
logits = model.logits(x, mel)[:, 0]
|
||||
|
||||
# collect detected languages; suppress all non-language tokens
|
||||
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||
mask[list(tokenizer.all_language_tokens)] = False
|
||||
logits[:, mask] = -np.inf
|
||||
language_tokens = logits.argmax(dim=-1)
|
||||
language_token_probs = logits.softmax(dim=-1).cpu()
|
||||
language_probs = [
|
||||
{
|
||||
c: language_token_probs[i, j].item()
|
||||
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
||||
}
|
||||
for i in range(n_audio)
|
||||
]
|
||||
|
||||
if single:
|
||||
language_tokens = language_tokens[0]
|
||||
language_probs = language_probs[0]
|
||||
|
||||
return language_tokens, language_probs
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodingOptions:
|
||||
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
|
||||
language: Optional[str] = None # language that the audio is in; uses detected language if None
|
||||
|
||||
# sampling-related options
|
||||
temperature: float = 0.0
|
||||
sample_len: Optional[int] = None # maximum number of tokens to sample
|
||||
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
|
||||
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
|
||||
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
|
||||
|
||||
# options for ranking generations (either beams or best-of-N samples)
|
||||
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
|
||||
|
||||
# prompt, prefix, and token suppression
|
||||
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
|
||||
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
|
||||
suppress_blank: bool = True # this will suppress blank outputs
|
||||
|
||||
# list of tokens ids (or comma-separated token ids) to suppress
|
||||
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
||||
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
||||
|
||||
# timestamp sampling options
|
||||
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
||||
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
|
||||
|
||||
# implementation details
|
||||
fp16: bool = True # use fp16 for most of the calculation
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodingResult:
|
||||
audio_features: Tensor
|
||||
language: str
|
||||
encoder_embeddings: np.ndarray
|
||||
decoder_embeddings: np.ndarray
|
||||
language_probs: Optional[Dict[str, float]] = None
|
||||
tokens: List[int] = field(default_factory=list)
|
||||
text: str = ""
|
||||
avg_logprob: float = np.nan
|
||||
no_speech_prob: float = np.nan
|
||||
temperature: float = np.nan
|
||||
compression_ratio: float = np.nan
|
||||
|
||||
|
||||
class Inference:
|
||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||
"""Perform a forward pass on the decoder and return per-token logits"""
|
||||
raise NotImplementedError
|
||||
|
||||
def rearrange_kv_cache(self, source_indices) -> None:
|
||||
"""Update the key-value cache according to the updated beams"""
|
||||
raise NotImplementedError
|
||||
|
||||
def cleanup_caching(self) -> None:
|
||||
"""Clean up any resources or hooks after decoding is finished"""
|
||||
pass
|
||||
|
||||
|
||||
class PyTorchInference(Inference):
|
||||
def __init__(self, model: "Whisper", initial_token_length: int):
|
||||
self.model: "Whisper" = model
|
||||
self.initial_token_length = initial_token_length
|
||||
self.kv_cache = {}
|
||||
self.hooks = []
|
||||
|
||||
def logits(self, tokens: Tensor, audio_features: Tensor, include_embeddings=False) -> Tensor:
|
||||
if not self.kv_cache:
|
||||
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
||||
|
||||
if tokens.shape[-1] > self.initial_token_length:
|
||||
# only need to use the last token except in the first forward pass
|
||||
tokens = tokens[:, -1:]
|
||||
|
||||
return_val = self.model.decoder(tokens, audio_features,
|
||||
kv_cache=self.kv_cache, include_embeddings=include_embeddings)
|
||||
return return_val
|
||||
|
||||
def cleanup_caching(self):
|
||||
for hook in self.hooks:
|
||||
hook.remove()
|
||||
|
||||
self.kv_cache = {}
|
||||
self.hooks = []
|
||||
|
||||
def rearrange_kv_cache(self, source_indices):
|
||||
for module, tensor in self.kv_cache.items():
|
||||
# update the key/value cache to contain the selected sequences
|
||||
self.kv_cache[module] = tensor[source_indices].detach()
|
||||
|
||||
|
||||
class SequenceRanker:
|
||||
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
|
||||
"""
|
||||
Given a list of groups of samples and their cumulative log probabilities,
|
||||
return the indices of the samples in each group to select as the final result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MaximumLikelihoodRanker(SequenceRanker):
|
||||
"""
|
||||
Select the sample with the highest log probabilities, penalized using either
|
||||
a simple length normalization or Google NMT paper's length penalty
|
||||
"""
|
||||
|
||||
def __init__(self, length_penalty: Optional[float]):
|
||||
self.length_penalty = length_penalty
|
||||
|
||||
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
|
||||
def scores(logprobs, lengths):
|
||||
result = []
|
||||
for logprob, length in zip(logprobs, lengths):
|
||||
if self.length_penalty is None:
|
||||
penalty = length
|
||||
else:
|
||||
# from the Google NMT paper
|
||||
penalty = ((5 + length) / 6) ** self.length_penalty
|
||||
result.append(logprob / penalty)
|
||||
return result
|
||||
|
||||
# get the sequence with the highest score
|
||||
lengths = [[len(t) for t in s] for s in tokens]
|
||||
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
||||
|
||||
|
||||
class TokenDecoder:
|
||||
def reset(self):
|
||||
"""Initialize any stateful variables for decoding a new sequence"""
|
||||
|
||||
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||
"""Specify how to select the next token, based on the current trace and logits
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||
|
||||
logits : Tensor, shape = (n_batch, vocab_size)
|
||||
per-token logits of the probability distribution at the current step
|
||||
|
||||
sum_logprobs : Tensor, shape = (n_batch)
|
||||
cumulative log probabilities for each sequence
|
||||
|
||||
Returns
|
||||
-------
|
||||
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
||||
the tokens, appended with the selected next token
|
||||
|
||||
completed : bool
|
||||
True if all sequences has reached the end of text
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def finalize(
|
||||
self, tokens: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
||||
"""Finalize search and return the final candidate sequences
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence
|
||||
|
||||
sum_logprobs : Tensor, shape = (n_audio, n_group)
|
||||
cumulative log probabilities for each sequence
|
||||
|
||||
Returns
|
||||
-------
|
||||
tokens : Sequence[Sequence[Tensor]], length = n_audio
|
||||
sequence of Tensors containing candidate token sequences, for each audio input
|
||||
|
||||
sum_logprobs : List[List[float]], length = n_audio
|
||||
sequence of cumulative log probabilities corresponding to the above
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GreedyDecoder(TokenDecoder):
|
||||
def __init__(self, temperature: float, eot: int):
|
||||
self.temperature = temperature
|
||||
self.eot = eot
|
||||
|
||||
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||
temperature = self.temperature
|
||||
if temperature == 0:
|
||||
next_tokens = logits.argmax(dim=-1)
|
||||
else:
|
||||
next_tokens = Categorical(logits=logits / temperature).sample()
|
||||
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
||||
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
||||
|
||||
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
||||
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
||||
|
||||
completed = (tokens[:, -1] == self.eot).all()
|
||||
return tokens, completed
|
||||
|
||||
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
||||
# make sure each sequence has at least one EOT token at the end
|
||||
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
||||
return tokens, sum_logprobs.tolist()
|
||||
|
||||
|
||||
class BeamSearchDecoder(TokenDecoder):
|
||||
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
|
||||
self.beam_size = beam_size
|
||||
self.eot = eot
|
||||
self.inference = inference
|
||||
self.patience = patience or 1.0
|
||||
self.max_candidates: int = round(beam_size * self.patience)
|
||||
self.finished_sequences = None
|
||||
|
||||
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
|
||||
|
||||
def reset(self):
|
||||
self.finished_sequences = None
|
||||
|
||||
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||
if tokens.shape[0] % self.beam_size != 0:
|
||||
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
||||
|
||||
n_audio = tokens.shape[0] // self.beam_size
|
||||
if self.finished_sequences is None: # for the first update
|
||||
self.finished_sequences = [{} for _ in range(n_audio)]
|
||||
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
next_tokens, source_indices, finished_sequences = [], [], []
|
||||
for i in range(n_audio):
|
||||
scores, sources, finished = {}, {}, {}
|
||||
|
||||
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
||||
for j in range(self.beam_size):
|
||||
idx = i * self.beam_size + j
|
||||
prefix = tokens[idx].tolist()
|
||||
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
||||
new_logprob = (sum_logprobs[idx] + logprob).item()
|
||||
sequence = tuple(prefix + [token.item()])
|
||||
scores[sequence] = new_logprob
|
||||
sources[sequence] = idx
|
||||
|
||||
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
||||
saved = 0
|
||||
for sequence in sorted(scores, key=scores.get, reverse=True):
|
||||
if sequence[-1] == self.eot:
|
||||
finished[sequence] = scores[sequence]
|
||||
else:
|
||||
sum_logprobs[len(next_tokens)] = scores[sequence]
|
||||
next_tokens.append(sequence)
|
||||
source_indices.append(sources[sequence])
|
||||
|
||||
saved += 1
|
||||
if saved == self.beam_size:
|
||||
break
|
||||
|
||||
finished_sequences.append(finished)
|
||||
|
||||
tokens = torch.tensor(next_tokens, device=tokens.device)
|
||||
self.inference.rearrange_kv_cache(source_indices)
|
||||
|
||||
# add newly finished sequences to self.finished_sequences
|
||||
assert len(self.finished_sequences) == len(finished_sequences)
|
||||
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
|
||||
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
||||
if len(previously_finished) >= self.max_candidates:
|
||||
break # the candidate list is full
|
||||
previously_finished[seq] = newly_finished[seq]
|
||||
|
||||
# mark as completed if all audio has enough number of samples
|
||||
completed = all(
|
||||
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
|
||||
)
|
||||
return tokens, completed
|
||||
|
||||
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
||||
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
||||
sum_logprobs = sum_logprobs.cpu()
|
||||
for i, sequences in enumerate(self.finished_sequences):
|
||||
if len(sequences) < self.beam_size: # when not enough sequences are finished
|
||||
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
||||
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
||||
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
||||
if len(sequences) >= self.beam_size:
|
||||
break
|
||||
|
||||
tokens: List[List[Tensor]] = [
|
||||
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
|
||||
]
|
||||
sum_logprobs: List[List[float]] = [
|
||||
list(sequences.values()) for sequences in self.finished_sequences
|
||||
]
|
||||
return tokens, sum_logprobs
|
||||
|
||||
|
||||
class LogitFilter:
|
||||
def apply(self, logits: Tensor, tokens: Tensor) -> None:
|
||||
"""Apply any filtering or masking to logits in-place
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logits : Tensor, shape = (n_batch, vocab_size)
|
||||
per-token logits of the probability distribution at the current step
|
||||
|
||||
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SuppressBlank(LogitFilter):
|
||||
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
||||
self.tokenizer = tokenizer
|
||||
self.sample_begin = sample_begin
|
||||
|
||||
def apply(self, logits: Tensor, tokens: Tensor):
|
||||
if tokens.shape[1] == self.sample_begin:
|
||||
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||
|
||||
|
||||
class SuppressTokens(LogitFilter):
|
||||
def __init__(self, suppress_tokens: Sequence[int]):
|
||||
self.suppress_tokens = list(suppress_tokens)
|
||||
|
||||
def apply(self, logits: Tensor, tokens: Tensor):
|
||||
logits[:, self.suppress_tokens] = -np.inf
|
||||
|
||||
|
||||
class ApplyTimestampRules(LogitFilter):
|
||||
def __init__(
|
||||
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.sample_begin = sample_begin
|
||||
self.max_initial_timestamp_index = max_initial_timestamp_index
|
||||
|
||||
def apply(self, logits: Tensor, tokens: Tensor):
|
||||
# suppress <|notimestamps|> which is handled by without_timestamps
|
||||
if self.tokenizer.no_timestamps is not None:
|
||||
logits[:, self.tokenizer.no_timestamps] = -np.inf
|
||||
|
||||
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||
for k in range(tokens.shape[0]):
|
||||
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
|
||||
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
||||
|
||||
if last_was_timestamp:
|
||||
if penultimate_was_timestamp: # has to be non-timestamp
|
||||
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
||||
else: # cannot be normal text tokens
|
||||
logits[k, : self.tokenizer.eot] = -np.inf
|
||||
|
||||
# apply the `max_initial_timestamp` option
|
||||
if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
|
||||
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
||||
logits[:, last_allowed + 1 :] = -np.inf
|
||||
|
||||
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
for k in range(tokens.shape[0]):
|
||||
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
|
||||
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
||||
if timestamp_logprob > max_text_token_logprob:
|
||||
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
|
||||
|
||||
class DecodingTask:
|
||||
inference: Inference
|
||||
sequence_ranker: SequenceRanker
|
||||
decoder: TokenDecoder
|
||||
logit_filters: List[LogitFilter]
|
||||
|
||||
def __init__(self, model: "Whisper", options: DecodingOptions):
|
||||
self.model = model
|
||||
|
||||
language = options.language or "en"
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
|
||||
self.tokenizer: Tokenizer = tokenizer
|
||||
self.options: DecodingOptions = self._verify_options(options)
|
||||
|
||||
self.n_group: int = options.beam_size or options.best_of or 1
|
||||
self.n_ctx: int = model.dims.n_text_ctx
|
||||
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
||||
|
||||
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
||||
if self.options.without_timestamps:
|
||||
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
||||
|
||||
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
||||
self.sample_begin: int = len(self.initial_tokens)
|
||||
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
||||
|
||||
# inference: implements the forward pass through the decoder, including kv caching
|
||||
self.inference = PyTorchInference(model, len(self.initial_tokens))
|
||||
|
||||
# sequence ranker: implements how to rank a group of sampled sequences
|
||||
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
||||
|
||||
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
||||
if options.beam_size is not None:
|
||||
self.decoder = BeamSearchDecoder(
|
||||
options.beam_size, tokenizer.eot, self.inference, options.patience
|
||||
)
|
||||
else:
|
||||
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
||||
|
||||
# logit filters: applies various rules to suppress or penalize certain tokens
|
||||
self.logit_filters = []
|
||||
if self.options.suppress_blank:
|
||||
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
||||
if self.options.suppress_tokens:
|
||||
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
||||
if not options.without_timestamps:
|
||||
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
||||
max_initial_timestamp_index = None
|
||||
if options.max_initial_timestamp:
|
||||
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
|
||||
self.logit_filters.append(
|
||||
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
|
||||
)
|
||||
|
||||
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
||||
if options.beam_size is not None and options.best_of is not None:
|
||||
raise ValueError("beam_size and best_of can't be given together")
|
||||
if options.temperature == 0:
|
||||
if options.best_of is not None:
|
||||
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
||||
if options.patience is not None and options.beam_size is None:
|
||||
raise ValueError("patience requires beam_size to be given")
|
||||
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
|
||||
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
||||
|
||||
return options
|
||||
|
||||
def _get_initial_tokens(self) -> Tuple[int]:
|
||||
tokens = list(self.sot_sequence)
|
||||
prefix = self.options.prefix
|
||||
prompt = self.options.prompt
|
||||
|
||||
if prefix:
|
||||
prefix_tokens = (
|
||||
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
|
||||
)
|
||||
if self.sample_len is not None:
|
||||
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
||||
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
||||
tokens = tokens + prefix_tokens
|
||||
|
||||
if prompt:
|
||||
prompt_tokens = (
|
||||
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
|
||||
)
|
||||
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
|
||||
|
||||
return tuple(tokens)
|
||||
|
||||
def _get_suppress_tokens(self) -> Tuple[int]:
|
||||
suppress_tokens = self.options.suppress_tokens
|
||||
|
||||
if isinstance(suppress_tokens, str):
|
||||
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
||||
|
||||
if -1 in suppress_tokens:
|
||||
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
||||
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
||||
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
||||
suppress_tokens = [] # interpret empty string as an empty list
|
||||
else:
|
||||
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
||||
|
||||
suppress_tokens.extend(
|
||||
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
|
||||
)
|
||||
if self.tokenizer.no_speech is not None:
|
||||
# no-speech probability is collected separately
|
||||
suppress_tokens.append(self.tokenizer.no_speech)
|
||||
|
||||
return tuple(sorted(set(suppress_tokens)))
|
||||
|
||||
def _get_audio_features(self, mel: Tensor, include_embeddings: bool = False):
|
||||
if self.options.fp16:
|
||||
mel = mel.half()
|
||||
|
||||
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
||||
# encoded audio features are given; skip audio encoding
|
||||
audio_features = mel
|
||||
else:
|
||||
result = self.model.encoder(mel, include_embeddings)
|
||||
if include_embeddings:
|
||||
audio_features, embeddings = result
|
||||
else:
|
||||
audio_features = result
|
||||
|
||||
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
|
||||
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
|
||||
|
||||
if include_embeddings:
|
||||
return audio_features, embeddings
|
||||
else:
|
||||
return audio_features
|
||||
|
||||
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
|
||||
languages = [self.options.language] * audio_features.shape[0]
|
||||
lang_probs = None
|
||||
|
||||
if self.options.language is None or self.options.task == "lang_id":
|
||||
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
|
||||
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
||||
if self.options.language is None:
|
||||
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
||||
|
||||
return languages, lang_probs
|
||||
|
||||
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
||||
assert audio_features.shape[0] == tokens.shape[0]
|
||||
n_batch = tokens.shape[0]
|
||||
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
||||
no_speech_probs = [np.nan] * n_batch
|
||||
|
||||
try:
|
||||
embeddings = []
|
||||
for i in range(self.sample_len):
|
||||
logits, token_embeddings = self.inference.logits(tokens, audio_features, include_embeddings=True)
|
||||
|
||||
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
|
||||
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||
|
||||
# now we need to consider the logits at the last token only
|
||||
logits = logits[:, -1]
|
||||
token_embeddings = token_embeddings[:, :, -1]
|
||||
|
||||
# Append embeddings together
|
||||
embeddings.append(token_embeddings)
|
||||
|
||||
# apply the logit filters, e.g. for suppressing or applying penalty to
|
||||
for logit_filter in self.logit_filters:
|
||||
logit_filter.apply(logits, tokens)
|
||||
|
||||
# expand the tokens tensor with the selected next tokens
|
||||
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
||||
|
||||
if completed or tokens.shape[-1] > self.n_ctx:
|
||||
break
|
||||
finally:
|
||||
if completed:
|
||||
embeddings = embeddings[:-1]
|
||||
embeddings = np.stack(embeddings, 2)
|
||||
self.inference.cleanup_caching()
|
||||
|
||||
return tokens, sum_logprobs, no_speech_probs, embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def run(self, mel: Tensor) -> List[DecodingResult]:
|
||||
self.decoder.reset()
|
||||
tokenizer: Tokenizer = self.tokenizer
|
||||
n_audio: int = mel.shape[0]
|
||||
|
||||
# encoder forward pass
|
||||
forward_pass: Tuple[Tensor, np.ndarray] = self._get_audio_features(mel, include_embeddings=True)
|
||||
audio_features, encoder_embeddings = forward_pass
|
||||
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
||||
|
||||
# detect language if requested, overwriting the language token
|
||||
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||
if self.options.task == "lang_id":
|
||||
return [
|
||||
DecodingResult(audio_features=features, language=language, language_probs=probs)
|
||||
for features, language, probs in zip(audio_features, languages, language_probs)
|
||||
]
|
||||
|
||||
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
|
||||
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
|
||||
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
||||
|
||||
# call the main sampling loop
|
||||
tokens, sum_logprobs, no_speech_probs, decoder_embeddings = self._main_loop(audio_features, tokens)
|
||||
|
||||
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
||||
audio_features = audio_features[:: self.n_group]
|
||||
no_speech_probs = no_speech_probs[:: self.n_group]
|
||||
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
||||
|
||||
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
||||
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
||||
|
||||
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
||||
tokens: List[List[Tensor]] = [
|
||||
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
|
||||
]
|
||||
|
||||
# select the top-ranked sample in each group
|
||||
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
||||
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
||||
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
||||
|
||||
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
||||
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
|
||||
|
||||
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
|
||||
if len(set(map(len, fields))) != 1:
|
||||
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
||||
|
||||
return [
|
||||
DecodingResult(
|
||||
audio_features=features,
|
||||
language=language,
|
||||
tokens=tokens,
|
||||
text=text,
|
||||
avg_logprob=avg_logprob,
|
||||
no_speech_prob=no_speech_prob,
|
||||
temperature=self.options.temperature,
|
||||
compression_ratio=compression_ratio(text),
|
||||
encoder_embeddings=encoder_embeddings,
|
||||
decoder_embeddings=decoder_embeddings
|
||||
)
|
||||
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
|
||||
"""
|
||||
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Whisper
|
||||
the Whisper model instance
|
||||
|
||||
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
||||
A tensor containing the Mel spectrogram(s)
|
||||
|
||||
options: DecodingOptions
|
||||
A dataclass that contains all necessary options for decoding 30-second segments
|
||||
|
||||
Returns
|
||||
-------
|
||||
result: Union[DecodingResult, List[DecodingResult]]
|
||||
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
||||
"""
|
||||
single = mel.ndim == 2
|
||||
if single:
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
result = DecodingTask(model, options).run(mel)
|
||||
|
||||
if single:
|
||||
result = result[0]
|
||||
|
||||
return result
|
||||
290
models/MuseTalk/musetalk/whisper/whisper/model.py
Normal file
290
models/MuseTalk/musetalk/whisper/whisper/model.py
Normal file
@@ -0,0 +1,290 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
|
||||
from .transcribe import transcribe as transcribe_function
|
||||
from .decoding import detect_language as detect_language_function, decode as decode_function
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelDimensions:
|
||||
n_mels: int
|
||||
n_audio_ctx: int
|
||||
n_audio_state: int
|
||||
n_audio_head: int
|
||||
n_audio_layer: int
|
||||
n_vocab: int
|
||||
n_text_ctx: int
|
||||
n_text_state: int
|
||||
n_text_head: int
|
||||
n_text_layer: int
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(
|
||||
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
|
||||
)
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
|
||||
return super()._conv_forward(
|
||||
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||
)
|
||||
|
||||
|
||||
def sinusoids(length, channels, max_timescale=10000):
|
||||
"""Returns sinusoids for positional embedding"""
|
||||
assert channels % 2 == 0
|
||||
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
||||
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
||||
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.query = Linear(n_state, n_state)
|
||||
self.key = Linear(n_state, n_state, bias=False)
|
||||
self.value = Linear(n_state, n_state)
|
||||
self.out = Linear(n_state, n_state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
q = self.query(x)
|
||||
|
||||
if kv_cache is None or xa is None:
|
||||
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||
k = self.key(x if xa is None else xa)
|
||||
v = self.value(x if xa is None else xa)
|
||||
else:
|
||||
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||
k = kv_cache.get(self.key, self.key(xa))
|
||||
v = kv_cache.get(self.value, self.value(xa))
|
||||
|
||||
wv = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv)
|
||||
|
||||
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
|
||||
qk = q @ k
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
|
||||
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
|
||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.attn = MultiHeadAttention(n_state, n_head)
|
||||
self.attn_ln = LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||
|
||||
n_mlp = n_state * 4
|
||||
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
|
||||
self.mlp_ln = LayerNorm(n_state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
|
||||
x = x + self.mlp(self.mlp_ln(x))
|
||||
return x
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||
super().__init__()
|
||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
||||
)
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
|
||||
def forward(self, x: Tensor, include_embeddings: bool = False):
|
||||
"""
|
||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||
the mel spectrogram of the audio
|
||||
include_embeddings: bool
|
||||
whether to include intermediate steps in the output
|
||||
"""
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||
x = (x + self.positional_embedding).to(x.dtype)
|
||||
|
||||
if include_embeddings:
|
||||
embeddings = [x.cpu().detach().numpy()]
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
if include_embeddings:
|
||||
embeddings.append(x.cpu().detach().numpy())
|
||||
|
||||
x = self.ln_post(x)
|
||||
|
||||
if include_embeddings:
|
||||
embeddings = np.stack(embeddings, axis=1)
|
||||
return x, embeddings
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class TextDecoder(nn.Module):
|
||||
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
|
||||
)
|
||||
self.ln = LayerNorm(n_state)
|
||||
|
||||
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None, include_embeddings: bool = False):
|
||||
"""
|
||||
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
||||
the text tokens
|
||||
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
|
||||
the encoded audio features to be attended on
|
||||
include_embeddings : bool
|
||||
Whether to include intermediate values in the output to this function
|
||||
"""
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
x = x.to(xa.dtype)
|
||||
|
||||
if include_embeddings:
|
||||
embeddings = [x.cpu().detach().numpy()]
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||
if include_embeddings:
|
||||
embeddings.append(x.cpu().detach().numpy())
|
||||
|
||||
x = self.ln(x)
|
||||
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
||||
|
||||
if include_embeddings:
|
||||
embeddings = np.stack(embeddings, axis=1)
|
||||
return logits, embeddings
|
||||
else:
|
||||
return logits
|
||||
|
||||
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.encoder = AudioEncoder(
|
||||
self.dims.n_mels,
|
||||
self.dims.n_audio_ctx,
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
)
|
||||
self.decoder = TextDecoder(
|
||||
self.dims.n_vocab,
|
||||
self.dims.n_text_ctx,
|
||||
self.dims.n_text_state,
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
)
|
||||
|
||||
def embed_audio(self, mel: torch.Tensor):
|
||||
return self.encoder.forward(mel)
|
||||
|
||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||
return self.decoder.forward(tokens, audio_features)
|
||||
|
||||
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
return self.decoder(tokens, self.encoder(mel))
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def is_multilingual(self):
|
||||
return self.dims.n_vocab == 51865
|
||||
|
||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||
"""
|
||||
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
||||
tensors calculated for the previous positions. This method returns a dictionary that stores
|
||||
all caches, and the necessary hooks for the key and value projection modules that save the
|
||||
intermediate tensors to be reused during later calculations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
cache : Dict[nn.Module, torch.Tensor]
|
||||
A dictionary object mapping the key/value projection modules to its cache
|
||||
hooks : List[RemovableHandle]
|
||||
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
||||
"""
|
||||
cache = {**cache} if cache is not None else {}
|
||||
hooks = []
|
||||
|
||||
def save_to_cache(module, _, output):
|
||||
if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
|
||||
cache[module] = output # save as-is, for the first token or cross attention
|
||||
else:
|
||||
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
||||
return cache[module]
|
||||
|
||||
def install_hooks(layer: nn.Module):
|
||||
if isinstance(layer, MultiHeadAttention):
|
||||
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
||||
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
||||
|
||||
self.decoder.apply(install_hooks)
|
||||
return cache, hooks
|
||||
|
||||
detect_language = detect_language_function
|
||||
transcribe = transcribe_function
|
||||
decode = decode_function
|
||||
@@ -0,0 +1,2 @@
|
||||
from .basic import BasicTextNormalizer
|
||||
from .english import EnglishTextNormalizer
|
||||
@@ -0,0 +1,71 @@
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
import regex
|
||||
|
||||
# non-ASCII letters that are not separated by "NFKD" normalization
|
||||
ADDITIONAL_DIACRITICS = {
|
||||
"œ": "oe",
|
||||
"Œ": "OE",
|
||||
"ø": "o",
|
||||
"Ø": "O",
|
||||
"æ": "ae",
|
||||
"Æ": "AE",
|
||||
"ß": "ss",
|
||||
"ẞ": "SS",
|
||||
"đ": "d",
|
||||
"Đ": "D",
|
||||
"ð": "d",
|
||||
"Ð": "D",
|
||||
"þ": "th",
|
||||
"Þ": "th",
|
||||
"ł": "l",
|
||||
"Ł": "L",
|
||||
}
|
||||
|
||||
|
||||
def remove_symbols_and_diacritics(s: str, keep=""):
|
||||
"""
|
||||
Replace any other markers, symbols, and punctuations with a space,
|
||||
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||
"""
|
||||
return "".join(
|
||||
c
|
||||
if c in keep
|
||||
else ADDITIONAL_DIACRITICS[c]
|
||||
if c in ADDITIONAL_DIACRITICS
|
||||
else ""
|
||||
if unicodedata.category(c) == "Mn"
|
||||
else " "
|
||||
if unicodedata.category(c)[0] in "MSP"
|
||||
else c
|
||||
for c in unicodedata.normalize("NFKD", s)
|
||||
)
|
||||
|
||||
|
||||
def remove_symbols(s: str):
|
||||
"""
|
||||
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
||||
"""
|
||||
return "".join(
|
||||
" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
|
||||
)
|
||||
|
||||
|
||||
class BasicTextNormalizer:
|
||||
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
|
||||
self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
||||
self.split_letters = split_letters
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = s.lower()
|
||||
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||
s = self.clean(s).lower()
|
||||
|
||||
if self.split_letters:
|
||||
s = " ".join(regex.findall(r"\X", s, regex.U))
|
||||
|
||||
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
|
||||
|
||||
return s
|
||||
1742
models/MuseTalk/musetalk/whisper/whisper/normalizers/english.json
Normal file
1742
models/MuseTalk/musetalk/whisper/whisper/normalizers/english.json
Normal file
File diff suppressed because it is too large
Load Diff
543
models/MuseTalk/musetalk/whisper/whisper/normalizers/english.py
Normal file
543
models/MuseTalk/musetalk/whisper/whisper/normalizers/english.py
Normal file
@@ -0,0 +1,543 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from fractions import Fraction
|
||||
from typing import Iterator, List, Match, Optional, Union
|
||||
|
||||
from more_itertools import windowed
|
||||
|
||||
from .basic import remove_symbols_and_diacritics
|
||||
|
||||
|
||||
class EnglishNumberNormalizer:
|
||||
"""
|
||||
Convert any spelled-out numbers into arabic numbers, while handling:
|
||||
|
||||
- remove any commas
|
||||
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
|
||||
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
|
||||
- spell out `one` and `ones`
|
||||
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.zeros = {"o", "oh", "zero"}
|
||||
self.ones = {
|
||||
name: i
|
||||
for i, name in enumerate(
|
||||
[
|
||||
"one",
|
||||
"two",
|
||||
"three",
|
||||
"four",
|
||||
"five",
|
||||
"six",
|
||||
"seven",
|
||||
"eight",
|
||||
"nine",
|
||||
"ten",
|
||||
"eleven",
|
||||
"twelve",
|
||||
"thirteen",
|
||||
"fourteen",
|
||||
"fifteen",
|
||||
"sixteen",
|
||||
"seventeen",
|
||||
"eighteen",
|
||||
"nineteen",
|
||||
],
|
||||
start=1,
|
||||
)
|
||||
}
|
||||
self.ones_plural = {
|
||||
"sixes" if name == "six" else name + "s": (value, "s")
|
||||
for name, value in self.ones.items()
|
||||
}
|
||||
self.ones_ordinal = {
|
||||
"zeroth": (0, "th"),
|
||||
"first": (1, "st"),
|
||||
"second": (2, "nd"),
|
||||
"third": (3, "rd"),
|
||||
"fifth": (5, "th"),
|
||||
"twelfth": (12, "th"),
|
||||
**{
|
||||
name + ("h" if name.endswith("t") else "th"): (value, "th")
|
||||
for name, value in self.ones.items()
|
||||
if value > 3 and value != 5 and value != 12
|
||||
},
|
||||
}
|
||||
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
|
||||
|
||||
self.tens = {
|
||||
"twenty": 20,
|
||||
"thirty": 30,
|
||||
"forty": 40,
|
||||
"fifty": 50,
|
||||
"sixty": 60,
|
||||
"seventy": 70,
|
||||
"eighty": 80,
|
||||
"ninety": 90,
|
||||
}
|
||||
self.tens_plural = {
|
||||
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_ordinal = {
|
||||
name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
|
||||
|
||||
self.multipliers = {
|
||||
"hundred": 100,
|
||||
"thousand": 1_000,
|
||||
"million": 1_000_000,
|
||||
"billion": 1_000_000_000,
|
||||
"trillion": 1_000_000_000_000,
|
||||
"quadrillion": 1_000_000_000_000_000,
|
||||
"quintillion": 1_000_000_000_000_000_000,
|
||||
"sextillion": 1_000_000_000_000_000_000_000,
|
||||
"septillion": 1_000_000_000_000_000_000_000_000,
|
||||
"octillion": 1_000_000_000_000_000_000_000_000_000,
|
||||
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
|
||||
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
|
||||
}
|
||||
self.multipliers_plural = {
|
||||
name + "s": (value, "s") for name, value in self.multipliers.items()
|
||||
}
|
||||
self.multipliers_ordinal = {
|
||||
name + "th": (value, "th") for name, value in self.multipliers.items()
|
||||
}
|
||||
self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal}
|
||||
self.decimals = {*self.ones, *self.tens, *self.zeros}
|
||||
|
||||
self.preceding_prefixers = {
|
||||
"minus": "-",
|
||||
"negative": "-",
|
||||
"plus": "+",
|
||||
"positive": "+",
|
||||
}
|
||||
self.following_prefixers = {
|
||||
"pound": "£",
|
||||
"pounds": "£",
|
||||
"euro": "€",
|
||||
"euros": "€",
|
||||
"dollar": "$",
|
||||
"dollars": "$",
|
||||
"cent": "¢",
|
||||
"cents": "¢",
|
||||
}
|
||||
self.prefixes = set(
|
||||
list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
|
||||
)
|
||||
self.suffixers = {
|
||||
"per": {"cent": "%"},
|
||||
"percent": "%",
|
||||
}
|
||||
self.specials = {"and", "double", "triple", "point"}
|
||||
|
||||
self.words = set(
|
||||
[
|
||||
key
|
||||
for mapping in [
|
||||
self.zeros,
|
||||
self.ones,
|
||||
self.ones_suffixed,
|
||||
self.tens,
|
||||
self.tens_suffixed,
|
||||
self.multipliers,
|
||||
self.multipliers_suffixed,
|
||||
self.preceding_prefixers,
|
||||
self.following_prefixers,
|
||||
self.suffixers,
|
||||
self.specials,
|
||||
]
|
||||
for key in mapping
|
||||
]
|
||||
)
|
||||
self.literal_words = {"one", "ones"}
|
||||
|
||||
def process_words(self, words: List[str]) -> Iterator[str]:
|
||||
prefix: Optional[str] = None
|
||||
value: Optional[Union[str, int]] = None
|
||||
skip = False
|
||||
|
||||
def to_fraction(s: str):
|
||||
try:
|
||||
return Fraction(s)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def output(result: Union[str, int]):
|
||||
nonlocal prefix, value
|
||||
result = str(result)
|
||||
if prefix is not None:
|
||||
result = prefix + result
|
||||
value = None
|
||||
prefix = None
|
||||
return result
|
||||
|
||||
if len(words) == 0:
|
||||
return
|
||||
|
||||
for prev, current, next in windowed([None] + words + [None], 3):
|
||||
if skip:
|
||||
skip = False
|
||||
continue
|
||||
|
||||
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
|
||||
has_prefix = current[0] in self.prefixes
|
||||
current_without_prefix = current[1:] if has_prefix else current
|
||||
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
|
||||
# arabic numbers (potentially with signs and fractions)
|
||||
f = to_fraction(current_without_prefix)
|
||||
assert f is not None
|
||||
if value is not None:
|
||||
if isinstance(value, str) and value.endswith("."):
|
||||
# concatenate decimals / ip address components
|
||||
value = str(value) + str(current)
|
||||
continue
|
||||
else:
|
||||
yield output(value)
|
||||
|
||||
prefix = current[0] if has_prefix else prefix
|
||||
if f.denominator == 1:
|
||||
value = f.numerator # store integers as int
|
||||
else:
|
||||
value = current_without_prefix
|
||||
elif current not in self.words:
|
||||
# non-numeric words
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current in self.zeros:
|
||||
value = str(value or "") + "0"
|
||||
elif current in self.ones:
|
||||
ones = self.ones[current]
|
||||
|
||||
if value is None:
|
||||
value = ones
|
||||
elif isinstance(value, str) or prev in self.ones:
|
||||
if prev in self.tens and ones < 10: # replace the last zero with the digit
|
||||
assert value[-1] == "0"
|
||||
value = value[:-1] + str(ones)
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
elif ones < 10:
|
||||
if value % 10 == 0:
|
||||
value += ones
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
else: # eleven to nineteen
|
||||
if value % 100 == 0:
|
||||
value += ones
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
elif current in self.ones_suffixed:
|
||||
# ordinal or cardinal; yield the number right away
|
||||
ones, suffix = self.ones_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(ones) + suffix)
|
||||
elif isinstance(value, str) or prev in self.ones:
|
||||
if prev in self.tens and ones < 10:
|
||||
assert value[-1] == "0"
|
||||
yield output(value[:-1] + str(ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
elif ones < 10:
|
||||
if value % 10 == 0:
|
||||
yield output(str(value + ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
else: # eleven to nineteen
|
||||
if value % 100 == 0:
|
||||
yield output(str(value + ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
value = None
|
||||
elif current in self.tens:
|
||||
tens = self.tens[current]
|
||||
if value is None:
|
||||
value = tens
|
||||
elif isinstance(value, str):
|
||||
value = str(value) + str(tens)
|
||||
else:
|
||||
if value % 100 == 0:
|
||||
value += tens
|
||||
else:
|
||||
value = str(value) + str(tens)
|
||||
elif current in self.tens_suffixed:
|
||||
# ordinal or cardinal; yield the number right away
|
||||
tens, suffix = self.tens_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(tens) + suffix)
|
||||
elif isinstance(value, str):
|
||||
yield output(str(value) + str(tens) + suffix)
|
||||
else:
|
||||
if value % 100 == 0:
|
||||
yield output(str(value + tens) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(tens) + suffix)
|
||||
elif current in self.multipliers:
|
||||
multiplier = self.multipliers[current]
|
||||
if value is None:
|
||||
value = multiplier
|
||||
elif isinstance(value, str) or value == 0:
|
||||
f = to_fraction(value)
|
||||
p = f * multiplier if f is not None else None
|
||||
if f is not None and p.denominator == 1:
|
||||
value = p.numerator
|
||||
else:
|
||||
yield output(value)
|
||||
value = multiplier
|
||||
else:
|
||||
before = value // 1000 * 1000
|
||||
residual = value % 1000
|
||||
value = before + residual * multiplier
|
||||
elif current in self.multipliers_suffixed:
|
||||
multiplier, suffix = self.multipliers_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(multiplier) + suffix)
|
||||
elif isinstance(value, str):
|
||||
f = to_fraction(value)
|
||||
p = f * multiplier if f is not None else None
|
||||
if f is not None and p.denominator == 1:
|
||||
yield output(str(p.numerator) + suffix)
|
||||
else:
|
||||
yield output(value)
|
||||
yield output(str(multiplier) + suffix)
|
||||
else: # int
|
||||
before = value // 1000 * 1000
|
||||
residual = value % 1000
|
||||
value = before + residual * multiplier
|
||||
yield output(str(value) + suffix)
|
||||
value = None
|
||||
elif current in self.preceding_prefixers:
|
||||
# apply prefix (positive, minus, etc.) if it precedes a number
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
|
||||
if next in self.words or next_is_numeric:
|
||||
prefix = self.preceding_prefixers[current]
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.following_prefixers:
|
||||
# apply prefix (dollars, cents, etc.) only after a number
|
||||
if value is not None:
|
||||
prefix = self.following_prefixers[current]
|
||||
yield output(value)
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.suffixers:
|
||||
# apply suffix symbols (percent -> '%')
|
||||
if value is not None:
|
||||
suffix = self.suffixers[current]
|
||||
if isinstance(suffix, dict):
|
||||
if next in suffix:
|
||||
yield output(str(value) + suffix[next])
|
||||
skip = True
|
||||
else:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
else:
|
||||
yield output(str(value) + suffix)
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.specials:
|
||||
if next not in self.words and not next_is_numeric:
|
||||
# apply special handling only if the next word can be numeric
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "and":
|
||||
# ignore "and" after hundreds, thousands, etc.
|
||||
if prev not in self.multipliers:
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "double" or current == "triple":
|
||||
if next in self.ones or next in self.zeros:
|
||||
repeats = 2 if current == "double" else 3
|
||||
ones = self.ones.get(next, 0)
|
||||
value = str(value or "") + str(ones) * repeats
|
||||
skip = True
|
||||
else:
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "point":
|
||||
if next in self.decimals or next_is_numeric:
|
||||
value = str(value or "") + "."
|
||||
else:
|
||||
# should all have been covered at this point
|
||||
raise ValueError(f"Unexpected token: {current}")
|
||||
else:
|
||||
# all should have been covered at this point
|
||||
raise ValueError(f"Unexpected token: {current}")
|
||||
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
|
||||
def preprocess(self, s: str):
|
||||
# replace "<number> and a half" with "<number> point five"
|
||||
results = []
|
||||
|
||||
segments = re.split(r"\band\s+a\s+half\b", s)
|
||||
for i, segment in enumerate(segments):
|
||||
if len(segment.strip()) == 0:
|
||||
continue
|
||||
if i == len(segments) - 1:
|
||||
results.append(segment)
|
||||
else:
|
||||
results.append(segment)
|
||||
last_word = segment.rsplit(maxsplit=2)[-1]
|
||||
if last_word in self.decimals or last_word in self.multipliers:
|
||||
results.append("point five")
|
||||
else:
|
||||
results.append("and a half")
|
||||
|
||||
s = " ".join(results)
|
||||
|
||||
# put a space at number/letter boundary
|
||||
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
|
||||
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
|
||||
|
||||
# but remove spaces which could be a suffix
|
||||
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
|
||||
|
||||
return s
|
||||
|
||||
def postprocess(self, s: str):
|
||||
def combine_cents(m: Match):
|
||||
try:
|
||||
currency = m.group(1)
|
||||
integer = m.group(2)
|
||||
cents = int(m.group(3))
|
||||
return f"{currency}{integer}.{cents:02d}"
|
||||
except ValueError:
|
||||
return m.string
|
||||
|
||||
def extract_cents(m: Match):
|
||||
try:
|
||||
return f"¢{int(m.group(1))}"
|
||||
except ValueError:
|
||||
return m.string
|
||||
|
||||
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
|
||||
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
|
||||
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
|
||||
|
||||
# write "one(s)" instead of "1(s)", just for the readability
|
||||
s = re.sub(r"\b1(s?)\b", r"one\1", s)
|
||||
|
||||
return s
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = self.preprocess(s)
|
||||
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
|
||||
s = self.postprocess(s)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class EnglishSpellingNormalizer:
|
||||
"""
|
||||
Applies British-American spelling mappings as listed in [1].
|
||||
|
||||
[1] https://www.tysto.com/uk-us-spelling-list.html
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
|
||||
self.mapping = json.load(open(mapping_path))
|
||||
|
||||
def __call__(self, s: str):
|
||||
return " ".join(self.mapping.get(word, word) for word in s.split())
|
||||
|
||||
|
||||
class EnglishTextNormalizer:
|
||||
def __init__(self):
|
||||
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
|
||||
self.replacers = {
|
||||
# common contractions
|
||||
r"\bwon't\b": "will not",
|
||||
r"\bcan't\b": "can not",
|
||||
r"\blet's\b": "let us",
|
||||
r"\bain't\b": "aint",
|
||||
r"\by'all\b": "you all",
|
||||
r"\bwanna\b": "want to",
|
||||
r"\bgotta\b": "got to",
|
||||
r"\bgonna\b": "going to",
|
||||
r"\bi'ma\b": "i am going to",
|
||||
r"\bimma\b": "i am going to",
|
||||
r"\bwoulda\b": "would have",
|
||||
r"\bcoulda\b": "could have",
|
||||
r"\bshoulda\b": "should have",
|
||||
r"\bma'am\b": "madam",
|
||||
# contractions in titles/prefixes
|
||||
r"\bmr\b": "mister ",
|
||||
r"\bmrs\b": "missus ",
|
||||
r"\bst\b": "saint ",
|
||||
r"\bdr\b": "doctor ",
|
||||
r"\bprof\b": "professor ",
|
||||
r"\bcapt\b": "captain ",
|
||||
r"\bgov\b": "governor ",
|
||||
r"\bald\b": "alderman ",
|
||||
r"\bgen\b": "general ",
|
||||
r"\bsen\b": "senator ",
|
||||
r"\brep\b": "representative ",
|
||||
r"\bpres\b": "president ",
|
||||
r"\brev\b": "reverend ",
|
||||
r"\bhon\b": "honorable ",
|
||||
r"\basst\b": "assistant ",
|
||||
r"\bassoc\b": "associate ",
|
||||
r"\blt\b": "lieutenant ",
|
||||
r"\bcol\b": "colonel ",
|
||||
r"\bjr\b": "junior ",
|
||||
r"\bsr\b": "senior ",
|
||||
r"\besq\b": "esquire ",
|
||||
# prefect tenses, ideally it should be any past participles, but it's harder..
|
||||
r"'d been\b": " had been",
|
||||
r"'s been\b": " has been",
|
||||
r"'d gone\b": " had gone",
|
||||
r"'s gone\b": " has gone",
|
||||
r"'d done\b": " had done", # "'s done" is ambiguous
|
||||
r"'s got\b": " has got",
|
||||
# general contractions
|
||||
r"n't\b": " not",
|
||||
r"'re\b": " are",
|
||||
r"'s\b": " is",
|
||||
r"'d\b": " would",
|
||||
r"'ll\b": " will",
|
||||
r"'t\b": " not",
|
||||
r"'ve\b": " have",
|
||||
r"'m\b": " am",
|
||||
}
|
||||
self.standardize_numbers = EnglishNumberNormalizer()
|
||||
self.standardize_spellings = EnglishSpellingNormalizer()
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = s.lower()
|
||||
|
||||
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||
s = re.sub(self.ignore_patterns, "", s)
|
||||
s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe
|
||||
|
||||
for pattern, replacement in self.replacers.items():
|
||||
s = re.sub(pattern, replacement, s)
|
||||
|
||||
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
|
||||
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
|
||||
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics
|
||||
|
||||
s = self.standardize_numbers(s)
|
||||
s = self.standardize_spellings(s)
|
||||
|
||||
# now remove prefix/suffix symbols that are not preceded/followed by numbers
|
||||
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
|
||||
s = re.sub(r"([^0-9])%", r"\1 ", s)
|
||||
|
||||
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
|
||||
|
||||
return s
|
||||
331
models/MuseTalk/musetalk/whisper/whisper/tokenizer.py
Normal file
331
models/MuseTalk/musetalk/whisper/whisper/tokenizer.py
Normal file
@@ -0,0 +1,331 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import GPT2TokenizerFast
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english",
|
||||
"zh": "chinese",
|
||||
"de": "german",
|
||||
"es": "spanish",
|
||||
"ru": "russian",
|
||||
"ko": "korean",
|
||||
"fr": "french",
|
||||
"ja": "japanese",
|
||||
"pt": "portuguese",
|
||||
"tr": "turkish",
|
||||
"pl": "polish",
|
||||
"ca": "catalan",
|
||||
"nl": "dutch",
|
||||
"ar": "arabic",
|
||||
"sv": "swedish",
|
||||
"it": "italian",
|
||||
"id": "indonesian",
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"iw": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
"cs": "czech",
|
||||
"ro": "romanian",
|
||||
"da": "danish",
|
||||
"hu": "hungarian",
|
||||
"ta": "tamil",
|
||||
"no": "norwegian",
|
||||
"th": "thai",
|
||||
"ur": "urdu",
|
||||
"hr": "croatian",
|
||||
"bg": "bulgarian",
|
||||
"lt": "lithuanian",
|
||||
"la": "latin",
|
||||
"mi": "maori",
|
||||
"ml": "malayalam",
|
||||
"cy": "welsh",
|
||||
"sk": "slovak",
|
||||
"te": "telugu",
|
||||
"fa": "persian",
|
||||
"lv": "latvian",
|
||||
"bn": "bengali",
|
||||
"sr": "serbian",
|
||||
"az": "azerbaijani",
|
||||
"sl": "slovenian",
|
||||
"kn": "kannada",
|
||||
"et": "estonian",
|
||||
"mk": "macedonian",
|
||||
"br": "breton",
|
||||
"eu": "basque",
|
||||
"is": "icelandic",
|
||||
"hy": "armenian",
|
||||
"ne": "nepali",
|
||||
"mn": "mongolian",
|
||||
"bs": "bosnian",
|
||||
"kk": "kazakh",
|
||||
"sq": "albanian",
|
||||
"sw": "swahili",
|
||||
"gl": "galician",
|
||||
"mr": "marathi",
|
||||
"pa": "punjabi",
|
||||
"si": "sinhala",
|
||||
"km": "khmer",
|
||||
"sn": "shona",
|
||||
"yo": "yoruba",
|
||||
"so": "somali",
|
||||
"af": "afrikaans",
|
||||
"oc": "occitan",
|
||||
"ka": "georgian",
|
||||
"be": "belarusian",
|
||||
"tg": "tajik",
|
||||
"sd": "sindhi",
|
||||
"gu": "gujarati",
|
||||
"am": "amharic",
|
||||
"yi": "yiddish",
|
||||
"lo": "lao",
|
||||
"uz": "uzbek",
|
||||
"fo": "faroese",
|
||||
"ht": "haitian creole",
|
||||
"ps": "pashto",
|
||||
"tk": "turkmen",
|
||||
"nn": "nynorsk",
|
||||
"mt": "maltese",
|
||||
"sa": "sanskrit",
|
||||
"lb": "luxembourgish",
|
||||
"my": "myanmar",
|
||||
"bo": "tibetan",
|
||||
"tl": "tagalog",
|
||||
"mg": "malagasy",
|
||||
"as": "assamese",
|
||||
"tt": "tatar",
|
||||
"haw": "hawaiian",
|
||||
"ln": "lingala",
|
||||
"ha": "hausa",
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
}
|
||||
|
||||
# language code lookup by name, with a few language aliases
|
||||
TO_LANGUAGE_CODE = {
|
||||
**{language: code for code, language in LANGUAGES.items()},
|
||||
"burmese": "my",
|
||||
"valencian": "ca",
|
||||
"flemish": "nl",
|
||||
"haitian": "ht",
|
||||
"letzeburgesch": "lb",
|
||||
"pushto": "ps",
|
||||
"panjabi": "pa",
|
||||
"moldavian": "ro",
|
||||
"moldovan": "ro",
|
||||
"sinhalese": "si",
|
||||
"castilian": "es",
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Tokenizer:
|
||||
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
|
||||
|
||||
tokenizer: "GPT2TokenizerFast"
|
||||
language: Optional[str]
|
||||
sot_sequence: Tuple[int]
|
||||
|
||||
def encode(self, text, **kwargs):
|
||||
return self.tokenizer.encode(text, **kwargs)
|
||||
|
||||
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
|
||||
return self.tokenizer.decode(token_ids, **kwargs)
|
||||
|
||||
def decode_with_timestamps(self, tokens) -> str:
|
||||
"""
|
||||
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
|
||||
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||
"""
|
||||
outputs = [[]]
|
||||
for token in tokens:
|
||||
if token >= self.timestamp_begin:
|
||||
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
||||
outputs.append(timestamp)
|
||||
outputs.append([])
|
||||
else:
|
||||
outputs[-1].append(token)
|
||||
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||
return "".join(outputs)
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def eot(self) -> int:
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def sot(self) -> int:
|
||||
return self._get_single_token_id("<|startoftranscript|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def sot_lm(self) -> int:
|
||||
return self._get_single_token_id("<|startoflm|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def sot_prev(self) -> int:
|
||||
return self._get_single_token_id("<|startofprev|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def no_speech(self) -> int:
|
||||
return self._get_single_token_id("<|nospeech|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def no_timestamps(self) -> int:
|
||||
return self._get_single_token_id("<|notimestamps|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def timestamp_begin(self) -> int:
|
||||
return self.tokenizer.all_special_ids[-1] + 1
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def language_token(self) -> int:
|
||||
"""Returns the token id corresponding to the value of the `language` field"""
|
||||
if self.language is None:
|
||||
raise ValueError(f"This tokenizer does not have language token configured")
|
||||
|
||||
additional_tokens = dict(
|
||||
zip(
|
||||
self.tokenizer.additional_special_tokens,
|
||||
self.tokenizer.additional_special_tokens_ids,
|
||||
)
|
||||
)
|
||||
candidate = f"<|{self.language}|>"
|
||||
if candidate in additional_tokens:
|
||||
return additional_tokens[candidate]
|
||||
|
||||
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def all_language_tokens(self) -> Tuple[int]:
|
||||
result = []
|
||||
for token, token_id in zip(
|
||||
self.tokenizer.additional_special_tokens,
|
||||
self.tokenizer.additional_special_tokens_ids,
|
||||
):
|
||||
if token.strip("<|>") in LANGUAGES:
|
||||
result.append(token_id)
|
||||
return tuple(result)
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def all_language_codes(self) -> Tuple[str]:
|
||||
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
||||
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def non_speech_tokens(self) -> Tuple[int]:
|
||||
"""
|
||||
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
||||
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
||||
|
||||
- ♪♪♪
|
||||
- ( SPEAKING FOREIGN LANGUAGE )
|
||||
- [DAVID] Hey there,
|
||||
|
||||
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||
"""
|
||||
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
|
||||
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||
|
||||
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
||||
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
||||
miscellaneous = set("♩♪♫♬♭♮♯")
|
||||
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||
|
||||
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
|
||||
for symbol in symbols + list(miscellaneous):
|
||||
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
|
||||
if len(tokens) == 1 or symbol in miscellaneous:
|
||||
result.add(tokens[0])
|
||||
|
||||
return tuple(sorted(result))
|
||||
|
||||
def _get_single_token_id(self, text) -> int:
|
||||
tokens = self.tokenizer.encode(text)
|
||||
assert len(tokens) == 1, f"{text} is not encoded as a single token"
|
||||
return tokens[0]
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def build_tokenizer(name: str = "gpt2"):
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
path = os.path.join(os.path.dirname(__file__), "assets", name)
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(path)
|
||||
|
||||
specials = [
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
]
|
||||
|
||||
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
|
||||
return tokenizer
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_tokenizer(
|
||||
multilingual: bool,
|
||||
*,
|
||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||
language: Optional[str] = None,
|
||||
) -> Tokenizer:
|
||||
if language is not None:
|
||||
language = language.lower()
|
||||
if language not in LANGUAGES:
|
||||
if language in TO_LANGUAGE_CODE:
|
||||
language = TO_LANGUAGE_CODE[language]
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {language}")
|
||||
|
||||
if multilingual:
|
||||
tokenizer_name = "multilingual"
|
||||
task = task or "transcribe"
|
||||
language = language or "en"
|
||||
else:
|
||||
tokenizer_name = "gpt2"
|
||||
task = None
|
||||
language = None
|
||||
|
||||
tokenizer = build_tokenizer(name=tokenizer_name)
|
||||
all_special_ids: List[int] = tokenizer.all_special_ids
|
||||
sot: int = all_special_ids[1]
|
||||
translate: int = all_special_ids[-6]
|
||||
transcribe: int = all_special_ids[-5]
|
||||
|
||||
langs = tuple(LANGUAGES.keys())
|
||||
sot_sequence = [sot]
|
||||
if language is not None:
|
||||
sot_sequence.append(sot + 1 + langs.index(language))
|
||||
if task is not None:
|
||||
sot_sequence.append(transcribe if task == "transcribe" else translate)
|
||||
|
||||
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
|
||||
207
models/MuseTalk/musetalk/whisper/whisper/transcribe.py
Normal file
207
models/MuseTalk/musetalk/whisper/whisper/transcribe.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
def transcribe(
|
||||
model: "Whisper",
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
*,
|
||||
verbose: Optional[bool] = None,
|
||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
compression_ratio_threshold: Optional[float] = 2.4,
|
||||
logprob_threshold: Optional[float] = -1.0,
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
force_extraction: bool = False,
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
Transcribe an audio file using Whisper
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Whisper
|
||||
The Whisper model instance
|
||||
|
||||
audio: Union[str, np.ndarray, torch.Tensor]
|
||||
The path to the audio file to open, or the audio waveform
|
||||
|
||||
verbose: bool
|
||||
Whether to display the text being decoded to the console. If True, displays all the details,
|
||||
If False, displays minimal details. If None, does not display anything
|
||||
|
||||
temperature: Union[float, Tuple[float, ...]]
|
||||
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
|
||||
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
||||
|
||||
compression_ratio_threshold: float
|
||||
If the gzip compression ratio is above this value, treat as failed
|
||||
|
||||
logprob_threshold: float
|
||||
If the average log probability over sampled tokens is below this value, treat as failed
|
||||
|
||||
no_speech_threshold: float
|
||||
If the no_speech probability is higher than this value AND the average log probability
|
||||
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
||||
|
||||
condition_on_previous_text: bool
|
||||
if True, the previous output of the model is provided as a prompt for the next window;
|
||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
||||
|
||||
decode_options: dict
|
||||
Keyword arguments to construct `DecodingOptions` instances
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||
"""
|
||||
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
||||
if model.device == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
warnings.warn("Performing inference on CPU when CUDA is available")
|
||||
if dtype == torch.float16:
|
||||
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
||||
dtype = torch.float32
|
||||
|
||||
if dtype == torch.float32:
|
||||
decode_options["fp16"] = False
|
||||
|
||||
mel = log_mel_spectrogram(audio)
|
||||
|
||||
all_segments = []
|
||||
def add_segment(
|
||||
*, start: float, end: float, encoder_embeddings
|
||||
):
|
||||
|
||||
all_segments.append(
|
||||
{
|
||||
"start": start,
|
||||
"end": end,
|
||||
"encoder_embeddings":encoder_embeddings,
|
||||
}
|
||||
)
|
||||
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
||||
num_frames = mel.shape[-1]
|
||||
seek = 0
|
||||
previous_seek_value = seek
|
||||
sample_skip = 3000 #
|
||||
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
|
||||
while seek < num_frames:
|
||||
# seek是开始的帧数
|
||||
end_seek = min(seek + sample_skip, num_frames)
|
||||
segment = pad_or_trim(mel[:,seek:seek+sample_skip], N_FRAMES).to(model.device).to(dtype)
|
||||
|
||||
single = segment.ndim == 2
|
||||
if single:
|
||||
segment = segment.unsqueeze(0)
|
||||
if dtype == torch.float16:
|
||||
segment = segment.half()
|
||||
audio_features, embeddings = model.encoder(segment, include_embeddings = True)
|
||||
|
||||
encoder_embeddings = embeddings
|
||||
#print(f"encoder_embeddings shape {encoder_embeddings.shape}")
|
||||
add_segment(
|
||||
start=seek,
|
||||
end=end_seek,
|
||||
#text_tokens=tokens,
|
||||
#result=result,
|
||||
encoder_embeddings=encoder_embeddings,
|
||||
)
|
||||
seek+=sample_skip
|
||||
|
||||
return dict(segments=all_segments)
|
||||
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||
|
||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||
|
||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||
|
||||
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
model_name: str = args.pop("model")
|
||||
model_dir: str = args.pop("model_dir")
|
||||
output_dir: str = args.pop("output_dir")
|
||||
device: str = args.pop("device")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||
if args["language"] is not None:
|
||||
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
||||
args["language"] = "en"
|
||||
|
||||
temperature = args.pop("temperature")
|
||||
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
|
||||
if temperature_increment_on_fallback is not None:
|
||||
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
||||
else:
|
||||
temperature = [temperature]
|
||||
|
||||
threads = args.pop("threads")
|
||||
if threads > 0:
|
||||
torch.set_num_threads(threads)
|
||||
|
||||
from . import load_model
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
|
||||
for audio_path in args.pop("audio"):
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
|
||||
# save TXT
|
||||
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
|
||||
write_txt(result["segments"], file=txt)
|
||||
|
||||
# save VTT
|
||||
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
|
||||
write_vtt(result["segments"], file=vtt)
|
||||
|
||||
# save SRT
|
||||
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
||||
write_srt(result["segments"], file=srt)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli()
|
||||
87
models/MuseTalk/musetalk/whisper/whisper/utils.py
Normal file
87
models/MuseTalk/musetalk/whisper/whisper/utils.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import zlib
|
||||
from typing import Iterator, TextIO
|
||||
|
||||
|
||||
def exact_div(x, y):
|
||||
assert x % y == 0
|
||||
return x // y
|
||||
|
||||
|
||||
def str2bool(string):
|
||||
str2val = {"True": True, "False": False}
|
||||
if string in str2val:
|
||||
return str2val[string]
|
||||
else:
|
||||
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
||||
|
||||
|
||||
def optional_int(string):
|
||||
return None if string == "None" else int(string)
|
||||
|
||||
|
||||
def optional_float(string):
|
||||
return None if string == "None" else float(string)
|
||||
|
||||
|
||||
def compression_ratio(text) -> float:
|
||||
return len(text) / len(zlib.compress(text.encode("utf-8")))
|
||||
|
||||
|
||||
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
|
||||
assert seconds >= 0, "non-negative timestamp expected"
|
||||
milliseconds = round(seconds * 1000.0)
|
||||
|
||||
hours = milliseconds // 3_600_000
|
||||
milliseconds -= hours * 3_600_000
|
||||
|
||||
minutes = milliseconds // 60_000
|
||||
milliseconds -= minutes * 60_000
|
||||
|
||||
seconds = milliseconds // 1_000
|
||||
milliseconds -= seconds * 1_000
|
||||
|
||||
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
||||
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||
|
||||
|
||||
def write_txt(transcript: Iterator[dict], file: TextIO):
|
||||
for segment in transcript:
|
||||
print(segment['text'].strip(), file=file, flush=True)
|
||||
|
||||
|
||||
def write_vtt(transcript: Iterator[dict], file: TextIO):
|
||||
print("WEBVTT\n", file=file)
|
||||
for segment in transcript:
|
||||
print(
|
||||
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def write_srt(transcript: Iterator[dict], file: TextIO):
|
||||
"""
|
||||
Write a transcript to a file in SRT format.
|
||||
|
||||
Example usage:
|
||||
from pathlib import Path
|
||||
from whisper.utils import write_srt
|
||||
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
|
||||
# save SRT
|
||||
audio_basename = Path(audio_path).stem
|
||||
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
||||
write_srt(result["segments"], file=srt)
|
||||
"""
|
||||
for i, segment in enumerate(transcript, start=1):
|
||||
# write srt lines
|
||||
print(
|
||||
f"{i}\n"
|
||||
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
||||
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
)
|
||||
157
models/MuseTalk/musetalk_api.py
Normal file
157
models/MuseTalk/musetalk_api.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
MuseTalk API 服务
|
||||
|
||||
这个脚本将 MuseTalk 封装为 FastAPI 服务,
|
||||
可以独立部署在 GPU 服务器上。
|
||||
|
||||
用法:
|
||||
python musetalk_api.py --port 8001
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import tempfile
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
# 添加 MuseTalk 路径
|
||||
MUSETALK_DIR = Path(__file__).parent
|
||||
sys.path.insert(0, str(MUSETALK_DIR))
|
||||
|
||||
app = FastAPI(
|
||||
title="MuseTalk API",
|
||||
description="唇形同步推理服务",
|
||||
version="0.1.0"
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 全局模型实例 (懒加载)
|
||||
_model = None
|
||||
|
||||
|
||||
def get_model():
|
||||
"""懒加载 MuseTalk 模型"""
|
||||
global _model
|
||||
if _model is None:
|
||||
print("🔄 加载 MuseTalk 模型...")
|
||||
# TODO: 根据 MuseTalk 实际 API 调整
|
||||
# from musetalk.inference import MuseTalkInference
|
||||
# _model = MuseTalkInference()
|
||||
print("✅ MuseTalk 模型加载完成")
|
||||
return _model
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"name": "MuseTalk API", "status": "ok"}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""健康检查"""
|
||||
return {"status": "healthy", "gpu": True}
|
||||
|
||||
|
||||
@app.post("/lipsync")
|
||||
async def lipsync(
|
||||
video: UploadFile = File(..., description="输入视频文件"),
|
||||
audio: UploadFile = File(..., description="音频文件"),
|
||||
fps: int = Form(25, description="输出帧率")
|
||||
):
|
||||
"""
|
||||
唇形同步推理
|
||||
|
||||
Args:
|
||||
video: 输入视频 (静态人物)
|
||||
audio: 驱动音频
|
||||
fps: 输出帧率
|
||||
|
||||
Returns:
|
||||
生成的视频文件
|
||||
"""
|
||||
# 创建临时目录
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir = Path(tmpdir)
|
||||
|
||||
# 保存上传的文件
|
||||
video_path = tmpdir / "input_video.mp4"
|
||||
audio_path = tmpdir / "input_audio.wav"
|
||||
output_path = tmpdir / "output.mp4"
|
||||
|
||||
with open(video_path, "wb") as f:
|
||||
shutil.copyfileobj(video.file, f)
|
||||
with open(audio_path, "wb") as f:
|
||||
shutil.copyfileobj(audio.file, f)
|
||||
|
||||
try:
|
||||
# 执行唇形同步
|
||||
model = get_model()
|
||||
|
||||
# TODO: 调用实际的 MuseTalk 推理
|
||||
# result = model.inference(
|
||||
# source_video=str(video_path),
|
||||
# driving_audio=str(audio_path),
|
||||
# output_path=str(output_path),
|
||||
# fps=fps
|
||||
# )
|
||||
|
||||
# 临时: 使用 subprocess 调用 MuseTalk CLI
|
||||
import subprocess
|
||||
cmd = [
|
||||
sys.executable, "-m", "scripts.inference",
|
||||
"--video_path", str(video_path),
|
||||
"--audio_path", str(audio_path),
|
||||
"--output_path", str(output_path),
|
||||
]
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=str(MUSETALK_DIR),
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"MuseTalk 推理失败: {result.stderr}")
|
||||
|
||||
if not output_path.exists():
|
||||
raise RuntimeError("输出文件不存在")
|
||||
|
||||
# 返回生成的视频
|
||||
# 需要先复制到持久化位置
|
||||
final_output = Path("outputs") / f"lipsync_{video.filename}"
|
||||
final_output.parent.mkdir(exist_ok=True)
|
||||
shutil.copy(output_path, final_output)
|
||||
|
||||
return FileResponse(
|
||||
final_output,
|
||||
media_type="video/mp4",
|
||||
filename=f"lipsync_{video.filename}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"🚀 MuseTalk API 启动在 http://{args.host}:{args.port}")
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
20
models/MuseTalk/requirements.txt
Normal file
20
models/MuseTalk/requirements.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
diffusers==0.30.2
|
||||
accelerate==0.28.0
|
||||
numpy==1.23.5
|
||||
tensorflow==2.12.0
|
||||
tensorboard==2.12.0
|
||||
opencv-python==4.9.0.80
|
||||
soundfile==0.12.1
|
||||
transformers==4.39.2
|
||||
huggingface_hub==0.30.2
|
||||
librosa==0.11.0
|
||||
einops==0.8.1
|
||||
gradio==5.24.0
|
||||
|
||||
gdown
|
||||
requests
|
||||
imageio[ffmpeg]
|
||||
|
||||
omegaconf
|
||||
ffmpeg-python
|
||||
moviepy
|
||||
1
models/MuseTalk/scripts/__init__.py
Normal file
1
models/MuseTalk/scripts/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
314
models/MuseTalk/scripts/inference.py
Normal file
314
models/MuseTalk/scripts/inference.py
Normal file
@@ -0,0 +1,314 @@
|
||||
import os
|
||||
import cv2
|
||||
import math
|
||||
import copy
|
||||
import torch
|
||||
import glob
|
||||
import shutil
|
||||
import pickle
|
||||
import argparse
|
||||
import numpy as np
|
||||
import subprocess
|
||||
from tqdm import tqdm
|
||||
from omegaconf import OmegaConf
|
||||
from transformers import WhisperModel
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
# Try imports, handle if running as script vs module
|
||||
try:
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.face_parsing import FaceParsing
|
||||
from musetalk.utils.audio_processor import AudioProcessor
|
||||
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
||||
except ImportError:
|
||||
# If running from root directory
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.face_parsing import FaceParsing
|
||||
from musetalk.utils.audio_processor import AudioProcessor
|
||||
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
||||
|
||||
def fast_check_ffmpeg():
|
||||
try:
|
||||
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def run_ffmpeg(cmd):
|
||||
print(f"Executing: {cmd}")
|
||||
try:
|
||||
# Use shell=True to support the command string format used
|
||||
result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error executing ffmpeg: {cmd}")
|
||||
print(f"Return code: {e.returncode}")
|
||||
print(f"Stdout: {e.stdout}")
|
||||
print(f"Stderr: {e.stderr}")
|
||||
return False
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
# Configure ffmpeg path
|
||||
if not fast_check_ffmpeg():
|
||||
print("Adding ffmpeg to PATH")
|
||||
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||
if not fast_check_ffmpeg():
|
||||
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||
|
||||
# Set computing device
|
||||
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Load model weights
|
||||
vae, unet, pe = load_all_model(
|
||||
unet_model_path=args.unet_model_path,
|
||||
vae_type=args.vae_type,
|
||||
unet_config=args.unet_config,
|
||||
device=device
|
||||
)
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
if args.use_float16:
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
pe = pe.to(device)
|
||||
vae.vae = vae.vae.to(device)
|
||||
unet.model = unet.model.to(device)
|
||||
|
||||
# Initialize components
|
||||
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
|
||||
weight_dtype = unet.model.dtype
|
||||
whisper = WhisperModel.from_pretrained(args.whisper_dir)
|
||||
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
||||
whisper.requires_grad_(False)
|
||||
|
||||
if args.version == "v15":
|
||||
fp = FaceParsing(
|
||||
left_cheek_width=args.left_cheek_width,
|
||||
right_cheek_width=args.right_cheek_width
|
||||
)
|
||||
else:
|
||||
fp = FaceParsing()
|
||||
|
||||
# TASK CONFIGURATION
|
||||
if args.video_path and args.audio_path:
|
||||
print(f"Using command line arguments. Video: {args.video_path}, Audio: {args.audio_path}")
|
||||
inference_config = {
|
||||
"task_cmd": {
|
||||
"video_path": args.video_path,
|
||||
"audio_path": args.audio_path
|
||||
}
|
||||
}
|
||||
if args.output_path:
|
||||
args.output_vid_name = args.output_path
|
||||
else:
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
print("Loaded inference config:", inference_config)
|
||||
|
||||
for task_id in inference_config:
|
||||
try:
|
||||
video_path = inference_config[task_id]["video_path"]
|
||||
audio_path = inference_config[task_id]["audio_path"]
|
||||
if "result_name" in inference_config[task_id]:
|
||||
args.output_vid_name = inference_config[task_id]["result_name"]
|
||||
|
||||
if args.version == "v15":
|
||||
bbox_shift = 0
|
||||
else:
|
||||
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
|
||||
|
||||
input_basename = os.path.basename(video_path).split('.')[0]
|
||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||
output_basename = f"{input_basename}_{audio_basename}"
|
||||
|
||||
temp_dir = os.path.join(args.result_dir, f"{args.version}")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
result_img_save_path = os.path.join(temp_dir, output_basename)
|
||||
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
|
||||
os.makedirs(result_img_save_path, exist_ok=True)
|
||||
|
||||
if args.output_vid_name is None:
|
||||
output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
|
||||
else:
|
||||
if os.path.isabs(args.output_vid_name) or "/" in args.output_vid_name or "\\" in args.output_vid_name:
|
||||
output_vid_name = args.output_vid_name
|
||||
else:
|
||||
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
|
||||
|
||||
if get_file_type(video_path) == "video":
|
||||
save_dir_full = os.path.join(temp_dir, input_basename)
|
||||
os.makedirs(save_dir_full, exist_ok=True)
|
||||
cmd = f"ffmpeg -y -v warning -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
|
||||
if not run_ffmpeg(cmd):
|
||||
raise RuntimeError("FFmpeg failed to extract frames")
|
||||
|
||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||
fps = get_video_fps(video_path)
|
||||
elif get_file_type(video_path) == "image":
|
||||
input_img_list = [video_path]
|
||||
fps = args.fps
|
||||
elif os.path.isdir(video_path):
|
||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
fps = args.fps
|
||||
else:
|
||||
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
||||
|
||||
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
||||
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||
whisper_input_features, device, weight_dtype, whisper, librosa_length, fps=fps,
|
||||
audio_padding_length_left=args.audio_padding_length_left,
|
||||
audio_padding_length_right=args.audio_padding_length_right,
|
||||
)
|
||||
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("Using saved coordinates")
|
||||
with open(crop_coord_save_path, 'rb') as f:
|
||||
coord_list = pickle.load(f)
|
||||
frame_list = read_imgs(input_img_list)
|
||||
else:
|
||||
print("Extracting landmarks...")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
|
||||
print(f"Number of frames: {len(frame_list)}")
|
||||
sys.stdout.flush()
|
||||
|
||||
print("Processing latents...")
|
||||
input_latent_list = []
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
if args.version == "v15":
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
|
||||
print(f"Starting inference with {len(input_latent_list)} latents...")
|
||||
sys.stdout.flush()
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(
|
||||
whisper_chunks=whisper_chunks,
|
||||
vae_encode_latents=input_latent_list_cycle,
|
||||
batch_size=batch_size,
|
||||
delay_frame=0,
|
||||
device=device,
|
||||
)
|
||||
|
||||
res_frame_list = []
|
||||
total = int(np.ceil(float(video_num) / batch_size))
|
||||
|
||||
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
|
||||
audio_feature_batch = pe(whisper_batch)
|
||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
res_frame_list.append(res_frame)
|
||||
|
||||
print(f"Inference complete. Generated {len(res_frame_list)} frames. Padding to original size...")
|
||||
sys.stdout.flush()
|
||||
for i, res_frame in enumerate(tqdm(res_frame_list, disable=True)): # Disable tqdm to avoid output issues
|
||||
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
||||
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
||||
x1, y1, x2, y2 = bbox
|
||||
if args.version == "v15":
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
|
||||
except:
|
||||
continue
|
||||
|
||||
if args.version == "v15":
|
||||
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
||||
else:
|
||||
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
|
||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
|
||||
|
||||
# VIDEO SYNTHESIS
|
||||
temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
|
||||
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
|
||||
print(f"Generating Video from {len(res_frame_list)} frames...")
|
||||
sys.stdout.flush()
|
||||
if not run_ffmpeg(cmd_img2video):
|
||||
print(f"FAILED to generate video from frames at {result_img_save_path}. Keeping frames.")
|
||||
continue # Skip to next task or stop
|
||||
|
||||
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
|
||||
print("Combining Audio...")
|
||||
if not run_ffmpeg(cmd_combine_audio):
|
||||
print(f"FAILED to combine audio. Temp video at {temp_vid_path}.")
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
print("Cleaning up temporary files...")
|
||||
try:
|
||||
shutil.rmtree(result_img_save_path)
|
||||
os.remove(temp_vid_path)
|
||||
shutil.rmtree(save_dir_full)
|
||||
if not args.saved_coord:
|
||||
os.remove(crop_coord_save_path)
|
||||
except Exception as e:
|
||||
print(f"Warning: Cleanup failed: {e}")
|
||||
|
||||
print(f"Results saved to {output_vid_name}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n\n=== ERROR OCCURRED ===")
|
||||
print(f"Exception type: {type(e).__name__}")
|
||||
print(f"Exception message: {e}")
|
||||
print(f"Full traceback:\n{traceback.format_exc()}")
|
||||
print(f"=== END ERROR ===")
|
||||
sys.stdout.flush()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
|
||||
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
|
||||
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
|
||||
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
|
||||
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
|
||||
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
|
||||
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
|
||||
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
|
||||
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
|
||||
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
||||
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
||||
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
||||
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
|
||||
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
|
||||
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
|
||||
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
|
||||
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
|
||||
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
|
||||
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
|
||||
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
|
||||
parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Model version to use")
|
||||
|
||||
# NEW ARGUMENTS
|
||||
parser.add_argument("--video_path", type=str, default=None, help="Input video path")
|
||||
parser.add_argument("--audio_path", type=str, default=None, help="Input audio path")
|
||||
parser.add_argument("--output_path", type=str, default=None, help="Output video path (alias for output_vid_name)")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
334
models/MuseTalk/scripts/preprocess.py
Normal file
334
models/MuseTalk/scripts/preprocess.py
Normal file
@@ -0,0 +1,334 @@
|
||||
import os
|
||||
import argparse
|
||||
import subprocess
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from omegaconf import OmegaConf
|
||||
from typing import Tuple, List, Union
|
||||
import decord
|
||||
import json
|
||||
import cv2
|
||||
from musetalk.utils.face_detection import FaceAlignment,LandmarksType
|
||||
from mmpose.apis import inference_topdown, init_model
|
||||
from mmpose.structures import merge_data_samples
|
||||
import sys
|
||||
|
||||
def fast_check_ffmpeg():
|
||||
try:
|
||||
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
ffmpeg_path = "./ffmpeg-4.4-amd64-static/"
|
||||
if not fast_check_ffmpeg():
|
||||
print("Adding ffmpeg to PATH")
|
||||
# Choose path separator based on operating system
|
||||
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||
if not fast_check_ffmpeg():
|
||||
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||
|
||||
class AnalyzeFace:
|
||||
def __init__(self, device: Union[str, torch.device], config_file: str, checkpoint_file: str):
|
||||
"""
|
||||
Initialize the AnalyzeFace class with the given device, config file, and checkpoint file.
|
||||
|
||||
Parameters:
|
||||
device (Union[str, torch.device]): The device to run the models on ('cuda' or 'cpu').
|
||||
config_file (str): Path to the mmpose model configuration file.
|
||||
checkpoint_file (str): Path to the mmpose model checkpoint file.
|
||||
"""
|
||||
self.device = device
|
||||
self.dwpose = init_model(config_file, checkpoint_file, device=self.device)
|
||||
self.facedet = FaceAlignment(LandmarksType._2D, flip_input=False, device=self.device)
|
||||
|
||||
def __call__(self, im: np.ndarray) -> Tuple[List[np.ndarray], np.ndarray]:
|
||||
"""
|
||||
Detect faces and keypoints in the given image.
|
||||
|
||||
Parameters:
|
||||
im (np.ndarray): The input image.
|
||||
maxface (bool): Whether to detect the maximum face. Default is True.
|
||||
|
||||
Returns:
|
||||
Tuple[List[np.ndarray], np.ndarray]: A tuple containing the bounding boxes and keypoints.
|
||||
"""
|
||||
try:
|
||||
# Ensure the input image has the correct shape
|
||||
if im.ndim == 3:
|
||||
im = np.expand_dims(im, axis=0)
|
||||
elif im.ndim != 4 or im.shape[0] != 1:
|
||||
raise ValueError("Input image must have shape (1, H, W, C)")
|
||||
|
||||
bbox = self.facedet.get_detections_for_batch(np.asarray(im))
|
||||
results = inference_topdown(self.dwpose, np.asarray(im)[0])
|
||||
results = merge_data_samples(results)
|
||||
keypoints = results.pred_instances.keypoints
|
||||
face_land_mark= keypoints[0][23:91]
|
||||
face_land_mark = face_land_mark.astype(np.int32)
|
||||
|
||||
return face_land_mark, bbox
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error during face analysis: {e}")
|
||||
return np.array([]),[]
|
||||
|
||||
def convert_video(org_path: str, dst_path: str, vid_list: List[str]) -> None:
|
||||
|
||||
"""
|
||||
Convert video files to a specified format and save them to the destination path.
|
||||
|
||||
Parameters:
|
||||
org_path (str): The directory containing the original video files.
|
||||
dst_path (str): The directory where the converted video files will be saved.
|
||||
vid_list (List[str]): A list of video file names to process.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for idx, vid in enumerate(vid_list):
|
||||
if vid.endswith('.mp4'):
|
||||
org_vid_path = os.path.join(org_path, vid)
|
||||
dst_vid_path = os.path.join(dst_path, vid)
|
||||
|
||||
if org_vid_path != dst_vid_path:
|
||||
cmd = [
|
||||
"ffmpeg", "-hide_banner", "-y", "-i", org_vid_path,
|
||||
"-r", "25", "-crf", "15", "-c:v", "libx264",
|
||||
"-pix_fmt", "yuv420p", dst_vid_path
|
||||
]
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
if idx % 1000 == 0:
|
||||
print(f"### {idx} videos converted ###")
|
||||
|
||||
def segment_video(org_path: str, dst_path: str, vid_list: List[str], segment_duration: int = 30) -> None:
|
||||
"""
|
||||
Segment video files into smaller clips of specified duration.
|
||||
|
||||
Parameters:
|
||||
org_path (str): The directory containing the original video files.
|
||||
dst_path (str): The directory where the segmented video files will be saved.
|
||||
vid_list (List[str]): A list of video file names to process.
|
||||
segment_duration (int): The duration of each segment in seconds. Default is 30 seconds.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for idx, vid in enumerate(vid_list):
|
||||
if vid.endswith('.mp4'):
|
||||
input_file = os.path.join(org_path, vid)
|
||||
original_filename = os.path.basename(input_file)
|
||||
|
||||
command = [
|
||||
'ffmpeg', '-i', input_file, '-c', 'copy', '-map', '0',
|
||||
'-segment_time', str(segment_duration), '-f', 'segment',
|
||||
'-reset_timestamps', '1',
|
||||
os.path.join(dst_path, f'clip%03d_{original_filename}')
|
||||
]
|
||||
|
||||
subprocess.run(command, check=True)
|
||||
|
||||
def extract_audio(org_path: str, dst_path: str, vid_list: List[str]) -> None:
|
||||
"""
|
||||
Extract audio from video files and save as WAV format.
|
||||
|
||||
Parameters:
|
||||
org_path (str): The directory containing the original video files.
|
||||
dst_path (str): The directory where the extracted audio files will be saved.
|
||||
vid_list (List[str]): A list of video file names to process.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for idx, vid in enumerate(vid_list):
|
||||
if vid.endswith('.mp4'):
|
||||
video_path = os.path.join(org_path, vid)
|
||||
audio_output_path = os.path.join(dst_path, os.path.splitext(vid)[0] + ".wav")
|
||||
try:
|
||||
command = [
|
||||
'ffmpeg', '-hide_banner', '-y', '-i', video_path,
|
||||
'-vn', '-acodec', 'pcm_s16le', '-f', 'wav',
|
||||
'-ar', '16000', '-ac', '1', audio_output_path,
|
||||
]
|
||||
|
||||
subprocess.run(command, check=True)
|
||||
print(f"Audio saved to: {audio_output_path}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error extracting audio from {vid}: {e}")
|
||||
|
||||
def split_data(video_files: List[str], val_list_hdtf: List[str]) -> (List[str], List[str]):
|
||||
"""
|
||||
Split video files into training and validation sets based on val_list_hdtf.
|
||||
|
||||
Parameters:
|
||||
video_files (List[str]): A list of video file names.
|
||||
val_list_hdtf (List[str]): A list of validation file identifiers.
|
||||
|
||||
Returns:
|
||||
(List[str], List[str]): A tuple containing the training and validation file lists.
|
||||
"""
|
||||
val_files = [f for f in video_files if any(val_id in f for val_id in val_list_hdtf)]
|
||||
train_files = [f for f in video_files if f not in val_files]
|
||||
return train_files, val_files
|
||||
|
||||
def save_list_to_file(file_path: str, data_list: List[str]) -> None:
|
||||
"""
|
||||
Save a list of strings to a file, each string on a new line.
|
||||
|
||||
Parameters:
|
||||
file_path (str): The path to the file where the list will be saved.
|
||||
data_list (List[str]): The list of strings to save.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
with open(file_path, 'w') as file:
|
||||
for item in data_list:
|
||||
file.write(f"{item}\n")
|
||||
|
||||
def generate_train_list(cfg):
|
||||
train_file_path = cfg.video_clip_file_list_train
|
||||
val_file_path = cfg.video_clip_file_list_val
|
||||
val_list_hdtf = cfg.val_list_hdtf
|
||||
|
||||
meta_list = os.listdir(cfg.meta_root)
|
||||
|
||||
sorted_meta_list = sorted(meta_list)
|
||||
train_files, val_files = split_data(meta_list, val_list_hdtf)
|
||||
|
||||
save_list_to_file(train_file_path, train_files)
|
||||
save_list_to_file(val_file_path, val_files)
|
||||
|
||||
print(val_list_hdtf)
|
||||
|
||||
def analyze_video(org_path: str, dst_path: str, vid_list: List[str]) -> None:
|
||||
"""
|
||||
Convert video files to a specified format and save them to the destination path.
|
||||
|
||||
Parameters:
|
||||
org_path (str): The directory containing the original video files.
|
||||
dst_path (str): The directory where the meta json will be saved.
|
||||
vid_list (List[str]): A list of video file names to process.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
|
||||
checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth'
|
||||
|
||||
analyze_face = AnalyzeFace(device, config_file, checkpoint_file)
|
||||
|
||||
for vid in tqdm(vid_list, desc="Processing videos"):
|
||||
#vid = "clip005_WDA_BernieSanders_000.mp4"
|
||||
#print(vid)
|
||||
if vid.endswith('.mp4'):
|
||||
vid_path = os.path.join(org_path, vid)
|
||||
wav_path = vid_path.replace(".mp4",".wav")
|
||||
vid_meta = os.path.join(dst_path, os.path.splitext(vid)[0] + ".json")
|
||||
if os.path.exists(vid_meta):
|
||||
continue
|
||||
print('process video {}'.format(vid))
|
||||
|
||||
total_bbox_list = []
|
||||
total_pts_list = []
|
||||
isvalid = True
|
||||
|
||||
# process
|
||||
try:
|
||||
cap = decord.VideoReader(vid_path, fault_tol=1)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
|
||||
total_frames = len(cap)
|
||||
for frame_idx in range(total_frames):
|
||||
frame = cap[frame_idx]
|
||||
if frame_idx==0:
|
||||
video_height,video_width,_ = frame.shape
|
||||
frame_bgr = cv2.cvtColor(frame.asnumpy(), cv2.COLOR_BGR2RGB)
|
||||
pts_list, bbox_list = analyze_face(frame_bgr)
|
||||
|
||||
if len(bbox_list)>0 and None not in bbox_list:
|
||||
bbox = bbox_list[0]
|
||||
else:
|
||||
isvalid = False
|
||||
bbox = []
|
||||
print(f"set isvalid to False as broken img in {frame_idx} of {vid}")
|
||||
break
|
||||
|
||||
#print(pts_list)
|
||||
if len(pts_list)>0 and pts_list is not None:
|
||||
pts = pts_list.tolist()
|
||||
else:
|
||||
isvalid = False
|
||||
pts = []
|
||||
break
|
||||
|
||||
if frame_idx==0:
|
||||
x1,y1,x2,y2 = bbox
|
||||
face_height, face_width = y2-y1,x2-x1
|
||||
|
||||
total_pts_list.append(pts)
|
||||
total_bbox_list.append(bbox)
|
||||
|
||||
meta_data = {
|
||||
"mp4_path": vid_path,
|
||||
"wav_path": wav_path,
|
||||
"video_size": [video_height, video_width],
|
||||
"face_size": [face_height, face_width],
|
||||
"frames": total_frames,
|
||||
"face_list": total_bbox_list,
|
||||
"landmark_list": total_pts_list,
|
||||
"isvalid":isvalid,
|
||||
}
|
||||
with open(vid_meta, 'w') as f:
|
||||
json.dump(meta_data, f, indent=4)
|
||||
|
||||
|
||||
|
||||
def main(cfg):
|
||||
# Ensure all necessary directories exist
|
||||
os.makedirs(cfg.video_root_25fps, exist_ok=True)
|
||||
os.makedirs(cfg.video_audio_clip_root, exist_ok=True)
|
||||
os.makedirs(cfg.meta_root, exist_ok=True)
|
||||
os.makedirs(os.path.dirname(cfg.video_file_list), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(cfg.video_clip_file_list_train), exist_ok=True)
|
||||
os.makedirs(os.path.dirname(cfg.video_clip_file_list_val), exist_ok=True)
|
||||
|
||||
vid_list = os.listdir(cfg.video_root_raw)
|
||||
sorted_vid_list = sorted(vid_list)
|
||||
|
||||
# Save video file list
|
||||
with open(cfg.video_file_list, 'w') as file:
|
||||
for vid in sorted_vid_list:
|
||||
file.write(vid + '\n')
|
||||
|
||||
# 1. Convert videos to 25 FPS
|
||||
convert_video(cfg.video_root_raw, cfg.video_root_25fps, sorted_vid_list)
|
||||
|
||||
# 2. Segment videos into 30-second clips
|
||||
segment_video(cfg.video_root_25fps, cfg.video_audio_clip_root, vid_list, segment_duration=cfg.clip_len_second)
|
||||
|
||||
# 3. Extract audio
|
||||
clip_vid_list = os.listdir(cfg.video_audio_clip_root)
|
||||
extract_audio(cfg.video_audio_clip_root, cfg.video_audio_clip_root, clip_vid_list)
|
||||
|
||||
# 4. Generate video metadata
|
||||
analyze_video(cfg.video_audio_clip_root, cfg.meta_root, clip_vid_list)
|
||||
|
||||
# 5. Generate training and validation set lists
|
||||
generate_train_list(cfg)
|
||||
print("done")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="./configs/training/preprocess.yaml")
|
||||
args = parser.parse_args()
|
||||
config = OmegaConf.load(args.config)
|
||||
|
||||
main(config)
|
||||
|
||||
409
models/MuseTalk/scripts/realtime_inference.py
Normal file
409
models/MuseTalk/scripts/realtime_inference.py
Normal file
@@ -0,0 +1,409 @@
|
||||
import argparse
|
||||
import os
|
||||
from omegaconf import OmegaConf
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import glob
|
||||
import pickle
|
||||
import sys
|
||||
from tqdm import tqdm
|
||||
import copy
|
||||
import json
|
||||
from transformers import WhisperModel
|
||||
|
||||
from musetalk.utils.face_parsing import FaceParsing
|
||||
from musetalk.utils.utils import datagen
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs
|
||||
from musetalk.utils.blending import get_image_prepare_material, get_image_blending
|
||||
from musetalk.utils.utils import load_all_model
|
||||
from musetalk.utils.audio_processor import AudioProcessor
|
||||
|
||||
import shutil
|
||||
import threading
|
||||
import queue
|
||||
import time
|
||||
import subprocess
|
||||
|
||||
|
||||
def fast_check_ffmpeg():
|
||||
try:
|
||||
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
|
||||
cap = cv2.VideoCapture(vid_path)
|
||||
count = 0
|
||||
while True:
|
||||
if count > cut_frame:
|
||||
break
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
|
||||
count += 1
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
def osmakedirs(path_list):
|
||||
for path in path_list:
|
||||
os.makedirs(path) if not os.path.exists(path) else None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
class Avatar:
|
||||
def __init__(self, avatar_id, video_path, bbox_shift, batch_size, preparation):
|
||||
self.avatar_id = avatar_id
|
||||
self.video_path = video_path
|
||||
self.bbox_shift = bbox_shift
|
||||
# 根据版本设置不同的基础路径
|
||||
if args.version == "v15":
|
||||
self.base_path = f"./results/{args.version}/avatars/{avatar_id}"
|
||||
else: # v1
|
||||
self.base_path = f"./results/avatars/{avatar_id}"
|
||||
|
||||
self.avatar_path = self.base_path
|
||||
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
|
||||
self.coords_path = f"{self.avatar_path}/coords.pkl"
|
||||
self.latents_out_path = f"{self.avatar_path}/latents.pt"
|
||||
self.video_out_path = f"{self.avatar_path}/vid_output/"
|
||||
self.mask_out_path = f"{self.avatar_path}/mask"
|
||||
self.mask_coords_path = f"{self.avatar_path}/mask_coords.pkl"
|
||||
self.avatar_info_path = f"{self.avatar_path}/avator_info.json"
|
||||
self.avatar_info = {
|
||||
"avatar_id": avatar_id,
|
||||
"video_path": video_path,
|
||||
"bbox_shift": bbox_shift,
|
||||
"version": args.version
|
||||
}
|
||||
self.preparation = preparation
|
||||
self.batch_size = batch_size
|
||||
self.idx = 0
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
if self.preparation:
|
||||
if os.path.exists(self.avatar_path):
|
||||
response = input(f"{self.avatar_id} exists, Do you want to re-create it ? (y/n)")
|
||||
if response.lower() == "y":
|
||||
shutil.rmtree(self.avatar_path)
|
||||
print("*********************************")
|
||||
print(f" creating avator: {self.avatar_id}")
|
||||
print("*********************************")
|
||||
osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path])
|
||||
self.prepare_material()
|
||||
else:
|
||||
self.input_latent_list_cycle = torch.load(self.latents_out_path)
|
||||
with open(self.coords_path, 'rb') as f:
|
||||
self.coord_list_cycle = pickle.load(f)
|
||||
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
self.frame_list_cycle = read_imgs(input_img_list)
|
||||
with open(self.mask_coords_path, 'rb') as f:
|
||||
self.mask_coords_list_cycle = pickle.load(f)
|
||||
input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
self.mask_list_cycle = read_imgs(input_mask_list)
|
||||
else:
|
||||
print("*********************************")
|
||||
print(f" creating avator: {self.avatar_id}")
|
||||
print("*********************************")
|
||||
osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path])
|
||||
self.prepare_material()
|
||||
else:
|
||||
if not os.path.exists(self.avatar_path):
|
||||
print(f"{self.avatar_id} does not exist, you should set preparation to True")
|
||||
sys.exit()
|
||||
|
||||
with open(self.avatar_info_path, "r") as f:
|
||||
avatar_info = json.load(f)
|
||||
|
||||
if avatar_info['bbox_shift'] != self.avatar_info['bbox_shift']:
|
||||
response = input(f" 【bbox_shift】 is changed, you need to re-create it ! (c/continue)")
|
||||
if response.lower() == "c":
|
||||
shutil.rmtree(self.avatar_path)
|
||||
print("*********************************")
|
||||
print(f" creating avator: {self.avatar_id}")
|
||||
print("*********************************")
|
||||
osmakedirs([self.avatar_path, self.full_imgs_path, self.video_out_path, self.mask_out_path])
|
||||
self.prepare_material()
|
||||
else:
|
||||
sys.exit()
|
||||
else:
|
||||
self.input_latent_list_cycle = torch.load(self.latents_out_path)
|
||||
with open(self.coords_path, 'rb') as f:
|
||||
self.coord_list_cycle = pickle.load(f)
|
||||
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
self.frame_list_cycle = read_imgs(input_img_list)
|
||||
with open(self.mask_coords_path, 'rb') as f:
|
||||
self.mask_coords_list_cycle = pickle.load(f)
|
||||
input_mask_list = glob.glob(os.path.join(self.mask_out_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_mask_list = sorted(input_mask_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
self.mask_list_cycle = read_imgs(input_mask_list)
|
||||
|
||||
def prepare_material(self):
|
||||
print("preparing data materials ... ...")
|
||||
with open(self.avatar_info_path, "w") as f:
|
||||
json.dump(self.avatar_info, f)
|
||||
|
||||
if os.path.isfile(self.video_path):
|
||||
video2imgs(self.video_path, self.full_imgs_path, ext='png')
|
||||
else:
|
||||
print(f"copy files in {self.video_path}")
|
||||
files = os.listdir(self.video_path)
|
||||
files.sort()
|
||||
files = [file for file in files if file.split(".")[-1] == "png"]
|
||||
for filename in files:
|
||||
shutil.copyfile(f"{self.video_path}/{filename}", f"{self.full_imgs_path}/{filename}")
|
||||
input_img_list = sorted(glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')))
|
||||
|
||||
print("extracting landmarks...")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, self.bbox_shift)
|
||||
input_latent_list = []
|
||||
idx = -1
|
||||
# maker if the bbox is not sufficient
|
||||
coord_placeholder = (0.0, 0.0, 0.0, 0.0)
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
idx = idx + 1
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
if args.version == "v15":
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
coord_list[idx] = [x1, y1, x2, y2] # 更新coord_list中的bbox
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
resized_crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
|
||||
latents = vae.get_latents_for_unet(resized_crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
|
||||
self.frame_list_cycle = frame_list + frame_list[::-1]
|
||||
self.coord_list_cycle = coord_list + coord_list[::-1]
|
||||
self.input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
self.mask_coords_list_cycle = []
|
||||
self.mask_list_cycle = []
|
||||
|
||||
for i, frame in enumerate(tqdm(self.frame_list_cycle)):
|
||||
cv2.imwrite(f"{self.full_imgs_path}/{str(i).zfill(8)}.png", frame)
|
||||
|
||||
x1, y1, x2, y2 = self.coord_list_cycle[i]
|
||||
if args.version == "v15":
|
||||
mode = args.parsing_mode
|
||||
else:
|
||||
mode = "raw"
|
||||
mask, crop_box = get_image_prepare_material(frame, [x1, y1, x2, y2], fp=fp, mode=mode)
|
||||
|
||||
cv2.imwrite(f"{self.mask_out_path}/{str(i).zfill(8)}.png", mask)
|
||||
self.mask_coords_list_cycle += [crop_box]
|
||||
self.mask_list_cycle.append(mask)
|
||||
|
||||
with open(self.mask_coords_path, 'wb') as f:
|
||||
pickle.dump(self.mask_coords_list_cycle, f)
|
||||
|
||||
with open(self.coords_path, 'wb') as f:
|
||||
pickle.dump(self.coord_list_cycle, f)
|
||||
|
||||
torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
|
||||
|
||||
def process_frames(self, res_frame_queue, video_len, skip_save_images):
|
||||
print(video_len)
|
||||
while True:
|
||||
if self.idx >= video_len - 1:
|
||||
break
|
||||
try:
|
||||
start = time.time()
|
||||
res_frame = res_frame_queue.get(block=True, timeout=1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
bbox = self.coord_list_cycle[self.idx % (len(self.coord_list_cycle))]
|
||||
ori_frame = copy.deepcopy(self.frame_list_cycle[self.idx % (len(self.frame_list_cycle))])
|
||||
x1, y1, x2, y2 = bbox
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
|
||||
except:
|
||||
continue
|
||||
mask = self.mask_list_cycle[self.idx % (len(self.mask_list_cycle))]
|
||||
mask_crop_box = self.mask_coords_list_cycle[self.idx % (len(self.mask_coords_list_cycle))]
|
||||
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
|
||||
|
||||
if skip_save_images is False:
|
||||
cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png", combine_frame)
|
||||
self.idx = self.idx + 1
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, audio_path, out_vid_name, fps, skip_save_images):
|
||||
os.makedirs(self.avatar_path + '/tmp', exist_ok=True)
|
||||
print("start inference")
|
||||
############################################## extract audio feature ##############################################
|
||||
start_time = time.time()
|
||||
# Extract audio features
|
||||
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path, weight_dtype=weight_dtype)
|
||||
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||
whisper_input_features,
|
||||
device,
|
||||
weight_dtype,
|
||||
whisper,
|
||||
librosa_length,
|
||||
fps=fps,
|
||||
audio_padding_length_left=args.audio_padding_length_left,
|
||||
audio_padding_length_right=args.audio_padding_length_right,
|
||||
)
|
||||
print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
|
||||
############################################## inference batch by batch ##############################################
|
||||
video_num = len(whisper_chunks)
|
||||
res_frame_queue = queue.Queue()
|
||||
self.idx = 0
|
||||
# Create a sub-thread and start it
|
||||
process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images))
|
||||
process_thread.start()
|
||||
|
||||
gen = datagen(whisper_chunks,
|
||||
self.input_latent_list_cycle,
|
||||
self.batch_size)
|
||||
start_time = time.time()
|
||||
res_frame_list = []
|
||||
|
||||
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=int(np.ceil(float(video_num) / self.batch_size)))):
|
||||
audio_feature_batch = pe(whisper_batch.to(device))
|
||||
latent_batch = latent_batch.to(device=device, dtype=unet.model.dtype)
|
||||
|
||||
pred_latents = unet.model(latent_batch,
|
||||
timesteps,
|
||||
encoder_hidden_states=audio_feature_batch).sample
|
||||
pred_latents = pred_latents.to(device=device, dtype=vae.vae.dtype)
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
res_frame_queue.put(res_frame)
|
||||
# Close the queue and sub-thread after all tasks are completed
|
||||
process_thread.join()
|
||||
|
||||
if args.skip_save_images is True:
|
||||
print('Total process time of {} frames without saving images = {}s'.format(
|
||||
video_num,
|
||||
time.time() - start_time))
|
||||
else:
|
||||
print('Total process time of {} frames including saving images = {}s'.format(
|
||||
video_num,
|
||||
time.time() - start_time))
|
||||
|
||||
if out_vid_name is not None and args.skip_save_images is False:
|
||||
# optional
|
||||
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
|
||||
print(cmd_img2video)
|
||||
os.system(cmd_img2video)
|
||||
|
||||
output_vid = os.path.join(self.video_out_path, out_vid_name + ".mp4") # on
|
||||
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {self.avatar_path}/temp.mp4 {output_vid}"
|
||||
print(cmd_combine_audio)
|
||||
os.system(cmd_combine_audio)
|
||||
|
||||
os.remove(f"{self.avatar_path}/temp.mp4")
|
||||
shutil.rmtree(f"{self.avatar_path}/tmp")
|
||||
print(f"result is save to {output_vid}")
|
||||
print("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
'''
|
||||
This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
|
||||
'''
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Version of MuseTalk: v1 or v15")
|
||||
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
|
||||
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
|
||||
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
|
||||
parser.add_argument("--unet_config", type=str, default="./models/musetalk/musetalk.json", help="Path to UNet configuration file")
|
||||
parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights")
|
||||
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/realtime.yaml")
|
||||
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
|
||||
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
|
||||
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
|
||||
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
||||
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
||||
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
||||
parser.add_argument("--batch_size", type=int, default=20, help="Batch size for inference")
|
||||
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
|
||||
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
|
||||
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
|
||||
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
|
||||
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
|
||||
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
|
||||
parser.add_argument("--skip_save_images",
|
||||
action="store_true",
|
||||
help="Whether skip saving images for better generation speed calculation",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure ffmpeg path
|
||||
if not fast_check_ffmpeg():
|
||||
print("Adding ffmpeg to PATH")
|
||||
# Choose path separator based on operating system
|
||||
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||
if not fast_check_ffmpeg():
|
||||
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||
|
||||
# Set computing device
|
||||
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Load model weights
|
||||
vae, unet, pe = load_all_model(
|
||||
unet_model_path=args.unet_model_path,
|
||||
vae_type=args.vae_type,
|
||||
unet_config=args.unet_config,
|
||||
device=device
|
||||
)
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
pe = pe.half().to(device)
|
||||
vae.vae = vae.vae.half().to(device)
|
||||
unet.model = unet.model.half().to(device)
|
||||
|
||||
# Initialize audio processor and Whisper model
|
||||
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
|
||||
weight_dtype = unet.model.dtype
|
||||
whisper = WhisperModel.from_pretrained(args.whisper_dir)
|
||||
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
||||
whisper.requires_grad_(False)
|
||||
|
||||
# Initialize face parser with configurable parameters based on version
|
||||
if args.version == "v15":
|
||||
fp = FaceParsing(
|
||||
left_cheek_width=args.left_cheek_width,
|
||||
right_cheek_width=args.right_cheek_width
|
||||
)
|
||||
else: # v1
|
||||
fp = FaceParsing()
|
||||
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
print(inference_config)
|
||||
|
||||
for avatar_id in inference_config:
|
||||
data_preparation = inference_config[avatar_id]["preparation"]
|
||||
video_path = inference_config[avatar_id]["video_path"]
|
||||
if args.version == "v15":
|
||||
bbox_shift = 0
|
||||
else:
|
||||
bbox_shift = inference_config[avatar_id]["bbox_shift"]
|
||||
avatar = Avatar(
|
||||
avatar_id=avatar_id,
|
||||
video_path=video_path,
|
||||
bbox_shift=bbox_shift,
|
||||
batch_size=args.batch_size,
|
||||
preparation=data_preparation)
|
||||
|
||||
audio_clips = inference_config[avatar_id]["audio_clips"]
|
||||
for audio_num, audio_path in audio_clips.items():
|
||||
print("Inferring using:", audio_path)
|
||||
avatar.inference(audio_path,
|
||||
audio_num,
|
||||
args.fps,
|
||||
args.skip_save_images)
|
||||
572
models/MuseTalk/scripts/server.py
Normal file
572
models/MuseTalk/scripts/server.py
Normal file
@@ -0,0 +1,572 @@
|
||||
"""
|
||||
MuseTalk v1.5 常驻推理服务 (优化版 v2)
|
||||
- 端口: 8011
|
||||
- GPU: 从 backend/.env 读取 MUSETALK_GPU_ID (默认 0)
|
||||
- 架构: FastAPI + lifespan (与 LatentSync server.py 同模式)
|
||||
|
||||
优化项 (vs v1):
|
||||
1. cv2.VideoCapture 直读帧 (跳过 ffmpeg→PNG→imread)
|
||||
2. 人脸检测降频 (每 N 帧检测, 中间插值 bbox)
|
||||
3. BiSeNet mask 缓存 (每 N 帧更新, 中间复用)
|
||||
4. cv2.VideoWriter 直写视频 (跳过逐帧 PNG 写盘)
|
||||
5. batch_size 8→32
|
||||
6. 每阶段计时
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import copy
|
||||
import time
|
||||
import glob
|
||||
import shutil
|
||||
import tempfile
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
# --- 自动加载 GPU 配置 (必须在 torch 导入前) ---
|
||||
def load_gpu_config():
|
||||
"""尝试从后端 .env 文件读取 MUSETALK_GPU_ID"""
|
||||
try:
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
env_path = current_dir.parent.parent.parent / "backend" / ".env"
|
||||
|
||||
target_gpu = "0" # 默认 GPU 0
|
||||
|
||||
if env_path.exists():
|
||||
print(f"📖 读取配置文件: {env_path}")
|
||||
with open(env_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("MUSETALK_GPU_ID="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip()
|
||||
if val:
|
||||
target_gpu = val
|
||||
print(f"⚙️ 发现配置 MUSETALK_GPU_ID={target_gpu}")
|
||||
break
|
||||
|
||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = target_gpu
|
||||
print(f"✅ 已自动设置: CUDA_VISIBLE_DEVICES={target_gpu}")
|
||||
else:
|
||||
print(f"ℹ️ 检测到外部 CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']},跳过自动配置")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ 读取 GPU 配置失败: {e},将使用默认设置")
|
||||
|
||||
load_gpu_config()
|
||||
|
||||
# --- 性能优化: 限制 CPU 线程数 ---
|
||||
os.environ["OMP_NUM_THREADS"] = "8"
|
||||
os.environ["MKL_NUM_THREADS"] = "8"
|
||||
os.environ["TORCH_NUM_THREADS"] = "8"
|
||||
print("⚙️ 已限制 PyTorch CPU 线程数为 8,防止系统卡顿")
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import pickle
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from transformers import WhisperModel
|
||||
|
||||
# 添加项目根目录到 sys.path (MuseTalk 根目录)
|
||||
musetalk_root = Path(__file__).resolve().parent.parent
|
||||
sys.path.insert(0, str(musetalk_root))
|
||||
|
||||
from musetalk.utils.blending import get_image, get_image_blending, get_image_prepare_material
|
||||
from musetalk.utils.face_parsing import FaceParsing
|
||||
from musetalk.utils.audio_processor import AudioProcessor
|
||||
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
||||
|
||||
# --- 从 .env 读取额外配置 ---
|
||||
def load_env_config():
|
||||
"""读取 MuseTalk 相关环境变量"""
|
||||
config = {
|
||||
"batch_size": 32,
|
||||
"version": "v15",
|
||||
"use_float16": True,
|
||||
}
|
||||
try:
|
||||
env_path = musetalk_root.parent.parent / "backend" / ".env"
|
||||
if env_path.exists():
|
||||
with open(env_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("MUSETALK_BATCH_SIZE="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip()
|
||||
if val:
|
||||
config["batch_size"] = int(val)
|
||||
elif line.startswith("MUSETALK_VERSION="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip()
|
||||
if val:
|
||||
config["version"] = val
|
||||
elif line.startswith("MUSETALK_USE_FLOAT16="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip().lower()
|
||||
config["use_float16"] = val in ("true", "1", "yes")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 读取额外配置失败: {e}")
|
||||
return config
|
||||
|
||||
env_config = load_env_config()
|
||||
|
||||
# 全局模型缓存
|
||||
models = {}
|
||||
|
||||
# ===================== 优化参数 =====================
|
||||
DETECT_EVERY = 5 # 人脸检测降频: 每 N 帧检测一次
|
||||
BLEND_CACHE_EVERY = 5 # BiSeNet mask 缓存: 每 N 帧更新一次
|
||||
# ====================================================
|
||||
|
||||
|
||||
def run_ffmpeg(cmd):
|
||||
"""执行 FFmpeg 命令"""
|
||||
print(f"Executing: {cmd}")
|
||||
try:
|
||||
result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Error executing ffmpeg: {cmd}")
|
||||
print(f"Return code: {e.returncode}")
|
||||
print(f"Stderr: {e.stderr[:500]}")
|
||||
return False
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""启动时加载所有模型,只做一次"""
|
||||
print("⏳ 正在加载 MuseTalk v1.5 模型...")
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
version = env_config["version"]
|
||||
use_float16 = env_config["use_float16"]
|
||||
|
||||
if torch.cuda.is_available():
|
||||
gpu_name = torch.cuda.get_device_name(0)
|
||||
print(f"🖥️ 正在使用 GPU: {gpu_name}")
|
||||
else:
|
||||
print("⚠️ 警告: 未检测到 GPU,将使用 CPU 进行推理 (速度极慢)")
|
||||
|
||||
# 根据版本选择模型路径
|
||||
models_dir = musetalk_root / "models"
|
||||
if version == "v15":
|
||||
unet_model_path = str(models_dir / "musetalkV15" / "unet.pth")
|
||||
unet_config = str(models_dir / "musetalk" / "config.json")
|
||||
else:
|
||||
unet_model_path = str(models_dir / "musetalk" / "pytorch_model.bin")
|
||||
unet_config = str(models_dir / "musetalk" / "musetalk.json")
|
||||
|
||||
# 切换工作目录(load_all_model 使用相对路径加载 VAE)
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(str(musetalk_root))
|
||||
|
||||
vae, unet, pe = load_all_model(
|
||||
unet_model_path=unet_model_path,
|
||||
vae_type="sd-vae",
|
||||
unet_config=unet_config,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if use_float16 and torch.cuda.is_available():
|
||||
print("⚡ 使用 float16 半精度加速")
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
|
||||
pe = pe.to(device)
|
||||
vae.vae = vae.vae.to(device)
|
||||
unet.model = unet.model.to(device)
|
||||
|
||||
# Whisper
|
||||
whisper_dir = str(models_dir / "whisper")
|
||||
audio_processor = AudioProcessor(feature_extractor_path=whisper_dir)
|
||||
weight_dtype = unet.model.dtype
|
||||
whisper = WhisperModel.from_pretrained(whisper_dir)
|
||||
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
||||
whisper.requires_grad_(False)
|
||||
|
||||
# FaceParsing
|
||||
if version == "v15":
|
||||
fp = FaceParsing(left_cheek_width=90, right_cheek_width=90)
|
||||
else:
|
||||
fp = FaceParsing()
|
||||
|
||||
# 恢复工作目录
|
||||
os.chdir(original_cwd)
|
||||
|
||||
models["vae"] = vae
|
||||
models["unet"] = unet
|
||||
models["pe"] = pe
|
||||
models["whisper"] = whisper
|
||||
models["audio_processor"] = audio_processor
|
||||
models["fp"] = fp
|
||||
models["device"] = device
|
||||
models["weight_dtype"] = weight_dtype
|
||||
models["version"] = version
|
||||
models["timesteps"] = torch.tensor([0], device=device)
|
||||
|
||||
print("✅ MuseTalk v1.5 模型加载完成,服务就绪!")
|
||||
print(f"⚙️ 优化参数: batch_size={env_config['batch_size']}, "
|
||||
f"detect_every={DETECT_EVERY}, blend_cache_every={BLEND_CACHE_EVERY}")
|
||||
yield
|
||||
models.clear()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
class LipSyncRequest(BaseModel):
|
||||
video_path: str
|
||||
audio_path: str
|
||||
video_out_path: str
|
||||
batch_size: int = 32
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health_check():
|
||||
return {"status": "ok", "model_loaded": "unet" in models}
|
||||
|
||||
|
||||
@app.post("/lipsync")
|
||||
async def generate_lipsync(req: LipSyncRequest):
|
||||
if "unet" not in models:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
if not os.path.exists(req.video_path):
|
||||
raise HTTPException(status_code=404, detail=f"Video not found: {req.video_path}")
|
||||
if not os.path.exists(req.audio_path):
|
||||
raise HTTPException(status_code=404, detail=f"Audio not found: {req.audio_path}")
|
||||
|
||||
print(f"🎬 收到任务: {Path(req.video_path).name} -> {Path(req.video_out_path).name}")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
result = _run_inference(req)
|
||||
elapsed = time.time() - start_time
|
||||
print(f"✅ 推理完成,耗时 {elapsed:.1f}s ({elapsed/60:.1f}min)")
|
||||
return result
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# 降频人脸检测: 每 N 帧检测一次, 中间帧线性插值 bbox
|
||||
# =====================================================================
|
||||
def _detect_faces_subsampled(frames, detect_every=5):
|
||||
"""
|
||||
降频人脸检测:
|
||||
- 每 detect_every 帧运行 DWPose + FaceAlignment
|
||||
- 中间帧线性插值 bbox 坐标
|
||||
- 对于口播视频 (人脸几乎不动), 插值误差可忽略
|
||||
"""
|
||||
from mmpose.apis import inference_topdown
|
||||
from mmpose.structures import merge_data_samples
|
||||
import musetalk.utils.preprocessing as _prep
|
||||
|
||||
n = len(frames)
|
||||
if n == 0:
|
||||
return []
|
||||
|
||||
# 确定需要检测的帧索引
|
||||
sampled_indices = list(range(0, n, detect_every))
|
||||
if sampled_indices[-1] != n - 1:
|
||||
sampled_indices.append(n - 1)
|
||||
|
||||
print(f" 检测 {len(sampled_indices)}/{n} 帧 (每{detect_every}帧)")
|
||||
|
||||
# 在采样帧上运行检测
|
||||
detected = {}
|
||||
for idx in tqdm(sampled_indices, desc="人脸检测"):
|
||||
frame = frames[idx]
|
||||
try:
|
||||
results = inference_topdown(_prep.model, frame)
|
||||
results = merge_data_samples(results)
|
||||
keypoints = results.pred_instances.keypoints
|
||||
face_land_mark = keypoints[0][23:91].astype(np.int32)
|
||||
|
||||
bbox_list = _prep.fa.get_detections_for_batch(np.array([frame]))
|
||||
|
||||
if bbox_list[0] is None:
|
||||
detected[idx] = coord_placeholder
|
||||
continue
|
||||
|
||||
half_face_coord = face_land_mark[29].copy()
|
||||
half_face_dist = np.max(face_land_mark[:, 1]) - half_face_coord[1]
|
||||
upper_bond = max(0, half_face_coord[1] - half_face_dist)
|
||||
|
||||
f_landmark = (
|
||||
int(np.min(face_land_mark[:, 0])),
|
||||
int(upper_bond),
|
||||
int(np.max(face_land_mark[:, 0])),
|
||||
int(np.max(face_land_mark[:, 1])),
|
||||
)
|
||||
x1, y1, x2, y2 = f_landmark
|
||||
if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0:
|
||||
detected[idx] = bbox_list[0] if bbox_list[0] is not None else coord_placeholder
|
||||
else:
|
||||
detected[idx] = f_landmark
|
||||
except Exception as e:
|
||||
print(f"⚠️ 帧 {idx} 检测失败: {e}")
|
||||
detected[idx] = coord_placeholder
|
||||
|
||||
# 插值填充所有帧
|
||||
coord_list = [None] * n
|
||||
for idx in sampled_indices:
|
||||
coord_list[idx] = detected[idx]
|
||||
|
||||
for i in range(n):
|
||||
if coord_list[i] is not None:
|
||||
continue
|
||||
|
||||
# 找前后已检测的帧
|
||||
prev_idx = max(j for j in sampled_indices if j < i)
|
||||
next_idx = min(j for j in sampled_indices if j > i)
|
||||
|
||||
prev_bbox = detected[prev_idx]
|
||||
next_bbox = detected[next_idx]
|
||||
|
||||
if prev_bbox == coord_placeholder and next_bbox == coord_placeholder:
|
||||
coord_list[i] = coord_placeholder
|
||||
elif prev_bbox == coord_placeholder:
|
||||
coord_list[i] = next_bbox
|
||||
elif next_bbox == coord_placeholder:
|
||||
coord_list[i] = prev_bbox
|
||||
else:
|
||||
alpha = (i - prev_idx) / (next_idx - prev_idx)
|
||||
coord_list[i] = tuple(
|
||||
int(a * (1 - alpha) + b * alpha)
|
||||
for a, b in zip(prev_bbox, next_bbox)
|
||||
)
|
||||
|
||||
return coord_list
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# 核心推理 (优化版)
|
||||
# =====================================================================
|
||||
@torch.no_grad()
|
||||
def _run_inference(req: LipSyncRequest) -> dict:
|
||||
"""
|
||||
优化版推理逻辑:
|
||||
1. cv2.VideoCapture 直读帧 (跳过 ffmpeg→PNG→imread)
|
||||
2. 人脸检测降频 (每 N 帧, 中间插值)
|
||||
3. BiSeNet mask 缓存 (每 N 帧更新)
|
||||
4. cv2.VideoWriter 直写 (跳过逐帧 PNG)
|
||||
5. 每阶段计时
|
||||
"""
|
||||
vae = models["vae"]
|
||||
unet = models["unet"]
|
||||
pe = models["pe"]
|
||||
whisper = models["whisper"]
|
||||
audio_processor = models["audio_processor"]
|
||||
fp = models["fp"]
|
||||
device = models["device"]
|
||||
weight_dtype = models["weight_dtype"]
|
||||
version = models["version"]
|
||||
timesteps = models["timesteps"]
|
||||
batch_size = req.batch_size or env_config["batch_size"]
|
||||
|
||||
video_path = req.video_path
|
||||
audio_path = req.audio_path
|
||||
output_vid_path = req.video_out_path
|
||||
|
||||
os.makedirs(os.path.dirname(output_vid_path), exist_ok=True)
|
||||
|
||||
t_total = time.time()
|
||||
timings = {}
|
||||
|
||||
# ===== Phase 1: 读取视频帧 (cv2.VideoCapture, 跳过 ffmpeg→PNG) =====
|
||||
t0 = time.time()
|
||||
if get_file_type(video_path) == "video":
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
||||
frames = []
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
frames.append(frame)
|
||||
cap.release()
|
||||
elif get_file_type(video_path) == "image":
|
||||
frames = [cv2.imread(video_path)]
|
||||
fps = 25.0
|
||||
else:
|
||||
raise ValueError(f"不支持的文件类型: {video_path}")
|
||||
|
||||
timings["1_read"] = time.time() - t0
|
||||
print(f"📹 读取 {len(frames)} 帧, FPS={fps} [{timings['1_read']:.1f}s]")
|
||||
|
||||
if not frames:
|
||||
raise RuntimeError("视频帧为空")
|
||||
|
||||
# ===== Phase 2: Whisper 音频特征 =====
|
||||
t0 = time.time()
|
||||
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
||||
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||
whisper_input_features, device, weight_dtype, whisper, librosa_length,
|
||||
fps=fps,
|
||||
audio_padding_length_left=2,
|
||||
audio_padding_length_right=2,
|
||||
)
|
||||
timings["2_whisper"] = time.time() - t0
|
||||
print(f"🎵 Whisper 特征 [{timings['2_whisper']:.1f}s]")
|
||||
|
||||
# ===== Phase 3: 人脸检测 (降频) =====
|
||||
t0 = time.time()
|
||||
coord_list = _detect_faces_subsampled(frames, detect_every=DETECT_EVERY)
|
||||
timings["3_face"] = time.time() - t0
|
||||
print(f"🔍 人脸检测 [{timings['3_face']:.1f}s]")
|
||||
|
||||
# ===== Phase 4: VAE 潜空间编码 =====
|
||||
t0 = time.time()
|
||||
input_latent_list = []
|
||||
extra_margin = 10
|
||||
for bbox, frame in zip(coord_list, frames):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
if version == "v15":
|
||||
y2 = min(y2 + extra_margin, frame.shape[0])
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
|
||||
timings["4_vae"] = time.time() - t0
|
||||
print(f"🧠 VAE 编码 [{timings['4_vae']:.1f}s]")
|
||||
|
||||
# ===== Phase 5: UNet 批量推理 =====
|
||||
t0 = time.time()
|
||||
|
||||
# 循环帧序列 (引用, 不复制数据)
|
||||
frame_list_cycle = frames + frames[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
|
||||
video_num = len(whisper_chunks)
|
||||
gen = datagen(
|
||||
whisper_chunks=whisper_chunks,
|
||||
vae_encode_latents=input_latent_list_cycle,
|
||||
batch_size=batch_size,
|
||||
delay_frame=0,
|
||||
device=device,
|
||||
)
|
||||
|
||||
res_frame_list = []
|
||||
total_batches = int(np.ceil(float(video_num) / batch_size))
|
||||
print(f"🚀 推理: {video_num} 帧, batch={batch_size}, {total_batches} 批")
|
||||
|
||||
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total_batches)):
|
||||
audio_feature_batch = pe(whisper_batch)
|
||||
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
||||
pred_latents = unet.model(
|
||||
latent_batch, timesteps,
|
||||
encoder_hidden_states=audio_feature_batch
|
||||
).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
res_frame_list.append(res_frame)
|
||||
|
||||
timings["5_unet"] = time.time() - t0
|
||||
print(f"✅ UNet 推理: {len(res_frame_list)} 帧 [{timings['5_unet']:.1f}s]")
|
||||
|
||||
# ===== Phase 6: 合成 (缓存 BiSeNet mask + cv2.VideoWriter) =====
|
||||
t0 = time.time()
|
||||
|
||||
h, w = frames[0].shape[:2]
|
||||
temp_raw_path = output_vid_path + ".raw.mp4"
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
writer = cv2.VideoWriter(temp_raw_path, fourcc, fps, (w, h))
|
||||
|
||||
if not writer.isOpened():
|
||||
raise RuntimeError(f"cv2.VideoWriter 打开失败: {temp_raw_path}")
|
||||
|
||||
cached_mask = None
|
||||
cached_crop_box = None
|
||||
blend_mode = "jaw" if version == "v15" else "raw"
|
||||
|
||||
for i in tqdm(range(len(res_frame_list)), desc="合成"):
|
||||
res_frame = res_frame_list[i]
|
||||
bbox = coord_list_cycle[i % len(coord_list_cycle)]
|
||||
ori_frame = frame_list_cycle[i % len(frame_list_cycle)].copy()
|
||||
|
||||
x1, y1, x2, y2 = bbox
|
||||
if version == "v15":
|
||||
y2 = min(y2 + extra_margin, ori_frame.shape[0])
|
||||
adjusted_bbox = (x1, y1, x2, y2)
|
||||
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
|
||||
except Exception:
|
||||
writer.write(ori_frame)
|
||||
continue
|
||||
|
||||
# 每 N 帧更新 BiSeNet 人脸解析 mask, 其余帧复用缓存
|
||||
if i % BLEND_CACHE_EVERY == 0 or cached_mask is None:
|
||||
try:
|
||||
cached_mask, cached_crop_box = get_image_prepare_material(
|
||||
ori_frame, adjusted_bbox, mode=blend_mode, fp=fp)
|
||||
except Exception:
|
||||
# 如果 prepare 失败, 用完整方式
|
||||
combine_frame = get_image(
|
||||
ori_frame, res_frame, list(adjusted_bbox),
|
||||
mode=blend_mode, fp=fp)
|
||||
writer.write(combine_frame)
|
||||
continue
|
||||
|
||||
try:
|
||||
combine_frame = get_image_blending(
|
||||
ori_frame, res_frame, adjusted_bbox, cached_mask, cached_crop_box)
|
||||
except Exception:
|
||||
# blending 失败时 fallback 到完整方式
|
||||
combine_frame = get_image(
|
||||
ori_frame, res_frame, list(adjusted_bbox),
|
||||
mode=blend_mode, fp=fp)
|
||||
|
||||
writer.write(combine_frame)
|
||||
|
||||
writer.release()
|
||||
timings["6_blend"] = time.time() - t0
|
||||
print(f"🎨 合成 [{timings['6_blend']:.1f}s]")
|
||||
|
||||
# ===== Phase 7: FFmpeg 重编码 H.264 + 合并音频 =====
|
||||
t0 = time.time()
|
||||
cmd = (
|
||||
f"ffmpeg -y -v warning -i {temp_raw_path} -i {audio_path} "
|
||||
f"-c:v libx264 -crf 18 -pix_fmt yuv420p "
|
||||
f"-c:a copy -shortest {output_vid_path}"
|
||||
)
|
||||
if not run_ffmpeg(cmd):
|
||||
raise RuntimeError("FFmpeg 重编码+音频合并失败")
|
||||
|
||||
# 清理临时文件
|
||||
if os.path.exists(temp_raw_path):
|
||||
os.unlink(temp_raw_path)
|
||||
|
||||
timings["7_encode"] = time.time() - t0
|
||||
print(f"🔊 编码+音频 [{timings['7_encode']:.1f}s]")
|
||||
|
||||
# ===== 汇总 =====
|
||||
total_time = time.time() - t_total
|
||||
print(f"\n⏱️ 总耗时: {total_time:.1f}s ({total_time/60:.1f}min)")
|
||||
for k, v in timings.items():
|
||||
pct = v / total_time * 100
|
||||
print(f" {k}: {v:.1f}s ({pct:.0f}%)")
|
||||
|
||||
if not os.path.exists(output_vid_path):
|
||||
raise RuntimeError("输出文件未生成")
|
||||
|
||||
return {"status": "success", "output_path": output_vid_path}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8011)
|
||||
33
models/MuseTalk/test_ffmpeg.py
Normal file
33
models/MuseTalk/test_ffmpeg.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
def test_ffmpeg(ffmpeg_path):
|
||||
print(f"Testing ffmpeg path: {ffmpeg_path}")
|
||||
|
||||
# Choose path separator based on operating system
|
||||
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||
|
||||
# Add ffmpeg path to environment variable
|
||||
os.environ["PATH"] = f"{ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||
|
||||
try:
|
||||
# Try to run ffmpeg
|
||||
result = subprocess.run(["ffmpeg", "-version"], capture_output=True, text=True)
|
||||
print("FFmpeg test successful!")
|
||||
print("FFmpeg version information:")
|
||||
print(result.stdout)
|
||||
return True
|
||||
except Exception as e:
|
||||
print("FFmpeg test failed!")
|
||||
print(f"Error message: {str(e)}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Default ffmpeg path, can be modified as needed
|
||||
default_path = r"ffmpeg-master-latest-win64-gpl-shared\bin"
|
||||
|
||||
# Use command line argument if provided, otherwise use default path
|
||||
ffmpeg_path = sys.argv[1] if len(sys.argv) > 1 else default_path
|
||||
|
||||
test_ffmpeg(ffmpeg_path)
|
||||
580
models/MuseTalk/train.py
Normal file
580
models/MuseTalk/train.py
Normal file
@@ -0,0 +1,580 @@
|
||||
import argparse
|
||||
import diffusers
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
import warnings
|
||||
import random
|
||||
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import LoggerType
|
||||
from accelerate import InitProcessGroupKwargs
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
from diffusers.utils import check_min_version
|
||||
from einops import rearrange
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from musetalk.utils.utils import (
|
||||
delete_additional_ckpt,
|
||||
seed_everything,
|
||||
get_mouth_region,
|
||||
process_audio_features,
|
||||
save_models
|
||||
)
|
||||
from musetalk.loss.basic_loss import set_requires_grad
|
||||
from musetalk.loss.syncnet import get_sync_loss
|
||||
from musetalk.utils.training_utils import (
|
||||
initialize_models_and_optimizers,
|
||||
initialize_dataloaders,
|
||||
initialize_loss_functions,
|
||||
initialize_syncnet,
|
||||
initialize_vgg,
|
||||
validation
|
||||
)
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
warnings.filterwarnings("ignore")
|
||||
check_min_version("0.10.0.dev0")
|
||||
|
||||
def main(cfg):
|
||||
exp_name = cfg.exp_name
|
||||
save_dir = f"{cfg.output_dir}/{exp_name}"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
kwargs = DistributedDataParallelKwargs()
|
||||
process_group_kwargs = InitProcessGroupKwargs(
|
||||
timeout=timedelta(seconds=5400))
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
|
||||
log_with=["tensorboard", LoggerType.TENSORBOARD],
|
||||
project_dir=os.path.join(save_dir, "./tensorboard"),
|
||||
kwargs_handlers=[kwargs, process_group_kwargs],
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if cfg.seed is not None:
|
||||
print('cfg.seed', cfg.seed, accelerator.process_index)
|
||||
seed_everything(cfg.seed + accelerator.process_index)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
|
||||
model_dict = initialize_models_and_optimizers(cfg, accelerator, weight_dtype)
|
||||
dataloader_dict = initialize_dataloaders(cfg)
|
||||
loss_dict = initialize_loss_functions(cfg, accelerator, model_dict['scheduler_max_steps'])
|
||||
syncnet = initialize_syncnet(cfg, accelerator, weight_dtype)
|
||||
vgg_IN, pyramid, downsampler = initialize_vgg(cfg, accelerator)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader'] = accelerator.prepare(
|
||||
model_dict['net'], model_dict['optimizer'], model_dict['lr_scheduler'], dataloader_dict['train_dataloader'], dataloader_dict['val_dataloader']
|
||||
)
|
||||
print("length train/val", len(dataloader_dict['train_dataloader']), len(dataloader_dict['val_dataloader']))
|
||||
|
||||
# Calculate training steps and epochs
|
||||
num_update_steps_per_epoch = math.ceil(
|
||||
len(dataloader_dict['train_dataloader']) / cfg.solver.gradient_accumulation_steps
|
||||
)
|
||||
num_train_epochs = math.ceil(
|
||||
cfg.solver.max_train_steps / num_update_steps_per_epoch
|
||||
)
|
||||
|
||||
# Initialize trackers on the main process
|
||||
if accelerator.is_main_process:
|
||||
run_time = datetime.now().strftime("%Y%m%d-%H%M")
|
||||
accelerator.init_trackers(
|
||||
cfg.exp_name,
|
||||
init_kwargs={"mlflow": {"run_name": run_time}},
|
||||
)
|
||||
|
||||
# Calculate total batch size
|
||||
total_batch_size = (
|
||||
cfg.data.train_bs
|
||||
* accelerator.num_processes
|
||||
* cfg.solver.gradient_accumulation_steps
|
||||
)
|
||||
|
||||
# Log training information
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f"Num Epochs = {num_train_epochs}")
|
||||
logger.info(f"Instantaneous batch size per device = {cfg.data.train_bs}")
|
||||
logger.info(
|
||||
f"Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
||||
)
|
||||
logger.info(
|
||||
f"Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}")
|
||||
logger.info(f"Total optimization steps = {cfg.solver.max_train_steps}")
|
||||
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Load checkpoint if resuming training
|
||||
if cfg.resume_from_checkpoint:
|
||||
resume_dir = save_dir
|
||||
dirs = os.listdir(resume_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
if len(dirs) > 0:
|
||||
path = dirs[-1]
|
||||
accelerator.load_state(os.path.join(resume_dir, path))
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
global_step = int(path.split("-")[1])
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = global_step % num_update_steps_per_epoch
|
||||
|
||||
# Initialize progress bar
|
||||
progress_bar = tqdm(
|
||||
range(global_step, cfg.solver.max_train_steps),
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
# Log model types
|
||||
print("log type of models")
|
||||
print("unet", model_dict['unet'].dtype)
|
||||
print("vae", model_dict['vae'].dtype)
|
||||
print("wav2vec", model_dict['wav2vec'].dtype)
|
||||
|
||||
def get_ganloss_weight(step):
|
||||
"""Calculate GAN loss weight based on training step"""
|
||||
if step < cfg.discriminator_train_params.start_gan:
|
||||
return 0.0
|
||||
else:
|
||||
return 1.0
|
||||
|
||||
# Training loop
|
||||
for epoch in range(first_epoch, num_train_epochs):
|
||||
# Set models to training mode
|
||||
model_dict['unet'].train()
|
||||
if cfg.loss_params.gan_loss > 0:
|
||||
loss_dict['discriminator'].train()
|
||||
if cfg.loss_params.mouth_gan_loss > 0:
|
||||
loss_dict['mouth_discriminator'].train()
|
||||
|
||||
# Initialize loss accumulators
|
||||
train_loss = 0.0
|
||||
train_loss_D = 0.0
|
||||
train_loss_D_mouth = 0.0
|
||||
l1_loss_accum = 0.0
|
||||
vgg_loss_accum = 0.0
|
||||
gan_loss_accum = 0.0
|
||||
gan_loss_accum_mouth = 0.0
|
||||
fm_loss_accum = 0.0
|
||||
sync_loss_accum = 0.0
|
||||
adapted_weight_accum = 0.0
|
||||
|
||||
t_data_start = time.time()
|
||||
for step, batch in enumerate(dataloader_dict['train_dataloader']):
|
||||
t_data = time.time() - t_data_start
|
||||
t_model_start = time.time()
|
||||
|
||||
with torch.no_grad():
|
||||
# Process input data
|
||||
pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
|
||||
accelerator.device,
|
||||
non_blocking=True
|
||||
)
|
||||
bsz, num_frames, c, h, w = pixel_values.shape
|
||||
|
||||
# Process reference images
|
||||
ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
|
||||
accelerator.device,
|
||||
non_blocking=True
|
||||
)
|
||||
|
||||
# Get face mask for GAN
|
||||
pixel_values_face_mask = batch['pixel_values_face_mask']
|
||||
|
||||
# Process audio features
|
||||
audio_prompts = process_audio_features(cfg, batch, model_dict['wav2vec'], bsz, num_frames, weight_dtype)
|
||||
|
||||
# Initialize adapted weight
|
||||
adapted_weight = 1
|
||||
|
||||
# Process sync loss if enabled
|
||||
if cfg.loss_params.sync_loss > 0:
|
||||
mels = batch['mel']
|
||||
# Prepare frames for latentsync (combine channels and frames)
|
||||
gt_frames = rearrange(pixel_values, 'b f c h w-> b (f c) h w')
|
||||
# Use lower half of face for latentsync
|
||||
height = gt_frames.shape[2]
|
||||
gt_frames = gt_frames[:, :, height // 2:, :]
|
||||
|
||||
# Get audio embeddings
|
||||
audio_embed = syncnet.get_audio_embed(mels)
|
||||
|
||||
# Calculate adapted weight based on audio-visual similarity
|
||||
if cfg.use_adapted_weight:
|
||||
vision_embed_gt = syncnet.get_vision_embed(gt_frames)
|
||||
image_audio_sim_gt = F.cosine_similarity(
|
||||
audio_embed,
|
||||
vision_embed_gt,
|
||||
dim=1
|
||||
)[0]
|
||||
|
||||
if image_audio_sim_gt < 0.05 or image_audio_sim_gt > 0.65:
|
||||
if cfg.adapted_weight_type == "cut_off":
|
||||
adapted_weight = 0.0 # Skip this batch
|
||||
print(
|
||||
f"\nThe i-a similarity in step {global_step} is {image_audio_sim_gt}, set adapted_weight to {adapted_weight}.")
|
||||
elif cfg.adapted_weight_type == "linear":
|
||||
adapted_weight = image_audio_sim_gt
|
||||
else:
|
||||
print(f"unknown adapted_weight_type: {cfg.adapted_weight_type}")
|
||||
adapted_weight = 1
|
||||
|
||||
# Random frame selection for memory efficiency
|
||||
max_start = 16 - cfg.num_backward_frames
|
||||
frames_left_index = random.randint(0, max_start) if max_start > 0 else 0
|
||||
frames_right_index = frames_left_index + cfg.num_backward_frames
|
||||
else:
|
||||
frames_left_index = 0
|
||||
frames_right_index = cfg.data.n_sample_frames
|
||||
|
||||
# Extract frames for backward pass
|
||||
pixel_values_backward = pixel_values[:, frames_left_index:frames_right_index, ...]
|
||||
ref_pixel_values_backward = ref_pixel_values[:, frames_left_index:frames_right_index, ...]
|
||||
pixel_values_face_mask_backward = pixel_values_face_mask[:, frames_left_index:frames_right_index, ...]
|
||||
audio_prompts_backward = audio_prompts[:, frames_left_index:frames_right_index, ...]
|
||||
|
||||
# Encode target images
|
||||
frames = rearrange(pixel_values_backward, 'b f c h w-> (b f) c h w')
|
||||
latents = model_dict['vae'].encode(frames).latent_dist.mode()
|
||||
latents = latents * model_dict['vae'].config.scaling_factor
|
||||
latents = latents.float()
|
||||
|
||||
# Create masked images
|
||||
masked_pixel_values = pixel_values_backward.clone()
|
||||
masked_pixel_values[:, :, :, h//2:, :] = -1
|
||||
masked_frames = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
|
||||
masked_latents = model_dict['vae'].encode(masked_frames).latent_dist.mode()
|
||||
masked_latents = masked_latents * model_dict['vae'].config.scaling_factor
|
||||
masked_latents = masked_latents.float()
|
||||
|
||||
# Encode reference images
|
||||
ref_frames = rearrange(ref_pixel_values_backward, 'b f c h w-> (b f) c h w')
|
||||
ref_latents = model_dict['vae'].encode(ref_frames).latent_dist.mode()
|
||||
ref_latents = ref_latents * model_dict['vae'].config.scaling_factor
|
||||
ref_latents = ref_latents.float()
|
||||
|
||||
# Prepare face mask and audio features
|
||||
pixel_values_face_mask_backward = rearrange(
|
||||
pixel_values_face_mask_backward,
|
||||
"b f c h w -> (b f) c h w"
|
||||
)
|
||||
audio_prompts_backward = rearrange(
|
||||
audio_prompts_backward,
|
||||
'b f c h w-> (b f) c h w'
|
||||
)
|
||||
audio_prompts_backward = rearrange(
|
||||
audio_prompts_backward,
|
||||
'(b f) c h w -> (b f) (c h) w',
|
||||
b=bsz
|
||||
)
|
||||
|
||||
# Apply reference dropout (currently inactive)
|
||||
dropout = nn.Dropout(p=cfg.ref_dropout_rate)
|
||||
ref_latents = dropout(ref_latents)
|
||||
|
||||
# Prepare model inputs
|
||||
input_latents = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
input_latents = input_latents.to(weight_dtype)
|
||||
timesteps = torch.tensor([0], device=input_latents.device)
|
||||
|
||||
# Forward pass
|
||||
latents_pred = model_dict['net'](
|
||||
input_latents,
|
||||
timesteps,
|
||||
audio_prompts_backward,
|
||||
)
|
||||
latents_pred = (1 / model_dict['vae'].config.scaling_factor) * latents_pred
|
||||
image_pred = model_dict['vae'].decode(latents_pred).sample
|
||||
|
||||
# Convert to float
|
||||
image_pred = image_pred.float()
|
||||
frames = frames.float()
|
||||
|
||||
# Calculate L1 loss
|
||||
l1_loss = loss_dict['L1_loss'](frames, image_pred)
|
||||
l1_loss_accum += l1_loss.item()
|
||||
loss = cfg.loss_params.l1_loss * l1_loss * adapted_weight
|
||||
|
||||
# Process mouth GAN loss if enabled
|
||||
if cfg.loss_params.mouth_gan_loss > 0:
|
||||
frames_mouth, image_pred_mouth = get_mouth_region(
|
||||
frames,
|
||||
image_pred,
|
||||
pixel_values_face_mask_backward
|
||||
)
|
||||
pyramide_real_mouth = pyramid(downsampler(frames_mouth))
|
||||
pyramide_generated_mouth = pyramid(downsampler(image_pred_mouth))
|
||||
|
||||
# Process VGG loss if enabled
|
||||
if cfg.loss_params.vgg_loss > 0:
|
||||
pyramide_real = pyramid(downsampler(frames))
|
||||
pyramide_generated = pyramid(downsampler(image_pred))
|
||||
|
||||
loss_IN = 0
|
||||
for scale in cfg.loss_params.pyramid_scale:
|
||||
x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)])
|
||||
y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)])
|
||||
for i, weight in enumerate(cfg.loss_params.vgg_layer_weight):
|
||||
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
||||
loss_IN += weight * value
|
||||
loss_IN /= sum(cfg.loss_params.vgg_layer_weight)
|
||||
loss += loss_IN * cfg.loss_params.vgg_loss * adapted_weight
|
||||
vgg_loss_accum += loss_IN.item()
|
||||
|
||||
# Process GAN loss if enabled
|
||||
if cfg.loss_params.gan_loss > 0:
|
||||
set_requires_grad(loss_dict['discriminator'], False)
|
||||
loss_G = 0.
|
||||
discriminator_maps_generated = loss_dict['discriminator'](pyramide_generated)
|
||||
discriminator_maps_real = loss_dict['discriminator'](pyramide_real)
|
||||
|
||||
for scale in loss_dict['disc_scales']:
|
||||
key = 'prediction_map_%s' % scale
|
||||
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
|
||||
loss_G += value
|
||||
gan_loss_accum += loss_G.item()
|
||||
|
||||
loss += loss_G * cfg.loss_params.gan_loss * get_ganloss_weight(global_step) * adapted_weight
|
||||
|
||||
# Process feature matching loss if enabled
|
||||
if cfg.loss_params.fm_loss[0] > 0:
|
||||
L_feature_matching = 0.
|
||||
for scale in loss_dict['disc_scales']:
|
||||
key = 'feature_maps_%s' % scale
|
||||
for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
|
||||
value = torch.abs(a - b).mean()
|
||||
L_feature_matching += value * cfg.loss_params.fm_loss[i]
|
||||
loss += L_feature_matching * adapted_weight
|
||||
fm_loss_accum += L_feature_matching.item()
|
||||
|
||||
# Process mouth GAN loss if enabled
|
||||
if cfg.loss_params.mouth_gan_loss > 0:
|
||||
set_requires_grad(loss_dict['mouth_discriminator'], False)
|
||||
loss_G = 0.
|
||||
mouth_discriminator_maps_generated = loss_dict['mouth_discriminator'](pyramide_generated_mouth)
|
||||
mouth_discriminator_maps_real = loss_dict['mouth_discriminator'](pyramide_real_mouth)
|
||||
|
||||
for scale in loss_dict['disc_scales']:
|
||||
key = 'prediction_map_%s' % scale
|
||||
value = ((1 - mouth_discriminator_maps_generated[key]) ** 2).mean()
|
||||
loss_G += value
|
||||
gan_loss_accum_mouth += loss_G.item()
|
||||
|
||||
loss += loss_G * cfg.loss_params.mouth_gan_loss * get_ganloss_weight(global_step) * adapted_weight
|
||||
|
||||
# Process feature matching loss for mouth if enabled
|
||||
if cfg.loss_params.fm_loss[0] > 0:
|
||||
L_feature_matching = 0.
|
||||
for scale in loss_dict['disc_scales']:
|
||||
key = 'feature_maps_%s' % scale
|
||||
for i, (a, b) in enumerate(zip(mouth_discriminator_maps_real[key], mouth_discriminator_maps_generated[key])):
|
||||
value = torch.abs(a - b).mean()
|
||||
L_feature_matching += value * cfg.loss_params.fm_loss[i]
|
||||
loss += L_feature_matching * adapted_weight
|
||||
fm_loss_accum += L_feature_matching.item()
|
||||
|
||||
# Process sync loss if enabled
|
||||
if cfg.loss_params.sync_loss > 0:
|
||||
pred_frames = rearrange(
|
||||
image_pred, '(b f) c h w-> b (f c) h w', f=pixel_values_backward.shape[1])
|
||||
pred_frames = pred_frames[:, :, height // 2 :, :]
|
||||
sync_loss, image_audio_sim_pred = get_sync_loss(
|
||||
audio_embed,
|
||||
gt_frames,
|
||||
pred_frames,
|
||||
syncnet,
|
||||
adapted_weight,
|
||||
frames_left_index=frames_left_index,
|
||||
frames_right_index=frames_right_index,
|
||||
)
|
||||
sync_loss_accum += sync_loss.item()
|
||||
loss += sync_loss * cfg.loss_params.sync_loss * adapted_weight
|
||||
|
||||
# Backward pass
|
||||
avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean()
|
||||
train_loss += avg_loss.item()
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Train discriminator if GAN loss is enabled
|
||||
if cfg.loss_params.gan_loss > 0:
|
||||
set_requires_grad(loss_dict['discriminator'], True)
|
||||
loss_D = loss_dict['discriminator_full'](frames, image_pred.detach())
|
||||
avg_loss_D = accelerator.gather(loss_D.repeat(cfg.data.train_bs)).mean()
|
||||
train_loss_D += avg_loss_D.item() / 1
|
||||
loss_D = loss_D * get_ganloss_weight(global_step) * adapted_weight
|
||||
accelerator.backward(loss_D)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(
|
||||
loss_dict['discriminator'].parameters(), cfg.solver.max_grad_norm)
|
||||
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
|
||||
loss_dict['optimizer_D'].step()
|
||||
loss_dict['scheduler_D'].step()
|
||||
loss_dict['optimizer_D'].zero_grad()
|
||||
|
||||
# Train mouth discriminator if mouth GAN loss is enabled
|
||||
if cfg.loss_params.mouth_gan_loss > 0:
|
||||
set_requires_grad(loss_dict['mouth_discriminator'], True)
|
||||
mouth_loss_D = loss_dict['mouth_discriminator_full'](
|
||||
frames_mouth, image_pred_mouth.detach())
|
||||
avg_mouth_loss_D = accelerator.gather(
|
||||
mouth_loss_D.repeat(cfg.data.train_bs)).mean()
|
||||
train_loss_D_mouth += avg_mouth_loss_D.item() / 1
|
||||
mouth_loss_D = mouth_loss_D * get_ganloss_weight(global_step) * adapted_weight
|
||||
accelerator.backward(mouth_loss_D)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(
|
||||
loss_dict['mouth_discriminator'].parameters(), cfg.solver.max_grad_norm)
|
||||
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
|
||||
loss_dict['mouth_optimizer_D'].step()
|
||||
loss_dict['mouth_scheduler_D'].step()
|
||||
loss_dict['mouth_optimizer_D'].zero_grad()
|
||||
|
||||
# Update main model
|
||||
if (global_step + 1) % cfg.solver.gradient_accumulation_steps == 0:
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(
|
||||
model_dict['trainable_params'],
|
||||
cfg.solver.max_grad_norm,
|
||||
)
|
||||
model_dict['optimizer'].step()
|
||||
model_dict['lr_scheduler'].step()
|
||||
model_dict['optimizer'].zero_grad()
|
||||
|
||||
# Update progress and log metrics
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
accelerator.log({
|
||||
"train_loss": train_loss,
|
||||
"train_loss_D": train_loss_D,
|
||||
"train_loss_D_mouth": train_loss_D_mouth,
|
||||
"l1_loss": l1_loss_accum,
|
||||
"vgg_loss": vgg_loss_accum,
|
||||
"gan_loss": gan_loss_accum,
|
||||
"fm_loss": fm_loss_accum,
|
||||
"sync_loss": sync_loss_accum,
|
||||
"adapted_weight": adapted_weight_accum,
|
||||
"lr": model_dict['lr_scheduler'].get_last_lr()[0],
|
||||
}, step=global_step)
|
||||
|
||||
# Reset loss accumulators
|
||||
train_loss = 0.0
|
||||
l1_loss_accum = 0.0
|
||||
vgg_loss_accum = 0.0
|
||||
gan_loss_accum = 0.0
|
||||
fm_loss_accum = 0.0
|
||||
sync_loss_accum = 0.0
|
||||
adapted_weight_accum = 0.0
|
||||
train_loss_D = 0.0
|
||||
train_loss_D_mouth = 0.0
|
||||
|
||||
# Run validation if needed
|
||||
if global_step % cfg.val_freq == 0 or global_step == 10:
|
||||
try:
|
||||
validation(
|
||||
cfg,
|
||||
dataloader_dict['val_dataloader'],
|
||||
model_dict['net'],
|
||||
model_dict['vae'],
|
||||
model_dict['wav2vec'],
|
||||
accelerator,
|
||||
save_dir,
|
||||
global_step,
|
||||
weight_dtype,
|
||||
syncnet_score=adapted_weight,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"An error occurred during validation: {e}")
|
||||
|
||||
# Save checkpoint if needed
|
||||
if global_step % cfg.checkpointing_steps == 0:
|
||||
save_path = os.path.join(save_dir, f"checkpoint-{global_step}")
|
||||
try:
|
||||
start_time = time.time()
|
||||
if accelerator.is_main_process:
|
||||
save_models(
|
||||
accelerator,
|
||||
model_dict['net'],
|
||||
save_dir,
|
||||
global_step,
|
||||
cfg,
|
||||
logger=logger
|
||||
)
|
||||
delete_additional_ckpt(save_dir, cfg.total_limit)
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time > 300:
|
||||
print(f"Skipping storage as it took too long in step {global_step}.")
|
||||
else:
|
||||
print(f"Resume states saved at {save_dir} successfully in {elapsed_time}s.")
|
||||
except Exception as e:
|
||||
print(f"Error when saving model in step {global_step}:", e)
|
||||
|
||||
# Update progress bar
|
||||
t_model = time.time() - t_model_start
|
||||
logs = {
|
||||
"step_loss": loss.detach().item(),
|
||||
"lr": model_dict['lr_scheduler'].get_last_lr()[0],
|
||||
"td": f"{t_data:.2f}s",
|
||||
"tm": f"{t_model:.2f}s",
|
||||
}
|
||||
t_data_start = time.time()
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= cfg.solver.max_train_steps:
|
||||
break
|
||||
|
||||
# Save model after each epoch
|
||||
if (epoch + 1) % cfg.save_model_epoch_interval == 0:
|
||||
try:
|
||||
start_time = time.time()
|
||||
if accelerator.is_main_process:
|
||||
save_models(accelerator, model_dict['net'], save_dir, global_step, cfg)
|
||||
accelerator.save_state(save_path)
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time > 120:
|
||||
print(f"Skipping storage as it took too long in step {global_step}.")
|
||||
else:
|
||||
print(f"Model saved successfully in {elapsed_time}s.")
|
||||
except Exception as e:
|
||||
print(f"Error when saving model in step {global_step}:", e)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# End training
|
||||
accelerator.end_training()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml")
|
||||
args = parser.parse_args()
|
||||
config = OmegaConf.load(args.config)
|
||||
main(config)
|
||||
34
models/MuseTalk/train.sh
Normal file
34
models/MuseTalk/train.sh
Normal file
@@ -0,0 +1,34 @@
|
||||
#!/bin/bash
|
||||
|
||||
# MuseTalk Training Script
|
||||
# This script combines both training stages for the MuseTalk model
|
||||
# Usage: sh train.sh [stage1|stage2]
|
||||
# Example: sh train.sh stage1 # To run stage 1 training
|
||||
# Example: sh train.sh stage2 # To run stage 2 training
|
||||
|
||||
# Check if stage argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Error: Please specify the training stage"
|
||||
echo "Usage: ./train.sh [stage1|stage2]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
STAGE=$1
|
||||
|
||||
# Validate stage argument
|
||||
if [ "$STAGE" != "stage1" ] && [ "$STAGE" != "stage2" ]; then
|
||||
echo "Error: Invalid stage. Must be either 'stage1' or 'stage2'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Launch distributed training using accelerate
|
||||
# --config_file: Path to the GPU configuration file
|
||||
# --main_process_port: Port number for the main process, used for distributed training communication
|
||||
# train.py: Training script
|
||||
# --config: Path to the training configuration file
|
||||
echo "Starting $STAGE training..."
|
||||
accelerate launch --config_file ./configs/training/gpu.yaml \
|
||||
--main_process_port 29502 \
|
||||
train.py --config ./configs/training/$STAGE.yaml
|
||||
|
||||
echo "Training completed for $STAGE"
|
||||
Reference in New Issue
Block a user