更新
This commit is contained in:
@@ -65,14 +65,15 @@ async def lifespan(app: FastAPI):
|
||||
# --- 模型加载逻辑 (参考 inference.py) ---
|
||||
print("⏳ 正在加载 LatentSync 模型...")
|
||||
|
||||
# 默认配置路径 (相对于根目录)
|
||||
unet_config_path = "configs/unet/stage2_512.yaml"
|
||||
ckpt_path = "checkpoints/latentsync_unet.pt"
|
||||
# 使用绝对路径,确保可以从任意目录启动
|
||||
latentsync_root = Path(__file__).resolve().parent.parent # scripts -> LatentSync 根目录
|
||||
unet_config_path = latentsync_root / "configs" / "unet" / "stage2_512.yaml"
|
||||
ckpt_path = latentsync_root / "checkpoints" / "latentsync_unet.pt"
|
||||
|
||||
if not os.path.exists(unet_config_path):
|
||||
print(f"⚠️ 找不到配置文件: {unet_config_path},请确保在 models/LatentSync 根目录运行")
|
||||
if not unet_config_path.exists():
|
||||
print(f"⚠️ 找不到配置文件: {unet_config_path}")
|
||||
|
||||
config = OmegaConf.load(unet_config_path)
|
||||
config = OmegaConf.load(str(unet_config_path))
|
||||
|
||||
# Check GPU
|
||||
is_fp16_supported = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] > 7
|
||||
@@ -85,13 +86,13 @@ async def lifespan(app: FastAPI):
|
||||
else:
|
||||
print("⚠️ 警告: 未检测到 GPU,将使用 CPU 进行推理 (速度极慢)")
|
||||
|
||||
scheduler = DDIMScheduler.from_pretrained("configs")
|
||||
scheduler = DDIMScheduler.from_pretrained(str(latentsync_root / "configs"))
|
||||
|
||||
# Whisper Model
|
||||
if config.model.cross_attention_dim == 768:
|
||||
whisper_path = "checkpoints/whisper/small.pt"
|
||||
whisper_path = str(latentsync_root / "checkpoints" / "whisper" / "small.pt")
|
||||
else:
|
||||
whisper_path = "checkpoints/whisper/tiny.pt"
|
||||
whisper_path = str(latentsync_root / "checkpoints" / "whisper" / "tiny.pt")
|
||||
|
||||
audio_encoder = Audio2Feature(
|
||||
model_path=whisper_path,
|
||||
@@ -108,7 +109,7 @@ async def lifespan(app: FastAPI):
|
||||
# UNet
|
||||
unet, _ = UNet3DConditionModel.from_pretrained(
|
||||
OmegaConf.to_container(config.model),
|
||||
ckpt_path,
|
||||
str(ckpt_path),
|
||||
device="cpu", # Load to CPU first to save memory during init
|
||||
)
|
||||
unet = unet.to(dtype=dtype)
|
||||
@@ -129,6 +130,7 @@ async def lifespan(app: FastAPI):
|
||||
models["pipeline"] = pipeline
|
||||
models["config"] = config
|
||||
models["dtype"] = dtype
|
||||
models["latentsync_root"] = latentsync_root
|
||||
|
||||
print("✅ LatentSync 模型加载完成,服务就绪!")
|
||||
yield
|
||||
@@ -167,6 +169,7 @@ async def generate_lipsync(req: LipSyncRequest):
|
||||
pipeline = models["pipeline"]
|
||||
config = models["config"]
|
||||
dtype = models["dtype"]
|
||||
latentsync_root = models["latentsync_root"]
|
||||
|
||||
# Set seed
|
||||
if req.seed != -1:
|
||||
@@ -185,7 +188,7 @@ async def generate_lipsync(req: LipSyncRequest):
|
||||
weight_dtype=dtype,
|
||||
width=config.data.resolution,
|
||||
height=config.data.resolution,
|
||||
mask_image_path=config.data.mask_image_path,
|
||||
mask_image_path=str(latentsync_root / config.data.mask_image_path),
|
||||
temp_dir=req.temp_dir,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user