This commit is contained in:
Kevin Wong
2026-02-05 12:03:55 +08:00
parent b2c1042c5c
commit be6a3436bb
75 changed files with 3896 additions and 2900 deletions

View File

View File

View File

@@ -0,0 +1,164 @@
"""
管理员 API用户管理
"""
from fastapi import APIRouter, HTTPException, Depends, status
from pydantic import BaseModel
from typing import Optional, List, Any, cast
from datetime import datetime, timezone, timedelta
from app.core.deps import get_current_admin
from app.core.response import success_response
from app.repositories.sessions import delete_sessions
from app.repositories.users import get_user_by_id, list_users as list_users_repo, update_user
from loguru import logger
router = APIRouter(prefix="/api/admin", tags=["管理"])
class UserListItem(BaseModel):
id: str
phone: str
username: Optional[str]
role: str
is_active: bool
expires_at: Optional[str]
created_at: str
class ActivateRequest(BaseModel):
expires_days: Optional[int] = None # 授权天数None 表示永久
@router.get("/users")
async def list_users(admin: dict = Depends(get_current_admin)):
"""获取所有用户列表"""
try:
data = list_users_repo()
return success_response([
UserListItem(
id=u["id"],
phone=u["phone"],
username=u.get("username"),
role=u["role"],
is_active=u["is_active"],
expires_at=u.get("expires_at"),
created_at=u["created_at"]
).model_dump()
for u in data
])
except Exception as e:
logger.error(f"获取用户列表失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="获取用户列表失败"
)
@router.post("/users/{user_id}/activate")
async def activate_user(
user_id: str,
request: ActivateRequest,
admin: dict = Depends(get_current_admin)
):
"""
激活用户
Args:
user_id: 用户 ID
request.expires_days: 授权天数 (None 表示永久)
"""
try:
# 计算过期时间
expires_at = None
if request.expires_days:
expires_at = (datetime.now(timezone.utc) + timedelta(days=request.expires_days)).isoformat()
result = update_user(user_id, {
"is_active": True,
"role": "user",
"expires_at": expires_at
})
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
logger.info(f"管理员 {admin['phone']} 激活用户 {user_id}, 有效期: {request.expires_days or '永久'}")
return success_response(message=f"用户已激活,有效期: {request.expires_days or '永久'}")
except HTTPException:
raise
except Exception as e:
logger.error(f"激活用户失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="激活用户失败"
)
@router.post("/users/{user_id}/deactivate")
async def deactivate_user(
user_id: str,
admin: dict = Depends(get_current_admin)
):
"""停用用户"""
try:
# 不能停用管理员
user = cast(dict[str, Any], get_user_by_id(user_id) or {})
if user.get("role") == "admin":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不能停用管理员账号"
)
update_user(user_id, {"is_active": False})
delete_sessions(user_id)
logger.info(f"管理员 {admin['phone']} 停用用户 {user_id}")
return success_response(message="用户已停用")
except HTTPException:
raise
except Exception as e:
logger.error(f"停用用户失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="停用用户失败"
)
@router.post("/users/{user_id}/extend")
async def extend_user(
user_id: str,
request: ActivateRequest,
admin: dict = Depends(get_current_admin)
):
"""延长用户授权期限"""
try:
if not request.expires_days:
# 设为永久
expires_at = None
else:
# 获取当前过期时间
user = cast(dict[str, Any], get_user_by_id(user_id) or {})
if user and user.get("expires_at"):
current_expires = datetime.fromisoformat(user["expires_at"].replace("Z", "+00:00"))
base_time = max(current_expires, datetime.now(timezone.utc))
else:
base_time = datetime.now(timezone.utc)
expires_at = (base_time + timedelta(days=request.expires_days)).isoformat()
update_user(user_id, {"expires_at": expires_at})
logger.info(f"管理员 {admin['phone']} 延长用户 {user_id} 授权 {request.expires_days or '永久'}")
return success_response(message=f"授权已延长 {request.expires_days or '永久'}")
except Exception as e:
logger.error(f"延长授权失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="延长授权失败"
)

View File

View File

@@ -0,0 +1,46 @@
"""
AI 相关 API 路由
"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from loguru import logger
from app.services.glm_service import glm_service
from app.core.response import success_response
router = APIRouter(prefix="/api/ai", tags=["AI"])
class GenerateMetaRequest(BaseModel):
"""生成标题标签请求"""
text: str
class GenerateMetaResponse(BaseModel):
"""生成标题标签响应"""
title: str
tags: list[str]
@router.post("/generate-meta")
async def generate_meta(req: GenerateMetaRequest):
"""
AI 生成视频标题和标签
根据口播文案自动生成吸引人的标题和相关标签
"""
if not req.text or not req.text.strip():
raise HTTPException(status_code=400, detail="口播文案不能为空")
try:
logger.info(f"Generating meta for text: {req.text[:50]}...")
result = await glm_service.generate_title_tags(req.text)
return success_response(GenerateMetaResponse(
title=result.get("title", ""),
tags=result.get("tags", [])
).model_dump())
except Exception as e:
logger.error(f"Generate meta failed: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

View File

@@ -0,0 +1,23 @@
from fastapi import APIRouter, Depends
from app.core.deps import get_current_user
from app.services.assets_service import list_styles, list_bgm
from app.core.response import success_response
router = APIRouter()
@router.get("/subtitle-styles")
async def list_subtitle_styles(current_user: dict = Depends(get_current_user)):
return success_response({"styles": list_styles("subtitle")})
@router.get("/title-styles")
async def list_title_styles(current_user: dict = Depends(get_current_user)):
return success_response({"styles": list_styles("title")})
@router.get("/bgm")
async def list_bgm_items(current_user: dict = Depends(get_current_user)):
return success_response({"bgm": list_bgm()})

View File

View File

@@ -0,0 +1,293 @@
"""
认证 API注册、登录、登出、修改密码
"""
from fastapi import APIRouter, HTTPException, Response, status, Request
from pydantic import BaseModel, field_validator
from app.core.security import (
get_password_hash,
verify_password,
create_access_token,
generate_session_token,
set_auth_cookie,
clear_auth_cookie,
decode_access_token
)
from app.repositories.sessions import create_session, delete_sessions
from app.repositories.users import create_user, get_user_by_id, get_user_by_phone, user_exists_by_phone, update_user
from app.core.response import success_response
from loguru import logger
from typing import Optional, Any, cast
import re
router = APIRouter(prefix="/api/auth", tags=["认证"])
class RegisterRequest(BaseModel):
phone: str
password: str
username: Optional[str] = None
@field_validator('phone')
@classmethod
def validate_phone(cls, v):
if not re.match(r'^\d{11}$', v):
raise ValueError('手机号必须是11位数字')
return v
class LoginRequest(BaseModel):
phone: str
password: str
@field_validator('phone')
@classmethod
def validate_phone(cls, v):
if not re.match(r'^\d{11}$', v):
raise ValueError('手机号必须是11位数字')
return v
class ChangePasswordRequest(BaseModel):
old_password: str
new_password: str
@field_validator('new_password')
@classmethod
def validate_new_password(cls, v):
if len(v) < 6:
raise ValueError('新密码长度至少6位')
return v
class UserResponse(BaseModel):
id: str
phone: str
username: Optional[str]
role: str
is_active: bool
expires_at: Optional[str] = None
@router.post("/register")
async def register(request: RegisterRequest):
"""
用户注册
注册后状态为 pending需要管理员激活
"""
try:
if user_exists_by_phone(request.phone):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="该手机号已注册"
)
# 创建用户
password_hash = get_password_hash(request.password)
create_user({
"phone": request.phone,
"password_hash": password_hash,
"username": request.username or f"用户{request.phone[-4:]}",
"role": "pending",
"is_active": False
})
logger.info(f"新用户注册: {request.phone}")
return success_response(message="注册成功,请等待管理员审核激活")
except HTTPException:
raise
except Exception as e:
logger.error(f"注册失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="注册失败,请稍后重试"
)
@router.post("/login")
async def login(request: LoginRequest, response: Response):
"""
用户登录
- 验证密码
- 检查是否激活
- 实现"后踢前"单设备登录
"""
try:
user = cast(dict[str, Any], get_user_by_phone(request.phone) or {})
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="手机号或密码错误"
)
# 验证密码
if not verify_password(request.password, user["password_hash"]):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="手机号或密码错误"
)
# 检查是否激活
if not user["is_active"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="账号未激活,请等待管理员审核"
)
# 检查授权是否过期
if user.get("expires_at"):
from datetime import datetime, timezone
expires_at = datetime.fromisoformat(user["expires_at"].replace("Z", "+00:00"))
if datetime.now(timezone.utc) > expires_at:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="授权已过期,请联系管理员续期"
)
# 生成新的 session_token (后踢前)
session_token = generate_session_token()
# 删除旧 session插入新 session
delete_sessions(user["id"])
create_session(user["id"], session_token, None)
# 生成 JWT Token
token = create_access_token(user["id"], session_token)
# 设置 HttpOnly Cookie
set_auth_cookie(response, token)
logger.info(f"用户登录: {request.phone}")
return success_response(
data={
"user": UserResponse(
id=user["id"],
phone=user["phone"],
username=user.get("username"),
role=user["role"],
is_active=user["is_active"],
expires_at=user.get("expires_at")
).model_dump()
},
message="登录成功",
)
except HTTPException:
raise
except Exception as e:
logger.error(f"登录失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="登录失败,请稍后重试"
)
@router.post("/logout")
async def logout(response: Response):
"""用户登出"""
clear_auth_cookie(response)
return success_response(message="已登出")
@router.post("/change-password")
async def change_password(request: ChangePasswordRequest, req: Request, response: Response):
"""
修改密码
- 验证当前密码
- 设置新密码
- 重新生成 session token
"""
# 从 Cookie 获取用户
token = req.cookies.get("access_token")
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未登录"
)
token_data = decode_access_token(token)
if not token_data:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 无效"
)
try:
user = cast(dict[str, Any], get_user_by_id(token_data.user_id) or {})
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在"
)
# 验证当前密码
if not verify_password(request.old_password, user["password_hash"]):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="当前密码错误"
)
# 更新密码
new_password_hash = get_password_hash(request.new_password)
update_user(user["id"], {"password_hash": new_password_hash})
# 生成新的 session token使旧 token 失效
new_session_token = generate_session_token()
delete_sessions(user["id"])
create_session(user["id"], new_session_token, None)
# 生成新的 JWT Token
new_token = create_access_token(user["id"], new_session_token)
set_auth_cookie(response, new_token)
logger.info(f"用户修改密码: {user['phone']}")
return success_response(message="密码修改成功")
except HTTPException:
raise
except Exception as e:
logger.error(f"修改密码失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="修改密码失败,请稍后重试"
)
@router.get("/me")
async def get_me(request: Request):
"""获取当前用户信息"""
# 从 Cookie 获取用户
token = request.cookies.get("access_token")
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未登录"
)
token_data = decode_access_token(token)
if not token_data:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 无效"
)
user = cast(dict[str, Any], get_user_by_id(token_data.user_id) or {})
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在"
)
return success_response(UserResponse(
id=user["id"],
phone=user["phone"],
username=user.get("username"),
role=user["role"],
is_active=user["is_active"],
expires_at=user.get("expires_at")
).model_dump())

View File

@@ -0,0 +1,221 @@
"""
前端一键扫码登录辅助页面
客户在自己的浏览器中扫码JavaScript自动提取Cookie并上传到服务器
"""
from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse
from app.core.config import settings
router = APIRouter()
@router.get("/login-helper/{platform}", response_class=HTMLResponse)
async def login_helper_page(platform: str, request: Request):
"""
提供一个HTML页面让用户在自己的浏览器中登录平台
登录后JavaScript自动提取Cookie并POST回服务器
"""
platform_urls = {
"bilibili": "https://www.bilibili.com/",
"douyin": "https://creator.douyin.com/",
"xiaohongshu": "https://creator.xiaohongshu.com/"
}
platform_names = {
"bilibili": "B站",
"douyin": "抖音",
"xiaohongshu": "小红书"
}
if platform not in platform_urls:
return "<h1>不支持的平台</h1>"
# 获取服务器地址用于回传Cookie
server_url = str(request.base_url).rstrip('/')
html_content = f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{platform_names[platform]} 一键登录</title>
<style>
body {{
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
margin: 0;
padding: 20px;
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
}}
.container {{
background: white;
border-radius: 20px;
padding: 50px;
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
max-width: 700px;
width: 100%;
}}
h1 {{
color: #333;
margin: 0 0 30px 0;
text-align: center;
font-size: 32px;
}}
.step {{
display: flex;
align-items: flex-start;
margin: 25px 0;
padding: 20px;
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
border-radius: 12px;
border-left: 5px solid #667eea;
}}
.step-number {{
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
width: 40px;
height: 40px;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
font-weight: bold;
font-size: 20px;
margin-right: 20px;
flex-shrink: 0;
}}
.step-content {{
flex: 1;
}}
.step-title {{
font-weight: 600;
font-size: 18px;
margin-bottom: 8px;
color: #333;
}}
.step-desc {{
color: #666;
line-height: 1.6;
}}
.bookmarklet {{
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 15px 30px;
border-radius: 10px;
text-decoration: none;
display: inline-block;
font-weight: 600;
font-size: 18px;
margin: 20px 0;
cursor: move;
border: 3px dashed white;
transition: transform 0.2s;
}}
.bookmarklet:hover {{
transform: scale(1.05);
}}
.bookmarklet-container {{
text-align: center;
margin: 30px 0;
padding: 30px;
background: #f8f9fa;
border-radius: 12px;
}}
.instruction {{
font-size: 14px;
color: #666;
margin-top: 10px;
}}
.highlight {{
background: #fff3cd;
padding: 2px 6px;
border-radius: 4px;
font-weight: 600;
}}
.btn {{
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
padding: 15px 40px;
border-radius: 10px;
font-size: 18px;
cursor: pointer;
font-weight: 600;
width: 100%;
margin-top: 20px;
transition: transform 0.2s;
}}
.btn:hover {{
transform: translateY(-2px);
}}
</style>
</head>
<body>
<div class="container">
<h1>🔐 {platform_names[platform]} 一键登录</h1>
<div class="step">
<div class="step-number">1</div>
<div class="step-content">
<div class="step-title">拖拽书签到书签栏</div>
<div class="step-desc">
将下方的"<span class="highlight">保存{platform_names[platform]}登录</span>"按钮拖拽到浏览器书签栏
<br><small>(如果书签栏未显示,按 Ctrl+Shift+B 显示)</small>
</div>
</div>
</div>
<div class="bookmarklet-container">
<a href="javascript:(function(){{var c=document.cookie;if(!c){{alert('请先登录{platform_names[platform]}');return;}}fetch('{server_url}/api/publish/cookies/save/{platform}',{{method:'POST',headers:{{'Content-Type':'application/json'}},body:JSON.stringify({{cookie_string:c}})}}).then(r=>r.json()).then(d=>{{if(d.success){{alert('✅ 登录成功!');window.opener&&window.opener.location.reload();}}else{{alert(''+d.message);}}}}
).catch(e=>alert('提交失败:'+e));}})();"
class="bookmarklet"
onclick="alert('请拖拽此按钮到书签栏,不要点击!'); return false;">
🔖 保存{platform_names[platform]}登录
</a>
<div class="instruction">
⬆️ <strong>拖拽此按钮到浏览器顶部书签栏</strong>
</div>
</div>
<div class="step">
<div class="step-number">2</div>
<div class="step-content">
<div class="step-title">登录 {platform_names[platform]}</div>
<div class="step-desc">
点击下方按钮打开{platform_names[platform]}登录页,扫码登录
</div>
</div>
</div>
<button class="btn" onclick="window.open('{platform_urls[platform]}', 'login_tab')">
🚀 打开{platform_names[platform]}登录页
</button>
<div class="step">
<div class="step-number">3</div>
<div class="step-content">
<div class="step-title">一键保存登录</div>
<div class="step-desc">
登录成功后,点击书签栏的"<span class="highlight">保存{platform_names[platform]}登录</span>"书签
<br>系统会自动提取并保存Cookie完成
</div>
</div>
</div>
<hr style="margin: 40px 0; border: none; border-top: 2px solid #eee;">
<div style="text-align: center; color: #999; font-size: 14px;">
<p>💡 <strong>提示</strong>:书签只需拖拽一次,下次登录直接点击书签即可</p>
<p>🔒 所有数据仅在您的浏览器和服务器之间传输,安全可靠</p>
</div>
</div>
</body>
</html>
"""
return HTMLResponse(content=html_content)

