142 lines
4.3 KiB
Python
142 lines
4.3 KiB
Python
"""
|
|
依赖注入模块:认证和用户获取
|
|
"""
|
|
from typing import Optional
|
|
from fastapi import Request, HTTPException, Depends, status
|
|
from app.core.security import decode_access_token, TokenData
|
|
from app.core.supabase import get_supabase
|
|
from loguru import logger
|
|
|
|
|
|
async def get_token_from_cookie(request: Request) -> Optional[str]:
|
|
"""从 Cookie 中获取 Token"""
|
|
return request.cookies.get("access_token")
|
|
|
|
|
|
async def get_current_user_optional(
|
|
request: Request
|
|
) -> Optional[dict]:
|
|
"""
|
|
获取当前用户 (可选,未登录返回 None)
|
|
"""
|
|
token = await get_token_from_cookie(request)
|
|
if not token:
|
|
return None
|
|
|
|
token_data = decode_access_token(token)
|
|
if not token_data:
|
|
return None
|
|
|
|
# 验证 session_token 是否有效 (单设备登录检查)
|
|
try:
|
|
supabase = get_supabase()
|
|
result = supabase.table("user_sessions").select("*").eq(
|
|
"user_id", token_data.user_id
|
|
).eq(
|
|
"session_token", token_data.session_token
|
|
).execute()
|
|
|
|
if not result.data:
|
|
logger.warning(f"Session token 无效: user_id={token_data.user_id}")
|
|
return None
|
|
|
|
# 获取用户信息
|
|
user_result = supabase.table("users").select("*").eq(
|
|
"id", token_data.user_id
|
|
).single().execute()
|
|
|
|
return user_result.data
|
|
except Exception as e:
|
|
logger.error(f"获取用户信息失败: {e}")
|
|
return None
|
|
|
|
|
|
async def get_current_user(
|
|
request: Request
|
|
) -> dict:
|
|
"""
|
|
获取当前用户 (必须登录)
|
|
|
|
Raises:
|
|
HTTPException 401: 未登录
|
|
HTTPException 403: 会话失效或授权过期
|
|
"""
|
|
token = await get_token_from_cookie(request)
|
|
if not token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="未登录,请先登录"
|
|
)
|
|
|
|
token_data = decode_access_token(token)
|
|
if not token_data:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Token 无效或已过期"
|
|
)
|
|
|
|
try:
|
|
supabase = get_supabase()
|
|
|
|
# 验证 session_token (单设备登录)
|
|
session_result = supabase.table("user_sessions").select("*").eq(
|
|
"user_id", token_data.user_id
|
|
).eq(
|
|
"session_token", token_data.session_token
|
|
).execute()
|
|
|
|
if not session_result.data:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="会话已失效,请重新登录(可能已在其他设备登录)"
|
|
)
|
|
|
|
# 获取用户信息
|
|
user_result = supabase.table("users").select("*").eq(
|
|
"id", token_data.user_id
|
|
).single().execute()
|
|
|
|
user = user_result.data
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="用户不存在"
|
|
)
|
|
|
|
# 检查授权是否过期
|
|
if user.get("expires_at"):
|
|
from datetime import datetime, timezone
|
|
expires_at = datetime.fromisoformat(user["expires_at"].replace("Z", "+00:00"))
|
|
if datetime.now(timezone.utc) > expires_at:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="授权已过期,请联系管理员续期"
|
|
)
|
|
|
|
return user
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"获取用户信息失败: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="服务器错误"
|
|
)
|
|
|
|
|
|
async def get_current_admin(
|
|
current_user: dict = Depends(get_current_user)
|
|
) -> dict:
|
|
"""
|
|
获取当前管理员用户
|
|
|
|
Raises:
|
|
HTTPException 403: 非管理员
|
|
"""
|
|
if current_user.get("role") != "admin":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="需要管理员权限"
|
|
)
|
|
return current_user
|