118 lines
4.5 KiB
Python
118 lines
4.5 KiB
Python
# omni_client.py
|
||
# -*- coding: utf-8 -*-
|
||
import os, base64, asyncio, threading
|
||
from typing import AsyncGenerator, Dict, Any, List, Optional, Tuple
|
||
|
||
from openai import OpenAI
|
||
|
||
# ===== OpenAI 兼容(达摩院 DashScope 兼容模式)=====
|
||
API_KEY = os.getenv("DASHSCOPE_API_KEY", "sk-a9440db694924559ae4ebdc2023d2b9a")
|
||
if not API_KEY:
|
||
raise RuntimeError("未设置 DASHSCOPE_API_KEY")
|
||
|
||
QWEN_MODEL = "qwen-omni-turbo"
|
||
|
||
# 兼容模式
|
||
oai_client = OpenAI(
|
||
api_key=API_KEY,
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
)
|
||
|
||
class OmniStreamPiece:
|
||
"""对外的统一增量数据:text/audio 二选一或同时。"""
|
||
def __init__(self, text_delta: Optional[str] = None, audio_b64: Optional[str] = None):
|
||
self.text_delta = text_delta
|
||
self.audio_b64 = audio_b64
|
||
|
||
async def stream_chat(
|
||
content_list: List[Dict[str, Any]],
|
||
voice: str = "Cherry",
|
||
audio_format: str = "wav",
|
||
) -> AsyncGenerator[OmniStreamPiece, None]:
|
||
"""
|
||
发起一轮 Omni-Turbo ChatCompletions 流式对话:
|
||
- content_list: OpenAI chat 的 content,多模态(image_url/text)
|
||
- 以 stream=True 返回
|
||
- 增量产出:OmniStreamPiece(text_delta=?, audio_b64=?)
|
||
|
||
Day 13 修复:使用队列+线程解耦同步 API 调用,避免阻塞事件循环
|
||
"""
|
||
# 使用 asyncio.Queue 在线程和异步之间传递数据
|
||
queue: asyncio.Queue = asyncio.Queue()
|
||
loop = asyncio.get_running_loop()
|
||
|
||
def _sync_stream():
|
||
"""在独立线程中运行同步 API 调用"""
|
||
try:
|
||
# Day 21 优化:添加 system prompt 让 AI 回答简洁
|
||
# 导盲眼镜场景需要快速、简短的回答
|
||
system_prompt = """你是一个视障辅助AI助手,安装在智能导盲眼镜上。
|
||
请用极简短的语言回答,每次回答不超过2-3句话。
|
||
避免冗长解释,只提供最关键的信息。
|
||
语气友好但简洁。"""
|
||
|
||
messages = [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": content_list}
|
||
]
|
||
|
||
completion = oai_client.chat.completions.create(
|
||
model=QWEN_MODEL,
|
||
messages=messages,
|
||
modalities=["text", "audio"],
|
||
audio={"voice": voice, "format": audio_format},
|
||
stream=True,
|
||
stream_options={"include_usage": True},
|
||
)
|
||
|
||
for chunk in completion:
|
||
text_delta: Optional[str] = None
|
||
audio_b64: Optional[str] = None
|
||
|
||
if getattr(chunk, "choices", None):
|
||
c0 = chunk.choices[0]
|
||
delta = getattr(c0, "delta", None)
|
||
# 文本增量
|
||
if delta and getattr(delta, "content", None):
|
||
piece = delta.content
|
||
if piece:
|
||
text_delta = piece
|
||
# 音频分片
|
||
if delta and getattr(delta, "audio", None):
|
||
aud = delta.audio
|
||
audio_b64 = aud.get("data") if isinstance(aud, dict) else getattr(aud, "data", None)
|
||
if audio_b64 is None:
|
||
msg = getattr(c0, "message", None)
|
||
if msg and getattr(msg, "audio", None):
|
||
ma = msg.audio
|
||
audio_b64 = ma.get("data") if isinstance(ma, dict) else getattr(ma, "data", None)
|
||
|
||
if (text_delta is not None) or (audio_b64 is not None):
|
||
# 线程安全地放入队列
|
||
loop.call_soon_threadsafe(
|
||
queue.put_nowait,
|
||
OmniStreamPiece(text_delta=text_delta, audio_b64=audio_b64)
|
||
)
|
||
except Exception as e:
|
||
# 发生异常时也要通知
|
||
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||
finally:
|
||
# 发送结束标记
|
||
loop.call_soon_threadsafe(queue.put_nowait, None)
|
||
|
||
# 在独立线程中启动同步 API 调用
|
||
thread = threading.Thread(target=_sync_stream, daemon=True)
|
||
thread.start()
|
||
|
||
# 异步消费队列
|
||
while True:
|
||
item = await queue.get()
|
||
if item is None:
|
||
# 流结束
|
||
break
|
||
if isinstance(item, Exception):
|
||
# 发生异常
|
||
raise item
|
||
yield item
|
||
|