345 lines
13 KiB
Python
345 lines
13 KiB
Python
"""
|
||
Silero VAD 服务器端语音活动检测
|
||
参考 xiaozhi-esp32-server 实现
|
||
"""
|
||
import torch
|
||
import numpy as np
|
||
import os
|
||
import collections # Day 23: For VAD lookback buffer
|
||
import time
|
||
|
||
# 尝试加载模型
|
||
_vad_model = None
|
||
_model_loaded = False
|
||
|
||
def get_vad_model():
|
||
"""获取或加载 Silero VAD 模型"""
|
||
global _vad_model, _model_loaded
|
||
|
||
if _model_loaded:
|
||
return _vad_model
|
||
|
||
try:
|
||
# 尝试从本地加载
|
||
model_dir = os.path.join(os.path.dirname(__file__), "model", "snakers4_silero-vad")
|
||
if os.path.exists(model_dir):
|
||
print(f"[VAD] 从本地加载 Silero VAD: {model_dir}")
|
||
_vad_model, _ = torch.hub.load(
|
||
repo_or_dir=model_dir,
|
||
source="local",
|
||
model="silero_vad",
|
||
force_reload=False,
|
||
)
|
||
else:
|
||
# 优先使用缓存,避免每次检查 GitHub 更新
|
||
cache_dir = os.path.expanduser("~/.cache/torch/hub/snakers4_silero-vad_master")
|
||
if os.path.exists(cache_dir):
|
||
print(f"[VAD] 使用 torch hub 缓存: {cache_dir}")
|
||
_vad_model, _ = torch.hub.load(
|
||
repo_or_dir=cache_dir,
|
||
source="local",
|
||
model="silero_vad",
|
||
force_reload=False,
|
||
)
|
||
else:
|
||
# 缓存不存在,从网络下载
|
||
print("[VAD] 从 torch.hub 下载 Silero VAD...")
|
||
_vad_model, _ = torch.hub.load(
|
||
repo_or_dir='snakers4/silero-vad',
|
||
model='silero_vad',
|
||
force_reload=False,
|
||
)
|
||
|
||
_model_loaded = True
|
||
print("[VAD] Silero VAD 模型加载成功")
|
||
return _vad_model
|
||
except Exception as e:
|
||
print(f"[VAD] Silero VAD 加载失败: {e}")
|
||
_model_loaded = True # 避免重复尝试加载
|
||
return None
|
||
|
||
|
||
class SileroVAD:
|
||
"""
|
||
服务器端 Silero VAD
|
||
用于检测语音开始和结束
|
||
"""
|
||
|
||
def __init__(self,
|
||
threshold: float = 0.5, # Day 23: 再次降低阈值 (原 0.7)
|
||
threshold_low: float = 0.3, # Day 23: 再次降低低阈值 (原 0.4)
|
||
min_silence_ms: int = 800, # Day 23: 延长静默 (原 600)
|
||
min_speech_ms: int = 300, # Day 23: 降低最小语音 (原 500)
|
||
sample_rate: int = 16000):
|
||
"""
|
||
初始化 VAD
|
||
|
||
Args:
|
||
threshold: 语音概率阈值(超过此值判断为语音)
|
||
threshold_low: 语音概率低阈值(低于此值判断为静默)
|
||
min_silence_ms: 最小静默时间(毫秒),超过此时间认为语音结束
|
||
min_speech_ms: 最小语音时间(毫秒),至少说这么久才算有效语音
|
||
sample_rate: 采样率
|
||
"""
|
||
self.model = get_vad_model()
|
||
self.threshold = threshold
|
||
self.threshold_low = threshold_low
|
||
self.min_silence_ms = min_silence_ms
|
||
self.min_speech_ms = min_speech_ms
|
||
self.sample_rate = sample_rate
|
||
|
||
# 状态
|
||
self.audio_buffer = bytearray()
|
||
self.is_speaking = False
|
||
self.last_speech_time = 0
|
||
self.speech_start_time = 0
|
||
self.speech_audio = bytearray() # 存储语音音频
|
||
|
||
# TTS 播放状态 - 播放期间暂停 VAD
|
||
# Day 28: 使用引用计数处理并发播放的情况
|
||
self.tts_playing_count = 0
|
||
self.tts_end_time = 0 # TTS 结束时间
|
||
self.tts_cooldown_ms = 500 # TTS 结束后等待 500ms 再开始检测
|
||
|
||
# 滑动窗口
|
||
self.voice_window = []
|
||
self.window_size = 5 # 滑动窗口大小
|
||
self.frame_threshold = 3 # 至少多少帧语音才算开始说话
|
||
|
||
# Day 23+28: Pre-speech buffer (Lookback) to fix "cut-off" start of words
|
||
# Day 28: 增加到 768ms (24 chunks) 以捕获 "室内导航" 等较长开头,防止 ASR 吞字
|
||
self.pre_speech_buffer = collections.deque(maxlen=24)
|
||
|
||
print(f"[VAD] 初始化: threshold={threshold}, threshold_low={threshold_low}, "
|
||
f"min_silence_ms={min_silence_ms}, min_speech_ms={min_speech_ms}")
|
||
|
||
def reset(self):
|
||
"""重置 VAD 状态"""
|
||
self.audio_buffer.clear()
|
||
self.speech_audio.clear()
|
||
self.is_speaking = False
|
||
self.last_speech_time = 0
|
||
self.speech_start_time = 0
|
||
self.voice_window.clear()
|
||
self.tts_playing_count = 0
|
||
self.tts_end_time = 0
|
||
if self.model:
|
||
self.model.reset_states()
|
||
if hasattr(self, 'pre_speech_buffer'):
|
||
self.pre_speech_buffer.clear()
|
||
|
||
def reset_tts_state(self):
|
||
"""强制重置 TTS 播放状态 (用于硬重置)"""
|
||
self.tts_playing_count = 0
|
||
print("[VAD] 强制重置 TTS 状态 (VAD 恢复)")
|
||
|
||
def set_tts_playing(self, playing: bool):
|
||
"""设置 TTS 播放状态 (引用计数)"""
|
||
if playing:
|
||
self.tts_playing_count += 1
|
||
if self.tts_playing_count == 1:
|
||
print("[VAD] TTS 开始播放,暂停 VAD 检测")
|
||
# TTS 开始播放时,如果正在录音则中断
|
||
if self.is_speaking:
|
||
self.is_speaking = False
|
||
self.speech_audio.clear()
|
||
self.voice_window.clear()
|
||
# Day 23: Clear lookback buffer
|
||
if hasattr(self, 'pre_speech_buffer'):
|
||
self.pre_speech_buffer.clear()
|
||
# Day 28: 重置模型状态
|
||
if self.model:
|
||
self.model.reset_states()
|
||
print("[VAD] TTS 播放打断语音录制")
|
||
else:
|
||
if self.tts_playing_count > 0:
|
||
self.tts_playing_count -= 1
|
||
if self.tts_playing_count == 0:
|
||
# TTS 结束,记录时间
|
||
self.tts_end_time = time.time() * 1000
|
||
print("[VAD] TTS 完全结束,等待冷却期...")
|
||
else:
|
||
# 已经是0了,忽略
|
||
pass
|
||
|
||
def process(self, audio_bytes: bytes) -> dict:
|
||
"""
|
||
处理音频数据
|
||
|
||
Args:
|
||
audio_bytes: PCM 16-bit 音频数据
|
||
|
||
Returns:
|
||
dict: {
|
||
'speech_started': bool, # 语音刚刚开始
|
||
'speech_ended': bool, # 语音刚刚结束
|
||
'is_speaking': bool, # 当前是否在说话
|
||
'speech_audio': bytes, # 如果语音结束,返回完整语音音频
|
||
}
|
||
"""
|
||
result = {
|
||
'speech_started': False,
|
||
'speech_ended': False,
|
||
'is_speaking': self.is_speaking,
|
||
'speech_audio': None,
|
||
}
|
||
|
||
if self.model is None:
|
||
# 没有模型,使用简单能量检测
|
||
return self._fallback_energy_vad(audio_bytes, result)
|
||
|
||
# TTS 播放期间,跳过 VAD 检测
|
||
current_time = time.time() * 1000
|
||
if self.tts_playing_count > 0:
|
||
return result
|
||
|
||
# TTS 刚结束,等待冷却期
|
||
if self.tts_end_time > 0 and (current_time - self.tts_end_time) < self.tts_cooldown_ms:
|
||
return result
|
||
|
||
# 将音频添加到缓冲区
|
||
self.audio_buffer.extend(audio_bytes)
|
||
|
||
# Silero VAD 需要 512 采样点 (32ms @ 16kHz)
|
||
chunk_size = 512 * 2 # 512 samples * 2 bytes
|
||
|
||
while len(self.audio_buffer) >= chunk_size:
|
||
chunk = self.audio_buffer[:chunk_size]
|
||
self.audio_buffer = self.audio_buffer[chunk_size:]
|
||
|
||
# 转换为模型需要的格式
|
||
audio_int16 = np.frombuffer(chunk, dtype=np.int16)
|
||
audio_float32 = audio_int16.astype(np.float32) / 32768.0
|
||
audio_tensor = torch.from_numpy(audio_float32)
|
||
|
||
# 检测语音概率
|
||
with torch.no_grad():
|
||
speech_prob = self.model(audio_tensor, self.sample_rate).item()
|
||
# Day 23: Debug logging to diagnose low volume/mic issues
|
||
if speech_prob > 0.3:
|
||
print(f"[VAD DEBUG] Prob: {speech_prob:.3f}")
|
||
|
||
# 双阈值判断
|
||
if speech_prob >= self.threshold:
|
||
is_voice = True
|
||
elif speech_prob <= self.threshold_low:
|
||
is_voice = False
|
||
else:
|
||
is_voice = self.is_speaking # 保持当前状态
|
||
|
||
# 更新滑动窗口
|
||
self.voice_window.append(is_voice)
|
||
if len(self.voice_window) > self.window_size:
|
||
self.voice_window.pop(0)
|
||
|
||
# 判断是否有语音
|
||
voice_count = self.voice_window.count(True)
|
||
has_voice = voice_count >= self.frame_threshold
|
||
|
||
# Maintain lookback buffer (always add current chunk)
|
||
self.pre_speech_buffer.append(chunk)
|
||
|
||
current_time = time.time() * 1000 # 毫秒
|
||
|
||
if has_voice:
|
||
if not self.is_speaking:
|
||
# 语音开始
|
||
self.is_speaking = True
|
||
self.speech_start_time = current_time
|
||
self.speech_audio.clear()
|
||
result['speech_started'] = True
|
||
result['speech_started'] = True
|
||
print("[VAD] 🎤 Speech started")
|
||
|
||
# Day 23: Prepend lookback buffer to recover the start of speech
|
||
if self.pre_speech_buffer:
|
||
for prev_chunk in self.pre_speech_buffer:
|
||
self.speech_audio.extend(prev_chunk)
|
||
print(f"[VAD] Recovered {len(self.pre_speech_buffer)} chunks ({len(self.pre_speech_buffer)*32}ms) from history")
|
||
|
||
self.last_speech_time = current_time
|
||
self.speech_audio.extend(chunk)
|
||
|
||
elif self.is_speaking:
|
||
# 仍在收集音频(可能是短暂停顿)
|
||
self.speech_audio.extend(chunk)
|
||
|
||
# 检查是否静默时间过长
|
||
silence_duration = current_time - self.last_speech_time
|
||
speech_duration = current_time - self.speech_start_time
|
||
|
||
if silence_duration >= self.min_silence_ms:
|
||
# 语音结束
|
||
self.is_speaking = False
|
||
|
||
# 检查语音是否足够长
|
||
if speech_duration >= self.min_speech_ms:
|
||
result['speech_ended'] = True
|
||
result['speech_audio'] = bytes(self.speech_audio)
|
||
print(f"[VAD] 🔇 Speech ended, duration={speech_duration:.0f}ms, "
|
||
f"audio_size={len(self.speech_audio)} bytes")
|
||
else:
|
||
print(f"[VAD] 语音太短 ({speech_duration:.0f}ms), 忽略")
|
||
|
||
self.speech_audio.clear()
|
||
|
||
result['is_speaking'] = self.is_speaking
|
||
return result
|
||
|
||
def _fallback_energy_vad(self, audio_bytes: bytes, result: dict) -> dict:
|
||
"""简单能量检测(作为备用)"""
|
||
# 计算 RMS 能量
|
||
audio_int16 = np.frombuffer(audio_bytes, dtype=np.int16)
|
||
rms = np.sqrt(np.mean(audio_int16.astype(np.float32) ** 2))
|
||
|
||
# 简单阈值
|
||
threshold = 500
|
||
is_voice = rms > threshold
|
||
|
||
current_time = time.time() * 1000
|
||
|
||
if is_voice:
|
||
if not self.is_speaking:
|
||
self.is_speaking = True
|
||
self.speech_start_time = current_time
|
||
self.speech_audio.clear()
|
||
result['speech_started'] = True
|
||
|
||
self.last_speech_time = current_time
|
||
self.speech_audio.extend(audio_bytes)
|
||
|
||
elif self.is_speaking:
|
||
self.speech_audio.extend(audio_bytes)
|
||
|
||
silence_duration = current_time - self.last_speech_time
|
||
if silence_duration >= self.min_silence_ms:
|
||
self.is_speaking = False
|
||
speech_duration = current_time - self.speech_start_time
|
||
|
||
if speech_duration >= self.min_speech_ms:
|
||
result['speech_ended'] = True
|
||
result['speech_audio'] = bytes(self.speech_audio)
|
||
|
||
self.speech_audio.clear()
|
||
|
||
result['is_speaking'] = self.is_speaking
|
||
return result
|
||
|
||
|
||
# 全局 VAD 实例
|
||
_global_vad = None
|
||
|
||
def get_server_vad() -> SileroVAD:
|
||
"""获取全局 VAD 实例"""
|
||
global _global_vad
|
||
if _global_vad is None:
|
||
_global_vad = SileroVAD()
|
||
return _global_vad
|
||
|
||
|
||
def reset_server_vad():
|
||
"""重置全局 VAD 状态"""
|
||
global _global_vad
|
||
if _global_vad:
|
||
_global_vad.reset()
|