133 lines
4.0 KiB
Python
133 lines
4.0 KiB
Python
"""
|
|
依赖注入模块:认证和用户获取
|
|
"""
|
|
from typing import Optional, Any, Dict, cast
|
|
from fastapi import Request, HTTPException, Depends, status
|
|
from app.core.security import decode_access_token
|
|
from app.repositories.sessions import get_session, delete_sessions
|
|
from app.repositories.users import get_user_by_id, deactivate_user_if_expired
|
|
from loguru import logger
|
|
|
|
|
|
async def get_token_from_cookie(request: Request) -> Optional[str]:
|
|
"""从 Cookie 中获取 Token"""
|
|
return request.cookies.get("access_token")
|
|
|
|
|
|
async def get_current_user_optional(
|
|
request: Request
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
获取当前用户 (可选,未登录返回 None)
|
|
"""
|
|
token = await get_token_from_cookie(request)
|
|
if not token:
|
|
return None
|
|
|
|
token_data = decode_access_token(token)
|
|
if not token_data:
|
|
return None
|
|
|
|
# 验证 session_token 是否有效 (单设备登录检查)
|
|
try:
|
|
session = get_session(token_data.user_id, token_data.session_token)
|
|
if not session:
|
|
logger.warning(f"Session token 无效: user_id={token_data.user_id}")
|
|
return None
|
|
|
|
user = cast(Optional[Dict[str, Any]], get_user_by_id(token_data.user_id))
|
|
if user and deactivate_user_if_expired(user):
|
|
delete_sessions(token_data.user_id)
|
|
return None
|
|
|
|
if user and not user.get("is_active"):
|
|
delete_sessions(token_data.user_id)
|
|
return None
|
|
|
|
return user
|
|
except Exception as e:
|
|
logger.error(f"获取用户信息失败: {e}")
|
|
return None
|
|
|
|
|
|
async def get_current_user(
|
|
request: Request
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
获取当前用户 (必须登录)
|
|
|
|
Raises:
|
|
HTTPException 401: 未登录
|
|
HTTPException 403: 会话失效或授权过期
|
|
"""
|
|
token = await get_token_from_cookie(request)
|
|
if not token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="未登录,请先登录"
|
|
)
|
|
|
|
token_data = decode_access_token(token)
|
|
if not token_data:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Token 无效或已过期"
|
|
)
|
|
|
|
try:
|
|
session = get_session(token_data.user_id, token_data.session_token)
|
|
if not session:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="会话已失效,请重新登录(可能已在其他设备登录)"
|
|
)
|
|
|
|
user = get_user_by_id(token_data.user_id)
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="用户不存在"
|
|
)
|
|
user = cast(Dict[str, Any], user)
|
|
|
|
if deactivate_user_if_expired(user):
|
|
delete_sessions(token_data.user_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="会员已到期,请续费"
|
|
)
|
|
|
|
if not user.get("is_active"):
|
|
delete_sessions(token_data.user_id)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="账号已停用"
|
|
)
|
|
|
|
return user
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"获取用户信息失败: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="服务器错误"
|
|
)
|
|
|
|
|
|
async def get_current_admin(
|
|
current_user: dict = Depends(get_current_user)
|
|
) -> dict:
|
|
"""
|
|
获取当前管理员用户
|
|
|
|
Raises:
|
|
HTTPException 403: 非管理员
|
|
"""
|
|
if current_user.get("role") != "admin":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="需要管理员权限"
|
|
)
|
|
return current_user
|