169 lines
5.1 KiB
Python
169 lines
5.1 KiB
Python
# sensevoice_asr.py
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
SenseVoice 本地 ASR 模块 - Day 21
|
||
参考 xiaozhi-esp32-server 的非流式实现
|
||
|
||
特点:
|
||
- 非流式识别(等语音说完再识别)
|
||
- 内置 VAD 自动切分
|
||
- 整句输出,不会"蹦字"
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import asyncio
|
||
import numpy as np
|
||
from typing import Optional, Tuple
|
||
from funasr import AutoModel
|
||
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||
|
||
# 模型路径 - 支持环境变量配置和相对路径
|
||
# 【重要】FunASR 需要目录路径(包含 config.yaml 和 model.pt)
|
||
# 本地模型目录(相对于当前文件)
|
||
_LOCAL_MODEL_DIR = os.path.join(os.path.dirname(__file__), "model", "SenseVoiceSmall")
|
||
|
||
# 自动选择可用路径:环境变量 > 本地目录 > 在线下载
|
||
if os.path.exists(os.getenv("SENSEVOICE_MODEL_PATH", "")):
|
||
MODEL_PATH = os.getenv("SENSEVOICE_MODEL_PATH")
|
||
elif os.path.isdir(_LOCAL_MODEL_DIR) and os.path.exists(os.path.join(_LOCAL_MODEL_DIR, "model.pt")):
|
||
MODEL_PATH = _LOCAL_MODEL_DIR
|
||
else:
|
||
# 使用 FunASR 模型标识符(首次运行会自动下载到 ~/.cache)
|
||
MODEL_PATH = "iic/SenseVoiceSmall"
|
||
|
||
# GPU 设备 - 使用与主程序相同的配置方式
|
||
# 注意:服务器通过 CUDA_VISIBLE_DEVICES=1 选择 GPU,程序中统一使用 cuda:0
|
||
import torch
|
||
if torch.cuda.is_available():
|
||
DEVICE = os.getenv("SENSEVOICE_DEVICE", os.getenv("AIGLASS_DEVICE", "cuda:0"))
|
||
else:
|
||
DEVICE = "cpu"
|
||
|
||
# 全局模型实例
|
||
_model: Optional[AutoModel] = None
|
||
_model_lock = asyncio.Lock()
|
||
|
||
|
||
def _load_model():
|
||
"""加载 SenseVoice 模型"""
|
||
global _model
|
||
if _model is not None:
|
||
return _model
|
||
|
||
print(f"[SenseVoice] 正在加载模型: {MODEL_PATH}")
|
||
print(f"[SenseVoice] 使用设备: {DEVICE}")
|
||
|
||
start_time = time.time()
|
||
|
||
_model = AutoModel(
|
||
model=MODEL_PATH,
|
||
vad_kwargs={"max_single_segment_time": 30000}, # VAD 最大 30 秒
|
||
disable_update=True,
|
||
hub="hf", # 参考 xiaozhi-esp32-server
|
||
device=DEVICE,
|
||
)
|
||
|
||
print(f"[SenseVoice] 模型加载完成,耗时: {time.time() - start_time:.2f}s")
|
||
return _model
|
||
|
||
|
||
async def init_sensevoice():
|
||
"""异步初始化 SenseVoice(服务器启动时调用)"""
|
||
async with _model_lock:
|
||
await asyncio.to_thread(_load_model)
|
||
print("[SenseVoice] 初始化完成")
|
||
|
||
|
||
async def recognize(pcm_data: bytes, sample_rate: int = 16000) -> str:
|
||
"""
|
||
识别 PCM 音频数据
|
||
|
||
Args:
|
||
pcm_data: PCM 16-bit 音频数据 (bytes)
|
||
sample_rate: 采样率 (默认 16000)
|
||
|
||
Returns:
|
||
识别结果文本
|
||
"""
|
||
if _model is None:
|
||
await init_sensevoice()
|
||
|
||
if not pcm_data or len(pcm_data) < 640: # 至少 20ms 音频
|
||
return ""
|
||
|
||
try:
|
||
start_time = time.time()
|
||
|
||
# 在线程池中执行推理(避免阻塞事件循环)
|
||
# 【Day 22 修复】language 从 "auto" 改为 "zh"
|
||
# 避免误识别为韩语等其他语言
|
||
result = await asyncio.to_thread(
|
||
_model.generate,
|
||
input=pcm_data,
|
||
cache={},
|
||
language="zh", # 固定为中文,避免 auto 误判
|
||
use_itn=True,
|
||
batch_size_s=60,
|
||
)
|
||
|
||
# 后处理
|
||
if result and len(result) > 0 and "text" in result[0]:
|
||
text = await asyncio.to_thread(
|
||
rich_transcription_postprocess,
|
||
result[0]["text"]
|
||
)
|
||
elapsed = time.time() - start_time
|
||
print(f"[SenseVoice] 识别耗时: {elapsed:.3f}s | 结果: {text}")
|
||
return text.strip()
|
||
else:
|
||
print("[SenseVoice] 识别结果为空")
|
||
return ""
|
||
|
||
except Exception as e:
|
||
print(f"[SenseVoice] 识别失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return ""
|
||
|
||
|
||
async def recognize_from_file(file_path: str) -> str:
|
||
"""
|
||
从文件识别音频
|
||
|
||
Args:
|
||
file_path: 音频文件路径
|
||
|
||
Returns:
|
||
识别结果文本
|
||
"""
|
||
if _model is None:
|
||
await init_sensevoice()
|
||
|
||
try:
|
||
start_time = time.time()
|
||
|
||
result = await asyncio.to_thread(
|
||
_model.generate,
|
||
input=file_path,
|
||
cache={},
|
||
language="zh", # 【Day 22 修复】固定为中文
|
||
use_itn=True,
|
||
batch_size_s=60,
|
||
)
|
||
|
||
if result and len(result) > 0 and "text" in result[0]:
|
||
text = await asyncio.to_thread(
|
||
rich_transcription_postprocess,
|
||
result[0]["text"]
|
||
)
|
||
elapsed = time.time() - start_time
|
||
print(f"[SenseVoice] 文件识别耗时: {elapsed:.3f}s | 结果: {text}")
|
||
return text.strip()
|
||
else:
|
||
return ""
|
||
|
||
except Exception as e:
|
||
print(f"[SenseVoice] 文件识别失败: {e}")
|
||
return ""
|