Files
NaviGlassServer/audio_stream.py
2025-12-31 15:42:30 +08:00

299 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# audio_stream.py
# -*- coding: utf-8 -*-
import asyncio
from collections import deque
from dataclasses import dataclass
from typing import Optional, Set, List, Tuple, Any, Dict
from fastapi import Request
from fastapi.responses import StreamingResponse
# ===== 下行 WAV 流基础参数 =====
STREAM_SR = 8000 # 改为8kHzESP32支持
STREAM_CH = 1
STREAM_SW = 2
BYTES_PER_20MS_16K = STREAM_SR * STREAM_SW * 20 // 1000 # 320B (8kHz)
# ===== Day 13: TTS 缓存队列 =====
# 当 WebSocket 断开时,缓存 TTS 音频,等待重连后发送
TTS_BUFFER_MAX_SECONDS = 30 # 最多缓存 30 秒音频
TTS_BUFFER_MAX_BYTES = 16000 * 2 * TTS_BUFFER_MAX_SECONDS # 16kHz * 2 bytes * 30s = ~960KB
tts_audio_buffer: deque = deque() # 每个元素是 (timestamp, pcm16k_bytes)
tts_buffer_total_bytes = 0
# Day 13: TTS 专用 WebSocket 引用
# 在 AI 处理开始前保存,避免被 ws_audio 的 finally 块清空
tts_websocket = None
def set_tts_websocket(ws):
"""保存 TTS 发送专用的 WebSocket 引用"""
global tts_websocket
tts_websocket = ws
def get_tts_websocket():
"""获取 TTS WebSocket优先使用保存的引用其次尝试全局变量"""
global tts_websocket
if tts_websocket is not None:
return tts_websocket
# Day 15 修复:避免 import app_main因为会触发模块顶层代码重新执行
# 改为通过 sys.modules 获取已加载的模块引用
try:
import sys
if 'app_main' in sys.modules:
return sys.modules['app_main'].esp32_audio_ws
except:
pass
return None
# ===== AI 播放任务总闸 =====
current_ai_task: Optional[asyncio.Task] = None
async def cancel_current_ai():
"""取消当前大模型语音任务,并等待其退出。"""
global current_ai_task
task = current_ai_task
current_ai_task = None
if task and not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
except Exception:
pass
def is_playing_now() -> bool:
t = current_ai_task
return (t is not None) and (not t.done())
# ===== /stream.wav 连接管理 =====
@dataclass(frozen=True)
class StreamClient:
q: asyncio.Queue
abort_event: asyncio.Event
stream_clients: "Set[StreamClient]" = set()
STREAM_QUEUE_MAX = 96 # 小缓冲,避免积压
def _wav_header_unknown_size(sr=16000, ch=1, sw=2) -> bytes:
import struct
byte_rate = sr * ch * sw
block_align = ch * sw
data_size = 0x7FFFFFF0
riff_size = 36 + data_size
return struct.pack(
"<4sI4s4sIHHIIHH4sI",
b"RIFF", riff_size, b"WAVE",
b"fmt ", 16,
1, ch, sr, byte_rate, block_align, sw * 8,
b"data", data_size
)
async def hard_reset_audio(reason: str = ""):
"""
**一键清场**取消当前AI任务。
注意:不再断开 HTTP /stream.wav 连接,因为 Avaota F1 使用这个通道播放 TTS。
"""
# Day 14: 不再断开 HTTP 连接,只取消 AI 任务
# 因为 Avaota F1 的 HTTP TTS 客户端需要保持长连接
# 断开会导致客户端收不到后续 TTS 音频
# 2) 取消当前AI任务
await cancel_current_ai()
# 3) 日志
if reason:
print(f"[HARD-RESET] {reason}")
async def flush_tts_buffer(ws) -> int:
"""
Day 13: 刷新 TTS 缓存,发送所有缓存的音频到 WebSocket
返回发送的字节数
"""
global tts_audio_buffer, tts_buffer_total_bytes
from starlette.websockets import WebSocketState
if not tts_audio_buffer:
return 0
total_sent = 0
items_to_send = list(tts_audio_buffer)
tts_audio_buffer.clear()
tts_buffer_total_bytes = 0
try:
for _, audio_data in items_to_send:
if hasattr(ws, 'client_state') and ws.client_state != WebSocketState.CONNECTED:
print(f"[TTS->WS] ⚠️ WebSocket disconnected while flushing buffer")
break
await ws.send_bytes(audio_data)
total_sent += len(audio_data)
if total_sent > 0:
duration = total_sent / (16000 * 2)
print(f"[TTS->WS] 📤 Flushed {total_sent} bytes ({duration:.1f}s) of cached TTS audio")
except Exception as e:
print(f"[TTS->WS] ❌ Error flushing buffer: {e}")
return total_sent
async def broadcast_pcm16_realtime(pcm16: bytes):
"""
Day 14 优化WebSocket 立即发送HTTP 节拍广播在后台执行
避免 HTTP 20ms pacing 阻塞 WebSocket TTS 传输
"""
# 【新增】录制音频(在分发之前整体录制,避免分片)
try:
import sync_recorder
sync_recorder.record_audio(pcm16, text="[Omni对话]")
except Exception:
pass # 静默失败,不影响播放
# Day 13: 同时发送给 WebSocket 客户端 (Avaota F1)
# 注意Avaota 期望 16kHz PCM16 数据,而这里的 pcm16 是 8kHz
# 需要进行采样率转换
global tts_audio_buffer, tts_buffer_total_bytes
import time as _time
try:
import audioop
# Day 13: 使用 get_tts_websocket() 获取 WebSocket 引用
# 优先使用保存的引用,避免因 ws_audio 的 finally 清空全局变量
ws = get_tts_websocket()
# Day 21 优化:输入现在已经是 16kHz无需转换
# app_main.py 已经直接从 24kHz 转换到 16kHz
pcm16k = pcm16
sent_ok = False
if ws is not None:
try:
# Day 13 修复:不检查 client_state直接尝试发送
# WebSocketState 检查可能不准确,导致音频被错误缓存
# 先发送缓存的音频
while tts_audio_buffer:
_, buffered_audio = tts_audio_buffer.popleft()
tts_buffer_total_bytes -= len(buffered_audio)
await ws.send_bytes(buffered_audio)
if not getattr(broadcast_pcm16_realtime, '_flush_logged', False):
print(f"[TTS->WS] 📤 Flushing buffered TTS audio...")
broadcast_pcm16_realtime._flush_logged = True
# 发送当前音频
await ws.send_bytes(pcm16k)
sent_ok = True
if len(pcm16k) > 320:
print(f"[TTS->WS] 📤 Sent {len(pcm16k)} bytes (16kHz) to Avaota")
# 重置警告标志
broadcast_pcm16_realtime._ws_warned = False
broadcast_pcm16_realtime._buffer_warned = False
broadcast_pcm16_realtime._flush_logged = False
except Exception as send_err:
# 发送失败,将当前音频放回缓存
if not getattr(broadcast_pcm16_realtime, '_send_err_warned', False):
print(f"[TTS->WS] ❌ Send error: {send_err}, will buffer")
broadcast_pcm16_realtime._send_err_warned = True
# 如果发送失败或 WebSocket 断开,缓存音频
if not sent_ok:
# 添加到缓存队列
tts_audio_buffer.append((_time.time(), pcm16k))
tts_buffer_total_bytes += len(pcm16k)
# 如果缓存过大,移除最旧的
while tts_buffer_total_bytes > TTS_BUFFER_MAX_BYTES and tts_audio_buffer:
_, old_audio = tts_audio_buffer.popleft()
tts_buffer_total_bytes -= len(old_audio)
if not getattr(broadcast_pcm16_realtime, '_buffer_warned', False):
buffer_secs = tts_buffer_total_bytes / (16000 * 2)
print(f"[TTS->WS] 📦 Buffering TTS audio ({buffer_secs:.1f}s cached), will send when reconnected")
broadcast_pcm16_realtime._buffer_warned = True
except Exception:
pass # 静默忽略所有异常
# Day 14 优化:将 HTTP 节拍广播放到后台任务,不阻塞 WebSocket 发送
# 这样下一个 Omni 音频块可以立即处理,不用等待 HTTP 节拍完成
if stream_clients:
asyncio.create_task(_http_pacing_broadcast(pcm16))
async def _http_pacing_broadcast(pcm16: bytes):
"""
Day 14: HTTP 客户端的 20ms 节拍广播(独立后台任务)
原来嵌入在 broadcast_pcm16_realtime 中,会阻塞 WebSocket 发送
"""
loop = asyncio.get_event_loop()
next_tick = loop.time()
off = 0
while off < len(pcm16):
take = min(BYTES_PER_20MS_16K, len(pcm16) - off)
piece = pcm16[off:off + take]
dead: List[StreamClient] = []
# Day 14 调试:确认 stream_clients 状态
if len(stream_clients) > 0 and off == 0:
print(f"[TTS->HTTP] 📤 Sending to {len(stream_clients)} HTTP stream client(s)")
for sc in list(stream_clients):
if sc.abort_event.is_set():
dead.append(sc)
continue
try:
if sc.q.full():
try: sc.q.get_nowait()
except Exception: pass
sc.q.put_nowait(piece)
except Exception:
dead.append(sc)
for sc in dead:
try: stream_clients.discard(sc)
except Exception: pass
next_tick += 0.020
now = loop.time()
if now < next_tick:
await asyncio.sleep(next_tick - now)
else:
next_tick = now
off += take
# ===== FastAPI 路由注册器 =====
def register_stream_route(app):
@app.get("/stream.wav")
async def stream_wav(_: Request):
# —— 强制单连接(或少数连接),先拉闸所有旧连接 ——
for sc in list(stream_clients):
try: sc.abort_event.set()
except Exception: pass
stream_clients.clear()
q: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=STREAM_QUEUE_MAX)
abort_event = asyncio.Event()
sc = StreamClient(q=q, abort_event=abort_event)
stream_clients.add(sc)
async def gen():
yield _wav_header_unknown_size(STREAM_SR, STREAM_CH, STREAM_SW)
try:
while True:
if abort_event.is_set():
break
try:
chunk = await asyncio.wait_for(q.get(), timeout=0.5)
except asyncio.TimeoutError:
continue
if abort_event.is_set():
break
if chunk is None:
break
if chunk:
yield chunk
finally:
stream_clients.discard(sc)
return StreamingResponse(gen(), media_type="audio/wav")