327 lines
12 KiB
Python
327 lines
12 KiB
Python
"""
|
|
Silero VAD 服务器端语音活动检测
|
|
参考 xiaozhi-esp32-server 实现
|
|
"""
|
|
import torch
|
|
import numpy as np
|
|
import os
|
|
import collections # Day 23: For VAD lookback buffer
|
|
import time
|
|
|
|
# 尝试加载模型
|
|
_vad_model = None
|
|
_model_loaded = False
|
|
|
|
def get_vad_model():
|
|
"""获取或加载 Silero VAD 模型"""
|
|
global _vad_model, _model_loaded
|
|
|
|
if _model_loaded:
|
|
return _vad_model
|
|
|
|
try:
|
|
# 尝试从本地加载
|
|
model_dir = os.path.join(os.path.dirname(__file__), "model", "snakers4_silero-vad")
|
|
if os.path.exists(model_dir):
|
|
print(f"[VAD] 从本地加载 Silero VAD: {model_dir}")
|
|
_vad_model, _ = torch.hub.load(
|
|
repo_or_dir=model_dir,
|
|
source="local",
|
|
model="silero_vad",
|
|
force_reload=False,
|
|
)
|
|
else:
|
|
# 优先使用缓存,避免每次检查 GitHub 更新
|
|
cache_dir = os.path.expanduser("~/.cache/torch/hub/snakers4_silero-vad_master")
|
|
if os.path.exists(cache_dir):
|
|
print(f"[VAD] 使用 torch hub 缓存: {cache_dir}")
|
|
_vad_model, _ = torch.hub.load(
|
|
repo_or_dir=cache_dir,
|
|
source="local",
|
|
model="silero_vad",
|
|
force_reload=False,
|
|
)
|
|
else:
|
|
# 缓存不存在,从网络下载
|
|
print("[VAD] 从 torch.hub 下载 Silero VAD...")
|
|
_vad_model, _ = torch.hub.load(
|
|
repo_or_dir='snakers4/silero-vad',
|
|
model='silero_vad',
|
|
force_reload=False,
|
|
)
|
|
|
|
_model_loaded = True
|
|
print("[VAD] Silero VAD 模型加载成功")
|
|
return _vad_model
|
|
except Exception as e:
|
|
print(f"[VAD] Silero VAD 加载失败: {e}")
|
|
_model_loaded = True # 避免重复尝试加载
|
|
return None
|
|
|
|
|
|
class SileroVAD:
|
|
"""
|
|
服务器端 Silero VAD
|
|
用于检测语音开始和结束
|
|
"""
|
|
|
|
def __init__(self,
|
|
threshold: float = 0.5, # Day 23: 再次降低阈值 (原 0.7)
|
|
threshold_low: float = 0.3, # Day 23: 再次降低低阈值 (原 0.4)
|
|
min_silence_ms: int = 800, # Day 23: 延长静默 (原 600)
|
|
min_speech_ms: int = 300, # Day 23: 降低最小语音 (原 500)
|
|
sample_rate: int = 16000):
|
|
"""
|
|
初始化 VAD
|
|
|
|
Args:
|
|
threshold: 语音概率阈值(超过此值判断为语音)
|
|
threshold_low: 语音概率低阈值(低于此值判断为静默)
|
|
min_silence_ms: 最小静默时间(毫秒),超过此时间认为语音结束
|
|
min_speech_ms: 最小语音时间(毫秒),至少说这么久才算有效语音
|
|
sample_rate: 采样率
|
|
"""
|
|
self.model = get_vad_model()
|
|
self.threshold = threshold
|
|
self.threshold_low = threshold_low
|
|
self.min_silence_ms = min_silence_ms
|
|
self.min_speech_ms = min_speech_ms
|
|
self.sample_rate = sample_rate
|
|
|
|
# 状态
|
|
self.audio_buffer = bytearray()
|
|
self.is_speaking = False
|
|
self.last_speech_time = 0
|
|
self.speech_start_time = 0
|
|
self.speech_audio = bytearray() # 存储语音音频
|
|
|
|
# TTS 播放状态 - 播放期间暂停 VAD
|
|
self.tts_playing = False
|
|
self.tts_end_time = 0 # TTS 结束时间
|
|
self.tts_cooldown_ms = 500 # TTS 结束后等待 500ms 再开始检测
|
|
|
|
# 滑动窗口
|
|
self.voice_window = []
|
|
self.window_size = 5 # 滑动窗口大小
|
|
self.frame_threshold = 3 # 至少多少帧语音才算开始说话
|
|
|
|
# Day 23: Pre-speech buffer (Lookback) to fix "cut-off" start of words
|
|
# 300ms lookback approx. (each chunk is 32ms) -> 10 chunks
|
|
self.pre_speech_buffer = collections.deque(maxlen=10)
|
|
|
|
print(f"[VAD] 初始化: threshold={threshold}, threshold_low={threshold_low}, "
|
|
f"min_silence_ms={min_silence_ms}, min_speech_ms={min_speech_ms}")
|
|
|
|
def reset(self):
|
|
"""重置 VAD 状态"""
|
|
self.audio_buffer.clear()
|
|
self.speech_audio.clear()
|
|
self.is_speaking = False
|
|
self.last_speech_time = 0
|
|
self.speech_start_time = 0
|
|
self.voice_window.clear()
|
|
self.tts_playing = False
|
|
self.tts_end_time = 0
|
|
if self.model:
|
|
self.model.reset_states()
|
|
|
|
def set_tts_playing(self, playing: bool):
|
|
"""设置 TTS 播放状态"""
|
|
self.tts_playing = playing
|
|
if not playing:
|
|
# TTS 结束,记录时间
|
|
self.tts_end_time = time.time() * 1000
|
|
print("[VAD] TTS 结束,等待冷却期...")
|
|
else:
|
|
print("[VAD] TTS 开始播放,暂停 VAD 检测")
|
|
# TTS 开始播放时,如果正在录音则中断
|
|
if self.is_speaking:
|
|
self.is_speaking = False
|
|
self.speech_audio.clear()
|
|
self.voice_window.clear()
|
|
# Day 23: Clear lookback buffer
|
|
if hasattr(self, 'pre_speech_buffer'):
|
|
self.pre_speech_buffer.clear()
|
|
print("[VAD] TTS 播放打断语音录制")
|
|
|
|
def process(self, audio_bytes: bytes) -> dict:
|
|
"""
|
|
处理音频数据
|
|
|
|
Args:
|
|
audio_bytes: PCM 16-bit 音频数据
|
|
|
|
Returns:
|
|
dict: {
|
|
'speech_started': bool, # 语音刚刚开始
|
|
'speech_ended': bool, # 语音刚刚结束
|
|
'is_speaking': bool, # 当前是否在说话
|
|
'speech_audio': bytes, # 如果语音结束,返回完整语音音频
|
|
}
|
|
"""
|
|
result = {
|
|
'speech_started': False,
|
|
'speech_ended': False,
|
|
'is_speaking': self.is_speaking,
|
|
'speech_audio': None,
|
|
}
|
|
|
|
if self.model is None:
|
|
# 没有模型,使用简单能量检测
|
|
return self._fallback_energy_vad(audio_bytes, result)
|
|
|
|
# TTS 播放期间,跳过 VAD 检测
|
|
current_time = time.time() * 1000
|
|
if self.tts_playing:
|
|
return result
|
|
|
|
# TTS 刚结束,等待冷却期
|
|
if self.tts_end_time > 0 and (current_time - self.tts_end_time) < self.tts_cooldown_ms:
|
|
return result
|
|
|
|
# 将音频添加到缓冲区
|
|
self.audio_buffer.extend(audio_bytes)
|
|
|
|
# Silero VAD 需要 512 采样点 (32ms @ 16kHz)
|
|
chunk_size = 512 * 2 # 512 samples * 2 bytes
|
|
|
|
while len(self.audio_buffer) >= chunk_size:
|
|
chunk = self.audio_buffer[:chunk_size]
|
|
self.audio_buffer = self.audio_buffer[chunk_size:]
|
|
|
|
# 转换为模型需要的格式
|
|
audio_int16 = np.frombuffer(chunk, dtype=np.int16)
|
|
audio_float32 = audio_int16.astype(np.float32) / 32768.0
|
|
audio_tensor = torch.from_numpy(audio_float32)
|
|
|
|
# 检测语音概率
|
|
with torch.no_grad():
|
|
speech_prob = self.model(audio_tensor, self.sample_rate).item()
|
|
# Day 23: Debug logging to diagnose low volume/mic issues
|
|
if speech_prob > 0.3:
|
|
print(f"[VAD DEBUG] Prob: {speech_prob:.3f}")
|
|
|
|
# 双阈值判断
|
|
if speech_prob >= self.threshold:
|
|
is_voice = True
|
|
elif speech_prob <= self.threshold_low:
|
|
is_voice = False
|
|
else:
|
|
is_voice = self.is_speaking # 保持当前状态
|
|
|
|
# 更新滑动窗口
|
|
self.voice_window.append(is_voice)
|
|
if len(self.voice_window) > self.window_size:
|
|
self.voice_window.pop(0)
|
|
|
|
# 判断是否有语音
|
|
voice_count = self.voice_window.count(True)
|
|
has_voice = voice_count >= self.frame_threshold
|
|
|
|
# Maintain lookback buffer (always add current chunk)
|
|
self.pre_speech_buffer.append(chunk)
|
|
|
|
current_time = time.time() * 1000 # 毫秒
|
|
|
|
if has_voice:
|
|
if not self.is_speaking:
|
|
# 语音开始
|
|
self.is_speaking = True
|
|
self.speech_start_time = current_time
|
|
self.speech_audio.clear()
|
|
result['speech_started'] = True
|
|
result['speech_started'] = True
|
|
print("[VAD] 🎤 Speech started")
|
|
|
|
# Day 23: Prepend lookback buffer to recover the start of speech
|
|
if self.pre_speech_buffer:
|
|
for prev_chunk in self.pre_speech_buffer:
|
|
self.speech_audio.extend(prev_chunk)
|
|
print(f"[VAD] Recovered {len(self.pre_speech_buffer)} chunks ({len(self.pre_speech_buffer)*32}ms) from history")
|
|
|
|
self.last_speech_time = current_time
|
|
self.speech_audio.extend(chunk)
|
|
|
|
elif self.is_speaking:
|
|
# 仍在收集音频(可能是短暂停顿)
|
|
self.speech_audio.extend(chunk)
|
|
|
|
# 检查是否静默时间过长
|
|
silence_duration = current_time - self.last_speech_time
|
|
speech_duration = current_time - self.speech_start_time
|
|
|
|
if silence_duration >= self.min_silence_ms:
|
|
# 语音结束
|
|
self.is_speaking = False
|
|
|
|
# 检查语音是否足够长
|
|
if speech_duration >= self.min_speech_ms:
|
|
result['speech_ended'] = True
|
|
result['speech_audio'] = bytes(self.speech_audio)
|
|
print(f"[VAD] 🔇 Speech ended, duration={speech_duration:.0f}ms, "
|
|
f"audio_size={len(self.speech_audio)} bytes")
|
|
else:
|
|
print(f"[VAD] 语音太短 ({speech_duration:.0f}ms), 忽略")
|
|
|
|
self.speech_audio.clear()
|
|
|
|
result['is_speaking'] = self.is_speaking
|
|
return result
|
|
|
|
def _fallback_energy_vad(self, audio_bytes: bytes, result: dict) -> dict:
|
|
"""简单能量检测(作为备用)"""
|
|
# 计算 RMS 能量
|
|
audio_int16 = np.frombuffer(audio_bytes, dtype=np.int16)
|
|
rms = np.sqrt(np.mean(audio_int16.astype(np.float32) ** 2))
|
|
|
|
# 简单阈值
|
|
threshold = 500
|
|
is_voice = rms > threshold
|
|
|
|
current_time = time.time() * 1000
|
|
|
|
if is_voice:
|
|
if not self.is_speaking:
|
|
self.is_speaking = True
|
|
self.speech_start_time = current_time
|
|
self.speech_audio.clear()
|
|
result['speech_started'] = True
|
|
|
|
self.last_speech_time = current_time
|
|
self.speech_audio.extend(audio_bytes)
|
|
|
|
elif self.is_speaking:
|
|
self.speech_audio.extend(audio_bytes)
|
|
|
|
silence_duration = current_time - self.last_speech_time
|
|
if silence_duration >= self.min_silence_ms:
|
|
self.is_speaking = False
|
|
speech_duration = current_time - self.speech_start_time
|
|
|
|
if speech_duration >= self.min_speech_ms:
|
|
result['speech_ended'] = True
|
|
result['speech_audio'] = bytes(self.speech_audio)
|
|
|
|
self.speech_audio.clear()
|
|
|
|
result['is_speaking'] = self.is_speaking
|
|
return result
|
|
|
|
|
|
# 全局 VAD 实例
|
|
_global_vad = None
|
|
|
|
def get_server_vad() -> SileroVAD:
|
|
"""获取全局 VAD 实例"""
|
|
global _global_vad
|
|
if _global_vad is None:
|
|
_global_vad = SileroVAD()
|
|
return _global_vad
|
|
|
|
|
|
def reset_server_vad():
|
|
"""重置全局 VAD 状态"""
|
|
global _global_vad
|
|
if _global_vad:
|
|
_global_vad.reset()
|