166 lines
6.5 KiB
Python
166 lines
6.5 KiB
Python
from fastapi import FastAPI, HTTPException
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import JSONResponse
|
||
from app.core import config
|
||
from app.core.response import error_response
|
||
# 直接从 modules 导入路由,消除 api 转发层
|
||
from app.modules.materials.router import router as materials_router
|
||
from app.modules.videos.router import router as videos_router
|
||
from app.modules.publish.router import router as publish_router
|
||
from app.modules.login_helper.router import router as login_helper_router
|
||
from app.modules.auth.router import router as auth_router
|
||
from app.modules.admin.router import router as admin_router
|
||
from app.modules.ref_audios.router import router as ref_audios_router
|
||
from app.modules.ai.router import router as ai_router
|
||
from app.modules.tools.router import router as tools_router
|
||
from app.modules.assets.router import router as assets_router
|
||
from loguru import logger
|
||
import os
|
||
|
||
settings = config.settings
|
||
|
||
app = FastAPI(title="ViGent TalkingHead Agent")
|
||
|
||
from fastapi import Request
|
||
from fastapi.exceptions import RequestValidationError
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
import time
|
||
import traceback
|
||
|
||
class LoggingMiddleware(BaseHTTPMiddleware):
|
||
# 敏感 header 名称列表(小写)
|
||
SENSITIVE_HEADERS = {'authorization', 'cookie', 'set-cookie', 'x-api-key', 'api-key'}
|
||
|
||
def _sanitize_headers(self, headers: dict) -> dict:
|
||
"""脱敏处理请求头,隐藏敏感信息"""
|
||
sanitized = {}
|
||
for key, value in headers.items():
|
||
if key.lower() in self.SENSITIVE_HEADERS:
|
||
# 显示前8个字符 + 掩码
|
||
if len(value) > 8:
|
||
sanitized[key] = value[:8] + "..." + f"[{len(value)} chars]"
|
||
else:
|
||
sanitized[key] = "[REDACTED]"
|
||
else:
|
||
sanitized[key] = value
|
||
return sanitized
|
||
|
||
async def dispatch(self, request: Request, call_next):
|
||
start_time = time.time()
|
||
logger.info(f"START Request: {request.method} {request.url}")
|
||
logger.debug(f"HEADERS: {self._sanitize_headers(dict(request.headers))}")
|
||
try:
|
||
response = await call_next(request)
|
||
process_time = time.time() - start_time
|
||
logger.info(f"END Request: {request.method} {request.url} - Status: {response.status_code} - Duration: {process_time:.2f}s")
|
||
return response
|
||
except Exception as e:
|
||
process_time = time.time() - start_time
|
||
logger.error(f"EXCEPTION during request {request.method} {request.url}: {str(e)}\n{traceback.format_exc()}")
|
||
raise e
|
||
|
||
app.add_middleware(LoggingMiddleware)
|
||
|
||
|
||
@app.exception_handler(RequestValidationError)
|
||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||
return JSONResponse(
|
||
status_code=422,
|
||
content=error_response("参数校验失败", 422, data=exc.errors()),
|
||
)
|
||
|
||
|
||
@app.exception_handler(HTTPException)
|
||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||
detail = exc.detail
|
||
message = detail if isinstance(detail, str) else "请求失败"
|
||
data = detail if not isinstance(detail, str) else None
|
||
return JSONResponse(
|
||
status_code=exc.status_code,
|
||
content=error_response(message, exc.status_code, data=data),
|
||
headers=exc.headers,
|
||
)
|
||
|
||
|
||
@app.exception_handler(Exception)
|
||
async def unhandled_exception_handler(request: Request, exc: Exception):
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=error_response("服务器内部错误", 500),
|
||
)
|
||
|
||
# CORS 配置:从环境变量读取允许的域名
|
||
# 当使用 credentials 时,不能使用 * 通配符
|
||
cors_origins = settings.CORS_ORIGINS.split(",") if settings.CORS_ORIGINS != "*" else ["*"]
|
||
allow_credentials = settings.CORS_ORIGINS != "*" # 使用 * 时不能 allow_credentials
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=cors_origins,
|
||
allow_credentials=allow_credentials,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# Create dirs
|
||
settings.UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||
settings.OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
(settings.UPLOAD_DIR / "materials").mkdir(exist_ok=True)
|
||
settings.ASSETS_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
app.mount("/outputs", StaticFiles(directory=str(settings.OUTPUT_DIR)), name="outputs")
|
||
app.mount("/uploads", StaticFiles(directory=str(settings.UPLOAD_DIR)), name="uploads")
|
||
app.mount("/assets", StaticFiles(directory=str(settings.ASSETS_DIR)), name="assets")
|
||
|
||
# 注册路由
|
||
app.include_router(materials_router, prefix="/api/materials", tags=["Materials"])
|
||
app.include_router(videos_router, prefix="/api/videos", tags=["Videos"])
|
||
app.include_router(publish_router, prefix="/api/publish", tags=["Publish"])
|
||
app.include_router(login_helper_router, prefix="/api", tags=["LoginHelper"])
|
||
app.include_router(auth_router) # /api/auth
|
||
app.include_router(admin_router) # /api/admin
|
||
app.include_router(ref_audios_router, prefix="/api/ref-audios", tags=["RefAudios"])
|
||
app.include_router(ai_router) # /api/ai
|
||
app.include_router(tools_router, prefix="/api/tools", tags=["Tools"])
|
||
app.include_router(assets_router, prefix="/api/assets", tags=["Assets"])
|
||
|
||
|
||
@app.on_event("startup")
|
||
async def init_admin():
|
||
"""
|
||
服务启动时初始化管理员账号
|
||
"""
|
||
admin_phone = settings.ADMIN_PHONE
|
||
admin_password = settings.ADMIN_PASSWORD
|
||
|
||
if not admin_phone or not admin_password:
|
||
logger.warning("未配置 ADMIN_PHONE 和 ADMIN_PASSWORD,跳过管理员初始化")
|
||
return
|
||
|
||
try:
|
||
from app.core.security import get_password_hash
|
||
from app.repositories.users import create_user, user_exists_by_phone
|
||
|
||
if user_exists_by_phone(admin_phone):
|
||
logger.info(f"管理员账号已存在: {admin_phone}")
|
||
return
|
||
|
||
create_user({
|
||
"phone": admin_phone,
|
||
"password_hash": get_password_hash(admin_password),
|
||
"username": "Admin",
|
||
"role": "admin",
|
||
"is_active": True,
|
||
"expires_at": None # 永不过期
|
||
})
|
||
|
||
logger.success(f"管理员账号已创建: {admin_phone}")
|
||
except Exception as e:
|
||
logger.error(f"初始化管理员失败: {e}")
|
||
|
||
|
||
@app.get("/health")
|
||
def health():
|
||
return {"status": "ok"}
|