Init: 导入源码
This commit is contained in:
29
backend/README.md
Normal file
29
backend/README.md
Normal file
@@ -0,0 +1,29 @@
|
||||
# Backend - FastAPI
|
||||
|
||||
## 简介
|
||||
本目录为后端服务,基于 FastAPI 框架,负责用户、余额、AI 应用等 API。
|
||||
|
||||
## 运行方式
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
python -m venv venv
|
||||
source venv/bin/activate # Windows 下用 venv\Scripts\activate
|
||||
pip install -r requirements.txt
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
## 目录结构
|
||||
```
|
||||
backend/
|
||||
├── app/
|
||||
│ ├── main.py
|
||||
│ ├── models.py
|
||||
│ ├── schemas.py
|
||||
│ ├── database.py
|
||||
│ └── routers/
|
||||
│ ├── users.py
|
||||
│ └── balance.py
|
||||
├── requirements.txt
|
||||
└── README.md
|
||||
```
|
||||
BIN
backend/aiplatform.db
Normal file
BIN
backend/aiplatform.db
Normal file
Binary file not shown.
18
backend/app/create_admin.py
Normal file
18
backend/app/create_admin.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# 文件已删除,admin 创建脚本不再需要。
|
||||
from passlib.context import CryptContext
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
db = database.SessionLocal()
|
||||
|
||||
# 删除已有 admin
|
||||
admin = db.query(models.User).filter(models.User.username == "admin").first()
|
||||
if admin:
|
||||
db.delete(admin)
|
||||
db.commit()
|
||||
|
||||
hashed_password = pwd_context.hash("admin123")
|
||||
admin_user = models.User(username="admin", hashed_password=hashed_password, is_admin=True, is_active=True, balance=100)
|
||||
db.add(admin_user)
|
||||
db.commit()
|
||||
db.close()
|
||||
print("admin 管理员已插入!")
|
||||
17
backend/app/create_history_table.py
Normal file
17
backend/app/create_history_table.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import sqlite3
|
||||
|
||||
conn = sqlite3.connect("app.db")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS history (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER,
|
||||
type VARCHAR,
|
||||
amount FLOAT,
|
||||
desc VARCHAR,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
''')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print("history 表创建成功!")
|
||||
12
backend/app/database.py
Normal file
12
backend/app/database.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# 统一使用一个数据库文件
|
||||
import os
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
SQLALCHEMY_DATABASE_URL = f"sqlite:///{os.path.join(BASE_DIR, '../aiplatform.db')}"
|
||||
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base = declarative_base()
|
||||
216
backend/app/main.py
Normal file
216
backend/app/main.py
Normal file
@@ -0,0 +1,216 @@
|
||||
from fastapi import FastAPI, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from app.routers import users, balance, apps, history, orders, twitter, twitter_post, news_stock, ai_chatbot # 统一导入所有路由模块
|
||||
from app import models, database
|
||||
from passlib.context import CryptContext
|
||||
from sqlalchemy.orm import Session
|
||||
import requests
|
||||
import logging
|
||||
from openai import OpenAI
|
||||
|
||||
app = FastAPI(title="AI Platform API")
|
||||
|
||||
# 自动建表,确保所有表都存在
|
||||
models.Base.metadata.create_all(bind=database.engine)
|
||||
|
||||
# 加密工具
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# 初始化数据
|
||||
@app.on_event("startup")
|
||||
async def init_db_data():
|
||||
"""在应用启动时初始化必要的数据"""
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
# 1. 创建admin用户(如果不存在)
|
||||
admin_user = db.query(models.User).filter(models.User.username == "admin").first()
|
||||
if not admin_user:
|
||||
print("创建admin用户...")
|
||||
hashed_password = pwd_context.hash("admin123")
|
||||
admin_user = models.User(
|
||||
username="admin",
|
||||
hashed_password=hashed_password,
|
||||
is_admin=True,
|
||||
is_active=True,
|
||||
balance=100
|
||||
)
|
||||
db.add(admin_user)
|
||||
db.commit()
|
||||
db.refresh(admin_user)
|
||||
print("admin用户创建成功")
|
||||
|
||||
# 2. 创建初始应用(如果应用表为空)
|
||||
app_count = db.query(models.App).count()
|
||||
if app_count == 0:
|
||||
print("创建初始应用...")
|
||||
default_apps = [
|
||||
models.App(name="Twitter推文摘要", desc="输入Twitter用户名,获取最近推文摘要", price=12, status="上架"),
|
||||
models.App(name="Twitter自动发推", desc="输入Twitter用户名,获取摘要并发送到Twitter", price=15, status="上架"),
|
||||
models.App(name="热点新闻选股", desc="分析热点新闻对股票的影响,提供选股建议", price=20, status="上架") # 添加热点新闻选股应用
|
||||
]
|
||||
db.add_all(default_apps)
|
||||
db.commit()
|
||||
print("初始应用创建成功")
|
||||
|
||||
# 3. 创建示例订单数据(如果订单表为空)
|
||||
order_count = db.query(models.Order).count()
|
||||
if order_count == 0:
|
||||
print("创建示例订单数据...")
|
||||
|
||||
# 获取用户和应用
|
||||
users = db.query(models.User).all()
|
||||
apps = db.query(models.App).all()
|
||||
|
||||
if users and apps:
|
||||
# 创建一些示例订单
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
default_orders = [
|
||||
models.Order(
|
||||
user_id=users[0].id,
|
||||
app_id=apps[0].id,
|
||||
type="应用调用",
|
||||
amount=apps[0].price,
|
||||
description="使用Twitter推文摘要服务",
|
||||
status="已完成",
|
||||
created_at=datetime.utcnow() - timedelta(days=5)
|
||||
),
|
||||
models.Order(
|
||||
user_id=users[0].id,
|
||||
app_id=apps[1].id,
|
||||
type="应用调用",
|
||||
amount=apps[1].price,
|
||||
description="使用Twitter自动发推服务",
|
||||
status="已完成",
|
||||
created_at=datetime.utcnow() - timedelta(days=2)
|
||||
)
|
||||
]
|
||||
|
||||
db.add_all(default_orders)
|
||||
db.commit()
|
||||
print("示例订单创建成功")
|
||||
|
||||
# 4. 创建示例历史记录数据(如果历史记录表为空)
|
||||
history_count = db.query(models.History).count()
|
||||
if history_count == 0:
|
||||
print("创建示例历史记录数据...")
|
||||
|
||||
# 获取用户
|
||||
users = db.query(models.User).all()
|
||||
|
||||
if users:
|
||||
# 创建一些示例历史记录
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
default_history = [
|
||||
models.History(
|
||||
user_id=users[0].id,
|
||||
type="recharge",
|
||||
amount=100,
|
||||
desc="账户充值",
|
||||
created_at=datetime.utcnow() - timedelta(days=10)
|
||||
),
|
||||
models.History(
|
||||
user_id=users[0].id,
|
||||
type="consume",
|
||||
amount=-12,
|
||||
desc="使用Twitter推文摘要服务",
|
||||
created_at=datetime.utcnow() - timedelta(days=5)
|
||||
),
|
||||
models.History(
|
||||
user_id=users[0].id,
|
||||
type="consume",
|
||||
amount=-15,
|
||||
desc="使用Twitter自动发推服务",
|
||||
created_at=datetime.utcnow() - timedelta(days=2)
|
||||
)
|
||||
]
|
||||
|
||||
db.add_all(default_history)
|
||||
db.commit()
|
||||
print("示例历史记录创建成功")
|
||||
except Exception as e:
|
||||
print(f"初始化数据时出错: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# 配置CORS
|
||||
# 使用统一的跨域配置
|
||||
allowed_origins = [
|
||||
# 生产环境
|
||||
"http://174.129.175.43:3000", # EC2实例1
|
||||
"http://44.206.227.249:3000", # EC2实例2
|
||||
"http://52.91.169.148:3000", # EC2实例3
|
||||
"https://rongye.xyz", # 新域名
|
||||
"http://rongye.xyz", # 新域名(http)
|
||||
|
||||
# 本地开发环境
|
||||
"http://localhost:3000",
|
||||
"http://127.0.0.1:3000",
|
||||
"http://localhost",
|
||||
"http://127.0.0.1",
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=allowed_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["*"] # 允许前端访问所有响应头
|
||||
)
|
||||
|
||||
# 注册路由 - 不再需要在这里添加/admin前缀,因为已经在router中定义了
|
||||
|
||||
app.include_router(users.router, prefix="/users", tags=["users"])
|
||||
app.include_router(balance.router, prefix="/balance", tags=["balance"])
|
||||
app.include_router(apps.router, prefix="/apps", tags=["apps"])
|
||||
app.include_router(history.router, prefix="/history", tags=["history"])
|
||||
app.include_router(orders.router, prefix="/orders", tags=["orders"])
|
||||
app.include_router(twitter.router, prefix="/twitter", tags=["twitter"])
|
||||
app.include_router(twitter_post.router, prefix="/twitter-post", tags=["twitter_post"])
|
||||
app.include_router(news_stock.router, prefix="/news-stock", tags=["news_stock"]) # 添加热点新闻选股路由
|
||||
app.include_router(ai_chatbot.router, prefix="/ai-chatbot", tags=["ai_chatbot"]) # 添加AI客服路由
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "Welcome to AI Platform API"}
|
||||
|
||||
@app.get("/test-twitter")
|
||||
async def test_twitter():
|
||||
try:
|
||||
# 测试Twitter API连接
|
||||
twitter_api_key = "e3dad005b0e54bdc88c6178a89adec13"
|
||||
twitter_api_url = "https://api.twitterapi.io/twitter/tweet/advanced_search"
|
||||
headers = {"X-API-Key": twitter_api_key}
|
||||
params = {
|
||||
"queryType": "Latest",
|
||||
"query": "from:elonmusk",
|
||||
"count": 5
|
||||
}
|
||||
|
||||
twitter_response = requests.get(twitter_api_url, headers=headers, params=params)
|
||||
|
||||
# 初始化OpenAI客户端
|
||||
openai_client = OpenAI(
|
||||
api_key="sk-8a121704a9bc4ec6a5ab0ae16e0bc0ba",
|
||||
base_url="https://api.deepseek.com"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"twitter_api_status": twitter_response.status_code,
|
||||
"twitter_api_response": twitter_response.json() if twitter_response.status_code == 200 else str(twitter_response.text)[:100],
|
||||
"apis_available": {
|
||||
"twitter_api": True,
|
||||
"openai_api": True
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"type": str(type(e))
|
||||
}
|
||||
53
backend/app/models.py
Normal file
53
backend/app/models.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from sqlalchemy import Column, Integer, String, Float, DateTime, ForeignKey, Boolean
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.database import Base
|
||||
from datetime import datetime
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
username = Column(String, unique=True, index=True)
|
||||
email = Column(String, unique=True, index=True, nullable=True)
|
||||
hashed_password = Column(String)
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_admin = Column(Boolean, default=False)
|
||||
balance = Column(Float, default=0)
|
||||
|
||||
# 添加关系
|
||||
orders = relationship("Order", back_populates="user")
|
||||
|
||||
class App(Base):
|
||||
__tablename__ = "apps"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, index=True)
|
||||
desc = Column(String)
|
||||
price = Column(Float)
|
||||
status = Column(String, default="上架") # 上架、下架
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# 添加关系
|
||||
orders = relationship("Order", back_populates="app")
|
||||
|
||||
class Order(Base):
|
||||
__tablename__ = "orders"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"))
|
||||
app_id = Column(Integer, ForeignKey("apps.id")) # 添加 app_id 字段
|
||||
type = Column(String, default="应用调用") # 订单类型
|
||||
amount = Column(Float) # 订单金额
|
||||
description = Column(String, nullable=True) # 订单描述
|
||||
status = Column(String) # 订单状态:待支付、已完成、已取消等
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
# 添加关系
|
||||
user = relationship("User", back_populates="orders")
|
||||
app = relationship("App", back_populates="orders")
|
||||
|
||||
class History(Base):
|
||||
__tablename__ = "history"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey('users.id'))
|
||||
type = Column(String) # 'recharge' or 'consume'
|
||||
amount = Column(Float)
|
||||
desc = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
344
backend/app/routers/admin.py
Normal file
344
backend/app/routers/admin.py
Normal file
@@ -0,0 +1,344 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from app import models, database
|
||||
from pydantic import BaseModel
|
||||
from app.utils_jwt import create_access_token, get_current_user
|
||||
from passlib.context import CryptContext
|
||||
from typing import List, Optional
|
||||
|
||||
# 创建路由器,设置统一的前缀和标签
|
||||
router = APIRouter(
|
||||
prefix="/admin/api",
|
||||
tags=["admin"]
|
||||
)
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# 获取数据库会话
|
||||
def get_db():
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# 管理员登录请求模型
|
||||
class AdminLoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
@router.post("/login")
|
||||
async def admin_login(req: AdminLoginRequest, db: Session = Depends(get_db)):
|
||||
print(f"=== Admin login attempt: {req.username} ===")
|
||||
|
||||
# 特殊处理admin用户,确保它存在且正确设置
|
||||
if req.username == "admin":
|
||||
admin_user = db.query(models.User).filter(models.User.username == "admin").first()
|
||||
if not admin_user:
|
||||
hashed_password = pwd_context.hash("admin123")
|
||||
admin_user = models.User(
|
||||
username="admin",
|
||||
hashed_password=hashed_password,
|
||||
is_admin=True,
|
||||
is_active=True,
|
||||
balance=100
|
||||
)
|
||||
db.add(admin_user)
|
||||
db.commit()
|
||||
db.refresh(admin_user)
|
||||
print("admin用户创建成功")
|
||||
|
||||
user = db.query(models.User).filter(models.User.username == req.username).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="账号或密码错误")
|
||||
|
||||
if not pwd_context.verify(req.password, user.hashed_password):
|
||||
raise HTTPException(status_code=401, detail="账号或密码错误")
|
||||
|
||||
if not user.is_admin:
|
||||
raise HTTPException(status_code=401, detail="该账户没有管理员权限")
|
||||
|
||||
token = create_access_token({"sub": str(user.id), "is_admin": True})
|
||||
print(f"登录成功 - 用户ID: {user.id}, 用户名: {user.username}")
|
||||
return {
|
||||
"token": token,
|
||||
"user": {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"is_admin": True
|
||||
}
|
||||
}
|
||||
|
||||
# 管理员权限依赖
|
||||
async def admin_required(user=Depends(get_current_user)):
|
||||
if not getattr(user, "is_admin", False):
|
||||
raise HTTPException(status_code=403, detail="无权限")
|
||||
return user
|
||||
|
||||
# 应用管理
|
||||
class AppCreateRequest(BaseModel):
|
||||
name: str
|
||||
desc: str
|
||||
price: float
|
||||
status: str = "上架"
|
||||
|
||||
@router.get("/apps")
|
||||
def get_apps(db: Session = Depends(get_db), _=Depends(admin_required)):
|
||||
apps = db.query(models.App).all()
|
||||
return {"code": 0, "msg": "success", "data": [
|
||||
{
|
||||
"id": app.id,
|
||||
"name": app.name,
|
||||
"desc": app.desc,
|
||||
"price": app.price,
|
||||
"status": app.status
|
||||
}
|
||||
for app in apps
|
||||
]}
|
||||
|
||||
@router.post("/apps")
|
||||
def add_app(req: AppCreateRequest, db: Session = Depends(get_db), _=Depends(admin_required)):
|
||||
if db.query(models.App).filter(models.App.name == req.name).first():
|
||||
return {"code": 1, "msg": "应用已存在"}
|
||||
new_app = models.App(
|
||||
name=req.name, desc=req.desc, price=req.price, status=req.status
|
||||
)
|
||||
db.add(new_app)
|
||||
db.commit()
|
||||
db.refresh(new_app)
|
||||
# 操作日志
|
||||
# db.add(models.Log(action="add_app", detail=f"添加应用 {req.name}"))
|
||||
# db.commit()
|
||||
return {"code": 0, "msg": "应用创建成功", "data": {"id": new_app.id}}
|
||||
|
||||
@router.put("/apps/{app_id}")
|
||||
def edit_app(app_id: int, req: AppCreateRequest, db: Session = Depends(get_db), _=Depends(admin_required)):
|
||||
app = db.query(models.App).filter(models.App.id == app_id).first()
|
||||
if not app:
|
||||
return {"code": 1, "msg": "应用不存在"}
|
||||
app.name = req.name
|
||||
app.desc = req.desc
|
||||
app.price = req.price
|
||||
app.status = req.status
|
||||
db.commit()
|
||||
# db.add(models.Log(action="edit_app", detail=f"修改应用 {app_id}"))
|
||||
# db.commit()
|
||||
return {"code": 0, "msg": "应用修改成功"}
|
||||
|
||||
@router.delete("/apps/{app_id}")
|
||||
def delete_app(app_id: int, db: Session = Depends(get_db), _=Depends(admin_required)):
|
||||
app = db.query(models.App).filter(models.App.id == app_id).first()
|
||||
if not app:
|
||||
return {"code": 1, "msg": "应用不存在"}
|
||||
db.delete(app)
|
||||
db.commit()
|
||||
# db.add(models.Log(action="delete_app", detail=f"删除应用 {app_id}"))
|
||||
# db.commit()
|
||||
return {"code": 0, "msg": "应用删除成功"}
|
||||
|
||||
|
||||
# 用户管理
|
||||
@router.get("/users")
|
||||
def get_users(db: Session = Depends(get_db), _=Depends(admin_required)):
|
||||
users = db.query(models.User).all()
|
||||
return {"code": 0, "msg": "success", "data": [
|
||||
{
|
||||
"id": u.id,
|
||||
"username": u.username,
|
||||
"email": getattr(u, "email", None),
|
||||
"is_admin": getattr(u, "is_admin", False),
|
||||
"status": "正常" if getattr(u, "is_active", True) else "禁用"
|
||||
}
|
||||
for u in users
|
||||
]}
|
||||
|
||||
@router.post("/users")
|
||||
def add_user(req: UserCreateRequest, db: Session = Depends(get_db), _=Depends(admin_required)):
|
||||
if db.query(models.User).filter(models.User.username == req.username).first():
|
||||
return {"code": 1, "msg": "用户名已存在"}
|
||||
hashed_password = pwd_context.hash(req.password)
|
||||
new_user = models.User(
|
||||
username=req.username,
|
||||
hashed_password=hashed_password,
|
||||
email=req.email,
|
||||
is_active=True
|
||||
)
|
||||
db.add(new_user)
|
||||
db.commit()
|
||||
db.refresh(new_user)
|
||||
# db.add(models.Log(action="add_user", detail=f"添加用户 {req.username}"))
|
||||
# db.commit()
|
||||
return {"code": 0, "msg": "用户创建成功", "data": {"id": new_user.id}}
|
||||
|
||||
@router.put("/users/{user_id}")
|
||||
def edit_user(user_id: int, req: UserCreateRequest, db: Session = Depends(get_db), _=Depends(admin_required)):
|
||||
db_user = db.query(models.User).filter(models.User.id == user_id).first()
|
||||
if not db_user:
|
||||
return {"code": 1, "msg": "用户不存在"}
|
||||
db_user.username = req.username
|
||||
if req.password:
|
||||
db_user.hashed_password = pwd_context.hash(req.password)
|
||||
if req.email:
|
||||
db_user.email = req.email
|
||||
db.commit()
|
||||
# db.add(models.Log(action="edit_user", detail=f"修改用户 {user_id}"))
|
||||
# db.commit()
|
||||
return {"code": 0, "msg": "用户修改成功"}
|
||||
|
||||
@router.delete("/users/{user_id}")
|
||||
def delete_user(user_id: int, db: Session = Depends(get_db), _=Depends(admin_required)):
|
||||
db_user = db.query(models.User).filter(models.User.id == user_id).first()
|
||||
if not db_user:
|
||||
return {"code": 1, "msg": "用户不存在"}
|
||||
db.delete(db_user)
|
||||
db.commit()
|
||||
# db.add(models.Log(action="delete_user", detail=f"删除用户 {user_id}"))
|
||||
# db.commit()
|
||||
return {"code": 0, "msg": "用户删除成功"}
|
||||
|
||||
@router.put("/users/{user_id}/status")
|
||||
def update_user_status(user_id: int, status: str, db: Session = Depends(get_db), _=Depends(admin_required)):
|
||||
db_user = db.query(models.User).filter(models.User.id == user_id).first()
|
||||
if not db_user:
|
||||
return {"code": 1, "msg": "用户不存在"}
|
||||
db_user.is_active = (status == "正常")
|
||||
db.commit()
|
||||
# db.add(models.Log(action="update_user_status", detail=f"设置用户 {user_id} 状态为 {status}"))
|
||||
# db.commit()
|
||||
return {"code": 0, "msg": "状态更新成功"}
|
||||
|
||||
|
||||
class UserCreateRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
email: Optional[str] = None
|
||||
|
||||
@router.post("/users") # 修改路径,移除重复的admin
|
||||
async def add_user(req: UserCreateRequest, db: Session = Depends(get_db), _=Depends(admin_required)):
|
||||
if db.query(models.User).filter(models.User.username == req.username).first():
|
||||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||||
|
||||
hashed_password = pwd_context.hash(req.password)
|
||||
new_user = models.User(
|
||||
username=req.username,
|
||||
hashed_password=hashed_password,
|
||||
email=req.email,
|
||||
is_active=True
|
||||
)
|
||||
db.add(new_user)
|
||||
db.commit()
|
||||
db.refresh(new_user)
|
||||
return {"user": {"id": new_user.id, "username": new_user.username}}
|
||||
|
||||
@router.put("/users/{user_id}") # 修改路径,移除重复的admin
|
||||
async def edit_user(
|
||||
user_id: int,
|
||||
req: UserCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
_=Depends(admin_required)
|
||||
):
|
||||
db_user = db.query(models.User).filter(models.User.id == user_id).first()
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
db_user.username = req.username
|
||||
if req.password:
|
||||
db_user.hashed_password = pwd_context.hash(req.password)
|
||||
if req.email:
|
||||
db_user.email = req.email
|
||||
|
||||
db.commit()
|
||||
return {"msg": "修改成功"}
|
||||
|
||||
@router.delete("/users/{user_id}") # 修改路径,移除重复的admin
|
||||
async def delete_user(user_id: int, db: Session = Depends(get_db), _=Depends(admin_required)):
|
||||
db_user = db.query(models.User).filter(models.User.id == user_id).first()
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
db.delete(db_user)
|
||||
db.commit()
|
||||
return {"msg": "删除成功"}
|
||||
|
||||
# 订单管理
|
||||
@router.get("/orders")
|
||||
def get_orders(db: Session = Depends(get_db), _=Depends(admin_required)):
|
||||
orders = db.query(models.Order).all()
|
||||
return {"code": 0, "msg": "success", "data": [
|
||||
{
|
||||
"id": order.id,
|
||||
"user_id": order.user.username if order.user else None,
|
||||
"type": order.type,
|
||||
"amount": order.amount,
|
||||
"description": order.description,
|
||||
"created_at": order.created_at.strftime("%Y-%m-%d %H:%M:%S") if order.created_at else None,
|
||||
"status": order.status
|
||||
}
|
||||
for order in orders
|
||||
]}
|
||||
|
||||
@router.get("/orders/{order_id}")
|
||||
def order_detail(order_id: int, db: Session = Depends(get_db), _=Depends(admin_required)):
|
||||
order = db.query(models.Order).filter(models.Order.id == order_id).first()
|
||||
if not order:
|
||||
return {"code": 1, "msg": "订单不存在"}
|
||||
return {"code": 0, "msg": "success", "data": {
|
||||
"id": order.id,
|
||||
"user_id": order.user.username if order.user else None,
|
||||
"type": order.type,
|
||||
"amount": order.amount,
|
||||
"description": order.description,
|
||||
"created_at": order.created_at.strftime("%Y-%m-%d %H:%M:%S") if order.created_at else None,
|
||||
"status": order.status
|
||||
}}
|
||||
|
||||
|
||||
# 充值记录
|
||||
@router.get("/finance")
|
||||
def get_finance(db: Session = Depends(get_db), _=Depends(admin_required)):
|
||||
finance_records = db.query(models.Finance).all()
|
||||
user_map = {}
|
||||
user_ids = set(record.user_id for record in finance_records)
|
||||
users = db.query(models.User).filter(models.User.id.in_(user_ids)).all()
|
||||
for user in users:
|
||||
user_map[user.id] = user.username
|
||||
return {"code": 0, "msg": "success", "data": [
|
||||
{
|
||||
"id": record.id,
|
||||
"user_id": record.user_id,
|
||||
"username": user_map.get(record.user_id, "未知用户"),
|
||||
"amount": record.amount,
|
||||
"description": record.desc,
|
||||
"created_at": record.created_at.strftime("%Y-%m-%d %H:%M:%S") if record.created_at else None
|
||||
}
|
||||
for record in finance_records
|
||||
]}
|
||||
|
||||
|
||||
# 添加查询历史记录接口(包括充值记录)
|
||||
@router.get("/history")
|
||||
async def get_history(db: Session = Depends(get_db), _=Depends(admin_required)):
|
||||
"""获取所有历史记录,包括充值和消费"""
|
||||
history_records = db.query(models.History).all()
|
||||
|
||||
# 查询用户信息,用于显示用户名
|
||||
user_map = {}
|
||||
user_ids = set(record.user_id for record in history_records)
|
||||
users = db.query(models.User).filter(models.User.id.in_(user_ids)).all()
|
||||
for user in users:
|
||||
user_map[user.id] = user.username
|
||||
|
||||
return {
|
||||
"history": [
|
||||
{
|
||||
"id": record.id,
|
||||
"user_id": record.user_id,
|
||||
"username": user_map.get(record.user_id, "未知用户"),
|
||||
"type": "充值" if record.type == "recharge" else "消费",
|
||||
"amount": record.amount,
|
||||
"description": record.desc,
|
||||
"created_at": record.created_at.strftime("%Y-%m-%d %H:%M:%S") if record.created_at else None
|
||||
}
|
||||
for record in history_records
|
||||
]
|
||||
}
|
||||
641
backend/app/routers/ai_chatbot.py
Normal file
641
backend/app/routers/ai_chatbot.py
Normal file
@@ -0,0 +1,641 @@
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
import os
|
||||
import time
|
||||
import base64
|
||||
import hashlib
|
||||
import requests
|
||||
import hmac
|
||||
import urllib.parse
|
||||
import http.client
|
||||
from urllib.parse import urlencode
|
||||
import json
|
||||
from openai import OpenAI
|
||||
from pydub import AudioSegment
|
||||
import tempfile
|
||||
import logging
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 阿里云语音服务配置
|
||||
ALIYUN_ACCESS_KEY_ID = "LTAI5t5ZrbKQuuwkmQ1LFCBo"
|
||||
ALIYUN_ACCESS_KEY_SECRET = "2vvspr0HcmmnBFzpXw4iNyLafSgUuN"
|
||||
ALIYUN_APP_KEY = "wlIvC6tOAvQLoQDz"
|
||||
ALIYUN_REGION = "cn-shanghai"
|
||||
ALIYUN_HOST = "nls-gateway-cn-shanghai.aliyuncs.com"
|
||||
|
||||
# DeepSeek配置
|
||||
DEEPSEEK_API_KEY = "sk-8a121704a9bc4ec6a5ab0ae16e0bc0ba"
|
||||
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
|
||||
|
||||
# 音频文件存储目录
|
||||
AUDIO_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "audio")
|
||||
os.makedirs(AUDIO_DIR, exist_ok=True)
|
||||
|
||||
# 全局变量存储token和过期时间
|
||||
token_info = {
|
||||
'token': None,
|
||||
'expire_time': 0
|
||||
}
|
||||
|
||||
# 添加线程池用于异步处理
|
||||
executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
# 简单的内存缓存
|
||||
response_cache = {}
|
||||
cache_expire_time = {}
|
||||
|
||||
def get_signature(secret, text):
|
||||
"""生成签名"""
|
||||
h = hmac.new(secret.encode('utf-8'), text.encode('utf-8'), hashlib.sha1)
|
||||
return base64.b64encode(h.digest()).decode('utf-8')
|
||||
|
||||
def get_aliyun_token(force_refresh=False):
|
||||
"""获取阿里云访问令牌,带缓存和自动刷新"""
|
||||
global token_info
|
||||
|
||||
# 检查token是否有效(提前5分钟刷新)
|
||||
current_time = int(time.time())
|
||||
if not force_refresh and token_info['token'] and token_info['expire_time'] > current_time + 300:
|
||||
logger.info("使用缓存的阿里云Token")
|
||||
return token_info['token']
|
||||
|
||||
logger.info("正在获取阿里云访问令牌..." + (" (强制刷新)" if force_refresh else ""))
|
||||
|
||||
# 构建请求参数
|
||||
params = {
|
||||
"Action": "CreateToken",
|
||||
"Version": "2019-02-28",
|
||||
"Format": "JSON",
|
||||
"AccessKeyId": ALIYUN_ACCESS_KEY_ID,
|
||||
"Timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
"SignatureMethod": "HMAC-SHA1",
|
||||
"SignatureVersion": "1.0",
|
||||
"SignatureNonce": str(int(time.time() * 1000))
|
||||
}
|
||||
|
||||
# 对参数进行排序并生成查询字符串
|
||||
sorted_params = sorted(params.items(), key=lambda x: x[0])
|
||||
query_string = urlencode(dict(sorted_params))
|
||||
|
||||
# 构建待签名字符串
|
||||
string_to_sign = f"POST&%2F&{urllib.parse.quote_plus(query_string)}"
|
||||
|
||||
# 计算签名
|
||||
signature = get_signature(ALIYUN_ACCESS_KEY_SECRET + "&", string_to_sign)
|
||||
params["Signature"] = signature
|
||||
|
||||
try:
|
||||
conn = http.client.HTTPSConnection("nls-meta.cn-shanghai.aliyuncs.com")
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
# 发送请求
|
||||
conn.request("POST", "/", headers=headers, body=urlencode(params))
|
||||
response = conn.getresponse()
|
||||
result = json.loads(response.read().decode('utf-8'))
|
||||
|
||||
if 'Token' in result and 'Id' in result['Token']:
|
||||
# 更新token信息,设置过期时间为55分钟后
|
||||
token_info['token'] = result['Token']['Id']
|
||||
token_info['expire_time'] = int(time.time()) + 3300 # 55分钟
|
||||
logger.info("获取阿里云Token成功")
|
||||
return token_info['token']
|
||||
else:
|
||||
logger.error(f"获取阿里云Token失败: {result}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取阿里云Token时发生异常: {str(e)}")
|
||||
return None
|
||||
finally:
|
||||
if 'conn' in locals():
|
||||
conn.close()
|
||||
|
||||
def aliyun_tts_chinese_text_to_audio(chinese_text, output_audio_path):
|
||||
"""使用阿里云TTS接口进行中文语音合成"""
|
||||
token = get_aliyun_token()
|
||||
if not token:
|
||||
raise Exception("无法获取阿里云访问令牌")
|
||||
|
||||
logger.info(f"正在合成语音: {chinese_text[:50]}...")
|
||||
|
||||
try:
|
||||
conn = http.client.HTTPSConnection(ALIYUN_HOST, timeout=10)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-NLS-Token": token
|
||||
}
|
||||
|
||||
# 使用更自然的音色和参数配置
|
||||
tts_data = {
|
||||
"appkey": ALIYUN_APP_KEY,
|
||||
"text": chinese_text[:100], # 限制长度避免过长
|
||||
"format": "wav",
|
||||
"voice": "aixia", # 使用艾夏音色,符合智能语音交互2.0配置
|
||||
"volume": 50,
|
||||
"speech_rate": 0, # 正常语速
|
||||
"pitch_rate": 0, # 正常语调
|
||||
"sample_rate": 16000
|
||||
}
|
||||
|
||||
conn.request("POST", "/stream/v1/tts",
|
||||
body=json.dumps(tts_data),
|
||||
headers=headers)
|
||||
|
||||
response = conn.getresponse()
|
||||
|
||||
if response.status == 200:
|
||||
# 检查响应类型
|
||||
content_type = response.getheader('Content-Type', '')
|
||||
response_data = response.read()
|
||||
|
||||
if 'audio' in content_type:
|
||||
# 是音频数据
|
||||
with open(output_audio_path, 'wb') as f:
|
||||
f.write(response_data)
|
||||
logger.info(f"语音合成成功,已保存到: {output_audio_path}, 大小: {len(response_data)} bytes")
|
||||
return True
|
||||
else:
|
||||
# 可能是错误信息
|
||||
try:
|
||||
error_info = json.loads(response_data.decode('utf-8'))
|
||||
logger.error(f"阿里云TTS返回错误信息: {json.dumps(error_info, indent=2, ensure_ascii=False)}")
|
||||
except:
|
||||
logger.error(f"阿里云TTS返回非JSON响应: {response_data[:500]}")
|
||||
return False
|
||||
else:
|
||||
error_message = response.read().decode('utf-8')
|
||||
logger.error(f"TTS请求失败: {response.status} {response.reason}, 错误: {error_message}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TTS合成时发生异常: {str(e)}")
|
||||
return False
|
||||
finally:
|
||||
if 'conn' in locals():
|
||||
conn.close()
|
||||
|
||||
def convert_audio_format(input_file, target_sample_rate=16000, target_channels=1):
|
||||
"""转换音频格式到WAV格式,确保符合阿里云ASR要求"""
|
||||
try:
|
||||
logger.info(f"开始转换音频格式: {input_file}")
|
||||
|
||||
# 检查输入文件是否存在
|
||||
if not os.path.exists(input_file):
|
||||
logger.error(f"输入音频文件不存在: {input_file}")
|
||||
return None
|
||||
|
||||
# 加载音频文件,自动检测格式
|
||||
try:
|
||||
audio = AudioSegment.from_file(input_file)
|
||||
except Exception as e:
|
||||
logger.error(f"无法读取音频文件: {e}")
|
||||
return None
|
||||
|
||||
logger.info(f"原始音频信息 - 时长: {len(audio)}ms, 采样率: {audio.frame_rate}, 声道: {audio.channels}")
|
||||
|
||||
# 检查音频时长(阿里云ASR限制60秒)
|
||||
if len(audio) > 60000: # 60秒 = 60000毫秒
|
||||
logger.warning(f"音频时长({len(audio)/1000:.2f}s)超过60秒限制,将截取前60秒")
|
||||
audio = audio[:60000]
|
||||
|
||||
# 转换为目标格式:16kHz单声道WAV
|
||||
audio = audio.set_frame_rate(target_sample_rate).set_channels(target_channels)
|
||||
|
||||
# 确保是16位PCM编码
|
||||
audio = audio.set_sample_width(2) # 2字节 = 16位
|
||||
|
||||
# 生成输出文件名
|
||||
base_name = os.path.splitext(input_file)[0]
|
||||
output_file = f"{base_name}_converted.wav"
|
||||
|
||||
# 导出为WAV格式
|
||||
audio.export(output_file, format="wav", parameters=["-acodec", "pcm_s16le"])
|
||||
|
||||
logger.info(f"音频格式转换成功: {input_file} -> {output_file}")
|
||||
logger.info(f"转换后音频信息 - 时长: {len(audio)}ms, 采样率: {target_sample_rate}, 声道: {target_channels}")
|
||||
|
||||
return output_file
|
||||
except Exception as e:
|
||||
logger.error(f"音频格式转换失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def aliyun_asr_chinese_audio_to_text(audio_path):
|
||||
"""使用阿里云智能语音交互2.0 RESTful API进行中文语音识别"""
|
||||
max_retries = 3
|
||||
|
||||
for attempt in range(max_retries):
|
||||
# 获取token,如果是重试则强制刷新
|
||||
token = get_aliyun_token(force_refresh=(attempt > 0))
|
||||
if not token:
|
||||
logger.error("获取阿里云Token失败,无法继续识别")
|
||||
return ""
|
||||
|
||||
logger.info(f"正在识别音频: {audio_path} (尝试 {attempt + 1}/{max_retries})")
|
||||
|
||||
# 检查音频文件
|
||||
if not os.path.exists(audio_path):
|
||||
logger.error(f"音频文件不存在: {audio_path}")
|
||||
return ""
|
||||
|
||||
# 转换音频格式,确保符合阿里云ASR要求
|
||||
converted_audio = convert_audio_format(audio_path)
|
||||
if not converted_audio:
|
||||
logger.error("音频格式转换失败")
|
||||
return ""
|
||||
|
||||
conn = None
|
||||
try:
|
||||
# 读取转换后的音频文件
|
||||
with open(converted_audio, 'rb') as f:
|
||||
audio_data = f.read()
|
||||
|
||||
logger.info(f"音频文件大小: {len(audio_data)} bytes")
|
||||
|
||||
# 检查音频文件大小
|
||||
if len(audio_data) == 0:
|
||||
logger.error("音频文件为空")
|
||||
continue
|
||||
|
||||
# 使用阿里云智能语音交互2.0的RESTful API
|
||||
conn = http.client.HTTPSConnection(ALIYUN_HOST, timeout=30)
|
||||
|
||||
# 设置正确的请求头
|
||||
headers = {
|
||||
"X-NLS-Token": token,
|
||||
"Content-Type": "application/octet-stream",
|
||||
"Content-Length": str(len(audio_data)),
|
||||
"Host": ALIYUN_HOST
|
||||
}
|
||||
|
||||
# 构建请求参数,按照阿里云智能语音交互2.0文档
|
||||
params = {
|
||||
"appkey": ALIYUN_APP_KEY,
|
||||
"format": "wav",
|
||||
"sample_rate": 16000,
|
||||
"enable_punctuation_prediction": "true",
|
||||
"enable_inverse_text_normalization": "true",
|
||||
"enable_voice_detection": "false"
|
||||
}
|
||||
|
||||
# 构建完整的请求URL
|
||||
query_string = urllib.parse.urlencode(params)
|
||||
full_url = f"/stream/v1/asr?{query_string}"
|
||||
|
||||
logger.info(f"发送ASR请求,URL: {full_url}")
|
||||
logger.info(f"请求头: {headers}")
|
||||
|
||||
# 发送POST请求,直接传输二进制音频数据
|
||||
conn.request("POST", full_url, body=audio_data, headers=headers)
|
||||
|
||||
response = conn.getresponse()
|
||||
response_data = response.read()
|
||||
|
||||
logger.info(f"ASR响应状态: {response.status}")
|
||||
|
||||
if response.status == 200:
|
||||
try:
|
||||
result = json.loads(response_data.decode('utf-8'))
|
||||
logger.info(f"ASR响应结果: {result}")
|
||||
|
||||
# 检查响应状态
|
||||
status = result.get('status', 0)
|
||||
message = result.get('message', '')
|
||||
|
||||
if status == 20000000 and message == 'SUCCESS':
|
||||
transcription = result.get('result', '').strip()
|
||||
if transcription:
|
||||
logger.info(f"语音识别成功: {transcription}")
|
||||
return transcription
|
||||
else:
|
||||
logger.warning("识别成功但结果为空,可能是:1)音频内容为静音 2)语音不清晰 3)语言不匹配")
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("正在重试...")
|
||||
continue
|
||||
return ""
|
||||
else:
|
||||
logger.error(f"ASR识别失败: status={status}, message={message}")
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("正在重试...")
|
||||
continue
|
||||
return ""
|
||||
|
||||
except json.JSONDecodeError as json_error:
|
||||
logger.error(f"解析ASR响应JSON失败: {json_error}")
|
||||
logger.info(f"原始响应: {response_data[:1000]}")
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("正在重试...")
|
||||
continue
|
||||
return ""
|
||||
|
||||
elif response.status == 401: # 未授权,token可能已过期
|
||||
logger.warning("Token无效或已过期")
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("正在尝试刷新Token并重试...")
|
||||
continue
|
||||
else:
|
||||
logger.error("Token刷新重试次数已用完")
|
||||
return ""
|
||||
|
||||
elif response.status == 40000001:
|
||||
logger.error("身份认证失败,检查Token是否正确或过期")
|
||||
return ""
|
||||
|
||||
elif response.status == 40000003:
|
||||
logger.error("参数无效,检查音频格式和采样率")
|
||||
return ""
|
||||
|
||||
elif response.status == 41010101:
|
||||
logger.error("不支持的采样率,当前仅支持8000Hz和16000Hz")
|
||||
return ""
|
||||
|
||||
else:
|
||||
error_message = response_data.decode('utf-8', errors='ignore')
|
||||
logger.error(f"ASR识别失败: 状态码{response.status}, 错误信息: {error_message}")
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("正在重试...")
|
||||
continue
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ASR识别时发生异常: {str(e)}")
|
||||
if attempt < max_retries - 1:
|
||||
logger.info("正在重试...")
|
||||
continue
|
||||
return ""
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if converted_audio and os.path.exists(converted_audio) and converted_audio != audio_path:
|
||||
try:
|
||||
os.remove(converted_audio)
|
||||
logger.info(f"已清理临时文件: {converted_audio}")
|
||||
except:
|
||||
pass
|
||||
|
||||
# 关闭连接
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
# 所有重试都失败了
|
||||
logger.error("所有ASR重试都失败")
|
||||
return ""
|
||||
|
||||
def get_cache_key(question: str) -> str:
|
||||
"""生成缓存键"""
|
||||
return hashlib.md5(question.encode()).hexdigest()
|
||||
|
||||
def get_cached_response(question: str) -> str:
|
||||
"""获取缓存的回答"""
|
||||
cache_key = get_cache_key(question)
|
||||
current_time = time.time()
|
||||
|
||||
# 检查缓存是否存在且未过期(5分钟过期)
|
||||
if (cache_key in response_cache and
|
||||
cache_key in cache_expire_time and
|
||||
cache_expire_time[cache_key] > current_time):
|
||||
logger.info(f"使用缓存回答: {cache_key}")
|
||||
return response_cache[cache_key]
|
||||
|
||||
return None
|
||||
|
||||
def set_cached_response(question: str, answer: str):
|
||||
"""设置缓存回答"""
|
||||
cache_key = get_cache_key(question)
|
||||
response_cache[cache_key] = answer
|
||||
cache_expire_time[cache_key] = time.time() + 300 # 5分钟过期
|
||||
logger.info(f"缓存回答: {cache_key}")
|
||||
|
||||
def get_deepseek_response(question):
|
||||
"""使用DeepSeek API获取简洁回答"""
|
||||
# 首先检查缓存
|
||||
cached_answer = get_cached_response(question)
|
||||
if cached_answer:
|
||||
return cached_answer
|
||||
|
||||
try:
|
||||
client = OpenAI(
|
||||
api_key=DEEPSEEK_API_KEY,
|
||||
base_url=DEEPSEEK_BASE_URL,
|
||||
timeout=8.0 # 设置8秒超时
|
||||
)
|
||||
|
||||
# 针对客服场景优化提示词
|
||||
system_prompt = """你是一个专业的智能客服助手。请遵循以下原则:
|
||||
1. 回答要简洁明了,通常控制在40字以内
|
||||
2. 语气要友好、专业、有帮助
|
||||
3. 如果不确定答案,请诚实说明并建议联系人工客服
|
||||
4. 重点解决客户的实际问题
|
||||
5. 避免冗长的解释,直接给出有用信息"""
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": question}
|
||||
],
|
||||
max_tokens=150, # 进一步限制回答长度
|
||||
temperature=0.5 # 降低随机性,提高一致性
|
||||
)
|
||||
|
||||
answer = response.choices[0].message.content.strip()
|
||||
logger.info(f"DeepSeek回答: {answer}")
|
||||
|
||||
# 缓存回答
|
||||
set_cached_response(question, answer)
|
||||
|
||||
return answer
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DeepSeek API调用失败: {str(e)}")
|
||||
return "抱歉,智能客服暂时繁忙,请稍后重试或联系人工客服。"
|
||||
|
||||
async def get_deepseek_response_async(question):
|
||||
"""异步调用DeepSeek API"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(executor, get_deepseek_response, question)
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
question: str
|
||||
|
||||
@router.post("/query")
|
||||
async def chat_query(req: ChatRequest):
|
||||
"""AI客服对话接口"""
|
||||
try:
|
||||
question = req.question.strip()
|
||||
if not question:
|
||||
raise HTTPException(status_code=400, detail="问题不能为空")
|
||||
|
||||
# 使用DeepSeek获取回答
|
||||
answer = await get_deepseek_response_async(question)
|
||||
|
||||
# 生成语音文件
|
||||
timestamp = int(time.time() * 1000)
|
||||
audio_filename = f"ai_response_{timestamp}.wav"
|
||||
audio_path = os.path.join(AUDIO_DIR, audio_filename)
|
||||
|
||||
# 语音合成
|
||||
tts_success = aliyun_tts_chinese_text_to_audio(answer, audio_path)
|
||||
|
||||
response_data = {
|
||||
"answer": answer,
|
||||
"audio_url": f"/ai-chatbot/audio/{audio_filename}" if tts_success else None
|
||||
}
|
||||
|
||||
return response_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询处理失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="处理请求时发生错误")
|
||||
|
||||
@router.post("/query-text")
|
||||
async def chat_query_text_only(req: ChatRequest):
|
||||
"""AI客服对话接口 - 仅返回文本,快速响应"""
|
||||
try:
|
||||
question = req.question.strip()
|
||||
if not question:
|
||||
raise HTTPException(status_code=400, detail="问题不能为空")
|
||||
|
||||
# 使用DeepSeek获取回答
|
||||
answer = await get_deepseek_response_async(question)
|
||||
|
||||
response_data = {
|
||||
"answer": answer,
|
||||
"timestamp": int(time.time() * 1000) # 用于生成语音时的唯一标识
|
||||
}
|
||||
|
||||
return response_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"文本查询处理失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="处理请求时发生错误")
|
||||
|
||||
class AudioRequest(BaseModel):
|
||||
text: str
|
||||
timestamp: int
|
||||
|
||||
@router.post("/generate-audio")
|
||||
async def generate_audio(req: AudioRequest):
|
||||
"""异步生成语音文件"""
|
||||
try:
|
||||
text = req.text.strip()
|
||||
if not text:
|
||||
raise HTTPException(status_code=400, detail="文本不能为空")
|
||||
|
||||
# 使用时间戳生成语音文件名
|
||||
audio_filename = f"ai_response_{req.timestamp}.wav"
|
||||
audio_path = os.path.join(AUDIO_DIR, audio_filename)
|
||||
|
||||
# 语音合成
|
||||
tts_success = aliyun_tts_chinese_text_to_audio(text, audio_path)
|
||||
|
||||
if tts_success:
|
||||
return {
|
||||
"success": True,
|
||||
"audio_url": f"/ai-chatbot/audio/{audio_filename}"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "语音生成失败"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"语音生成失败: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
@router.post("/asr")
|
||||
async def speech_recognition(file: UploadFile = File(...)):
|
||||
"""语音识别接口"""
|
||||
try:
|
||||
# 检查文件格式
|
||||
if not file.filename.lower().endswith(('.wav', '.mp3', '.m4a', '.webm', '.ogg')):
|
||||
logger.warning(f"不支持的文件格式: {file.filename}")
|
||||
|
||||
# 保存上传的音频文件
|
||||
timestamp = int(time.time() * 1000)
|
||||
# 保持原始文件扩展名,让pydub自动检测格式
|
||||
original_ext = os.path.splitext(file.filename)[1] if file.filename else '.wav'
|
||||
temp_filename = f"temp_audio_{timestamp}{original_ext}"
|
||||
temp_path = os.path.join(AUDIO_DIR, temp_filename)
|
||||
|
||||
# 保存文件
|
||||
with open(temp_path, "wb") as buffer:
|
||||
content = await file.read()
|
||||
buffer.write(content)
|
||||
|
||||
logger.info(f"接收到音频文件: {file.filename}, 大小: {len(content)} bytes, 临时保存为: {temp_path}")
|
||||
|
||||
# 语音识别
|
||||
transcription = aliyun_asr_chinese_audio_to_text(temp_path)
|
||||
|
||||
# 清理临时文件
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
logger.info(f"已清理临时文件: {temp_path}")
|
||||
|
||||
return {"text": transcription}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"语音识别失败: {str(e)}")
|
||||
# 清理临时文件
|
||||
if 'temp_path' in locals() and os.path.exists(temp_path):
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except:
|
||||
pass
|
||||
raise HTTPException(status_code=500, detail="语音识别失败")
|
||||
|
||||
@router.get("/audio/{filename}")
|
||||
async def get_audio(filename: str):
|
||||
"""获取音频文件"""
|
||||
audio_path = os.path.join(AUDIO_DIR, filename)
|
||||
if os.path.exists(audio_path):
|
||||
return FileResponse(
|
||||
audio_path,
|
||||
media_type="audio/wav",
|
||||
headers={"Content-Disposition": f"inline; filename={filename}"}
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail="音频文件不存在")
|
||||
|
||||
# 添加音频文件清理功能
|
||||
def cleanup_old_audio_files():
|
||||
"""清理超过1小时的音频文件"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
for filename in os.listdir(AUDIO_DIR):
|
||||
if filename.endswith('.wav'):
|
||||
file_path = os.path.join(AUDIO_DIR, filename)
|
||||
file_age = current_time - os.path.getctime(file_path)
|
||||
|
||||
# 删除超过1小时的文件
|
||||
if file_age > 3600:
|
||||
try:
|
||||
os.remove(file_path)
|
||||
logger.info(f"已清理过期音频文件: {filename}")
|
||||
except Exception as e:
|
||||
logger.error(f"清理音频文件失败 {filename}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"音频文件清理过程出错: {e}")
|
||||
|
||||
def start_cleanup_timer():
|
||||
"""启动定时清理任务"""
|
||||
cleanup_old_audio_files()
|
||||
# 每30分钟执行一次清理
|
||||
threading.Timer(1800.0, start_cleanup_timer).start()
|
||||
|
||||
# 启动清理任务
|
||||
start_cleanup_timer()
|
||||
147
backend/app/routers/apps.py
Normal file
147
backend/app/routers/apps.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from app import models, database
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def get_db():
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
from app.utils_jwt import get_current_user
|
||||
|
||||
# 根路径返回所有应用
|
||||
@router.get("/")
|
||||
def root_apps(db: Session = Depends(get_db)):
|
||||
apps = db.query(models.App).all()
|
||||
return {"apps": [
|
||||
{"id": app.id, "name": app.name, "desc": app.desc, "price": app.price, "status": app.status} for app in apps
|
||||
]}
|
||||
|
||||
class AppCreate(BaseModel):
|
||||
name: str
|
||||
desc: str
|
||||
price: float
|
||||
status: str = "上架"
|
||||
|
||||
class AppUpdate(BaseModel):
|
||||
name: str = None
|
||||
desc: str = None
|
||||
price: float = None
|
||||
status: str = None
|
||||
|
||||
# 查询所有应用(用户/前端)
|
||||
@router.get("/list")
|
||||
def list_apps(db: Session = Depends(get_db)):
|
||||
apps = db.query(models.App).all()
|
||||
return {"apps": [
|
||||
{"id": app.id, "name": app.name, "desc": app.desc, "price": app.price, "status": app.status} for app in apps
|
||||
]}
|
||||
|
||||
# 管理员获取全部应用(含下架)
|
||||
@router.get("/all")
|
||||
def list_all_apps(db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
if not getattr(user, "is_admin", False):
|
||||
raise HTTPException(status_code=403, detail="无权限")
|
||||
apps = db.query(models.App).all()
|
||||
return [
|
||||
{"id": app.id, "name": app.name, "desc": app.desc, "price": app.price, "status": app.status} for app in apps
|
||||
]
|
||||
|
||||
# 新增应用(管理员)
|
||||
@router.post("/add")
|
||||
def add_app(app: AppCreate, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
if not getattr(user, "is_admin", False):
|
||||
raise HTTPException(status_code=403, detail="无权限")
|
||||
new_app = models.App(**app.dict())
|
||||
db.add(new_app)
|
||||
db.commit()
|
||||
db.refresh(new_app)
|
||||
return {"msg": "添加成功", "app": {"id": new_app.id, "name": new_app.name}}
|
||||
|
||||
# 修改应用(管理员)
|
||||
@router.put("/edit/{app_id}")
|
||||
def edit_app(app_id: int, app: AppUpdate, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
if not getattr(user, "is_admin", False):
|
||||
raise HTTPException(status_code=403, detail="无权限")
|
||||
db_app = db.query(models.App).filter(models.App.id == app_id).first()
|
||||
if not db_app:
|
||||
raise HTTPException(status_code=404, detail="应用不存在")
|
||||
for field, value in app.dict(exclude_unset=True).items():
|
||||
setattr(db_app, field, value)
|
||||
db.commit()
|
||||
return {"msg": "修改成功"}
|
||||
|
||||
# 删除应用(管理员)
|
||||
@router.delete("/delete/{app_id}")
|
||||
def delete_app(app_id: int, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
if not getattr(user, "is_admin", False):
|
||||
raise HTTPException(status_code=403, detail="无权限")
|
||||
db_app = db.query(models.App).filter(models.App.id == app_id).first()
|
||||
if not db_app:
|
||||
raise HTTPException(status_code=404, detail="应用不存在")
|
||||
db.delete(db_app)
|
||||
db.commit()
|
||||
return {"msg": "删除成功"}
|
||||
|
||||
# 用户调用应用(消费)
|
||||
class UseAppRequest(BaseModel):
|
||||
app_id: int
|
||||
|
||||
@router.post("/use")
|
||||
def use_app(req: UseAppRequest, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
# 获取应用信息
|
||||
db_app = db.query(models.App).filter(models.App.id == req.app_id).first()
|
||||
if not db_app:
|
||||
raise HTTPException(status_code=404, detail="应用不存在")
|
||||
|
||||
# 获取最新的用户信息
|
||||
db_user = db.query(models.User).filter(models.User.id == user.id).first()
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
# 检查余额
|
||||
if db_user.balance < db_app.price:
|
||||
raise HTTPException(status_code=400, detail="余额不足,请先充值")
|
||||
|
||||
try:
|
||||
# 扣除余额
|
||||
db_user.balance -= db_app.price
|
||||
|
||||
# 添加消费记录
|
||||
from app.models import History
|
||||
record = History(
|
||||
user_id=db_user.id,
|
||||
type='consume',
|
||||
amount=-db_app.price,
|
||||
desc=f"调用{db_app.name}"
|
||||
)
|
||||
db.add(record)
|
||||
|
||||
# 添加订单记录
|
||||
from app.models import Order
|
||||
order = Order(
|
||||
user_id=db_user.id,
|
||||
app_id=db_app.id,
|
||||
type=db_app.name, # 使用应用名称作为订单类型
|
||||
amount=db_app.price,
|
||||
description=db_app.desc, # 添加应用描述
|
||||
status="已完成" # 使用"已完成"代替"已支付"
|
||||
)
|
||||
db.add(order)
|
||||
|
||||
# 提交事务
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"msg": f"成功调用 {db_app.name}!已扣除{db_app.price}元。",
|
||||
"balance": db_user.balance
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"操作失败:{str(e)}")
|
||||
42
backend/app/routers/balance.py
Normal file
42
backend/app/routers/balance.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from app import models, database
|
||||
from datetime import datetime
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def get_db():
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.get("/me")
|
||||
def get_my_balance(user_id: int, db: Session = Depends(get_db)):
|
||||
user = db.query(models.User).filter(models.User.id == user_id).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
return {"balance": user.balance}
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
class RechargeRequest(BaseModel):
|
||||
amount: float
|
||||
|
||||
from app.utils_jwt import get_current_user
|
||||
|
||||
@router.post("/recharge")
|
||||
def recharge_balance(req: RechargeRequest, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
if req.amount <= 0:
|
||||
raise HTTPException(status_code=400, detail="Amount must be positive")
|
||||
# 强制用db查一次,确保user为当前session的持久对象
|
||||
db_user = db.query(models.User).filter(models.User.id == user.id).first()
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
db_user.balance += req.amount
|
||||
from app.models import History
|
||||
record = History(user_id=db_user.id, type='recharge', amount=req.amount, desc='余额充值')
|
||||
db.add(record)
|
||||
db.commit()
|
||||
return {"balance": db_user.balance}
|
||||
114
backend/app/routers/history.py
Normal file
114
backend/app/routers/history.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from app import models, database
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def get_db():
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
from app.utils_jwt import get_current_user
|
||||
from typing import List
|
||||
from app import schemas
|
||||
|
||||
# 根路径返回所有充值记录(管理员)
|
||||
@router.get("/")
|
||||
def root_history(db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
if not getattr(user, "is_admin", False):
|
||||
raise HTTPException(status_code=403, detail="无权限")
|
||||
|
||||
try:
|
||||
# 只获取充值记录,不显示消费记录
|
||||
records = db.query(models.History).filter(models.History.type == "recharge").all()
|
||||
result = []
|
||||
for r in records:
|
||||
# 获取用户名(如果存在)
|
||||
username = ""
|
||||
user_record = db.query(models.User).filter(models.User.id == r.user_id).first()
|
||||
if user_record:
|
||||
username = user_record.username
|
||||
|
||||
# 安全处理时间格式
|
||||
timestamp_str = ""
|
||||
try:
|
||||
if hasattr(r, 'created_at') and r.created_at:
|
||||
timestamp_str = r.created_at.strftime('%Y-%m-%d %H:%M:%S')
|
||||
elif hasattr(r, 'timestamp') and r.timestamp:
|
||||
timestamp_str = r.timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
timestamp_str = str(r.created_at or r.timestamp or "")
|
||||
|
||||
result.append({
|
||||
"id": r.id,
|
||||
"user_id": r.user_id,
|
||||
"username": username,
|
||||
"type": r.type,
|
||||
"amount": r.amount,
|
||||
"desc": r.desc,
|
||||
"time": timestamp_str
|
||||
})
|
||||
|
||||
return {"history": result}
|
||||
except Exception as e:
|
||||
print(f"Error in root_history: {str(e)}")
|
||||
return {"history": [], "error": str(e)}
|
||||
|
||||
@router.get("/list")
|
||||
def list_history(db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
records = db.query(models.History).filter(models.History.user_id == user.id).order_by(models.History.created_at.desc()).all()
|
||||
return {"history": [
|
||||
{
|
||||
"type": r.type,
|
||||
"amount": r.amount,
|
||||
"desc": r.desc,
|
||||
"time": r.created_at.strftime('%Y-%m-%d %H:%M:%S')
|
||||
} for r in records
|
||||
]}
|
||||
|
||||
# 管理员获取全部充值记录
|
||||
@router.get("/all")
|
||||
def list_all_history(db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
if not getattr(user, "is_admin", False):
|
||||
raise HTTPException(status_code=403, detail="无权限")
|
||||
|
||||
try:
|
||||
# 只获取充值记录,不显示消费记录
|
||||
records = db.query(models.History).filter(models.History.type == "recharge").all()
|
||||
|
||||
result = []
|
||||
for r in records:
|
||||
# 获取用户名(如果存在)
|
||||
username = ""
|
||||
user_record = db.query(models.User).filter(models.User.id == r.user_id).first()
|
||||
if user_record:
|
||||
username = user_record.username
|
||||
|
||||
# 安全处理时间格式
|
||||
timestamp_str = ""
|
||||
try:
|
||||
if hasattr(r, 'created_at') and r.created_at:
|
||||
timestamp_str = r.created_at.strftime('%Y-%m-%d %H:%M:%S')
|
||||
elif hasattr(r, 'timestamp') and r.timestamp:
|
||||
timestamp_str = r.timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
timestamp_str = str(r.created_at or r.timestamp or "")
|
||||
|
||||
result.append({
|
||||
"id": r.id,
|
||||
"user_id": r.user_id,
|
||||
"username": username,
|
||||
"type": r.type,
|
||||
"amount": r.amount,
|
||||
"desc": r.desc,
|
||||
"time": timestamp_str
|
||||
})
|
||||
|
||||
return {"history": result}
|
||||
except Exception as e:
|
||||
# 记录错误但返回空列表,避免500错误
|
||||
print(f"Error in list_all_history: {str(e)}")
|
||||
return {"history": [], "error": str(e)}
|
||||
420
backend/app/routers/news_stock.py
Normal file
420
backend/app/routers/news_stock.py
Normal file
@@ -0,0 +1,420 @@
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict, Any
|
||||
import os
|
||||
import logging
|
||||
import pandas as pd
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
import requests
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.FileHandler("news_stock_api.log"), logging.StreamHandler()]
|
||||
)
|
||||
logger = logging.getLogger("NewsStockAPI")
|
||||
|
||||
router = APIRouter(
|
||||
tags=["news_stock"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
# 请求模型
|
||||
class NewsAnalysisRequest(BaseModel):
|
||||
news_content: str
|
||||
|
||||
# 响应模型
|
||||
class AnalysisResult(BaseModel):
|
||||
industries: List[str]
|
||||
companies: List[dict]
|
||||
analysis_details: str
|
||||
timestamp: str
|
||||
|
||||
# DeepSeek API配置
|
||||
API_KEY = "sk-8a121704a9bc4ec6a5ab0ae16e0bc0ba"
|
||||
BASE_URL = "https://api.deepseek.com"
|
||||
|
||||
class NewsStockAnalyzer:
|
||||
"""热点新闻股票影响分析器 - HTTP请求版本"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化分析器"""
|
||||
self.api_key = API_KEY
|
||||
self.base_url = BASE_URL
|
||||
self.load_company_info()
|
||||
|
||||
def load_company_info(self):
|
||||
"""加载上市公司信息"""
|
||||
try:
|
||||
logger.info("开始加载上市公司信息...")
|
||||
|
||||
# 只使用相对路径,移除硬编码的绝对路径
|
||||
possible_paths = [
|
||||
# 相对于当前文件的路径
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "..", "NewsImpactOnStocks", "上市公司信息表.xlsx"),
|
||||
# 相对于backend目录的路径
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "上市公司信息表.xlsx"),
|
||||
# 相对于项目根目录的路径
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "..", "上市公司信息表.xlsx")
|
||||
]
|
||||
|
||||
company_info_path = None
|
||||
for path in possible_paths:
|
||||
if os.path.exists(path):
|
||||
company_info_path = path
|
||||
break
|
||||
|
||||
if company_info_path:
|
||||
self.company_df = pd.read_excel(company_info_path)
|
||||
logger.info(f"成功加载上市公司信息,共 {len(self.company_df)} 条记录")
|
||||
logger.info(f"加载路径: {company_info_path}")
|
||||
else:
|
||||
logger.warning("未找到上市公司信息表,使用空数据")
|
||||
logger.warning(f"尝试的路径: {possible_paths}")
|
||||
self.company_df = pd.DataFrame()
|
||||
except Exception as e:
|
||||
logger.error(f"加载上市公司信息失败: {e}")
|
||||
self.company_df = pd.DataFrame()
|
||||
|
||||
def call_deepseek_api(self, messages, max_retries=2):
|
||||
"""使用HTTP请求调用DeepSeek API - 简化版本"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
logger.info(f"调用DeepSeek API,尝试次数: {attempt + 1}/{max_retries}")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": "deepseek-chat",
|
||||
"messages": messages,
|
||||
"temperature": 0.1,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=25 # 简化超时设置
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
logger.info(f"API调用成功,返回内容长度: {len(content)}")
|
||||
return content
|
||||
else:
|
||||
logger.error(f"API请求失败: {response.status_code}")
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API调用失败: {e}")
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(1)
|
||||
|
||||
return None
|
||||
|
||||
def analyze_news_impact_on_industries(self, news):
|
||||
"""分析新闻对行业的影响"""
|
||||
logger.info("开始分析新闻对行业的影响...")
|
||||
|
||||
system_prompt = """
|
||||
# 角色
|
||||
你是一位顶尖的证券分析师,拥有深厚的行业知识和敏锐的市场洞察力。能够对热点新闻进行深入分析,判断其对各行业的影响,并给出详细理由。
|
||||
|
||||
## 技能
|
||||
### 技能 1:分析新闻对行业的影响
|
||||
1. 当用户提供热点新闻标题或内容时,确定其中的关键要素;
|
||||
2. 结合最新的市场趋势、经济环境以及各行业发展情况,深入分析该新闻可能对哪些行业产生影响;
|
||||
3. 输出全部可能有被影响的行业,并整合到一句话中。
|
||||
|
||||
## 输出格式
|
||||
请按以下格式输出:
|
||||
|
||||
影响行业:行业1、行业2、行业3...
|
||||
|
||||
影响分析:
|
||||
1. 行业1:[分析理由]
|
||||
2. 行业2:[分析理由]
|
||||
...
|
||||
|
||||
## 限制:
|
||||
- 只分析与新闻相关的行业影响,拒绝回答与新闻无关的问题。
|
||||
- 分析理由要充分、有条理。
|
||||
"""
|
||||
|
||||
user_prompt = f"""
|
||||
下面是用户提供的热点新闻信息,请分析该新闻可能影响的全部行业,并给出详细理由。
|
||||
|
||||
===热点新闻开始===
|
||||
{news}
|
||||
===热点新闻结束===
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
# 优先使用更快的chat模型
|
||||
analysis = self.call_deepseek_api(messages)
|
||||
|
||||
if analysis:
|
||||
logger.info("行业影响分析完成")
|
||||
|
||||
# 提取行业
|
||||
industries = []
|
||||
industry_pattern = r"影响行业[::](.*?)(?:\n|$)"
|
||||
industry_match = re.search(industry_pattern, analysis)
|
||||
|
||||
if industry_match:
|
||||
industry_text = industry_match.group(1).strip()
|
||||
industries = [ind.strip() for ind in re.split(r'[,,、]', industry_text) if ind.strip()]
|
||||
|
||||
if not industries:
|
||||
industries = self._extract_industries_from_text(analysis)
|
||||
|
||||
return {
|
||||
"industries": industries,
|
||||
"reasons": analysis
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"industries": [],
|
||||
"reasons": "分析过程中出现错误,请检查API配置或网络连接。"
|
||||
}
|
||||
|
||||
def _extract_industries_from_text(self, text):
|
||||
"""从文本中提取可能的行业名称"""
|
||||
common_industries = [
|
||||
"互联网", "金融", "银行", "保险", "证券", "房地产", "医药", "医疗", "健康",
|
||||
"教育", "零售", "消费", "制造", "能源", "电力", "新能源", "汽车", "电子",
|
||||
"半导体", "通信", "传媒", "娱乐", "旅游", "餐饮", "物流", "交通", "航空",
|
||||
"铁路", "船舶", "钢铁", "煤炭", "石油", "化工", "农业", "食品", "饮料",
|
||||
"纺织", "服装", "建筑", "建材", "家电", "软件", "硬件", "人工智能", "云计算",
|
||||
"大数据", "区块链", "物联网", "5G", "军工", "航天", "环保", "新材料"
|
||||
]
|
||||
|
||||
found_industries = []
|
||||
for industry in common_industries:
|
||||
if industry in text:
|
||||
found_industries.append(industry)
|
||||
|
||||
return found_industries
|
||||
|
||||
def search_related_companies(self, industries):
|
||||
"""查找与行业相关的上市公司 - 简化版本"""
|
||||
logger.info(f"开始查找与行业相关的上市公司: {', '.join(industries)}")
|
||||
|
||||
if self.company_df.empty:
|
||||
return []
|
||||
|
||||
# 简化关键词映射
|
||||
search_keywords = set()
|
||||
for industry in industries:
|
||||
# 直接使用行业名称作为关键词
|
||||
search_keywords.add(industry)
|
||||
|
||||
# 添加一些基本的相关词
|
||||
if "汽车" in industry:
|
||||
search_keywords.update(["汽车", "汽车零部件", "新能源汽车"])
|
||||
elif "电池" in industry:
|
||||
search_keywords.update(["电池", "锂电池", "储能"])
|
||||
elif "电子" in industry:
|
||||
search_keywords.update(["电子", "消费电子"])
|
||||
elif "半导体" in industry:
|
||||
search_keywords.update(["半导体", "芯片"])
|
||||
elif "通信" in industry:
|
||||
search_keywords.update(["通信", "5G", "通信设备"])
|
||||
|
||||
logger.info(f"搜索关键词: {', '.join(search_keywords)}")
|
||||
|
||||
# 搜索相关公司
|
||||
company_scores = {}
|
||||
for keyword in search_keywords:
|
||||
try:
|
||||
# 在行业和主营业务中搜索
|
||||
industry_match = self.company_df['IndustryName'].str.contains(keyword, na=False)
|
||||
business_match = self.company_df['MAINBUSSINESS'].str.contains(keyword, na=False)
|
||||
|
||||
matched_companies = self.company_df[industry_match | business_match]
|
||||
|
||||
for _, company in matched_companies.iterrows():
|
||||
symbol = company.get('Symbol')
|
||||
score = company_scores.get(symbol, 0)
|
||||
|
||||
# 简化评分:行业匹配得2分,业务匹配得1分
|
||||
if industry_match.iloc[company.name]:
|
||||
score += 2
|
||||
if business_match.iloc[company.name]:
|
||||
score += 1
|
||||
|
||||
company_scores[symbol] = score
|
||||
except Exception as e:
|
||||
logger.error(f"搜索关键词 '{keyword}' 时出错: {e}")
|
||||
|
||||
if not company_scores:
|
||||
logger.warning("未找到与行业相关的公司")
|
||||
return []
|
||||
|
||||
# 按得分排序,取前10家
|
||||
sorted_companies = sorted(company_scores.items(), key=lambda x: x[1], reverse=True)[:10]
|
||||
|
||||
# 转换为返回格式
|
||||
related_companies = []
|
||||
for symbol, score in sorted_companies:
|
||||
company_matches = self.company_df[self.company_df['Symbol'] == symbol]
|
||||
if len(company_matches) > 0:
|
||||
company_row = company_matches.iloc[0]
|
||||
company_info = {
|
||||
"code": str(company_row.get('Symbol', '')),
|
||||
"name": str(company_row.get('ShortName', '')),
|
||||
"industry": str(company_row.get('IndustryName', '')),
|
||||
"business": str(company_row.get('MAINBUSSINESS', ''))[:200],
|
||||
"score": score
|
||||
}
|
||||
related_companies.append(company_info)
|
||||
|
||||
logger.info(f"找到 {len(related_companies)} 家相关公司")
|
||||
return related_companies
|
||||
|
||||
def analyze_company_impact(self, news, companies):
|
||||
"""分析新闻对具体公司的影响 - 简化版本"""
|
||||
if not companies:
|
||||
return "未找到相关公司"
|
||||
|
||||
try:
|
||||
logger.info(f"开始分析新闻对 {len(companies)} 家公司的影响...")
|
||||
|
||||
# 简化prompt
|
||||
system_prompt = """你是证券分析师,请分析新闻对相关公司的影响。
|
||||
|
||||
输出格式:
|
||||
🎯 公司名称(代码)
|
||||
📈 影响分析:[简述影响]
|
||||
|
||||
要求:每家公司分析不超过50字。"""
|
||||
|
||||
# 只分析前5家公司
|
||||
company_list = []
|
||||
for company in companies[:5]:
|
||||
company_list.append(f"{company['name']}({company['code']})- {company['industry']}")
|
||||
|
||||
user_prompt = f"新闻:{news}\n\n相关公司:\n" + "\n".join(company_list)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
analysis = self.call_deepseek_api(messages)
|
||||
|
||||
if analysis:
|
||||
logger.info(f"公司影响分析完成")
|
||||
return analysis
|
||||
else:
|
||||
return "分析服务暂时不可用"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"分析公司影响时出错: {e}")
|
||||
return "分析过程中出现错误"
|
||||
|
||||
def analyze_news(self, news):
|
||||
"""完整的新闻分析流程 - 简化版本"""
|
||||
try:
|
||||
# 1. 分析行业影响
|
||||
industry_result = self.analyze_news_impact_on_industries(news)
|
||||
industries = industry_result["industries"]
|
||||
industry_analysis = industry_result["reasons"]
|
||||
|
||||
# 2. 查找相关公司
|
||||
companies = self.search_related_companies(industries)
|
||||
|
||||
# 3. 分析公司影响
|
||||
company_analysis = self.analyze_company_impact(news, companies)
|
||||
|
||||
# 4. 返回结果
|
||||
return {
|
||||
"industries": industries,
|
||||
"companies": companies,
|
||||
"analysis_details": f"{industry_analysis}\n\n{company_analysis}",
|
||||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"新闻分析失败: {e}")
|
||||
return {
|
||||
"industries": [],
|
||||
"companies": [],
|
||||
"analysis_details": "分析过程中出现错误",
|
||||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
|
||||
# 创建分析器实例
|
||||
news_analyzer = NewsStockAnalyzer()
|
||||
|
||||
@router.post("/analyze", response_model=AnalysisResult)
|
||||
async def analyze_news(request: NewsAnalysisRequest):
|
||||
"""分析热点新闻对股票的影响"""
|
||||
try:
|
||||
logger.info(f"收到新闻分析请求,内容长度: {len(request.news_content)}")
|
||||
|
||||
if not request.news_content.strip():
|
||||
raise HTTPException(status_code=400, detail="新闻内容不能为空")
|
||||
|
||||
# 执行分析
|
||||
result = news_analyzer.analyze_news(request.news_content)
|
||||
|
||||
# 验证结果数据
|
||||
if not isinstance(result, dict):
|
||||
logger.error("分析结果不是字典类型")
|
||||
raise HTTPException(status_code=500, detail="分析结果格式错误")
|
||||
|
||||
# 确保必要字段存在
|
||||
if "industries" not in result:
|
||||
result["industries"] = []
|
||||
if "companies" not in result:
|
||||
result["companies"] = []
|
||||
if "analysis_details" not in result:
|
||||
result["analysis_details"] = "分析完成"
|
||||
if "timestamp" not in result:
|
||||
result["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
logger.info(f"分析完成,找到 {len(result['industries'])} 个相关行业,{len(result['companies'])} 家相关公司")
|
||||
|
||||
# 创建响应对象,确保类型安全
|
||||
try:
|
||||
response = AnalysisResult(**result)
|
||||
logger.info("成功创建AnalysisResult响应对象")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"创建响应对象失败: {e}")
|
||||
logger.error(f"结果数据: {result}")
|
||||
# 返回一个最基本的安全响应
|
||||
return AnalysisResult(
|
||||
industries=[],
|
||||
companies=[],
|
||||
analysis_details="分析完成,但响应格式化失败",
|
||||
timestamp=datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"处理新闻分析请求失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=f"服务器内部错误: {str(e)[:100]}")
|
||||
|
||||
@router.get("/test")
|
||||
async def test_news_stock_api():
|
||||
"""测试新闻选股API路由是否正常工作"""
|
||||
return {"status": "ok", "message": "热点新闻选股路由正常工作"}
|
||||
140
backend/app/routers/orders.py
Normal file
140
backend/app/routers/orders.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from app import models, database
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from app.utils_jwt import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def get_db():
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# 根路径返回所有订单(管理员)
|
||||
@router.get("/")
|
||||
def root_orders(db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
if not getattr(user, "is_admin", False):
|
||||
raise HTTPException(status_code=403, detail="无权限")
|
||||
|
||||
try:
|
||||
# 从历史记录表中获取消费记录
|
||||
consume_records = db.query(models.History).filter(models.History.type == "consume").all()
|
||||
|
||||
result = []
|
||||
for r in consume_records:
|
||||
# 获取用户名(如果存在)
|
||||
username = ""
|
||||
user_record = db.query(models.User).filter(models.User.id == r.user_id).first()
|
||||
if user_record:
|
||||
username = user_record.username
|
||||
|
||||
# 安全处理时间格式
|
||||
timestamp_str = ""
|
||||
try:
|
||||
if hasattr(r, 'created_at') and r.created_at:
|
||||
timestamp_str = r.created_at.strftime('%Y-%m-%d %H:%M:%S')
|
||||
elif hasattr(r, 'timestamp') and r.timestamp:
|
||||
timestamp_str = r.timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
timestamp_str = str(r.created_at or r.timestamp or "")
|
||||
|
||||
result.append({
|
||||
"id": r.id,
|
||||
"user_id": r.user_id,
|
||||
"username": username,
|
||||
"type": r.type,
|
||||
"amount": abs(r.amount), # 转为正数显示
|
||||
"desc": r.desc,
|
||||
"time": timestamp_str,
|
||||
"status": "已完成" # 默认状态
|
||||
})
|
||||
|
||||
return {"orders": result}
|
||||
except Exception as e:
|
||||
print(f"Error in root_orders: {str(e)}")
|
||||
return {"orders": [], "error": str(e)}
|
||||
|
||||
# 查询当前用户订单
|
||||
@router.get("/my")
|
||||
def my_orders(db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
orders = db.query(models.Order).filter(models.Order.user_id == user.id).all()
|
||||
return {"orders": [
|
||||
{"id": o.id, "app_id": o.app_id, "amount": o.amount, "status": o.status, "timestamp": o.timestamp} for o in orders
|
||||
]}
|
||||
|
||||
# 管理员查询所有订单
|
||||
@router.get("/all")
|
||||
def all_orders(db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
if not getattr(user, "is_admin", False):
|
||||
raise HTTPException(status_code=403, detail="无权限")
|
||||
|
||||
try:
|
||||
# 从历史记录表中获取消费记录
|
||||
consume_records = db.query(models.History).filter(models.History.type == "consume").all()
|
||||
|
||||
result = []
|
||||
for r in consume_records:
|
||||
# 获取用户名(如果存在)
|
||||
username = ""
|
||||
user_record = db.query(models.User).filter(models.User.id == r.user_id).first()
|
||||
if user_record:
|
||||
username = user_record.username
|
||||
|
||||
# 安全处理时间格式
|
||||
timestamp_str = ""
|
||||
try:
|
||||
if hasattr(r, 'created_at') and r.created_at:
|
||||
timestamp_str = r.created_at.strftime('%Y-%m-%d %H:%M:%S')
|
||||
elif hasattr(r, 'timestamp') and r.timestamp:
|
||||
timestamp_str = r.timestamp.strftime('%Y-%m-%d %H:%M:%S')
|
||||
except:
|
||||
timestamp_str = str(r.created_at or r.timestamp or "")
|
||||
|
||||
result.append({
|
||||
"id": r.id,
|
||||
"user_id": r.user_id,
|
||||
"username": username,
|
||||
"type": r.type,
|
||||
"amount": abs(r.amount), # 转为正数显示
|
||||
"desc": r.desc,
|
||||
"time": timestamp_str,
|
||||
"status": "已完成" # 默认状态
|
||||
})
|
||||
|
||||
return {"orders": result}
|
||||
except Exception as e:
|
||||
print(f"Error in all_orders: {str(e)}")
|
||||
return {"orders": [], "error": str(e)}
|
||||
|
||||
class OrderUpdate(BaseModel):
|
||||
status: Optional[str] = None
|
||||
amount: Optional[float] = None
|
||||
|
||||
# 管理员修改订单
|
||||
@router.put("/orders/{order_id}")
|
||||
def update_order(order_id: int, order: OrderUpdate, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
if not getattr(user, "is_admin", False):
|
||||
raise HTTPException(status_code=403, detail="无权限")
|
||||
db_order = db.query(models.Order).filter(models.Order.id == order_id).first()
|
||||
if not db_order:
|
||||
raise HTTPException(status_code=404, detail="订单不存在")
|
||||
for field, value in order.dict(exclude_unset=True).items():
|
||||
setattr(db_order, field, value)
|
||||
db.commit()
|
||||
return {"msg": "修改成功"}
|
||||
|
||||
# 管理员删除订单
|
||||
@router.delete("/orders/{order_id}")
|
||||
def delete_order(order_id: int, db: Session = Depends(get_db), user=Depends(get_current_user)):
|
||||
if not getattr(user, "is_admin", False):
|
||||
raise HTTPException(status_code=403, detail="无权限")
|
||||
db_order = db.query(models.Order).filter(models.Order.id == order_id).first()
|
||||
if not db_order:
|
||||
raise HTTPException(status_code=404, detail="订单不存在")
|
||||
db.delete(db_order)
|
||||
db.commit()
|
||||
return {"msg": "删除成功"}
|
||||
248
backend/app/routers/twitter.py
Normal file
248
backend/app/routers/twitter.py
Normal file
@@ -0,0 +1,248 @@
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from typing import Optional
|
||||
import os
|
||||
import logging
|
||||
import requests
|
||||
import json
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.FileHandler("twitter_api.log"), logging.StreamHandler()]
|
||||
)
|
||||
logger = logging.getLogger("TwitterAPI")
|
||||
|
||||
router = APIRouter(
|
||||
tags=["twitter"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
# DeepSeek API配置
|
||||
API_KEY = "sk-8a121704a9bc4ec6a5ab0ae16e0bc0ba"
|
||||
BASE_URL = "https://api.deepseek.com"
|
||||
|
||||
# 帮助调试问题的简单测试路由
|
||||
@router.get("/test")
|
||||
async def test_twitter_api():
|
||||
"""测试Twitter API路由是否正常工作"""
|
||||
return {"status": "ok", "message": "Twitter路由正常工作"}
|
||||
|
||||
class TwitterService:
|
||||
def __init__(self):
|
||||
"""初始化Twitter和用户推文总结服务"""
|
||||
try:
|
||||
# 使用HTTP请求方式调用DeepSeek API
|
||||
self.api_key = API_KEY
|
||||
self.base_url = BASE_URL
|
||||
self.twitter_api_key = "e3dad005b0e54bdc88c6178a89adec13"
|
||||
self.twitter_api_url = "https://api.twitterapi.io/twitter/tweet/advanced_search"
|
||||
logger.info("TwitterService 初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化TwitterService失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def call_deepseek_api(self, messages, model="deepseek-chat"):
|
||||
"""使用HTTP请求调用DeepSeek API"""
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": 0.1,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result["choices"][0]["message"]["content"]
|
||||
else:
|
||||
logger.error(f"DeepSeek API请求失败: {response.status_code}, {response.text}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调用DeepSeek API失败: {e}")
|
||||
return None
|
||||
|
||||
def get_user_tweets(self, username):
|
||||
"""获取指定用户的最近推文 - 保持与原始代码相同的逻辑"""
|
||||
username = username.lstrip('@')
|
||||
|
||||
try:
|
||||
logger.info(f"请求Twitter API获取用户 {username} 的推文")
|
||||
|
||||
# 使用与原始代码相同的请求参数
|
||||
url = self.twitter_api_url
|
||||
headers = {"X-API-Key": self.twitter_api_key}
|
||||
params = {
|
||||
"queryType": "Latest",
|
||||
"query": f"from:{username}",
|
||||
"count": 10 # 获取更多推文以生成更好的摘要
|
||||
}
|
||||
|
||||
# 记录请求详情
|
||||
logger.info(f"API请求详情: URL={url}, 参数={params}")
|
||||
|
||||
# 发送请求,保持简单实现
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
|
||||
# 记录响应状态
|
||||
logger.info(f"API响应状态码: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
data = response.json()
|
||||
# 记录API返回的JSON数据以便调试
|
||||
logger.info(f"API响应数据: {json.dumps(data)[:500]}")
|
||||
|
||||
tweets = data.get("tweets", [])
|
||||
if tweets:
|
||||
tweet_texts = [tweet["text"] for tweet in tweets]
|
||||
logger.info(f"成功获取用户 '{username}' 的 {len(tweet_texts)} 条推文")
|
||||
return {
|
||||
"success": True,
|
||||
"tweets": tweet_texts,
|
||||
"content": "\n".join(tweet_texts)
|
||||
}
|
||||
else:
|
||||
logger.warning(f"没有找到用户 '{username}' 的推文")
|
||||
return {
|
||||
"success": False,
|
||||
"detail": f"没有找到用户 @{username} 的推文"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"解析API响应失败: {str(e)}")
|
||||
logger.error(f"响应内容: {response.text[:500]}")
|
||||
return {
|
||||
"success": False,
|
||||
"detail": f"解析Twitter API响应失败: {str(e)}"
|
||||
}
|
||||
else:
|
||||
# 记录错误响应内容以便调试
|
||||
logger.error(f"Twitter API 请求失败,状态码: {response.status_code}")
|
||||
logger.error(f"响应内容: {response.text[:500]}")
|
||||
return {
|
||||
"success": False,
|
||||
"detail": f"Twitter API请求失败,状态码: {response.status_code}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户推文失败: {str(e)}")
|
||||
logger.exception("详细错误信息:")
|
||||
return {
|
||||
"success": False,
|
||||
"detail": f"获取用户推文失败: {str(e)}"
|
||||
}
|
||||
|
||||
def generate_summary(self, tweet_content, username):
|
||||
"""根据推文内容生成总结 - 使用HTTP请求方式"""
|
||||
try:
|
||||
# 记录推文内容长度
|
||||
content_length = len(tweet_content) if tweet_content else 0
|
||||
logger.info(f"生成摘要: 用户={username}, 推文长度={content_length}")
|
||||
|
||||
# 如果推文内容为空,使用模拟数据
|
||||
if not tweet_content:
|
||||
logger.warning(f"推文内容为空,无法生成摘要")
|
||||
return f"无法获取用户 @{username} 的推文数据,请稍后再试。"
|
||||
|
||||
# 计算最大摘要长度
|
||||
max_summary_length = 280 - len(f"{username} 的最近推文摘要:") - 1
|
||||
|
||||
# 使用HTTP请求调用DeepSeek API生成摘要
|
||||
try:
|
||||
logger.info("调用DeepSeek API生成摘要")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant that summarizes tweets accurately"},
|
||||
{"role": "user", "content": f"请总结用户 @{username} 以下推文的主要内容,必须使用简体中文,总结必须以序号(1.、2.、...)分隔每条要点,每条要点简洁明了且以句号或感叹号结尾,总长度不得超过 {max_summary_length} 字符:\n{tweet_content}"}
|
||||
]
|
||||
|
||||
summary_content = self.call_deepseek_api(messages)
|
||||
|
||||
if summary_content:
|
||||
logger.info(f"成功生成摘要,长度: {len(summary_content)}")
|
||||
else:
|
||||
logger.error("DeepSeek API调用失败")
|
||||
return f"生成摘要失败,请稍后再试。"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DeepSeek API调用失败: {str(e)}")
|
||||
logger.exception("详细错误信息:")
|
||||
return f"生成摘要失败,请稍后再试。错误: {str(e)[:100]}"
|
||||
|
||||
# 确保返回的内容不超过限制
|
||||
summary = summary_content.strip()
|
||||
if len(summary) > max_summary_length:
|
||||
summary = summary[:max_summary_length]
|
||||
|
||||
# 检查并删除不完整的最后一条要点
|
||||
lines = summary.split('\n')
|
||||
if lines:
|
||||
last_line = lines[-1]
|
||||
# 检查最后一条是否以句号、感叹号或问号结尾
|
||||
if not last_line.endswith(('。', '!', '?', '.')):
|
||||
# 如果不完整,删除最后一条
|
||||
lines.pop()
|
||||
summary = '\n'.join(lines)
|
||||
|
||||
logger.info(f"成功为用户 '{username}' 生成摘要")
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成摘要失败: {str(e)}")
|
||||
logger.exception("详细错误信息:")
|
||||
return f"生成摘要失败,请稍后再试。"
|
||||
|
||||
|
||||
# 创建TwitterService实例
|
||||
twitter_service = TwitterService()
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_twitter_summary(username: str):
|
||||
"""获取Twitter用户最近推文的摘要"""
|
||||
try:
|
||||
logger.info(f"收到Twitter摘要请求,用户名: {username}")
|
||||
|
||||
# 获取用户推文
|
||||
tweets_result = twitter_service.get_user_tweets(username)
|
||||
logger.info(f"获取推文结果: {json.dumps(tweets_result)[:200] if isinstance(tweets_result, dict) else '未知结果'}")
|
||||
|
||||
# 如果获取推文失败,直接返回错误信息
|
||||
if not tweets_result.get("success", False):
|
||||
logger.warning(f"获取推文失败: {tweets_result.get('detail', '未知错误')}")
|
||||
return {
|
||||
"detail": tweets_result.get("detail", "获取推文失败")
|
||||
}
|
||||
|
||||
# 生成摘要
|
||||
tweet_content = tweets_result.get("content", "")
|
||||
summary = twitter_service.generate_summary(tweet_content, username)
|
||||
logger.info(f"生成摘要: 长度={len(summary) if summary else 0}")
|
||||
|
||||
# 返回摘要结果 - 确保格式与原始main.py中的返回完全一致
|
||||
return {
|
||||
"username": username,
|
||||
"summary": summary,
|
||||
"status": "success"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理Twitter摘要请求失败: {str(e)}")
|
||||
logger.exception("详细错误信息:")
|
||||
# 与原始main.py中的错误处理一致
|
||||
return {
|
||||
"detail": f"处理请求失败: {str(e)}"
|
||||
}
|
||||
325
backend/app/routers/twitter_post.py
Normal file
325
backend/app/routers/twitter_post.py
Normal file
@@ -0,0 +1,325 @@
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from typing import Optional
|
||||
import os
|
||||
import logging
|
||||
import requests
|
||||
import tweepy
|
||||
import json
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[logging.FileHandler("twitter_post_api.log"), logging.StreamHandler()]
|
||||
)
|
||||
logger = logging.getLogger("TwitterPostAPI")
|
||||
|
||||
router = APIRouter(
|
||||
tags=["twitter_post"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
# DeepSeek API配置
|
||||
API_KEY = "sk-8a121704a9bc4ec6a5ab0ae16e0bc0ba"
|
||||
BASE_URL = "https://api.deepseek.com"
|
||||
|
||||
# 请求模型
|
||||
class TwitterPostRequest(BaseModel):
|
||||
username: str
|
||||
post_to_twitter: bool = False
|
||||
|
||||
# 帮助调试问题的简单测试路由
|
||||
@router.get("/test")
|
||||
async def test_twitter_post_api():
|
||||
"""测试Twitter发推API路由是否正常工作"""
|
||||
return {"status": "ok", "message": "Twitter发推路由正常工作"}
|
||||
|
||||
class TwitterPostService:
|
||||
def __init__(self):
|
||||
"""初始化Twitter发推服务"""
|
||||
try:
|
||||
# 使用HTTP请求方式调用DeepSeek API
|
||||
self.api_key = API_KEY
|
||||
self.base_url = BASE_URL
|
||||
|
||||
# Twitter API凭证
|
||||
self.api_key_twitter = "3nt1jN4VvqUaaXGHv9AN5VsTV"
|
||||
self.api_secret = "M2io73S7TzitFiBw825QIq8atyZRljbIDQuTpH39uFZanQ4XFh"
|
||||
self.access_token = "1944636908-prxfjL6OIb56BQjuFTdChrUPh81OjmBbV7pfnWw"
|
||||
self.access_secret = "D5AdCVRvIhGEmTmXA8hL5ciAUxIqNMZ3K3B3YejpqqNKj"
|
||||
self.bearer_token = "AAAAAAAAAAAAAAAAAAAAAGOd0gEAAAAALtv%2BzLsfGLLa5ydUt60ci6J5ce0%3DMBXViJ1NLY4XYdeuMq1xZQ98kHbeGK5lAJoV2j7Ssmcafk8Skn"
|
||||
|
||||
# 初始化Twitter客户端(用于发送)
|
||||
self.twitter_client = tweepy.Client(
|
||||
bearer_token=self.bearer_token,
|
||||
consumer_key=self.api_key_twitter,
|
||||
consumer_secret=self.api_secret,
|
||||
access_token=self.access_token,
|
||||
access_token_secret=self.access_secret
|
||||
)
|
||||
|
||||
# 用于获取推文的API
|
||||
self.twitter_api_key = "e3dad005b0e54bdc88c6178a89adec13"
|
||||
self.twitter_api_url = "https://api.twitterapi.io/twitter/tweet/advanced_search"
|
||||
|
||||
logger.info("TwitterPostService 初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化TwitterPostService失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def call_deepseek_api(self, messages, model="deepseek-chat"):
|
||||
"""使用HTTP请求调用DeepSeek API"""
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": 0.1,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=60
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result["choices"][0]["message"]["content"]
|
||||
else:
|
||||
logger.error(f"DeepSeek API请求失败: {response.status_code}, {response.text}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调用DeepSeek API失败: {e}")
|
||||
return None
|
||||
|
||||
def get_user_tweets(self, username):
|
||||
"""获取指定用户的最近推文"""
|
||||
username = username.lstrip('@')
|
||||
|
||||
try:
|
||||
logger.info(f"请求Twitter API获取用户 {username} 的推文")
|
||||
|
||||
# 使用与原始代码相同的请求参数
|
||||
url = self.twitter_api_url
|
||||
headers = {"X-API-Key": self.twitter_api_key}
|
||||
params = {
|
||||
"queryType": "Latest",
|
||||
"query": f"from:{username}",
|
||||
"count": 5 # 获取5条推文
|
||||
}
|
||||
|
||||
# 记录请求详情
|
||||
logger.info(f"API请求详情: URL={url}, 参数={params}")
|
||||
|
||||
# 发送请求
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
|
||||
# 记录响应状态
|
||||
logger.info(f"API响应状态码: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
data = response.json()
|
||||
# 记录API返回的JSON数据以便调试
|
||||
logger.info(f"API响应数据: {json.dumps(data)[:500]}")
|
||||
|
||||
tweets = data.get("tweets", [])
|
||||
if tweets:
|
||||
tweet_texts = [tweet["text"] for tweet in tweets]
|
||||
logger.info(f"成功获取用户 '{username}' 的 {len(tweet_texts)} 条推文")
|
||||
return {
|
||||
"success": True,
|
||||
"tweets": tweet_texts,
|
||||
"content": "\n".join(tweet_texts)
|
||||
}
|
||||
else:
|
||||
logger.warning(f"没有找到用户 '{username}' 的推文")
|
||||
return {
|
||||
"success": False,
|
||||
"detail": f"没有找到用户 @{username} 的推文"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"解析API响应失败: {str(e)}")
|
||||
logger.error(f"响应内容: {response.text[:500]}")
|
||||
return {
|
||||
"success": False,
|
||||
"detail": f"解析Twitter API响应失败: {str(e)}"
|
||||
}
|
||||
else:
|
||||
# 记录错误响应内容以便调试
|
||||
logger.error(f"Twitter API 请求失败,状态码: {response.status_code}")
|
||||
logger.error(f"响应内容: {response.text[:500]}")
|
||||
return {
|
||||
"success": False,
|
||||
"detail": f"Twitter API请求失败,状态码: {response.status_code}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户推文失败: {str(e)}")
|
||||
logger.exception("详细错误信息:")
|
||||
return {
|
||||
"success": False,
|
||||
"detail": f"获取用户推文失败: {str(e)}"
|
||||
}
|
||||
|
||||
def generate_summary(self, tweet_content, username):
|
||||
"""根据推文内容生成总结 - 使用HTTP请求方式"""
|
||||
try:
|
||||
# 记录推文内容长度
|
||||
content_length = len(tweet_content) if tweet_content else 0
|
||||
logger.info(f"生成摘要: 用户={username}, 推文长度={content_length}")
|
||||
|
||||
# 如果推文内容为空,返回错误
|
||||
if not tweet_content:
|
||||
logger.warning(f"推文内容为空,无法生成摘要")
|
||||
return f"无法获取用户 @{username} 的推文数据,请稍后再试。"
|
||||
|
||||
# 计算最大摘要长度
|
||||
prefix = f"{username} 的最近推文摘要:"
|
||||
max_tweet_length = 280
|
||||
max_summary_length = max_tweet_length - len(prefix) - 1
|
||||
|
||||
# 使用HTTP请求调用DeepSeek API生成摘要
|
||||
try:
|
||||
logger.info("调用DeepSeek API生成摘要")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant that summarizes tweets accurately"},
|
||||
{"role": "user", "content": f"请总结用户 @{username} 以下推文的主要内容,必须使用简体中文,总结必须以序号(1.、2.、...)分隔每条要点,每条要点简洁明了且以句号或感叹号结尾,总长度不得超过 {max_summary_length} 字符:\n{tweet_content}"}
|
||||
]
|
||||
|
||||
summary_content = self.call_deepseek_api(messages)
|
||||
|
||||
if summary_content:
|
||||
logger.info(f"成功生成摘要,长度: {len(summary_content)}")
|
||||
else:
|
||||
logger.error("DeepSeek API调用失败")
|
||||
return f"生成摘要失败,请稍后再试。"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DeepSeek API调用失败: {str(e)}")
|
||||
logger.exception("详细错误信息:")
|
||||
return f"生成摘要失败,请稍后再试。错误: {str(e)[:100]}"
|
||||
|
||||
# 确保返回的内容不超过限制
|
||||
summary = summary_content.strip()
|
||||
if len(summary) > max_summary_length:
|
||||
summary = summary[:max_summary_length]
|
||||
|
||||
# 检查并删除不完整的最后一条要点
|
||||
lines = summary.split('\n')
|
||||
if lines:
|
||||
last_line = lines[-1]
|
||||
# 检查最后一条是否以句号、感叹号或问号结尾
|
||||
if not last_line.endswith(('。', '!', '?', '.')):
|
||||
# 如果不完整,删除最后一条
|
||||
lines.pop()
|
||||
summary = '\n'.join(lines)
|
||||
|
||||
logger.info(f"成功为用户 '{username}' 生成摘要")
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成摘要失败: {str(e)}")
|
||||
logger.exception("详细错误信息:")
|
||||
return f"生成摘要失败,请稍后再试。"
|
||||
|
||||
def post_to_twitter(self, summary, username):
|
||||
"""将摘要发布到Twitter"""
|
||||
try:
|
||||
prefix = f"{username} 的最近推文摘要:"
|
||||
tweet_text = f"{prefix}\n{summary}"
|
||||
|
||||
logger.info(f"准备发送推文: {tweet_text[:100]}...")
|
||||
logger.info(f"推文长度: {len(tweet_text)} 字符")
|
||||
|
||||
# 发送推文
|
||||
response = self.twitter_client.create_tweet(text=tweet_text)
|
||||
|
||||
# 获取发送的推文ID
|
||||
tweet_id = response.data['id'] if hasattr(response, 'data') and 'id' in response.data else "未知"
|
||||
logger.info(f"推文发送成功,ID: {tweet_id}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tweet_id": tweet_id,
|
||||
"tweet_url": f"https://twitter.com/user/status/{tweet_id}" if tweet_id != "未知" else None
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"发送推文失败: {str(e)}")
|
||||
logger.exception("详细错误信息:")
|
||||
return {
|
||||
"success": False,
|
||||
"detail": f"发送推文失败: {str(e)}"
|
||||
}
|
||||
|
||||
|
||||
# 创建TwitterPostService实例
|
||||
twitter_post_service = TwitterPostService()
|
||||
|
||||
|
||||
@router.post("/create")
|
||||
async def create_twitter_post(request: TwitterPostRequest):
|
||||
"""获取Twitter用户最近推文的摘要并可选择发布到Twitter"""
|
||||
try:
|
||||
username = request.username
|
||||
post_to_twitter = request.post_to_twitter
|
||||
|
||||
logger.info(f"收到Twitter发推请求,用户名: {username}, 是否发送: {post_to_twitter}")
|
||||
|
||||
# 获取用户推文
|
||||
tweets_result = twitter_post_service.get_user_tweets(username)
|
||||
logger.info(f"获取推文结果: {json.dumps(tweets_result)[:200] if isinstance(tweets_result, dict) else '未知结果'}")
|
||||
|
||||
# 如果获取推文失败,直接返回错误信息
|
||||
if not tweets_result.get("success", False):
|
||||
logger.warning(f"获取推文失败: {tweets_result.get('detail', '未知错误')}")
|
||||
return {
|
||||
"detail": tweets_result.get("detail", "获取推文失败")
|
||||
}
|
||||
|
||||
# 生成摘要
|
||||
tweet_content = tweets_result.get("content", "")
|
||||
summary = twitter_post_service.generate_summary(tweet_content, username)
|
||||
logger.info(f"生成摘要: 长度={len(summary) if summary else 0}")
|
||||
|
||||
# 返回结果对象
|
||||
result = {
|
||||
"username": username,
|
||||
"summary": summary,
|
||||
"status": "success"
|
||||
}
|
||||
|
||||
# 如果需要发送到Twitter
|
||||
if post_to_twitter:
|
||||
logger.info(f"准备发送摘要到Twitter...")
|
||||
post_result = twitter_post_service.post_to_twitter(summary, username)
|
||||
|
||||
if post_result.get("success", False):
|
||||
result["tweet_posted"] = True
|
||||
result["tweet_id"] = post_result.get("tweet_id")
|
||||
result["tweet_url"] = post_result.get("tweet_url")
|
||||
else:
|
||||
result["tweet_posted"] = False
|
||||
result["tweet_error"] = post_result.get("detail", "发送推文失败")
|
||||
else:
|
||||
result["tweet_posted"] = False
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理Twitter发推请求失败: {str(e)}")
|
||||
logger.exception("详细错误信息:")
|
||||
# 与原始main.py中的错误处理一致
|
||||
return {
|
||||
"detail": f"处理请求失败: {str(e)}"
|
||||
}
|
||||
177
backend/app/routers/users.py
Normal file
177
backend/app/routers/users.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from app import schemas, models, database
|
||||
from passlib.context import CryptContext
|
||||
|
||||
router = APIRouter()
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
def get_db():
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
from typing import List
|
||||
|
||||
@router.get("/", response_model=List[schemas.UserOut])
|
||||
def list_users(db: Session = Depends(get_db)):
|
||||
return db.query(models.User).all()
|
||||
|
||||
|
||||
def get_password_hash(password):
|
||||
return pwd_context.hash(password)
|
||||
|
||||
@router.post("/register", response_model=schemas.UserOut)
|
||||
def register(user: schemas.UserCreate, db: Session = Depends(get_db)):
|
||||
try:
|
||||
# 检查用户名是否已存在
|
||||
db_user = db.query(models.User).filter(models.User.username == user.username).first()
|
||||
if db_user:
|
||||
print(f"用户名已存在: {user.username}")
|
||||
raise HTTPException(status_code=400, detail="用户名已被注册,请更换用户名")
|
||||
|
||||
hashed_password = get_password_hash(user.password)
|
||||
new_user = models.User(
|
||||
username=user.username,
|
||||
hashed_password=hashed_password,
|
||||
balance=user.balance,
|
||||
is_admin=user.is_admin
|
||||
)
|
||||
db.add(new_user)
|
||||
db.commit()
|
||||
db.refresh(new_user)
|
||||
print(f"用户注册成功: {user.username}, 初始余额: {user.balance}, 管理员权限: {user.is_admin}")
|
||||
return new_user
|
||||
except HTTPException:
|
||||
# 已处理的HTTP异常直接抛出
|
||||
raise
|
||||
except Exception as e:
|
||||
print("注册用户出错:", e)
|
||||
raise HTTPException(status_code=500, detail=f"注册失败: {e}")
|
||||
|
||||
|
||||
def verify_password(plain_password, hashed_password):
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from app.utils_jwt import create_access_token, get_current_user
|
||||
|
||||
@router.post("/login")
|
||||
def login(user: schemas.UserCreate, db: Session = Depends(get_db)):
|
||||
# 检查用户是否存在
|
||||
db_user = db.query(models.User).filter(models.User.username == user.username).first()
|
||||
if not db_user:
|
||||
print(f"登录失败 - 用户不存在: {user.username}")
|
||||
raise HTTPException(status_code=400, detail="用户名或密码错误")
|
||||
|
||||
# 验证密码
|
||||
if not verify_password(user.password, db_user.hashed_password):
|
||||
print(f"登录失败 - 密码错误: {user.username}")
|
||||
raise HTTPException(status_code=400, detail="用户名或密码错误")
|
||||
|
||||
# 登录成功,生成token
|
||||
token = create_access_token({"sub": str(db_user.id)})
|
||||
print(f"用户登录成功: {user.username}, 余额: {db_user.balance}")
|
||||
return {"access_token": token, "token_type": "bearer", "user": {"id": db_user.id, "username": db_user.username, "balance": db_user.balance, "is_admin": db_user.is_admin}}
|
||||
|
||||
# 支持OAuth2标准token获取
|
||||
@router.post("/token")
|
||||
def token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
||||
db_user = db.query(models.User).filter(models.User.username == form_data.username).first()
|
||||
if not db_user or not verify_password(form_data.password, db_user.hashed_password):
|
||||
raise HTTPException(status_code=400, detail="Incorrect username or password")
|
||||
token = create_access_token({"sub": str(db_user.id)})
|
||||
return {"access_token": token, "token_type": "bearer"}
|
||||
|
||||
|
||||
# 管理员创建用户(包括设置余额和权限)
|
||||
@router.post("/create", response_model=schemas.UserOut)
|
||||
def create_user(user: schemas.UserCreate, db: Session = Depends(get_db), current_user=Depends(get_current_user)):
|
||||
# 检查当前用户是否为管理员
|
||||
if not getattr(current_user, "is_admin", False):
|
||||
raise HTTPException(status_code=403, detail="无权限,只有管理员可以创建用户")
|
||||
|
||||
try:
|
||||
# 检查用户名是否已存在
|
||||
db_user = db.query(models.User).filter(models.User.username == user.username).first()
|
||||
if db_user:
|
||||
raise HTTPException(status_code=400, detail="用户名已被注册,请更换用户名")
|
||||
|
||||
hashed_password = get_password_hash(user.password)
|
||||
new_user = models.User(
|
||||
username=user.username,
|
||||
hashed_password=hashed_password,
|
||||
balance=user.balance,
|
||||
is_admin=user.is_admin
|
||||
)
|
||||
db.add(new_user)
|
||||
db.commit()
|
||||
db.refresh(new_user)
|
||||
print(f"管理员创建用户成功: {user.username}, 初始余额: {user.balance}, 管理员权限: {user.is_admin}")
|
||||
return new_user
|
||||
except HTTPException:
|
||||
# 已处理的HTTP异常直接抛出
|
||||
raise
|
||||
except Exception as e:
|
||||
print("创建用户出错:", e)
|
||||
raise HTTPException(status_code=500, detail=f"创建用户失败: {e}")
|
||||
|
||||
|
||||
# 管理员更新用户信息
|
||||
@router.put("/update/{user_id}", response_model=dict)
|
||||
def update_user(user_id: int, user_update: schemas.UserUpdate, db: Session = Depends(get_db), current_user=Depends(get_current_user)):
|
||||
# 检查当前用户是否为管理员
|
||||
if not getattr(current_user, "is_admin", False):
|
||||
raise HTTPException(status_code=403, detail="无权限,只有管理员可以更新用户信息")
|
||||
|
||||
try:
|
||||
# 查找用户
|
||||
db_user = db.query(models.User).filter(models.User.id == user_id).first()
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
# 更新用户信息
|
||||
update_data = user_update.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
# 如果字段值不为Null,才更新
|
||||
if value is not None:
|
||||
setattr(db_user, field, value)
|
||||
|
||||
db.commit()
|
||||
print(f"管理员更新用户成功: {db_user.username}, 余额: {db_user.balance}, 管理员权限: {db_user.is_admin}")
|
||||
return {"msg": "更新成功", "id": db_user.id}
|
||||
except HTTPException:
|
||||
# 已处理的HTTP异常直接抛出
|
||||
raise
|
||||
except Exception as e:
|
||||
print("更新用户出错:", e)
|
||||
raise HTTPException(status_code=500, detail=f"更新用户失败: {e}")
|
||||
|
||||
|
||||
# 管理员删除用户
|
||||
@router.delete("/delete/{user_id}", response_model=dict)
|
||||
def delete_user(user_id: int, db: Session = Depends(get_db), current_user=Depends(get_current_user)):
|
||||
# 检查当前用户是否为管理员
|
||||
if not getattr(current_user, "is_admin", False):
|
||||
raise HTTPException(status_code=403, detail="无权限,只有管理员可以删除用户")
|
||||
|
||||
try:
|
||||
# 查找用户
|
||||
db_user = db.query(models.User).filter(models.User.id == user_id).first()
|
||||
if not db_user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
# 删除用户
|
||||
username = db_user.username # 保存用户名以便记录
|
||||
db.delete(db_user)
|
||||
db.commit()
|
||||
print(f"管理员删除用户成功: {username}")
|
||||
return {"msg": "删除成功"}
|
||||
except HTTPException:
|
||||
# 已处理的HTTP异常直接抛出
|
||||
raise
|
||||
except Exception as e:
|
||||
print("删除用户出错:", e)
|
||||
raise HTTPException(status_code=500, detail=f"删除用户失败: {e}")
|
||||
31
backend/app/schemas.py
Normal file
31
backend/app/schemas.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
balance: float = 0.0
|
||||
is_admin: bool = False
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
username: str = None
|
||||
balance: float = None
|
||||
is_admin: bool = None
|
||||
|
||||
class UserOut(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
balance: float
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
class HistoryOut(BaseModel):
|
||||
id: int
|
||||
user_id: int
|
||||
type: str
|
||||
amount: float
|
||||
desc: str
|
||||
timestamp: str
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
14
backend/app/show_admin.py
Normal file
14
backend/app/show_admin.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# 文件已删除,admin 查询脚本不再需要。
|
||||
|
||||
db = database.SessionLocal()
|
||||
admin = db.query(models.User).filter(models.User.username == "admin").first()
|
||||
if admin:
|
||||
print("id:", admin.id)
|
||||
print("username:", admin.username)
|
||||
print("is_admin:", admin.is_admin)
|
||||
print("is_active:", admin.is_active)
|
||||
print("balance:", admin.balance)
|
||||
print("hashed_password:", admin.hashed_password)
|
||||
else:
|
||||
print("admin 用户不存在")
|
||||
db.close()
|
||||
44
backend/app/utils_jwt.py
Normal file
44
backend/app/utils_jwt.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from datetime import datetime, timedelta
|
||||
from jose import JWTError, jwt
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from app import models, database
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
SECRET_KEY = "mysecretkey123456" # 生产环境请用更安全的key
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/token")
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta = None):
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
def get_db():
|
||||
db = database.SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)):
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id: int = payload.get("sub")
|
||||
if user_id is None:
|
||||
raise credentials_exception
|
||||
except JWTError:
|
||||
raise credentials_exception
|
||||
user = db.query(models.User).filter(models.User.id == int(user_id)).first()
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
||||
167
backend/init_db.py
Normal file
167
backend/init_db.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import os
|
||||
import sqlite3
|
||||
from datetime import datetime, timedelta
|
||||
from passlib.context import CryptContext
|
||||
|
||||
# 加密工具
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
def init_database():
|
||||
"""初始化数据库,确保所有必要的表和数据存在"""
|
||||
# 连接数据库
|
||||
db_path = 'aiplatform.db'
|
||||
|
||||
# 如果数据库文件已存在,先删除
|
||||
if os.path.exists(db_path):
|
||||
print(f"删除现有数据库文件: {db_path}")
|
||||
os.remove(db_path)
|
||||
|
||||
print(f"创建新数据库: {db_path}")
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 创建表(如果不存在)
|
||||
create_tables(cursor)
|
||||
|
||||
# 初始化基础数据
|
||||
init_data(cursor)
|
||||
|
||||
# 提交更改
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
print("数据库初始化完成")
|
||||
|
||||
def create_tables(cursor):
|
||||
"""创建所有必要的表"""
|
||||
print("创建数据库表...")
|
||||
|
||||
# 用户表
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT UNIQUE,
|
||||
email TEXT UNIQUE,
|
||||
hashed_password TEXT,
|
||||
is_active INTEGER DEFAULT 1,
|
||||
is_admin INTEGER DEFAULT 0,
|
||||
balance REAL DEFAULT 0
|
||||
)
|
||||
''')
|
||||
|
||||
# 应用表
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS apps (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT,
|
||||
desc TEXT,
|
||||
price REAL,
|
||||
status TEXT DEFAULT '上架',
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
# 订单表
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS orders (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER,
|
||||
app_id INTEGER,
|
||||
type TEXT DEFAULT '应用调用',
|
||||
amount REAL,
|
||||
description TEXT,
|
||||
status TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users (id),
|
||||
FOREIGN KEY (app_id) REFERENCES apps (id)
|
||||
)
|
||||
''')
|
||||
|
||||
# 历史记录表
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS history (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER,
|
||||
type TEXT,
|
||||
amount REAL,
|
||||
desc TEXT,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users (id)
|
||||
)
|
||||
''')
|
||||
|
||||
print("数据库表创建完成")
|
||||
|
||||
def init_data(cursor):
|
||||
"""初始化基础数据"""
|
||||
print("初始化基础数据...")
|
||||
|
||||
# 1. 创建admin用户
|
||||
print("创建admin用户...")
|
||||
hashed_password = pwd_context.hash("admin123")
|
||||
cursor.execute('''
|
||||
INSERT INTO users (username, hashed_password, is_admin, is_active, balance)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
''', ("admin", hashed_password, 1, 1, 100))
|
||||
|
||||
# 2. 创建Twitter应用
|
||||
print("创建Twitter应用...")
|
||||
apps = [
|
||||
("Twitter推文摘要", "输入Twitter用户名,获取最近推文摘要", 12.0, "上架", datetime.utcnow()),
|
||||
("Twitter自动发推", "输入Twitter用户名,获取摘要并发送到Twitter", 15.0, "上架", datetime.utcnow())
|
||||
]
|
||||
|
||||
for app in apps:
|
||||
cursor.execute('''
|
||||
INSERT INTO apps (name, desc, price, status, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
''', app)
|
||||
|
||||
# 获取新创建的应用ID
|
||||
cursor.execute("SELECT id FROM apps WHERE name = 'Twitter推文摘要'")
|
||||
summary_app_id = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute("SELECT id FROM apps WHERE name = 'Twitter自动发推'")
|
||||
post_app_id = cursor.fetchone()[0]
|
||||
|
||||
# 3. 创建示例订单
|
||||
print("创建示例订单...")
|
||||
# 获取admin用户ID
|
||||
cursor.execute("SELECT id FROM users WHERE username = 'admin'")
|
||||
admin_id = cursor.fetchone()[0]
|
||||
|
||||
days_ago_5 = datetime.utcnow() - timedelta(days=5)
|
||||
days_ago_2 = datetime.utcnow() - timedelta(days=2)
|
||||
|
||||
orders = [
|
||||
(admin_id, summary_app_id, "应用调用", 12.0, "使用Twitter推文摘要服务", "已完成", days_ago_5),
|
||||
(admin_id, post_app_id, "应用调用", 15.0, "使用Twitter自动发推服务", "已完成", days_ago_2)
|
||||
]
|
||||
|
||||
for order in orders:
|
||||
cursor.execute('''
|
||||
INSERT INTO orders (user_id, app_id, type, amount, description, status, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
''', order)
|
||||
|
||||
# 4. 创建历史记录
|
||||
print("创建历史记录...")
|
||||
days_ago_10 = datetime.utcnow() - timedelta(days=10)
|
||||
|
||||
histories = [
|
||||
(admin_id, "recharge", 100.0, "账户充值", days_ago_10),
|
||||
(admin_id, "consume", -12.0, "使用Twitter推文摘要服务", days_ago_5),
|
||||
(admin_id, "consume", -15.0, "使用Twitter自动发推服务", days_ago_2)
|
||||
]
|
||||
|
||||
for history in histories:
|
||||
cursor.execute('''
|
||||
INSERT INTO history (user_id, type, amount, desc, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
''', history)
|
||||
|
||||
# 如果想要手动运行初始化
|
||||
if __name__ == "__main__":
|
||||
print("开始初始化数据库...")
|
||||
init_database()
|
||||
print("数据库初始化完成!")
|
||||
15
backend/requirements.txt
Normal file
15
backend/requirements.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
fastapi
|
||||
uvicorn
|
||||
sqlalchemy
|
||||
pydantic
|
||||
passlib[bcrypt]
|
||||
python-jose
|
||||
python-multipart
|
||||
PyJWT
|
||||
openai
|
||||
tweepy
|
||||
requests
|
||||
pandas
|
||||
openpyxl
|
||||
pydub
|
||||
ffmpeg-python
|
||||
BIN
backend/上市公司信息表.xlsx
Normal file
BIN
backend/上市公司信息表.xlsx
Normal file
Binary file not shown.
Reference in New Issue
Block a user