更新
This commit is contained in:
0
backend/app/modules/__init__.py
Normal file
0
backend/app/modules/__init__.py
Normal file
0
backend/app/modules/admin/__init__.py
Normal file
0
backend/app/modules/admin/__init__.py
Normal file
164
backend/app/modules/admin/router.py
Normal file
164
backend/app/modules/admin/router.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
管理员 API:用户管理
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Any, cast
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from app.core.deps import get_current_admin
|
||||
from app.core.response import success_response
|
||||
from app.repositories.sessions import delete_sessions
|
||||
from app.repositories.users import get_user_by_id, list_users as list_users_repo, update_user
|
||||
from loguru import logger
|
||||
|
||||
router = APIRouter(prefix="/api/admin", tags=["管理"])
|
||||
|
||||
|
||||
class UserListItem(BaseModel):
|
||||
id: str
|
||||
phone: str
|
||||
username: Optional[str]
|
||||
role: str
|
||||
is_active: bool
|
||||
expires_at: Optional[str]
|
||||
created_at: str
|
||||
|
||||
|
||||
class ActivateRequest(BaseModel):
|
||||
expires_days: Optional[int] = None # 授权天数,None 表示永久
|
||||
|
||||
|
||||
@router.get("/users")
|
||||
async def list_users(admin: dict = Depends(get_current_admin)):
|
||||
"""获取所有用户列表"""
|
||||
try:
|
||||
data = list_users_repo()
|
||||
return success_response([
|
||||
UserListItem(
|
||||
id=u["id"],
|
||||
phone=u["phone"],
|
||||
username=u.get("username"),
|
||||
role=u["role"],
|
||||
is_active=u["is_active"],
|
||||
expires_at=u.get("expires_at"),
|
||||
created_at=u["created_at"]
|
||||
).model_dump()
|
||||
for u in data
|
||||
])
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户列表失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="获取用户列表失败"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/activate")
|
||||
async def activate_user(
|
||||
user_id: str,
|
||||
request: ActivateRequest,
|
||||
admin: dict = Depends(get_current_admin)
|
||||
):
|
||||
"""
|
||||
激活用户
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
request.expires_days: 授权天数 (None 表示永久)
|
||||
"""
|
||||
try:
|
||||
# 计算过期时间
|
||||
expires_at = None
|
||||
if request.expires_days:
|
||||
expires_at = (datetime.now(timezone.utc) + timedelta(days=request.expires_days)).isoformat()
|
||||
|
||||
result = update_user(user_id, {
|
||||
"is_active": True,
|
||||
"role": "user",
|
||||
"expires_at": expires_at
|
||||
})
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
logger.info(f"管理员 {admin['phone']} 激活用户 {user_id}, 有效期: {request.expires_days or '永久'} 天")
|
||||
|
||||
return success_response(message=f"用户已激活,有效期: {request.expires_days or '永久'} 天")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"激活用户失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="激活用户失败"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/deactivate")
|
||||
async def deactivate_user(
|
||||
user_id: str,
|
||||
admin: dict = Depends(get_current_admin)
|
||||
):
|
||||
"""停用用户"""
|
||||
try:
|
||||
# 不能停用管理员
|
||||
user = cast(dict[str, Any], get_user_by_id(user_id) or {})
|
||||
if user.get("role") == "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="不能停用管理员账号"
|
||||
)
|
||||
|
||||
update_user(user_id, {"is_active": False})
|
||||
delete_sessions(user_id)
|
||||
|
||||
logger.info(f"管理员 {admin['phone']} 停用用户 {user_id}")
|
||||
|
||||
return success_response(message="用户已停用")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"停用用户失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="停用用户失败"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/extend")
|
||||
async def extend_user(
|
||||
user_id: str,
|
||||
request: ActivateRequest,
|
||||
admin: dict = Depends(get_current_admin)
|
||||
):
|
||||
"""延长用户授权期限"""
|
||||
try:
|
||||
if not request.expires_days:
|
||||
# 设为永久
|
||||
expires_at = None
|
||||
else:
|
||||
# 获取当前过期时间
|
||||
user = cast(dict[str, Any], get_user_by_id(user_id) or {})
|
||||
|
||||
if user and user.get("expires_at"):
|
||||
current_expires = datetime.fromisoformat(user["expires_at"].replace("Z", "+00:00"))
|
||||
base_time = max(current_expires, datetime.now(timezone.utc))
|
||||
else:
|
||||
base_time = datetime.now(timezone.utc)
|
||||
|
||||
expires_at = (base_time + timedelta(days=request.expires_days)).isoformat()
|
||||
|
||||
update_user(user_id, {"expires_at": expires_at})
|
||||
|
||||
logger.info(f"管理员 {admin['phone']} 延长用户 {user_id} 授权 {request.expires_days or '永久'} 天")
|
||||
|
||||
return success_response(message=f"授权已延长 {request.expires_days or '永久'} 天")
|
||||
except Exception as e:
|
||||
logger.error(f"延长授权失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="延长授权失败"
|
||||
)
|
||||
0
backend/app/modules/ai/__init__.py
Normal file
0
backend/app/modules/ai/__init__.py
Normal file
46
backend/app/modules/ai/router.py
Normal file
46
backend/app/modules/ai/router.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
AI 相关 API 路由
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
|
||||
from app.services.glm_service import glm_service
|
||||
from app.core.response import success_response
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/ai", tags=["AI"])
|
||||
|
||||
|
||||
class GenerateMetaRequest(BaseModel):
|
||||
"""生成标题标签请求"""
|
||||
text: str
|
||||
|
||||
|
||||
class GenerateMetaResponse(BaseModel):
|
||||
"""生成标题标签响应"""
|
||||
title: str
|
||||
tags: list[str]
|
||||
|
||||
|
||||
@router.post("/generate-meta")
|
||||
async def generate_meta(req: GenerateMetaRequest):
|
||||
"""
|
||||
AI 生成视频标题和标签
|
||||
|
||||
根据口播文案自动生成吸引人的标题和相关标签
|
||||
"""
|
||||
if not req.text or not req.text.strip():
|
||||
raise HTTPException(status_code=400, detail="口播文案不能为空")
|
||||
|
||||
try:
|
||||
logger.info(f"Generating meta for text: {req.text[:50]}...")
|
||||
result = await glm_service.generate_title_tags(req.text)
|
||||
return success_response(GenerateMetaResponse(
|
||||
title=result.get("title", ""),
|
||||
tags=result.get("tags", [])
|
||||
).model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Generate meta failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
0
backend/app/modules/assets/__init__.py
Normal file
0
backend/app/modules/assets/__init__.py
Normal file
23
backend/app/modules/assets/router.py
Normal file
23
backend/app/modules/assets/router.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.core.deps import get_current_user
|
||||
from app.services.assets_service import list_styles, list_bgm
|
||||
from app.core.response import success_response
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/subtitle-styles")
|
||||
async def list_subtitle_styles(current_user: dict = Depends(get_current_user)):
|
||||
return success_response({"styles": list_styles("subtitle")})
|
||||
|
||||
|
||||
@router.get("/title-styles")
|
||||
async def list_title_styles(current_user: dict = Depends(get_current_user)):
|
||||
return success_response({"styles": list_styles("title")})
|
||||
|
||||
|
||||
@router.get("/bgm")
|
||||
async def list_bgm_items(current_user: dict = Depends(get_current_user)):
|
||||
return success_response({"bgm": list_bgm()})
|
||||
0
backend/app/modules/auth/__init__.py
Normal file
0
backend/app/modules/auth/__init__.py
Normal file
293
backend/app/modules/auth/router.py
Normal file
293
backend/app/modules/auth/router.py
Normal file
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
认证 API:注册、登录、登出、修改密码
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Response, status, Request
|
||||
from pydantic import BaseModel, field_validator
|
||||
from app.core.security import (
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
create_access_token,
|
||||
generate_session_token,
|
||||
set_auth_cookie,
|
||||
clear_auth_cookie,
|
||||
decode_access_token
|
||||
)
|
||||
from app.repositories.sessions import create_session, delete_sessions
|
||||
from app.repositories.users import create_user, get_user_by_id, get_user_by_phone, user_exists_by_phone, update_user
|
||||
from app.core.response import success_response
|
||||
from loguru import logger
|
||||
from typing import Optional, Any, cast
|
||||
import re
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["认证"])
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
phone: str
|
||||
password: str
|
||||
username: Optional[str] = None
|
||||
|
||||
@field_validator('phone')
|
||||
@classmethod
|
||||
def validate_phone(cls, v):
|
||||
if not re.match(r'^\d{11}$', v):
|
||||
raise ValueError('手机号必须是11位数字')
|
||||
return v
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
phone: str
|
||||
password: str
|
||||
|
||||
@field_validator('phone')
|
||||
@classmethod
|
||||
def validate_phone(cls, v):
|
||||
if not re.match(r'^\d{11}$', v):
|
||||
raise ValueError('手机号必须是11位数字')
|
||||
return v
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
old_password: str
|
||||
new_password: str
|
||||
|
||||
@field_validator('new_password')
|
||||
@classmethod
|
||||
def validate_new_password(cls, v):
|
||||
if len(v) < 6:
|
||||
raise ValueError('新密码长度至少6位')
|
||||
return v
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
phone: str
|
||||
username: Optional[str]
|
||||
role: str
|
||||
is_active: bool
|
||||
expires_at: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/register")
|
||||
async def register(request: RegisterRequest):
|
||||
"""
|
||||
用户注册
|
||||
|
||||
注册后状态为 pending,需要管理员激活
|
||||
"""
|
||||
try:
|
||||
if user_exists_by_phone(request.phone):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="该手机号已注册"
|
||||
)
|
||||
|
||||
# 创建用户
|
||||
password_hash = get_password_hash(request.password)
|
||||
|
||||
create_user({
|
||||
"phone": request.phone,
|
||||
"password_hash": password_hash,
|
||||
"username": request.username or f"用户{request.phone[-4:]}",
|
||||
"role": "pending",
|
||||
"is_active": False
|
||||
})
|
||||
|
||||
logger.info(f"新用户注册: {request.phone}")
|
||||
|
||||
return success_response(message="注册成功,请等待管理员审核激活")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"注册失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="注册失败,请稍后重试"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
async def login(request: LoginRequest, response: Response):
|
||||
"""
|
||||
用户登录
|
||||
|
||||
- 验证密码
|
||||
- 检查是否激活
|
||||
- 实现"后踢前"单设备登录
|
||||
"""
|
||||
try:
|
||||
user = cast(dict[str, Any], get_user_by_phone(request.phone) or {})
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="手机号或密码错误"
|
||||
)
|
||||
|
||||
# 验证密码
|
||||
if not verify_password(request.password, user["password_hash"]):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="手机号或密码错误"
|
||||
)
|
||||
|
||||
# 检查是否激活
|
||||
if not user["is_active"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="账号未激活,请等待管理员审核"
|
||||
)
|
||||
|
||||
# 检查授权是否过期
|
||||
if user.get("expires_at"):
|
||||
from datetime import datetime, timezone
|
||||
expires_at = datetime.fromisoformat(user["expires_at"].replace("Z", "+00:00"))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="授权已过期,请联系管理员续期"
|
||||
)
|
||||
|
||||
# 生成新的 session_token (后踢前)
|
||||
session_token = generate_session_token()
|
||||
|
||||
# 删除旧 session,插入新 session
|
||||
delete_sessions(user["id"])
|
||||
create_session(user["id"], session_token, None)
|
||||
|
||||
# 生成 JWT Token
|
||||
token = create_access_token(user["id"], session_token)
|
||||
|
||||
# 设置 HttpOnly Cookie
|
||||
set_auth_cookie(response, token)
|
||||
|
||||
logger.info(f"用户登录: {request.phone}")
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
"user": UserResponse(
|
||||
id=user["id"],
|
||||
phone=user["phone"],
|
||||
username=user.get("username"),
|
||||
role=user["role"],
|
||||
is_active=user["is_active"],
|
||||
expires_at=user.get("expires_at")
|
||||
).model_dump()
|
||||
},
|
||||
message="登录成功",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"登录失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="登录失败,请稍后重试"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(response: Response):
|
||||
"""用户登出"""
|
||||
clear_auth_cookie(response)
|
||||
return success_response(message="已登出")
|
||||
|
||||
|
||||
@router.post("/change-password")
|
||||
async def change_password(request: ChangePasswordRequest, req: Request, response: Response):
|
||||
"""
|
||||
修改密码
|
||||
|
||||
- 验证当前密码
|
||||
- 设置新密码
|
||||
- 重新生成 session token
|
||||
"""
|
||||
# 从 Cookie 获取用户
|
||||
token = req.cookies.get("access_token")
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="未登录"
|
||||
)
|
||||
|
||||
token_data = decode_access_token(token)
|
||||
if not token_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token 无效"
|
||||
)
|
||||
|
||||
try:
|
||||
user = cast(dict[str, Any], get_user_by_id(token_data.user_id) or {})
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
# 验证当前密码
|
||||
if not verify_password(request.old_password, user["password_hash"]):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="当前密码错误"
|
||||
)
|
||||
|
||||
# 更新密码
|
||||
new_password_hash = get_password_hash(request.new_password)
|
||||
update_user(user["id"], {"password_hash": new_password_hash})
|
||||
|
||||
# 生成新的 session token,使旧 token 失效
|
||||
new_session_token = generate_session_token()
|
||||
|
||||
delete_sessions(user["id"])
|
||||
create_session(user["id"], new_session_token, None)
|
||||
|
||||
# 生成新的 JWT Token
|
||||
new_token = create_access_token(user["id"], new_session_token)
|
||||
set_auth_cookie(response, new_token)
|
||||
|
||||
logger.info(f"用户修改密码: {user['phone']}")
|
||||
|
||||
return success_response(message="密码修改成功")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"修改密码失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="修改密码失败,请稍后重试"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def get_me(request: Request):
|
||||
"""获取当前用户信息"""
|
||||
# 从 Cookie 获取用户
|
||||
token = request.cookies.get("access_token")
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="未登录"
|
||||
)
|
||||
|
||||
token_data = decode_access_token(token)
|
||||
if not token_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token 无效"
|
||||
)
|
||||
|
||||
user = cast(dict[str, Any], get_user_by_id(token_data.user_id) or {})
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
return success_response(UserResponse(
|
||||
id=user["id"],
|
||||
phone=user["phone"],
|
||||
username=user.get("username"),
|
||||
role=user["role"],
|
||||
is_active=user["is_active"],
|
||||
expires_at=user.get("expires_at")
|
||||
).model_dump())
|
||||
0
backend/app/modules/login_helper/__init__.py
Normal file
0
backend/app/modules/login_helper/__init__.py
Normal file
221
backend/app/modules/login_helper/router.py
Normal file
221
backend/app/modules/login_helper/router.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
前端一键扫码登录辅助页面
|
||||
客户在自己的浏览器中扫码,JavaScript自动提取Cookie并上传到服务器
|
||||
"""
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from app.core.config import settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/login-helper/{platform}", response_class=HTMLResponse)
|
||||
async def login_helper_page(platform: str, request: Request):
|
||||
"""
|
||||
提供一个HTML页面,让用户在自己的浏览器中登录平台
|
||||
登录后JavaScript自动提取Cookie并POST回服务器
|
||||
"""
|
||||
|
||||
platform_urls = {
|
||||
"bilibili": "https://www.bilibili.com/",
|
||||
"douyin": "https://creator.douyin.com/",
|
||||
"xiaohongshu": "https://creator.xiaohongshu.com/"
|
||||
}
|
||||
|
||||
platform_names = {
|
||||
"bilibili": "B站",
|
||||
"douyin": "抖音",
|
||||
"xiaohongshu": "小红书"
|
||||
}
|
||||
|
||||
if platform not in platform_urls:
|
||||
return "<h1>不支持的平台</h1>"
|
||||
|
||||
# 获取服务器地址(用于回传Cookie)
|
||||
server_url = str(request.base_url).rstrip('/')
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>{platform_names[platform]} 一键登录</title>
|
||||
<style>
|
||||
body {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}}
|
||||
.container {{
|
||||
background: white;
|
||||
border-radius: 20px;
|
||||
padding: 50px;
|
||||
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
|
||||
max-width: 700px;
|
||||
width: 100%;
|
||||
}}
|
||||
h1 {{
|
||||
color: #333;
|
||||
margin: 0 0 30px 0;
|
||||
text-align: center;
|
||||
font-size: 32px;
|
||||
}}
|
||||
.step {{
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
margin: 25px 0;
|
||||
padding: 20px;
|
||||
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
|
||||
border-radius: 12px;
|
||||
border-left: 5px solid #667eea;
|
||||
}}
|
||||
.step-number {{
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-weight: bold;
|
||||
font-size: 20px;
|
||||
margin-right: 20px;
|
||||
flex-shrink: 0;
|
||||
}}
|
||||
.step-content {{
|
||||
flex: 1;
|
||||
}}
|
||||
.step-title {{
|
||||
font-weight: 600;
|
||||
font-size: 18px;
|
||||
margin-bottom: 8px;
|
||||
color: #333;
|
||||
}}
|
||||
.step-desc {{
|
||||
color: #666;
|
||||
line-height: 1.6;
|
||||
}}
|
||||
.bookmarklet {{
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
padding: 15px 30px;
|
||||
border-radius: 10px;
|
||||
text-decoration: none;
|
||||
display: inline-block;
|
||||
font-weight: 600;
|
||||
font-size: 18px;
|
||||
margin: 20px 0;
|
||||
cursor: move;
|
||||
border: 3px dashed white;
|
||||
transition: transform 0.2s;
|
||||
}}
|
||||
.bookmarklet:hover {{
|
||||
transform: scale(1.05);
|
||||
}}
|
||||
.bookmarklet-container {{
|
||||
text-align: center;
|
||||
margin: 30px 0;
|
||||
padding: 30px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 12px;
|
||||
}}
|
||||
.instruction {{
|
||||
font-size: 14px;
|
||||
color: #666;
|
||||
margin-top: 10px;
|
||||
}}
|
||||
.highlight {{
|
||||
background: #fff3cd;
|
||||
padding: 2px 6px;
|
||||
border-radius: 4px;
|
||||
font-weight: 600;
|
||||
}}
|
||||
.btn {{
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 15px 40px;
|
||||
border-radius: 10px;
|
||||
font-size: 18px;
|
||||
cursor: pointer;
|
||||
font-weight: 600;
|
||||
width: 100%;
|
||||
margin-top: 20px;
|
||||
transition: transform 0.2s;
|
||||
}}
|
||||
.btn:hover {{
|
||||
transform: translateY(-2px);
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔐 {platform_names[platform]} 一键登录</h1>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-number">1</div>
|
||||
<div class="step-content">
|
||||
<div class="step-title">拖拽书签到书签栏</div>
|
||||
<div class="step-desc">
|
||||
将下方的"<span class="highlight">保存{platform_names[platform]}登录</span>"按钮拖拽到浏览器书签栏
|
||||
<br><small>(如果书签栏未显示,按 Ctrl+Shift+B 显示)</small>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="bookmarklet-container">
|
||||
<a href="javascript:(function(){{var c=document.cookie;if(!c){{alert('请先登录{platform_names[platform]}');return;}}fetch('{server_url}/api/publish/cookies/save/{platform}',{{method:'POST',headers:{{'Content-Type':'application/json'}},body:JSON.stringify({{cookie_string:c}})}}).then(r=>r.json()).then(d=>{{if(d.success){{alert('✅ 登录成功!');window.opener&&window.opener.location.reload();}}else{{alert('❌ '+d.message);}}}}
|
||||
|
||||
).catch(e=>alert('提交失败:'+e));}})();"
|
||||
class="bookmarklet"
|
||||
onclick="alert('请拖拽此按钮到书签栏,不要点击!'); return false;">
|
||||
🔖 保存{platform_names[platform]}登录
|
||||
</a>
|
||||
<div class="instruction">
|
||||
⬆️ <strong>拖拽此按钮到浏览器顶部书签栏</strong>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-number">2</div>
|
||||
<div class="step-content">
|
||||
<div class="step-title">登录 {platform_names[platform]}</div>
|
||||
<div class="step-desc">
|
||||
点击下方按钮打开{platform_names[platform]}登录页,扫码登录
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button class="btn" onclick="window.open('{platform_urls[platform]}', 'login_tab')">
|
||||
🚀 打开{platform_names[platform]}登录页
|
||||
</button>
|
||||
|
||||
<div class="step">
|
||||
<div class="step-number">3</div>
|
||||
<div class="step-content">
|
||||
<div class="step-title">一键保存登录</div>
|
||||
<div class="step-desc">
|
||||
登录成功后,点击书签栏的"<span class="highlight">保存{platform_names[platform]}登录</span>"书签
|
||||
<br>系统会自动提取并保存Cookie,完成!
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<hr style="margin: 40px 0; border: none; border-top: 2px solid #eee;">
|
||||
|
||||
<div style="text-align: center; color: #999; font-size: 14px;">
|
||||
<p>💡 <strong>提示</strong>:书签只需拖拽一次,下次登录直接点击书签即可</p>
|
||||
<p>🔒 所有数据仅在您的浏览器和服务器之间传输,安全可靠</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return HTMLResponse(content=html_content)
|
||||
0
backend/app/modules/materials/__init__.py
Normal file
0
backend/app/modules/materials/__init__.py
Normal file
416
backend/app/modules/materials/router.py
Normal file
416
backend/app/modules/materials/router.py
Normal file
@@ -0,0 +1,416 @@
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Request, BackgroundTasks, Depends
|
||||
from app.core.config import settings
|
||||
from app.core.deps import get_current_user
|
||||
from app.core.response import success_response
|
||||
from app.services.storage import storage_service
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
import os
|
||||
import aiofiles
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
import asyncio
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import httpx
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class RenameMaterialRequest(BaseModel):
|
||||
new_name: str
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
safe_name = re.sub(r'[<>:"/\\|?*]', '_', filename)
|
||||
if len(safe_name) > 100:
|
||||
ext = Path(safe_name).suffix
|
||||
safe_name = safe_name[:100 - len(ext)] + ext
|
||||
return safe_name
|
||||
|
||||
async def process_and_upload(temp_file_path: str, original_filename: str, content_type: str, user_id: str):
|
||||
"""Background task to strip multipart headers and upload to Supabase"""
|
||||
try:
|
||||
logger.info(f"Processing raw upload: {temp_file_path} for user {user_id}")
|
||||
|
||||
# 1. Analyze file to find actual video content (strip multipart boundaries)
|
||||
# This is a simplified manual parser for a SINGLE file upload.
|
||||
# Structure:
|
||||
# --boundary
|
||||
# Content-Disposition: form-data; name="file"; filename="..."
|
||||
# Content-Type: video/mp4
|
||||
# \r\n\r\n
|
||||
# [DATA]
|
||||
# \r\n--boundary--
|
||||
|
||||
# We need to read the first few KB to find the header end
|
||||
start_offset = 0
|
||||
end_offset = 0
|
||||
boundary = b""
|
||||
|
||||
file_size = os.path.getsize(temp_file_path)
|
||||
|
||||
with open(temp_file_path, 'rb') as f:
|
||||
# Read first 4KB to find header
|
||||
head = f.read(4096)
|
||||
|
||||
# Find boundary
|
||||
first_line_end = head.find(b'\r\n')
|
||||
if first_line_end == -1:
|
||||
raise Exception("Could not find boundary in multipart body")
|
||||
|
||||
boundary = head[:first_line_end] # e.g. --boundary123
|
||||
logger.info(f"Detected boundary: {boundary}")
|
||||
|
||||
# Find end of headers (\r\n\r\n)
|
||||
header_end = head.find(b'\r\n\r\n')
|
||||
if header_end == -1:
|
||||
raise Exception("Could not find end of multipart headers")
|
||||
|
||||
start_offset = header_end + 4
|
||||
logger.info(f"Video data starts at offset: {start_offset}")
|
||||
|
||||
# Find end boundary (read from end of file)
|
||||
# It should be \r\n + boundary + -- + \r\n
|
||||
# We seek to end-200 bytes
|
||||
f.seek(max(0, file_size - 200))
|
||||
tail = f.read()
|
||||
|
||||
# The closing boundary is usually --boundary--
|
||||
# We look for the last occurrence of the boundary
|
||||
last_boundary_pos = tail.rfind(boundary)
|
||||
if last_boundary_pos != -1:
|
||||
# The data ends before \r\n + boundary
|
||||
# The tail buffer relative position needs to be converted to absolute
|
||||
end_pos_in_tail = last_boundary_pos
|
||||
# We also need to check for the preceding \r\n
|
||||
if end_pos_in_tail >= 2 and tail[end_pos_in_tail-2:end_pos_in_tail] == b'\r\n':
|
||||
end_pos_in_tail -= 2
|
||||
|
||||
# Absolute end offset
|
||||
end_offset = (file_size - 200) + last_boundary_pos
|
||||
# Correction for CRLF before boundary
|
||||
# Actually, simply: read until (file_size - len(tail) + last_boundary_pos) - 2
|
||||
end_offset = (max(0, file_size - 200) + last_boundary_pos) - 2
|
||||
else:
|
||||
logger.warning("Could not find closing boundary, assuming EOF")
|
||||
end_offset = file_size
|
||||
|
||||
logger.info(f"Video data ends at offset: {end_offset}. Total video size: {end_offset - start_offset}")
|
||||
|
||||
# 2. Extract and Upload to Supabase
|
||||
# Since we have the file on disk, we can just pass the file object (seeked) to upload_file?
|
||||
# Or if upload_file expects bytes/path, checking storage.py...
|
||||
# It takes `file_data` (bytes) or file-like?
|
||||
# supabase-py's `upload` method handles parsing if we pass a file object.
|
||||
# But we need to pass ONLY the video slice.
|
||||
# So we create a generator or a sliced file object?
|
||||
# Simpler: Read the slice into memory if < 1GB? Or copy to new temp file?
|
||||
# Copying to new temp file is safer for memory.
|
||||
|
||||
video_path = temp_file_path + "_video.mp4"
|
||||
with open(temp_file_path, 'rb') as src, open(video_path, 'wb') as dst:
|
||||
src.seek(start_offset)
|
||||
# Copy in chunks
|
||||
bytes_to_copy = end_offset - start_offset
|
||||
copied = 0
|
||||
while copied < bytes_to_copy:
|
||||
chunk_size = min(1024*1024*10, bytes_to_copy - copied) # 10MB chunks
|
||||
chunk = src.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
dst.write(chunk)
|
||||
copied += len(chunk)
|
||||
|
||||
logger.info(f"Extracted video content to {video_path}")
|
||||
|
||||
# 3. Upload to Supabase with user isolation
|
||||
timestamp = int(time.time())
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9._-]', '', original_filename)
|
||||
# 使用 user_id 作为目录前缀实现隔离
|
||||
storage_path = f"{user_id}/{timestamp}_{safe_name}"
|
||||
|
||||
# Use storage service (this calls Supabase which might do its own http request)
|
||||
# We read the cleaned video file
|
||||
with open(video_path, 'rb') as f:
|
||||
file_content = f.read() # Still reading into memory for simple upload call, but server has 32GB RAM so ok for 500MB
|
||||
await storage_service.upload_file(
|
||||
bucket=storage_service.BUCKET_MATERIALS,
|
||||
path=storage_path,
|
||||
file_data=file_content,
|
||||
content_type=content_type
|
||||
)
|
||||
|
||||
logger.info(f"Upload to Supabase complete: {storage_path}")
|
||||
|
||||
# Cleanup
|
||||
os.remove(temp_file_path)
|
||||
os.remove(video_path)
|
||||
|
||||
return storage_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Background upload processing failed: {e}\n{traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def upload_material(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
user_id = current_user["id"]
|
||||
logger.info(f"ENTERED upload_material (Streaming Mode) for user {user_id}. Headers: {request.headers}")
|
||||
|
||||
filename = "unknown_video.mp4" # Fallback
|
||||
content_type = "video/mp4"
|
||||
|
||||
# Try to parse filename from header if possible (unreliable in raw stream)
|
||||
# We will rely on post-processing or client hint
|
||||
# Frontend sends standard multipart.
|
||||
|
||||
# Create temp file
|
||||
timestamp = int(time.time())
|
||||
temp_filename = f"upload_{timestamp}.raw"
|
||||
temp_path = os.path.join("/tmp", temp_filename) # Use /tmp on Linux
|
||||
# Ensure /tmp exists (it does) but verify paths
|
||||
if os.name == 'nt': # Local dev
|
||||
temp_path = f"d:/tmp/{temp_filename}"
|
||||
os.makedirs("d:/tmp", exist_ok=True)
|
||||
|
||||
try:
|
||||
total_size = 0
|
||||
last_log = 0
|
||||
|
||||
async with aiofiles.open(temp_path, 'wb') as f:
|
||||
async for chunk in request.stream():
|
||||
await f.write(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
# Log progress every 20MB
|
||||
if total_size - last_log > 20 * 1024 * 1024:
|
||||
logger.info(f"Receiving stream... Processed {total_size / (1024*1024):.2f} MB")
|
||||
last_log = total_size
|
||||
|
||||
logger.info(f"Stream reception complete. Total size: {total_size} bytes. Saved to {temp_path}")
|
||||
|
||||
if total_size == 0:
|
||||
raise HTTPException(400, "Received empty body")
|
||||
|
||||
# Attempt to extract filename from the saved file's first bytes?
|
||||
# Or just accept it as "uploaded_video.mp4" for now to prove it works.
|
||||
# We can try to regex the header in the file content we just wrote.
|
||||
# Implemented in background task to return success immediately.
|
||||
|
||||
# Wait, if we return immediately, the user's UI might not show the file yet?
|
||||
# The prompt says "Wait for upload".
|
||||
# But to avoid User Waiting Timeout, maybe returning early is better?
|
||||
# NO, user expects the file to be in the list.
|
||||
# So we Must await the processing.
|
||||
# But "Processing" (Strip + Upload to Supabase) takes time.
|
||||
# Receiving took time.
|
||||
# If we await Supabase upload, does it timeout?
|
||||
# Supabase upload is outgoing. Usually faster/stable.
|
||||
|
||||
# Let's await the processing to ensure "List Materials" shows it.
|
||||
# We need to extract the filename for the list.
|
||||
|
||||
# Quick extract filename from first 4kb
|
||||
with open(temp_path, 'rb') as f:
|
||||
head = f.read(4096).decode('utf-8', errors='ignore')
|
||||
match = re.search(r'filename="([^"]+)"', head)
|
||||
if match:
|
||||
filename = match.group(1)
|
||||
logger.info(f"Extracted filename from body: {filename}")
|
||||
|
||||
# Run processing sync (in await)
|
||||
storage_path = await process_and_upload(temp_path, filename, content_type, user_id)
|
||||
|
||||
# Get signed URL (it exists now)
|
||||
signed_url = await storage_service.get_signed_url(
|
||||
bucket=storage_service.BUCKET_MATERIALS,
|
||||
path=storage_path
|
||||
)
|
||||
|
||||
size_mb = total_size / (1024 * 1024) # Approximate (includes headers)
|
||||
|
||||
# 从 storage_path 提取显示名
|
||||
display_name = storage_path.split('/')[-1] # 去掉 user_id 前缀
|
||||
if '_' in display_name:
|
||||
parts = display_name.split('_', 1)
|
||||
if parts[0].isdigit():
|
||||
display_name = parts[1]
|
||||
|
||||
return success_response({
|
||||
"id": storage_path,
|
||||
"name": display_name,
|
||||
"path": signed_url,
|
||||
"size_mb": size_mb,
|
||||
"type": "video"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Streaming upload failed: {str(e)}"
|
||||
detail_msg = f"Exception: {repr(e)}\nArgs: {e.args}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg + "\n" + detail_msg)
|
||||
|
||||
# Write to debug file
|
||||
try:
|
||||
with open("debug_upload.log", "a") as logf:
|
||||
logf.write(f"\n--- Error at {time.ctime()} ---\n")
|
||||
logf.write(detail_msg)
|
||||
logf.write("\n-----------------------------\n")
|
||||
except:
|
||||
pass
|
||||
|
||||
if os.path.exists(temp_path):
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except:
|
||||
pass
|
||||
raise HTTPException(500, f"Upload failed. Check server logs. Error: {str(e)}")
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_materials(current_user: dict = Depends(get_current_user)):
|
||||
user_id = current_user["id"]
|
||||
try:
|
||||
# 只列出当前用户目录下的文件
|
||||
files_obj = await storage_service.list_files(
|
||||
bucket=storage_service.BUCKET_MATERIALS,
|
||||
path=user_id
|
||||
)
|
||||
semaphore = asyncio.Semaphore(8)
|
||||
|
||||
async def build_item(f):
|
||||
name = f.get('name')
|
||||
if not name or name == '.emptyFolderPlaceholder':
|
||||
return None
|
||||
display_name = name
|
||||
if '_' in name:
|
||||
parts = name.split('_', 1)
|
||||
if parts[0].isdigit():
|
||||
display_name = parts[1]
|
||||
full_path = f"{user_id}/{name}"
|
||||
async with semaphore:
|
||||
signed_url = await storage_service.get_signed_url(
|
||||
bucket=storage_service.BUCKET_MATERIALS,
|
||||
path=full_path
|
||||
)
|
||||
metadata = f.get('metadata', {})
|
||||
size = metadata.get('size', 0)
|
||||
created_at_str = f.get('created_at', '')
|
||||
created_at = 0
|
||||
if created_at_str:
|
||||
from datetime import datetime
|
||||
try:
|
||||
dt = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
|
||||
created_at = int(dt.timestamp())
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"id": full_path,
|
||||
"name": display_name,
|
||||
"path": signed_url,
|
||||
"size_mb": size / (1024 * 1024),
|
||||
"type": "video",
|
||||
"created_at": created_at
|
||||
}
|
||||
|
||||
tasks = [build_item(f) for f in files_obj]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
materials = []
|
||||
for item in results:
|
||||
if not item:
|
||||
continue
|
||||
if isinstance(item, Exception):
|
||||
logger.warning(f"Material signed url build failed: {item}")
|
||||
continue
|
||||
materials.append(item)
|
||||
materials.sort(key=lambda x: x['id'], reverse=True)
|
||||
return success_response({"materials": materials})
|
||||
except Exception as e:
|
||||
logger.error(f"List materials failed: {e}")
|
||||
return success_response({"materials": []}, message="获取素材失败")
|
||||
|
||||
|
||||
@router.delete("/{material_id:path}")
|
||||
async def delete_material(material_id: str, current_user: dict = Depends(get_current_user)):
|
||||
user_id = current_user["id"]
|
||||
# 验证 material_id 属于当前用户
|
||||
if not material_id.startswith(f"{user_id}/"):
|
||||
raise HTTPException(403, "无权删除此素材")
|
||||
try:
|
||||
await storage_service.delete_file(
|
||||
bucket=storage_service.BUCKET_MATERIALS,
|
||||
path=material_id
|
||||
)
|
||||
return success_response(message="素材已删除")
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"删除失败: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/{material_id:path}")
|
||||
async def rename_material(
|
||||
material_id: str,
|
||||
payload: RenameMaterialRequest,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
user_id = current_user["id"]
|
||||
if not material_id.startswith(f"{user_id}/"):
|
||||
raise HTTPException(403, "无权重命名此素材")
|
||||
|
||||
new_name_raw = payload.new_name.strip() if payload.new_name else ""
|
||||
if not new_name_raw:
|
||||
raise HTTPException(400, "新名称不能为空")
|
||||
|
||||
old_name = material_id.split("/", 1)[1]
|
||||
old_ext = Path(old_name).suffix
|
||||
base_name = Path(new_name_raw).stem if Path(new_name_raw).suffix else new_name_raw
|
||||
safe_base = sanitize_filename(base_name).strip()
|
||||
if not safe_base:
|
||||
raise HTTPException(400, "新名称无效")
|
||||
|
||||
new_filename = f"{safe_base}{old_ext}"
|
||||
|
||||
prefix = None
|
||||
if "_" in old_name:
|
||||
maybe_prefix, _ = old_name.split("_", 1)
|
||||
if maybe_prefix.isdigit():
|
||||
prefix = maybe_prefix
|
||||
if prefix:
|
||||
new_filename = f"{prefix}_{new_filename}"
|
||||
|
||||
new_path = f"{user_id}/{new_filename}"
|
||||
try:
|
||||
if new_path != material_id:
|
||||
await storage_service.move_file(
|
||||
bucket=storage_service.BUCKET_MATERIALS,
|
||||
from_path=material_id,
|
||||
to_path=new_path
|
||||
)
|
||||
|
||||
signed_url = await storage_service.get_signed_url(
|
||||
bucket=storage_service.BUCKET_MATERIALS,
|
||||
path=new_path
|
||||
)
|
||||
|
||||
display_name = new_filename
|
||||
if "_" in new_filename:
|
||||
parts = new_filename.split("_", 1)
|
||||
if parts[0].isdigit():
|
||||
display_name = parts[1]
|
||||
|
||||
return success_response({
|
||||
"id": new_path,
|
||||
"name": display_name,
|
||||
"path": signed_url,
|
||||
}, message="重命名成功")
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"重命名失败: {str(e)}")
|
||||
|
||||
|
||||
|
||||
0
backend/app/modules/publish/__init__.py
Normal file
0
backend/app/modules/publish/__init__.py
Normal file
141
backend/app/modules/publish/router.py
Normal file
141
backend/app/modules/publish/router.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
发布管理 API (支持用户认证)
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from app.services.publish_service import PublishService
|
||||
from app.core.response import success_response
|
||||
|
||||
router = APIRouter()
|
||||
publish_service = PublishService()
|
||||
|
||||
class PublishRequest(BaseModel):
|
||||
"""Video publish request model"""
|
||||
video_path: str
|
||||
platform: str
|
||||
title: str
|
||||
tags: List[str] = []
|
||||
description: str = ""
|
||||
publish_time: Optional[datetime] = None
|
||||
|
||||
class PublishResponse(BaseModel):
|
||||
"""Video publish response model"""
|
||||
success: bool
|
||||
message: str
|
||||
platform: str
|
||||
url: Optional[str] = None
|
||||
|
||||
# Supported platforms for validation
|
||||
SUPPORTED_PLATFORMS = {"bilibili", "douyin", "xiaohongshu"}
|
||||
|
||||
|
||||
def _get_user_id(request: Request) -> Optional[str]:
|
||||
"""从请求中获取用户 ID (兼容未登录场景)"""
|
||||
try:
|
||||
from app.core.security import decode_access_token
|
||||
token = request.cookies.get("access_token")
|
||||
if token:
|
||||
token_data = decode_access_token(token)
|
||||
if token_data:
|
||||
return token_data.user_id
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def publish_video(request: PublishRequest, req: Request, background_tasks: BackgroundTasks):
|
||||
"""发布视频到指定平台"""
|
||||
# Validate platform
|
||||
if request.platform not in SUPPORTED_PLATFORMS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的平台: {request.platform}。支持的平台: {', '.join(SUPPORTED_PLATFORMS)}"
|
||||
)
|
||||
|
||||
# 获取用户 ID (可选)
|
||||
user_id = _get_user_id(req)
|
||||
|
||||
try:
|
||||
result = await publish_service.publish(
|
||||
video_path=request.video_path,
|
||||
platform=request.platform,
|
||||
title=request.title,
|
||||
tags=request.tags,
|
||||
description=request.description,
|
||||
publish_time=request.publish_time,
|
||||
user_id=user_id
|
||||
)
|
||||
message = result.get("message", "")
|
||||
return success_response(result, message=message)
|
||||
except Exception as e:
|
||||
logger.error(f"发布失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.get("/platforms")
|
||||
async def list_platforms():
|
||||
return success_response({"platforms": [{**pinfo, "id": pid} for pid, pinfo in publish_service.PLATFORMS.items()]})
|
||||
|
||||
@router.get("/accounts")
|
||||
async def list_accounts(req: Request):
|
||||
user_id = _get_user_id(req)
|
||||
return success_response({"accounts": publish_service.get_accounts(user_id)})
|
||||
|
||||
@router.post("/login/{platform}")
|
||||
async def login_platform(platform: str, req: Request):
|
||||
"""触发平台QR码登录"""
|
||||
if platform not in SUPPORTED_PLATFORMS:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
|
||||
|
||||
user_id = _get_user_id(req)
|
||||
result = await publish_service.login(platform, user_id)
|
||||
|
||||
message = result.get("message", "")
|
||||
return success_response(result, message=message)
|
||||
|
||||
@router.post("/logout/{platform}")
|
||||
async def logout_platform(platform: str, req: Request):
|
||||
"""注销平台登录"""
|
||||
if platform not in SUPPORTED_PLATFORMS:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
|
||||
|
||||
user_id = _get_user_id(req)
|
||||
result = publish_service.logout(platform, user_id)
|
||||
message = result.get("message", "")
|
||||
return success_response(result, message=message)
|
||||
|
||||
@router.get("/login/status/{platform}")
|
||||
async def get_login_status(platform: str, req: Request):
|
||||
"""检查登录状态 (优先检查活跃的扫码会话)"""
|
||||
if platform not in SUPPORTED_PLATFORMS:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
|
||||
|
||||
user_id = _get_user_id(req)
|
||||
result = publish_service.get_login_session_status(platform, user_id)
|
||||
message = result.get("message", "")
|
||||
return success_response(result, message=message)
|
||||
|
||||
@router.post("/cookies/save/{platform}")
|
||||
async def save_platform_cookie(platform: str, cookie_data: dict, req: Request):
|
||||
"""
|
||||
保存从客户端浏览器提取的Cookie
|
||||
|
||||
Args:
|
||||
platform: 平台ID
|
||||
cookie_data: {"cookie_string": "document.cookie的内容"}
|
||||
"""
|
||||
if platform not in SUPPORTED_PLATFORMS:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
|
||||
|
||||
cookie_string = cookie_data.get("cookie_string", "")
|
||||
if not cookie_string:
|
||||
raise HTTPException(status_code=400, detail="cookie_string 不能为空")
|
||||
|
||||
user_id = _get_user_id(req)
|
||||
result = await publish_service.save_cookie_string(platform, cookie_string, user_id)
|
||||
|
||||
message = result.get("message", "")
|
||||
return success_response(result, message=message)
|
||||
0
backend/app/modules/ref_audios/__init__.py
Normal file
0
backend/app/modules/ref_audios/__init__.py
Normal file
416
backend/app/modules/ref_audios/router.py
Normal file
416
backend/app/modules/ref_audios/router.py
Normal file
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
参考音频管理 API
|
||||
支持上传/列表/删除参考音频,用于 Qwen3-TTS 声音克隆
|
||||
"""
|
||||
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
import time
|
||||
import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
import os
|
||||
import re
|
||||
|
||||
from app.core.deps import get_current_user
|
||||
from app.services.storage import storage_service
|
||||
from app.core.response import success_response
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 支持的音频格式
|
||||
ALLOWED_AUDIO_EXTENSIONS = {'.wav', '.mp3', '.m4a', '.webm', '.ogg', '.flac', '.aac'}
|
||||
|
||||
# 参考音频 bucket
|
||||
BUCKET_REF_AUDIOS = "ref-audios"
|
||||
|
||||
|
||||
class RefAudioResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
path: str # signed URL for playback
|
||||
ref_text: str
|
||||
duration_sec: float
|
||||
created_at: int
|
||||
|
||||
|
||||
class RefAudioListResponse(BaseModel):
|
||||
items: List[RefAudioResponse]
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""清理文件名,移除特殊字符"""
|
||||
safe_name = re.sub(r'[<>:"/\\|?*\s]', '_', filename)
|
||||
if len(safe_name) > 50:
|
||||
ext = Path(safe_name).suffix
|
||||
safe_name = safe_name[:50 - len(ext)] + ext
|
||||
return safe_name
|
||||
|
||||
|
||||
def get_audio_duration(file_path: str) -> float:
|
||||
"""获取音频时长 (秒)"""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['ffprobe', '-v', 'quiet', '-show_entries', 'format=duration',
|
||||
'-of', 'csv=p=0', file_path],
|
||||
capture_output=True, text=True, timeout=10
|
||||
)
|
||||
return float(result.stdout.strip())
|
||||
except Exception as e:
|
||||
logger.warning(f"获取音频时长失败: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def convert_to_wav(input_path: str, output_path: str) -> bool:
|
||||
"""将音频转换为 WAV 格式 (16kHz, mono)"""
|
||||
try:
|
||||
subprocess.run([
|
||||
'ffmpeg', '-y', '-i', input_path,
|
||||
'-ar', '16000', # 16kHz 采样率
|
||||
'-ac', '1', # 单声道
|
||||
'-acodec', 'pcm_s16le', # 16-bit PCM
|
||||
output_path
|
||||
], capture_output=True, timeout=60, check=True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"音频转换失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def upload_ref_audio(
|
||||
file: UploadFile = File(...),
|
||||
ref_text: str = Form(...),
|
||||
user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
上传参考音频
|
||||
|
||||
- file: 音频文件 (支持 wav, mp3, m4a, webm 等)
|
||||
- ref_text: 参考音频的转写文字 (必填)
|
||||
"""
|
||||
user_id = user["id"]
|
||||
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="文件名无效")
|
||||
filename = file.filename
|
||||
|
||||
# 验证文件扩展名
|
||||
ext = Path(filename).suffix.lower()
|
||||
if ext not in ALLOWED_AUDIO_EXTENSIONS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的音频格式: {ext}。支持的格式: {', '.join(ALLOWED_AUDIO_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
# 验证 ref_text
|
||||
if not ref_text or len(ref_text.strip()) < 2:
|
||||
raise HTTPException(status_code=400, detail="参考文字不能为空")
|
||||
|
||||
try:
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_input:
|
||||
content = await file.read()
|
||||
tmp_input.write(content)
|
||||
tmp_input_path = tmp_input.name
|
||||
|
||||
# 转换为 WAV 格式
|
||||
tmp_wav_path = tmp_input_path + ".wav"
|
||||
if ext != '.wav':
|
||||
if not convert_to_wav(tmp_input_path, tmp_wav_path):
|
||||
raise HTTPException(status_code=500, detail="音频格式转换失败")
|
||||
else:
|
||||
# 即使是 wav 也要标准化格式
|
||||
convert_to_wav(tmp_input_path, tmp_wav_path)
|
||||
|
||||
# 获取音频时长
|
||||
duration = get_audio_duration(tmp_wav_path)
|
||||
if duration < 1.0:
|
||||
raise HTTPException(status_code=400, detail="音频时长过短,至少需要 1 秒")
|
||||
if duration > 60.0:
|
||||
raise HTTPException(status_code=400, detail="音频时长过长,最多 60 秒")
|
||||
|
||||
|
||||
# 3. 处理重名逻辑 (Friendly Display Name)
|
||||
original_name = filename
|
||||
|
||||
# 获取用户现有的所有参考音频列表 (为了检查文件名冲突)
|
||||
# 注意: 这种列表方式在文件极多时性能一般,但考虑到单用户参考音频数量有限,目前可行
|
||||
existing_files = await storage_service.list_files(BUCKET_REF_AUDIOS, user_id)
|
||||
existing_names = set()
|
||||
|
||||
# 预加载所有现有的 display name
|
||||
# 这里需要并发请求 metadata 可能会慢,优化: 仅检查 metadata 文件并解析
|
||||
# 简易方案: 仅在 metadata 中读取 original_filename
|
||||
# 但 list_files 返回的是 name,我们需要 metadata
|
||||
# 考虑到性能,这里使用一种妥协方案:
|
||||
# 我们不做全量检查,而是简单的检查:如果用户上传 myvoice.wav
|
||||
# 我们看看有没有 (timestamp)_myvoice.wav 这种其实并不能准确判断 display name 是否冲突
|
||||
#
|
||||
# 正确做法: 应该有个数据库表存 metadata。但目前是无数据库设计。
|
||||
#
|
||||
# 改用简单方案:
|
||||
# 既然我们无法快速获取所有 display name,
|
||||
# 我们暂时只处理 "在新上传时,original_filename 保持原样"
|
||||
# 但用户希望 "如果在列表中看到重复的,自动加(1)"
|
||||
#
|
||||
# 鉴于无数据库架构的限制,要在上传时知道"已有的 display name" 成本太高(需遍历下载所有json)。
|
||||
#
|
||||
# 💡 替代方案:
|
||||
# 我们不检查旧的。我们只保证**存储**唯一。
|
||||
# 对于用户提到的 "新上传的文件名后加个数字" -> 这通常是指 "另存为" 的逻辑。
|
||||
# 既然用户现在的痛点是 "显示了时间戳太丑",而我已经去掉了时间戳显示。
|
||||
# 那么如果用户上传两个 "TEST.wav",列表里就会有两个 "TEST.wav" (但时间不同)。
|
||||
# 这其实是可以接受的。
|
||||
#
|
||||
# 但如果用户强求 "自动重命名":
|
||||
# 我们可以在这里做一个轻量级的 "同名检测":
|
||||
# 检查有没有 *_{original_name} 的文件存在。
|
||||
# 如果 storage 里已经有 123_abc.wav, 456_abc.wav
|
||||
# 我们可以认为 abc.wav 已经存在。
|
||||
|
||||
dup_count = 0
|
||||
search_suffix = f"_{original_name}" # 比如 _test.wav
|
||||
|
||||
for f in existing_files:
|
||||
fname = f.get('name', '')
|
||||
if fname.endswith(search_suffix):
|
||||
dup_count += 1
|
||||
|
||||
final_display_name = original_name
|
||||
if dup_count > 0:
|
||||
name_stem = Path(original_name).stem
|
||||
name_ext = Path(original_name).suffix
|
||||
final_display_name = f"{name_stem}({dup_count}){name_ext}"
|
||||
|
||||
# 生成存储路径 (唯一ID)
|
||||
timestamp = int(time.time())
|
||||
safe_name = sanitize_filename(Path(filename).stem)
|
||||
storage_path = f"{user_id}/{timestamp}_{safe_name}.wav"
|
||||
|
||||
# 上传 WAV 文件到 Supabase
|
||||
with open(tmp_wav_path, 'rb') as f:
|
||||
wav_data = f.read()
|
||||
|
||||
await storage_service.upload_file(
|
||||
bucket=BUCKET_REF_AUDIOS,
|
||||
path=storage_path,
|
||||
file_data=wav_data,
|
||||
content_type="audio/wav"
|
||||
)
|
||||
|
||||
# 上传元数据 JSON
|
||||
metadata = {
|
||||
"ref_text": ref_text.strip(),
|
||||
"original_filename": final_display_name, # 这里的名字如果有重复会自动加(1)
|
||||
"duration_sec": duration,
|
||||
"created_at": timestamp
|
||||
}
|
||||
metadata_path = f"{user_id}/{timestamp}_{safe_name}.json"
|
||||
await storage_service.upload_file(
|
||||
bucket=BUCKET_REF_AUDIOS,
|
||||
path=metadata_path,
|
||||
file_data=json.dumps(metadata, ensure_ascii=False).encode('utf-8'),
|
||||
content_type="application/json"
|
||||
)
|
||||
|
||||
# 获取签名 URL
|
||||
signed_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, storage_path)
|
||||
|
||||
# 清理临时文件
|
||||
os.unlink(tmp_input_path)
|
||||
if os.path.exists(tmp_wav_path):
|
||||
os.unlink(tmp_wav_path)
|
||||
|
||||
return success_response(RefAudioResponse(
|
||||
id=storage_path,
|
||||
name=filename,
|
||||
path=signed_url,
|
||||
ref_text=ref_text.strip(),
|
||||
duration_sec=duration,
|
||||
created_at=timestamp
|
||||
).model_dump())
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"上传参考音频失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_ref_audios(user: dict = Depends(get_current_user)):
|
||||
"""列出当前用户的所有参考音频"""
|
||||
user_id = user["id"]
|
||||
|
||||
try:
|
||||
# 列出用户目录下的文件
|
||||
files = await storage_service.list_files(BUCKET_REF_AUDIOS, user_id)
|
||||
|
||||
# 过滤出 .wav 文件并获取对应的 metadata
|
||||
items = []
|
||||
for f in files:
|
||||
name = f.get("name", "")
|
||||
if not name.endswith(".wav"):
|
||||
continue
|
||||
|
||||
storage_path = f"{user_id}/{name}"
|
||||
|
||||
# 尝试读取 metadata
|
||||
metadata_name = name.replace(".wav", ".json")
|
||||
metadata_path = f"{user_id}/{metadata_name}"
|
||||
|
||||
ref_text = ""
|
||||
duration_sec = 0.0
|
||||
created_at = 0
|
||||
original_filename = ""
|
||||
|
||||
try:
|
||||
# 获取 metadata 内容
|
||||
metadata_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, metadata_path)
|
||||
import httpx
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(metadata_url)
|
||||
if resp.status_code == 200:
|
||||
metadata = resp.json()
|
||||
ref_text = metadata.get("ref_text", "")
|
||||
duration_sec = metadata.get("duration_sec", 0.0)
|
||||
created_at = metadata.get("created_at", 0)
|
||||
original_filename = metadata.get("original_filename", "")
|
||||
except Exception as e:
|
||||
logger.warning(f"读取 metadata 失败: {e}")
|
||||
# 从文件名提取时间戳
|
||||
try:
|
||||
created_at = int(name.split("_")[0])
|
||||
except:
|
||||
pass
|
||||
|
||||
# 获取音频签名 URL
|
||||
signed_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, storage_path)
|
||||
|
||||
# 优先显示原始文件名 (去掉时间戳前缀)
|
||||
display_name = original_filename if original_filename else name
|
||||
# 如果原始文件名丢失,尝试从现有文件名中通过正则去掉时间戳
|
||||
if not display_name or display_name == name:
|
||||
# 匹配 "1234567890_filename.wav"
|
||||
match = re.match(r'^\d+_(.+)$', name)
|
||||
if match:
|
||||
display_name = match.group(1)
|
||||
|
||||
items.append(RefAudioResponse(
|
||||
id=storage_path,
|
||||
name=display_name,
|
||||
path=signed_url,
|
||||
ref_text=ref_text,
|
||||
duration_sec=duration_sec,
|
||||
created_at=created_at
|
||||
))
|
||||
|
||||
# 按创建时间倒序排列
|
||||
items.sort(key=lambda x: x.created_at, reverse=True)
|
||||
|
||||
return success_response(RefAudioListResponse(items=items).model_dump())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"列出参考音频失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取列表失败: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/{audio_id:path}")
|
||||
async def delete_ref_audio(audio_id: str, user: dict = Depends(get_current_user)):
|
||||
"""删除参考音频"""
|
||||
user_id = user["id"]
|
||||
|
||||
# 安全检查:确保只能删除自己的文件
|
||||
if not audio_id.startswith(f"{user_id}/"):
|
||||
raise HTTPException(status_code=403, detail="无权删除此文件")
|
||||
|
||||
try:
|
||||
# 删除 WAV 文件
|
||||
await storage_service.delete_file(BUCKET_REF_AUDIOS, audio_id)
|
||||
|
||||
# 删除 metadata JSON
|
||||
metadata_path = audio_id.replace(".wav", ".json")
|
||||
try:
|
||||
await storage_service.delete_file(BUCKET_REF_AUDIOS, metadata_path)
|
||||
except:
|
||||
pass # metadata 可能不存在
|
||||
|
||||
return success_response(message="删除成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除参考音频失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}")
|
||||
|
||||
|
||||
class RenameRequest(BaseModel):
|
||||
new_name: str
|
||||
|
||||
|
||||
@router.put("/{audio_id:path}")
|
||||
async def rename_ref_audio(
|
||||
audio_id: str,
|
||||
request: RenameRequest,
|
||||
user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""重命名参考音频 (修改 metadata 中的 display name)"""
|
||||
user_id = user["id"]
|
||||
|
||||
# 安全检查
|
||||
if not audio_id.startswith(f"{user_id}/"):
|
||||
raise HTTPException(status_code=403, detail="无权修改此文件")
|
||||
|
||||
new_name = request.new_name.strip()
|
||||
if not new_name:
|
||||
raise HTTPException(status_code=400, detail="新名称不能为空")
|
||||
|
||||
# 确保新名称有后缀 (保留原后缀或添加 .wav)
|
||||
if not Path(new_name).suffix:
|
||||
new_name += ".wav"
|
||||
|
||||
try:
|
||||
# 1. 下载现有的 metadata
|
||||
metadata_path = audio_id.replace(".wav", ".json")
|
||||
try:
|
||||
# 获取已有的 JSON
|
||||
import httpx
|
||||
metadata_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, metadata_path)
|
||||
if not metadata_url:
|
||||
# 如果 json 不存在,则需要新建一个基础的
|
||||
raise Exception("Metadata not found")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(metadata_url)
|
||||
if resp.status_code == 200:
|
||||
metadata = resp.json()
|
||||
else:
|
||||
raise Exception(f"Failed to fetch metadata: {resp.status_code}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"无法读取元数据: {e}, 将创建新的元数据")
|
||||
# 兜底:如果读取失败,构建最小元数据
|
||||
metadata = {
|
||||
"ref_text": "", # 可能丢失
|
||||
"duration_sec": 0.0,
|
||||
"created_at": int(time.time()),
|
||||
"original_filename": new_name
|
||||
}
|
||||
|
||||
# 2. 更新 original_filename
|
||||
metadata["original_filename"] = new_name
|
||||
|
||||
# 3. 覆盖上传 metadata
|
||||
await storage_service.upload_file(
|
||||
bucket=BUCKET_REF_AUDIOS,
|
||||
path=metadata_path,
|
||||
file_data=json.dumps(metadata, ensure_ascii=False).encode('utf-8'),
|
||||
content_type="application/json"
|
||||
)
|
||||
|
||||
return success_response({"name": new_name}, message="重命名成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重命名失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"重命名失败: {str(e)}")
|
||||
0
backend/app/modules/tools/__init__.py
Normal file
0
backend/app/modules/tools/__init__.py
Normal file
407
backend/app/modules/tools/router.py
Normal file
407
backend/app/modules/tools/router.py
Normal file
@@ -0,0 +1,407 @@
|
||||
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
|
||||
from typing import Optional, Any, cast
|
||||
import asyncio
|
||||
import shutil
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
import traceback
|
||||
import re
|
||||
import json
|
||||
import requests
|
||||
from urllib.parse import unquote
|
||||
|
||||
from app.services.whisper_service import whisper_service
|
||||
from app.services.glm_service import glm_service
|
||||
from app.core.response import success_response
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/extract-script")
|
||||
async def extract_script_tool(
|
||||
file: Optional[UploadFile] = File(None),
|
||||
url: Optional[str] = Form(None),
|
||||
rewrite: bool = Form(True)
|
||||
):
|
||||
"""
|
||||
独立文案提取工具
|
||||
支持上传视频/音频 OR 输入视频链接 -> 提取文字 -> (可选) AI洗稿
|
||||
"""
|
||||
if not file and not url:
|
||||
raise HTTPException(400, "必须提供文件或视频链接")
|
||||
|
||||
temp_path = None
|
||||
try:
|
||||
timestamp = int(time.time())
|
||||
temp_dir = Path("/tmp")
|
||||
if os.name == 'nt':
|
||||
temp_dir = Path("d:/tmp")
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 1. 获取/保存文件
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if file:
|
||||
filename = file.filename
|
||||
if not filename:
|
||||
raise HTTPException(400, "文件名无效")
|
||||
safe_filename = Path(filename).name.replace(" ", "_")
|
||||
temp_path = temp_dir / f"tool_extract_{timestamp}_{safe_filename}"
|
||||
# 文件 I/O 放入线程池
|
||||
await loop.run_in_executor(None, lambda: shutil.copyfileobj(file.file, open(temp_path, "wb")))
|
||||
logger.info(f"Tool processing upload file: {temp_path}")
|
||||
else:
|
||||
if not url:
|
||||
raise HTTPException(400, "必须提供视频链接")
|
||||
url_value: str = url
|
||||
# URL 下载逻辑
|
||||
# 自动提取文案中的链接 (支持 Douyin/Bilibili 等分享文案)
|
||||
url_match = re.search(r'https?://[^\s]+', url_value)
|
||||
if url_match:
|
||||
extracted_url = url_match.group(0)
|
||||
logger.info(f"Extracted URL from text: {extracted_url}")
|
||||
url_value = extracted_url
|
||||
|
||||
logger.info(f"Tool downloading URL: {url_value}")
|
||||
|
||||
# 封装 yt-dlp 下载函数 (Blocking)
|
||||
def _download_yt_dlp():
|
||||
import yt_dlp
|
||||
logger.info("Attempting download with yt-dlp...")
|
||||
|
||||
ydl_opts = {
|
||||
'format': 'bestaudio/best',
|
||||
'outtmpl': str(temp_dir / f"tool_download_{timestamp}_%(id)s.%(ext)s"),
|
||||
'quiet': True,
|
||||
'no_warnings': True,
|
||||
'http_headers': {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||||
'Referer': 'https://www.douyin.com/',
|
||||
}
|
||||
}
|
||||
|
||||
with yt_dlp.YoutubeDL() as ydl_raw:
|
||||
ydl: Any = ydl_raw
|
||||
ydl.params.update(ydl_opts)
|
||||
info = ydl.extract_info(url_value, download=True)
|
||||
if 'requested_downloads' in info:
|
||||
downloaded_file = info['requested_downloads'][0]['filepath']
|
||||
else:
|
||||
ext = info.get('ext', 'mp4')
|
||||
id = info.get('id')
|
||||
downloaded_file = str(temp_dir / f"tool_download_{timestamp}_{id}.{ext}")
|
||||
|
||||
return Path(downloaded_file)
|
||||
|
||||
# 先尝试 yt-dlp (Run in Executor)
|
||||
try:
|
||||
temp_path = await loop.run_in_executor(None, _download_yt_dlp)
|
||||
logger.info(f"yt-dlp downloaded to: {temp_path}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"yt-dlp download failed: {e}. Trying manual Douyin fallback...")
|
||||
|
||||
# 失败则尝试手动解析 (Douyin Fallback)
|
||||
if "douyin" in url_value:
|
||||
manual_path = await download_douyin_manual(url_value, temp_dir, timestamp)
|
||||
if manual_path:
|
||||
temp_path = manual_path
|
||||
logger.info(f"Manual Douyin fallback successful: {temp_path}")
|
||||
else:
|
||||
raise HTTPException(400, f"视频下载失败。yt-dlp 报错: {str(e)}")
|
||||
elif "bilibili" in url_value:
|
||||
manual_path = await download_bilibili_manual(url_value, temp_dir, timestamp)
|
||||
if manual_path:
|
||||
temp_path = manual_path
|
||||
logger.info(f"Manual Bilibili fallback successful: {temp_path}")
|
||||
else:
|
||||
raise HTTPException(400, f"视频下载失败。yt-dlp 报错: {str(e)}")
|
||||
else:
|
||||
raise HTTPException(400, f"视频下载失败: {str(e)}")
|
||||
|
||||
if not temp_path or not temp_path.exists():
|
||||
raise HTTPException(400, "文件获取失败")
|
||||
|
||||
# 1.5 安全转换: 强制转为 WAV (16k)
|
||||
import subprocess
|
||||
audio_path = temp_dir / f"extract_audio_{timestamp}.wav"
|
||||
|
||||
def _convert_audio():
|
||||
try:
|
||||
convert_cmd = [
|
||||
'ffmpeg',
|
||||
'-i', str(temp_path),
|
||||
'-vn', # 忽略视频
|
||||
'-acodec', 'pcm_s16le',
|
||||
'-ar', '16000', # Whisper 推荐采样率
|
||||
'-ac', '1', # 单声道
|
||||
'-y', # 覆盖
|
||||
str(audio_path)
|
||||
]
|
||||
# 捕获 stderr
|
||||
subprocess.run(convert_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
return True
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_log = e.stderr.decode('utf-8', errors='ignore') if e.stderr else str(e)
|
||||
logger.error(f"FFmpeg check/convert failed: {error_log}")
|
||||
# 检查是否为 HTML
|
||||
head = b""
|
||||
try:
|
||||
with open(temp_path, 'rb') as f:
|
||||
head = f.read(100)
|
||||
except: pass
|
||||
if b'<!DOCTYPE html' in head or b'<html' in head:
|
||||
raise ValueError("HTML_DETECTED")
|
||||
raise ValueError("CONVERT_FAILED")
|
||||
|
||||
# 执行转换 (Run in Executor)
|
||||
try:
|
||||
await loop.run_in_executor(None, _convert_audio)
|
||||
logger.info(f"Converted to WAV: {audio_path}")
|
||||
target_path = audio_path
|
||||
except ValueError as ve:
|
||||
if str(ve) == "HTML_DETECTED":
|
||||
raise HTTPException(400, "下载的文件是网页而非视频,请重试或手动上传。")
|
||||
else:
|
||||
raise HTTPException(400, "下载的文件已损坏或格式无法识别。")
|
||||
|
||||
# 2. 提取文案 (Whisper)
|
||||
script = await whisper_service.transcribe(str(target_path))
|
||||
|
||||
# 3. AI 洗稿 (GLM)
|
||||
rewritten = None
|
||||
if rewrite:
|
||||
if script and len(script.strip()) > 0:
|
||||
logger.info("Rewriting script...")
|
||||
rewritten = await glm_service.rewrite_script(script)
|
||||
else:
|
||||
logger.warning("No script extracted, skipping rewrite")
|
||||
|
||||
return success_response({
|
||||
"original_script": script,
|
||||
"rewritten_script": rewritten
|
||||
})
|
||||
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except Exception as e:
|
||||
logger.error(f"Tool extract failed: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# Friendly error message
|
||||
msg = str(e)
|
||||
if "Fresh cookies" in msg:
|
||||
msg = "下载失败:目标平台开启了反爬验证,请过段时间重试或直接上传视频文件。"
|
||||
|
||||
raise HTTPException(500, f"提取失败: {msg}")
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if temp_path and temp_path.exists():
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
logger.info(f"Cleaned up temp file: {temp_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup temp file {temp_path}: {e}")
|
||||
|
||||
|
||||
async def download_douyin_manual(url: str, temp_dir: Path, timestamp: int) -> Optional[Path]:
|
||||
"""
|
||||
手动下载抖音视频 (Fallback logic - Ported from SuperIPAgent/douyinDownloader)
|
||||
使用特定的 User Profile URL 和硬编码 Cookie 绕过反爬
|
||||
"""
|
||||
logger.info(f"[SuperIPAgent] Starting download for: {url}")
|
||||
|
||||
try:
|
||||
# 1. 提取 Modal ID (支持短链跳转)
|
||||
headers = {
|
||||
"user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
|
||||
}
|
||||
|
||||
# 如果是短链或重定向
|
||||
resp = requests.get(url, headers=headers, allow_redirects=True, timeout=10)
|
||||
final_url = resp.url
|
||||
logger.info(f"[SuperIPAgent] Final URL: {final_url}")
|
||||
|
||||
modal_id = None
|
||||
match = re.search(r'/video/(\d+)', final_url)
|
||||
if match:
|
||||
modal_id = match.group(1)
|
||||
|
||||
if not modal_id:
|
||||
logger.error("[SuperIPAgent] Could not extract modal_id")
|
||||
return None
|
||||
|
||||
logger.info(f"[SuperIPAgent] Extracted modal_id: {modal_id}")
|
||||
|
||||
# 2. 构造特定请求 URL (Copy from SuperIPAgent)
|
||||
# 使用特定用户的 Profile 页 + modal_id 参数,配合特定 Cookie
|
||||
target_url = f"https://www.douyin.com/user/MS4wLjABAAAAN_s_hups7LD0N4qnrM3o2gI0vuG3pozNaEolz2_py3cHTTrpVr1Z4dukFD9SOlwY?from_tab_name=main&modal_id={modal_id}"
|
||||
|
||||
# 3. 使用硬编码 Cookie (Copy from SuperIPAgent)
|
||||
headers_with_cookie = {
|
||||
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
|
||||
"cookie": "douyin.com; device_web_cpu_core=10; device_web_memory_size=8; __ac_nonce=06760391f00b9b51264ae; __ac_signature=_02B4Z6wo00f019a5ceAAAIDAhEZR-X3jjWfWmXVAAJLXd4; ttwid=1%7C7MTKBSMsP4eOv9h5NAh8p0E-NYIud09ftNmB0mjLpWc%7C1734359327%7C8794abeabbd47447e1f56e5abc726be089f2a0344d6343b5f75f23e7b0f0028f; UIFID_TEMP=0de8750d2b188f4235dbfd208e44abbb976428f0720eb983255afefa45d39c0c6532e1d4768dd8587bf919f866ff1396912bcb2af71efee56a14a2a9f37b74010d0a0413795262f6d4afe02a032ac7ab; s_v_web_id=verify_m4r4ribr_c7krmY1z_WoeI_43po_ATpO_I4o8U1bex2D7; hevc_supported=true; home_can_add_dy_2_desktop=%220%22; dy_swidth=2560; dy_sheight=1440; stream_recommend_feed_params=%22%7B%5C%22cookie_enabled%5C%22%3Atrue%2C%5C%22screen_width%5C%22%3A2560%2C%5C%22screen_height%5C%22%3A1440%2C%5C%22browser_online%5C%22%3Atrue%2C%5C%22cpu_core_num%5C%22%3A10%2C%5C%22device_memory%5C%22%3A8%2C%5C%22downlink%5C%22%3A10%2C%5C%22effective_type%5C%22%3A%5C%224g%5C%22%2C%5C%22round_trip_time%5C%22%3A50%7D%22; strategyABtestKey=%221734359328.577%22; csrf_session_id=2f53aed9aa6974e83aa9a1014180c3a4; fpk1=U2FsdGVkX1/IpBh0qdmlKAVhGyYHgur4/VtL9AReZoeSxadXn4juKvsakahRGqjxOPytHWspYoBogyhS/V6QSw==; fpk2=0845b309c7b9b957afd9ecf775a4c21f; passport_csrf_token=d80e0c5b2fa2328219856be5ba7e671e; passport_csrf_token_default=d80e0c5b2fa2328219856be5ba7e671e; odin_tt=3c891091d2eb0f4718c1d5645bc4a0017032d4d5aa989decb729e9da2ad570918cbe5e9133dc6b145fa8c758de98efe32ff1f81aa0d611e838cc73ab08ef7d3f6adf66ab4d10e8372ddd628f94f16b8e; volume_info=%7B%22isUserMute%22%3Afalse%2C%22isMute%22%3Afalse%2C%22volume%22%3A0.5%7D; bd_ticket_guard_client_web_domain=2; FORCE_LOGIN=%7B%22videoConsumedRemainSeconds%22%3A180%7D; UIFID=0de8750d2b188f4235dbfd208e44abbb976428f0720eb983255afefa45d39c0c6532e1d4768dd8587bf919f866ff139655a3c2b735923234f371c699560c657923fd3d6c5b63ab7bb9b83423b6cb4787e2ce66a7fbc4ecb24c8570f520fe6de068bbb95115023c0c6c1b6ee31b49fb7e3996fb8349f43a3fd8b7a61cd9e18e8fe65eb6a7c13de4c0960d84e344b644725db3eb2fa6b7caf821de1b50527979f2; is_dash_user=1; biz_trace_id=b57a241f; bd_ticket_guard_client_data=eyJiZC10aWNrZXQtZ3VhcmQtdmVyc2lvbiI6MiwiYmQtdGlja2V0LWd1YXJkLWl0ZXJhdGlvbi12ZXJzaW9uIjoxLCJiZC10aWNrZXQtZ3VhcmQtcmVlLXB1YmxpYy1rZXkiOiJCTEo2R0lDalVoWW1XcHpGOFdrN0Vrc0dXcCtaUzNKY1g4NGNGY2k0TTl1TEowNjdUb21mbFU5aDdvWVBGamhNRWNRQWtKdnN1MnM3RmpTWnlJQXpHMjA9IiwiYmQtdGlja2V0LWd1YXJkLXdlYi12ZXJzaW9uIjoyfQ%3D%3D; download_guide=%221%2F20241216%2F0%22; sdk_source_info=7e276470716a68645a606960273f276364697660272927676c715a6d6069756077273f276364697660272927666d776a68605a607d71606b766c6a6b5a7666776c7571273f275e58272927666a6b766a69605a696c6061273f27636469766027292762696a6764695a7364776c6467696076273f275e5827292771273f273d33323131333c3036313632342778; bit_env=RiOY4jzzpxZoVCl6zdVSVhVRjdwHRTxqcqWdqMBZLPGjMdB4Tax1kAELHNTVAAh72KuhumewE4Lq6f0-VJ2UpJrkrhSxoPw9LUb3zQrq1OSwbeSPHkRlRgRQvO89sItdGUyq1oFr0XyRCnMYG87KSeWyc4x0czGR0o50hTDoDLG5rJVoRcdQOLvjiAegsqyytKF59sPX_QM9qffK2SqYsg0hCggURc_AI6kguDDE5DvG0bnyz1utw4z1eEnIoLrkGDqzqBZj4dOAr0BVU6ofbsS-pOQ2u2PM1dLP9FlBVBlVaqYVgHJeSLsR5k76BRTddUjTb4zEilVIEwAMJWGN4I1BxVt6fC9B5tBQpuT0lj3n3eKXCKXZsd8FrEs5_pbfDsxV-e_WMiXI2ff4qxiTC0U73sfo9OpicKICtZjdq8qsHxJuu6wVR36zvXeL2Wch5C6MzprNvkivv0l8nbh2mSgy1nabZr3dmU6NcR-Bg3Q3xTWUlR9aAUmpopC-cNuXjgLpT-Lw1AYGilSUnCvosth1Gfypq-b0MpgmdSDgTrQ%3D; gulu_source_res=eyJwX2luIjoiMDhjOGQ3ZTJiODQyNjZkZWI5Y2VkMGJiODNlNmY1ZWY0ZjMyNTE2ZmYyZjAzNDMzZjI0OWU1Y2Q1NTczNTk5NyJ9; passport_auth_mix_state=hp9bc3dgb1tm5wd8p82zawus27g0e3ue; IsDouyinActive=false",
|
||||
"user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
|
||||
}
|
||||
|
||||
logger.info(f"[SuperIPAgent] Requesting page with Cookie...")
|
||||
# 必须 verify=False 否则有些环境会报错
|
||||
response = requests.get(target_url, headers=headers_with_cookie, timeout=10)
|
||||
|
||||
# 4. 解析 RENDER_DATA
|
||||
content_match = re.findall(r'<script id="RENDER_DATA" type="application/json">(.*?)</script>', response.text)
|
||||
if not content_match:
|
||||
# 尝试解码后再查找?或者结构变了
|
||||
# 再尝试找 SSR_HYDRATED_DATA
|
||||
if "SSR_HYDRATED_DATA" in response.text:
|
||||
content_match = re.findall(r'<script id="SSR_HYDRATED_DATA" type="application/json">(.*?)</script>', response.text)
|
||||
|
||||
if not content_match:
|
||||
logger.error(f"[SuperIPAgent] Could not find RENDER_DATA in page (len={len(response.text)})")
|
||||
return None
|
||||
|
||||
content = unquote(content_match[0])
|
||||
try:
|
||||
data = json.loads(content)
|
||||
except:
|
||||
logger.error("[SuperIPAgent] JSON decode failed")
|
||||
return None
|
||||
|
||||
# 5. 提取视频流
|
||||
video_url = None
|
||||
try:
|
||||
# 路径通常是: app -> videoDetail -> video -> bitRateList -> playAddr -> src
|
||||
if "app" in data and "videoDetail" in data["app"]:
|
||||
info = data["app"]["videoDetail"]["video"]
|
||||
if "bitRateList" in info and info["bitRateList"]:
|
||||
video_url = info["bitRateList"][0]["playAddr"][0]["src"]
|
||||
elif "playAddr" in info and info["playAddr"]:
|
||||
video_url = info["playAddr"][0]["src"]
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperIPAgent] Path extraction failed: {e}")
|
||||
|
||||
if not video_url:
|
||||
logger.error("[SuperIPAgent] No video_url found")
|
||||
return None
|
||||
|
||||
if video_url.startswith("//"):
|
||||
video_url = "https:" + video_url
|
||||
|
||||
logger.info(f"[SuperIPAgent] Found video URL: {video_url[:50]}...")
|
||||
|
||||
# 6. 下载 (带 Header)
|
||||
temp_path = temp_dir / f"douyin_manual_{timestamp}.mp4"
|
||||
download_headers = {
|
||||
'Referer': 'https://www.douyin.com/',
|
||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36',
|
||||
}
|
||||
|
||||
dl_resp = requests.get(video_url, headers=download_headers, stream=True, timeout=60)
|
||||
if dl_resp.status_code == 200:
|
||||
with open(temp_path, 'wb') as f:
|
||||
for chunk in dl_resp.iter_content(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
|
||||
logger.info(f"[SuperIPAgent] Downloaded successfully: {temp_path}")
|
||||
return temp_path
|
||||
else:
|
||||
logger.error(f"[SuperIPAgent] Download failed: {dl_resp.status_code}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SuperIPAgent] Logic failed: {e}")
|
||||
return None
|
||||
|
||||
async def download_bilibili_manual(url: str, temp_dir: Path, timestamp: int) -> Optional[Path]:
|
||||
"""
|
||||
手动下载 Bilibili 视频 (Fallback logic - Playwright Version)
|
||||
B站通常音视频分离,这里只提取音频即可(因为只需要文案)
|
||||
"""
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
logger.info(f"[Playwright] Starting Bilibili download for: {url}")
|
||||
|
||||
playwright = None
|
||||
browser = None
|
||||
try:
|
||||
playwright = await async_playwright().start()
|
||||
# Launch browser (ensure chromium is installed: playwright install chromium)
|
||||
browser = await playwright.chromium.launch(headless=True, args=['--no-sandbox', '--disable-setuid-sandbox'])
|
||||
|
||||
# Mobile User Agent often gives single stream?
|
||||
# But Bilibili mobile web is tricky. Desktop is fine.
|
||||
context = await browser.new_context(
|
||||
user_agent="Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
page = await context.new_page()
|
||||
|
||||
# Intercept audio responses?
|
||||
# Bilibili streams are usually .m4s
|
||||
# But finding the initial state is easier.
|
||||
|
||||
logger.info("[Playwright] Navigating to Bilibili...")
|
||||
await page.goto(url, timeout=45000)
|
||||
|
||||
# Wait for video element (triggers loading)
|
||||
try:
|
||||
await page.wait_for_selector('video', timeout=15000)
|
||||
except:
|
||||
logger.warning("[Playwright] Video selector timeout")
|
||||
|
||||
# 1. Try extracting from __playinfo__
|
||||
# window.__playinfo__ contains dash streams
|
||||
playinfo = await page.evaluate("window.__playinfo__")
|
||||
|
||||
audio_url = None
|
||||
|
||||
if playinfo and "data" in playinfo and "dash" in playinfo["data"]:
|
||||
dash = playinfo["data"]["dash"]
|
||||
if "audio" in dash and dash["audio"]:
|
||||
audio_url = dash["audio"][0]["baseUrl"]
|
||||
logger.info(f"[Playwright] Found audio stream in __playinfo__: {audio_url[:50]}...")
|
||||
|
||||
# 2. If playinfo fails, try extracting video src (sometimes it's a blob, which we can't fetch easily without interception)
|
||||
# But interception is complex. Let's try requests with Referer if we have URL.
|
||||
|
||||
if not audio_url:
|
||||
logger.warning("[Playwright] Could not find audio in __playinfo__")
|
||||
return None
|
||||
|
||||
# Download the audio stream
|
||||
temp_path = temp_dir / f"bilibili_audio_{timestamp}.m4s" # usually m4s
|
||||
|
||||
try:
|
||||
api_request = context.request
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
|
||||
"Referer": "https://www.bilibili.com/"
|
||||
}
|
||||
|
||||
logger.info(f"[Playwright] Downloading audio stream...")
|
||||
response = await api_request.get(audio_url, headers=headers)
|
||||
|
||||
if response.status == 200:
|
||||
body = await response.body()
|
||||
with open(temp_path, 'wb') as f:
|
||||
f.write(body)
|
||||
|
||||
logger.info(f"[Playwright] Downloaded successfully: {temp_path}")
|
||||
return temp_path
|
||||
else:
|
||||
logger.error(f"[Playwright] API Request failed: {response.status}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Playwright] Download logic error: {e}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Playwright] Bilibili download failed: {e}")
|
||||
return None
|
||||
finally:
|
||||
if browser:
|
||||
await browser.close()
|
||||
if playwright:
|
||||
await playwright.stop()
|
||||
0
backend/app/modules/videos/__init__.py
Normal file
0
backend/app/modules/videos/__init__.py
Normal file
57
backend/app/modules/videos/router.py
Normal file
57
backend/app/modules/videos/router.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends
|
||||
import uuid
|
||||
|
||||
from app.core.deps import get_current_user
|
||||
from app.core.response import success_response
|
||||
|
||||
from .schemas import GenerateRequest
|
||||
from .task_store import create_task, get_task, list_tasks
|
||||
from .workflow import process_video_generation, get_lipsync_health, get_voiceclone_health
|
||||
from .service import list_generated_videos, delete_generated_video
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/generate")
|
||||
async def generate_video(
|
||||
req: GenerateRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
user_id = current_user["id"]
|
||||
task_id = str(uuid.uuid4())
|
||||
create_task(task_id, user_id)
|
||||
background_tasks.add_task(process_video_generation, task_id, req, user_id)
|
||||
return success_response({"task_id": task_id})
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}")
|
||||
async def get_task_status(task_id: str):
|
||||
return success_response(get_task(task_id))
|
||||
|
||||
|
||||
@router.get("/tasks")
|
||||
async def list_tasks_view():
|
||||
return success_response({"tasks": list_tasks()})
|
||||
|
||||
|
||||
@router.get("/lipsync/health")
|
||||
async def lipsync_health():
|
||||
return success_response(await get_lipsync_health())
|
||||
|
||||
|
||||
@router.get("/voiceclone/health")
|
||||
async def voiceclone_health():
|
||||
return success_response(await get_voiceclone_health())
|
||||
|
||||
|
||||
@router.get("/generated")
|
||||
async def list_generated(current_user: dict = Depends(get_current_user)):
|
||||
return success_response(await list_generated_videos(current_user["id"]))
|
||||
|
||||
|
||||
@router.delete("/generated/{video_id}")
|
||||
async def delete_generated(video_id: str, current_user: dict = Depends(get_current_user)):
|
||||
result = await delete_generated_video(current_user["id"], video_id)
|
||||
return success_response(result, message="视频已删除")
|
||||
19
backend/app/modules/videos/schemas.py
Normal file
19
backend/app/modules/videos/schemas.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
text: str
|
||||
voice: str = "zh-CN-YunxiNeural"
|
||||
material_path: str
|
||||
tts_mode: str = "edgetts"
|
||||
ref_audio_id: Optional[str] = None
|
||||
ref_text: Optional[str] = None
|
||||
title: Optional[str] = None
|
||||
enable_subtitles: bool = True
|
||||
subtitle_style_id: Optional[str] = None
|
||||
title_style_id: Optional[str] = None
|
||||
subtitle_font_size: Optional[int] = None
|
||||
title_font_size: Optional[int] = None
|
||||
bgm_id: Optional[str] = None
|
||||
bgm_volume: Optional[float] = 0.2
|
||||
87
backend/app/modules/videos/service.py
Normal file
87
backend/app/modules/videos/service.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from fastapi import HTTPException
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
from app.services.storage import storage_service
|
||||
|
||||
|
||||
async def list_generated_videos(user_id: str) -> dict:
|
||||
"""从 Storage 读取当前用户生成的视频列表"""
|
||||
try:
|
||||
files_obj = await storage_service.list_files(
|
||||
bucket=storage_service.BUCKET_OUTPUTS,
|
||||
path=user_id
|
||||
)
|
||||
|
||||
semaphore = asyncio.Semaphore(8)
|
||||
|
||||
async def build_item(f):
|
||||
name = f.get("name")
|
||||
if not name or name == ".emptyFolderPlaceholder":
|
||||
return None
|
||||
|
||||
if not name.endswith("_output.mp4"):
|
||||
return None
|
||||
|
||||
video_id = Path(name).stem
|
||||
full_path = f"{user_id}/{name}"
|
||||
|
||||
async with semaphore:
|
||||
signed_url = await storage_service.get_signed_url(
|
||||
bucket=storage_service.BUCKET_OUTPUTS,
|
||||
path=full_path
|
||||
)
|
||||
|
||||
metadata = f.get("metadata", {})
|
||||
size = metadata.get("size", 0)
|
||||
created_at_str = f.get("created_at", "")
|
||||
created_at = 0
|
||||
if created_at_str:
|
||||
from datetime import datetime
|
||||
try:
|
||||
dt = datetime.fromisoformat(created_at_str.replace("Z", "+00:00"))
|
||||
created_at = int(dt.timestamp())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"id": video_id,
|
||||
"name": name,
|
||||
"path": signed_url,
|
||||
"size_mb": size / (1024 * 1024),
|
||||
"created_at": created_at
|
||||
}
|
||||
|
||||
tasks = [build_item(f) for f in files_obj]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
videos = []
|
||||
for item in results:
|
||||
if not item:
|
||||
continue
|
||||
if isinstance(item, Exception):
|
||||
logger.warning(f"Signed url build failed: {item}")
|
||||
continue
|
||||
videos.append(item)
|
||||
|
||||
videos.sort(key=lambda x: x.get("created_at", ""), reverse=True)
|
||||
return {"videos": videos}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"List generated videos failed: {e}")
|
||||
return {"videos": []}
|
||||
|
||||
|
||||
async def delete_generated_video(user_id: str, video_id: str) -> dict:
|
||||
"""删除生成的视频"""
|
||||
try:
|
||||
storage_path = f"{user_id}/{video_id}.mp4"
|
||||
|
||||
await storage_service.delete_file(
|
||||
bucket=storage_service.BUCKET_OUTPUTS,
|
||||
path=storage_path
|
||||
)
|
||||
return {"video_id": video_id}
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"删除失败: {str(e)}")
|
||||
118
backend/app/modules/videos/task_store.py
Normal file
118
backend/app/modules/videos/task_store.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from typing import Any, Dict, List
|
||||
import json
|
||||
|
||||
from loguru import logger
|
||||
from app.core.config import settings
|
||||
|
||||
try:
|
||||
import redis
|
||||
except Exception: # pragma: no cover - optional dependency
|
||||
redis = None
|
||||
|
||||
|
||||
class InMemoryTaskStore:
|
||||
def __init__(self) -> None:
|
||||
self._tasks: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def create(self, task_id: str, user_id: str) -> Dict[str, Any]:
|
||||
task = {
|
||||
"status": "pending",
|
||||
"task_id": task_id,
|
||||
"progress": 0,
|
||||
"user_id": user_id,
|
||||
}
|
||||
self._tasks[task_id] = task
|
||||
return task
|
||||
|
||||
def get(self, task_id: str) -> Dict[str, Any]:
|
||||
return self._tasks.get(task_id, {"status": "not_found"})
|
||||
|
||||
def list(self) -> List[Dict[str, Any]]:
|
||||
return list(self._tasks.values())
|
||||
|
||||
def update(self, task_id: str, updates: Dict[str, Any]) -> Dict[str, Any]:
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
task = {"status": "pending", "task_id": task_id}
|
||||
self._tasks[task_id] = task
|
||||
task.update(updates)
|
||||
return task
|
||||
|
||||
|
||||
class RedisTaskStore:
|
||||
def __init__(self, client: "redis.Redis") -> None:
|
||||
self._client = client
|
||||
self._index_key = "vigent:tasks:index"
|
||||
|
||||
def _key(self, task_id: str) -> str:
|
||||
return f"vigent:tasks:{task_id}"
|
||||
|
||||
def create(self, task_id: str, user_id: str) -> Dict[str, Any]:
|
||||
task = {
|
||||
"status": "pending",
|
||||
"task_id": task_id,
|
||||
"progress": 0,
|
||||
"user_id": user_id,
|
||||
}
|
||||
self._client.set(self._key(task_id), json.dumps(task, ensure_ascii=False))
|
||||
self._client.sadd(self._index_key, task_id)
|
||||
return task
|
||||
|
||||
def get(self, task_id: str) -> Dict[str, Any]:
|
||||
raw = self._client.get(self._key(task_id))
|
||||
if not raw:
|
||||
return {"status": "not_found"}
|
||||
return json.loads(raw)
|
||||
|
||||
def list(self) -> List[Dict[str, Any]]:
|
||||
task_ids = list(self._client.smembers(self._index_key) or [])
|
||||
if not task_ids:
|
||||
return []
|
||||
keys = [self._key(task_id) for task_id in task_ids]
|
||||
raw_items = self._client.mget(keys)
|
||||
tasks = []
|
||||
for raw in raw_items:
|
||||
if raw:
|
||||
try:
|
||||
tasks.append(json.loads(raw))
|
||||
except Exception:
|
||||
continue
|
||||
return tasks
|
||||
|
||||
def update(self, task_id: str, updates: Dict[str, Any]) -> Dict[str, Any]:
|
||||
task = self.get(task_id)
|
||||
if task.get("status") == "not_found":
|
||||
task = {"status": "pending", "task_id": task_id}
|
||||
task.update(updates)
|
||||
self._client.set(self._key(task_id), json.dumps(task, ensure_ascii=False))
|
||||
self._client.sadd(self._index_key, task_id)
|
||||
return task
|
||||
|
||||
|
||||
def _build_task_store():
|
||||
if redis is None:
|
||||
logger.warning("Redis not available, using in-memory task store")
|
||||
return InMemoryTaskStore()
|
||||
try:
|
||||
client = redis.Redis.from_url(settings.REDIS_URL, decode_responses=True)
|
||||
client.ping()
|
||||
logger.info("Using Redis task store")
|
||||
return RedisTaskStore(client)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis connection failed, using in-memory task store: {e}")
|
||||
return InMemoryTaskStore()
|
||||
|
||||
|
||||
task_store = _build_task_store()
|
||||
|
||||
|
||||
def create_task(task_id: str, user_id: str) -> Dict[str, Any]:
|
||||
return task_store.create(task_id, user_id)
|
||||
|
||||
|
||||
def get_task(task_id: str) -> Dict[str, Any]:
|
||||
return task_store.get(task_id)
|
||||
|
||||
|
||||
def list_tasks() -> List[Dict[str, Any]]:
|
||||
return task_store.list()
|
||||
328
backend/app/modules/videos/workflow.py
Normal file
328
backend/app/modules/videos/workflow.py
Normal file
@@ -0,0 +1,328 @@
|
||||
from typing import Optional, Any
|
||||
from pathlib import Path
|
||||
import time
|
||||
import traceback
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from app.core.config import settings
|
||||
from app.services.tts_service import TTSService
|
||||
from app.services.video_service import VideoService
|
||||
from app.services.lipsync_service import LipSyncService
|
||||
from app.services.voice_clone_service import voice_clone_service
|
||||
from app.services.assets_service import (
|
||||
get_style,
|
||||
get_default_style,
|
||||
resolve_bgm_path,
|
||||
prepare_style_for_remotion,
|
||||
)
|
||||
from app.services.storage import storage_service
|
||||
from app.services.whisper_service import whisper_service
|
||||
from app.services.remotion_service import remotion_service
|
||||
|
||||
from .schemas import GenerateRequest
|
||||
from .task_store import task_store
|
||||
|
||||
|
||||
_lipsync_service: Optional[LipSyncService] = None
|
||||
_lipsync_ready: Optional[bool] = None
|
||||
_lipsync_last_check: float = 0
|
||||
|
||||
|
||||
def _get_lipsync_service() -> LipSyncService:
|
||||
"""获取或创建 LipSync 服务实例(单例模式,避免重复初始化)"""
|
||||
global _lipsync_service
|
||||
if _lipsync_service is None:
|
||||
_lipsync_service = LipSyncService()
|
||||
return _lipsync_service
|
||||
|
||||
|
||||
async def _check_lipsync_ready(force: bool = False) -> bool:
|
||||
"""检查 LipSync 是否就绪(带缓存,5分钟内不重复检查)"""
|
||||
global _lipsync_ready, _lipsync_last_check
|
||||
|
||||
now = time.time()
|
||||
if not force and _lipsync_ready is not None and (now - _lipsync_last_check) < 300:
|
||||
return bool(_lipsync_ready)
|
||||
|
||||
lipsync = _get_lipsync_service()
|
||||
health = await lipsync.check_health()
|
||||
_lipsync_ready = health.get("ready", False)
|
||||
_lipsync_last_check = now
|
||||
print(f"[LipSync] Health check: ready={_lipsync_ready}")
|
||||
return bool(_lipsync_ready)
|
||||
|
||||
|
||||
async def _download_material(path_or_url: str, temp_path: Path):
|
||||
"""下载素材到临时文件 (流式下载,节省内存)"""
|
||||
if path_or_url.startswith("http"):
|
||||
timeout = httpx.Timeout(None)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream("GET", path_or_url) as resp:
|
||||
resp.raise_for_status()
|
||||
with open(temp_path, "wb") as f:
|
||||
async for chunk in resp.aiter_bytes():
|
||||
f.write(chunk)
|
||||
else:
|
||||
src = Path(path_or_url)
|
||||
if not src.is_absolute():
|
||||
src = settings.BASE_DIR.parent / path_or_url
|
||||
|
||||
if src.exists():
|
||||
import shutil
|
||||
shutil.copy(src, temp_path)
|
||||
else:
|
||||
raise FileNotFoundError(f"Material not found: {path_or_url}")
|
||||
|
||||
|
||||
def _update_task(task_id: str, **updates: Any) -> None:
|
||||
task_store.update(task_id, updates)
|
||||
|
||||
|
||||
async def process_video_generation(task_id: str, req: GenerateRequest, user_id: str):
|
||||
temp_files = []
|
||||
try:
|
||||
start_time = time.time()
|
||||
_update_task(task_id, status="processing", progress=5, message="正在下载素材...")
|
||||
|
||||
temp_dir = settings.UPLOAD_DIR / "temp"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
input_material_path = temp_dir / f"{task_id}_input.mp4"
|
||||
temp_files.append(input_material_path)
|
||||
|
||||
await _download_material(req.material_path, input_material_path)
|
||||
|
||||
_update_task(task_id, message="正在生成语音...", progress=10)
|
||||
|
||||
audio_path = temp_dir / f"{task_id}_audio.wav"
|
||||
temp_files.append(audio_path)
|
||||
|
||||
if req.tts_mode == "voiceclone":
|
||||
if not req.ref_audio_id or not req.ref_text:
|
||||
raise ValueError("声音克隆模式需要提供参考音频和参考文字")
|
||||
|
||||
_update_task(task_id, message="正在下载参考音频...")
|
||||
|
||||
ref_audio_local = temp_dir / f"{task_id}_ref.wav"
|
||||
temp_files.append(ref_audio_local)
|
||||
|
||||
ref_audio_url = await storage_service.get_signed_url(
|
||||
bucket="ref-audios",
|
||||
path=req.ref_audio_id
|
||||
)
|
||||
await _download_material(ref_audio_url, ref_audio_local)
|
||||
|
||||
_update_task(task_id, message="正在克隆声音 (Qwen3-TTS)...")
|
||||
await voice_clone_service.generate_audio(
|
||||
text=req.text,
|
||||
ref_audio_path=str(ref_audio_local),
|
||||
ref_text=req.ref_text,
|
||||
output_path=str(audio_path),
|
||||
language="Chinese"
|
||||
)
|
||||
else:
|
||||
_update_task(task_id, message="正在生成语音 (EdgeTTS)...")
|
||||
tts = TTSService()
|
||||
await tts.generate_audio(req.text, req.voice, str(audio_path))
|
||||
|
||||
tts_time = time.time() - start_time
|
||||
print(f"[Pipeline] TTS completed in {tts_time:.1f}s")
|
||||
_update_task(task_id, progress=25)
|
||||
|
||||
_update_task(task_id, message="正在合成唇形 (LatentSync)...", progress=30)
|
||||
|
||||
lipsync = _get_lipsync_service()
|
||||
lipsync_video_path = temp_dir / f"{task_id}_lipsync.mp4"
|
||||
temp_files.append(lipsync_video_path)
|
||||
|
||||
lipsync_start = time.time()
|
||||
is_ready = await _check_lipsync_ready()
|
||||
|
||||
if is_ready:
|
||||
print(f"[LipSync] Starting LatentSync inference...")
|
||||
_update_task(task_id, progress=35, message="正在运行 LatentSync 推理...")
|
||||
await lipsync.generate(str(input_material_path), str(audio_path), str(lipsync_video_path))
|
||||
else:
|
||||
print(f"[LipSync] LatentSync not ready, copying original video")
|
||||
_update_task(task_id, message="唇形同步不可用,使用原始视频...")
|
||||
import shutil
|
||||
shutil.copy(str(input_material_path), lipsync_video_path)
|
||||
|
||||
lipsync_time = time.time() - lipsync_start
|
||||
print(f"[Pipeline] LipSync completed in {lipsync_time:.1f}s")
|
||||
_update_task(task_id, progress=80)
|
||||
|
||||
captions_path = None
|
||||
if req.enable_subtitles:
|
||||
_update_task(task_id, message="正在生成字幕 (Whisper)...", progress=82)
|
||||
|
||||
captions_path = temp_dir / f"{task_id}_captions.json"
|
||||
temp_files.append(captions_path)
|
||||
|
||||
try:
|
||||
await whisper_service.align(
|
||||
audio_path=str(audio_path),
|
||||
text=req.text,
|
||||
output_path=str(captions_path)
|
||||
)
|
||||
print(f"[Pipeline] Whisper alignment completed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Whisper alignment failed, skipping subtitles: {e}")
|
||||
captions_path = None
|
||||
|
||||
_update_task(task_id, progress=85)
|
||||
|
||||
video = VideoService()
|
||||
final_audio_path = audio_path
|
||||
if req.bgm_id:
|
||||
_update_task(task_id, message="正在合成背景音乐...", progress=86)
|
||||
|
||||
bgm_path = resolve_bgm_path(req.bgm_id)
|
||||
if bgm_path:
|
||||
mix_output_path = temp_dir / f"{task_id}_audio_mix.wav"
|
||||
temp_files.append(mix_output_path)
|
||||
volume = req.bgm_volume if req.bgm_volume is not None else 0.2
|
||||
volume = max(0.0, min(float(volume), 1.0))
|
||||
try:
|
||||
video.mix_audio(
|
||||
voice_path=str(audio_path),
|
||||
bgm_path=str(bgm_path),
|
||||
output_path=str(mix_output_path),
|
||||
bgm_volume=volume
|
||||
)
|
||||
final_audio_path = mix_output_path
|
||||
except Exception as e:
|
||||
logger.warning(f"BGM mix failed, fallback to voice only: {e}")
|
||||
else:
|
||||
logger.warning(f"BGM not found: {req.bgm_id}")
|
||||
|
||||
use_remotion = (captions_path and captions_path.exists()) or req.title
|
||||
|
||||
subtitle_style = None
|
||||
title_style = None
|
||||
if req.enable_subtitles:
|
||||
subtitle_style = get_style("subtitle", req.subtitle_style_id) or get_default_style("subtitle")
|
||||
if req.title:
|
||||
title_style = get_style("title", req.title_style_id) or get_default_style("title")
|
||||
|
||||
if req.subtitle_font_size and req.enable_subtitles:
|
||||
if subtitle_style is None:
|
||||
subtitle_style = {}
|
||||
subtitle_style["font_size"] = int(req.subtitle_font_size)
|
||||
|
||||
if req.title_font_size and req.title:
|
||||
if title_style is None:
|
||||
title_style = {}
|
||||
title_style["font_size"] = int(req.title_font_size)
|
||||
|
||||
if use_remotion:
|
||||
subtitle_style = prepare_style_for_remotion(
|
||||
subtitle_style,
|
||||
temp_dir,
|
||||
f"{task_id}_subtitle_font"
|
||||
)
|
||||
title_style = prepare_style_for_remotion(
|
||||
title_style,
|
||||
temp_dir,
|
||||
f"{task_id}_title_font"
|
||||
)
|
||||
|
||||
final_output_local_path = temp_dir / f"{task_id}_output.mp4"
|
||||
temp_files.append(final_output_local_path)
|
||||
|
||||
if use_remotion:
|
||||
_update_task(task_id, message="正在合成视频 (Remotion)...", progress=87)
|
||||
|
||||
composed_video_path = temp_dir / f"{task_id}_composed.mp4"
|
||||
temp_files.append(composed_video_path)
|
||||
|
||||
await video.compose(str(lipsync_video_path), str(final_audio_path), str(composed_video_path))
|
||||
|
||||
remotion_health = await remotion_service.check_health()
|
||||
if remotion_health.get("ready"):
|
||||
try:
|
||||
def on_remotion_progress(percent):
|
||||
mapped = 87 + int(percent * 0.08)
|
||||
_update_task(task_id, progress=mapped)
|
||||
|
||||
await remotion_service.render(
|
||||
video_path=str(composed_video_path),
|
||||
output_path=str(final_output_local_path),
|
||||
captions_path=str(captions_path) if captions_path else None,
|
||||
title=req.title,
|
||||
title_duration=3.0,
|
||||
fps=25,
|
||||
enable_subtitles=req.enable_subtitles,
|
||||
subtitle_style=subtitle_style,
|
||||
title_style=title_style,
|
||||
on_progress=on_remotion_progress
|
||||
)
|
||||
print(f"[Pipeline] Remotion render completed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Remotion render failed, using FFmpeg fallback: {e}")
|
||||
import shutil
|
||||
shutil.copy(str(composed_video_path), final_output_local_path)
|
||||
else:
|
||||
logger.warning(f"Remotion not ready: {remotion_health.get('error')}, using FFmpeg")
|
||||
import shutil
|
||||
shutil.copy(str(composed_video_path), final_output_local_path)
|
||||
else:
|
||||
_update_task(task_id, message="正在合成最终视频...", progress=90)
|
||||
|
||||
await video.compose(str(lipsync_video_path), str(final_audio_path), str(final_output_local_path))
|
||||
|
||||
total_time = time.time() - start_time
|
||||
|
||||
_update_task(task_id, message="正在上传结果...", progress=95)
|
||||
|
||||
storage_path = f"{user_id}/{task_id}_output.mp4"
|
||||
with open(final_output_local_path, "rb") as f:
|
||||
file_data = f.read()
|
||||
await storage_service.upload_file(
|
||||
bucket=storage_service.BUCKET_OUTPUTS,
|
||||
path=storage_path,
|
||||
file_data=file_data,
|
||||
content_type="video/mp4"
|
||||
)
|
||||
|
||||
signed_url = await storage_service.get_signed_url(
|
||||
bucket=storage_service.BUCKET_OUTPUTS,
|
||||
path=storage_path
|
||||
)
|
||||
|
||||
print(f"[Pipeline] Total generation time: {total_time:.1f}s")
|
||||
|
||||
_update_task(
|
||||
task_id,
|
||||
status="completed",
|
||||
progress=100,
|
||||
message=f"生成完成!耗时 {total_time:.0f} 秒",
|
||||
output=storage_path,
|
||||
download_url=signed_url,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
_update_task(
|
||||
task_id,
|
||||
status="failed",
|
||||
message=f"错误: {str(e)}",
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
logger.error(f"Generate video failed: {e}")
|
||||
finally:
|
||||
for f in temp_files:
|
||||
try:
|
||||
if f.exists():
|
||||
f.unlink()
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up {f}: {e}")
|
||||
|
||||
|
||||
async def get_lipsync_health():
|
||||
lipsync = _get_lipsync_service()
|
||||
return await lipsync.check_health()
|
||||
|
||||
|
||||
async def get_voiceclone_health():
|
||||
return await voice_clone_service.check_health()
|
||||
Reference in New Issue
Block a user