123 lines
5.1 KiB
Python
123 lines
5.1 KiB
Python
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from latentsync.utils.util import read_video, write_video
|
|
from torchvision import transforms
|
|
import cv2
|
|
from einops import rearrange
|
|
import torch
|
|
import numpy as np
|
|
from typing import Union
|
|
from .affine_transform import AlignRestore
|
|
from .face_detector import FaceDetector
|
|
|
|
|
|
def load_fixed_mask(resolution: int, mask_image_path="latentsync/utils/mask.png") -> torch.Tensor:
|
|
mask_image = cv2.imread(mask_image_path)
|
|
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
|
|
mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4) / 255.0
|
|
mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w")
|
|
return mask_image
|
|
|
|
|
|
class ImageProcessor:
|
|
def __init__(self, resolution: int = 512, device: str = "cpu", mask_image=None):
|
|
self.resolution = resolution
|
|
self.resize = transforms.Resize(
|
|
(resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True
|
|
)
|
|
self.normalize = transforms.Normalize([0.5], [0.5], inplace=True)
|
|
|
|
self.restorer = AlignRestore(resolution=resolution, device=device)
|
|
|
|
if mask_image is None:
|
|
self.mask_image = load_fixed_mask(resolution)
|
|
else:
|
|
self.mask_image = mask_image
|
|
|
|
if device == "cpu":
|
|
self.face_detector = None
|
|
else:
|
|
self.face_detector = FaceDetector(device=device)
|
|
|
|
def affine_transform(self, image: torch.Tensor) -> np.ndarray:
|
|
if self.face_detector is None:
|
|
raise NotImplementedError("Using the CPU for face detection is not supported")
|
|
bbox, landmark_2d_106 = self.face_detector(image)
|
|
if bbox is None:
|
|
raise RuntimeError("Face not detected")
|
|
|
|
pt_left_eye = np.mean(landmark_2d_106[[43, 48, 49, 51, 50]], axis=0) # left eyebrow center
|
|
pt_right_eye = np.mean(landmark_2d_106[101:106], axis=0) # right eyebrow center
|
|
pt_nose = np.mean(landmark_2d_106[[74, 77, 83, 86]], axis=0) # nose center
|
|
|
|
landmarks3 = np.round([pt_left_eye, pt_right_eye, pt_nose])
|
|
|
|
face, affine_matrix = self.restorer.align_warp_face(image.copy(), landmarks3=landmarks3, smooth=True)
|
|
box = [0, 0, face.shape[1], face.shape[0]] # x1, y1, x2, y2
|
|
face = cv2.resize(face, (self.resolution, self.resolution), interpolation=cv2.INTER_LANCZOS4)
|
|
face = rearrange(torch.from_numpy(face), "h w c -> c h w")
|
|
return face, box, affine_matrix
|
|
|
|
def preprocess_fixed_mask_image(self, image: torch.Tensor, affine_transform=False):
|
|
if affine_transform:
|
|
image, _, _ = self.affine_transform(image)
|
|
else:
|
|
image = self.resize(image)
|
|
pixel_values = self.normalize(image / 255.0)
|
|
masked_pixel_values = pixel_values * self.mask_image
|
|
return pixel_values, masked_pixel_values, self.mask_image[0:1]
|
|
|
|
def prepare_masks_and_masked_images(self, images: Union[torch.Tensor, np.ndarray], affine_transform=False):
|
|
if isinstance(images, np.ndarray):
|
|
images = torch.from_numpy(images)
|
|
if images.shape[3] == 3:
|
|
images = rearrange(images, "f h w c -> f c h w")
|
|
|
|
results = [self.preprocess_fixed_mask_image(image, affine_transform=affine_transform) for image in images]
|
|
|
|
pixel_values_list, masked_pixel_values_list, masks_list = list(zip(*results))
|
|
return torch.stack(pixel_values_list), torch.stack(masked_pixel_values_list), torch.stack(masks_list)
|
|
|
|
def process_images(self, images: Union[torch.Tensor, np.ndarray]):
|
|
if isinstance(images, np.ndarray):
|
|
images = torch.from_numpy(images)
|
|
if images.shape[3] == 3:
|
|
images = rearrange(images, "f h w c -> f c h w")
|
|
images = self.resize(images)
|
|
pixel_values = self.normalize(images / 255.0)
|
|
return pixel_values
|
|
|
|
|
|
class VideoProcessor:
|
|
def __init__(self, resolution: int = 512, device: str = "cpu"):
|
|
self.image_processor = ImageProcessor(resolution, device)
|
|
|
|
def affine_transform_video(self, video_path):
|
|
video_frames = read_video(video_path, change_fps=False)
|
|
results = []
|
|
for frame in video_frames:
|
|
frame, _, _ = self.image_processor.affine_transform(frame)
|
|
results.append(frame)
|
|
results = torch.stack(results)
|
|
|
|
results = rearrange(results, "f c h w -> f h w c").numpy()
|
|
return results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
video_processor = VideoProcessor(256, "cuda")
|
|
video_frames = video_processor.affine_transform_video("assets/demo2_video.mp4")
|
|
write_video("output.mp4", video_frames, fps=25)
|