优化代码

This commit is contained in:
Kevin Wong
2026-01-21 10:30:32 +08:00
parent 1890cea3ee
commit cbf840f472
72 changed files with 112399 additions and 92 deletions

View File

@@ -0,0 +1,196 @@
import os
import argparse
from pathlib import Path
# --- 自动加载 GPU 配置 (必须在 torch 导入前) ---
def load_gpu_config():
"""尝试从后端 .env 文件读取 LATENTSYNC_GPU_ID"""
try:
# 路径: scripts/server.py -> scripts -> LatentSync -> models -> ViGent2 -> backend -> .env
current_dir = Path(__file__).resolve().parent
env_path = current_dir.parent.parent.parent / "backend" / ".env"
target_gpu = "1" # 默认 fallback
if env_path.exists():
print(f"📖 读取配置文件: {env_path}")
with open(env_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line.startswith("LATENTSYNC_GPU_ID="):
val = line.split("=")[1].strip().split("#")[0].strip()
if val:
target_gpu = val
print(f"⚙️ 发现配置 LATENTSYNC_GPU_ID={target_gpu}")
break
# 设置环境变量
if "CUDA_VISIBLE_DEVICES" not in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = target_gpu
print(f"✅ 已自动设置: CUDA_VISIBLE_DEVICES={target_gpu}")
else:
print(f" 检测到外部 CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']},跳过自动配置")
except Exception as e:
print(f"⚠️ 读取 GPU 配置失败: {e},将使用默认设置")
load_gpu_config()
import torch
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from omegaconf import OmegaConf
from diffusers import AutoencoderKL, DDIMScheduler
from latentsync.models.unet import UNet3DConditionModel
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
from latentsync.whisper.audio2feature import Audio2Feature
from accelerate.utils import set_seed
from DeepCache import DeepCacheSDHelper
# 全局模型缓存
models = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
# --- 模型加载逻辑 (参考 inference.py) ---
print("⏳ 正在加载 LatentSync 模型...")
# 默认配置路径 (相对于根目录)
unet_config_path = "configs/unet/stage2_512.yaml"
ckpt_path = "checkpoints/latentsync_unet.pt"
if not os.path.exists(unet_config_path):
print(f"⚠️ 找不到配置文件: {unet_config_path},请确保在 models/LatentSync 根目录运行")
config = OmegaConf.load(unet_config_path)
# Check GPU
is_fp16_supported = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] > 7
dtype = torch.float16 if is_fp16_supported else torch.float32
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
print(f"🖥️ 正在使用 GPU: {gpu_name} (CUDA_VISIBLE_DEVICES 已生效)")
else:
print("⚠️ 警告: 未检测到 GPU将使用 CPU 进行推理 (速度极慢)")
scheduler = DDIMScheduler.from_pretrained("configs")
# Whisper Model
if config.model.cross_attention_dim == 768:
whisper_path = "checkpoints/whisper/small.pt"
else:
whisper_path = "checkpoints/whisper/tiny.pt"
audio_encoder = Audio2Feature(
model_path=whisper_path,
device=device,
num_frames=config.data.num_frames,
audio_feat_length=config.data.audio_feat_length,
)
# VAE
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype)
vae.config.scaling_factor = 0.18215
vae.config.shift_factor = 0
# UNet
unet, _ = UNet3DConditionModel.from_pretrained(
OmegaConf.to_container(config.model),
ckpt_path,
device="cpu", # Load to CPU first to save memory during init
)
unet = unet.to(dtype=dtype)
# Pipeline
pipeline = LipsyncPipeline(
vae=vae,
audio_encoder=audio_encoder,
unet=unet,
scheduler=scheduler,
).to(device)
# DeepCache (默认启用)
helper = DeepCacheSDHelper(pipe=pipeline)
helper.set_params(cache_interval=3, cache_branch_id=0)
helper.enable()
models["pipeline"] = pipeline
models["config"] = config
models["dtype"] = dtype
print("✅ LatentSync 模型加载完成,服务就绪!")
yield
# Clean up if needed
models.clear()
torch.cuda.empty_cache()
app = FastAPI(lifespan=lifespan)
class LipSyncRequest(BaseModel):
video_path: str
audio_path: str
video_out_path: str
inference_steps: int = 20
guidance_scale: float = 1.5
seed: int = 1247
temp_dir: str = "temp"
@app.get("/health")
def health_check():
return {"status": "ok", "model_loaded": "pipeline" in models}
@app.post("/lipsync")
async def generate_lipsync(req: LipSyncRequest):
if "pipeline" not in models:
raise HTTPException(status_code=503, detail="Model not loaded")
if not os.path.exists(req.video_path):
raise HTTPException(status_code=404, detail=f"Video not found: {req.video_path}")
if not os.path.exists(req.audio_path):
raise HTTPException(status_code=404, detail=f"Audio not found: {req.audio_path}")
print(f"🎬 收到任务: {Path(req.video_path).name} -> {Path(req.video_out_path).name}")
try:
pipeline = models["pipeline"]
config = models["config"]
dtype = models["dtype"]
# Set seed
if req.seed != -1:
set_seed(req.seed)
else:
torch.seed()
# Run Inference
pipeline(
video_path=req.video_path,
audio_path=req.audio_path,
video_out_path=req.video_out_path,
num_frames=config.data.num_frames,
num_inference_steps=req.inference_steps,
guidance_scale=req.guidance_scale,
weight_dtype=dtype,
width=config.data.resolution,
height=config.data.resolution,
mask_image_path=config.data.mask_image_path,
temp_dir=req.temp_dir,
)
if os.path.exists(req.video_out_path):
return {"status": "success", "output_path": req.video_out_path}
else:
raise HTTPException(status_code=500, detail="Output file generation failed")
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8007)