代码修复

This commit is contained in:
Kevin Wong
2026-01-15 17:26:55 +08:00
parent e2282195b4
commit e2a3a88e23
3 changed files with 309 additions and 4 deletions

View File

@@ -0,0 +1,305 @@
import os
import cv2
import math
import copy
import torch
import glob
import shutil
import pickle
import argparse
import numpy as np
import subprocess
from tqdm import tqdm
from omegaconf import OmegaConf
from transformers import WhisperModel
import sys
import traceback
# Try imports, handle if running as script vs module
try:
from musetalk.utils.blending import get_image
from musetalk.utils.face_parsing import FaceParsing
from musetalk.utils.audio_processor import AudioProcessor
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
except ImportError:
# If running from root directory
from musetalk.utils.blending import get_image
from musetalk.utils.face_parsing import FaceParsing
from musetalk.utils.audio_processor import AudioProcessor
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
def fast_check_ffmpeg():
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
return True
except:
return False
def run_ffmpeg(cmd):
print(f"Executing: {cmd}")
try:
# Use shell=True to support the command string format used
result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
return True
except subprocess.CalledProcessError as e:
print(f"Error executing ffmpeg: {cmd}")
print(f"Return code: {e.returncode}")
print(f"Stdout: {e.stdout}")
print(f"Stderr: {e.stderr}")
return False
@torch.no_grad()
def main(args):
# Configure ffmpeg path
if not fast_check_ffmpeg():
print("Adding ffmpeg to PATH")
path_separator = ';' if sys.platform == 'win32' else ':'
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
if not fast_check_ffmpeg():
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
# Set computing device
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
# Load model weights
vae, unet, pe = load_all_model(
unet_model_path=args.unet_model_path,
vae_type=args.vae_type,
unet_config=args.unet_config,
device=device
)
timesteps = torch.tensor([0], device=device)
if args.use_float16:
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
pe = pe.to(device)
vae.vae = vae.vae.to(device)
unet.model = unet.model.to(device)
# Initialize components
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
weight_dtype = unet.model.dtype
whisper = WhisperModel.from_pretrained(args.whisper_dir)
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
whisper.requires_grad_(False)
if args.version == "v15":
fp = FaceParsing(
left_cheek_width=args.left_cheek_width,
right_cheek_width=args.right_cheek_width
)
else:
fp = FaceParsing()
# TASK CONFIGURATION
if args.video_path and args.audio_path:
print(f"Using command line arguments. Video: {args.video_path}, Audio: {args.audio_path}")
inference_config = {
"task_cmd": {
"video_path": args.video_path,
"audio_path": args.audio_path
}
}
if args.output_path:
args.output_vid_name = args.output_path
else:
inference_config = OmegaConf.load(args.inference_config)
print("Loaded inference config:", inference_config)
for task_id in inference_config:
try:
video_path = inference_config[task_id]["video_path"]
audio_path = inference_config[task_id]["audio_path"]
if "result_name" in inference_config[task_id]:
args.output_vid_name = inference_config[task_id]["result_name"]
if args.version == "v15":
bbox_shift = 0
else:
bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
input_basename = os.path.basename(video_path).split('.')[0]
audio_basename = os.path.basename(audio_path).split('.')[0]
output_basename = f"{input_basename}_{audio_basename}"
temp_dir = os.path.join(args.result_dir, f"{args.version}")
os.makedirs(temp_dir, exist_ok=True)
result_img_save_path = os.path.join(temp_dir, output_basename)
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
os.makedirs(result_img_save_path, exist_ok=True)
if args.output_vid_name is None:
output_vid_name = os.path.join(temp_dir, output_basename + ".mp4")
else:
if os.path.isabs(args.output_vid_name) or "/" in args.output_vid_name or "\\" in args.output_vid_name:
output_vid_name = args.output_vid_name
else:
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
if get_file_type(video_path) == "video":
save_dir_full = os.path.join(temp_dir, input_basename)
os.makedirs(save_dir_full, exist_ok=True)
cmd = f"ffmpeg -y -v warning -i {video_path} -start_number 0 {save_dir_full}/%08d.png"
if not run_ffmpeg(cmd):
raise RuntimeError("FFmpeg failed to extract frames")
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
fps = get_video_fps(video_path)
elif get_file_type(video_path) == "image":
input_img_list = [video_path]
fps = args.fps
elif os.path.isdir(video_path):
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
fps = args.fps
else:
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features, device, weight_dtype, whisper, librosa_length, fps=fps,
audio_padding_length_left=args.audio_padding_length_left,
audio_padding_length_right=args.audio_padding_length_right,
)
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
print("Using saved coordinates")
with open(crop_coord_save_path, 'rb') as f:
coord_list = pickle.load(f)
frame_list = read_imgs(input_img_list)
else:
print("Extracting landmarks...")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
with open(crop_coord_save_path, 'wb') as f:
pickle.dump(coord_list, f)
print(f"Number of frames: {len(frame_list)}")
input_latent_list = []
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
if args.version == "v15":
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
print("Starting inference")
video_num = len(whisper_chunks)
batch_size = args.batch_size
gen = datagen(
whisper_chunks=whisper_chunks,
vae_encode_latents=input_latent_list_cycle,
batch_size=batch_size,
delay_frame=0,
device=device,
)
res_frame_list = []
total = int(np.ceil(float(video_num) / batch_size))
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)):
audio_feature_batch = pe(whisper_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_list.append(res_frame)
print("Padding generated images to original video size")
for i, res_frame in enumerate(tqdm(res_frame_list)):
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
x1, y1, x2, y2 = bbox
if args.version == "v15":
y2 = y2 + args.extra_margin
y2 = min(y2, frame.shape[0])
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1))
except:
continue
if args.version == "v15":
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
else:
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame)
# VIDEO SYNTHESIS
temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4"
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}"
print("Generating Video...")
if not run_ffmpeg(cmd_img2video):
print(f"FAILED to generate video from frames at {result_img_save_path}. Keeping frames.")
continue # Skip to next task or stop
cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}"
print("Combining Audio...")
if not run_ffmpeg(cmd_combine_audio):
print(f"FAILED to combine audio. Temp video at {temp_vid_path}.")
continue
# Clean up
print("Cleaning up temporary files...")
try:
shutil.rmtree(result_img_save_path)
os.remove(temp_vid_path)
shutil.rmtree(save_dir_full)
if not args.saved_coord:
os.remove(crop_coord_save_path)
except Exception as e:
print(f"Warning: Cleanup failed: {e}")
print(f"Results saved to {output_vid_name}")
except Exception as e:
print("Error occurred during processing:", e)
traceback.print_exc()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping")
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference")
parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file")
parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time')
parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use')
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode")
parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region")
parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region")
parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Model version to use")
# NEW ARGUMENTS
parser.add_argument("--video_path", type=str, default=None, help="Input video path")
parser.add_argument("--audio_path", type=str, default=None, help="Input audio path")
parser.add_argument("--output_path", type=str, default=None, help="Output video path (alias for output_vid_name)")
args = parser.parse_args()
main(args)