View File

@@ -0,0 +1,416 @@
from fastapi import APIRouter, UploadFile, File, HTTPException, Request, BackgroundTasks, Depends
from app.core.config import settings
from app.core.deps import get_current_user
from app.core.response import success_response
from app.services.storage import storage_service
import re
import time
import traceback
import os
import aiofiles
from pathlib import Path
from loguru import logger
import asyncio
from pydantic import BaseModel
from typing import Optional
import httpx
router = APIRouter()
class RenameMaterialRequest(BaseModel):
new_name: str
def sanitize_filename(filename: str) -> str:
safe_name = re.sub(r'[<>:"/\\|?*]', '_', filename)
if len(safe_name) > 100:
ext = Path(safe_name).suffix
safe_name = safe_name[:100 - len(ext)] + ext
return safe_name
async def process_and_upload(temp_file_path: str, original_filename: str, content_type: str, user_id: str):
"""Background task to strip multipart headers and upload to Supabase"""
try:
logger.info(f"Processing raw upload: {temp_file_path} for user {user_id}")
# 1. Analyze file to find actual video content (strip multipart boundaries)
# This is a simplified manual parser for a SINGLE file upload.
# Structure:
# --boundary
# Content-Disposition: form-data; name="file"; filename="..."
# Content-Type: video/mp4
# \r\n\r\n
# [DATA]
# \r\n--boundary--
# We need to read the first few KB to find the header end
start_offset = 0
end_offset = 0
boundary = b""
file_size = os.path.getsize(temp_file_path)
with open(temp_file_path, 'rb') as f:
# Read first 4KB to find header
head = f.read(4096)
# Find boundary
first_line_end = head.find(b'\r\n')
if first_line_end == -1:
raise Exception("Could not find boundary in multipart body")
boundary = head[:first_line_end] # e.g. --boundary123
logger.info(f"Detected boundary: {boundary}")
# Find end of headers (\r\n\r\n)
header_end = head.find(b'\r\n\r\n')
if header_end == -1:
raise Exception("Could not find end of multipart headers")
start_offset = header_end + 4
logger.info(f"Video data starts at offset: {start_offset}")
# Find end boundary (read from end of file)
# It should be \r\n + boundary + -- + \r\n
# We seek to end-200 bytes
f.seek(max(0, file_size - 200))
tail = f.read()
# The closing boundary is usually --boundary--
# We look for the last occurrence of the boundary
last_boundary_pos = tail.rfind(boundary)
if last_boundary_pos != -1:
# The data ends before \r\n + boundary
# The tail buffer relative position needs to be converted to absolute
end_pos_in_tail = last_boundary_pos
# We also need to check for the preceding \r\n
if end_pos_in_tail >= 2 and tail[end_pos_in_tail-2:end_pos_in_tail] == b'\r\n':
end_pos_in_tail -= 2
# Absolute end offset
end_offset = (file_size - 200) + last_boundary_pos
# Correction for CRLF before boundary
# Actually, simply: read until (file_size - len(tail) + last_boundary_pos) - 2
end_offset = (max(0, file_size - 200) + last_boundary_pos) - 2
else:
logger.warning("Could not find closing boundary, assuming EOF")
end_offset = file_size
logger.info(f"Video data ends at offset: {end_offset}. Total video size: {end_offset - start_offset}")
# 2. Extract and Upload to Supabase
# Since we have the file on disk, we can just pass the file object (seeked) to upload_file?
# Or if upload_file expects bytes/path, checking storage.py...
# It takes `file_data` (bytes) or file-like?
# supabase-py's `upload` method handles parsing if we pass a file object.
# But we need to pass ONLY the video slice.
# So we create a generator or a sliced file object?
# Simpler: Read the slice into memory if < 1GB? Or copy to new temp file?
# Copying to new temp file is safer for memory.
video_path = temp_file_path + "_video.mp4"
with open(temp_file_path, 'rb') as src, open(video_path, 'wb') as dst:
src.seek(start_offset)
# Copy in chunks
bytes_to_copy = end_offset - start_offset
copied = 0
while copied < bytes_to_copy:
chunk_size = min(1024*1024*10, bytes_to_copy - copied) # 10MB chunks
chunk = src.read(chunk_size)
if not chunk:
break
dst.write(chunk)
copied += len(chunk)
logger.info(f"Extracted video content to {video_path}")
# 3. Upload to Supabase with user isolation
timestamp = int(time.time())
safe_name = re.sub(r'[^a-zA-Z0-9._-]', '', original_filename)
# 使用 user_id 作为目录前缀实现隔离
storage_path = f"{user_id}/{timestamp}_{safe_name}"
# Use storage service (this calls Supabase which might do its own http request)
# We read the cleaned video file
with open(video_path, 'rb') as f:
file_content = f.read() # Still reading into memory for simple upload call, but server has 32GB RAM so ok for 500MB
await storage_service.upload_file(
bucket=storage_service.BUCKET_MATERIALS,
path=storage_path,
file_data=file_content,
content_type=content_type
)
logger.info(f"Upload to Supabase complete: {storage_path}")
# Cleanup
os.remove(temp_file_path)
os.remove(video_path)
return storage_path
except Exception as e:
logger.error(f"Background upload processing failed: {e}\n{traceback.format_exc()}")
raise
@router.post("")
async def upload_material(
request: Request,
background_tasks: BackgroundTasks,
current_user: dict = Depends(get_current_user)
):
user_id = current_user["id"]
logger.info(f"ENTERED upload_material (Streaming Mode) for user {user_id}. Headers: {request.headers}")
filename = "unknown_video.mp4" # Fallback
content_type = "video/mp4"
# Try to parse filename from header if possible (unreliable in raw stream)
# We will rely on post-processing or client hint
# Frontend sends standard multipart.
# Create temp file
timestamp = int(time.time())
temp_filename = f"upload_{timestamp}.raw"
temp_path = os.path.join("/tmp", temp_filename) # Use /tmp on Linux
# Ensure /tmp exists (it does) but verify paths
if os.name == 'nt': # Local dev
temp_path = f"d:/tmp/{temp_filename}"
os.makedirs("d:/tmp", exist_ok=True)
try:
total_size = 0
last_log = 0
async with aiofiles.open(temp_path, 'wb') as f:
async for chunk in request.stream():
await f.write(chunk)
total_size += len(chunk)
# Log progress every 20MB
if total_size - last_log > 20 * 1024 * 1024:
logger.info(f"Receiving stream... Processed {total_size / (1024*1024):.2f} MB")
last_log = total_size
logger.info(f"Stream reception complete. Total size: {total_size} bytes. Saved to {temp_path}")
if total_size == 0:
raise HTTPException(400, "Received empty body")
# Attempt to extract filename from the saved file's first bytes?
# Or just accept it as "uploaded_video.mp4" for now to prove it works.
# We can try to regex the header in the file content we just wrote.
# Implemented in background task to return success immediately.
# Wait, if we return immediately, the user's UI might not show the file yet?
# The prompt says "Wait for upload".
# But to avoid User Waiting Timeout, maybe returning early is better?
# NO, user expects the file to be in the list.
# So we Must await the processing.
# But "Processing" (Strip + Upload to Supabase) takes time.
# Receiving took time.
# If we await Supabase upload, does it timeout?
# Supabase upload is outgoing. Usually faster/stable.
# Let's await the processing to ensure "List Materials" shows it.
# We need to extract the filename for the list.
# Quick extract filename from first 4kb
with open(temp_path, 'rb') as f:
head = f.read(4096).decode('utf-8', errors='ignore')
match = re.search(r'filename="([^"]+)"', head)
if match:
filename = match.group(1)
logger.info(f"Extracted filename from body: {filename}")
# Run processing sync (in await)
storage_path = await process_and_upload(temp_path, filename, content_type, user_id)
# Get signed URL (it exists now)
signed_url = await storage_service.get_signed_url(
bucket=storage_service.BUCKET_MATERIALS,
path=storage_path
)
size_mb = total_size / (1024 * 1024) # Approximate (includes headers)
# 从 storage_path 提取显示名
display_name = storage_path.split('/')[-1] # 去掉 user_id 前缀
if '_' in display_name:
parts = display_name.split('_', 1)
if parts[0].isdigit():
display_name = parts[1]
return success_response({
"id": storage_path,
"name": display_name,
"path": signed_url,
"size_mb": size_mb,
"type": "video"
})
except Exception as e:
error_msg = f"Streaming upload failed: {str(e)}"
detail_msg = f"Exception: {repr(e)}\nArgs: {e.args}\n{traceback.format_exc()}"
logger.error(error_msg + "\n" + detail_msg)
# Write to debug file
try:
with open("debug_upload.log", "a") as logf:
logf.write(f"\n--- Error at {time.ctime()} ---\n")
logf.write(detail_msg)
logf.write("\n-----------------------------\n")
except:
pass
if os.path.exists(temp_path):
try:
os.remove(temp_path)
except:
pass
raise HTTPException(500, f"Upload failed. Check server logs. Error: {str(e)}")
@router.get("")
async def list_materials(current_user: dict = Depends(get_current_user)):
user_id = current_user["id"]
try:
# 只列出当前用户目录下的文件
files_obj = await storage_service.list_files(
bucket=storage_service.BUCKET_MATERIALS,
path=user_id
)
semaphore = asyncio.Semaphore(8)
async def build_item(f):
name = f.get('name')
if not name or name == '.emptyFolderPlaceholder':
return None
display_name = name
if '_' in name:
parts = name.split('_', 1)
if parts[0].isdigit():
display_name = parts[1]
full_path = f"{user_id}/{name}"
async with semaphore:
signed_url = await storage_service.get_signed_url(
bucket=storage_service.BUCKET_MATERIALS,
path=full_path
)
metadata = f.get('metadata', {})
size = metadata.get('size', 0)
created_at_str = f.get('created_at', '')
created_at = 0
if created_at_str:
from datetime import datetime
try:
dt = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
created_at = int(dt.timestamp())
except Exception:
pass
return {
"id": full_path,
"name": display_name,
"path": signed_url,
"size_mb": size / (1024 * 1024),
"type": "video",
"created_at": created_at
}
tasks = [build_item(f) for f in files_obj]
results = await asyncio.gather(*tasks, return_exceptions=True)
materials = []
for item in results:
if not item:
continue
if isinstance(item, Exception):
logger.warning(f"Material signed url build failed: {item}")
continue
materials.append(item)
materials.sort(key=lambda x: x['id'], reverse=True)
return success_response({"materials": materials})
except Exception as e:
logger.error(f"List materials failed: {e}")
return success_response({"materials": []}, message="获取素材失败")
@router.delete("/{material_id:path}")
async def delete_material(material_id: str, current_user: dict = Depends(get_current_user)):
user_id = current_user["id"]
# 验证 material_id 属于当前用户
if not material_id.startswith(f"{user_id}/"):
raise HTTPException(403, "无权删除此素材")
try:
await storage_service.delete_file(
bucket=storage_service.BUCKET_MATERIALS,
path=material_id
)
return success_response(message="素材已删除")
except Exception as e:
raise HTTPException(500, f"删除失败: {str(e)}")
@router.put("/{material_id:path}")
async def rename_material(
material_id: str,
payload: RenameMaterialRequest,
current_user: dict = Depends(get_current_user)
):
user_id = current_user["id"]
if not material_id.startswith(f"{user_id}/"):
raise HTTPException(403, "无权重命名此素材")
new_name_raw = payload.new_name.strip() if payload.new_name else ""
if not new_name_raw:
raise HTTPException(400, "新名称不能为空")
old_name = material_id.split("/", 1)[1]
old_ext = Path(old_name).suffix
base_name = Path(new_name_raw).stem if Path(new_name_raw).suffix else new_name_raw
safe_base = sanitize_filename(base_name).strip()
if not safe_base:
raise HTTPException(400, "新名称无效")
new_filename = f"{safe_base}{old_ext}"
prefix = None
if "_" in old_name:
maybe_prefix, _ = old_name.split("_", 1)
if maybe_prefix.isdigit():
prefix = maybe_prefix
if prefix:
new_filename = f"{prefix}_{new_filename}"
new_path = f"{user_id}/{new_filename}"
try:
if new_path != material_id:
await storage_service.move_file(
bucket=storage_service.BUCKET_MATERIALS,
from_path=material_id,
to_path=new_path
)
signed_url = await storage_service.get_signed_url(
bucket=storage_service.BUCKET_MATERIALS,
path=new_path
)
display_name = new_filename
if "_" in new_filename:
parts = new_filename.split("_", 1)
if parts[0].isdigit():
display_name = parts[1]
return success_response({
"id": new_path,
"name": display_name,
"path": signed_url,
}, message="重命名成功")
except Exception as e:
raise HTTPException(500, f"重命名失败: {str(e)}")

