优化代码

This commit is contained in:
Kevin Wong
2026-01-21 10:30:32 +08:00
parent 1890cea3ee
commit cbf840f472
72 changed files with 112399 additions and 92 deletions

View File

@@ -139,6 +139,45 @@ CUDA_VISIBLE_DEVICES=1 python -m scripts.inference \
---
---
## 步骤 7: 性能优化 (预加载模型服务)
为了消除每次生成视频时 30-40秒 的模型加载时间,建议运行常驻服务。
### 1. 安装服务依赖
```bash
conda activate latentsync
pip install fastapi uvicorn
```
### 2. 启动服务
**前台运行 (测试)**:
```bash
cd /home/rongye/ProgramFiles/ViGent2/models/LatentSync
# 启动服务 (端口 8007) - 会自动读取 backend/.env 中的 GPU 配置
python -m scripts.server
```
**后台运行 (推荐)**:
```bash
nohup python -m scripts.server > server.log 2>&1 &
```
### 3. 更新配置
修改 `ViGent2/backend/.env`:
```bash
LATENTSYNC_USE_SERVER=True
```
现在,后端通过 API 调用本地常驻服务,生成速度将显著提升。
---
## 故障排除
### CUDA 内存不足

View File

@@ -0,0 +1,23 @@
audio:
num_mels: 80 # Number of mel-spectrogram channels and local conditioning dimensionality
rescale: true # Whether to rescale audio prior to preprocessing
rescaling_max: 0.9 # Rescaling value
use_lws:
false # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
# Does not work if n_ffit is not multiple of hop_size!!
n_fft: 800 # Extra window size is filled with 0 paddings to match this parameter
hop_size: 200 # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
win_size: 800 # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
sample_rate: 16000 # 16000Hz (corresponding to librispeech) (sox --i <filename>)
frame_shift_ms: null
signal_normalization: true
allow_clipping_in_normalization: true
symmetric_mels: true
max_abs_value: 4.0
preemphasize: true # whether to apply filter
preemphasis: 0.97 # filter coefficient.
min_level_db: -100
ref_level_db: 20
fmin: 55
fmax: 7600

View File

@@ -0,0 +1,12 @@
{
"_class_name": "DDIMScheduler",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": false,
"num_train_timesteps": 1000,
"set_alpha_to_one": false,
"steps_offset": 1,
"trained_betas": null,
"skip_prk_steps": true
}

View File

@@ -0,0 +1,46 @@
model:
audio_encoder: # input (1, 80, 52)
in_channels: 1
block_out_channels: [32, 64, 128, 256, 512, 1024]
downsample_factors: [[2, 1], 2, 2, 2, 2, [2, 3]]
attn_blocks: [0, 0, 0, 0, 0, 0]
dropout: 0.0
visual_encoder: # input (64, 32, 32)
in_channels: 64
block_out_channels: [64, 128, 256, 256, 512, 1024]
downsample_factors: [2, 2, 2, 1, 2, 2]
attn_blocks: [0, 0, 0, 0, 0, 0]
dropout: 0.0
ckpt:
resume_ckpt_path: ""
inference_ckpt_path: ""
save_ckpt_steps: 2500
data:
train_output_dir: debug/syncnet
num_val_samples: 1200
batch_size: 120 # 40
gradient_accumulation_steps: 1
num_workers: 12 # 12
latent_space: true
num_frames: 16
resolution: 256
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
val_fileslist: ""
val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
lower_half: false
audio_sample_rate: 16000
video_fps: 25
optimizer:
lr: 1e-5
max_grad_norm: 1.0
run:
max_train_steps: 10000000
validation_steps: 2500
mixed_precision_training: true
seed: 42

View File

@@ -0,0 +1,46 @@
model:
audio_encoder: # input (1, 80, 52)
in_channels: 1
block_out_channels: [32, 64, 128, 256, 512, 1024, 2048]
downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]]
attn_blocks: [0, 0, 0, 0, 0, 0, 0]
dropout: 0.0
visual_encoder: # input (48, 128, 256)
in_channels: 48
block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048]
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
attn_blocks: [0, 0, 0, 0, 0, 0, 0, 0]
dropout: 0.0
ckpt:
resume_ckpt_path: ""
inference_ckpt_path: ""
save_ckpt_steps: 2500
data:
train_output_dir: debug/syncnet
num_val_samples: 2048
batch_size: 256 # 256
gradient_accumulation_steps: 1
num_workers: 12 # 12
latent_space: false
num_frames: 16
resolution: 256
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
val_fileslist: ""
val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
lower_half: true
audio_sample_rate: 16000
video_fps: 25
optimizer:
lr: 1e-5
max_grad_norm: 1.0
run:
max_train_steps: 10000000
validation_steps: 2500
mixed_precision_training: true
seed: 42

View File

@@ -0,0 +1,46 @@
model:
audio_encoder: # input (1, 80, 52)
in_channels: 1
block_out_channels: [32, 64, 128, 256, 512, 1024, 2048]
downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]]
attn_blocks: [0, 0, 0, 1, 1, 0, 0]
dropout: 0.0
visual_encoder: # input (48, 128, 256)
in_channels: 48
block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048]
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
attn_blocks: [0, 0, 0, 0, 1, 1, 0, 0]
dropout: 0.0
ckpt:
resume_ckpt_path: ""
inference_ckpt_path: checkpoints/stable_syncnet.pt
save_ckpt_steps: 2500
data:
train_output_dir: debug/syncnet
num_val_samples: 2048
batch_size: 256 # 256
gradient_accumulation_steps: 1
num_workers: 12 # 12
latent_space: false
num_frames: 16
resolution: 256
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
val_fileslist: ""
val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
lower_half: true
audio_sample_rate: 16000
video_fps: 25
optimizer:
lr: 1e-5
max_grad_norm: 1.0
run:
max_train_steps: 10000000
validation_steps: 2500
mixed_precision_training: true
seed: 42

View File

@@ -0,0 +1,44 @@
model:
audio_encoder: # input (1, 80, 80)
in_channels: 1
block_out_channels: [64, 128, 256, 256, 512, 1024]
downsample_factors: [2, 2, 2, 2, 2, 2]
dropout: 0.0
visual_encoder: # input (75, 128, 256)
in_channels: 75
block_out_channels: [128, 128, 256, 256, 512, 512, 1024, 1024]
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
dropout: 0.0
ckpt:
resume_ckpt_path: ""
inference_ckpt_path: ""
save_ckpt_steps: 2500
data:
train_output_dir: debug/syncnet
num_val_samples: 2048
batch_size: 64 # 64
gradient_accumulation_steps: 1
num_workers: 12 # 12
latent_space: false
num_frames: 25
resolution: 256
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
val_fileslist: ""
val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
lower_half: true
audio_sample_rate: 16000
video_fps: 25
optimizer:
lr: 1e-5
max_grad_norm: 1.0
run:
max_train_steps: 10000000
validation_steps: 2500
mixed_precision_training: true
seed: 42

View File

@@ -0,0 +1,96 @@
data:
syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml
train_output_dir: debug/unet
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
val_video_path: assets/demo1_video.mp4
val_audio_path: assets/demo1_audio.wav
batch_size: 1 # 24
num_workers: 12 # 12
num_frames: 16
resolution: 256
mask_image_path: latentsync/utils/mask.png
audio_sample_rate: 16000
video_fps: 25
audio_feat_length: [2, 2]
ckpt:
resume_ckpt_path: checkpoints/latentsync_unet.pt
save_ckpt_steps: 10000
run:
pixel_space_supervise: false
use_syncnet: false
sync_loss_weight: 0.05
perceptual_loss_weight: 0.1 # 0.1
recon_loss_weight: 1 # 1
guidance_scale: 1.5 # [1.0 - 3.0]
trepa_loss_weight: 10
inference_steps: 20
seed: 1247
use_mixed_noise: true
mixed_noise_alpha: 1 # 1
mixed_precision_training: true
enable_gradient_checkpointing: true
max_train_steps: 10000000
max_train_epochs: -1
optimizer:
lr: 1e-5
scale_lr: false
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_steps: 0
model:
act_fn: silu
add_audio_layer: true
attention_head_dim: 8
block_out_channels: [320, 640, 1280, 1280]
center_input_sample: false
cross_attention_dim: 384
down_block_types:
[
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
]
mid_block_type: UNetMidBlock3DCrossAttn
up_block_types:
[
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
]
downsample_padding: 1
flip_sin_to_cos: true
freq_shift: 0
in_channels: 13 # 49
layers_per_block: 2
mid_block_scale_factor: 1
norm_eps: 1e-5
norm_num_groups: 32
out_channels: 4 # 16
sample_size: 64
resnet_time_scale_shift: default # Choose between [default, scale_shift]
use_motion_module: false
motion_module_resolutions: [1, 2, 4, 8]
motion_module_mid_block: false
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 24
temporal_attention_dim_div: 1
zero_initialize: true

View File

@@ -0,0 +1,96 @@
data:
syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml
train_output_dir: debug/unet
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
val_video_path: assets/demo1_video.mp4
val_audio_path: assets/demo1_audio.wav
batch_size: 1 # 8
num_workers: 12 # 12
num_frames: 16
resolution: 512
mask_image_path: latentsync/utils/mask.png
audio_sample_rate: 16000
video_fps: 25
audio_feat_length: [2, 2]
ckpt:
resume_ckpt_path: checkpoints/latentsync_unet.pt
save_ckpt_steps: 10000
run:
pixel_space_supervise: false
use_syncnet: false
sync_loss_weight: 0.05
perceptual_loss_weight: 0.1 # 0.1
recon_loss_weight: 1 # 1
guidance_scale: 1.5 # [1.0 - 3.0]
trepa_loss_weight: 10
inference_steps: 20
seed: 1247
use_mixed_noise: true
mixed_noise_alpha: 1 # 1
mixed_precision_training: true
enable_gradient_checkpointing: true
max_train_steps: 10000000
max_train_epochs: -1
optimizer:
lr: 1e-5
scale_lr: false
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_steps: 0
model:
act_fn: silu
add_audio_layer: true
attention_head_dim: 8
block_out_channels: [320, 640, 1280, 1280]
center_input_sample: false
cross_attention_dim: 384
down_block_types:
[
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
]
mid_block_type: UNetMidBlock3DCrossAttn
up_block_types:
[
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
]
downsample_padding: 1
flip_sin_to_cos: true
freq_shift: 0
in_channels: 13 # 49
layers_per_block: 2
mid_block_scale_factor: 1
norm_eps: 1e-5
norm_num_groups: 32
out_channels: 4 # 16
sample_size: 64
resnet_time_scale_shift: default # Choose between [default, scale_shift]
use_motion_module: false
motion_module_resolutions: [1, 2, 4, 8]
motion_module_mid_block: false
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 24
temporal_attention_dim_div: 1
zero_initialize: true

View File

@@ -0,0 +1,99 @@
data:
syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml
train_output_dir: debug/unet
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
val_video_path: assets/demo1_video.mp4
val_audio_path: assets/demo1_audio.wav
batch_size: 1 # 4
num_workers: 12 # 12
num_frames: 16
resolution: 256
mask_image_path: latentsync/utils/mask.png
audio_sample_rate: 16000
video_fps: 25
audio_feat_length: [2, 2]
ckpt:
resume_ckpt_path: checkpoints/latentsync_unet.pt
save_ckpt_steps: 10000
run:
pixel_space_supervise: true
use_syncnet: true
sync_loss_weight: 0.05
perceptual_loss_weight: 0.1 # 0.1
recon_loss_weight: 1 # 1
guidance_scale: 1.5 # [1.0 - 3.0]
trepa_loss_weight: 10
inference_steps: 20
trainable_modules:
- motion_modules.
- attentions.
seed: 1247
use_mixed_noise: true
mixed_noise_alpha: 1 # 1
mixed_precision_training: true
enable_gradient_checkpointing: true
max_train_steps: 10000000
max_train_epochs: -1
optimizer:
lr: 1e-5
scale_lr: false
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_steps: 0
model:
act_fn: silu
add_audio_layer: true
attention_head_dim: 8
block_out_channels: [320, 640, 1280, 1280]
center_input_sample: false
cross_attention_dim: 384
down_block_types:
[
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
]
mid_block_type: UNetMidBlock3DCrossAttn
up_block_types:
[
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
]
downsample_padding: 1
flip_sin_to_cos: true
freq_shift: 0
in_channels: 13 # 49
layers_per_block: 2
mid_block_scale_factor: 1
norm_eps: 1e-5
norm_num_groups: 32
out_channels: 4 # 16
sample_size: 64
resnet_time_scale_shift: default # Choose between [default, scale_shift]
use_motion_module: true
motion_module_resolutions: [1, 2, 4, 8]
motion_module_mid_block: false
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 24
temporal_attention_dim_div: 1
zero_initialize: true

View File

@@ -0,0 +1,99 @@
data:
syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml
train_output_dir: debug/unet
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
val_video_path: assets/demo1_video.mp4
val_audio_path: assets/demo1_audio.wav
batch_size: 1 # 4
num_workers: 12 # 12
num_frames: 16
resolution: 512
mask_image_path: latentsync/utils/mask.png
audio_sample_rate: 16000
video_fps: 25
audio_feat_length: [2, 2]
ckpt:
resume_ckpt_path: checkpoints/latentsync_unet.pt
save_ckpt_steps: 10000
run:
pixel_space_supervise: true
use_syncnet: true
sync_loss_weight: 0.05
perceptual_loss_weight: 0.1 # 0.1
recon_loss_weight: 1 # 1
guidance_scale: 1.5 # [1.0 - 3.0]
trepa_loss_weight: 10
inference_steps: 20
trainable_modules:
- motion_modules.
- attentions.
seed: 1247
use_mixed_noise: true
mixed_noise_alpha: 1 # 1
mixed_precision_training: true
enable_gradient_checkpointing: true
max_train_steps: 10000000
max_train_epochs: -1
optimizer:
lr: 1e-5
scale_lr: false
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_steps: 0
model:
act_fn: silu
add_audio_layer: true
attention_head_dim: 8
block_out_channels: [320, 640, 1280, 1280]
center_input_sample: false
cross_attention_dim: 384
down_block_types:
[
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
]
mid_block_type: UNetMidBlock3DCrossAttn
up_block_types:
[
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
]
downsample_padding: 1
flip_sin_to_cos: true
freq_shift: 0
in_channels: 13 # 49
layers_per_block: 2
mid_block_scale_factor: 1
norm_eps: 1e-5
norm_num_groups: 32
out_channels: 4 # 16
sample_size: 64
resnet_time_scale_shift: default # Choose between [default, scale_shift]
use_motion_module: true
motion_module_resolutions: [1, 2, 4, 8]
motion_module_mid_block: false
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 24
temporal_attention_dim_div: 1
zero_initialize: true

View File

@@ -0,0 +1,99 @@
data:
syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml
train_output_dir: debug/unet
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
val_video_path: assets/demo1_video.mp4
val_audio_path: assets/demo1_audio.wav
batch_size: 1 # 4
num_workers: 12 # 12
num_frames: 16
resolution: 256
mask_image_path: latentsync/utils/mask.png
audio_sample_rate: 16000
video_fps: 25
audio_feat_length: [2, 2]
ckpt:
resume_ckpt_path: checkpoints/latentsync_unet.pt
save_ckpt_steps: 10000
run:
pixel_space_supervise: true
use_syncnet: true
sync_loss_weight: 0.05
perceptual_loss_weight: 0.1 # 0.1
recon_loss_weight: 1 # 1
guidance_scale: 1.5 # [1.0 - 3.0]
trepa_loss_weight: 0
inference_steps: 20
trainable_modules:
- motion_modules.
- attn2.
seed: 1247
use_mixed_noise: true
mixed_noise_alpha: 1 # 1
mixed_precision_training: true
enable_gradient_checkpointing: true
max_train_steps: 10000000
max_train_epochs: -1
optimizer:
lr: 1e-5
scale_lr: false
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_steps: 0
model:
act_fn: silu
add_audio_layer: true
attention_head_dim: 8
block_out_channels: [320, 640, 1280, 1280]
center_input_sample: false
cross_attention_dim: 384
down_block_types:
[
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
]
mid_block_type: UNetMidBlock3DCrossAttn
up_block_types:
[
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
]
downsample_padding: 1
flip_sin_to_cos: true
freq_shift: 0
in_channels: 13 # 49
layers_per_block: 2
mid_block_scale_factor: 1
norm_eps: 1e-5
norm_num_groups: 32
out_channels: 4 # 16
sample_size: 64
resnet_time_scale_shift: default # Choose between [default, scale_shift]
use_motion_module: true
motion_module_resolutions: [1, 2, 4, 8]
motion_module_mid_block: false
motion_module_decoder_only: true
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 24
temporal_attention_dim_div: 1
zero_initialize: true

View File

@@ -0,0 +1,139 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from torch.utils.data import Dataset
import torch
import random
from ..utils.util import gather_video_paths_recursively
from ..utils.image_processor import ImageProcessor
from ..utils.audio import melspectrogram
import math
from pathlib import Path
from decord import AudioReader, VideoReader, cpu
class SyncNetDataset(Dataset):
def __init__(self, data_dir: str, fileslist: str, config):
if fileslist != "":
with open(fileslist) as file:
self.video_paths = [line.rstrip() for line in file]
elif data_dir != "":
self.video_paths = gather_video_paths_recursively(data_dir)
else:
raise ValueError("data_dir and fileslist cannot be both empty")
self.resolution = config.data.resolution
self.num_frames = config.data.num_frames
self.mel_window_length = math.ceil(self.num_frames / 5 * 16)
self.audio_sample_rate = config.data.audio_sample_rate
self.video_fps = config.data.video_fps
self.image_processor = ImageProcessor(resolution=config.data.resolution)
self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
Path(self.audio_mel_cache_dir).mkdir(parents=True, exist_ok=True)
def __len__(self):
return len(self.video_paths)
def read_audio(self, video_path: str):
ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
return torch.from_numpy(original_mel)
def crop_audio_window(self, original_mel, start_index):
start_idx = int(80.0 * (start_index / float(self.video_fps)))
end_idx = start_idx + self.mel_window_length
return original_mel[:, start_idx:end_idx].unsqueeze(0)
def get_frames(self, video_reader: VideoReader):
total_num_frames = len(video_reader)
start_idx = random.randint(0, total_num_frames - self.num_frames)
frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
while True:
wrong_start_idx = random.randint(0, total_num_frames - self.num_frames)
if wrong_start_idx == start_idx:
continue
wrong_frames_index = np.arange(wrong_start_idx, wrong_start_idx + self.num_frames, dtype=int)
break
frames = video_reader.get_batch(frames_index).asnumpy()
wrong_frames = video_reader.get_batch(wrong_frames_index).asnumpy()
return frames, wrong_frames, start_idx
def worker_init_fn(self, worker_id):
self.worker_id = worker_id
def __getitem__(self, idx):
while True:
try:
idx = random.randint(0, len(self) - 1)
# Get video file path
video_path = self.video_paths[idx]
vr = VideoReader(video_path, ctx=cpu(self.worker_id))
if len(vr) < 2 * self.num_frames:
continue
frames, wrong_frames, start_idx = self.get_frames(vr)
mel_cache_path = os.path.join(
self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
)
if os.path.isfile(mel_cache_path):
try:
original_mel = torch.load(mel_cache_path, weights_only=True)
except Exception as e:
print(f"{type(e).__name__} - {e} - {mel_cache_path}")
os.remove(mel_cache_path)
original_mel = self.read_audio(video_path)
torch.save(original_mel, mel_cache_path)
else:
original_mel = self.read_audio(video_path)
torch.save(original_mel, mel_cache_path)
mel = self.crop_audio_window(original_mel, start_idx)
if mel.shape[-1] != self.mel_window_length:
continue
if random.choice([True, False]):
y = torch.ones(1).float()
chosen_frames = frames
else:
y = torch.zeros(1).float()
chosen_frames = wrong_frames
chosen_frames = self.image_processor.process_images(chosen_frames)
vr.seek(0) # avoid memory leak
break
except Exception as e: # Handle the exception of face not detcted
print(f"{type(e).__name__} - {e} - {video_path}")
if "vr" in locals():
vr.seek(0) # avoid memory leak
sample = dict(frames=chosen_frames, audio_samples=mel, y=y)
return sample

View File

@@ -0,0 +1,152 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import numpy as np
from torch.utils.data import Dataset
import torch
import random
import cv2
from ..utils.image_processor import ImageProcessor, load_fixed_mask
from ..utils.audio import melspectrogram
from decord import AudioReader, VideoReader, cpu
import torch.nn.functional as F
from pathlib import Path
class UNetDataset(Dataset):
def __init__(self, train_data_dir: str, config):
if config.data.train_fileslist != "":
with open(config.data.train_fileslist) as file:
self.video_paths = [line.rstrip() for line in file]
elif train_data_dir != "":
self.video_paths = []
for file in os.listdir(train_data_dir):
if file.endswith(".mp4"):
self.video_paths.append(os.path.join(train_data_dir, file))
else:
raise ValueError("data_dir and fileslist cannot be both empty")
self.resolution = config.data.resolution
self.num_frames = config.data.num_frames
self.mel_window_length = math.ceil(self.num_frames / 5 * 16)
self.audio_sample_rate = config.data.audio_sample_rate
self.video_fps = config.data.video_fps
self.image_processor = ImageProcessor(
self.resolution, mask_image=load_fixed_mask(self.resolution, config.data.mask_image_path)
)
self.load_audio_data = config.model.add_audio_layer and config.run.use_syncnet
self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
Path(self.audio_mel_cache_dir).mkdir(parents=True, exist_ok=True)
def __len__(self):
return len(self.video_paths)
def read_audio(self, video_path: str):
ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
return torch.from_numpy(original_mel)
def crop_audio_window(self, original_mel, start_index):
start_idx = int(80.0 * (start_index / float(self.video_fps)))
end_idx = start_idx + self.mel_window_length
return original_mel[:, start_idx:end_idx].unsqueeze(0)
def get_frames(self, video_reader: VideoReader):
total_num_frames = len(video_reader)
start_idx = random.randint(0, total_num_frames - self.num_frames)
gt_frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
while True:
ref_start_idx = random.randint(0, total_num_frames - self.num_frames)
if ref_start_idx > start_idx - self.num_frames and ref_start_idx < start_idx + self.num_frames:
continue
ref_frames_index = np.arange(ref_start_idx, ref_start_idx + self.num_frames, dtype=int)
break
gt_frames = video_reader.get_batch(gt_frames_index).asnumpy()
ref_frames = video_reader.get_batch(ref_frames_index).asnumpy()
return gt_frames, ref_frames, start_idx
def worker_init_fn(self, worker_id):
self.worker_id = worker_id
def __getitem__(self, idx):
while True:
try:
idx = random.randint(0, len(self) - 1)
# Get video file path
video_path = self.video_paths[idx]
vr = VideoReader(video_path, ctx=cpu(self.worker_id))
if len(vr) < 3 * self.num_frames:
continue
gt_frames, ref_frames, start_idx = self.get_frames(vr)
if self.load_audio_data:
mel_cache_path = os.path.join(
self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
)
if os.path.isfile(mel_cache_path):
try:
original_mel = torch.load(mel_cache_path, weights_only=True)
except Exception as e:
print(f"{type(e).__name__} - {e} - {mel_cache_path}")
os.remove(mel_cache_path)
original_mel = self.read_audio(video_path)
torch.save(original_mel, mel_cache_path)
else:
original_mel = self.read_audio(video_path)
torch.save(original_mel, mel_cache_path)
mel = self.crop_audio_window(original_mel, start_idx)
if mel.shape[-1] != self.mel_window_length:
continue
else:
mel = []
gt_pixel_values, masked_pixel_values, masks = self.image_processor.prepare_masks_and_masked_images(
gt_frames, affine_transform=False
) # (f, c, h, w)
ref_pixel_values = self.image_processor.process_images(ref_frames)
vr.seek(0) # avoid memory leak
break
except Exception as e: # Handle the exception of face not detcted
print(f"{type(e).__name__} - {e} - {video_path}")
if "vr" in locals():
vr.seek(0) # avoid memory leak
sample = dict(
gt_pixel_values=gt_pixel_values,
masked_pixel_values=masked_pixel_values,
ref_pixel_values=ref_pixel_values,
mel=mel,
masks=masks,
video_path=video_path,
start_idx=start_idx,
)
return sample

View File

