278 lines
8.2 KiB
Python
278 lines
8.2 KiB
Python
"""
|
||
认证 API:注册、登录、登出、修改密码
|
||
"""
|
||
from fastapi import APIRouter, HTTPException, Response, status, Request, Depends
|
||
from pydantic import BaseModel, field_validator
|
||
from app.core.security import (
|
||
get_password_hash,
|
||
verify_password,
|
||
create_access_token,
|
||
generate_session_token,
|
||
set_auth_cookie,
|
||
clear_auth_cookie,
|
||
decode_access_token
|
||
)
|
||
from app.repositories.sessions import create_session, delete_sessions
|
||
from app.repositories.users import (
|
||
create_user,
|
||
get_user_by_id,
|
||
get_user_by_phone,
|
||
user_exists_by_phone,
|
||
update_user,
|
||
deactivate_user_if_expired,
|
||
)
|
||
from app.core.deps import get_current_user
|
||
from app.core.response import success_response
|
||
from loguru import logger
|
||
from typing import Optional, Any, cast
|
||
import re
|
||
|
||
router = APIRouter(prefix="/api/auth", tags=["认证"])
|
||
|
||
|
||
class RegisterRequest(BaseModel):
|
||
phone: str
|
||
password: str
|
||
username: Optional[str] = None
|
||
|
||
@field_validator('phone')
|
||
@classmethod
|
||
def validate_phone(cls, v):
|
||
if not re.match(r'^\d{11}$', v):
|
||
raise ValueError('手机号必须是11位数字')
|
||
return v
|
||
|
||
|
||
class LoginRequest(BaseModel):
|
||
phone: str
|
||
password: str
|
||
|
||
@field_validator('phone')
|
||
@classmethod
|
||
def validate_phone(cls, v):
|
||
if not re.match(r'^\d{11}$', v):
|
||
raise ValueError('手机号必须是11位数字')
|
||
return v
|
||
|
||
|
||
class ChangePasswordRequest(BaseModel):
|
||
old_password: str
|
||
new_password: str
|
||
|
||
@field_validator('new_password')
|
||
@classmethod
|
||
def validate_new_password(cls, v):
|
||
if len(v) < 6:
|
||
raise ValueError('新密码长度至少6位')
|
||
return v
|
||
|
||
|
||
class UserResponse(BaseModel):
|
||
id: str
|
||
phone: str
|
||
username: Optional[str]
|
||
role: str
|
||
is_active: bool
|
||
expires_at: Optional[str] = None
|
||
|
||
|
||
@router.post("/register")
|
||
async def register(request: RegisterRequest):
|
||
"""
|
||
用户注册
|
||
|
||
注册后状态为 pending,需要管理员激活
|
||
"""
|
||
try:
|
||
if user_exists_by_phone(request.phone):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="该手机号已注册"
|
||
)
|
||
|
||
# 创建用户
|
||
password_hash = get_password_hash(request.password)
|
||
|
||
create_user({
|
||
"phone": request.phone,
|
||
"password_hash": password_hash,
|
||
"username": request.username or f"用户{request.phone[-4:]}",
|
||
"role": "pending",
|
||
"is_active": False
|
||
})
|
||
|
||
logger.info(f"新用户注册: {request.phone}")
|
||
|
||
return success_response(message="注册成功,请等待管理员审核激活")
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"注册失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="注册失败,请稍后重试"
|
||
)
|
||
|
||
|
||
@router.post("/login")
|
||
async def login(request: LoginRequest, response: Response):
|
||
"""
|
||
用户登录
|
||
|
||
- 验证密码
|
||
- 检查是否激活
|
||
- 实现"后踢前"单设备登录
|
||
"""
|
||
try:
|
||
user = cast(dict[str, Any], get_user_by_phone(request.phone) or {})
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="手机号或密码错误"
|
||
)
|
||
|
||
# 验证密码
|
||
if not verify_password(request.password, user["password_hash"]):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="手机号或密码错误"
|
||
)
|
||
|
||
# 授权过期时自动停用账号
|
||
if deactivate_user_if_expired(user):
|
||
delete_sessions(user["id"])
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="会员已到期,请续费"
|
||
)
|
||
|
||
# 检查是否激活
|
||
if not user["is_active"]:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="账号未激活,请等待管理员审核"
|
||
)
|
||
|
||
# 生成新的 session_token (后踢前)
|
||
session_token = generate_session_token()
|
||
|
||
# 删除旧 session,插入新 session
|
||
delete_sessions(user["id"])
|
||
create_session(user["id"], session_token, None)
|
||
|
||
# 生成 JWT Token
|
||
token = create_access_token(user["id"], session_token)
|
||
|
||
# 设置 HttpOnly Cookie
|
||
set_auth_cookie(response, token)
|
||
|
||
logger.info(f"用户登录: {request.phone}")
|
||
|
||
return success_response(
|
||
data={
|
||
"user": UserResponse(
|
||
id=user["id"],
|
||
phone=user["phone"],
|
||
username=user.get("username"),
|
||
role=user["role"],
|
||
is_active=user["is_active"],
|
||
expires_at=user.get("expires_at")
|
||
).model_dump()
|
||
},
|
||
message="登录成功",
|
||
)
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"登录失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="登录失败,请稍后重试"
|
||
)
|
||
|
||
|
||
@router.post("/logout")
|
||
async def logout(response: Response):
|
||
"""用户登出"""
|
||
clear_auth_cookie(response)
|
||
return success_response(message="已登出")
|
||
|
||
|
||
@router.post("/change-password")
|
||
async def change_password(request: ChangePasswordRequest, req: Request, response: Response):
|
||
"""
|
||
修改密码
|
||
|
||
- 验证当前密码
|
||
- 设置新密码
|
||
- 重新生成 session token
|
||
"""
|
||
# 从 Cookie 获取用户
|
||
token = req.cookies.get("access_token")
|
||
if not token:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="未登录"
|
||
)
|
||
|
||
token_data = decode_access_token(token)
|
||
if not token_data:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="Token 无效"
|
||
)
|
||
|
||
try:
|
||
user = cast(dict[str, Any], get_user_by_id(token_data.user_id) or {})
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="用户不存在"
|
||
)
|
||
|
||
# 验证当前密码
|
||
if not verify_password(request.old_password, user["password_hash"]):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="当前密码错误"
|
||
)
|
||
|
||
# 更新密码
|
||
new_password_hash = get_password_hash(request.new_password)
|
||
update_user(user["id"], {"password_hash": new_password_hash})
|
||
|
||
# 生成新的 session token,使旧 token 失效
|
||
new_session_token = generate_session_token()
|
||
|
||
delete_sessions(user["id"])
|
||
create_session(user["id"], new_session_token, None)
|
||
|
||
# 生成新的 JWT Token
|
||
new_token = create_access_token(user["id"], new_session_token)
|
||
set_auth_cookie(response, new_token)
|
||
|
||
logger.info(f"用户修改密码: {user['phone']}")
|
||
|
||
return success_response(message="密码修改成功")
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"修改密码失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="修改密码失败,请稍后重试"
|
||
)
|
||
|
||
|
||
@router.get("/me")
|
||
async def get_me(user: dict = Depends(get_current_user)):
|
||
"""获取当前用户信息"""
|
||
return success_response(UserResponse(
|
||
id=user["id"],
|
||
phone=user["phone"],
|
||
username=user.get("username"),
|
||
role=user["role"],
|
||
is_active=user["is_active"],
|
||
expires_at=user.get("expires_at")
|
||
).model_dump())
|