View File

View File

@@ -0,0 +1,141 @@
"""
发布管理 API (支持用户认证)
"""
from fastapi import APIRouter, HTTPException, BackgroundTasks, Depends, Request
from pydantic import BaseModel
from typing import List, Optional
from datetime import datetime
from loguru import logger
from app.services.publish_service import PublishService
from app.core.response import success_response
router = APIRouter()
publish_service = PublishService()
class PublishRequest(BaseModel):
"""Video publish request model"""
video_path: str
platform: str
title: str
tags: List[str] = []
description: str = ""
publish_time: Optional[datetime] = None
class PublishResponse(BaseModel):
"""Video publish response model"""
success: bool
message: str
platform: str
url: Optional[str] = None
# Supported platforms for validation
SUPPORTED_PLATFORMS = {"bilibili", "douyin", "xiaohongshu"}
def _get_user_id(request: Request) -> Optional[str]:
"""从请求中获取用户 ID (兼容未登录场景)"""
try:
from app.core.security import decode_access_token
token = request.cookies.get("access_token")
if token:
token_data = decode_access_token(token)
if token_data:
return token_data.user_id
except Exception:
pass
return None
@router.post("")
async def publish_video(request: PublishRequest, req: Request, background_tasks: BackgroundTasks):
"""发布视频到指定平台"""
# Validate platform
if request.platform not in SUPPORTED_PLATFORMS:
raise HTTPException(
status_code=400,
detail=f"不支持的平台: {request.platform}。支持的平台: {', '.join(SUPPORTED_PLATFORMS)}"
)
# 获取用户 ID (可选)
user_id = _get_user_id(req)
try:
result = await publish_service.publish(
video_path=request.video_path,
platform=request.platform,
title=request.title,
tags=request.tags,
description=request.description,
publish_time=request.publish_time,
user_id=user_id
)
message = result.get("message", "")
return success_response(result, message=message)
except Exception as e:
logger.error(f"发布失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/platforms")
async def list_platforms():
return success_response({"platforms": [{**pinfo, "id": pid} for pid, pinfo in publish_service.PLATFORMS.items()]})
@router.get("/accounts")
async def list_accounts(req: Request):
user_id = _get_user_id(req)
return success_response({"accounts": publish_service.get_accounts(user_id)})
@router.post("/login/{platform}")
async def login_platform(platform: str, req: Request):
"""触发平台QR码登录"""
if platform not in SUPPORTED_PLATFORMS:
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
user_id = _get_user_id(req)
result = await publish_service.login(platform, user_id)
message = result.get("message", "")
return success_response(result, message=message)
@router.post("/logout/{platform}")
async def logout_platform(platform: str, req: Request):
"""注销平台登录"""
if platform not in SUPPORTED_PLATFORMS:
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
user_id = _get_user_id(req)
result = publish_service.logout(platform, user_id)
message = result.get("message", "")
return success_response(result, message=message)
@router.get("/login/status/{platform}")
async def get_login_status(platform: str, req: Request):
"""检查登录状态 (优先检查活跃的扫码会话)"""
if platform not in SUPPORTED_PLATFORMS:
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
user_id = _get_user_id(req)
result = publish_service.get_login_session_status(platform, user_id)
message = result.get("message", "")
return success_response(result, message=message)
@router.post("/cookies/save/{platform}")
async def save_platform_cookie(platform: str, cookie_data: dict, req: Request):
"""
保存从客户端浏览器提取的Cookie
Args:
platform: 平台ID
cookie_data: {"cookie_string": "document.cookie的内容"}
"""
if platform not in SUPPORTED_PLATFORMS:
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
cookie_string = cookie_data.get("cookie_string", "")
if not cookie_string:
raise HTTPException(status_code=400, detail="cookie_string 不能为空")
user_id = _get_user_id(req)
result = await publish_service.save_cookie_string(platform, cookie_string, user_id)
message = result.get("message", "")
return success_response(result, message=message)

