""" MuseTalk v1.5 常驻推理服务 (优化版 v2) - 端口: 8011 - GPU: 从 backend/.env 读取 MUSETALK_GPU_ID (默认 0) - 架构: FastAPI + lifespan (与 LatentSync server.py 同模式) 优化项 (vs v1): 1. cv2.VideoCapture 直读帧 (跳过 ffmpeg→PNG→imread) 2. 人脸检测降频 (每 N 帧检测, 中间插值 bbox) 3. BiSeNet mask 缓存 (每 N 帧更新, 中间复用) 4. cv2.VideoWriter 直写视频 (跳过逐帧 PNG 写盘) 5. batch_size 8→32 6. 每阶段计时 """ import os import sys import math import copy import time import glob import shutil import tempfile import subprocess from pathlib import Path # --- 自动加载 GPU 配置 (必须在 torch 导入前) --- def load_gpu_config(): """尝试从后端 .env 文件读取 MUSETALK_GPU_ID""" try: current_dir = Path(__file__).resolve().parent env_path = current_dir.parent.parent.parent / "backend" / ".env" target_gpu = "0" # 默认 GPU 0 if env_path.exists(): print(f"📖 读取配置文件: {env_path}") with open(env_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line.startswith("MUSETALK_GPU_ID="): val = line.split("=")[1].strip().split("#")[0].strip() if val: target_gpu = val print(f"⚙️ 发现配置 MUSETALK_GPU_ID={target_gpu}") break if "CUDA_VISIBLE_DEVICES" not in os.environ: os.environ["CUDA_VISIBLE_DEVICES"] = target_gpu print(f"✅ 已自动设置: CUDA_VISIBLE_DEVICES={target_gpu}") else: print(f"ℹ️ 检测到外部 CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']},跳过自动配置") except Exception as e: print(f"⚠️ 读取 GPU 配置失败: {e},将使用默认设置") load_gpu_config() # --- 性能优化: 限制 CPU 线程数 --- os.environ["OMP_NUM_THREADS"] = "8" os.environ["MKL_NUM_THREADS"] = "8" os.environ["TORCH_NUM_THREADS"] = "8" print("⚙️ 已限制 PyTorch CPU 线程数为 8,防止系统卡顿") import cv2 import torch import pickle import numpy as np from tqdm import tqdm from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import Optional from transformers import WhisperModel # 添加项目根目录到 sys.path (MuseTalk 根目录) musetalk_root = Path(__file__).resolve().parent.parent sys.path.insert(0, str(musetalk_root)) from musetalk.utils.blending import get_image, get_image_blending, get_image_blending_fast, get_image_prepare_material from musetalk.utils.face_parsing import FaceParsing from musetalk.utils.audio_processor import AudioProcessor from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder # --- 从 .env 读取额外配置 --- def load_env_config(): """读取 MuseTalk 相关环境变量""" config = { "batch_size": 32, "version": "v15", "use_float16": True, } try: env_path = musetalk_root.parent.parent / "backend" / ".env" if env_path.exists(): with open(env_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line.startswith("MUSETALK_BATCH_SIZE="): val = line.split("=")[1].strip().split("#")[0].strip() if val: config["batch_size"] = int(val) elif line.startswith("MUSETALK_VERSION="): val = line.split("=")[1].strip().split("#")[0].strip() if val: config["version"] = val elif line.startswith("MUSETALK_USE_FLOAT16="): val = line.split("=")[1].strip().split("#")[0].strip().lower() config["use_float16"] = val in ("true", "1", "yes") except Exception as e: print(f"⚠️ 读取额外配置失败: {e}") return config env_config = load_env_config() # 全局模型缓存 models = {} # ===================== 优化参数 ===================== DETECT_EVERY = 5 # 人脸检测降频: 每 N 帧检测一次 BLEND_CACHE_EVERY = 5 # BiSeNet mask 缓存: 每 N 帧更新一次 # ==================================================== def run_ffmpeg(cmd): """执行 FFmpeg 命令(接受列表或字符串)""" if isinstance(cmd, str): cmd = cmd.split() print(f"Executing: {' '.join(cmd)}") try: result = subprocess.run(cmd, check=True, capture_output=True, text=True) return True except subprocess.CalledProcessError as e: print(f"Error executing ffmpeg: {' '.join(cmd)}") print(f"Return code: {e.returncode}") print(f"Stderr: {e.stderr[:500]}") return False @asynccontextmanager async def lifespan(app: FastAPI): """启动时加载所有模型,只做一次""" print("⏳ 正在加载 MuseTalk v1.5 模型...") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") version = env_config["version"] use_float16 = env_config["use_float16"] if torch.cuda.is_available(): gpu_name = torch.cuda.get_device_name(0) print(f"🖥️ 正在使用 GPU: {gpu_name}") else: print("⚠️ 警告: 未检测到 GPU,将使用 CPU 进行推理 (速度极慢)") # 根据版本选择模型路径 models_dir = musetalk_root / "models" if version == "v15": unet_model_path = str(models_dir / "musetalkV15" / "unet.pth") unet_config = str(models_dir / "musetalk" / "config.json") else: unet_model_path = str(models_dir / "musetalk" / "pytorch_model.bin") unet_config = str(models_dir / "musetalk" / "musetalk.json") # 切换工作目录(load_all_model 使用相对路径加载 VAE) original_cwd = os.getcwd() os.chdir(str(musetalk_root)) vae, unet, pe = load_all_model( unet_model_path=unet_model_path, vae_type="sd-vae", unet_config=unet_config, device=device, ) if use_float16 and torch.cuda.is_available(): print("⚡ 使用 float16 半精度加速") pe = pe.half() vae.vae = vae.vae.half() unet.model = unet.model.half() pe = pe.to(device) vae.vae = vae.vae.to(device) unet.model = unet.model.to(device) # Whisper whisper_dir = str(models_dir / "whisper") audio_processor = AudioProcessor(feature_extractor_path=whisper_dir) weight_dtype = unet.model.dtype whisper = WhisperModel.from_pretrained(whisper_dir) whisper = whisper.to(device=device, dtype=weight_dtype).eval() whisper.requires_grad_(False) # FaceParsing if version == "v15": fp = FaceParsing(left_cheek_width=90, right_cheek_width=90) else: fp = FaceParsing() # 恢复工作目录 os.chdir(original_cwd) models["vae"] = vae models["unet"] = unet models["pe"] = pe models["whisper"] = whisper models["audio_processor"] = audio_processor models["fp"] = fp models["device"] = device models["weight_dtype"] = weight_dtype models["version"] = version models["timesteps"] = torch.tensor([0], device=device) print("✅ MuseTalk v1.5 模型加载完成,服务就绪!") print(f"⚙️ 优化参数: batch_size={env_config['batch_size']}, " f"detect_every={DETECT_EVERY}, blend_cache_every={BLEND_CACHE_EVERY}") yield models.clear() torch.cuda.empty_cache() app = FastAPI(lifespan=lifespan) class LipSyncRequest(BaseModel): video_path: str audio_path: str video_out_path: str batch_size: int = 32 @app.get("/health") def health_check(): return {"status": "ok", "model_loaded": "unet" in models} @app.post("/lipsync") async def generate_lipsync(req: LipSyncRequest): if "unet" not in models: raise HTTPException(status_code=503, detail="Model not loaded") if not os.path.exists(req.video_path): raise HTTPException(status_code=404, detail=f"Video not found: {req.video_path}") if not os.path.exists(req.audio_path): raise HTTPException(status_code=404, detail=f"Audio not found: {req.audio_path}") print(f"🎬 收到任务: {Path(req.video_path).name} -> {Path(req.video_out_path).name}") start_time = time.time() try: result = _run_inference(req) elapsed = time.time() - start_time print(f"✅ 推理完成,耗时 {elapsed:.1f}s ({elapsed/60:.1f}min)") return result except Exception as e: import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) # ===================================================================== # 降频人脸检测: 每 N 帧检测一次, 中间帧线性插值 bbox # ===================================================================== def _detect_faces_subsampled(frames, detect_every=5): """ 降频人脸检测: - 每 detect_every 帧运行 DWPose + FaceAlignment - 中间帧线性插值 bbox 坐标 - 对于口播视频 (人脸几乎不动), 插值误差可忽略 """ from mmpose.apis import inference_topdown from mmpose.structures import merge_data_samples import musetalk.utils.preprocessing as _prep n = len(frames) if n == 0: return [] # 确定需要检测的帧索引 sampled_indices = list(range(0, n, detect_every)) if sampled_indices[-1] != n - 1: sampled_indices.append(n - 1) print(f" 检测 {len(sampled_indices)}/{n} 帧 (每{detect_every}帧)") # 在采样帧上运行检测 detected = {} for idx in tqdm(sampled_indices, desc="人脸检测"): frame = frames[idx] try: results = inference_topdown(_prep.model, frame) results = merge_data_samples(results) keypoints = results.pred_instances.keypoints face_land_mark = keypoints[0][23:91].astype(np.int32) bbox_list = _prep.fa.get_detections_for_batch(np.array([frame])) if bbox_list[0] is None: detected[idx] = coord_placeholder continue half_face_coord = face_land_mark[29].copy() half_face_dist = np.max(face_land_mark[:, 1]) - half_face_coord[1] upper_bond = max(0, half_face_coord[1] - half_face_dist) f_landmark = ( int(np.min(face_land_mark[:, 0])), int(upper_bond), int(np.max(face_land_mark[:, 0])), int(np.max(face_land_mark[:, 1])), ) x1, y1, x2, y2 = f_landmark if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0: detected[idx] = bbox_list[0] if bbox_list[0] is not None else coord_placeholder else: detected[idx] = f_landmark except Exception as e: print(f"⚠️ 帧 {idx} 检测失败: {e}") detected[idx] = coord_placeholder # 插值填充所有帧 coord_list = [None] * n for idx in sampled_indices: coord_list[idx] = detected[idx] for i in range(n): if coord_list[i] is not None: continue # 找前后已检测的帧 prev_idx = max(j for j in sampled_indices if j < i) next_idx = min(j for j in sampled_indices if j > i) prev_bbox = detected[prev_idx] next_bbox = detected[next_idx] if prev_bbox == coord_placeholder and next_bbox == coord_placeholder: coord_list[i] = coord_placeholder elif prev_bbox == coord_placeholder: coord_list[i] = next_bbox elif next_bbox == coord_placeholder: coord_list[i] = prev_bbox else: alpha = (i - prev_idx) / (next_idx - prev_idx) coord_list[i] = tuple( int(a * (1 - alpha) + b * alpha) for a, b in zip(prev_bbox, next_bbox) ) return coord_list # ===================================================================== # 核心推理 (优化版) # ===================================================================== @torch.no_grad() def _run_inference(req: LipSyncRequest) -> dict: """ 优化版推理逻辑: 1. cv2.VideoCapture 直读帧 (跳过 ffmpeg→PNG→imread) 2. 人脸检测降频 (每 N 帧, 中间插值) 3. BiSeNet mask 缓存 (每 N 帧更新) 4. cv2.VideoWriter 直写 (跳过逐帧 PNG) 5. 每阶段计时 """ vae = models["vae"] unet = models["unet"] pe = models["pe"] whisper = models["whisper"] audio_processor = models["audio_processor"] fp = models["fp"] device = models["device"] weight_dtype = models["weight_dtype"] version = models["version"] timesteps = models["timesteps"] batch_size = req.batch_size or env_config["batch_size"] video_path = req.video_path audio_path = req.audio_path output_vid_path = req.video_out_path os.makedirs(os.path.dirname(output_vid_path), exist_ok=True) t_total = time.time() timings = {} # ===== Phase 1: 读取视频帧 (cv2.VideoCapture, 跳过 ffmpeg→PNG) ===== t0 = time.time() if get_file_type(video_path) == "video": cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) or 25.0 frames = [] while True: ret, frame = cap.read() if not ret: break frames.append(frame) cap.release() elif get_file_type(video_path) == "image": frames = [cv2.imread(video_path)] fps = 25.0 else: raise ValueError(f"不支持的文件类型: {video_path}") timings["1_read"] = time.time() - t0 print(f"📹 读取 {len(frames)} 帧, FPS={fps} [{timings['1_read']:.1f}s]") if not frames: raise RuntimeError("视频帧为空") # ===== Phase 2: Whisper 音频特征 ===== t0 = time.time() whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path) whisper_chunks = audio_processor.get_whisper_chunk( whisper_input_features, device, weight_dtype, whisper, librosa_length, fps=fps, audio_padding_length_left=2, audio_padding_length_right=2, ) timings["2_whisper"] = time.time() - t0 print(f"🎵 Whisper 特征 [{timings['2_whisper']:.1f}s]") # ===== Phase 3: 人脸检测 (降频) ===== t0 = time.time() coord_list = _detect_faces_subsampled(frames, detect_every=DETECT_EVERY) timings["3_face"] = time.time() - t0 print(f"🔍 人脸检测 [{timings['3_face']:.1f}s]") # ===== Phase 4: VAE 潜空间编码 ===== t0 = time.time() input_latent_list = [] extra_margin = 15 for bbox, frame in zip(coord_list, frames): if bbox == coord_placeholder: continue x1, y1, x2, y2 = bbox if version == "v15": y2 = min(y2 + extra_margin, frame.shape[0]) crop_frame = frame[y1:y2, x1:x2] crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4) latents = vae.get_latents_for_unet(crop_frame) input_latent_list.append(latents) timings["4_vae"] = time.time() - t0 print(f"🧠 VAE 编码 [{timings['4_vae']:.1f}s]") # ===== Phase 5: UNet 批量推理 ===== t0 = time.time() # 循环帧序列 (引用, 不复制数据) frame_list_cycle = frames + frames[::-1] coord_list_cycle = coord_list + coord_list[::-1] input_latent_list_cycle = input_latent_list + input_latent_list[::-1] video_num = len(whisper_chunks) gen = datagen( whisper_chunks=whisper_chunks, vae_encode_latents=input_latent_list_cycle, batch_size=batch_size, delay_frame=0, device=device, ) res_frame_list = [] total_batches = int(np.ceil(float(video_num) / batch_size)) print(f"🚀 推理: {video_num} 帧, batch={batch_size}, {total_batches} 批") for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total_batches)): audio_feature_batch = pe(whisper_batch) latent_batch = latent_batch.to(dtype=unet.model.dtype) pred_latents = unet.model( latent_batch, timesteps, encoder_hidden_states=audio_feature_batch ).sample recon = vae.decode_latents(pred_latents) for res_frame in recon: res_frame_list.append(res_frame) timings["5_unet"] = time.time() - t0 print(f"✅ UNet 推理: {len(res_frame_list)} 帧 [{timings['5_unet']:.1f}s]") # ===== Phase 6: 合成 (cv2.VideoWriter + 纯 numpy blending) ===== t0 = time.time() h, w = frames[0].shape[:2] temp_raw_path = output_vid_path + ".raw.mp4" fourcc = cv2.VideoWriter_fourcc(*'mp4v') writer = cv2.VideoWriter(temp_raw_path, fourcc, fps, (w, h)) if not writer.isOpened(): raise RuntimeError(f"cv2.VideoWriter 打开失败: {temp_raw_path}") cached_mask = None cached_crop_box = None blend_mode = "jaw" if version == "v15" else "raw" for i in tqdm(range(len(res_frame_list)), desc="合成"): res_frame = res_frame_list[i] bbox = coord_list_cycle[i % len(coord_list_cycle)] ori_frame = frame_list_cycle[i % len(frame_list_cycle)].copy() x1, y1, x2, y2 = bbox if version == "v15": y2 = min(y2 + extra_margin, ori_frame.shape[0]) adjusted_bbox = (x1, y1, x2, y2) try: res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1)) except Exception: writer.write(ori_frame) continue # 每 N 帧更新 BiSeNet 人脸解析 mask, 其余帧复用缓存 if i % BLEND_CACHE_EVERY == 0 or cached_mask is None: try: cached_mask, cached_crop_box = get_image_prepare_material( ori_frame, adjusted_bbox, mode=blend_mode, fp=fp) except Exception: # 如果 prepare 失败, 用完整方式 combine_frame = get_image( ori_frame, res_frame, list(adjusted_bbox), mode=blend_mode, fp=fp) writer.write(combine_frame) continue try: combine_frame = get_image_blending_fast( ori_frame, res_frame, adjusted_bbox, cached_mask, cached_crop_box) except Exception: # blending_fast 失败时 fallback 到 PIL 方式 try: combine_frame = get_image_blending( ori_frame, res_frame, adjusted_bbox, cached_mask, cached_crop_box) except Exception: combine_frame = get_image( ori_frame, res_frame, list(adjusted_bbox), mode=blend_mode, fp=fp) writer.write(combine_frame) writer.release() timings["6_blend"] = time.time() - t0 print(f"🎨 合成 [{timings['6_blend']:.1f}s]") # ===== Phase 7: FFmpeg H.264 编码 + 合并音频 ===== t0 = time.time() cmd = [ "ffmpeg", "-y", "-v", "warning", "-i", temp_raw_path, "-i", audio_path, "-c:v", "libx264", "-crf", "18", "-pix_fmt", "yuv420p", "-c:a", "copy", "-shortest", output_vid_path ] if not run_ffmpeg(cmd): raise RuntimeError("FFmpeg 重编码+音频合并失败") # 清理临时文件 if os.path.exists(temp_raw_path): os.unlink(temp_raw_path) timings["7_encode"] = time.time() - t0 print(f"🔊 编码+音频 [{timings['7_encode']:.1f}s]") # ===== 汇总 ===== total_time = time.time() - t_total print(f"\n⏱️ 总耗时: {total_time:.1f}s ({total_time/60:.1f}min)") for k, v in timings.items(): pct = v / total_time * 100 print(f" {k}: {v:.1f}s ({pct:.0f}%)") if not os.path.exists(output_vid_path): raise RuntimeError("输出文件未生成") return {"status": "success", "output_path": output_vid_path} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8011)