更新
This commit is contained in:
@@ -4,14 +4,14 @@ MuseTalk v1.5 常驻推理服务 (优化版 v2)
|
||||
- GPU: 从 backend/.env 读取 MUSETALK_GPU_ID (默认 0)
|
||||
- 架构: FastAPI + lifespan (与 LatentSync server.py 同模式)
|
||||
|
||||
优化项 (vs v1):
|
||||
1. cv2.VideoCapture 直读帧 (跳过 ffmpeg→PNG→imread)
|
||||
2. 人脸检测降频 (每 N 帧检测, 中间插值 bbox)
|
||||
3. BiSeNet mask 缓存 (每 N 帧更新, 中间复用)
|
||||
4. cv2.VideoWriter 直写视频 (跳过逐帧 PNG 写盘)
|
||||
5. batch_size 8→32
|
||||
6. 每阶段计时
|
||||
"""
|
||||
优化项 (vs v1):
|
||||
1. cv2.VideoCapture 直读帧 (跳过 ffmpeg→PNG→imread)
|
||||
2. 人脸检测降频 (每 N 帧检测, 中间插值 bbox)
|
||||
3. BiSeNet mask 缓存 (每 N 帧更新, 中间复用)
|
||||
4. FFmpeg rawvideo 管道直编码 (去掉中间有损 mp4v)
|
||||
5. batch_size 8→32
|
||||
6. 每阶段计时
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
@@ -84,17 +84,28 @@ from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
||||
|
||||
# --- 从 .env 读取额外配置 ---
|
||||
def load_env_config():
|
||||
"""读取 MuseTalk 相关环境变量"""
|
||||
config = {
|
||||
"batch_size": 32,
|
||||
"version": "v15",
|
||||
"use_float16": True,
|
||||
}
|
||||
try:
|
||||
env_path = musetalk_root.parent.parent / "backend" / ".env"
|
||||
if env_path.exists():
|
||||
with open(env_path, "r", encoding="utf-8") as f:
|
||||
def load_env_config():
|
||||
"""读取 MuseTalk 相关环境变量"""
|
||||
config = {
|
||||
"batch_size": 32,
|
||||
"version": "v15",
|
||||
"use_float16": True,
|
||||
"detect_every": 5,
|
||||
"blend_cache_every": 5,
|
||||
"audio_padding_left": 2,
|
||||
"audio_padding_right": 2,
|
||||
"extra_margin": 15,
|
||||
"delay_frame": 0,
|
||||
"blend_mode": "auto",
|
||||
"faceparsing_left_cheek_width": 90,
|
||||
"faceparsing_right_cheek_width": 90,
|
||||
"encode_crf": 18,
|
||||
"encode_preset": "medium",
|
||||
}
|
||||
try:
|
||||
env_path = musetalk_root.parent.parent / "backend" / ".env"
|
||||
if env_path.exists():
|
||||
with open(env_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("MUSETALK_BATCH_SIZE="):
|
||||
@@ -105,22 +116,78 @@ def load_env_config():
|
||||
val = line.split("=")[1].strip().split("#")[0].strip()
|
||||
if val:
|
||||
config["version"] = val
|
||||
elif line.startswith("MUSETALK_USE_FLOAT16="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip().lower()
|
||||
config["use_float16"] = val in ("true", "1", "yes")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 读取额外配置失败: {e}")
|
||||
return config
|
||||
|
||||
env_config = load_env_config()
|
||||
elif line.startswith("MUSETALK_USE_FLOAT16="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip().lower()
|
||||
config["use_float16"] = val in ("true", "1", "yes")
|
||||
elif line.startswith("MUSETALK_DETECT_EVERY="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip()
|
||||
if val:
|
||||
config["detect_every"] = max(1, int(val))
|
||||
elif line.startswith("MUSETALK_BLEND_CACHE_EVERY="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip()
|
||||
if val:
|
||||
config["blend_cache_every"] = max(1, int(val))
|
||||
elif line.startswith("MUSETALK_AUDIO_PADDING_LEFT="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip()
|
||||
if val:
|
||||
config["audio_padding_left"] = max(0, int(val))
|
||||
elif line.startswith("MUSETALK_AUDIO_PADDING_RIGHT="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip()
|
||||
if val:
|
||||
config["audio_padding_right"] = max(0, int(val))
|
||||
elif line.startswith("MUSETALK_EXTRA_MARGIN="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip()
|
||||
if val:
|
||||
config["extra_margin"] = max(0, int(val))
|
||||
elif line.startswith("MUSETALK_DELAY_FRAME="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip()
|
||||
if val:
|
||||
config["delay_frame"] = int(val)
|
||||
elif line.startswith("MUSETALK_BLEND_MODE="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip().lower()
|
||||
if val in ("auto", "jaw", "raw"):
|
||||
config["blend_mode"] = val
|
||||
elif line.startswith("MUSETALK_FACEPARSING_LEFT_CHEEK_WIDTH="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip()
|
||||
if val:
|
||||
config["faceparsing_left_cheek_width"] = max(0, int(val))
|
||||
elif line.startswith("MUSETALK_FACEPARSING_RIGHT_CHEEK_WIDTH="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip()
|
||||
if val:
|
||||
config["faceparsing_right_cheek_width"] = max(0, int(val))
|
||||
elif line.startswith("MUSETALK_ENCODE_CRF="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip()
|
||||
if val:
|
||||
config["encode_crf"] = min(51, max(0, int(val)))
|
||||
elif line.startswith("MUSETALK_ENCODE_PRESET="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip().lower()
|
||||
if val in (
|
||||
"ultrafast", "superfast", "veryfast", "faster", "fast",
|
||||
"medium", "slow", "slower", "veryslow"
|
||||
):
|
||||
config["encode_preset"] = val
|
||||
except Exception as e:
|
||||
print(f"⚠️ 读取额外配置失败: {e}")
|
||||
return config
|
||||
|
||||
env_config = load_env_config()
|
||||
|
||||
# 全局模型缓存
|
||||
models = {}
|
||||
|
||||
# ===================== 优化参数 =====================
|
||||
DETECT_EVERY = 5 # 人脸检测降频: 每 N 帧检测一次
|
||||
BLEND_CACHE_EVERY = 5 # BiSeNet mask 缓存: 每 N 帧更新一次
|
||||
# ====================================================
|
||||
# ===================== 优化参数 =====================
|
||||
DETECT_EVERY = int(env_config["detect_every"]) # 人脸检测降频: 每 N 帧检测一次
|
||||
BLEND_CACHE_EVERY = int(env_config["blend_cache_every"]) # BiSeNet mask 缓存: 每 N 帧更新一次
|
||||
AUDIO_PADDING_LEFT = int(env_config["audio_padding_left"])
|
||||
AUDIO_PADDING_RIGHT = int(env_config["audio_padding_right"])
|
||||
EXTRA_MARGIN = int(env_config["extra_margin"])
|
||||
DELAY_FRAME = int(env_config["delay_frame"])
|
||||
BLEND_MODE = str(env_config["blend_mode"])
|
||||
FACEPARSING_LEFT_CHEEK_WIDTH = int(env_config["faceparsing_left_cheek_width"])
|
||||
FACEPARSING_RIGHT_CHEEK_WIDTH = int(env_config["faceparsing_right_cheek_width"])
|
||||
ENCODE_CRF = int(env_config["encode_crf"])
|
||||
ENCODE_PRESET = str(env_config["encode_preset"])
|
||||
# ====================================================
|
||||
|
||||
|
||||
def run_ffmpeg(cmd):
|
||||
@@ -191,11 +258,14 @@ async def lifespan(app: FastAPI):
|
||||
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
||||
whisper.requires_grad_(False)
|
||||
|
||||
# FaceParsing
|
||||
if version == "v15":
|
||||
fp = FaceParsing(left_cheek_width=90, right_cheek_width=90)
|
||||
else:
|
||||
fp = FaceParsing()
|
||||
# FaceParsing
|
||||
if version == "v15":
|
||||
fp = FaceParsing(
|
||||
left_cheek_width=FACEPARSING_LEFT_CHEEK_WIDTH,
|
||||
right_cheek_width=FACEPARSING_RIGHT_CHEEK_WIDTH,
|
||||
)
|
||||
else:
|
||||
fp = FaceParsing()
|
||||
|
||||
# 恢复工作目录
|
||||
os.chdir(original_cwd)
|
||||
@@ -211,9 +281,13 @@ async def lifespan(app: FastAPI):
|
||||
models["version"] = version
|
||||
models["timesteps"] = torch.tensor([0], device=device)
|
||||
|
||||
print("✅ MuseTalk v1.5 模型加载完成,服务就绪!")
|
||||
print(f"⚙️ 优化参数: batch_size={env_config['batch_size']}, "
|
||||
f"detect_every={DETECT_EVERY}, blend_cache_every={BLEND_CACHE_EVERY}")
|
||||
print("✅ MuseTalk v1.5 模型加载完成,服务就绪!")
|
||||
print(f"⚙️ 优化参数: batch_size={env_config['batch_size']}, "
|
||||
f"detect_every={DETECT_EVERY}, blend_cache_every={BLEND_CACHE_EVERY}, "
|
||||
f"audio_padding=({AUDIO_PADDING_LEFT},{AUDIO_PADDING_RIGHT}), extra_margin={EXTRA_MARGIN}, "
|
||||
f"delay_frame={DELAY_FRAME}, blend_mode={BLEND_MODE}, "
|
||||
f"faceparsing_cheek=({FACEPARSING_LEFT_CHEEK_WIDTH},{FACEPARSING_RIGHT_CHEEK_WIDTH}), "
|
||||
f"encode=libx264/{ENCODE_PRESET}/crf{ENCODE_CRF}")
|
||||
yield
|
||||
models.clear()
|
||||
torch.cuda.empty_cache()
|
||||
@@ -354,15 +428,15 @@ def _detect_faces_subsampled(frames, detect_every=5):
|
||||
# 核心推理 (优化版)
|
||||
# =====================================================================
|
||||
@torch.no_grad()
|
||||
def _run_inference(req: LipSyncRequest) -> dict:
|
||||
"""
|
||||
优化版推理逻辑:
|
||||
1. cv2.VideoCapture 直读帧 (跳过 ffmpeg→PNG→imread)
|
||||
2. 人脸检测降频 (每 N 帧, 中间插值)
|
||||
3. BiSeNet mask 缓存 (每 N 帧更新)
|
||||
4. cv2.VideoWriter 直写 (跳过逐帧 PNG)
|
||||
5. 每阶段计时
|
||||
def _run_inference(req: LipSyncRequest) -> dict:
|
||||
"""
|
||||
优化版推理逻辑:
|
||||
1. cv2.VideoCapture 直读帧 (跳过 ffmpeg→PNG→imread)
|
||||
2. 人脸检测降频 (每 N 帧, 中间插值)
|
||||
3. BiSeNet mask 缓存 (每 N 帧更新)
|
||||
4. FFmpeg rawvideo 管道直编码 (无中间有损文件)
|
||||
5. 每阶段计时
|
||||
"""
|
||||
vae = models["vae"]
|
||||
unet = models["unet"]
|
||||
pe = models["pe"]
|
||||
@@ -411,12 +485,12 @@ def _run_inference(req: LipSyncRequest) -> dict:
|
||||
# ===== Phase 2: Whisper 音频特征 =====
|
||||
t0 = time.time()
|
||||
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
||||
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||
whisper_input_features, device, weight_dtype, whisper, librosa_length,
|
||||
fps=fps,
|
||||
audio_padding_length_left=2,
|
||||
audio_padding_length_right=2,
|
||||
)
|
||||
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||
whisper_input_features, device, weight_dtype, whisper, librosa_length,
|
||||
fps=fps,
|
||||
audio_padding_length_left=AUDIO_PADDING_LEFT,
|
||||
audio_padding_length_right=AUDIO_PADDING_RIGHT,
|
||||
)
|
||||
timings["2_whisper"] = time.time() - t0
|
||||
print(f"🎵 Whisper 特征 [{timings['2_whisper']:.1f}s]")
|
||||
|
||||
@@ -427,12 +501,12 @@ def _run_inference(req: LipSyncRequest) -> dict:
|
||||
print(f"🔍 人脸检测 [{timings['3_face']:.1f}s]")
|
||||
|
||||
# ===== Phase 4: VAE 潜空间编码 =====
|
||||
t0 = time.time()
|
||||
input_latent_list = []
|
||||
extra_margin = 15
|
||||
for bbox, frame in zip(coord_list, frames):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
t0 = time.time()
|
||||
input_latent_list = []
|
||||
extra_margin = EXTRA_MARGIN
|
||||
for bbox, frame in zip(coord_list, frames):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
if version == "v15":
|
||||
y2 = min(y2 + extra_margin, frame.shape[0])
|
||||
@@ -453,13 +527,13 @@ def _run_inference(req: LipSyncRequest) -> dict:
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
|
||||
video_num = len(whisper_chunks)
|
||||
gen = datagen(
|
||||
whisper_chunks=whisper_chunks,
|
||||
vae_encode_latents=input_latent_list_cycle,
|
||||
batch_size=batch_size,
|
||||
delay_frame=0,
|
||||
device=device,
|
||||
)
|
||||
gen = datagen(
|
||||
whisper_chunks=whisper_chunks,
|
||||
vae_encode_latents=input_latent_list_cycle,
|
||||
batch_size=batch_size,
|
||||
delay_frame=DELAY_FRAME,
|
||||
device=device,
|
||||
)
|
||||
|
||||
res_frame_list = []
|
||||
total_batches = int(np.ceil(float(video_num) / batch_size))
|
||||
@@ -479,21 +553,44 @@ def _run_inference(req: LipSyncRequest) -> dict:
|
||||
timings["5_unet"] = time.time() - t0
|
||||
print(f"✅ UNet 推理: {len(res_frame_list)} 帧 [{timings['5_unet']:.1f}s]")
|
||||
|
||||
# ===== Phase 6: 合成 (cv2.VideoWriter + 纯 numpy blending) =====
|
||||
t0 = time.time()
|
||||
|
||||
h, w = frames[0].shape[:2]
|
||||
temp_raw_path = output_vid_path + ".raw.mp4"
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
writer = cv2.VideoWriter(temp_raw_path, fourcc, fps, (w, h))
|
||||
|
||||
if not writer.isOpened():
|
||||
raise RuntimeError(f"cv2.VideoWriter 打开失败: {temp_raw_path}")
|
||||
|
||||
cached_mask = None
|
||||
cached_crop_box = None
|
||||
blend_mode = "jaw" if version == "v15" else "raw"
|
||||
# ===== Phase 6: 合成并写入 FFmpeg rawvideo 管道 =====
|
||||
t0 = time.time()
|
||||
|
||||
h, w = frames[0].shape[:2]
|
||||
ffmpeg_cmd = [
|
||||
"ffmpeg", "-y", "-v", "warning",
|
||||
"-f", "rawvideo",
|
||||
"-pix_fmt", "bgr24",
|
||||
"-s", f"{w}x{h}",
|
||||
"-r", str(fps),
|
||||
"-i", "-",
|
||||
"-i", audio_path,
|
||||
"-c:v", "libx264", "-preset", ENCODE_PRESET, "-crf", str(ENCODE_CRF), "-pix_fmt", "yuv420p",
|
||||
"-c:a", "copy", "-shortest",
|
||||
output_vid_path,
|
||||
]
|
||||
ffmpeg_proc = subprocess.Popen(
|
||||
ffmpeg_cmd,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
pipe_in = ffmpeg_proc.stdin
|
||||
if pipe_in is None:
|
||||
raise RuntimeError("FFmpeg 管道初始化失败")
|
||||
|
||||
def _write_pipe_frame(frame: np.ndarray):
|
||||
try:
|
||||
pipe_in.write(np.ascontiguousarray(frame, dtype=np.uint8).tobytes())
|
||||
except BrokenPipeError as exc:
|
||||
raise RuntimeError("FFmpeg 管道写入失败") from exc
|
||||
|
||||
cached_mask = None
|
||||
cached_crop_box = None
|
||||
if BLEND_MODE == "auto":
|
||||
blend_mode = "jaw" if version == "v15" else "raw"
|
||||
else:
|
||||
blend_mode = BLEND_MODE
|
||||
|
||||
for i in tqdm(range(len(res_frame_list)), desc="合成"):
|
||||
res_frame = res_frame_list[i]
|
||||
@@ -503,26 +600,26 @@ def _run_inference(req: LipSyncRequest) -> dict:
|
||||
x1, y1, x2, y2 = bbox
|
||||
if version == "v15":
|
||||
y2 = min(y2 + extra_margin, ori_frame.shape[0])
|
||||
adjusted_bbox = (x1, y1, x2, y2)
|
||||
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
|
||||
except Exception:
|
||||
writer.write(ori_frame)
|
||||
continue
|
||||
adjusted_bbox = (x1, y1, x2, y2)
|
||||
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
|
||||
except Exception:
|
||||
_write_pipe_frame(ori_frame)
|
||||
continue
|
||||
|
||||
# 每 N 帧更新 BiSeNet 人脸解析 mask, 其余帧复用缓存
|
||||
if i % BLEND_CACHE_EVERY == 0 or cached_mask is None:
|
||||
try:
|
||||
cached_mask, cached_crop_box = get_image_prepare_material(
|
||||
ori_frame, adjusted_bbox, mode=blend_mode, fp=fp)
|
||||
except Exception:
|
||||
# 如果 prepare 失败, 用完整方式
|
||||
combine_frame = get_image(
|
||||
ori_frame, res_frame, list(adjusted_bbox),
|
||||
mode=blend_mode, fp=fp)
|
||||
writer.write(combine_frame)
|
||||
continue
|
||||
except Exception:
|
||||
# 如果 prepare 失败, 用完整方式
|
||||
combine_frame = get_image(
|
||||
ori_frame, res_frame, list(adjusted_bbox),
|
||||
mode=blend_mode, fp=fp)
|
||||
_write_pipe_frame(combine_frame)
|
||||
continue
|
||||
|
||||
try:
|
||||
combine_frame = get_image_blending_fast(
|
||||
@@ -532,35 +629,25 @@ def _run_inference(req: LipSyncRequest) -> dict:
|
||||
try:
|
||||
combine_frame = get_image_blending(
|
||||
ori_frame, res_frame, adjusted_bbox, cached_mask, cached_crop_box)
|
||||
except Exception:
|
||||
combine_frame = get_image(
|
||||
ori_frame, res_frame, list(adjusted_bbox),
|
||||
mode=blend_mode, fp=fp)
|
||||
|
||||
writer.write(combine_frame)
|
||||
|
||||
writer.release()
|
||||
timings["6_blend"] = time.time() - t0
|
||||
print(f"🎨 合成 [{timings['6_blend']:.1f}s]")
|
||||
|
||||
# ===== Phase 7: FFmpeg H.264 编码 + 合并音频 =====
|
||||
t0 = time.time()
|
||||
cmd = [
|
||||
"ffmpeg", "-y", "-v", "warning",
|
||||
"-i", temp_raw_path, "-i", audio_path,
|
||||
"-c:v", "libx264", "-crf", "18", "-pix_fmt", "yuv420p",
|
||||
"-c:a", "copy", "-shortest",
|
||||
output_vid_path
|
||||
]
|
||||
if not run_ffmpeg(cmd):
|
||||
raise RuntimeError("FFmpeg 重编码+音频合并失败")
|
||||
|
||||
# 清理临时文件
|
||||
if os.path.exists(temp_raw_path):
|
||||
os.unlink(temp_raw_path)
|
||||
|
||||
timings["7_encode"] = time.time() - t0
|
||||
print(f"🔊 编码+音频 [{timings['7_encode']:.1f}s]")
|
||||
except Exception:
|
||||
combine_frame = get_image(
|
||||
ori_frame, res_frame, list(adjusted_bbox),
|
||||
mode=blend_mode, fp=fp)
|
||||
|
||||
_write_pipe_frame(combine_frame)
|
||||
|
||||
pipe_in.close()
|
||||
timings["6_blend"] = time.time() - t0
|
||||
print(f"🎨 合成 [{timings['6_blend']:.1f}s]")
|
||||
|
||||
# ===== Phase 7: 等待 FFmpeg 编码完成 =====
|
||||
t0 = time.time()
|
||||
return_code = ffmpeg_proc.wait()
|
||||
if return_code != 0:
|
||||
raise RuntimeError("FFmpeg 编码+音频合并失败")
|
||||
|
||||
timings["7_encode"] = time.time() - t0
|
||||
print(f"🔊 编码+音频 [{timings['7_encode']:.1f}s]")
|
||||
|
||||
# ===== 汇总 =====
|
||||
total_time = time.time() - t_total
|
||||
|
||||
Reference in New Issue
Block a user