Files
ViGent2/models/MuseTalk/scripts/server.py
Kevin Wong 0e3502c6f0 更新
2026-02-27 16:11:34 +08:00

573 lines
20 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
MuseTalk v1.5 常驻推理服务 (优化版 v2)
- 端口: 8011
- GPU: 从 backend/.env 读取 MUSETALK_GPU_ID (默认 0)
- 架构: FastAPI + lifespan (与 LatentSync server.py 同模式)
优化项 (vs v1):
1. cv2.VideoCapture 直读帧 (跳过 ffmpeg→PNG→imread)
2. 人脸检测降频 (每 N 帧检测, 中间插值 bbox)
3. BiSeNet mask 缓存 (每 N 帧更新, 中间复用)
4. cv2.VideoWriter 直写视频 (跳过逐帧 PNG 写盘)
5. batch_size 8→32
6. 每阶段计时
"""
import os
import sys
import math
import copy
import time
import glob
import shutil
import tempfile
import subprocess
from pathlib import Path
# --- 自动加载 GPU 配置 (必须在 torch 导入前) ---
def load_gpu_config():
"""尝试从后端 .env 文件读取 MUSETALK_GPU_ID"""
try:
current_dir = Path(__file__).resolve().parent
env_path = current_dir.parent.parent.parent / "backend" / ".env"
target_gpu = "0" # 默认 GPU 0
if env_path.exists():
print(f"📖 读取配置文件: {env_path}")
with open(env_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("MUSETALK_GPU_ID="):
val = line.split("=")[1].strip().split("#")[0].strip()
if val:
target_gpu = val
print(f"⚙️ 发现配置 MUSETALK_GPU_ID={target_gpu}")
break
if "CUDA_VISIBLE_DEVICES" not in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = target_gpu
print(f"✅ 已自动设置: CUDA_VISIBLE_DEVICES={target_gpu}")
else:
print(f" 检测到外部 CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']},跳过自动配置")
except Exception as e:
print(f"⚠️ 读取 GPU 配置失败: {e},将使用默认设置")
load_gpu_config()
# --- 性能优化: 限制 CPU 线程数 ---
os.environ["OMP_NUM_THREADS"] = "8"
os.environ["MKL_NUM_THREADS"] = "8"
os.environ["TORCH_NUM_THREADS"] = "8"
print("⚙️ 已限制 PyTorch CPU 线程数为 8防止系统卡顿")
import cv2
import torch
import pickle
import numpy as np
from tqdm import tqdm
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
from transformers import WhisperModel
# 添加项目根目录到 sys.path (MuseTalk 根目录)
musetalk_root = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(musetalk_root))
from musetalk.utils.blending import get_image, get_image_blending, get_image_prepare_material
from musetalk.utils.face_parsing import FaceParsing
from musetalk.utils.audio_processor import AudioProcessor
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
# --- 从 .env 读取额外配置 ---
def load_env_config():
"""读取 MuseTalk 相关环境变量"""
config = {
"batch_size": 32,
"version": "v15",
"use_float16": True,
}
try:
env_path = musetalk_root.parent.parent / "backend" / ".env"
if env_path.exists():
with open(env_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("MUSETALK_BATCH_SIZE="):
val = line.split("=")[1].strip().split("#")[0].strip()
if val:
config["batch_size"] = int(val)
elif line.startswith("MUSETALK_VERSION="):
val = line.split("=")[1].strip().split("#")[0].strip()
if val:
config["version"] = val
elif line.startswith("MUSETALK_USE_FLOAT16="):
val = line.split("=")[1].strip().split("#")[0].strip().lower()
config["use_float16"] = val in ("true", "1", "yes")
except Exception as e:
print(f"⚠️ 读取额外配置失败: {e}")
return config
env_config = load_env_config()
# 全局模型缓存
models = {}
# ===================== 优化参数 =====================
DETECT_EVERY = 5 # 人脸检测降频: 每 N 帧检测一次
BLEND_CACHE_EVERY = 5 # BiSeNet mask 缓存: 每 N 帧更新一次
# ====================================================
def run_ffmpeg(cmd):
"""执行 FFmpeg 命令"""
print(f"Executing: {cmd}")
try:
result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
return True
except subprocess.CalledProcessError as e:
print(f"Error executing ffmpeg: {cmd}")
print(f"Return code: {e.returncode}")
print(f"Stderr: {e.stderr[:500]}")
return False
@asynccontextmanager
async def lifespan(app: FastAPI):
"""启动时加载所有模型,只做一次"""
print("⏳ 正在加载 MuseTalk v1.5 模型...")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
version = env_config["version"]
use_float16 = env_config["use_float16"]
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
print(f"🖥️ 正在使用 GPU: {gpu_name}")
else:
print("⚠️ 警告: 未检测到 GPU将使用 CPU 进行推理 (速度极慢)")
# 根据版本选择模型路径
models_dir = musetalk_root / "models"
if version == "v15":
unet_model_path = str(models_dir / "musetalkV15" / "unet.pth")
unet_config = str(models_dir / "musetalk" / "config.json")
else:
unet_model_path = str(models_dir / "musetalk" / "pytorch_model.bin")
unet_config = str(models_dir / "musetalk" / "musetalk.json")
# 切换工作目录load_all_model 使用相对路径加载 VAE
original_cwd = os.getcwd()
os.chdir(str(musetalk_root))
vae, unet, pe = load_all_model(
unet_model_path=unet_model_path,
vae_type="sd-vae",
unet_config=unet_config,
device=device,
)
if use_float16 and torch.cuda.is_available():
print("⚡ 使用 float16 半精度加速")
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
pe = pe.to(device)
vae.vae = vae.vae.to(device)
unet.model = unet.model.to(device)
# Whisper
whisper_dir = str(models_dir / "whisper")
audio_processor = AudioProcessor(feature_extractor_path=whisper_dir)
weight_dtype = unet.model.dtype
whisper = WhisperModel.from_pretrained(whisper_dir)
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
whisper.requires_grad_(False)
# FaceParsing
if version == "v15":
fp = FaceParsing(left_cheek_width=90, right_cheek_width=90)
else:
fp = FaceParsing()
# 恢复工作目录
os.chdir(original_cwd)
models["vae"] = vae
models["unet"] = unet
models["pe"] = pe
models["whisper"] = whisper
models["audio_processor"] = audio_processor
models["fp"] = fp
models["device"] = device
models["weight_dtype"] = weight_dtype
models["version"] = version
models["timesteps"] = torch.tensor([0], device=device)
print("✅ MuseTalk v1.5 模型加载完成,服务就绪!")
print(f"⚙️ 优化参数: batch_size={env_config['batch_size']}, "
f"detect_every={DETECT_EVERY}, blend_cache_every={BLEND_CACHE_EVERY}")
yield
models.clear()
torch.cuda.empty_cache()
app = FastAPI(lifespan=lifespan)
class LipSyncRequest(BaseModel):
video_path: str
audio_path: str
video_out_path: str
batch_size: int = 32
@app.get("/health")
def health_check():
return {"status": "ok", "model_loaded": "unet" in models}
@app.post("/lipsync")
async def generate_lipsync(req: LipSyncRequest):
if "unet" not in models:
raise HTTPException(status_code=503, detail="Model not loaded")
if not os.path.exists(req.video_path):
raise HTTPException(status_code=404, detail=f"Video not found: {req.video_path}")
if not os.path.exists(req.audio_path):
raise HTTPException(status_code=404, detail=f"Audio not found: {req.audio_path}")
print(f"🎬 收到任务: {Path(req.video_path).name} -> {Path(req.video_out_path).name}")
start_time = time.time()
try:
result = _run_inference(req)
elapsed = time.time() - start_time
print(f"✅ 推理完成,耗时 {elapsed:.1f}s ({elapsed/60:.1f}min)")
return result
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
# =====================================================================
# 降频人脸检测: 每 N 帧检测一次, 中间帧线性插值 bbox
# =====================================================================
def _detect_faces_subsampled(frames, detect_every=5):
"""
降频人脸检测:
- 每 detect_every 帧运行 DWPose + FaceAlignment
- 中间帧线性插值 bbox 坐标
- 对于口播视频 (人脸几乎不动), 插值误差可忽略
"""
from mmpose.apis import inference_topdown
from mmpose.structures import merge_data_samples
import musetalk.utils.preprocessing as _prep
n = len(frames)
if n == 0:
return []
# 确定需要检测的帧索引
sampled_indices = list(range(0, n, detect_every))
if sampled_indices[-1] != n - 1:
sampled_indices.append(n - 1)
print(f" 检测 {len(sampled_indices)}/{n} 帧 (每{detect_every}帧)")
# 在采样帧上运行检测
detected = {}
for idx in tqdm(sampled_indices, desc="人脸检测"):
frame = frames[idx]
try:
results = inference_topdown(_prep.model, frame)
results = merge_data_samples(results)
keypoints = results.pred_instances.keypoints
face_land_mark = keypoints[0][23:91].astype(np.int32)
bbox_list = _prep.fa.get_detections_for_batch(np.array([frame]))
if bbox_list[0] is None:
detected[idx] = coord_placeholder
continue
half_face_coord = face_land_mark[29].copy()
half_face_dist = np.max(face_land_mark[:, 1]) - half_face_coord[1]
upper_bond = max(0, half_face_coord[1] - half_face_dist)
f_landmark = (
int(np.min(face_land_mark[:, 0])),
int(upper_bond),
int(np.max(face_land_mark[:, 0])),
int(np.max(face_land_mark[:, 1])),
)
x1, y1, x2, y2 = f_landmark
if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0:
detected[idx] = bbox_list[0] if bbox_list[0] is not None else coord_placeholder
else:
detected[idx] = f_landmark
except Exception as e:
print(f"⚠️ 帧 {idx} 检测失败: {e}")
detected[idx] = coord_placeholder
# 插值填充所有帧
coord_list = [None] * n
for idx in sampled_indices:
coord_list[idx] = detected[idx]
for i in range(n):
if coord_list[i] is not None:
continue
# 找前后已检测的帧
prev_idx = max(j for j in sampled_indices if j < i)
next_idx = min(j for j in sampled_indices if j > i)
prev_bbox = detected[prev_idx]
next_bbox = detected[next_idx]
if prev_bbox == coord_placeholder and next_bbox == coord_placeholder:
coord_list[i] = coord_placeholder
elif prev_bbox == coord_placeholder:
coord_list[i] = next_bbox
elif next_bbox == coord_placeholder:
coord_list[i] = prev_bbox
else:
alpha = (i - prev_idx) / (next_idx - prev_idx)
coord_list[i] = tuple(
int(a * (1 - alpha) + b * alpha)
for a, b in zip(prev_bbox, next_bbox)
)
return coord_list
# =====================================================================
# 核心推理 (优化版)
# =====================================================================
@torch.no_grad()
def _run_inference(req: LipSyncRequest) -> dict:
"""
优化版推理逻辑:
1. cv2.VideoCapture 直读帧 (跳过 ffmpeg→PNG→imread)
2. 人脸检测降频 (每 N 帧, 中间插值)
3. BiSeNet mask 缓存 (每 N 帧更新)
4. cv2.VideoWriter 直写 (跳过逐帧 PNG)
5. 每阶段计时
"""
vae = models["vae"]
unet = models["unet"]
pe = models["pe"]
whisper = models["whisper"]
audio_processor = models["audio_processor"]
fp = models["fp"]
device = models["device"]
weight_dtype = models["weight_dtype"]
version = models["version"]
timesteps = models["timesteps"]
batch_size = req.batch_size or env_config["batch_size"]
video_path = req.video_path
audio_path = req.audio_path
output_vid_path = req.video_out_path
os.makedirs(os.path.dirname(output_vid_path), exist_ok=True)
t_total = time.time()
timings = {}
# ===== Phase 1: 读取视频帧 (cv2.VideoCapture, 跳过 ffmpeg→PNG) =====
t0 = time.time()
if get_file_type(video_path) == "video":
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
frames = []
while True:
ret, frame = cap.read()
if not ret:
break
frames.append(frame)
cap.release()
elif get_file_type(video_path) == "image":
frames = [cv2.imread(video_path)]
fps = 25.0
else:
raise ValueError(f"不支持的文件类型: {video_path}")
timings["1_read"] = time.time() - t0
print(f"📹 读取 {len(frames)} 帧, FPS={fps} [{timings['1_read']:.1f}s]")
if not frames:
raise RuntimeError("视频帧为空")
# ===== Phase 2: Whisper 音频特征 =====
t0 = time.time()
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features, device, weight_dtype, whisper, librosa_length,
fps=fps,
audio_padding_length_left=2,
audio_padding_length_right=2,
)
timings["2_whisper"] = time.time() - t0
print(f"🎵 Whisper 特征 [{timings['2_whisper']:.1f}s]")
# ===== Phase 3: 人脸检测 (降频) =====
t0 = time.time()
coord_list = _detect_faces_subsampled(frames, detect_every=DETECT_EVERY)
timings["3_face"] = time.time() - t0
print(f"🔍 人脸检测 [{timings['3_face']:.1f}s]")
# ===== Phase 4: VAE 潜空间编码 =====
t0 = time.time()
input_latent_list = []
extra_margin = 10
for bbox, frame in zip(coord_list, frames):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
if version == "v15":
y2 = min(y2 + extra_margin, frame.shape[0])
crop_frame = frame[y1:y2, x1:x2]
crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
latents = vae.get_latents_for_unet(crop_frame)
input_latent_list.append(latents)
timings["4_vae"] = time.time() - t0
print(f"🧠 VAE 编码 [{timings['4_vae']:.1f}s]")
# ===== Phase 5: UNet 批量推理 =====
t0 = time.time()
# 循环帧序列 (引用, 不复制数据)
frame_list_cycle = frames + frames[::-1]
coord_list_cycle = coord_list + coord_list[::-1]
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
video_num = len(whisper_chunks)
gen = datagen(
whisper_chunks=whisper_chunks,
vae_encode_latents=input_latent_list_cycle,
batch_size=batch_size,
delay_frame=0,
device=device,
)
res_frame_list = []
total_batches = int(np.ceil(float(video_num) / batch_size))
print(f"🚀 推理: {video_num} 帧, batch={batch_size}, {total_batches}")
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total_batches)):
audio_feature_batch = pe(whisper_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(
latent_batch, timesteps,
encoder_hidden_states=audio_feature_batch
).sample
recon = vae.decode_latents(pred_latents)
for res_frame in recon:
res_frame_list.append(res_frame)
timings["5_unet"] = time.time() - t0
print(f"✅ UNet 推理: {len(res_frame_list)} 帧 [{timings['5_unet']:.1f}s]")
# ===== Phase 6: 合成 (缓存 BiSeNet mask + cv2.VideoWriter) =====
t0 = time.time()
h, w = frames[0].shape[:2]
temp_raw_path = output_vid_path + ".raw.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(temp_raw_path, fourcc, fps, (w, h))
if not writer.isOpened():
raise RuntimeError(f"cv2.VideoWriter 打开失败: {temp_raw_path}")
cached_mask = None
cached_crop_box = None
blend_mode = "jaw" if version == "v15" else "raw"
for i in tqdm(range(len(res_frame_list)), desc="合成"):
res_frame = res_frame_list[i]
bbox = coord_list_cycle[i % len(coord_list_cycle)]
ori_frame = frame_list_cycle[i % len(frame_list_cycle)].copy()
x1, y1, x2, y2 = bbox
if version == "v15":
y2 = min(y2 + extra_margin, ori_frame.shape[0])
adjusted_bbox = (x1, y1, x2, y2)
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
except Exception:
writer.write(ori_frame)
continue
# 每 N 帧更新 BiSeNet 人脸解析 mask, 其余帧复用缓存
if i % BLEND_CACHE_EVERY == 0 or cached_mask is None:
try:
cached_mask, cached_crop_box = get_image_prepare_material(
ori_frame, adjusted_bbox, mode=blend_mode, fp=fp)
except Exception:
# 如果 prepare 失败, 用完整方式
combine_frame = get_image(
ori_frame, res_frame, list(adjusted_bbox),
mode=blend_mode, fp=fp)
writer.write(combine_frame)
continue
try:
combine_frame = get_image_blending(
ori_frame, res_frame, adjusted_bbox, cached_mask, cached_crop_box)
except Exception:
# blending 失败时 fallback 到完整方式
combine_frame = get_image(
ori_frame, res_frame, list(adjusted_bbox),
mode=blend_mode, fp=fp)
writer.write(combine_frame)
writer.release()
timings["6_blend"] = time.time() - t0
print(f"🎨 合成 [{timings['6_blend']:.1f}s]")
# ===== Phase 7: FFmpeg 重编码 H.264 + 合并音频 =====
t0 = time.time()
cmd = (
f"ffmpeg -y -v warning -i {temp_raw_path} -i {audio_path} "
f"-c:v libx264 -crf 18 -pix_fmt yuv420p "
f"-c:a copy -shortest {output_vid_path}"
)
if not run_ffmpeg(cmd):
raise RuntimeError("FFmpeg 重编码+音频合并失败")
# 清理临时文件
if os.path.exists(temp_raw_path):
os.unlink(temp_raw_path)
timings["7_encode"] = time.time() - t0
print(f"🔊 编码+音频 [{timings['7_encode']:.1f}s]")
# ===== 汇总 =====
total_time = time.time() - t_total
print(f"\n⏱️ 总耗时: {total_time:.1f}s ({total_time/60:.1f}min)")
for k, v in timings.items():
pct = v / total_time * 100
print(f" {k}: {v:.1f}s ({pct:.0f}%)")
if not os.path.exists(output_vid_path):
raise RuntimeError("输出文件未生成")
return {"status": "success", "output_path": output_vid_path}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8011)