View File

@@ -0,0 +1,416 @@
"""
参考音频管理 API
支持上传/列表/删除参考音频,用于 Qwen3-TTS 声音克隆
"""
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends
from pydantic import BaseModel
from typing import List, Optional
from pathlib import Path
from loguru import logger
import time
import json
import subprocess
import tempfile
import os
import re
from app.core.deps import get_current_user
from app.services.storage import storage_service
from app.core.response import success_response
router = APIRouter()
# 支持的音频格式
ALLOWED_AUDIO_EXTENSIONS = {'.wav', '.mp3', '.m4a', '.webm', '.ogg', '.flac', '.aac'}
# 参考音频 bucket
BUCKET_REF_AUDIOS = "ref-audios"
class RefAudioResponse(BaseModel):
id: str
name: str
path: str # signed URL for playback
ref_text: str
duration_sec: float
created_at: int
class RefAudioListResponse(BaseModel):
items: List[RefAudioResponse]
def sanitize_filename(filename: str) -> str:
"""清理文件名,移除特殊字符"""
safe_name = re.sub(r'[<>:"/\\|?*\s]', '_', filename)
if len(safe_name) > 50:
ext = Path(safe_name).suffix
safe_name = safe_name[:50 - len(ext)] + ext
return safe_name
def get_audio_duration(file_path: str) -> float:
"""获取音频时长 (秒)"""
try:
result = subprocess.run(
['ffprobe', '-v', 'quiet', '-show_entries', 'format=duration',
'-of', 'csv=p=0', file_path],
capture_output=True, text=True, timeout=10
)
return float(result.stdout.strip())
except Exception as e:
logger.warning(f"获取音频时长失败: {e}")
return 0.0
def convert_to_wav(input_path: str, output_path: str) -> bool:
"""将音频转换为 WAV 格式 (16kHz, mono)"""
try:
subprocess.run([
'ffmpeg', '-y', '-i', input_path,
'-ar', '16000', # 16kHz 采样率
'-ac', '1', # 单声道
'-acodec', 'pcm_s16le', # 16-bit PCM
output_path
], capture_output=True, timeout=60, check=True)
return True
except Exception as e:
logger.error(f"音频转换失败: {e}")
return False
@router.post("")
async def upload_ref_audio(
file: UploadFile = File(...),
ref_text: str = Form(...),
user: dict = Depends(get_current_user)
):
"""
上传参考音频
- file: 音频文件 (支持 wav, mp3, m4a, webm 等)
- ref_text: 参考音频的转写文字 (必填)
"""
user_id = user["id"]
if not file.filename:
raise HTTPException(status_code=400, detail="文件名无效")
filename = file.filename
# 验证文件扩展名
ext = Path(filename).suffix.lower()
if ext not in ALLOWED_AUDIO_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"不支持的音频格式: {ext}。支持的格式: {', '.join(ALLOWED_AUDIO_EXTENSIONS)}"
)
# 验证 ref_text
if not ref_text or len(ref_text.strip()) < 2:
raise HTTPException(status_code=400, detail="参考文字不能为空")
try:
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_input:
content = await file.read()
tmp_input.write(content)
tmp_input_path = tmp_input.name
# 转换为 WAV 格式
tmp_wav_path = tmp_input_path + ".wav"
if ext != '.wav':
if not convert_to_wav(tmp_input_path, tmp_wav_path):
raise HTTPException(status_code=500, detail="音频格式转换失败")
else:
# 即使是 wav 也要标准化格式
convert_to_wav(tmp_input_path, tmp_wav_path)
# 获取音频时长
duration = get_audio_duration(tmp_wav_path)
if duration < 1.0:
raise HTTPException(status_code=400, detail="音频时长过短,至少需要 1 秒")
if duration > 60.0:
raise HTTPException(status_code=400, detail="音频时长过长,最多 60 秒")
# 3. 处理重名逻辑 (Friendly Display Name)
original_name = filename
# 获取用户现有的所有参考音频列表 (为了检查文件名冲突)
# 注意: 这种列表方式在文件极多时性能一般,但考虑到单用户参考音频数量有限,目前可行
existing_files = await storage_service.list_files(BUCKET_REF_AUDIOS, user_id)
existing_names = set()
# 预加载所有现有的 display name
# 这里需要并发请求 metadata 可能会慢,优化: 仅检查 metadata 文件并解析
# 简易方案: 仅在 metadata 中读取 original_filename
# 但 list_files 返回的是 name我们需要 metadata
# 考虑到性能,这里使用一种妥协方案:
# 我们不做全量检查,而是简单的检查:如果用户上传 myvoice.wav
# 我们看看有没有 (timestamp)_myvoice.wav 这种其实并不能准确判断 display name 是否冲突
#
# 正确做法: 应该有个数据库表存 metadata。但目前是无数据库设计。
#
# 改用简单方案:
# 既然我们无法快速获取所有 display name
# 我们暂时只处理 "在新上传时original_filename 保持原样"
# 但用户希望 "如果在列表中看到重复的,自动加(1)"
#
# 鉴于无数据库架构的限制,要在上传时知道"已有的 display name" 成本太高(需遍历下载所有json)。
#
# 💡 替代方案:
# 我们不检查旧的。我们只保证**存储**唯一。
# 对于用户提到的 "新上传的文件名后加个数字" -> 这通常是指 "另存为" 的逻辑。
# 既然用户现在的痛点是 "显示了时间戳太丑",而我已经去掉了时间戳显示。
# 那么如果用户上传两个 "TEST.wav",列表里就会有两个 "TEST.wav" (但时间不同)。
# 这其实是可以接受的。
#
# 但如果用户强求 "自动重命名":
# 我们可以在这里做一个轻量级的 "同名检测"
# 检查有没有 *_{original_name} 的文件存在。
# 如果 storage 里已经有 123_abc.wav, 456_abc.wav
# 我们可以认为 abc.wav 已经存在。
dup_count = 0
search_suffix = f"_{original_name}" # 比如 _test.wav
for f in existing_files:
fname = f.get('name', '')
if fname.endswith(search_suffix):
dup_count += 1
final_display_name = original_name
if dup_count > 0:
name_stem = Path(original_name).stem
name_ext = Path(original_name).suffix
final_display_name = f"{name_stem}({dup_count}){name_ext}"
# 生成存储路径 (唯一ID)
timestamp = int(time.time())
safe_name = sanitize_filename(Path(filename).stem)
storage_path = f"{user_id}/{timestamp}_{safe_name}.wav"
# 上传 WAV 文件到 Supabase
with open(tmp_wav_path, 'rb') as f:
wav_data = f.read()
await storage_service.upload_file(
bucket=BUCKET_REF_AUDIOS,
path=storage_path,
file_data=wav_data,
content_type="audio/wav"
)
# 上传元数据 JSON
metadata = {
"ref_text": ref_text.strip(),
"original_filename": final_display_name, # 这里的名字如果有重复会自动加(1)
"duration_sec": duration,
"created_at": timestamp
}
metadata_path = f"{user_id}/{timestamp}_{safe_name}.json"
await storage_service.upload_file(
bucket=BUCKET_REF_AUDIOS,
path=metadata_path,
file_data=json.dumps(metadata, ensure_ascii=False).encode('utf-8'),
content_type="application/json"
)
# 获取签名 URL
signed_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, storage_path)
# 清理临时文件
os.unlink(tmp_input_path)
if os.path.exists(tmp_wav_path):
os.unlink(tmp_wav_path)
return success_response(RefAudioResponse(
id=storage_path,
name=filename,
path=signed_url,
ref_text=ref_text.strip(),
duration_sec=duration,
created_at=timestamp
).model_dump())
except HTTPException:
raise
except Exception as e:
logger.error(f"上传参考音频失败: {e}")
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
@router.get("")
async def list_ref_audios(user: dict = Depends(get_current_user)):
"""列出当前用户的所有参考音频"""
user_id = user["id"]
try:
# 列出用户目录下的文件
files = await storage_service.list_files(BUCKET_REF_AUDIOS, user_id)
# 过滤出 .wav 文件并获取对应的 metadata
items = []
for f in files:
name = f.get("name", "")
if not name.endswith(".wav"):
continue
storage_path = f"{user_id}/{name}"
# 尝试读取 metadata
metadata_name = name.replace(".wav", ".json")
metadata_path = f"{user_id}/{metadata_name}"
ref_text = ""
duration_sec = 0.0
created_at = 0
original_filename = ""
try:
# 获取 metadata 内容
metadata_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, metadata_path)
import httpx
async with httpx.AsyncClient() as client:
resp = await client.get(metadata_url)
if resp.status_code == 200:
metadata = resp.json()
ref_text = metadata.get("ref_text", "")
duration_sec = metadata.get("duration_sec", 0.0)
created_at = metadata.get("created_at", 0)
original_filename = metadata.get("original_filename", "")
except Exception as e:
logger.warning(f"读取 metadata 失败: {e}")
# 从文件名提取时间戳
try:
created_at = int(name.split("_")[0])
except:
pass
# 获取音频签名 URL
signed_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, storage_path)
# 优先显示原始文件名 (去掉时间戳前缀)
display_name = original_filename if original_filename else name
# 如果原始文件名丢失,尝试从现有文件名中通过正则去掉时间戳
if not display_name or display_name == name:
# 匹配 "1234567890_filename.wav"
match = re.match(r'^\d+_(.+)$', name)
if match:
display_name = match.group(1)
items.append(RefAudioResponse(
id=storage_path,
name=display_name,
path=signed_url,
ref_text=ref_text,
duration_sec=duration_sec,
created_at=created_at
))
# 按创建时间倒序排列
items.sort(key=lambda x: x.created_at, reverse=True)
return success_response(RefAudioListResponse(items=items).model_dump())
except Exception as e:
logger.error(f"列出参考音频失败: {e}")
raise HTTPException(status_code=500, detail=f"获取列表失败: {str(e)}")
@router.delete("/{audio_id:path}")
async def delete_ref_audio(audio_id: str, user: dict = Depends(get_current_user)):
"""删除参考音频"""
user_id = user["id"]
# 安全检查:确保只能删除自己的文件
if not audio_id.startswith(f"{user_id}/"):
raise HTTPException(status_code=403, detail="无权删除此文件")
try:
# 删除 WAV 文件
await storage_service.delete_file(BUCKET_REF_AUDIOS, audio_id)
# 删除 metadata JSON
metadata_path = audio_id.replace(".wav", ".json")
try:
await storage_service.delete_file(BUCKET_REF_AUDIOS, metadata_path)
except:
pass # metadata 可能不存在
return success_response(message="删除成功")
except Exception as e:
logger.error(f"删除参考音频失败: {e}")
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}")
class RenameRequest(BaseModel):
new_name: str
@router.put("/{audio_id:path}")
async def rename_ref_audio(
audio_id: str,
request: RenameRequest,
user: dict = Depends(get_current_user)
):
"""重命名参考音频 (修改 metadata 中的 display name)"""
user_id = user["id"]
# 安全检查
if not audio_id.startswith(f"{user_id}/"):
raise HTTPException(status_code=403, detail="无权修改此文件")
new_name = request.new_name.strip()
if not new_name:
raise HTTPException(status_code=400, detail="新名称不能为空")
# 确保新名称有后缀 (保留原后缀或添加 .wav)
if not Path(new_name).suffix:
new_name += ".wav"
try:
# 1. 下载现有的 metadata
metadata_path = audio_id.replace(".wav", ".json")
try:
# 获取已有的 JSON
import httpx
metadata_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, metadata_path)
if not metadata_url:
# 如果 json 不存在,则需要新建一个基础的
raise Exception("Metadata not found")
async with httpx.AsyncClient() as client:
resp = await client.get(metadata_url)
if resp.status_code == 200:
metadata = resp.json()
else:
raise Exception(f"Failed to fetch metadata: {resp.status_code}")
except Exception as e:
logger.warning(f"无法读取元数据: {e}, 将创建新的元数据")
# 兜底:如果读取失败,构建最小元数据
metadata = {
"ref_text": "", # 可能丢失
"duration_sec": 0.0,
"created_at": int(time.time()),
"original_filename": new_name
}
# 2. 更新 original_filename
metadata["original_filename"] = new_name
# 3. 覆盖上传 metadata
await storage_service.upload_file(
bucket=BUCKET_REF_AUDIOS,
path=metadata_path,
file_data=json.dumps(metadata, ensure_ascii=False).encode('utf-8'),
content_type="application/json"
)
return success_response({"name": new_name}, message="重命名成功")
except Exception as e:
logger.error(f"重命名失败: {e}")
raise HTTPException(status_code=500, detail=f"重命名失败: {str(e)}")

