177 lines
5.1 KiB
Python
177 lines
5.1 KiB
Python
"""
|
||
字幕对齐服务
|
||
使用 faster-whisper 生成字级别时间戳
|
||
"""
|
||
|
||
import json
|
||
import re
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
from loguru import logger
|
||
|
||
# 模型缓存
|
||
_whisper_model = None
|
||
|
||
|
||
def split_word_to_chars(word: str, start: float, end: float) -> list:
|
||
"""
|
||
将词拆分成单个字符,时间戳线性插值
|
||
|
||
Args:
|
||
word: 词文本
|
||
start: 词开始时间
|
||
end: 词结束时间
|
||
|
||
Returns:
|
||
单字符列表,每个包含 word/start/end
|
||
"""
|
||
# 只保留中文字符和基本标点
|
||
chars = [c for c in word if c.strip()]
|
||
if not chars:
|
||
return []
|
||
|
||
if len(chars) == 1:
|
||
return [{"word": chars[0], "start": start, "end": end}]
|
||
|
||
# 线性插值时间戳
|
||
duration = end - start
|
||
char_duration = duration / len(chars)
|
||
|
||
result = []
|
||
for i, char in enumerate(chars):
|
||
char_start = start + i * char_duration
|
||
char_end = start + (i + 1) * char_duration
|
||
result.append({
|
||
"word": char,
|
||
"start": round(char_start, 3),
|
||
"end": round(char_end, 3)
|
||
})
|
||
|
||
return result
|
||
|
||
|
||
class WhisperService:
|
||
"""字幕对齐服务(基于 faster-whisper)"""
|
||
|
||
def __init__(
|
||
self,
|
||
model_size: str = "large-v3",
|
||
device: str = "cuda",
|
||
compute_type: str = "float16",
|
||
):
|
||
self.model_size = model_size
|
||
self.device = device
|
||
self.compute_type = compute_type
|
||
|
||
def _load_model(self):
|
||
"""懒加载 faster-whisper 模型"""
|
||
global _whisper_model
|
||
|
||
if _whisper_model is None:
|
||
from faster_whisper import WhisperModel
|
||
|
||
logger.info(f"Loading faster-whisper model: {self.model_size} on {self.device}")
|
||
_whisper_model = WhisperModel(
|
||
self.model_size,
|
||
device=self.device,
|
||
compute_type=self.compute_type
|
||
)
|
||
logger.info("faster-whisper model loaded")
|
||
|
||
return _whisper_model
|
||
|
||
async def align(
|
||
self,
|
||
audio_path: str,
|
||
text: str,
|
||
output_path: Optional[str] = None
|
||
) -> dict:
|
||
"""
|
||
对音频进行转录,生成字级别时间戳
|
||
|
||
Args:
|
||
audio_path: 音频文件路径
|
||
text: 原始文本(用于参考,但实际使用 whisper 转录结果)
|
||
output_path: 可选,输出 JSON 文件路径
|
||
|
||
Returns:
|
||
包含字级别时间戳的字典
|
||
"""
|
||
import asyncio
|
||
|
||
def _do_transcribe():
|
||
model = self._load_model()
|
||
|
||
logger.info(f"Transcribing audio: {audio_path}")
|
||
|
||
# 转录并获取字级别时间戳
|
||
segments_iter, info = model.transcribe(
|
||
audio_path,
|
||
language="zh",
|
||
word_timestamps=True, # 启用字级别时间戳
|
||
vad_filter=True, # 启用 VAD 过滤静音
|
||
)
|
||
|
||
logger.info(f"Detected language: {info.language} (prob: {info.language_probability:.2f})")
|
||
|
||
segments = []
|
||
for segment in segments_iter:
|
||
seg_data = {
|
||
"text": segment.text.strip(),
|
||
"start": segment.start,
|
||
"end": segment.end,
|
||
"words": []
|
||
}
|
||
|
||
# 提取每个字的时间戳,并拆分成单字
|
||
if segment.words:
|
||
for word_info in segment.words:
|
||
word_text = word_info.word.strip()
|
||
if word_text:
|
||
# 将词拆分成单字,时间戳线性插值
|
||
chars = split_word_to_chars(
|
||
word_text,
|
||
word_info.start,
|
||
word_info.end
|
||
)
|
||
seg_data["words"].extend(chars)
|
||
|
||
if seg_data["words"]: # 只添加有内容的段落
|
||
segments.append(seg_data)
|
||
|
||
return {"segments": segments}
|
||
|
||
# 在线程池中执行
|
||
loop = asyncio.get_event_loop()
|
||
result = await loop.run_in_executor(None, _do_transcribe)
|
||
|
||
# 保存到文件
|
||
if output_path:
|
||
output_file = Path(output_path)
|
||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(output_file, "w", encoding="utf-8") as f:
|
||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||
logger.info(f"Captions saved to: {output_path}")
|
||
|
||
return result
|
||
|
||
async def check_health(self) -> dict:
|
||
"""检查服务健康状态"""
|
||
try:
|
||
from faster_whisper import WhisperModel
|
||
return {
|
||
"ready": True,
|
||
"model_size": self.model_size,
|
||
"device": self.device,
|
||
"backend": "faster-whisper"
|
||
}
|
||
except ImportError:
|
||
return {
|
||
"ready": False,
|
||
"error": "faster-whisper not installed"
|
||
}
|
||
|
||
|
||
# 全局服务实例
|
||
whisper_service = WhisperService()
|