217 lines
8.2 KiB
Python
217 lines
8.2 KiB
Python
from fastapi import FastAPI, Depends
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import HTMLResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from app.routers import users, balance, apps, history, orders, twitter, twitter_post, news_stock, ai_chatbot # 统一导入所有路由模块
|
||
from app import models, database
|
||
from passlib.context import CryptContext
|
||
from sqlalchemy.orm import Session
|
||
import requests
|
||
import logging
|
||
from openai import OpenAI
|
||
|
||
app = FastAPI(title="AI Platform API")
|
||
|
||
# 自动建表,确保所有表都存在
|
||
models.Base.metadata.create_all(bind=database.engine)
|
||
|
||
# 加密工具
|
||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||
|
||
# 初始化数据
|
||
@app.on_event("startup")
|
||
async def init_db_data():
|
||
"""在应用启动时初始化必要的数据"""
|
||
db = database.SessionLocal()
|
||
try:
|
||
# 1. 创建admin用户(如果不存在)
|
||
admin_user = db.query(models.User).filter(models.User.username == "admin").first()
|
||
if not admin_user:
|
||
print("创建admin用户...")
|
||
hashed_password = pwd_context.hash("admin123")
|
||
admin_user = models.User(
|
||
username="admin",
|
||
hashed_password=hashed_password,
|
||
is_admin=True,
|
||
is_active=True,
|
||
balance=100
|
||
)
|
||
db.add(admin_user)
|
||
db.commit()
|
||
db.refresh(admin_user)
|
||
print("admin用户创建成功")
|
||
|
||
# 2. 创建初始应用(如果应用表为空)
|
||
app_count = db.query(models.App).count()
|
||
if app_count == 0:
|
||
print("创建初始应用...")
|
||
default_apps = [
|
||
models.App(name="Twitter推文摘要", desc="输入Twitter用户名,获取最近推文摘要", price=12, status="上架"),
|
||
models.App(name="Twitter自动发推", desc="输入Twitter用户名,获取摘要并发送到Twitter", price=15, status="上架"),
|
||
models.App(name="热点新闻选股", desc="分析热点新闻对股票的影响,提供选股建议", price=20, status="上架") # 添加热点新闻选股应用
|
||
]
|
||
db.add_all(default_apps)
|
||
db.commit()
|
||
print("初始应用创建成功")
|
||
|
||
# 3. 创建示例订单数据(如果订单表为空)
|
||
order_count = db.query(models.Order).count()
|
||
if order_count == 0:
|
||
print("创建示例订单数据...")
|
||
|
||
# 获取用户和应用
|
||
users = db.query(models.User).all()
|
||
apps = db.query(models.App).all()
|
||
|
||
if users and apps:
|
||
# 创建一些示例订单
|
||
from datetime import datetime, timedelta
|
||
|
||
default_orders = [
|
||
models.Order(
|
||
user_id=users[0].id,
|
||
app_id=apps[0].id,
|
||
type="应用调用",
|
||
amount=apps[0].price,
|
||
description="使用Twitter推文摘要服务",
|
||
status="已完成",
|
||
created_at=datetime.utcnow() - timedelta(days=5)
|
||
),
|
||
models.Order(
|
||
user_id=users[0].id,
|
||
app_id=apps[1].id,
|
||
type="应用调用",
|
||
amount=apps[1].price,
|
||
description="使用Twitter自动发推服务",
|
||
status="已完成",
|
||
created_at=datetime.utcnow() - timedelta(days=2)
|
||
)
|
||
]
|
||
|
||
db.add_all(default_orders)
|
||
db.commit()
|
||
print("示例订单创建成功")
|
||
|
||
# 4. 创建示例历史记录数据(如果历史记录表为空)
|
||
history_count = db.query(models.History).count()
|
||
if history_count == 0:
|
||
print("创建示例历史记录数据...")
|
||
|
||
# 获取用户
|
||
users = db.query(models.User).all()
|
||
|
||
if users:
|
||
# 创建一些示例历史记录
|
||
from datetime import datetime, timedelta
|
||
|
||
default_history = [
|
||
models.History(
|
||
user_id=users[0].id,
|
||
type="recharge",
|
||
amount=100,
|
||
desc="账户充值",
|
||
created_at=datetime.utcnow() - timedelta(days=10)
|
||
),
|
||
models.History(
|
||
user_id=users[0].id,
|
||
type="consume",
|
||
amount=-12,
|
||
desc="使用Twitter推文摘要服务",
|
||
created_at=datetime.utcnow() - timedelta(days=5)
|
||
),
|
||
models.History(
|
||
user_id=users[0].id,
|
||
type="consume",
|
||
amount=-15,
|
||
desc="使用Twitter自动发推服务",
|
||
created_at=datetime.utcnow() - timedelta(days=2)
|
||
)
|
||
]
|
||
|
||
db.add_all(default_history)
|
||
db.commit()
|
||
print("示例历史记录创建成功")
|
||
except Exception as e:
|
||
print(f"初始化数据时出错: {e}")
|
||
finally:
|
||
db.close()
|
||
|
||
# 配置CORS
|
||
# 使用统一的跨域配置
|
||
allowed_origins = [
|
||
# 生产环境
|
||
"http://174.129.175.43:3000", # EC2实例1
|
||
"http://44.206.227.249:3000", # EC2实例2
|
||
"http://52.91.169.148:3000", # EC2实例3
|
||
"https://rongye.xyz", # 新域名
|
||
"http://rongye.xyz", # 新域名(http)
|
||
|
||
# 本地开发环境
|
||
"http://localhost:3000",
|
||
"http://127.0.0.1:3000",
|
||
"http://localhost",
|
||
"http://127.0.0.1",
|
||
]
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=allowed_origins,
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
expose_headers=["*"] # 允许前端访问所有响应头
|
||
)
|
||
|
||
# 注册路由 - 不再需要在这里添加/admin前缀,因为已经在router中定义了
|
||
|
||
app.include_router(users.router, prefix="/users", tags=["users"])
|
||
app.include_router(balance.router, prefix="/balance", tags=["balance"])
|
||
app.include_router(apps.router, prefix="/apps", tags=["apps"])
|
||
app.include_router(history.router, prefix="/history", tags=["history"])
|
||
app.include_router(orders.router, prefix="/orders", tags=["orders"])
|
||
app.include_router(twitter.router, prefix="/twitter", tags=["twitter"])
|
||
app.include_router(twitter_post.router, prefix="/twitter-post", tags=["twitter_post"])
|
||
app.include_router(news_stock.router, prefix="/news-stock", tags=["news_stock"]) # 添加热点新闻选股路由
|
||
app.include_router(ai_chatbot.router, prefix="/ai-chatbot", tags=["ai_chatbot"]) # 添加AI客服路由
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
return {"message": "Welcome to AI Platform API"}
|
||
|
||
@app.get("/test-twitter")
|
||
async def test_twitter():
|
||
try:
|
||
# 测试Twitter API连接
|
||
twitter_api_key = "e3dad005b0e54bdc88c6178a89adec13"
|
||
twitter_api_url = "https://api.twitterapi.io/twitter/tweet/advanced_search"
|
||
headers = {"X-API-Key": twitter_api_key}
|
||
params = {
|
||
"queryType": "Latest",
|
||
"query": "from:elonmusk",
|
||
"count": 5
|
||
}
|
||
|
||
twitter_response = requests.get(twitter_api_url, headers=headers, params=params)
|
||
|
||
# 初始化OpenAI客户端
|
||
openai_client = OpenAI(
|
||
api_key="sk-8a121704a9bc4ec6a5ab0ae16e0bc0ba",
|
||
base_url="https://api.deepseek.com"
|
||
)
|
||
|
||
return {
|
||
"status": "ok",
|
||
"twitter_api_status": twitter_response.status_code,
|
||
"twitter_api_response": twitter_response.json() if twitter_response.status_code == 200 else str(twitter_response.text)[:100],
|
||
"apis_available": {
|
||
"twitter_api": True,
|
||
"openai_api": True
|
||
}
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
"status": "error",
|
||
"message": str(e),
|
||
"type": str(type(e))
|
||
}
|