优化代码
This commit is contained in:
196
models/LatentSync/scripts/server.py
Normal file
196
models/LatentSync/scripts/server.py
Normal file
@@ -0,0 +1,196 @@
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
# --- 自动加载 GPU 配置 (必须在 torch 导入前) ---
|
||||
def load_gpu_config():
|
||||
"""尝试从后端 .env 文件读取 LATENTSYNC_GPU_ID"""
|
||||
try:
|
||||
# 路径: scripts/server.py -> scripts -> LatentSync -> models -> ViGent2 -> backend -> .env
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
env_path = current_dir.parent.parent.parent / "backend" / ".env"
|
||||
|
||||
target_gpu = "1" # 默认 fallback
|
||||
|
||||
if env_path.exists():
|
||||
print(f"📖 读取配置文件: {env_path}")
|
||||
with open(env_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line.startswith("LATENTSYNC_GPU_ID="):
|
||||
val = line.split("=")[1].strip().split("#")[0].strip()
|
||||
if val:
|
||||
target_gpu = val
|
||||
print(f"⚙️ 发现配置 LATENTSYNC_GPU_ID={target_gpu}")
|
||||
break
|
||||
|
||||
# 设置环境变量
|
||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = target_gpu
|
||||
print(f"✅ 已自动设置: CUDA_VISIBLE_DEVICES={target_gpu}")
|
||||
else:
|
||||
print(f"ℹ️ 检测到外部 CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']},跳过自动配置")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ 读取 GPU 配置失败: {e},将使用默认设置")
|
||||
|
||||
load_gpu_config()
|
||||
|
||||
import torch
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from omegaconf import OmegaConf
|
||||
from diffusers import AutoencoderKL, DDIMScheduler
|
||||
from latentsync.models.unet import UNet3DConditionModel
|
||||
from latentsync.pipelines.lipsync_pipeline import LipsyncPipeline
|
||||
from latentsync.whisper.audio2feature import Audio2Feature
|
||||
from accelerate.utils import set_seed
|
||||
from DeepCache import DeepCacheSDHelper
|
||||
|
||||
# 全局模型缓存
|
||||
models = {}
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# --- 模型加载逻辑 (参考 inference.py) ---
|
||||
print("⏳ 正在加载 LatentSync 模型...")
|
||||
|
||||
# 默认配置路径 (相对于根目录)
|
||||
unet_config_path = "configs/unet/stage2_512.yaml"
|
||||
ckpt_path = "checkpoints/latentsync_unet.pt"
|
||||
|
||||
if not os.path.exists(unet_config_path):
|
||||
print(f"⚠️ 找不到配置文件: {unet_config_path},请确保在 models/LatentSync 根目录运行")
|
||||
|
||||
config = OmegaConf.load(unet_config_path)
|
||||
|
||||
# Check GPU
|
||||
is_fp16_supported = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] > 7
|
||||
dtype = torch.float16 if is_fp16_supported else torch.float32
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
gpu_name = torch.cuda.get_device_name(0)
|
||||
print(f"🖥️ 正在使用 GPU: {gpu_name} (CUDA_VISIBLE_DEVICES 已生效)")
|
||||
else:
|
||||
print("⚠️ 警告: 未检测到 GPU,将使用 CPU 进行推理 (速度极慢)")
|
||||
|
||||
scheduler = DDIMScheduler.from_pretrained("configs")
|
||||
|
||||
# Whisper Model
|
||||
if config.model.cross_attention_dim == 768:
|
||||
whisper_path = "checkpoints/whisper/small.pt"
|
||||
else:
|
||||
whisper_path = "checkpoints/whisper/tiny.pt"
|
||||
|
||||
audio_encoder = Audio2Feature(
|
||||
model_path=whisper_path,
|
||||
device=device,
|
||||
num_frames=config.data.num_frames,
|
||||
audio_feat_length=config.data.audio_feat_length,
|
||||
)
|
||||
|
||||
# VAE
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype)
|
||||
vae.config.scaling_factor = 0.18215
|
||||
vae.config.shift_factor = 0
|
||||
|
||||
# UNet
|
||||
unet, _ = UNet3DConditionModel.from_pretrained(
|
||||
OmegaConf.to_container(config.model),
|
||||
ckpt_path,
|
||||
device="cpu", # Load to CPU first to save memory during init
|
||||
)
|
||||
unet = unet.to(dtype=dtype)
|
||||
|
||||
# Pipeline
|
||||
pipeline = LipsyncPipeline(
|
||||
vae=vae,
|
||||
audio_encoder=audio_encoder,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
).to(device)
|
||||
|
||||
# DeepCache (默认启用)
|
||||
helper = DeepCacheSDHelper(pipe=pipeline)
|
||||
helper.set_params(cache_interval=3, cache_branch_id=0)
|
||||
helper.enable()
|
||||
|
||||
models["pipeline"] = pipeline
|
||||
models["config"] = config
|
||||
models["dtype"] = dtype
|
||||
|
||||
print("✅ LatentSync 模型加载完成,服务就绪!")
|
||||
yield
|
||||
# Clean up if needed
|
||||
models.clear()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
class LipSyncRequest(BaseModel):
|
||||
video_path: str
|
||||
audio_path: str
|
||||
video_out_path: str
|
||||
inference_steps: int = 20
|
||||
guidance_scale: float = 1.5
|
||||
seed: int = 1247
|
||||
temp_dir: str = "temp"
|
||||
|
||||
@app.get("/health")
|
||||
def health_check():
|
||||
return {"status": "ok", "model_loaded": "pipeline" in models}
|
||||
|
||||
@app.post("/lipsync")
|
||||
async def generate_lipsync(req: LipSyncRequest):
|
||||
if "pipeline" not in models:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
if not os.path.exists(req.video_path):
|
||||
raise HTTPException(status_code=404, detail=f"Video not found: {req.video_path}")
|
||||
if not os.path.exists(req.audio_path):
|
||||
raise HTTPException(status_code=404, detail=f"Audio not found: {req.audio_path}")
|
||||
|
||||
print(f"🎬 收到任务: {Path(req.video_path).name} -> {Path(req.video_out_path).name}")
|
||||
|
||||
try:
|
||||
pipeline = models["pipeline"]
|
||||
config = models["config"]
|
||||
dtype = models["dtype"]
|
||||
|
||||
# Set seed
|
||||
if req.seed != -1:
|
||||
set_seed(req.seed)
|
||||
else:
|
||||
torch.seed()
|
||||
|
||||
# Run Inference
|
||||
pipeline(
|
||||
video_path=req.video_path,
|
||||
audio_path=req.audio_path,
|
||||
video_out_path=req.video_out_path,
|
||||
num_frames=config.data.num_frames,
|
||||
num_inference_steps=req.inference_steps,
|
||||
guidance_scale=req.guidance_scale,
|
||||
weight_dtype=dtype,
|
||||
width=config.data.resolution,
|
||||
height=config.data.resolution,
|
||||
mask_image_path=config.data.mask_image_path,
|
||||
temp_dir=req.temp_dir,
|
||||
)
|
||||
|
||||
if os.path.exists(req.video_out_path):
|
||||
return {"status": "success", "output_path": req.video_out_path}
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Output file generation failed")
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8007)
|
||||
Reference in New Issue
Block a user