Files
AI-Website/backend/app/routers/ai_chatbot.py
2026-01-09 09:48:57 +08:00

641 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from fastapi import APIRouter, UploadFile, File, HTTPException
from fastapi.responses import FileResponse
from pydantic import BaseModel
import os
import time
import base64
import hashlib
import requests
import hmac
import urllib.parse
import http.client
from urllib.parse import urlencode
import json
from openai import OpenAI
from pydub import AudioSegment
import tempfile
import logging
import asyncio
from concurrent.futures import ThreadPoolExecutor
import threading
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
router = APIRouter()
# 阿里云语音服务配置
ALIYUN_ACCESS_KEY_ID = "LTAI5t5ZrbKQuuwkmQ1LFCBo"
ALIYUN_ACCESS_KEY_SECRET = "2vvspr0HcmmnBFzpXw4iNyLafSgUuN"
ALIYUN_APP_KEY = "wlIvC6tOAvQLoQDz"
ALIYUN_REGION = "cn-shanghai"
ALIYUN_HOST = "nls-gateway-cn-shanghai.aliyuncs.com"
# DeepSeek配置
DEEPSEEK_API_KEY = "sk-8a121704a9bc4ec6a5ab0ae16e0bc0ba"
DEEPSEEK_BASE_URL = "https://api.deepseek.com"
# 音频文件存储目录
AUDIO_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "audio")
os.makedirs(AUDIO_DIR, exist_ok=True)
# 全局变量存储token和过期时间
token_info = {
'token': None,
'expire_time': 0
}
# 添加线程池用于异步处理
executor = ThreadPoolExecutor(max_workers=4)
# 简单的内存缓存
response_cache = {}
cache_expire_time = {}
def get_signature(secret, text):
"""生成签名"""
h = hmac.new(secret.encode('utf-8'), text.encode('utf-8'), hashlib.sha1)
return base64.b64encode(h.digest()).decode('utf-8')
def get_aliyun_token(force_refresh=False):
"""获取阿里云访问令牌,带缓存和自动刷新"""
global token_info
# 检查token是否有效提前5分钟刷新
current_time = int(time.time())
if not force_refresh and token_info['token'] and token_info['expire_time'] > current_time + 300:
logger.info("使用缓存的阿里云Token")
return token_info['token']
logger.info("正在获取阿里云访问令牌..." + (" (强制刷新)" if force_refresh else ""))
# 构建请求参数
params = {
"Action": "CreateToken",
"Version": "2019-02-28",
"Format": "JSON",
"AccessKeyId": ALIYUN_ACCESS_KEY_ID,
"Timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"SignatureMethod": "HMAC-SHA1",
"SignatureVersion": "1.0",
"SignatureNonce": str(int(time.time() * 1000))
}
# 对参数进行排序并生成查询字符串
sorted_params = sorted(params.items(), key=lambda x: x[0])
query_string = urlencode(dict(sorted_params))
# 构建待签名字符串
string_to_sign = f"POST&%2F&{urllib.parse.quote_plus(query_string)}"
# 计算签名
signature = get_signature(ALIYUN_ACCESS_KEY_SECRET + "&", string_to_sign)
params["Signature"] = signature
try:
conn = http.client.HTTPSConnection("nls-meta.cn-shanghai.aliyuncs.com")
headers = {"Content-Type": "application/x-www-form-urlencoded"}
# 发送请求
conn.request("POST", "/", headers=headers, body=urlencode(params))
response = conn.getresponse()
result = json.loads(response.read().decode('utf-8'))
if 'Token' in result and 'Id' in result['Token']:
# 更新token信息设置过期时间为55分钟后
token_info['token'] = result['Token']['Id']
token_info['expire_time'] = int(time.time()) + 3300 # 55分钟
logger.info("获取阿里云Token成功")
return token_info['token']
else:
logger.error(f"获取阿里云Token失败: {result}")
return None
except Exception as e:
logger.error(f"获取阿里云Token时发生异常: {str(e)}")
return None
finally:
if 'conn' in locals():
conn.close()
def aliyun_tts_chinese_text_to_audio(chinese_text, output_audio_path):
"""使用阿里云TTS接口进行中文语音合成"""
token = get_aliyun_token()
if not token:
raise Exception("无法获取阿里云访问令牌")
logger.info(f"正在合成语音: {chinese_text[:50]}...")
try:
conn = http.client.HTTPSConnection(ALIYUN_HOST, timeout=10)
headers = {
"Content-Type": "application/json",
"X-NLS-Token": token
}
# 使用更自然的音色和参数配置
tts_data = {
"appkey": ALIYUN_APP_KEY,
"text": chinese_text[:100], # 限制长度避免过长
"format": "wav",
"voice": "aixia", # 使用艾夏音色符合智能语音交互2.0配置
"volume": 50,
"speech_rate": 0, # 正常语速
"pitch_rate": 0, # 正常语调
"sample_rate": 16000
}
conn.request("POST", "/stream/v1/tts",
body=json.dumps(tts_data),
headers=headers)
response = conn.getresponse()
if response.status == 200:
# 检查响应类型
content_type = response.getheader('Content-Type', '')
response_data = response.read()
if 'audio' in content_type:
# 是音频数据
with open(output_audio_path, 'wb') as f:
f.write(response_data)
logger.info(f"语音合成成功,已保存到: {output_audio_path}, 大小: {len(response_data)} bytes")
return True
else:
# 可能是错误信息
try:
error_info = json.loads(response_data.decode('utf-8'))
logger.error(f"阿里云TTS返回错误信息: {json.dumps(error_info, indent=2, ensure_ascii=False)}")
except:
logger.error(f"阿里云TTS返回非JSON响应: {response_data[:500]}")
return False
else:
error_message = response.read().decode('utf-8')
logger.error(f"TTS请求失败: {response.status} {response.reason}, 错误: {error_message}")
return False
except Exception as e:
logger.error(f"TTS合成时发生异常: {str(e)}")
return False
finally:
if 'conn' in locals():
conn.close()
def convert_audio_format(input_file, target_sample_rate=16000, target_channels=1):
"""转换音频格式到WAV格式确保符合阿里云ASR要求"""
try:
logger.info(f"开始转换音频格式: {input_file}")
# 检查输入文件是否存在
if not os.path.exists(input_file):
logger.error(f"输入音频文件不存在: {input_file}")
return None
# 加载音频文件,自动检测格式
try:
audio = AudioSegment.from_file(input_file)
except Exception as e:
logger.error(f"无法读取音频文件: {e}")
return None
logger.info(f"原始音频信息 - 时长: {len(audio)}ms, 采样率: {audio.frame_rate}, 声道: {audio.channels}")
# 检查音频时长阿里云ASR限制60秒
if len(audio) > 60000: # 60秒 = 60000毫秒
logger.warning(f"音频时长({len(audio)/1000:.2f}s)超过60秒限制将截取前60秒")
audio = audio[:60000]
# 转换为目标格式16kHz单声道WAV
audio = audio.set_frame_rate(target_sample_rate).set_channels(target_channels)
# 确保是16位PCM编码
audio = audio.set_sample_width(2) # 2字节 = 16位
# 生成输出文件名
base_name = os.path.splitext(input_file)[0]
output_file = f"{base_name}_converted.wav"
# 导出为WAV格式
audio.export(output_file, format="wav", parameters=["-acodec", "pcm_s16le"])
logger.info(f"音频格式转换成功: {input_file} -> {output_file}")
logger.info(f"转换后音频信息 - 时长: {len(audio)}ms, 采样率: {target_sample_rate}, 声道: {target_channels}")
return output_file
except Exception as e:
logger.error(f"音频格式转换失败: {str(e)}")
return None
def aliyun_asr_chinese_audio_to_text(audio_path):
"""使用阿里云智能语音交互2.0 RESTful API进行中文语音识别"""
max_retries = 3
for attempt in range(max_retries):
# 获取token如果是重试则强制刷新
token = get_aliyun_token(force_refresh=(attempt > 0))
if not token:
logger.error("获取阿里云Token失败无法继续识别")
return ""
logger.info(f"正在识别音频: {audio_path} (尝试 {attempt + 1}/{max_retries})")
# 检查音频文件
if not os.path.exists(audio_path):
logger.error(f"音频文件不存在: {audio_path}")
return ""
# 转换音频格式确保符合阿里云ASR要求
converted_audio = convert_audio_format(audio_path)
if not converted_audio:
logger.error("音频格式转换失败")
return ""
conn = None
try:
# 读取转换后的音频文件
with open(converted_audio, 'rb') as f:
audio_data = f.read()
logger.info(f"音频文件大小: {len(audio_data)} bytes")
# 检查音频文件大小
if len(audio_data) == 0:
logger.error("音频文件为空")
continue
# 使用阿里云智能语音交互2.0的RESTful API
conn = http.client.HTTPSConnection(ALIYUN_HOST, timeout=30)
# 设置正确的请求头
headers = {
"X-NLS-Token": token,
"Content-Type": "application/octet-stream",
"Content-Length": str(len(audio_data)),
"Host": ALIYUN_HOST
}
# 构建请求参数按照阿里云智能语音交互2.0文档
params = {
"appkey": ALIYUN_APP_KEY,
"format": "wav",
"sample_rate": 16000,
"enable_punctuation_prediction": "true",
"enable_inverse_text_normalization": "true",
"enable_voice_detection": "false"
}
# 构建完整的请求URL
query_string = urllib.parse.urlencode(params)
full_url = f"/stream/v1/asr?{query_string}"
logger.info(f"发送ASR请求URL: {full_url}")
logger.info(f"请求头: {headers}")
# 发送POST请求直接传输二进制音频数据
conn.request("POST", full_url, body=audio_data, headers=headers)
response = conn.getresponse()
response_data = response.read()
logger.info(f"ASR响应状态: {response.status}")
if response.status == 200:
try:
result = json.loads(response_data.decode('utf-8'))
logger.info(f"ASR响应结果: {result}")
# 检查响应状态
status = result.get('status', 0)
message = result.get('message', '')
if status == 20000000 and message == 'SUCCESS':
transcription = result.get('result', '').strip()
if transcription:
logger.info(f"语音识别成功: {transcription}")
return transcription
else:
logger.warning("识别成功但结果为空可能是1)音频内容为静音 2)语音不清晰 3)语言不匹配")
if attempt < max_retries - 1:
logger.info("正在重试...")
continue
return ""
else:
logger.error(f"ASR识别失败: status={status}, message={message}")
if attempt < max_retries - 1:
logger.info("正在重试...")
continue
return ""
except json.JSONDecodeError as json_error:
logger.error(f"解析ASR响应JSON失败: {json_error}")
logger.info(f"原始响应: {response_data[:1000]}")
if attempt < max_retries - 1:
logger.info("正在重试...")
continue
return ""
elif response.status == 401: # 未授权token可能已过期
logger.warning("Token无效或已过期")
if attempt < max_retries - 1:
logger.info("正在尝试刷新Token并重试...")
continue
else:
logger.error("Token刷新重试次数已用完")
return ""
elif response.status == 40000001:
logger.error("身份认证失败检查Token是否正确或过期")
return ""
elif response.status == 40000003:
logger.error("参数无效,检查音频格式和采样率")
return ""
elif response.status == 41010101:
logger.error("不支持的采样率当前仅支持8000Hz和16000Hz")
return ""
else:
error_message = response_data.decode('utf-8', errors='ignore')
logger.error(f"ASR识别失败: 状态码{response.status}, 错误信息: {error_message}")
if attempt < max_retries - 1:
logger.info("正在重试...")
continue
return ""
except Exception as e:
logger.error(f"ASR识别时发生异常: {str(e)}")
if attempt < max_retries - 1:
logger.info("正在重试...")
continue
return ""
finally:
# 清理临时文件
if converted_audio and os.path.exists(converted_audio) and converted_audio != audio_path:
try:
os.remove(converted_audio)
logger.info(f"已清理临时文件: {converted_audio}")
except:
pass
# 关闭连接
if conn:
conn.close()
# 所有重试都失败了
logger.error("所有ASR重试都失败")
return ""
def get_cache_key(question: str) -> str:
"""生成缓存键"""
return hashlib.md5(question.encode()).hexdigest()
def get_cached_response(question: str) -> str:
"""获取缓存的回答"""
cache_key = get_cache_key(question)
current_time = time.time()
# 检查缓存是否存在且未过期5分钟过期
if (cache_key in response_cache and
cache_key in cache_expire_time and
cache_expire_time[cache_key] > current_time):
logger.info(f"使用缓存回答: {cache_key}")
return response_cache[cache_key]
return None
def set_cached_response(question: str, answer: str):
"""设置缓存回答"""
cache_key = get_cache_key(question)
response_cache[cache_key] = answer
cache_expire_time[cache_key] = time.time() + 300 # 5分钟过期
logger.info(f"缓存回答: {cache_key}")
def get_deepseek_response(question):
"""使用DeepSeek API获取简洁回答"""
# 首先检查缓存
cached_answer = get_cached_response(question)
if cached_answer:
return cached_answer
try:
client = OpenAI(
api_key=DEEPSEEK_API_KEY,
base_url=DEEPSEEK_BASE_URL,
timeout=8.0 # 设置8秒超时
)
# 针对客服场景优化提示词
system_prompt = """你是一个专业的智能客服助手。请遵循以下原则:
1. 回答要简洁明了通常控制在40字以内
2. 语气要友好、专业、有帮助
3. 如果不确定答案,请诚实说明并建议联系人工客服
4. 重点解决客户的实际问题
5. 避免冗长的解释,直接给出有用信息"""
response = client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": question}
],
max_tokens=150, # 进一步限制回答长度
temperature=0.5 # 降低随机性,提高一致性
)
answer = response.choices[0].message.content.strip()
logger.info(f"DeepSeek回答: {answer}")
# 缓存回答
set_cached_response(question, answer)
return answer
except Exception as e:
logger.error(f"DeepSeek API调用失败: {str(e)}")
return "抱歉,智能客服暂时繁忙,请稍后重试或联系人工客服。"
async def get_deepseek_response_async(question):
"""异步调用DeepSeek API"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, get_deepseek_response, question)
class ChatRequest(BaseModel):
question: str
@router.post("/query")
async def chat_query(req: ChatRequest):
"""AI客服对话接口"""
try:
question = req.question.strip()
if not question:
raise HTTPException(status_code=400, detail="问题不能为空")
# 使用DeepSeek获取回答
answer = await get_deepseek_response_async(question)
# 生成语音文件
timestamp = int(time.time() * 1000)
audio_filename = f"ai_response_{timestamp}.wav"
audio_path = os.path.join(AUDIO_DIR, audio_filename)
# 语音合成
tts_success = aliyun_tts_chinese_text_to_audio(answer, audio_path)
response_data = {
"answer": answer,
"audio_url": f"/ai-chatbot/audio/{audio_filename}" if tts_success else None
}
return response_data
except Exception as e:
logger.error(f"查询处理失败: {str(e)}")
raise HTTPException(status_code=500, detail="处理请求时发生错误")
@router.post("/query-text")
async def chat_query_text_only(req: ChatRequest):
"""AI客服对话接口 - 仅返回文本,快速响应"""
try:
question = req.question.strip()
if not question:
raise HTTPException(status_code=400, detail="问题不能为空")
# 使用DeepSeek获取回答
answer = await get_deepseek_response_async(question)
response_data = {
"answer": answer,
"timestamp": int(time.time() * 1000) # 用于生成语音时的唯一标识
}
return response_data
except Exception as e:
logger.error(f"文本查询处理失败: {str(e)}")
raise HTTPException(status_code=500, detail="处理请求时发生错误")
class AudioRequest(BaseModel):
text: str
timestamp: int
@router.post("/generate-audio")
async def generate_audio(req: AudioRequest):
"""异步生成语音文件"""
try:
text = req.text.strip()
if not text:
raise HTTPException(status_code=400, detail="文本不能为空")
# 使用时间戳生成语音文件名
audio_filename = f"ai_response_{req.timestamp}.wav"
audio_path = os.path.join(AUDIO_DIR, audio_filename)
# 语音合成
tts_success = aliyun_tts_chinese_text_to_audio(text, audio_path)
if tts_success:
return {
"success": True,
"audio_url": f"/ai-chatbot/audio/{audio_filename}"
}
else:
return {
"success": False,
"error": "语音生成失败"
}
except Exception as e:
logger.error(f"语音生成失败: {str(e)}")
return {
"success": False,
"error": str(e)
}
@router.post("/asr")
async def speech_recognition(file: UploadFile = File(...)):
"""语音识别接口"""
try:
# 检查文件格式
if not file.filename.lower().endswith(('.wav', '.mp3', '.m4a', '.webm', '.ogg')):
logger.warning(f"不支持的文件格式: {file.filename}")
# 保存上传的音频文件
timestamp = int(time.time() * 1000)
# 保持原始文件扩展名让pydub自动检测格式
original_ext = os.path.splitext(file.filename)[1] if file.filename else '.wav'
temp_filename = f"temp_audio_{timestamp}{original_ext}"
temp_path = os.path.join(AUDIO_DIR, temp_filename)
# 保存文件
with open(temp_path, "wb") as buffer:
content = await file.read()
buffer.write(content)
logger.info(f"接收到音频文件: {file.filename}, 大小: {len(content)} bytes, 临时保存为: {temp_path}")
# 语音识别
transcription = aliyun_asr_chinese_audio_to_text(temp_path)
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
logger.info(f"已清理临时文件: {temp_path}")
return {"text": transcription}
except Exception as e:
logger.error(f"语音识别失败: {str(e)}")
# 清理临时文件
if 'temp_path' in locals() and os.path.exists(temp_path):
try:
os.remove(temp_path)
except:
pass
raise HTTPException(status_code=500, detail="语音识别失败")
@router.get("/audio/{filename}")
async def get_audio(filename: str):
"""获取音频文件"""
audio_path = os.path.join(AUDIO_DIR, filename)
if os.path.exists(audio_path):
return FileResponse(
audio_path,
media_type="audio/wav",
headers={"Content-Disposition": f"inline; filename={filename}"}
)
else:
raise HTTPException(status_code=404, detail="音频文件不存在")
# 添加音频文件清理功能
def cleanup_old_audio_files():
"""清理超过1小时的音频文件"""
try:
current_time = time.time()
for filename in os.listdir(AUDIO_DIR):
if filename.endswith('.wav'):
file_path = os.path.join(AUDIO_DIR, filename)
file_age = current_time - os.path.getctime(file_path)
# 删除超过1小时的文件
if file_age > 3600:
try:
os.remove(file_path)
logger.info(f"已清理过期音频文件: {filename}")
except Exception as e:
logger.error(f"清理音频文件失败 {filename}: {e}")
except Exception as e:
logger.error(f"音频文件清理过程出错: {e}")
def start_cleanup_timer():
"""启动定时清理任务"""
cleanup_old_audio_files()
# 每30分钟执行一次清理
threading.Timer(1800.0, start_cleanup_timer).start()
# 启动清理任务
start_cleanup_timer()