406 lines
15 KiB
Python
406 lines
15 KiB
Python
import re
|
||
import os
|
||
import time
|
||
import json
|
||
import hashlib
|
||
import asyncio
|
||
import subprocess
|
||
import tempfile
|
||
import unicodedata
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
import httpx
|
||
from loguru import logger
|
||
|
||
from app.services.storage import storage_service
|
||
from app.modules.ref_audios.schemas import RefAudioResponse, RefAudioListResponse
|
||
|
||
ALLOWED_AUDIO_EXTENSIONS = {'.wav', '.mp3', '.m4a', '.webm', '.ogg', '.flac', '.aac'}
|
||
BUCKET_REF_AUDIOS = "ref-audios"
|
||
|
||
|
||
def sanitize_filename(filename: str) -> str:
|
||
"""清理文件名用于 Storage key(仅保留 ASCII 安全字符)。"""
|
||
normalized = unicodedata.normalize("NFKD", filename)
|
||
ascii_name = normalized.encode("ascii", "ignore").decode("ascii")
|
||
safe_name = re.sub(r"[^A-Za-z0-9._-]+", "_", ascii_name).strip("._-")
|
||
|
||
# 纯中文/emoji 等场景会被清空,使用稳定哈希兜底,避免 InvalidKey
|
||
if not safe_name:
|
||
digest = hashlib.md5(filename.encode("utf-8")).hexdigest()[:12]
|
||
safe_name = f"audio_{digest}"
|
||
|
||
if len(safe_name) > 50:
|
||
ext = Path(safe_name).suffix
|
||
safe_name = safe_name[:50 - len(ext)] + ext
|
||
return safe_name
|
||
|
||
|
||
def _get_audio_duration(file_path: str) -> float:
|
||
"""获取音频时长 (秒)"""
|
||
try:
|
||
result = subprocess.run(
|
||
['ffprobe', '-v', 'quiet', '-show_entries', 'format=duration',
|
||
'-of', 'csv=p=0', file_path],
|
||
capture_output=True, text=True, timeout=10
|
||
)
|
||
return float(result.stdout.strip())
|
||
except Exception as e:
|
||
logger.warning(f"获取音频时长失败: {e}")
|
||
return 0.0
|
||
|
||
|
||
def _find_silence_cut_point(file_path: str, max_duration: float) -> float:
|
||
"""在 max_duration 附近找一个静音点作为截取位置,找不到则回退到 max_duration"""
|
||
try:
|
||
# 用 silencedetect 找所有静音段(阈值 -30dB,最短 0.3 秒)
|
||
result = subprocess.run(
|
||
['ffmpeg', '-i', file_path, '-af',
|
||
'silencedetect=noise=-30dB:d=0.3', '-f', 'null', '-'],
|
||
capture_output=True, text=True, timeout=30
|
||
)
|
||
# 解析 silence_end 时间点
|
||
import re as _re
|
||
ends = [float(m) for m in _re.findall(r'silence_end:\s*([\d.]+)', result.stderr)]
|
||
# 找 max_duration 之前最后一个静音结束点(至少 3 秒)
|
||
candidates = [t for t in ends if 3.0 <= t <= max_duration]
|
||
if candidates:
|
||
cut = candidates[-1]
|
||
logger.info(f"Found silence cut point at {cut:.1f}s (max={max_duration}s)")
|
||
return cut
|
||
except Exception as e:
|
||
logger.warning(f"Silence detection failed: {e}")
|
||
return max_duration
|
||
|
||
|
||
def _convert_to_wav(input_path: str, output_path: str, max_duration: float = 0) -> bool:
|
||
"""将音频转换为 WAV 格式 (16kHz, mono),可选截取前 max_duration 秒并淡出"""
|
||
try:
|
||
cmd = ['ffmpeg', '-y', '-i', input_path]
|
||
if max_duration > 0:
|
||
cmd += ['-t', str(max_duration)]
|
||
# 末尾 0.1 秒淡出,避免截断爆音
|
||
fade_start = max(0, max_duration - 0.1)
|
||
cmd += ['-af', f'afade=t=out:st={fade_start}:d=0.1']
|
||
cmd += ['-ar', '16000', '-ac', '1', '-acodec', 'pcm_s16le', output_path]
|
||
subprocess.run(cmd, capture_output=True, timeout=60, check=True)
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"音频转换失败: {e}")
|
||
return False
|
||
|
||
|
||
async def upload_ref_audio(file, ref_text: str, user_id: str) -> dict:
|
||
"""上传参考音频:转码、获取时长、存储到 Supabase"""
|
||
if not file.filename:
|
||
raise ValueError("文件名无效")
|
||
filename = file.filename
|
||
|
||
ext = Path(filename).suffix.lower()
|
||
if ext not in ALLOWED_AUDIO_EXTENSIONS:
|
||
raise ValueError(f"不支持的音频格式: {ext}。支持的格式: {', '.join(ALLOWED_AUDIO_EXTENSIONS)}")
|
||
|
||
# 创建临时文件
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_input:
|
||
content = await file.read()
|
||
tmp_input.write(content)
|
||
tmp_input_path = tmp_input.name
|
||
|
||
try:
|
||
# 转换为 WAV 格式
|
||
tmp_wav_path = tmp_input_path + ".wav"
|
||
if not _convert_to_wav(tmp_input_path, tmp_wav_path):
|
||
raise RuntimeError("音频格式转换失败")
|
||
|
||
# 获取音频时长
|
||
duration = _get_audio_duration(tmp_wav_path)
|
||
if duration < 1.0:
|
||
raise ValueError("音频时长过短,至少需要 1 秒")
|
||
|
||
# 超过 10 秒自动在静音点截取(CosyVoice 对 3-10 秒效果最好)
|
||
MAX_REF_DURATION = 10.0
|
||
if duration > MAX_REF_DURATION:
|
||
cut_point = _find_silence_cut_point(tmp_wav_path, MAX_REF_DURATION)
|
||
logger.info(f"Ref audio {duration:.1f}s > {MAX_REF_DURATION}s, trimming at {cut_point:.1f}s")
|
||
trimmed_path = tmp_input_path + "_trimmed.wav"
|
||
if not _convert_to_wav(tmp_wav_path, trimmed_path, max_duration=cut_point):
|
||
raise RuntimeError("音频截取失败")
|
||
os.unlink(tmp_wav_path)
|
||
tmp_wav_path = trimmed_path
|
||
duration = _get_audio_duration(tmp_wav_path)
|
||
|
||
# 自动转写参考音频内容
|
||
try:
|
||
from app.services.whisper_service import whisper_service
|
||
transcribed = await whisper_service.transcribe(tmp_wav_path)
|
||
if transcribed.strip():
|
||
ref_text = transcribed.strip()
|
||
logger.info(f"Auto-transcribed ref audio: {ref_text[:50]}...")
|
||
except Exception as e:
|
||
logger.warning(f"Auto-transcribe failed: {e}")
|
||
|
||
if not ref_text or not ref_text.strip():
|
||
raise ValueError("无法识别音频内容,请确保音频包含清晰的语音")
|
||
|
||
# 检查重名
|
||
existing_files = await storage_service.list_files(BUCKET_REF_AUDIOS, user_id)
|
||
dup_count = 0
|
||
search_suffix = f"_{filename}"
|
||
for f in existing_files:
|
||
fname = f.get('name', '')
|
||
if fname.endswith(search_suffix):
|
||
dup_count += 1
|
||
|
||
final_display_name = filename
|
||
if dup_count > 0:
|
||
name_stem = Path(filename).stem
|
||
name_ext = Path(filename).suffix
|
||
final_display_name = f"{name_stem}({dup_count}){name_ext}"
|
||
|
||
# 生成存储路径
|
||
timestamp = int(time.time())
|
||
safe_name = sanitize_filename(Path(filename).stem)
|
||
storage_path = f"{user_id}/{timestamp}_{safe_name}.wav"
|
||
|
||
# 上传 WAV 文件
|
||
with open(tmp_wav_path, 'rb') as f:
|
||
wav_data = f.read()
|
||
|
||
await storage_service.upload_file(
|
||
bucket=BUCKET_REF_AUDIOS,
|
||
path=storage_path,
|
||
file_data=wav_data,
|
||
content_type="audio/wav"
|
||
)
|
||
|
||
# 上传元数据 JSON
|
||
metadata = {
|
||
"ref_text": ref_text.strip(),
|
||
"original_filename": final_display_name,
|
||
"duration_sec": duration,
|
||
"created_at": timestamp
|
||
}
|
||
metadata_path = f"{user_id}/{timestamp}_{safe_name}.json"
|
||
await storage_service.upload_file(
|
||
bucket=BUCKET_REF_AUDIOS,
|
||
path=metadata_path,
|
||
file_data=json.dumps(metadata, ensure_ascii=False).encode('utf-8'),
|
||
content_type="application/json"
|
||
)
|
||
|
||
# 获取签名 URL
|
||
signed_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, storage_path)
|
||
|
||
return RefAudioResponse(
|
||
id=storage_path,
|
||
name=filename,
|
||
path=signed_url,
|
||
ref_text=ref_text.strip(),
|
||
duration_sec=duration,
|
||
created_at=timestamp
|
||
).model_dump()
|
||
|
||
finally:
|
||
os.unlink(tmp_input_path)
|
||
if os.path.exists(tmp_input_path + ".wav"):
|
||
os.unlink(tmp_input_path + ".wav")
|
||
|
||
|
||
async def list_ref_audios(user_id: str) -> dict:
|
||
"""列出用户的所有参考音频"""
|
||
files = await storage_service.list_files(BUCKET_REF_AUDIOS, user_id)
|
||
wav_files = [f for f in files if f.get("name", "").endswith(".wav")]
|
||
|
||
if not wav_files:
|
||
return RefAudioListResponse(items=[]).model_dump()
|
||
|
||
async def fetch_audio_info(f):
|
||
name = f.get("name", "")
|
||
storage_path = f"{user_id}/{name}"
|
||
metadata_name = name.replace(".wav", ".json")
|
||
metadata_path = f"{user_id}/{metadata_name}"
|
||
|
||
ref_text = ""
|
||
duration_sec = 0.0
|
||
created_at = 0
|
||
original_filename = ""
|
||
|
||
try:
|
||
metadata_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, metadata_path)
|
||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||
resp = await client.get(metadata_url)
|
||
if resp.status_code == 200:
|
||
metadata = resp.json()
|
||
ref_text = metadata.get("ref_text", "")
|
||
duration_sec = metadata.get("duration_sec", 0.0)
|
||
created_at = metadata.get("created_at", 0)
|
||
original_filename = metadata.get("original_filename", "")
|
||
except Exception as e:
|
||
logger.debug(f"读取 metadata 失败: {e}")
|
||
try:
|
||
created_at = int(name.split("_")[0])
|
||
except:
|
||
pass
|
||
|
||
signed_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, storage_path)
|
||
|
||
display_name = original_filename if original_filename else name
|
||
if not display_name or display_name == name:
|
||
match = re.match(r'^\d+_(.+)$', name)
|
||
if match:
|
||
display_name = match.group(1)
|
||
|
||
return RefAudioResponse(
|
||
id=storage_path,
|
||
name=display_name,
|
||
path=signed_url,
|
||
ref_text=ref_text,
|
||
duration_sec=duration_sec,
|
||
created_at=created_at
|
||
)
|
||
|
||
items = await asyncio.gather(*[fetch_audio_info(f) for f in wav_files])
|
||
items = sorted(items, key=lambda x: x.created_at, reverse=True)
|
||
|
||
return RefAudioListResponse(items=items).model_dump()
|
||
|
||
|
||
async def delete_ref_audio(audio_id: str, user_id: str) -> None:
|
||
"""删除参考音频及其元数据"""
|
||
if not audio_id.startswith(f"{user_id}/"):
|
||
raise PermissionError("无权删除此文件")
|
||
|
||
await storage_service.delete_file(BUCKET_REF_AUDIOS, audio_id)
|
||
|
||
metadata_path = audio_id.replace(".wav", ".json")
|
||
try:
|
||
await storage_service.delete_file(BUCKET_REF_AUDIOS, metadata_path)
|
||
except:
|
||
pass
|
||
|
||
|
||
async def rename_ref_audio(audio_id: str, new_name: str, user_id: str) -> dict:
|
||
"""重命名参考音频(修改 metadata 中的 display name)"""
|
||
if not audio_id.startswith(f"{user_id}/"):
|
||
raise PermissionError("无权修改此文件")
|
||
|
||
new_name = new_name.strip()
|
||
if not new_name:
|
||
raise ValueError("新名称不能为空")
|
||
|
||
if not Path(new_name).suffix:
|
||
new_name += ".wav"
|
||
|
||
# 下载现有 metadata
|
||
metadata_path = audio_id.replace(".wav", ".json")
|
||
try:
|
||
metadata_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, metadata_path)
|
||
async with httpx.AsyncClient() as client:
|
||
resp = await client.get(metadata_url)
|
||
if resp.status_code == 200:
|
||
metadata = resp.json()
|
||
else:
|
||
raise Exception(f"Failed to fetch metadata: {resp.status_code}")
|
||
except Exception as e:
|
||
logger.warning(f"无法读取元数据: {e}, 将创建新的元数据")
|
||
metadata = {
|
||
"ref_text": "",
|
||
"duration_sec": 0.0,
|
||
"created_at": int(time.time()),
|
||
"original_filename": new_name
|
||
}
|
||
|
||
# 更新并覆盖上传
|
||
metadata["original_filename"] = new_name
|
||
await storage_service.upload_file(
|
||
bucket=BUCKET_REF_AUDIOS,
|
||
path=metadata_path,
|
||
file_data=json.dumps(metadata, ensure_ascii=False).encode('utf-8'),
|
||
content_type="application/json"
|
||
)
|
||
|
||
return {"name": new_name}
|
||
|
||
|
||
async def retranscribe_ref_audio(audio_id: str, user_id: str) -> dict:
|
||
"""重新转写参考音频的 ref_text,并截取前 10 秒重新上传(用于迁移旧数据)"""
|
||
if not audio_id.startswith(f"{user_id}/"):
|
||
raise PermissionError("无权修改此文件")
|
||
|
||
# 下载音频到临时文件
|
||
audio_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, audio_id)
|
||
tmp_wav_path = None
|
||
trimmed_path = None
|
||
try:
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||
tmp_wav_path = tmp.name
|
||
timeout = httpx.Timeout(None)
|
||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||
async with client.stream("GET", audio_url) as resp:
|
||
resp.raise_for_status()
|
||
async for chunk in resp.aiter_bytes():
|
||
tmp.write(chunk)
|
||
|
||
# 超过 10 秒则截取前 10 秒并重新上传音频
|
||
MAX_REF_DURATION = 10.0
|
||
duration = _get_audio_duration(tmp_wav_path)
|
||
transcribe_path = tmp_wav_path
|
||
need_reupload = False
|
||
|
||
if duration > MAX_REF_DURATION:
|
||
cut_point = _find_silence_cut_point(tmp_wav_path, MAX_REF_DURATION)
|
||
logger.info(f"Retranscribe: trimming {audio_id} from {duration:.1f}s at {cut_point:.1f}s")
|
||
trimmed_path = tmp_wav_path + "_trimmed.wav"
|
||
if _convert_to_wav(tmp_wav_path, trimmed_path, max_duration=cut_point):
|
||
transcribe_path = trimmed_path
|
||
duration = _get_audio_duration(trimmed_path)
|
||
need_reupload = True
|
||
|
||
# Whisper 转写
|
||
from app.services.whisper_service import whisper_service
|
||
transcribed = await whisper_service.transcribe(transcribe_path)
|
||
if not transcribed or not transcribed.strip():
|
||
raise ValueError("无法识别音频内容")
|
||
|
||
ref_text = transcribed.strip()
|
||
logger.info(f"Re-transcribed ref audio {audio_id}: {ref_text[:50]}...")
|
||
|
||
# 截取过的音频重新上传覆盖原文件
|
||
if need_reupload and trimmed_path:
|
||
with open(trimmed_path, "rb") as f:
|
||
await storage_service.upload_file(
|
||
bucket=BUCKET_REF_AUDIOS, path=audio_id,
|
||
file_data=f.read(), content_type="audio/wav",
|
||
)
|
||
logger.info(f"Re-uploaded trimmed audio: {audio_id} ({duration:.1f}s)")
|
||
|
||
# 更新 metadata
|
||
metadata_path = audio_id.replace(".wav", ".json")
|
||
try:
|
||
meta_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, metadata_path)
|
||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||
resp = await client.get(meta_url)
|
||
if resp.status_code == 200:
|
||
metadata = resp.json()
|
||
else:
|
||
raise Exception(f"status {resp.status_code}")
|
||
except Exception:
|
||
metadata = {}
|
||
|
||
metadata["ref_text"] = ref_text
|
||
metadata["duration_sec"] = duration
|
||
await storage_service.upload_file(
|
||
bucket=BUCKET_REF_AUDIOS,
|
||
path=metadata_path,
|
||
file_data=json.dumps(metadata, ensure_ascii=False).encode('utf-8'),
|
||
content_type="application/json"
|
||
)
|
||
|
||
return {"ref_text": ref_text, "duration_sec": duration}
|
||
finally:
|
||
if tmp_wav_path and os.path.exists(tmp_wav_path):
|
||
os.unlink(tmp_wav_path)
|
||
if trimmed_path and os.path.exists(trimmed_path):
|
||
os.unlink(trimmed_path)
|