Files
NaviGlassServer/server_vad.py
2025-12-31 15:42:30 +08:00

327 lines
12 KiB
Python

"""
Silero VAD 服务器端语音活动检测
参考 xiaozhi-esp32-server 实现
"""
import torch
import numpy as np
import os
import collections # Day 23: For VAD lookback buffer
import time
# 尝试加载模型
_vad_model = None
_model_loaded = False
def get_vad_model():
"""获取或加载 Silero VAD 模型"""
global _vad_model, _model_loaded
if _model_loaded:
return _vad_model
try:
# 尝试从本地加载
model_dir = os.path.join(os.path.dirname(__file__), "model", "snakers4_silero-vad")
if os.path.exists(model_dir):
print(f"[VAD] 从本地加载 Silero VAD: {model_dir}")
_vad_model, _ = torch.hub.load(
repo_or_dir=model_dir,
source="local",
model="silero_vad",
force_reload=False,
)
else:
# 优先使用缓存,避免每次检查 GitHub 更新
cache_dir = os.path.expanduser("~/.cache/torch/hub/snakers4_silero-vad_master")
if os.path.exists(cache_dir):
print(f"[VAD] 使用 torch hub 缓存: {cache_dir}")
_vad_model, _ = torch.hub.load(
repo_or_dir=cache_dir,
source="local",
model="silero_vad",
force_reload=False,
)
else:
# 缓存不存在,从网络下载
print("[VAD] 从 torch.hub 下载 Silero VAD...")
_vad_model, _ = torch.hub.load(
repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False,
)
_model_loaded = True
print("[VAD] Silero VAD 模型加载成功")
return _vad_model
except Exception as e:
print(f"[VAD] Silero VAD 加载失败: {e}")
_model_loaded = True # 避免重复尝试加载
return None
class SileroVAD:
"""
服务器端 Silero VAD
用于检测语音开始和结束
"""
def __init__(self,
threshold: float = 0.5, # Day 23: 再次降低阈值 (原 0.7)
threshold_low: float = 0.3, # Day 23: 再次降低低阈值 (原 0.4)
min_silence_ms: int = 800, # Day 23: 延长静默 (原 600)
min_speech_ms: int = 300, # Day 23: 降低最小语音 (原 500)
sample_rate: int = 16000):
"""
初始化 VAD
Args:
threshold: 语音概率阈值(超过此值判断为语音)
threshold_low: 语音概率低阈值(低于此值判断为静默)
min_silence_ms: 最小静默时间(毫秒),超过此时间认为语音结束
min_speech_ms: 最小语音时间(毫秒),至少说这么久才算有效语音
sample_rate: 采样率
"""
self.model = get_vad_model()
self.threshold = threshold
self.threshold_low = threshold_low
self.min_silence_ms = min_silence_ms
self.min_speech_ms = min_speech_ms
self.sample_rate = sample_rate
# 状态
self.audio_buffer = bytearray()
self.is_speaking = False
self.last_speech_time = 0
self.speech_start_time = 0
self.speech_audio = bytearray() # 存储语音音频
# TTS 播放状态 - 播放期间暂停 VAD
self.tts_playing = False
self.tts_end_time = 0 # TTS 结束时间
self.tts_cooldown_ms = 500 # TTS 结束后等待 500ms 再开始检测
# 滑动窗口
self.voice_window = []
self.window_size = 5 # 滑动窗口大小
self.frame_threshold = 3 # 至少多少帧语音才算开始说话
# Day 23: Pre-speech buffer (Lookback) to fix "cut-off" start of words
# 300ms lookback approx. (each chunk is 32ms) -> 10 chunks
self.pre_speech_buffer = collections.deque(maxlen=10)
print(f"[VAD] 初始化: threshold={threshold}, threshold_low={threshold_low}, "
f"min_silence_ms={min_silence_ms}, min_speech_ms={min_speech_ms}")
def reset(self):
"""重置 VAD 状态"""
self.audio_buffer.clear()
self.speech_audio.clear()
self.is_speaking = False
self.last_speech_time = 0
self.speech_start_time = 0
self.voice_window.clear()
self.tts_playing = False
self.tts_end_time = 0
if self.model:
self.model.reset_states()
def set_tts_playing(self, playing: bool):
"""设置 TTS 播放状态"""
self.tts_playing = playing
if not playing:
# TTS 结束,记录时间
self.tts_end_time = time.time() * 1000
print("[VAD] TTS 结束,等待冷却期...")
else:
print("[VAD] TTS 开始播放,暂停 VAD 检测")
# TTS 开始播放时,如果正在录音则中断
if self.is_speaking:
self.is_speaking = False
self.speech_audio.clear()
self.voice_window.clear()
# Day 23: Clear lookback buffer
if hasattr(self, 'pre_speech_buffer'):
self.pre_speech_buffer.clear()
print("[VAD] TTS 播放打断语音录制")
def process(self, audio_bytes: bytes) -> dict:
"""
处理音频数据
Args:
audio_bytes: PCM 16-bit 音频数据
Returns:
dict: {
'speech_started': bool, # 语音刚刚开始
'speech_ended': bool, # 语音刚刚结束
'is_speaking': bool, # 当前是否在说话
'speech_audio': bytes, # 如果语音结束,返回完整语音音频
}
"""
result = {
'speech_started': False,
'speech_ended': False,
'is_speaking': self.is_speaking,
'speech_audio': None,
}
if self.model is None:
# 没有模型,使用简单能量检测
return self._fallback_energy_vad(audio_bytes, result)
# TTS 播放期间,跳过 VAD 检测
current_time = time.time() * 1000
if self.tts_playing:
return result
# TTS 刚结束,等待冷却期
if self.tts_end_time > 0 and (current_time - self.tts_end_time) < self.tts_cooldown_ms:
return result
# 将音频添加到缓冲区
self.audio_buffer.extend(audio_bytes)
# Silero VAD 需要 512 采样点 (32ms @ 16kHz)
chunk_size = 512 * 2 # 512 samples * 2 bytes
while len(self.audio_buffer) >= chunk_size:
chunk = self.audio_buffer[:chunk_size]
self.audio_buffer = self.audio_buffer[chunk_size:]
# 转换为模型需要的格式
audio_int16 = np.frombuffer(chunk, dtype=np.int16)
audio_float32 = audio_int16.astype(np.float32) / 32768.0
audio_tensor = torch.from_numpy(audio_float32)
# 检测语音概率
with torch.no_grad():
speech_prob = self.model(audio_tensor, self.sample_rate).item()
# Day 23: Debug logging to diagnose low volume/mic issues
if speech_prob > 0.3:
print(f"[VAD DEBUG] Prob: {speech_prob:.3f}")
# 双阈值判断
if speech_prob >= self.threshold:
is_voice = True
elif speech_prob <= self.threshold_low:
is_voice = False
else:
is_voice = self.is_speaking # 保持当前状态
# 更新滑动窗口
self.voice_window.append(is_voice)
if len(self.voice_window) > self.window_size:
self.voice_window.pop(0)
# 判断是否有语音
voice_count = self.voice_window.count(True)
has_voice = voice_count >= self.frame_threshold
# Maintain lookback buffer (always add current chunk)
self.pre_speech_buffer.append(chunk)
current_time = time.time() * 1000 # 毫秒
if has_voice:
if not self.is_speaking:
# 语音开始
self.is_speaking = True
self.speech_start_time = current_time
self.speech_audio.clear()
result['speech_started'] = True
result['speech_started'] = True
print("[VAD] 🎤 Speech started")
# Day 23: Prepend lookback buffer to recover the start of speech
if self.pre_speech_buffer:
for prev_chunk in self.pre_speech_buffer:
self.speech_audio.extend(prev_chunk)
print(f"[VAD] Recovered {len(self.pre_speech_buffer)} chunks ({len(self.pre_speech_buffer)*32}ms) from history")
self.last_speech_time = current_time
self.speech_audio.extend(chunk)
elif self.is_speaking:
# 仍在收集音频(可能是短暂停顿)
self.speech_audio.extend(chunk)
# 检查是否静默时间过长
silence_duration = current_time - self.last_speech_time
speech_duration = current_time - self.speech_start_time
if silence_duration >= self.min_silence_ms:
# 语音结束
self.is_speaking = False
# 检查语音是否足够长
if speech_duration >= self.min_speech_ms:
result['speech_ended'] = True
result['speech_audio'] = bytes(self.speech_audio)
print(f"[VAD] 🔇 Speech ended, duration={speech_duration:.0f}ms, "
f"audio_size={len(self.speech_audio)} bytes")
else:
print(f"[VAD] 语音太短 ({speech_duration:.0f}ms), 忽略")
self.speech_audio.clear()
result['is_speaking'] = self.is_speaking
return result
def _fallback_energy_vad(self, audio_bytes: bytes, result: dict) -> dict:
"""简单能量检测(作为备用)"""
# 计算 RMS 能量
audio_int16 = np.frombuffer(audio_bytes, dtype=np.int16)
rms = np.sqrt(np.mean(audio_int16.astype(np.float32) ** 2))
# 简单阈值
threshold = 500
is_voice = rms > threshold
current_time = time.time() * 1000
if is_voice:
if not self.is_speaking:
self.is_speaking = True
self.speech_start_time = current_time
self.speech_audio.clear()
result['speech_started'] = True
self.last_speech_time = current_time
self.speech_audio.extend(audio_bytes)
elif self.is_speaking:
self.speech_audio.extend(audio_bytes)
silence_duration = current_time - self.last_speech_time
if silence_duration >= self.min_silence_ms:
self.is_speaking = False
speech_duration = current_time - self.speech_start_time
if speech_duration >= self.min_speech_ms:
result['speech_ended'] = True
result['speech_audio'] = bytes(self.speech_audio)
self.speech_audio.clear()
result['is_speaking'] = self.is_speaking
return result
# 全局 VAD 实例
_global_vad = None
def get_server_vad() -> SileroVAD:
"""获取全局 VAD 实例"""
global _global_vad
if _global_vad is None:
_global_vad = SileroVAD()
return _global_vad
def reset_server_vad():
"""重置全局 VAD 状态"""
global _global_vad
if _global_vad:
_global_vad.reset()