312 lines
12 KiB
Python
312 lines
12 KiB
Python
# audio_stream.py
|
||
# -*- coding: utf-8 -*-
|
||
import asyncio
|
||
from collections import deque
|
||
from dataclasses import dataclass
|
||
from typing import Optional, Set, List, Tuple, Any, Dict
|
||
from fastapi import Request
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
# ===== 下行 WAV 流基础参数 =====
|
||
STREAM_SR = 8000 # 改为8kHz,ESP32支持
|
||
STREAM_CH = 1
|
||
STREAM_SW = 2
|
||
BYTES_PER_20MS_16K = STREAM_SR * STREAM_SW * 20 // 1000 # 320B (8kHz)
|
||
|
||
# ===== Day 13: TTS 缓存队列 =====
|
||
# 当 WebSocket 断开时,缓存 TTS 音频,等待重连后发送
|
||
TTS_BUFFER_MAX_SECONDS = 30 # 最多缓存 30 秒音频
|
||
TTS_BUFFER_MAX_BYTES = 16000 * 2 * TTS_BUFFER_MAX_SECONDS # 16kHz * 2 bytes * 30s = ~960KB
|
||
tts_audio_buffer: deque = deque() # 每个元素是 (timestamp, pcm16k_bytes)
|
||
tts_buffer_total_bytes = 0
|
||
|
||
# Day 13: TTS 专用 WebSocket 引用
|
||
# 在 AI 处理开始前保存,避免被 ws_audio 的 finally 块清空
|
||
tts_websocket = None
|
||
|
||
def set_tts_websocket(ws):
|
||
"""保存 TTS 发送专用的 WebSocket 引用"""
|
||
global tts_websocket
|
||
tts_websocket = ws
|
||
|
||
def get_tts_websocket():
|
||
"""获取 TTS WebSocket(优先使用保存的引用,其次尝试全局变量)"""
|
||
global tts_websocket
|
||
if tts_websocket is not None:
|
||
return tts_websocket
|
||
# Day 15 修复:避免 import app_main,因为会触发模块顶层代码重新执行
|
||
# 改为通过 sys.modules 获取已加载的模块引用
|
||
try:
|
||
import sys
|
||
if 'app_main' in sys.modules:
|
||
return sys.modules['app_main'].esp32_audio_ws
|
||
except:
|
||
pass
|
||
return None
|
||
|
||
|
||
# ===== AI 播放任务总闸 =====
|
||
current_ai_task: Optional[asyncio.Task] = None
|
||
|
||
async def cancel_current_ai():
|
||
"""取消当前大模型语音任务,并等待其退出。"""
|
||
global current_ai_task
|
||
task = current_ai_task
|
||
current_ai_task = None
|
||
if task and not task.done():
|
||
task.cancel()
|
||
try:
|
||
await task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
except Exception:
|
||
pass
|
||
|
||
def is_playing_now() -> bool:
|
||
t = current_ai_task
|
||
return (t is not None) and (not t.done())
|
||
|
||
# ===== /stream.wav 连接管理 =====
|
||
@dataclass(frozen=True)
|
||
class StreamClient:
|
||
q: asyncio.Queue
|
||
abort_event: asyncio.Event
|
||
|
||
stream_clients: "Set[StreamClient]" = set()
|
||
STREAM_QUEUE_MAX = 96 # 小缓冲,避免积压
|
||
|
||
def _wav_header_unknown_size(sr=16000, ch=1, sw=2) -> bytes:
|
||
import struct
|
||
byte_rate = sr * ch * sw
|
||
block_align = ch * sw
|
||
data_size = 0x7FFFFFF0
|
||
riff_size = 36 + data_size
|
||
return struct.pack(
|
||
"<4sI4s4sIHHIIHH4sI",
|
||
b"RIFF", riff_size, b"WAVE",
|
||
b"fmt ", 16,
|
||
1, ch, sr, byte_rate, block_align, sw * 8,
|
||
b"data", data_size
|
||
)
|
||
|
||
async def hard_reset_audio(reason: str = ""):
|
||
"""
|
||
**一键清场**:取消当前AI任务。
|
||
注意:不再断开 HTTP /stream.wav 连接,因为 Avaota F1 使用这个通道播放 TTS。
|
||
"""
|
||
# Day 14: 不再断开 HTTP 连接,只取消 AI 任务
|
||
# 因为 Avaota F1 的 HTTP TTS 客户端需要保持长连接
|
||
# 断开会导致客户端收不到后续 TTS 音频
|
||
|
||
|
||
# 2) 取消当前AI任务
|
||
await cancel_current_ai()
|
||
|
||
# Day 28: 强制重置 VAD TTS 状态,防止因任务取消导致计数器未归零(VAD 冻结)
|
||
try:
|
||
# Safe import to avoid circular dependency
|
||
import sys
|
||
if 'server_vad' in sys.modules:
|
||
server_vad = sys.modules['server_vad']
|
||
if hasattr(server_vad, 'get_server_vad'):
|
||
vad = server_vad.get_server_vad()
|
||
if vad:
|
||
vad.reset_tts_state()
|
||
except Exception as e:
|
||
print(f"[HARD-RESET] 重置 VAD 状态失败: {e}")
|
||
|
||
# 3) 日志
|
||
if reason:
|
||
print(f"[HARD-RESET] {reason}")
|
||
|
||
async def flush_tts_buffer(ws) -> int:
|
||
"""
|
||
Day 13: 刷新 TTS 缓存,发送所有缓存的音频到 WebSocket
|
||
返回发送的字节数
|
||
"""
|
||
global tts_audio_buffer, tts_buffer_total_bytes
|
||
from starlette.websockets import WebSocketState
|
||
|
||
if not tts_audio_buffer:
|
||
return 0
|
||
|
||
total_sent = 0
|
||
items_to_send = list(tts_audio_buffer)
|
||
tts_audio_buffer.clear()
|
||
tts_buffer_total_bytes = 0
|
||
|
||
try:
|
||
for _, audio_data in items_to_send:
|
||
if hasattr(ws, 'client_state') and ws.client_state != WebSocketState.CONNECTED:
|
||
print(f"[TTS->WS] ⚠️ WebSocket disconnected while flushing buffer")
|
||
break
|
||
await ws.send_bytes(audio_data)
|
||
total_sent += len(audio_data)
|
||
|
||
if total_sent > 0:
|
||
duration = total_sent / (16000 * 2)
|
||
print(f"[TTS->WS] 📤 Flushed {total_sent} bytes ({duration:.1f}s) of cached TTS audio")
|
||
except Exception as e:
|
||
print(f"[TTS->WS] ❌ Error flushing buffer: {e}")
|
||
|
||
return total_sent
|
||
|
||
async def broadcast_pcm16_realtime(pcm16: bytes):
|
||
"""
|
||
Day 14 优化:WebSocket 立即发送,HTTP 节拍广播在后台执行
|
||
避免 HTTP 20ms pacing 阻塞 WebSocket TTS 传输
|
||
"""
|
||
# 【新增】录制音频(在分发之前整体录制,避免分片)
|
||
try:
|
||
import sync_recorder
|
||
sync_recorder.record_audio(pcm16, text="[Omni对话]")
|
||
except Exception:
|
||
pass # 静默失败,不影响播放
|
||
|
||
# Day 13: 同时发送给 WebSocket 客户端 (Avaota F1)
|
||
# 注意:Avaota 期望 16kHz PCM16 数据,而这里的 pcm16 是 8kHz
|
||
# 需要进行采样率转换
|
||
global tts_audio_buffer, tts_buffer_total_bytes
|
||
import time as _time
|
||
|
||
try:
|
||
import audioop
|
||
|
||
# Day 13: 使用 get_tts_websocket() 获取 WebSocket 引用
|
||
# 优先使用保存的引用,避免因 ws_audio 的 finally 清空全局变量
|
||
ws = get_tts_websocket()
|
||
|
||
# Day 21 优化:输入现在已经是 16kHz,无需转换
|
||
# app_main.py 已经直接从 24kHz 转换到 16kHz
|
||
pcm16k = pcm16
|
||
|
||
sent_ok = False
|
||
if ws is not None:
|
||
try:
|
||
# Day 13 修复:不检查 client_state,直接尝试发送
|
||
# WebSocketState 检查可能不准确,导致音频被错误缓存
|
||
|
||
# 先发送缓存的音频
|
||
while tts_audio_buffer:
|
||
_, buffered_audio = tts_audio_buffer.popleft()
|
||
tts_buffer_total_bytes -= len(buffered_audio)
|
||
await ws.send_bytes(buffered_audio)
|
||
if not getattr(broadcast_pcm16_realtime, '_flush_logged', False):
|
||
print(f"[TTS->WS] 📤 Flushing buffered TTS audio...")
|
||
broadcast_pcm16_realtime._flush_logged = True
|
||
|
||
# 发送当前音频
|
||
await ws.send_bytes(pcm16k)
|
||
sent_ok = True
|
||
|
||
if len(pcm16k) > 320:
|
||
print(f"[TTS->WS] 📤 Sent {len(pcm16k)} bytes (16kHz) to Avaota")
|
||
|
||
# 重置警告标志
|
||
broadcast_pcm16_realtime._ws_warned = False
|
||
broadcast_pcm16_realtime._buffer_warned = False
|
||
broadcast_pcm16_realtime._flush_logged = False
|
||
except Exception as send_err:
|
||
# 发送失败,将当前音频放回缓存
|
||
if not getattr(broadcast_pcm16_realtime, '_send_err_warned', False):
|
||
print(f"[TTS->WS] ❌ Send error: {send_err}, will buffer")
|
||
broadcast_pcm16_realtime._send_err_warned = True
|
||
|
||
# 如果发送失败或 WebSocket 断开,缓存音频
|
||
if not sent_ok:
|
||
# 添加到缓存队列
|
||
tts_audio_buffer.append((_time.time(), pcm16k))
|
||
tts_buffer_total_bytes += len(pcm16k)
|
||
|
||
# 如果缓存过大,移除最旧的
|
||
while tts_buffer_total_bytes > TTS_BUFFER_MAX_BYTES and tts_audio_buffer:
|
||
_, old_audio = tts_audio_buffer.popleft()
|
||
tts_buffer_total_bytes -= len(old_audio)
|
||
|
||
if not getattr(broadcast_pcm16_realtime, '_buffer_warned', False):
|
||
buffer_secs = tts_buffer_total_bytes / (16000 * 2)
|
||
print(f"[TTS->WS] 📦 Buffering TTS audio ({buffer_secs:.1f}s cached), will send when reconnected")
|
||
broadcast_pcm16_realtime._buffer_warned = True
|
||
|
||
except Exception:
|
||
pass # 静默忽略所有异常
|
||
|
||
# Day 14 优化:将 HTTP 节拍广播放到后台任务,不阻塞 WebSocket 发送
|
||
# 这样下一个 Omni 音频块可以立即处理,不用等待 HTTP 节拍完成
|
||
if stream_clients:
|
||
asyncio.create_task(_http_pacing_broadcast(pcm16))
|
||
|
||
|
||
async def _http_pacing_broadcast(pcm16: bytes):
|
||
"""
|
||
Day 14: HTTP 客户端的 20ms 节拍广播(独立后台任务)
|
||
原来嵌入在 broadcast_pcm16_realtime 中,会阻塞 WebSocket 发送
|
||
"""
|
||
loop = asyncio.get_event_loop()
|
||
next_tick = loop.time()
|
||
off = 0
|
||
while off < len(pcm16):
|
||
take = min(BYTES_PER_20MS_16K, len(pcm16) - off)
|
||
piece = pcm16[off:off + take]
|
||
|
||
dead: List[StreamClient] = []
|
||
# Day 14 调试:确认 stream_clients 状态
|
||
if len(stream_clients) > 0 and off == 0:
|
||
print(f"[TTS->HTTP] 📤 Sending to {len(stream_clients)} HTTP stream client(s)")
|
||
for sc in list(stream_clients):
|
||
if sc.abort_event.is_set():
|
||
dead.append(sc)
|
||
continue
|
||
try:
|
||
if sc.q.full():
|
||
try: sc.q.get_nowait()
|
||
except Exception: pass
|
||
sc.q.put_nowait(piece)
|
||
except Exception:
|
||
dead.append(sc)
|
||
for sc in dead:
|
||
try: stream_clients.discard(sc)
|
||
except Exception: pass
|
||
|
||
next_tick += 0.020
|
||
now = loop.time()
|
||
if now < next_tick:
|
||
await asyncio.sleep(next_tick - now)
|
||
else:
|
||
next_tick = now
|
||
off += take
|
||
|
||
# ===== FastAPI 路由注册器 =====
|
||
def register_stream_route(app):
|
||
@app.get("/stream.wav")
|
||
async def stream_wav(_: Request):
|
||
# —— 强制单连接(或少数连接),先拉闸所有旧连接 ——
|
||
for sc in list(stream_clients):
|
||
try: sc.abort_event.set()
|
||
except Exception: pass
|
||
stream_clients.clear()
|
||
|
||
q: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=STREAM_QUEUE_MAX)
|
||
abort_event = asyncio.Event()
|
||
sc = StreamClient(q=q, abort_event=abort_event)
|
||
stream_clients.add(sc)
|
||
|
||
async def gen():
|
||
yield _wav_header_unknown_size(STREAM_SR, STREAM_CH, STREAM_SW)
|
||
try:
|
||
while True:
|
||
if abort_event.is_set():
|
||
break
|
||
try:
|
||
chunk = await asyncio.wait_for(q.get(), timeout=0.5)
|
||
except asyncio.TimeoutError:
|
||
continue
|
||
if abort_event.is_set():
|
||
break
|
||
if chunk is None:
|
||
break
|
||
if chunk:
|
||
yield chunk
|
||
finally:
|
||
stream_clients.discard(sc)
|
||
return StreamingResponse(gen(), media_type="audio/wav") |