@@ -0,0 +1,280 @@
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.models.attention import FeedForward, AdaLayerNorm
from einops import rearrange, repeat
@dataclass
class Transformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
class Transformer3DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
add_audio_layer=False,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# Define input layers
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
# Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
add_audio_layer=add_audio_layer,
)
for d in range(num_layers)
]
)
# Define output layers
if use_linear_projection:
self.proj_out = nn.Linear(in_channels, inner_dim)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
# Input
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)
# Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
video_length=video_length,
)
# Output
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
if not return_dict:
return (output,)
return Transformer3DModelOutput(sample=output)
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
upcast_attention: bool = False,
add_audio_layer=False,
):
super().__init__()
self.use_ada_layer_norm = num_embeds_ada_norm is not None
self.add_audio_layer = add_audio_layer
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
# Cross-attn
if add_audio_layer:
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
else:
self.attn2 = None
# Feed-forward
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.norm3 = nn.LayerNorm(dim)
def forward(
self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
):
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
if self.attn2 is not None and encoder_hidden_states is not None:
if encoder_hidden_states.dim() == 4:
encoder_hidden_states = rearrange(encoder_hidden_states, "b f s d -> (b f) s d")
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = (
self.attn2(
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
+ hidden_states
)
# Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
return hidden_states
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
norm_num_groups: Optional[int] = None,
):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.scale = dim_head**-0.5
self.heads = heads
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
else:
self.group_norm = None
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))
def split_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len, self.heads, dim // self.heads)
tensor = tensor.permute(0, 2, 1, 3)
return tensor
def concat_heads(self, tensor):
batch_size, heads, seq_len, head_dim = tensor.shape
tensor = tensor.permute(0, 2, 1, 3)
tensor = tensor.reshape(batch_size, seq_len, heads * head_dim)
return tensor
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
query = self.split_heads(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = self.split_heads(key)
value = self.split_heads(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# Use PyTorch native implementation of FlashAttention-2
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
hidden_states = self.concat_heads(hidden_states)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states

View File

@@ -0,0 +1,313 @@
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
# Actually we don't use the motion module in the final version of LatentSync
# When we started the project, we used the codebase of AnimateDiff and tried motion module
# But the results are poor, and we decied to leave the code here for possible future usage
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.models.attention import FeedForward
from .attention import Attention
from einops import rearrange, repeat
import math
from .utils import zero_module
@dataclass
class TemporalTransformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
if motion_module_type == "Vanilla":
return VanillaTemporalModule(
in_channels=in_channels,
**motion_module_kwargs,
)
else:
raise ValueError
class VanillaTemporalModule(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads=8,
num_transformer_block=2,
attention_block_types=("Temporal_Self", "Temporal_Self"),
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
temporal_attention_dim_div=1,
zero_initialize=True,
):
super().__init__()
self.temporal_transformer = TemporalTransformer3DModel(
in_channels=in_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
num_layers=num_transformer_block,
attention_block_types=attention_block_types,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
if zero_initialize:
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
hidden_states = input_tensor
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
output = hidden_states
return output
class TemporalTransformer3DModel(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads,
attention_head_dim,
num_layers,
attention_block_types=(
"Temporal_Self",
"Temporal_Self",
),
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
activation_fn="geglu",
attention_bias=False,
upcast_attention=False,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
attention_block_types=attention_block_types,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
hidden_states = self.proj_in(hidden_states)
# Transformer Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length
)
# output
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
return output
class TemporalTransformerBlock(nn.Module):
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
attention_block_types=(
"Temporal_Self",
"Temporal_Self",
),
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
activation_fn="geglu",
attention_bias=False,
upcast_attention=False,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
):
super().__init__()
attention_blocks = []
norms = []
for block_name in attention_block_types:
attention_blocks.append(
VersatileAttention(
attention_mode=block_name.split("_")[0],
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
)
norms.append(nn.LayerNorm(dim))
self.attention_blocks = nn.ModuleList(attention_blocks)
self.norms = nn.ModuleList(norms)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.ff_norm = nn.LayerNorm(dim)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
for attention_block, norm in zip(self.attention_blocks, self.norms):
norm_hidden_states = norm(hidden_states)
hidden_states = (
attention_block(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
video_length=video_length,
)
+ hidden_states
)
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
output = hidden_states
return output
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.0, max_len=24):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[:, : x.size(1)]
return self.dropout(x)
class VersatileAttention(Attention):
def __init__(
self,
attention_mode=None,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
assert attention_mode == "Temporal"
self.attention_mode = attention_mode
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
self.pos_encoder = (
PositionalEncoding(kwargs["query_dim"], dropout=0.0, max_len=temporal_position_encoding_max_len)
if (temporal_position_encoding and attention_mode == "Temporal")
else None
)
def extra_repr(self):
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
if self.attention_mode == "Temporal":
s = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) s c -> (b s) f c", f=video_length)
if self.pos_encoder is not None:
hidden_states = self.pos_encoder(hidden_states)
##### This section will not be executed #####
encoder_hidden_states = (
repeat(encoder_hidden_states, "b n c -> (b s) n c", s=s)
if encoder_hidden_states is not None
else encoder_hidden_states
)
#############################################
else:
raise NotImplementedError
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
query = self.split_heads(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = self.split_heads(key)
value = self.split_heads(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# Use PyTorch native implementation of FlashAttention-2
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
hidden_states = self.concat_heads(hidden_states)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
if self.attention_mode == "Temporal":
hidden_states = rearrange(hidden_states, "(b s) f c -> (b f) s c", s=s)
return hidden_states

View File

@@ -0,0 +1,228 @@
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class InflatedConv3d(nn.Conv2d):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class InflatedGroupNorm(nn.GroupNorm):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class Upsample3D(nn.Module):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose:
raise NotImplementedError
elif use_conv:
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
def forward(self, hidden_states, output_size=None):
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
raise NotImplementedError
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
hidden_states = self.conv(hidden_states)
return hidden_states
class Downsample3D(nn.Module):
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
raise NotImplementedError
def forward(self, hidden_states):
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
raise NotImplementedError
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
class ResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
use_inflated_groupnorm=False,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
assert use_inflated_groupnorm != None
if use_inflated_groupnorm:
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
time_emb_proj_out_channels = out_channels
elif self.time_embedding_norm == "scale_shift":
time_emb_proj_out_channels = out_channels * 2
else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
else:
self.time_emb_proj = None
if self.time_embedding_norm == "scale_shift":
self.double_len_linear = torch.nn.Linear(time_emb_proj_out_channels, 2 * time_emb_proj_out_channels)
else:
self.double_len_linear = None
if use_inflated_groupnorm:
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, input_tensor, temb):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
if temb.dim() == 2:
# input (1, 1280)
temb = self.time_emb_proj(self.nonlinearity(temb))
temb = temb[:, :, None, None, None] # unsqueeze
else:
# input (1, 1280, 16)
temb = temb.permute(0, 2, 1)
temb = self.time_emb_proj(self.nonlinearity(temb))
if self.double_len_linear is not None:
temb = self.double_len_linear(self.nonlinearity(temb))
temb = temb.permute(0, 2, 1)
temb = temb[:, :, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class Mish(torch.nn.Module):
def forward(self, hidden_states):
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))

View File

@@ -0,0 +1,233 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from einops import rearrange
from torch.nn import functional as F
from .attention import Attention
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention import FeedForward
from einops import rearrange
class StableSyncNet(nn.Module):
def __init__(self, config, gradient_checkpointing=False):
super().__init__()
self.audio_encoder = DownEncoder2D(
in_channels=config["audio_encoder"]["in_channels"],
block_out_channels=config["audio_encoder"]["block_out_channels"],
downsample_factors=config["audio_encoder"]["downsample_factors"],
dropout=config["audio_encoder"]["dropout"],
attn_blocks=config["audio_encoder"]["attn_blocks"],
gradient_checkpointing=gradient_checkpointing,
)
self.visual_encoder = DownEncoder2D(
in_channels=config["visual_encoder"]["in_channels"],
block_out_channels=config["visual_encoder"]["block_out_channels"],
downsample_factors=config["visual_encoder"]["downsample_factors"],
dropout=config["visual_encoder"]["dropout"],
attn_blocks=config["visual_encoder"]["attn_blocks"],
gradient_checkpointing=gradient_checkpointing,
)
self.eval()
def forward(self, image_sequences, audio_sequences):
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
# Make them unit vectors
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
return vision_embeds, audio_embeds
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
norm_num_groups: int = 32,
eps: float = 1e-6,
act_fn: str = "silu",
downsample_factor=2,
):
super().__init__()
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if act_fn == "relu":
self.act_fn = nn.ReLU()
elif act_fn == "silu":
self.act_fn = nn.SiLU()
if in_channels != out_channels:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = None
if isinstance(downsample_factor, list):
downsample_factor = tuple(downsample_factor)
if downsample_factor == 1:
self.downsample_conv = None
else:
self.downsample_conv = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
)
self.pad = (0, 1, 0, 1)
if isinstance(downsample_factor, tuple):
if downsample_factor[0] == 1:
self.pad = (0, 1, 1, 1) # The padding order is from back to front
elif downsample_factor[1] == 1:
self.pad = (1, 1, 0, 1)
def forward(self, input_tensor):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
hidden_states += input_tensor
if self.downsample_conv is not None:
hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
hidden_states = self.downsample_conv(hidden_states)
return hidden_states
class AttentionBlock2D(nn.Module):
def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
super().__init__()
self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
self.norm2 = nn.LayerNorm(query_dim)
self.norm3 = nn.LayerNorm(query_dim)
self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
self.attn = Attention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
def forward(self, hidden_states):
assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
batch, channel, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.conv_in(hidden_states)
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
norm_hidden_states = self.norm2(hidden_states)
hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width).contiguous()
hidden_states = self.conv_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class DownEncoder2D(nn.Module):
def __init__(
self,
in_channels=4 * 16,
block_out_channels=[64, 128, 256, 256],
downsample_factors=[2, 2, 2, 2],
layers_per_block=2,
norm_num_groups=32,
attn_blocks=[1, 1, 1, 1],
dropout: float = 0.0,
act_fn="silu",
gradient_checkpointing=False,
):
super().__init__()
self.layers_per_block = layers_per_block
self.gradient_checkpointing = gradient_checkpointing
# in
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
# down
self.down_blocks = nn.ModuleList([])
output_channels = block_out_channels[0]
for i, block_out_channel in enumerate(block_out_channels):
input_channels = output_channels
output_channels = block_out_channel
down_block = ResnetBlock2D(
in_channels=input_channels,
out_channels=output_channels,
downsample_factor=downsample_factors[i],
norm_num_groups=norm_num_groups,
dropout=dropout,
act_fn=act_fn,
)
self.down_blocks.append(down_block)
if attn_blocks[i] == 1:
attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
self.down_blocks.append(attention_block)
# out
self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.act_fn_out = nn.ReLU()
def forward(self, hidden_states):
hidden_states = self.conv_in(hidden_states)
# down
for down_block in self.down_blocks:
if self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(down_block, hidden_states, use_reentrant=False)
else:
hidden_states = down_block(hidden_states)
# post-process
hidden_states = self.norm_out(hidden_states)
hidden_states = self.act_fn_out(hidden_states)
return hidden_states

View File

@@ -0,0 +1,512 @@
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet.py
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import copy
import torch
import torch.nn as nn
import torch.utils.checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from diffusers.utils import BaseOutput, logging
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import (
CrossAttnDownBlock3D,
CrossAttnUpBlock3D,
DownBlock3D,
UNetMidBlock3DCrossAttn,
UpBlock3D,
get_down_block,
get_up_block,
)
from .resnet import InflatedConv3d, InflatedGroupNorm
from ..utils.util import zero_rank_log
from .utils import zero_module
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class UNet3DConditionOutput(BaseOutput):
sample: torch.FloatTensor
class UNet3DConditionModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
),
mid_block_type: str = "UNetMidBlock3DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
use_inflated_groupnorm=False,
# Additional
use_motion_module=False,
motion_module_resolutions=(1, 2, 4, 8),
motion_module_mid_block=False,
motion_module_decoder_only=False,
motion_module_type=None,
motion_module_kwargs={},
add_audio_layer=False,
):
super().__init__()
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
self.use_motion_module = use_motion_module
self.add_audio_layer = add_audio_layer
self.conv_in = zero_module(InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)))
# time
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
else:
self.class_embedding = None
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
res = 2**i
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module
and (res in motion_module_resolutions)
and (not motion_module_decoder_only),
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
add_audio_layer=add_audio_layer,
)
self.down_blocks.append(down_block)
# mid
if mid_block_type == "UNetMidBlock3DCrossAttn":
self.mid_block = UNetMidBlock3DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module and motion_module_mid_block,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
add_audio_layer=add_audio_layer,
)
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
# count how many layers upsample the videos
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
res = 2 ** (3 - i)
is_final_block = i == len(block_out_channels) - 1
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module and (res in motion_module_resolutions),
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
add_audio_layer=add_audio_layer,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if use_inflated_groupnorm:
self.conv_norm_out = InflatedGroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
else:
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
self.conv_act = nn.SiLU()
self.conv_out = zero_module(InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1))
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_slicable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_slicable_dims(module)
num_slicable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_slicable_layers * [1]
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor = None,
class_labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
# support controlnet
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet3DConditionOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# time
timesteps = timestep
if not torch.is_tensor(timesteps):
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# pre-process
sample = self.conv_in(sample)
# down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
)
else:
sample, res_samples = downsample_block(
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
)
down_block_res_samples += res_samples
# support controlnet
down_block_res_samples = list(down_block_res_samples)
if down_block_additional_residuals is not None:
for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
if down_block_additional_residual.dim() == 4: # boardcast
down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
# mid
sample = self.mid_block(
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
# support controlnet
if mid_block_additional_residual is not None:
if mid_block_additional_residual.dim() == 4: # boardcast
mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
sample = sample + mid_block_additional_residual
# up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
attention_mask=attention_mask,
)
else:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
encoder_hidden_states=encoder_hidden_states,
)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return UNet3DConditionOutput(sample=sample)
def load_state_dict(self, state_dict, strict=True):
# If the loaded checkpoint's in_channels or out_channels are different from config
if state_dict["conv_in.weight"].shape[1] != self.config.in_channels:
del state_dict["conv_in.weight"]
del state_dict["conv_in.bias"]
if state_dict["conv_out.weight"].shape[0] != self.config.out_channels:
del state_dict["conv_out.weight"]
del state_dict["conv_out.bias"]
# If the loaded checkpoint's cross_attention_dim is different from config
keys_to_remove = []
for key in state_dict:
if "attn2.to_k." in key or "attn2.to_v." in key:
if state_dict[key].shape[1] != self.config.cross_attention_dim:
keys_to_remove.append(key)
for key in keys_to_remove:
del state_dict[key]
return super().load_state_dict(state_dict=state_dict, strict=strict)
@classmethod
def from_pretrained(cls, model_config: dict, ckpt_path: str, device="cpu"):
unet = cls.from_config(model_config).to(device)
if ckpt_path != "":
zero_rank_log(logger, f"Load from checkpoint: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
if "global_step" in ckpt:
zero_rank_log(logger, f"resume from global_step: {ckpt['global_step']}")
resume_global_step = ckpt["global_step"]
else:
resume_global_step = 0
unet.load_state_dict(ckpt["state_dict"], strict=False)
del ckpt
torch.cuda.empty_cache()
else:
resume_global_step = 0
return unet, resume_global_step

View File

@@ -0,0 +1,777 @@
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
import torch
from torch import nn
from .attention import Transformer3DModel
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
from .motion_module import get_motion_module
def get_down_block(
down_block_type,
num_layers,
in_channels,
out_channels,
temb_channels,
add_downsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
add_audio_layer=False,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock3D":
return DownBlock3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
elif down_block_type == "CrossAttnDownBlock3D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
return CrossAttnDownBlock3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
add_audio_layer=add_audio_layer,
)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(
up_block_type,
num_layers,
in_channels,
out_channels,
prev_output_channel,
temb_channels,
add_upsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
add_audio_layer=False,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock3D":
return UpBlock3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
elif up_block_type == "CrossAttnUpBlock3D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
return CrossAttnUpBlock3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
add_audio_layer=add_audio_layer,
)
raise ValueError(f"{up_block_type} does not exist.")
class UNetMidBlock3DCrossAttn(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
use_linear_projection=False,
upcast_attention=False,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
add_audio_layer=False,
):
super().__init__()
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet
resnets = [
ResnetBlock3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
]
attentions = []
motion_modules = []
for _ in range(num_layers):
if dual_cross_attention:
raise NotImplementedError
attentions.append(
Transformer3DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
add_audio_layer=add_audio_layer,
)
)
motion_modules.append(
get_motion_module(
in_channels=in_channels,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
if use_motion_module
else None
)
resnets.append(
ResnetBlock3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
return_dict=False,
)[0]
if motion_module is not None:
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
class CrossAttnDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
add_audio_layer=False,
):
super().__init__()
resnets = []
attentions = []
motion_modules = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock3D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
if dual_cross_attention:
raise NotImplementedError
attentions.append(
Transformer3DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
add_audio_layer=add_audio_layer,
)
)
motion_modules.append(
get_motion_module(
in_channels=out_channels,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
if use_motion_module
else None
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample3D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
output_states = ()
for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
use_reentrant=False,
)[0]
if motion_module is not None:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
temb,
encoder_hidden_states,
use_reentrant=False,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
if motion_module is not None:
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class DownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
):
super().__init__()
resnets = []
motion_modules = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock3D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
motion_modules.append(
get_motion_module(
in_channels=out_channels,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
if use_motion_module
else None
)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample3D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()
for resnet, motion_module in zip(self.resnets, self.motion_modules):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
if motion_module is not None:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
temb,
encoder_hidden_states,
use_reentrant=False,
)
else:
hidden_states = resnet(hidden_states, temb)
if motion_module is not None:
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class CrossAttnUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
add_audio_layer=False,
):
super().__init__()
resnets = []
attentions = []
motion_modules = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock3D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
if dual_cross_attention:
raise NotImplementedError
attentions.append(
Transformer3DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
add_audio_layer=add_audio_layer,
)
)
motion_modules.append(
get_motion_module(
in_channels=out_channels,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
if use_motion_module
else None
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
attention_mask=None,
):
for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
use_reentrant=False,
)[0]
if motion_module is not None:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
temb,
encoder_hidden_states,
use_reentrant=False,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
if motion_module is not None:
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
class UpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
prev_output_channel: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
):
super().__init__()
resnets = []
motion_modules = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock3D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
motion_modules.append(
get_motion_module(
in_channels=out_channels,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
if use_motion_module
else None
)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
res_hidden_states_tuple,
temb=None,
upsample_size=None,
encoder_hidden_states=None,
):
for resnet, motion_module in zip(self.resnets, self.motion_modules):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
if motion_module is not None:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
temb,
encoder_hidden_states,
use_reentrant=False,
)
else:
hidden_states = resnet(hidden_states, temb)
if motion_module is not None:
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states

View File

@@ -0,0 +1,19 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def zero_module(module):
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module

View File

@@ -0,0 +1,90 @@
# Adapted from https://github.com/primepake/wav2lip_288x288/blob/master/models/syncnetv2.py
# The code here is for ablation study.
from torch import nn
from torch.nn import functional as F
class Wav2LipSyncNet(nn.Module):
def __init__(self, act_fn="leaky"):
super().__init__()
# input image sequences: (15, 128, 256)
self.visual_encoder = nn.Sequential(
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3, act_fn=act_fn), # (128, 256)
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1, act_fn=act_fn), # (126, 127)
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(64, 128, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (63, 64)
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 256, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (21, 22)
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(256, 512, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (11, 11)
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (6, 6)
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1, act_fn="relu"), # (3, 3)
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
)
# input audio sequences: (1, 80, 16)
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1, act_fn=act_fn), # (27, 16)
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (9, 6)
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1, act_fn=act_fn), # (3, 3)
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(256, 512, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(512, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
)
def forward(self, image_sequences, audio_sequences):
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
# Make them unit vectors
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
return vision_embeds, audio_embeds
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, act_fn="relu", *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout))
if act_fn == "relu":
self.act_fn = nn.ReLU()
elif act_fn == "tanh":
self.act_fn = nn.Tanh()
elif act_fn == "silu":
self.act_fn = nn.SiLU()
elif act_fn == "leaky":
self.act_fn = nn.LeakyReLU(0.2, inplace=True)
self.residual = residual
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
return self.act_fn(out)

View File

