Files
ViGent2/backend/app/services/whisper_service.py
Kevin Wong b74bacb0b5 更新
2026-01-29 17:54:43 +08:00

177 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
字幕对齐服务
使用 faster-whisper 生成字级别时间戳
"""
import json
import re
from pathlib import Path
from typing import Optional
from loguru import logger
# 模型缓存
_whisper_model = None
def split_word_to_chars(word: str, start: float, end: float) -> list:
"""
将词拆分成单个字符,时间戳线性插值
Args:
word: 词文本
start: 词开始时间
end: 词结束时间
Returns:
单字符列表,每个包含 word/start/end
"""
# 只保留中文字符和基本标点
chars = [c for c in word if c.strip()]
if not chars:
return []
if len(chars) == 1:
return [{"word": chars[0], "start": start, "end": end}]
# 线性插值时间戳
duration = end - start
char_duration = duration / len(chars)
result = []
for i, char in enumerate(chars):
char_start = start + i * char_duration
char_end = start + (i + 1) * char_duration
result.append({
"word": char,
"start": round(char_start, 3),
"end": round(char_end, 3)
})
return result
class WhisperService:
"""字幕对齐服务(基于 faster-whisper"""
def __init__(
self,
model_size: str = "large-v3",
device: str = "cuda",
compute_type: str = "float16",
):
self.model_size = model_size
self.device = device
self.compute_type = compute_type
def _load_model(self):
"""懒加载 faster-whisper 模型"""
global _whisper_model
if _whisper_model is None:
from faster_whisper import WhisperModel
logger.info(f"Loading faster-whisper model: {self.model_size} on {self.device}")
_whisper_model = WhisperModel(
self.model_size,
device=self.device,
compute_type=self.compute_type
)
logger.info("faster-whisper model loaded")
return _whisper_model
async def align(
self,
audio_path: str,
text: str,
output_path: Optional[str] = None
) -> dict:
"""
对音频进行转录,生成字级别时间戳
Args:
audio_path: 音频文件路径
text: 原始文本(用于参考,但实际使用 whisper 转录结果)
output_path: 可选,输出 JSON 文件路径
Returns:
包含字级别时间戳的字典
"""
import asyncio
def _do_transcribe():
model = self._load_model()
logger.info(f"Transcribing audio: {audio_path}")
# 转录并获取字级别时间戳
segments_iter, info = model.transcribe(
audio_path,
language="zh",
word_timestamps=True, # 启用字级别时间戳
vad_filter=True, # 启用 VAD 过滤静音
)
logger.info(f"Detected language: {info.language} (prob: {info.language_probability:.2f})")
segments = []
for segment in segments_iter:
seg_data = {
"text": segment.text.strip(),
"start": segment.start,
"end": segment.end,
"words": []
}
# 提取每个字的时间戳,并拆分成单字
if segment.words:
for word_info in segment.words:
word_text = word_info.word.strip()
if word_text:
# 将词拆分成单字,时间戳线性插值
chars = split_word_to_chars(
word_text,
word_info.start,
word_info.end
)
seg_data["words"].extend(chars)
if seg_data["words"]: # 只添加有内容的段落
segments.append(seg_data)
return {"segments": segments}
# 在线程池中执行
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(None, _do_transcribe)
# 保存到文件
if output_path:
output_file = Path(output_path)
output_file.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
logger.info(f"Captions saved to: {output_path}")
return result
async def check_health(self) -> dict:
"""检查服务健康状态"""
try:
from faster_whisper import WhisperModel
return {
"ready": True,
"model_size": self.model_size,
"device": self.device,
"backend": "faster-whisper"
}
except ImportError:
return {
"ready": False,
"error": "faster-whisper not installed"
}
# 全局服务实例
whisper_service = WhisperService()