205 lines
7.1 KiB
Python
205 lines
7.1 KiB
Python
|
||
import os
|
||
import argparse
|
||
from pathlib import Path
|
||
|
||
# --- 自动加载 GPU 配置 (必须在 torch 导入前) ---
|
||
def load_gpu_config():
|
||
"""尝试从后端 .env 文件读取 LATENTSYNC_GPU_ID"""
|
||
try:
|
||
# 路径: scripts/server.py -> scripts -> LatentSync -> models -> ViGent2 -> backend -> .env
|
||
current_dir = Path(__file__).resolve().parent
|
||
env_path = current_dir.parent.parent.parent / "backend" / ".env"
|
||
|
||
target_gpu = "1" # 默认 fallback
|
||
|
||
if env_path.exists():
|
||
print(f"📖 读取配置文件: {env_path}")
|
||
with open(env_path, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if line.startswith("LATENTSYNC_GPU_ID="):
|
||
val = line.split("=")[1].strip().split("#")[0].strip()
|
||
if val:
|
||
target_gpu = val
|
||
print(f"⚙️ 发现配置 LATENTSYNC_GPU_ID={target_gpu}")
|
||
break
|
||
|
||
# 设置环境变量
|
||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = target_gpu
|
||
print(f"✅ 已自动设置: CUDA_VISIBLE_DEVICES={target_gpu}")
|
||
else:
|
||
print(f"ℹ️ 检测到外部 CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']},跳过自动配置")
|
||
|
||
except Exception as e:
|
||
print(f"⚠️ 读取 GPU 配置失败: {e},将使用默认设置")
|
||
|
||
load_gpu_config()
|
||
|
||
# --- 性能优化: 限制 CPU 线程数 ---
|
||
# 防止 PyTorch 默认占用所有 CPU 核心 (56线程) 导致系统卡顿
|
||
# 预留资源给 Backend, Frontend 和 SSH
|
||
os.environ["OMP_NUM_THREADS"] = "8"
|
||
os.environ["MKL_NUM_THREADS"] = "8"
|
||
os.environ["TORCH_NUM_THREADS"] = "8"
|
||
print("⚙️ 已限制 PyTorch CPU 线程数为 8,防止系统卡顿")
|
||
|
||
import torch
|
||
from contextlib import asynccontextmanager
|
||
from fastapi import FastAPI, HTTPException
|
||
from pydantic import BaseModel
|
||
from omegaconf import OmegaConf
|
||
from diffusers import AutoencoderKL, DDIMScheduler
|
||
from latentsync.models.unet import UNet3DConditionModel
|
||
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
|
||
from latentsync.whisper.audio2feature import Audio2Feature
|
||
from accelerate.utils import set_seed
|
||
from DeepCache import DeepCacheSDHelper
|
||
|
||
# 全局模型缓存
|
||
models = {}
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
# --- 模型加载逻辑 (参考 inference.py) ---
|
||
print("⏳ 正在加载 LatentSync 模型...")
|
||
|
||
# 默认配置路径 (相对于根目录)
|
||
unet_config_path = "configs/unet/stage2_512.yaml"
|
||
ckpt_path = "checkpoints/latentsync_unet.pt"
|
||
|
||
if not os.path.exists(unet_config_path):
|
||
print(f"⚠️ 找不到配置文件: {unet_config_path},请确保在 models/LatentSync 根目录运行")
|
||
|
||
config = OmegaConf.load(unet_config_path)
|
||
|
||
# Check GPU
|
||
is_fp16_supported = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] > 7
|
||
dtype = torch.float16 if is_fp16_supported else torch.float32
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
||
if torch.cuda.is_available():
|
||
gpu_name = torch.cuda.get_device_name(0)
|
||
print(f"🖥️ 正在使用 GPU: {gpu_name} (CUDA_VISIBLE_DEVICES 已生效)")
|
||
else:
|
||
print("⚠️ 警告: 未检测到 GPU,将使用 CPU 进行推理 (速度极慢)")
|
||
|
||
scheduler = DDIMScheduler.from_pretrained("configs")
|
||
|
||
# Whisper Model
|
||
if config.model.cross_attention_dim == 768:
|
||
whisper_path = "checkpoints/whisper/small.pt"
|
||
else:
|
||
whisper_path = "checkpoints/whisper/tiny.pt"
|
||
|
||
audio_encoder = Audio2Feature(
|
||
model_path=whisper_path,
|
||
device=device,
|
||
num_frames=config.data.num_frames,
|
||
audio_feat_length=config.data.audio_feat_length,
|
||
)
|
||
|
||
# VAE
|
||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype)
|
||
vae.config.scaling_factor = 0.18215
|
||
vae.config.shift_factor = 0
|
||
|
||
# UNet
|
||
unet, _ = UNet3DConditionModel.from_pretrained(
|
||
OmegaConf.to_container(config.model),
|
||
ckpt_path,
|
||
device="cpu", # Load to CPU first to save memory during init
|
||
)
|
||
unet = unet.to(dtype=dtype)
|
||
|
||
# Pipeline
|
||
pipeline = LipsyncPipeline(
|
||
vae=vae,
|
||
audio_encoder=audio_encoder,
|
||
unet=unet,
|
||
scheduler=scheduler,
|
||
).to(device)
|
||
|
||
# DeepCache (默认启用)
|
||
helper = DeepCacheSDHelper(pipe=pipeline)
|
||
helper.set_params(cache_interval=3, cache_branch_id=0)
|
||
helper.enable()
|
||
|
||
models["pipeline"] = pipeline
|
||
models["config"] = config
|
||
models["dtype"] = dtype
|
||
|
||
print("✅ LatentSync 模型加载完成,服务就绪!")
|
||
yield
|
||
# Clean up if needed
|
||
models.clear()
|
||
torch.cuda.empty_cache()
|
||
|
||
app = FastAPI(lifespan=lifespan)
|
||
|
||
class LipSyncRequest(BaseModel):
|
||
video_path: str
|
||
audio_path: str
|
||
video_out_path: str
|
||
inference_steps: int = 20
|
||
guidance_scale: float = 1.5
|
||
seed: int = 1247
|
||
temp_dir: str = "temp"
|
||
|
||
@app.get("/health")
|
||
def health_check():
|
||
return {"status": "ok", "model_loaded": "pipeline" in models}
|
||
|
||
@app.post("/lipsync")
|
||
async def generate_lipsync(req: LipSyncRequest):
|
||
if "pipeline" not in models:
|
||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||
|
||
if not os.path.exists(req.video_path):
|
||
raise HTTPException(status_code=404, detail=f"Video not found: {req.video_path}")
|
||
if not os.path.exists(req.audio_path):
|
||
raise HTTPException(status_code=404, detail=f"Audio not found: {req.audio_path}")
|
||
|
||
print(f"🎬 收到任务: {Path(req.video_path).name} -> {Path(req.video_out_path).name}")
|
||
|
||
try:
|
||
pipeline = models["pipeline"]
|
||
config = models["config"]
|
||
dtype = models["dtype"]
|
||
|
||
# Set seed
|
||
if req.seed != -1:
|
||
set_seed(req.seed)
|
||
else:
|
||
torch.seed()
|
||
|
||
# Run Inference
|
||
pipeline(
|
||
video_path=req.video_path,
|
||
audio_path=req.audio_path,
|
||
video_out_path=req.video_out_path,
|
||
num_frames=config.data.num_frames,
|
||
num_inference_steps=req.inference_steps,
|
||
guidance_scale=req.guidance_scale,
|
||
weight_dtype=dtype,
|
||
width=config.data.resolution,
|
||
height=config.data.resolution,
|
||
mask_image_path=config.data.mask_image_path,
|
||
temp_dir=req.temp_dir,
|
||
)
|
||
|
||
if os.path.exists(req.video_out_path):
|
||
return {"status": "success", "output_path": req.video_out_path}
|
||
else:
|
||
raise HTTPException(status_code=500, detail="Output file generation failed")
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=8007)
|