340 lines
12 KiB
Python
340 lines
12 KiB
Python
from fastapi import APIRouter, HTTPException, BackgroundTasks, Depends, Request
|
||
from pydantic import BaseModel
|
||
from typing import Optional
|
||
from pathlib import Path
|
||
from loguru import logger
|
||
import uuid
|
||
import traceback
|
||
import time
|
||
import httpx
|
||
import os
|
||
from app.services.tts_service import TTSService
|
||
from app.services.video_service import VideoService
|
||
from app.services.lipsync_service import LipSyncService
|
||
from app.services.voice_clone_service import voice_clone_service
|
||
from app.services.storage import storage_service
|
||
from app.core.config import settings
|
||
from app.core.deps import get_current_user
|
||
|
||
router = APIRouter()
|
||
|
||
class GenerateRequest(BaseModel):
|
||
text: str
|
||
voice: str = "zh-CN-YunxiNeural"
|
||
material_path: str
|
||
# 声音克隆模式新增字段
|
||
tts_mode: str = "edgetts" # "edgetts" | "voiceclone"
|
||
ref_audio_id: Optional[str] = None # 参考音频 storage path
|
||
ref_text: Optional[str] = None # 参考音频的转写文字
|
||
|
||
tasks = {} # In-memory task store
|
||
|
||
# 缓存 LipSync 服务实例和健康状态
|
||
_lipsync_service: Optional[LipSyncService] = None
|
||
_lipsync_ready: Optional[bool] = None
|
||
_lipsync_last_check: float = 0
|
||
|
||
def _get_lipsync_service() -> LipSyncService:
|
||
"""获取或创建 LipSync 服务实例(单例模式,避免重复初始化)"""
|
||
global _lipsync_service
|
||
if _lipsync_service is None:
|
||
_lipsync_service = LipSyncService()
|
||
return _lipsync_service
|
||
|
||
async def _check_lipsync_ready(force: bool = False) -> bool:
|
||
"""检查 LipSync 是否就绪(带缓存,5分钟内不重复检查)"""
|
||
global _lipsync_ready, _lipsync_last_check
|
||
|
||
now = time.time()
|
||
# 5分钟缓存
|
||
if not force and _lipsync_ready is not None and (now - _lipsync_last_check) < 300:
|
||
return _lipsync_ready
|
||
|
||
lipsync = _get_lipsync_service()
|
||
health = await lipsync.check_health()
|
||
_lipsync_ready = health.get("ready", False)
|
||
_lipsync_last_check = now
|
||
print(f"[LipSync] Health check: ready={_lipsync_ready}")
|
||
return _lipsync_ready
|
||
|
||
async def _download_material(path_or_url: str, temp_path: Path):
|
||
"""下载素材到临时文件 (流式下载,节省内存)"""
|
||
if path_or_url.startswith("http"):
|
||
# Download from URL
|
||
timeout = httpx.Timeout(None) # Disable timeout for large files
|
||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||
async with client.stream("GET", path_or_url) as resp:
|
||
resp.raise_for_status()
|
||
with open(temp_path, "wb") as f:
|
||
async for chunk in resp.aiter_bytes():
|
||
f.write(chunk)
|
||
else:
|
||
# Local file (legacy or absolute path)
|
||
src = Path(path_or_url)
|
||
if not src.is_absolute():
|
||
src = settings.BASE_DIR.parent / path_or_url
|
||
|
||
if src.exists():
|
||
import shutil
|
||
shutil.copy(src, temp_path)
|
||
else:
|
||
raise FileNotFoundError(f"Material not found: {path_or_url}")
|
||
|
||
async def _process_video_generation(task_id: str, req: GenerateRequest, user_id: str):
|
||
temp_files = [] # Track files to clean up
|
||
try:
|
||
start_time = time.time()
|
||
|
||
tasks[task_id]["status"] = "processing"
|
||
tasks[task_id]["progress"] = 5
|
||
tasks[task_id]["message"] = "正在下载素材..."
|
||
|
||
# Prepare temp dir
|
||
temp_dir = settings.UPLOAD_DIR / "temp"
|
||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 0. Download Material
|
||
input_material_path = temp_dir / f"{task_id}_input.mp4"
|
||
temp_files.append(input_material_path)
|
||
|
||
await _download_material(req.material_path, input_material_path)
|
||
|
||
# 1. TTS - 进度 5% -> 25%
|
||
tasks[task_id]["message"] = "正在生成语音..."
|
||
tasks[task_id]["progress"] = 10
|
||
|
||
audio_path = temp_dir / f"{task_id}_audio.wav"
|
||
temp_files.append(audio_path)
|
||
|
||
if req.tts_mode == "voiceclone":
|
||
# 声音克隆模式
|
||
if not req.ref_audio_id or not req.ref_text:
|
||
raise ValueError("声音克隆模式需要提供参考音频和参考文字")
|
||
|
||
tasks[task_id]["message"] = "正在下载参考音频..."
|
||
|
||
# 从 Supabase 下载参考音频
|
||
ref_audio_local = temp_dir / f"{task_id}_ref.wav"
|
||
temp_files.append(ref_audio_local)
|
||
|
||
ref_audio_url = await storage_service.get_signed_url(
|
||
bucket="ref-audios",
|
||
path=req.ref_audio_id
|
||
)
|
||
await _download_material(ref_audio_url, ref_audio_local)
|
||
|
||
tasks[task_id]["message"] = "正在克隆声音 (Qwen3-TTS)..."
|
||
await voice_clone_service.generate_audio(
|
||
text=req.text,
|
||
ref_audio_path=str(ref_audio_local),
|
||
ref_text=req.ref_text,
|
||
output_path=str(audio_path),
|
||
language="Chinese"
|
||
)
|
||
else:
|
||
# EdgeTTS 模式 (默认)
|
||
tasks[task_id]["message"] = "正在生成语音 (EdgeTTS)..."
|
||
tts = TTSService()
|
||
await tts.generate_audio(req.text, req.voice, str(audio_path))
|
||
|
||
tts_time = time.time() - start_time
|
||
print(f"[Pipeline] TTS completed in {tts_time:.1f}s")
|
||
tasks[task_id]["progress"] = 25
|
||
|
||
# 2. LipSync - 进度 25% -> 85%
|
||
tasks[task_id]["message"] = "正在合成唇形 (LatentSync)..."
|
||
tasks[task_id]["progress"] = 30
|
||
|
||
lipsync = _get_lipsync_service()
|
||
lipsync_video_path = temp_dir / f"{task_id}_lipsync.mp4"
|
||
temp_files.append(lipsync_video_path)
|
||
|
||
# 使用缓存的健康检查结果
|
||
lipsync_start = time.time()
|
||
is_ready = await _check_lipsync_ready()
|
||
|
||
if is_ready:
|
||
print(f"[LipSync] Starting LatentSync inference...")
|
||
tasks[task_id]["progress"] = 35
|
||
tasks[task_id]["message"] = "正在运行 LatentSync 推理..."
|
||
await lipsync.generate(str(input_material_path), str(audio_path), str(lipsync_video_path))
|
||
else:
|
||
# Skip lipsync if not available
|
||
print(f"[LipSync] LatentSync not ready, copying original video")
|
||
tasks[task_id]["message"] = "唇形同步不可用,使用原始视频..."
|
||
import shutil
|
||
shutil.copy(str(input_material_path), lipsync_video_path)
|
||
|
||
lipsync_time = time.time() - lipsync_start
|
||
print(f"[Pipeline] LipSync completed in {lipsync_time:.1f}s")
|
||
tasks[task_id]["progress"] = 85
|
||
|
||
# 3. Composition - 进度 85% -> 100%
|
||
tasks[task_id]["message"] = "正在合成最终视频..."
|
||
tasks[task_id]["progress"] = 90
|
||
|
||
video = VideoService()
|
||
final_output_local_path = temp_dir / f"{task_id}_output.mp4"
|
||
temp_files.append(final_output_local_path)
|
||
|
||
await video.compose(str(lipsync_video_path), str(audio_path), str(final_output_local_path))
|
||
|
||
total_time = time.time() - start_time
|
||
|
||
# 4. Upload to Supabase with user isolation
|
||
tasks[task_id]["message"] = "正在上传结果..."
|
||
tasks[task_id]["progress"] = 95
|
||
|
||
# 使用 user_id 作为目录前缀实现隔离
|
||
storage_path = f"{user_id}/{task_id}_output.mp4"
|
||
with open(final_output_local_path, "rb") as f:
|
||
file_data = f.read()
|
||
await storage_service.upload_file(
|
||
bucket=storage_service.BUCKET_OUTPUTS,
|
||
path=storage_path,
|
||
file_data=file_data,
|
||
content_type="video/mp4"
|
||
)
|
||
|
||
# Get Signed URL
|
||
signed_url = await storage_service.get_signed_url(
|
||
bucket=storage_service.BUCKET_OUTPUTS,
|
||
path=storage_path
|
||
)
|
||
|
||
print(f"[Pipeline] Total generation time: {total_time:.1f}s")
|
||
|
||
tasks[task_id]["status"] = "completed"
|
||
tasks[task_id]["progress"] = 100
|
||
tasks[task_id]["message"] = f"生成完成!耗时 {total_time:.0f} 秒"
|
||
tasks[task_id]["output"] = storage_path
|
||
tasks[task_id]["download_url"] = signed_url
|
||
|
||
except Exception as e:
|
||
tasks[task_id]["status"] = "failed"
|
||
tasks[task_id]["message"] = f"错误: {str(e)}"
|
||
tasks[task_id]["error"] = traceback.format_exc()
|
||
logger.error(f"Generate video failed: {e}")
|
||
finally:
|
||
# Cleanup temp files
|
||
for f in temp_files:
|
||
try:
|
||
if f.exists():
|
||
f.unlink()
|
||
except Exception as e:
|
||
print(f"Error cleaning up {f}: {e}")
|
||
|
||
@router.post("/generate")
|
||
async def generate_video(
|
||
req: GenerateRequest,
|
||
background_tasks: BackgroundTasks,
|
||
current_user: dict = Depends(get_current_user)
|
||
):
|
||
user_id = current_user["id"]
|
||
task_id = str(uuid.uuid4())
|
||
tasks[task_id] = {"status": "pending", "task_id": task_id, "progress": 0, "user_id": user_id}
|
||
background_tasks.add_task(_process_video_generation, task_id, req, user_id)
|
||
return {"task_id": task_id}
|
||
|
||
@router.get("/tasks/{task_id}")
|
||
async def get_task(task_id: str):
|
||
return tasks.get(task_id, {"status": "not_found"})
|
||
|
||
@router.get("/tasks")
|
||
async def list_tasks():
|
||
return {"tasks": list(tasks.values())}
|
||
|
||
@router.get("/lipsync/health")
|
||
async def lipsync_health():
|
||
"""获取 LipSync 服务健康状态"""
|
||
lipsync = _get_lipsync_service()
|
||
return await lipsync.check_health()
|
||
|
||
|
||
@router.get("/voiceclone/health")
|
||
async def voiceclone_health():
|
||
"""获取声音克隆服务健康状态"""
|
||
return await voice_clone_service.check_health()
|
||
|
||
|
||
@router.get("/generated")
|
||
async def list_generated_videos(current_user: dict = Depends(get_current_user)):
|
||
"""从 Storage 读取当前用户生成的视频列表"""
|
||
user_id = current_user["id"]
|
||
try:
|
||
# 只列出当前用户目录下的文件
|
||
files_obj = await storage_service.list_files(
|
||
bucket=storage_service.BUCKET_OUTPUTS,
|
||
path=user_id
|
||
)
|
||
|
||
videos = []
|
||
for f in files_obj:
|
||
name = f.get('name')
|
||
if not name or name == '.emptyFolderPlaceholder':
|
||
continue
|
||
|
||
# 过滤非 output.mp4 文件
|
||
if not name.endswith("_output.mp4"):
|
||
continue
|
||
|
||
# 获取 ID (即文件名去除后缀)
|
||
video_id = Path(name).stem
|
||
|
||
# 完整路径包含 user_id
|
||
full_path = f"{user_id}/{name}"
|
||
|
||
# 获取签名链接
|
||
signed_url = await storage_service.get_signed_url(
|
||
bucket=storage_service.BUCKET_OUTPUTS,
|
||
path=full_path
|
||
)
|
||
|
||
metadata = f.get('metadata', {})
|
||
size = metadata.get('size', 0)
|
||
# created_at 在顶层,是 ISO 字符串,转换为 Unix 时间戳
|
||
created_at_str = f.get('created_at', '')
|
||
created_at = 0
|
||
if created_at_str:
|
||
from datetime import datetime
|
||
try:
|
||
dt = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
|
||
created_at = int(dt.timestamp())
|
||
except:
|
||
pass
|
||
|
||
videos.append({
|
||
"id": video_id,
|
||
"name": name,
|
||
"path": signed_url, # Direct playable URL
|
||
"size_mb": size / (1024 * 1024),
|
||
"created_at": created_at
|
||
})
|
||
|
||
# Sort by created_at desc (newest first)
|
||
# Supabase API usually returns ISO string, simpler string sort works for ISO
|
||
videos.sort(key=lambda x: x.get("created_at", ""), reverse=True)
|
||
return {"videos": videos}
|
||
|
||
except Exception as e:
|
||
logger.error(f"List generated videos failed: {e}")
|
||
return {"videos": []}
|
||
|
||
|
||
@router.delete("/generated/{video_id}")
|
||
async def delete_generated_video(video_id: str, current_user: dict = Depends(get_current_user)):
|
||
"""删除生成的视频"""
|
||
user_id = current_user["id"]
|
||
try:
|
||
# video_id 通常是 uuid_output,完整路径需要加上 user_id
|
||
storage_path = f"{user_id}/{video_id}.mp4"
|
||
|
||
await storage_service.delete_file(
|
||
bucket=storage_service.BUCKET_OUTPUTS,
|
||
path=storage_path
|
||
)
|
||
return {"success": True, "message": "视频已删除"}
|
||
except Exception as e:
|
||
raise HTTPException(500, f"删除失败: {str(e)}")
|
||
|