@@ -0,0 +1,477 @@
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/pipelines/pipeline_animation.py
import inspect
import math
import os
import shutil
from typing import Callable, List, Optional, Union
import subprocess
import numpy as np
import torch
import torchvision
from torchvision import transforms
from packaging import version
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL
from diffusers.pipelines import DiffusionPipeline
from diffusers.schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import deprecate, logging
from einops import rearrange
import cv2
from ..models.unet import UNet3DConditionModel
from ..utils.util import read_video, read_audio, write_video, check_ffmpeg_installed
from ..utils.image_processor import ImageProcessor, load_fixed_mask
from ..whisper.audio2feature import Audio2Feature
import tqdm
import soundfile as sf
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LipsyncPipeline(DiffusionPipeline):
_optional_components = []
def __init__(
self,
vae: AutoencoderKL,
audio_encoder: Audio2Feature,
unet: UNet3DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
audio_encoder=audio_encoder,
unet=unet,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.set_progress_bar_config(desc="Steps")
def enable_vae_slicing(self):
self.vae.enable_slicing()
def disable_vae_slicing(self):
self.vae.disable_slicing()
@property
def _execution_device(self):
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def decode_latents(self, latents):
latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor
latents = rearrange(latents, "b c f h w -> (b f) c h w")
decoded_latents = self.vae.decode(latents).sample
return decoded_latents
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, height, width, callback_steps):
assert height == width, "Height and width must be equal"
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def prepare_latents(self, num_frames, num_channels_latents, height, width, dtype, device, generator):
shape = (
1,
num_channels_latents,
1,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
) # (b, c, f, h, w)
rand_device = "cpu" if device.type == "mps" else device
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
latents = latents.repeat(1, 1, num_frames, 1, 1)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_mask_latents(
self, mask, masked_image, height, width, dtype, device, generator, do_classifier_free_guidance
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
)
masked_image = masked_image.to(device=device, dtype=dtype)
# encode the mask image into latents space so we can concatenate it to the latents
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
mask = mask.to(device=device, dtype=dtype)
# assume batch size = 1
mask = rearrange(mask, "f c h w -> 1 c f h w")
masked_image_latents = rearrange(masked_image_latents, "f c h w -> 1 c f h w")
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
)
return mask, masked_image_latents
def prepare_image_latents(self, images, device, dtype, generator, do_classifier_free_guidance):
images = images.to(device=device, dtype=dtype)
image_latents = self.vae.encode(images).latent_dist.sample(generator=generator)
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
image_latents = rearrange(image_latents, "f c h w -> 1 c f h w")
image_latents = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
return image_latents
def set_progress_bar_config(self, **kwargs):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
self._progress_bar_config.update(kwargs)
@staticmethod
def paste_surrounding_pixels_back(decoded_latents, pixel_values, masks, device, weight_dtype):
# Paste the surrounding pixels back, because we only want to change the mouth region
pixel_values = pixel_values.to(device=device, dtype=weight_dtype)
masks = masks.to(device=device, dtype=weight_dtype)
combined_pixel_values = decoded_latents * masks + pixel_values * (1 - masks)
return combined_pixel_values
@staticmethod
def pixel_values_to_images(pixel_values: torch.Tensor):
pixel_values = rearrange(pixel_values, "f c h w -> f h w c")
pixel_values = (pixel_values / 2 + 0.5).clamp(0, 1)
images = (pixel_values * 255).to(torch.uint8)
images = images.cpu().numpy()
return images
def affine_transform_video(self, video_frames: np.ndarray):
faces = []
boxes = []
affine_matrices = []
print(f"Affine transforming {len(video_frames)} faces...")
for frame in tqdm.tqdm(video_frames):
face, box, affine_matrix = self.image_processor.affine_transform(frame)
faces.append(face)
boxes.append(box)
affine_matrices.append(affine_matrix)
faces = torch.stack(faces)
return faces, boxes, affine_matrices
def restore_video(self, faces: torch.Tensor, video_frames: np.ndarray, boxes: list, affine_matrices: list):
video_frames = video_frames[: len(faces)]
out_frames = []
print(f"Restoring {len(faces)} faces...")
for index, face in enumerate(tqdm.tqdm(faces)):
x1, y1, x2, y2 = boxes[index]
height = int(y2 - y1)
width = int(x2 - x1)
face = torchvision.transforms.functional.resize(
face, size=(height, width), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True
)
out_frame = self.image_processor.restorer.restore_img(video_frames[index], face, affine_matrices[index])
out_frames.append(out_frame)
return np.stack(out_frames, axis=0)
def loop_video(self, whisper_chunks: list, video_frames: np.ndarray):
# If the audio is longer than the video, we need to loop the video
if len(whisper_chunks) > len(video_frames):
faces, boxes, affine_matrices = self.affine_transform_video(video_frames)
num_loops = math.ceil(len(whisper_chunks) / len(video_frames))
loop_video_frames = []
loop_faces = []
loop_boxes = []
loop_affine_matrices = []
for i in range(num_loops):
if i % 2 == 0:
loop_video_frames.append(video_frames)
loop_faces.append(faces)
loop_boxes += boxes
loop_affine_matrices += affine_matrices
else:
loop_video_frames.append(video_frames[::-1])
loop_faces.append(faces.flip(0))
loop_boxes += boxes[::-1]
loop_affine_matrices += affine_matrices[::-1]
video_frames = np.concatenate(loop_video_frames, axis=0)[: len(whisper_chunks)]
faces = torch.cat(loop_faces, dim=0)[: len(whisper_chunks)]
boxes = loop_boxes[: len(whisper_chunks)]
affine_matrices = loop_affine_matrices[: len(whisper_chunks)]
else:
video_frames = video_frames[: len(whisper_chunks)]
faces, boxes, affine_matrices = self.affine_transform_video(video_frames)
return video_frames, faces, boxes, affine_matrices
@torch.no_grad()
def __call__(
self,
video_path: str,
audio_path: str,
video_out_path: str,
num_frames: int = 16,
video_fps: int = 25,
audio_sample_rate: int = 16000,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 20,
guidance_scale: float = 1.5,
weight_dtype: Optional[torch.dtype] = torch.float16,
eta: float = 0.0,
mask_image_path: str = "latentsync/utils/mask.png",
temp_dir: str = "temp",
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
is_train = self.unet.training
self.unet.eval()
check_ffmpeg_installed()
# 0. Define call parameters
device = self._execution_device
mask_image = load_fixed_mask(height, mask_image_path)
self.image_processor = ImageProcessor(height, device="cuda", mask_image=mask_image)
self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
# 1. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 2. Check inputs
self.check_inputs(height, width, callback_steps)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 4. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
whisper_feature = self.audio_encoder.audio2feat(audio_path)
whisper_chunks = self.audio_encoder.feature2chunks(feature_array=whisper_feature, fps=video_fps)
audio_samples = read_audio(audio_path)
video_frames = read_video(video_path, use_decord=False)
video_frames, faces, boxes, affine_matrices = self.loop_video(whisper_chunks, video_frames)
synced_video_frames = []
num_channels_latents = self.vae.config.latent_channels
# Prepare latent variables
all_latents = self.prepare_latents(
len(whisper_chunks),
num_channels_latents,
height,
width,
weight_dtype,
device,
generator,
)
num_inferences = math.ceil(len(whisper_chunks) / num_frames)
for i in tqdm.tqdm(range(num_inferences), desc="Doing inference..."):
if self.unet.add_audio_layer:
audio_embeds = torch.stack(whisper_chunks[i * num_frames : (i + 1) * num_frames])
audio_embeds = audio_embeds.to(device, dtype=weight_dtype)
if do_classifier_free_guidance:
null_audio_embeds = torch.zeros_like(audio_embeds)
audio_embeds = torch.cat([null_audio_embeds, audio_embeds])
else:
audio_embeds = None
inference_faces = faces[i * num_frames : (i + 1) * num_frames]
latents = all_latents[:, :, i * num_frames : (i + 1) * num_frames]
ref_pixel_values, masked_pixel_values, masks = self.image_processor.prepare_masks_and_masked_images(
inference_faces, affine_transform=False
)
# 7. Prepare mask latent variables
mask_latents, masked_image_latents = self.prepare_mask_latents(
masks,
masked_pixel_values,
height,
width,
weight_dtype,
device,
generator,
do_classifier_free_guidance,
)
# 8. Prepare image latents
ref_latents = self.prepare_image_latents(
ref_pixel_values,
device,
weight_dtype,
generator,
do_classifier_free_guidance,
)
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for j, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
unet_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
unet_input = self.scheduler.scale_model_input(unet_input, t)
# concat latents, mask, masked_image_latents in the channel dimension
unet_input = torch.cat([unet_input, mask_latents, masked_image_latents, ref_latents], dim=1)
# predict the noise residual
noise_pred = self.unet(unet_input, t, encoder_hidden_states=audio_embeds).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_audio = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_audio - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if j == len(timesteps) - 1 or ((j + 1) > num_warmup_steps and (j + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and j % callback_steps == 0:
callback(j, t, latents)
# Recover the pixel values
decoded_latents = self.decode_latents(latents)
decoded_latents = self.paste_surrounding_pixels_back(
decoded_latents, ref_pixel_values, 1 - masks, device, weight_dtype
)
synced_video_frames.append(decoded_latents)
synced_video_frames = self.restore_video(torch.cat(synced_video_frames), video_frames, boxes, affine_matrices)
audio_samples_remain_length = int(synced_video_frames.shape[0] / video_fps * audio_sample_rate)
audio_samples = audio_samples[:audio_samples_remain_length].cpu().numpy()
if is_train:
self.unet.train()
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
os.makedirs(temp_dir, exist_ok=True)
write_video(os.path.join(temp_dir, "video.mp4"), synced_video_frames, fps=video_fps)
sf.write(os.path.join(temp_dir, "audio.wav"), audio_samples, audio_sample_rate)
command = f"ffmpeg -y -loglevel error -nostdin -i {os.path.join(temp_dir, 'video.mp4')} -i {os.path.join(temp_dir, 'audio.wav')} -c:v libx264 -crf 18 -c:a aac -q:v 0 -q:a 0 {video_out_path}"
subprocess.run(command, shell=True)

View File

@@ -0,0 +1,67 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from einops import rearrange
from .third_party.VideoMAEv2.utils import load_videomae_model
from ..utils.util import check_model_and_download
class TREPALoss:
def __init__(
self,
device="cuda",
ckpt_path="checkpoints/auxiliary/vit_g_hybrid_pt_1200e_ssv2_ft.pth",
with_cp=False,
):
check_model_and_download(ckpt_path)
self.model = load_videomae_model(device, ckpt_path, with_cp).eval().to(dtype=torch.float16)
self.model.requires_grad_(False)
def __call__(self, videos_fake, videos_real):
batch_size = videos_fake.shape[0]
num_frames = videos_fake.shape[2]
videos_fake = rearrange(videos_fake.clone(), "b c f h w -> (b f) c h w")
videos_real = rearrange(videos_real.clone(), "b c f h w -> (b f) c h w")
videos_fake = F.interpolate(videos_fake, size=(224, 224), mode="bicubic")
videos_real = F.interpolate(videos_real, size=(224, 224), mode="bicubic")
videos_fake = rearrange(videos_fake, "(b f) c h w -> b c f h w", f=num_frames)
videos_real = rearrange(videos_real, "(b f) c h w -> b c f h w", f=num_frames)
# Because input pixel range is [-1, 1], and model expects pixel range to be [0, 1]
videos_fake = (videos_fake / 2 + 0.5).clamp(0, 1)
videos_real = (videos_real / 2 + 0.5).clamp(0, 1)
feats_fake = self.model.forward_features(videos_fake)
feats_real = self.model.forward_features(videos_real)
feats_fake = F.normalize(feats_fake, p=2, dim=1)
feats_real = F.normalize(feats_real, p=2, dim=1)
return F.mse_loss(feats_fake, feats_real)
if __name__ == "__main__":
torch.manual_seed(42)
# input shape: (b, c, f, h, w)
videos_fake = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16)
videos_real = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16)
trepa_loss = TREPALoss(device="cuda", with_cp=True)
loss = trepa_loss(videos_fake, videos_real)
print(loss)

View File

@@ -0,0 +1,82 @@
import os
import torch
import requests
from tqdm import tqdm
from torchvision import transforms
from .videomaev2_finetune import vit_giant_patch14_224
def to_normalized_float_tensor(vid):
return vid.permute(3, 0, 1, 2).to(torch.float32) / 255
# NOTE: for those functions, which generally expect mini-batches, we keep them
# as non-minibatch so that they are applied as if they were 4d (thus image).
# this way, we only apply the transformation in the spatial domain
def resize(vid, size, interpolation="bilinear"):
# NOTE: using bilinear interpolation because we don't work on minibatches
# at this level
scale = None
if isinstance(size, int):
scale = float(size) / min(vid.shape[-2:])
size = None
return torch.nn.functional.interpolate(vid, size=size, scale_factor=scale, mode=interpolation, align_corners=False)
class ToFloatTensorInZeroOne(object):
def __call__(self, vid):
return to_normalized_float_tensor(vid)
class Resize(object):
def __init__(self, size):
self.size = size
def __call__(self, vid):
return resize(vid, self.size)
def preprocess_videomae(videos):
transform = transforms.Compose([ToFloatTensorInZeroOne(), Resize((224, 224))])
return torch.stack([transform(f) for f in torch.from_numpy(videos)])
def load_videomae_model(device, ckpt_path=None, with_cp=False):
if ckpt_path is None:
current_dir = os.path.dirname(os.path.abspath(__file__))
ckpt_path = os.path.join(current_dir, "vit_g_hybrid_pt_1200e_ssv2_ft.pth")
if not os.path.exists(ckpt_path):
# download the ckpt to the path
ckpt_url = "https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/videomaev2/vit_g_hybrid_pt_1200e_ssv2_ft.pth"
response = requests.get(ckpt_url, stream=True, allow_redirects=True)
total_size = int(response.headers.get("content-length", 0))
block_size = 1024
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
with open(ckpt_path, "wb") as fw:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
fw.write(data)
model = vit_giant_patch14_224(
img_size=224,
pretrained=False,
num_classes=174,
all_frames=16,
tubelet_size=2,
drop_path_rate=0.3,
use_mean_pooling=True,
with_cp=with_cp,
)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
for model_key in ["model", "module"]:
if model_key in ckpt:
ckpt = ckpt[model_key]
break
model.load_state_dict(ckpt)
del ckpt
torch.cuda.empty_cache()
return model.to(device)

View File

@@ -0,0 +1,543 @@
# --------------------------------------------------------
# Based on BEiT, timm, DINO and DeiT code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
from functools import partial
import math
import warnings
import numpy as np
import collections.abc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from itertools import repeat
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
"""
Adapted from timm codebase
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
def _cfg(url="", **kwargs):
return {
"url": url,
"num_classes": 400,
"input_size": (3, 224, 224),
"pool_size": None,
"crop_pct": 0.9,
"interpolation": "bicubic",
"mean": (0.5, 0.5, 0.5),
"std": (0.5, 0.5, 0.5),
**kwargs,
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the original BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class CosAttention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, attn_head_dim=None
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
# self.scale = qk_scale or head_dim**-0.5
# DO NOT RENAME [self.scale] (for no weight decay)
if qk_scale is None:
self.scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
else:
self.scale = qk_scale
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
# torch.log(torch.tensor(1. / 0.01)) = 4.6052
logit_scale = torch.clamp(self.scale, max=4.6052).exp()
attn = attn * logit_scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, attn_head_dim=None
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
# Use PyTorch native implementation of FlashAttention-2
attn = F.scaled_dot_product_attention(q, k, v)
x = attn.transpose(1, 2).reshape(B, N, -1)
# Deprecated attn implementation, which consumes much more VRAM
# q = q * self.scale
# attn = q @ k.transpose(-2, -1)
# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn)
# x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
init_values=None,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
attn_head_dim=None,
cos_attn=False,
):
super().__init__()
self.norm1 = norm_layer(dim)
if cos_attn:
self.attn = CosAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
attn_head_dim=attn_head_dim,
)
else:
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
attn_head_dim=attn_head_dim,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_spatial_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
num_patches = num_spatial_patches * (num_frames // tubelet_size)
self.img_size = img_size
self.tubelet_size = tubelet_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv3d(
in_channels=in_chans,
out_channels=embed_dim,
kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
stride=(self.tubelet_size, patch_size[0], patch_size[1]),
)
def forward(self, x, **kwargs):
B, C, T, H, W = x.shape
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# b, c, l -> b, l, c
# [1, 1408, 8, 16, 16] -> [1, 1408, 2048] -> [1, 2048, 1408]
x = self.proj(x).flatten(2).transpose(1, 2)
return x
# sin-cos position encoding
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
def get_sinusoid_encoding_table(n_position, d_hid):
"""Sinusoid position encoding table"""
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
class VisionTransformer(nn.Module):
"""Vision Transformer with support for patch or hybrid CNN input stage"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
head_drop_rate=0.0,
norm_layer=nn.LayerNorm,
init_values=0.0,
use_learnable_pos_emb=False,
init_scale=0.0,
all_frames=16,
tubelet_size=2,
use_mean_pooling=True,
with_cp=False,
cos_attn=False,
):
super().__init__()
self.num_classes = num_classes
# num_features for consistency with other models
self.num_features = self.embed_dim = embed_dim
self.tubelet_size = tubelet_size
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
num_frames=all_frames,
tubelet_size=tubelet_size,
)
num_patches = self.patch_embed.num_patches
self.with_cp = with_cp
if use_learnable_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
else:
# sine-cosine positional embeddings is on the way
self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList(
[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
cos_attn=cos_attn,
)
for i in range(depth)
]
)
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
self.head_dropout = nn.Dropout(head_drop_rate)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if use_learnable_pos_emb:
trunc_normal_(self.pos_embed, std=0.02)
self.apply(self._init_weights)
self.head.weight.data.mul_(init_scale)
self.head.bias.data.mul_(init_scale)
self.num_frames = all_frames
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=""):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def interpolate_pos_encoding(self, t):
T = 8
t0 = t // self.tubelet_size
if T == t0:
return self.pos_embed
dim = self.pos_embed.shape[-1]
patch_pos_embed = self.pos_embed.permute(0, 2, 1).reshape(1, dim, 8, 16, 16)
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
t0 = t0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(t0 / T, 1, 1),
mode="trilinear",
)
assert int(t0) == patch_pos_embed.shape[-3]
patch_pos_embed = patch_pos_embed.reshape(1, dim, -1).permute(0, 2, 1)
return patch_pos_embed
def forward_features(self, x):
# [1, 3, 16, 224, 224]
B = x.size(0)
T = x.size(2)
# [1, 2048, 1408]
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.interpolate_pos_encoding(T).expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
x = self.pos_drop(x)
for blk in self.blocks:
if self.with_cp:
x = cp.checkpoint(blk, x, use_reentrant=False)
else:
x = blk(x)
# return self.fc_norm(x)
if self.fc_norm is not None:
return self.fc_norm(x.mean(1))
else:
return self.norm(x[:, 0])
def forward(self, x):
x = self.forward_features(x)
x = self.head_dropout(x)
x = self.head(x)
return x
def vit_giant_patch14_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=14,
embed_dim=1408,
depth=40,
num_heads=16,
mlp_ratio=48 / 11,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
model.default_cfg = _cfg()
return model

View File

