diff --git a/backend/app/services/lipsync_service.py b/backend/app/services/lipsync_service.py index d872532..1df38a4 100644 --- a/backend/app/services/lipsync_service.py +++ b/backend/app/services/lipsync_service.py @@ -1,75 +1,46 @@ """ 唇形同步服务 -支持本地 MuseTalk 推理 (Python API) 或远程 MuseTalk API +通过 subprocess 调用 MuseTalk conda 环境进行推理 配置为使用 GPU1 (CUDA:1) """ import os -import sys import shutil import subprocess import tempfile import httpx from pathlib import Path from loguru import logger -from typing import Optional, Any +from typing import Optional from app.core.config import settings -# 设置 MuseTalk 使用 GPU1 (在导入 torch 之前设置) -os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(settings.MUSETALK_GPU_ID)) - class LipSyncService: - """唇形同步服务 - MuseTalk 集成""" + """唇形同步服务 - MuseTalk 集成 (Subprocess 方式)""" def __init__(self): self.use_local = settings.MUSETALK_LOCAL self.api_url = settings.MUSETALK_API_URL self.version = settings.MUSETALK_VERSION self.musetalk_dir = settings.MUSETALK_DIR + self.gpu_id = settings.MUSETALK_GPU_ID - # 模型相关 (懒加载) - self._model_loaded = False - self._vae = None - self._unet = None - self._pe = None - self._whisper = None - self._audio_processor = None - self._face_parser = None - self._device = None + # Conda 环境 Python 路径 + # 根据服务器实际情况调整 + self.conda_python = Path.home() / "ProgramFiles" / "miniconda3" / "envs" / "musetalk" / "bin" / "python" # 运行时检测 - self._gpu_available: Optional[bool] = None self._weights_available: Optional[bool] = None - def _check_gpu(self) -> bool: - """检查 GPU 是否可用""" - if self._gpu_available is not None: - return self._gpu_available - - try: - import torch - self._gpu_available = torch.cuda.is_available() - if self._gpu_available: - device_name = torch.cuda.get_device_name(0) - logger.info(f"✅ GPU 可用: {device_name}") - else: - logger.warning("⚠️ GPU 不可用,将使用 Fallback 模式") - except ImportError: - self._gpu_available = False - logger.warning("⚠️ PyTorch 未安装,将使用 Fallback 模式") - - return self._gpu_available - def _check_weights(self) -> bool: """检查模型权重是否存在""" if self._weights_available is not None: return self._weights_available - # 检查关键权重文件 required_dirs = [ self.musetalk_dir / "models" / "musetalkV15", self.musetalk_dir / "models" / "whisper", + self.musetalk_dir / "models" / "sd-vae-ft-mse", ] self._weights_available = all(d.exists() for d in required_dirs) @@ -82,84 +53,12 @@ class LipSyncService: return self._weights_available - def _load_models(self): - """懒加载 MuseTalk 模型 (Python API 方式)""" - if self._model_loaded: - return True - - if not self._check_gpu() or not self._check_weights(): - return False - - logger.info("🔄 加载 MuseTalk 模型到 GPU...") - - try: - # 添加 MuseTalk 到 Python 路径 - if str(self.musetalk_dir) not in sys.path: - sys.path.insert(0, str(self.musetalk_dir)) - logger.debug(f"Added to sys.path: {self.musetalk_dir}") - - import torch - from omegaconf import OmegaConf - from transformers import WhisperModel - - # 导入 MuseTalk 模块 - from musetalk.utils.utils import load_all_model - from musetalk.utils.audio_processor import AudioProcessor - from musetalk.utils.face_parsing import FaceParsing - - # 设置设备 (CUDA_VISIBLE_DEVICES=1 后,可见设备变为 cuda:0) - self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - # 加载模型 - unet_model_path = str(self.musetalk_dir / "models" / "musetalkV15" / "unet.pth") - unet_config = str(self.musetalk_dir / "models" / "musetalk" / "config.json") - whisper_dir = str(self.musetalk_dir / "models" / "whisper") - - self._vae, self._unet, self._pe = load_all_model( - unet_model_path=unet_model_path, - vae_type="sd-vae", - unet_config=unet_config, - device=self._device - ) - - # 使用半精度加速 - if settings.MUSETALK_USE_FLOAT16: - self._pe = self._pe.half() - self._vae.vae = self._vae.vae.half() - self._unet.model = self._unet.model.half() - - # 移动到 GPU - self._pe = self._pe.to(self._device) - self._vae.vae = self._vae.vae.to(self._device) - self._unet.model = self._unet.model.to(self._device) - - # 加载 Whisper - weight_dtype = self._unet.model.dtype - self._whisper = WhisperModel.from_pretrained(whisper_dir) - self._whisper = self._whisper.to(device=self._device, dtype=weight_dtype).eval() - self._whisper.requires_grad_(False) - - # 音频处理器 - self._audio_processor = AudioProcessor(feature_extractor_path=whisper_dir) - - # 人脸解析器 (v15 版本支持更多参数) - if self.version == "v15": - self._face_parser = FaceParsing( - left_cheek_width=90, - right_cheek_width=90 - ) - else: - self._face_parser = FaceParsing() - - self._model_loaded = True - logger.info("✅ MuseTalk 模型加载完成") - return True - - except Exception as e: - logger.error(f"❌ MuseTalk 模型加载失败: {e}") - import traceback - logger.debug(traceback.format_exc()) + def _check_conda_env(self) -> bool: + """检查 conda 环境是否可用""" + if not self.conda_python.exists(): + logger.warning(f"⚠️ Conda Python 不存在: {self.conda_python}") return False + return True async def generate( self, @@ -172,225 +71,106 @@ class LipSyncService: logger.info(f"🎬 唇形同步任务: {Path(video_path).name} + {Path(audio_path).name}") Path(output_path).parent.mkdir(parents=True, exist_ok=True) - # 决定使用哪种模式 if self.use_local: - if self._load_models(): - return await self._local_generate_api(video_path, audio_path, output_path, fps) - else: - logger.warning("⚠️ 本地推理失败,尝试 subprocess 方式") - return await self._local_generate_subprocess(video_path, audio_path, output_path, fps) + return await self._local_generate(video_path, audio_path, output_path, fps) else: return await self._remote_generate(video_path, audio_path, output_path, fps) - async def _local_generate_api( + async def _local_generate( self, video_path: str, audio_path: str, output_path: str, fps: int ) -> str: - """使用 Python API 进行本地推理""" - import torch - import cv2 - import copy - import glob - import pickle - import numpy as np - from tqdm import tqdm + """使用 subprocess 调用 MuseTalk conda 环境""" - from musetalk.utils.utils import get_file_type, get_video_fps, datagen - from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder - from musetalk.utils.blending import get_image - - logger.info("🔄 开始 MuseTalk 推理 (Python API)...") - - with tempfile.TemporaryDirectory() as tmpdir: - tmpdir = Path(tmpdir) - result_img_dir = tmpdir / "frames" - result_img_dir.mkdir() - - # 1. 提取视频帧 - logger.info("📹 提取视频帧...") - if get_file_type(video_path) == "video": - frames_dir = tmpdir / "input_frames" - frames_dir.mkdir() - cmd = f'ffmpeg -v fatal -i "{video_path}" -start_number 0 "{frames_dir}/%08d.png"' - subprocess.run(cmd, shell=True, check=True) - input_img_list = sorted(glob.glob(str(frames_dir / "*.png"))) - video_fps = get_video_fps(video_path) - else: - input_img_list = [video_path] - video_fps = fps - - # 2. 提取音频特征 - logger.info("🎵 提取音频特征...") - whisper_input_features, librosa_length = self._audio_processor.get_audio_feature(audio_path) - weight_dtype = self._unet.model.dtype - whisper_chunks = self._audio_processor.get_whisper_chunk( - whisper_input_features, - self._device, - weight_dtype, - self._whisper, - librosa_length, - fps=video_fps, - audio_padding_length_left=2, - audio_padding_length_right=2, - ) - - # 3. 预处理图像 - logger.info("🧑 检测人脸关键点...") - coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift=0) - - # 4. 编码潜在表示 - logger.info("🔢 编码图像潜在表示...") - input_latent_list = [] - for bbox, frame in zip(coord_list, frame_list): - if bbox == coord_placeholder: - continue - x1, y1, x2, y2 = bbox - if self.version == "v15": - y2 = min(y2 + 10, frame.shape[0]) - crop_frame = frame[y1:y2, x1:x2] - crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4) - latents = self._vae.get_latents_for_unet(crop_frame) - input_latent_list.append(latents) - - # 循环帧列表 - frame_list_cycle = frame_list + frame_list[::-1] - coord_list_cycle = coord_list + coord_list[::-1] - input_latent_list_cycle = input_latent_list + input_latent_list[::-1] - - # 5. 批量推理 - logger.info("🤖 执行 MuseTalk 推理...") - timesteps = torch.tensor([0], device=self._device) - batch_size = settings.MUSETALK_BATCH_SIZE - video_num = len(whisper_chunks) - - gen = datagen( - whisper_chunks=whisper_chunks, - vae_encode_latents=input_latent_list_cycle, - batch_size=batch_size, - delay_frame=0, - device=self._device, - ) - - res_frame_list = [] - total = int(np.ceil(float(video_num) / batch_size)) - - with torch.no_grad(): - for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total, desc="推理")): - audio_feature_batch = self._pe(whisper_batch) - latent_batch = latent_batch.to(dtype=self._unet.model.dtype) - pred_latents = self._unet.model( - latent_batch, timesteps, encoder_hidden_states=audio_feature_batch - ).sample - recon = self._vae.decode_latents(pred_latents) - for res_frame in recon: - res_frame_list.append(res_frame) - - # 6. 合成结果帧 - logger.info("🖼️ 合成结果帧...") - for i, res_frame in enumerate(tqdm(res_frame_list, desc="合成")): - bbox = coord_list_cycle[i % len(coord_list_cycle)] - ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)]) - x1, y1, x2, y2 = bbox - - if self.version == "v15": - y2 = min(y2 + 10, ori_frame.shape[0]) - - try: - res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1)) - except: - continue - - if self.version == "v15": - combine_frame = get_image( - ori_frame, res_frame, [x1, y1, x2, y2], - mode="jaw", fp=self._face_parser - ) - else: - combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self._face_parser) - - cv2.imwrite(str(result_img_dir / f"{i:08d}.png"), combine_frame) - - # 7. 合成视频 - logger.info("🎬 合成最终视频...") - temp_video = tmpdir / "temp_video.mp4" - cmd_video = f'ffmpeg -y -v warning -r {video_fps} -f image2 -i "{result_img_dir}/%08d.png" -vcodec libx264 -vf format=yuv420p -crf 18 "{temp_video}"' - subprocess.run(cmd_video, shell=True, check=True) - - # 8. 添加音频 - cmd_audio = f'ffmpeg -y -v warning -i "{audio_path}" -i "{temp_video}" -c:v copy -c:a aac -shortest "{output_path}"' - subprocess.run(cmd_audio, shell=True, check=True) - - logger.info(f"✅ 唇形同步完成: {output_path}") - return output_path - - async def _local_generate_subprocess( - self, - video_path: str, - audio_path: str, - output_path: str, - fps: int - ) -> str: - """使用 subprocess 调用 MuseTalk CLI""" - logger.info("🔄 使用 subprocess 调用 MuseTalk...") - - # 如果权重不存在,直接 fallback - if not self._check_weights(): - logger.warning("⚠️ 权重不存在,使用 Fallback 模式") + # 检查前置条件 + if not self._check_conda_env(): + logger.warning("⚠️ Conda 环境不可用,使用 Fallback") shutil.copy(video_path, output_path) return output_path + if not self._check_weights(): + logger.warning("⚠️ 模型权重不存在,使用 Fallback") + shutil.copy(video_path, output_path) + return output_path + + logger.info("🔄 调用 MuseTalk 推理 (subprocess)...") + with tempfile.TemporaryDirectory() as tmpdir: - # 创建临时配置文件 - config_path = Path(tmpdir) / "inference_config.yaml" + tmpdir = Path(tmpdir) + + # 创建推理配置文件 + config_path = tmpdir / "inference_config.yaml" + result_dir = tmpdir / "results" + result_dir.mkdir() + + # 配置文件内容 config_content = f""" -task1: +task_0: video_path: "{video_path}" audio_path: "{audio_path}" - result_name: "output.mp4" """ config_path.write_text(config_content) - result_dir = Path(tmpdir) / "results" - result_dir.mkdir() - + # 构建命令 cmd = [ - sys.executable, "-m", "scripts.inference", - "--version", self.version, + str(self.conda_python), + "-m", "scripts.inference", "--inference_config", str(config_path), "--result_dir", str(result_dir), - "--gpu_id", "0", # 因为 CUDA_VISIBLE_DEVICES 已设置 + "--version", self.version, + "--gpu_id", "0", # CUDA_VISIBLE_DEVICES 设置后,可见设备为 0 + "--batch_size", str(settings.MUSETALK_BATCH_SIZE), ] if settings.MUSETALK_USE_FLOAT16: cmd.append("--use_float16") - result = subprocess.run( - cmd, - cwd=str(self.musetalk_dir), - capture_output=True, - text=True, - env={**os.environ, "CUDA_VISIBLE_DEVICES": str(settings.MUSETALK_GPU_ID)} - ) + # 设置环境变量 + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id) - if result.returncode != 0: - logger.error(f"MuseTalk CLI 失败: {result.stderr}") - # Fallback + logger.info(f"🖥️ 执行命令: {' '.join(cmd[:6])}...") + + try: + result = subprocess.run( + cmd, + cwd=str(self.musetalk_dir), + env=env, + capture_output=True, + text=True, + timeout=600 # 10分钟超时 + ) + + if result.returncode != 0: + logger.error(f"MuseTalk 推理失败:\n{result.stderr}") + # Fallback + shutil.copy(video_path, output_path) + return output_path + + logger.info(f"MuseTalk 输出:\n{result.stdout[-500:]}") + + # 查找输出文件 + output_files = list(result_dir.rglob("*.mp4")) + if output_files: + shutil.copy(output_files[0], output_path) + logger.info(f"✅ 唇形同步完成: {output_path}") + return output_path + else: + logger.warning("⚠️ 未找到输出文件,使用 Fallback") + shutil.copy(video_path, output_path) + return output_path + + except subprocess.TimeoutExpired: + logger.error("⏰ MuseTalk 推理超时") shutil.copy(video_path, output_path) return output_path - - # 查找输出文件 - output_files = list(result_dir.rglob("*.mp4")) - if output_files: - shutil.copy(output_files[0], output_path) - logger.info(f"✅ 唇形同步完成: {output_path}") - else: - logger.warning("⚠️ 未找到输出文件,使用 Fallback") + except Exception as e: + logger.error(f"❌ 推理异常: {e}") shutil.copy(video_path, output_path) - - return output_path + return output_path async def _remote_generate( self, @@ -404,7 +184,6 @@ task1: try: async with httpx.AsyncClient(timeout=300.0) as client: - # 上传文件 with open(video_path, "rb") as vf, open(audio_path, "rb") as af: files = { "video": (Path(video_path).name, vf, "video/mp4"), @@ -419,30 +198,46 @@ task1: ) if response.status_code == 200: - # 保存响应视频 with open(output_path, "wb") as f: f.write(response.content) logger.info(f"✅ 远程推理完成: {output_path}") return output_path else: - raise RuntimeError(f"API 错误: {response.status_code} - {response.text}") + raise RuntimeError(f"API 错误: {response.status_code}") except Exception as e: logger.error(f"远程 API 调用失败: {e}") - # Fallback shutil.copy(video_path, output_path) return output_path - async def check_health(self) -> bool: + async def check_health(self) -> dict: """健康检查""" - if self.use_local: - gpu_ok = self._check_gpu() - weights_ok = self._check_weights() - return gpu_ok and weights_ok - else: + conda_ok = self._check_conda_env() + weights_ok = self._check_weights() + + # 检查 GPU + gpu_ok = False + gpu_name = "Unknown" + if conda_ok: try: - async with httpx.AsyncClient(timeout=5.0) as client: - response = await client.get(f"{self.api_url}/health") - return response.status_code == 200 + result = subprocess.run( + [str(self.conda_python), "-c", + "import torch; print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A')"], + capture_output=True, + text=True, + env={**os.environ, "CUDA_VISIBLE_DEVICES": str(self.gpu_id)}, + timeout=10 + ) + gpu_name = result.stdout.strip() + gpu_ok = gpu_name != "N/A" and result.returncode == 0 except: - return False + pass + + return { + "conda_env": conda_ok, + "weights": weights_ok, + "gpu": gpu_ok, + "gpu_name": gpu_name, + "gpu_id": self.gpu_id, + "ready": conda_ok and weights_ok and gpu_ok + }