import os import cv2 import math import copy import torch import glob import shutil import pickle import argparse import numpy as np import subprocess from tqdm import tqdm from omegaconf import OmegaConf from transformers import WhisperModel import sys import traceback # Try imports, handle if running as script vs module try: from musetalk.utils.blending import get_image from musetalk.utils.face_parsing import FaceParsing from musetalk.utils.audio_processor import AudioProcessor from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder except ImportError: # If running from root directory from musetalk.utils.blending import get_image from musetalk.utils.face_parsing import FaceParsing from musetalk.utils.audio_processor import AudioProcessor from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder def fast_check_ffmpeg(): try: subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) return True except: return False def run_ffmpeg(cmd): print(f"Executing: {cmd}") try: # Use shell=True to support the command string format used result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True) return True except subprocess.CalledProcessError as e: print(f"Error executing ffmpeg: {cmd}") print(f"Return code: {e.returncode}") print(f"Stdout: {e.stdout}") print(f"Stderr: {e.stderr}") return False @torch.no_grad() def main(args): # Configure ffmpeg path if not fast_check_ffmpeg(): print("Adding ffmpeg to PATH") path_separator = ';' if sys.platform == 'win32' else ':' os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}" if not fast_check_ffmpeg(): print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed") # Set computing device device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu") # Load model weights vae, unet, pe = load_all_model( unet_model_path=args.unet_model_path, vae_type=args.vae_type, unet_config=args.unet_config, device=device ) timesteps = torch.tensor([0], device=device) if args.use_float16: pe = pe.half() vae.vae = vae.vae.half() unet.model = unet.model.half() pe = pe.to(device) vae.vae = vae.vae.to(device) unet.model = unet.model.to(device) # Initialize components audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir) weight_dtype = unet.model.dtype whisper = WhisperModel.from_pretrained(args.whisper_dir) whisper = whisper.to(device=device, dtype=weight_dtype).eval() whisper.requires_grad_(False) if args.version == "v15": fp = FaceParsing( left_cheek_width=args.left_cheek_width, right_cheek_width=args.right_cheek_width ) else: fp = FaceParsing() # TASK CONFIGURATION if args.video_path and args.audio_path: print(f"Using command line arguments. Video: {args.video_path}, Audio: {args.audio_path}") inference_config = { "task_cmd": { "video_path": args.video_path, "audio_path": args.audio_path } } if args.output_path: args.output_vid_name = args.output_path else: inference_config = OmegaConf.load(args.inference_config) print("Loaded inference config:", inference_config) for task_id in inference_config: try: video_path = inference_config[task_id]["video_path"] audio_path = inference_config[task_id]["audio_path"] if "result_name" in inference_config[task_id]: args.output_vid_name = inference_config[task_id]["result_name"] if args.version == "v15": bbox_shift = 0 else: bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) input_basename = os.path.basename(video_path).split('.')[0] audio_basename = os.path.basename(audio_path).split('.')[0] output_basename = f"{input_basename}_{audio_basename}" temp_dir = os.path.join(args.result_dir, f"{args.version}") os.makedirs(temp_dir, exist_ok=True) result_img_save_path = os.path.join(temp_dir, output_basename) crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl") os.makedirs(result_img_save_path, exist_ok=True) if args.output_vid_name is None: output_vid_name = os.path.join(temp_dir, output_basename + ".mp4") else: if os.path.isabs(args.output_vid_name) or "/" in args.output_vid_name or "\\" in args.output_vid_name: output_vid_name = args.output_vid_name else: output_vid_name = os.path.join(temp_dir, args.output_vid_name) if get_file_type(video_path) == "video": save_dir_full = os.path.join(temp_dir, input_basename) os.makedirs(save_dir_full, exist_ok=True) cmd = f"ffmpeg -y -v warning -i {video_path} -start_number 0 {save_dir_full}/%08d.png" if not run_ffmpeg(cmd): raise RuntimeError("FFmpeg failed to extract frames") input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) fps = get_video_fps(video_path) elif get_file_type(video_path) == "image": input_img_list = [video_path] fps = args.fps elif os.path.isdir(video_path): input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]')) input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) fps = args.fps else: raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path) whisper_chunks = audio_processor.get_whisper_chunk( whisper_input_features, device, weight_dtype, whisper, librosa_length, fps=fps, audio_padding_length_left=args.audio_padding_length_left, audio_padding_length_right=args.audio_padding_length_right, ) if os.path.exists(crop_coord_save_path) and args.use_saved_coord: print("Using saved coordinates") with open(crop_coord_save_path, 'rb') as f: coord_list = pickle.load(f) frame_list = read_imgs(input_img_list) else: print("Extracting landmarks...") coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) with open(crop_coord_save_path, 'wb') as f: pickle.dump(coord_list, f) print(f"Number of frames: {len(frame_list)}") input_latent_list = [] for bbox, frame in zip(coord_list, frame_list): if bbox == coord_placeholder: continue x1, y1, x2, y2 = bbox if args.version == "v15": y2 = y2 + args.extra_margin y2 = min(y2, frame.shape[0]) crop_frame = frame[y1:y2, x1:x2] crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4) latents = vae.get_latents_for_unet(crop_frame) input_latent_list.append(latents) frame_list_cycle = frame_list + frame_list[::-1] coord_list_cycle = coord_list + coord_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1] print("Starting inference") video_num = len(whisper_chunks) batch_size = args.batch_size gen = datagen( whisper_chunks=whisper_chunks, vae_encode_latents=input_latent_list_cycle, batch_size=batch_size, delay_frame=0, device=device, ) res_frame_list = [] total = int(np.ceil(float(video_num) / batch_size)) for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)): audio_feature_batch = pe(whisper_batch) latent_batch = latent_batch.to(dtype=unet.model.dtype) pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample recon = vae.decode_latents(pred_latents) for res_frame in recon: res_frame_list.append(res_frame) print("Padding generated images to original video size") for i, res_frame in enumerate(tqdm(res_frame_list)): bbox = coord_list_cycle[i%(len(coord_list_cycle))] ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))]) x1, y1, x2, y2 = bbox if args.version == "v15": y2 = y2 + args.extra_margin y2 = min(y2, frame.shape[0]) try: res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) except: continue if args.version == "v15": combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp) else: combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp) cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame) # VIDEO SYNTHESIS temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4" cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}" print("Generating Video...") if not run_ffmpeg(cmd_img2video): print(f"FAILED to generate video from frames at {result_img_save_path}. Keeping frames.") continue # Skip to next task or stop cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}" print("Combining Audio...") if not run_ffmpeg(cmd_combine_audio): print(f"FAILED to combine audio. Temp video at {temp_vid_path}.") continue # Clean up print("Cleaning up temporary files...") try: shutil.rmtree(result_img_save_path) os.remove(temp_vid_path) shutil.rmtree(save_dir_full) if not args.saved_coord: os.remove(crop_coord_save_path) except Exception as e: print(f"Warning: Cleanup failed: {e}") print(f"Results saved to {output_vid_name}") except Exception as e: print("Error occurred during processing:", e) traceback.print_exc() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable") parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use") parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model") parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file") parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights") parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model") parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file") parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value") parser.add_argument("--result_dir", default='./results', help="Directory for output results") parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping") parser.add_argument("--fps", type=int, default=25, help="Video frames per second") parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio") parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio") parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference") parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file") parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time') parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use') parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference") parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode") parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region") parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region") parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Model version to use") # NEW ARGUMENTS parser.add_argument("--video_path", type=str, default=None, help="Input video path") parser.add_argument("--audio_path", type=str, default=None, help="Input audio path") parser.add_argument("--output_path", type=str, default=None, help="Output video path (alias for output_vid_name)") args = parser.parse_args() main(args)