@@ -0,0 +1,469 @@
# --------------------------------------------------------
# Based on BEiT, timm, DINO and DeiT code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
from functools import partial
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from .videomaev2_finetune import (
Block,
PatchEmbed,
_cfg,
get_sinusoid_encoding_table,
)
from .videomaev2_finetune import trunc_normal_ as __call_trunc_normal_
def trunc_normal_(tensor, mean=0., std=1.):
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
class PretrainVisionTransformerEncoder(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=0,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
init_values=None,
tubelet_size=2,
use_learnable_pos_emb=False,
with_cp=False,
all_frames=16,
cos_attn=False):
super().__init__()
self.num_classes = num_classes
# num_features for consistency with other models
self.num_features = self.embed_dim = embed_dim
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
num_frames=all_frames,
tubelet_size=tubelet_size)
num_patches = self.patch_embed.num_patches
self.with_cp = with_cp
if use_learnable_pos_emb:
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim))
else:
# sine-cosine positional embeddings
self.pos_embed = get_sinusoid_encoding_table(
num_patches, embed_dim)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
cos_attn=cos_attn) for i in range(depth)
])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if use_learnable_pos_emb:
trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x, mask):
x = self.patch_embed(x)
x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
B, _, C = x.shape
x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible
for blk in self.blocks:
if self.with_cp:
x_vis = cp.checkpoint(blk, x_vis)
else:
x_vis = blk(x_vis)
x_vis = self.norm(x_vis)
return x_vis
def forward(self, x, mask):
x = self.forward_features(x, mask)
x = self.head(x)
return x
class PretrainVisionTransformerDecoder(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
patch_size=16,
num_classes=768,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
init_values=None,
num_patches=196,
tubelet_size=2,
with_cp=False,
cos_attn=False):
super().__init__()
self.num_classes = num_classes
assert num_classes == 3 * tubelet_size * patch_size**2
# num_features for consistency with other models
self.num_features = self.embed_dim = embed_dim
self.patch_size = patch_size
self.with_cp = with_cp
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
cos_attn=cos_attn) for i in range(depth)
])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x, return_token_num):
for blk in self.blocks:
if self.with_cp:
x = cp.checkpoint(blk, x)
else:
x = blk(x)
if return_token_num > 0:
# only return the mask tokens predict pixels
x = self.head(self.norm(x[:, -return_token_num:]))
else:
# [B, N, 3*16^2]
x = self.head(self.norm(x))
return x
class PretrainVisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(
self,
img_size=224,
patch_size=16,
encoder_in_chans=3,
encoder_num_classes=0,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
decoder_num_classes=1536, # decoder_num_classes=768
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=8,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
init_values=0.,
use_learnable_pos_emb=False,
tubelet_size=2,
num_classes=0, # avoid the error from create_fn in timm
in_chans=0, # avoid the error from create_fn in timm
with_cp=False,
all_frames=16,
cos_attn=False,
):
super().__init__()
self.encoder = PretrainVisionTransformerEncoder(
img_size=img_size,
patch_size=patch_size,
in_chans=encoder_in_chans,
num_classes=encoder_num_classes,
embed_dim=encoder_embed_dim,
depth=encoder_depth,
num_heads=encoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
init_values=init_values,
tubelet_size=tubelet_size,
use_learnable_pos_emb=use_learnable_pos_emb,
with_cp=with_cp,
all_frames=all_frames,
cos_attn=cos_attn)
self.decoder = PretrainVisionTransformerDecoder(
patch_size=patch_size,
num_patches=self.encoder.patch_embed.num_patches,
num_classes=decoder_num_classes,
embed_dim=decoder_embed_dim,
depth=decoder_depth,
num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
init_values=init_values,
tubelet_size=tubelet_size,
with_cp=with_cp,
cos_attn=cos_attn)
self.encoder_to_decoder = nn.Linear(
encoder_embed_dim, decoder_embed_dim, bias=False)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.pos_embed = get_sinusoid_encoding_table(
self.encoder.patch_embed.num_patches, decoder_embed_dim)
trunc_normal_(self.mask_token, std=.02)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'mask_token'}
def forward(self, x, mask, decode_mask=None):
decode_vis = mask if decode_mask is None else ~decode_mask
x_vis = self.encoder(x, mask) # [B, N_vis, C_e]
x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d]
B, N_vis, C = x_vis.shape
# we don't unshuffle the correct visible token order,
# but shuffle the pos embedding accorddingly.
expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(
x.device).clone().detach()
pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
pos_emd_mask = expand_pos_embed[decode_vis].reshape(B, -1, C)
# [B, N, C_d]
x_full = torch.cat(
[x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1)
# NOTE: if N_mask==0, the shape of x is [B, N_mask, 3 * 16 * 16]
x = self.decoder(x_full, pos_emd_mask.shape[1])
return x
def pretrain_videomae_small_patch16_224(pretrained=False, **kwargs):
model = PretrainVisionTransformer(
img_size=224,
patch_size=16,
encoder_embed_dim=384,
encoder_depth=12,
encoder_num_heads=6,
encoder_num_classes=0,
decoder_num_classes=1536, # 16 * 16 * 3 * 2
decoder_embed_dim=192,
decoder_num_heads=3,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
def pretrain_videomae_base_patch16_224(pretrained=False, **kwargs):
model = PretrainVisionTransformer(
img_size=224,
patch_size=16,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_num_classes=0,
decoder_num_classes=1536, # 16 * 16 * 3 * 2
decoder_embed_dim=384,
decoder_num_heads=6,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
def pretrain_videomae_large_patch16_224(pretrained=False, **kwargs):
model = PretrainVisionTransformer(
img_size=224,
patch_size=16,
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_num_classes=0,
decoder_num_classes=1536, # 16 * 16 * 3 * 2
decoder_embed_dim=512,
decoder_num_heads=8,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
def pretrain_videomae_huge_patch16_224(pretrained=False, **kwargs):
model = PretrainVisionTransformer(
img_size=224,
patch_size=16,
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_num_classes=0,
decoder_num_classes=1536, # 16 * 16 * 3 * 2
decoder_embed_dim=512,
decoder_num_heads=8,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
def pretrain_videomae_giant_patch14_224(pretrained=False, **kwargs):
model = PretrainVisionTransformer(
img_size=224,
patch_size=14,
encoder_embed_dim=1408,
encoder_depth=40,
encoder_num_heads=16,
encoder_num_classes=0,
decoder_num_classes=1176, # 14 * 14 * 3 * 2,
decoder_embed_dim=512,
decoder_num_heads=8,
mlp_ratio=48 / 11,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model

View File

@@ -0,0 +1,321 @@
import os
import math
import os.path as osp
import random
import pickle
import warnings
import glob
import numpy as np
from PIL import Image
import torch
import torch.utils.data as data
import torch.nn.functional as F
import torch.distributed as dist
from torchvision.datasets.video_utils import VideoClips
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
VID_EXTENSIONS = ['.avi', '.mp4', '.webm', '.mov', '.mkv', '.m4v']
def get_dataloader(data_path, image_folder, resolution=128, sequence_length=16, sample_every_n_frames=1,
batch_size=16, num_workers=8):
data = VideoData(data_path, image_folder, resolution, sequence_length, sample_every_n_frames, batch_size, num_workers)
loader = data._dataloader()
return loader
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def get_parent_dir(path):
return osp.basename(osp.dirname(path))
def preprocess(video, resolution, sequence_length=None, in_channels=3, sample_every_n_frames=1):
# video: THWC, {0, ..., 255}
assert in_channels == 3
video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
t, c, h, w = video.shape
# temporal crop
if sequence_length is not None:
assert sequence_length <= t
video = video[:sequence_length]
# skip frames
if sample_every_n_frames > 1:
video = video[::sample_every_n_frames]
# scale shorter side to resolution
scale = resolution / min(h, w)
if h < w:
target_size = (resolution, math.ceil(w * scale))
else:
target_size = (math.ceil(h * scale), resolution)
video = F.interpolate(video, size=target_size, mode='bilinear',
align_corners=False, antialias=True)
# center crop
t, c, h, w = video.shape
w_start = (w - resolution) // 2
h_start = (h - resolution) // 2
video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
video = video.permute(1, 0, 2, 3).contiguous() # CTHW
return {'video': video}
def preprocess_image(image):
# [0, 1] => [-1, 1]
img = torch.from_numpy(image)
return img
class VideoData(data.Dataset):
""" Class to create dataloaders for video datasets
Args:
data_path: Path to the folder with video frames or videos.
image_folder: If True, the data is stored as images in folders.
resolution: Resolution of the returned videos.
sequence_length: Length of extracted video sequences.
sample_every_n_frames: Sample every n frames from the video.
batch_size: Batch size.
num_workers: Number of workers for the dataloader.
shuffle: If True, shuffle the data.
"""
def __init__(self, data_path: str, image_folder: bool, resolution: int, sequence_length: int,
sample_every_n_frames: int, batch_size: int, num_workers: int, shuffle: bool = True):
super().__init__()
self.data_path = data_path
self.image_folder = image_folder
self.resolution = resolution
self.sequence_length = sequence_length
self.sample_every_n_frames = sample_every_n_frames
self.batch_size = batch_size
self.num_workers = num_workers
self.shuffle = shuffle
def _dataset(self):
'''
Initializes and return the dataset.
'''
if self.image_folder:
Dataset = FrameDataset
dataset = Dataset(self.data_path, self.sequence_length,
resolution=self.resolution, sample_every_n_frames=self.sample_every_n_frames)
else:
Dataset = VideoDataset
dataset = Dataset(self.data_path, self.sequence_length,
resolution=self.resolution, sample_every_n_frames=self.sample_every_n_frames)
return dataset
def _dataloader(self):
'''
Initializes and returns the dataloader.
'''
dataset = self._dataset()
if dist.is_initialized():
sampler = data.distributed.DistributedSampler(
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()
)
else:
sampler = None
dataloader = data.DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
sampler=sampler,
shuffle=sampler is None and self.shuffle is True
)
return dataloader
class VideoDataset(data.Dataset):
"""
Generic dataset for videos files stored in folders.
Videos of the same class are expected to be stored in a single folder. Multiple folders can exist in the provided directory.
The class depends on `torchvision.datasets.video_utils.VideoClips` to load the videos.
Returns BCTHW videos in the range [0, 1].
Args:
data_folder: Path to the folder with corresponding videos stored.
sequence_length: Length of extracted video sequences.
resolution: Resolution of the returned videos.
sample_every_n_frames: Sample every n frames from the video.
"""
def __init__(self, data_folder: str, sequence_length: int = 16, resolution: int = 128, sample_every_n_frames: int = 1):
super().__init__()
self.sequence_length = sequence_length
self.resolution = resolution
self.sample_every_n_frames = sample_every_n_frames
folder = data_folder
files = sum([glob.glob(osp.join(folder, '**', f'*{ext}'), recursive=True)
for ext in VID_EXTENSIONS], [])
warnings.filterwarnings('ignore')
cache_file = osp.join(folder, f"metadata_{sequence_length}.pkl")
if not osp.exists(cache_file):
clips = VideoClips(files, sequence_length, num_workers=4)
try:
pickle.dump(clips.metadata, open(cache_file, 'wb'))
except:
print(f"Failed to save metadata to {cache_file}")
else:
metadata = pickle.load(open(cache_file, 'rb'))
clips = VideoClips(files, sequence_length,
_precomputed_metadata=metadata)
self._clips = clips
# instead of uniformly sampling from all possible clips, we sample uniformly from all possible videos
self._clips.get_clip_location = self.get_random_clip_from_video
def get_random_clip_from_video(self, idx: int) -> tuple:
'''
Sample a random clip starting index from the video.
Args:
idx: Index of the video.
'''
# Note that some videos may not contain enough frames, we skip those videos here.
while self._clips.clips[idx].shape[0] <= 0:
idx += 1
n_clip = self._clips.clips[idx].shape[0]
clip_id = random.randint(0, n_clip - 1)
return idx, clip_id
def __len__(self):
return self._clips.num_videos()
def __getitem__(self, idx):
resolution = self.resolution
while True:
try:
video, _, _, idx = self._clips.get_clip(idx)
except Exception as e:
print(idx, e)
idx = (idx + 1) % self._clips.num_clips()
continue
break
return dict(**preprocess(video, resolution, sample_every_n_frames=self.sample_every_n_frames))
class FrameDataset(data.Dataset):
"""
Generic dataset for videos stored as images. The loading will iterates over all the folders and subfolders
in the provided directory. Each leaf folder is assumed to contain frames from a single video.
Args:
data_folder: path to the folder with video frames. The folder
should contain folders with frames from each video.
sequence_length: length of extracted video sequences
resolution: resolution of the returned videos
sample_every_n_frames: sample every n frames from the video
"""
def __init__(self, data_folder, sequence_length, resolution=64, sample_every_n_frames=1):
self.resolution = resolution
self.sequence_length = sequence_length
self.sample_every_n_frames = sample_every_n_frames
self.data_all = self.load_video_frames(data_folder)
self.video_num = len(self.data_all)
def __getitem__(self, index):
batch_data = self.getTensor(index)
return_list = {'video': batch_data}
return return_list
def load_video_frames(self, dataroot: str) -> list:
'''
Loads all the video frames under the dataroot and returns a list of all the video frames.
Args:
dataroot: The root directory containing the video frames.
Returns:
A list of all the video frames.
'''
data_all = []
frame_list = os.walk(dataroot)
for _, meta in enumerate(frame_list):
root = meta[0]
try:
frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
except:
print(meta[0], meta[2])
if len(frames) < max(0, self.sequence_length * self.sample_every_n_frames):
continue
frames = [
os.path.join(root, item) for item in frames
if is_image_file(item)
]
if len(frames) > max(0, self.sequence_length * self.sample_every_n_frames):
data_all.append(frames)
return data_all
def getTensor(self, index: int) -> torch.Tensor:
'''
Returns a tensor of the video frames at the given index.
Args:
index: The index of the video frames to return.
Returns:
A BCTHW tensor in the range `[0, 1]` of the video frames at the given index.
'''
video = self.data_all[index]
video_len = len(video)
# load the entire video when sequence_length = -1, whiel the sample_every_n_frames has to be 1
if self.sequence_length == -1:
assert self.sample_every_n_frames == 1
start_idx = 0
end_idx = video_len
else:
n_frames_interval = self.sequence_length * self.sample_every_n_frames
start_idx = random.randint(0, video_len - n_frames_interval)
end_idx = start_idx + n_frames_interval
img = Image.open(video[0])
h, w = img.height, img.width
if h > w:
half = (h - w) // 2
cropsize = (0, half, w, half + w) # left, upper, right, lower
elif w > h:
half = (w - h) // 2
cropsize = (half, 0, half + h, h)
images = []
for i in range(start_idx, end_idx,
self.sample_every_n_frames):
path = video[i]
img = Image.open(path)
if h != w:
img = img.crop(cropsize)
img = img.resize(
(self.resolution, self.resolution),
Image.ANTIALIAS)
img = np.asarray(img, dtype=np.float32)
img /= 255.
img_tensor = preprocess_image(img).unsqueeze(0)
images.append(img_tensor)
video_clip = torch.cat(images).permute(3, 0, 1, 2)
return video_clip
def __len__(self):
return self.video_num

View File

@@ -0,0 +1,161 @@
# Adapted from https://github.com/universome/stylegan-v/blob/master/src/metrics/metric_utils.py
import os
import random
import torch
import pickle
import numpy as np
from typing import List, Tuple
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
class FeatureStats:
'''
Class to store statistics of features, including all features and mean/covariance.
Args:
capture_all: Whether to store all the features.
capture_mean_cov: Whether to store mean and covariance.
max_items: Maximum number of items to store.
'''
def __init__(self, capture_all: bool = False, capture_mean_cov: bool = False, max_items: int = None):
'''
'''
self.capture_all = capture_all
self.capture_mean_cov = capture_mean_cov
self.max_items = max_items
self.num_items = 0
self.num_features = None
self.all_features = None
self.raw_mean = None
self.raw_cov = None
def set_num_features(self, num_features: int):
'''
Set the number of features diminsions.
Args:
num_features: Number of features diminsions.
'''
if self.num_features is not None:
assert num_features == self.num_features
else:
self.num_features = num_features
self.all_features = []
self.raw_mean = np.zeros([num_features], dtype=np.float64)
self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
def is_full(self) -> bool:
'''
Check if the maximum number of samples is reached.
Returns:
True if the storage is full, False otherwise.
'''
return (self.max_items is not None) and (self.num_items >= self.max_items)
def append(self, x: np.ndarray):
'''
Add the newly computed features to the list. Update the mean and covariance.
Args:
x: New features to record.
'''
x = np.asarray(x, dtype=np.float32)
assert x.ndim == 2
if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
if self.num_items >= self.max_items:
return
x = x[:self.max_items - self.num_items]
self.set_num_features(x.shape[1])
self.num_items += x.shape[0]
if self.capture_all:
self.all_features.append(x)
if self.capture_mean_cov:
x64 = x.astype(np.float64)
self.raw_mean += x64.sum(axis=0)
self.raw_cov += x64.T @ x64
def append_torch(self, x: torch.Tensor, rank: int, num_gpus: int):
'''
Add the newly computed PyTorch features to the list. Update the mean and covariance.
Args:
x: New features to record.
rank: Rank of the current GPU.
num_gpus: Total number of GPUs.
'''
assert isinstance(x, torch.Tensor) and x.ndim == 2
assert 0 <= rank < num_gpus
if num_gpus > 1:
ys = []
for src in range(num_gpus):
y = x.clone()
torch.distributed.broadcast(y, src=src)
ys.append(y)
x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
self.append(x.cpu().numpy())
def get_all(self) -> np.ndarray:
'''
Get all the stored features as NumPy Array.
Returns:
Concatenation of the stored features.
'''
assert self.capture_all
return np.concatenate(self.all_features, axis=0)
def get_all_torch(self) -> torch.Tensor:
'''
Get all the stored features as PyTorch Tensor.
Returns:
Concatenation of the stored features.
'''
return torch.from_numpy(self.get_all())
def get_mean_cov(self) -> Tuple[np.ndarray, np.ndarray]:
'''
Get the mean and covariance of the stored features.
Returns:
Mean and covariance of the stored features.
'''
assert self.capture_mean_cov
mean = self.raw_mean / self.num_items
cov = self.raw_cov / self.num_items
cov = cov - np.outer(mean, mean)
return mean, cov
def save(self, pkl_file: str):
'''
Save the features and statistics to a pickle file.
Args:
pkl_file: Path to the pickle file.
'''
with open(pkl_file, 'wb') as f:
pickle.dump(self.__dict__, f)
@staticmethod
def load(pkl_file: str) -> 'FeatureStats':
'''
Load the features and statistics from a pickle file.
Args:
pkl_file: Path to the pickle file.
'''
with open(pkl_file, 'rb') as f:
s = pickle.load(f)
obj = FeatureStats(capture_all=s['capture_all'], max_items=s['max_items'])
obj.__dict__.update(s)
print('Loaded %d features from %s' % (obj.num_items, pkl_file))
return obj

View File

@@ -0,0 +1,145 @@
# Adapted from https://github.com/guanjz20/StyleSync/blob/main/utils.py
import numpy as np
import cv2
import torch
from einops import rearrange
import kornia
class AlignRestore(object):
def __init__(self, align_points=3, resolution=256, device="cpu", dtype=torch.float16):
if align_points == 3:
self.upscale_factor = 1
ratio = resolution / 256 * 2.8
self.crop_ratio = (ratio, ratio)
self.face_template = np.array([[19 - 2, 30 - 10], [56 + 2, 30 - 10], [37.5, 45 - 5]])
self.face_template = self.face_template * ratio
self.face_size = (int(75 * self.crop_ratio[0]), int(100 * self.crop_ratio[1]))
self.p_bias = None
self.device = device
self.dtype = dtype
self.fill_value = torch.tensor([127, 127, 127], device=device, dtype=dtype)
self.mask = torch.ones((1, 1, self.face_size[1], self.face_size[0]), device=device, dtype=dtype)
def align_warp_face(self, img, landmarks3, smooth=True):
affine_matrix, self.p_bias = self.transformation_from_points(
landmarks3, self.face_template, smooth, self.p_bias
)
img = rearrange(torch.from_numpy(img).to(device=self.device, dtype=self.dtype), "h w c -> c h w").unsqueeze(0)
affine_matrix = torch.from_numpy(affine_matrix).to(device=self.device, dtype=self.dtype).unsqueeze(0)
cropped_face = kornia.geometry.transform.warp_affine(
img,
affine_matrix,
(self.face_size[1], self.face_size[0]),
mode="bilinear",
padding_mode="fill",
fill_value=self.fill_value,
)
cropped_face = rearrange(cropped_face.squeeze(0), "c h w -> h w c").cpu().numpy().astype(np.uint8)
return cropped_face, affine_matrix
def restore_img(self, input_img, face, affine_matrix):
h, w, _ = input_img.shape
if isinstance(affine_matrix, np.ndarray):
affine_matrix = torch.from_numpy(affine_matrix).to(device=self.device, dtype=self.dtype).unsqueeze(0)
inv_affine_matrix = kornia.geometry.transform.invert_affine_transform(affine_matrix)
face = face.to(dtype=self.dtype).unsqueeze(0)
inv_face = kornia.geometry.transform.warp_affine(
face, inv_affine_matrix, (h, w), mode="bilinear", padding_mode="fill", fill_value=self.fill_value
).squeeze(0)
inv_face = (inv_face / 2 + 0.5).clamp(0, 1) * 255
input_img = rearrange(torch.from_numpy(input_img).to(device=self.device, dtype=self.dtype), "h w c -> c h w")
inv_mask = kornia.geometry.transform.warp_affine(
self.mask, inv_affine_matrix, (h, w), padding_mode="zeros"
) # (1, 1, h_up, w_up)
inv_mask_erosion = kornia.morphology.erosion(
inv_mask,
torch.ones(
(int(2 * self.upscale_factor), int(2 * self.upscale_factor)), device=self.device, dtype=self.dtype
),
)
inv_mask_erosion_t = inv_mask_erosion.squeeze(0).expand_as(inv_face)
pasted_face = inv_mask_erosion_t * inv_face
total_face_area = torch.sum(inv_mask_erosion.float())
w_edge = int(total_face_area**0.5) // 20
erosion_radius = w_edge * 2
# This step will consume a large amount of GPU memory.
# inv_mask_center = kornia.morphology.erosion(
# inv_mask_erosion, torch.ones((erosion_radius, erosion_radius), device=self.device, dtype=self.dtype)
# )
# Run on CPU to avoid consuming a large amount of GPU memory.
inv_mask_erosion = inv_mask_erosion.squeeze().cpu().numpy().astype(np.float32)
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
inv_mask_center = torch.from_numpy(inv_mask_center).to(device=self.device, dtype=self.dtype)[None, None, ...]
blur_size = w_edge * 2 + 1
sigma = 0.3 * ((blur_size - 1) * 0.5 - 1) + 0.8
inv_soft_mask = kornia.filters.gaussian_blur2d(
inv_mask_center, (blur_size, blur_size), (sigma, sigma)
).squeeze(0)
inv_soft_mask_3d = inv_soft_mask.expand_as(inv_face)
img_back = inv_soft_mask_3d * pasted_face + (1 - inv_soft_mask_3d) * input_img
img_back = rearrange(img_back, "c h w -> h w c").contiguous().to(dtype=torch.uint8)
img_back = img_back.cpu().numpy()
return img_back
def transformation_from_points(self, points1: torch.Tensor, points0: torch.Tensor, smooth=True, p_bias=None):
if isinstance(points0, np.ndarray):
points2 = torch.tensor(points0, device=self.device, dtype=torch.float32)
else:
points2 = points0.clone()
if isinstance(points1, np.ndarray):
points1_tensor = torch.tensor(points1, device=self.device, dtype=torch.float32)
else:
points1_tensor = points1.clone()
c1 = torch.mean(points1_tensor, dim=0)
c2 = torch.mean(points2, dim=0)
points1_centered = points1_tensor - c1
points2_centered = points2 - c2
s1 = torch.std(points1_centered)
s2 = torch.std(points2_centered)
points1_normalized = points1_centered / s1
points2_normalized = points2_centered / s2
covariance = torch.matmul(points1_normalized.T, points2_normalized)
U, S, V = torch.svd(covariance.float())
R = torch.matmul(V, U.T)
det = torch.det(R.float())
if det < 0:
V[:, -1] = -V[:, -1]
R = torch.matmul(V, U.T)
sR = (s2 / s1) * R
T = c2.reshape(2, 1) - (s2 / s1) * torch.matmul(R, c1.reshape(2, 1))
M = torch.cat((sR, T), dim=1)
if smooth:
bias = points2_normalized[2] - points1_normalized[2]
if p_bias is None:
p_bias = bias
else:
bias = p_bias * 0.2 + bias * 0.8
p_bias = bias
M[:, 2] = M[:, 2] + bias
return M.cpu().numpy(), p_bias

View File

@@ -0,0 +1,194 @@
# Adapted from https://github.com/Rudrabha/Wav2Lip/blob/master/audio.py
import librosa
import librosa.filters
import numpy as np
from scipy import signal
from scipy.io import wavfile
from omegaconf import OmegaConf
import torch
audio_config_path = "configs/audio.yaml"
config = OmegaConf.load(audio_config_path)
def load_wav(path, sr):
return librosa.core.load(path, sr=sr)[0]
def save_wav(wav, path, sr):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
# proposed by @dsmiller
wavfile.write(path, sr, wav.astype(np.int16))
def save_wavenet_wav(wav, path, sr):
librosa.output.write_wav(path, wav, sr=sr)
def preemphasis(wav, k, preemphasize=True):
if preemphasize:
return signal.lfilter([1, -k], [1], wav)
return wav
def inv_preemphasis(wav, k, inv_preemphasize=True):
if inv_preemphasize:
return signal.lfilter([1], [1, -k], wav)
return wav
def get_hop_size():
hop_size = config.audio.hop_size
if hop_size is None:
assert config.audio.frame_shift_ms is not None
hop_size = int(config.audio.frame_shift_ms / 1000 * config.audio.sample_rate)
return hop_size
def linearspectrogram(wav):
D = _stft(preemphasis(wav, config.audio.preemphasis, config.audio.preemphasize))
S = _amp_to_db(np.abs(D)) - config.audio.ref_level_db
if config.audio.signal_normalization:
return _normalize(S)
return S
def melspectrogram(wav):
D = _stft(preemphasis(wav, config.audio.preemphasis, config.audio.preemphasize))
S = _amp_to_db(_linear_to_mel(np.abs(D))) - config.audio.ref_level_db
if config.audio.signal_normalization:
return _normalize(S)
return S
def _lws_processor():
import lws
return lws.lws(config.audio.n_fft, get_hop_size(), fftsize=config.audio.win_size, mode="speech")
def _stft(y):
if config.audio.use_lws:
return _lws_processor(config.audio).stft(y).T
else:
return librosa.stft(y=y, n_fft=config.audio.n_fft, hop_length=get_hop_size(), win_length=config.audio.win_size)
##########################################################
# Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
def num_frames(length, fsize, fshift):
"""Compute number of time frames of spectrogram"""
pad = fsize - fshift
if length % fshift == 0:
M = (length + pad * 2 - fsize) // fshift + 1
else:
M = (length + pad * 2 - fsize) // fshift + 2
return M
def pad_lr(x, fsize, fshift):
"""Compute left and right padding"""
M = num_frames(len(x), fsize, fshift)
pad = fsize - fshift
T = len(x) + 2 * pad
r = (M - 1) * fshift + fsize - T
return pad, pad + r
##########################################################
# Librosa correct padding
def librosa_pad_lr(x, fsize, fshift):
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
# Conversions
_mel_basis = None
def _linear_to_mel(spectogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = _build_mel_basis()
return np.dot(_mel_basis, spectogram)
def _build_mel_basis():
assert config.audio.fmax <= config.audio.sample_rate // 2
return librosa.filters.mel(
sr=config.audio.sample_rate,
n_fft=config.audio.n_fft,
n_mels=config.audio.num_mels,
fmin=config.audio.fmin,
fmax=config.audio.fmax,
)
def _amp_to_db(x):
min_level = np.exp(config.audio.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x))
def _db_to_amp(x):
return np.power(10.0, (x) * 0.05)
def _normalize(S):
if config.audio.allow_clipping_in_normalization:
if config.audio.symmetric_mels:
return np.clip(
(2 * config.audio.max_abs_value) * ((S - config.audio.min_level_db) / (-config.audio.min_level_db))
- config.audio.max_abs_value,
-config.audio.max_abs_value,
config.audio.max_abs_value,
)
else:
return np.clip(
config.audio.max_abs_value * ((S - config.audio.min_level_db) / (-config.audio.min_level_db)),
0,
config.audio.max_abs_value,
)
assert S.max() <= 0 and S.min() - config.audio.min_level_db >= 0
if config.audio.symmetric_mels:
return (2 * config.audio.max_abs_value) * (
(S - config.audio.min_level_db) / (-config.audio.min_level_db)
) - config.audio.max_abs_value
else:
return config.audio.max_abs_value * ((S - config.audio.min_level_db) / (-config.audio.min_level_db))
def _denormalize(D):
if config.audio.allow_clipping_in_normalization:
if config.audio.symmetric_mels:
return (
(np.clip(D, -config.audio.max_abs_value, config.audio.max_abs_value) + config.audio.max_abs_value)
* -config.audio.min_level_db
/ (2 * config.audio.max_abs_value)
) + config.audio.min_level_db
else:
return (
np.clip(D, 0, config.audio.max_abs_value) * -config.audio.min_level_db / config.audio.max_abs_value
) + config.audio.min_level_db
if config.audio.symmetric_mels:
return (
(D + config.audio.max_abs_value) * -config.audio.min_level_db / (2 * config.audio.max_abs_value)
) + config.audio.min_level_db
else:
return (D * -config.audio.min_level_db / config.audio.max_abs_value) + config.audio.min_level_db
def get_melspec_overlap(audio_samples, melspec_length=52):
mel_spec_overlap = melspectrogram(audio_samples.numpy())
mel_spec_overlap = torch.from_numpy(mel_spec_overlap)
i = 0
mel_spec_overlap_list = []
while i + melspec_length < mel_spec_overlap.shape[1] - 3:
mel_spec_overlap_list.append(mel_spec_overlap[:, i : i + melspec_length].unsqueeze(0))
i += 3
mel_spec_overlap = torch.stack(mel_spec_overlap_list)
return mel_spec_overlap

View File

@@ -0,0 +1,157 @@
# We modified the original AVReader class of decord to solve the problem of memory leak.
# For more details, refer to: https://github.com/dmlc/decord/issues/208
import numpy as np
from decord.video_reader import VideoReader
from decord.audio_reader import AudioReader
from decord.ndarray import cpu
from decord import ndarray as _nd
from decord.bridge import bridge_out
class AVReader(object):
"""Individual audio video reader with convenient indexing function.
Parameters
----------
uri: str
Path of file.
ctx: decord.Context
The context to decode the file, can be decord.cpu() or decord.gpu().
sample_rate: int, default is -1
Desired output sample rate of the audio, unchanged if `-1` is specified.
mono: bool, default is True
Desired output channel layout of the audio. `True` is mono layout. `False` is unchanged.
width : int, default is -1
Desired output width of the video, unchanged if `-1` is specified.
height : int, default is -1
Desired output height of the video, unchanged if `-1` is specified.
num_threads : int, default is 0
Number of decoding thread, auto if `0` is specified.
fault_tol : int, default is -1
The threshold of corupted and recovered frames. This is to prevent silent fault
tolerance when for example 50% frames of a video cannot be decoded and duplicate
frames are returned. You may find the fault tolerant feature sweet in many cases,
but not for training models. Say `N = # recovered frames`
If `fault_tol` < 0, nothing will happen.
If 0 < `fault_tol` < 1.0, if N > `fault_tol * len(video)`, raise `DECORDLimitReachedError`.
If 1 < `fault_tol`, if N > `fault_tol`, raise `DECORDLimitReachedError`.
"""
def __init__(
self, uri, ctx=cpu(0), sample_rate=44100, mono=True, width=-1, height=-1, num_threads=0, fault_tol=-1
):
self.__audio_reader = AudioReader(uri, ctx, sample_rate, mono)
self.__audio_reader.add_padding()
if hasattr(uri, "read"):
uri.seek(0)
self.__video_reader = VideoReader(uri, ctx, width, height, num_threads, fault_tol)
self.__video_reader.seek(0)
def __len__(self):
"""Get length of the video. Note that sometimes FFMPEG reports inaccurate number of frames,
we always follow what FFMPEG reports.
Returns
-------
int
The number of frames in the video file.
"""
return len(self.__video_reader)
def __getitem__(self, idx):
"""Get audio samples and video frame at `idx`.
Parameters
----------
idx : int or slice
The frame index, can be negative which means it will index backwards,
or slice of frame indices.
Returns
-------
(ndarray/list of ndarray, ndarray)
First element is samples of shape CxS or a list of length N containing samples of shape CxS,
where N is the number of frames, C is the number of channels,
S is the number of samples of the corresponding frame.
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
where N is the length of the slice.
"""
assert self.__video_reader is not None and self.__audio_reader is not None
if isinstance(idx, slice):
return self.get_batch(range(*idx.indices(len(self.__video_reader))))
if idx < 0:
idx += len(self.__video_reader)
if idx >= len(self.__video_reader) or idx < 0:
raise IndexError("Index: {} out of bound: {}".format(idx, len(self.__video_reader)))
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
results = (self.__audio_reader[audio_start_idx:audio_end_idx], self.__video_reader[idx])
self.__video_reader.seek(0)
return results
def get_batch(self, indices):
"""Get entire batch of audio samples and video frames.
Parameters
----------
indices : list of integers
A list of frame indices. If negative indices detected, the indices will be indexed from backward
Returns
-------
(list of ndarray, ndarray)
First element is a list of length N containing samples of shape CxS,
where N is the number of frames, C is the number of channels,
S is the number of samples of the corresponding frame.
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
where N is the length of the slice.
"""
assert self.__video_reader is not None and self.__audio_reader is not None
indices = self._validate_indices(indices)
audio_arr = []
prev_video_idx = None
prev_audio_end_idx = None
for idx in list(indices):
frame_start_time, frame_end_time = self.__video_reader.get_frame_timestamp(idx)
# timestamp and sample conversion could have some error that could cause non-continuous audio
# we detect if retrieving continuous frame and make the audio continuous
if prev_video_idx and idx == prev_video_idx + 1:
audio_start_idx = prev_audio_end_idx
else:
audio_start_idx = self.__audio_reader._time_to_sample(frame_start_time)
audio_end_idx = self.__audio_reader._time_to_sample(frame_end_time)
audio_arr.append(self.__audio_reader[audio_start_idx:audio_end_idx])
prev_video_idx = idx
prev_audio_end_idx = audio_end_idx
results = (audio_arr, self.__video_reader.get_batch(indices))
self.__video_reader.seek(0)
return results
def _get_slice(self, sl):
audio_arr = np.empty(shape=(self.__audio_reader.shape()[0], 0), dtype="float32")
for idx in list(sl):
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
audio_arr = np.concatenate(
(audio_arr, self.__audio_reader[audio_start_idx:audio_end_idx].asnumpy()), axis=1
)
results = (bridge_out(_nd.array(audio_arr)), self.__video_reader.get_batch(sl))
self.__video_reader.seek(0)
return results
def _validate_indices(self, indices):
"""Validate int64 integers and convert negative integers to positive by backward search"""
assert self.__video_reader is not None and self.__audio_reader is not None
indices = np.array(indices, dtype=np.int64)
# process negative indices
indices[indices < 0] += len(self.__video_reader)
if not (indices >= 0).all():
raise IndexError("Invalid negative indices: {}".format(indices[indices < 0] + len(self.__video_reader)))
if not (indices < len(self.__video_reader)).all():
raise IndexError("Out of bound indices: {}".format(indices[indices >= len(self.__video_reader)]))
return indices

View File

@@ -0,0 +1,115 @@
from insightface.app import FaceAnalysis
import numpy as np
import torch
INSIGHTFACE_DETECT_SIZE = 512
class FaceDetector:
def __init__(self, device="cuda"):
self.app = FaceAnalysis(
allowed_modules=["detection", "landmark_2d_106"],
root="checkpoints/auxiliary",
providers=["CUDAExecutionProvider"],
)
self.app.prepare(ctx_id=cuda_to_int(device), det_size=(INSIGHTFACE_DETECT_SIZE, INSIGHTFACE_DETECT_SIZE))
def __call__(self, frame, threshold=0.5):
f_h, f_w, _ = frame.shape
faces = self.app.get(frame)
get_face_store = None
max_size = 0
if len(faces) == 0:
return None, None
else:
for face in faces:
bbox = face.bbox.astype(np.int_).tolist()
w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
if w < 50 or h < 80:
continue
if w / h > 1.5 or w / h < 0.2:
continue
if face.det_score < threshold:
continue
size_now = w * h
if size_now > max_size:
max_size = size_now
get_face_store = face
if get_face_store is None:
return None, None
else:
face = get_face_store
lmk = np.round(face.landmark_2d_106).astype(np.int_)
halk_face_coord = np.mean([lmk[74], lmk[73]], axis=0) # lmk[73]
sub_lmk = lmk[LMK_ADAPT_ORIGIN_ORDER]
halk_face_dist = np.max(sub_lmk[:, 1]) - halk_face_coord[1]
upper_bond = halk_face_coord[1] - halk_face_dist # *0.94
x1, y1, x2, y2 = (np.min(sub_lmk[:, 0]), int(upper_bond), np.max(sub_lmk[:, 0]), np.max(sub_lmk[:, 1]))
if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0:
x1, y1, x2, y2 = face.bbox.astype(np.int_).tolist()
y2 += int((x2 - x1) * 0.1)
x1 -= int((x2 - x1) * 0.05)
x2 += int((x2 - x1) * 0.05)
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(f_w, x2)
y2 = min(f_h, y2)
return (x1, y1, x2, y2), lmk
def cuda_to_int(cuda_str: str) -> int:
"""
Convert the string with format "cuda:X" to integer X.
"""
if cuda_str == "cuda":
return 0
device = torch.device(cuda_str)
if device.type != "cuda":
raise ValueError(f"Device type must be 'cuda', got: {device.type}")
return device.index
LMK_ADAPT_ORIGIN_ORDER = [
1,
10,
12,
14,
16,
3,
5,
7,
0,
23,
21,
19,
32,
30,
28,
26,
17,
43,
48,
49,
51,
50,
102,
103,
104,
105,
101,
73,
74,
86,
]

View File

@@ -0,0 +1,122 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from latentsync.utils.util import read_video, write_video
from torchvision import transforms
import cv2
from einops import rearrange
import torch
import numpy as np
from typing import Union
from .affine_transform import AlignRestore
from .face_detector import FaceDetector
def load_fixed_mask(resolution: int, mask_image_path="latentsync/utils/mask.png") -> torch.Tensor:
mask_image = cv2.imread(mask_image_path)
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4) / 255.0
mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w")
return mask_image
class ImageProcessor:
def __init__(self, resolution: int = 512, device: str = "cpu", mask_image=None):
self.resolution = resolution
self.resize = transforms.Resize(
(resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True
)
self.normalize = transforms.Normalize([0.5], [0.5], inplace=True)
self.restorer = AlignRestore(resolution=resolution, device=device)
if mask_image is None:
self.mask_image = load_fixed_mask(resolution)
else:
self.mask_image = mask_image
if device == "cpu":
self.face_detector = None
else:
self.face_detector = FaceDetector(device=device)
def affine_transform(self, image: torch.Tensor) -> np.ndarray:
if self.face_detector is None:
raise NotImplementedError("Using the CPU for face detection is not supported")
bbox, landmark_2d_106 = self.face_detector(image)
if bbox is None:
raise RuntimeError("Face not detected")
pt_left_eye = np.mean(landmark_2d_106[[43, 48, 49, 51, 50]], axis=0) # left eyebrow center
pt_right_eye = np.mean(landmark_2d_106[101:106], axis=0) # right eyebrow center
pt_nose = np.mean(landmark_2d_106[[74, 77, 83, 86]], axis=0) # nose center
landmarks3 = np.round([pt_left_eye, pt_right_eye, pt_nose])
face, affine_matrix = self.restorer.align_warp_face(image.copy(), landmarks3=landmarks3, smooth=True)
box = [0, 0, face.shape[1], face.shape[0]] # x1, y1, x2, y2
face = cv2.resize(face, (self.resolution, self.resolution), interpolation=cv2.INTER_LANCZOS4)
face = rearrange(torch.from_numpy(face), "h w c -> c h w")
return face, box, affine_matrix
def preprocess_fixed_mask_image(self, image: torch.Tensor, affine_transform=False):
if affine_transform:
image, _, _ = self.affine_transform(image)
else:
image = self.resize(image)
pixel_values = self.normalize(image / 255.0)
masked_pixel_values = pixel_values * self.mask_image
return pixel_values, masked_pixel_values, self.mask_image[0:1]
def prepare_masks_and_masked_images(self, images: Union[torch.Tensor, np.ndarray], affine_transform=False):
if isinstance(images, np.ndarray):
images = torch.from_numpy(images)
if images.shape[3] == 3:
images = rearrange(images, "f h w c -> f c h w")
results = [self.preprocess_fixed_mask_image(image, affine_transform=affine_transform) for image in images]
pixel_values_list, masked_pixel_values_list, masks_list = list(zip(*results))
return torch.stack(pixel_values_list), torch.stack(masked_pixel_values_list), torch.stack(masks_list)
def process_images(self, images: Union[torch.Tensor, np.ndarray]):
if isinstance(images, np.ndarray):
images = torch.from_numpy(images)
if images.shape[3] == 3:
images = rearrange(images, "f h w c -> f c h w")
images = self.resize(images)
pixel_values = self.normalize(images / 255.0)
return pixel_values
class VideoProcessor:
def __init__(self, resolution: int = 512, device: str = "cpu"):
self.image_processor = ImageProcessor(resolution, device)
def affine_transform_video(self, video_path):
video_frames = read_video(video_path, change_fps=False)
results = []
for frame in video_frames:
frame, _, _ = self.image_processor.affine_transform(frame)
results.append(frame)
results = torch.stack(results)
results = rearrange(results, "f c h w -> f h w c").numpy()
return results
if __name__ == "__main__":
video_processor = VideoProcessor(256, "cuda")
video_frames = video_processor.affine_transform_video("assets/demo2_video.mp4")
write_video("output.mp4", video_frames, fps=25)

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@@ -0,0 +1,289 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import json
from typing import Union
from pathlib import Path
import matplotlib.pyplot as plt
import imageio
import torch
import torch.nn as nn
import torchvision
import torch.distributed as dist
from torchvision import transforms
from einops import rearrange
import cv2
from decord import AudioReader, VideoReader
import shutil
import subprocess
# Machine epsilon for a float32 (single precision)
eps = np.finfo(np.float32).eps
def read_json(filepath: str):
with open(filepath) as f:
json_dict = json.load(f)
return json_dict
def read_video(video_path: str, change_fps=True, use_decord=True):
if change_fps:
temp_dir = "temp"
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
os.makedirs(temp_dir, exist_ok=True)
command = (
f"ffmpeg -loglevel error -y -nostdin -i {video_path} -r 25 -crf 18 {os.path.join(temp_dir, 'video.mp4')}"
)
subprocess.run(command, shell=True)
target_video_path = os.path.join(temp_dir, "video.mp4")
else:
target_video_path = video_path
if use_decord:
return read_video_decord(target_video_path)
else:
return read_video_cv2(target_video_path)
def read_video_decord(video_path: str):
vr = VideoReader(video_path)
video_frames = vr[:].asnumpy()
vr.seek(0)
return video_frames
def read_video_cv2(video_path: str):
# Open the video file
cap = cv2.VideoCapture(video_path)
# Check if the video was opened successfully
if not cap.isOpened():
print("Error: Could not open video.")
return np.array([])
frames = []
while True:
# Read a frame
ret, frame = cap.read()
# If frame is read correctly ret is True
if not ret:
break
# Convert BGR to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame_rgb)
# Release the video capture object
cap.release()
return np.array(frames)
def read_audio(audio_path: str, audio_sample_rate: int = 16000):
if audio_path is None:
raise ValueError("Audio path is required.")
ar = AudioReader(audio_path, sample_rate=audio_sample_rate, mono=True)
# To access the audio samples
audio_samples = torch.from_numpy(ar[:].asnumpy())
audio_samples = audio_samples.squeeze(0)
return audio_samples
def write_video(video_output_path: str, video_frames: np.ndarray, fps: int):
with imageio.get_writer(
video_output_path,
fps=fps,
codec="libx264",
macro_block_size=None,
ffmpeg_params=["-crf", "13"],
ffmpeg_log_level="error",
) as writer:
for video_frame in video_frames:
writer.append_data(video_frame)
def write_video_cv2(video_output_path: str, video_frames: np.ndarray, fps: int):
height, width = video_frames[0].shape[:2]
out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))
# out = cv2.VideoWriter(video_output_path, cv2.VideoWriter_fourcc(*"vp09"), fps, (width, height))
for frame in video_frames:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
out.write(frame)
out.release()
def init_dist(backend="nccl", **kwargs):
"""Initializes distributed environment."""
rank = int(os.environ["RANK"])
num_gpus = torch.cuda.device_count()
if num_gpus == 0:
raise RuntimeError("No GPUs available for training.")
local_rank = rank % num_gpus
torch.cuda.set_device(local_rank)
dist.init_process_group(backend=backend, **kwargs)
return local_rank
def zero_rank_print(s):
if dist.is_initialized() and dist.get_rank() == 0:
print("### " + s)
def zero_rank_log(logger, message: str):
if dist.is_initialized() and dist.get_rank() == 0:
logger.info(message)
def check_video_fps(video_path: str):
cam = cv2.VideoCapture(video_path)
fps = cam.get(cv2.CAP_PROP_FPS)
if fps != 25:
raise ValueError(f"Video FPS is not 25, it is {fps}. Please convert the video to 25 FPS.")
def one_step_sampling(ddim_scheduler, pred_noise, timesteps, x_t):
# Compute alphas, betas
alpha_prod_t = ddim_scheduler.alphas_cumprod[timesteps].to(dtype=pred_noise.dtype)
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/abs/2010.02502
if ddim_scheduler.config.prediction_type == "epsilon":
beta_prod_t = beta_prod_t[:, None, None, None, None]
alpha_prod_t = alpha_prod_t[:, None, None, None, None]
pred_original_sample = (x_t - beta_prod_t ** (0.5) * pred_noise) / alpha_prod_t ** (0.5)
else:
raise NotImplementedError("This prediction type is not implemented yet")
# Clip "predicted x_0"
if ddim_scheduler.config.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
return pred_original_sample
def plot_loss_chart(save_path: str, *args):
# Creating the plot
plt.figure()
for loss_line in args:
plt.plot(loss_line[1], loss_line[2], label=loss_line[0])
plt.xlabel("Step")
plt.ylabel("Loss")
plt.legend()
# Save the figure to a file
plt.savefig(save_path)
# Close the figure to free memory
plt.close()
CRED = "\033[91m"
CEND = "\033[0m"
def red_text(text: str):
return f"{CRED}{text}{CEND}"
log_loss = nn.BCELoss(reduction="none")
def cosine_loss(vision_embeds, audio_embeds, y):
sims = nn.functional.cosine_similarity(vision_embeds, audio_embeds)
# sims[sims!=sims] = 0 # remove nan
# sims = sims.clamp(0, 1)
loss = log_loss(sims.unsqueeze(1), y).squeeze()
return loss
def save_image(image, save_path):
# input size (C, H, W)
image = (image / 2 + 0.5).clamp(0, 1)
image = (image * 255).to(torch.uint8)
image = transforms.ToPILImage()(image)
# Save the image copy
image.save(save_path)
# Close the image file
image.close()
def gather_loss(loss, device):
# Sum the local loss across all processes
local_loss = loss.item()
global_loss = torch.tensor(local_loss, dtype=torch.float32).to(device)
dist.all_reduce(global_loss, op=dist.ReduceOp.SUM)
# Calculate the average loss across all processes
global_average_loss = global_loss.item() / dist.get_world_size()
return global_average_loss
def gather_video_paths_recursively(input_dir):
print(f"Recursively gathering video paths of {input_dir} ...")
paths = []
gather_video_paths(input_dir, paths)
return paths
def gather_video_paths(input_dir, paths):
for file in sorted(os.listdir(input_dir)):
if file.endswith(".mp4"):
filepath = os.path.join(input_dir, file)
paths.append(filepath)
elif os.path.isdir(os.path.join(input_dir, file)):
gather_video_paths(os.path.join(input_dir, file), paths)
def count_video_time(video_path):
video = cv2.VideoCapture(video_path)
frame_count = video.get(cv2.CAP_PROP_FRAME_COUNT)
fps = video.get(cv2.CAP_PROP_FPS)
return frame_count / fps
def check_ffmpeg_installed():
# Run the ffmpeg command with the -version argument to check if it's installed
result = subprocess.run("ffmpeg -version", stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
if not result.returncode == 0:
raise FileNotFoundError("ffmpeg not found, please install it by:\n $ conda install -c conda-forge ffmpeg")
def check_model_and_download(ckpt_path: str, huggingface_model_id: str = "ByteDance/LatentSync-1.5"):
if not os.path.exists(ckpt_path):
ckpt_path_obj = Path(ckpt_path)
download_cmd = f"huggingface-cli download {huggingface_model_id} {Path(*ckpt_path_obj.parts[1:])} --local-dir {Path(ckpt_path_obj.parts[0])}"
subprocess.run(download_cmd, shell=True)
class dummy_context:
def __enter__(self):
pass
def __exit__(self, *args):
pass

View File

@@ -0,0 +1,167 @@
# Adapted from https://github.com/TMElyralab/MuseTalk/blob/main/musetalk/whisper/audio2feature.py
from .whisper import load_model
import numpy as np
import torch
import os
from pathlib import Path
class Audio2Feature:
def __init__(
self,
model_path="checkpoints/whisper/tiny.pt",
device=None,
audio_embeds_cache_dir=None,
num_frames=16,
audio_feat_length=[2, 2],
):
self.model = load_model(model_path, device)
self.audio_embeds_cache_dir = audio_embeds_cache_dir
if audio_embeds_cache_dir is not None and audio_embeds_cache_dir != "":
Path(audio_embeds_cache_dir).mkdir(parents=True, exist_ok=True)
self.num_frames = num_frames
self.embedding_dim = self.model.dims.n_audio_state
self.audio_feat_length = audio_feat_length
def get_sliced_feature(self, feature_array, vid_idx, fps=25):
"""
Get sliced features based on a given index
:param feature_array:
:param start_idx: the start index of the feature
:param audio_feat_length:
:return:
"""
length = len(feature_array)
selected_feature = []
selected_idx = []
center_idx = int(vid_idx * 50 / fps)
left_idx = center_idx - self.audio_feat_length[0] * 2
right_idx = center_idx + (self.audio_feat_length[1] + 1) * 2
for idx in range(left_idx, right_idx):
idx = max(0, idx)
idx = min(length - 1, idx)
x = feature_array[idx]
selected_feature.append(x)
selected_idx.append(idx)
selected_feature = torch.cat(selected_feature, dim=0)
selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384
return selected_feature, selected_idx
def get_sliced_feature_sparse(self, feature_array, vid_idx, fps=25):
"""
Get sliced features based on a given index
:param feature_array:
:param start_idx: the start index of the feature
:param audio_feat_length:
:return:
"""
length = len(feature_array)
selected_feature = []
selected_idx = []
for dt in range(-self.audio_feat_length[0], self.audio_feat_length[1] + 1):
left_idx = int((vid_idx + dt) * 50 / fps)
if left_idx < 1 or left_idx > length - 1:
left_idx = max(0, left_idx)
left_idx = min(length - 1, left_idx)
x = feature_array[left_idx]
x = x[np.newaxis, :, :]
x = np.repeat(x, 2, axis=0)
selected_feature.append(x)
selected_idx.append(left_idx)
selected_idx.append(left_idx)
else:
x = feature_array[left_idx - 1 : left_idx + 1]
selected_feature.append(x)
selected_idx.append(left_idx - 1)
selected_idx.append(left_idx)
selected_feature = np.concatenate(selected_feature, axis=0)
selected_feature = selected_feature.reshape(-1, self.embedding_dim) # 50*384
selected_feature = torch.from_numpy(selected_feature)
return selected_feature, selected_idx
def feature2chunks(self, feature_array, fps):
whisper_chunks = []
whisper_idx_multiplier = 50.0 / fps
i = 0
print(f"video in {fps} FPS, audio idx in 50FPS")
while True:
start_idx = int(i * whisper_idx_multiplier)
selected_feature, selected_idx = self.get_sliced_feature(feature_array=feature_array, vid_idx=i, fps=fps)
# print(f"i:{i},selected_idx {selected_idx}")
whisper_chunks.append(selected_feature)
i += 1
if start_idx > len(feature_array):
break
return whisper_chunks
def _audio2feat(self, audio_path: str):
# get the sample rate of the audio
result = self.model.transcribe(audio_path)
embed_list = []
for emb in result["segments"]:
encoder_embeddings = emb["encoder_embeddings"]
encoder_embeddings = encoder_embeddings.transpose(0, 2, 1, 3)
encoder_embeddings = encoder_embeddings.squeeze(0)
start_idx = int(emb["start"])
end_idx = int(emb["end"])
emb_end_idx = int((end_idx - start_idx) / 2)
embed_list.append(encoder_embeddings[:emb_end_idx])
concatenated_array = torch.from_numpy(np.concatenate(embed_list, axis=0))
return concatenated_array
def audio2feat(self, audio_path):
if self.audio_embeds_cache_dir == "" or self.audio_embeds_cache_dir is None:
return self._audio2feat(audio_path)
audio_embeds_cache_path = os.path.join(
self.audio_embeds_cache_dir, os.path.basename(audio_path).replace(".mp4", "_embeds.pt")
)
if os.path.isfile(audio_embeds_cache_path):
try:
audio_feat = torch.load(audio_embeds_cache_path, weights_only=True)
except Exception as e:
print(f"{type(e).__name__} - {e} - {audio_embeds_cache_path}")
os.remove(audio_embeds_cache_path)
audio_feat = self._audio2feat(audio_path)
torch.save(audio_feat, audio_embeds_cache_path)
else:
audio_feat = self._audio2feat(audio_path)
torch.save(audio_feat, audio_embeds_cache_path)
return audio_feat
def crop_overlap_audio_window(self, audio_feat, start_index):
selected_feature_list = []
for i in range(start_index, start_index + self.num_frames):
selected_feature, selected_idx = self.get_sliced_feature(feature_array=audio_feat, vid_idx=i, fps=25)
selected_feature_list.append(selected_feature)
mel_overlap = torch.stack(selected_feature_list)
return mel_overlap
if __name__ == "__main__":
audio_encoder = Audio2Feature(model_path="checkpoints/whisper/tiny.pt")
audio_path = "assets/demo1_audio.wav"
array = audio_encoder.audio2feat(audio_path)
print(array.shape)
fps = 25
whisper_idx_multiplier = 50.0 / fps
i = 0
print(f"video in {fps} FPS, audio idx in 50FPS")
while True:
start_idx = int(i * whisper_idx_multiplier)
selected_feature, selected_idx = audio_encoder.get_sliced_feature(feature_array=array, vid_idx=i, fps=fps)
print(f"video idx {i},\t audio idx {selected_idx},\t shape {selected_feature.shape}")
i += 1
if start_idx > len(array):
break

View File

@@ -0,0 +1,122 @@
import hashlib
import io
import os
import urllib
import warnings
from typing import List, Optional, Union
import torch
from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import Whisper, ModelDimensions
from .transcribe import transcribe
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
}
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return model_bytes if in_memory else download_target
def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())
def load_model(
name: str, device: Optional[Union[str, torch.device]] = None, download_root: str = None, in_memory: bool = False
) -> Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
download_root = os.getenv("XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper"))
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") as fp:
checkpoint = torch.load(fp, map_location=device, weights_only=True)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
del checkpoint
torch.cuda.empty_cache()
return model.to(device)

View File

@@ -0,0 +1,4 @@
from .transcribe import cli
cli()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}

