修改成Subprocess调用
This commit is contained in:
@@ -1,75 +1,46 @@
|
||||
"""
|
||||
唇形同步服务
|
||||
支持本地 MuseTalk 推理 (Python API) 或远程 MuseTalk API
|
||||
通过 subprocess 调用 MuseTalk conda 环境进行推理
|
||||
配置为使用 GPU1 (CUDA:1)
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from typing import Optional, Any
|
||||
from typing import Optional
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# 设置 MuseTalk 使用 GPU1 (在导入 torch 之前设置)
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(settings.MUSETALK_GPU_ID))
|
||||
|
||||
|
||||
class LipSyncService:
|
||||
"""唇形同步服务 - MuseTalk 集成"""
|
||||
"""唇形同步服务 - MuseTalk 集成 (Subprocess 方式)"""
|
||||
|
||||
def __init__(self):
|
||||
self.use_local = settings.MUSETALK_LOCAL
|
||||
self.api_url = settings.MUSETALK_API_URL
|
||||
self.version = settings.MUSETALK_VERSION
|
||||
self.musetalk_dir = settings.MUSETALK_DIR
|
||||
self.gpu_id = settings.MUSETALK_GPU_ID
|
||||
|
||||
# 模型相关 (懒加载)
|
||||
self._model_loaded = False
|
||||
self._vae = None
|
||||
self._unet = None
|
||||
self._pe = None
|
||||
self._whisper = None
|
||||
self._audio_processor = None
|
||||
self._face_parser = None
|
||||
self._device = None
|
||||
# Conda 环境 Python 路径
|
||||
# 根据服务器实际情况调整
|
||||
self.conda_python = Path.home() / "ProgramFiles" / "miniconda3" / "envs" / "musetalk" / "bin" / "python"
|
||||
|
||||
# 运行时检测
|
||||
self._gpu_available: Optional[bool] = None
|
||||
self._weights_available: Optional[bool] = None
|
||||
|
||||
def _check_gpu(self) -> bool:
|
||||
"""检查 GPU 是否可用"""
|
||||
if self._gpu_available is not None:
|
||||
return self._gpu_available
|
||||
|
||||
try:
|
||||
import torch
|
||||
self._gpu_available = torch.cuda.is_available()
|
||||
if self._gpu_available:
|
||||
device_name = torch.cuda.get_device_name(0)
|
||||
logger.info(f"✅ GPU 可用: {device_name}")
|
||||
else:
|
||||
logger.warning("⚠️ GPU 不可用,将使用 Fallback 模式")
|
||||
except ImportError:
|
||||
self._gpu_available = False
|
||||
logger.warning("⚠️ PyTorch 未安装,将使用 Fallback 模式")
|
||||
|
||||
return self._gpu_available
|
||||
|
||||
def _check_weights(self) -> bool:
|
||||
"""检查模型权重是否存在"""
|
||||
if self._weights_available is not None:
|
||||
return self._weights_available
|
||||
|
||||
# 检查关键权重文件
|
||||
required_dirs = [
|
||||
self.musetalk_dir / "models" / "musetalkV15",
|
||||
self.musetalk_dir / "models" / "whisper",
|
||||
self.musetalk_dir / "models" / "sd-vae-ft-mse",
|
||||
]
|
||||
|
||||
self._weights_available = all(d.exists() for d in required_dirs)
|
||||
@@ -82,85 +53,13 @@ class LipSyncService:
|
||||
|
||||
return self._weights_available
|
||||
|
||||
def _load_models(self):
|
||||
"""懒加载 MuseTalk 模型 (Python API 方式)"""
|
||||
if self._model_loaded:
|
||||
return True
|
||||
|
||||
if not self._check_gpu() or not self._check_weights():
|
||||
def _check_conda_env(self) -> bool:
|
||||
"""检查 conda 环境是否可用"""
|
||||
if not self.conda_python.exists():
|
||||
logger.warning(f"⚠️ Conda Python 不存在: {self.conda_python}")
|
||||
return False
|
||||
|
||||
logger.info("🔄 加载 MuseTalk 模型到 GPU...")
|
||||
|
||||
try:
|
||||
# 添加 MuseTalk 到 Python 路径
|
||||
if str(self.musetalk_dir) not in sys.path:
|
||||
sys.path.insert(0, str(self.musetalk_dir))
|
||||
logger.debug(f"Added to sys.path: {self.musetalk_dir}")
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from transformers import WhisperModel
|
||||
|
||||
# 导入 MuseTalk 模块
|
||||
from musetalk.utils.utils import load_all_model
|
||||
from musetalk.utils.audio_processor import AudioProcessor
|
||||
from musetalk.utils.face_parsing import FaceParsing
|
||||
|
||||
# 设置设备 (CUDA_VISIBLE_DEVICES=1 后,可见设备变为 cuda:0)
|
||||
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 加载模型
|
||||
unet_model_path = str(self.musetalk_dir / "models" / "musetalkV15" / "unet.pth")
|
||||
unet_config = str(self.musetalk_dir / "models" / "musetalk" / "config.json")
|
||||
whisper_dir = str(self.musetalk_dir / "models" / "whisper")
|
||||
|
||||
self._vae, self._unet, self._pe = load_all_model(
|
||||
unet_model_path=unet_model_path,
|
||||
vae_type="sd-vae",
|
||||
unet_config=unet_config,
|
||||
device=self._device
|
||||
)
|
||||
|
||||
# 使用半精度加速
|
||||
if settings.MUSETALK_USE_FLOAT16:
|
||||
self._pe = self._pe.half()
|
||||
self._vae.vae = self._vae.vae.half()
|
||||
self._unet.model = self._unet.model.half()
|
||||
|
||||
# 移动到 GPU
|
||||
self._pe = self._pe.to(self._device)
|
||||
self._vae.vae = self._vae.vae.to(self._device)
|
||||
self._unet.model = self._unet.model.to(self._device)
|
||||
|
||||
# 加载 Whisper
|
||||
weight_dtype = self._unet.model.dtype
|
||||
self._whisper = WhisperModel.from_pretrained(whisper_dir)
|
||||
self._whisper = self._whisper.to(device=self._device, dtype=weight_dtype).eval()
|
||||
self._whisper.requires_grad_(False)
|
||||
|
||||
# 音频处理器
|
||||
self._audio_processor = AudioProcessor(feature_extractor_path=whisper_dir)
|
||||
|
||||
# 人脸解析器 (v15 版本支持更多参数)
|
||||
if self.version == "v15":
|
||||
self._face_parser = FaceParsing(
|
||||
left_cheek_width=90,
|
||||
right_cheek_width=90
|
||||
)
|
||||
else:
|
||||
self._face_parser = FaceParsing()
|
||||
|
||||
self._model_loaded = True
|
||||
logger.info("✅ MuseTalk 模型加载完成")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ MuseTalk 模型加载失败: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
video_path: str,
|
||||
@@ -172,224 +71,105 @@ class LipSyncService:
|
||||
logger.info(f"🎬 唇形同步任务: {Path(video_path).name} + {Path(audio_path).name}")
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 决定使用哪种模式
|
||||
if self.use_local:
|
||||
if self._load_models():
|
||||
return await self._local_generate_api(video_path, audio_path, output_path, fps)
|
||||
else:
|
||||
logger.warning("⚠️ 本地推理失败,尝试 subprocess 方式")
|
||||
return await self._local_generate_subprocess(video_path, audio_path, output_path, fps)
|
||||
return await self._local_generate(video_path, audio_path, output_path, fps)
|
||||
else:
|
||||
return await self._remote_generate(video_path, audio_path, output_path, fps)
|
||||
|
||||
async def _local_generate_api(
|
||||
async def _local_generate(
|
||||
self,
|
||||
video_path: str,
|
||||
audio_path: str,
|
||||
output_path: str,
|
||||
fps: int
|
||||
) -> str:
|
||||
"""使用 Python API 进行本地推理"""
|
||||
import torch
|
||||
import cv2
|
||||
import copy
|
||||
import glob
|
||||
import pickle
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
"""使用 subprocess 调用 MuseTalk conda 环境"""
|
||||
|
||||
from musetalk.utils.utils import get_file_type, get_video_fps, datagen
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
||||
from musetalk.utils.blending import get_image
|
||||
|
||||
logger.info("🔄 开始 MuseTalk 推理 (Python API)...")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir = Path(tmpdir)
|
||||
result_img_dir = tmpdir / "frames"
|
||||
result_img_dir.mkdir()
|
||||
|
||||
# 1. 提取视频帧
|
||||
logger.info("📹 提取视频帧...")
|
||||
if get_file_type(video_path) == "video":
|
||||
frames_dir = tmpdir / "input_frames"
|
||||
frames_dir.mkdir()
|
||||
cmd = f'ffmpeg -v fatal -i "{video_path}" -start_number 0 "{frames_dir}/%08d.png"'
|
||||
subprocess.run(cmd, shell=True, check=True)
|
||||
input_img_list = sorted(glob.glob(str(frames_dir / "*.png")))
|
||||
video_fps = get_video_fps(video_path)
|
||||
else:
|
||||
input_img_list = [video_path]
|
||||
video_fps = fps
|
||||
|
||||
# 2. 提取音频特征
|
||||
logger.info("🎵 提取音频特征...")
|
||||
whisper_input_features, librosa_length = self._audio_processor.get_audio_feature(audio_path)
|
||||
weight_dtype = self._unet.model.dtype
|
||||
whisper_chunks = self._audio_processor.get_whisper_chunk(
|
||||
whisper_input_features,
|
||||
self._device,
|
||||
weight_dtype,
|
||||
self._whisper,
|
||||
librosa_length,
|
||||
fps=video_fps,
|
||||
audio_padding_length_left=2,
|
||||
audio_padding_length_right=2,
|
||||
)
|
||||
|
||||
# 3. 预处理图像
|
||||
logger.info("🧑 检测人脸关键点...")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift=0)
|
||||
|
||||
# 4. 编码潜在表示
|
||||
logger.info("🔢 编码图像潜在表示...")
|
||||
input_latent_list = []
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
if self.version == "v15":
|
||||
y2 = min(y2 + 10, frame.shape[0])
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
|
||||
latents = self._vae.get_latents_for_unet(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
|
||||
# 循环帧列表
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
|
||||
# 5. 批量推理
|
||||
logger.info("🤖 执行 MuseTalk 推理...")
|
||||
timesteps = torch.tensor([0], device=self._device)
|
||||
batch_size = settings.MUSETALK_BATCH_SIZE
|
||||
video_num = len(whisper_chunks)
|
||||
|
||||
gen = datagen(
|
||||
whisper_chunks=whisper_chunks,
|
||||
vae_encode_latents=input_latent_list_cycle,
|
||||
batch_size=batch_size,
|
||||
delay_frame=0,
|
||||
device=self._device,
|
||||
)
|
||||
|
||||
res_frame_list = []
|
||||
total = int(np.ceil(float(video_num) / batch_size))
|
||||
|
||||
with torch.no_grad():
|
||||
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total, desc="推理")):
|
||||
audio_feature_batch = self._pe(whisper_batch)
|
||||
latent_batch = latent_batch.to(dtype=self._unet.model.dtype)
|
||||
pred_latents = self._unet.model(
|
||||
latent_batch, timesteps, encoder_hidden_states=audio_feature_batch
|
||||
).sample
|
||||
recon = self._vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
res_frame_list.append(res_frame)
|
||||
|
||||
# 6. 合成结果帧
|
||||
logger.info("🖼️ 合成结果帧...")
|
||||
for i, res_frame in enumerate(tqdm(res_frame_list, desc="合成")):
|
||||
bbox = coord_list_cycle[i % len(coord_list_cycle)]
|
||||
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
||||
if self.version == "v15":
|
||||
y2 = min(y2 + 10, ori_frame.shape[0])
|
||||
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
|
||||
except:
|
||||
continue
|
||||
|
||||
if self.version == "v15":
|
||||
combine_frame = get_image(
|
||||
ori_frame, res_frame, [x1, y1, x2, y2],
|
||||
mode="jaw", fp=self._face_parser
|
||||
)
|
||||
else:
|
||||
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self._face_parser)
|
||||
|
||||
cv2.imwrite(str(result_img_dir / f"{i:08d}.png"), combine_frame)
|
||||
|
||||
# 7. 合成视频
|
||||
logger.info("🎬 合成最终视频...")
|
||||
temp_video = tmpdir / "temp_video.mp4"
|
||||
cmd_video = f'ffmpeg -y -v warning -r {video_fps} -f image2 -i "{result_img_dir}/%08d.png" -vcodec libx264 -vf format=yuv420p -crf 18 "{temp_video}"'
|
||||
subprocess.run(cmd_video, shell=True, check=True)
|
||||
|
||||
# 8. 添加音频
|
||||
cmd_audio = f'ffmpeg -y -v warning -i "{audio_path}" -i "{temp_video}" -c:v copy -c:a aac -shortest "{output_path}"'
|
||||
subprocess.run(cmd_audio, shell=True, check=True)
|
||||
|
||||
logger.info(f"✅ 唇形同步完成: {output_path}")
|
||||
return output_path
|
||||
|
||||
async def _local_generate_subprocess(
|
||||
self,
|
||||
video_path: str,
|
||||
audio_path: str,
|
||||
output_path: str,
|
||||
fps: int
|
||||
) -> str:
|
||||
"""使用 subprocess 调用 MuseTalk CLI"""
|
||||
logger.info("🔄 使用 subprocess 调用 MuseTalk...")
|
||||
|
||||
# 如果权重不存在,直接 fallback
|
||||
if not self._check_weights():
|
||||
logger.warning("⚠️ 权重不存在,使用 Fallback 模式")
|
||||
# 检查前置条件
|
||||
if not self._check_conda_env():
|
||||
logger.warning("⚠️ Conda 环境不可用,使用 Fallback")
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
if not self._check_weights():
|
||||
logger.warning("⚠️ 模型权重不存在,使用 Fallback")
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
logger.info("🔄 调用 MuseTalk 推理 (subprocess)...")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# 创建临时配置文件
|
||||
config_path = Path(tmpdir) / "inference_config.yaml"
|
||||
tmpdir = Path(tmpdir)
|
||||
|
||||
# 创建推理配置文件
|
||||
config_path = tmpdir / "inference_config.yaml"
|
||||
result_dir = tmpdir / "results"
|
||||
result_dir.mkdir()
|
||||
|
||||
# 配置文件内容
|
||||
config_content = f"""
|
||||
task1:
|
||||
task_0:
|
||||
video_path: "{video_path}"
|
||||
audio_path: "{audio_path}"
|
||||
result_name: "output.mp4"
|
||||
"""
|
||||
config_path.write_text(config_content)
|
||||
|
||||
result_dir = Path(tmpdir) / "results"
|
||||
result_dir.mkdir()
|
||||
|
||||
# 构建命令
|
||||
cmd = [
|
||||
sys.executable, "-m", "scripts.inference",
|
||||
"--version", self.version,
|
||||
str(self.conda_python),
|
||||
"-m", "scripts.inference",
|
||||
"--inference_config", str(config_path),
|
||||
"--result_dir", str(result_dir),
|
||||
"--gpu_id", "0", # 因为 CUDA_VISIBLE_DEVICES 已设置
|
||||
"--version", self.version,
|
||||
"--gpu_id", "0", # CUDA_VISIBLE_DEVICES 设置后,可见设备为 0
|
||||
"--batch_size", str(settings.MUSETALK_BATCH_SIZE),
|
||||
]
|
||||
|
||||
if settings.MUSETALK_USE_FLOAT16:
|
||||
cmd.append("--use_float16")
|
||||
|
||||
# 设置环境变量
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
|
||||
|
||||
logger.info(f"🖥️ 执行命令: {' '.join(cmd[:6])}...")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=str(self.musetalk_dir),
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env={**os.environ, "CUDA_VISIBLE_DEVICES": str(settings.MUSETALK_GPU_ID)}
|
||||
timeout=600 # 10分钟超时
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error(f"MuseTalk CLI 失败: {result.stderr}")
|
||||
logger.error(f"MuseTalk 推理失败:\n{result.stderr}")
|
||||
# Fallback
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
logger.info(f"MuseTalk 输出:\n{result.stdout[-500:]}")
|
||||
|
||||
# 查找输出文件
|
||||
output_files = list(result_dir.rglob("*.mp4"))
|
||||
if output_files:
|
||||
shutil.copy(output_files[0], output_path)
|
||||
logger.info(f"✅ 唇形同步完成: {output_path}")
|
||||
return output_path
|
||||
else:
|
||||
logger.warning("⚠️ 未找到输出文件,使用 Fallback")
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("⏰ MuseTalk 推理超时")
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 推理异常: {e}")
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
async def _remote_generate(
|
||||
@@ -404,7 +184,6 @@ task1:
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
# 上传文件
|
||||
with open(video_path, "rb") as vf, open(audio_path, "rb") as af:
|
||||
files = {
|
||||
"video": (Path(video_path).name, vf, "video/mp4"),
|
||||
@@ -419,30 +198,46 @@ task1:
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
# 保存响应视频
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
logger.info(f"✅ 远程推理完成: {output_path}")
|
||||
return output_path
|
||||
else:
|
||||
raise RuntimeError(f"API 错误: {response.status_code} - {response.text}")
|
||||
raise RuntimeError(f"API 错误: {response.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"远程 API 调用失败: {e}")
|
||||
# Fallback
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
async def check_health(self) -> bool:
|
||||
async def check_health(self) -> dict:
|
||||
"""健康检查"""
|
||||
if self.use_local:
|
||||
gpu_ok = self._check_gpu()
|
||||
conda_ok = self._check_conda_env()
|
||||
weights_ok = self._check_weights()
|
||||
return gpu_ok and weights_ok
|
||||
else:
|
||||
|
||||
# 检查 GPU
|
||||
gpu_ok = False
|
||||
gpu_name = "Unknown"
|
||||
if conda_ok:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(f"{self.api_url}/health")
|
||||
return response.status_code == 200
|
||||
result = subprocess.run(
|
||||
[str(self.conda_python), "-c",
|
||||
"import torch; print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A')"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env={**os.environ, "CUDA_VISIBLE_DEVICES": str(self.gpu_id)},
|
||||
timeout=10
|
||||
)
|
||||
gpu_name = result.stdout.strip()
|
||||
gpu_ok = gpu_name != "N/A" and result.returncode == 0
|
||||
except:
|
||||
return False
|
||||
pass
|
||||
|
||||
return {
|
||||
"conda_env": conda_ok,
|
||||
"weights": weights_ok,
|
||||
"gpu": gpu_ok,
|
||||
"gpu_name": gpu_name,
|
||||
"gpu_id": self.gpu_id,
|
||||
"ready": conda_ok and weights_ok and gpu_ok
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user