229 lines
8.3 KiB
Python
229 lines
8.3 KiB
Python
# glm_client.py
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
GLM-4.6v-Flash LLM 客户端 - Day 22
|
||
|
||
使用官方 zai-sdk + glm-4.6v-flash 模型
|
||
"""
|
||
|
||
import os
|
||
import asyncio
|
||
from datetime import datetime
|
||
from typing import AsyncGenerator, Optional
|
||
from zai import ZhipuAiClient
|
||
|
||
# API 配置
|
||
API_KEY = os.getenv("GLM_API_KEY")
|
||
if not API_KEY:
|
||
raise RuntimeError("未设置 GLM_API_KEY 环境变量,请在 .env 中配置")
|
||
MODEL = "glm-4.6v-flash" # 升级到 glm-4.6v-flash (支持视觉)
|
||
|
||
# 星期映射
|
||
WEEKDAY_MAP = ["星期一", "星期二", "星期三", "星期四", "星期五", "星期六", "星期日"]
|
||
|
||
|
||
def get_system_prompt() -> str:
|
||
"""动态生成 system prompt,包含当前时间信息"""
|
||
now = datetime.now()
|
||
current_time = now.strftime("%H:%M")
|
||
current_date = now.strftime("%Y年%m月%d日")
|
||
current_weekday = WEEKDAY_MAP[now.weekday()]
|
||
|
||
return f"""你是一个视障辅助AI助手,安装在智能导盲眼镜上。
|
||
当前时间:{current_time}
|
||
今天日期:{current_date} {current_weekday}
|
||
|
||
请用极简短的语言回答,每次回答不超过2-3句话。
|
||
避免冗长解释,只提供最关键的信息。
|
||
语气友好但简洁。"""
|
||
|
||
|
||
# 客户端和对话历史
|
||
_client = None
|
||
_conversation_history = []
|
||
MAX_HISTORY_TURNS = 5 # 保留最近5轮对话
|
||
|
||
|
||
def _get_client() -> ZhipuAiClient:
|
||
"""获取智谱 AI 客户端"""
|
||
global _client
|
||
if _client is None:
|
||
_client = ZhipuAiClient(api_key=API_KEY)
|
||
return _client
|
||
|
||
|
||
def clear_conversation_history():
|
||
"""清除对话历史"""
|
||
global _conversation_history
|
||
_conversation_history = []
|
||
print("[GLM] 对话历史已清除")
|
||
|
||
|
||
async def chat(user_message: str, image_base64: Optional[str] = None) -> str:
|
||
"""
|
||
与 GLM-4.6v-Flash 对话(带上下文记忆)
|
||
|
||
Args:
|
||
user_message: 用户消息文本
|
||
image_base64: 可选,Base64 编码的图片
|
||
|
||
Returns:
|
||
AI 回复文本
|
||
"""
|
||
global _conversation_history
|
||
client = _get_client()
|
||
|
||
# 构建用户消息
|
||
if image_base64:
|
||
# 多模态消息(带图片)
|
||
user_content = [
|
||
{"type": "text", "text": user_message},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
|
||
]
|
||
else:
|
||
user_content = user_message
|
||
|
||
# 添加用户消息到历史
|
||
_conversation_history.append({"role": "user", "content": user_content})
|
||
|
||
# 限制历史长度(每轮 = 1用户 + 1助手 = 2条消息)
|
||
max_messages = MAX_HISTORY_TURNS * 2
|
||
if len(_conversation_history) > max_messages:
|
||
_conversation_history = _conversation_history[-max_messages:]
|
||
|
||
# 构建完整消息列表(每次动态生成包含当前时间的 system prompt)
|
||
messages = [{"role": "system", "content": get_system_prompt()}] + _conversation_history
|
||
|
||
# Day 22: 添加重试逻辑处理速率限制
|
||
max_retries = 3
|
||
retry_delay = 1 # 初始延迟1秒
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
# Day 22: 升级到 glm-4.6v-flash
|
||
# 【修正】根据官方文档,thinking 参数也是必须的 (即使是 Vision 模型)
|
||
response = await asyncio.to_thread(
|
||
client.chat.completions.create,
|
||
model=MODEL,
|
||
messages=messages,
|
||
thinking={"type": "disabled"}, # 显式禁用思考以降低延迟
|
||
)
|
||
|
||
if response.choices and len(response.choices) > 0:
|
||
ai_reply = response.choices[0].message.content.strip()
|
||
# 添加助手回复到历史
|
||
_conversation_history.append({"role": "assistant", "content": ai_reply})
|
||
print(f"[GLM] 回复: {ai_reply[:50]}..." if len(ai_reply) > 50 else f"[GLM] 回复: {ai_reply}")
|
||
return ai_reply
|
||
return ""
|
||
|
||
except Exception as e:
|
||
error_str = str(e)
|
||
# 检查是否是速率限制错误(429 或 1305)
|
||
if "429" in error_str or "1305" in error_str or "请求过多" in error_str:
|
||
if attempt < max_retries - 1:
|
||
print(f"[GLM] 速率限制,{retry_delay}秒后重试... (尝试 {attempt + 1}/{max_retries})")
|
||
await asyncio.sleep(retry_delay)
|
||
retry_delay *= 2 # 指数退避
|
||
continue
|
||
|
||
print(f"[GLM] 调用失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
break
|
||
|
||
# 所有重试失败,移除用户消息
|
||
if _conversation_history and _conversation_history[-1]["role"] == "user":
|
||
_conversation_history.pop()
|
||
return "抱歉,我暂时无法回答。"
|
||
|
||
|
||
async def chat_stream(user_message: str, image_base64: Optional[str] = None) -> AsyncGenerator[str, None]:
|
||
"""
|
||
流式对话(逐字返回)- GLM-4.6v-Flash
|
||
|
||
Args:
|
||
user_message: 用户消息文本
|
||
image_base64: 可选,Base64 编码的图片
|
||
|
||
Yields:
|
||
AI 回复的文本片段
|
||
"""
|
||
global _conversation_history
|
||
client = _get_client()
|
||
|
||
# 构建用户消息
|
||
if image_base64:
|
||
user_content = [
|
||
{"type": "text", "text": user_message},
|
||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
|
||
]
|
||
else:
|
||
user_content = user_message
|
||
|
||
# 添加用户消息到历史
|
||
_conversation_history.append({"role": "user", "content": user_content})
|
||
|
||
# 限制历史长度
|
||
max_messages = MAX_HISTORY_TURNS * 2
|
||
if len(_conversation_history) > max_messages:
|
||
_conversation_history = _conversation_history[-max_messages:]
|
||
|
||
# 构建完整消息列表
|
||
messages = [{"role": "system", "content": get_system_prompt()}] + _conversation_history
|
||
|
||
full_response = ""
|
||
|
||
try:
|
||
# 流式调用
|
||
# Day 22: 升级到 glm-4.6v-flash
|
||
max_retries = 3
|
||
retry_delay = 1
|
||
|
||
response = None
|
||
for attempt in range(max_retries):
|
||
try:
|
||
# 【修正】根据官方文档,thinking 参数也是必须的
|
||
response = await asyncio.to_thread(
|
||
client.chat.completions.create,
|
||
model=MODEL,
|
||
messages=messages,
|
||
thinking={"type": "disabled"},
|
||
stream=True,
|
||
)
|
||
break # 成功则跳出循环
|
||
except Exception as e:
|
||
error_str = str(e)
|
||
if attempt < max_retries - 1:
|
||
if "429" in error_str or "1305" in error_str or "请求过多" in error_str:
|
||
print(f"[GLM] (流式) 速率限制,{retry_delay}秒后重试... ({attempt + 1}/{max_retries})")
|
||
await asyncio.sleep(retry_delay)
|
||
retry_delay *= 2
|
||
continue
|
||
# 其他网络错误也可以重试
|
||
print(f"[GLM] (流式) 连接错误: {e},重试... ({attempt + 1}/{max_retries})")
|
||
await asyncio.sleep(retry_delay)
|
||
continue
|
||
else:
|
||
raise e # 最后一次尝试失败,抛出异常
|
||
|
||
for chunk in response:
|
||
if chunk.choices[0].delta.content:
|
||
text = chunk.choices[0].delta.content
|
||
full_response += text
|
||
yield text
|
||
|
||
# 添加完整回复到历史
|
||
if full_response:
|
||
_conversation_history.append({"role": "assistant", "content": full_response})
|
||
print(f"[GLM] 流式完成: {full_response[:50]}..." if len(full_response) > 50 else f"[GLM] 流式完成: {full_response}")
|
||
|
||
except Exception as e:
|
||
print(f"[GLM] 流式调用失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
# 移除刚才添加的用户消息
|
||
if _conversation_history and _conversation_history[-1]["role"] == "user":
|
||
_conversation_history.pop()
|
||
yield "抱歉,我暂时无法回答。"
|