View File

@@ -0,0 +1 @@
{"unk_token": "<|endoftext|>", "bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "gpt2", "tokenizer_class": "GPT2Tokenizer"}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1 @@
{"<|endoftext|>": 50257}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
{"bos_token": "<|endoftext|>", "eos_token": "<|endoftext|>", "unk_token": "<|endoftext|>"}

View File

@@ -0,0 +1 @@
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "add_prefix_space": false, "model_max_length": 1024, "special_tokens_map_file": null, "name_or_path": "multilingual", "errors": "replace", "tokenizer_class": "GPT2Tokenizer"}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,125 @@
import os
from functools import lru_cache
from typing import Union
import ffmpeg
import numpy as np
import torch
import torch.nn.functional as F
from .utils import exact_div
# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
N_MELS = 80
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input
def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
try:
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
out, _ = (
ffmpeg.input(file, threads=0)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sr)
.run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
)
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(dim=axis, index=torch.arange(length))
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
else:
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
return array
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
)
"""
assert n_mels == 80, f"Unsupported n_mels: {n_mels}"
with np.load(os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
Returns
-------
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[:, :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec

View File

@@ -0,0 +1,729 @@
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Iterable, Optional, Sequence, Union, TYPE_CHECKING
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical
from .audio import CHUNK_LENGTH
from .tokenizer import Tokenizer, get_tokenizer
from .utils import compression_ratio
if TYPE_CHECKING:
from .model import Whisper
@torch.no_grad()
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None) -> Tuple[Tensor, List[dict]]:
"""
Detect the spoken language in the audio, and return them as list of strings, along with the ids
of the most probable language tokens and the probability distribution over all language tokens.
This is performed outside the main decode loop in order to not interfere with kv-caching.
Returns
-------
language_tokens : Tensor, shape = (n_audio,)
ids of the most probable language tokens, which appears after the startoftranscript token.
language_probs : List[Dict[str, float]], length = n_audio
list of dictionaries containing the probability distribution over all languages.
"""
if tokenizer is None:
tokenizer = get_tokenizer(model.is_multilingual)
if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
raise ValueError(f"This model doesn't have language tokens so it can't perform lang id")
single = mel.ndim == 2
if single:
mel = mel.unsqueeze(0)
# skip encoder forward pass if already-encoded audio features were given
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
mel = model.encoder(mel)
# forward pass using a single token, startoftranscript
n_audio = mel.shape[0]
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
logits = model.logits(x, mel)[:, 0]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)
language_token_probs = logits.softmax(dim=-1).cpu()
language_probs = [
{
c: language_token_probs[i, j].item()
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
}
for i in range(n_audio)
]
if single:
language_tokens = language_tokens[0]
language_probs = language_probs[0]
return language_tokens, language_probs
@dataclass(frozen=True)
class DecodingOptions:
task: str = "transcribe" # whether to perform X->X "transcribe" or X->English "translate"
language: Optional[str] = None # language that the audio is in; uses detected language if None
# sampling-related options
temperature: float = 0.0
sample_len: Optional[int] = None # maximum number of tokens to sample
best_of: Optional[int] = None # number of independent samples to collect, when t > 0
beam_size: Optional[int] = None # number of beams in beam search, when t == 0
patience: Optional[float] = None # patience in beam search (https://arxiv.org/abs/2204.05424)
# options for ranking generations (either beams or best-of-N samples)
length_penalty: Optional[float] = None # "alpha" in Google NMT, None defaults to length norm
# prompt, prefix, and token suppression
prompt: Optional[Union[str, List[int]]] = None # text or tokens for the previous context
prefix: Optional[Union[str, List[int]]] = None # text or tokens to prefix the current context
suppress_blank: bool = True # this will suppress blank outputs
# list of tokens ids (or comma-separated token ids) to suppress
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
# timestamp sampling options
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 1.0 # the initial timestamp cannot be later than this
# implementation details
fp16: bool = True # use fp16 for most of the calculation
@dataclass(frozen=True)
class DecodingResult:
audio_features: Tensor
language: str
encoder_embeddings: np.ndarray
decoder_embeddings: np.ndarray
language_probs: Optional[Dict[str, float]] = None
tokens: List[int] = field(default_factory=list)
text: str = ""
avg_logprob: float = np.nan
no_speech_prob: float = np.nan
temperature: float = np.nan
compression_ratio: float = np.nan
class Inference:
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
"""Perform a forward pass on the decoder and return per-token logits"""
raise NotImplementedError
def rearrange_kv_cache(self, source_indices) -> None:
"""Update the key-value cache according to the updated beams"""
raise NotImplementedError
def cleanup_caching(self) -> None:
"""Clean up any resources or hooks after decoding is finished"""
pass
class PyTorchInference(Inference):
def __init__(self, model: "Whisper", initial_token_length: int):
self.model: "Whisper" = model
self.initial_token_length = initial_token_length
self.kv_cache = {}
self.hooks = []
def logits(self, tokens: Tensor, audio_features: Tensor, include_embeddings=False) -> Tensor:
if not self.kv_cache:
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]
return_val = self.model.decoder(tokens, audio_features,
kv_cache=self.kv_cache, include_embeddings=include_embeddings)
return return_val
def cleanup_caching(self):
for hook in self.hooks:
hook.remove()
self.kv_cache = {}
self.hooks = []
def rearrange_kv_cache(self, source_indices):
for module, tensor in self.kv_cache.items():
# update the key/value cache to contain the selected sequences
self.kv_cache[module] = tensor[source_indices].detach()
class SequenceRanker:
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
"""
Given a list of groups of samples and their cumulative log probabilities,
return the indices of the samples in each group to select as the final result
"""
raise NotImplementedError
class MaximumLikelihoodRanker(SequenceRanker):
"""
Select the sample with the highest log probabilities, penalized using either
a simple length normalization or Google NMT paper's length penalty
"""
def __init__(self, length_penalty: Optional[float]):
self.length_penalty = length_penalty
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
def scores(logprobs, lengths):
result = []
for logprob, length in zip(logprobs, lengths):
if self.length_penalty is None:
penalty = length
else:
# from the Google NMT paper
penalty = ((5 + length) / 6) ** self.length_penalty
result.append(logprob / penalty)
return result
# get the sequence with the highest score
lengths = [[len(t) for t in s] for s in tokens]
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
class TokenDecoder:
def reset(self):
"""Initialize any stateful variables for decoding a new sequence"""
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
"""Specify how to select the next token, based on the current trace and logits
Parameters
----------
tokens : Tensor, shape = (n_batch, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence tokens
logits : Tensor, shape = (n_batch, vocab_size)
per-token logits of the probability distribution at the current step
sum_logprobs : Tensor, shape = (n_batch)
cumulative log probabilities for each sequence
Returns
-------
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
the tokens, appended with the selected next token
completed : bool
True if all sequences has reached the end of text
"""
raise NotImplementedError
def finalize(
self, tokens: Tensor, sum_logprobs: Tensor
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
"""Finalize search and return the final candidate sequences
Parameters
----------
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence
sum_logprobs : Tensor, shape = (n_audio, n_group)
cumulative log probabilities for each sequence
Returns
-------
tokens : Sequence[Sequence[Tensor]], length = n_audio
sequence of Tensors containing candidate token sequences, for each audio input
sum_logprobs : List[List[float]], length = n_audio
sequence of cumulative log probabilities corresponding to the above
"""
raise NotImplementedError
class GreedyDecoder(TokenDecoder):
def __init__(self, temperature: float, eot: int):
self.temperature = temperature
self.eot = eot
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
temperature = self.temperature
if temperature == 0:
next_tokens = logits.argmax(dim=-1)
else:
next_tokens = Categorical(logits=logits / temperature).sample()
logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
next_tokens[tokens[:, -1] == self.eot] = self.eot
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
completed = (tokens[:, -1] == self.eot).all()
return tokens, completed
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
# make sure each sequence has at least one EOT token at the end
tokens = F.pad(tokens, (0, 1), value=self.eot)
return tokens, sum_logprobs.tolist()
class BeamSearchDecoder(TokenDecoder):
def __init__(self, beam_size: int, eot: int, inference: Inference, patience: Optional[float] = None):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
self.patience = patience or 1.0
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences = None
assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self):
self.finished_sequences = None
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
n_audio = tokens.shape[0] // self.beam_size
if self.finished_sequences is None: # for the first update
self.finished_sequences = [{} for _ in range(n_audio)]
logprobs = F.log_softmax(logits.float(), dim=-1)
next_tokens, source_indices, finished_sequences = [], [], []
for i in range(n_audio):
scores, sources, finished = {}, {}, {}
# STEP 1: calculate the cumulative log probabilities for possible candidates
for j in range(self.beam_size):
idx = i * self.beam_size + j
prefix = tokens[idx].tolist()
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
new_logprob = (sum_logprobs[idx] + logprob).item()
sequence = tuple(prefix + [token.item()])
scores[sequence] = new_logprob
sources[sequence] = idx
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
saved = 0
for sequence in sorted(scores, key=scores.get, reverse=True):
if sequence[-1] == self.eot:
finished[sequence] = scores[sequence]
else:
sum_logprobs[len(next_tokens)] = scores[sequence]
next_tokens.append(sequence)
source_indices.append(sources[sequence])
saved += 1
if saved == self.beam_size:
break
finished_sequences.append(finished)
tokens = torch.tensor(next_tokens, device=tokens.device)
self.inference.rearrange_kv_cache(source_indices)
# add newly finished sequences to self.finished_sequences
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break # the candidate list is full
previously_finished[seq] = newly_finished[seq]
# mark as completed if all audio has enough number of samples
completed = all(
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
)
return tokens, completed
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
# collect all finished sequences, including patience, and add unfinished ones if not enough
sum_logprobs = sum_logprobs.cpu()
for i, sequences in enumerate(self.finished_sequences):
if len(sequences) < self.beam_size: # when not enough sequences are finished
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
sequence = preceding_tokens[i, j].tolist() + [self.eot]
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
if len(sequences) >= self.beam_size:
break
tokens: List[List[Tensor]] = [
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
]
sum_logprobs: List[List[float]] = [
list(sequences.values()) for sequences in self.finished_sequences
]
return tokens, sum_logprobs
class LogitFilter:
def apply(self, logits: Tensor, tokens: Tensor) -> None:
"""Apply any filtering or masking to logits in-place
Parameters
----------
logits : Tensor, shape = (n_batch, vocab_size)
per-token logits of the probability distribution at the current step
tokens : Tensor, shape = (n_batch, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence tokens
"""
raise NotImplementedError
class SuppressBlank(LogitFilter):
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
def apply(self, logits: Tensor, tokens: Tensor):
if tokens.shape[1] == self.sample_begin:
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
class SuppressTokens(LogitFilter):
def __init__(self, suppress_tokens: Sequence[int]):
self.suppress_tokens = list(suppress_tokens)
def apply(self, logits: Tensor, tokens: Tensor):
logits[:, self.suppress_tokens] = -np.inf
class ApplyTimestampRules(LogitFilter):
def __init__(
self, tokenizer: Tokenizer, sample_begin: int, max_initial_timestamp_index: Optional[int]
):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
self.max_initial_timestamp_index = max_initial_timestamp_index
def apply(self, logits: Tensor, tokens: Tensor):
# suppress <|notimestamps|> which is handled by without_timestamps
if self.tokenizer.no_timestamps is not None:
logits[:, self.tokenizer.no_timestamps] = -np.inf
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]):
seq = [t for t in tokens[k, self.sample_begin :].tolist()]
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
if last_was_timestamp:
if penultimate_was_timestamp: # has to be non-timestamp
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf
# apply the `max_initial_timestamp` option
if tokens.shape[1] == self.sample_begin and self.max_initial_timestamp_index is not None:
last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
logits[:, last_allowed + 1 :] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = F.log_softmax(logits.float(), dim=-1)
for k in range(tokens.shape[0]):
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
class DecodingTask:
inference: Inference
sequence_ranker: SequenceRanker
decoder: TokenDecoder
logit_filters: List[LogitFilter]
def __init__(self, model: "Whisper", options: DecodingOptions):
self.model = model
language = options.language or "en"
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=options.task)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
self.n_group: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
if self.options.without_timestamps:
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
self.sample_begin: int = len(self.initial_tokens)
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
# inference: implements the forward pass through the decoder, including kv caching
self.inference = PyTorchInference(model, len(self.initial_tokens))
# sequence ranker: implements how to rank a group of sampled sequences
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
# decoder: implements how to select the next tokens, given the autoregressive distribution
if options.beam_size is not None:
self.decoder = BeamSearchDecoder(
options.beam_size, tokenizer.eot, self.inference, options.patience
)
else:
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
# logit filters: applies various rules to suppress or penalize certain tokens
self.logit_filters = []
if self.options.suppress_blank:
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
if self.options.suppress_tokens:
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
if not options.without_timestamps:
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
max_initial_timestamp_index = None
if options.max_initial_timestamp:
max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
self.logit_filters.append(
ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
)
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
if options.beam_size is not None and options.best_of is not None:
raise ValueError("beam_size and best_of can't be given together")
if options.temperature == 0:
if options.best_of is not None:
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
if options.patience is not None and options.beam_size is None:
raise ValueError("patience requires beam_size to be given")
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
return options
def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence)
prefix = self.options.prefix
prompt = self.options.prompt
if prefix:
prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
)
if self.sample_len is not None:
max_prefix_len = self.n_ctx // 2 - self.sample_len
prefix_tokens = prefix_tokens[-max_prefix_len:]
tokens = tokens + prefix_tokens
if prompt:
prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
)
tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
return tuple(tokens)
def _get_suppress_tokens(self) -> Tuple[int]:
suppress_tokens = self.options.suppress_tokens
if isinstance(suppress_tokens, str):
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
if -1 in suppress_tokens:
suppress_tokens = [t for t in suppress_tokens if t >= 0]
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
elif suppress_tokens is None or len(suppress_tokens) == 0:
suppress_tokens = [] # interpret empty string as an empty list
else:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
suppress_tokens.extend(
[self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm]
)
if self.tokenizer.no_speech is not None:
# no-speech probability is collected separately
suppress_tokens.append(self.tokenizer.no_speech)
return tuple(sorted(set(suppress_tokens)))
def _get_audio_features(self, mel: Tensor, include_embeddings: bool = False):
if self.options.fp16:
mel = mel.half()
if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state):
# encoded audio features are given; skip audio encoding
audio_features = mel
else:
result = self.model.encoder(mel, include_embeddings)
if include_embeddings:
audio_features, embeddings = result
else:
audio_features = result
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
if include_embeddings:
return audio_features, embeddings
else:
return audio_features
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
languages = [self.options.language] * audio_features.shape[0]
lang_probs = None
if self.options.language is None or self.options.task == "lang_id":
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
languages = [max(probs, key=probs.get) for probs in lang_probs]
if self.options.language is None:
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
return languages, lang_probs
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
assert audio_features.shape[0] == tokens.shape[0]
n_batch = tokens.shape[0]
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
no_speech_probs = [np.nan] * n_batch
try:
embeddings = []
for i in range(self.sample_len):
logits, token_embeddings = self.inference.logits(tokens, audio_features, include_embeddings=True)
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
# now we need to consider the logits at the last token only
logits = logits[:, -1]
token_embeddings = token_embeddings[:, :, -1]
# Append embeddings together
embeddings.append(token_embeddings)
# apply the logit filters, e.g. for suppressing or applying penalty to
for logit_filter in self.logit_filters:
logit_filter.apply(logits, tokens)
# expand the tokens tensor with the selected next tokens
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
if completed or tokens.shape[-1] > self.n_ctx:
break
finally:
if completed:
embeddings = embeddings[:-1]
embeddings = np.stack(embeddings, 2)
self.inference.cleanup_caching()
return tokens, sum_logprobs, no_speech_probs, embeddings
@torch.no_grad()
def run(self, mel: Tensor) -> List[DecodingResult]:
self.decoder.reset()
tokenizer: Tokenizer = self.tokenizer
n_audio: int = mel.shape[0]
# encoder forward pass
forward_pass: Tuple[Tensor, np.ndarray] = self._get_audio_features(mel, include_embeddings=True)
audio_features, encoder_embeddings = forward_pass
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
# detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(audio_features, tokens)
if self.options.task == "lang_id":
return [
DecodingResult(audio_features=features, language=language, language_probs=probs)
for features, language, probs in zip(audio_features, languages, language_probs)
]
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
# call the main sampling loop
tokens, sum_logprobs, no_speech_probs, decoder_embeddings = self._main_loop(audio_features, tokens)
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
audio_features = audio_features[:: self.n_group]
no_speech_probs = no_speech_probs[:: self.n_group]
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
tokens = tokens.reshape(n_audio, self.n_group, -1)
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
# get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens: List[List[Tensor]] = [
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
]
# select the top-ranked sample in each group
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
if len(set(map(len, fields))) != 1:
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
return [
DecodingResult(
audio_features=features,
language=language,
tokens=tokens,
text=text,
avg_logprob=avg_logprob,
no_speech_prob=no_speech_prob,
temperature=self.options.temperature,
compression_ratio=compression_ratio(text),
encoder_embeddings=encoder_embeddings,
decoder_embeddings=decoder_embeddings
)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
]
@torch.no_grad()
def decode(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions()) -> Union[DecodingResult, List[DecodingResult]]:
"""
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
Parameters
----------
model: Whisper
the Whisper model instance
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
A tensor containing the Mel spectrogram(s)
options: DecodingOptions
A dataclass that contains all necessary options for decoding 30-second segments
Returns
-------
result: Union[DecodingResult, List[DecodingResult]]
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
"""
single = mel.ndim == 2
if single:
mel = mel.unsqueeze(0)
result = DecodingTask(model, options).run(mel)
if single:
result = result[0]
return result

View File

@@ -0,0 +1,290 @@
from dataclasses import dataclass
from typing import Dict
from typing import Iterable, Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from .transcribe import transcribe as transcribe_function
from .decoding import detect_language as detect_language_function, decode as decode_function
@dataclass
class ModelDimensions:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_vocab: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int
class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
return super().forward(x.float()).type(x.dtype)
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype)
)
class Conv1d(nn.Conv1d):
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
q = self.query(x)
if kv_cache is None or xa is None:
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
# otherwise, perform key/value projections for self- or cross-attention as usual.
k = self.key(x if xa is None else xa)
v = self.value(x if xa is None else xa)
else:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
k = kv_cache.get(self.key, self.key(xa))
v = kv_cache.get(self.value, self.value(xa))
wv = self.qkv_attention(q, k, v, mask)
return self.out(wv)
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None):
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
w = F.softmax(qk.float(), dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
self.mlp_ln = LayerNorm(n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)
x = x + self.mlp(self.mlp_ln(x))
return x
class AudioEncoder(nn.Module):
def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = LayerNorm(n_state)
def forward(self, x: Tensor, include_embeddings: bool = False):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
include_embeddings: bool
whether to include intermediate steps in the output
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
if include_embeddings:
embeddings = [x.cpu().detach().numpy()]
for block in self.blocks:
x = block(x)
if include_embeddings:
embeddings.append(x.cpu().detach().numpy())
x = self.ln_post(x)
if include_embeddings:
embeddings = np.stack(embeddings, axis=1)
return x, embeddings
else:
return x
class TextDecoder(nn.Module):
def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
)
self.ln = LayerNorm(n_state)
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None, include_embeddings: bool = False):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx)
the encoded audio features to be attended on
include_embeddings : bool
Whether to include intermediate values in the output to this function
"""
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
x = x.to(xa.dtype)
if include_embeddings:
embeddings = [x.cpu().detach().numpy()]
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
if include_embeddings:
embeddings.append(x.cpu().detach().numpy())
x = self.ln(x)
logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
if include_embeddings:
embeddings = np.stack(embeddings, axis=1)
return logits, embeddings
else:
return logits
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
def embed_audio(self, mel: torch.Tensor):
return self.encoder.forward(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder.forward(tokens, audio_features)
def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
@property
def device(self):
return next(self.parameters()).device
@property
def is_multilingual(self):
return self.dims.n_vocab == 51865
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
tensors calculated for the previous positions. This method returns a dictionary that stores
all caches, and the necessary hooks for the key and value projection modules that save the
intermediate tensors to be reused during later calculations.
Returns
-------
cache : Dict[nn.Module, torch.Tensor]
A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called
"""
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]:
cache[module] = output # save as-is, for the first token or cross attention
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function

View File

@@ -0,0 +1,2 @@
from .basic import BasicTextNormalizer
from .english import EnglishTextNormalizer

View File

@@ -0,0 +1,71 @@
import re
import unicodedata
import regex
# non-ASCII letters that are not separated by "NFKD" normalization
ADDITIONAL_DIACRITICS = {
"œ": "oe",
"Œ": "OE",
"ø": "o",
"Ø": "O",
"æ": "ae",
"Æ": "AE",
"ß": "ss",
"": "SS",
"đ": "d",
"Đ": "D",
"ð": "d",
"Ð": "D",
"þ": "th",
"Þ": "th",
"ł": "l",
"Ł": "L",
}
def remove_symbols_and_diacritics(s: str, keep=""):
"""
Replace any other markers, symbols, and punctuations with a space,
and drop any diacritics (category 'Mn' and some manual mappings)
"""
return "".join(
c
if c in keep
else ADDITIONAL_DIACRITICS[c]
if c in ADDITIONAL_DIACRITICS
else ""
if unicodedata.category(c) == "Mn"
else " "
if unicodedata.category(c)[0] in "MSP"
else c
for c in unicodedata.normalize("NFKD", s)
)
def remove_symbols(s: str):
"""
Replace any other markers, symbols, punctuations with a space, keeping diacritics
"""
return "".join(
" " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
)
class BasicTextNormalizer:
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
self.split_letters = split_letters
def __call__(self, s: str):
s = s.lower()
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
s = self.clean(s).lower()
if self.split_letters:
s = " ".join(regex.findall(r"\X", s, regex.U))
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
return s

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,543 @@
import json
import os
import re
from fractions import Fraction
from typing import Iterator, List, Match, Optional, Union
from more_itertools import windowed
from .basic import remove_symbols_and_diacritics
class EnglishNumberNormalizer:
"""
Convert any spelled-out numbers into arabic numbers, while handling:
- remove any commas
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
- spell out `one` and `ones`
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
"""
def __init__(self):
super().__init__()
self.zeros = {"o", "oh", "zero"}
self.ones = {
name: i
for i, name in enumerate(
[
"one",
"two",
"three",
"four",
"five",
"six",
"seven",
"eight",
"nine",
"ten",
"eleven",
"twelve",
"thirteen",
"fourteen",
"fifteen",
"sixteen",
"seventeen",
"eighteen",
"nineteen",
],
start=1,
)
}
self.ones_plural = {
"sixes" if name == "six" else name + "s": (value, "s")
for name, value in self.ones.items()
}
self.ones_ordinal = {
"zeroth": (0, "th"),
"first": (1, "st"),
"second": (2, "nd"),
"third": (3, "rd"),
"fifth": (5, "th"),
"twelfth": (12, "th"),
**{
name + ("h" if name.endswith("t") else "th"): (value, "th")
for name, value in self.ones.items()
if value > 3 and value != 5 and value != 12
},
}
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
self.tens = {
"twenty": 20,
"thirty": 30,
"forty": 40,
"fifty": 50,
"sixty": 60,
"seventy": 70,
"eighty": 80,
"ninety": 90,
}
self.tens_plural = {
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
}
self.tens_ordinal = {
name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
}
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
self.multipliers = {
"hundred": 100,
"thousand": 1_000,
"million": 1_000_000,
"billion": 1_000_000_000,
"trillion": 1_000_000_000_000,
"quadrillion": 1_000_000_000_000_000,
"quintillion": 1_000_000_000_000_000_000,
"sextillion": 1_000_000_000_000_000_000_000,
"septillion": 1_000_000_000_000_000_000_000_000,
"octillion": 1_000_000_000_000_000_000_000_000_000,
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
}
self.multipliers_plural = {
name + "s": (value, "s") for name, value in self.multipliers.items()
}
self.multipliers_ordinal = {
name + "th": (value, "th") for name, value in self.multipliers.items()
}
self.multipliers_suffixed = {**self.multipliers_plural, **self.multipliers_ordinal}
self.decimals = {*self.ones, *self.tens, *self.zeros}
self.preceding_prefixers = {
"minus": "-",
"negative": "-",
"plus": "+",
"positive": "+",
}
self.following_prefixers = {
"pound": "£",
"pounds": "£",
"euro": "",
"euros": "",
"dollar": "$",
"dollars": "$",
"cent": "¢",
"cents": "¢",
}
self.prefixes = set(
list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
)
self.suffixers = {
"per": {"cent": "%"},
"percent": "%",
}
self.specials = {"and", "double", "triple", "point"}
self.words = set(
[
key
for mapping in [
self.zeros,
self.ones,
self.ones_suffixed,
self.tens,
self.tens_suffixed,
self.multipliers,
self.multipliers_suffixed,
self.preceding_prefixers,
self.following_prefixers,
self.suffixers,
self.specials,
]
for key in mapping
]
)
self.literal_words = {"one", "ones"}
def process_words(self, words: List[str]) -> Iterator[str]:
prefix: Optional[str] = None
value: Optional[Union[str, int]] = None
skip = False
def to_fraction(s: str):
try:
return Fraction(s)
except ValueError:
return None
def output(result: Union[str, int]):
nonlocal prefix, value
result = str(result)
if prefix is not None:
result = prefix + result
value = None
prefix = None
return result
if len(words) == 0:
return
for prev, current, next in windowed([None] + words + [None], 3):
if skip:
skip = False
continue
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
has_prefix = current[0] in self.prefixes
current_without_prefix = current[1:] if has_prefix else current
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
# arabic numbers (potentially with signs and fractions)
f = to_fraction(current_without_prefix)
assert f is not None
if value is not None:
if isinstance(value, str) and value.endswith("."):
# concatenate decimals / ip address components
value = str(value) + str(current)
continue
else:
yield output(value)
prefix = current[0] if has_prefix else prefix
if f.denominator == 1:
value = f.numerator # store integers as int
else:
value = current_without_prefix
elif current not in self.words:
# non-numeric words
if value is not None:
yield output(value)
yield output(current)
elif current in self.zeros:
value = str(value or "") + "0"
elif current in self.ones:
ones = self.ones[current]
if value is None:
value = ones
elif isinstance(value, str) or prev in self.ones:
if prev in self.tens and ones < 10: # replace the last zero with the digit
assert value[-1] == "0"
value = value[:-1] + str(ones)
else:
value = str(value) + str(ones)
elif ones < 10:
if value % 10 == 0:
value += ones
else:
value = str(value) + str(ones)
else: # eleven to nineteen
if value % 100 == 0:
value += ones
else:
value = str(value) + str(ones)
elif current in self.ones_suffixed:
# ordinal or cardinal; yield the number right away
ones, suffix = self.ones_suffixed[current]
if value is None:
yield output(str(ones) + suffix)
elif isinstance(value, str) or prev in self.ones:
if prev in self.tens and ones < 10:
assert value[-1] == "0"
yield output(value[:-1] + str(ones) + suffix)
else:
yield output(str(value) + str(ones) + suffix)
elif ones < 10:
if value % 10 == 0:
yield output(str(value + ones) + suffix)
else:
yield output(str(value) + str(ones) + suffix)
else: # eleven to nineteen
if value % 100 == 0:
yield output(str(value + ones) + suffix)
else:
yield output(str(value) + str(ones) + suffix)
value = None
elif current in self.tens:
tens = self.tens[current]
if value is None:
value = tens
elif isinstance(value, str):
value = str(value) + str(tens)
else:
if value % 100 == 0:
value += tens
else:
value = str(value) + str(tens)
elif current in self.tens_suffixed:
# ordinal or cardinal; yield the number right away
tens, suffix = self.tens_suffixed[current]
if value is None:
yield output(str(tens) + suffix)
elif isinstance(value, str):
yield output(str(value) + str(tens) + suffix)
else:
if value % 100 == 0:
yield output(str(value + tens) + suffix)
else:
yield output(str(value) + str(tens) + suffix)
elif current in self.multipliers:
multiplier = self.multipliers[current]
if value is None:
value = multiplier
elif isinstance(value, str) or value == 0:
f = to_fraction(value)
p = f * multiplier if f is not None else None
if f is not None and p.denominator == 1:
value = p.numerator
else:
yield output(value)
value = multiplier
else:
before = value // 1000 * 1000
residual = value % 1000
value = before + residual * multiplier
elif current in self.multipliers_suffixed:
multiplier, suffix = self.multipliers_suffixed[current]
if value is None:
yield output(str(multiplier) + suffix)
elif isinstance(value, str):
f = to_fraction(value)
p = f * multiplier if f is not None else None
if f is not None and p.denominator == 1:
yield output(str(p.numerator) + suffix)
else:
yield output(value)
yield output(str(multiplier) + suffix)
else: # int
before = value // 1000 * 1000
residual = value % 1000
value = before + residual * multiplier
yield output(str(value) + suffix)
value = None
elif current in self.preceding_prefixers:
# apply prefix (positive, minus, etc.) if it precedes a number
if value is not None:
yield output(value)
if next in self.words or next_is_numeric:
prefix = self.preceding_prefixers[current]
else:
yield output(current)
elif current in self.following_prefixers:
# apply prefix (dollars, cents, etc.) only after a number
if value is not None:
prefix = self.following_prefixers[current]
yield output(value)
else:
yield output(current)
elif current in self.suffixers:
# apply suffix symbols (percent -> '%')
if value is not None:
suffix = self.suffixers[current]
if isinstance(suffix, dict):
if next in suffix:
yield output(str(value) + suffix[next])
skip = True
else:
yield output(value)
yield output(current)
else:
yield output(str(value) + suffix)
else:
yield output(current)
elif current in self.specials:
if next not in self.words and not next_is_numeric:
# apply special handling only if the next word can be numeric
if value is not None:
yield output(value)
yield output(current)
elif current == "and":
# ignore "and" after hundreds, thousands, etc.
if prev not in self.multipliers:
if value is not None:
yield output(value)
yield output(current)
elif current == "double" or current == "triple":
if next in self.ones or next in self.zeros:
repeats = 2 if current == "double" else 3
ones = self.ones.get(next, 0)
value = str(value or "") + str(ones) * repeats
skip = True
else:
if value is not None:
yield output(value)
yield output(current)
elif current == "point":
if next in self.decimals or next_is_numeric:
value = str(value or "") + "."
else:
# should all have been covered at this point
raise ValueError(f"Unexpected token: {current}")
else:
# all should have been covered at this point
raise ValueError(f"Unexpected token: {current}")
if value is not None:
yield output(value)
def preprocess(self, s: str):
# replace "<number> and a half" with "<number> point five"
results = []
segments = re.split(r"\band\s+a\s+half\b", s)
for i, segment in enumerate(segments):
if len(segment.strip()) == 0:
continue
if i == len(segments) - 1:
results.append(segment)
else:
results.append(segment)
last_word = segment.rsplit(maxsplit=2)[-1]
if last_word in self.decimals or last_word in self.multipliers:
results.append("point five")
else:
results.append("and a half")
s = " ".join(results)
# put a space at number/letter boundary
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
# but remove spaces which could be a suffix
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
return s
def postprocess(self, s: str):
def combine_cents(m: Match):
try:
currency = m.group(1)
integer = m.group(2)
cents = int(m.group(3))
return f"{currency}{integer}.{cents:02d}"
except ValueError:
return m.string
def extract_cents(m: Match):
try:
return f"¢{int(m.group(1))}"
except ValueError:
return m.string
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
# write "one(s)" instead of "1(s)", just for the readability
s = re.sub(r"\b1(s?)\b", r"one\1", s)
return s
def __call__(self, s: str):
s = self.preprocess(s)
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
s = self.postprocess(s)
return s
class EnglishSpellingNormalizer:
"""
Applies British-American spelling mappings as listed in [1].
[1] https://www.tysto.com/uk-us-spelling-list.html
"""
def __init__(self):
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
self.mapping = json.load(open(mapping_path))
def __call__(self, s: str):
return " ".join(self.mapping.get(word, word) for word in s.split())
class EnglishTextNormalizer:
def __init__(self):
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
self.replacers = {
# common contractions
r"\bwon't\b": "will not",
r"\bcan't\b": "can not",
r"\blet's\b": "let us",
r"\bain't\b": "aint",
r"\by'all\b": "you all",
r"\bwanna\b": "want to",
r"\bgotta\b": "got to",
r"\bgonna\b": "going to",
r"\bi'ma\b": "i am going to",
r"\bimma\b": "i am going to",
r"\bwoulda\b": "would have",
r"\bcoulda\b": "could have",
r"\bshoulda\b": "should have",
r"\bma'am\b": "madam",
# contractions in titles/prefixes
r"\bmr\b": "mister ",
r"\bmrs\b": "missus ",
r"\bst\b": "saint ",
r"\bdr\b": "doctor ",
r"\bprof\b": "professor ",
r"\bcapt\b": "captain ",
r"\bgov\b": "governor ",
r"\bald\b": "alderman ",
r"\bgen\b": "general ",
r"\bsen\b": "senator ",
r"\brep\b": "representative ",
r"\bpres\b": "president ",
r"\brev\b": "reverend ",
r"\bhon\b": "honorable ",
r"\basst\b": "assistant ",
r"\bassoc\b": "associate ",
r"\blt\b": "lieutenant ",
r"\bcol\b": "colonel ",
r"\bjr\b": "junior ",
r"\bsr\b": "senior ",
r"\besq\b": "esquire ",
# prefect tenses, ideally it should be any past participles, but it's harder..
r"'d been\b": " had been",
r"'s been\b": " has been",
r"'d gone\b": " had gone",
r"'s gone\b": " has gone",
r"'d done\b": " had done", # "'s done" is ambiguous
r"'s got\b": " has got",
# general contractions
r"n't\b": " not",
r"'re\b": " are",
r"'s\b": " is",
r"'d\b": " would",
r"'ll\b": " will",
r"'t\b": " not",
r"'ve\b": " have",
r"'m\b": " am",
}
self.standardize_numbers = EnglishNumberNormalizer()
self.standardize_spellings = EnglishSpellingNormalizer()
def __call__(self, s: str):
s = s.lower()
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
s = re.sub(self.ignore_patterns, "", s)
s = re.sub(r"\s+'", "'", s) # standardize when there's a space before an apostrophe
for pattern, replacement in self.replacers.items():
s = re.sub(pattern, replacement, s)
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep some symbols for numerics
s = self.standardize_numbers(s)
s = self.standardize_spellings(s)
# now remove prefix/suffix symbols that are not preceded/followed by numbers
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
s = re.sub(r"([^0-9])%", r"\1 ", s)
s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
return s

View File

@@ -0,0 +1,331 @@
import os
from dataclasses import dataclass
from functools import lru_cache
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import GPT2TokenizerFast
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"iw": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
}
@dataclass(frozen=True)
class Tokenizer:
"""A thin wrapper around `GPT2TokenizerFast` providing quick access to special tokens"""
tokenizer: "GPT2TokenizerFast"
language: Optional[str]
sot_sequence: Tuple[int]
def encode(self, text, **kwargs):
return self.tokenizer.encode(text, **kwargs)
def decode(self, token_ids: Union[int, List[int], np.ndarray, torch.Tensor], **kwargs):
return self.tokenizer.decode(token_ids, **kwargs)
def decode_with_timestamps(self, tokens) -> str:
"""
Timestamp tokens are above the special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
outputs = [[]]
for token in tokens:
if token >= self.timestamp_begin:
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
outputs.append(timestamp)
outputs.append([])
else:
outputs[-1].append(token)
outputs = [s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
return "".join(outputs)
@property
@lru_cache()
def eot(self) -> int:
return self.tokenizer.eos_token_id
@property
@lru_cache()
def sot(self) -> int:
return self._get_single_token_id("<|startoftranscript|>")
@property
@lru_cache()
def sot_lm(self) -> int:
return self._get_single_token_id("<|startoflm|>")
@property
@lru_cache()
def sot_prev(self) -> int:
return self._get_single_token_id("<|startofprev|>")
@property
@lru_cache()
def no_speech(self) -> int:
return self._get_single_token_id("<|nospeech|>")
@property
@lru_cache()
def no_timestamps(self) -> int:
return self._get_single_token_id("<|notimestamps|>")
@property
@lru_cache()
def timestamp_begin(self) -> int:
return self.tokenizer.all_special_ids[-1] + 1
@property
@lru_cache()
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
raise ValueError(f"This tokenizer does not have language token configured")
additional_tokens = dict(
zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
)
)
candidate = f"<|{self.language}|>"
if candidate in additional_tokens:
return additional_tokens[candidate]
raise KeyError(f"Language {self.language} not found in tokenizer.")
@property
@lru_cache()
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in zip(
self.tokenizer.additional_special_tokens,
self.tokenizer.additional_special_tokens_ids,
):
if token.strip("<|>") in LANGUAGES:
result.append(token_id)
return tuple(result)
@property
@lru_cache()
def all_language_codes(self) -> Tuple[str]:
return tuple(self.decode([l]).strip("<|>") for l in self.all_language_tokens)
@property
@lru_cache()
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps])
@property
@lru_cache()
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
- ♪♪♪
- ( SPEAKING FOREIGN LANGUAGE )
- [DAVID] Hey there,
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list("\"#()*+/:;<=>@[\\]^_`{|}~「」『』")
symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
miscellaneous = set("♩♪♫♬♭♮♯")
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.tokenizer.encode(" -")[0], self.tokenizer.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [self.tokenizer.encode(symbol), self.tokenizer.encode(" " + symbol)]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
return tuple(sorted(result))
def _get_single_token_id(self, text) -> int:
tokens = self.tokenizer.encode(text)
assert len(tokens) == 1, f"{text} is not encoded as a single token"
return tokens[0]
@lru_cache(maxsize=None)
def build_tokenizer(name: str = "gpt2"):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
path = os.path.join(os.path.dirname(__file__), "assets", name)
tokenizer = GPT2TokenizerFast.from_pretrained(path)
specials = [
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in LANGUAGES.keys()],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
]
tokenizer.add_special_tokens(dict(additional_special_tokens=specials))
return tokenizer
@lru_cache(maxsize=None)
def get_tokenizer(
multilingual: bool,
*,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
language: Optional[str] = None,
) -> Tokenizer:
if language is not None:
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
language = TO_LANGUAGE_CODE[language]
else:
raise ValueError(f"Unsupported language: {language}")
if multilingual:
tokenizer_name = "multilingual"
task = task or "transcribe"
language = language or "en"
else:
tokenizer_name = "gpt2"
task = None
language = None
tokenizer = build_tokenizer(name=tokenizer_name)
all_special_ids: List[int] = tokenizer.all_special_ids
sot: int = all_special_ids[1]
translate: int = all_special_ids[-6]
transcribe: int = all_special_ids[-5]
langs = tuple(LANGUAGES.keys())
sot_sequence = [sot]
if language is not None:
sot_sequence.append(sot + 1 + langs.index(language))
if task is not None:
sot_sequence.append(transcribe if task == "transcribe" else translate)
return Tokenizer(tokenizer=tokenizer, language=language, sot_sequence=tuple(sot_sequence))

