178 lines
7.2 KiB
Python
178 lines
7.2 KiB
Python
from fastapi import APIRouter, Depends, HTTPException
|
||
from sqlalchemy.orm import Session
|
||
from app import schemas, models, database
|
||
from passlib.context import CryptContext
|
||
|
||
router = APIRouter()
|
||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||
|
||
def get_db():
|
||
db = database.SessionLocal()
|
||
try:
|
||
yield db
|
||
finally:
|
||
db.close()
|
||
|
||
from typing import List
|
||
|
||
@router.get("/", response_model=List[schemas.UserOut])
|
||
def list_users(db: Session = Depends(get_db)):
|
||
return db.query(models.User).all()
|
||
|
||
|
||
def get_password_hash(password):
|
||
return pwd_context.hash(password)
|
||
|
||
@router.post("/register", response_model=schemas.UserOut)
|
||
def register(user: schemas.UserCreate, db: Session = Depends(get_db)):
|
||
try:
|
||
# 检查用户名是否已存在
|
||
db_user = db.query(models.User).filter(models.User.username == user.username).first()
|
||
if db_user:
|
||
print(f"用户名已存在: {user.username}")
|
||
raise HTTPException(status_code=400, detail="用户名已被注册,请更换用户名")
|
||
|
||
hashed_password = get_password_hash(user.password)
|
||
new_user = models.User(
|
||
username=user.username,
|
||
hashed_password=hashed_password,
|
||
balance=user.balance,
|
||
is_admin=user.is_admin
|
||
)
|
||
db.add(new_user)
|
||
db.commit()
|
||
db.refresh(new_user)
|
||
print(f"用户注册成功: {user.username}, 初始余额: {user.balance}, 管理员权限: {user.is_admin}")
|
||
return new_user
|
||
except HTTPException:
|
||
# 已处理的HTTP异常直接抛出
|
||
raise
|
||
except Exception as e:
|
||
print("注册用户出错:", e)
|
||
raise HTTPException(status_code=500, detail=f"注册失败: {e}")
|
||
|
||
|
||
def verify_password(plain_password, hashed_password):
|
||
return pwd_context.verify(plain_password, hashed_password)
|
||
|
||
from fastapi.security import OAuth2PasswordRequestForm
|
||
from app.utils_jwt import create_access_token, get_current_user
|
||
|
||
@router.post("/login")
|
||
def login(user: schemas.UserCreate, db: Session = Depends(get_db)):
|
||
# 检查用户是否存在
|
||
db_user = db.query(models.User).filter(models.User.username == user.username).first()
|
||
if not db_user:
|
||
print(f"登录失败 - 用户不存在: {user.username}")
|
||
raise HTTPException(status_code=400, detail="用户名或密码错误")
|
||
|
||
# 验证密码
|
||
if not verify_password(user.password, db_user.hashed_password):
|
||
print(f"登录失败 - 密码错误: {user.username}")
|
||
raise HTTPException(status_code=400, detail="用户名或密码错误")
|
||
|
||
# 登录成功,生成token
|
||
token = create_access_token({"sub": str(db_user.id)})
|
||
print(f"用户登录成功: {user.username}, 余额: {db_user.balance}")
|
||
return {"access_token": token, "token_type": "bearer", "user": {"id": db_user.id, "username": db_user.username, "balance": db_user.balance, "is_admin": db_user.is_admin}}
|
||
|
||
# 支持OAuth2标准token获取
|
||
@router.post("/token")
|
||
def token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||
db_user = db.query(models.User).filter(models.User.username == form_data.username).first()
|
||
if not db_user or not verify_password(form_data.password, db_user.hashed_password):
|
||
raise HTTPException(status_code=400, detail="Incorrect username or password")
|
||
token = create_access_token({"sub": str(db_user.id)})
|
||
return {"access_token": token, "token_type": "bearer"}
|
||
|
||
|
||
# 管理员创建用户(包括设置余额和权限)
|
||
@router.post("/create", response_model=schemas.UserOut)
|
||
def create_user(user: schemas.UserCreate, db: Session = Depends(get_db), current_user=Depends(get_current_user)):
|
||
# 检查当前用户是否为管理员
|
||
if not getattr(current_user, "is_admin", False):
|
||
raise HTTPException(status_code=403, detail="无权限,只有管理员可以创建用户")
|
||
|
||
try:
|
||
# 检查用户名是否已存在
|
||
db_user = db.query(models.User).filter(models.User.username == user.username).first()
|
||
if db_user:
|
||
raise HTTPException(status_code=400, detail="用户名已被注册,请更换用户名")
|
||
|
||
hashed_password = get_password_hash(user.password)
|
||
new_user = models.User(
|
||
username=user.username,
|
||
hashed_password=hashed_password,
|
||
balance=user.balance,
|
||
is_admin=user.is_admin
|
||
)
|
||
db.add(new_user)
|
||
db.commit()
|
||
db.refresh(new_user)
|
||
print(f"管理员创建用户成功: {user.username}, 初始余额: {user.balance}, 管理员权限: {user.is_admin}")
|
||
return new_user
|
||
except HTTPException:
|
||
# 已处理的HTTP异常直接抛出
|
||
raise
|
||
except Exception as e:
|
||
print("创建用户出错:", e)
|
||
raise HTTPException(status_code=500, detail=f"创建用户失败: {e}")
|
||
|
||
|
||
# 管理员更新用户信息
|
||
@router.put("/update/{user_id}", response_model=dict)
|
||
def update_user(user_id: int, user_update: schemas.UserUpdate, db: Session = Depends(get_db), current_user=Depends(get_current_user)):
|
||
# 检查当前用户是否为管理员
|
||
if not getattr(current_user, "is_admin", False):
|
||
raise HTTPException(status_code=403, detail="无权限,只有管理员可以更新用户信息")
|
||
|
||
try:
|
||
# 查找用户
|
||
db_user = db.query(models.User).filter(models.User.id == user_id).first()
|
||
if not db_user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
# 更新用户信息
|
||
update_data = user_update.dict(exclude_unset=True)
|
||
for field, value in update_data.items():
|
||
# 如果字段值不为Null,才更新
|
||
if value is not None:
|
||
setattr(db_user, field, value)
|
||
|
||
db.commit()
|
||
print(f"管理员更新用户成功: {db_user.username}, 余额: {db_user.balance}, 管理员权限: {db_user.is_admin}")
|
||
return {"msg": "更新成功", "id": db_user.id}
|
||
except HTTPException:
|
||
# 已处理的HTTP异常直接抛出
|
||
raise
|
||
except Exception as e:
|
||
print("更新用户出错:", e)
|
||
raise HTTPException(status_code=500, detail=f"更新用户失败: {e}")
|
||
|
||
|
||
# 管理员删除用户
|
||
@router.delete("/delete/{user_id}", response_model=dict)
|
||
def delete_user(user_id: int, db: Session = Depends(get_db), current_user=Depends(get_current_user)):
|
||
# 检查当前用户是否为管理员
|
||
if not getattr(current_user, "is_admin", False):
|
||
raise HTTPException(status_code=403, detail="无权限,只有管理员可以删除用户")
|
||
|
||
try:
|
||
# 查找用户
|
||
db_user = db.query(models.User).filter(models.User.id == user_id).first()
|
||
if not db_user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
# 删除用户
|
||
username = db_user.username # 保存用户名以便记录
|
||
db.delete(db_user)
|
||
db.commit()
|
||
print(f"管理员删除用户成功: {username}")
|
||
return {"msg": "删除成功"}
|
||
except HTTPException:
|
||
# 已处理的HTTP异常直接抛出
|
||
raise
|
||
except Exception as e:
|
||
print("删除用户出错:", e)
|
||
raise HTTPException(status_code=500, detail=f"删除用户失败: {e}")
|