修改成Subprocess调用

This commit is contained in:
Kevin Wong
2026-01-14 17:07:17 +08:00
parent 6470f45aa0
commit b1c3d25298

View File

@@ -1,75 +1,46 @@
"""
唇形同步服务
支持本地 MuseTalk 推理 (Python API) 或远程 MuseTalk API
通过 subprocess 调用 MuseTalk conda 环境进行推理
配置为使用 GPU1 (CUDA:1)
"""
import os
import sys
import shutil
import subprocess
import tempfile
import httpx
from pathlib import Path
from loguru import logger
from typing import Optional, Any
from typing import Optional
from app.core.config import settings
# 设置 MuseTalk 使用 GPU1 (在导入 torch 之前设置)
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(settings.MUSETALK_GPU_ID))
class LipSyncService:
"""唇形同步服务 - MuseTalk 集成"""
"""唇形同步服务 - MuseTalk 集成 (Subprocess 方式)"""
def __init__(self):
self.use_local = settings.MUSETALK_LOCAL
self.api_url = settings.MUSETALK_API_URL
self.version = settings.MUSETALK_VERSION
self.musetalk_dir = settings.MUSETALK_DIR
self.gpu_id = settings.MUSETALK_GPU_ID
# 模型相关 (懒加载)
self._model_loaded = False
self._vae = None
self._unet = None
self._pe = None
self._whisper = None
self._audio_processor = None
self._face_parser = None
self._device = None
# Conda 环境 Python 路径
# 根据服务器实际情况调整
self.conda_python = Path.home() / "ProgramFiles" / "miniconda3" / "envs" / "musetalk" / "bin" / "python"
# 运行时检测
self._gpu_available: Optional[bool] = None
self._weights_available: Optional[bool] = None
def _check_gpu(self) -> bool:
"""检查 GPU 是否可用"""
if self._gpu_available is not None:
return self._gpu_available
try:
import torch
self._gpu_available = torch.cuda.is_available()
if self._gpu_available:
device_name = torch.cuda.get_device_name(0)
logger.info(f"✅ GPU 可用: {device_name}")
else:
logger.warning("⚠️ GPU 不可用,将使用 Fallback 模式")
except ImportError:
self._gpu_available = False
logger.warning("⚠️ PyTorch 未安装,将使用 Fallback 模式")
return self._gpu_available
def _check_weights(self) -> bool:
"""检查模型权重是否存在"""
if self._weights_available is not None:
return self._weights_available
# 检查关键权重文件
required_dirs = [
self.musetalk_dir / "models" / "musetalkV15",
self.musetalk_dir / "models" / "whisper",
self.musetalk_dir / "models" / "sd-vae-ft-mse",
]
self._weights_available = all(d.exists() for d in required_dirs)
@@ -82,85 +53,13 @@ class LipSyncService:
return self._weights_available
def _load_models(self):
"""懒加载 MuseTalk 模型 (Python API 方式)"""
if self._model_loaded:
return True
if not self._check_gpu() or not self._check_weights():
def _check_conda_env(self) -> bool:
"""检查 conda 环境是否可用"""
if not self.conda_python.exists():
logger.warning(f"⚠️ Conda Python 不存在: {self.conda_python}")
return False
logger.info("🔄 加载 MuseTalk 模型到 GPU...")
try:
# 添加 MuseTalk 到 Python 路径
if str(self.musetalk_dir) not in sys.path:
sys.path.insert(0, str(self.musetalk_dir))
logger.debug(f"Added to sys.path: {self.musetalk_dir}")
import torch
from omegaconf import OmegaConf
from transformers import WhisperModel
# 导入 MuseTalk 模块
from musetalk.utils.utils import load_all_model
from musetalk.utils.audio_processor import AudioProcessor
from musetalk.utils.face_parsing import FaceParsing
# 设置设备 (CUDA_VISIBLE_DEVICES=1 后,可见设备变为 cuda:0)
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 加载模型
unet_model_path = str(self.musetalk_dir / "models" / "musetalkV15" / "unet.pth")
unet_config = str(self.musetalk_dir / "models" / "musetalk" / "config.json")
whisper_dir = str(self.musetalk_dir / "models" / "whisper")
self._vae, self._unet, self._pe = load_all_model(
unet_model_path=unet_model_path,
vae_type="sd-vae",
unet_config=unet_config,
device=self._device
)
# 使用半精度加速
if settings.MUSETALK_USE_FLOAT16:
self._pe = self._pe.half()
self._vae.vae = self._vae.vae.half()
self._unet.model = self._unet.model.half()
# 移动到 GPU
self._pe = self._pe.to(self._device)
self._vae.vae = self._vae.vae.to(self._device)
self._unet.model = self._unet.model.to(self._device)
# 加载 Whisper
weight_dtype = self._unet.model.dtype
self._whisper = WhisperModel.from_pretrained(whisper_dir)
self._whisper = self._whisper.to(device=self._device, dtype=weight_dtype).eval()
self._whisper.requires_grad_(False)
# 音频处理器
self._audio_processor = AudioProcessor(feature_extractor_path=whisper_dir)
# 人脸解析器 (v15 版本支持更多参数)
if self.version == "v15":
self._face_parser = FaceParsing(
left_cheek_width=90,
right_cheek_width=90
)
else:
self._face_parser = FaceParsing()
self._model_loaded = True
logger.info("✅ MuseTalk 模型加载完成")
return True
except Exception as e:
logger.error(f"❌ MuseTalk 模型加载失败: {e}")
import traceback
logger.debug(traceback.format_exc())
return False
async def generate(
self,
video_path: str,
@@ -172,224 +71,105 @@ class LipSyncService:
logger.info(f"🎬 唇形同步任务: {Path(video_path).name} + {Path(audio_path).name}")
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
# 决定使用哪种模式
if self.use_local:
if self._load_models():
return await self._local_generate_api(video_path, audio_path, output_path, fps)
else:
logger.warning("⚠️ 本地推理失败,尝试 subprocess 方式")
return await self._local_generate_subprocess(video_path, audio_path, output_path, fps)
return await self._local_generate(video_path, audio_path, output_path, fps)
else:
return await self._remote_generate(video_path, audio_path, output_path, fps)
async def _local_generate_api(
async def _local_generate(
self,
video_path: str,
audio_path: str,
output_path: str,
fps: int
) -> str:
"""使用 Python API 进行本地推理"""
import torch
import cv2
import copy
import glob
import pickle
import numpy as np
from tqdm import tqdm
"""使用 subprocess 调用 MuseTalk conda 环境"""
from musetalk.utils.utils import get_file_type, get_video_fps, datagen
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
from musetalk.utils.blending import get_image
logger.info("🔄 开始 MuseTalk 推理 (Python API)...")
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
result_img_dir = tmpdir / "frames"
result_img_dir.mkdir()
# 1. 提取视频帧
logger.info("📹 提取视频帧...")
if get_file_type(video_path) == "video":
frames_dir = tmpdir / "input_frames"
frames_dir.mkdir()
cmd = f'ffmpeg -v fatal -i "{video_path}" -start_number 0 "{frames_dir}/%08d.png"'
subprocess.run(cmd, shell=True, check=True)
input_img_list = sorted(glob.glob(str(frames_dir / "*.png")))
video_fps = get_video_fps(video_path)
else:
input_img_list = [video_path]
video_fps = fps
# 2. 提取音频特征
logger.info("🎵 提取音频特征...")
whisper_input_features, librosa_length = self._audio_processor.get_audio_feature(audio_path)
weight_dtype = self._unet.model.dtype
whisper_chunks = self._audio_processor.get_whisper_chunk(
whisper_input_features,
self._device,
weight_dtype,
self._whisper,
librosa_length,
fps=video_fps,
audio_padding_length_left=2,
audio_padding_length_right=2,
)
# 3. 预处理图像
logger.info("🧑 检测人脸关键点...")
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift=0)
# 4. 编码潜在表示
logger.info("🔢 编码图像潜在表示...")
input_latent_list = []
for bbox, frame in zip(coord_list, frame_list):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
if self.version == "v15":
y2 = min(y2 + 10, frame.shape[0])
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
latents = self._vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
# 循环帧列表
frame_list_cycle = frame_list + frame_list[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
# 5. 批量推理
logger.info("🤖 执行 MuseTalk 推理...")
timesteps = torch.tensor([0], device=self._device)
batch_size = settings.MUSETALK_BATCH_SIZE
video_num = len(whisper_chunks)
gen = datagen(
whisper_chunks=whisper_chunks,
vae_encode_latents=input_latent_list_cycle,
batch_size=batch_size,
delay_frame=0,
device=self._device,
)
res_frame_list = []
total = int(np.ceil(float(video_num) / batch_size))
with torch.no_grad():
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total, desc="推理")):
audio_feature_batch = self._pe(whisper_batch)
latent_batch = latent_batch.to(dtype=self._unet.model.dtype)
pred_latents = self._unet.model(
latent_batch, timesteps, encoder_hidden_states=audio_feature_batch
).sample
recon = self._vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_list.append(res_frame)
# 6. 合成结果帧
logger.info("🖼️ 合成结果帧...")
for i, res_frame in enumerate(tqdm(res_frame_list, desc="合成")):
bbox = coord_list_cycle[i % len(coord_list_cycle)]
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
x1, y1, x2, y2 = bbox
if self.version == "v15":
y2 = min(y2 + 10, ori_frame.shape[0])
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
except:
continue
if self.version == "v15":
combine_frame = get_image(
ori_frame, res_frame, [x1, y1, x2, y2],
mode="jaw", fp=self._face_parser
)
else:
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self._face_parser)
cv2.imwrite(str(result_img_dir / f"{i:08d}.png"), combine_frame)
# 7. 合成视频
logger.info("🎬 合成最终视频...")
temp_video = tmpdir / "temp_video.mp4"
cmd_video = f'ffmpeg -y -v warning -r {video_fps} -f image2 -i "{result_img_dir}/%08d.png" -vcodec libx264 -vf format=yuv420p -crf 18 "{temp_video}"'
subprocess.run(cmd_video, shell=True, check=True)
# 8. 添加音频
cmd_audio = f'ffmpeg -y -v warning -i "{audio_path}" -i "{temp_video}" -c:v copy -c:a aac -shortest "{output_path}"'
subprocess.run(cmd_audio, shell=True, check=True)
logger.info(f"✅ 唇形同步完成: {output_path}")
return output_path
async def _local_generate_subprocess(
self,
video_path: str,
audio_path: str,
output_path: str,
fps: int
) -> str:
"""使用 subprocess 调用 MuseTalk CLI"""
logger.info("🔄 使用 subprocess 调用 MuseTalk...")
# 如果权重不存在,直接 fallback
if not self._check_weights():
logger.warning("⚠️ 权重不存在,使用 Fallback 模式")
# 检查前置条件
if not self._check_conda_env():
logger.warning("⚠️ Conda 环境不可用,使用 Fallback")
shutil.copy(video_path, output_path)
return output_path
if not self._check_weights():
logger.warning("⚠️ 模型权重不存在,使用 Fallback")
shutil.copy(video_path, output_path)
return output_path
logger.info("🔄 调用 MuseTalk 推理 (subprocess)...")
with tempfile.TemporaryDirectory() as tmpdir:
# 创建临时配置文件
config_path = Path(tmpdir) / "inference_config.yaml"
tmpdir = Path(tmpdir)
# 创建推理配置文件
config_path = tmpdir / "inference_config.yaml"
result_dir = tmpdir / "results"
result_dir.mkdir()
# 配置文件内容
config_content = f"""
task1:
task_0:
video_path: "{video_path}"
audio_path: "{audio_path}"
result_name: "output.mp4"
"""
config_path.write_text(config_content)
result_dir = Path(tmpdir) / "results"
result_dir.mkdir()
# 构建命令
cmd = [
sys.executable, "-m", "scripts.inference",
"--version", self.version,
str(self.conda_python),
"-m", "scripts.inference",
"--inference_config", str(config_path),
"--result_dir", str(result_dir),
"--gpu_id", "0", # 因为 CUDA_VISIBLE_DEVICES 已设置
"--version", self.version,
"--gpu_id", "0", # CUDA_VISIBLE_DEVICES 设置后,可见设备为 0
"--batch_size", str(settings.MUSETALK_BATCH_SIZE),
]
if settings.MUSETALK_USE_FLOAT16:
cmd.append("--use_float16")
# 设置环境变量
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
logger.info(f"🖥️ 执行命令: {' '.join(cmd[:6])}...")
try:
result = subprocess.run(
cmd,
cwd=str(self.musetalk_dir),
env=env,
capture_output=True,
text=True,
env={**os.environ, "CUDA_VISIBLE_DEVICES": str(settings.MUSETALK_GPU_ID)}
timeout=600 # 10分钟超时
)
if result.returncode != 0:
logger.error(f"MuseTalk CLI 失败: {result.stderr}")
logger.error(f"MuseTalk 推理失败:\n{result.stderr}")
# Fallback
shutil.copy(video_path, output_path)
return output_path
logger.info(f"MuseTalk 输出:\n{result.stdout[-500:]}")
# 查找输出文件
output_files = list(result_dir.rglob("*.mp4"))
if output_files:
shutil.copy(output_files[0], output_path)
logger.info(f"✅ 唇形同步完成: {output_path}")
return output_path
else:
logger.warning("⚠️ 未找到输出文件,使用 Fallback")
shutil.copy(video_path, output_path)
return output_path
except subprocess.TimeoutExpired:
logger.error("⏰ MuseTalk 推理超时")
shutil.copy(video_path, output_path)
return output_path
except Exception as e:
logger.error(f"❌ 推理异常: {e}")
shutil.copy(video_path, output_path)
return output_path
async def _remote_generate(
@@ -404,7 +184,6 @@ task1:
try:
async with httpx.AsyncClient(timeout=300.0) as client:
# 上传文件
with open(video_path, "rb") as vf, open(audio_path, "rb") as af:
files = {
"video": (Path(video_path).name, vf, "video/mp4"),
@@ -419,30 +198,46 @@ task1:
)
if response.status_code == 200:
# 保存响应视频
with open(output_path, "wb") as f:
f.write(response.content)
logger.info(f"✅ 远程推理完成: {output_path}")
return output_path
else:
raise RuntimeError(f"API 错误: {response.status_code} - {response.text}")
raise RuntimeError(f"API 错误: {response.status_code}")
except Exception as e:
logger.error(f"远程 API 调用失败: {e}")
# Fallback
shutil.copy(video_path, output_path)
return output_path
async def check_health(self) -> bool:
async def check_health(self) -> dict:
"""健康检查"""
if self.use_local:
gpu_ok = self._check_gpu()
conda_ok = self._check_conda_env()
weights_ok = self._check_weights()
return gpu_ok and weights_ok
else:
# 检查 GPU
gpu_ok = False
gpu_name = "Unknown"
if conda_ok:
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{self.api_url}/health")
return response.status_code == 200
result = subprocess.run(
[str(self.conda_python), "-c",
"import torch; print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A')"],
capture_output=True,
text=True,
env={**os.environ, "CUDA_VISIBLE_DEVICES": str(self.gpu_id)},
timeout=10
)
gpu_name = result.stdout.strip()
gpu_ok = gpu_name != "N/A" and result.returncode == 0
except:
return False
pass
return {
"conda_env": conda_ok,
"weights": weights_ok,
"gpu": gpu_ok,
"gpu_name": gpu_name,
"gpu_id": self.gpu_id,
"ready": conda_ok and weights_ok and gpu_ok
}