diff --git a/Docs/DEPLOY_MANUAL.md b/Docs/DEPLOY_MANUAL.md index ff70f62..7fb2a28 100644 --- a/Docs/DEPLOY_MANUAL.md +++ b/Docs/DEPLOY_MANUAL.md @@ -135,7 +135,7 @@ uvicorn app.main:app --host 0.0.0.0 --port 8006 ```bash cd /home/rongye/ProgramFiles/ViGent/frontend -npm run dev -- --host 0.0.0.0 +npm run dev -- --host 0.0.0.0 --port 3002 ``` --- diff --git a/frontend/package.json b/frontend/package.json index fe2e542..2d12704 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -3,9 +3,9 @@ "version": "0.1.0", "private": true, "scripts": { - "dev": "next dev", + "dev": "next dev -p 3002", "build": "next build", - "start": "next start", + "start": "next start -p 3002", "lint": "eslint" }, "dependencies": { @@ -23,4 +23,4 @@ "tailwindcss": "^4", "typescript": "^5" } -} +} \ No newline at end of file diff --git a/models/MuseTalk/scripts/inference.py b/models/MuseTalk/scripts/inference.py new file mode 100644 index 0000000..aa12e2a --- /dev/null +++ b/models/MuseTalk/scripts/inference.py @@ -0,0 +1,305 @@ +import os +import cv2 +import math +import copy +import torch +import glob +import shutil +import pickle +import argparse +import numpy as np +import subprocess +from tqdm import tqdm +from omegaconf import OmegaConf +from transformers import WhisperModel +import sys +import traceback + +# Try imports, handle if running as script vs module +try: + from musetalk.utils.blending import get_image + from musetalk.utils.face_parsing import FaceParsing + from musetalk.utils.audio_processor import AudioProcessor + from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model + from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder +except ImportError: + # If running from root directory + from musetalk.utils.blending import get_image + from musetalk.utils.face_parsing import FaceParsing + from musetalk.utils.audio_processor import AudioProcessor + from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model + from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder + +def fast_check_ffmpeg(): + try: + subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) + return True + except: + return False + +def run_ffmpeg(cmd): + print(f"Executing: {cmd}") + try: + # Use shell=True to support the command string format used + result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True) + return True + except subprocess.CalledProcessError as e: + print(f"Error executing ffmpeg: {cmd}") + print(f"Return code: {e.returncode}") + print(f"Stdout: {e.stdout}") + print(f"Stderr: {e.stderr}") + return False + +@torch.no_grad() +def main(args): + # Configure ffmpeg path + if not fast_check_ffmpeg(): + print("Adding ffmpeg to PATH") + path_separator = ';' if sys.platform == 'win32' else ':' + os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}" + if not fast_check_ffmpeg(): + print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed") + + # Set computing device + device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu") + + # Load model weights + vae, unet, pe = load_all_model( + unet_model_path=args.unet_model_path, + vae_type=args.vae_type, + unet_config=args.unet_config, + device=device + ) + timesteps = torch.tensor([0], device=device) + + if args.use_float16: + pe = pe.half() + vae.vae = vae.vae.half() + unet.model = unet.model.half() + + pe = pe.to(device) + vae.vae = vae.vae.to(device) + unet.model = unet.model.to(device) + + # Initialize components + audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir) + weight_dtype = unet.model.dtype + whisper = WhisperModel.from_pretrained(args.whisper_dir) + whisper = whisper.to(device=device, dtype=weight_dtype).eval() + whisper.requires_grad_(False) + + if args.version == "v15": + fp = FaceParsing( + left_cheek_width=args.left_cheek_width, + right_cheek_width=args.right_cheek_width + ) + else: + fp = FaceParsing() + + # TASK CONFIGURATION + if args.video_path and args.audio_path: + print(f"Using command line arguments. Video: {args.video_path}, Audio: {args.audio_path}") + inference_config = { + "task_cmd": { + "video_path": args.video_path, + "audio_path": args.audio_path + } + } + if args.output_path: + args.output_vid_name = args.output_path + else: + inference_config = OmegaConf.load(args.inference_config) + print("Loaded inference config:", inference_config) + + for task_id in inference_config: + try: + video_path = inference_config[task_id]["video_path"] + audio_path = inference_config[task_id]["audio_path"] + if "result_name" in inference_config[task_id]: + args.output_vid_name = inference_config[task_id]["result_name"] + + if args.version == "v15": + bbox_shift = 0 + else: + bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift) + + input_basename = os.path.basename(video_path).split('.')[0] + audio_basename = os.path.basename(audio_path).split('.')[0] + output_basename = f"{input_basename}_{audio_basename}" + + temp_dir = os.path.join(args.result_dir, f"{args.version}") + os.makedirs(temp_dir, exist_ok=True) + + result_img_save_path = os.path.join(temp_dir, output_basename) + crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl") + os.makedirs(result_img_save_path, exist_ok=True) + + if args.output_vid_name is None: + output_vid_name = os.path.join(temp_dir, output_basename + ".mp4") + else: + if os.path.isabs(args.output_vid_name) or "/" in args.output_vid_name or "\\" in args.output_vid_name: + output_vid_name = args.output_vid_name + else: + output_vid_name = os.path.join(temp_dir, args.output_vid_name) + + if get_file_type(video_path) == "video": + save_dir_full = os.path.join(temp_dir, input_basename) + os.makedirs(save_dir_full, exist_ok=True) + cmd = f"ffmpeg -y -v warning -i {video_path} -start_number 0 {save_dir_full}/%08d.png" + if not run_ffmpeg(cmd): + raise RuntimeError("FFmpeg failed to extract frames") + + input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]'))) + fps = get_video_fps(video_path) + elif get_file_type(video_path) == "image": + input_img_list = [video_path] + fps = args.fps + elif os.path.isdir(video_path): + input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]')) + input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + fps = args.fps + else: + raise ValueError(f"{video_path} should be a video file, an image file or a directory of images") + + whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path) + whisper_chunks = audio_processor.get_whisper_chunk( + whisper_input_features, device, weight_dtype, whisper, librosa_length, fps=fps, + audio_padding_length_left=args.audio_padding_length_left, + audio_padding_length_right=args.audio_padding_length_right, + ) + + if os.path.exists(crop_coord_save_path) and args.use_saved_coord: + print("Using saved coordinates") + with open(crop_coord_save_path, 'rb') as f: + coord_list = pickle.load(f) + frame_list = read_imgs(input_img_list) + else: + print("Extracting landmarks...") + coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift) + with open(crop_coord_save_path, 'wb') as f: + pickle.dump(coord_list, f) + + print(f"Number of frames: {len(frame_list)}") + + input_latent_list = [] + for bbox, frame in zip(coord_list, frame_list): + if bbox == coord_placeholder: + continue + x1, y1, x2, y2 = bbox + if args.version == "v15": + y2 = y2 + args.extra_margin + y2 = min(y2, frame.shape[0]) + crop_frame = frame[y1:y2, x1:x2] + crop_frame = cv2.resize(crop_frame, (256,256), interpolation=cv2.INTER_LANCZOS4) + latents = vae.get_latents_for_unet(crop_frame) + input_latent_list.append(latents) + + frame_list_cycle = frame_list + frame_list[::-1] + coord_list_cycle = coord_list + coord_list[::-1] + input_latent_list_cycle = input_latent_list + input_latent_list[::-1] + + print("Starting inference") + video_num = len(whisper_chunks) + batch_size = args.batch_size + gen = datagen( + whisper_chunks=whisper_chunks, + vae_encode_latents=input_latent_list_cycle, + batch_size=batch_size, + delay_frame=0, + device=device, + ) + + res_frame_list = [] + total = int(np.ceil(float(video_num) / batch_size)) + + for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total)): + audio_feature_batch = pe(whisper_batch) + latent_batch = latent_batch.to(dtype=unet.model.dtype) + pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample + recon = vae.decode_latents(pred_latents) + for res_frame in recon: + res_frame_list.append(res_frame) + + print("Padding generated images to original video size") + for i, res_frame in enumerate(tqdm(res_frame_list)): + bbox = coord_list_cycle[i%(len(coord_list_cycle))] + ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))]) + x1, y1, x2, y2 = bbox + if args.version == "v15": + y2 = y2 + args.extra_margin + y2 = min(y2, frame.shape[0]) + try: + res_frame = cv2.resize(res_frame.astype(np.uint8), (x2-x1, y2-y1)) + except: + continue + + if args.version == "v15": + combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp) + else: + combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp) + cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png", combine_frame) + + # VIDEO SYNTHESIS + temp_vid_path = f"{temp_dir}/temp_{input_basename}_{audio_basename}.mp4" + + cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=yuv420p -crf 18 {temp_vid_path}" + print("Generating Video...") + if not run_ffmpeg(cmd_img2video): + print(f"FAILED to generate video from frames at {result_img_save_path}. Keeping frames.") + continue # Skip to next task or stop + + cmd_combine_audio = f"ffmpeg -y -v warning -i {audio_path} -i {temp_vid_path} {output_vid_name}" + print("Combining Audio...") + if not run_ffmpeg(cmd_combine_audio): + print(f"FAILED to combine audio. Temp video at {temp_vid_path}.") + continue + + # Clean up + print("Cleaning up temporary files...") + try: + shutil.rmtree(result_img_save_path) + os.remove(temp_vid_path) + shutil.rmtree(save_dir_full) + if not args.saved_coord: + os.remove(crop_coord_save_path) + except Exception as e: + print(f"Warning: Cleanup failed: {e}") + + print(f"Results saved to {output_vid_name}") + + except Exception as e: + print("Error occurred during processing:", e) + traceback.print_exc() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable") + parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use") + parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model") + parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file") + parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights") + parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model") + parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file") + parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value") + parser.add_argument("--result_dir", default='./results', help="Directory for output results") + parser.add_argument("--extra_margin", type=int, default=10, help="Extra margin for face cropping") + parser.add_argument("--fps", type=int, default=25, help="Video frames per second") + parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio") + parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio") + parser.add_argument("--batch_size", type=int, default=8, help="Batch size for inference") + parser.add_argument("--output_vid_name", type=str, default=None, help="Name of output video file") + parser.add_argument("--use_saved_coord", action="store_true", help='Use saved coordinates to save time') + parser.add_argument("--saved_coord", action="store_true", help='Save coordinates for future use') + parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference") + parser.add_argument("--parsing_mode", default='jaw', help="Face blending parsing mode") + parser.add_argument("--left_cheek_width", type=int, default=90, help="Width of left cheek region") + parser.add_argument("--right_cheek_width", type=int, default=90, help="Width of right cheek region") + parser.add_argument("--version", type=str, default="v15", choices=["v1", "v15"], help="Model version to use") + + # NEW ARGUMENTS + parser.add_argument("--video_path", type=str, default=None, help="Input video path") + parser.add_argument("--audio_path", type=str, default=None, help="Input audio path") + parser.add_argument("--output_path", type=str, default=None, help="Output video path (alias for output_vid_name)") + + args = parser.parse_args() + main(args)