386 lines
13 KiB
Python
386 lines
13 KiB
Python
"""
|
||
字幕对齐服务
|
||
使用 faster-whisper 生成字级别时间戳
|
||
"""
|
||
|
||
import json
|
||
import re
|
||
from pathlib import Path
|
||
from typing import Optional, List
|
||
from loguru import logger
|
||
|
||
# 模型缓存
|
||
_whisper_model = None
|
||
|
||
# 断句标点
|
||
SENTENCE_PUNCTUATION = set('。!?,、;:,.!?;:')
|
||
# 每行最大字数
|
||
MAX_CHARS_PER_LINE = 12
|
||
|
||
|
||
def split_word_to_chars(word: str, start: float, end: float) -> list:
|
||
"""
|
||
将词拆分成单个字符,时间戳线性插值。
|
||
保留英文词前的空格(Whisper 输出如 " Hello"),用于正确重建英文字幕。
|
||
|
||
Args:
|
||
word: 词文本(可能含前导空格)
|
||
start: 词开始时间
|
||
end: 词结束时间
|
||
|
||
Returns:
|
||
单字符列表,每个包含 word/start/end
|
||
"""
|
||
# 保留前导空格(英文 Whisper 输出常见 " Hello" 形式)
|
||
leading_space = ""
|
||
if word and not word[0].strip():
|
||
leading_space = " "
|
||
word = word.lstrip()
|
||
|
||
tokens = []
|
||
ascii_buffer = ""
|
||
pending_space = False # 记录是否有待处理的空格(用于英文单词间距)
|
||
|
||
for char in word:
|
||
if not char.strip():
|
||
# 空格:flush ascii_buffer,标记下一个 token 需要前导空格
|
||
if ascii_buffer:
|
||
tokens.append(ascii_buffer)
|
||
ascii_buffer = ""
|
||
if tokens: # 仅在已有 token 时标记(避免开头重复空格)
|
||
pending_space = True
|
||
continue
|
||
|
||
if char.isascii() and char.isalnum():
|
||
if pending_space and not ascii_buffer:
|
||
ascii_buffer = " " # 将空格前置到新英文单词
|
||
pending_space = False
|
||
ascii_buffer += char
|
||
continue
|
||
|
||
if ascii_buffer:
|
||
tokens.append(ascii_buffer)
|
||
ascii_buffer = ""
|
||
|
||
prefix = " " if pending_space else ""
|
||
pending_space = False
|
||
tokens.append(prefix + char)
|
||
|
||
if ascii_buffer:
|
||
tokens.append(ascii_buffer)
|
||
|
||
if not tokens:
|
||
return []
|
||
|
||
if len(tokens) == 1:
|
||
w = leading_space + tokens[0] if leading_space else tokens[0]
|
||
return [{"word": w, "start": start, "end": end}]
|
||
|
||
# 线性插值时间戳
|
||
duration = end - start
|
||
token_duration = duration / len(tokens)
|
||
|
||
result = []
|
||
for i, token in enumerate(tokens):
|
||
token_start = start + i * token_duration
|
||
token_end = start + (i + 1) * token_duration
|
||
w = token
|
||
if i == 0 and leading_space:
|
||
w = leading_space + w
|
||
result.append({
|
||
"word": w,
|
||
"start": round(token_start, 3),
|
||
"end": round(token_end, 3)
|
||
})
|
||
|
||
return result
|
||
|
||
|
||
def split_segment_to_lines(words: List[dict], max_chars: int = MAX_CHARS_PER_LINE) -> List[dict]:
|
||
"""
|
||
将长段落按标点和字数拆分成多行
|
||
|
||
Args:
|
||
words: 字列表,每个包含 word/start/end
|
||
max_chars: 每行最大字数
|
||
|
||
Returns:
|
||
拆分后的 segment 列表
|
||
"""
|
||
if not words:
|
||
return []
|
||
|
||
segments = []
|
||
current_words = []
|
||
current_text = ""
|
||
|
||
for word_info in words:
|
||
char = word_info["word"]
|
||
current_words.append(word_info)
|
||
current_text += char
|
||
|
||
# 判断是否需要断句
|
||
should_break = False
|
||
|
||
# 1. 遇到断句标点
|
||
if char in SENTENCE_PUNCTUATION:
|
||
should_break = True
|
||
# 2. 达到最大字数
|
||
elif len(current_text) >= max_chars:
|
||
should_break = True
|
||
|
||
if should_break and current_words:
|
||
segments.append({
|
||
"text": current_text.strip(),
|
||
"start": current_words[0]["start"],
|
||
"end": current_words[-1]["end"],
|
||
"words": current_words.copy()
|
||
})
|
||
current_words = []
|
||
current_text = ""
|
||
|
||
# 处理剩余的字
|
||
if current_words:
|
||
segments.append({
|
||
"text": current_text.strip(),
|
||
"start": current_words[0]["start"],
|
||
"end": current_words[-1]["end"],
|
||
"words": current_words.copy()
|
||
})
|
||
|
||
return segments
|
||
|
||
|
||
class WhisperService:
|
||
"""字幕对齐服务(基于 faster-whisper)"""
|
||
|
||
def __init__(
|
||
self,
|
||
model_size: str = "large-v3",
|
||
device: str = "cuda",
|
||
compute_type: str = "float16",
|
||
):
|
||
self.model_size = model_size
|
||
self.device = device
|
||
self.compute_type = compute_type
|
||
|
||
def _load_model(self):
|
||
"""懒加载 faster-whisper 模型"""
|
||
global _whisper_model
|
||
|
||
if _whisper_model is None:
|
||
from faster_whisper import WhisperModel
|
||
|
||
logger.info(f"Loading faster-whisper model: {self.model_size} on {self.device}")
|
||
_whisper_model = WhisperModel(
|
||
self.model_size,
|
||
device=self.device,
|
||
compute_type=self.compute_type
|
||
)
|
||
logger.info("faster-whisper model loaded")
|
||
|
||
return _whisper_model
|
||
|
||
async def align(
|
||
self,
|
||
audio_path: str,
|
||
text: str,
|
||
output_path: Optional[str] = None,
|
||
language: str = "zh",
|
||
original_text: Optional[str] = None,
|
||
) -> dict:
|
||
"""
|
||
对音频进行转录,生成字级别时间戳
|
||
|
||
Args:
|
||
audio_path: 音频文件路径
|
||
text: 原始文本(用于参考,但实际使用 whisper 转录结果)
|
||
output_path: 可选,输出 JSON 文件路径
|
||
language: 语言代码 (zh/en 等)
|
||
original_text: 原始文案。非空时,Whisper 仅用于检测总时间范围,
|
||
字幕文字用此原文替换(解决语言不匹配问题)
|
||
|
||
Returns:
|
||
包含字级别时间戳的字典
|
||
"""
|
||
import asyncio
|
||
|
||
# 英文等西文需要更大的每行字数
|
||
max_chars = 40 if language != "zh" else MAX_CHARS_PER_LINE
|
||
|
||
def _do_transcribe():
|
||
model = self._load_model()
|
||
|
||
logger.info(f"Transcribing audio: {audio_path}")
|
||
|
||
# 转录并获取字级别时间戳
|
||
segments_iter, info = model.transcribe(
|
||
audio_path,
|
||
language=language,
|
||
word_timestamps=True, # 启用字级别时间戳
|
||
vad_filter=True, # 启用 VAD 过滤静音
|
||
)
|
||
|
||
logger.info(f"Detected language: {info.language} (prob: {info.language_probability:.2f})")
|
||
|
||
# 收集 Whisper 转录结果(始终需要,用于获取时间范围)
|
||
all_segments = []
|
||
whisper_first_start = None
|
||
whisper_last_end = None
|
||
for segment in segments_iter:
|
||
all_words = []
|
||
if segment.words:
|
||
for word_info in segment.words:
|
||
word_text = word_info.word
|
||
if word_text.strip():
|
||
if whisper_first_start is None:
|
||
whisper_first_start = word_info.start
|
||
whisper_last_end = word_info.end
|
||
chars = split_word_to_chars(
|
||
word_text,
|
||
word_info.start,
|
||
word_info.end
|
||
)
|
||
all_words.extend(chars)
|
||
|
||
if all_words:
|
||
line_segments = split_segment_to_lines(all_words, max_chars)
|
||
all_segments.extend(line_segments)
|
||
|
||
# 如果提供了 original_text,用原文替换 Whisper 转录文字,保留语音节奏
|
||
if original_text and original_text.strip() and whisper_first_start is not None:
|
||
# 收集 Whisper 逐字时间戳(保留真实语音节奏)
|
||
whisper_chars = []
|
||
for seg in all_segments:
|
||
whisper_chars.extend(seg.get("words", []))
|
||
|
||
# 用原文字符 + Whisper 节奏生成新的时间戳
|
||
orig_chars = split_word_to_chars(
|
||
original_text.strip(),
|
||
whisper_first_start,
|
||
whisper_last_end
|
||
)
|
||
|
||
if orig_chars and len(whisper_chars) >= 2:
|
||
# 将原文字符按比例映射到 Whisper 的时间节奏上
|
||
n_w = len(whisper_chars)
|
||
n_o = len(orig_chars)
|
||
w_starts = [c["start"] for c in whisper_chars]
|
||
w_final_end = whisper_chars[-1]["end"]
|
||
|
||
logger.info(
|
||
f"Using original_text for subtitles (len={len(original_text)}), "
|
||
f"rhythm-mapping {n_o} orig chars onto {n_w} Whisper chars, "
|
||
f"time range: {whisper_first_start:.2f}-{whisper_last_end:.2f}s"
|
||
)
|
||
|
||
remapped = []
|
||
for i, oc in enumerate(orig_chars):
|
||
# 原文第 i 个字符对应 Whisper 时间线的位置
|
||
pos = (i / n_o) * n_w
|
||
idx = min(int(pos), n_w - 1)
|
||
frac = pos - idx
|
||
t_start = (
|
||
w_starts[idx] + frac * (w_starts[idx + 1] - w_starts[idx])
|
||
if idx < n_w - 1
|
||
else w_starts[idx] + frac * (w_final_end - w_starts[idx])
|
||
)
|
||
|
||
# 结束时间 = 下一个字符的开始时间
|
||
pos_next = ((i + 1) / n_o) * n_w
|
||
idx_n = min(int(pos_next), n_w - 1)
|
||
frac_n = pos_next - idx_n
|
||
t_end = (
|
||
w_starts[idx_n] + frac_n * (w_starts[idx_n + 1] - w_starts[idx_n])
|
||
if idx_n < n_w - 1
|
||
else w_starts[idx_n] + frac_n * (w_final_end - w_starts[idx_n])
|
||
)
|
||
|
||
remapped.append({
|
||
"word": oc["word"],
|
||
"start": round(t_start, 3),
|
||
"end": round(t_end, 3),
|
||
})
|
||
|
||
all_segments = split_segment_to_lines(remapped, max_chars)
|
||
logger.info(f"Rebuilt {len(all_segments)} subtitle segments (rhythm-mapped)")
|
||
elif orig_chars:
|
||
# Whisper 字符不足,退回线性插值
|
||
all_segments = split_segment_to_lines(orig_chars, max_chars)
|
||
logger.info(f"Rebuilt {len(all_segments)} subtitle segments (linear fallback)")
|
||
|
||
logger.info(f"Generated {len(all_segments)} subtitle segments")
|
||
return {"segments": all_segments}
|
||
|
||
# 在线程池中执行
|
||
loop = asyncio.get_event_loop()
|
||
result = await loop.run_in_executor(None, _do_transcribe)
|
||
|
||
# 保存到文件
|
||
if output_path:
|
||
output_file = Path(output_path)
|
||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||
with open(output_file, "w", encoding="utf-8") as f:
|
||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||
logger.info(f"Captions saved to: {output_path}")
|
||
|
||
return result
|
||
|
||
async def transcribe(self, audio_path: str, language: str | None = None) -> str:
|
||
"""
|
||
仅转录文本(用于提取文案)
|
||
|
||
Args:
|
||
audio_path: 音频/视频文件路径
|
||
language: 语言代码,None 表示自动检测
|
||
|
||
Returns:
|
||
纯文本内容
|
||
"""
|
||
import asyncio
|
||
|
||
def _do_transcribe_text():
|
||
model = self._load_model()
|
||
logger.info(f"Extracting script from: {audio_path}")
|
||
|
||
# 转录 (无需字级时间戳)
|
||
segments_iter, _ = model.transcribe(
|
||
audio_path,
|
||
language=language,
|
||
word_timestamps=False,
|
||
vad_filter=True,
|
||
)
|
||
|
||
text_parts = []
|
||
for segment in segments_iter:
|
||
text_parts.append(segment.text.strip())
|
||
|
||
full_text = " ".join(text_parts)
|
||
logger.info(f"Extracted text length: {len(full_text)}")
|
||
return full_text
|
||
|
||
# 在线程池中执行
|
||
loop = asyncio.get_event_loop()
|
||
result = await loop.run_in_executor(None, _do_transcribe_text)
|
||
return result
|
||
|
||
async def check_health(self) -> dict:
|
||
"""检查服务健康状态"""
|
||
try:
|
||
from faster_whisper import WhisperModel
|
||
return {
|
||
"ready": True,
|
||
"model_size": self.model_size,
|
||
"device": self.device,
|
||
"backend": "faster-whisper"
|
||
}
|
||
except ImportError:
|
||
return {
|
||
"ready": False,
|
||
"error": "faster-whisper not installed"
|
||
}
|
||
|
||
|
||
# 全局服务实例
|
||
whisper_service = WhisperService()
|