Init: 初版代码
This commit is contained in:
39
backend/.env.example
Normal file
39
backend/.env.example
Normal file
@@ -0,0 +1,39 @@
|
||||
# ViGent 环境配置示例
|
||||
# 复制此文件为 .env 并填入实际值
|
||||
|
||||
# 调试模式
|
||||
DEBUG=true
|
||||
|
||||
# Redis 配置 (Celery 任务队列)
|
||||
REDIS_URL=redis://localhost:6379/0
|
||||
|
||||
# =============== TTS 配置 ===============
|
||||
# 默认 TTS 音色
|
||||
DEFAULT_TTS_VOICE=zh-CN-YunxiNeural
|
||||
|
||||
# =============== MuseTalk 配置 ===============
|
||||
# GPU 选择 (0=第一块GPU, 1=第二块GPU)
|
||||
MUSETALK_GPU_ID=1
|
||||
|
||||
# 使用本地模式 (true) 或远程 API (false)
|
||||
MUSETALK_LOCAL=true
|
||||
|
||||
# 远程 API 地址 (仅 MUSETALK_LOCAL=false 时使用)
|
||||
# MUSETALK_API_URL=http://localhost:8001
|
||||
|
||||
# 模型版本 (v1 或 v15,推荐 v15)
|
||||
MUSETALK_VERSION=v15
|
||||
|
||||
# 推理批次大小 (根据 GPU 显存调整,RTX 3090 可用 8-16)
|
||||
MUSETALK_BATCH_SIZE=8
|
||||
|
||||
# 使用半精度加速 (推荐开启,减少显存占用)
|
||||
MUSETALK_USE_FLOAT16=true
|
||||
|
||||
# =============== 上传配置 ===============
|
||||
# 最大上传文件大小 (MB)
|
||||
MAX_UPLOAD_SIZE_MB=500
|
||||
|
||||
# =============== FFmpeg 配置 ===============
|
||||
# FFmpeg 路径 (如果不在系统 PATH 中)
|
||||
# FFMPEG_PATH=/usr/bin/ffmpeg
|
||||
0
backend/app/__init__.py
Normal file
0
backend/app/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
53
backend/app/api/materials.py
Normal file
53
backend/app/api/materials.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException
|
||||
from app.core.config import settings
|
||||
import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/")
|
||||
async def upload_material(file: UploadFile = File(...)):
|
||||
if not file.filename.lower().endswith(('.mp4', '.mov', '.avi')):
|
||||
raise HTTPException(400, "Invalid format")
|
||||
|
||||
file_id = str(uuid.uuid4())
|
||||
ext = Path(file.filename).suffix
|
||||
save_path = settings.UPLOAD_DIR / "materials" / f"{file_id}{ext}"
|
||||
|
||||
# Save file
|
||||
with open(save_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
# Calculate size
|
||||
size_mb = save_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
return {
|
||||
"id": file_id,
|
||||
"name": file.filename,
|
||||
"path": f"uploads/materials/{file_id}{ext}",
|
||||
"size_mb": size_mb,
|
||||
"type": "video"
|
||||
}
|
||||
|
||||
@router.get("/")
|
||||
async def list_materials():
|
||||
materials_dir = settings.UPLOAD_DIR / "materials"
|
||||
files = []
|
||||
if materials_dir.exists():
|
||||
for f in materials_dir.glob("*"):
|
||||
try:
|
||||
stat = f.stat()
|
||||
files.append({
|
||||
"id": f.stem,
|
||||
"name": f.name,
|
||||
"path": f"uploads/materials/{f.name}",
|
||||
"size_mb": stat.st_size / (1024 * 1024),
|
||||
"type": "video",
|
||||
"created_at": stat.st_ctime
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
# Sort by creation time desc
|
||||
files.sort(key=lambda x: x.get("created_at", 0), reverse=True)
|
||||
return {"materials": files}
|
||||
59
backend/app/api/publish.py
Normal file
59
backend/app/api/publish.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
发布管理 API
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from app.services.publish_service import PublishService
|
||||
|
||||
router = APIRouter()
|
||||
publish_service = PublishService()
|
||||
|
||||
class PublishRequest(BaseModel):
|
||||
video_path: str
|
||||
platform: str
|
||||
title: str
|
||||
tags: List[str] = []
|
||||
description: str = ""
|
||||
publish_time: Optional[datetime] = None
|
||||
|
||||
class PublishResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
platform: str
|
||||
url: Optional[str] = None
|
||||
|
||||
@router.post("/", response_model=PublishResponse)
|
||||
async def publish_video(request: PublishRequest, background_tasks: BackgroundTasks):
|
||||
try:
|
||||
result = await publish_service.publish(
|
||||
video_path=request.video_path,
|
||||
platform=request.platform,
|
||||
title=request.title,
|
||||
tags=request.tags,
|
||||
description=request.description,
|
||||
publish_time=request.publish_time
|
||||
)
|
||||
return PublishResponse(
|
||||
success=result.get("success", False),
|
||||
message=result.get("message", ""),
|
||||
platform=request.platform,
|
||||
url=result.get("url")
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"发布失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/platforms")
|
||||
async def list_platforms():
|
||||
return {"platforms": [{"id": pid, **pinfo} for pid, pinfo in publish_service.PLATFORMS.items()]}
|
||||
|
||||
@router.get("/accounts")
|
||||
async def list_accounts():
|
||||
return {"accounts": publish_service.get_accounts()}
|
||||
|
||||
@router.post("/login/{platform}")
|
||||
async def login_platform(platform: str):
|
||||
return await publish_service.login(platform)
|
||||
85
backend/app/api/videos.py
Normal file
85
backend/app/api/videos.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
import traceback
|
||||
from app.services.tts_service import TTSService
|
||||
from app.services.video_service import VideoService
|
||||
from app.services.lipsync_service import LipSyncService
|
||||
from app.core.config import settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
text: str
|
||||
voice: str = "zh-CN-YunxiNeural"
|
||||
material_path: str
|
||||
|
||||
tasks = {} # In-memory task store
|
||||
|
||||
async def _process_video_generation(task_id: str, req: GenerateRequest):
|
||||
try:
|
||||
# Resolve path if it's relative
|
||||
input_material_path = Path(req.material_path)
|
||||
if not input_material_path.is_absolute():
|
||||
input_material_path = settings.BASE_DIR.parent / req.material_path
|
||||
|
||||
tasks[task_id]["status"] = "processing"
|
||||
tasks[task_id]["progress"] = 5
|
||||
tasks[task_id]["message"] = "Initializing generation..."
|
||||
|
||||
# 1. TTS
|
||||
tasks[task_id]["message"] = "Generating Audio (TTS)..."
|
||||
tts = TTSService()
|
||||
audio_path = settings.OUTPUT_DIR / f"{task_id}_audio.mp3"
|
||||
await tts.generate_audio(req.text, req.voice, str(audio_path))
|
||||
|
||||
tasks[task_id]["progress"] = 30
|
||||
|
||||
# 2. LipSync
|
||||
tasks[task_id]["message"] = "Synthesizing Video (MuseTalk)..."
|
||||
lipsync = LipSyncService()
|
||||
lipsync_video_path = settings.OUTPUT_DIR / f"{task_id}_lipsync.mp4"
|
||||
|
||||
# Check health and generate
|
||||
if await lipsync.check_health():
|
||||
await lipsync.generate(str(input_material_path), str(audio_path), str(lipsync_video_path))
|
||||
else:
|
||||
# Skip lipsync if not available
|
||||
import shutil
|
||||
shutil.copy(str(input_material_path), lipsync_video_path)
|
||||
|
||||
tasks[task_id]["progress"] = 80
|
||||
|
||||
# 3. Composition
|
||||
tasks[task_id]["message"] = "Final compositing..."
|
||||
video = VideoService()
|
||||
final_output = settings.OUTPUT_DIR / f"{task_id}_output.mp4"
|
||||
await video.compose(str(lipsync_video_path), str(audio_path), str(final_output))
|
||||
|
||||
tasks[task_id]["status"] = "completed"
|
||||
tasks[task_id]["progress"] = 100
|
||||
tasks[task_id]["message"] = "Generation Complete!"
|
||||
tasks[task_id]["output"] = str(final_output)
|
||||
tasks[task_id]["download_url"] = f"/outputs/{final_output.name}"
|
||||
|
||||
except Exception as e:
|
||||
tasks[task_id]["status"] = "failed"
|
||||
tasks[task_id]["message"] = f"Error: {str(e)}"
|
||||
tasks[task_id]["error"] = traceback.format_exc()
|
||||
|
||||
@router.post("/generate")
|
||||
async def generate_video(req: GenerateRequest, background_tasks: BackgroundTasks):
|
||||
task_id = str(uuid.uuid4())
|
||||
tasks[task_id] = {"status": "pending", "task_id": task_id}
|
||||
background_tasks.add_task(_process_video_generation, task_id, req)
|
||||
return {"task_id": task_id}
|
||||
|
||||
@router.get("/tasks/{task_id}")
|
||||
async def get_task(task_id: str):
|
||||
return tasks.get(task_id, {"status": "not_found"})
|
||||
|
||||
@router.get("/tasks")
|
||||
async def list_tasks():
|
||||
return {"tasks": list(tasks.values())}
|
||||
0
backend/app/core/__init__.py
Normal file
0
backend/app/core/__init__.py
Normal file
36
backend/app/core/config.py
Normal file
36
backend/app/core/config.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# 基础路径配置
|
||||
BASE_DIR: Path = Path(__file__).resolve().parent.parent
|
||||
UPLOAD_DIR: Path = BASE_DIR.parent / "uploads"
|
||||
OUTPUT_DIR: Path = BASE_DIR.parent / "outputs"
|
||||
|
||||
# 数据库/缓存
|
||||
REDIS_URL: str = "redis://localhost:6379/0"
|
||||
DEBUG: bool = True
|
||||
|
||||
# TTS 配置
|
||||
DEFAULT_TTS_VOICE: str = "zh-CN-YunxiNeural"
|
||||
MAX_UPLOAD_SIZE_MB: int = 500
|
||||
|
||||
# MuseTalk 配置
|
||||
MUSETALK_GPU_ID: int = 1 # GPU ID (默认使用 GPU1)
|
||||
MUSETALK_LOCAL: bool = True # 使用本地推理 (False 则使用远程 API)
|
||||
MUSETALK_API_URL: str = "http://localhost:8001" # 远程 API 地址
|
||||
MUSETALK_VERSION: Literal["v1", "v15"] = "v15" # 模型版本
|
||||
MUSETALK_BATCH_SIZE: int = 8 # 推理批次大小
|
||||
MUSETALK_USE_FLOAT16: bool = True # 使用半精度加速
|
||||
|
||||
@property
|
||||
def MUSETALK_DIR(self) -> Path:
|
||||
"""MuseTalk 目录路径 (动态计算)"""
|
||||
return self.BASE_DIR.parent.parent / "models" / "MuseTalk"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
extra = "ignore" # 忽略未知的环境变量
|
||||
|
||||
settings = Settings()
|
||||
32
backend/app/main.py
Normal file
32
backend/app/main.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.core import config
|
||||
from app.api import materials, videos, publish
|
||||
|
||||
settings = config.settings
|
||||
|
||||
app = FastAPI(title="ViGent TalkingHead Agent")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Create dirs
|
||||
settings.UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
settings.OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
(settings.UPLOAD_DIR / "materials").mkdir(exist_ok=True)
|
||||
|
||||
app.mount("/outputs", StaticFiles(directory=str(settings.OUTPUT_DIR)), name="outputs")
|
||||
|
||||
app.include_router(materials.router, prefix="/api/materials", tags=["Materials"])
|
||||
app.include_router(videos.router, prefix="/api/videos", tags=["Videos"])
|
||||
app.include_router(publish.router, prefix="/api/publish", tags=["Publish"])
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
return {"status": "ok"}
|
||||
0
backend/app/services/__init__.py
Normal file
0
backend/app/services/__init__.py
Normal file
448
backend/app/services/lipsync_service.py
Normal file
448
backend/app/services/lipsync_service.py
Normal file
@@ -0,0 +1,448 @@
|
||||
"""
|
||||
唇形同步服务
|
||||
支持本地 MuseTalk 推理 (Python API) 或远程 MuseTalk API
|
||||
配置为使用 GPU1 (CUDA:1)
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from typing import Optional, Any
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# 设置 MuseTalk 使用 GPU1 (在导入 torch 之前设置)
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(settings.MUSETALK_GPU_ID))
|
||||
|
||||
|
||||
class LipSyncService:
|
||||
"""唇形同步服务 - MuseTalk 集成"""
|
||||
|
||||
def __init__(self):
|
||||
self.use_local = settings.MUSETALK_LOCAL
|
||||
self.api_url = settings.MUSETALK_API_URL
|
||||
self.version = settings.MUSETALK_VERSION
|
||||
self.musetalk_dir = settings.MUSETALK_DIR
|
||||
|
||||
# 模型相关 (懒加载)
|
||||
self._model_loaded = False
|
||||
self._vae = None
|
||||
self._unet = None
|
||||
self._pe = None
|
||||
self._whisper = None
|
||||
self._audio_processor = None
|
||||
self._face_parser = None
|
||||
self._device = None
|
||||
|
||||
# 运行时检测
|
||||
self._gpu_available: Optional[bool] = None
|
||||
self._weights_available: Optional[bool] = None
|
||||
|
||||
def _check_gpu(self) -> bool:
|
||||
"""检查 GPU 是否可用"""
|
||||
if self._gpu_available is not None:
|
||||
return self._gpu_available
|
||||
|
||||
try:
|
||||
import torch
|
||||
self._gpu_available = torch.cuda.is_available()
|
||||
if self._gpu_available:
|
||||
device_name = torch.cuda.get_device_name(0)
|
||||
logger.info(f"✅ GPU 可用: {device_name}")
|
||||
else:
|
||||
logger.warning("⚠️ GPU 不可用,将使用 Fallback 模式")
|
||||
except ImportError:
|
||||
self._gpu_available = False
|
||||
logger.warning("⚠️ PyTorch 未安装,将使用 Fallback 模式")
|
||||
|
||||
return self._gpu_available
|
||||
|
||||
def _check_weights(self) -> bool:
|
||||
"""检查模型权重是否存在"""
|
||||
if self._weights_available is not None:
|
||||
return self._weights_available
|
||||
|
||||
# 检查关键权重文件
|
||||
required_dirs = [
|
||||
self.musetalk_dir / "models" / "musetalkV15",
|
||||
self.musetalk_dir / "models" / "whisper",
|
||||
]
|
||||
|
||||
self._weights_available = all(d.exists() for d in required_dirs)
|
||||
|
||||
if self._weights_available:
|
||||
logger.info("✅ MuseTalk 权重文件已就绪")
|
||||
else:
|
||||
missing = [str(d) for d in required_dirs if not d.exists()]
|
||||
logger.warning(f"⚠️ 缺少权重文件: {missing}")
|
||||
|
||||
return self._weights_available
|
||||
|
||||
def _load_models(self):
|
||||
"""懒加载 MuseTalk 模型 (Python API 方式)"""
|
||||
if self._model_loaded:
|
||||
return True
|
||||
|
||||
if not self._check_gpu() or not self._check_weights():
|
||||
return False
|
||||
|
||||
logger.info("🔄 加载 MuseTalk 模型到 GPU...")
|
||||
|
||||
try:
|
||||
# 添加 MuseTalk 到 Python 路径
|
||||
if str(self.musetalk_dir) not in sys.path:
|
||||
sys.path.insert(0, str(self.musetalk_dir))
|
||||
logger.debug(f"Added to sys.path: {self.musetalk_dir}")
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from transformers import WhisperModel
|
||||
|
||||
# 导入 MuseTalk 模块
|
||||
from musetalk.utils.utils import load_all_model
|
||||
from musetalk.utils.audio_processor import AudioProcessor
|
||||
from musetalk.utils.face_parsing import FaceParsing
|
||||
|
||||
# 设置设备 (CUDA_VISIBLE_DEVICES=1 后,可见设备变为 cuda:0)
|
||||
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# 加载模型
|
||||
unet_model_path = str(self.musetalk_dir / "models" / "musetalkV15" / "unet.pth")
|
||||
unet_config = str(self.musetalk_dir / "models" / "musetalk" / "config.json")
|
||||
whisper_dir = str(self.musetalk_dir / "models" / "whisper")
|
||||
|
||||
self._vae, self._unet, self._pe = load_all_model(
|
||||
unet_model_path=unet_model_path,
|
||||
vae_type="sd-vae",
|
||||
unet_config=unet_config,
|
||||
device=self._device
|
||||
)
|
||||
|
||||
# 使用半精度加速
|
||||
if settings.MUSETALK_USE_FLOAT16:
|
||||
self._pe = self._pe.half()
|
||||
self._vae.vae = self._vae.vae.half()
|
||||
self._unet.model = self._unet.model.half()
|
||||
|
||||
# 移动到 GPU
|
||||
self._pe = self._pe.to(self._device)
|
||||
self._vae.vae = self._vae.vae.to(self._device)
|
||||
self._unet.model = self._unet.model.to(self._device)
|
||||
|
||||
# 加载 Whisper
|
||||
weight_dtype = self._unet.model.dtype
|
||||
self._whisper = WhisperModel.from_pretrained(whisper_dir)
|
||||
self._whisper = self._whisper.to(device=self._device, dtype=weight_dtype).eval()
|
||||
self._whisper.requires_grad_(False)
|
||||
|
||||
# 音频处理器
|
||||
self._audio_processor = AudioProcessor(feature_extractor_path=whisper_dir)
|
||||
|
||||
# 人脸解析器 (v15 版本支持更多参数)
|
||||
if self.version == "v15":
|
||||
self._face_parser = FaceParsing(
|
||||
left_cheek_width=90,
|
||||
right_cheek_width=90
|
||||
)
|
||||
else:
|
||||
self._face_parser = FaceParsing()
|
||||
|
||||
self._model_loaded = True
|
||||
logger.info("✅ MuseTalk 模型加载完成")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ MuseTalk 模型加载失败: {e}")
|
||||
import traceback
|
||||
logger.debug(traceback.format_exc())
|
||||
return False
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
video_path: str,
|
||||
audio_path: str,
|
||||
output_path: str,
|
||||
fps: int = 25
|
||||
) -> str:
|
||||
"""生成唇形同步视频"""
|
||||
logger.info(f"🎬 唇形同步任务: {Path(video_path).name} + {Path(audio_path).name}")
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 决定使用哪种模式
|
||||
if self.use_local:
|
||||
if self._load_models():
|
||||
return await self._local_generate_api(video_path, audio_path, output_path, fps)
|
||||
else:
|
||||
logger.warning("⚠️ 本地推理失败,尝试 subprocess 方式")
|
||||
return await self._local_generate_subprocess(video_path, audio_path, output_path, fps)
|
||||
else:
|
||||
return await self._remote_generate(video_path, audio_path, output_path, fps)
|
||||
|
||||
async def _local_generate_api(
|
||||
self,
|
||||
video_path: str,
|
||||
audio_path: str,
|
||||
output_path: str,
|
||||
fps: int
|
||||
) -> str:
|
||||
"""使用 Python API 进行本地推理"""
|
||||
import torch
|
||||
import cv2
|
||||
import copy
|
||||
import glob
|
||||
import pickle
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from musetalk.utils.utils import get_file_type, get_video_fps, datagen
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
||||
from musetalk.utils.blending import get_image
|
||||
|
||||
logger.info("🔄 开始 MuseTalk 推理 (Python API)...")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmpdir = Path(tmpdir)
|
||||
result_img_dir = tmpdir / "frames"
|
||||
result_img_dir.mkdir()
|
||||
|
||||
# 1. 提取视频帧
|
||||
logger.info("📹 提取视频帧...")
|
||||
if get_file_type(video_path) == "video":
|
||||
frames_dir = tmpdir / "input_frames"
|
||||
frames_dir.mkdir()
|
||||
cmd = f'ffmpeg -v fatal -i "{video_path}" -start_number 0 "{frames_dir}/%08d.png"'
|
||||
subprocess.run(cmd, shell=True, check=True)
|
||||
input_img_list = sorted(glob.glob(str(frames_dir / "*.png")))
|
||||
video_fps = get_video_fps(video_path)
|
||||
else:
|
||||
input_img_list = [video_path]
|
||||
video_fps = fps
|
||||
|
||||
# 2. 提取音频特征
|
||||
logger.info("🎵 提取音频特征...")
|
||||
whisper_input_features, librosa_length = self._audio_processor.get_audio_feature(audio_path)
|
||||
weight_dtype = self._unet.model.dtype
|
||||
whisper_chunks = self._audio_processor.get_whisper_chunk(
|
||||
whisper_input_features,
|
||||
self._device,
|
||||
weight_dtype,
|
||||
self._whisper,
|
||||
librosa_length,
|
||||
fps=video_fps,
|
||||
audio_padding_length_left=2,
|
||||
audio_padding_length_right=2,
|
||||
)
|
||||
|
||||
# 3. 预处理图像
|
||||
logger.info("🧑 检测人脸关键点...")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift=0)
|
||||
|
||||
# 4. 编码潜在表示
|
||||
logger.info("🔢 编码图像潜在表示...")
|
||||
input_latent_list = []
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
if self.version == "v15":
|
||||
y2 = min(y2 + 10, frame.shape[0])
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame, (256, 256), interpolation=cv2.INTER_LANCZOS4)
|
||||
latents = self._vae.get_latents_for_unet(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
|
||||
# 循环帧列表
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
|
||||
# 5. 批量推理
|
||||
logger.info("🤖 执行 MuseTalk 推理...")
|
||||
timesteps = torch.tensor([0], device=self._device)
|
||||
batch_size = settings.MUSETALK_BATCH_SIZE
|
||||
video_num = len(whisper_chunks)
|
||||
|
||||
gen = datagen(
|
||||
whisper_chunks=whisper_chunks,
|
||||
vae_encode_latents=input_latent_list_cycle,
|
||||
batch_size=batch_size,
|
||||
delay_frame=0,
|
||||
device=self._device,
|
||||
)
|
||||
|
||||
res_frame_list = []
|
||||
total = int(np.ceil(float(video_num) / batch_size))
|
||||
|
||||
with torch.no_grad():
|
||||
for i, (whisper_batch, latent_batch) in enumerate(tqdm(gen, total=total, desc="推理")):
|
||||
audio_feature_batch = self._pe(whisper_batch)
|
||||
latent_batch = latent_batch.to(dtype=self._unet.model.dtype)
|
||||
pred_latents = self._unet.model(
|
||||
latent_batch, timesteps, encoder_hidden_states=audio_feature_batch
|
||||
).sample
|
||||
recon = self._vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
res_frame_list.append(res_frame)
|
||||
|
||||
# 6. 合成结果帧
|
||||
logger.info("🖼️ 合成结果帧...")
|
||||
for i, res_frame in enumerate(tqdm(res_frame_list, desc="合成")):
|
||||
bbox = coord_list_cycle[i % len(coord_list_cycle)]
|
||||
ori_frame = copy.deepcopy(frame_list_cycle[i % len(frame_list_cycle)])
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
||||
if self.version == "v15":
|
||||
y2 = min(y2 + 10, ori_frame.shape[0])
|
||||
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
|
||||
except:
|
||||
continue
|
||||
|
||||
if self.version == "v15":
|
||||
combine_frame = get_image(
|
||||
ori_frame, res_frame, [x1, y1, x2, y2],
|
||||
mode="jaw", fp=self._face_parser
|
||||
)
|
||||
else:
|
||||
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=self._face_parser)
|
||||
|
||||
cv2.imwrite(str(result_img_dir / f"{i:08d}.png"), combine_frame)
|
||||
|
||||
# 7. 合成视频
|
||||
logger.info("🎬 合成最终视频...")
|
||||
temp_video = tmpdir / "temp_video.mp4"
|
||||
cmd_video = f'ffmpeg -y -v warning -r {video_fps} -f image2 -i "{result_img_dir}/%08d.png" -vcodec libx264 -vf format=yuv420p -crf 18 "{temp_video}"'
|
||||
subprocess.run(cmd_video, shell=True, check=True)
|
||||
|
||||
# 8. 添加音频
|
||||
cmd_audio = f'ffmpeg -y -v warning -i "{audio_path}" -i "{temp_video}" -c:v copy -c:a aac -shortest "{output_path}"'
|
||||
subprocess.run(cmd_audio, shell=True, check=True)
|
||||
|
||||
logger.info(f"✅ 唇形同步完成: {output_path}")
|
||||
return output_path
|
||||
|
||||
async def _local_generate_subprocess(
|
||||
self,
|
||||
video_path: str,
|
||||
audio_path: str,
|
||||
output_path: str,
|
||||
fps: int
|
||||
) -> str:
|
||||
"""使用 subprocess 调用 MuseTalk CLI"""
|
||||
logger.info("🔄 使用 subprocess 调用 MuseTalk...")
|
||||
|
||||
# 如果权重不存在,直接 fallback
|
||||
if not self._check_weights():
|
||||
logger.warning("⚠️ 权重不存在,使用 Fallback 模式")
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# 创建临时配置文件
|
||||
config_path = Path(tmpdir) / "inference_config.yaml"
|
||||
config_content = f"""
|
||||
task1:
|
||||
video_path: "{video_path}"
|
||||
audio_path: "{audio_path}"
|
||||
result_name: "output.mp4"
|
||||
"""
|
||||
config_path.write_text(config_content)
|
||||
|
||||
result_dir = Path(tmpdir) / "results"
|
||||
result_dir.mkdir()
|
||||
|
||||
cmd = [
|
||||
sys.executable, "-m", "scripts.inference",
|
||||
"--version", self.version,
|
||||
"--inference_config", str(config_path),
|
||||
"--result_dir", str(result_dir),
|
||||
"--gpu_id", "0", # 因为 CUDA_VISIBLE_DEVICES 已设置
|
||||
]
|
||||
|
||||
if settings.MUSETALK_USE_FLOAT16:
|
||||
cmd.append("--use_float16")
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
cwd=str(self.musetalk_dir),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env={**os.environ, "CUDA_VISIBLE_DEVICES": str(settings.MUSETALK_GPU_ID)}
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error(f"MuseTalk CLI 失败: {result.stderr}")
|
||||
# Fallback
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
# 查找输出文件
|
||||
output_files = list(result_dir.rglob("*.mp4"))
|
||||
if output_files:
|
||||
shutil.copy(output_files[0], output_path)
|
||||
logger.info(f"✅ 唇形同步完成: {output_path}")
|
||||
else:
|
||||
logger.warning("⚠️ 未找到输出文件,使用 Fallback")
|
||||
shutil.copy(video_path, output_path)
|
||||
|
||||
return output_path
|
||||
|
||||
async def _remote_generate(
|
||||
self,
|
||||
video_path: str,
|
||||
audio_path: str,
|
||||
output_path: str,
|
||||
fps: int
|
||||
) -> str:
|
||||
"""调用远程 MuseTalk API 服务"""
|
||||
logger.info(f"📡 调用远程 API: {self.api_url}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
# 上传文件
|
||||
with open(video_path, "rb") as vf, open(audio_path, "rb") as af:
|
||||
files = {
|
||||
"video": (Path(video_path).name, vf, "video/mp4"),
|
||||
"audio": (Path(audio_path).name, af, "audio/mpeg"),
|
||||
}
|
||||
data = {"fps": fps}
|
||||
|
||||
response = await client.post(
|
||||
f"{self.api_url}/lipsync",
|
||||
files=files,
|
||||
data=data
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
# 保存响应视频
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
logger.info(f"✅ 远程推理完成: {output_path}")
|
||||
return output_path
|
||||
else:
|
||||
raise RuntimeError(f"API 错误: {response.status_code} - {response.text}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"远程 API 调用失败: {e}")
|
||||
# Fallback
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
async def check_health(self) -> bool:
|
||||
"""健康检查"""
|
||||
if self.use_local:
|
||||
gpu_ok = self._check_gpu()
|
||||
weights_ok = self._check_weights()
|
||||
return gpu_ok and weights_ok
|
||||
else:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(f"{self.api_url}/health")
|
||||
return response.status_code == 200
|
||||
except:
|
||||
return False
|
||||
71
backend/app/services/publish_service.py
Normal file
71
backend/app/services/publish_service.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
发布服务 (Playwright)
|
||||
"""
|
||||
from playwright.async_api import async_playwright
|
||||
from pathlib import Path
|
||||
import json
|
||||
import asyncio
|
||||
from loguru import logger
|
||||
from app.core.config import settings
|
||||
|
||||
class PublishService:
|
||||
PLATFORMS = {
|
||||
"douyin": {"name": "抖音", "url": "https://creator.douyin.com/"},
|
||||
"xiaohongshu": {"name": "小红书", "url": "https://creator.xiaohongshu.com/"},
|
||||
"weixin": {"name": "微信视频号", "url": "https://channels.weixin.qq.com/"},
|
||||
"kuaishou": {"name": "快手", "url": "https://cp.kuaishou.com/"},
|
||||
"bilibili": {"name": "B站", "url": "https://member.bilibili.com/platform/upload/video/frame"},
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.cookies_dir = settings.BASE_DIR / "cookies"
|
||||
self.cookies_dir.mkdir(exist_ok=True)
|
||||
|
||||
def get_accounts(self):
|
||||
accounts = []
|
||||
for pid, pinfo in self.PLATFORMS.items():
|
||||
cookie_file = self.cookies_dir / f"{pid}_cookies.json"
|
||||
accounts.append({
|
||||
"platform": pid,
|
||||
"name": pinfo["name"],
|
||||
"logged_in": cookie_file.exists(),
|
||||
"enabled": True
|
||||
})
|
||||
return accounts
|
||||
|
||||
async def login(self, platform: str):
|
||||
if platform not in self.PLATFORMS:
|
||||
raise ValueError("Unsupported platform")
|
||||
|
||||
pinfo = self.PLATFORMS[platform]
|
||||
logger.info(f"Logging in to {platform}...")
|
||||
|
||||
async with async_playwright() as p:
|
||||
browser = await p.chromium.launch(headless=False)
|
||||
context = await browser.new_context()
|
||||
page = await context.new_page()
|
||||
|
||||
await page.goto(pinfo["url"])
|
||||
logger.info("Please login manually in the browser window...")
|
||||
|
||||
# Wait for user input (naive check via title or url change, or explicit timeout)
|
||||
# For simplicity in restore, wait for 60s or until manually closed?
|
||||
# In a real API, this blocks.
|
||||
# We implemented a simplistic wait in the previous iteration.
|
||||
try:
|
||||
await page.wait_for_timeout(45000) # Give user 45s to login
|
||||
cookies = await context.cookies()
|
||||
cookie_path = self.cookies_dir / f"{platform}_cookies.json"
|
||||
with open(cookie_path, "w") as f:
|
||||
json.dump(cookies, f)
|
||||
return {"success": True, "message": f"Login {platform} successful"}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": str(e)}
|
||||
finally:
|
||||
await browser.close()
|
||||
|
||||
async def publish(self, video_path: str, platform: str, title: str, **kwargs):
|
||||
# Placeholder for actual automation logic
|
||||
# Real implementation requires complex selectors per platform
|
||||
await asyncio.sleep(2)
|
||||
return {"success": True, "message": f"Published to {platform} (Mock)", "url": ""}
|
||||
33
backend/app/services/tts_service.py
Normal file
33
backend/app/services/tts_service.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
TTS 服务 (EdgeTTS)
|
||||
"""
|
||||
import edge_tts
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
class TTSService:
|
||||
VOICES = {
|
||||
"zh-CN-YunxiNeural": "云希 (男, 轻松)",
|
||||
"zh-CN-YunjianNeural": "云健 (男, 体育)",
|
||||
"zh-CN-YunyangNeural": "云扬 (男, 专业)",
|
||||
"zh-CN-XiaoxiaoNeural": "晓晓 (女, 活泼)",
|
||||
"zh-CN-XiaoyiNeural": "晓伊 (女, 卡通)",
|
||||
}
|
||||
|
||||
async def generate_audio(self, text: str, voice: str, output_path: str) -> str:
|
||||
"""生成语音"""
|
||||
logger.info(f"TTS Generating: {text[:20]}... ({voice})")
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
communicate = edge_tts.Communicate(text, voice)
|
||||
await communicate.save(output_path)
|
||||
# Create SUBTITLES (vtt -> srt conversion logic omitted for brevity in restore)
|
||||
return output_path
|
||||
except Exception as e:
|
||||
logger.error(f"TTS Failed: {e}")
|
||||
raise
|
||||
|
||||
async def list_voices(self):
|
||||
return [{"id": k, "name": v} for k, v in self.VOICES.items()]
|
||||
95
backend/app/services/video_service.py
Normal file
95
backend/app/services/video_service.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
视频合成服务
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import json
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
class VideoService:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def _run_ffmpeg(self, cmd: list) -> bool:
|
||||
cmd_str = ' '.join(f'"{c}"' if ' ' in c or '\\' in c else c for c in cmd)
|
||||
logger.debug(f"FFmpeg CMD: {cmd_str}")
|
||||
try:
|
||||
# Synchronous call for BackgroundTasks compatibility
|
||||
result = subprocess.run(
|
||||
cmd_str,
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
encoding='utf-8',
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.error(f"FFmpeg Error: {result.stderr}")
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"FFmpeg Exception: {e}")
|
||||
return False
|
||||
|
||||
def _get_duration(self, file_path: str) -> float:
|
||||
# Synchronous call for BackgroundTasks compatibility
|
||||
cmd = f'ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 "{file_path}"'
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
shell=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return float(result.stdout.strip())
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
async def compose(
|
||||
self,
|
||||
video_path: str,
|
||||
audio_path: str,
|
||||
output_path: str,
|
||||
subtitle_path: Optional[str] = None
|
||||
) -> str:
|
||||
"""合成视频"""
|
||||
# Ensure output dir
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_duration = self._get_duration(video_path)
|
||||
audio_duration = self._get_duration(audio_path)
|
||||
|
||||
# Audio loop if needed
|
||||
loop_count = 1
|
||||
if audio_duration > video_duration and video_duration > 0:
|
||||
loop_count = int(audio_duration / video_duration) + 1
|
||||
|
||||
cmd = ["ffmpeg", "-y"]
|
||||
|
||||
# Input video (stream_loop must be before -i)
|
||||
if loop_count > 1:
|
||||
cmd.extend(["-stream_loop", str(loop_count)])
|
||||
cmd.extend(["-i", video_path])
|
||||
|
||||
# Input audio
|
||||
cmd.extend(["-i", audio_path])
|
||||
|
||||
# Filter complex
|
||||
filter_complex = []
|
||||
|
||||
# Subtitles (skip for now to mimic previous state or implement basic)
|
||||
# Previous state: subtitles disabled due to font issues
|
||||
# if subtitle_path: ...
|
||||
|
||||
# Audio map
|
||||
cmd.extend(["-c:v", "libx264", "-c:a", "aac", "-shortest"])
|
||||
# Use audio from input 1
|
||||
cmd.extend(["-map", "0:v", "-map", "1:a"])
|
||||
|
||||
cmd.append(output_path)
|
||||
|
||||
if self._run_ffmpeg(cmd):
|
||||
return output_path
|
||||
else:
|
||||
raise RuntimeError("FFmpeg composition failed")
|
||||
20
backend/requirements.txt
Normal file
20
backend/requirements.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
# ViGent Backend 依赖
|
||||
# MuseTalk 依赖请参考: models/MuseTalk/DEPLOY.md
|
||||
|
||||
fastapi>=0.109.0
|
||||
uvicorn[standard]>=0.27.0
|
||||
python-multipart>=0.0.6
|
||||
pydantic>=2.5.3
|
||||
pydantic-settings>=2.1.0
|
||||
celery>=5.3.6
|
||||
redis>=5.0.1
|
||||
edge-tts>=6.1.9
|
||||
ffmpeg-python>=0.2.0
|
||||
httpx>=0.26.0
|
||||
aiofiles>=23.2.1
|
||||
sqlalchemy>=2.0.25
|
||||
aiosqlite>=0.19.0
|
||||
python-dotenv>=1.0.0
|
||||
loguru>=0.7.2
|
||||
playwright>=1.40.0
|
||||
requests>=2.31.0
|
||||
Reference in New Issue
Block a user