This commit is contained in:
Kevin Wong
2026-03-02 16:35:16 +08:00
parent abf005f225
commit 48bc78fe38
19 changed files with 1092 additions and 316 deletions

View File

@@ -4,14 +4,14 @@ MuseTalk v1.5 常驻推理服务 (优化版 v2)
- GPU: 从 backend/.env 读取 MUSETALK_GPU_ID (默认 0)
- 架构: FastAPI + lifespan (与 LatentSync server.py 同模式)
优化项 (vs v1):
1. cv2.VideoCapture 直读帧 (跳过 ffmpeg→PNG→imread)
2. 人脸检测降频 (每 N 帧检测, 中间插值 bbox)
3. BiSeNet mask 缓存 (每 N 帧更新, 中间复用)
4. cv2.VideoWriter 直写视频 (跳过逐帧 PNG 写盘)
5. batch_size 8→32
6. 每阶段计时
"""
优化项 (vs v1):
1. cv2.VideoCapture 直读帧 (跳过 ffmpeg→PNG→imread)
2. 人脸检测降频 (每 N 帧检测, 中间插值 bbox)
3. BiSeNet mask 缓存 (每 N 帧更新, 中间复用)
4. FFmpeg rawvideo 管道直编码 (去掉中间有损 mp4v)
5. batch_size 8→32
6. 每阶段计时
"""
import os
import sys
@@ -84,17 +84,28 @@ from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
# --- 从 .env 读取额外配置 ---
def load_env_config():
"""读取 MuseTalk 相关环境变量"""
config = {
"batch_size": 32,
"version": "v15",
"use_float16": True,
}
try:
env_path = musetalk_root.parent.parent / "backend" / ".env"
if env_path.exists():
with open(env_path, "r", encoding="utf-8") as f:
def load_env_config():
"""读取 MuseTalk 相关环境变量"""
config = {
"batch_size": 32,
"version": "v15",
"use_float16": True,
"detect_every": 5,
"blend_cache_every": 5,
"audio_padding_left": 2,
"audio_padding_right": 2,
"extra_margin": 15,
"delay_frame": 0,
"blend_mode": "auto",
"faceparsing_left_cheek_width": 90,
"faceparsing_right_cheek_width": 90,
"encode_crf": 18,
"encode_preset": "medium",
}
try:
env_path = musetalk_root.parent.parent / "backend" / ".env"
if env_path.exists():
with open(env_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("MUSETALK_BATCH_SIZE="):
@@ -105,22 +116,78 @@ def load_env_config():
val = line.split("=")[1].strip().split("#")[0].strip()
if val:
config["version"] = val
elif line.startswith("MUSETALK_USE_FLOAT16="):
val = line.split("=")[1].strip().split("#")[0].strip().lower()
config["use_float16"] = val in ("true", "1", "yes")
except Exception as e:
print(f"⚠️ 读取额外配置失败: {e}")
return config
env_config = load_env_config()
elif line.startswith("MUSETALK_USE_FLOAT16="):
val = line.split("=")[1].strip().split("#")[0].strip().lower()
config["use_float16"] = val in ("true", "1", "yes")
elif line.startswith("MUSETALK_DETECT_EVERY="):
val = line.split("=")[1].strip().split("#")[0].strip()
if val:
config["detect_every"] = max(1, int(val))
elif line.startswith("MUSETALK_BLEND_CACHE_EVERY="):
val = line.split("=")[1].strip().split("#")[0].strip()
if val:
config["blend_cache_every"] = max(1, int(val))
elif line.startswith("MUSETALK_AUDIO_PADDING_LEFT="):
val = line.split("=")[1].strip().split("#")[0].strip()
if val:
config["audio_padding_left"] = max(0, int(val))
elif line.startswith("MUSETALK_AUDIO_PADDING_RIGHT="):
val = line.split("=")[1].strip().split("#")[0].strip()
if val:
config["audio_padding_right"] = max(0, int(val))
elif line.startswith("MUSETALK_EXTRA_MARGIN="):
val = line.split("=")[1].strip().split("#")[0].strip()
if val:
config["extra_margin"] = max(0, int(val))
elif line.startswith("MUSETALK_DELAY_FRAME="):
val = line.split("=")[1].strip().split("#")[0].strip()
if val:
config["delay_frame"] = int(val)
elif line.startswith("MUSETALK_BLEND_MODE="):
val = line.split("=")[1].strip().split("#")[0].strip().lower()
if val in ("auto", "jaw", "raw"):
config["blend_mode"] = val
elif line.startswith("MUSETALK_FACEPARSING_LEFT_CHEEK_WIDTH="):
val = line.split("=")[1].strip().split("#")[0].strip()
if val:
config["faceparsing_left_cheek_width"] = max(0, int(val))
elif line.startswith("MUSETALK_FACEPARSING_RIGHT_CHEEK_WIDTH="):
val = line.split("=")[1].strip().split("#")[0].strip()
if val:
config["faceparsing_right_cheek_width"] = max(0, int(val))
elif line.startswith("MUSETALK_ENCODE_CRF="):
val = line.split("=")[1].strip().split("#")[0].strip()
if val:
config["encode_crf"] = min(51, max(0, int(val)))
elif line.startswith("MUSETALK_ENCODE_PRESET="):
val = line.split("=")[1].strip().split("#")[0].strip().lower()
if val in (
"ultrafast", "superfast", "veryfast", "faster", "fast",
"medium", "slow", "slower", "veryslow"
):
config["encode_preset"] = val
except Exception as e:
print(f"⚠️ 读取额外配置失败: {e}")
return config
env_config = load_env_config()
# 全局模型缓存
models = {}
# ===================== 优化参数 =====================
DETECT_EVERY = 5 # 人脸检测降频: 每 N 帧检测一次
BLEND_CACHE_EVERY = 5 # BiSeNet mask 缓存: 每 N 帧更新一次
# ====================================================
# ===================== 优化参数 =====================
DETECT_EVERY = int(env_config["detect_every"]) # 人脸检测降频: 每 N 帧检测一次
BLEND_CACHE_EVERY = int(env_config["blend_cache_every"]) # BiSeNet mask 缓存: 每 N 帧更新一次
AUDIO_PADDING_LEFT = int(env_config["audio_padding_left"])
AUDIO_PADDING_RIGHT = int(env_config["audio_padding_right"])
EXTRA_MARGIN = int(env_config["extra_margin"])
DELAY_FRAME = int(env_config["delay_frame"])
BLEND_MODE = str(env_config["blend_mode"])
FACEPARSING_LEFT_CHEEK_WIDTH = int(env_config["faceparsing_left_cheek_width"])
FACEPARSING_RIGHT_CHEEK_WIDTH = int(env_config["faceparsing_right_cheek_width"])
ENCODE_CRF = int(env_config["encode_crf"])
ENCODE_PRESET = str(env_config["encode_preset"])
# ====================================================
def run_ffmpeg(cmd):
@@ -191,11 +258,14 @@ async def lifespan(app: FastAPI):
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
whisper.requires_grad_(False)
# FaceParsing
if version == "v15":
fp = FaceParsing(left_cheek_width=90, right_cheek_width=90)
else:
fp = FaceParsing()
# FaceParsing
if version == "v15":
fp = FaceParsing(
left_cheek_width=FACEPARSING_LEFT_CHEEK_WIDTH,
right_cheek_width=FACEPARSING_RIGHT_CHEEK_WIDTH,
)
else:
fp = FaceParsing()
# 恢复工作目录
os.chdir(original_cwd)
@@ -211,9 +281,13 @@ async def lifespan(app: FastAPI):
models["version"] = version
models["timesteps"] = torch.tensor([0], device=device)
print("✅ MuseTalk v1.5 模型加载完成,服务就绪!")
print(f"⚙️ 优化参数: batch_size={env_config['batch_size']}, "
f"detect_every={DETECT_EVERY}, blend_cache_every={BLEND_CACHE_EVERY}")
print("✅ MuseTalk v1.5 模型加载完成,服务就绪!")
print(f"⚙️ 优化参数: batch_size={env_config['batch_size']}, "
f"detect_every={DETECT_EVERY}, blend_cache_every={BLEND_CACHE_EVERY}, "
f"audio_padding=({AUDIO_PADDING_LEFT},{AUDIO_PADDING_RIGHT}), extra_margin={EXTRA_MARGIN}, "
f"delay_frame={DELAY_FRAME}, blend_mode={BLEND_MODE}, "
f"faceparsing_cheek=({FACEPARSING_LEFT_CHEEK_WIDTH},{FACEPARSING_RIGHT_CHEEK_WIDTH}), "
f"encode=libx264/{ENCODE_PRESET}/crf{ENCODE_CRF}")
yield
models.clear()
torch.cuda.empty_cache()
@@ -354,15 +428,15 @@ def _detect_faces_subsampled(frames, detect_every=5):
# 核心推理 (优化版)
# =====================================================================
@torch.no_grad()
def _run_inference(req: LipSyncRequest) -> dict:
"""
优化版推理逻辑:
1. cv2.VideoCapture 直读帧 (跳过 ffmpeg→PNG→imread)
2. 人脸检测降频 (每 N 帧, 中间插值)
3. BiSeNet mask 缓存 (每 N 帧更新)
4. cv2.VideoWriter 直写 (跳过逐帧 PNG)
5. 每阶段计时
def _run_inference(req: LipSyncRequest) -> dict:
"""
优化版推理逻辑:
1. cv2.VideoCapture 直读帧 (跳过 ffmpeg→PNG→imread)
2. 人脸检测降频 (每 N 帧, 中间插值)
3. BiSeNet mask 缓存 (每 N 帧更新)
4. FFmpeg rawvideo 管道直编码 (无中间有损文件)
5. 每阶段计时
"""
vae = models["vae"]
unet = models["unet"]
pe = models["pe"]
@@ -411,12 +485,12 @@ def _run_inference(req: LipSyncRequest) -> dict:
# ===== Phase 2: Whisper 音频特征 =====
t0 = time.time()
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features, device, weight_dtype, whisper, librosa_length,
fps=fps,
audio_padding_length_left=2,
audio_padding_length_right=2,
)
whisper_chunks = audio_processor.get_whisper_chunk(
whisper_input_features, device, weight_dtype, whisper, librosa_length,
fps=fps,
audio_padding_length_left=AUDIO_PADDING_LEFT,
audio_padding_length_right=AUDIO_PADDING_RIGHT,
)
timings["2_whisper"] = time.time() - t0
print(f"🎵 Whisper 特征 [{timings['2_whisper']:.1f}s]")
@@ -427,12 +501,12 @@ def _run_inference(req: LipSyncRequest) -> dict:
print(f"🔍 人脸检测 [{timings['3_face']:.1f}s]")
# ===== Phase 4: VAE 潜空间编码 =====
t0 = time.time()
input_latent_list = []
extra_margin = 15
for bbox, frame in zip(coord_list, frames):
if bbox == coord_placeholder:
continue
t0 = time.time()
input_latent_list = []
extra_margin = EXTRA_MARGIN
for bbox, frame in zip(coord_list, frames):
if bbox == coord_placeholder:
continue
x1, y1, x2, y2 = bbox
if version == "v15":
y2 = min(y2 + extra_margin, frame.shape[0])
@@ -453,13 +527,13 @@ def _run_inference(req: LipSyncRequest) -> dict:
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
video_num = len(whisper_chunks)
gen = datagen(
whisper_chunks=whisper_chunks,
vae_encode_latents=input_latent_list_cycle,
batch_size=batch_size,
delay_frame=0,
device=device,
)
gen = datagen(
whisper_chunks=whisper_chunks,
vae_encode_latents=input_latent_list_cycle,
batch_size=batch_size,
delay_frame=DELAY_FRAME,
device=device,
)
res_frame_list = []
total_batches = int(np.ceil(float(video_num) / batch_size))
@@ -479,21 +553,44 @@ def _run_inference(req: LipSyncRequest) -> dict:
timings["5_unet"] = time.time() - t0
print(f"✅ UNet 推理: {len(res_frame_list)} 帧 [{timings['5_unet']:.1f}s]")
# ===== Phase 6: 合成 (cv2.VideoWriter + 纯 numpy blending) =====
t0 = time.time()
h, w = frames[0].shape[:2]
temp_raw_path = output_vid_path + ".raw.mp4"
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(temp_raw_path, fourcc, fps, (w, h))
if not writer.isOpened():
raise RuntimeError(f"cv2.VideoWriter 打开失败: {temp_raw_path}")
cached_mask = None
cached_crop_box = None
blend_mode = "jaw" if version == "v15" else "raw"
# ===== Phase 6: 合成并写入 FFmpeg rawvideo 管道 =====
t0 = time.time()
h, w = frames[0].shape[:2]
ffmpeg_cmd = [
"ffmpeg", "-y", "-v", "warning",
"-f", "rawvideo",
"-pix_fmt", "bgr24",
"-s", f"{w}x{h}",
"-r", str(fps),
"-i", "-",
"-i", audio_path,
"-c:v", "libx264", "-preset", ENCODE_PRESET, "-crf", str(ENCODE_CRF), "-pix_fmt", "yuv420p",
"-c:a", "copy", "-shortest",
output_vid_path,
]
ffmpeg_proc = subprocess.Popen(
ffmpeg_cmd,
stdin=subprocess.PIPE,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
pipe_in = ffmpeg_proc.stdin
if pipe_in is None:
raise RuntimeError("FFmpeg 管道初始化失败")
def _write_pipe_frame(frame: np.ndarray):
try:
pipe_in.write(np.ascontiguousarray(frame, dtype=np.uint8).tobytes())
except BrokenPipeError as exc:
raise RuntimeError("FFmpeg 管道写入失败") from exc
cached_mask = None
cached_crop_box = None
if BLEND_MODE == "auto":
blend_mode = "jaw" if version == "v15" else "raw"
else:
blend_mode = BLEND_MODE
for i in tqdm(range(len(res_frame_list)), desc="合成"):
res_frame = res_frame_list[i]
@@ -503,26 +600,26 @@ def _run_inference(req: LipSyncRequest) -> dict:
x1, y1, x2, y2 = bbox
if version == "v15":
y2 = min(y2 + extra_margin, ori_frame.shape[0])
adjusted_bbox = (x1, y1, x2, y2)
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
except Exception:
writer.write(ori_frame)
continue
adjusted_bbox = (x1, y1, x2, y2)
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
except Exception:
_write_pipe_frame(ori_frame)
continue
# 每 N 帧更新 BiSeNet 人脸解析 mask, 其余帧复用缓存
if i % BLEND_CACHE_EVERY == 0 or cached_mask is None:
try:
cached_mask, cached_crop_box = get_image_prepare_material(
ori_frame, adjusted_bbox, mode=blend_mode, fp=fp)
except Exception:
# 如果 prepare 失败, 用完整方式
combine_frame = get_image(
ori_frame, res_frame, list(adjusted_bbox),
mode=blend_mode, fp=fp)
writer.write(combine_frame)
continue
except Exception:
# 如果 prepare 失败, 用完整方式
combine_frame = get_image(
ori_frame, res_frame, list(adjusted_bbox),
mode=blend_mode, fp=fp)
_write_pipe_frame(combine_frame)
continue
try:
combine_frame = get_image_blending_fast(
@@ -532,35 +629,25 @@ def _run_inference(req: LipSyncRequest) -> dict:
try:
combine_frame = get_image_blending(
ori_frame, res_frame, adjusted_bbox, cached_mask, cached_crop_box)
except Exception:
combine_frame = get_image(
ori_frame, res_frame, list(adjusted_bbox),
mode=blend_mode, fp=fp)
writer.write(combine_frame)
writer.release()
timings["6_blend"] = time.time() - t0
print(f"🎨 合成 [{timings['6_blend']:.1f}s]")
# ===== Phase 7: FFmpeg H.264 编码 + 合并音频 =====
t0 = time.time()
cmd = [
"ffmpeg", "-y", "-v", "warning",
"-i", temp_raw_path, "-i", audio_path,
"-c:v", "libx264", "-crf", "18", "-pix_fmt", "yuv420p",
"-c:a", "copy", "-shortest",
output_vid_path
]
if not run_ffmpeg(cmd):
raise RuntimeError("FFmpeg 重编码+音频合并失败")
# 清理临时文件
if os.path.exists(temp_raw_path):
os.unlink(temp_raw_path)
timings["7_encode"] = time.time() - t0
print(f"🔊 编码+音频 [{timings['7_encode']:.1f}s]")
except Exception:
combine_frame = get_image(
ori_frame, res_frame, list(adjusted_bbox),
mode=blend_mode, fp=fp)
_write_pipe_frame(combine_frame)
pipe_in.close()
timings["6_blend"] = time.time() - t0
print(f"🎨 合成 [{timings['6_blend']:.1f}s]")
# ===== Phase 7: 等待 FFmpeg 编码完成 =====
t0 = time.time()
return_code = ffmpeg_proc.wait()
if return_code != 0:
raise RuntimeError("FFmpeg 编码+音频合并失败")
timings["7_encode"] = time.time() - t0
print(f"🔊 编码+音频 [{timings['7_encode']:.1f}s]")
# ===== 汇总 =====
total_time = time.time() - t_total