Files
NaviGlassServer/server_vad.py
2026-01-06 17:15:06 +08:00

345 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Silero VAD 服务器端语音活动检测
参考 xiaozhi-esp32-server 实现
"""
import torch
import numpy as np
import os
import collections # Day 23: For VAD lookback buffer
import time
# 尝试加载模型
_vad_model = None
_model_loaded = False
def get_vad_model():
"""获取或加载 Silero VAD 模型"""
global _vad_model, _model_loaded
if _model_loaded:
return _vad_model
try:
# 尝试从本地加载
model_dir = os.path.join(os.path.dirname(__file__), "model", "snakers4_silero-vad")
if os.path.exists(model_dir):
print(f"[VAD] 从本地加载 Silero VAD: {model_dir}")
_vad_model, _ = torch.hub.load(
repo_or_dir=model_dir,
source="local",
model="silero_vad",
force_reload=False,
)
else:
# 优先使用缓存,避免每次检查 GitHub 更新
cache_dir = os.path.expanduser("~/.cache/torch/hub/snakers4_silero-vad_master")
if os.path.exists(cache_dir):
print(f"[VAD] 使用 torch hub 缓存: {cache_dir}")
_vad_model, _ = torch.hub.load(
repo_or_dir=cache_dir,
source="local",
model="silero_vad",
force_reload=False,
)
else:
# 缓存不存在,从网络下载
print("[VAD] 从 torch.hub 下载 Silero VAD...")
_vad_model, _ = torch.hub.load(
repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
)
_model_loaded = True
print("[VAD] Silero VAD 模型加载成功")
return _vad_model
except Exception as e:
print(f"[VAD] Silero VAD 加载失败: {e}")
_model_loaded = True # 避免重复尝试加载
return None
class SileroVAD:
"""
服务器端 Silero VAD
用于检测语音开始和结束
"""
def __init__(self,
threshold: float = 0.5, # Day 23: 再次降低阈值 (原 0.7)
threshold_low: float = 0.3, # Day 23: 再次降低低阈值 (原 0.4)
min_silence_ms: int = 800, # Day 23: 延长静默 (原 600)
min_speech_ms: int = 300, # Day 23: 降低最小语音 (原 500)
sample_rate: int = 16000):
"""
初始化 VAD
Args:
threshold: 语音概率阈值(超过此值判断为语音)
threshold_low: 语音概率低阈值(低于此值判断为静默)
min_silence_ms: 最小静默时间(毫秒),超过此时间认为语音结束
min_speech_ms: 最小语音时间(毫秒),至少说这么久才算有效语音
sample_rate: 采样率
"""
self.model = get_vad_model()
self.threshold = threshold
self.threshold_low = threshold_low
self.min_silence_ms = min_silence_ms
self.min_speech_ms = min_speech_ms
self.sample_rate = sample_rate
# 状态
self.audio_buffer = bytearray()
self.is_speaking = False
self.last_speech_time = 0
self.speech_start_time = 0
self.speech_audio = bytearray() # 存储语音音频
# TTS 播放状态 - 播放期间暂停 VAD
# Day 28: 使用引用计数处理并发播放的情况
self.tts_playing_count = 0
self.tts_end_time = 0 # TTS 结束时间
self.tts_cooldown_ms = 500 # TTS 结束后等待 500ms 再开始检测
# 滑动窗口
self.voice_window = []
self.window_size = 5 # 滑动窗口大小
self.frame_threshold = 3 # 至少多少帧语音才算开始说话
# Day 23+28: Pre-speech buffer (Lookback) to fix "cut-off" start of words
# Day 28: 增加到 768ms (24 chunks) 以捕获 "室内导航" 等较长开头,防止 ASR 吞字
self.pre_speech_buffer = collections.deque(maxlen=24)
print(f"[VAD] 初始化: threshold={threshold}, threshold_low={threshold_low}, "
f"min_silence_ms={min_silence_ms}, min_speech_ms={min_speech_ms}")
def reset(self):
"""重置 VAD 状态"""
self.audio_buffer.clear()
self.speech_audio.clear()
self.is_speaking = False
self.last_speech_time = 0
self.speech_start_time = 0
self.voice_window.clear()
self.tts_playing_count = 0
self.tts_end_time = 0
if self.model:
self.model.reset_states()
if hasattr(self, 'pre_speech_buffer'):
self.pre_speech_buffer.clear()
def reset_tts_state(self):
"""强制重置 TTS 播放状态 (用于硬重置)"""
self.tts_playing_count = 0
print("[VAD] 强制重置 TTS 状态 (VAD 恢复)")
def set_tts_playing(self, playing: bool):
"""设置 TTS 播放状态 (引用计数)"""
if playing:
self.tts_playing_count += 1
if self.tts_playing_count == 1:
print("[VAD] TTS 开始播放,暂停 VAD 检测")
# TTS 开始播放时,如果正在录音则中断
if self.is_speaking:
self.is_speaking = False
self.speech_audio.clear()
self.voice_window.clear()
# Day 23: Clear lookback buffer
if hasattr(self, 'pre_speech_buffer'):
self.pre_speech_buffer.clear()
# Day 28: 重置模型状态
if self.model:
self.model.reset_states()
print("[VAD] TTS 播放打断语音录制")
else:
if self.tts_playing_count > 0:
self.tts_playing_count -= 1
if self.tts_playing_count == 0:
# TTS 结束,记录时间
self.tts_end_time = time.time() * 1000
print("[VAD] TTS 完全结束,等待冷却期...")
else:
# 已经是0了忽略
pass
def process(self, audio_bytes: bytes) -> dict:
"""
处理音频数据
Args:
audio_bytes: PCM 16-bit 音频数据
Returns:
dict: {
'speech_started': bool, # 语音刚刚开始
'speech_ended': bool, # 语音刚刚结束
'is_speaking': bool, # 当前是否在说话
'speech_audio': bytes, # 如果语音结束,返回完整语音音频
}
"""
result = {
'speech_started': False,
'speech_ended': False,
'is_speaking': self.is_speaking,
'speech_audio': None,
}
if self.model is None:
# 没有模型,使用简单能量检测
return self._fallback_energy_vad(audio_bytes, result)
# TTS 播放期间,跳过 VAD 检测
current_time = time.time() * 1000
if self.tts_playing_count > 0:
return result
# TTS 刚结束,等待冷却期
if self.tts_end_time > 0 and (current_time - self.tts_end_time) < self.tts_cooldown_ms:
return result
# 将音频添加到缓冲区
self.audio_buffer.extend(audio_bytes)
# Silero VAD 需要 512 采样点 (32ms @ 16kHz)
chunk_size = 512 * 2 # 512 samples * 2 bytes
while len(self.audio_buffer) >= chunk_size:
chunk = self.audio_buffer[:chunk_size]
self.audio_buffer = self.audio_buffer[chunk_size:]
# 转换为模型需要的格式
audio_int16 = np.frombuffer(chunk, dtype=np.int16)
audio_float32 = audio_int16.astype(np.float32) / 32768.0
audio_tensor = torch.from_numpy(audio_float32)
# 检测语音概率
with torch.no_grad():
speech_prob = self.model(audio_tensor, self.sample_rate).item()
# Day 23: Debug logging to diagnose low volume/mic issues
if speech_prob > 0.3:
print(f"[VAD DEBUG] Prob: {speech_prob:.3f}")
# 双阈值判断
if speech_prob >= self.threshold:
is_voice = True
elif speech_prob <= self.threshold_low:
is_voice = False
else:
is_voice = self.is_speaking # 保持当前状态
# 更新滑动窗口
self.voice_window.append(is_voice)
if len(self.voice_window) > self.window_size:
self.voice_window.pop(0)
# 判断是否有语音
voice_count = self.voice_window.count(True)
has_voice = voice_count >= self.frame_threshold
# Maintain lookback buffer (always add current chunk)
self.pre_speech_buffer.append(chunk)
current_time = time.time() * 1000 # 毫秒
if has_voice:
if not self.is_speaking:
# 语音开始
self.is_speaking = True
self.speech_start_time = current_time
self.speech_audio.clear()
result['speech_started'] = True
result['speech_started'] = True
print("[VAD] 🎤 Speech started")
# Day 23: Prepend lookback buffer to recover the start of speech
if self.pre_speech_buffer:
for prev_chunk in self.pre_speech_buffer:
self.speech_audio.extend(prev_chunk)
print(f"[VAD] Recovered {len(self.pre_speech_buffer)} chunks ({len(self.pre_speech_buffer)*32}ms) from history")
self.last_speech_time = current_time
self.speech_audio.extend(chunk)
elif self.is_speaking:
# 仍在收集音频(可能是短暂停顿)
self.speech_audio.extend(chunk)
# 检查是否静默时间过长
silence_duration = current_time - self.last_speech_time
speech_duration = current_time - self.speech_start_time
if silence_duration >= self.min_silence_ms:
# 语音结束
self.is_speaking = False
# 检查语音是否足够长
if speech_duration >= self.min_speech_ms:
result['speech_ended'] = True
result['speech_audio'] = bytes(self.speech_audio)
print(f"[VAD] 🔇 Speech ended, duration={speech_duration:.0f}ms, "
f"audio_size={len(self.speech_audio)} bytes")
else:
print(f"[VAD] 语音太短 ({speech_duration:.0f}ms), 忽略")
self.speech_audio.clear()
result['is_speaking'] = self.is_speaking
return result
def _fallback_energy_vad(self, audio_bytes: bytes, result: dict) -> dict:
"""简单能量检测(作为备用)"""
# 计算 RMS 能量
audio_int16 = np.frombuffer(audio_bytes, dtype=np.int16)
rms = np.sqrt(np.mean(audio_int16.astype(np.float32) ** 2))
# 简单阈值
threshold = 500
is_voice = rms > threshold
current_time = time.time() * 1000
if is_voice:
if not self.is_speaking:
self.is_speaking = True
self.speech_start_time = current_time
self.speech_audio.clear()
result['speech_started'] = True
self.last_speech_time = current_time
self.speech_audio.extend(audio_bytes)
elif self.is_speaking:
self.speech_audio.extend(audio_bytes)
silence_duration = current_time - self.last_speech_time
if silence_duration >= self.min_silence_ms:
self.is_speaking = False
speech_duration = current_time - self.speech_start_time
if speech_duration >= self.min_speech_ms:
result['speech_ended'] = True
result['speech_audio'] = bytes(self.speech_audio)
self.speech_audio.clear()
result['is_speaking'] = self.is_speaking
return result
# 全局 VAD 实例
_global_vad = None
def get_server_vad() -> SileroVAD:
"""获取全局 VAD 实例"""
global _global_vad
if _global_vad is None:
_global_vad = SileroVAD()
return _global_vad
def reset_server_vad():
"""重置全局 VAD 状态"""
global _global_vad
if _global_vad:
_global_vad.reset()