View File

@@ -0,0 +1,207 @@
import argparse
import os
import warnings
from typing import List, Optional, Tuple, Union, TYPE_CHECKING
import numpy as np
import torch
import tqdm
from .audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
from .decoding import DecodingOptions, DecodingResult
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import exact_div, format_timestamp, optional_int, optional_float, str2bool, write_txt, write_vtt, write_srt
if TYPE_CHECKING:
from .model import Whisper
def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
*,
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
force_extraction: bool = False,
**decode_options,
):
"""
Transcribe an audio file using Whisper
Parameters
----------
model: Whisper
The Whisper model instance
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details,
If False, displays minimal details. If None, does not display anything
temperature: Union[float, Tuple[float, ...]]
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
compression_ratio_threshold: float
If the gzip compression ratio is above this value, treat as failed
logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed
no_speech_threshold: float
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below `logprob_threshold`, consider the segment as silent
condition_on_previous_text: bool
if True, the previous output of the model is provided as a prompt for the next window;
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
if model.device == torch.device("cpu"):
if torch.cuda.is_available():
warnings.warn("Performing inference on CPU when CUDA is available")
if dtype == torch.float16:
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32
if dtype == torch.float32:
decode_options["fp16"] = False
mel = log_mel_spectrogram(audio)
all_segments = []
def add_segment(
*, start: float, end: float, encoder_embeddings
):
all_segments.append(
{
"start": start,
"end": end,
"encoder_embeddings":encoder_embeddings,
}
)
# show the progress bar when verbose is False (otherwise the transcribed text will be printed)
num_frames = mel.shape[-1]
seek = 0
previous_seek_value = seek
sample_skip = 3000 #
with tqdm.tqdm(total=num_frames, unit='frames', disable=verbose is not False) as pbar:
while seek < num_frames:
# seek是开始的帧数
end_seek = min(seek + sample_skip, num_frames)
segment = pad_or_trim(mel[:,seek:seek+sample_skip], N_FRAMES).to(model.device).to(dtype)
single = segment.ndim == 2
if single:
segment = segment.unsqueeze(0)
if dtype == torch.float16:
segment = segment.half()
audio_features, embeddings = model.encoder(segment, include_embeddings = True)
encoder_embeddings = embeddings
#print(f"encoder_embeddings shape {encoder_embeddings.shape}")
add_segment(
start=seek,
end=end_seek,
#text_tokens=tokens,
#result=result,
encoder_embeddings=encoder_embeddings,
)
seek+=sample_skip
return dict(segments=all_segments)
def cli():
from . import available_models
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
args = parser.parse_args().__dict__
model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir")
output_dir: str = args.pop("output_dir")
device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True)
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
args["language"] = "en"
temperature = args.pop("temperature")
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
if temperature_increment_on_fallback is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
else:
temperature = [temperature]
threads = args.pop("threads")
if threads > 0:
torch.set_num_threads(threads)
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)
for audio_path in args.pop("audio"):
result = transcribe(model, audio_path, temperature=temperature, **args)
audio_basename = os.path.basename(audio_path)
# save TXT
with open(os.path.join(output_dir, audio_basename + ".txt"), "w", encoding="utf-8") as txt:
write_txt(result["segments"], file=txt)
# save VTT
with open(os.path.join(output_dir, audio_basename + ".vtt"), "w", encoding="utf-8") as vtt:
write_vtt(result["segments"], file=vtt)
# save SRT
with open(os.path.join(output_dir, audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
if __name__ == '__main__':
cli()

View File

@@ -0,0 +1,87 @@
import zlib
from typing import Iterator, TextIO
def exact_div(x, y):
assert x % y == 0
return x // y
def str2bool(string):
str2val = {"True": True, "False": False}
if string in str2val:
return str2val[string]
else:
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
def optional_int(string):
return None if string == "None" else int(string)
def optional_float(string):
return None if string == "None" else float(string)
def compression_ratio(text) -> float:
return len(text) / len(zlib.compress(text.encode("utf-8")))
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = '.'):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
def write_txt(transcript: Iterator[dict], file: TextIO):
for segment in transcript:
print(segment['text'].strip(), file=file, flush=True)
def write_vtt(transcript: Iterator[dict], file: TextIO):
print("WEBVTT\n", file=file)
for segment in transcript:
print(
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)
def write_srt(transcript: Iterator[dict], file: TextIO):
"""
Write a transcript to a file in SRT format.
Example usage:
from pathlib import Path
from whisper.utils import write_srt
result = transcribe(model, audio_path, temperature=temperature, **args)
# save SRT
audio_basename = Path(audio_path).stem
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
write_srt(result["segments"], file=srt)
"""
for i, segment in enumerate(transcript, start=1):
# write srt lines
print(
f"{i}\n"
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
f"{segment['text'].strip().replace('-->', '->')}\n",
file=file,
flush=True,
)

View File

@@ -0,0 +1,120 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from omegaconf import OmegaConf
import torch
from diffusers import AutoencoderKL, DDIMScheduler
from latentsync.models.unet import UNet3DConditionModel
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
from accelerate.utils import set_seed
from latentsync.whisper.audio2feature import Audio2Feature
from DeepCache import DeepCacheSDHelper
def main(config, args):
if not os.path.exists(args.video_path):
raise RuntimeError(f"Video path '{args.video_path}' not found")
if not os.path.exists(args.audio_path):
raise RuntimeError(f"Audio path '{args.audio_path}' not found")
# Check if the GPU supports float16
is_fp16_supported = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] > 7
dtype = torch.float16 if is_fp16_supported else torch.float32
print(f"Input video path: {args.video_path}")
print(f"Input audio path: {args.audio_path}")
print(f"Loaded checkpoint path: {args.inference_ckpt_path}")
scheduler = DDIMScheduler.from_pretrained("configs")
if config.model.cross_attention_dim == 768:
whisper_model_path = "checkpoints/whisper/small.pt"
elif config.model.cross_attention_dim == 384:
whisper_model_path = "checkpoints/whisper/tiny.pt"
else:
raise NotImplementedError("cross_attention_dim must be 768 or 384")
audio_encoder = Audio2Feature(
model_path=whisper_model_path,
device="cuda",
num_frames=config.data.num_frames,
audio_feat_length=config.data.audio_feat_length,
)
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype)
vae.config.scaling_factor = 0.18215
vae.config.shift_factor = 0
unet, _ = UNet3DConditionModel.from_pretrained(
OmegaConf.to_container(config.model),
args.inference_ckpt_path,
device="cpu",
)
unet = unet.to(dtype=dtype)
pipeline = LipsyncPipeline(
vae=vae,
audio_encoder=audio_encoder,
unet=unet,
scheduler=scheduler,
).to("cuda")
# use DeepCache
if args.enable_deepcache:
helper = DeepCacheSDHelper(pipe=pipeline)
helper.set_params(cache_interval=3, cache_branch_id=0)
helper.enable()
if args.seed != -1:
set_seed(args.seed)
else:
torch.seed()
print(f"Initial seed: {torch.initial_seed()}")
pipeline(
video_path=args.video_path,
audio_path=args.audio_path,
video_out_path=args.video_out_path,
num_frames=config.data.num_frames,
num_inference_steps=args.inference_steps,
guidance_scale=args.guidance_scale,
weight_dtype=dtype,
width=config.data.resolution,
height=config.data.resolution,
mask_image_path=config.data.mask_image_path,
temp_dir=args.temp_dir,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--unet_config_path", type=str, default="configs/unet.yaml")
parser.add_argument("--inference_ckpt_path", type=str, required=True)
parser.add_argument("--video_path", type=str, required=True)
parser.add_argument("--audio_path", type=str, required=True)
parser.add_argument("--video_out_path", type=str, required=True)
parser.add_argument("--inference_steps", type=int, default=20)
parser.add_argument("--guidance_scale", type=float, default=1.0)
parser.add_argument("--temp_dir", type=str, default="temp")
parser.add_argument("--seed", type=int, default=1247)
parser.add_argument("--enable_deepcache", action="store_true")
args = parser.parse_args()
config = OmegaConf.load(args.unet_config_path)
main(config, args)

View File

@@ -0,0 +1,196 @@
import os
import argparse
from pathlib import Path
# --- 自动加载 GPU 配置 (必须在 torch 导入前) ---
def load_gpu_config():
"""尝试从后端 .env 文件读取 LATENTSYNC_GPU_ID"""
try:
# 路径: scripts/server.py -> scripts -> LatentSync -> models -> ViGent2 -> backend -> .env
current_dir = Path(__file__).resolve().parent
env_path = current_dir.parent.parent.parent / "backend" / ".env"
target_gpu = "1" # 默认 fallback
if env_path.exists():
print(f"📖 读取配置文件: {env_path}")
with open(env_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("LATENTSYNC_GPU_ID="):
val = line.split("=")[1].strip().split("#")[0].strip()
if val:
target_gpu = val
print(f"⚙️ 发现配置 LATENTSYNC_GPU_ID={target_gpu}")
break
# 设置环境变量
if "CUDA_VISIBLE_DEVICES" not in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = target_gpu
print(f"✅ 已自动设置: CUDA_VISIBLE_DEVICES={target_gpu}")
else:
print(f" 检测到外部 CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']},跳过自动配置")
except Exception as e:
print(f"⚠️ 读取 GPU 配置失败: {e},将使用默认设置")
load_gpu_config()
import torch
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from omegaconf import OmegaConf
from diffusers import AutoencoderKL, DDIMScheduler
from latentsync.models.unet import UNet3DConditionModel
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
from latentsync.whisper.audio2feature import Audio2Feature
from accelerate.utils import set_seed
from DeepCache import DeepCacheSDHelper
# 全局模型缓存
models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# --- 模型加载逻辑 (参考 inference.py) ---
print("⏳ 正在加载 LatentSync 模型...")
# 默认配置路径 (相对于根目录)
unet_config_path = "configs/unet/stage2_512.yaml"
ckpt_path = "checkpoints/latentsync_unet.pt"
if not os.path.exists(unet_config_path):
print(f"⚠️ 找不到配置文件: {unet_config_path},请确保在 models/LatentSync 根目录运行")
config = OmegaConf.load(unet_config_path)
# Check GPU
is_fp16_supported = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] > 7
dtype = torch.float16 if is_fp16_supported else torch.float32
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
print(f"🖥️ 正在使用 GPU: {gpu_name} (CUDA_VISIBLE_DEVICES 已生效)")
else:
print("⚠️ 警告: 未检测到 GPU将使用 CPU 进行推理 (速度极慢)")
scheduler = DDIMScheduler.from_pretrained("configs")
# Whisper Model
if config.model.cross_attention_dim == 768:
whisper_path = "checkpoints/whisper/small.pt"
else:
whisper_path = "checkpoints/whisper/tiny.pt"
audio_encoder = Audio2Feature(
model_path=whisper_path,
device=device,
num_frames=config.data.num_frames,
audio_feat_length=config.data.audio_feat_length,
)
# VAE
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype)
vae.config.scaling_factor = 0.18215
vae.config.shift_factor = 0
# UNet
unet, _ = UNet3DConditionModel.from_pretrained(
OmegaConf.to_container(config.model),
ckpt_path,
device="cpu", # Load to CPU first to save memory during init
)
unet = unet.to(dtype=dtype)
# Pipeline
pipeline = LipsyncPipeline(
vae=vae,
audio_encoder=audio_encoder,
unet=unet,
scheduler=scheduler,
).to(device)
# DeepCache (默认启用)
helper = DeepCacheSDHelper(pipe=pipeline)
helper.set_params(cache_interval=3, cache_branch_id=0)
helper.enable()
models["pipeline"] = pipeline
models["config"] = config
models["dtype"] = dtype
print("✅ LatentSync 模型加载完成,服务就绪!")
yield
# Clean up if needed
models.clear()
torch.cuda.empty_cache()
app = FastAPI(lifespan=lifespan)
class LipSyncRequest(BaseModel):
video_path: str
audio_path: str
video_out_path: str
inference_steps: int = 20
guidance_scale: float = 1.5
seed: int = 1247
temp_dir: str = "temp"
@app.get("/health")
def health_check():
return {"status": "ok", "model_loaded": "pipeline" in models}
@app.post("/lipsync")
async def generate_lipsync(req: LipSyncRequest):
if "pipeline" not in models:
raise HTTPException(status_code=503, detail="Model not loaded")
if not os.path.exists(req.video_path):
raise HTTPException(status_code=404, detail=f"Video not found: {req.video_path}")
if not os.path.exists(req.audio_path):
raise HTTPException(status_code=404, detail=f"Audio not found: {req.audio_path}")
print(f"🎬 收到任务: {Path(req.video_path).name} -> {Path(req.video_out_path).name}")
try:
pipeline = models["pipeline"]
config = models["config"]
dtype = models["dtype"]
# Set seed
if req.seed != -1:
set_seed(req.seed)
else:
torch.seed()
# Run Inference
pipeline(
video_path=req.video_path,
audio_path=req.audio_path,
video_out_path=req.video_out_path,
num_frames=config.data.num_frames,
num_inference_steps=req.inference_steps,
guidance_scale=req.guidance_scale,
weight_dtype=dtype,
width=config.data.resolution,
height=config.data.resolution,
mask_image_path=config.data.mask_image_path,
temp_dir=req.temp_dir,
)
if os.path.exists(req.video_out_path):
return {"status": "success", "output_path": req.video_out_path}
else:
raise HTTPException(status_code=500, detail="Output file generation failed")
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8007)

View File

@@ -0,0 +1,340 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tqdm.auto import tqdm
import os, argparse, datetime, math
import logging
from omegaconf import OmegaConf
import shutil
from latentsync.data.syncnet_dataset import SyncNetDataset
from latentsync.models.stable_syncnet import StableSyncNet
from latentsync.models.wav2lip_syncnet import Wav2LipSyncNet
from latentsync.utils.util import gather_loss, plot_loss_chart
from accelerate.utils import set_seed
import torch
from diffusers import AutoencoderKL
from diffusers.utils.logging import get_logger
from einops import rearrange
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from latentsync.utils.util import init_dist, cosine_loss, dummy_context
logger = get_logger(__name__)
def main(config):
# Initialize distributed training
local_rank = init_dist()
global_rank = dist.get_rank()
num_processes = dist.get_world_size()
is_main_process = global_rank == 0
seed = config.run.seed + global_rank
set_seed(seed)
# Logging folder
folder_name = "train" + datetime.datetime.now().strftime(f"-%Y_%m_%d-%H:%M:%S")
output_dir = os.path.join(config.data.train_output_dir, folder_name)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Handle the output folder creation
if is_main_process:
os.makedirs(output_dir, exist_ok=True)
os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
os.makedirs(f"{output_dir}/loss_charts", exist_ok=True)
shutil.copy(config.config_path, output_dir)
device = torch.device(local_rank)
if config.data.latent_space:
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
vae.requires_grad_(False)
vae.to(device)
else:
vae = None
# Dataset and Dataloader setup
train_dataset = SyncNetDataset(config.data.train_data_dir, config.data.train_fileslist, config)
val_dataset = SyncNetDataset(config.data.val_data_dir, config.data.val_fileslist, config)
train_distributed_sampler = DistributedSampler(
train_dataset,
num_replicas=num_processes,
rank=global_rank,
shuffle=True,
seed=config.run.seed,
)
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.data.batch_size,
shuffle=False,
sampler=train_distributed_sampler,
num_workers=config.data.num_workers,
pin_memory=False,
drop_last=True,
worker_init_fn=train_dataset.worker_init_fn,
)
num_samples_limit = 640
val_batch_size = min(
num_samples_limit // config.data.num_frames, config.data.batch_size
) # limit batch size to avoid CUDA OOM
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
batch_size=val_batch_size,
shuffle=False,
num_workers=config.data.num_workers,
pin_memory=False,
drop_last=False,
worker_init_fn=val_dataset.worker_init_fn,
)
# Model
syncnet = StableSyncNet(OmegaConf.to_container(config.model)).to(device)
# syncnet = Wav2LipSyncNet().to(device)
optimizer = torch.optim.AdamW(
list(filter(lambda p: p.requires_grad, syncnet.parameters())), lr=config.optimizer.lr
)
global_step = 0
train_step_list = []
train_loss_list = []
val_step_list = []
val_loss_list = []
if config.ckpt.resume_ckpt_path != "":
if is_main_process:
logger.info(f"Load checkpoint from: {config.ckpt.resume_ckpt_path}")
ckpt = torch.load(config.ckpt.resume_ckpt_path, map_location=device, weights_only=True)
syncnet.load_state_dict(ckpt["state_dict"])
if "global_step" in ckpt:
global_step = ckpt["global_step"]
train_step_list = ckpt["train_step_list"]
train_loss_list = ckpt["train_loss_list"]
val_step_list = ckpt["val_step_list"]
val_loss_list = ckpt["val_loss_list"]
# DDP wrapper
syncnet = DDP(syncnet, device_ids=[local_rank], output_device=local_rank)
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
num_train_epochs = math.ceil(config.run.max_train_steps / num_update_steps_per_epoch)
if is_main_process:
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {config.data.batch_size}")
logger.info(
f" Total train batch size (w. parallel & distributed & accumulation) = {config.data.batch_size * num_processes * config.data.gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {config.run.max_train_steps}")
first_epoch = global_step // num_update_steps_per_epoch
num_val_batches = config.data.num_val_samples // (num_processes * config.data.batch_size)
# Only show the progress bar once on each machine.
progress_bar = tqdm(
range(0, config.run.max_train_steps), initial=global_step, desc="Steps", disable=not is_main_process
)
# Support mixed-precision training
scaler = torch.amp.GradScaler("cuda") if config.run.mixed_precision_training else None
for epoch in range(first_epoch, num_train_epochs):
train_dataloader.sampler.set_epoch(epoch)
syncnet.train()
step_loss = 0
optimizer.zero_grad()
for index, batch in enumerate(train_dataloader):
### >>>> Training >>>> ###
frames = batch["frames"].to(device, dtype=torch.float16)
audio_samples = batch["audio_samples"].to(device, dtype=torch.float16)
y = batch["y"].to(device, dtype=torch.float32)
if config.data.latent_space:
max_batch_size = (
num_samples_limit // config.data.num_frames
) # due to the limited cuda memory, we split the input frames into parts
if frames.shape[0] > max_batch_size:
assert (
frames.shape[0] % max_batch_size == 0
), f"max_batch_size {max_batch_size} should be divisible by batch_size {frames.shape[0]}"
frames_part_results = []
for i in range(0, frames.shape[0], max_batch_size):
frames_part = frames[i : i + max_batch_size]
frames_part = rearrange(frames_part, "b f c h w -> (b f) c h w")
with torch.no_grad():
frames_part = vae.encode(frames_part).latent_dist.sample() * 0.18215
frames_part_results.append(frames_part)
frames = torch.cat(frames_part_results, dim=0)
else:
frames = rearrange(frames, "b f c h w -> (b f) c h w")
with torch.no_grad():
frames = vae.encode(frames).latent_dist.sample() * 0.18215
frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=config.data.num_frames)
else:
frames = rearrange(frames, "b f c h w -> b (f c) h w")
if config.data.lower_half:
height = frames.shape[2]
frames = frames[:, :, height // 2 :, :]
# Disable gradient sync for the first N-1 steps, enable sync on the final step
with syncnet.no_sync() if (index + 1) % config.data.gradient_accumulation_steps != 0 else dummy_context():
# Mixed-precision training
with torch.autocast(
device_type="cuda", dtype=torch.float16, enabled=config.run.mixed_precision_training
):
vision_embeds, audio_embeds = syncnet(frames, audio_samples)
loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), y).mean()
loss = loss / config.data.gradient_accumulation_steps
# Backpropagate
scaler.scale(loss).backward()
step_loss += gather_loss(loss, device)
# Update parameters when the accumulation steps are reached
if (index + 1) % config.data.gradient_accumulation_steps == 0:
""">>> gradient clipping >>>"""
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(syncnet.parameters(), config.optimizer.max_grad_norm)
""" <<< gradient clipping <<< """
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
progress_bar.update(1)
global_step += 1
train_step_list.append(global_step)
train_loss_list.append(step_loss)
if is_main_process and global_step % config.run.validation_steps == 0:
logger.info(f"Validation at step {global_step}")
val_loss = validation(
val_dataloader,
device,
syncnet,
config.data.latent_space,
config.data.lower_half,
vae,
num_val_batches,
)
val_step_list.append(global_step)
val_loss_list.append(val_loss)
logger.info(f"Validation loss at step {global_step} is {val_loss:0.3f}")
plot_loss_chart(
os.path.join(output_dir, f"loss_charts/loss_chart-{global_step}.png"),
("Train loss", train_step_list, train_loss_list),
("Val loss", val_step_list, val_loss_list),
)
if is_main_process and global_step % config.ckpt.save_ckpt_steps == 0:
checkpoint_save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}.pt")
torch.save(
{
"state_dict": syncnet.module.state_dict(), # to unwrap DDP
"global_step": global_step,
"train_step_list": train_step_list,
"train_loss_list": train_loss_list,
"val_step_list": val_step_list,
"val_loss_list": val_loss_list,
},
checkpoint_save_path,
)
logger.info(f"Saved checkpoint to {checkpoint_save_path}")
progress_bar.set_postfix({"step_loss": step_loss, "epoch": epoch})
step_loss = 0
if global_step >= config.run.max_train_steps:
break
progress_bar.close()
dist.destroy_process_group()
@torch.no_grad()
def validation(val_dataloader, device, syncnet, latent_space, lower_half, vae, num_val_batches):
syncnet.eval()
losses = []
val_step = 0
while True:
for index, batch in enumerate(val_dataloader):
### >>>> Validation >>>> ###
frames = batch["frames"].to(device, dtype=torch.float16)
audio_samples = batch["audio_samples"].to(device, dtype=torch.float16)
y = batch["y"].to(device, dtype=torch.float32)
if latent_space:
num_frames = frames.shape[1]
frames = rearrange(frames, "b f c h w -> (b f) c h w")
frames = vae.encode(frames).latent_dist.sample() * 0.18215
frames = rearrange(frames, "(b f) c h w -> b (f c) h w", f=num_frames)
else:
frames = rearrange(frames, "b f c h w -> b (f c) h w")
if lower_half:
height = frames.shape[2]
frames = frames[:, :, height // 2 :, :]
with torch.autocast(device_type="cuda", dtype=torch.float16):
vision_embeds, audio_embeds = syncnet(frames, audio_samples)
loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), y).mean()
losses.append(loss.item())
val_step += 1
if val_step > num_val_batches:
syncnet.train()
if len(losses) == 0:
raise RuntimeError("No validation data")
return sum(losses) / len(losses)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Code to train the SyncNet")
parser.add_argument("--config_path", type=str, default="configs/syncnet/syncnet_16_pixel.yaml")
args = parser.parse_args()
# Load a configuration file
config = OmegaConf.load(args.config_path)
config.config_path = args.config_path
main(config)

View File

@@ -0,0 +1,519 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import argparse
import shutil
import datetime
import logging
from omegaconf import OmegaConf
from tqdm.auto import tqdm
from einops import rearrange
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import diffusers
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.utils.logging import get_logger
from diffusers.optimization import get_scheduler
from accelerate.utils import set_seed
from latentsync.data.unet_dataset import UNetDataset
from latentsync.models.unet import UNet3DConditionModel
from latentsync.models.stable_syncnet import StableSyncNet
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
from latentsync.utils.util import (
init_dist,
cosine_loss,
one_step_sampling,
)
from latentsync.utils.util import plot_loss_chart
from latentsync.whisper.audio2feature import Audio2Feature
from latentsync.trepa.loss import TREPALoss
from eval.syncnet import SyncNetEval
from eval.syncnet_detect import SyncNetDetector
from eval.eval_sync_conf import syncnet_eval
import lpips
logger = get_logger(__name__)
def main(config):
# Initialize distributed training
local_rank = init_dist()
global_rank = dist.get_rank()
num_processes = dist.get_world_size()
is_main_process = global_rank == 0
seed = config.run.seed + global_rank
set_seed(seed)
# Logging folder
folder_name = "train" + datetime.datetime.now().strftime(f"-%Y_%m_%d-%H:%M:%S")
output_dir = os.path.join(config.data.train_output_dir, folder_name)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
# Handle the output folder creation
if is_main_process:
diffusers.utils.logging.set_verbosity_info()
os.makedirs(output_dir, exist_ok=True)
os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
os.makedirs(f"{output_dir}/val_videos", exist_ok=True)
os.makedirs(f"{output_dir}/sync_conf_results", exist_ok=True)
shutil.copy(config.unet_config_path, output_dir)
shutil.copy(config.data.syncnet_config_path, output_dir)
device = torch.device(local_rank)
noise_scheduler = DDIMScheduler.from_pretrained("configs")
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
vae.config.scaling_factor = 0.18215
vae.config.shift_factor = 0
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
vae.requires_grad_(False)
vae.to(device)
if config.run.pixel_space_supervise:
vae.enable_gradient_checkpointing()
syncnet_eval_model = SyncNetEval(device=device)
syncnet_eval_model.loadParameters("checkpoints/auxiliary/syncnet_v2.model")
syncnet_detector = SyncNetDetector(device=device, detect_results_dir="detect_results")
if config.model.cross_attention_dim == 768:
whisper_model_path = "checkpoints/whisper/small.pt"
elif config.model.cross_attention_dim == 384:
whisper_model_path = "checkpoints/whisper/tiny.pt"
else:
raise NotImplementedError("cross_attention_dim must be 768 or 384")
audio_encoder = Audio2Feature(
model_path=whisper_model_path,
device=device,
audio_embeds_cache_dir=config.data.audio_embeds_cache_dir,
num_frames=config.data.num_frames,
audio_feat_length=config.data.audio_feat_length,
)
unet, resume_global_step = UNet3DConditionModel.from_pretrained(
OmegaConf.to_container(config.model),
config.ckpt.resume_ckpt_path,
device=device,
)
if config.model.add_audio_layer and config.run.use_syncnet:
syncnet_config = OmegaConf.load(config.data.syncnet_config_path)
if syncnet_config.ckpt.inference_ckpt_path == "":
raise ValueError("SyncNet path is not provided")
syncnet = StableSyncNet(OmegaConf.to_container(syncnet_config.model), gradient_checkpointing=True).to(
device=device, dtype=torch.float16
)
syncnet_checkpoint = torch.load(
syncnet_config.ckpt.inference_ckpt_path, map_location=device, weights_only=True
)
syncnet.load_state_dict(syncnet_checkpoint["state_dict"])
syncnet.requires_grad_(False)
del syncnet_checkpoint
torch.cuda.empty_cache()
if config.model.use_motion_module:
unet.requires_grad_(False)
for name, param in unet.named_parameters():
for trainable_module_name in config.run.trainable_modules:
if trainable_module_name in name:
param.requires_grad = True
break
trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters()))
else:
unet.requires_grad_(True)
trainable_params = list(unet.parameters())
if config.optimizer.scale_lr:
config.optimizer.lr = config.optimizer.lr * num_processes
optimizer = torch.optim.AdamW(trainable_params, lr=config.optimizer.lr)
if is_main_process:
logger.info(f"trainable params number: {len(trainable_params)}")
logger.info(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
# Enable gradient checkpointing
if config.run.enable_gradient_checkpointing:
unet.enable_gradient_checkpointing()
# Get the training dataset
train_dataset = UNetDataset(config.data.train_data_dir, config)
distributed_sampler = DistributedSampler(
train_dataset,
num_replicas=num_processes,
rank=global_rank,
shuffle=True,
seed=config.run.seed,
)
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.data.batch_size,
shuffle=False,
sampler=distributed_sampler,
num_workers=config.data.num_workers,
pin_memory=False,
drop_last=True,
worker_init_fn=train_dataset.worker_init_fn,
)
# Get the training iteration
if config.run.max_train_steps == -1:
assert config.run.max_train_epochs != -1
config.run.max_train_steps = config.run.max_train_epochs * len(train_dataloader)
# Scheduler
lr_scheduler = get_scheduler(
config.optimizer.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=config.optimizer.lr_warmup_steps,
num_training_steps=config.run.max_train_steps,
)
if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise:
lpips_loss_func = lpips.LPIPS(net="vgg").to(device)
if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise:
trepa_loss_func = TREPALoss(device=device, with_cp=True)
# Validation pipeline
pipeline = LipsyncPipeline(
vae=vae,
audio_encoder=audio_encoder,
unet=unet,
scheduler=noise_scheduler,
).to(device)
pipeline.set_progress_bar_config(disable=True)
# DDP warpper
unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader))
# Afterwards we recalculate our number of training epochs
num_train_epochs = math.ceil(config.run.max_train_steps / num_update_steps_per_epoch)
# Train!
total_batch_size = config.data.batch_size * num_processes
if is_main_process:
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {config.data.batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Total optimization steps = {config.run.max_train_steps}")
global_step = resume_global_step
first_epoch = resume_global_step // num_update_steps_per_epoch
# Only show the progress bar once on each machine.
progress_bar = tqdm(
range(0, config.run.max_train_steps),
initial=resume_global_step,
desc="Steps",
disable=not is_main_process,
)
train_step_list = []
val_step_list = []
sync_conf_list = []
# Support mixed-precision training
scaler = torch.amp.GradScaler("cuda") if config.run.mixed_precision_training else None
for epoch in range(first_epoch, num_train_epochs):
train_dataloader.sampler.set_epoch(epoch)
unet.train()
for step, batch in enumerate(train_dataloader):
### >>>> Training >>>> ###
if config.model.add_audio_layer:
if batch["mel"] != []:
mel = batch["mel"].to(device, dtype=torch.float16)
audio_embeds_list = []
try:
for idx in range(len(batch["video_path"])):
video_path = batch["video_path"][idx]
start_idx = batch["start_idx"][idx]
with torch.no_grad():
audio_feat = audio_encoder.audio2feat(video_path)
audio_embeds = audio_encoder.crop_overlap_audio_window(audio_feat, start_idx)
audio_embeds_list.append(audio_embeds)
except Exception as e:
logger.info(f"{type(e).__name__} - {e} - {video_path}")
continue
audio_embeds = torch.stack(audio_embeds_list) # (B, 16, 50, 384)
audio_embeds = audio_embeds.to(device, dtype=torch.float16)
else:
audio_embeds = None
# Convert videos to latent space
gt_pixel_values = batch["gt_pixel_values"].to(device, dtype=torch.float16)
masked_pixel_values = batch["masked_pixel_values"].to(device, dtype=torch.float16)
masks = batch["masks"].to(device, dtype=torch.float16)
ref_pixel_values = batch["ref_pixel_values"].to(device, dtype=torch.float16)
gt_pixel_values = rearrange(gt_pixel_values, "b f c h w -> (b f) c h w")
masked_pixel_values = rearrange(masked_pixel_values, "b f c h w -> (b f) c h w")
masks = rearrange(masks, "b f c h w -> (b f) c h w")
ref_pixel_values = rearrange(ref_pixel_values, "b f c h w -> (b f) c h w")
with torch.no_grad():
gt_latents = vae.encode(gt_pixel_values).latent_dist.sample()
masked_latents = vae.encode(masked_pixel_values).latent_dist.sample()
ref_latents = vae.encode(ref_pixel_values).latent_dist.sample()
masks = torch.nn.functional.interpolate(masks, size=config.data.resolution // vae_scale_factor)
gt_latents = (
rearrange(gt_latents, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor
) * vae.config.scaling_factor
masked_latents = (
rearrange(masked_latents, "(b f) c h w -> b c f h w", f=config.data.num_frames)
- vae.config.shift_factor
) * vae.config.scaling_factor
ref_latents = (
rearrange(ref_latents, "(b f) c h w -> b c f h w", f=config.data.num_frames) - vae.config.shift_factor
) * vae.config.scaling_factor
masks = rearrange(masks, "(b f) c h w -> b c f h w", f=config.data.num_frames)
# Sample noise that we'll add to the latents
if config.run.use_mixed_noise:
# Refer to the paper: https://arxiv.org/abs/2305.10474
noise_shared_std_dev = (config.run.mixed_noise_alpha**2 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5
noise_shared = torch.randn_like(gt_latents) * noise_shared_std_dev
noise_shared = noise_shared[:, :, 0:1].repeat(1, 1, config.data.num_frames, 1, 1)
noise_ind_std_dev = (1 / (1 + config.run.mixed_noise_alpha**2)) ** 0.5
noise_ind = torch.randn_like(gt_latents) * noise_ind_std_dev
noise = noise_ind + noise_shared
else:
noise = torch.randn_like(gt_latents)
noise = noise[:, :, 0:1].repeat(
1, 1, config.data.num_frames, 1, 1
) # Using the same noise for all frames, refer to the paper: https://arxiv.org/abs/2308.09716
bsz = gt_latents.shape[0]
# Sample a random timestep for each video
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=gt_latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_gt_latents = noise_scheduler.add_noise(gt_latents, noise, timesteps)
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
raise NotImplementedError
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
unet_input = torch.cat([noisy_gt_latents, masks, masked_latents, ref_latents], dim=1)
# Predict the noise and compute loss
# Mixed-precision training
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=config.run.mixed_precision_training):
pred_noise = unet(unet_input, timesteps, encoder_hidden_states=audio_embeds).sample
if config.run.recon_loss_weight != 0:
recon_loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean")
else:
recon_loss = 0
pred_latents = one_step_sampling(noise_scheduler, pred_noise, timesteps, noisy_gt_latents)
if config.run.pixel_space_supervise:
pred_pixel_values = vae.decode(
rearrange(pred_latents, "b c f h w -> (b f) c h w") / vae.config.scaling_factor
+ vae.config.shift_factor
).sample
if config.run.perceptual_loss_weight != 0 and config.run.pixel_space_supervise:
pred_pixel_values_perceptual = pred_pixel_values[:, :, pred_pixel_values.shape[2] // 2 :, :]
gt_pixel_values_perceptual = gt_pixel_values[:, :, gt_pixel_values.shape[2] // 2 :, :]
lpips_loss = lpips_loss_func(
pred_pixel_values_perceptual.float(), gt_pixel_values_perceptual.float()
).mean()
else:
lpips_loss = 0
if config.run.trepa_loss_weight != 0 and config.run.pixel_space_supervise:
trepa_pred_pixel_values = rearrange(
pred_pixel_values, "(b f) c h w -> b c f h w", f=config.data.num_frames
)
trepa_gt_pixel_values = rearrange(
gt_pixel_values, "(b f) c h w -> b c f h w", f=config.data.num_frames
)
trepa_loss = trepa_loss_func(trepa_pred_pixel_values, trepa_gt_pixel_values)
else:
trepa_loss = 0
if config.model.add_audio_layer and config.run.use_syncnet:
if config.run.pixel_space_supervise:
if config.data.resolution != syncnet_config.data.resolution:
pred_pixel_values = F.interpolate(
pred_pixel_values,
size=(syncnet_config.data.resolution, syncnet_config.data.resolution),
mode="bicubic",
)
syncnet_input = rearrange(
pred_pixel_values, "(b f) c h w -> b (f c) h w", f=config.data.num_frames
)
else:
syncnet_input = rearrange(pred_latents, "b c f h w -> b (f c) h w")
if syncnet_config.data.lower_half:
height = syncnet_input.shape[2]
syncnet_input = syncnet_input[:, :, height // 2 :, :]
ones_tensor = torch.ones((config.data.batch_size, 1)).float().to(device=device)
vision_embeds, audio_embeds = syncnet(syncnet_input, mel)
sync_loss = cosine_loss(vision_embeds.float(), audio_embeds.float(), ones_tensor).mean()
else:
sync_loss = 0
loss = (
recon_loss * config.run.recon_loss_weight
+ sync_loss * config.run.sync_loss_weight
+ lpips_loss * config.run.perceptual_loss_weight
+ trepa_loss * config.run.trepa_loss_weight
)
train_step_list.append(global_step)
optimizer.zero_grad()
# Backpropagate
if config.run.mixed_precision_training:
scaler.scale(loss).backward()
""" >>> gradient clipping >>> """
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(trainable_params, config.optimizer.max_grad_norm)
""" <<< gradient clipping <<< """
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
""" >>> gradient clipping >>> """
torch.nn.utils.clip_grad_norm_(trainable_params, config.optimizer.max_grad_norm)
""" <<< gradient clipping <<< """
optimizer.step()
# Check the grad of attn blocks for debugging
# print(unet.module.up_blocks[3].attentions[2].transformer_blocks[0].attn2.to_q.weight.grad)
lr_scheduler.step()
progress_bar.update(1)
global_step += 1
### <<<< Training <<<< ###
# Save checkpoint and conduct validation
if is_main_process and (global_step % config.ckpt.save_ckpt_steps == 0):
model_save_path = os.path.join(output_dir, f"checkpoints/checkpoint-{global_step}.pt")
state_dict = {
"global_step": global_step,
"state_dict": unet.module.state_dict(),
}
try:
torch.save(state_dict, model_save_path)
logger.info(f"Saved checkpoint to {model_save_path}")
except Exception as e:
logger.error(f"Error saving model: {e}")
# Validation
logger.info("Running validation... ")
validation_video_out_path = os.path.join(output_dir, f"val_videos/val_video_{global_step}.mp4")
with torch.autocast(device_type="cuda", dtype=torch.float16):
pipeline(
config.data.val_video_path,
config.data.val_audio_path,
validation_video_out_path,
num_frames=config.data.num_frames,
num_inference_steps=config.run.inference_steps,
guidance_scale=config.run.guidance_scale,
weight_dtype=torch.float16,
width=config.data.resolution,
height=config.data.resolution,
mask_image_path=config.data.mask_image_path,
)
logger.info(f"Saved validation video output to {validation_video_out_path}")
val_step_list.append(global_step)
if config.model.add_audio_layer and os.path.exists(validation_video_out_path):
try:
_, conf = syncnet_eval(syncnet_eval_model, syncnet_detector, validation_video_out_path, "temp")
except Exception as e:
logger.info(e)
conf = 0
sync_conf_list.append(conf)
plot_loss_chart(
os.path.join(output_dir, f"sync_conf_results/sync_conf_chart-{global_step}.png"),
("Sync confidence", val_step_list, sync_conf_list),
)
logs = {"step_loss": loss.item(), "epoch": epoch}
progress_bar.set_postfix(**logs)
if global_step >= config.run.max_train_steps:
break
progress_bar.close()
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Config file path
parser.add_argument("--unet_config_path", type=str, default="configs/unet.yaml")
args = parser.parse_args()
config = OmegaConf.load(args.unet_config_path)
config.unet_config_path = args.unet_config_path
main(config)