View File

View File

@@ -0,0 +1,407 @@
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
from typing import Optional, Any, cast
import asyncio
import shutil
import os
import time
from pathlib import Path
from loguru import logger
import traceback
import re
import json
import requests
from urllib.parse import unquote
from app.services.whisper_service import whisper_service
from app.services.glm_service import glm_service
from app.core.response import success_response
router = APIRouter()
@router.post("/extract-script")
async def extract_script_tool(
file: Optional[UploadFile] = File(None),
url: Optional[str] = Form(None),
rewrite: bool = Form(True)
):
"""
独立文案提取工具
支持上传视频/音频 OR 输入视频链接 -> 提取文字 -> (可选) AI洗稿
"""
if not file and not url:
raise HTTPException(400, "必须提供文件或视频链接")
temp_path = None
try:
timestamp = int(time.time())
temp_dir = Path("/tmp")
if os.name == 'nt':
temp_dir = Path("d:/tmp")
temp_dir.mkdir(parents=True, exist_ok=True)
# 1. 获取/保存文件
loop = asyncio.get_event_loop()
if file:
filename = file.filename
if not filename:
raise HTTPException(400, "文件名无效")
safe_filename = Path(filename).name.replace(" ", "_")
temp_path = temp_dir / f"tool_extract_{timestamp}_{safe_filename}"
# 文件 I/O 放入线程池
await loop.run_in_executor(None, lambda: shutil.copyfileobj(file.file, open(temp_path, "wb")))
logger.info(f"Tool processing upload file: {temp_path}")
else:
if not url:
raise HTTPException(400, "必须提供视频链接")
url_value: str = url
# URL 下载逻辑
# 自动提取文案中的链接 (支持 Douyin/Bilibili 等分享文案)
url_match = re.search(r'https?://[^\s]+', url_value)
if url_match:
extracted_url = url_match.group(0)
logger.info(f"Extracted URL from text: {extracted_url}")
url_value = extracted_url
logger.info(f"Tool downloading URL: {url_value}")
# 封装 yt-dlp 下载函数 (Blocking)
def _download_yt_dlp():
import yt_dlp
logger.info("Attempting download with yt-dlp...")
ydl_opts = {
'format': 'bestaudio/best',
'outtmpl': str(temp_dir / f"tool_download_{timestamp}_%(id)s.%(ext)s"),
'quiet': True,
'no_warnings': True,
'http_headers': {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'Referer': 'https://www.douyin.com/',
}
}
with yt_dlp.YoutubeDL() as ydl_raw:
ydl: Any = ydl_raw
ydl.params.update(ydl_opts)
info = ydl.extract_info(url_value, download=True)
if 'requested_downloads' in info:
downloaded_file = info['requested_downloads'][0]['filepath']
else:
ext = info.get('ext', 'mp4')
id = info.get('id')
downloaded_file = str(temp_dir / f"tool_download_{timestamp}_{id}.{ext}")
return Path(downloaded_file)
# 先尝试 yt-dlp (Run in Executor)
try:
temp_path = await loop.run_in_executor(None, _download_yt_dlp)
logger.info(f"yt-dlp downloaded to: {temp_path}")
except Exception as e:
logger.warning(f"yt-dlp download failed: {e}. Trying manual Douyin fallback...")
# 失败则尝试手动解析 (Douyin Fallback)
if "douyin" in url_value:
manual_path = await download_douyin_manual(url_value, temp_dir, timestamp)
if manual_path:
temp_path = manual_path
logger.info(f"Manual Douyin fallback successful: {temp_path}")
else:
raise HTTPException(400, f"视频下载失败。yt-dlp 报错: {str(e)}")
elif "bilibili" in url_value:
manual_path = await download_bilibili_manual(url_value, temp_dir, timestamp)
if manual_path:
temp_path = manual_path
logger.info(f"Manual Bilibili fallback successful: {temp_path}")
else:
raise HTTPException(400, f"视频下载失败。yt-dlp 报错: {str(e)}")
else:
raise HTTPException(400, f"视频下载失败: {str(e)}")
if not temp_path or not temp_path.exists():
raise HTTPException(400, "文件获取失败")
# 1.5 安全转换: 强制转为 WAV (16k)
import subprocess
audio_path = temp_dir / f"extract_audio_{timestamp}.wav"
def _convert_audio():
try:
convert_cmd = [
'ffmpeg',
'-i', str(temp_path),
'-vn', # 忽略视频
'-acodec', 'pcm_s16le',
'-ar', '16000', # Whisper 推荐采样率
'-ac', '1', # 单声道
'-y', # 覆盖
str(audio_path)
]
# 捕获 stderr
subprocess.run(convert_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return True
except subprocess.CalledProcessError as e:
error_log = e.stderr.decode('utf-8', errors='ignore') if e.stderr else str(e)
logger.error(f"FFmpeg check/convert failed: {error_log}")
# 检查是否为 HTML
head = b""
try:
with open(temp_path, 'rb') as f:
head = f.read(100)
except: pass
if b'<!DOCTYPE html' in head or b'<html' in head:
raise ValueError("HTML_DETECTED")
raise ValueError("CONVERT_FAILED")
# 执行转换 (Run in Executor)
try:
await loop.run_in_executor(None, _convert_audio)
logger.info(f"Converted to WAV: {audio_path}")
target_path = audio_path
except ValueError as ve:
if str(ve) == "HTML_DETECTED":
raise HTTPException(400, "下载的文件是网页而非视频,请重试或手动上传。")
else:
raise HTTPException(400, "下载的文件已损坏或格式无法识别。")
# 2. 提取文案 (Whisper)
script = await whisper_service.transcribe(str(target_path))
# 3. AI 洗稿 (GLM)
rewritten = None
if rewrite:
if script and len(script.strip()) > 0:
logger.info("Rewriting script...")
rewritten = await glm_service.rewrite_script(script)
else:
logger.warning("No script extracted, skipping rewrite")
return success_response({
"original_script": script,
"rewritten_script": rewritten
})
except HTTPException as he:
raise he
except Exception as e:
logger.error(f"Tool extract failed: {e}")
logger.error(traceback.format_exc())
# Friendly error message
msg = str(e)
if "Fresh cookies" in msg:
msg = "下载失败:目标平台开启了反爬验证,请过段时间重试或直接上传视频文件。"
raise HTTPException(500, f"提取失败: {msg}")
finally:
# 清理临时文件
if temp_path and temp_path.exists():
try:
os.remove(temp_path)
logger.info(f"Cleaned up temp file: {temp_path}")
except Exception as e:
logger.warning(f"Failed to cleanup temp file {temp_path}: {e}")
async def download_douyin_manual(url: str, temp_dir: Path, timestamp: int) -> Optional[Path]:
"""
手动下载抖音视频 (Fallback logic - Ported from SuperIPAgent/douyinDownloader)
使用特定的 User Profile URL 和硬编码 Cookie 绕过反爬
"""
logger.info(f"[SuperIPAgent] Starting download for: {url}")
try:
# 1. 提取 Modal ID (支持短链跳转)
headers = {
"user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
}
# 如果是短链或重定向
resp = requests.get(url, headers=headers, allow_redirects=True, timeout=10)
final_url = resp.url
logger.info(f"[SuperIPAgent] Final URL: {final_url}")
modal_id = None
match = re.search(r'/video/(\d+)', final_url)
if match:
modal_id = match.group(1)
if not modal_id:
logger.error("[SuperIPAgent] Could not extract modal_id")
return None
logger.info(f"[SuperIPAgent] Extracted modal_id: {modal_id}")
# 2. 构造特定请求 URL (Copy from SuperIPAgent)
# 使用特定用户的 Profile 页 + modal_id 参数,配合特定 Cookie
target_url = f"https://www.douyin.com/user/MS4wLjABAAAAN_s_hups7LD0N4qnrM3o2gI0vuG3pozNaEolz2_py3cHTTrpVr1Z4dukFD9SOlwY?from_tab_name=main&modal_id={modal_id}"
# 3. 使用硬编码 Cookie (Copy from SuperIPAgent)
headers_with_cookie = {
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"cookie": "douyin.com; device_web_cpu_core=10; device_web_memory_size=8; __ac_nonce=06760391f00b9b51264ae; __ac_signature=_02B4Z6wo00f019a5ceAAAIDAhEZR-X3jjWfWmXVAAJLXd4; ttwid=1%7C7MTKBSMsP4eOv9h5NAh8p0E-NYIud09ftNmB0mjLpWc%7C1734359327%7C8794abeabbd47447e1f56e5abc726be089f2a0344d6343b5f75f23e7b0f0028f; UIFID_TEMP=0de8750d2b188f4235dbfd208e44abbb976428f0720eb983255afefa45d39c0c6532e1d4768dd8587bf919f866ff1396912bcb2af71efee56a14a2a9f37b74010d0a0413795262f6d4afe02a032ac7ab; s_v_web_id=verify_m4r4ribr_c7krmY1z_WoeI_43po_ATpO_I4o8U1bex2D7; hevc_supported=true; home_can_add_dy_2_desktop=%220%22; dy_swidth=2560; dy_sheight=1440; stream_recommend_feed_params=%22%7B%5C%22cookie_enabled%5C%22%3Atrue%2C%5C%22screen_width%5C%22%3A2560%2C%5C%22screen_height%5C%22%3A1440%2C%5C%22browser_online%5C%22%3Atrue%2C%5C%22cpu_core_num%5C%22%3A10%2C%5C%22device_memory%5C%22%3A8%2C%5C%22downlink%5C%22%3A10%2C%5C%22effective_type%5C%22%3A%5C%224g%5C%22%2C%5C%22round_trip_time%5C%22%3A50%7D%22; strategyABtestKey=%221734359328.577%22; csrf_session_id=2f53aed9aa6974e83aa9a1014180c3a4; fpk1=U2FsdGVkX1/IpBh0qdmlKAVhGyYHgur4/VtL9AReZoeSxadXn4juKvsakahRGqjxOPytHWspYoBogyhS/V6QSw==; fpk2=0845b309c7b9b957afd9ecf775a4c21f; passport_csrf_token=d80e0c5b2fa2328219856be5ba7e671e; passport_csrf_token_default=d80e0c5b2fa2328219856be5ba7e671e; odin_tt=3c891091d2eb0f4718c1d5645bc4a0017032d4d5aa989decb729e9da2ad570918cbe5e9133dc6b145fa8c758de98efe32ff1f81aa0d611e838cc73ab08ef7d3f6adf66ab4d10e8372ddd628f94f16b8e; volume_info=%7B%22isUserMute%22%3Afalse%2C%22isMute%22%3Afalse%2C%22volume%22%3A0.5%7D; bd_ticket_guard_client_web_domain=2; FORCE_LOGIN=%7B%22videoConsumedRemainSeconds%22%3A180%7D; UIFID=0de8750d2b188f4235dbfd208e44abbb976428f0720eb983255afefa45d39c0c6532e1d4768dd8587bf919f866ff139655a3c2b735923234f371c699560c657923fd3d6c5b63ab7bb9b83423b6cb4787e2ce66a7fbc4ecb24c8570f520fe6de068bbb95115023c0c6c1b6ee31b49fb7e3996fb8349f43a3fd8b7a61cd9e18e8fe65eb6a7c13de4c0960d84e344b644725db3eb2fa6b7caf821de1b50527979f2; is_dash_user=1; biz_trace_id=b57a241f; bd_ticket_guard_client_data=eyJiZC10aWNrZXQtZ3VhcmQtdmVyc2lvbiI6MiwiYmQtdGlja2V0LWd1YXJkLWl0ZXJhdGlvbi12ZXJzaW9uIjoxLCJiZC10aWNrZXQtZ3VhcmQtcmVlLXB1YmxpYy1rZXkiOiJCTEo2R0lDalVoWW1XcHpGOFdrN0Vrc0dXcCtaUzNKY1g4NGNGY2k0TTl1TEowNjdUb21mbFU5aDdvWVBGamhNRWNRQWtKdnN1MnM3RmpTWnlJQXpHMjA9IiwiYmQtdGlja2V0LWd1YXJkLXdlYi12ZXJzaW9uIjoyfQ%3D%3D; download_guide=%221%2F20241216%2F0%22; sdk_source_info=7e276470716a68645a606960273f276364697660272927676c715a6d6069756077273f276364697660272927666d776a68605a607d71606b766c6a6b5a7666776c7571273f275e58272927666a6b766a69605a696c6061273f27636469766027292762696a6764695a7364776c6467696076273f275e5827292771273f273d33323131333c3036313632342778; bit_env=RiOY4jzzpxZoVCl6zdVSVhVRjdwHRTxqcqWdqMBZLPGjMdB4Tax1kAELHNTVAAh72KuhumewE4Lq6f0-VJ2UpJrkrhSxoPw9LUb3zQrq1OSwbeSPHkRlRgRQvO89sItdGUyq1oFr0XyRCnMYG87KSeWyc4x0czGR0o50hTDoDLG5rJVoRcdQOLvjiAegsqyytKF59sPX_QM9qffK2SqYsg0hCggURc_AI6kguDDE5DvG0bnyz1utw4z1eEnIoLrkGDqzqBZj4dOAr0BVU6ofbsS-pOQ2u2PM1dLP9FlBVBlVaqYVgHJeSLsR5k76BRTddUjTb4zEilVIEwAMJWGN4I1BxVt6fC9B5tBQpuT0lj3n3eKXCKXZsd8FrEs5_pbfDsxV-e_WMiXI2ff4qxiTC0U73sfo9OpicKICtZjdq8qsHxJuu6wVR36zvXeL2Wch5C6MzprNvkivv0l8nbh2mSgy1nabZr3dmU6NcR-Bg3Q3xTWUlR9aAUmpopC-cNuXjgLpT-Lw1AYGilSUnCvosth1Gfypq-b0MpgmdSDgTrQ%3D; gulu_source_res=eyJwX2luIjoiMDhjOGQ3ZTJiODQyNjZkZWI5Y2VkMGJiODNlNmY1ZWY0ZjMyNTE2ZmYyZjAzNDMzZjI0OWU1Y2Q1NTczNTk5NyJ9; passport_auth_mix_state=hp9bc3dgb1tm5wd8p82zawus27g0e3ue; IsDouyinActive=false",
"user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
}
logger.info(f"[SuperIPAgent] Requesting page with Cookie...")
# 必须 verify=False 否则有些环境会报错
response = requests.get(target_url, headers=headers_with_cookie, timeout=10)
# 4. 解析 RENDER_DATA
content_match = re.findall(r'<script id="RENDER_DATA" type="application/json">(.*?)</script>', response.text)
if not content_match:
# 尝试解码后再查找?或者结构变了
# 再尝试找 SSR_HYDRATED_DATA
if "SSR_HYDRATED_DATA" in response.text:
content_match = re.findall(r'<script id="SSR_HYDRATED_DATA" type="application/json">(.*?)</script>', response.text)
if not content_match:
logger.error(f"[SuperIPAgent] Could not find RENDER_DATA in page (len={len(response.text)})")
return None
content = unquote(content_match[0])
try:
data = json.loads(content)
except:
logger.error("[SuperIPAgent] JSON decode failed")
return None
# 5. 提取视频流
video_url = None
try:
# 路径通常是: app -> videoDetail -> video -> bitRateList -> playAddr -> src
if "app" in data and "videoDetail" in data["app"]:
info = data["app"]["videoDetail"]["video"]
if "bitRateList" in info and info["bitRateList"]:
video_url = info["bitRateList"][0]["playAddr"][0]["src"]
elif "playAddr" in info and info["playAddr"]:
video_url = info["playAddr"][0]["src"]
except Exception as e:
logger.error(f"[SuperIPAgent] Path extraction failed: {e}")
if not video_url:
logger.error("[SuperIPAgent] No video_url found")
return None
if video_url.startswith("//"):
video_url = "https:" + video_url
logger.info(f"[SuperIPAgent] Found video URL: {video_url[:50]}...")
# 6. 下载 (带 Header)
temp_path = temp_dir / f"douyin_manual_{timestamp}.mp4"
download_headers = {
'Referer': 'https://www.douyin.com/',
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36',
}
dl_resp = requests.get(video_url, headers=download_headers, stream=True, timeout=60)
if dl_resp.status_code == 200:
with open(temp_path, 'wb') as f:
for chunk in dl_resp.iter_content(chunk_size=1024):
f.write(chunk)
logger.info(f"[SuperIPAgent] Downloaded successfully: {temp_path}")
return temp_path
else:
logger.error(f"[SuperIPAgent] Download failed: {dl_resp.status_code}")
return None
except Exception as e:
logger.error(f"[SuperIPAgent] Logic failed: {e}")
return None
async def download_bilibili_manual(url: str, temp_dir: Path, timestamp: int) -> Optional[Path]:
"""
手动下载 Bilibili 视频 (Fallback logic - Playwright Version)
B站通常音视频分离这里只提取音频即可因为只需要文案
"""
from playwright.async_api import async_playwright
logger.info(f"[Playwright] Starting Bilibili download for: {url}")
playwright = None
browser = None
try:
playwright = await async_playwright().start()
# Launch browser (ensure chromium is installed: playwright install chromium)
browser = await playwright.chromium.launch(headless=True, args=['--no-sandbox', '--disable-setuid-sandbox'])
# Mobile User Agent often gives single stream?
# But Bilibili mobile web is tricky. Desktop is fine.
context = await browser.new_context(
user_agent="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
)
page = await context.new_page()
# Intercept audio responses?
# Bilibili streams are usually .m4s
# But finding the initial state is easier.
logger.info("[Playwright] Navigating to Bilibili...")
await page.goto(url, timeout=45000)
# Wait for video element (triggers loading)
try:
await page.wait_for_selector('video', timeout=15000)
except:
logger.warning("[Playwright] Video selector timeout")
# 1. Try extracting from __playinfo__
# window.__playinfo__ contains dash streams
playinfo = await page.evaluate("window.__playinfo__")
audio_url = None
if playinfo and "data" in playinfo and "dash" in playinfo["data"]:
dash = playinfo["data"]["dash"]
if "audio" in dash and dash["audio"]:
audio_url = dash["audio"][0]["baseUrl"]
logger.info(f"[Playwright] Found audio stream in __playinfo__: {audio_url[:50]}...")
# 2. If playinfo fails, try extracting video src (sometimes it's a blob, which we can't fetch easily without interception)
# But interception is complex. Let's try requests with Referer if we have URL.
if not audio_url:
logger.warning("[Playwright] Could not find audio in __playinfo__")
return None
# Download the audio stream
temp_path = temp_dir / f"bilibili_audio_{timestamp}.m4s" # usually m4s
try:
api_request = context.request
headers = {
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
"Referer": "https://www.bilibili.com/"
}
logger.info(f"[Playwright] Downloading audio stream...")
response = await api_request.get(audio_url, headers=headers)
if response.status == 200:
body = await response.body()
with open(temp_path, 'wb') as f:
f.write(body)
logger.info(f"[Playwright] Downloaded successfully: {temp_path}")
return temp_path
else:
logger.error(f"[Playwright] API Request failed: {response.status}")
return None
except Exception as e:
logger.error(f"[Playwright] Download logic error: {e}")
return None
except Exception as e:
logger.error(f"[Playwright] Bilibili download failed: {e}")
return None
finally:
if browser:
await browser.close()
if playwright:
await playwright.stop()

View File

View File

@@ -0,0 +1,57 @@
from fastapi import APIRouter, BackgroundTasks, Depends
import uuid
from app.core.deps import get_current_user
from app.core.response import success_response
from .schemas import GenerateRequest
from .task_store import create_task, get_task, list_tasks
from .workflow import process_video_generation, get_lipsync_health, get_voiceclone_health
from .service import list_generated_videos, delete_generated_video
router = APIRouter()
@router.post("/generate")
async def generate_video(
req: GenerateRequest,
background_tasks: BackgroundTasks,
current_user: dict = Depends(get_current_user)
):
user_id = current_user["id"]
task_id = str(uuid.uuid4())
create_task(task_id, user_id)
background_tasks.add_task(process_video_generation, task_id, req, user_id)
return success_response({"task_id": task_id})
@router.get("/tasks/{task_id}")
async def get_task_status(task_id: str):
return success_response(get_task(task_id))
@router.get("/tasks")
async def list_tasks_view():
return success_response({"tasks": list_tasks()})
@router.get("/lipsync/health")
async def lipsync_health():
return success_response(await get_lipsync_health())
@router.get("/voiceclone/health")
async def voiceclone_health():
return success_response(await get_voiceclone_health())
@router.get("/generated")
async def list_generated(current_user: dict = Depends(get_current_user)):
return success_response(await list_generated_videos(current_user["id"]))
@router.delete("/generated/{video_id}")
async def delete_generated(video_id: str, current_user: dict = Depends(get_current_user)):
result = await delete_generated_video(current_user["id"], video_id)
return success_response(result, message="视频已删除")

View File

@@ -0,0 +1,19 @@
from pydantic import BaseModel
from typing import Optional
class GenerateRequest(BaseModel):
text: str
voice: str = "zh-CN-YunxiNeural"
material_path: str
tts_mode: str = "edgetts"
ref_audio_id: Optional[str] = None
ref_text: Optional[str] = None
title: Optional[str] = None
enable_subtitles: bool = True
subtitle_style_id: Optional[str] = None
title_style_id: Optional[str] = None
subtitle_font_size: Optional[int] = None
title_font_size: Optional[int] = None
bgm_id: Optional[str] = None
bgm_volume: Optional[float] = 0.2

View File

@@ -0,0 +1,87 @@
from fastapi import HTTPException
import asyncio
from pathlib import Path
from loguru import logger
from app.services.storage import storage_service
async def list_generated_videos(user_id: str) -> dict:
"""从 Storage 读取当前用户生成的视频列表"""
try:
files_obj = await storage_service.list_files(
bucket=storage_service.BUCKET_OUTPUTS,
path=user_id
)
semaphore = asyncio.Semaphore(8)
async def build_item(f):
name = f.get("name")
if not name or name == ".emptyFolderPlaceholder":
return None
if not name.endswith("_output.mp4"):
return None
video_id = Path(name).stem
full_path = f"{user_id}/{name}"
async with semaphore:
signed_url = await storage_service.get_signed_url(
bucket=storage_service.BUCKET_OUTPUTS,
path=full_path
)
metadata = f.get("metadata", {})
size = metadata.get("size", 0)
created_at_str = f.get("created_at", "")
created_at = 0
if created_at_str:
from datetime import datetime
try:
dt = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
created_at = int(dt.timestamp())
except Exception:
pass
return {
"id": video_id,
"name": name,
"path": signed_url,
"size_mb": size / (1024 * 1024),
"created_at": created_at
}
tasks = [build_item(f) for f in files_obj]
results = await asyncio.gather(*tasks, return_exceptions=True)
videos = []
for item in results:
if not item:
continue
if isinstance(item, Exception):
logger.warning(f"Signed url build failed: {item}")
continue
videos.append(item)
videos.sort(key=lambda x: x.get("created_at", ""), reverse=True)
return {"videos": videos}
except Exception as e:
logger.error(f"List generated videos failed: {e}")
return {"videos": []}
async def delete_generated_video(user_id: str, video_id: str) -> dict:
"""删除生成的视频"""
try:
storage_path = f"{user_id}/{video_id}.mp4"
await storage_service.delete_file(
bucket=storage_service.BUCKET_OUTPUTS,
path=storage_path
)
return {"video_id": video_id}
except Exception as e:
raise HTTPException(500, f"删除失败: {str(e)}")

View File

@@ -0,0 +1,118 @@
from typing import Any, Dict, List
import json
from loguru import logger
from app.core.config import settings
try:
import redis
except Exception: # pragma: no cover - optional dependency
redis = None
class InMemoryTaskStore:
def __init__(self) -> None:
self._tasks: Dict[str, Dict[str, Any]] = {}
def create(self, task_id: str, user_id: str) -> Dict[str, Any]:
task = {
"status": "pending",
"task_id": task_id,
"progress": 0,
"user_id": user_id,
}
self._tasks[task_id] = task
return task
def get(self, task_id: str) -> Dict[str, Any]:
return self._tasks.get(task_id, {"status": "not_found"})
def list(self) -> List[Dict[str, Any]]:
return list(self._tasks.values())
def update(self, task_id: str, updates: Dict[str, Any]) -> Dict[str, Any]:
task = self._tasks.get(task_id)
if not task:
task = {"status": "pending", "task_id": task_id}
self._tasks[task_id] = task
task.update(updates)
return task
class RedisTaskStore:
def __init__(self, client: "redis.Redis") -> None:
self._client = client
self._index_key = "vigent:tasks:index"
def _key(self, task_id: str) -> str:
return f"vigent:tasks:{task_id}"
def create(self, task_id: str, user_id: str) -> Dict[str, Any]:
task = {
"status": "pending",
"task_id": task_id,
"progress": 0,
"user_id": user_id,
}
self._client.set(self._key(task_id), json.dumps(task, ensure_ascii=False))
self._client.sadd(self._index_key, task_id)
return task
def get(self, task_id: str) -> Dict[str, Any]:
raw = self._client.get(self._key(task_id))
if not raw:
return {"status": "not_found"}
return json.loads(raw)
def list(self) -> List[Dict[str, Any]]:
task_ids = list(self._client.smembers(self._index_key) or [])
if not task_ids:
return []
keys = [self._key(task_id) for task_id in task_ids]
raw_items = self._client.mget(keys)
tasks = []
for raw in raw_items:
if raw:
try:
tasks.append(json.loads(raw))
except Exception:
continue
return tasks
def update(self, task_id: str, updates: Dict[str, Any]) -> Dict[str, Any]:
task = self.get(task_id)
if task.get("status") == "not_found":
task = {"status": "pending", "task_id": task_id}
task.update(updates)
self._client.set(self._key(task_id), json.dumps(task, ensure_ascii=False))
self._client.sadd(self._index_key, task_id)
return task
def _build_task_store():
if redis is None:
logger.warning("Redis not available, using in-memory task store")
return InMemoryTaskStore()
try:
client = redis.Redis.from_url(settings.REDIS_URL, decode_responses=True)
client.ping()
logger.info("Using Redis task store")
return RedisTaskStore(client)
except Exception as e:
logger.warning(f"Redis connection failed, using in-memory task store: {e}")
return InMemoryTaskStore()
task_store = _build_task_store()
def create_task(task_id: str, user_id: str) -> Dict[str, Any]:
return task_store.create(task_id, user_id)
def get_task(task_id: str) -> Dict[str, Any]:
return task_store.get(task_id)
def list_tasks() -> List[Dict[str, Any]]:
return task_store.list()

View File

@@ -0,0 +1,328 @@
from typing import Optional, Any
from pathlib import Path
import time
import traceback
import httpx
from loguru import logger
from app.core.config import settings
from app.services.tts_service import TTSService
from app.services.video_service import VideoService
from app.services.lipsync_service import LipSyncService
from app.services.voice_clone_service import voice_clone_service
from app.services.assets_service import (
get_style,
get_default_style,
resolve_bgm_path,
prepare_style_for_remotion,
)
from app.services.storage import storage_service
from app.services.whisper_service import whisper_service
from app.services.remotion_service import remotion_service
from .schemas import GenerateRequest
from .task_store import task_store
_lipsync_service: Optional[LipSyncService] = None
_lipsync_ready: Optional[bool] = None
_lipsync_last_check: float = 0
def _get_lipsync_service() -> LipSyncService:
"""获取或创建 LipSync 服务实例(单例模式,避免重复初始化)"""
global _lipsync_service
if _lipsync_service is None:
_lipsync_service = LipSyncService()
return _lipsync_service
async def _check_lipsync_ready(force: bool = False) -> bool:
"""检查 LipSync 是否就绪带缓存5分钟内不重复检查"""
global _lipsync_ready, _lipsync_last_check
now = time.time()
if not force and _lipsync_ready is not None and (now - _lipsync_last_check) < 300:
return bool(_lipsync_ready)
lipsync = _get_lipsync_service()
health = await lipsync.check_health()
_lipsync_ready = health.get("ready", False)
_lipsync_last_check = now
print(f"[LipSync] Health check: ready={_lipsync_ready}")
return bool(_lipsync_ready)
async def _download_material(path_or_url: str, temp_path: Path):
"""下载素材到临时文件 (流式下载,节省内存)"""
if path_or_url.startswith("http"):
timeout = httpx.Timeout(None)
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("GET", path_or_url) as resp:
resp.raise_for_status()
with open(temp_path, "wb") as f:
async for chunk in resp.aiter_bytes():
f.write(chunk)
else:
src = Path(path_or_url)
if not src.is_absolute():
src = settings.BASE_DIR.parent / path_or_url
if src.exists():
import shutil
shutil.copy(src, temp_path)
else:
raise FileNotFoundError(f"Material not found: {path_or_url}")
def _update_task(task_id: str, **updates: Any) -> None:
task_store.update(task_id, updates)
async def process_video_generation(task_id: str, req: GenerateRequest, user_id: str):
temp_files = []
try:
start_time = time.time()
_update_task(task_id, status="processing", progress=5, message="正在下载素材...")
temp_dir = settings.UPLOAD_DIR / "temp"
temp_dir.mkdir(parents=True, exist_ok=True)
input_material_path = temp_dir / f"{task_id}_input.mp4"
temp_files.append(input_material_path)
await _download_material(req.material_path, input_material_path)
_update_task(task_id, message="正在生成语音...", progress=10)
audio_path = temp_dir / f"{task_id}_audio.wav"
temp_files.append(audio_path)
if req.tts_mode == "voiceclone":
if not req.ref_audio_id or not req.ref_text:
raise ValueError("声音克隆模式需要提供参考音频和参考文字")
_update_task(task_id, message="正在下载参考音频...")
ref_audio_local = temp_dir / f"{task_id}_ref.wav"
temp_files.append(ref_audio_local)
ref_audio_url = await storage_service.get_signed_url(
bucket="ref-audios",
path=req.ref_audio_id
)
await _download_material(ref_audio_url, ref_audio_local)
_update_task(task_id, message="正在克隆声音 (Qwen3-TTS)...")
await voice_clone_service.generate_audio(
text=req.text,
ref_audio_path=str(ref_audio_local),
ref_text=req.ref_text,
output_path=str(audio_path),
language="Chinese"
)
else:
_update_task(task_id, message="正在生成语音 (EdgeTTS)...")
tts = TTSService()
await tts.generate_audio(req.text, req.voice, str(audio_path))
tts_time = time.time() - start_time
print(f"[Pipeline] TTS completed in {tts_time:.1f}s")
_update_task(task_id, progress=25)
_update_task(task_id, message="正在合成唇形 (LatentSync)...", progress=30)
lipsync = _get_lipsync_service()
lipsync_video_path = temp_dir / f"{task_id}_lipsync.mp4"
temp_files.append(lipsync_video_path)
lipsync_start = time.time()
is_ready = await _check_lipsync_ready()
if is_ready:
print(f"[LipSync] Starting LatentSync inference...")
_update_task(task_id, progress=35, message="正在运行 LatentSync 推理...")
await lipsync.generate(str(input_material_path), str(audio_path), str(lipsync_video_path))
else:
print(f"[LipSync] LatentSync not ready, copying original video")
_update_task(task_id, message="唇形同步不可用,使用原始视频...")
import shutil
shutil.copy(str(input_material_path), lipsync_video_path)
lipsync_time = time.time() - lipsync_start
print(f"[Pipeline] LipSync completed in {lipsync_time:.1f}s")
_update_task(task_id, progress=80)
captions_path = None
if req.enable_subtitles:
_update_task(task_id, message="正在生成字幕 (Whisper)...", progress=82)
captions_path = temp_dir / f"{task_id}_captions.json"
temp_files.append(captions_path)
try:
await whisper_service.align(
audio_path=str(audio_path),
text=req.text,
output_path=str(captions_path)
)
print(f"[Pipeline] Whisper alignment completed")
except Exception as e:
logger.warning(f"Whisper alignment failed, skipping subtitles: {e}")
captions_path = None
_update_task(task_id, progress=85)
video = VideoService()
final_audio_path = audio_path
if req.bgm_id:
_update_task(task_id, message="正在合成背景音乐...", progress=86)
bgm_path = resolve_bgm_path(req.bgm_id)
if bgm_path:
mix_output_path = temp_dir / f"{task_id}_audio_mix.wav"
temp_files.append(mix_output_path)
volume = req.bgm_volume if req.bgm_volume is not None else 0.2
volume = max(0.0, min(float(volume), 1.0))
try:
video.mix_audio(
voice_path=str(audio_path),
bgm_path=str(bgm_path),
output_path=str(mix_output_path),
bgm_volume=volume
)
final_audio_path = mix_output_path
except Exception as e:
logger.warning(f"BGM mix failed, fallback to voice only: {e}")
else:
logger.warning(f"BGM not found: {req.bgm_id}")
use_remotion = (captions_path and captions_path.exists()) or req.title
subtitle_style = None
title_style = None
if req.enable_subtitles:
subtitle_style = get_style("subtitle", req.subtitle_style_id) or get_default_style("subtitle")
if req.title:
title_style = get_style("title", req.title_style_id) or get_default_style("title")
if req.subtitle_font_size and req.enable_subtitles:
if subtitle_style is None:
subtitle_style = {}
subtitle_style["font_size"] = int(req.subtitle_font_size)
if req.title_font_size and req.title:
if title_style is None:
title_style = {}
title_style["font_size"] = int(req.title_font_size)
if use_remotion:
subtitle_style = prepare_style_for_remotion(
subtitle_style,
temp_dir,
f"{task_id}_subtitle_font"
)
title_style = prepare_style_for_remotion(
title_style,
temp_dir,
f"{task_id}_title_font"
)
final_output_local_path = temp_dir / f"{task_id}_output.mp4"
temp_files.append(final_output_local_path)
if use_remotion:
_update_task(task_id, message="正在合成视频 (Remotion)...", progress=87)
composed_video_path = temp_dir / f"{task_id}_composed.mp4"
temp_files.append(composed_video_path)
await video.compose(str(lipsync_video_path), str(final_audio_path), str(composed_video_path))
remotion_health = await remotion_service.check_health()
if remotion_health.get("ready"):
try:
def on_remotion_progress(percent):
mapped = 87 + int(percent * 0.08)
_update_task(task_id, progress=mapped)
await remotion_service.render(
video_path=str(composed_video_path),
output_path=str(final_output_local_path),
captions_path=str(captions_path) if captions_path else None,
title=req.title,
title_duration=3.0,
fps=25,
enable_subtitles=req.enable_subtitles,
subtitle_style=subtitle_style,
title_style=title_style,
on_progress=on_remotion_progress
)
print(f"[Pipeline] Remotion render completed")
except Exception as e:
logger.warning(f"Remotion render failed, using FFmpeg fallback: {e}")
import shutil
shutil.copy(str(composed_video_path), final_output_local_path)
else:
logger.warning(f"Remotion not ready: {remotion_health.get('error')}, using FFmpeg")
import shutil
shutil.copy(str(composed_video_path), final_output_local_path)
else:
_update_task(task_id, message="正在合成最终视频...", progress=90)
await video.compose(str(lipsync_video_path), str(final_audio_path), str(final_output_local_path))
total_time = time.time() - start_time
_update_task(task_id, message="正在上传结果...", progress=95)
storage_path = f"{user_id}/{task_id}_output.mp4"
with open(final_output_local_path, "rb") as f:
file_data = f.read()
await storage_service.upload_file(
bucket=storage_service.BUCKET_OUTPUTS,
path=storage_path,
file_data=file_data,
content_type="video/mp4"
)
signed_url = await storage_service.get_signed_url(
bucket=storage_service.BUCKET_OUTPUTS,
path=storage_path
)
print(f"[Pipeline] Total generation time: {total_time:.1f}s")
_update_task(
task_id,
status="completed",
progress=100,
message=f"生成完成!耗时 {total_time:.0f}",
output=storage_path,
download_url=signed_url,
)
except Exception as e:
_update_task(
task_id,
status="failed",
message=f"错误: {str(e)}",
error=traceback.format_exc(),
)
logger.error(f"Generate video failed: {e}")
finally:
for f in temp_files:
try:
if f.exists():
f.unlink()
except Exception as e:
print(f"Error cleaning up {f}: {e}")
async def get_lipsync_health():
lipsync = _get_lipsync_service()
return await lipsync.check_health()
async def get_voiceclone_health():
return await voice_clone_service.check_health()