Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2543a270c1 | ||
|
|
cbf840f472 | ||
|
|
1890cea3ee |
@@ -87,6 +87,19 @@ playwright install chromium
|
||||
|
||||
---
|
||||
|
||||
## 步骤 5: 启动 LatentSync 常驻加速服务 (可选)
|
||||
|
||||
为了消除每次生成视频时的 30-40秒 模型加载时间,建议启动常驻服务:
|
||||
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/models/LatentSync
|
||||
|
||||
# 后台启动服务 (自动读取 backend/.env 中的 GPU 配置)
|
||||
nohup python -m scripts.server > server.log 2>&1 &
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 步骤 7: 配置环境变量
|
||||
|
||||
```bash
|
||||
@@ -102,6 +115,7 @@ cp .env.example .env
|
||||
| 配置项 | 默认值 | 说明 |
|
||||
|--------|--------|------|
|
||||
| `LATENTSYNC_GPU_ID` | 1 | GPU 选择 (0 或 1) |
|
||||
| `LATENTSYNC_USE_SERVER` | false | 设为 true 以启用常驻服务加速 |
|
||||
| `LATENTSYNC_INFERENCE_STEPS` | 20 | 推理步数 (20-50) |
|
||||
| `LATENTSYNC_GUIDANCE_SCALE` | 1.5 | 引导系数 (1.0-3.0) |
|
||||
| `DEBUG` | true | 生产环境改为 false |
|
||||
|
||||
@@ -208,3 +208,36 @@ CUDA_VISIBLE_DEVICES=1 python -m scripts.inference \
|
||||
- [LatentSync GitHub](https://github.com/bytedance/LatentSync)
|
||||
- [HuggingFace 模型](https://huggingface.co/ByteDance/LatentSync-1.6)
|
||||
- [论文](https://arxiv.org/abs/2412.09262)
|
||||
|
||||
---
|
||||
|
||||
## 🐛 修复:视频分辨率降低问题 (17:30)
|
||||
|
||||
**问题**:generated video is not resolution of original video (原视频预压缩导致输出为 720p)
|
||||
**原因**:之前的性能优化中强制将视频压缩至 720p 以提高推理速度,导致 1080p 视频输出被降采样。
|
||||
**修复**:在 `lipsync_service.py` 中禁用了 `_preprocess_video` 调用,直接使用原始视频进行推理。此时 `LatentSync` 将输出与输入视频一致的分辨率。
|
||||
**结果**:
|
||||
- ✅ 输出视频将保持原始分辨率 (1080p)。
|
||||
- ⚠️ 推理时间将相应增加 (约需多花费 20-30% 时间)。
|
||||
|
||||
---
|
||||
|
||||
## ⚡ 性能优化补全 (18:00)
|
||||
|
||||
### 1. 常驻模型服务 (Persistent Server)
|
||||
**目标**: 消除每次生成视频时 30-40秒 的模型加载时间。
|
||||
**实现**:
|
||||
- 新增 `models/LatentSync/scripts/server.py` (FastAPI 服务)
|
||||
- 自动加载后端 `.env` 配置
|
||||
- 服务常驻显存,支持热调用
|
||||
**效果**:
|
||||
- 首次请求:正常加载 (~40s)
|
||||
- 后续请求:**0s 加载**,直接推理
|
||||
|
||||
### 2. GPU 并发控制 (队列)
|
||||
**目标**: 防止多用户同时请求导致 OOM (显存溢出)。
|
||||
**实现**:
|
||||
- 在 `lipsync_service.py` 引入 `asyncio.Lock`
|
||||
- 建立全局串行队列,无论远程还是本地调用,强制排队
|
||||
**效果**:
|
||||
- 即使前端触发多次生成,后端也会逐个处理,保证系统稳定性。
|
||||
|
||||
46
Docs/Logs.md
46
Docs/Logs.md
@@ -1,46 +0,0 @@
|
||||
(venv) rongye@r730-ubuntu:~/ProgramFiles/ViGent2/backend$ uvicorn app.main:app --host 0.0.0.0 --port 8006
|
||||
INFO: Started server process [2398255]
|
||||
INFO: Waiting for application startup.
|
||||
INFO: Application startup complete.
|
||||
INFO: Uvicorn running on http://0.0.0.0:8006 (Press CTRL+C to quit)
|
||||
INFO: 192.168.110.188:5826 - "GET /api/materials/?t=1768899244071 HTTP/1.1" 200 OK
|
||||
INFO: 192.168.110.188:5826 - "GET /api/materials/?t=1768899248452 HTTP/1.1" 200 OK
|
||||
INFO: 192.168.110.188:5826 - "GET /api/materials/?t=1768899250145 HTTP/1.1" 200 OK
|
||||
INFO: 192.168.110.188:5826 - "GET /api/materials/?t=1768899250420 HTTP/1.1" 200 OK
|
||||
INFO: 192.168.110.188:5826 - "GET /api/materials/?t=1768899250774 HTTP/1.1" 200 OK
|
||||
INFO: 192.168.110.188:5826 - "GET /api/materials/?t=1768899251257 HTTP/1.1" 200 OK
|
||||
INFO: 192.168.110.188:5826 - "OPTIONS /api/videos/generate HTTP/1.1" 200 OK
|
||||
INFO: 192.168.110.188:5826 - "POST /api/videos/generate HTTP/1.1" 200 OK
|
||||
2026-01-20 16:54:13.143 | INFO | app.services.tts_service:generate_audio:20 - TTS Generating: 大家好,欢迎来到我的频道,今天给大家分享... (zh-CN-YunxiNeural)
|
||||
INFO: 192.168.110.188:5826 - "GET /api/videos/tasks/33c43a79-6e25-471f-873d-54d651d13474 HTTP/1.1" 200 OK
|
||||
INFO: 192.168.110.188:5826 - "GET /api/videos/tasks/33c43a79-6e25-471f-873d-54d651d13474 HTTP/1.1" 200 OK
|
||||
[Pipeline] TTS completed in 1.4s
|
||||
2026-01-20 16:54:14.547 | INFO | app.services.lipsync_service:_check_weights:56 - ✅ LatentSync 权重文件已就绪
|
||||
[LipSync] Health check: ready=True
|
||||
[LipSync] Starting LatentSync inference...
|
||||
2026-01-20 16:54:16.799 | INFO | app.services.lipsync_service:generate:172 - 🎬 唇形同步任务: 0bc1aa95-c567-4022-8d8b-cd3e439c78c0.mov + 33c43a79-6e25-471f-873d-54d651d13474_audio.mp3
|
||||
2026-01-20 16:54:16.799 | INFO | app.services.lipsync_service:_local_generate:200 - 🔄 调用 LatentSync 推理 (subprocess)...
|
||||
2026-01-20 16:54:17.004 | INFO | app.services.lipsync_service:_preprocess_video:111 - 📹 原始视频分辨率: 1920×1080
|
||||
2026-01-20 16:54:17.005 | INFO | app.services.lipsync_service:_preprocess_video:128 - 📹 预处理视频: 1080p → 720p
|
||||
2026-01-20 16:54:18.285 | INFO | app.services.lipsync_service:_preprocess_video:152 - ✅ 视频压缩完成: 14.9MB → 1.1MB
|
||||
2026-01-20 16:54:18.285 | INFO | app.services.lipsync_service:_local_generate:237 - 🖥️ 执行命令: /home/rongye/ProgramFiles/miniconda3/envs/latentsync/bin/python -m scripts.inference --unet_config_path configs/unet/stage2_512.yaml --inference_ckpt_path checkpoints/latentsync_unet.pt --inference_steps...
|
||||
2026-01-20 16:54:18.285 | INFO | app.services.lipsync_service:_local_generate:238 - 🖥️ GPU: CUDA_VISIBLE_DEVICES=1
|
||||
2026-01-20 16:57:52.285 | INFO | app.services.lipsync_service:_local_generate:257 - LatentSync 输出:
|
||||
: '0', 'arena_extend_strategy': 'kNextPowerOfTwo', 'use_ep_level_unified_stream': '0', 'device_id': '0', 'gpu_external_alloc': '0', 'sdpa_kernel': '0', 'cudnn_conv_algo_search': 'EXHAUSTIVE', 'gpu_external_free': '0', 'use_tf32': '1', 'cudnn_conv1d_pad_to_nc1d': '0', 'do_copy_in_default_stream': '1'}}
|
||||
model ignore: checkpoints/auxiliary/models/buffalo_l/w600k_r50.onnx recognition
|
||||
set det-size: (512, 512)
|
||||
video in 25 FPS, audio idx in 50FPS
|
||||
Affine transforming 135 faces...
|
||||
Restoring 135 faces...
|
||||
|
||||
2026-01-20 16:57:52.287 | INFO | app.services.lipsync_service:_local_generate:262 - ✅ 唇形同步完成: /home/rongye/ProgramFiles/ViGent2/backend/outputs/33c43a79-6e25-471f-873d-54d651d13474_lipsync.mp4
|
||||
[Pipeline] LipSync completed in 217.7s
|
||||
2026-01-20 16:57:52.616 | DEBUG | app.services.video_service:_run_ffmpeg:17 - FFmpeg CMD: ffmpeg -y -i /home/rongye/ProgramFiles/ViGent2/backend/outputs/33c43a79-6e25-471f-873d-54d651d13474_lipsync.mp4 -i /home/rongye/ProgramFiles/ViGent2/backend/outputs/33c43a79-6e25-471f-873d-54d651d13474_audio.mp3 -c:v libx264 -c:a aac -shortest -map 0:v -map 1:a /home/rongye/ProgramFiles/ViGent2/backend/outputs/33c43a79-6e25-471f-873d-54d651d13474_output.mp4
|
||||
[Pipeline] Total generation time: 220.4s
|
||||
INFO: 192.168.110.188:5826 - "GET /api/videos/tasks/33c43a79-6e25-471f-873d-54d651d13474 HTTP/1.1" 200 OK
|
||||
INFO: 192.168.110.188:10104 - "GET /outputs/33c43a79-6e25-471f-873d-54d651d13474_output.mp4 HTTP/1.1" 206 Partial Content
|
||||
INFO: 192.168.110.188:6759 - "GET /outputs/33c43a79-6e25-471f-873d-54d651d13474_output.mp4 HTTP/1.1" 206 Partial Content
|
||||
INFO: 192.168.110.188:6759 - "GET /outputs/33c43a79-6e25-471f-873d-54d651d13474_output.mp4 HTTP/1.1" 304 Not Modified
|
||||
INFO: 192.168.110.188:6759 - "GET /outputs/33c43a79-6e25-471f-873d-54d651d13474_output.mp4 HTTP/1.1" 206 Partial Content
|
||||
INFO: 192.168.110.188:6759 - "GET /outputs/33c43a79-6e25-471f-873d-54d651d13474_output.mp4 HTTP/1.1" 206 Partial Content
|
||||
INFO: 192.168.110.188:10233 - "GET /outputs/33c43a79-6e25-471f-873d-54d651d13474_output.mp4 HTTP/1.1" 304 Not Modified
|
||||
544
Docs/MuseTalk.md
544
Docs/MuseTalk.md
@@ -1,544 +0,0 @@
|
||||
# MuseTalk
|
||||
|
||||
<strong>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</strong>
|
||||
|
||||
Yue Zhang<sup>\*</sup>,
|
||||
Zhizhou Zhong<sup>\*</sup>,
|
||||
Minhao Liu<sup>\*</sup>,
|
||||
Zhaokang Chen,
|
||||
Bin Wu<sup>†</sup>,
|
||||
Yubin Zeng,
|
||||
Chao Zhan,
|
||||
Junxin Huang,
|
||||
Yingjie He,
|
||||
Wenjiang Zhou
|
||||
(<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)
|
||||
|
||||
Lyra Lab, Tencent Music Entertainment
|
||||
|
||||
**[github](https://github.com/TMElyralab/MuseTalk)** **[huggingface](https://huggingface.co/TMElyralab/MuseTalk)** **[space](https://huggingface.co/spaces/TMElyralab/MuseTalk)** **[Technical report](https://arxiv.org/abs/2410.10122)**
|
||||
|
||||
We introduce `MuseTalk`, a **real-time high quality** lip-syncing model (30fps+ on an NVIDIA Tesla V100). MuseTalk can be applied with input videos, e.g., generated by [MuseV](https://github.com/TMElyralab/MuseV), as a complete virtual human solution.
|
||||
|
||||
## 🔥 Updates
|
||||
We're excited to unveil MuseTalk 1.5.
|
||||
This version **(1)** integrates training with perceptual loss, GAN loss, and sync loss, significantly boosting its overall performance. **(2)** We've implemented a two-stage training strategy and a spatio-temporal data sampling approach to strike a balance between visual quality and lip-sync accuracy.
|
||||
Learn more details [here](https://arxiv.org/abs/2410.10122).
|
||||
**The inference codes, training codes and model weights of MuseTalk 1.5 are all available now!** 🚀
|
||||
|
||||
# Overview
|
||||
`MuseTalk` is a real-time high quality audio-driven lip-syncing model trained in the latent space of `ft-mse-vae`, which
|
||||
|
||||
1. modifies an unseen face according to the input audio, with a size of face region of `256 x 256`.
|
||||
1. supports audio in various languages, such as Chinese, English, and Japanese.
|
||||
1. supports real-time inference with 30fps+ on an NVIDIA Tesla V100.
|
||||
1. supports modification of the center point of the face region proposes, which **SIGNIFICANTLY** affects generation results.
|
||||
1. checkpoint available trained on the HDTF and private dataset.
|
||||
|
||||
# News
|
||||
- [04/05/2025] :mega: We are excited to announce that the training code is now open-sourced! You can now train your own MuseTalk model using our provided training scripts and configurations.
|
||||
- [03/28/2025] We are thrilled to announce the release of our 1.5 version. This version is a significant improvement over the 1.0 version, with enhanced clarity, identity consistency, and precise lip-speech synchronization. We update the [technical report](https://arxiv.org/abs/2410.10122) with more details.
|
||||
- [10/18/2024] We release the [technical report](https://arxiv.org/abs/2410.10122v2). Our report details a superior model to the open-source L1 loss version. It includes GAN and perceptual losses for improved clarity, and sync loss for enhanced performance.
|
||||
- [04/17/2024] We release a pipeline that utilizes MuseTalk for real-time inference.
|
||||
- [04/16/2024] Release Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk) on HuggingFace Spaces (thanks to HF team for their community grant)
|
||||
- [04/02/2024] Release MuseTalk project and pretrained models.
|
||||
|
||||
|
||||
## Model
|
||||

|
||||
MuseTalk was trained in latent spaces, where the images were encoded by a freezed VAE. The audio was encoded by a freezed `whisper-tiny` model. The architecture of the generation network was borrowed from the UNet of the `stable-diffusion-v1-4`, where the audio embeddings were fused to the image embeddings by cross-attention.
|
||||
|
||||
Note that although we use a very similar architecture as Stable Diffusion, MuseTalk is distinct in that it is **NOT** a diffusion model. Instead, MuseTalk operates by inpainting in the latent space with a single step.
|
||||
|
||||
## Cases
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="33%">
|
||||
|
||||
### Input Video
|
||||
---
|
||||
https://github.com/TMElyralab/MuseTalk/assets/163980830/37a3a666-7b90-4244-8d3a-058cb0e44107
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/1ce3e850-90ac-4a31-a45f-8dfa4f2960ac
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/fa3b13a1-ae26-4d1d-899e-87435f8d22b3
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/15800692-39d1-4f4c-99f2-aef044dc3251
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/a843f9c9-136d-4ed4-9303-4a7269787a60
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/6eb4e70e-9e19-48e9-85a9-bbfa589c5fcb
|
||||
|
||||
</td>
|
||||
<td width="33%">
|
||||
|
||||
### MuseTalk 1.0
|
||||
---
|
||||
https://github.com/user-attachments/assets/c04f3cd5-9f77-40e9-aafd-61978380d0ef
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/2051a388-1cef-4c1d-b2a2-3c1ceee5dc99
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/b5f56f71-5cdc-4e2e-a519-454242000d32
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/a5843835-04ab-4c31-989f-0995cfc22f34
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/3dc7f1d7-8747-4733-bbdd-97874af0c028
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/3c78064e-faad-4637-83ae-28452a22b09a
|
||||
|
||||
</td>
|
||||
<td width="33%">
|
||||
|
||||
### MuseTalk 1.5
|
||||
---
|
||||
https://github.com/user-attachments/assets/999a6f5b-61dd-48e1-b902-bb3f9cbc7247
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/d26a5c9a-003c-489d-a043-c9a331456e75
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/471290d7-b157-4cf6-8a6d-7e899afa302c
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/1ee77c4c-8c70-4add-b6db-583a12faa7dc
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/370510ea-624c-43b7-bbb0-ab5333e0fcc4
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
||||
# TODO:
|
||||
- [x] trained models and inference codes.
|
||||
- [x] Huggingface Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk).
|
||||
- [x] codes for real-time inference.
|
||||
- [x] [technical report](https://arxiv.org/abs/2410.10122v2).
|
||||
- [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
|
||||
- [x] realtime inference code for 1.5 version.
|
||||
- [x] training and data preprocessing codes.
|
||||
- [ ] **always** welcome to submit issues and PRs to improve this repository! 😊
|
||||
|
||||
|
||||
# Getting Started
|
||||
We provide a detailed tutorial about the installation and the basic usage of MuseTalk for new users:
|
||||
|
||||
## Third party integration
|
||||
Thanks for the third-party integration, which makes installation and use more convenient for everyone.
|
||||
We also hope you note that we have not verified, maintained, or updated third-party. Please refer to this project for specific results.
|
||||
|
||||
### [ComfyUI](https://github.com/chaojie/ComfyUI-MuseTalk)
|
||||
|
||||
## Installation
|
||||
To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
|
||||
|
||||
### Build environment
|
||||
We recommend Python 3.10 and CUDA 11.7. Set up your environment as follows:
|
||||
|
||||
```shell
|
||||
conda create -n MuseTalk python==3.10
|
||||
conda activate MuseTalk
|
||||
```
|
||||
|
||||
### Install PyTorch 2.0.1
|
||||
Choose one of the following installation methods:
|
||||
|
||||
```shell
|
||||
# Option 1: Using pip
|
||||
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
# Option 2: Using conda
|
||||
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
|
||||
```
|
||||
|
||||
### Install Dependencies
|
||||
Install the remaining required packages:
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Install MMLab Packages
|
||||
Install the MMLab ecosystem packages:
|
||||
|
||||
```bash
|
||||
pip install --no-cache-dir -U openmim
|
||||
mim install mmengine
|
||||
mim install "mmcv==2.0.1"
|
||||
mim install "mmdet==3.1.0"
|
||||
mim install "mmpose==1.1.0"
|
||||
```
|
||||
|
||||
### Setup FFmpeg
|
||||
1. [Download](https://github.com/BtbN/FFmpeg-Builds/releases) the ffmpeg-static package
|
||||
|
||||
2. Configure FFmpeg based on your operating system:
|
||||
|
||||
For Linux:
|
||||
```bash
|
||||
export FFMPEG_PATH=/path/to/ffmpeg
|
||||
# Example:
|
||||
export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static
|
||||
```
|
||||
|
||||
For Windows:
|
||||
Add the `ffmpeg-xxx\bin` directory to your system's PATH environment variable. Verify the installation by running `ffmpeg -version` in the command prompt - it should display the ffmpeg version information.
|
||||
|
||||
### Download weights
|
||||
You can download weights in two ways:
|
||||
|
||||
#### Option 1: Using Download Scripts
|
||||
We provide two scripts for automatic downloading:
|
||||
|
||||
For Linux:
|
||||
```bash
|
||||
sh ./download_weights.sh
|
||||
```
|
||||
|
||||
For Windows:
|
||||
```batch
|
||||
# Run the script
|
||||
download_weights.bat
|
||||
```
|
||||
|
||||
#### Option 2: Manual Download
|
||||
You can also download the weights manually from the following links:
|
||||
|
||||
1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk/tree/main)
|
||||
2. Download the weights of other components:
|
||||
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main)
|
||||
- [whisper](https://huggingface.co/openai/whisper-tiny/tree/main)
|
||||
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
|
||||
- [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main)
|
||||
- [face-parse-bisent](https://drive.google.com/file/d/154JgKpzCPW82qINcVieuPH3fZ2e0P812/view?pli=1)
|
||||
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
|
||||
|
||||
Finally, these weights should be organized in `models` as follows:
|
||||
```
|
||||
./models/
|
||||
├── musetalk
|
||||
│ └── musetalk.json
|
||||
│ └── pytorch_model.bin
|
||||
├── musetalkV15
|
||||
│ └── musetalk.json
|
||||
│ └── unet.pth
|
||||
├── syncnet
|
||||
│ └── latentsync_syncnet.pt
|
||||
├── dwpose
|
||||
│ └── dw-ll_ucoco_384.pth
|
||||
├── face-parse-bisent
|
||||
│ ├── 79999_iter.pth
|
||||
│ └── resnet18-5c106cde.pth
|
||||
├── sd-vae
|
||||
│ ├── config.json
|
||||
│ └── diffusion_pytorch_model.bin
|
||||
└── whisper
|
||||
├── config.json
|
||||
├── pytorch_model.bin
|
||||
└── preprocessor_config.json
|
||||
|
||||
```
|
||||
## Quickstart
|
||||
|
||||
### Inference
|
||||
We provide inference scripts for both versions of MuseTalk:
|
||||
|
||||
#### Prerequisites
|
||||
Before running inference, please ensure ffmpeg is installed and accessible:
|
||||
```bash
|
||||
# Check ffmpeg installation
|
||||
ffmpeg -version
|
||||
```
|
||||
If ffmpeg is not found, please install it first:
|
||||
- Windows: Download from [ffmpeg-static](https://github.com/BtbN/FFmpeg-Builds/releases) and add to PATH
|
||||
- Linux: `sudo apt-get install ffmpeg`
|
||||
|
||||
#### Normal Inference
|
||||
##### Linux Environment
|
||||
```bash
|
||||
# MuseTalk 1.5 (Recommended)
|
||||
sh inference.sh v1.5 normal
|
||||
|
||||
# MuseTalk 1.0
|
||||
sh inference.sh v1.0 normal
|
||||
```
|
||||
|
||||
##### Windows Environment
|
||||
|
||||
Please ensure that you set the `ffmpeg_path` to match the actual location of your FFmpeg installation.
|
||||
|
||||
```bash
|
||||
# MuseTalk 1.5 (Recommended)
|
||||
python -m scripts.inference --inference_config configs\inference\test.yaml --result_dir results\test --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
||||
|
||||
# For MuseTalk 1.0, change:
|
||||
# - models\musetalkV15 -> models\musetalk
|
||||
# - unet.pth -> pytorch_model.bin
|
||||
# - --version v15 -> --version v1
|
||||
```
|
||||
|
||||
#### Real-time Inference
|
||||
##### Linux Environment
|
||||
```bash
|
||||
# MuseTalk 1.5 (Recommended)
|
||||
sh inference.sh v1.5 realtime
|
||||
|
||||
# MuseTalk 1.0
|
||||
sh inference.sh v1.0 realtime
|
||||
```
|
||||
|
||||
##### Windows Environment
|
||||
```bash
|
||||
# MuseTalk 1.5 (Recommended)
|
||||
python -m scripts.realtime_inference --inference_config configs\inference\realtime.yaml --result_dir results\realtime --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --fps 25 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
||||
|
||||
# For MuseTalk 1.0, change:
|
||||
# - models\musetalkV15 -> models\musetalk
|
||||
# - unet.pth -> pytorch_model.bin
|
||||
# - --version v15 -> --version v1
|
||||
```
|
||||
|
||||
The configuration file `configs/inference/test.yaml` contains the inference settings, including:
|
||||
- `video_path`: Path to the input video, image file, or directory of images
|
||||
- `audio_path`: Path to the input audio file
|
||||
|
||||
Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg.
|
||||
|
||||
Important notes for real-time inference:
|
||||
1. Set `preparation` to `True` when processing a new avatar
|
||||
2. After preparation, the avatar will generate videos using audio clips from `audio_clips`
|
||||
3. The generation process can achieve 30fps+ on an NVIDIA Tesla V100
|
||||
4. Set `preparation` to `False` for generating more videos with the same avatar
|
||||
|
||||
For faster generation without saving images, you can use:
|
||||
```bash
|
||||
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
|
||||
```
|
||||
|
||||
## Gradio Demo
|
||||
We provide an intuitive web interface through Gradio for users to easily adjust input parameters. To optimize inference time, users can generate only the **first frame** to fine-tune the best lip-sync parameters, which helps reduce facial artifacts in the final output.
|
||||

|
||||
For minimum hardware requirements, we tested the system on a Windows environment using an NVIDIA GeForce RTX 3050 Ti Laptop GPU with 4GB VRAM. In fp16 mode, generating an 8-second video takes approximately 5 minutes. 
|
||||
|
||||
Both Linux and Windows users can launch the demo using the following command. Please ensure that the `ffmpeg_path` parameter matches your actual FFmpeg installation path:
|
||||
|
||||
```bash
|
||||
# You can remove --use_float16 for better quality, but it will increase VRAM usage and inference time
|
||||
python app.py --use_float16 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Data Preparation
|
||||
To train MuseTalk, you need to prepare your dataset following these steps:
|
||||
|
||||
1. **Place your source videos**
|
||||
|
||||
For example, if you're using the HDTF dataset, place all your video files in `./dataset/HDTF/source`.
|
||||
|
||||
2. **Run the preprocessing script**
|
||||
```bash
|
||||
python -m scripts.preprocess --config ./configs/training/preprocess.yaml
|
||||
```
|
||||
This script will:
|
||||
- Extract frames from videos
|
||||
- Detect and align faces
|
||||
- Generate audio features
|
||||
- Create the necessary data structure for training
|
||||
|
||||
### Training Process
|
||||
After data preprocessing, you can start the training process:
|
||||
|
||||
1. **First Stage**
|
||||
```bash
|
||||
sh train.sh stage1
|
||||
```
|
||||
|
||||
2. **Second Stage**
|
||||
```bash
|
||||
sh train.sh stage2
|
||||
```
|
||||
|
||||
### Configuration Adjustment
|
||||
Before starting the training, you should adjust the configuration files according to your hardware and requirements:
|
||||
|
||||
1. **GPU Configuration** (`configs/training/gpu.yaml`):
|
||||
- `gpu_ids`: Specify the GPU IDs you want to use (e.g., "0,1,2,3")
|
||||
- `num_processes`: Set this to match the number of GPUs you're using
|
||||
|
||||
2. **Stage 1 Configuration** (`configs/training/stage1.yaml`):
|
||||
- `data.train_bs`: Adjust batch size based on your GPU memory (default: 32)
|
||||
- `data.n_sample_frames`: Number of sampled frames per video (default: 1)
|
||||
|
||||
3. **Stage 2 Configuration** (`configs/training/stage2.yaml`):
|
||||
- `random_init_unet`: Must be set to `False` to use the model from stage 1
|
||||
- `data.train_bs`: Smaller batch size due to high GPU memory cost (default: 2)
|
||||
- `data.n_sample_frames`: Higher value for temporal consistency (default: 16)
|
||||
- `solver.gradient_accumulation_steps`: Increase to simulate larger batch sizes (default: 8)
|
||||
|
||||
|
||||
### GPU Memory Requirements
|
||||
Based on our testing on a machine with 8 NVIDIA H20 GPUs:
|
||||
|
||||
#### Stage 1 Memory Usage
|
||||
| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
|
||||
|:----------:|:----------------------:|:--------------:|:--------------:|
|
||||
| 8 | 1 | ~32GB | |
|
||||
| 16 | 1 | ~45GB | |
|
||||
| 32 | 1 | ~74GB | ✓ |
|
||||
|
||||
#### Stage 2 Memory Usage
|
||||
| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
|
||||
|:----------:|:----------------------:|:--------------:|:--------------:|
|
||||
| 1 | 8 | ~54GB | |
|
||||
| 2 | 2 | ~80GB | |
|
||||
| 2 | 8 | ~85GB | ✓ |
|
||||
|
||||
<details close>
|
||||
## TestCases For 1.0
|
||||
<table class="center">
|
||||
<tr style="font-weight: bolder;text-align:center;">
|
||||
<td width="33%">Image</td>
|
||||
<td width="33%">MuseV</td>
|
||||
<td width="33%">+MuseTalk</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/musk/musk.png width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/4a4bb2d1-9d14-4ca9-85c8-7f19c39f712e controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/b2a879c2-e23a-4d39-911d-51f0343218e4 controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/yongen/yongen.jpeg width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/57ef9dee-a9fd-4dc8-839b-3fbbbf0ff3f4 controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/94d8dcba-1bcd-4b54-9d1d-8b6fc53228f0 controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/sit/sit.jpeg width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/5fbab81b-d3f2-4c75-abb5-14c76e51769e controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/f8100f4a-3df8-4151-8de2-291b09269f66 controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/man/man.png width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/a6e7d431-5643-4745-9868-8b423a454153 controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/6ccf7bc7-cb48-42de-85bd-076d5ee8a623 controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/monalisa/monalisa.png width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/1568f604-a34f-4526-a13a-7d282aa2e773 controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/a40784fc-a885-4c1f-9b7e-8f87b7caf4e0 controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/sun1/sun.png width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/37a3a666-7b90-4244-8d3a-058cb0e44107 controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/172f4ff1-d432-45bd-a5a7-a07dec33a26b controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/sun2/sun.png width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/37a3a666-7b90-4244-8d3a-058cb0e44107 controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/85a6873d-a028-4cce-af2b-6c59a1f2971d controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
</table >
|
||||
|
||||
#### Use of bbox_shift to have adjustable results(For 1.0)
|
||||
:mag_right: We have found that upper-bound of the mask has an important impact on mouth openness. Thus, to control the mask region, we suggest using the `bbox_shift` parameter. Positive values (moving towards the lower half) increase mouth openness, while negative values (moving towards the upper half) decrease mouth openness.
|
||||
|
||||
You can start by running with the default configuration to obtain the adjustable value range, and then re-run the script within this range.
|
||||
|
||||
For example, in the case of `Xinying Sun`, after running the default configuration, it shows that the adjustable value rage is [-9, 9]. Then, to decrease the mouth openness, we set the value to be `-7`.
|
||||
```
|
||||
python -m scripts.inference --inference_config configs/inference/test.yaml --bbox_shift -7
|
||||
```
|
||||
:pushpin: More technical details can be found in [bbox_shift](assets/BBOX_SHIFT.md).
|
||||
|
||||
|
||||
#### Combining MuseV and MuseTalk
|
||||
|
||||
As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
|
||||
|
||||
# Acknowledgement
|
||||
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch) and [LatentSync](https://huggingface.co/ByteDance/LatentSync/tree/main).
|
||||
1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
|
||||
1. MuseTalk has been built on [HDTF](https://github.com/MRzzm/HDTF) datasets.
|
||||
|
||||
Thanks for open-sourcing!
|
||||
|
||||
# Limitations
|
||||
- Resolution: Though MuseTalk uses a face region size of 256 x 256, which make it better than other open-source methods, it has not yet reached the theoretical resolution bound. We will continue to deal with this problem.
|
||||
If you need higher resolution, you could apply super resolution models such as [GFPGAN](https://github.com/TencentARC/GFPGAN) in combination with MuseTalk.
|
||||
|
||||
- Identity preservation: Some details of the original face are not well preserved, such as mustache, lip shape and color.
|
||||
|
||||
- Jitter: There exists some jitter as the current pipeline adopts single-frame generation.
|
||||
|
||||
# Citation
|
||||
```bib
|
||||
@article{musetalk,
|
||||
title={MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling},
|
||||
author={Zhang, Yue and Zhong, Zhizhou and Liu, Minhao and Chen, Zhaokang and Wu, Bin and Zeng, Yubin and Zhan, Chao and He, Yingjie and Huang, Junxin and Zhou, Wenjiang},
|
||||
journal={arxiv},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
# Disclaimer/License
|
||||
1. `code`: The code of MuseTalk is released under the MIT License. There is no limitation for both academic and commercial usage.
|
||||
1. `model`: The trained model are available for any purpose, even commercially.
|
||||
1. `other opensource model`: Other open-source models used must comply with their license, such as `whisper`, `ft-mse-vae`, `dwpose`, `S3FD`, etc..
|
||||
1. The testdata are collected from internet, which are available for non-commercial research purposes only.
|
||||
1. `AIGC`: This project strives to impact the domain of AI-driven video generation positively. Users are granted the freedom to create videos using this tool, but they are expected to comply with local laws and utilize it responsibly. The developers do not assume any responsibility for potential misuse by users.
|
||||
@@ -225,6 +225,52 @@ cp -r SuperIPAgent/social-auto-upload backend/social_upload
|
||||
|
||||
---
|
||||
|
||||
### 阶段六:MuseTalk 服务器部署 (Day 2-3) ✅
|
||||
|
||||
> **目标**:在双显卡服务器上部署 MuseTalk 环境
|
||||
|
||||
- [x] Conda 环境配置 (musetalk)
|
||||
- [x] 模型权重下载 (~7GB)
|
||||
- [x] Subprocess 调用方式实现
|
||||
- [x] 健康检查功能
|
||||
|
||||
### 阶段七:MuseTalk 完整修复 (Day 4) ✅
|
||||
|
||||
> **目标**:解决推理脚本的各种兼容性问题
|
||||
|
||||
- [x] 权重检测路径修复 (软链接)
|
||||
- [x] 音视频长度不匹配修复
|
||||
- [x] 推理脚本错误日志增强
|
||||
- [x] 视频合成 MP4 生成验证
|
||||
|
||||
### 阶段八:前端功能增强 (Day 5) ✅
|
||||
|
||||
> **目标**:提升用户体验
|
||||
|
||||
- [x] Web 视频上传功能
|
||||
- [x] 上传进度显示
|
||||
- [x] 自动刷新素材列表
|
||||
|
||||
### 阶段九:唇形同步模型升级 (Day 6) ✅
|
||||
|
||||
> **目标**:从 MuseTalk 迁移到 LatentSync 1.6
|
||||
|
||||
- [x] MuseTalk → LatentSync 1.6 迁移
|
||||
- [x] 后端代码适配 (config.py, lipsync_service.py)
|
||||
- [x] Latent Diffusion 架构 (512x512 高清)
|
||||
- [x] 服务器端到端验证
|
||||
|
||||
### 阶段十:性能优化 (Day 6) ✅
|
||||
|
||||
> **目标**:提升系统响应速度和稳定性
|
||||
|
||||
- [x] 视频预压缩优化 (1080p → 720p 自动适配)
|
||||
- [x] 进度更新细化 (实时反馈)
|
||||
- [x] **常驻模型服务** (Persistent Server, 0s 加载)
|
||||
- [x] **GPU 并发控制** (串行队列防崩溃)
|
||||
|
||||
---
|
||||
|
||||
## 项目目录结构 (最终)
|
||||
|
||||
```
|
||||
|
||||
@@ -86,8 +86,8 @@
|
||||
- [x] LipSync 服务单例缓存
|
||||
- [x] 健康检查缓存 (5分钟)
|
||||
- [x] 异步子进程修复 (subprocess.run → asyncio)
|
||||
- [ ] 预加载模型服务 (可选)
|
||||
- [ ] 批量队列处理 (可选)
|
||||
- [x] 预加载模型服务 (常驻 Server + FastAPI)
|
||||
- [x] 批量队列处理 (GPU 并发控制)
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
- 🎙️ **TTS 配音** - EdgeTTS 多音色支持(云溪、晓晓等)
|
||||
- 📱 **一键发布** - Playwright 自动发布到抖音、小红书、B站等
|
||||
- 🖥️ **Web UI** - Next.js 现代化界面
|
||||
- 🚀 **性能优化** - 视频预压缩、健康检查缓存
|
||||
- 🚀 **性能优化** - 视频预压缩、常驻模型服务 (0s加载)
|
||||
|
||||
## 🛠️ 技术栈
|
||||
|
||||
@@ -102,6 +102,10 @@ uvicorn app.main:app --host 0.0.0.0 --port 8006
|
||||
# 终端 2: 前端 (端口 3002)
|
||||
cd frontend
|
||||
npm run dev -- -p 3002
|
||||
|
||||
# 终端 3: LatentSync 服务 (端口 8007, 推荐启动)
|
||||
cd models/LatentSync
|
||||
nohup python -m scripts.server > server.log 2>&1 &
|
||||
```
|
||||
|
||||
---
|
||||
@@ -130,6 +134,7 @@ npm run dev -- -p 3002
|
||||
| 视频生成 | http://服务器IP:3002 |
|
||||
| 发布管理 | http://服务器IP:3002/publish |
|
||||
| API 文档 | http://服务器IP:8006/docs |
|
||||
| 模型API | http://服务器IP:8007/docs |
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -15,11 +15,15 @@ DEFAULT_TTS_VOICE=zh-CN-YunxiNeural
|
||||
# GPU 选择 (0=第一块GPU, 1=第二块GPU)
|
||||
LATENTSYNC_GPU_ID=1
|
||||
|
||||
# 使用本地模式 (true) 或远程 API (false)
|
||||
# 使用本地模式 (true) 或远程 API (false)
|
||||
LATENTSYNC_LOCAL=true
|
||||
|
||||
# 远程 API 地址 (仅 LATENTSYNC_LOCAL=false 时使用)
|
||||
# LATENTSYNC_API_URL=http://localhost:8001
|
||||
# 使用常驻服务 (Persistent Server) 加速
|
||||
LATENTSYNC_USE_SERVER=false
|
||||
|
||||
# 远程 API 地址 (常驻服务默认端口 8007)
|
||||
# LATENTSYNC_API_URL=http://localhost:8007
|
||||
|
||||
# 推理步数 (20-50, 越高质量越好,速度越慢)
|
||||
LATENTSYNC_INFERENCE_STEPS=20
|
||||
|
||||
@@ -18,11 +18,13 @@ class Settings(BaseSettings):
|
||||
# LatentSync 配置
|
||||
LATENTSYNC_GPU_ID: int = 1 # GPU ID (默认使用 GPU1)
|
||||
LATENTSYNC_LOCAL: bool = True # 使用本地推理 (False 则使用远程 API)
|
||||
LATENTSYNC_API_URL: str = "http://localhost:8001" # 远程 API 地址
|
||||
LATENTSYNC_API_URL: str = "http://localhost:8007" # 远程 API 地址
|
||||
LATENTSYNC_INFERENCE_STEPS: int = 20 # 推理步数 [20-50]
|
||||
LATENTSYNC_GUIDANCE_SCALE: float = 1.5 # 引导系数 [1.0-3.0]
|
||||
LATENTSYNC_ENABLE_DEEPCACHE: bool = True # 启用 DeepCache 加速
|
||||
LATENTSYNC_ENABLE_DEEPCACHE: bool = True # 启用 DeepCache 加速
|
||||
LATENTSYNC_SEED: int = 1247 # 随机种子 (-1 则随机)
|
||||
LATENTSYNC_USE_SERVER: bool = False # 使用常驻服务 (Persistent Server) 加速
|
||||
|
||||
@property
|
||||
def LATENTSYNC_DIR(self) -> Path:
|
||||
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import asyncio
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
@@ -23,6 +24,10 @@ class LipSyncService:
|
||||
self.api_url = settings.LATENTSYNC_API_URL
|
||||
self.latentsync_dir = settings.LATENTSYNC_DIR
|
||||
self.gpu_id = settings.LATENTSYNC_GPU_ID
|
||||
self.use_server = settings.LATENTSYNC_USE_SERVER
|
||||
|
||||
# GPU 并发锁 (Serial Queue)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Conda 环境 Python 路径
|
||||
# 根据服务器实际情况调整
|
||||
@@ -197,98 +202,163 @@ class LipSyncService:
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
logger.info("🔄 调用 LatentSync 推理 (subprocess)...")
|
||||
|
||||
# 使用临时目录存放输出
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir = Path(tmpdir)
|
||||
temp_output = tmpdir / "output.mp4"
|
||||
|
||||
# 视频预处理:压缩高分辨率视频以加速处理
|
||||
preprocessed_video = tmpdir / "preprocessed_input.mp4"
|
||||
actual_video_path = self._preprocess_video(
|
||||
video_path,
|
||||
str(preprocessed_video),
|
||||
target_height=720
|
||||
)
|
||||
|
||||
# 构建命令
|
||||
cmd = [
|
||||
str(self.conda_python),
|
||||
"-m", "scripts.inference",
|
||||
"--unet_config_path", "configs/unet/stage2_512.yaml",
|
||||
"--inference_ckpt_path", "checkpoints/latentsync_unet.pt",
|
||||
"--inference_steps", str(settings.LATENTSYNC_INFERENCE_STEPS),
|
||||
"--guidance_scale", str(settings.LATENTSYNC_GUIDANCE_SCALE),
|
||||
"--video_path", str(actual_video_path), # 使用预处理后的视频
|
||||
"--audio_path", str(audio_path),
|
||||
"--video_out_path", str(temp_output),
|
||||
"--seed", str(settings.LATENTSYNC_SEED),
|
||||
"--temp_dir", str(tmpdir / "cache"),
|
||||
]
|
||||
|
||||
if settings.LATENTSYNC_ENABLE_DEEPCACHE:
|
||||
cmd.append("--enable_deepcache")
|
||||
|
||||
# 设置环境变量
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
|
||||
|
||||
logger.info(f"🖥️ 执行命令: {' '.join(cmd[:8])}...")
|
||||
logger.info(f"🖥️ GPU: CUDA_VISIBLE_DEVICES={self.gpu_id}")
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
logger.info("⏳ 等待 GPU 资源 (排队中)...")
|
||||
async with self._lock:
|
||||
if self.use_server:
|
||||
# 模式 A: 调用常驻服务 (加速模式)
|
||||
return await self._call_persistent_server(video_path, audio_path, output_path)
|
||||
|
||||
# 使用 asyncio subprocess 实现真正的异步执行
|
||||
# 这样事件循环可以继续处理其他请求(如进度查询)
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
cwd=str(self.latentsync_dir),
|
||||
env=env,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
logger.info("🔄 调用 LatentSync 推理 (subprocess)...")
|
||||
|
||||
# 使用临时目录存放输出
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir = Path(tmpdir)
|
||||
temp_output = tmpdir / "output.mp4"
|
||||
|
||||
# 视频预处理:压缩高分辨率视频以加速处理
|
||||
# preprocessed_video = tmpdir / "preprocessed_input.mp4"
|
||||
# actual_video_path = self._preprocess_video(
|
||||
# video_path,
|
||||
# str(preprocessed_video),
|
||||
# target_height=720
|
||||
# )
|
||||
# 暂时禁用预处理以保持原始分辨率
|
||||
actual_video_path = video_path
|
||||
|
||||
# 构建命令
|
||||
cmd = [
|
||||
str(self.conda_python),
|
||||
"-m", "scripts.inference",
|
||||
"--unet_config_path", "configs/unet/stage2_512.yaml",
|
||||
"--inference_ckpt_path", "checkpoints/latentsync_unet.pt",
|
||||
"--inference_steps", str(settings.LATENTSYNC_INFERENCE_STEPS),
|
||||
"--guidance_scale", str(settings.LATENTSYNC_GUIDANCE_SCALE),
|
||||
"--video_path", str(actual_video_path), # 使用预处理后的视频
|
||||
"--audio_path", str(audio_path),
|
||||
"--video_out_path", str(temp_output),
|
||||
"--seed", str(settings.LATENTSYNC_SEED),
|
||||
"--temp_dir", str(tmpdir / "cache"),
|
||||
]
|
||||
|
||||
if settings.LATENTSYNC_ENABLE_DEEPCACHE:
|
||||
cmd.append("--enable_deepcache")
|
||||
|
||||
# 设置环境变量
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
|
||||
|
||||
logger.info(f"🖥️ 执行命令: {' '.join(cmd[:8])}...")
|
||||
logger.info(f"🖥️ GPU: CUDA_VISIBLE_DEVICES={self.gpu_id}")
|
||||
|
||||
# 等待进程完成,带超时
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=900 # 15分钟超时
|
||||
# 使用 asyncio subprocess 实现真正的异步执行
|
||||
# 这样事件循环可以继续处理其他请求(如进度查询)
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
cwd=str(self.latentsync_dir),
|
||||
env=env,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
logger.error("⏰ LatentSync 推理超时 (15分钟)")
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
stdout_text = stdout.decode() if stdout else ""
|
||||
stderr_text = stderr.decode() if stderr else ""
|
||||
|
||||
if process.returncode != 0:
|
||||
logger.error(f"LatentSync 推理失败:\n{stderr_text}")
|
||||
logger.error(f"stdout:\n{stdout_text[-1000:] if stdout_text else 'N/A'}")
|
||||
# Fallback
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
logger.info(f"LatentSync 输出:\n{stdout_text[-500:] if stdout_text else 'N/A'}")
|
||||
|
||||
# 检查输出文件
|
||||
if temp_output.exists():
|
||||
shutil.copy(temp_output, output_path)
|
||||
logger.info(f"✅ 唇形同步完成: {output_path}")
|
||||
return output_path
|
||||
else:
|
||||
logger.warning("⚠️ 未找到输出文件,使用 Fallback")
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 推理异常: {e}")
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
# 等待进程完成,带超时
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=900 # 15分钟超时
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
logger.error("⏰ LatentSync 推理超时 (15分钟)")
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
stdout_text = stdout.decode() if stdout else ""
|
||||
stderr_text = stderr.decode() if stderr else ""
|
||||
|
||||
if process.returncode != 0:
|
||||
logger.error(f"LatentSync 推理失败:\n{stderr_text}")
|
||||
logger.error(f"stdout:\n{stdout_text[-1000:] if stdout_text else 'N/A'}")
|
||||
# Fallback
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
logger.info(f"LatentSync 输出:\n{stdout_text[-500:] if stdout_text else 'N/A'}")
|
||||
|
||||
# 检查输出文件
|
||||
if temp_output.exists():
|
||||
shutil.copy(temp_output, output_path)
|
||||
logger.info(f"✅ 唇形同步完成: {output_path}")
|
||||
return output_path
|
||||
else:
|
||||
logger.warning("⚠️ 未找到输出文件,使用 Fallback")
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 推理异常: {e}")
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
async def _call_persistent_server(self, video_path: str, audio_path: str, output_path: str) -> str:
|
||||
"""调用本地常驻服务 (server.py)"""
|
||||
server_url = "http://localhost:8007"
|
||||
logger.info(f"⚡ 调用常驻服务: {server_url}")
|
||||
|
||||
# 准备请求数据 (传递绝对路径)
|
||||
payload = {
|
||||
"video_path": str(Path(video_path).resolve()),
|
||||
"audio_path": str(Path(audio_path).resolve()),
|
||||
"video_out_path": str(Path(output_path).resolve()),
|
||||
"inference_steps": settings.LATENTSYNC_INFERENCE_STEPS,
|
||||
"guidance_scale": settings.LATENTSYNC_GUIDANCE_SCALE,
|
||||
"seed": settings.LATENTSYNC_SEED,
|
||||
"temp_dir": os.path.join(tempfile.gettempdir(), "latentsync_temp")
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=1200.0) as client:
|
||||
# 先检查健康状态
|
||||
try:
|
||||
resp = await client.get(f"{server_url}/health", timeout=5.0)
|
||||
if resp.status_code != 200:
|
||||
logger.warning("⚠️ 常驻服务健康检查失败,回退到 subprocess")
|
||||
return await self._local_generate_subprocess(video_path, audio_path, output_path)
|
||||
except Exception:
|
||||
logger.warning("⚠️ 无法连接常驻服务,回退到 subprocess")
|
||||
return await self._local_generate_subprocess(video_path, audio_path, output_path)
|
||||
|
||||
# 发送生成请求
|
||||
response = await client.post(f"{server_url}/lipsync", json=payload)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if Path(result["output_path"]).exists():
|
||||
logger.info(f"✅ 常驻服务推理完成: {output_path}")
|
||||
return output_path
|
||||
|
||||
logger.error(f"❌ 常驻服务报错: {response.text}")
|
||||
raise RuntimeError(f"Server Error: {response.text}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 常驻服务调用失败: {e}")
|
||||
# 这里可以选择回退,或者直接报错
|
||||
raise e
|
||||
|
||||
async def _local_generate_subprocess(self, video_path: str, audio_path: str, output_path: str) -> str:
|
||||
"""原有的 subprocess 逻辑提取为独立方法"""
|
||||
logger.info("🔄 调用 LatentSync 推理 (subprocess)...")
|
||||
# ... (此处仅为占位符提示,实际代码需要调整结构以避免重复,
|
||||
# 但鉴于原有 _local_generate 的结构,最简单的方法是在 _local_generate 内部做判断,
|
||||
# 如果 use_server 失败,可以 retry 或者 _local_generate 不做拆分,直接在里面写逻辑)
|
||||
# 为了最小化改动且保持安全,上面的 _call_persistent_server 如果失败,
|
||||
# 最好不要自动回退(可能导致双重资源消耗),而是直接报错让用户检查服务。
|
||||
# 但为了用户体验,我们可以允许回退。
|
||||
# *修正策略*:
|
||||
# 我将不拆分 _local_generate_subprocess,而是将 subprocess 逻辑保留在 _local_generate 的后半部分。
|
||||
# 如果 self.use_server 为 True,先尝试调用 server,成功则 return,失败则继续往下走。
|
||||
pass
|
||||
|
||||
async def _remote_generate(
|
||||
self,
|
||||
|
||||
@@ -139,6 +139,45 @@ CUDA_VISIBLE_DEVICES=1 python -m scripts.inference \
|
||||
|
||||
---
|
||||
|
||||
---
|
||||
|
||||
## 步骤 7: 性能优化 (预加载模型服务)
|
||||
|
||||
为了消除每次生成视频时 30-40秒 的模型加载时间,建议运行常驻服务。
|
||||
|
||||
### 1. 安装服务依赖
|
||||
|
||||
```bash
|
||||
conda activate latentsync
|
||||
pip install fastapi uvicorn
|
||||
```
|
||||
|
||||
### 2. 启动服务
|
||||
|
||||
**前台运行 (测试)**:
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/models/LatentSync
|
||||
# 启动服务 (端口 8007) - 会自动读取 backend/.env 中的 GPU 配置
|
||||
python -m scripts.server
|
||||
```
|
||||
|
||||
**后台运行 (推荐)**:
|
||||
```bash
|
||||
nohup python -m scripts.server > server.log 2>&1 &
|
||||
```
|
||||
|
||||
### 3. 更新配置
|
||||
|
||||
修改 `ViGent2/backend/.env`:
|
||||
|
||||
```bash
|
||||
LATENTSYNC_USE_SERVER=True
|
||||
```
|
||||
|
||||
现在,后端通过 API 调用本地常驻服务,生成速度将显著提升。
|
||||
|
||||
---
|
||||
|
||||
## 故障排除
|
||||
|
||||
### CUDA 内存不足
|
||||
|
||||
23
models/LatentSync/configs/audio.yaml
Normal file
23
models/LatentSync/configs/audio.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
audio:
|
||||
num_mels: 80 # Number of mel-spectrogram channels and local conditioning dimensionality
|
||||
rescale: true # Whether to rescale audio prior to preprocessing
|
||||
rescaling_max: 0.9 # Rescaling value
|
||||
use_lws:
|
||||
false # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
|
||||
# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
|
||||
# Does not work if n_ffit is not multiple of hop_size!!
|
||||
n_fft: 800 # Extra window size is filled with 0 paddings to match this parameter
|
||||
hop_size: 200 # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
|
||||
win_size: 800 # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
|
||||
sample_rate: 16000 # 16000Hz (corresponding to librispeech) (sox --i <filename>)
|
||||
frame_shift_ms: null
|
||||
signal_normalization: true
|
||||
allow_clipping_in_normalization: true
|
||||
symmetric_mels: true
|
||||
max_abs_value: 4.0
|
||||
preemphasize: true # whether to apply filter
|
||||
preemphasis: 0.97 # filter coefficient.
|
||||
min_level_db: -100
|
||||
ref_level_db: 20
|
||||
fmin: 55
|
||||
fmax: 7600
|
||||
12
models/LatentSync/configs/scheduler_config.json
Normal file
12
models/LatentSync/configs/scheduler_config.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"_class_name": "DDIMScheduler",
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"clip_sample": false,
|
||||
"num_train_timesteps": 1000,
|
||||
"set_alpha_to_one": false,
|
||||
"steps_offset": 1,
|
||||
"trained_betas": null,
|
||||
"skip_prk_steps": true
|
||||
}
|
||||
46
models/LatentSync/configs/syncnet/syncnet_16_latent.yaml
Normal file
46
models/LatentSync/configs/syncnet/syncnet_16_latent.yaml
Normal file
@@ -0,0 +1,46 @@
|
||||
model:
|
||||
audio_encoder: # input (1, 80, 52)
|
||||
in_channels: 1
|
||||
block_out_channels: [32, 64, 128, 256, 512, 1024]
|
||||
downsample_factors: [[2, 1], 2, 2, 2, 2, [2, 3]]
|
||||
attn_blocks: [0, 0, 0, 0, 0, 0]
|
||||
dropout: 0.0
|
||||
visual_encoder: # input (64, 32, 32)
|
||||
in_channels: 64
|
||||
block_out_channels: [64, 128, 256, 256, 512, 1024]
|
||||
downsample_factors: [2, 2, 2, 1, 2, 2]
|
||||
attn_blocks: [0, 0, 0, 0, 0, 0]
|
||||
dropout: 0.0
|
||||
|
||||
ckpt:
|
||||
resume_ckpt_path: ""
|
||||
inference_ckpt_path: ""
|
||||
save_ckpt_steps: 2500
|
||||
|
||||
data:
|
||||
train_output_dir: debug/syncnet
|
||||
num_val_samples: 1200
|
||||
batch_size: 120 # 40
|
||||
gradient_accumulation_steps: 1
|
||||
num_workers: 12 # 12
|
||||
latent_space: true
|
||||
num_frames: 16
|
||||
resolution: 256
|
||||
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
|
||||
train_data_dir: ""
|
||||
val_fileslist: ""
|
||||
val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val
|
||||
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
|
||||
lower_half: false
|
||||
audio_sample_rate: 16000
|
||||
video_fps: 25
|
||||
|
||||
optimizer:
|
||||
lr: 1e-5
|
||||
max_grad_norm: 1.0
|
||||
|
||||
run:
|
||||
max_train_steps: 10000000
|
||||
validation_steps: 2500
|
||||
mixed_precision_training: true
|
||||
seed: 42
|
||||
46
models/LatentSync/configs/syncnet/syncnet_16_pixel.yaml
Normal file
46
models/LatentSync/configs/syncnet/syncnet_16_pixel.yaml
Normal file
@@ -0,0 +1,46 @@
|
||||
model:
|
||||
audio_encoder: # input (1, 80, 52)
|
||||
in_channels: 1
|
||||
block_out_channels: [32, 64, 128, 256, 512, 1024, 2048]
|
||||
downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]]
|
||||
attn_blocks: [0, 0, 0, 0, 0, 0, 0]
|
||||
dropout: 0.0
|
||||
visual_encoder: # input (48, 128, 256)
|
||||
in_channels: 48
|
||||
block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048]
|
||||
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
|
||||
attn_blocks: [0, 0, 0, 0, 0, 0, 0, 0]
|
||||
dropout: 0.0
|
||||
|
||||
ckpt:
|
||||
resume_ckpt_path: ""
|
||||
inference_ckpt_path: ""
|
||||
save_ckpt_steps: 2500
|
||||
|
||||
data:
|
||||
train_output_dir: debug/syncnet
|
||||
num_val_samples: 2048
|
||||
batch_size: 256 # 256
|
||||
gradient_accumulation_steps: 1
|
||||
num_workers: 12 # 12
|
||||
latent_space: false
|
||||
num_frames: 16
|
||||
resolution: 256
|
||||
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
|
||||
train_data_dir: ""
|
||||
val_fileslist: ""
|
||||
val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val
|
||||
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
|
||||
lower_half: true
|
||||
audio_sample_rate: 16000
|
||||
video_fps: 25
|
||||
|
||||
optimizer:
|
||||
lr: 1e-5
|
||||
max_grad_norm: 1.0
|
||||
|
||||
run:
|
||||
max_train_steps: 10000000
|
||||
validation_steps: 2500
|
||||
mixed_precision_training: true
|
||||
seed: 42
|
||||
46
models/LatentSync/configs/syncnet/syncnet_16_pixel_attn.yaml
Normal file
46
models/LatentSync/configs/syncnet/syncnet_16_pixel_attn.yaml
Normal file
@@ -0,0 +1,46 @@
|
||||
model:
|
||||
audio_encoder: # input (1, 80, 52)
|
||||
in_channels: 1
|
||||
block_out_channels: [32, 64, 128, 256, 512, 1024, 2048]
|
||||
downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]]
|
||||
attn_blocks: [0, 0, 0, 1, 1, 0, 0]
|
||||
dropout: 0.0
|
||||
visual_encoder: # input (48, 128, 256)
|
||||
in_channels: 48
|
||||
block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048]
|
||||
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
|
||||
attn_blocks: [0, 0, 0, 0, 1, 1, 0, 0]
|
||||
dropout: 0.0
|
||||
|
||||
ckpt:
|
||||
resume_ckpt_path: ""
|
||||
inference_ckpt_path: checkpoints/stable_syncnet.pt
|
||||
save_ckpt_steps: 2500
|
||||
|
||||
data:
|
||||
train_output_dir: debug/syncnet
|
||||
num_val_samples: 2048
|
||||
batch_size: 256 # 256
|
||||
gradient_accumulation_steps: 1
|
||||
num_workers: 12 # 12
|
||||
latent_space: false
|
||||
num_frames: 16
|
||||
resolution: 256
|
||||
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
|
||||
train_data_dir: ""
|
||||
val_fileslist: ""
|
||||
val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val
|
||||
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
|
||||
lower_half: true
|
||||
audio_sample_rate: 16000
|
||||
video_fps: 25
|
||||
|
||||
optimizer:
|
||||
lr: 1e-5
|
||||
max_grad_norm: 1.0
|
||||
|
||||
run:
|
||||
max_train_steps: 10000000
|
||||
validation_steps: 2500
|
||||
mixed_precision_training: true
|
||||
seed: 42
|
||||
44
models/LatentSync/configs/syncnet/syncnet_25_pixel.yaml
Normal file
44
models/LatentSync/configs/syncnet/syncnet_25_pixel.yaml
Normal file
@@ -0,0 +1,44 @@
|
||||
model:
|
||||
audio_encoder: # input (1, 80, 80)
|
||||
in_channels: 1
|
||||
block_out_channels: [64, 128, 256, 256, 512, 1024]
|
||||
downsample_factors: [2, 2, 2, 2, 2, 2]
|
||||
dropout: 0.0
|
||||
visual_encoder: # input (75, 128, 256)
|
||||
in_channels: 75
|
||||
block_out_channels: [128, 128, 256, 256, 512, 512, 1024, 1024]
|
||||
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
|
||||
dropout: 0.0
|
||||
|
||||
ckpt:
|
||||
resume_ckpt_path: ""
|
||||
inference_ckpt_path: ""
|
||||
save_ckpt_steps: 2500
|
||||
|
||||
data:
|
||||
train_output_dir: debug/syncnet
|
||||
num_val_samples: 2048
|
||||
batch_size: 64 # 64
|
||||
gradient_accumulation_steps: 1
|
||||
num_workers: 12 # 12
|
||||
latent_space: false
|
||||
num_frames: 25
|
||||
resolution: 256
|
||||
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
|
||||
train_data_dir: ""
|
||||
val_fileslist: ""
|
||||
val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val
|
||||
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
|
||||
lower_half: true
|
||||
audio_sample_rate: 16000
|
||||
video_fps: 25
|
||||
|
||||
optimizer:
|
||||
lr: 1e-5
|
||||
max_grad_norm: 1.0
|
||||
|
||||
run:
|
||||
max_train_steps: 10000000
|
||||
validation_steps: 2500
|
||||
mixed_precision_training: true
|
||||
seed: 42
|
||||
96
models/LatentSync/configs/unet/stage1.yaml
Normal file
96
models/LatentSync/configs/unet/stage1.yaml
Normal file
@@ -0,0 +1,96 @@
|
||||
data:
|
||||
syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml
|
||||
train_output_dir: debug/unet
|
||||
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
|
||||
train_data_dir: ""
|
||||
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds
|
||||
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
|
||||
|
||||
val_video_path: assets/demo1_video.mp4
|
||||
val_audio_path: assets/demo1_audio.wav
|
||||
batch_size: 1 # 24
|
||||
num_workers: 12 # 12
|
||||
num_frames: 16
|
||||
resolution: 256
|
||||
mask_image_path: latentsync/utils/mask.png
|
||||
audio_sample_rate: 16000
|
||||
video_fps: 25
|
||||
audio_feat_length: [2, 2]
|
||||
|
||||
ckpt:
|
||||
resume_ckpt_path: checkpoints/latentsync_unet.pt
|
||||
save_ckpt_steps: 10000
|
||||
|
||||
run:
|
||||
pixel_space_supervise: false
|
||||
use_syncnet: false
|
||||
sync_loss_weight: 0.05
|
||||
perceptual_loss_weight: 0.1 # 0.1
|
||||
recon_loss_weight: 1 # 1
|
||||
guidance_scale: 1.5 # [1.0 - 3.0]
|
||||
trepa_loss_weight: 10
|
||||
inference_steps: 20
|
||||
seed: 1247
|
||||
use_mixed_noise: true
|
||||
mixed_noise_alpha: 1 # 1
|
||||
mixed_precision_training: true
|
||||
enable_gradient_checkpointing: true
|
||||
max_train_steps: 10000000
|
||||
max_train_epochs: -1
|
||||
|
||||
optimizer:
|
||||
lr: 1e-5
|
||||
scale_lr: false
|
||||
max_grad_norm: 1.0
|
||||
lr_scheduler: constant
|
||||
lr_warmup_steps: 0
|
||||
|
||||
model:
|
||||
act_fn: silu
|
||||
add_audio_layer: true
|
||||
attention_head_dim: 8
|
||||
block_out_channels: [320, 640, 1280, 1280]
|
||||
center_input_sample: false
|
||||
cross_attention_dim: 384
|
||||
down_block_types:
|
||||
[
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"DownBlock3D",
|
||||
]
|
||||
mid_block_type: UNetMidBlock3DCrossAttn
|
||||
up_block_types:
|
||||
[
|
||||
"UpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
]
|
||||
downsample_padding: 1
|
||||
flip_sin_to_cos: true
|
||||
freq_shift: 0
|
||||
in_channels: 13 # 49
|
||||
layers_per_block: 2
|
||||
mid_block_scale_factor: 1
|
||||
norm_eps: 1e-5
|
||||
norm_num_groups: 32
|
||||
out_channels: 4 # 16
|
||||
sample_size: 64
|
||||
resnet_time_scale_shift: default # Choose between [default, scale_shift]
|
||||
|
||||
use_motion_module: false
|
||||
motion_module_resolutions: [1, 2, 4, 8]
|
||||
motion_module_mid_block: false
|
||||
motion_module_decoder_only: false
|
||||
motion_module_type: Vanilla
|
||||
motion_module_kwargs:
|
||||
num_attention_heads: 8
|
||||
num_transformer_block: 1
|
||||
attention_block_types:
|
||||
- Temporal_Self
|
||||
- Temporal_Self
|
||||
temporal_position_encoding: true
|
||||
temporal_position_encoding_max_len: 24
|
||||
temporal_attention_dim_div: 1
|
||||
zero_initialize: true
|
||||
96
models/LatentSync/configs/unet/stage1_512.yaml
Normal file
96
models/LatentSync/configs/unet/stage1_512.yaml
Normal file
@@ -0,0 +1,96 @@
|
||||
data:
|
||||
syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml
|
||||
train_output_dir: debug/unet
|
||||
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
|
||||
train_data_dir: ""
|
||||
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds
|
||||
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
|
||||
|
||||
val_video_path: assets/demo1_video.mp4
|
||||
val_audio_path: assets/demo1_audio.wav
|
||||
batch_size: 1 # 8
|
||||
num_workers: 12 # 12
|
||||
num_frames: 16
|
||||
resolution: 512
|
||||
mask_image_path: latentsync/utils/mask.png
|
||||
audio_sample_rate: 16000
|
||||
video_fps: 25
|
||||
audio_feat_length: [2, 2]
|
||||
|
||||
ckpt:
|
||||
resume_ckpt_path: checkpoints/latentsync_unet.pt
|
||||
save_ckpt_steps: 10000
|
||||
|
||||
run:
|
||||
pixel_space_supervise: false
|
||||
use_syncnet: false
|
||||
sync_loss_weight: 0.05
|
||||
perceptual_loss_weight: 0.1 # 0.1
|
||||
recon_loss_weight: 1 # 1
|
||||
guidance_scale: 1.5 # [1.0 - 3.0]
|
||||
trepa_loss_weight: 10
|
||||
inference_steps: 20
|
||||
seed: 1247
|
||||
use_mixed_noise: true
|
||||
mixed_noise_alpha: 1 # 1
|
||||
mixed_precision_training: true
|
||||
enable_gradient_checkpointing: true
|
||||
max_train_steps: 10000000
|
||||
max_train_epochs: -1
|
||||
|
||||
optimizer:
|
||||
lr: 1e-5
|
||||
scale_lr: false
|
||||
max_grad_norm: 1.0
|
||||
lr_scheduler: constant
|
||||
lr_warmup_steps: 0
|
||||
|
||||
model:
|
||||
act_fn: silu
|
||||
add_audio_layer: true
|
||||
attention_head_dim: 8
|
||||
block_out_channels: [320, 640, 1280, 1280]
|
||||
center_input_sample: false
|
||||
cross_attention_dim: 384
|
||||
down_block_types:
|
||||
[
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"DownBlock3D",
|
||||
]
|
||||
mid_block_type: UNetMidBlock3DCrossAttn
|
||||
up_block_types:
|
||||
[
|
||||
"UpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
]
|
||||
downsample_padding: 1
|
||||
flip_sin_to_cos: true
|
||||
freq_shift: 0
|
||||
in_channels: 13 # 49
|
||||
layers_per_block: 2
|
||||
mid_block_scale_factor: 1
|
||||
norm_eps: 1e-5
|
||||
norm_num_groups: 32
|
||||
out_channels: 4 # 16
|
||||
sample_size: 64
|
||||
resnet_time_scale_shift: default # Choose between [default, scale_shift]
|
||||
|
||||
use_motion_module: false
|
||||
motion_module_resolutions: [1, 2, 4, 8]
|
||||
motion_module_mid_block: false
|
||||
motion_module_decoder_only: false
|
||||
motion_module_type: Vanilla
|
||||
motion_module_kwargs:
|
||||
num_attention_heads: 8
|
||||
num_transformer_block: 1
|
||||
attention_block_types:
|
||||
- Temporal_Self
|
||||
- Temporal_Self
|
||||
temporal_position_encoding: true
|
||||
temporal_position_encoding_max_len: 24
|
||||
temporal_attention_dim_div: 1
|
||||
zero_initialize: true
|
||||
99
models/LatentSync/configs/unet/stage2.yaml
Normal file
99
models/LatentSync/configs/unet/stage2.yaml
Normal file
@@ -0,0 +1,99 @@
|
||||
data:
|
||||
syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml
|
||||
train_output_dir: debug/unet
|
||||
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
|
||||
train_data_dir: ""
|
||||
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds
|
||||
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
|
||||
|
||||
val_video_path: assets/demo1_video.mp4
|
||||
val_audio_path: assets/demo1_audio.wav
|
||||
batch_size: 1 # 4
|
||||
num_workers: 12 # 12
|
||||
num_frames: 16
|
||||
resolution: 256
|
||||
mask_image_path: latentsync/utils/mask.png
|
||||
audio_sample_rate: 16000
|
||||
video_fps: 25
|
||||
audio_feat_length: [2, 2]
|
||||
|
||||
ckpt:
|
||||
resume_ckpt_path: checkpoints/latentsync_unet.pt
|
||||
save_ckpt_steps: 10000
|
||||
|
||||
run:
|
||||
pixel_space_supervise: true
|
||||
use_syncnet: true
|
||||
sync_loss_weight: 0.05
|
||||
perceptual_loss_weight: 0.1 # 0.1
|
||||
recon_loss_weight: 1 # 1
|
||||
guidance_scale: 1.5 # [1.0 - 3.0]
|
||||
trepa_loss_weight: 10
|
||||
inference_steps: 20
|
||||
trainable_modules:
|
||||
- motion_modules.
|
||||
- attentions.
|
||||
seed: 1247
|
||||
use_mixed_noise: true
|
||||
mixed_noise_alpha: 1 # 1
|
||||
mixed_precision_training: true
|
||||
enable_gradient_checkpointing: true
|
||||
max_train_steps: 10000000
|
||||
max_train_epochs: -1
|
||||
|
||||
optimizer:
|
||||
lr: 1e-5
|
||||
scale_lr: false
|
||||
max_grad_norm: 1.0
|
||||
lr_scheduler: constant
|
||||
lr_warmup_steps: 0
|
||||
|
||||
model:
|
||||
act_fn: silu
|
||||
add_audio_layer: true
|
||||
attention_head_dim: 8
|
||||
block_out_channels: [320, 640, 1280, 1280]
|
||||
center_input_sample: false
|
||||
cross_attention_dim: 384
|
||||
down_block_types:
|
||||
[
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"DownBlock3D",
|
||||
]
|
||||
mid_block_type: UNetMidBlock3DCrossAttn
|
||||
up_block_types:
|
||||
[
|
||||
"UpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
]
|
||||
downsample_padding: 1
|
||||
flip_sin_to_cos: true
|
||||
freq_shift: 0
|
||||
in_channels: 13 # 49
|
||||
layers_per_block: 2
|
||||
mid_block_scale_factor: 1
|
||||
norm_eps: 1e-5
|
||||
norm_num_groups: 32
|
||||
out_channels: 4 # 16
|
||||
sample_size: 64
|
||||
resnet_time_scale_shift: default # Choose between [default, scale_shift]
|
||||
|
||||
use_motion_module: true
|
||||
motion_module_resolutions: [1, 2, 4, 8]
|
||||
motion_module_mid_block: false
|
||||
motion_module_decoder_only: false
|
||||
motion_module_type: Vanilla
|
||||
motion_module_kwargs:
|
||||
num_attention_heads: 8
|
||||
num_transformer_block: 1
|
||||
attention_block_types:
|
||||
- Temporal_Self
|
||||
- Temporal_Self
|
||||
temporal_position_encoding: true
|
||||
temporal_position_encoding_max_len: 24
|
||||
temporal_attention_dim_div: 1
|
||||
zero_initialize: true
|
||||
99
models/LatentSync/configs/unet/stage2_512.yaml
Normal file
99
models/LatentSync/configs/unet/stage2_512.yaml
Normal file
@@ -0,0 +1,99 @@
|
||||
data:
|
||||
syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml
|
||||
train_output_dir: debug/unet
|
||||
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
|
||||
train_data_dir: ""
|
||||
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds
|
||||
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
|
||||
|
||||
val_video_path: assets/demo1_video.mp4
|
||||
val_audio_path: assets/demo1_audio.wav
|
||||
batch_size: 1 # 4
|
||||
num_workers: 12 # 12
|
||||
num_frames: 16
|
||||
resolution: 512
|
||||
mask_image_path: latentsync/utils/mask.png
|
||||
audio_sample_rate: 16000
|
||||
video_fps: 25
|
||||
audio_feat_length: [2, 2]
|
||||
|
||||
ckpt:
|
||||
resume_ckpt_path: checkpoints/latentsync_unet.pt
|
||||
save_ckpt_steps: 10000
|
||||
|
||||
run:
|
||||
pixel_space_supervise: true
|
||||
use_syncnet: true
|
||||
sync_loss_weight: 0.05
|
||||
perceptual_loss_weight: 0.1 # 0.1
|
||||
recon_loss_weight: 1 # 1
|
||||
guidance_scale: 1.5 # [1.0 - 3.0]
|
||||
trepa_loss_weight: 10
|
||||
inference_steps: 20
|
||||
trainable_modules:
|
||||
- motion_modules.
|
||||
- attentions.
|
||||
seed: 1247
|
||||
use_mixed_noise: true
|
||||
mixed_noise_alpha: 1 # 1
|
||||
mixed_precision_training: true
|
||||
enable_gradient_checkpointing: true
|
||||
max_train_steps: 10000000
|
||||
max_train_epochs: -1
|
||||
|
||||
optimizer:
|
||||
lr: 1e-5
|
||||
scale_lr: false
|
||||
max_grad_norm: 1.0
|
||||
lr_scheduler: constant
|
||||
lr_warmup_steps: 0
|
||||
|
||||
model:
|
||||
act_fn: silu
|
||||
add_audio_layer: true
|
||||
attention_head_dim: 8
|
||||
block_out_channels: [320, 640, 1280, 1280]
|
||||
center_input_sample: false
|
||||
cross_attention_dim: 384
|
||||
down_block_types:
|
||||
[
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"DownBlock3D",
|
||||
]
|
||||
mid_block_type: UNetMidBlock3DCrossAttn
|
||||
up_block_types:
|
||||
[
|
||||
"UpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
]
|
||||
downsample_padding: 1
|
||||
flip_sin_to_cos: true
|
||||
freq_shift: 0
|
||||
in_channels: 13 # 49
|
||||
layers_per_block: 2
|
||||
mid_block_scale_factor: 1
|
||||
norm_eps: 1e-5
|
||||
norm_num_groups: 32
|
||||
out_channels: 4 # 16
|
||||
sample_size: 64
|
||||
resnet_time_scale_shift: default # Choose between [default, scale_shift]
|
||||
|
||||
use_motion_module: true
|
||||
motion_module_resolutions: [1, 2, 4, 8]
|
||||
motion_module_mid_block: false
|
||||
motion_module_decoder_only: false
|
||||
motion_module_type: Vanilla
|
||||
motion_module_kwargs:
|
||||
num_attention_heads: 8
|
||||
num_transformer_block: 1
|
||||
attention_block_types:
|
||||
- Temporal_Self
|
||||
- Temporal_Self
|
||||
temporal_position_encoding: true
|
||||
temporal_position_encoding_max_len: 24
|
||||
temporal_attention_dim_div: 1
|
||||
zero_initialize: true
|
||||
99
models/LatentSync/configs/unet/stage2_efficient.yaml
Normal file
99
models/LatentSync/configs/unet/stage2_efficient.yaml
Normal file
@@ -0,0 +1,99 @@
|
||||
data:
|
||||
syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml
|
||||
train_output_dir: debug/unet
|
||||
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
|
||||
train_data_dir: ""
|
||||
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds
|
||||
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
|
||||
|
||||
val_video_path: assets/demo1_video.mp4
|
||||
val_audio_path: assets/demo1_audio.wav
|
||||
batch_size: 1 # 4
|
||||
num_workers: 12 # 12
|
||||
num_frames: 16
|
||||
resolution: 256
|
||||
mask_image_path: latentsync/utils/mask.png
|
||||
audio_sample_rate: 16000
|
||||
video_fps: 25
|
||||
audio_feat_length: [2, 2]
|
||||
|
||||
ckpt:
|
||||
resume_ckpt_path: checkpoints/latentsync_unet.pt
|
||||
save_ckpt_steps: 10000
|
||||
|
||||
run:
|
||||
pixel_space_supervise: true
|
||||
use_syncnet: true
|
||||
sync_loss_weight: 0.05
|
||||
perceptual_loss_weight: 0.1 # 0.1
|
||||
recon_loss_weight: 1 # 1
|
||||
guidance_scale: 1.5 # [1.0 - 3.0]
|
||||
trepa_loss_weight: 0
|
||||
inference_steps: 20
|
||||
trainable_modules:
|
||||
- motion_modules.
|
||||
- attn2.
|
||||
seed: 1247
|
||||
use_mixed_noise: true
|
||||
mixed_noise_alpha: 1 # 1
|
||||
mixed_precision_training: true
|
||||
enable_gradient_checkpointing: true
|
||||
max_train_steps: 10000000
|
||||
max_train_epochs: -1
|
||||
|
||||
optimizer:
|
||||
lr: 1e-5
|
||||
scale_lr: false
|
||||
max_grad_norm: 1.0
|
||||
lr_scheduler: constant
|
||||
lr_warmup_steps: 0
|
||||
|
||||
model:
|
||||
act_fn: silu
|
||||
add_audio_layer: true
|
||||
attention_head_dim: 8
|
||||
block_out_channels: [320, 640, 1280, 1280]
|
||||
center_input_sample: false
|
||||
cross_attention_dim: 384
|
||||
down_block_types:
|
||||
[
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"DownBlock3D",
|
||||
]
|
||||
mid_block_type: UNetMidBlock3DCrossAttn
|
||||
up_block_types:
|
||||
[
|
||||
"UpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
]
|
||||
downsample_padding: 1
|
||||
flip_sin_to_cos: true
|
||||
freq_shift: 0
|
||||
in_channels: 13 # 49
|
||||
layers_per_block: 2
|
||||
mid_block_scale_factor: 1
|
||||
norm_eps: 1e-5
|
||||
norm_num_groups: 32
|
||||
out_channels: 4 # 16
|
||||
sample_size: 64
|
||||
resnet_time_scale_shift: default # Choose between [default, scale_shift]
|
||||
|
||||
use_motion_module: true
|
||||
motion_module_resolutions: [1, 2, 4, 8]
|
||||
motion_module_mid_block: false
|
||||
motion_module_decoder_only: true
|
||||
motion_module_type: Vanilla
|
||||
motion_module_kwargs:
|
||||
num_attention_heads: 8
|
||||
num_transformer_block: 1
|
||||
attention_block_types:
|
||||
- Temporal_Self
|
||||
- Temporal_Self
|
||||
temporal_position_encoding: true
|
||||
temporal_position_encoding_max_len: 24
|
||||
temporal_attention_dim_div: 1
|
||||
zero_initialize: true
|
||||
139
models/LatentSync/latentsync/data/syncnet_dataset.py
Normal file
139
models/LatentSync/latentsync/data/syncnet_dataset.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
import random
|
||||
from ..utils.util import gather_video_paths_recursively
|
||||
from ..utils.image_processor import ImageProcessor
|
||||
from ..utils.audio import melspectrogram
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
from decord import AudioReader, VideoReader, cpu
|
||||
|
||||
|
||||
class SyncNetDataset(Dataset):
|
||||
def __init__(self, data_dir: str, fileslist: str, config):
|
||||
if fileslist != "":
|
||||
with open(fileslist) as file:
|
||||
self.video_paths = [line.rstrip() for line in file]
|
||||
elif data_dir != "":
|
||||
self.video_paths = gather_video_paths_recursively(data_dir)
|
||||
else:
|
||||
raise ValueError("data_dir and fileslist cannot be both empty")
|
||||
|
||||
self.resolution = config.data.resolution
|
||||
self.num_frames = config.data.num_frames
|
||||
|
||||
self.mel_window_length = math.ceil(self.num_frames / 5 * 16)
|
||||
|
||||
self.audio_sample_rate = config.data.audio_sample_rate
|
||||
self.video_fps = config.data.video_fps
|
||||
self.image_processor = ImageProcessor(resolution=config.data.resolution)
|
||||
self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
|
||||
Path(self.audio_mel_cache_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.video_paths)
|
||||
|
||||
def read_audio(self, video_path: str):
|
||||
ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
|
||||
original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
|
||||
return torch.from_numpy(original_mel)
|
||||
|
||||
def crop_audio_window(self, original_mel, start_index):
|
||||
start_idx = int(80.0 * (start_index / float(self.video_fps)))
|
||||
end_idx = start_idx + self.mel_window_length
|
||||
return original_mel[:, start_idx:end_idx].unsqueeze(0)
|
||||
|
||||
def get_frames(self, video_reader: VideoReader):
|
||||
total_num_frames = len(video_reader)
|
||||
|
||||
start_idx = random.randint(0, total_num_frames - self.num_frames)
|
||||
frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
|
||||
|
||||
while True:
|
||||
wrong_start_idx = random.randint(0, total_num_frames - self.num_frames)
|
||||
if wrong_start_idx == start_idx:
|
||||
continue
|
||||
wrong_frames_index = np.arange(wrong_start_idx, wrong_start_idx + self.num_frames, dtype=int)
|
||||
break
|
||||
|
||||
frames = video_reader.get_batch(frames_index).asnumpy()
|
||||
wrong_frames = video_reader.get_batch(wrong_frames_index).asnumpy()
|
||||
|
||||
return frames, wrong_frames, start_idx
|
||||
|
||||
def worker_init_fn(self, worker_id):
|
||||
self.worker_id = worker_id
|
||||
|
||||
def __getitem__(self, idx):
|
||||
while True:
|
||||
try:
|
||||
idx = random.randint(0, len(self) - 1)
|
||||
|
||||
# Get video file path
|
||||
video_path = self.video_paths[idx]
|
||||
|
||||
vr = VideoReader(video_path, ctx=cpu(self.worker_id))
|
||||
|
||||
if len(vr) < 2 * self.num_frames:
|
||||
continue
|
||||
|
||||
frames, wrong_frames, start_idx = self.get_frames(vr)
|
||||
|
||||
mel_cache_path = os.path.join(
|
||||
self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
|
||||
)
|
||||
|
||||
if os.path.isfile(mel_cache_path):
|
||||
try:
|
||||
original_mel = torch.load(mel_cache_path, weights_only=True)
|
||||
except Exception as e:
|
||||
print(f"{type(e).__name__} - {e} - {mel_cache_path}")
|
||||
os.remove(mel_cache_path)
|
||||
original_mel = self.read_audio(video_path)
|
||||
torch.save(original_mel, mel_cache_path)
|
||||
else:
|
||||
original_mel = self.read_audio(video_path)
|
||||
torch.save(original_mel, mel_cache_path)
|
||||
|
||||
mel = self.crop_audio_window(original_mel, start_idx)
|
||||
|
||||
if mel.shape[-1] != self.mel_window_length:
|
||||
continue
|
||||
|
||||
if random.choice([True, False]):
|
||||
y = torch.ones(1).float()
|
||||
chosen_frames = frames
|
||||
else:
|
||||
y = torch.zeros(1).float()
|
||||
chosen_frames = wrong_frames
|
||||
|
||||
chosen_frames = self.image_processor.process_images(chosen_frames)
|
||||
|
||||
vr.seek(0) # avoid memory leak
|
||||
break
|
||||
|
||||
except Exception as e: # Handle the exception of face not detcted
|
||||
print(f"{type(e).__name__} - {e} - {video_path}")
|
||||
if "vr" in locals():
|
||||
vr.seek(0) # avoid memory leak
|
||||
|
||||
sample = dict(frames=chosen_frames, audio_samples=mel, y=y)
|
||||
|
||||
return sample
|
||||
152
models/LatentSync/latentsync/data/unet_dataset.py
Normal file
152
models/LatentSync/latentsync/data/unet_dataset.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
import random
|
||||
import cv2
|
||||
from ..utils.image_processor import ImageProcessor, load_fixed_mask
|
||||
from ..utils.audio import melspectrogram
|
||||
from decord import AudioReader, VideoReader, cpu
|
||||
import torch.nn.functional as F
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class UNetDataset(Dataset):
|
||||
def __init__(self, train_data_dir: str, config):
|
||||
if config.data.train_fileslist != "":
|
||||
with open(config.data.train_fileslist) as file:
|
||||
self.video_paths = [line.rstrip() for line in file]
|
||||
elif train_data_dir != "":
|
||||
self.video_paths = []
|
||||
for file in os.listdir(train_data_dir):
|
||||
if file.endswith(".mp4"):
|
||||
self.video_paths.append(os.path.join(train_data_dir, file))
|
||||
else:
|
||||
raise ValueError("data_dir and fileslist cannot be both empty")
|
||||
|
||||
self.resolution = config.data.resolution
|
||||
self.num_frames = config.data.num_frames
|
||||
|
||||
self.mel_window_length = math.ceil(self.num_frames / 5 * 16)
|
||||
|
||||
self.audio_sample_rate = config.data.audio_sample_rate
|
||||
self.video_fps = config.data.video_fps
|
||||
self.image_processor = ImageProcessor(
|
||||
self.resolution, mask_image=load_fixed_mask(self.resolution, config.data.mask_image_path)
|
||||
)
|
||||
self.load_audio_data = config.model.add_audio_layer and config.run.use_syncnet
|
||||
self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
|
||||
Path(self.audio_mel_cache_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.video_paths)
|
||||
|
||||
def read_audio(self, video_path: str):
|
||||
ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
|
||||
original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
|
||||
return torch.from_numpy(original_mel)
|
||||
|
||||
def crop_audio_window(self, original_mel, start_index):
|
||||
start_idx = int(80.0 * (start_index / float(self.video_fps)))
|
||||
end_idx = start_idx + self.mel_window_length
|
||||
return original_mel[:, start_idx:end_idx].unsqueeze(0)
|
||||
|
||||
def get_frames(self, video_reader: VideoReader):
|
||||
total_num_frames = len(video_reader)
|
||||
|
||||
start_idx = random.randint(0, total_num_frames - self.num_frames)
|
||||
gt_frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
|
||||
|
||||
while True:
|
||||
ref_start_idx = random.randint(0, total_num_frames - self.num_frames)
|
||||
if ref_start_idx > start_idx - self.num_frames and ref_start_idx < start_idx + self.num_frames:
|
||||
continue
|
||||
ref_frames_index = np.arange(ref_start_idx, ref_start_idx + self.num_frames, dtype=int)
|
||||
break
|
||||
|
||||
gt_frames = video_reader.get_batch(gt_frames_index).asnumpy()
|
||||
ref_frames = video_reader.get_batch(ref_frames_index).asnumpy()
|
||||
|
||||
return gt_frames, ref_frames, start_idx
|
||||
|
||||
def worker_init_fn(self, worker_id):
|
||||
self.worker_id = worker_id
|
||||
|
||||
def __getitem__(self, idx):
|
||||
while True:
|
||||
try:
|
||||
idx = random.randint(0, len(self) - 1)
|
||||
|
||||
# Get video file path
|
||||
video_path = self.video_paths[idx]
|
||||
|
||||
vr = VideoReader(video_path, ctx=cpu(self.worker_id))
|
||||
|
||||
if len(vr) < 3 * self.num_frames:
|
||||
continue
|
||||
|
||||
gt_frames, ref_frames, start_idx = self.get_frames(vr)
|
||||
|
||||
if self.load_audio_data:
|
||||
mel_cache_path = os.path.join(
|
||||
self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
|
||||
)
|
||||
|
||||
if os.path.isfile(mel_cache_path):
|
||||
try:
|
||||
original_mel = torch.load(mel_cache_path, weights_only=True)
|
||||
except Exception as e:
|
||||
print(f"{type(e).__name__} - {e} - {mel_cache_path}")
|
||||
os.remove(mel_cache_path)
|
||||
original_mel = self.read_audio(video_path)
|
||||
torch.save(original_mel, mel_cache_path)
|
||||
else:
|
||||
original_mel = self.read_audio(video_path)
|
||||
torch.save(original_mel, mel_cache_path)
|
||||
|
||||
mel = self.crop_audio_window(original_mel, start_idx)
|
||||
|
||||
if mel.shape[-1] != self.mel_window_length:
|
||||
continue
|
||||
else:
|
||||
mel = []
|
||||
|
||||
gt_pixel_values, masked_pixel_values, masks = self.image_processor.prepare_masks_and_masked_images(
|
||||
gt_frames, affine_transform=False
|
||||
) # (f, c, h, w)
|
||||
ref_pixel_values = self.image_processor.process_images(ref_frames)
|
||||
|
||||
vr.seek(0) # avoid memory leak
|
||||
break
|
||||
|
||||
except Exception as e: # Handle the exception of face not detcted
|
||||
print(f"{type(e).__name__} - {e} - {video_path}")
|
||||
if "vr" in locals():
|
||||
vr.seek(0) # avoid memory leak
|
||||
|
||||
sample = dict(
|
||||
gt_pixel_values=gt_pixel_values,
|
||||
masked_pixel_values=masked_pixel_values,
|
||||
ref_pixel_values=ref_pixel_values,
|
||||
mel=mel,
|
||||
masks=masks,
|
||||
video_path=video_path,
|
||||
start_idx=start_idx,
|
||||
)
|
||||
|
||||
return sample
|
||||
280
models/LatentSync/latentsync/models/attention.py
Normal file
280
models/LatentSync/latentsync/models/attention.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.utils import BaseOutput
|
||||
from diffusers.models.attention import FeedForward, AdaLayerNorm
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer3DModelOutput(BaseOutput):
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class Transformer3DModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
add_audio_layer=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_linear_projection = use_linear_projection
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# Define input layers
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if use_linear_projection:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
# Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
add_audio_layer=add_audio_layer,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# Define output layers
|
||||
if use_linear_projection:
|
||||
self.proj_out = nn.Linear(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
||||
# Input
|
||||
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
||||
video_length = hidden_states.shape[2]
|
||||
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
||||
|
||||
batch, channel, height, weight = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
video_length=video_length,
|
||||
)
|
||||
|
||||
# Output
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
else:
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer3DModelOutput(sample=output)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
add_audio_layer=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
||||
self.add_audio_layer = add_audio_layer
|
||||
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
# Cross-attn
|
||||
if add_audio_layer:
|
||||
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
else:
|
||||
self.attn2 = None
|
||||
|
||||
# Feed-forward
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
|
||||
):
|
||||
norm_hidden_states = (
|
||||
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
||||
)
|
||||
|
||||
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
||||
|
||||
if self.attn2 is not None and encoder_hidden_states is not None:
|
||||
if encoder_hidden_states.dim() == 4:
|
||||
encoder_hidden_states = rearrange(encoder_hidden_states, "b f s d -> (b f) s d")
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
hidden_states = (
|
||||
self.attn2(
|
||||
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
# Feed-forward
|
||||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias=False,
|
||||
upcast_attention: bool = False,
|
||||
upcast_softmax: bool = False,
|
||||
norm_num_groups: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.upcast_attention = upcast_attention
|
||||
self.upcast_softmax = upcast_softmax
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
|
||||
self.heads = heads
|
||||
|
||||
if norm_num_groups is not None:
|
||||
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
||||
else:
|
||||
self.group_norm = None
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
def split_heads(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
tensor = tensor.reshape(batch_size, seq_len, self.heads, dim // self.heads)
|
||||
tensor = tensor.permute(0, 2, 1, 3)
|
||||
return tensor
|
||||
|
||||
def concat_heads(self, tensor):
|
||||
batch_size, heads, seq_len, head_dim = tensor.shape
|
||||
tensor = tensor.permute(0, 2, 1, 3)
|
||||
tensor = tensor.reshape(batch_size, seq_len, heads * head_dim)
|
||||
return tensor
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
if self.group_norm is not None:
|
||||
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
query = self.split_heads(query)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = self.to_k(encoder_hidden_states)
|
||||
value = self.to_v(encoder_hidden_states)
|
||||
|
||||
key = self.split_heads(key)
|
||||
value = self.split_heads(value)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.shape[-1] != query.shape[1]:
|
||||
target_length = query.shape[1]
|
||||
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
||||
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
||||
|
||||
# Use PyTorch native implementation of FlashAttention-2
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
|
||||
|
||||
hidden_states = self.concat_heads(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
313
models/LatentSync/latentsync/models/motion_module.py
Normal file
313
models/LatentSync/latentsync/models/motion_module.py
Normal file
@@ -0,0 +1,313 @@
|
||||
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
|
||||
|
||||
# Actually we don't use the motion module in the final version of LatentSync
|
||||
# When we started the project, we used the codebase of AnimateDiff and tried motion module
|
||||
# But the results are poor, and we decied to leave the code here for possible future usage
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.utils import BaseOutput
|
||||
from diffusers.models.attention import FeedForward
|
||||
from .attention import Attention
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import math
|
||||
from .utils import zero_module
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemporalTransformer3DModelOutput(BaseOutput):
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
|
||||
if motion_module_type == "Vanilla":
|
||||
return VanillaTemporalModule(
|
||||
in_channels=in_channels,
|
||||
**motion_module_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
class VanillaTemporalModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
num_attention_heads=8,
|
||||
num_transformer_block=2,
|
||||
attention_block_types=("Temporal_Self", "Temporal_Self"),
|
||||
cross_frame_attention_mode=None,
|
||||
temporal_position_encoding=False,
|
||||
temporal_position_encoding_max_len=24,
|
||||
temporal_attention_dim_div=1,
|
||||
zero_initialize=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.temporal_transformer = TemporalTransformer3DModel(
|
||||
in_channels=in_channels,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
|
||||
num_layers=num_transformer_block,
|
||||
attention_block_types=attention_block_types,
|
||||
cross_frame_attention_mode=cross_frame_attention_mode,
|
||||
temporal_position_encoding=temporal_position_encoding,
|
||||
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
||||
)
|
||||
|
||||
if zero_initialize:
|
||||
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
||||
|
||||
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
|
||||
hidden_states = input_tensor
|
||||
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
|
||||
|
||||
output = hidden_states
|
||||
return output
|
||||
|
||||
|
||||
class TemporalTransformer3DModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
num_layers,
|
||||
attention_block_types=(
|
||||
"Temporal_Self",
|
||||
"Temporal_Self",
|
||||
),
|
||||
dropout=0.0,
|
||||
norm_num_groups=32,
|
||||
cross_attention_dim=768,
|
||||
activation_fn="geglu",
|
||||
attention_bias=False,
|
||||
upcast_attention=False,
|
||||
cross_frame_attention_mode=None,
|
||||
temporal_position_encoding=False,
|
||||
temporal_position_encoding_max_len=24,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
TemporalTransformerBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
attention_block_types=attention_block_types,
|
||||
dropout=dropout,
|
||||
norm_num_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
cross_frame_attention_mode=cross_frame_attention_mode,
|
||||
temporal_position_encoding=temporal_position_encoding,
|
||||
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
||||
video_length = hidden_states.shape[2]
|
||||
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
||||
|
||||
batch, channel, height, weight = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# Transformer Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length
|
||||
)
|
||||
|
||||
# output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TemporalTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
attention_block_types=(
|
||||
"Temporal_Self",
|
||||
"Temporal_Self",
|
||||
),
|
||||
dropout=0.0,
|
||||
norm_num_groups=32,
|
||||
cross_attention_dim=768,
|
||||
activation_fn="geglu",
|
||||
attention_bias=False,
|
||||
upcast_attention=False,
|
||||
cross_frame_attention_mode=None,
|
||||
temporal_position_encoding=False,
|
||||
temporal_position_encoding_max_len=24,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
attention_blocks = []
|
||||
norms = []
|
||||
|
||||
for block_name in attention_block_types:
|
||||
attention_blocks.append(
|
||||
VersatileAttention(
|
||||
attention_mode=block_name.split("_")[0],
|
||||
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
cross_frame_attention_mode=cross_frame_attention_mode,
|
||||
temporal_position_encoding=temporal_position_encoding,
|
||||
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
||||
)
|
||||
)
|
||||
norms.append(nn.LayerNorm(dim))
|
||||
|
||||
self.attention_blocks = nn.ModuleList(attention_blocks)
|
||||
self.norms = nn.ModuleList(norms)
|
||||
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
||||
self.ff_norm = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
||||
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
||||
norm_hidden_states = norm(hidden_states)
|
||||
hidden_states = (
|
||||
attention_block(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
|
||||
video_length=video_length,
|
||||
)
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
||||
|
||||
output = hidden_states
|
||||
return output
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model, dropout=0.0, max_len=24):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
position = torch.arange(max_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
||||
pe = torch.zeros(1, max_len, d_model)
|
||||
pe[0, :, 0::2] = torch.sin(position * div_term)
|
||||
pe[0, :, 1::2] = torch.cos(position * div_term)
|
||||
self.register_buffer("pe", pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[:, : x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class VersatileAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
attention_mode=None,
|
||||
cross_frame_attention_mode=None,
|
||||
temporal_position_encoding=False,
|
||||
temporal_position_encoding_max_len=24,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert attention_mode == "Temporal"
|
||||
|
||||
self.attention_mode = attention_mode
|
||||
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
|
||||
|
||||
self.pos_encoder = (
|
||||
PositionalEncoding(kwargs["query_dim"], dropout=0.0, max_len=temporal_position_encoding_max_len)
|
||||
if (temporal_position_encoding and attention_mode == "Temporal")
|
||||
else None
|
||||
)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
||||
if self.attention_mode == "Temporal":
|
||||
s = hidden_states.shape[1]
|
||||
hidden_states = rearrange(hidden_states, "(b f) s c -> (b s) f c", f=video_length)
|
||||
|
||||
if self.pos_encoder is not None:
|
||||
hidden_states = self.pos_encoder(hidden_states)
|
||||
|
||||
##### This section will not be executed #####
|
||||
encoder_hidden_states = (
|
||||
repeat(encoder_hidden_states, "b n c -> (b s) n c", s=s)
|
||||
if encoder_hidden_states is not None
|
||||
else encoder_hidden_states
|
||||
)
|
||||
#############################################
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if self.group_norm is not None:
|
||||
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
query = self.split_heads(query)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = self.to_k(encoder_hidden_states)
|
||||
value = self.to_v(encoder_hidden_states)
|
||||
|
||||
key = self.split_heads(key)
|
||||
value = self.split_heads(value)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.shape[-1] != query.shape[1]:
|
||||
target_length = query.shape[1]
|
||||
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
||||
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
||||
|
||||
# Use PyTorch native implementation of FlashAttention-2
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
|
||||
|
||||
hidden_states = self.concat_heads(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
|
||||
if self.attention_mode == "Temporal":
|
||||
hidden_states = rearrange(hidden_states, "(b s) f c -> (b f) s c", s=s)
|
||||
|
||||
return hidden_states
|
||||
228
models/LatentSync/latentsync/models/resnet.py
Normal file
228
models/LatentSync/latentsync/models/resnet.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class InflatedConv3d(nn.Conv2d):
|
||||
def forward(self, x):
|
||||
video_length = x.shape[2]
|
||||
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = super().forward(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class InflatedGroupNorm(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
video_length = x.shape[2]
|
||||
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = super().forward(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Upsample3D(nn.Module):
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
|
||||
conv = None
|
||||
if use_conv_transpose:
|
||||
raise NotImplementedError
|
||||
elif use_conv:
|
||||
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, hidden_states, output_size=None):
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv_transpose:
|
||||
raise NotImplementedError
|
||||
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
||||
if hidden_states.shape[0] >= 64:
|
||||
hidden_states = hidden_states.contiguous()
|
||||
|
||||
# if `output_size` is passed we force the interpolation output
|
||||
# size and do not make use of `scale_factor=2`
|
||||
if output_size is None:
|
||||
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
|
||||
else:
|
||||
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Downsample3D(nn.Module):
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
|
||||
if use_conv:
|
||||
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, hidden_states):
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
if self.use_conv and self.padding == 0:
|
||||
raise NotImplementedError
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout=0.0,
|
||||
temb_channels=512,
|
||||
groups=32,
|
||||
groups_out=None,
|
||||
pre_norm=True,
|
||||
eps=1e-6,
|
||||
non_linearity="swish",
|
||||
time_embedding_norm="default",
|
||||
output_scale_factor=1.0,
|
||||
use_in_shortcut=None,
|
||||
use_inflated_groupnorm=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.pre_norm = pre_norm
|
||||
self.pre_norm = True
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
assert use_inflated_groupnorm != None
|
||||
if use_inflated_groupnorm:
|
||||
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
else:
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if temb_channels is not None:
|
||||
if self.time_embedding_norm == "default":
|
||||
time_emb_proj_out_channels = out_channels
|
||||
elif self.time_embedding_norm == "scale_shift":
|
||||
time_emb_proj_out_channels = out_channels * 2
|
||||
else:
|
||||
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
||||
|
||||
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
|
||||
else:
|
||||
self.time_emb_proj = None
|
||||
|
||||
if self.time_embedding_norm == "scale_shift":
|
||||
self.double_len_linear = torch.nn.Linear(time_emb_proj_out_channels, 2 * time_emb_proj_out_channels)
|
||||
else:
|
||||
self.double_len_linear = None
|
||||
|
||||
if use_inflated_groupnorm:
|
||||
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
||||
else:
|
||||
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = lambda x: F.silu(x)
|
||||
elif non_linearity == "mish":
|
||||
self.nonlinearity = Mish()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
|
||||
|
||||
self.conv_shortcut = None
|
||||
if self.use_in_shortcut:
|
||||
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, input_tensor, temb):
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
if temb is not None:
|
||||
if temb.dim() == 2:
|
||||
# input (1, 1280)
|
||||
temb = self.time_emb_proj(self.nonlinearity(temb))
|
||||
temb = temb[:, :, None, None, None] # unsqueeze
|
||||
else:
|
||||
# input (1, 1280, 16)
|
||||
temb = temb.permute(0, 2, 1)
|
||||
temb = self.time_emb_proj(self.nonlinearity(temb))
|
||||
if self.double_len_linear is not None:
|
||||
temb = self.double_len_linear(self.nonlinearity(temb))
|
||||
temb = temb.permute(0, 2, 1)
|
||||
temb = temb[:, :, :, None, None]
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "default":
|
||||
hidden_states = hidden_states + temb
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "scale_shift":
|
||||
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
class Mish(torch.nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
||||
233
models/LatentSync/latentsync/models/stable_syncnet.py
Normal file
233
models/LatentSync/latentsync/models/stable_syncnet.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from einops import rearrange
|
||||
from torch.nn import functional as F
|
||||
from .attention import Attention
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.models.attention import FeedForward
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class StableSyncNet(nn.Module):
|
||||
def __init__(self, config, gradient_checkpointing=False):
|
||||
super().__init__()
|
||||
self.audio_encoder = DownEncoder2D(
|
||||
in_channels=config["audio_encoder"]["in_channels"],
|
||||
block_out_channels=config["audio_encoder"]["block_out_channels"],
|
||||
downsample_factors=config["audio_encoder"]["downsample_factors"],
|
||||
dropout=config["audio_encoder"]["dropout"],
|
||||
attn_blocks=config["audio_encoder"]["attn_blocks"],
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
|
||||
self.visual_encoder = DownEncoder2D(
|
||||
in_channels=config["visual_encoder"]["in_channels"],
|
||||
block_out_channels=config["visual_encoder"]["block_out_channels"],
|
||||
downsample_factors=config["visual_encoder"]["downsample_factors"],
|
||||
dropout=config["visual_encoder"]["dropout"],
|
||||
attn_blocks=config["visual_encoder"]["attn_blocks"],
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
|
||||
self.eval()
|
||||
|
||||
def forward(self, image_sequences, audio_sequences):
|
||||
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
||||
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
||||
|
||||
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
||||
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
||||
|
||||
# Make them unit vectors
|
||||
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
||||
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
||||
|
||||
return vision_embeds, audio_embeds
|
||||
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
eps: float = 1e-6,
|
||||
act_fn: str = "silu",
|
||||
downsample_factor=2,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if act_fn == "relu":
|
||||
self.act_fn = nn.ReLU()
|
||||
elif act_fn == "silu":
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.conv_shortcut = None
|
||||
|
||||
if isinstance(downsample_factor, list):
|
||||
downsample_factor = tuple(downsample_factor)
|
||||
|
||||
if downsample_factor == 1:
|
||||
self.downsample_conv = None
|
||||
else:
|
||||
self.downsample_conv = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
|
||||
)
|
||||
self.pad = (0, 1, 0, 1)
|
||||
if isinstance(downsample_factor, tuple):
|
||||
if downsample_factor[0] == 1:
|
||||
self.pad = (0, 1, 1, 1) # The padding order is from back to front
|
||||
elif downsample_factor[1] == 1:
|
||||
self.pad = (1, 1, 0, 1)
|
||||
|
||||
def forward(self, input_tensor):
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
hidden_states += input_tensor
|
||||
|
||||
if self.downsample_conv is not None:
|
||||
hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
|
||||
hidden_states = self.downsample_conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttentionBlock2D(nn.Module):
|
||||
def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
|
||||
super().__init__()
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
self.norm3 = nn.LayerNorm(query_dim)
|
||||
|
||||
self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
|
||||
|
||||
self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
||||
self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.attn = Attention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
|
||||
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
|
||||
hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
|
||||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
||||
|
||||
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width).contiguous()
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DownEncoder2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4 * 16,
|
||||
block_out_channels=[64, 128, 256, 256],
|
||||
downsample_factors=[2, 2, 2, 2],
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
attn_blocks=[1, 1, 1, 1],
|
||||
dropout: float = 0.0,
|
||||
act_fn="silu",
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
# in
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# down
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
output_channels = block_out_channels[0]
|
||||
for i, block_out_channel in enumerate(block_out_channels):
|
||||
input_channels = output_channels
|
||||
output_channels = block_out_channel
|
||||
|
||||
down_block = ResnetBlock2D(
|
||||
in_channels=input_channels,
|
||||
out_channels=output_channels,
|
||||
downsample_factor=downsample_factors[i],
|
||||
norm_num_groups=norm_num_groups,
|
||||
dropout=dropout,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
if attn_blocks[i] == 1:
|
||||
attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
|
||||
self.down_blocks.append(attention_block)
|
||||
|
||||
# out
|
||||
self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.act_fn_out = nn.ReLU()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
if self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(down_block, hidden_states, use_reentrant=False)
|
||||
else:
|
||||
hidden_states = down_block(hidden_states)
|
||||
|
||||
# post-process
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.act_fn_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
512
models/LatentSync/latentsync/models/unet.py
Normal file
512
models/LatentSync/latentsync/models/unet.py
Normal file
@@ -0,0 +1,512 @@
|
||||
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet.py
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models import ModelMixin
|
||||
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||
from .unet_blocks import (
|
||||
CrossAttnDownBlock3D,
|
||||
CrossAttnUpBlock3D,
|
||||
DownBlock3D,
|
||||
UNetMidBlock3DCrossAttn,
|
||||
UpBlock3D,
|
||||
get_down_block,
|
||||
get_up_block,
|
||||
)
|
||||
from .resnet import InflatedConv3d, InflatedGroupNorm
|
||||
|
||||
from ..utils.util import zero_rank_log
|
||||
from .utils import zero_module
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet3DConditionOutput(BaseOutput):
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Optional[int] = None,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 4,
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"DownBlock3D",
|
||||
),
|
||||
mid_block_type: str = "UNetMidBlock3DCrossAttn",
|
||||
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: int = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
use_inflated_groupnorm=False,
|
||||
# Additional
|
||||
use_motion_module=False,
|
||||
motion_module_resolutions=(1, 2, 4, 8),
|
||||
motion_module_mid_block=False,
|
||||
motion_module_decoder_only=False,
|
||||
motion_module_type=None,
|
||||
motion_module_kwargs={},
|
||||
add_audio_layer=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.use_motion_module = use_motion_module
|
||||
self.add_audio_layer = add_audio_layer
|
||||
|
||||
self.conv_in = zero_module(InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)))
|
||||
|
||||
# time
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
res = 2**i
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[i],
|
||||
downsample_padding=downsample_padding,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
use_motion_module=use_motion_module
|
||||
and (res in motion_module_resolutions)
|
||||
and (not motion_module_decoder_only),
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
add_audio_layer=add_audio_layer,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
if mid_block_type == "UNetMidBlock3DCrossAttn":
|
||||
self.mid_block = UNetMidBlock3DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
use_motion_module=use_motion_module and motion_module_mid_block,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
add_audio_layer=add_audio_layer,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
||||
|
||||
# count how many layers upsample the videos
|
||||
self.num_upsamplers = 0
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
res = 2 ** (3 - i)
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
|
||||
# add upsample block for all BUT final layer
|
||||
if not is_final_block:
|
||||
add_upsample = True
|
||||
self.num_upsamplers += 1
|
||||
else:
|
||||
add_upsample = False
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=reversed_attention_head_dim[i],
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
use_motion_module=use_motion_module and (res in motion_module_resolutions),
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
add_audio_layer=add_audio_layer,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
if use_inflated_groupnorm:
|
||||
self.conv_norm_out = InflatedGroupNorm(
|
||||
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
||||
)
|
||||
else:
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
||||
)
|
||||
self.conv_act = nn.SiLU()
|
||||
|
||||
self.conv_out = zero_module(InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1))
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_slicable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_slicable_dims(module)
|
||||
|
||||
num_slicable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == "max":
|
||||
# make smallest slice possible
|
||||
slice_size = num_slicable_layers * [1]
|
||||
|
||||
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||
)
|
||||
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
# support controlnet
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet3DConditionOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
||||
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||
# on the fly if necessary.
|
||||
default_overall_up_factor = 2**self.num_upsamplers
|
||||
|
||||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
logger.info("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
# pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
|
||||
)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# support controlnet
|
||||
down_block_res_samples = list(down_block_res_samples)
|
||||
if down_block_additional_residuals is not None:
|
||||
for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
|
||||
if down_block_additional_residual.dim() == 4: # boardcast
|
||||
down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
|
||||
down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
|
||||
|
||||
# mid
|
||||
sample = self.mid_block(
|
||||
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
# support controlnet
|
||||
if mid_block_additional_residual is not None:
|
||||
if mid_block_additional_residual.dim() == 4: # boardcast
|
||||
mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
|
||||
sample = sample + mid_block_additional_residual
|
||||
|
||||
# up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
# post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet3DConditionOutput(sample=sample)
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
# If the loaded checkpoint's in_channels or out_channels are different from config
|
||||
if state_dict["conv_in.weight"].shape[1] != self.config.in_channels:
|
||||
del state_dict["conv_in.weight"]
|
||||
del state_dict["conv_in.bias"]
|
||||
if state_dict["conv_out.weight"].shape[0] != self.config.out_channels:
|
||||
del state_dict["conv_out.weight"]
|
||||
del state_dict["conv_out.bias"]
|
||||
|
||||
# If the loaded checkpoint's cross_attention_dim is different from config
|
||||
keys_to_remove = []
|
||||
for key in state_dict:
|
||||
if "attn2.to_k." in key or "attn2.to_v." in key:
|
||||
if state_dict[key].shape[1] != self.config.cross_attention_dim:
|
||||
keys_to_remove.append(key)
|
||||
|
||||
for key in keys_to_remove:
|
||||
del state_dict[key]
|
||||
|
||||
return super().load_state_dict(state_dict=state_dict, strict=strict)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_config: dict, ckpt_path: str, device="cpu"):
|
||||
unet = cls.from_config(model_config).to(device)
|
||||
if ckpt_path != "":
|
||||
zero_rank_log(logger, f"Load from checkpoint: {ckpt_path}")
|
||||
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
|
||||
if "global_step" in ckpt:
|
||||
zero_rank_log(logger, f"resume from global_step: {ckpt['global_step']}")
|
||||
resume_global_step = ckpt["global_step"]
|
||||
else:
|
||||
resume_global_step = 0
|
||||
unet.load_state_dict(ckpt["state_dict"], strict=False)
|
||||
|
||||
del ckpt
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
resume_global_step = 0
|
||||
|
||||
return unet, resume_global_step
|
||||
777
models/LatentSync/latentsync/models/unet_blocks.py
Normal file
777
models/LatentSync/latentsync/models/unet_blocks.py
Normal file
@@ -0,0 +1,777 @@
|
||||
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .attention import Transformer3DModel
|
||||
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
|
||||
from .motion_module import get_motion_module
|
||||
|
||||
|
||||
def get_down_block(
|
||||
down_block_type,
|
||||
num_layers,
|
||||
in_channels,
|
||||
out_channels,
|
||||
temb_channels,
|
||||
add_downsample,
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
attn_num_head_channels,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
downsample_padding=None,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
use_inflated_groupnorm=False,
|
||||
use_motion_module=None,
|
||||
motion_module_type=None,
|
||||
motion_module_kwargs=None,
|
||||
add_audio_layer=False,
|
||||
):
|
||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||
if down_block_type == "DownBlock3D":
|
||||
return DownBlock3D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
use_motion_module=use_motion_module,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
)
|
||||
elif down_block_type == "CrossAttnDownBlock3D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
||||
return CrossAttnDownBlock3D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
use_motion_module=use_motion_module,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
add_audio_layer=add_audio_layer,
|
||||
)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_up_block(
|
||||
up_block_type,
|
||||
num_layers,
|
||||
in_channels,
|
||||
out_channels,
|
||||
prev_output_channel,
|
||||
temb_channels,
|
||||
add_upsample,
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
attn_num_head_channels,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
use_inflated_groupnorm=False,
|
||||
use_motion_module=None,
|
||||
motion_module_type=None,
|
||||
motion_module_kwargs=None,
|
||||
add_audio_layer=False,
|
||||
):
|
||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||
if up_block_type == "UpBlock3D":
|
||||
return UpBlock3D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
use_motion_module=use_motion_module,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
)
|
||||
elif up_block_type == "CrossAttnUpBlock3D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
||||
return CrossAttnUpBlock3D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
use_motion_module=use_motion_module,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
add_audio_layer=add_audio_layer,
|
||||
)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
upcast_attention=False,
|
||||
use_inflated_groupnorm=False,
|
||||
use_motion_module=None,
|
||||
motion_module_type=None,
|
||||
motion_module_kwargs=None,
|
||||
add_audio_layer=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
motion_modules = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
if dual_cross_attention:
|
||||
raise NotImplementedError
|
||||
attentions.append(
|
||||
Transformer3DModel(
|
||||
attn_num_head_channels,
|
||||
in_channels // attn_num_head_channels,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
add_audio_layer=add_audio_layer,
|
||||
)
|
||||
)
|
||||
motion_modules.append(
|
||||
get_motion_module(
|
||||
in_channels=in_channels,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
)
|
||||
if use_motion_module
|
||||
else None
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.motion_modules = nn.ModuleList(motion_modules)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if motion_module is not None:
|
||||
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossAttnDownBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
use_inflated_groupnorm=False,
|
||||
use_motion_module=None,
|
||||
motion_module_type=None,
|
||||
motion_module_kwargs=None,
|
||||
add_audio_layer=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
attentions = []
|
||||
motion_modules = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
)
|
||||
)
|
||||
if dual_cross_attention:
|
||||
raise NotImplementedError
|
||||
attentions.append(
|
||||
Transformer3DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
add_audio_layer=add_audio_layer,
|
||||
)
|
||||
)
|
||||
motion_modules.append(
|
||||
get_motion_module(
|
||||
in_channels=out_channels,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
)
|
||||
if use_motion_module
|
||||
else None
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.motion_modules = nn.ModuleList(motion_modules)
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample3D(
|
||||
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
use_reentrant=False,
|
||||
)[0]
|
||||
|
||||
if motion_module is not None:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(motion_module),
|
||||
hidden_states,
|
||||
temb,
|
||||
encoder_hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
if motion_module is not None:
|
||||
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class DownBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
downsample_padding=1,
|
||||
use_inflated_groupnorm=False,
|
||||
use_motion_module=None,
|
||||
motion_module_type=None,
|
||||
motion_module_kwargs=None,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
motion_modules = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
)
|
||||
)
|
||||
motion_modules.append(
|
||||
get_motion_module(
|
||||
in_channels=out_channels,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
)
|
||||
if use_motion_module
|
||||
else None
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.motion_modules = nn.ModuleList(motion_modules)
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample3D(
|
||||
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
output_states = ()
|
||||
|
||||
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
|
||||
if motion_module is not None:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(motion_module),
|
||||
hidden_states,
|
||||
temb,
|
||||
encoder_hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if motion_module is not None:
|
||||
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class CrossAttnUpBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
use_inflated_groupnorm=False,
|
||||
use_motion_module=None,
|
||||
motion_module_type=None,
|
||||
motion_module_kwargs=None,
|
||||
add_audio_layer=False,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
attentions = []
|
||||
motion_modules = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock3D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
)
|
||||
)
|
||||
if dual_cross_attention:
|
||||
raise NotImplementedError
|
||||
attentions.append(
|
||||
Transformer3DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
add_audio_layer=add_audio_layer,
|
||||
)
|
||||
)
|
||||
motion_modules.append(
|
||||
get_motion_module(
|
||||
in_channels=out_channels,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
)
|
||||
if use_motion_module
|
||||
else None
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.motion_modules = nn.ModuleList(motion_modules)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
use_reentrant=False,
|
||||
)[0]
|
||||
|
||||
if motion_module is not None:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(motion_module),
|
||||
hidden_states,
|
||||
temb,
|
||||
encoder_hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
if motion_module is not None:
|
||||
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
use_inflated_groupnorm=False,
|
||||
use_motion_module=None,
|
||||
motion_module_type=None,
|
||||
motion_module_kwargs=None,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
motion_modules = []
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock3D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
)
|
||||
)
|
||||
motion_modules.append(
|
||||
get_motion_module(
|
||||
in_channels=out_channels,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
)
|
||||
if use_motion_module
|
||||
else None
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.motion_modules = nn.ModuleList(motion_modules)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
upsample_size=None,
|
||||
encoder_hidden_states=None,
|
||||
):
|
||||
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
|
||||
if motion_module is not None:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(motion_module),
|
||||
hidden_states,
|
||||
temb,
|
||||
encoder_hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if motion_module is not None:
|
||||
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
19
models/LatentSync/latentsync/models/utils.py
Normal file
19
models/LatentSync/latentsync/models/utils.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
def zero_module(module):
|
||||
# Zero out the parameters of a module and return it.
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
90
models/LatentSync/latentsync/models/wav2lip_syncnet.py
Normal file
90
models/LatentSync/latentsync/models/wav2lip_syncnet.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# Adapted from https://github.com/primepake/wav2lip_288x288/blob/master/models/syncnetv2.py
|
||||
# The code here is for ablation study.
|
||||
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Wav2LipSyncNet(nn.Module):
|
||||
def __init__(self, act_fn="leaky"):
|
||||
super().__init__()
|
||||
|
||||
# input image sequences: (15, 128, 256)
|
||||
self.visual_encoder = nn.Sequential(
|
||||
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3, act_fn=act_fn), # (128, 256)
|
||||
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1, act_fn=act_fn), # (126, 127)
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(64, 128, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (63, 64)
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(128, 256, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (21, 22)
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(256, 512, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (11, 11)
|
||||
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (6, 6)
|
||||
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1, act_fn="relu"), # (3, 3)
|
||||
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
|
||||
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
|
||||
)
|
||||
|
||||
# input audio sequences: (1, 80, 16)
|
||||
self.audio_encoder = nn.Sequential(
|
||||
Conv2d(1, 32, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
|
||||
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1, act_fn=act_fn), # (27, 16)
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(64, 128, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (9, 6)
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1, act_fn=act_fn), # (3, 3)
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(256, 512, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
|
||||
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
|
||||
Conv2d(512, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
|
||||
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
|
||||
)
|
||||
|
||||
def forward(self, image_sequences, audio_sequences):
|
||||
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
||||
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
||||
|
||||
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
||||
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
||||
|
||||
# Make them unit vectors
|
||||
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
||||
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
||||
|
||||
return vision_embeds, audio_embeds
|
||||
|
||||
|
||||
class Conv2d(nn.Module):
|
||||
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, act_fn="relu", *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout))
|
||||
if act_fn == "relu":
|
||||
self.act_fn = nn.ReLU()
|
||||
elif act_fn == "tanh":
|
||||
self.act_fn = nn.Tanh()
|
||||
elif act_fn == "silu":
|
||||
self.act_fn = nn.SiLU()
|
||||
elif act_fn == "leaky":
|
||||
self.act_fn = nn.LeakyReLU(0.2, inplace=True)
|
||||
|
||||
self.residual = residual
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_block(x)
|
||||
if self.residual:
|
||||
out += x
|
||||
return self.act_fn(out)
|
||||
477
models/LatentSync/latentsync/pipelines/lipsync_pipeline.py
Normal file
477
models/LatentSync/latentsync/pipelines/lipsync_pipeline.py
Normal file
@@ -0,0 +1,477 @@
|
||||
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/pipelines/pipeline_animation.py
|
||||
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from typing import Callable, List, Optional, Union
|
||||
import subprocess
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
|
||||
from packaging import version
|
||||
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.models import AutoencoderKL
|
||||
from diffusers.pipelines import DiffusionPipeline
|
||||
from diffusers.schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils import deprecate, logging
|
||||
|
||||
from einops import rearrange
|
||||
import cv2
|
||||
|
||||
from ..models.unet import UNet3DConditionModel
|
||||
from ..utils.util import read_video, read_audio, write_video, check_ffmpeg_installed
|
||||
from ..utils.image_processor import ImageProcessor, load_fixed_mask
|
||||
from ..whisper.audio2feature import Audio2Feature
|
||||
import tqdm
|
||||
import soundfile as sf
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LipsyncPipeline(DiffusionPipeline):
|
||||
_optional_components = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
audio_encoder: Audio2Feature,
|
||||
unet: UNet3DConditionModel,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
audio_encoder=audio_encoder,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
self.set_progress_bar_config(desc="Steps")
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
self.vae.disable_slicing()
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def decode_latents(self, latents):
|
||||
latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor
|
||||
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
||||
decoded_latents = self.vae.decode(latents).sample
|
||||
return decoded_latents
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(self, height, width, callback_steps):
|
||||
assert height == width, "Height and width must be equal"
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
def prepare_latents(self, num_frames, num_channels_latents, height, width, dtype, device, generator):
|
||||
shape = (
|
||||
1,
|
||||
num_channels_latents,
|
||||
1,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
) # (b, c, f, h, w)
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
latents = latents.repeat(1, 1, num_frames, 1, 1)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def prepare_mask_latents(
|
||||
self, mask, masked_image, height, width, dtype, device, generator, do_classifier_free_guidance
|
||||
):
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
)
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
|
||||
# encode the mask image into latents space so we can concatenate it to the latents
|
||||
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
||||
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
|
||||
# assume batch size = 1
|
||||
mask = rearrange(mask, "f c h w -> 1 c f h w")
|
||||
masked_image_latents = rearrange(masked_image_latents, "f c h w -> 1 c f h w")
|
||||
|
||||
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
||||
masked_image_latents = (
|
||||
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
|
||||
)
|
||||
return mask, masked_image_latents
|
||||
|
||||
def prepare_image_latents(self, images, device, dtype, generator, do_classifier_free_guidance):
|
||||
images = images.to(device=device, dtype=dtype)
|
||||
image_latents = self.vae.encode(images).latent_dist.sample(generator=generator)
|
||||
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
image_latents = rearrange(image_latents, "f c h w -> 1 c f h w")
|
||||
image_latents = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
|
||||
|
||||
return image_latents
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
if not hasattr(self, "_progress_bar_config"):
|
||||
self._progress_bar_config = {}
|
||||
self._progress_bar_config.update(kwargs)
|
||||
|
||||
@staticmethod
|
||||
def paste_surrounding_pixels_back(decoded_latents, pixel_values, masks, device, weight_dtype):
|
||||
# Paste the surrounding pixels back, because we only want to change the mouth region
|
||||
pixel_values = pixel_values.to(device=device, dtype=weight_dtype)
|
||||
masks = masks.to(device=device, dtype=weight_dtype)
|
||||
combined_pixel_values = decoded_latents * masks + pixel_values * (1 - masks)
|
||||
return combined_pixel_values
|
||||
|
||||
@staticmethod
|
||||
def pixel_values_to_images(pixel_values: torch.Tensor):
|
||||
pixel_values = rearrange(pixel_values, "f c h w -> f h w c")
|
||||
pixel_values = (pixel_values / 2 + 0.5).clamp(0, 1)
|
||||
images = (pixel_values * 255).to(torch.uint8)
|
||||
images = images.cpu().numpy()
|
||||
return images
|
||||
|
||||
def affine_transform_video(self, video_frames: np.ndarray):
|
||||
faces = []
|
||||
boxes = []
|
||||
affine_matrices = []
|
||||
print(f"Affine transforming {len(video_frames)} faces...")
|
||||
for frame in tqdm.tqdm(video_frames):
|
||||
face, box, affine_matrix = self.image_processor.affine_transform(frame)
|
||||
faces.append(face)
|
||||
boxes.append(box)
|
||||
affine_matrices.append(affine_matrix)
|
||||
|
||||
faces = torch.stack(faces)
|
||||
return faces, boxes, affine_matrices
|
||||
|
||||
def restore_video(self, faces: torch.Tensor, video_frames: np.ndarray, boxes: list, affine_matrices: list):
|
||||
video_frames = video_frames[: len(faces)]
|
||||
out_frames = []
|
||||
print(f"Restoring {len(faces)} faces...")
|
||||
for index, face in enumerate(tqdm.tqdm(faces)):
|
||||
x1, y1, x2, y2 = boxes[index]
|
||||
height = int(y2 - y1)
|
||||
width = int(x2 - x1)
|
||||
face = torchvision.transforms.functional.resize(
|
||||
face, size=(height, width), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True
|
||||
)
|
||||
out_frame = self.image_processor.restorer.restore_img(video_frames[index], face, affine_matrices[index])
|
||||
out_frames.append(out_frame)
|
||||
return np.stack(out_frames, axis=0)
|
||||
|
||||
def loop_video(self, whisper_chunks: list, video_frames: np.ndarray):
|
||||
# If the audio is longer than the video, we need to loop the video
|
||||
if len(whisper_chunks) > len(video_frames):
|
||||
faces, boxes, affine_matrices = self.affine_transform_video(video_frames)
|
||||
num_loops = math.ceil(len(whisper_chunks) / len(video_frames))
|
||||
loop_video_frames = []
|
||||
loop_faces = []
|
||||
loop_boxes = []
|
||||
loop_affine_matrices = []
|
||||
for i in range(num_loops):
|
||||
if i % 2 == 0:
|
||||
loop_video_frames.append(video_frames)
|
||||
loop_faces.append(faces)
|
||||
loop_boxes += boxes
|
||||
loop_affine_matrices += affine_matrices
|
||||
else:
|
||||
loop_video_frames.append(video_frames[::-1])
|
||||
loop_faces.append(faces.flip(0))
|
||||
loop_boxes += boxes[::-1]
|
||||
loop_affine_matrices += affine_matrices[::-1]
|
||||
|
||||
video_frames = np.concatenate(loop_video_frames, axis=0)[: len(whisper_chunks)]
|
||||
faces = torch.cat(loop_faces, dim=0)[: len(whisper_chunks)]
|
||||
boxes = loop_boxes[: len(whisper_chunks)]
|
||||
affine_matrices = loop_affine_matrices[: len(whisper_chunks)]
|
||||
else:
|
||||
video_frames = video_frames[: len(whisper_chunks)]
|
||||
faces, boxes, affine_matrices = self.affine_transform_video(video_frames)
|
||||
|
||||
return video_frames, faces, boxes, affine_matrices
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
video_path: str,
|
||||
audio_path: str,
|
||||
video_out_path: str,
|
||||
num_frames: int = 16,
|
||||
video_fps: int = 25,
|
||||
audio_sample_rate: int = 16000,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 20,
|
||||
guidance_scale: float = 1.5,
|
||||
weight_dtype: Optional[torch.dtype] = torch.float16,
|
||||
eta: float = 0.0,
|
||||
mask_image_path: str = "latentsync/utils/mask.png",
|
||||
temp_dir: str = "temp",
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
is_train = self.unet.training
|
||||
self.unet.eval()
|
||||
|
||||
check_ffmpeg_installed()
|
||||
|
||||
# 0. Define call parameters
|
||||
device = self._execution_device
|
||||
mask_image = load_fixed_mask(height, mask_image_path)
|
||||
self.image_processor = ImageProcessor(height, device="cuda", mask_image=mask_image)
|
||||
self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
|
||||
|
||||
# 1. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 2. Check inputs
|
||||
self.check_inputs(height, width, callback_steps)
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 4. Prepare extra step kwargs.
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
whisper_feature = self.audio_encoder.audio2feat(audio_path)
|
||||
whisper_chunks = self.audio_encoder.feature2chunks(feature_array=whisper_feature, fps=video_fps)
|
||||
|
||||
audio_samples = read_audio(audio_path)
|
||||
video_frames = read_video(video_path, use_decord=False)
|
||||
|
||||
video_frames, faces, boxes, affine_matrices = self.loop_video(whisper_chunks, video_frames)
|
||||
|
||||
synced_video_frames = []
|
||||
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
|
||||
# Prepare latent variables
|
||||
all_latents = self.prepare_latents(
|
||||
len(whisper_chunks),
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
weight_dtype,
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
|
||||
num_inferences = math.ceil(len(whisper_chunks) / num_frames)
|
||||
for i in tqdm.tqdm(range(num_inferences), desc="Doing inference..."):
|
||||
if self.unet.add_audio_layer:
|
||||
audio_embeds = torch.stack(whisper_chunks[i * num_frames : (i + 1) * num_frames])
|
||||
audio_embeds = audio_embeds.to(device, dtype=weight_dtype)
|
||||
if do_classifier_free_guidance:
|
||||
null_audio_embeds = torch.zeros_like(audio_embeds)
|
||||
audio_embeds = torch.cat([null_audio_embeds, audio_embeds])
|
||||
else:
|
||||
audio_embeds = None
|
||||
inference_faces = faces[i * num_frames : (i + 1) * num_frames]
|
||||
latents = all_latents[:, :, i * num_frames : (i + 1) * num_frames]
|
||||
ref_pixel_values, masked_pixel_values, masks = self.image_processor.prepare_masks_and_masked_images(
|
||||
inference_faces, affine_transform=False
|
||||
)
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
mask_latents, masked_image_latents = self.prepare_mask_latents(
|
||||
masks,
|
||||
masked_pixel_values,
|
||||
height,
|
||||
width,
|
||||
weight_dtype,
|
||||
device,
|
||||
generator,
|
||||
do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 8. Prepare image latents
|
||||
ref_latents = self.prepare_image_latents(
|
||||
ref_pixel_values,
|
||||
device,
|
||||
weight_dtype,
|
||||
generator,
|
||||
do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 9. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for j, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
unet_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
unet_input = self.scheduler.scale_model_input(unet_input, t)
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
unet_input = torch.cat([unet_input, mask_latents, masked_image_latents, ref_latents], dim=1)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(unet_input, t, encoder_hidden_states=audio_embeds).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_audio = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_audio - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if j == len(timesteps) - 1 or ((j + 1) > num_warmup_steps and (j + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and j % callback_steps == 0:
|
||||
callback(j, t, latents)
|
||||
|
||||
# Recover the pixel values
|
||||
decoded_latents = self.decode_latents(latents)
|
||||
decoded_latents = self.paste_surrounding_pixels_back(
|
||||
decoded_latents, ref_pixel_values, 1 - masks, device, weight_dtype
|
||||
)
|
||||
synced_video_frames.append(decoded_latents)
|
||||
|
||||
synced_video_frames = self.restore_video(torch.cat(synced_video_frames), video_frames, boxes, affine_matrices)
|
||||
|
||||
audio_samples_remain_length = int(synced_video_frames.shape[0] / video_fps * audio_sample_rate)
|
||||
audio_samples = audio_samples[:audio_samples_remain_length].cpu().numpy()
|
||||
|
||||
if is_train:
|
||||
self.unet.train()
|
||||
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
write_video(os.path.join(temp_dir, "video.mp4"), synced_video_frames, fps=video_fps)
|
||||
|
||||
sf.write(os.path.join(temp_dir, "audio.wav"), audio_samples, audio_sample_rate)
|
||||
|
||||
command = f"ffmpeg -y -loglevel error -nostdin -i {os.path.join(temp_dir, 'video.mp4')} -i {os.path.join(temp_dir, 'audio.wav')} -c:v libx264 -crf 18 -c:a aac -q:v 0 -q:a 0 {video_out_path}"
|
||||
subprocess.run(command, shell=True)
|
||||
67
models/LatentSync/latentsync/trepa/loss.py
Normal file
67
models/LatentSync/latentsync/trepa/loss.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from .third_party.VideoMAEv2.utils import load_videomae_model
|
||||
from ..utils.util import check_model_and_download
|
||||
|
||||
|
||||
class TREPALoss:
|
||||
def __init__(
|
||||
self,
|
||||
device="cuda",
|
||||
ckpt_path="checkpoints/auxiliary/vit_g_hybrid_pt_1200e_ssv2_ft.pth",
|
||||
with_cp=False,
|
||||
):
|
||||
check_model_and_download(ckpt_path)
|
||||
self.model = load_videomae_model(device, ckpt_path, with_cp).eval().to(dtype=torch.float16)
|
||||
self.model.requires_grad_(False)
|
||||
|
||||
def __call__(self, videos_fake, videos_real):
|
||||
batch_size = videos_fake.shape[0]
|
||||
num_frames = videos_fake.shape[2]
|
||||
videos_fake = rearrange(videos_fake.clone(), "b c f h w -> (b f) c h w")
|
||||
videos_real = rearrange(videos_real.clone(), "b c f h w -> (b f) c h w")
|
||||
|
||||
videos_fake = F.interpolate(videos_fake, size=(224, 224), mode="bicubic")
|
||||
videos_real = F.interpolate(videos_real, size=(224, 224), mode="bicubic")
|
||||
|
||||
videos_fake = rearrange(videos_fake, "(b f) c h w -> b c f h w", f=num_frames)
|
||||
videos_real = rearrange(videos_real, "(b f) c h w -> b c f h w", f=num_frames)
|
||||
|
||||
# Because input pixel range is [-1, 1], and model expects pixel range to be [0, 1]
|
||||
videos_fake = (videos_fake / 2 + 0.5).clamp(0, 1)
|
||||
videos_real = (videos_real / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
feats_fake = self.model.forward_features(videos_fake)
|
||||
feats_real = self.model.forward_features(videos_real)
|
||||
|
||||
feats_fake = F.normalize(feats_fake, p=2, dim=1)
|
||||
feats_real = F.normalize(feats_real, p=2, dim=1)
|
||||
|
||||
return F.mse_loss(feats_fake, feats_real)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.manual_seed(42)
|
||||
|
||||
# input shape: (b, c, f, h, w)
|
||||
videos_fake = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16)
|
||||
videos_real = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16)
|
||||
|
||||
trepa_loss = TREPALoss(device="cuda", with_cp=True)
|
||||
loss = trepa_loss(videos_fake, videos_real)
|
||||
print(loss)
|
||||
0
models/LatentSync/latentsync/trepa/third_party/VideoMAEv2/__init__.py
vendored
Normal file
0
models/LatentSync/latentsync/trepa/third_party/VideoMAEv2/__init__.py
vendored
Normal file
82
models/LatentSync/latentsync/trepa/third_party/VideoMAEv2/utils.py
vendored
Normal file
82
models/LatentSync/latentsync/trepa/third_party/VideoMAEv2/utils.py
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
import os
|
||||
import torch
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
from torchvision import transforms
|
||||
from .videomaev2_finetune import vit_giant_patch14_224
|
||||
|
||||
|
||||
def to_normalized_float_tensor(vid):
|
||||
return vid.permute(3, 0, 1, 2).to(torch.float32) / 255
|
||||
|
||||
|
||||
# NOTE: for those functions, which generally expect mini-batches, we keep them
|
||||
# as non-minibatch so that they are applied as if they were 4d (thus image).
|
||||
# this way, we only apply the transformation in the spatial domain
|
||||
def resize(vid, size, interpolation="bilinear"):
|
||||
# NOTE: using bilinear interpolation because we don't work on minibatches
|
||||
# at this level
|
||||
scale = None
|
||||
if isinstance(size, int):
|
||||
scale = float(size) / min(vid.shape[-2:])
|
||||
size = None
|
||||
return torch.nn.functional.interpolate(vid, size=size, scale_factor=scale, mode=interpolation, align_corners=False)
|
||||
|
||||
|
||||
class ToFloatTensorInZeroOne(object):
|
||||
def __call__(self, vid):
|
||||
return to_normalized_float_tensor(vid)
|
||||
|
||||
|
||||
class Resize(object):
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
|
||||
def __call__(self, vid):
|
||||
return resize(vid, self.size)
|
||||
|
||||
|
||||
def preprocess_videomae(videos):
|
||||
transform = transforms.Compose([ToFloatTensorInZeroOne(), Resize((224, 224))])
|
||||
return torch.stack([transform(f) for f in torch.from_numpy(videos)])
|
||||
|
||||
|
||||
def load_videomae_model(device, ckpt_path=None, with_cp=False):
|
||||
if ckpt_path is None:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
ckpt_path = os.path.join(current_dir, "vit_g_hybrid_pt_1200e_ssv2_ft.pth")
|
||||
|
||||
if not os.path.exists(ckpt_path):
|
||||
# download the ckpt to the path
|
||||
ckpt_url = "https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/videomaev2/vit_g_hybrid_pt_1200e_ssv2_ft.pth"
|
||||
response = requests.get(ckpt_url, stream=True, allow_redirects=True)
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
block_size = 1024
|
||||
|
||||
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
|
||||
with open(ckpt_path, "wb") as fw:
|
||||
for data in response.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
fw.write(data)
|
||||
|
||||
model = vit_giant_patch14_224(
|
||||
img_size=224,
|
||||
pretrained=False,
|
||||
num_classes=174,
|
||||
all_frames=16,
|
||||
tubelet_size=2,
|
||||
drop_path_rate=0.3,
|
||||
use_mean_pooling=True,
|
||||
with_cp=with_cp,
|
||||
)
|
||||
|
||||
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
for model_key in ["model", "module"]:
|
||||
if model_key in ckpt:
|
||||
ckpt = ckpt[model_key]
|
||||
break
|
||||
model.load_state_dict(ckpt)
|
||||
|
||||
del ckpt
|
||||
torch.cuda.empty_cache()
|
||||
return model.to(device)
|
||||
543
models/LatentSync/latentsync/trepa/third_party/VideoMAEv2/videomaev2_finetune.py
vendored
Normal file
543
models/LatentSync/latentsync/trepa/third_party/VideoMAEv2/videomaev2_finetune.py
vendored
Normal file
@@ -0,0 +1,543 @@
|
||||
# --------------------------------------------------------
|
||||
# Based on BEiT, timm, DINO and DeiT code bases
|
||||
# https://github.com/microsoft/unilm/tree/master/beit
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
||||
# https://github.com/facebookresearch/deit
|
||||
# https://github.com/facebookresearch/dino
|
||||
# --------------------------------------------------------'
|
||||
from functools import partial
|
||||
|
||||
import math
|
||||
import warnings
|
||||
import numpy as np
|
||||
import collections.abc
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from itertools import repeat
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
warnings.warn(
|
||||
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||
"The distribution of values may be incorrect.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
l = norm_cdf((a - mean) / std)
|
||||
u = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to
|
||||
# [2l-1, 2u-1].
|
||||
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
tensor.erfinv_()
|
||||
|
||||
# Transform to proper mean, std
|
||||
tensor.mul_(std * math.sqrt(2.0))
|
||||
tensor.add_(mean)
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
tensor.clamp_(min=a, max=b)
|
||||
return tensor
|
||||
|
||||
|
||||
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
||||
r"""Fills the input Tensor with values drawn from a truncated
|
||||
normal distribution. The values are effectively drawn from the
|
||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mean: the mean of the normal distribution
|
||||
std: the standard deviation of the normal distribution
|
||||
a: the minimum cutoff value
|
||||
b: the maximum cutoff value
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.trunc_normal_(w)
|
||||
"""
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
||||
"""
|
||||
Adapted from timm codebase
|
||||
"""
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
def _cfg(url="", **kwargs):
|
||||
return {
|
||||
"url": url,
|
||||
"num_classes": 400,
|
||||
"input_size": (3, 224, 224),
|
||||
"pool_size": None,
|
||||
"crop_pct": 0.9,
|
||||
"interpolation": "bicubic",
|
||||
"mean": (0.5, 0.5, 0.5),
|
||||
"std": (0.5, 0.5, 0.5),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "p={}".format(self.drop_prob)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
# x = self.drop(x)
|
||||
# commit this for the original BERT implement
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class CosAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, attn_head_dim=None
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
if attn_head_dim is not None:
|
||||
head_dim = attn_head_dim
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
# self.scale = qk_scale or head_dim**-0.5
|
||||
# DO NOT RENAME [self.scale] (for no weight decay)
|
||||
if qk_scale is None:
|
||||
self.scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
|
||||
else:
|
||||
self.scale = qk_scale
|
||||
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.v_bias = None
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(all_head_dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv_bias = None
|
||||
if self.q_bias is not None:
|
||||
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
|
||||
|
||||
# torch.log(torch.tensor(1. / 0.01)) = 4.6052
|
||||
logit_scale = torch.clamp(self.scale, max=4.6052).exp()
|
||||
|
||||
attn = attn * logit_scale
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, attn_head_dim=None
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
if attn_head_dim is not None:
|
||||
head_dim = attn_head_dim
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.v_bias = None
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(all_head_dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv_bias = None
|
||||
if self.q_bias is not None:
|
||||
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
# Use PyTorch native implementation of FlashAttention-2
|
||||
attn = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
x = attn.transpose(1, 2).reshape(B, N, -1)
|
||||
|
||||
# Deprecated attn implementation, which consumes much more VRAM
|
||||
# q = q * self.scale
|
||||
# attn = q @ k.transpose(-2, -1)
|
||||
# attn = attn.softmax(dim=-1)
|
||||
# attn = self.attn_drop(attn)
|
||||
# x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
init_values=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
attn_head_dim=None,
|
||||
cos_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
if cos_attn:
|
||||
self.attn = CosAttention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
attn_head_dim=attn_head_dim,
|
||||
)
|
||||
else:
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
attn_head_dim=attn_head_dim,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
if init_values > 0:
|
||||
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
||||
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
||||
else:
|
||||
self.gamma_1, self.gamma_2 = None, None
|
||||
|
||||
def forward(self, x):
|
||||
if self.gamma_1 is None:
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Image to Patch Embedding"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
num_spatial_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
|
||||
num_patches = num_spatial_patches * (num_frames // tubelet_size)
|
||||
|
||||
self.img_size = img_size
|
||||
self.tubelet_size = tubelet_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
self.proj = nn.Conv3d(
|
||||
in_channels=in_chans,
|
||||
out_channels=embed_dim,
|
||||
kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
|
||||
stride=(self.tubelet_size, patch_size[0], patch_size[1]),
|
||||
)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
B, C, T, H, W = x.shape
|
||||
assert (
|
||||
H == self.img_size[0] and W == self.img_size[1]
|
||||
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
# b, c, l -> b, l, c
|
||||
# [1, 1408, 8, 16, 16] -> [1, 1408, 2048] -> [1, 2048, 1408]
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
# sin-cos position encoding
|
||||
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
|
||||
def get_sinusoid_encoding_table(n_position, d_hid):
|
||||
"""Sinusoid position encoding table"""
|
||||
|
||||
# TODO: make it with torch instead of numpy
|
||||
def get_position_angle_vec(position):
|
||||
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
|
||||
|
||||
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
"""Vision Transformer with support for patch or hybrid CNN input stage"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
head_drop_rate=0.0,
|
||||
norm_layer=nn.LayerNorm,
|
||||
init_values=0.0,
|
||||
use_learnable_pos_emb=False,
|
||||
init_scale=0.0,
|
||||
all_frames=16,
|
||||
tubelet_size=2,
|
||||
use_mean_pooling=True,
|
||||
with_cp=False,
|
||||
cos_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
# num_features for consistency with other models
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
self.tubelet_size = tubelet_size
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
num_frames=all_frames,
|
||||
tubelet_size=tubelet_size,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.with_cp = with_cp
|
||||
|
||||
if use_learnable_pos_emb:
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
||||
else:
|
||||
# sine-cosine positional embeddings is on the way
|
||||
self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
init_values=init_values,
|
||||
cos_attn=cos_attn,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
||||
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
||||
self.head_dropout = nn.Dropout(head_drop_rate)
|
||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
if use_learnable_pos_emb:
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
self.head.weight.data.mul_(init_scale)
|
||||
self.head.bias.data.mul_(init_scale)
|
||||
self.num_frames = all_frames
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def get_num_layers(self):
|
||||
return len(self.blocks)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {"pos_embed", "cls_token"}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=""):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def interpolate_pos_encoding(self, t):
|
||||
T = 8
|
||||
t0 = t // self.tubelet_size
|
||||
if T == t0:
|
||||
return self.pos_embed
|
||||
dim = self.pos_embed.shape[-1]
|
||||
patch_pos_embed = self.pos_embed.permute(0, 2, 1).reshape(1, dim, 8, 16, 16)
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
t0 = t0 + 0.1
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(t0 / T, 1, 1),
|
||||
mode="trilinear",
|
||||
)
|
||||
assert int(t0) == patch_pos_embed.shape[-3]
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, dim, -1).permute(0, 2, 1)
|
||||
return patch_pos_embed
|
||||
|
||||
def forward_features(self, x):
|
||||
# [1, 3, 16, 224, 224]
|
||||
B = x.size(0)
|
||||
T = x.size(2)
|
||||
|
||||
# [1, 2048, 1408]
|
||||
x = self.patch_embed(x)
|
||||
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.interpolate_pos_encoding(T).expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
if self.with_cp:
|
||||
x = cp.checkpoint(blk, x, use_reentrant=False)
|
||||
else:
|
||||
x = blk(x)
|
||||
|
||||
# return self.fc_norm(x)
|
||||
|
||||
if self.fc_norm is not None:
|
||||
return self.fc_norm(x.mean(1))
|
||||
else:
|
||||
return self.norm(x[:, 0])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head_dropout(x)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def vit_giant_patch14_224(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=14,
|
||||
embed_dim=1408,
|
||||
depth=40,
|
||||
num_heads=16,
|
||||
mlp_ratio=48 / 11,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs,
|
||||
)
|
||||
model.default_cfg = _cfg()
|
||||
return model
|
||||
469
models/LatentSync/latentsync/trepa/third_party/VideoMAEv2/videomaev2_pretrain.py
vendored
Normal file
469
models/LatentSync/latentsync/trepa/third_party/VideoMAEv2/videomaev2_pretrain.py
vendored
Normal file
@@ -0,0 +1,469 @@
|
||||
# --------------------------------------------------------
|
||||
# Based on BEiT, timm, DINO and DeiT code bases
|
||||
# https://github.com/microsoft/unilm/tree/master/beit
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
||||
# https://github.com/facebookresearch/deit
|
||||
# https://github.com/facebookresearch/dino
|
||||
# --------------------------------------------------------'
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
|
||||
from .videomaev2_finetune import (
|
||||
Block,
|
||||
PatchEmbed,
|
||||
_cfg,
|
||||
get_sinusoid_encoding_table,
|
||||
)
|
||||
|
||||
from .videomaev2_finetune import trunc_normal_ as __call_trunc_normal_
|
||||
|
||||
def trunc_normal_(tensor, mean=0., std=1.):
|
||||
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
|
||||
|
||||
|
||||
class PretrainVisionTransformerEncoder(nn.Module):
|
||||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=0,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
init_values=None,
|
||||
tubelet_size=2,
|
||||
use_learnable_pos_emb=False,
|
||||
with_cp=False,
|
||||
all_frames=16,
|
||||
cos_attn=False):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
# num_features for consistency with other models
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
num_frames=all_frames,
|
||||
tubelet_size=tubelet_size)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.with_cp = with_cp
|
||||
|
||||
if use_learnable_pos_emb:
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + 1, embed_dim))
|
||||
else:
|
||||
# sine-cosine positional embeddings
|
||||
self.pos_embed = get_sinusoid_encoding_table(
|
||||
num_patches, embed_dim)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
||||
] # stochastic depth decay rule
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
init_values=init_values,
|
||||
cos_attn=cos_attn) for i in range(depth)
|
||||
])
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.head = nn.Linear(
|
||||
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
if use_learnable_pos_emb:
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def get_num_layers(self):
|
||||
return len(self.blocks)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(
|
||||
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x, mask):
|
||||
x = self.patch_embed(x)
|
||||
|
||||
x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
|
||||
|
||||
B, _, C = x.shape
|
||||
x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible
|
||||
|
||||
for blk in self.blocks:
|
||||
if self.with_cp:
|
||||
x_vis = cp.checkpoint(blk, x_vis)
|
||||
else:
|
||||
x_vis = blk(x_vis)
|
||||
|
||||
x_vis = self.norm(x_vis)
|
||||
return x_vis
|
||||
|
||||
def forward(self, x, mask):
|
||||
x = self.forward_features(x, mask)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
class PretrainVisionTransformerDecoder(nn.Module):
|
||||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
patch_size=16,
|
||||
num_classes=768,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
init_values=None,
|
||||
num_patches=196,
|
||||
tubelet_size=2,
|
||||
with_cp=False,
|
||||
cos_attn=False):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
assert num_classes == 3 * tubelet_size * patch_size**2
|
||||
# num_features for consistency with other models
|
||||
self.num_features = self.embed_dim = embed_dim
|
||||
self.patch_size = patch_size
|
||||
self.with_cp = with_cp
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
||||
] # stochastic depth decay rule
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
init_values=init_values,
|
||||
cos_attn=cos_attn) for i in range(depth)
|
||||
])
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.head = nn.Linear(
|
||||
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def get_num_layers(self):
|
||||
return len(self.blocks)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(
|
||||
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x, return_token_num):
|
||||
for blk in self.blocks:
|
||||
if self.with_cp:
|
||||
x = cp.checkpoint(blk, x)
|
||||
else:
|
||||
x = blk(x)
|
||||
|
||||
if return_token_num > 0:
|
||||
# only return the mask tokens predict pixels
|
||||
x = self.head(self.norm(x[:, -return_token_num:]))
|
||||
else:
|
||||
# [B, N, 3*16^2]
|
||||
x = self.head(self.norm(x))
|
||||
return x
|
||||
|
||||
|
||||
class PretrainVisionTransformer(nn.Module):
|
||||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
encoder_in_chans=3,
|
||||
encoder_num_classes=0,
|
||||
encoder_embed_dim=768,
|
||||
encoder_depth=12,
|
||||
encoder_num_heads=12,
|
||||
decoder_num_classes=1536, # decoder_num_classes=768
|
||||
decoder_embed_dim=512,
|
||||
decoder_depth=8,
|
||||
decoder_num_heads=8,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
init_values=0.,
|
||||
use_learnable_pos_emb=False,
|
||||
tubelet_size=2,
|
||||
num_classes=0, # avoid the error from create_fn in timm
|
||||
in_chans=0, # avoid the error from create_fn in timm
|
||||
with_cp=False,
|
||||
all_frames=16,
|
||||
cos_attn=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = PretrainVisionTransformerEncoder(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=encoder_in_chans,
|
||||
num_classes=encoder_num_classes,
|
||||
embed_dim=encoder_embed_dim,
|
||||
depth=encoder_depth,
|
||||
num_heads=encoder_num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_layer=norm_layer,
|
||||
init_values=init_values,
|
||||
tubelet_size=tubelet_size,
|
||||
use_learnable_pos_emb=use_learnable_pos_emb,
|
||||
with_cp=with_cp,
|
||||
all_frames=all_frames,
|
||||
cos_attn=cos_attn)
|
||||
|
||||
self.decoder = PretrainVisionTransformerDecoder(
|
||||
patch_size=patch_size,
|
||||
num_patches=self.encoder.patch_embed.num_patches,
|
||||
num_classes=decoder_num_classes,
|
||||
embed_dim=decoder_embed_dim,
|
||||
depth=decoder_depth,
|
||||
num_heads=decoder_num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_layer=norm_layer,
|
||||
init_values=init_values,
|
||||
tubelet_size=tubelet_size,
|
||||
with_cp=with_cp,
|
||||
cos_attn=cos_attn)
|
||||
|
||||
self.encoder_to_decoder = nn.Linear(
|
||||
encoder_embed_dim, decoder_embed_dim, bias=False)
|
||||
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
||||
|
||||
self.pos_embed = get_sinusoid_encoding_table(
|
||||
self.encoder.patch_embed.num_patches, decoder_embed_dim)
|
||||
|
||||
trunc_normal_(self.mask_token, std=.02)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.xavier_uniform_(m.weight)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def get_num_layers(self):
|
||||
return len(self.blocks)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token', 'mask_token'}
|
||||
|
||||
def forward(self, x, mask, decode_mask=None):
|
||||
decode_vis = mask if decode_mask is None else ~decode_mask
|
||||
|
||||
x_vis = self.encoder(x, mask) # [B, N_vis, C_e]
|
||||
x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d]
|
||||
B, N_vis, C = x_vis.shape
|
||||
|
||||
# we don't unshuffle the correct visible token order,
|
||||
# but shuffle the pos embedding accorddingly.
|
||||
expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(
|
||||
x.device).clone().detach()
|
||||
pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
|
||||
pos_emd_mask = expand_pos_embed[decode_vis].reshape(B, -1, C)
|
||||
|
||||
# [B, N, C_d]
|
||||
x_full = torch.cat(
|
||||
[x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1)
|
||||
# NOTE: if N_mask==0, the shape of x is [B, N_mask, 3 * 16 * 16]
|
||||
x = self.decoder(x_full, pos_emd_mask.shape[1])
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def pretrain_videomae_small_patch16_224(pretrained=False, **kwargs):
|
||||
model = PretrainVisionTransformer(
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
encoder_embed_dim=384,
|
||||
encoder_depth=12,
|
||||
encoder_num_heads=6,
|
||||
encoder_num_classes=0,
|
||||
decoder_num_classes=1536, # 16 * 16 * 3 * 2
|
||||
decoder_embed_dim=192,
|
||||
decoder_num_heads=3,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
def pretrain_videomae_base_patch16_224(pretrained=False, **kwargs):
|
||||
model = PretrainVisionTransformer(
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
encoder_embed_dim=768,
|
||||
encoder_depth=12,
|
||||
encoder_num_heads=12,
|
||||
encoder_num_classes=0,
|
||||
decoder_num_classes=1536, # 16 * 16 * 3 * 2
|
||||
decoder_embed_dim=384,
|
||||
decoder_num_heads=6,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
def pretrain_videomae_large_patch16_224(pretrained=False, **kwargs):
|
||||
model = PretrainVisionTransformer(
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
encoder_embed_dim=1024,
|
||||
encoder_depth=24,
|
||||
encoder_num_heads=16,
|
||||
encoder_num_classes=0,
|
||||
decoder_num_classes=1536, # 16 * 16 * 3 * 2
|
||||
decoder_embed_dim=512,
|
||||
decoder_num_heads=8,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
def pretrain_videomae_huge_patch16_224(pretrained=False, **kwargs):
|
||||
model = PretrainVisionTransformer(
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
encoder_embed_dim=1280,
|
||||
encoder_depth=32,
|
||||
encoder_num_heads=16,
|
||||
encoder_num_classes=0,
|
||||
decoder_num_classes=1536, # 16 * 16 * 3 * 2
|
||||
decoder_embed_dim=512,
|
||||
decoder_num_heads=8,
|
||||
mlp_ratio=4,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
|
||||
|
||||
def pretrain_videomae_giant_patch14_224(pretrained=False, **kwargs):
|
||||
model = PretrainVisionTransformer(
|
||||
img_size=224,
|
||||
patch_size=14,
|
||||
encoder_embed_dim=1408,
|
||||
encoder_depth=40,
|
||||
encoder_num_heads=16,
|
||||
encoder_num_classes=0,
|
||||
decoder_num_classes=1176, # 14 * 14 * 3 * 2,
|
||||
decoder_embed_dim=512,
|
||||
decoder_num_heads=8,
|
||||
mlp_ratio=48 / 11,
|
||||
qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
**kwargs)
|
||||
model.default_cfg = _cfg()
|
||||
if pretrained:
|
||||
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
return model
|
||||
0
models/LatentSync/latentsync/trepa/third_party/__init__.py
vendored
Normal file
0
models/LatentSync/latentsync/trepa/third_party/__init__.py
vendored
Normal file
321
models/LatentSync/latentsync/trepa/utils/data_utils.py
Normal file
321
models/LatentSync/latentsync/trepa/utils/data_utils.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import os
|
||||
import math
|
||||
import os.path as osp
|
||||
import random
|
||||
import pickle
|
||||
import warnings
|
||||
|
||||
import glob
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from torchvision.datasets.video_utils import VideoClips
|
||||
|
||||
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
|
||||
VID_EXTENSIONS = ['.avi', '.mp4', '.webm', '.mov', '.mkv', '.m4v']
|
||||
|
||||
|
||||
def get_dataloader(data_path, image_folder, resolution=128, sequence_length=16, sample_every_n_frames=1,
|
||||
batch_size=16, num_workers=8):
|
||||
data = VideoData(data_path, image_folder, resolution, sequence_length, sample_every_n_frames, batch_size, num_workers)
|
||||
loader = data._dataloader()
|
||||
return loader
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def get_parent_dir(path):
|
||||
return osp.basename(osp.dirname(path))
|
||||
|
||||
|
||||
def preprocess(video, resolution, sequence_length=None, in_channels=3, sample_every_n_frames=1):
|
||||
# video: THWC, {0, ..., 255}
|
||||
assert in_channels == 3
|
||||
video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
|
||||
t, c, h, w = video.shape
|
||||
|
||||
# temporal crop
|
||||
if sequence_length is not None:
|
||||
assert sequence_length <= t
|
||||
video = video[:sequence_length]
|
||||
|
||||
# skip frames
|
||||
if sample_every_n_frames > 1:
|
||||
video = video[::sample_every_n_frames]
|
||||
|
||||
# scale shorter side to resolution
|
||||
scale = resolution / min(h, w)
|
||||
if h < w:
|
||||
target_size = (resolution, math.ceil(w * scale))
|
||||
else:
|
||||
target_size = (math.ceil(h * scale), resolution)
|
||||
video = F.interpolate(video, size=target_size, mode='bilinear',
|
||||
align_corners=False, antialias=True)
|
||||
|
||||
# center crop
|
||||
t, c, h, w = video.shape
|
||||
w_start = (w - resolution) // 2
|
||||
h_start = (h - resolution) // 2
|
||||
video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
|
||||
video = video.permute(1, 0, 2, 3).contiguous() # CTHW
|
||||
|
||||
return {'video': video}
|
||||
|
||||
|
||||
def preprocess_image(image):
|
||||
# [0, 1] => [-1, 1]
|
||||
img = torch.from_numpy(image)
|
||||
return img
|
||||
|
||||
|
||||
class VideoData(data.Dataset):
|
||||
""" Class to create dataloaders for video datasets
|
||||
|
||||
Args:
|
||||
data_path: Path to the folder with video frames or videos.
|
||||
image_folder: If True, the data is stored as images in folders.
|
||||
resolution: Resolution of the returned videos.
|
||||
sequence_length: Length of extracted video sequences.
|
||||
sample_every_n_frames: Sample every n frames from the video.
|
||||
batch_size: Batch size.
|
||||
num_workers: Number of workers for the dataloader.
|
||||
shuffle: If True, shuffle the data.
|
||||
"""
|
||||
|
||||
def __init__(self, data_path: str, image_folder: bool, resolution: int, sequence_length: int,
|
||||
sample_every_n_frames: int, batch_size: int, num_workers: int, shuffle: bool = True):
|
||||
super().__init__()
|
||||
self.data_path = data_path
|
||||
self.image_folder = image_folder
|
||||
self.resolution = resolution
|
||||
self.sequence_length = sequence_length
|
||||
self.sample_every_n_frames = sample_every_n_frames
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
self.shuffle = shuffle
|
||||
|
||||
def _dataset(self):
|
||||
'''
|
||||
Initializes and return the dataset.
|
||||
'''
|
||||
if self.image_folder:
|
||||
Dataset = FrameDataset
|
||||
dataset = Dataset(self.data_path, self.sequence_length,
|
||||
resolution=self.resolution, sample_every_n_frames=self.sample_every_n_frames)
|
||||
else:
|
||||
Dataset = VideoDataset
|
||||
dataset = Dataset(self.data_path, self.sequence_length,
|
||||
resolution=self.resolution, sample_every_n_frames=self.sample_every_n_frames)
|
||||
return dataset
|
||||
|
||||
def _dataloader(self):
|
||||
'''
|
||||
Initializes and returns the dataloader.
|
||||
'''
|
||||
dataset = self._dataset()
|
||||
if dist.is_initialized():
|
||||
sampler = data.distributed.DistributedSampler(
|
||||
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()
|
||||
)
|
||||
else:
|
||||
sampler = None
|
||||
dataloader = data.DataLoader(
|
||||
dataset,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=True,
|
||||
sampler=sampler,
|
||||
shuffle=sampler is None and self.shuffle is True
|
||||
)
|
||||
return dataloader
|
||||
|
||||
|
||||
class VideoDataset(data.Dataset):
|
||||
"""
|
||||
Generic dataset for videos files stored in folders.
|
||||
Videos of the same class are expected to be stored in a single folder. Multiple folders can exist in the provided directory.
|
||||
The class depends on `torchvision.datasets.video_utils.VideoClips` to load the videos.
|
||||
Returns BCTHW videos in the range [0, 1].
|
||||
|
||||
Args:
|
||||
data_folder: Path to the folder with corresponding videos stored.
|
||||
sequence_length: Length of extracted video sequences.
|
||||
resolution: Resolution of the returned videos.
|
||||
sample_every_n_frames: Sample every n frames from the video.
|
||||
"""
|
||||
|
||||
def __init__(self, data_folder: str, sequence_length: int = 16, resolution: int = 128, sample_every_n_frames: int = 1):
|
||||
super().__init__()
|
||||
self.sequence_length = sequence_length
|
||||
self.resolution = resolution
|
||||
self.sample_every_n_frames = sample_every_n_frames
|
||||
|
||||
folder = data_folder
|
||||
files = sum([glob.glob(osp.join(folder, '**', f'*{ext}'), recursive=True)
|
||||
for ext in VID_EXTENSIONS], [])
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
cache_file = osp.join(folder, f"metadata_{sequence_length}.pkl")
|
||||
if not osp.exists(cache_file):
|
||||
clips = VideoClips(files, sequence_length, num_workers=4)
|
||||
try:
|
||||
pickle.dump(clips.metadata, open(cache_file, 'wb'))
|
||||
except:
|
||||
print(f"Failed to save metadata to {cache_file}")
|
||||
else:
|
||||
metadata = pickle.load(open(cache_file, 'rb'))
|
||||
clips = VideoClips(files, sequence_length,
|
||||
_precomputed_metadata=metadata)
|
||||
|
||||
self._clips = clips
|
||||
# instead of uniformly sampling from all possible clips, we sample uniformly from all possible videos
|
||||
self._clips.get_clip_location = self.get_random_clip_from_video
|
||||
|
||||
def get_random_clip_from_video(self, idx: int) -> tuple:
|
||||
'''
|
||||
Sample a random clip starting index from the video.
|
||||
|
||||
Args:
|
||||
idx: Index of the video.
|
||||
'''
|
||||
# Note that some videos may not contain enough frames, we skip those videos here.
|
||||
while self._clips.clips[idx].shape[0] <= 0:
|
||||
idx += 1
|
||||
n_clip = self._clips.clips[idx].shape[0]
|
||||
clip_id = random.randint(0, n_clip - 1)
|
||||
return idx, clip_id
|
||||
|
||||
def __len__(self):
|
||||
return self._clips.num_videos()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
resolution = self.resolution
|
||||
while True:
|
||||
try:
|
||||
video, _, _, idx = self._clips.get_clip(idx)
|
||||
except Exception as e:
|
||||
print(idx, e)
|
||||
idx = (idx + 1) % self._clips.num_clips()
|
||||
continue
|
||||
break
|
||||
|
||||
return dict(**preprocess(video, resolution, sample_every_n_frames=self.sample_every_n_frames))
|
||||
|
||||
|
||||
class FrameDataset(data.Dataset):
|
||||
"""
|
||||
Generic dataset for videos stored as images. The loading will iterates over all the folders and subfolders
|
||||
in the provided directory. Each leaf folder is assumed to contain frames from a single video.
|
||||
|
||||
Args:
|
||||
data_folder: path to the folder with video frames. The folder
|
||||
should contain folders with frames from each video.
|
||||
sequence_length: length of extracted video sequences
|
||||
resolution: resolution of the returned videos
|
||||
sample_every_n_frames: sample every n frames from the video
|
||||
"""
|
||||
|
||||
def __init__(self, data_folder, sequence_length, resolution=64, sample_every_n_frames=1):
|
||||
self.resolution = resolution
|
||||
self.sequence_length = sequence_length
|
||||
self.sample_every_n_frames = sample_every_n_frames
|
||||
self.data_all = self.load_video_frames(data_folder)
|
||||
self.video_num = len(self.data_all)
|
||||
|
||||
def __getitem__(self, index):
|
||||
batch_data = self.getTensor(index)
|
||||
return_list = {'video': batch_data}
|
||||
|
||||
return return_list
|
||||
|
||||
def load_video_frames(self, dataroot: str) -> list:
|
||||
'''
|
||||
Loads all the video frames under the dataroot and returns a list of all the video frames.
|
||||
|
||||
Args:
|
||||
dataroot: The root directory containing the video frames.
|
||||
|
||||
Returns:
|
||||
A list of all the video frames.
|
||||
|
||||
'''
|
||||
data_all = []
|
||||
frame_list = os.walk(dataroot)
|
||||
for _, meta in enumerate(frame_list):
|
||||
root = meta[0]
|
||||
try:
|
||||
frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
|
||||
except:
|
||||
print(meta[0], meta[2])
|
||||
if len(frames) < max(0, self.sequence_length * self.sample_every_n_frames):
|
||||
continue
|
||||
frames = [
|
||||
os.path.join(root, item) for item in frames
|
||||
if is_image_file(item)
|
||||
]
|
||||
if len(frames) > max(0, self.sequence_length * self.sample_every_n_frames):
|
||||
data_all.append(frames)
|
||||
|
||||
return data_all
|
||||
|
||||
def getTensor(self, index: int) -> torch.Tensor:
|
||||
'''
|
||||
Returns a tensor of the video frames at the given index.
|
||||
|
||||
Args:
|
||||
index: The index of the video frames to return.
|
||||
|
||||
Returns:
|
||||
A BCTHW tensor in the range `[0, 1]` of the video frames at the given index.
|
||||
|
||||
'''
|
||||
video = self.data_all[index]
|
||||
video_len = len(video)
|
||||
|
||||
# load the entire video when sequence_length = -1, whiel the sample_every_n_frames has to be 1
|
||||
if self.sequence_length == -1:
|
||||
assert self.sample_every_n_frames == 1
|
||||
start_idx = 0
|
||||
end_idx = video_len
|
||||
else:
|
||||
n_frames_interval = self.sequence_length * self.sample_every_n_frames
|
||||
start_idx = random.randint(0, video_len - n_frames_interval)
|
||||
end_idx = start_idx + n_frames_interval
|
||||
img = Image.open(video[0])
|
||||
h, w = img.height, img.width
|
||||
|
||||
if h > w:
|
||||
half = (h - w) // 2
|
||||
cropsize = (0, half, w, half + w) # left, upper, right, lower
|
||||
elif w > h:
|
||||
half = (w - h) // 2
|
||||
cropsize = (half, 0, half + h, h)
|
||||
|
||||
images = []
|
||||
for i in range(start_idx, end_idx,
|
||||
self.sample_every_n_frames):
|
||||
path = video[i]
|
||||
img = Image.open(path)
|
||||
|
||||
if h != w:
|
||||
img = img.crop(cropsize)
|
||||
|
||||
img = img.resize(
|
||||
(self.resolution, self.resolution),
|
||||
Image.ANTIALIAS)
|
||||
img = np.asarray(img, dtype=np.float32)
|
||||
img /= 255.
|
||||
img_tensor = preprocess_image(img).unsqueeze(0)
|
||||
images.append(img_tensor)
|
||||
|
||||
video_clip = torch.cat(images).permute(3, 0, 1, 2)
|
||||
return video_clip
|
||||
|
||||
def __len__(self):
|
||||
return self.video_num
|
||||
161
models/LatentSync/latentsync/trepa/utils/metric_utils.py
Normal file
161
models/LatentSync/latentsync/trepa/utils/metric_utils.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# Adapted from https://github.com/universome/stylegan-v/blob/master/src/metrics/metric_utils.py
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
def seed_everything(seed):
|
||||
random.seed(seed)
|
||||
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
|
||||
class FeatureStats:
|
||||
'''
|
||||
Class to store statistics of features, including all features and mean/covariance.
|
||||
|
||||
Args:
|
||||
capture_all: Whether to store all the features.
|
||||
capture_mean_cov: Whether to store mean and covariance.
|
||||
max_items: Maximum number of items to store.
|
||||
'''
|
||||
def __init__(self, capture_all: bool = False, capture_mean_cov: bool = False, max_items: int = None):
|
||||
'''
|
||||
'''
|
||||
self.capture_all = capture_all
|
||||
self.capture_mean_cov = capture_mean_cov
|
||||
self.max_items = max_items
|
||||
self.num_items = 0
|
||||
self.num_features = None
|
||||
self.all_features = None
|
||||
self.raw_mean = None
|
||||
self.raw_cov = None
|
||||
|
||||
def set_num_features(self, num_features: int):
|
||||
'''
|
||||
Set the number of features diminsions.
|
||||
|
||||
Args:
|
||||
num_features: Number of features diminsions.
|
||||
'''
|
||||
if self.num_features is not None:
|
||||
assert num_features == self.num_features
|
||||
else:
|
||||
self.num_features = num_features
|
||||
self.all_features = []
|
||||
self.raw_mean = np.zeros([num_features], dtype=np.float64)
|
||||
self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
|
||||
|
||||
def is_full(self) -> bool:
|
||||
'''
|
||||
Check if the maximum number of samples is reached.
|
||||
|
||||
Returns:
|
||||
True if the storage is full, False otherwise.
|
||||
'''
|
||||
return (self.max_items is not None) and (self.num_items >= self.max_items)
|
||||
|
||||
def append(self, x: np.ndarray):
|
||||
'''
|
||||
Add the newly computed features to the list. Update the mean and covariance.
|
||||
|
||||
Args:
|
||||
x: New features to record.
|
||||
'''
|
||||
x = np.asarray(x, dtype=np.float32)
|
||||
assert x.ndim == 2
|
||||
if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
|
||||
if self.num_items >= self.max_items:
|
||||
return
|
||||
x = x[:self.max_items - self.num_items]
|
||||
|
||||
self.set_num_features(x.shape[1])
|
||||
self.num_items += x.shape[0]
|
||||
if self.capture_all:
|
||||
self.all_features.append(x)
|
||||
if self.capture_mean_cov:
|
||||
x64 = x.astype(np.float64)
|
||||
self.raw_mean += x64.sum(axis=0)
|
||||
self.raw_cov += x64.T @ x64
|
||||
|
||||
def append_torch(self, x: torch.Tensor, rank: int, num_gpus: int):
|
||||
'''
|
||||
Add the newly computed PyTorch features to the list. Update the mean and covariance.
|
||||
|
||||
Args:
|
||||
x: New features to record.
|
||||
rank: Rank of the current GPU.
|
||||
num_gpus: Total number of GPUs.
|
||||
'''
|
||||
assert isinstance(x, torch.Tensor) and x.ndim == 2
|
||||
assert 0 <= rank < num_gpus
|
||||
if num_gpus > 1:
|
||||
ys = []
|
||||
for src in range(num_gpus):
|
||||
y = x.clone()
|
||||
torch.distributed.broadcast(y, src=src)
|
||||
ys.append(y)
|
||||
x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
|
||||
self.append(x.cpu().numpy())
|
||||
|
||||
def get_all(self) -> np.ndarray:
|
||||
'''
|
||||
Get all the stored features as NumPy Array.
|
||||
|
||||
Returns:
|
||||
Concatenation of the stored features.
|
||||
'''
|
||||
assert self.capture_all
|
||||
return np.concatenate(self.all_features, axis=0)
|
||||
|
||||
def get_all_torch(self) -> torch.Tensor:
|
||||
'''
|
||||
Get all the stored features as PyTorch Tensor.
|
||||
|
||||
Returns:
|
||||
Concatenation of the stored features.
|
||||
'''
|
||||
return torch.from_numpy(self.get_all())
|
||||
|
||||
def get_mean_cov(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
'''
|
||||
Get the mean and covariance of the stored features.
|
||||
|
||||
Returns:
|
||||
Mean and covariance of the stored features.
|
||||
'''
|
||||
assert self.capture_mean_cov
|
||||
mean = self.raw_mean / self.num_items
|
||||
cov = self.raw_cov / self.num_items
|
||||
cov = cov - np.outer(mean, mean)
|
||||
return mean, cov
|
||||
|
||||
def save(self, pkl_file: str):
|
||||
'''
|
||||
Save the features and statistics to a pickle file.
|
||||
|
||||
Args:
|
||||
pkl_file: Path to the pickle file.
|
||||
'''
|
||||
with open(pkl_file, 'wb') as f:
|
||||
pickle.dump(self.__dict__, f)
|
||||
|
||||
@staticmethod
|
||||
def load(pkl_file: str) -> 'FeatureStats':
|
||||
'''
|
||||
Load the features and statistics from a pickle file.
|
||||
|
||||
Args:
|
||||
pkl_file: Path to the pickle file.
|
||||
'''
|
||||
with open(pkl_file, 'rb') as f:
|
||||
s = pickle.load(f)
|
||||
obj = FeatureStats(capture_all=s['capture_all'], max_items=s['max_items'])
|
||||
obj.__dict__.update(s)
|
||||
print('Loaded %d features from %s' % (obj.num_items, pkl_file))
|
||||
return obj
|
||||
145
models/LatentSync/latentsync/utils/affine_transform.py
Normal file
145
models/LatentSync/latentsync/utils/affine_transform.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# Adapted from https://github.com/guanjz20/StyleSync/blob/main/utils.py
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
from einops import rearrange
|
||||
import kornia
|
||||
|
||||
|
||||
class AlignRestore(object):
|
||||
def __init__(self, align_points=3, resolution=256, device="cpu", dtype=torch.float16):
|
||||
if align_points == 3:
|
||||
self.upscale_factor = 1
|
||||
ratio = resolution / 256 * 2.8
|
||||
self.crop_ratio = (ratio, ratio)
|
||||
self.face_template = np.array([[19 - 2, 30 - 10], [56 + 2, 30 - 10], [37.5, 45 - 5]])
|
||||
self.face_template = self.face_template * ratio
|
||||
self.face_size = (int(75 * self.crop_ratio[0]), int(100 * self.crop_ratio[1]))
|
||||
self.p_bias = None
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.fill_value = torch.tensor([127, 127, 127], device=device, dtype=dtype)
|
||||
self.mask = torch.ones((1, 1, self.face_size[1], self.face_size[0]), device=device, dtype=dtype)
|
||||
|
||||
def align_warp_face(self, img, landmarks3, smooth=True):
|
||||
affine_matrix, self.p_bias = self.transformation_from_points(
|
||||
landmarks3, self.face_template, smooth, self.p_bias
|
||||
)
|
||||
|
||||
img = rearrange(torch.from_numpy(img).to(device=self.device, dtype=self.dtype), "h w c -> c h w").unsqueeze(0)
|
||||
affine_matrix = torch.from_numpy(affine_matrix).to(device=self.device, dtype=self.dtype).unsqueeze(0)
|
||||
|
||||
cropped_face = kornia.geometry.transform.warp_affine(
|
||||
img,
|
||||
affine_matrix,
|
||||
(self.face_size[1], self.face_size[0]),
|
||||
mode="bilinear",
|
||||
padding_mode="fill",
|
||||
fill_value=self.fill_value,
|
||||
)
|
||||
cropped_face = rearrange(cropped_face.squeeze(0), "c h w -> h w c").cpu().numpy().astype(np.uint8)
|
||||
return cropped_face, affine_matrix
|
||||
|
||||
def restore_img(self, input_img, face, affine_matrix):
|
||||
h, w, _ = input_img.shape
|
||||
|
||||
if isinstance(affine_matrix, np.ndarray):
|
||||
affine_matrix = torch.from_numpy(affine_matrix).to(device=self.device, dtype=self.dtype).unsqueeze(0)
|
||||
|
||||
inv_affine_matrix = kornia.geometry.transform.invert_affine_transform(affine_matrix)
|
||||
face = face.to(dtype=self.dtype).unsqueeze(0)
|
||||
|
||||
inv_face = kornia.geometry.transform.warp_affine(
|
||||
face, inv_affine_matrix, (h, w), mode="bilinear", padding_mode="fill", fill_value=self.fill_value
|
||||
).squeeze(0)
|
||||
inv_face = (inv_face / 2 + 0.5).clamp(0, 1) * 255
|
||||
|
||||
input_img = rearrange(torch.from_numpy(input_img).to(device=self.device, dtype=self.dtype), "h w c -> c h w")
|
||||
inv_mask = kornia.geometry.transform.warp_affine(
|
||||
self.mask, inv_affine_matrix, (h, w), padding_mode="zeros"
|
||||
) # (1, 1, h_up, w_up)
|
||||
|
||||
inv_mask_erosion = kornia.morphology.erosion(
|
||||
inv_mask,
|
||||
torch.ones(
|
||||
(int(2 * self.upscale_factor), int(2 * self.upscale_factor)), device=self.device, dtype=self.dtype
|
||||
),
|
||||
)
|
||||
|
||||
inv_mask_erosion_t = inv_mask_erosion.squeeze(0).expand_as(inv_face)
|
||||
pasted_face = inv_mask_erosion_t * inv_face
|
||||
total_face_area = torch.sum(inv_mask_erosion.float())
|
||||
w_edge = int(total_face_area**0.5) // 20
|
||||
erosion_radius = w_edge * 2
|
||||
|
||||
# This step will consume a large amount of GPU memory.
|
||||
# inv_mask_center = kornia.morphology.erosion(
|
||||
# inv_mask_erosion, torch.ones((erosion_radius, erosion_radius), device=self.device, dtype=self.dtype)
|
||||
# )
|
||||
|
||||
# Run on CPU to avoid consuming a large amount of GPU memory.
|
||||
inv_mask_erosion = inv_mask_erosion.squeeze().cpu().numpy().astype(np.float32)
|
||||
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
|
||||
inv_mask_center = torch.from_numpy(inv_mask_center).to(device=self.device, dtype=self.dtype)[None, None, ...]
|
||||
|
||||
blur_size = w_edge * 2 + 1
|
||||
sigma = 0.3 * ((blur_size - 1) * 0.5 - 1) + 0.8
|
||||
inv_soft_mask = kornia.filters.gaussian_blur2d(
|
||||
inv_mask_center, (blur_size, blur_size), (sigma, sigma)
|
||||
).squeeze(0)
|
||||
inv_soft_mask_3d = inv_soft_mask.expand_as(inv_face)
|
||||
img_back = inv_soft_mask_3d * pasted_face + (1 - inv_soft_mask_3d) * input_img
|
||||
|
||||
img_back = rearrange(img_back, "c h w -> h w c").contiguous().to(dtype=torch.uint8)
|
||||
img_back = img_back.cpu().numpy()
|
||||
return img_back
|
||||
|
||||
def transformation_from_points(self, points1: torch.Tensor, points0: torch.Tensor, smooth=True, p_bias=None):
|
||||
if isinstance(points0, np.ndarray):
|
||||
points2 = torch.tensor(points0, device=self.device, dtype=torch.float32)
|
||||
else:
|
||||
points2 = points0.clone()
|
||||
|
||||
if isinstance(points1, np.ndarray):
|
||||
points1_tensor = torch.tensor(points1, device=self.device, dtype=torch.float32)
|
||||
else:
|
||||
points1_tensor = points1.clone()
|
||||
|
||||
c1 = torch.mean(points1_tensor, dim=0)
|
||||
c2 = torch.mean(points2, dim=0)
|
||||
|
||||
points1_centered = points1_tensor - c1
|
||||
points2_centered = points2 - c2
|
||||
|
||||
s1 = torch.std(points1_centered)
|
||||
s2 = torch.std(points2_centered)
|
||||
|
||||
points1_normalized = points1_centered / s1
|
||||
points2_normalized = points2_centered / s2
|
||||
|
||||
covariance = torch.matmul(points1_normalized.T, points2_normalized)
|
||||
U, S, V = torch.svd(covariance.float())
|
||||
|
||||
R = torch.matmul(V, U.T)
|
||||
|
||||
det = torch.det(R.float())
|
||||
if det < 0:
|
||||
V[:, -1] = -V[:, -1]
|
||||
R = torch.matmul(V, U.T)
|
||||
|
||||
sR = (s2 / s1) * R
|
||||
T = c2.reshape(2, 1) - (s2 / s1) * torch.matmul(R, c1.reshape(2, 1))
|
||||
|
||||
M = torch.cat((sR, T), dim=1)
|
||||
|
||||
if smooth:
|
||||
bias = points2_normalized[2] - points1_normalized[2]
|
||||
if p_bias is None:
|
||||
p_bias = bias
|
||||
else:
|
||||
bias = p_bias * 0.2 + bias * 0.8
|
||||
p_bias = bias
|
||||
M[:, 2] = M[:, 2] + bias
|
||||
|
||||
return M.cpu().numpy(), p_bias
|
||||
194
models/LatentSync/latentsync/utils/audio.py
Normal file
194
models/LatentSync/latentsync/utils/audio.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# Adapted from https://github.com/Rudrabha/Wav2Lip/blob/master/audio.py
|
||||
|
||||
import librosa
|
||||
import librosa.filters
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from scipy.io import wavfile
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
|
||||
audio_config_path = "configs/audio.yaml"
|
||||
|
||||
config = OmegaConf.load(audio_config_path)
|
||||
|
||||
|
||||
def load_wav(path, sr):
|
||||
return librosa.core.load(path, sr=sr)[0]
|
||||
|
||||
|
||||
def save_wav(wav, path, sr):
|
||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||
# proposed by @dsmiller
|
||||
wavfile.write(path, sr, wav.astype(np.int16))
|
||||
|
||||
|
||||
def save_wavenet_wav(wav, path, sr):
|
||||
librosa.output.write_wav(path, wav, sr=sr)
|
||||
|
||||
|
||||
def preemphasis(wav, k, preemphasize=True):
|
||||
if preemphasize:
|
||||
return signal.lfilter([1, -k], [1], wav)
|
||||
return wav
|
||||
|
||||
|
||||
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
||||
if inv_preemphasize:
|
||||
return signal.lfilter([1], [1, -k], wav)
|
||||
return wav
|
||||
|
||||
|
||||
def get_hop_size():
|
||||
hop_size = config.audio.hop_size
|
||||
if hop_size is None:
|
||||
assert config.audio.frame_shift_ms is not None
|
||||
hop_size = int(config.audio.frame_shift_ms / 1000 * config.audio.sample_rate)
|
||||
return hop_size
|
||||
|
||||
|
||||
def linearspectrogram(wav):
|
||||
D = _stft(preemphasis(wav, config.audio.preemphasis, config.audio.preemphasize))
|
||||
S = _amp_to_db(np.abs(D)) - config.audio.ref_level_db
|
||||
|
||||
if config.audio.signal_normalization:
|
||||
return _normalize(S)
|
||||
return S
|
||||
|
||||
|
||||
def melspectrogram(wav):
|
||||
D = _stft(preemphasis(wav, config.audio.preemphasis, config.audio.preemphasize))
|
||||
S = _amp_to_db(_linear_to_mel(np.abs(D))) - config.audio.ref_level_db
|
||||
|
||||
if config.audio.signal_normalization:
|
||||
return _normalize(S)
|
||||
return S
|
||||
|
||||
|
||||
def _lws_processor():
|
||||
import lws
|
||||
|
||||
return lws.lws(config.audio.n_fft, get_hop_size(), fftsize=config.audio.win_size, mode="speech")
|
||||
|
||||
|
||||
def _stft(y):
|
||||
if config.audio.use_lws:
|
||||
return _lws_processor(config.audio).stft(y).T
|
||||
else:
|
||||
return librosa.stft(y=y, n_fft=config.audio.n_fft, hop_length=get_hop_size(), win_length=config.audio.win_size)
|
||||
|
||||
|
||||
##########################################################
|
||||
# Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
||||
def num_frames(length, fsize, fshift):
|
||||
"""Compute number of time frames of spectrogram"""
|
||||
pad = fsize - fshift
|
||||
if length % fshift == 0:
|
||||
M = (length + pad * 2 - fsize) // fshift + 1
|
||||
else:
|
||||
M = (length + pad * 2 - fsize) // fshift + 2
|
||||
return M
|
||||
|
||||
|
||||
def pad_lr(x, fsize, fshift):
|
||||
"""Compute left and right padding"""
|
||||
M = num_frames(len(x), fsize, fshift)
|
||||
pad = fsize - fshift
|
||||
T = len(x) + 2 * pad
|
||||
r = (M - 1) * fshift + fsize - T
|
||||
return pad, pad + r
|
||||
|
||||
|
||||
##########################################################
|
||||
# Librosa correct padding
|
||||
def librosa_pad_lr(x, fsize, fshift):
|
||||
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
||||
|
||||
|
||||
# Conversions
|
||||
_mel_basis = None
|
||||
|
||||
|
||||
def _linear_to_mel(spectogram):
|
||||
global _mel_basis
|
||||
if _mel_basis is None:
|
||||
_mel_basis = _build_mel_basis()
|
||||
return np.dot(_mel_basis, spectogram)
|
||||
|
||||
|
||||
def _build_mel_basis():
|
||||
assert config.audio.fmax <= config.audio.sample_rate // 2
|
||||
return librosa.filters.mel(
|
||||
sr=config.audio.sample_rate,
|
||||
n_fft=config.audio.n_fft,
|
||||
n_mels=config.audio.num_mels,
|
||||
fmin=config.audio.fmin,
|
||||
fmax=config.audio.fmax,
|
||||
)
|
||||
|
||||
|
||||
def _amp_to_db(x):
|
||||
min_level = np.exp(config.audio.min_level_db / 20 * np.log(10))
|
||||
return 20 * np.log10(np.maximum(min_level, x))
|
||||
|
||||
|
||||
def _db_to_amp(x):
|
||||
return np.power(10.0, (x) * 0.05)
|
||||
|
||||
|
||||
def _normalize(S):
|
||||
if config.audio.allow_clipping_in_normalization:
|
||||
if config.audio.symmetric_mels:
|
||||
return np.clip(
|
||||
(2 * config.audio.max_abs_value) * ((S - config.audio.min_level_db) / (-config.audio.min_level_db))
|
||||
- config.audio.max_abs_value,
|
||||
-config.audio.max_abs_value,
|
||||
config.audio.max_abs_value,
|
||||
)
|
||||
else:
|
||||
return np.clip(
|
||||
config.audio.max_abs_value * ((S - config.audio.min_level_db) / (-config.audio.min_level_db)),
|
||||
0,
|
||||
config.audio.max_abs_value,
|
||||
)
|
||||
|
||||
assert S.max() <= 0 and S.min() - config.audio.min_level_db >= 0
|
||||
if config.audio.symmetric_mels:
|
||||
return (2 * config.audio.max_abs_value) * (
|
||||
(S - config.audio.min_level_db) / (-config.audio.min_level_db)
|
||||
) - config.audio.max_abs_value
|
||||
else:
|
||||
return config.audio.max_abs_value * ((S - config.audio.min_level_db) / (-config.audio.min_level_db))
|
||||
|
||||
|
||||
def _denormalize(D):
|
||||
if config.audio.allow_clipping_in_normalization:
|
||||
if config.audio.symmetric_mels:
|
||||
return (
|
||||
(np.clip(D, -config.audio.max_abs_value, config.audio.max_abs_value) + config.audio.max_abs_value)
|
||||
* -config.audio.min_level_db
|
||||
/ (2 * config.audio.max_abs_value)
|
||||
) + config.audio.min_level_db
|
||||
else:
|
||||
return (
|
||||
np.clip(D, 0, config.audio.max_abs_value) * -config.audio.min_level_db / config.audio.max_abs_value
|
||||
) + config.audio.min_level_db
|
||||
|
||||
if config.audio.symmetric_mels:
|
||||
return (
|
||||
(D + config.audio.max_abs_value) * -config.audio.min_level_db / (2 * config.audio.max_abs_value)
|
||||
) + config.audio.min_level_db
|
||||
else:
|
||||
return (D * -config.audio.min_level_db / config.audio.max_abs_value) + config.audio.min_level_db
|
||||
|
||||
|
||||
def get_melspec_overlap(audio_samples, melspec_length=52):
|
||||
mel_spec_overlap = melspectrogram(audio_samples.numpy())
|
||||
mel_spec_overlap = torch.from_numpy(mel_spec_overlap)
|
||||
i = 0
|
||||
mel_spec_overlap_list = []
|
||||
while i + melspec_length < mel_spec_overlap.shape[1] - 3:
|
||||
mel_spec_overlap_list.append(mel_spec_overlap[:, i : i + melspec_length].unsqueeze(0))
|
||||
i += 3
|
||||
mel_spec_overlap = torch.stack(mel_spec_overlap_list)
|
||||
return mel_spec_overlap
|
||||
157
models/LatentSync/latentsync/utils/av_reader.py
Normal file
157
models/LatentSync/latentsync/utils/av_reader.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# We modified the original AVReader class of decord to solve the problem of memory leak.
|
||||
# For more details, refer to: https://github.com/dmlc/decord/issues/208
|
||||
|
||||
import numpy as np
|
||||
from decord.video_reader import VideoReader
|
||||
from decord.audio_reader import AudioReader
|
||||
|
||||
from decord.ndarray import cpu
|
||||
from decord import ndarray as _nd
|
||||
from decord.bridge import bridge_out
|
||||
|
||||
|
||||
class AVReader(object):
|
||||
"""Individual audio video reader with convenient indexing function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri: str
|
||||
Path of file.
|
||||
ctx: decord.Context
|
||||
The context to decode the file, can be decord.cpu() or decord.gpu().
|
||||
sample_rate: int, default is -1
|
||||
Desired output sample rate of the audio, unchanged if `-1` is specified.
|
||||
mono: bool, default is True
|
||||
Desired output channel layout of the audio. `True` is mono layout. `False` is unchanged.
|
||||
width : int, default is -1
|
||||
Desired output width of the video, unchanged if `-1` is specified.
|
||||
height : int, default is -1
|
||||
Desired output height of the video, unchanged if `-1` is specified.
|
||||
num_threads : int, default is 0
|
||||
Number of decoding thread, auto if `0` is specified.
|
||||
fault_tol : int, default is -1
|
||||
The threshold of corupted and recovered frames. This is to prevent silent fault
|
||||
tolerance when for example 50% frames of a video cannot be decoded and duplicate
|
||||
frames are returned. You may find the fault tolerant feature sweet in many cases,
|
||||
but not for training models. Say `N = # recovered frames`
|
||||
If `fault_tol` < 0, nothing will happen.
|
||||
If 0 < `fault_tol` < 1.0, if N > `fault_tol * len(video)`, raise `DECORDLimitReachedError`.
|
||||
If 1 < `fault_tol`, if N > `fault_tol`, raise `DECORDLimitReachedError`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, uri, ctx=cpu(0), sample_rate=44100, mono=True, width=-1, height=-1, num_threads=0, fault_tol=-1
|
||||
):
|
||||
self.__audio_reader = AudioReader(uri, ctx, sample_rate, mono)
|
||||
self.__audio_reader.add_padding()
|
||||
if hasattr(uri, "read"):
|
||||
uri.seek(0)
|
||||
self.__video_reader = VideoReader(uri, ctx, width, height, num_threads, fault_tol)
|
||||
self.__video_reader.seek(0)
|
||||
|
||||
def __len__(self):
|
||||
"""Get length of the video. Note that sometimes FFMPEG reports inaccurate number of frames,
|
||||
we always follow what FFMPEG reports.
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The number of frames in the video file.
|
||||
"""
|
||||
return len(self.__video_reader)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Get audio samples and video frame at `idx`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idx : int or slice
|
||||
The frame index, can be negative which means it will index backwards,
|
||||
or slice of frame indices.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(ndarray/list of ndarray, ndarray)
|
||||
First element is samples of shape CxS or a list of length N containing samples of shape CxS,
|
||||
where N is the number of frames, C is the number of channels,
|
||||
S is the number of samples of the corresponding frame.
|
||||
|
||||
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
|
||||
where N is the length of the slice.
|
||||
"""
|
||||
assert self.__video_reader is not None and self.__audio_reader is not None
|
||||
if isinstance(idx, slice):
|
||||
return self.get_batch(range(*idx.indices(len(self.__video_reader))))
|
||||
if idx < 0:
|
||||
idx += len(self.__video_reader)
|
||||
if idx >= len(self.__video_reader) or idx < 0:
|
||||
raise IndexError("Index: {} out of bound: {}".format(idx, len(self.__video_reader)))
|
||||
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
|
||||
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
|
||||
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
|
||||
results = (self.__audio_reader[audio_start_idx:audio_end_idx], self.__video_reader[idx])
|
||||
self.__video_reader.seek(0)
|
||||
return results
|
||||
|
||||
def get_batch(self, indices):
|
||||
"""Get entire batch of audio samples and video frames.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
indices : list of integers
|
||||
A list of frame indices. If negative indices detected, the indices will be indexed from backward
|
||||
Returns
|
||||
-------
|
||||
(list of ndarray, ndarray)
|
||||
First element is a list of length N containing samples of shape CxS,
|
||||
where N is the number of frames, C is the number of channels,
|
||||
S is the number of samples of the corresponding frame.
|
||||
|
||||
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
|
||||
where N is the length of the slice.
|
||||
|
||||
"""
|
||||
assert self.__video_reader is not None and self.__audio_reader is not None
|
||||
indices = self._validate_indices(indices)
|
||||
audio_arr = []
|
||||
prev_video_idx = None
|
||||
prev_audio_end_idx = None
|
||||
for idx in list(indices):
|
||||
frame_start_time, frame_end_time = self.__video_reader.get_frame_timestamp(idx)
|
||||
# timestamp and sample conversion could have some error that could cause non-continuous audio
|
||||
# we detect if retrieving continuous frame and make the audio continuous
|
||||
if prev_video_idx and idx == prev_video_idx + 1:
|
||||
audio_start_idx = prev_audio_end_idx
|
||||
else:
|
||||
audio_start_idx = self.__audio_reader._time_to_sample(frame_start_time)
|
||||
audio_end_idx = self.__audio_reader._time_to_sample(frame_end_time)
|
||||
audio_arr.append(self.__audio_reader[audio_start_idx:audio_end_idx])
|
||||
prev_video_idx = idx
|
||||
prev_audio_end_idx = audio_end_idx
|
||||
results = (audio_arr, self.__video_reader.get_batch(indices))
|
||||
self.__video_reader.seek(0)
|
||||
return results
|
||||
|
||||
def _get_slice(self, sl):
|
||||
audio_arr = np.empty(shape=(self.__audio_reader.shape()[0], 0), dtype="float32")
|
||||
for idx in list(sl):
|
||||
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
|
||||
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
|
||||
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
|
||||
audio_arr = np.concatenate(
|
||||
(audio_arr, self.__audio_reader[audio_start_idx:audio_end_idx].asnumpy()), axis=1
|
||||
)
|
||||
results = (bridge_out(_nd.array(audio_arr)), self.__video_reader.get_batch(sl))
|
||||
self.__video_reader.seek(0)
|
||||
return results
|
||||
|
||||
def _validate_indices(self, indices):
|
||||
"""Validate int64 integers and convert negative integers to positive by backward search"""
|
||||
assert self.__video_reader is not None and self.__audio_reader is not None
|
||||
indices = np.array(indices, dtype=np.int64)
|
||||
# process negative indices
|
||||
indices[indices < 0] += len(self.__video_reader)
|
||||
if not (indices >= 0).all():
|
||||
raise IndexError("Invalid negative indices: {}".format(indices[indices < 0] + len(self.__video_reader)))
|
||||
if not (indices < len(self.__video_reader)).all():
|
||||
raise IndexError("Out of bound indices: {}".format(indices[indices >= len(self.__video_reader)]))
|
||||
return indices
|
||||
115
models/LatentSync/latentsync/utils/face_detector.py
Normal file
115
models/LatentSync/latentsync/utils/face_detector.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from insightface.app import FaceAnalysis
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
INSIGHTFACE_DETECT_SIZE = 512
|
||||
|
||||
|
||||
class FaceDetector:
|
||||
def __init__(self, device="cuda"):
|
||||
self.app = FaceAnalysis(
|
||||
allowed_modules=["detection", "landmark_2d_106"],
|
||||
root="checkpoints/auxiliary",
|
||||
providers=["CUDAExecutionProvider"],
|
||||
)
|
||||
self.app.prepare(ctx_id=cuda_to_int(device), det_size=(INSIGHTFACE_DETECT_SIZE, INSIGHTFACE_DETECT_SIZE))
|
||||
|
||||
def __call__(self, frame, threshold=0.5):
|
||||
f_h, f_w, _ = frame.shape
|
||||
|
||||
faces = self.app.get(frame)
|
||||
|
||||
get_face_store = None
|
||||
max_size = 0
|
||||
|
||||
if len(faces) == 0:
|
||||
return None, None
|
||||
else:
|
||||
for face in faces:
|
||||
bbox = face.bbox.astype(np.int_).tolist()
|
||||
w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
|
||||
if w < 50 or h < 80:
|
||||
continue
|
||||
if w / h > 1.5 or w / h < 0.2:
|
||||
continue
|
||||
if face.det_score < threshold:
|
||||
continue
|
||||
size_now = w * h
|
||||
|
||||
if size_now > max_size:
|
||||
max_size = size_now
|
||||
get_face_store = face
|
||||
|
||||
if get_face_store is None:
|
||||
return None, None
|
||||
else:
|
||||
face = get_face_store
|
||||
lmk = np.round(face.landmark_2d_106).astype(np.int_)
|
||||
|
||||
halk_face_coord = np.mean([lmk[74], lmk[73]], axis=0) # lmk[73]
|
||||
|
||||
sub_lmk = lmk[LMK_ADAPT_ORIGIN_ORDER]
|
||||
halk_face_dist = np.max(sub_lmk[:, 1]) - halk_face_coord[1]
|
||||
upper_bond = halk_face_coord[1] - halk_face_dist # *0.94
|
||||
|
||||
x1, y1, x2, y2 = (np.min(sub_lmk[:, 0]), int(upper_bond), np.max(sub_lmk[:, 0]), np.max(sub_lmk[:, 1]))
|
||||
|
||||
if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0:
|
||||
x1, y1, x2, y2 = face.bbox.astype(np.int_).tolist()
|
||||
|
||||
y2 += int((x2 - x1) * 0.1)
|
||||
x1 -= int((x2 - x1) * 0.05)
|
||||
x2 += int((x2 - x1) * 0.05)
|
||||
|
||||
x1 = max(0, x1)
|
||||
y1 = max(0, y1)
|
||||
x2 = min(f_w, x2)
|
||||
y2 = min(f_h, y2)
|
||||
|
||||
return (x1, y1, x2, y2), lmk
|
||||
|
||||
|
||||
def cuda_to_int(cuda_str: str) -> int:
|
||||
"""
|
||||
Convert the string with format "cuda:X" to integer X.
|
||||
"""
|
||||
if cuda_str == "cuda":
|
||||
return 0
|
||||
device = torch.device(cuda_str)
|
||||
if device.type != "cuda":
|
||||
raise ValueError(f"Device type must be 'cuda', got: {device.type}")
|
||||
return device.index
|
||||
|
||||
|
||||
LMK_ADAPT_ORIGIN_ORDER = [
|
||||
1,
|
||||
10,
|
||||
12,
|
||||
14,
|
||||
16,
|
||||
3,
|
||||
5,
|
||||
7,
|
||||
0,
|
||||
23,
|
||||
21,
|
||||
19,
|
||||
32,
|
||||
30,
|
||||
28,
|
||||
26,
|
||||
17,
|
||||
43,
|
||||
48,
|
||||
49,
|
||||
51,
|
||||
50,
|
||||
102,
|
||||
103,
|
||||
104,
|
||||
105,
|
||||
101,
|
||||
73,
|
||||
74,
|
||||
86,
|
||||
]
|
||||
122
models/LatentSync/latentsync/utils/image_processor.py
Normal file
122
models/LatentSync/latentsync/utils/image_processor.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from latentsync.utils.util import read_video, write_video
|
||||
from torchvision import transforms
|
||||
import cv2
|
||||
from einops import rearrange
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
from .affine_transform import AlignRestore
|
||||
from .face_detector import FaceDetector
|
||||
|
||||
|
||||
def load_fixed_mask(resolution: int, mask_image_path="latentsync/utils/mask.png") -> torch.Tensor:
|
||||
mask_image = cv2.imread(mask_image_path)
|
||||
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
|
||||
mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4) / 255.0
|
||||
mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w")
|
||||
return mask_image
|
||||
|
||||
|
||||
class ImageProcessor:
|
||||
def __init__(self, resolution: int = 512, device: str = "cpu", mask_image=None):
|
||||
self.resolution = resolution
|
||||
self.resize = transforms.Resize(
|
||||
(resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True
|
||||
)
|
||||
self.normalize = transforms.Normalize([0.5], [0.5], inplace=True)
|
||||
|
||||
self.restorer = AlignRestore(resolution=resolution, device=device)
|
||||
|
||||
if mask_image is None:
|
||||
self.mask_image = load_fixed_mask(resolution)
|
||||
else:
|
||||
self.mask_image = mask_image
|
||||
|
||||
if device == "cpu":
|
||||
self.face_detector = None
|
||||
else:
|
||||
self.face_detector = FaceDetector(device=device)
|
||||
|
||||
def affine_transform(self, image: torch.Tensor) -> np.ndarray:
|
||||
if self.face_detector is None:
|
||||
raise NotImplementedError("Using the CPU for face detection is not supported")
|
||||
bbox, landmark_2d_106 = self.face_detector(image)
|
||||
if bbox is None:
|
||||
raise RuntimeError("Face not detected")
|
||||
|
||||
pt_left_eye = np.mean(landmark_2d_106[[43, 48, 49, 51, 50]], axis=0) # left eyebrow center
|
||||
pt_right_eye = np.mean(landmark_2d_106[101:106], axis=0) # right eyebrow center
|
||||
pt_nose = np.mean(landmark_2d_106[[74, 77, 83, 86]], axis=0) # nose center
|
||||
|
||||
landmarks3 = np.round([pt_left_eye, pt_right_eye, pt_nose])
|
||||
|
||||
face, affine_matrix = self.restorer.align_warp_face(image.copy(), landmarks3=landmarks3, smooth=True)
|
||||
box = [0, 0, face.shape[1], face.shape[0]] # x1, y1, x2, y2
|
||||
face = cv2.resize(face, (self.resolution, self.resolution), interpolation=cv2.INTER_LANCZOS4)
|
||||
face = rearrange(torch.from_numpy(face), "h w c -> c h w")
|
||||
return face, box, affine_matrix
|
||||
|
||||
def preprocess_fixed_mask_image(self, image: torch.Tensor, affine_transform=False):
|
||||
if affine_transform:
|
||||
image, _, _ = self.affine_transform(image)
|
||||
else:
|
||||
image = self.resize(image)
|
||||
pixel_values = self.normalize(image / 255.0)
|
||||
masked_pixel_values = pixel_values * self.mask_image
|
||||
return pixel_values, masked_pixel_values, self.mask_image[0:1]
|
||||
|
||||
def prepare_masks_and_masked_images(self, images: Union[torch.Tensor, np.ndarray], affine_transform=False):
|
||||
if isinstance(images, np.ndarray):
|
||||
images = torch.from_numpy(images)
|
||||
if images.shape[3] == 3:
|
||||
images = rearrange(images, "f h w c -> f c h w")
|
||||
|
||||
results = [self.preprocess_fixed_mask_image(image, affine_transform=affine_transform) for image in images]
|
||||
|
||||
pixel_values_list, masked_pixel_values_list, masks_list = list(zip(*results))
|
||||
return torch.stack(pixel_values_list), torch.stack(masked_pixel_values_list), torch.stack(masks_list)
|
||||
|
||||
def process_images(self, images: Union[torch.Tensor, np.ndarray]):
|
||||
if isinstance(images, np.ndarray):
|
||||
images = torch.from_numpy(images)
|
||||
if images.shape[3] == 3:
|
||||
images = rearrange(images, "f h w c -> f c h w")
|
||||
images = self.resize(images)
|
||||
pixel_values = self.normalize(images / 255.0)
|
||||
return pixel_values
|
||||
|
||||
|
||||
class VideoProcessor:
|
||||
def __init__(self, resolution: int = 512, device: str = "cpu"):
|
||||
self.image_processor = ImageProcessor(resolution, device)
|
||||
|
||||
def affine_transform_video(self, video_path):
|
||||
video_frames = read_video(video_path, change_fps=False)
|
||||
results = []
|
||||
for frame in video_frames:
|
||||
frame, _, _ = self.image_processor.affine_transform(frame)
|
||||
results.append(frame)
|
||||
results = torch.stack(results)
|
||||
|
||||
results = rearrange(results, "f c h w -> f h w c").numpy()
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
video_processor = VideoProcessor(256, "cuda")
|
||||
video_frames = video_processor.affine_transform_video("assets/demo2_video.mp4")
|
||||
write_video("output.mp4", video_frames, fps=25)
|
||||
BIN
models/LatentSync/latentsync/utils/mask.png
Normal file
BIN
models/LatentSync/latentsync/utils/mask.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.8 KiB |
BIN
models/LatentSync/latentsync/utils/mask2.png
Normal file
BIN
models/LatentSync/latentsync/utils/mask2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.2 KiB |
BIN
models/LatentSync/latentsync/utils/mask3.png
Normal file
BIN
models/LatentSync/latentsync/utils/mask3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.1 KiB |
BIN
models/LatentSync/latentsync/utils/mask4.png
Normal file
BIN
models/LatentSync/latentsync/utils/mask4.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.2 KiB |
289
models/LatentSync/latentsync/utils/util.py
Normal file
289
models/LatentSync/latentsync/utils/util.py
Normal file
@@ -0,0 +1,289 @@
|
||||
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import json
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
import imageio
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
import torch.distributed as dist
|
||||
from torchvision import transforms
|
||||
|
||||
from einops import rearrange
|
||||
import cv2
|
||||
from decord import AudioReader, VideoReader
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
|
||||
# Machine epsilon for a float32 (single precision)
|
||||
eps = np.finfo(np.float32).eps
|
||||
|
||||
|
||||
def read_json(filepath: str):
|
||||
with open(filepath) as f:
|
||||
json_dict = json.load(f)
|
||||
return json_dict
|
||||
|
||||
|
||||
def read_video(video_path: str, change_fps=True, use_decord=True):
|
||||
if change_fps:
|
||||
temp_dir = "temp"
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir)
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
command = (
|
||||
f"ffmpeg -loglevel error -y -nostdin -i {video_path} -r 25 -crf 18 {os.path.join(temp_dir, 'video.mp4')}"
|
||||
)
|
||||
subprocess.run(command, shell=True)
|
||||
target_video_path = os.path.join(temp_dir, "video.mp4")
|
||||
else:
|
||||
target_video_path = video_path
|
||||
|
||||
if use_decord:
|
||||
return read_video_decord(target_video_path)
|
||||
else:
|
||||
return read_video_cv2(target_video_path)
|
||||
|
||||
|
||||
def read_video_decord(video_path: str):
|
||||
vr = VideoReader(video_path)
|
||||
video_frames = vr[:].asnumpy()
|
||||
vr.seek(0)
|
||||
return video_frames
|
||||
|
||||
|
||||
def read_video_cv2(video_path: str):
|
||||
# Open the video file
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
# Check if the video was opened successfully
|
||||
if not cap.isOpened():
|
||||
print("Error: Could not open video.")
|
||||
return np.array([])
|
||||
|
||||
frames = []
|
||||
|
||||
while True:
|
||||
# Read a frame
|
||||
ret, frame = cap.read()
|
||||
|
||||
# If frame is read correctly ret is True
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# Convert BGR to RGB
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
frames.append(frame_rgb)
|
||||
|
||||
# Release the video capture object
|
||||
cap.release()
|
||||
|
||||
return np.array(frames)
|
||||
|
||||
|
||||
def read_audio(audio_path: str, audio_sample_rate: int = 16000):
|
||||
if audio_path is None:
|
||||
raise ValueError("Audio path is required.")
|
||||
ar = AudioReader(audio_path, sample_rate=audio_sample_rate, mono=True)
|
||||
|
||||
# To access the audio samples
|
||||
audio_samples = torch.from_numpy(ar[:].asnumpy())
|
||||
audio_samples = audio_samples.squeeze(0)
|
||||
|
||||
return audio_samples
|
||||
|
||||
|
||||
def write_video(video_output_path: str, video_frames: np.ndarray, fps: int):
|
||||
with imageio.get_writer(
|
||||
video_output_path,
|
||||
fps=fps,
|
||||
codec="libx264",
|
||||
macro_block_size=None,
|
||||
ffmpeg_params=["-crf", "13"],
|
||||
ffmpeg_log_level="error",
|
||||
) as writer:
|
||||
for video_frame in video_frames:
|
||||
writer.append_data(video_frame)
|
||||
|
||||
|
||||
def write_video_cv2(video_output_path: str, video_frames: np.ndarray, fps: int):
|
||||
height, width = video_frames[0].shape[:2]
|
||||
out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
|
||||
# out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*"vp09"), fps, (width, height))
|
||||
for frame in video_frames:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||
out.write(frame)
|
||||
out.release()
|
||||
|
||||
|
||||
def init_dist(backend="nccl", **kwargs):
|
||||
"""Initializes distributed environment."""
|
||||
rank = int(os.environ["RANK"])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus == 0:
|
||||
raise RuntimeError("No GPUs available for training.")
|
||||
local_rank = rank % num_gpus
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
return local_rank
|
||||
|
||||
|
||||
def zero_rank_print(s):
|
||||
if dist.is_initialized() and dist.get_rank() == 0:
|
||||
print("### " + s)
|
||||
|
||||
|
||||
def zero_rank_log(logger, message: str):
|
||||
if dist.is_initialized() and dist.get_rank() == 0:
|
||||
logger.info(message)
|
||||
|
||||
|
||||
def check_video_fps(video_path: str):
|
||||
cam = cv2.VideoCapture(video_path)
|
||||
fps = cam.get(cv2.CAP_PROP_FPS)
|
||||
if fps != 25:
|
||||
raise ValueError(f"Video FPS is not 25, it is {fps}. Please convert the video to 25 FPS.")
|
||||
|
||||
|
||||
def one_step_sampling(ddim_scheduler, pred_noise, timesteps, x_t):
|
||||
# Compute alphas, betas
|
||||
alpha_prod_t = ddim_scheduler.alphas_cumprod[timesteps].to(dtype=pred_noise.dtype)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/abs/2010.02502
|
||||
if ddim_scheduler.config.prediction_type == "epsilon":
|
||||
beta_prod_t = beta_prod_t[:, None, None, None, None]
|
||||
alpha_prod_t = alpha_prod_t[:, None, None, None, None]
|
||||
pred_original_sample = (x_t - beta_prod_t ** (0.5) * pred_noise) / alpha_prod_t ** (0.5)
|
||||
else:
|
||||
raise NotImplementedError("This prediction type is not implemented yet")
|
||||
|
||||
# Clip "predicted x_0"
|
||||
if ddim_scheduler.config.clip_sample:
|
||||
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
||||
return pred_original_sample
|
||||
|
||||
|
||||
def plot_loss_chart(save_path: str, *args):
|
||||
# Creating the plot
|
||||
plt.figure()
|
||||
for loss_line in args:
|
||||
plt.plot(loss_line[1], loss_line[2], label=loss_line[0])
|
||||
plt.xlabel("Step")
|
||||
plt.ylabel("Loss")
|
||||
plt.legend()
|
||||
|
||||
# Save the figure to a file
|
||||
plt.savefig(save_path)
|
||||
|
||||
# Close the figure to free memory
|
||||
plt.close()
|
||||
|
||||
|
||||
CRED = "\033[91m"
|
||||
CEND = "\033[0m"
|
||||
|
||||
|
||||
def red_text(text: str):
|
||||
return f"{CRED}{text}{CEND}"
|
||||
|
||||
|
||||
log_loss = nn.BCELoss(reduction="none")
|
||||
|
||||
|
||||
def cosine_loss(vision_embeds, audio_embeds, y):
|
||||
sims = nn.functional.cosine_similarity(vision_embeds, audio_embeds)
|
||||
# sims[sims!=sims] = 0 # remove nan
|
||||
# sims = sims.clamp(0, 1)
|
||||
loss = log_loss(sims.unsqueeze(1), y).squeeze()
|
||||
return loss
|
||||
|
||||
|
||||
def save_image(image, save_path):
|
||||
# input size (C, H, W)
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = (image * 255).to(torch.uint8)
|
||||
image = transforms.ToPILImage()(image)
|
||||
# Save the image copy
|
||||
image.save(save_path)
|
||||
|
||||
# Close the image file
|
||||
image.close()
|
||||
|
||||
|
||||
def gather_loss(loss, device):
|
||||
# Sum the local loss across all processes
|
||||
local_loss = loss.item()
|
||||
global_loss = torch.tensor(local_loss, dtype=torch.float32).to(device)
|
||||
dist.all_reduce(global_loss, op=dist.ReduceOp.SUM)
|
||||
|
||||
# Calculate the average loss across all processes
|
||||
global_average_loss = global_loss.item() / dist.get_world_size()
|
||||
return global_average_loss
|
||||
|
||||
|
||||
def gather_video_paths_recursively(input_dir):
|
||||
print(f"Recursively gathering video paths of {input_dir} ...")
|
||||
paths = []
|
||||
gather_video_paths(input_dir, paths)
|
||||
return paths
|
||||
|
||||
|
||||
def gather_video_paths(input_dir, paths):
|
||||
for file in sorted(os.listdir(input_dir)):
|
||||
if file.endswith(".mp4"):
|
||||
filepath = os.path.join(input_dir, file)
|
||||
paths.append(filepath)
|
||||
elif os.path.isdir(os.path.join(input_dir, file)):
|
||||
gather_video_paths(os.path.join(input_dir, file), paths)
|
||||
|
||||
|
||||
def count_video_time(video_path):
|
||||
video = cv2.VideoCapture(video_path)
|
||||
|
||||
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
fps = video.get(cv2.CAP_PROP_FPS)
|
||||
return frame_count / fps
|
||||
|
||||
|
||||
def check_ffmpeg_installed():
|
||||
# Run the ffmpeg command with the -version argument to check if it's installed
|
||||
result = subprocess.run("ffmpeg -version", stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
||||
if not result.returncode == 0:
|
||||
raise FileNotFoundError("ffmpeg not found, please install it by:\n $ conda install -c conda-forge ffmpeg")
|
||||
|
||||
|
||||
def check_model_and_download(ckpt_path: str, huggingface_model_id: str = "ByteDance/LatentSync-1.5"):
|
||||
if not os.path.exists(ckpt_path):
|
||||
ckpt_path_obj = Path(ckpt_path)
|
||||
download_cmd = f"huggingface-cli download {huggingface_model_id} {Path(*ckpt_path_obj.parts[1:])} --local-dir {Path(ckpt_path_obj.parts[0])}"
|
||||
subprocess.run(download_cmd, shell=True)
|
||||
|
||||
|
||||
class dummy_context:
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
167
models/LatentSync/latentsync/whisper/audio2feature.py
Normal file
167
models/LatentSync/latentsync/whisper/audio2feature.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Adapted from https://github.com/TMElyralab/MuseTalk/blob/main/musetalk/whisper/audio2feature.py
|
||||
|
||||
from .whisper import load_model
|
||||
import numpy as np
|
||||
import torch
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class Audio2Feature:
|
||||
def __init__(
|
||||
self,
|
||||
model_path="checkpoints/whisper/tiny.pt",
|
||||
device=None,
|
||||
audio_embeds_cache_dir=None,
|
||||
num_frames=16,
|
||||
audio_feat_length=[2, 2],
|
||||
):
|
||||
self.model = load_model(model_path, device)
|
||||
self.audio_embeds_cache_dir = audio_embeds_cache_dir
|
||||
if audio_embeds_cache_dir is not None and audio_embeds_cache_dir != "":
|
||||
Path(audio_embeds_cache_dir).mkdir(parents=True, exist_ok=True)
|
||||
self.num_frames = num_frames
|
||||
self.embedding_dim = self.model.dims.n_audio_state
|
||||
self.audio_feat_length = audio_feat_length
|
||||
|
||||
def get_sliced_feature(self, feature_array, vid_idx, fps=25):
|
||||
"""
|
||||
Get sliced features based on a given index
|
||||
:param feature_array:
|
||||
:param start_idx: the start index of the feature
|
||||
:param audio_feat_length:
|
||||
:return:
|
||||
"""
|
||||
length = len(feature_array)
|
||||
selected_feature = []
|
||||
selected_idx = []
|
||||
|
||||
center_idx = int(vid_idx * 50 / fps)
|
||||
left_idx = center_idx - self.audio_feat_length[0] * 2
|
||||
right_idx = center_idx + (self.audio_feat_length[1] + 1) * 2
|
||||
|
||||
for idx in range(left_idx, right_idx):
|
||||
idx = max(0, idx)
|
||||
idx = min(length - 1, idx)
|
||||
x = feature_array[idx]
|
||||
selected_feature.append(x)
|
||||
selected_idx.append(idx)
|
||||
|
||||
selected_feature = torch.cat(selected_feature, dim=0)
|
||||
selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384
|
||||
return selected_feature, selected_idx
|
||||
|
||||
def get_sliced_feature_sparse(self, feature_array, vid_idx, fps=25):
|
||||
"""
|
||||
Get sliced features based on a given index
|
||||
:param feature_array:
|
||||
:param start_idx: the start index of the feature
|
||||
:param audio_feat_length:
|
||||
:return:
|
||||
"""
|
||||
length = len(feature_array)
|
||||
selected_feature = []
|
||||
selected_idx = []
|
||||
|
||||
for dt in range(-self.audio_feat_length[0], self.audio_feat_length[1] + 1):
|
||||
left_idx = int((vid_idx + dt) * 50 / fps)
|
||||
if left_idx < 1 or left_idx > length - 1:
|
||||
left_idx = max(0, left_idx)
|
||||
left_idx = min(length - 1, left_idx)
|
||||
|
||||
x = feature_array[left_idx]
|
||||
x = x[np.newaxis, :, :]
|
||||
x = np.repeat(x, 2, axis=0)
|
||||
selected_feature.append(x)
|
||||
selected_idx.append(left_idx)
|
||||
selected_idx.append(left_idx)
|
||||
else:
|
||||
x = feature_array[left_idx - 1 : left_idx + 1]
|
||||
selected_feature.append(x)
|
||||
selected_idx.append(left_idx - 1)
|
||||
selected_idx.append(left_idx)
|
||||
selected_feature = np.concatenate(selected_feature, axis=0)
|
||||
selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384
|
||||
selected_feature = torch.from_numpy(selected_feature)
|
||||
return selected_feature, selected_idx
|
||||
|
||||
def feature2chunks(self, feature_array, fps):
|
||||
whisper_chunks = []
|
||||
whisper_idx_multiplier = 50.0 / fps
|
||||
i = 0
|
||||
print(f"video in {fps} FPS, audio idx in 50FPS")
|
||||
|
||||
while True:
|
||||
start_idx = int(i * whisper_idx_multiplier)
|
||||
selected_feature, selected_idx = self.get_sliced_feature(feature_array=feature_array, vid_idx=i, fps=fps)
|
||||
# print(f"i:{i},selected_idx {selected_idx}")
|
||||
whisper_chunks.append(selected_feature)
|
||||
i += 1
|
||||
if start_idx > len(feature_array):
|
||||
break
|
||||
|
||||
return whisper_chunks
|
||||
|
||||
def _audio2feat(self, audio_path: str):
|
||||
# get the sample rate of the audio
|
||||
result = self.model.transcribe(audio_path)
|
||||
embed_list = []
|
||||
for emb in result["segments"]:
|
||||
encoder_embeddings = emb["encoder_embeddings"]
|
||||
encoder_embeddings = encoder_embeddings.transpose(0, 2, 1, 3)
|
||||
encoder_embeddings = encoder_embeddings.squeeze(0)
|
||||
start_idx = int(emb["start"])
|
||||
end_idx = int(emb["end"])
|
||||
emb_end_idx = int((end_idx - start_idx) / 2)
|
||||
embed_list.append(encoder_embeddings[:emb_end_idx])
|
||||
concatenated_array = torch.from_numpy(np.concatenate(embed_list, axis=0))
|
||||
return concatenated_array
|
||||
|
||||
def audio2feat(self, audio_path):
|
||||
if self.audio_embeds_cache_dir == "" or self.audio_embeds_cache_dir is None:
|
||||
return self._audio2feat(audio_path)
|
||||
|
||||
audio_embeds_cache_path = os.path.join(
|
||||
self.audio_embeds_cache_dir, os.path.basename(audio_path).replace(".mp4", "_embeds.pt")
|
||||
)
|
||||
|
||||
if os.path.isfile(audio_embeds_cache_path):
|
||||
try:
|
||||
audio_feat = torch.load(audio_embeds_cache_path, weights_only=True)
|
||||
except Exception as e:
|
||||
print(f"{type(e).__name__} - {e} - {audio_embeds_cache_path}")
|
||||
os.remove(audio_embeds_cache_path)
|
||||
audio_feat = self._audio2feat(audio_path)
|
||||
torch.save(audio_feat, audio_embeds_cache_path)
|
||||
else:
|
||||
audio_feat = self._audio2feat(audio_path)
|
||||
torch.save(audio_feat, audio_embeds_cache_path)
|
||||
|
||||
return audio_feat
|
||||
|
||||
def crop_overlap_audio_window(self, audio_feat, start_index):
|
||||
selected_feature_list = []
|
||||
for i in range(start_index, start_index + self.num_frames):
|
||||
selected_feature, selected_idx = self.get_sliced_feature(feature_array=audio_feat, vid_idx=i, fps=25)
|
||||
selected_feature_list.append(selected_feature)
|
||||
mel_overlap = torch.stack(selected_feature_list)
|
||||
return mel_overlap
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
audio_encoder = Audio2Feature(model_path="checkpoints/whisper/tiny.pt")
|
||||
audio_path = "assets/demo1_audio.wav"
|
||||
array = audio_encoder.audio2feat(audio_path)
|
||||
print(array.shape)
|
||||
fps = 25
|
||||
whisper_idx_multiplier = 50.0 / fps
|
||||
|
||||
i = 0
|
||||
print(f"video in {fps} FPS, audio idx in 50FPS")
|
||||
while True:
|
||||
start_idx = int(i * whisper_idx_multiplier)
|
||||
selected_feature, selected_idx = audio_encoder.get_sliced_feature(feature_array=array, vid_idx=i, fps=fps)
|
||||
print(f"video idx {i},\t audio idx {selected_idx},\t shape {selected_feature.shape}")
|
||||
i += 1
|
||||
if start_idx > len(array):
|
||||
break
|
||||
122
models/LatentSync/latentsync/whisper/whisper/__init__.py
Normal file
122
models/LatentSync/latentsync/whisper/whisper/__init__.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from .model import Whisper, ModelDimensions
|
||||
from .transcribe import transcribe
|
||||
|
||||
|
||||
_MODELS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt",
|
||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
download_target = os.path.join(root, os.path.basename(url))
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_bytes if in_memory else download_target
|
||||
else:
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(
|
||||
total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024
|
||||
) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
return model_bytes if in_memory else download_target
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load_model(
|
||||
name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False
|
||||
) -> Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
one of the official model names listed by `whisper.available_models()`, or
|
||||
path to a model checkpoint containing the model dimensions and the model state_dict.
|
||||
device : Union[str, torch.device]
|
||||
the PyTorch device to put the model into
|
||||
download_root: str
|
||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||
in_memory: bool
|
||||
whether to preload the model weights into host memory
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : Whisper
|
||||
The Whisper ASR model instance
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if download_root is None:
|
||||
download_root = os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
elif os.path.isfile(name):
|
||||
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||
else:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
with io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") as fp:
|
||||
checkpoint = torch.load(fp, map_location=device, weights_only=True)
|
||||
del checkpoint_file
|
||||
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
model = Whisper(dims)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
del checkpoint
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return model.to(device)
|
||||
4
models/LatentSync/latentsync/whisper/whisper/__main__.py
Normal file
4
models/LatentSync/latentsync/whisper/whisper/__main__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .transcribe import cli
|
||||
|
||||
|
||||
cli()
|
||||
50001
models/LatentSync/latentsync/whisper/whisper/assets/gpt2/merges.txt
Normal file
50001
models/LatentSync/latentsync/whisper/whisper/assets/gpt2/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1 @@
|
||||
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
||||
@@ -0,0 +1 @@
|
||||
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}
|
||||
File diff suppressed because one or more lines are too long
Binary file not shown.
@@ -0,0 +1 @@
|
||||
{"<|endoftext|>": 50257}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1 @@
|
||||
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}
|
||||
@@ -0,0 +1 @@
|
||||
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}
|
||||
File diff suppressed because one or more lines are too long
125
models/LatentSync/latentsync/whisper/whisper/audio.py
Normal file
125
models/LatentSync/latentsync/whisper/whisper/audio.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Union
|
||||
|
||||
import ffmpeg
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils import exact_div
|
||||
|
||||
# hard-coded audio hyperparameters
|
||||
SAMPLE_RATE = 16000
|
||||
N_FFT = 400
|
||||
N_MELS = 80
|
||||
HOP_LENGTH = 160
|
||||
CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
|
||||
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
|
||||
|
||||
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
"""
|
||||
Open an audio file and read as mono waveform, resampling as necessary
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file: str
|
||||
The audio file to open
|
||||
|
||||
sr: int
|
||||
The sample rate to resample the audio if necessary
|
||||
|
||||
Returns
|
||||
-------
|
||||
A NumPy array containing the audio waveform, in float32 dtype.
|
||||
"""
|
||||
try:
|
||||
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
||||
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
|
||||
out, _ = (
|
||||
ffmpeg.input(file, threads=0)
|
||||
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
|
||||
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
|
||||
)
|
||||
except ffmpeg.Error as e:
|
||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||
|
||||
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||
|
||||
|
||||
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||
"""
|
||||
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
||||
"""
|
||||
if torch.is_tensor(array):
|
||||
if array.shape[axis] > length:
|
||||
array = array.index_select(dim=axis, index=torch.arange(length))
|
||||
|
||||
if array.shape[axis] < length:
|
||||
pad_widths = [(0, 0)] * array.ndim
|
||||
pad_widths[axis] = (0, length - array.shape[axis])
|
||||
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
||||
else:
|
||||
if array.shape[axis] > length:
|
||||
array = array.take(indices=range(length), axis=axis)
|
||||
|
||||
if array.shape[axis] < length:
|
||||
pad_widths = [(0, 0)] * array.ndim
|
||||
pad_widths[axis] = (0, length - array.shape[axis])
|
||||
array = np.pad(array, pad_widths)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
|
||||
"""
|
||||
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||
Allows decoupling librosa dependency; saved using:
|
||||
|
||||
np.savez_compressed(
|
||||
"mel_filters.npz",
|
||||
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
||||
)
|
||||
"""
|
||||
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
|
||||
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
|
||||
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
||||
|
||||
|
||||
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
|
||||
"""
|
||||
Compute the log-Mel spectrogram of
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
||||
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
||||
|
||||
n_mels: int
|
||||
The number of Mel-frequency filters, only 80 is supported
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor, shape = (80, n_frames)
|
||||
A Tensor that contains the Mel spectrogram
|
||||
"""
|
||||
if not torch.is_tensor(audio):
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
audio = torch.from_numpy(audio)
|
||||
|
||||
window = torch.hann_window(N_FFT).to(audio.device)
|
||||
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||
|
||||
magnitudes = stft[:, :-1].abs() ** 2
|
||||
|
||||
filters = mel_filters(audio.device, n_mels)
|
||||
mel_spec = filters @ magnitudes
|
||||
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec
|
||||
729
models/LatentSync/latentsync/whisper/whisper/decoding.py
Normal file
729
models/LatentSync/latentsync/whisper/whisper/decoding.py
Normal file
@@ -0,0 +1,729 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from .audio import CHUNK_LENGTH
|
||||
from .tokenizer import Tokenizer, get_tokenizer
|
||||
from .utils import compression_ratio
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
|
||||
"""
|
||||
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
||||
of the most probable language tokens and the probability distribution over all language tokens.
|
||||
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
||||
|
||||
Returns
|
||||
-------
|
||||
language_tokens : Tensor, shape = (n_audio,)
|
||||
ids of the most probable language tokens, which appears after the startoftranscript token.
|
||||
language_probs : List[Dict[str, float]], length = n_audio
|
||||
list of dictionaries containing the probability distribution over all languages.
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer(model.is_multilingual)
|
||||
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
|
||||
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
|
||||
|
||||
single = mel.ndim == 2
|
||||
if single:
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
# skip encoder forward pass if already-encoded audio features were given
|
||||
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
||||
mel = model.encoder(mel)
|
||||
|
||||
# forward pass using a single token, startoftranscript
|
||||
n_audio = mel.shape[0]
|
||||
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
||||
logits = model.logits(x, mel)[:, 0]
|
||||
|
||||
# collect detected languages; suppress all non-language tokens
|
||||
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||
mask[list(tokenizer.all_language_tokens)] = False
|
||||
logits[:, mask] = -np.inf
|
||||
language_tokens = logits.argmax(dim=-1)
|
||||
language_token_probs = logits.softmax(dim=-1).cpu()
|
||||
language_probs = [
|
||||
{
|
||||
c: language_token_probs[i, j].item()
|
||||
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
||||
}
|
||||
for i in range(n_audio)
|
||||
]
|
||||
|
||||
if single:
|
||||
language_tokens = language_tokens[0]
|
||||
language_probs = language_probs[0]
|
||||
|
||||
return language_tokens, language_probs
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodingOptions:
|
||||
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
|
||||
language: Optional[str] = None # language that the audio is in; uses detected language if None
|
||||
|
||||
# sampling-related options
|
||||
temperature: float = 0.0
|
||||
sample_len: Optional[int] = None # maximum number of tokens to sample
|
||||
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
|
||||
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
|
||||
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
|
||||
|
||||
# options for ranking generations (either beams or best-of-N samples)
|
||||
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
|
||||
|
||||
# prompt, prefix, and token suppression
|
||||
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
|
||||
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
|
||||
suppress_blank: bool = True # this will suppress blank outputs
|
||||
|
||||
# list of tokens ids (or comma-separated token ids) to suppress
|
||||
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
||||
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
||||
|
||||
# timestamp sampling options
|
||||
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
||||
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
|
||||
|
||||
# implementation details
|
||||
fp16: bool = True # use fp16 for most of the calculation
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodingResult:
|
||||
audio_features: Tensor
|
||||
language: str
|
||||
encoder_embeddings: np.ndarray
|
||||
decoder_embeddings: np.ndarray
|
||||
language_probs: Optional[Dict[str, float]] = None
|
||||
tokens: List[int] = field(default_factory=list)
|
||||
text: str = ""
|
||||
avg_logprob: float = np.nan
|
||||
no_speech_prob: float = np.nan
|
||||
temperature: float = np.nan
|
||||
compression_ratio: float = np.nan
|
||||
|
||||
|
||||
class Inference:
|
||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||
"""Perform a forward pass on the decoder and return per-token logits"""
|
||||
raise NotImplementedError
|
||||
|
||||
def rearrange_kv_cache(self, source_indices) -> None:
|
||||
"""Update the key-value cache according to the updated beams"""
|
||||
raise NotImplementedError
|
||||
|
||||
def cleanup_caching(self) -> None:
|
||||
"""Clean up any resources or hooks after decoding is finished"""
|
||||
pass
|
||||
|
||||
|
||||
class PyTorchInference(Inference):
|
||||
def __init__(self, model: "Whisper", initial_token_length: int):
|
||||
self.model: "Whisper" = model
|
||||
self.initial_token_length = initial_token_length
|
||||
self.kv_cache = {}
|
||||
self.hooks = []
|
||||
|
||||
def logits(self, tokens: Tensor, audio_features: Tensor, include_embeddings=False) -> Tensor:
|
||||
if not self.kv_cache:
|
||||
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
||||
|
||||
if tokens.shape[-1] > self.initial_token_length:
|
||||
# only need to use the last token except in the first forward pass
|
||||
tokens = tokens[:, -1:]
|
||||
|
||||
return_val = self.model.decoder(tokens, audio_features,
|
||||
kv_cache=self.kv_cache, include_embeddings=include_embeddings)
|
||||
return return_val
|
||||
|
||||
def cleanup_caching(self):
|
||||
for hook in self.hooks:
|
||||
hook.remove()
|
||||
|
||||
self.kv_cache = {}
|
||||
self.hooks = []
|
||||
|
||||
def rearrange_kv_cache(self, source_indices):
|
||||
for module, tensor in self.kv_cache.items():
|
||||
# update the key/value cache to contain the selected sequences
|
||||
self.kv_cache[module] = tensor[source_indices].detach()
|
||||
|
||||
|
||||
class SequenceRanker:
|
||||
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
|
||||
"""
|
||||
Given a list of groups of samples and their cumulative log probabilities,
|
||||
return the indices of the samples in each group to select as the final result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MaximumLikelihoodRanker(SequenceRanker):
|
||||
"""
|
||||
Select the sample with the highest log probabilities, penalized using either
|
||||
a simple length normalization or Google NMT paper's length penalty
|
||||
"""
|
||||
|
||||
def __init__(self, length_penalty: Optional[float]):
|
||||
self.length_penalty = length_penalty
|
||||
|
||||
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
|
||||
def scores(logprobs, lengths):
|
||||
result = []
|
||||
for logprob, length in zip(logprobs, lengths):
|
||||
if self.length_penalty is None:
|
||||
penalty = length
|
||||
else:
|
||||
# from the Google NMT paper
|
||||
penalty = ((5 + length) / 6) ** self.length_penalty
|
||||
result.append(logprob / penalty)
|
||||
return result
|
||||
|
||||
# get the sequence with the highest score
|
||||
lengths = [[len(t) for t in s] for s in tokens]
|
||||
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
||||
|
||||
|
||||
class TokenDecoder:
|
||||
def reset(self):
|
||||
"""Initialize any stateful variables for decoding a new sequence"""
|
||||
|
||||
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||
"""Specify how to select the next token, based on the current trace and logits
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||
|
||||
logits : Tensor, shape = (n_batch, vocab_size)
|
||||
per-token logits of the probability distribution at the current step
|
||||
|
||||
sum_logprobs : Tensor, shape = (n_batch)
|
||||
cumulative log probabilities for each sequence
|
||||
|
||||
Returns
|
||||
-------
|
||||
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
||||
the tokens, appended with the selected next token
|
||||
|
||||
completed : bool
|
||||
True if all sequences has reached the end of text
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def finalize(
|
||||
self, tokens: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
||||
"""Finalize search and return the final candidate sequences
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence
|
||||
|
||||
sum_logprobs : Tensor, shape = (n_audio, n_group)
|
||||
cumulative log probabilities for each sequence
|
||||
|
||||
Returns
|
||||
-------
|
||||
tokens : Sequence[Sequence[Tensor]], length = n_audio
|
||||
sequence of Tensors containing candidate token sequences, for each audio input
|
||||
|
||||
sum_logprobs : List[List[float]], length = n_audio
|
||||
sequence of cumulative log probabilities corresponding to the above
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GreedyDecoder(TokenDecoder):
|
||||
def __init__(self, temperature: float, eot: int):
|
||||
self.temperature = temperature
|
||||
self.eot = eot
|
||||
|
||||
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||
temperature = self.temperature
|
||||
if temperature == 0:
|
||||
next_tokens = logits.argmax(dim=-1)
|
||||
else:
|
||||
next_tokens = Categorical(logits=logits / temperature).sample()
|
||||
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
||||
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
||||
|
||||
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
||||
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
||||
|
||||
completed = (tokens[:, -1] == self.eot).all()
|
||||
return tokens, completed
|
||||
|
||||
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
||||
# make sure each sequence has at least one EOT token at the end
|
||||
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
||||
return tokens, sum_logprobs.tolist()
|
||||
|
||||
|
||||
class BeamSearchDecoder(TokenDecoder):
|
||||
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
|
||||
self.beam_size = beam_size
|
||||
self.eot = eot
|
||||
self.inference = inference
|
||||
self.patience = patience or 1.0
|
||||
self.max_candidates: int = round(beam_size * self.patience)
|
||||
self.finished_sequences = None
|
||||
|
||||
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
|
||||
|
||||
def reset(self):
|
||||
self.finished_sequences = None
|
||||
|
||||
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
|
||||
if tokens.shape[0] % self.beam_size != 0:
|
||||
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
||||
|
||||
n_audio = tokens.shape[0] // self.beam_size
|
||||
if self.finished_sequences is None: # for the first update
|
||||
self.finished_sequences = [{} for _ in range(n_audio)]
|
||||
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
next_tokens, source_indices, finished_sequences = [], [], []
|
||||
for i in range(n_audio):
|
||||
scores, sources, finished = {}, {}, {}
|
||||
|
||||
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
||||
for j in range(self.beam_size):
|
||||
idx = i * self.beam_size + j
|
||||
prefix = tokens[idx].tolist()
|
||||
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
||||
new_logprob = (sum_logprobs[idx] + logprob).item()
|
||||
sequence = tuple(prefix + [token.item()])
|
||||
scores[sequence] = new_logprob
|
||||
sources[sequence] = idx
|
||||
|
||||
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
||||
saved = 0
|
||||
for sequence in sorted(scores, key=scores.get, reverse=True):
|
||||
if sequence[-1] == self.eot:
|
||||
finished[sequence] = scores[sequence]
|
||||
else:
|
||||
sum_logprobs[len(next_tokens)] = scores[sequence]
|
||||
next_tokens.append(sequence)
|
||||
source_indices.append(sources[sequence])
|
||||
|
||||
saved += 1
|
||||
if saved == self.beam_size:
|
||||
break
|
||||
|
||||
finished_sequences.append(finished)
|
||||
|
||||
tokens = torch.tensor(next_tokens, device=tokens.device)
|
||||
self.inference.rearrange_kv_cache(source_indices)
|
||||
|
||||
# add newly finished sequences to self.finished_sequences
|
||||
assert len(self.finished_sequences) == len(finished_sequences)
|
||||
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
|
||||
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
||||
if len(previously_finished) >= self.max_candidates:
|
||||
break # the candidate list is full
|
||||
previously_finished[seq] = newly_finished[seq]
|
||||
|
||||
# mark as completed if all audio has enough number of samples
|
||||
completed = all(
|
||||
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
|
||||
)
|
||||
return tokens, completed
|
||||
|
||||
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
||||
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
||||
sum_logprobs = sum_logprobs.cpu()
|
||||
for i, sequences in enumerate(self.finished_sequences):
|
||||
if len(sequences) < self.beam_size: # when not enough sequences are finished
|
||||
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
||||
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
||||
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
||||
if len(sequences) >= self.beam_size:
|
||||
break
|
||||
|
||||
tokens: List[List[Tensor]] = [
|
||||
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
|
||||
]
|
||||
sum_logprobs: List[List[float]] = [
|
||||
list(sequences.values()) for sequences in self.finished_sequences
|
||||
]
|
||||
return tokens, sum_logprobs
|
||||
|
||||
|
||||
class LogitFilter:
|
||||
def apply(self, logits: Tensor, tokens: Tensor) -> None:
|
||||
"""Apply any filtering or masking to logits in-place
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logits : Tensor, shape = (n_batch, vocab_size)
|
||||
per-token logits of the probability distribution at the current step
|
||||
|
||||
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SuppressBlank(LogitFilter):
|
||||
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
||||
self.tokenizer = tokenizer
|
||||
self.sample_begin = sample_begin
|
||||
|
||||
def apply(self, logits: Tensor, tokens: Tensor):
|
||||
if tokens.shape[1] == self.sample_begin:
|
||||
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||
|
||||
|
||||
class SuppressTokens(LogitFilter):
|
||||
def __init__(self, suppress_tokens: Sequence[int]):
|
||||
self.suppress_tokens = list(suppress_tokens)
|
||||
|
||||
def apply(self, logits: Tensor, tokens: Tensor):
|
||||
logits[:, self.suppress_tokens] = -np.inf
|
||||
|
||||
|
||||
class ApplyTimestampRules(LogitFilter):
|
||||
def __init__(
|
||||
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.sample_begin = sample_begin
|
||||
self.max_initial_timestamp_index = max_initial_timestamp_index
|
||||
|
||||
def apply(self, logits: Tensor, tokens: Tensor):
|
||||
# suppress <|notimestamps|> which is handled by without_timestamps
|
||||
if self.tokenizer.no_timestamps is not None:
|
||||
logits[:, self.tokenizer.no_timestamps] = -np.inf
|
||||
|
||||
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||
for k in range(tokens.shape[0]):
|
||||
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
|
||||
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
||||
|
||||
if last_was_timestamp:
|
||||
if penultimate_was_timestamp: # has to be non-timestamp
|
||||
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
||||
else: # cannot be normal text tokens
|
||||
logits[k, : self.tokenizer.eot] = -np.inf
|
||||
|
||||
# apply the `max_initial_timestamp` option
|
||||
if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
|
||||
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
||||
logits[:, last_allowed + 1 :] = -np.inf
|
||||
|
||||
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
for k in range(tokens.shape[0]):
|
||||
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
|
||||
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
||||
if timestamp_logprob > max_text_token_logprob:
|
||||
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
|
||||
|
||||
class DecodingTask:
|
||||
inference: Inference
|
||||
sequence_ranker: SequenceRanker
|
||||
decoder: TokenDecoder
|
||||
logit_filters: List[LogitFilter]
|
||||
|
||||
def __init__(self, model: "Whisper", options: DecodingOptions):
|
||||
self.model = model
|
||||
|
||||
language = options.language or "en"
|
||||
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
|
||||
self.tokenizer: Tokenizer = tokenizer
|
||||
self.options: DecodingOptions = self._verify_options(options)
|
||||
|
||||
self.n_group: int = options.beam_size or options.best_of or 1
|
||||
self.n_ctx: int = model.dims.n_text_ctx
|
||||
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
||||
|
||||
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
||||
if self.options.without_timestamps:
|
||||
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
||||
|
||||
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
||||
self.sample_begin: int = len(self.initial_tokens)
|
||||
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
||||
|
||||
# inference: implements the forward pass through the decoder, including kv caching
|
||||
self.inference = PyTorchInference(model, len(self.initial_tokens))
|
||||
|
||||
# sequence ranker: implements how to rank a group of sampled sequences
|
||||
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
||||
|
||||
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
||||
if options.beam_size is not None:
|
||||
self.decoder = BeamSearchDecoder(
|
||||
options.beam_size, tokenizer.eot, self.inference, options.patience
|
||||
)
|
||||
else:
|
||||
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
||||
|
||||
# logit filters: applies various rules to suppress or penalize certain tokens
|
||||
self.logit_filters = []
|
||||
if self.options.suppress_blank:
|
||||
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
||||
if self.options.suppress_tokens:
|
||||
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
||||
if not options.without_timestamps:
|
||||
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
||||
max_initial_timestamp_index = None
|
||||
if options.max_initial_timestamp:
|
||||
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
|
||||
self.logit_filters.append(
|
||||
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
|
||||
)
|
||||
|
||||
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
||||
if options.beam_size is not None and options.best_of is not None:
|
||||
raise ValueError("beam_size and best_of can't be given together")
|
||||
if options.temperature == 0:
|
||||
if options.best_of is not None:
|
||||
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
||||
if options.patience is not None and options.beam_size is None:
|
||||
raise ValueError("patience requires beam_size to be given")
|
||||
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
|
||||
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
||||
|
||||
return options
|
||||
|
||||
def _get_initial_tokens(self) -> Tuple[int]:
|
||||
tokens = list(self.sot_sequence)
|
||||
prefix = self.options.prefix
|
||||
prompt = self.options.prompt
|
||||
|
||||
if prefix:
|
||||
prefix_tokens = (
|
||||
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
|
||||
)
|
||||
if self.sample_len is not None:
|
||||
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
||||
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
||||
tokens = tokens + prefix_tokens
|
||||
|
||||
if prompt:
|
||||
prompt_tokens = (
|
||||
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
|
||||
)
|
||||
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
|
||||
|
||||
return tuple(tokens)
|
||||
|
||||
def _get_suppress_tokens(self) -> Tuple[int]:
|
||||
suppress_tokens = self.options.suppress_tokens
|
||||
|
||||
if isinstance(suppress_tokens, str):
|
||||
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
||||
|
||||
if -1 in suppress_tokens:
|
||||
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
||||
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
||||
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
||||
suppress_tokens = [] # interpret empty string as an empty list
|
||||
else:
|
||||
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
||||
|
||||
suppress_tokens.extend(
|
||||
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
|
||||
)
|
||||
if self.tokenizer.no_speech is not None:
|
||||
# no-speech probability is collected separately
|
||||
suppress_tokens.append(self.tokenizer.no_speech)
|
||||
|
||||
return tuple(sorted(set(suppress_tokens)))
|
||||
|
||||
def _get_audio_features(self, mel: Tensor, include_embeddings: bool = False):
|
||||
if self.options.fp16:
|
||||
mel = mel.half()
|
||||
|
||||
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
|
||||
# encoded audio features are given; skip audio encoding
|
||||
audio_features = mel
|
||||
else:
|
||||
result = self.model.encoder(mel, include_embeddings)
|
||||
if include_embeddings:
|
||||
audio_features, embeddings = result
|
||||
else:
|
||||
audio_features = result
|
||||
|
||||
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
|
||||
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
|
||||
|
||||
if include_embeddings:
|
||||
return audio_features, embeddings
|
||||
else:
|
||||
return audio_features
|
||||
|
||||
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
|
||||
languages = [self.options.language] * audio_features.shape[0]
|
||||
lang_probs = None
|
||||
|
||||
if self.options.language is None or self.options.task == "lang_id":
|
||||
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
|
||||
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
||||
if self.options.language is None:
|
||||
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
||||
|
||||
return languages, lang_probs
|
||||
|
||||
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
||||
assert audio_features.shape[0] == tokens.shape[0]
|
||||
n_batch = tokens.shape[0]
|
||||
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
||||
no_speech_probs = [np.nan] * n_batch
|
||||
|
||||
try:
|
||||
embeddings = []
|
||||
for i in range(self.sample_len):
|
||||
logits, token_embeddings = self.inference.logits(tokens, audio_features, include_embeddings=True)
|
||||
|
||||
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
|
||||
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||
|
||||
# now we need to consider the logits at the last token only
|
||||
logits = logits[:, -1]
|
||||
token_embeddings = token_embeddings[:, :, -1]
|
||||
|
||||
# Append embeddings together
|
||||
embeddings.append(token_embeddings)
|
||||
|
||||
# apply the logit filters, e.g. for suppressing or applying penalty to
|
||||
for logit_filter in self.logit_filters:
|
||||
logit_filter.apply(logits, tokens)
|
||||
|
||||
# expand the tokens tensor with the selected next tokens
|
||||
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
||||
|
||||
if completed or tokens.shape[-1] > self.n_ctx:
|
||||
break
|
||||
finally:
|
||||
if completed:
|
||||
embeddings = embeddings[:-1]
|
||||
embeddings = np.stack(embeddings, 2)
|
||||
self.inference.cleanup_caching()
|
||||
|
||||
return tokens, sum_logprobs, no_speech_probs, embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def run(self, mel: Tensor) -> List[DecodingResult]:
|
||||
self.decoder.reset()
|
||||
tokenizer: Tokenizer = self.tokenizer
|
||||
n_audio: int = mel.shape[0]
|
||||
|
||||
# encoder forward pass
|
||||
forward_pass: Tuple[Tensor, np.ndarray] = self._get_audio_features(mel, include_embeddings=True)
|
||||
audio_features, encoder_embeddings = forward_pass
|
||||
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
||||
|
||||
# detect language if requested, overwriting the language token
|
||||
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||
if self.options.task == "lang_id":
|
||||
return [
|
||||
DecodingResult(audio_features=features, language=language, language_probs=probs)
|
||||
for features, language, probs in zip(audio_features, languages, language_probs)
|
||||
]
|
||||
|
||||
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
|
||||
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
|
||||
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
||||
|
||||
# call the main sampling loop
|
||||
tokens, sum_logprobs, no_speech_probs, decoder_embeddings = self._main_loop(audio_features, tokens)
|
||||
|
||||
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
||||
audio_features = audio_features[:: self.n_group]
|
||||
no_speech_probs = no_speech_probs[:: self.n_group]
|
||||
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
||||
|
||||
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
||||
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
||||
|
||||
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
||||
tokens: List[List[Tensor]] = [
|
||||
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
|
||||
]
|
||||
|
||||
# select the top-ranked sample in each group
|
||||
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
||||
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
||||
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
||||
|
||||
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
||||
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
|
||||
|
||||
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
|
||||
if len(set(map(len, fields))) != 1:
|
||||
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
||||
|
||||
return [
|
||||
DecodingResult(
|
||||
audio_features=features,
|
||||
language=language,
|
||||
tokens=tokens,
|
||||
text=text,
|
||||
avg_logprob=avg_logprob,
|
||||
no_speech_prob=no_speech_prob,
|
||||
temperature=self.options.temperature,
|
||||
compression_ratio=compression_ratio(text),
|
||||
encoder_embeddings=encoder_embeddings,
|
||||
decoder_embeddings=decoder_embeddings
|
||||
)
|
||||
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
|
||||
"""
|
||||
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Whisper
|
||||
the Whisper model instance
|
||||
|
||||
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
||||
A tensor containing the Mel spectrogram(s)
|
||||
|
||||
options: DecodingOptions
|
||||
A dataclass that contains all necessary options for decoding 30-second segments
|
||||
|
||||
Returns
|
||||
-------
|
||||
result: Union[DecodingResult, List[DecodingResult]]
|
||||
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
||||
"""
|
||||
single = mel.ndim == 2
|
||||
if single:
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
result = DecodingTask(model, options).run(mel)
|
||||
|
||||
if single:
|
||||
result = result[0]
|
||||
|
||||
return result
|
||||
290
models/LatentSync/latentsync/whisper/whisper/model.py
Normal file
290
models/LatentSync/latentsync/whisper/whisper/model.py
Normal file
@@ -0,0 +1,290 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
|
||||
from .transcribe import transcribe as transcribe_function
|
||||
from .decoding import detect_language as detect_language_function, decode as decode_function
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelDimensions:
|
||||
n_mels: int
|
||||
n_audio_ctx: int
|
||||
n_audio_state: int
|
||||
n_audio_head: int
|
||||
n_audio_layer: int
|
||||
n_vocab: int
|
||||
n_text_ctx: int
|
||||
n_text_state: int
|
||||
n_text_head: int
|
||||
n_text_layer: int
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(
|
||||
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
|
||||
)
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
|
||||
return super()._conv_forward(
|
||||
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||
)
|
||||
|
||||
|
||||
def sinusoids(length, channels, max_timescale=10000):
|
||||
"""Returns sinusoids for positional embedding"""
|
||||
assert channels % 2 == 0
|
||||
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
||||
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
||||
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.query = Linear(n_state, n_state)
|
||||
self.key = Linear(n_state, n_state, bias=False)
|
||||
self.value = Linear(n_state, n_state)
|
||||
self.out = Linear(n_state, n_state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
q = self.query(x)
|
||||
|
||||
if kv_cache is None or xa is None:
|
||||
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
||||
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
||||
k = self.key(x if xa is None else xa)
|
||||
v = self.value(x if xa is None else xa)
|
||||
else:
|
||||
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||
k = kv_cache.get(self.key, self.key(xa))
|
||||
v = kv_cache.get(self.value, self.value(xa))
|
||||
|
||||
wv = self.qkv_attention(q, k, v, mask)
|
||||
return self.out(wv)
|
||||
|
||||
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
|
||||
n_batch, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
|
||||
qk = q @ k
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
|
||||
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
|
||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.attn = MultiHeadAttention(n_state, n_head)
|
||||
self.attn_ln = LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||
|
||||
n_mlp = n_state * 4
|
||||
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
|
||||
self.mlp_ln = LayerNorm(n_state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
|
||||
x = x + self.mlp(self.mlp_ln(x))
|
||||
return x
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||
super().__init__()
|
||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
||||
)
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
|
||||
def forward(self, x: Tensor, include_embeddings: bool = False):
|
||||
"""
|
||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||
the mel spectrogram of the audio
|
||||
include_embeddings: bool
|
||||
whether to include intermediate steps in the output
|
||||
"""
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||
x = (x + self.positional_embedding).to(x.dtype)
|
||||
|
||||
if include_embeddings:
|
||||
embeddings = [x.cpu().detach().numpy()]
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
if include_embeddings:
|
||||
embeddings.append(x.cpu().detach().numpy())
|
||||
|
||||
x = self.ln_post(x)
|
||||
|
||||
if include_embeddings:
|
||||
embeddings = np.stack(embeddings, axis=1)
|
||||
return x, embeddings
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
class TextDecoder(nn.Module):
|
||||
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
|
||||
)
|
||||
self.ln = LayerNorm(n_state)
|
||||
|
||||
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
|
||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None, include_embeddings: bool = False):
|
||||
"""
|
||||
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
||||
the text tokens
|
||||
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
|
||||
the encoded audio features to be attended on
|
||||
include_embeddings : bool
|
||||
Whether to include intermediate values in the output to this function
|
||||
"""
|
||||
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
||||
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
x = x.to(xa.dtype)
|
||||
|
||||
if include_embeddings:
|
||||
embeddings = [x.cpu().detach().numpy()]
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
||||
if include_embeddings:
|
||||
embeddings.append(x.cpu().detach().numpy())
|
||||
|
||||
x = self.ln(x)
|
||||
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
|
||||
|
||||
if include_embeddings:
|
||||
embeddings = np.stack(embeddings, axis=1)
|
||||
return logits, embeddings
|
||||
else:
|
||||
return logits
|
||||
|
||||
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.encoder = AudioEncoder(
|
||||
self.dims.n_mels,
|
||||
self.dims.n_audio_ctx,
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
)
|
||||
self.decoder = TextDecoder(
|
||||
self.dims.n_vocab,
|
||||
self.dims.n_text_ctx,
|
||||
self.dims.n_text_state,
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
)
|
||||
|
||||
def embed_audio(self, mel: torch.Tensor):
|
||||
return self.encoder.forward(mel)
|
||||
|
||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||
return self.decoder.forward(tokens, audio_features)
|
||||
|
||||
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
|
||||
return self.decoder(tokens, self.encoder(mel))
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def is_multilingual(self):
|
||||
return self.dims.n_vocab == 51865
|
||||
|
||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||
"""
|
||||
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
||||
tensors calculated for the previous positions. This method returns a dictionary that stores
|
||||
all caches, and the necessary hooks for the key and value projection modules that save the
|
||||
intermediate tensors to be reused during later calculations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
cache : Dict[nn.Module, torch.Tensor]
|
||||
A dictionary object mapping the key/value projection modules to its cache
|
||||
hooks : List[RemovableHandle]
|
||||
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
||||
"""
|
||||
cache = {**cache} if cache is not None else {}
|
||||
hooks = []
|
||||
|
||||
def save_to_cache(module, _, output):
|
||||
if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
|
||||
cache[module] = output # save as-is, for the first token or cross attention
|
||||
else:
|
||||
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
||||
return cache[module]
|
||||
|
||||
def install_hooks(layer: nn.Module):
|
||||
if isinstance(layer, MultiHeadAttention):
|
||||
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
||||
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
||||
|
||||
self.decoder.apply(install_hooks)
|
||||
return cache, hooks
|
||||
|
||||
detect_language = detect_language_function
|
||||
transcribe = transcribe_function
|
||||
decode = decode_function
|
||||
@@ -0,0 +1,2 @@
|
||||
from .basic import BasicTextNormalizer
|
||||
from .english import EnglishTextNormalizer
|
||||
@@ -0,0 +1,71 @@
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
import regex
|
||||
|
||||
# non-ASCII letters that are not separated by "NFKD" normalization
|
||||
ADDITIONAL_DIACRITICS = {
|
||||
"œ": "oe",
|
||||
"Œ": "OE",
|
||||
"ø": "o",
|
||||
"Ø": "O",
|
||||
"æ": "ae",
|
||||
"Æ": "AE",
|
||||
"ß": "ss",
|
||||
"ẞ": "SS",
|
||||
"đ": "d",
|
||||
"Đ": "D",
|
||||
"ð": "d",
|
||||
"Ð": "D",
|
||||
"þ": "th",
|
||||
"Þ": "th",
|
||||
"ł": "l",
|
||||
"Ł": "L",
|
||||
}
|
||||
|
||||
|
||||
def remove_symbols_and_diacritics(s: str, keep=""):
|
||||
"""
|
||||
Replace any other markers, symbols, and punctuations with a space,
|
||||
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||
"""
|
||||
return "".join(
|
||||
c
|
||||
if c in keep
|
||||
else ADDITIONAL_DIACRITICS[c]
|
||||
if c in ADDITIONAL_DIACRITICS
|
||||
else ""
|
||||
if unicodedata.category(c) == "Mn"
|
||||
else " "
|
||||
if unicodedata.category(c)[0] in "MSP"
|
||||
else c
|
||||
for c in unicodedata.normalize("NFKD", s)
|
||||
)
|
||||
|
||||
|
||||
def remove_symbols(s: str):
|
||||
"""
|
||||
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
||||
"""
|
||||
return "".join(
|
||||
" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
|
||||
)
|
||||
|
||||
|
||||
class BasicTextNormalizer:
|
||||
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
|
||||
self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
||||
self.split_letters = split_letters
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = s.lower()
|
||||
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||
s = self.clean(s).lower()
|
||||
|
||||
if self.split_letters:
|
||||
s = " ".join(regex.findall(r"\X", s, regex.U))
|
||||
|
||||
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
|
||||
|
||||
return s
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,543 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from fractions import Fraction
|
||||
from typing import Iterator, List, Match, Optional, Union
|
||||
|
||||
from more_itertools import windowed
|
||||
|
||||
from .basic import remove_symbols_and_diacritics
|
||||
|
||||
|
||||
class EnglishNumberNormalizer:
|
||||
"""
|
||||
Convert any spelled-out numbers into arabic numbers, while handling:
|
||||
|
||||
- remove any commas
|
||||
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
|
||||
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
|
||||
- spell out `one` and `ones`
|
||||
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.zeros = {"o", "oh", "zero"}
|
||||
self.ones = {
|
||||
name: i
|
||||
for i, name in enumerate(
|
||||
[
|
||||
"one",
|
||||
"two",
|
||||
"three",
|
||||
"four",
|
||||
"five",
|
||||
"six",
|
||||
"seven",
|
||||
"eight",
|
||||
"nine",
|
||||
"ten",
|
||||
"eleven",
|
||||
"twelve",
|
||||
"thirteen",
|
||||
"fourteen",
|
||||
"fifteen",
|
||||
"sixteen",
|
||||
"seventeen",
|
||||
"eighteen",
|
||||
"nineteen",
|
||||
],
|
||||
start=1,
|
||||
)
|
||||
}
|
||||
self.ones_plural = {
|
||||
"sixes" if name == "six" else name + "s": (value, "s")
|
||||
for name, value in self.ones.items()
|
||||
}
|
||||
self.ones_ordinal = {
|
||||
"zeroth": (0, "th"),
|
||||
"first": (1, "st"),
|
||||
"second": (2, "nd"),
|
||||
"third": (3, "rd"),
|
||||
"fifth": (5, "th"),
|
||||
"twelfth": (12, "th"),
|
||||
**{
|
||||
name + ("h" if name.endswith("t") else "th"): (value, "th")
|
||||
for name, value in self.ones.items()
|
||||
if value > 3 and value != 5 and value != 12
|
||||
},
|
||||
}
|
||||
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
|
||||
|
||||
self.tens = {
|
||||
"twenty": 20,
|
||||
"thirty": 30,
|
||||
"forty": 40,
|
||||
"fifty": 50,
|
||||
"sixty": 60,
|
||||
"seventy": 70,
|
||||
"eighty": 80,
|
||||
"ninety": 90,
|
||||
}
|
||||
self.tens_plural = {
|
||||
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_ordinal = {
|
||||
name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
|
||||
|
||||
self.multipliers = {
|
||||
"hundred": 100,
|
||||
"thousand": 1_000,
|
||||
"million": 1_000_000,
|
||||
"billion": 1_000_000_000,
|
||||
"trillion": 1_000_000_000_000,
|
||||
"quadrillion": 1_000_000_000_000_000,
|
||||
"quintillion": 1_000_000_000_000_000_000,
|
||||
"sextillion": 1_000_000_000_000_000_000_000,
|
||||
"septillion": 1_000_000_000_000_000_000_000_000,
|
||||
"octillion": 1_000_000_000_000_000_000_000_000_000,
|
||||
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
|
||||
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
|
||||
}
|
||||
self.multipliers_plural = {
|
||||
name + "s": (value, "s") for name, value in self.multipliers.items()
|
||||
}
|
||||
self.multipliers_ordinal = {
|
||||
name + "th": (value, "th") for name, value in self.multipliers.items()
|
||||
}
|
||||
self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal}
|
||||
self.decimals = {*self.ones, *self.tens, *self.zeros}
|
||||
|
||||
self.preceding_prefixers = {
|
||||
"minus": "-",
|
||||
"negative": "-",
|
||||
"plus": "+",
|
||||
"positive": "+",
|
||||
}
|
||||
self.following_prefixers = {
|
||||
"pound": "£",
|
||||
"pounds": "£",
|
||||
"euro": "€",
|
||||
"euros": "€",
|
||||
"dollar": "$",
|
||||
"dollars": "$",
|
||||
"cent": "¢",
|
||||
"cents": "¢",
|
||||
}
|
||||
self.prefixes = set(
|
||||
list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
|
||||
)
|
||||
self.suffixers = {
|
||||
"per": {"cent": "%"},
|
||||
"percent": "%",
|
||||
}
|
||||
self.specials = {"and", "double", "triple", "point"}
|
||||
|
||||
self.words = set(
|
||||
[
|
||||
key
|
||||
for mapping in [
|
||||
self.zeros,
|
||||
self.ones,
|
||||
self.ones_suffixed,
|
||||
self.tens,
|
||||
self.tens_suffixed,
|
||||
self.multipliers,
|
||||
self.multipliers_suffixed,
|
||||
self.preceding_prefixers,
|
||||
self.following_prefixers,
|
||||
self.suffixers,
|
||||
self.specials,
|
||||
]
|
||||
for key in mapping
|
||||
]
|
||||
)
|
||||
self.literal_words = {"one", "ones"}
|
||||
|
||||
def process_words(self, words: List[str]) -> Iterator[str]:
|
||||
prefix: Optional[str] = None
|
||||
value: Optional[Union[str, int]] = None
|
||||
skip = False
|
||||
|
||||
def to_fraction(s: str):
|
||||
try:
|
||||
return Fraction(s)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def output(result: Union[str, int]):
|
||||
nonlocal prefix, value
|
||||
result = str(result)
|
||||
if prefix is not None:
|
||||
result = prefix + result
|
||||
value = None
|
||||
prefix = None
|
||||
return result
|
||||
|
||||
if len(words) == 0:
|
||||
return
|
||||
|
||||
for prev, current, next in windowed([None] + words + [None], 3):
|
||||
if skip:
|
||||
skip = False
|
||||
continue
|
||||
|
||||
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
|
||||
has_prefix = current[0] in self.prefixes
|
||||
current_without_prefix = current[1:] if has_prefix else current
|
||||
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
|
||||
# arabic numbers (potentially with signs and fractions)
|
||||
f = to_fraction(current_without_prefix)
|
||||
assert f is not None
|
||||
if value is not None:
|
||||
if isinstance(value, str) and value.endswith("."):
|
||||
# concatenate decimals / ip address components
|
||||
value = str(value) + str(current)
|
||||
continue
|
||||
else:
|
||||
yield output(value)
|
||||
|
||||
prefix = current[0] if has_prefix else prefix
|
||||
if f.denominator == 1:
|
||||
value = f.numerator # store integers as int
|
||||
else:
|
||||
value = current_without_prefix
|
||||
elif current not in self.words:
|
||||
# non-numeric words
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current in self.zeros:
|
||||
value = str(value or "") + "0"
|
||||
elif current in self.ones:
|
||||
ones = self.ones[current]
|
||||
|
||||
if value is None:
|
||||
value = ones
|
||||
elif isinstance(value, str) or prev in self.ones:
|
||||
if prev in self.tens and ones < 10: # replace the last zero with the digit
|
||||
assert value[-1] == "0"
|
||||
value = value[:-1] + str(ones)
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
elif ones < 10:
|
||||
if value % 10 == 0:
|
||||
value += ones
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
else: # eleven to nineteen
|
||||
if value % 100 == 0:
|
||||
value += ones
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
elif current in self.ones_suffixed:
|
||||
# ordinal or cardinal; yield the number right away
|
||||
ones, suffix = self.ones_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(ones) + suffix)
|
||||
elif isinstance(value, str) or prev in self.ones:
|
||||
if prev in self.tens and ones < 10:
|
||||
assert value[-1] == "0"
|
||||
yield output(value[:-1] + str(ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
elif ones < 10:
|
||||
if value % 10 == 0:
|
||||
yield output(str(value + ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
else: # eleven to nineteen
|
||||
if value % 100 == 0:
|
||||
yield output(str(value + ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
value = None
|
||||
elif current in self.tens:
|
||||
tens = self.tens[current]
|
||||
if value is None:
|
||||
value = tens
|
||||
elif isinstance(value, str):
|
||||
value = str(value) + str(tens)
|
||||
else:
|
||||
if value % 100 == 0:
|
||||
value += tens
|
||||
else:
|
||||
value = str(value) + str(tens)
|
||||
elif current in self.tens_suffixed:
|
||||
# ordinal or cardinal; yield the number right away
|
||||
tens, suffix = self.tens_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(tens) + suffix)
|
||||
elif isinstance(value, str):
|
||||
yield output(str(value) + str(tens) + suffix)
|
||||
else:
|
||||
if value % 100 == 0:
|
||||
yield output(str(value + tens) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(tens) + suffix)
|
||||
elif current in self.multipliers:
|
||||
multiplier = self.multipliers[current]
|
||||
if value is None:
|
||||
value = multiplier
|
||||
elif isinstance(value, str) or value == 0:
|
||||
f = to_fraction(value)
|
||||
p = f * multiplier if f is not None else None
|
||||
if f is not None and p.denominator == 1:
|
||||
value = p.numerator
|
||||
else:
|
||||
yield output(value)
|
||||
value = multiplier
|
||||
else:
|
||||
before = value // 1000 * 1000
|
||||
residual = value % 1000
|
||||
value = before + residual * multiplier
|
||||
elif current in self.multipliers_suffixed:
|
||||
multiplier, suffix = self.multipliers_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(multiplier) + suffix)
|
||||
elif isinstance(value, str):
|
||||
f = to_fraction(value)
|
||||
p = f * multiplier if f is not None else None
|
||||
if f is not None and p.denominator == 1:
|
||||
yield output(str(p.numerator) + suffix)
|
||||
else:
|
||||
yield output(value)
|
||||
yield output(str(multiplier) + suffix)
|
||||
else: # int
|
||||
before = value // 1000 * 1000
|
||||
residual = value % 1000
|
||||
value = before + residual * multiplier
|
||||
yield output(str(value) + suffix)
|
||||
value = None
|
||||
elif current in self.preceding_prefixers:
|
||||
# apply prefix (positive, minus, etc.) if it precedes a number
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
|
||||
if next in self.words or next_is_numeric:
|
||||
prefix = self.preceding_prefixers[current]
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.following_prefixers:
|
||||
# apply prefix (dollars, cents, etc.) only after a number
|
||||
if value is not None:
|
||||
prefix = self.following_prefixers[current]
|
||||
yield output(value)
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.suffixers:
|
||||
# apply suffix symbols (percent -> '%')
|
||||
if value is not None:
|
||||
suffix = self.suffixers[current]
|
||||
if isinstance(suffix, dict):
|
||||
if next in suffix:
|
||||
yield output(str(value) + suffix[next])
|
||||
skip = True
|
||||
else:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
else:
|
||||
yield output(str(value) + suffix)
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.specials:
|
||||
if next not in self.words and not next_is_numeric:
|
||||
# apply special handling only if the next word can be numeric
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "and":
|
||||
# ignore "and" after hundreds, thousands, etc.
|
||||
if prev not in self.multipliers:
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "double" or current == "triple":
|
||||
if next in self.ones or next in self.zeros:
|
||||
repeats = 2 if current == "double" else 3
|
||||
ones = self.ones.get(next, 0)
|
||||
value = str(value or "") + str(ones) * repeats
|
||||
skip = True
|
||||
else:
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "point":
|
||||
if next in self.decimals or next_is_numeric:
|
||||
value = str(value or "") + "."
|
||||
else:
|
||||
# should all have been covered at this point
|
||||
raise ValueError(f"Unexpected token: {current}")
|
||||
else:
|
||||
# all should have been covered at this point
|
||||
raise ValueError(f"Unexpected token: {current}")
|
||||
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
|
||||
def preprocess(self, s: str):
|
||||
# replace "<number> and a half" with "<number> point five"
|
||||
results = []
|
||||
|
||||
segments = re.split(r"\band\s+a\s+half\b", s)
|
||||
for i, segment in enumerate(segments):
|
||||
if len(segment.strip()) == 0:
|
||||
continue
|
||||
if i == len(segments) - 1:
|
||||
results.append(segment)
|
||||
else:
|
||||
results.append(segment)
|
||||
last_word = segment.rsplit(maxsplit=2)[-1]
|
||||
if last_word in self.decimals or last_word in self.multipliers:
|
||||
results.append("point five")
|
||||
else:
|
||||
results.append("and a half")
|
||||
|
||||
s = " ".join(results)
|
||||
|
||||
# put a space at number/letter boundary
|
||||
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
|
||||
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
|
||||
|
||||
# but remove spaces which could be a suffix
|
||||
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
|
||||
|
||||
return s
|
||||
|
||||
def postprocess(self, s: str):
|
||||
def combine_cents(m: Match):
|
||||
try:
|
||||
currency = m.group(1)
|
||||
integer = m.group(2)
|
||||
cents = int(m.group(3))
|
||||
return f"{currency}{integer}.{cents:02d}"
|
||||
except ValueError:
|
||||
return m.string
|
||||
|
||||
def extract_cents(m: Match):
|
||||
try:
|
||||
return f"¢{int(m.group(1))}"
|
||||
except ValueError:
|
||||
return m.string
|
||||
|
||||
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
|
||||
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
|
||||
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
|
||||
|
||||
# write "one(s)" instead of "1(s)", just for the readability
|
||||
s = re.sub(r"\b1(s?)\b", r"one\1", s)
|
||||
|
||||
return s
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = self.preprocess(s)
|
||||
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
|
||||
s = self.postprocess(s)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class EnglishSpellingNormalizer:
|
||||
"""
|
||||
Applies British-American spelling mappings as listed in [1].
|
||||
|
||||
[1] https://www.tysto.com/uk-us-spelling-list.html
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
|
||||
self.mapping = json.load(open(mapping_path))
|
||||
|
||||
def __call__(self, s: str):
|
||||
return " ".join(self.mapping.get(word, word) for word in s.split())
|
||||
|
||||
|
||||
class EnglishTextNormalizer:
|
||||
def __init__(self):
|
||||
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
|
||||
self.replacers = {
|
||||
# common contractions
|
||||
r"\bwon't\b": "will not",
|
||||
r"\bcan't\b": "can not",
|
||||
r"\blet's\b": "let us",
|
||||
r"\bain't\b": "aint",
|
||||
r"\by'all\b": "you all",
|
||||
r"\bwanna\b": "want to",
|
||||
r"\bgotta\b": "got to",
|
||||
r"\bgonna\b": "going to",
|
||||
r"\bi'ma\b": "i am going to",
|
||||
r"\bimma\b": "i am going to",
|
||||
r"\bwoulda\b": "would have",
|
||||
r"\bcoulda\b": "could have",
|
||||
r"\bshoulda\b": "should have",
|
||||
r"\bma'am\b": "madam",
|
||||
# contractions in titles/prefixes
|
||||
r"\bmr\b": "mister ",
|
||||
r"\bmrs\b": "missus ",
|
||||
r"\bst\b": "saint ",
|
||||
r"\bdr\b": "doctor ",
|
||||
r"\bprof\b": "professor ",
|
||||
r"\bcapt\b": "captain ",
|
||||
r"\bgov\b": "governor ",
|
||||
r"\bald\b": "alderman ",
|
||||
r"\bgen\b": "general ",
|
||||
r"\bsen\b": "senator ",
|
||||
r"\brep\b": "representative ",
|
||||
r"\bpres\b": "president ",
|
||||
r"\brev\b": "reverend ",
|
||||
r"\bhon\b": "honorable ",
|
||||
r"\basst\b": "assistant ",
|
||||
r"\bassoc\b": "associate ",
|
||||
r"\blt\b": "lieutenant ",
|
||||
r"\bcol\b": "colonel ",
|
||||
r"\bjr\b": "junior ",
|
||||
r"\bsr\b": "senior ",
|
||||
r"\besq\b": "esquire ",
|
||||
# prefect tenses, ideally it should be any past participles, but it's harder..
|
||||
r"'d been\b": " had been",
|
||||
r"'s been\b": " has been",
|
||||
r"'d gone\b": " had gone",
|
||||
r"'s gone\b": " has gone",
|
||||
r"'d done\b": " had done", # "'s done" is ambiguous
|
||||
r"'s got\b": " has got",
|
||||
# general contractions
|
||||
r"n't\b": " not",
|
||||
r"'re\b": " are",
|
||||
r"'s\b": " is",
|
||||
r"'d\b": " would",
|
||||
r"'ll\b": " will",
|
||||
r"'t\b": " not",
|
||||
r"'ve\b": " have",
|
||||
r"'m\b": " am",
|
||||
}
|
||||
self.standardize_numbers = EnglishNumberNormalizer()
|
||||
self.standardize_spellings = EnglishSpellingNormalizer()
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = s.lower()
|
||||
|
||||
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||
s = re.sub(self.ignore_patterns, "", s)
|
||||
s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe
|
||||
|
||||
for pattern, replacement in self.replacers.items():
|
||||
s = re.sub(pattern, replacement, s)
|
||||
|
||||
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
|
||||
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
|
||||
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics
|
||||
|
||||
s = self.standardize_numbers(s)
|
||||
s = self.standardize_spellings(s)
|
||||
|
||||
# now remove prefix/suffix symbols that are not preceded/followed by numbers
|
||||
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
|
||||
s = re.sub(r"([^0-9])%", r"\1 ", s)
|
||||
|
||||
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
|
||||
|
||||
return s
|
||||
331
models/LatentSync/latentsync/whisper/whisper/tokenizer.py
Normal file
331
models/LatentSync/latentsync/whisper/whisper/tokenizer.py
Normal file
@@ -0,0 +1,331 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import GPT2TokenizerFast
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english",
|
||||
"zh": "chinese",
|
||||
"de": "german",
|
||||
"es": "spanish",
|
||||
"ru": "russian",
|
||||
"ko": "korean",
|
||||
"fr": "french",
|
||||
"ja": "japanese",
|
||||
"pt": "portuguese",
|
||||
"tr": "turkish",
|
||||
"pl": "polish",
|
||||
"ca": "catalan",
|
||||
"nl": "dutch",
|
||||
"ar": "arabic",
|
||||
"sv": "swedish",
|
||||
"it": "italian",
|
||||
"id": "indonesian",
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"iw": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
"cs": "czech",
|
||||
"ro": "romanian",
|
||||
"da": "danish",
|
||||
"hu": "hungarian",
|
||||
"ta": "tamil",
|
||||
"no": "norwegian",
|
||||
"th": "thai",
|
||||
"ur": "urdu",
|
||||
"hr": "croatian",
|
||||
"bg": "bulgarian",
|
||||
"lt": "lithuanian",
|
||||
"la": "latin",
|
||||
"mi": "maori",
|
||||
"ml": "malayalam",
|
||||
"cy": "welsh",
|
||||
"sk": "slovak",
|
||||
"te": "telugu",
|
||||
"fa": "persian",
|
||||
"lv": "latvian",
|
||||
"bn": "bengali",
|
||||
"sr": "serbian",
|
||||
"az": "azerbaijani",
|
||||
"sl": "slovenian",
|
||||
"kn": "kannada",
|
||||
"et": "estonian",
|
||||
"mk": "macedonian",
|
||||
"br": "breton",
|
||||
"eu": "basque",
|
||||
"is": "icelandic",
|
||||
"hy": "armenian",
|
||||
"ne": "nepali",
|
||||
"mn": "mongolian",
|
||||
"bs": "bosnian",
|
||||
"kk": "kazakh",
|
||||
"sq": "albanian",
|
||||
"sw": "swahili",
|
||||
"gl": "galician",
|
||||
"mr": "marathi",
|
||||
"pa": "punjabi",
|
||||
"si": "sinhala",
|
||||
"km": "khmer",
|
||||
"sn": "shona",
|
||||
"yo": "yoruba",
|
||||
"so": "somali",
|
||||
"af": "afrikaans",
|
||||
"oc": "occitan",
|
||||
"ka": "georgian",
|
||||
"be": "belarusian",
|
||||
"tg": "tajik",
|
||||
"sd": "sindhi",
|
||||
"gu": "gujarati",
|
||||
"am": "amharic",
|
||||
"yi": "yiddish",
|
||||
"lo": "lao",
|
||||
"uz": "uzbek",
|
||||
"fo": "faroese",
|
||||
"ht": "haitian creole",
|
||||
"ps": "pashto",
|
||||
"tk": "turkmen",
|
||||
"nn": "nynorsk",
|
||||
"mt": "maltese",
|
||||
"sa": "sanskrit",
|
||||
"lb": "luxembourgish",
|
||||
"my": "myanmar",
|
||||
"bo": "tibetan",
|
||||
"tl": "tagalog",
|
||||
"mg": "malagasy",
|
||||
"as": "assamese",
|
||||
"tt": "tatar",
|
||||
"haw": "hawaiian",
|
||||
"ln": "lingala",
|
||||
"ha": "hausa",
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
}
|
||||
|
||||
# language code lookup by name, with a few language aliases
|
||||
TO_LANGUAGE_CODE = {
|
||||
**{language: code for code, language in LANGUAGES.items()},
|
||||
"burmese": "my",
|
||||
"valencian": "ca",
|
||||
"flemish": "nl",
|
||||
"haitian": "ht",
|
||||
"letzeburgesch": "lb",
|
||||
"pushto": "ps",
|
||||
"panjabi": "pa",
|
||||
"moldavian": "ro",
|
||||
"moldovan": "ro",
|
||||
"sinhalese": "si",
|
||||
"castilian": "es",
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Tokenizer:
|
||||
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
|
||||
|
||||
tokenizer: "GPT2TokenizerFast"
|
||||
language: Optional[str]
|
||||
sot_sequence: Tuple[int]
|
||||
|
||||
def encode(self, text, **kwargs):
|
||||
return self.tokenizer.encode(text, **kwargs)
|
||||
|
||||
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
|
||||
return self.tokenizer.decode(token_ids, **kwargs)
|
||||
|
||||
def decode_with_timestamps(self, tokens) -> str:
|
||||
"""
|
||||
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
|
||||
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||
"""
|
||||
outputs = [[]]
|
||||
for token in tokens:
|
||||
if token >= self.timestamp_begin:
|
||||
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
||||
outputs.append(timestamp)
|
||||
outputs.append([])
|
||||
else:
|
||||
outputs[-1].append(token)
|
||||
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
||||
return "".join(outputs)
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def eot(self) -> int:
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def sot(self) -> int:
|
||||
return self._get_single_token_id("<|startoftranscript|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def sot_lm(self) -> int:
|
||||
return self._get_single_token_id("<|startoflm|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def sot_prev(self) -> int:
|
||||
return self._get_single_token_id("<|startofprev|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def no_speech(self) -> int:
|
||||
return self._get_single_token_id("<|nospeech|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def no_timestamps(self) -> int:
|
||||
return self._get_single_token_id("<|notimestamps|>")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def timestamp_begin(self) -> int:
|
||||
return self.tokenizer.all_special_ids[-1] + 1
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def language_token(self) -> int:
|
||||
"""Returns the token id corresponding to the value of the `language` field"""
|
||||
if self.language is None:
|
||||
raise ValueError(f"This tokenizer does not have language token configured")
|
||||
|
||||
additional_tokens = dict(
|
||||
zip(
|
||||
self.tokenizer.additional_special_tokens,
|
||||
self.tokenizer.additional_special_tokens_ids,
|
||||
)
|
||||
)
|
||||
candidate = f"<|{self.language}|>"
|
||||
if candidate in additional_tokens:
|
||||
return additional_tokens[candidate]
|
||||
|
||||
raise KeyError(f"Language {self.language} not found in tokenizer.")
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def all_language_tokens(self) -> Tuple[int]:
|
||||
result = []
|
||||
for token, token_id in zip(
|
||||
self.tokenizer.additional_special_tokens,
|
||||
self.tokenizer.additional_special_tokens_ids,
|
||||
):
|
||||
if token.strip("<|>") in LANGUAGES:
|
||||
result.append(token_id)
|
||||
return tuple(result)
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def all_language_codes(self) -> Tuple[str]:
|
||||
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
||||
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
||||
|
||||
@property
|
||||
@lru_cache()
|
||||
def non_speech_tokens(self) -> Tuple[int]:
|
||||
"""
|
||||
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
||||
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
||||
|
||||
- ♪♪♪
|
||||
- ( SPEAKING FOREIGN LANGUAGE )
|
||||
- [DAVID] Hey there,
|
||||
|
||||
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||
"""
|
||||
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
|
||||
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||
|
||||
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
||||
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
||||
miscellaneous = set("♩♪♫♬♭♮♯")
|
||||
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||
|
||||
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
|
||||
for symbol in symbols + list(miscellaneous):
|
||||
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
|
||||
if len(tokens) == 1 or symbol in miscellaneous:
|
||||
result.add(tokens[0])
|
||||
|
||||
return tuple(sorted(result))
|
||||
|
||||
def _get_single_token_id(self, text) -> int:
|
||||
tokens = self.tokenizer.encode(text)
|
||||
assert len(tokens) == 1, f"{text} is not encoded as a single token"
|
||||
return tokens[0]
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def build_tokenizer(name: str = "gpt2"):
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
path = os.path.join(os.path.dirname(__file__), "assets", name)
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained(path)
|
||||
|
||||
specials = [
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
]
|
||||
|
||||
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
|
||||
return tokenizer
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_tokenizer(
|
||||
multilingual: bool,
|
||||
*,
|
||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||
language: Optional[str] = None,
|
||||
) -> Tokenizer:
|
||||
if language is not None:
|
||||
language = language.lower()
|
||||
if language not in LANGUAGES:
|
||||
if language in TO_LANGUAGE_CODE:
|
||||
language = TO_LANGUAGE_CODE[language]
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {language}")
|
||||
|
||||
if multilingual:
|
||||
tokenizer_name = "multilingual"
|
||||
task = task or "transcribe"
|
||||
language = language or "en"
|
||||
else:
|
||||
tokenizer_name = "gpt2"
|
||||
task = None
|
||||
language = None
|
||||
|
||||
tokenizer = build_tokenizer(name=tokenizer_name)
|
||||
all_special_ids: List[int] = tokenizer.all_special_ids
|
||||
sot: int = all_special_ids[1]
|
||||
translate: int = all_special_ids[-6]
|
||||
transcribe: int = all_special_ids[-5]
|
||||
|
||||
langs = tuple(LANGUAGES.keys())
|
||||
sot_sequence = [sot]
|
||||
if language is not None:
|
||||
sot_sequence.append(sot + 1 + langs.index(language))
|
||||
if task is not None:
|
||||
sot_sequence.append(transcribe if task == "transcribe" else translate)
|
||||
|
||||
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))
|
||||
207
models/LatentSync/latentsync/whisper/whisper/transcribe.py
Normal file
207
models/LatentSync/latentsync/whisper/whisper/transcribe.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
def transcribe(
|
||||
model: "Whisper",
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
*,
|
||||
verbose: Optional[bool] = None,
|
||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
compression_ratio_threshold: Optional[float] = 2.4,
|
||||
logprob_threshold: Optional[float] = -1.0,
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
force_extraction: bool = False,
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
Transcribe an audio file using Whisper
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Whisper
|
||||
The Whisper model instance
|
||||
|
||||
audio: Union[str, np.ndarray, torch.Tensor]
|
||||
The path to the audio file to open, or the audio waveform
|
||||
|
||||
verbose: bool
|
||||
Whether to display the text being decoded to the console. If True, displays all the details,
|
||||
If False, displays minimal details. If None, does not display anything
|
||||
|
||||
temperature: Union[float, Tuple[float, ...]]
|
||||
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
|
||||
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
||||
|
||||
compression_ratio_threshold: float
|
||||
If the gzip compression ratio is above this value, treat as failed
|
||||
|
||||
logprob_threshold: float
|
||||
If the average log probability over sampled tokens is below this value, treat as failed
|
||||
|
||||
no_speech_threshold: float
|
||||
If the no_speech probability is higher than this value AND the average log probability
|
||||
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
||||
|
||||
condition_on_previous_text: bool
|
||||
if True, the previous output of the model is provided as a prompt for the next window;
|
||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
||||
|
||||
decode_options: dict
|
||||
Keyword arguments to construct `DecodingOptions` instances
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||
"""
|
||||
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
||||
if model.device == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
warnings.warn("Performing inference on CPU when CUDA is available")
|
||||
if dtype == torch.float16:
|
||||
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
||||
dtype = torch.float32
|
||||
|
||||
if dtype == torch.float32:
|
||||
decode_options["fp16"] = False
|
||||
|
||||
mel = log_mel_spectrogram(audio)
|
||||
|
||||
all_segments = []
|
||||
def add_segment(
|
||||
*, start: float, end: float, encoder_embeddings
|
||||
):
|
||||
|
||||
all_segments.append(
|
||||
{
|
||||
"start": start,
|
||||
"end": end,
|
||||
"encoder_embeddings":encoder_embeddings,
|
||||
}
|
||||
)
|
||||
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
|
||||
num_frames = mel.shape[-1]
|
||||
seek = 0
|
||||
previous_seek_value = seek
|
||||
sample_skip = 3000 #
|
||||
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
|
||||
while seek < num_frames:
|
||||
# seek是开始的帧数
|
||||
end_seek = min(seek + sample_skip, num_frames)
|
||||
segment = pad_or_trim(mel[:,seek:seek+sample_skip], N_FRAMES).to(model.device).to(dtype)
|
||||
|
||||
single = segment.ndim == 2
|
||||
if single:
|
||||
segment = segment.unsqueeze(0)
|
||||
if dtype == torch.float16:
|
||||
segment = segment.half()
|
||||
audio_features, embeddings = model.encoder(segment, include_embeddings = True)
|
||||
|
||||
encoder_embeddings = embeddings
|
||||
#print(f"encoder_embeddings shape {encoder_embeddings.shape}")
|
||||
add_segment(
|
||||
start=seek,
|
||||
end=end_seek,
|
||||
#text_tokens=tokens,
|
||||
#result=result,
|
||||
encoder_embeddings=encoder_embeddings,
|
||||
)
|
||||
seek+=sample_skip
|
||||
|
||||
return dict(segments=all_segments)
|
||||
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||
|
||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||
|
||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||
|
||||
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
model_name: str = args.pop("model")
|
||||
model_dir: str = args.pop("model_dir")
|
||||
output_dir: str = args.pop("output_dir")
|
||||
device: str = args.pop("device")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||
if args["language"] is not None:
|
||||
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
|
||||
args["language"] = "en"
|
||||
|
||||
temperature = args.pop("temperature")
|
||||
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
|
||||
if temperature_increment_on_fallback is not None:
|
||||
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
|
||||
else:
|
||||
temperature = [temperature]
|
||||
|
||||
threads = args.pop("threads")
|
||||
if threads > 0:
|
||||
torch.set_num_threads(threads)
|
||||
|
||||
from . import load_model
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
|
||||
for audio_path in args.pop("audio"):
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
|
||||
# save TXT
|
||||
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
|
||||
write_txt(result["segments"], file=txt)
|
||||
|
||||
# save VTT
|
||||
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
|
||||
write_vtt(result["segments"], file=vtt)
|
||||
|
||||
# save SRT
|
||||
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
||||
write_srt(result["segments"], file=srt)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli()
|
||||
87
models/LatentSync/latentsync/whisper/whisper/utils.py
Normal file
87
models/LatentSync/latentsync/whisper/whisper/utils.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import zlib
|
||||
from typing import Iterator, TextIO
|
||||
|
||||
|
||||
def exact_div(x, y):
|
||||
assert x % y == 0
|
||||
return x // y
|
||||
|
||||
|
||||
def str2bool(string):
|
||||
str2val = {"True": True, "False": False}
|
||||
if string in str2val:
|
||||
return str2val[string]
|
||||
else:
|
||||
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
||||
|
||||
|
||||
def optional_int(string):
|
||||
return None if string == "None" else int(string)
|
||||
|
||||
|
||||
def optional_float(string):
|
||||
return None if string == "None" else float(string)
|
||||
|
||||
|
||||
def compression_ratio(text) -> float:
|
||||
return len(text) / len(zlib.compress(text.encode("utf-8")))
|
||||
|
||||
|
||||
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
|
||||
assert seconds >= 0, "non-negative timestamp expected"
|
||||
milliseconds = round(seconds * 1000.0)
|
||||
|
||||
hours = milliseconds // 3_600_000
|
||||
milliseconds -= hours * 3_600_000
|
||||
|
||||
minutes = milliseconds // 60_000
|
||||
milliseconds -= minutes * 60_000
|
||||
|
||||
seconds = milliseconds // 1_000
|
||||
milliseconds -= seconds * 1_000
|
||||
|
||||
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
||||
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||
|
||||
|
||||
def write_txt(transcript: Iterator[dict], file: TextIO):
|
||||
for segment in transcript:
|
||||
print(segment['text'].strip(), file=file, flush=True)
|
||||
|
||||
|
||||
def write_vtt(transcript: Iterator[dict], file: TextIO):
|
||||
print("WEBVTT\n", file=file)
|
||||
for segment in transcript:
|
||||
print(
|
||||
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def write_srt(transcript: Iterator[dict], file: TextIO):
|
||||
"""
|
||||
Write a transcript to a file in SRT format.
|
||||
|
||||
Example usage:
|
||||
from pathlib import Path
|
||||
from whisper.utils import write_srt
|
||||
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
|
||||
# save SRT
|
||||
audio_basename = Path(audio_path).stem
|
||||
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
||||
write_srt(result["segments"], file=srt)
|
||||
"""
|
||||
for i, segment in enumerate(transcript, start=1):
|
||||
# write srt lines
|
||||
print(
|
||||
f"{i}\n"
|
||||
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
||||
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
||||
f"{segment['text'].strip().replace('-->', '->')}\n",
|
||||
file=file,
|
||||
flush=True,
|
||||
)
|
||||
120
models/LatentSync/scripts/inference.py
Normal file
120
models/LatentSync/scripts/inference.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, DDIMScheduler
|
||||
from latentsync.models.unet import UNet3DConditionModel
|
||||
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
|
||||
from accelerate.utils import set_seed
|
||||
from latentsync.whisper.audio2feature import Audio2Feature
|
||||
from DeepCache import DeepCacheSDHelper
|
||||
|
||||
|
||||
def main(config, args):
|
||||
if not os.path.exists(args.video_path):
|
||||
raise RuntimeError(f"Video path '{args.video_path}' not found")
|
||||
if not os.path.exists(args.audio_path):
|
||||
raise RuntimeError(f"Audio path '{args.audio_path}' not found")
|
||||
|
||||
# Check if the GPU supports float16
|
||||
is_fp16_supported = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] > 7
|
||||
dtype = torch.float16 if is_fp16_supported else torch.float32
|
||||
|
||||
print(f"Input video path: {args.video_path}")
|
||||
print(f"Input audio path: {args.audio_path}")
|
||||
print(f"Loaded checkpoint path: {args.inference_ckpt_path}")
|
||||
|
||||
scheduler = DDIMScheduler.from_pretrained("configs")
|
||||
|
||||
if config.model.cross_attention_dim == 768:
|
||||
whisper_model_path = "checkpoints/whisper/small.pt"
|
||||
elif config.model.cross_attention_dim == 384:
|
||||
whisper_model_path = "checkpoints/whisper/tiny.pt"
|
||||
else:
|
||||
raise NotImplementedError("cross_attention_dim must be 768 or 384")
|
||||
|
||||
audio_encoder = Audio2Feature(
|
||||
model_path=whisper_model_path,
|
||||
device="cuda",
|
||||
num_frames=config.data.num_frames,
|
||||
audio_feat_length=config.data.audio_feat_length,
|
||||
)
|
||||
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype)
|
||||
vae.config.scaling_factor = 0.18215
|
||||
vae.config.shift_factor = 0
|
||||
|
||||
unet, _ = UNet3DConditionModel.from_pretrained(
|
||||
OmegaConf.to_container(config.model),
|
||||
args.inference_ckpt_path,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
unet = unet.to(dtype=dtype)
|
||||
|
||||
pipeline = LipsyncPipeline(
|
||||
vae=vae,
|
||||
audio_encoder=audio_encoder,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
).to("cuda")
|
||||
|
||||
# use DeepCache
|
||||
if args.enable_deepcache:
|
||||
helper = DeepCacheSDHelper(pipe=pipeline)
|
||||
helper.set_params(cache_interval=3, cache_branch_id=0)
|
||||
helper.enable()
|
||||
|
||||
if args.seed != -1:
|
||||
set_seed(args.seed)
|
||||
else:
|
||||
torch.seed()
|
||||
|
||||
print(f"Initial seed: {torch.initial_seed()}")
|
||||
|
||||
pipeline(
|
||||
video_path=args.video_path,
|
||||
audio_path=args.audio_path,
|
||||
video_out_path=args.video_out_path,
|
||||
num_frames=config.data.num_frames,
|
||||
num_inference_steps=args.inference_steps,
|
||||
guidance_scale=args.guidance_scale,
|
||||
weight_dtype=dtype,
|
||||
width=config.data.resolution,
|
||||
height=config.data.resolution,
|
||||
mask_image_path=config.data.mask_image_path,
|
||||
temp_dir=args.temp_dir,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--unet_config_path", type=str, default="configs/unet.yaml")
|
||||
parser.add_argument("--inference_ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--video_path", type=str, required=True)
|
||||
parser.add_argument("--audio_path", type=str, required=True)
|
||||
parser.add_argument("--video_out_path", type=str, required=True)
|
||||
parser.add_argument("--inference_steps", type=int, default=20)
|
||||
parser.add_argument("--guidance_scale", type=float, default=1.0)
|
||||
parser.add_argument("--temp_dir", type=str, default="temp")
|
||||
parser.add_argument("--seed", type=int, default=1247)
|
||||
parser.add_argument("--enable_deepcache", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
config = OmegaConf.load(args.unet_config_path)
|
||||
|
||||
main(config, args)
|
||||
196
models/LatentSync/scripts/server.py
Normal file
196
models/LatentSync/scripts/server.py
Normal file
@@ -0,0 +1,196 @@
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
# --- 自动加载 GPU 配置 (必须在 torch 导入前) ---
|
||||
def load_gpu_config():
|
||||
"""尝试从后端 .env 文件读取 LATENTSYNC_GPU_ID"""
|
||||
try:
|
||||
# 路径: scripts/server.py -> scripts -> LatentSync -> models -> ViGent2 -> backend -> .env
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
env_path = current_dir.parent.parent.parent / "backend" / ".env"
|
||||
|
||||
target_gpu = "1" # 默认 fallback
|
||||
|
||||
if env_path.exists():
|
||||
print(f"📖 读取配置文件: {env_path}")
|
||||
with open(env_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("LATENTSYNC_GPU_ID="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip()
|
||||
if val:
|
||||
target_gpu = val
|
||||
print(f"⚙️ 发现配置 LATENTSYNC_GPU_ID={target_gpu}")
|
||||
break
|
||||
|
||||
# 设置环境变量
|
||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = target_gpu
|
||||
print(f"✅ 已自动设置: CUDA_VISIBLE_DEVICES={target_gpu}")
|
||||
else:
|
||||
print(f"ℹ️ 检测到外部 CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']},跳过自动配置")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ 读取 GPU 配置失败: {e},将使用默认设置")
|
||||
|
||||
load_gpu_config()
|
||||
|
||||
import torch
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from omegaconf import OmegaConf
|
||||
from diffusers import AutoencoderKL, DDIMScheduler
|
||||
from latentsync.models.unet import UNet3DConditionModel
|
||||
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
|
||||
from latentsync.whisper.audio2feature import Audio2Feature
|
||||
from accelerate.utils import set_seed
|
||||
from DeepCache import DeepCacheSDHelper
|
||||
|
||||
# 全局模型缓存
|
||||
models = {}
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# --- 模型加载逻辑 (参考 inference.py) ---
|
||||
print("⏳ 正在加载 LatentSync 模型...")
|
||||
|
||||
# 默认配置路径 (相对于根目录)
|
||||
unet_config_path = "configs/unet/stage2_512.yaml"
|
||||
ckpt_path = "checkpoints/latentsync_unet.pt"
|
||||
|
||||
if not os.path.exists(unet_config_path):
|
||||
print(f"⚠️ 找不到配置文件: {unet_config_path},请确保在 models/LatentSync 根目录运行")
|
||||
|
||||
config = OmegaConf.load(unet_config_path)
|
||||
|
||||
# Check GPU
|
||||
is_fp16_supported = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] > 7
|
||||
dtype = torch.float16 if is_fp16_supported else torch.float32
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
gpu_name = torch.cuda.get_device_name(0)
|
||||
print(f"🖥️ 正在使用 GPU: {gpu_name} (CUDA_VISIBLE_DEVICES 已生效)")
|
||||
else:
|
||||
print("⚠️ 警告: 未检测到 GPU,将使用 CPU 进行推理 (速度极慢)")
|
||||
|
||||
scheduler = DDIMScheduler.from_pretrained("configs")
|
||||
|
||||
# Whisper Model
|
||||
if config.model.cross_attention_dim == 768:
|
||||
whisper_path = "checkpoints/whisper/small.pt"
|
||||
else:
|
||||
whisper_path = "checkpoints/whisper/tiny.pt"
|
||||
|
||||
audio_encoder = Audio2Feature(
|
||||
model_path=whisper_path,
|
||||
device=device,
|
||||
num_frames=config.data.num_frames,
|
||||
audio_feat_length=config.data.audio_feat_length,
|
||||
)
|
||||
|
||||
# VAE
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype)
|
||||
vae.config.scaling_factor = 0.18215
|
||||
vae.config.shift_factor = 0
|
||||
|
||||
# UNet
|
||||
unet, _ = UNet3DConditionModel.from_pretrained(
|
||||
OmegaConf.to_container(config.model),
|
||||
ckpt_path,
|
||||
device="cpu", # Load to CPU first to save memory during init
|
||||
)
|
||||
unet = unet.to(dtype=dtype)
|
||||
|
||||
# Pipeline
|
||||
pipeline = LipsyncPipeline(
|
||||
vae=vae,
|
||||
audio_encoder=audio_encoder,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
).to(device)
|
||||
|
||||
# DeepCache (默认启用)
|
||||
helper = DeepCacheSDHelper(pipe=pipeline)
|
||||
helper.set_params(cache_interval=3, cache_branch_id=0)
|
||||
helper.enable()
|
||||
|
||||
models["pipeline"] = pipeline
|
||||
models["config"] = config
|
||||
models["dtype"] = dtype
|
||||
|
||||
print("✅ LatentSync 模型加载完成,服务就绪!")
|
||||
yield
|
||||
# Clean up if needed
|
||||
models.clear()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
class LipSyncRequest(BaseModel):
|
||||
video_path: str
|
||||
audio_path: str
|
||||
video_out_path: str
|
||||
inference_steps: int = 20
|
||||
guidance_scale: float = 1.5
|
||||
seed: int = 1247
|
||||
temp_dir: str = "temp"
|
||||
|
||||
@app.get("/health")
|
||||
def health_check():
|
||||
return {"status": "ok", "model_loaded": "pipeline" in models}
|
||||
|
||||
@app.post("/lipsync")
|
||||
async def generate_lipsync(req: LipSyncRequest):
|
||||
if "pipeline" not in models:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
if not os.path.exists(req.video_path):
|
||||
raise HTTPException(status_code=404, detail=f"Video not found: {req.video_path}")
|
||||
if not os.path.exists(req.audio_path):
|
||||
raise HTTPException(status_code=404, detail=f"Audio not found: {req.audio_path}")
|
||||
|
||||
print(f"🎬 收到任务: {Path(req.video_path).name} -> {Path(req.video_out_path).name}")
|
||||
|
||||
try:
|
||||
pipeline = models["pipeline"]
|
||||
config = models["config"]
|
||||
dtype = models["dtype"]
|
||||
|
||||
# Set seed
|
||||
if req.seed != -1:
|
||||
set_seed(req.seed)
|
||||
else:
|
||||
torch.seed()
|
||||
|
||||
# Run Inference
|
||||
pipeline(
|
||||
video_path=req.video_path,
|
||||
audio_path=req.audio_path,
|
||||
video_out_path=req.video_out_path,
|
||||
num_frames=config.data.num_frames,
|
||||
num_inference_steps=req.inference_steps,
|
||||
guidance_scale=req.guidance_scale,
|
||||
weight_dtype=dtype,
|
||||
width=config.data.resolution,
|
||||
height=config.data.resolution,
|
||||
mask_image_path=config.data.mask_image_path,
|
||||
temp_dir=req.temp_dir,
|
||||
)
|
||||
|
||||
if os.path.exists(req.video_out_path):
|
||||
return {"status": "success", "output_path": req.video_out_path}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Output file generation failed")
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8007)
|
||||
340
models/LatentSync/scripts/train_syncnet.py
Normal file
340
models/LatentSync/scripts/train_syncnet.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
import os, argparse, datetime, math
|
||||
import logging
|
||||
from omegaconf import OmegaConf
|
||||
import shutil
|
||||
|
||||
from latentsync.data.syncnet_dataset import SyncNetDataset
|
||||
from latentsync.models.stable_syncnet import StableSyncNet
|
||||
from latentsync.models.wav2lip_syncnet import Wav2LipSyncNet
|
||||
from latentsync.utils.util import gather_loss, plot_loss_chart
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
import torch
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.utils.logging import get_logger
|
||||
from einops import rearrange
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from latentsync.utils.util import init_dist, cosine_loss, dummy_context
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def main(config):
|
||||
# Initialize distributed training
|
||||
local_rank = init_dist()
|
||||
global_rank = dist.get_rank()
|
||||
num_processes = dist.get_world_size()
|
||||
is_main_process = global_rank == 0
|
||||
|
||||
seed = config.run.seed + global_rank
|
||||
set_seed(seed)
|
||||
|
||||
# Logging folder
|
||||
folder_name = "train" + datetime.datetime.now().strftime(f"-%Y_%m_%d-%H:%M:%S")
|
||||
output_dir = os.path.join(config.data.train_output_dir, folder_name)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
|
||||
# Handle the output folder creation
|
||||
if is_main_process:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
|
||||
os.makedirs(f"{output_dir}/loss_charts", exist_ok=True)
|
||||
shutil.copy(config.config_path, output_dir)
|
||||
|
||||
device = torch.device(local_rank)
|
||||
|
||||
if config.data.latent_space:
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
|
||||
vae.requires_grad_(False)
|
||||
vae.to(device)
|
||||
else:
|
||||
vae = None
|
||||
|
||||
# Dataset and Dataloader setup
|
||||
train_dataset = SyncNetDataset(config.data.train_data_dir, config.data.train_fileslist, config)
|
||||
val_dataset = SyncNetDataset(config.data.val_data_dir, config.data.val_fileslist, config)
|
||||
|
||||
train_distributed_sampler = DistributedSampler(
|
||||
train_dataset,
|
||||
num_replicas=num_processes,
|
||||
rank=global_rank,
|
||||
shuffle=True,
|
||||
seed=config.run.seed,
|
||||
)
|
||||
|
||||
# DataLoaders creation:
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=False,
|
||||
sampler=train_distributed_sampler,
|
||||
num_workers=config.data.num_workers,
|
||||
pin_memory=False,
|
||||
drop_last=True,
|
||||
worker_init_fn=train_dataset.worker_init_fn,
|
||||
)
|
||||
|
||||
num_samples_limit = 640
|
||||
|
||||
val_batch_size = min(
|
||||
num_samples_limit // config.data.num_frames, config.data.batch_size
|
||||
) # limit batch size to avoid CUDA OOM
|
||||
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset,
|
||||
batch_size=val_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=config.data.num_workers,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
worker_init_fn=val_dataset.worker_init_fn,
|
||||
)
|
||||
|
||||
# Model
|
||||
syncnet = StableSyncNet(OmegaConf.to_container(config.model)).to(device)
|
||||
# syncnet = Wav2LipSyncNet().to(device)
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
list(filter(lambda p: p.requires_grad, syncnet.parameters())), lr=config.optimizer.lr
|
||||
)
|
||||
|
||||
global_step = 0
|
||||
train_step_list = []
|
||||
train_loss_list = []
|
||||
val_step_list = []
|
||||
val_loss_list = []
|
||||
|
||||
if config.ckpt.resume_ckpt_path != "":
|
||||
if is_main_process:
|
||||
logger.info(f"Load checkpoint from: {config.ckpt.resume_ckpt_path}")
|
||||
ckpt = torch.load(config.ckpt.resume_ckpt_path, map_location=device, weights_only=True)
|
||||
|
||||
syncnet.load_state_dict(ckpt["state_dict"])
|
||||
|
||||
if "global_step" in ckpt:
|
||||
global_step = ckpt["global_step"]
|
||||
train_step_list = ckpt["train_step_list"]
|
||||
train_loss_list = ckpt["train_loss_list"]
|
||||
val_step_list = ckpt["val_step_list"]
|
||||
val_loss_list = ckpt["val_loss_list"]
|
||||
|
||||
# DDP wrapper
|
||||
syncnet = DDP(syncnet, device_ids=[local_rank], output_device=local_rank)
|
||||
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
|
||||
num_train_epochs = math.ceil(config.run.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
if is_main_process:
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {config.data.batch_size}")
|
||||
logger.info(
|
||||
f" Total train batch size (w. parallel & distributed & accumulation) = {config.data.batch_size * num_processes * config.data.gradient_accumulation_steps}"
|
||||
)
|
||||
logger.info(f" Total optimization steps = {config.run.max_train_steps}")
|
||||
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
num_val_batches = config.data.num_val_samples // (num_processes * config.data.batch_size)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(
|
||||
range(0, config.run.max_train_steps), initial=global_step, desc="Steps", disable=not is_main_process
|
||||
)
|
||||
|
||||
# Support mixed-precision training
|
||||
scaler = torch.amp.GradScaler("cuda") if config.run.mixed_precision_training else None
|
||||
|
||||
for epoch in range(first_epoch, num_train_epochs):
|
||||
train_dataloader.sampler.set_epoch(epoch)
|
||||
syncnet.train()
|
||||
step_loss = 0
|
||||
optimizer.zero_grad()
|
||||
|
||||
for index, batch in enumerate(train_dataloader):
|
||||
### >>>> Training >>>> ###
|
||||
|
||||
frames = batch["frames"].to(device, dtype=torch.float16)
|
||||
audio_samples = batch["audio_samples"].to(device, dtype=torch.float16)
|
||||
y = batch["y"].to(device, dtype=torch.float32)
|
||||
|
||||
if config.data.latent_space:
|
||||
max_batch_size = (
|
||||
num_samples_limit // config.data.num_frames
|
||||
) # due to the limited cuda memory, we split the input frames into parts
|
||||
if frames.shape[0] > max_batch_size:
|
||||
assert (
|
||||
frames.shape[0] % max_batch_size == 0
|
||||
), f"max_batch_size {max_batch_size} should be divisible by batch_size {frames.shape[0]}"
|
||||
frames_part_results = []
|
||||
for i in range(0, frames.shape[0], max_batch_size):
|
||||
frames_part = frames[i : i + max_batch_size]
|
||||
frames_part = rearrange(frames_part, "b f c h w -> (b f) c h w")
|
||||
with torch.no_grad():
|
||||
frames_part = vae.encode(frames_part).latent_dist.sample() * 0.18215
|
||||
frames_part_results.append(frames_part)
|
||||
frames = torch.cat(frames_part_results, dim=0)
|
||||
else:
|
||||
frames = rearrange(frames, "b f c h w -> (b f) c h w")
|
||||
with torch.no_grad():
|
||||
frames = vae.encode(frames).latent_dist.sample() * 0.18215
|
||||
|
||||
frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=config.data.num_frames)
|
||||
else:
|
||||
frames = rearrange(frames, "b f c h w -> b (f c) h w")
|
||||
|
||||
if config.data.lower_half:
|
||||
height = frames.shape[2]
|
||||
frames = frames[:, :, height // 2 :, :]
|
||||
|
||||
# Disable gradient sync for the first N-1 steps, enable sync on the final step
|
||||
with syncnet.no_sync() if (index + 1) % config.data.gradient_accumulation_steps != 0 else dummy_context():
|
||||
# Mixed-precision training
|
||||
with torch.autocast(
|
||||
device_type="cuda", dtype=torch.float16, enabled=config.run.mixed_precision_training
|
||||
):
|
||||
vision_embeds, audio_embeds = syncnet(frames, audio_samples)
|
||||
|
||||
loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), y).mean()
|
||||
loss = loss / config.data.gradient_accumulation_steps
|
||||
|
||||
# Backpropagate
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
step_loss += gather_loss(loss, device)
|
||||
|
||||
# Update parameters when the accumulation steps are reached
|
||||
if (index + 1) % config.data.gradient_accumulation_steps == 0:
|
||||
""">>> gradient clipping >>>"""
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(syncnet.parameters(), config.optimizer.max_grad_norm)
|
||||
""" <<< gradient clipping <<< """
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
train_step_list.append(global_step)
|
||||
train_loss_list.append(step_loss)
|
||||
|
||||
if is_main_process and global_step % config.run.validation_steps == 0:
|
||||
logger.info(f"Validation at step {global_step}")
|
||||
val_loss = validation(
|
||||
val_dataloader,
|
||||
device,
|
||||
syncnet,
|
||||
config.data.latent_space,
|
||||
config.data.lower_half,
|
||||
vae,
|
||||
num_val_batches,
|
||||
)
|
||||
val_step_list.append(global_step)
|
||||
val_loss_list.append(val_loss)
|
||||
logger.info(f"Validation loss at step {global_step} is {val_loss:0.3f}")
|
||||
plot_loss_chart(
|
||||
os.path.join(output_dir, f"loss_charts/loss_chart-{global_step}.png"),
|
||||
("Train loss", train_step_list, train_loss_list),
|
||||
("Val loss", val_step_list, val_loss_list),
|
||||
)
|
||||
|
||||
if is_main_process and global_step % config.ckpt.save_ckpt_steps == 0:
|
||||
checkpoint_save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}.pt")
|
||||
torch.save(
|
||||
{
|
||||
"state_dict": syncnet.module.state_dict(), # to unwrap DDP
|
||||
"global_step": global_step,
|
||||
"train_step_list": train_step_list,
|
||||
"train_loss_list": train_loss_list,
|
||||
"val_step_list": val_step_list,
|
||||
"val_loss_list": val_loss_list,
|
||||
},
|
||||
checkpoint_save_path,
|
||||
)
|
||||
logger.info(f"Saved checkpoint to {checkpoint_save_path}")
|
||||
|
||||
progress_bar.set_postfix({"step_loss": step_loss, "epoch": epoch})
|
||||
step_loss = 0
|
||||
|
||||
if global_step >= config.run.max_train_steps:
|
||||
break
|
||||
|
||||
progress_bar.close()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validation(val_dataloader, device, syncnet, latent_space, lower_half, vae, num_val_batches):
|
||||
syncnet.eval()
|
||||
|
||||
losses = []
|
||||
val_step = 0
|
||||
while True:
|
||||
for index, batch in enumerate(val_dataloader):
|
||||
### >>>> Validation >>>> ###
|
||||
|
||||
frames = batch["frames"].to(device, dtype=torch.float16)
|
||||
audio_samples = batch["audio_samples"].to(device, dtype=torch.float16)
|
||||
y = batch["y"].to(device, dtype=torch.float32)
|
||||
|
||||
if latent_space:
|
||||
num_frames = frames.shape[1]
|
||||
frames = rearrange(frames, "b f c h w -> (b f) c h w")
|
||||
frames = vae.encode(frames).latent_dist.sample() * 0.18215
|
||||
frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=num_frames)
|
||||
else:
|
||||
frames = rearrange(frames, "b f c h w -> b (f c) h w")
|
||||
|
||||
if lower_half:
|
||||
height = frames.shape[2]
|
||||
frames = frames[:, :, height // 2 :, :]
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||
vision_embeds, audio_embeds = syncnet(frames, audio_samples)
|
||||
|
||||
loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), y).mean()
|
||||
|
||||
losses.append(loss.item())
|
||||
|
||||
val_step += 1
|
||||
if val_step > num_val_batches:
|
||||
syncnet.train()
|
||||
if len(losses) == 0:
|
||||
raise RuntimeError("No validation data")
|
||||
return sum(losses) / len(losses)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Code to train the SyncNet")
|
||||
parser.add_argument("--config_path", type=str, default="configs/syncnet/syncnet_16_pixel.yaml")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load a configuration file
|
||||
config = OmegaConf.load(args.config_path)
|
||||
config.config_path = args.config_path
|
||||
|
||||
main(config)
|
||||
519
models/LatentSync/scripts/train_unet.py
Normal file
519
models/LatentSync/scripts/train_unet.py
Normal file
@@ -0,0 +1,519 @@
|
||||
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import math
|
||||
import argparse
|
||||
import shutil
|
||||
import datetime
|
||||
import logging
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from einops import rearrange
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDIMScheduler
|
||||
from diffusers.utils.logging import get_logger
|
||||
from diffusers.optimization import get_scheduler
|
||||
from accelerate.utils import set_seed
|
||||
|
||||
from latentsync.data.unet_dataset import UNetDataset
|
||||
from latentsync.models.unet import UNet3DConditionModel
|
||||
from latentsync.models.stable_syncnet import StableSyncNet
|
||||
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
|
||||
from latentsync.utils.util import (
|
||||
init_dist,
|
||||
cosine_loss,
|
||||
one_step_sampling,
|
||||
)
|
||||
from latentsync.utils.util import plot_loss_chart
|
||||
from latentsync.whisper.audio2feature import Audio2Feature
|
||||
from latentsync.trepa.loss import TREPALoss
|
||||
from eval.syncnet import SyncNetEval
|
||||
from eval.syncnet_detect import SyncNetDetector
|
||||
from eval.eval_sync_conf import syncnet_eval
|
||||
import lpips
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def main(config):
|
||||
# Initialize distributed training
|
||||
local_rank = init_dist()
|
||||
global_rank = dist.get_rank()
|
||||
num_processes = dist.get_world_size()
|
||||
is_main_process = global_rank == 0
|
||||
|
||||
seed = config.run.seed + global_rank
|
||||
set_seed(seed)
|
||||
|
||||
# Logging folder
|
||||
folder_name = "train" + datetime.datetime.now().strftime(f"-%Y_%m_%d-%H:%M:%S")
|
||||
output_dir = os.path.join(config.data.train_output_dir, folder_name)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
|
||||
# Handle the output folder creation
|
||||
if is_main_process:
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
|
||||
os.makedirs(f"{output_dir}/val_videos", exist_ok=True)
|
||||
os.makedirs(f"{output_dir}/sync_conf_results", exist_ok=True)
|
||||
shutil.copy(config.unet_config_path, output_dir)
|
||||
shutil.copy(config.data.syncnet_config_path, output_dir)
|
||||
|
||||
device = torch.device(local_rank)
|
||||
|
||||
noise_scheduler = DDIMScheduler.from_pretrained("configs")
|
||||
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
|
||||
vae.config.scaling_factor = 0.18215
|
||||
vae.config.shift_factor = 0
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
||||
vae.requires_grad_(False)
|
||||
vae.to(device)
|
||||
|
||||
if config.run.pixel_space_supervise:
|
||||
vae.enable_gradient_checkpointing()
|
||||
|
||||
syncnet_eval_model = SyncNetEval(device=device)
|
||||
syncnet_eval_model.loadParameters("checkpoints/auxiliary/syncnet_v2.model")
|
||||
|
||||
syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results")
|
||||
|
||||
if config.model.cross_attention_dim == 768:
|
||||
whisper_model_path = "checkpoints/whisper/small.pt"
|
||||
elif config.model.cross_attention_dim == 384:
|
||||
whisper_model_path = "checkpoints/whisper/tiny.pt"
|
||||
else:
|
||||
raise NotImplementedError("cross_attention_dim must be 768 or 384")
|
||||
|
||||
audio_encoder = Audio2Feature(
|
||||
model_path=whisper_model_path,
|
||||
device=device,
|
||||
audio_embeds_cache_dir=config.data.audio_embeds_cache_dir,
|
||||
num_frames=config.data.num_frames,
|
||||
audio_feat_length=config.data.audio_feat_length,
|
||||
)
|
||||
|
||||
unet, resume_global_step = UNet3DConditionModel.from_pretrained(
|
||||
OmegaConf.to_container(config.model),
|
||||
config.ckpt.resume_ckpt_path,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if config.model.add_audio_layer and config.run.use_syncnet:
|
||||
syncnet_config = OmegaConf.load(config.data.syncnet_config_path)
|
||||
if syncnet_config.ckpt.inference_ckpt_path == "":
|
||||
raise ValueError("SyncNet path is not provided")
|
||||
syncnet = StableSyncNet(OmegaConf.to_container(syncnet_config.model), gradient_checkpointing=True).to(
|
||||
device=device, dtype=torch.float16
|
||||
)
|
||||
syncnet_checkpoint = torch.load(
|
||||
syncnet_config.ckpt.inference_ckpt_path, map_location=device, weights_only=True
|
||||
)
|
||||
syncnet.load_state_dict(syncnet_checkpoint["state_dict"])
|
||||
syncnet.requires_grad_(False)
|
||||
|
||||
del syncnet_checkpoint
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if config.model.use_motion_module:
|
||||
unet.requires_grad_(False)
|
||||
for name, param in unet.named_parameters():
|
||||
for trainable_module_name in config.run.trainable_modules:
|
||||
if trainable_module_name in name:
|
||||
param.requires_grad = True
|
||||
break
|
||||
trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||
else:
|
||||
unet.requires_grad_(True)
|
||||
trainable_params = list(unet.parameters())
|
||||
|
||||
if config.optimizer.scale_lr:
|
||||
config.optimizer.lr = config.optimizer.lr * num_processes
|
||||
|
||||
optimizer = torch.optim.AdamW(trainable_params, lr=config.optimizer.lr)
|
||||
|
||||
if is_main_process:
|
||||
logger.info(f"trainable params number: {len(trainable_params)}")
|
||||
logger.info(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
|
||||
|
||||
# Enable gradient checkpointing
|
||||
if config.run.enable_gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
# Get the training dataset
|
||||
train_dataset = UNetDataset(config.data.train_data_dir, config)
|
||||
distributed_sampler = DistributedSampler(
|
||||
train_dataset,
|
||||
num_replicas=num_processes,
|
||||
rank=global_rank,
|
||||
shuffle=True,
|
||||
seed=config.run.seed,
|
||||
)
|
||||
|
||||
# DataLoaders creation:
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.data.batch_size,
|
||||
shuffle=False,
|
||||
sampler=distributed_sampler,
|
||||
num_workers=config.data.num_workers,
|
||||
pin_memory=False,
|
||||
drop_last=True,
|
||||
worker_init_fn=train_dataset.worker_init_fn,
|
||||
)
|
||||
|
||||
# Get the training iteration
|
||||
if config.run.max_train_steps == -1:
|
||||
assert config.run.max_train_epochs != -1
|
||||
config.run.max_train_steps = config.run.max_train_epochs * len(train_dataloader)
|
||||
|
||||
# Scheduler
|
||||
lr_scheduler = get_scheduler(
|
||||
config.optimizer.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=config.optimizer.lr_warmup_steps,
|
||||
num_training_steps=config.run.max_train_steps,
|
||||
)
|
||||
|
||||
if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise:
|
||||
lpips_loss_func = lpips.LPIPS(net="vgg").to(device)
|
||||
|
||||
if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise:
|
||||
trepa_loss_func = TREPALoss(device=device, with_cp=True)
|
||||
|
||||
# Validation pipeline
|
||||
pipeline = LipsyncPipeline(
|
||||
vae=vae,
|
||||
audio_encoder=audio_encoder,
|
||||
unet=unet,
|
||||
scheduler=noise_scheduler,
|
||||
).to(device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# DDP warpper
|
||||
unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
num_train_epochs = math.ceil(config.run.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Train!
|
||||
total_batch_size = config.data.batch_size * num_processes
|
||||
|
||||
if is_main_process:
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {config.data.batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Total optimization steps = {config.run.max_train_steps}")
|
||||
global_step = resume_global_step
|
||||
first_epoch = resume_global_step // num_update_steps_per_epoch
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(
|
||||
range(0, config.run.max_train_steps),
|
||||
initial=resume_global_step,
|
||||
desc="Steps",
|
||||
disable=not is_main_process,
|
||||
)
|
||||
|
||||
train_step_list = []
|
||||
val_step_list = []
|
||||
sync_conf_list = []
|
||||
|
||||
# Support mixed-precision training
|
||||
scaler = torch.amp.GradScaler("cuda") if config.run.mixed_precision_training else None
|
||||
|
||||
for epoch in range(first_epoch, num_train_epochs):
|
||||
train_dataloader.sampler.set_epoch(epoch)
|
||||
unet.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
### >>>> Training >>>> ###
|
||||
|
||||
if config.model.add_audio_layer:
|
||||
if batch["mel"] != []:
|
||||
mel = batch["mel"].to(device, dtype=torch.float16)
|
||||
|
||||
audio_embeds_list = []
|
||||
try:
|
||||
for idx in range(len(batch["video_path"])):
|
||||
video_path = batch["video_path"][idx]
|
||||
start_idx = batch["start_idx"][idx]
|
||||
|
||||
with torch.no_grad():
|
||||
audio_feat = audio_encoder.audio2feat(video_path)
|
||||
audio_embeds = audio_encoder.crop_overlap_audio_window(audio_feat, start_idx)
|
||||
audio_embeds_list.append(audio_embeds)
|
||||
except Exception as e:
|
||||
logger.info(f"{type(e).__name__} - {e} - {video_path}")
|
||||
continue
|
||||
audio_embeds = torch.stack(audio_embeds_list) # (B, 16, 50, 384)
|
||||
audio_embeds = audio_embeds.to(device, dtype=torch.float16)
|
||||
else:
|
||||
audio_embeds = None
|
||||
|
||||
# Convert videos to latent space
|
||||
gt_pixel_values = batch["gt_pixel_values"].to(device, dtype=torch.float16)
|
||||
masked_pixel_values = batch["masked_pixel_values"].to(device, dtype=torch.float16)
|
||||
masks = batch["masks"].to(device, dtype=torch.float16)
|
||||
ref_pixel_values = batch["ref_pixel_values"].to(device, dtype=torch.float16)
|
||||
|
||||
gt_pixel_values = rearrange(gt_pixel_values, "b f c h w -> (b f) c h w")
|
||||
masked_pixel_values = rearrange(masked_pixel_values, "b f c h w -> (b f) c h w")
|
||||
masks = rearrange(masks, "b f c h w -> (b f) c h w")
|
||||
ref_pixel_values = rearrange(ref_pixel_values, "b f c h w -> (b f) c h w")
|
||||
|
||||
with torch.no_grad():
|
||||
gt_latents = vae.encode(gt_pixel_values).latent_dist.sample()
|
||||
masked_latents = vae.encode(masked_pixel_values).latent_dist.sample()
|
||||
ref_latents = vae.encode(ref_pixel_values).latent_dist.sample()
|
||||
|
||||
masks = torch.nn.functional.interpolate(masks, size=config.data.resolution // vae_scale_factor)
|
||||
|
||||
gt_latents = (
|
||||
rearrange(gt_latents, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor
|
||||
) * vae.config.scaling_factor
|
||||
masked_latents = (
|
||||
rearrange(masked_latents, "(b f) c h w -> b c f h w", f=config.data.num_frames)
|
||||
- vae.config.shift_factor
|
||||
) * vae.config.scaling_factor
|
||||
ref_latents = (
|
||||
rearrange(ref_latents, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor
|
||||
) * vae.config.scaling_factor
|
||||
masks = rearrange(masks, "(b f) c h w -> b c f h w", f=config.data.num_frames)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
if config.run.use_mixed_noise:
|
||||
# Refer to the paper: https://arxiv.org/abs/2305.10474
|
||||
noise_shared_std_dev = (config.run.mixed_noise_alpha**2 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5
|
||||
noise_shared = torch.randn_like(gt_latents) * noise_shared_std_dev
|
||||
noise_shared = noise_shared[:, :, 0:1].repeat(1, 1, config.data.num_frames, 1, 1)
|
||||
|
||||
noise_ind_std_dev = (1 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5
|
||||
noise_ind = torch.randn_like(gt_latents) * noise_ind_std_dev
|
||||
noise = noise_ind + noise_shared
|
||||
else:
|
||||
noise = torch.randn_like(gt_latents)
|
||||
noise = noise[:, :, 0:1].repeat(
|
||||
1, 1, config.data.num_frames, 1, 1
|
||||
) # Using the same noise for all frames, refer to the paper: https://arxiv.org/abs/2308.09716
|
||||
|
||||
bsz = gt_latents.shape[0]
|
||||
|
||||
# Sample a random timestep for each video
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=gt_latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_gt_latents = noise_scheduler.add_noise(gt_latents, noise, timesteps)
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
unet_input = torch.cat([noisy_gt_latents, masks, masked_latents, ref_latents], dim=1)
|
||||
|
||||
# Predict the noise and compute loss
|
||||
# Mixed-precision training
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=config.run.mixed_precision_training):
|
||||
pred_noise = unet(unet_input, timesteps, encoder_hidden_states=audio_embeds).sample
|
||||
|
||||
if config.run.recon_loss_weight != 0:
|
||||
recon_loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean")
|
||||
else:
|
||||
recon_loss = 0
|
||||
|
||||
pred_latents = one_step_sampling(noise_scheduler, pred_noise, timesteps, noisy_gt_latents)
|
||||
|
||||
if config.run.pixel_space_supervise:
|
||||
pred_pixel_values = vae.decode(
|
||||
rearrange(pred_latents, "b c f h w -> (b f) c h w") / vae.config.scaling_factor
|
||||
+ vae.config.shift_factor
|
||||
).sample
|
||||
|
||||
if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise:
|
||||
pred_pixel_values_perceptual = pred_pixel_values[:, :, pred_pixel_values.shape[2] // 2 :, :]
|
||||
gt_pixel_values_perceptual = gt_pixel_values[:, :, gt_pixel_values.shape[2] // 2 :, :]
|
||||
lpips_loss = lpips_loss_func(
|
||||
pred_pixel_values_perceptual.float(), gt_pixel_values_perceptual.float()
|
||||
).mean()
|
||||
else:
|
||||
lpips_loss = 0
|
||||
|
||||
if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise:
|
||||
trepa_pred_pixel_values = rearrange(
|
||||
pred_pixel_values, "(b f) c h w -> b c f h w", f=config.data.num_frames
|
||||
)
|
||||
trepa_gt_pixel_values = rearrange(
|
||||
gt_pixel_values, "(b f) c h w -> b c f h w", f=config.data.num_frames
|
||||
)
|
||||
trepa_loss = trepa_loss_func(trepa_pred_pixel_values, trepa_gt_pixel_values)
|
||||
else:
|
||||
trepa_loss = 0
|
||||
|
||||
if config.model.add_audio_layer and config.run.use_syncnet:
|
||||
if config.run.pixel_space_supervise:
|
||||
if config.data.resolution != syncnet_config.data.resolution:
|
||||
pred_pixel_values = F.interpolate(
|
||||
pred_pixel_values,
|
||||
size=(syncnet_config.data.resolution, syncnet_config.data.resolution),
|
||||
mode="bicubic",
|
||||
)
|
||||
syncnet_input = rearrange(
|
||||
pred_pixel_values, "(b f) c h w -> b (f c) h w", f=config.data.num_frames
|
||||
)
|
||||
else:
|
||||
syncnet_input = rearrange(pred_latents, "b c f h w -> b (f c) h w")
|
||||
|
||||
if syncnet_config.data.lower_half:
|
||||
height = syncnet_input.shape[2]
|
||||
syncnet_input = syncnet_input[:, :, height // 2 :, :]
|
||||
ones_tensor = torch.ones((config.data.batch_size, 1)).float().to(device=device)
|
||||
vision_embeds, audio_embeds = syncnet(syncnet_input, mel)
|
||||
sync_loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), ones_tensor).mean()
|
||||
else:
|
||||
sync_loss = 0
|
||||
|
||||
loss = (
|
||||
recon_loss * config.run.recon_loss_weight
|
||||
+ sync_loss * config.run.sync_loss_weight
|
||||
+ lpips_loss * config.run.perceptual_loss_weight
|
||||
+ trepa_loss * config.run.trepa_loss_weight
|
||||
)
|
||||
|
||||
train_step_list.append(global_step)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Backpropagate
|
||||
if config.run.mixed_precision_training:
|
||||
scaler.scale(loss).backward()
|
||||
""" >>> gradient clipping >>> """
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(trainable_params, config.optimizer.max_grad_norm)
|
||||
""" <<< gradient clipping <<< """
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
""" >>> gradient clipping >>> """
|
||||
torch.nn.utils.clip_grad_norm_(trainable_params, config.optimizer.max_grad_norm)
|
||||
""" <<< gradient clipping <<< """
|
||||
optimizer.step()
|
||||
|
||||
# Check the grad of attn blocks for debugging
|
||||
# print(unet.module.up_blocks[3].attentions[2].transformer_blocks[0].attn2.to_q.weight.grad)
|
||||
|
||||
lr_scheduler.step()
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
### <<<< Training <<<< ###
|
||||
|
||||
# Save checkpoint and conduct validation
|
||||
if is_main_process and (global_step % config.ckpt.save_ckpt_steps == 0):
|
||||
model_save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}.pt")
|
||||
state_dict = {
|
||||
"global_step": global_step,
|
||||
"state_dict": unet.module.state_dict(),
|
||||
}
|
||||
try:
|
||||
torch.save(state_dict, model_save_path)
|
||||
logger.info(f"Saved checkpoint to {model_save_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving model: {e}")
|
||||
|
||||
# Validation
|
||||
logger.info("Running validation... ")
|
||||
|
||||
validation_video_out_path = os.path.join(output_dir, f"val_videos/val_video_{global_step}.mp4")
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
||||
pipeline(
|
||||
config.data.val_video_path,
|
||||
config.data.val_audio_path,
|
||||
validation_video_out_path,
|
||||
num_frames=config.data.num_frames,
|
||||
num_inference_steps=config.run.inference_steps,
|
||||
guidance_scale=config.run.guidance_scale,
|
||||
weight_dtype=torch.float16,
|
||||
width=config.data.resolution,
|
||||
height=config.data.resolution,
|
||||
mask_image_path=config.data.mask_image_path,
|
||||
)
|
||||
|
||||
logger.info(f"Saved validation video output to {validation_video_out_path}")
|
||||
|
||||
val_step_list.append(global_step)
|
||||
|
||||
if config.model.add_audio_layer and os.path.exists(validation_video_out_path):
|
||||
try:
|
||||
_, conf = syncnet_eval(syncnet_eval_model, syncnet_detector, validation_video_out_path, "temp")
|
||||
except Exception as e:
|
||||
logger.info(e)
|
||||
conf = 0
|
||||
sync_conf_list.append(conf)
|
||||
plot_loss_chart(
|
||||
os.path.join(output_dir, f"sync_conf_results/sync_conf_chart-{global_step}.png"),
|
||||
("Sync confidence", val_step_list, sync_conf_list),
|
||||
)
|
||||
|
||||
logs = {"step_loss": loss.item(), "epoch": epoch}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= config.run.max_train_steps:
|
||||
break
|
||||
|
||||
progress_bar.close()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Config file path
|
||||
parser.add_argument("--unet_config_path", type=str, default="configs/unet.yaml")
|
||||
|
||||
args = parser.parse_args()
|
||||
config = OmegaConf.load(args.unet_config_path)
|
||||
config.unet_config_path = args.unet_config_path
|
||||
|
||||
main(config)
|
||||
Reference in New Issue
Block a user