412 lines
15 KiB
Python
412 lines
15 KiB
Python
"""
|
||
参考音频管理 API
|
||
支持上传/列表/删除参考音频,用于 Qwen3-TTS 声音克隆
|
||
"""
|
||
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends
|
||
from pydantic import BaseModel
|
||
from typing import List, Optional
|
||
from pathlib import Path
|
||
from loguru import logger
|
||
import time
|
||
import json
|
||
import subprocess
|
||
import tempfile
|
||
import os
|
||
import re
|
||
|
||
from app.core.deps import get_current_user
|
||
from app.services.storage import storage_service
|
||
|
||
router = APIRouter()
|
||
|
||
# 支持的音频格式
|
||
ALLOWED_AUDIO_EXTENSIONS = {'.wav', '.mp3', '.m4a', '.webm', '.ogg', '.flac', '.aac'}
|
||
|
||
# 参考音频 bucket
|
||
BUCKET_REF_AUDIOS = "ref-audios"
|
||
|
||
|
||
class RefAudioResponse(BaseModel):
|
||
id: str
|
||
name: str
|
||
path: str # signed URL for playback
|
||
ref_text: str
|
||
duration_sec: float
|
||
created_at: int
|
||
|
||
|
||
class RefAudioListResponse(BaseModel):
|
||
items: List[RefAudioResponse]
|
||
|
||
|
||
def sanitize_filename(filename: str) -> str:
|
||
"""清理文件名,移除特殊字符"""
|
||
safe_name = re.sub(r'[<>:"/\\|?*\s]', '_', filename)
|
||
if len(safe_name) > 50:
|
||
ext = Path(safe_name).suffix
|
||
safe_name = safe_name[:50 - len(ext)] + ext
|
||
return safe_name
|
||
|
||
|
||
def get_audio_duration(file_path: str) -> float:
|
||
"""获取音频时长 (秒)"""
|
||
try:
|
||
result = subprocess.run(
|
||
['ffprobe', '-v', 'quiet', '-show_entries', 'format=duration',
|
||
'-of', 'csv=p=0', file_path],
|
||
capture_output=True, text=True, timeout=10
|
||
)
|
||
return float(result.stdout.strip())
|
||
except Exception as e:
|
||
logger.warning(f"获取音频时长失败: {e}")
|
||
return 0.0
|
||
|
||
|
||
def convert_to_wav(input_path: str, output_path: str) -> bool:
|
||
"""将音频转换为 WAV 格式 (16kHz, mono)"""
|
||
try:
|
||
subprocess.run([
|
||
'ffmpeg', '-y', '-i', input_path,
|
||
'-ar', '16000', # 16kHz 采样率
|
||
'-ac', '1', # 单声道
|
||
'-acodec', 'pcm_s16le', # 16-bit PCM
|
||
output_path
|
||
], capture_output=True, timeout=60, check=True)
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"音频转换失败: {e}")
|
||
return False
|
||
|
||
|
||
@router.post("", response_model=RefAudioResponse)
|
||
async def upload_ref_audio(
|
||
file: UploadFile = File(...),
|
||
ref_text: str = Form(...),
|
||
user: dict = Depends(get_current_user)
|
||
):
|
||
"""
|
||
上传参考音频
|
||
|
||
- file: 音频文件 (支持 wav, mp3, m4a, webm 等)
|
||
- ref_text: 参考音频的转写文字 (必填)
|
||
"""
|
||
user_id = user["id"]
|
||
|
||
# 验证文件扩展名
|
||
ext = Path(file.filename).suffix.lower()
|
||
if ext not in ALLOWED_AUDIO_EXTENSIONS:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"不支持的音频格式: {ext}。支持的格式: {', '.join(ALLOWED_AUDIO_EXTENSIONS)}"
|
||
)
|
||
|
||
# 验证 ref_text
|
||
if not ref_text or len(ref_text.strip()) < 2:
|
||
raise HTTPException(status_code=400, detail="参考文字不能为空")
|
||
|
||
try:
|
||
# 创建临时文件
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_input:
|
||
content = await file.read()
|
||
tmp_input.write(content)
|
||
tmp_input_path = tmp_input.name
|
||
|
||
# 转换为 WAV 格式
|
||
tmp_wav_path = tmp_input_path + ".wav"
|
||
if ext != '.wav':
|
||
if not convert_to_wav(tmp_input_path, tmp_wav_path):
|
||
raise HTTPException(status_code=500, detail="音频格式转换失败")
|
||
else:
|
||
# 即使是 wav 也要标准化格式
|
||
convert_to_wav(tmp_input_path, tmp_wav_path)
|
||
|
||
# 获取音频时长
|
||
duration = get_audio_duration(tmp_wav_path)
|
||
if duration < 1.0:
|
||
raise HTTPException(status_code=400, detail="音频时长过短,至少需要 1 秒")
|
||
if duration > 60.0:
|
||
raise HTTPException(status_code=400, detail="音频时长过长,最多 60 秒")
|
||
|
||
|
||
# 3. 处理重名逻辑 (Friendly Display Name)
|
||
original_name = file.filename
|
||
|
||
# 获取用户现有的所有参考音频列表 (为了检查文件名冲突)
|
||
# 注意: 这种列表方式在文件极多时性能一般,但考虑到单用户参考音频数量有限,目前可行
|
||
existing_files = await storage_service.list_files(BUCKET_REF_AUDIOS, user_id)
|
||
existing_names = set()
|
||
|
||
# 预加载所有现有的 display name
|
||
# 这里需要并发请求 metadata 可能会慢,优化: 仅检查 metadata 文件并解析
|
||
# 简易方案: 仅在 metadata 中读取 original_filename
|
||
# 但 list_files 返回的是 name,我们需要 metadata
|
||
# 考虑到性能,这里使用一种妥协方案:
|
||
# 我们不做全量检查,而是简单的检查:如果用户上传 myvoice.wav
|
||
# 我们看看有没有 (timestamp)_myvoice.wav 这种其实并不能准确判断 display name 是否冲突
|
||
#
|
||
# 正确做法: 应该有个数据库表存 metadata。但目前是无数据库设计。
|
||
#
|
||
# 改用简单方案:
|
||
# 既然我们无法快速获取所有 display name,
|
||
# 我们暂时只处理 "在新上传时,original_filename 保持原样"
|
||
# 但用户希望 "如果在列表中看到重复的,自动加(1)"
|
||
#
|
||
# 鉴于无数据库架构的限制,要在上传时知道"已有的 display name" 成本太高(需遍历下载所有json)。
|
||
#
|
||
# 💡 替代方案:
|
||
# 我们不检查旧的。我们只保证**存储**唯一。
|
||
# 对于用户提到的 "新上传的文件名后加个数字" -> 这通常是指 "另存为" 的逻辑。
|
||
# 既然用户现在的痛点是 "显示了时间戳太丑",而我已经去掉了时间戳显示。
|
||
# 那么如果用户上传两个 "TEST.wav",列表里就会有两个 "TEST.wav" (但时间不同)。
|
||
# 这其实是可以接受的。
|
||
#
|
||
# 但如果用户强求 "自动重命名":
|
||
# 我们可以在这里做一个轻量级的 "同名检测":
|
||
# 检查有没有 *_{original_name} 的文件存在。
|
||
# 如果 storage 里已经有 123_abc.wav, 456_abc.wav
|
||
# 我们可以认为 abc.wav 已经存在。
|
||
|
||
dup_count = 0
|
||
search_suffix = f"_{original_name}" # 比如 _test.wav
|
||
|
||
for f in existing_files:
|
||
fname = f.get('name', '')
|
||
if fname.endswith(search_suffix):
|
||
dup_count += 1
|
||
|
||
final_display_name = original_name
|
||
if dup_count > 0:
|
||
name_stem = Path(original_name).stem
|
||
name_ext = Path(original_name).suffix
|
||
final_display_name = f"{name_stem}({dup_count}){name_ext}"
|
||
|
||
# 生成存储路径 (唯一ID)
|
||
timestamp = int(time.time())
|
||
safe_name = sanitize_filename(Path(file.filename).stem)
|
||
storage_path = f"{user_id}/{timestamp}_{safe_name}.wav"
|
||
|
||
# 上传 WAV 文件到 Supabase
|
||
with open(tmp_wav_path, 'rb') as f:
|
||
wav_data = f.read()
|
||
|
||
await storage_service.upload_file(
|
||
bucket=BUCKET_REF_AUDIOS,
|
||
path=storage_path,
|
||
file_data=wav_data,
|
||
content_type="audio/wav"
|
||
)
|
||
|
||
# 上传元数据 JSON
|
||
metadata = {
|
||
"ref_text": ref_text.strip(),
|
||
"original_filename": final_display_name, # 这里的名字如果有重复会自动加(1)
|
||
"duration_sec": duration,
|
||
"created_at": timestamp
|
||
}
|
||
metadata_path = f"{user_id}/{timestamp}_{safe_name}.json"
|
||
await storage_service.upload_file(
|
||
bucket=BUCKET_REF_AUDIOS,
|
||
path=metadata_path,
|
||
file_data=json.dumps(metadata, ensure_ascii=False).encode('utf-8'),
|
||
content_type="application/json"
|
||
)
|
||
|
||
# 获取签名 URL
|
||
signed_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, storage_path)
|
||
|
||
# 清理临时文件
|
||
os.unlink(tmp_input_path)
|
||
if os.path.exists(tmp_wav_path):
|
||
os.unlink(tmp_wav_path)
|
||
|
||
return RefAudioResponse(
|
||
id=storage_path,
|
||
name=file.filename,
|
||
path=signed_url,
|
||
ref_text=ref_text.strip(),
|
||
duration_sec=duration,
|
||
created_at=timestamp
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"上传参考音频失败: {e}")
|
||
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
|
||
|
||
|
||
@router.get("", response_model=RefAudioListResponse)
|
||
async def list_ref_audios(user: dict = Depends(get_current_user)):
|
||
"""列出当前用户的所有参考音频"""
|
||
user_id = user["id"]
|
||
|
||
try:
|
||
# 列出用户目录下的文件
|
||
files = await storage_service.list_files(BUCKET_REF_AUDIOS, user_id)
|
||
|
||
# 过滤出 .wav 文件并获取对应的 metadata
|
||
items = []
|
||
for f in files:
|
||
name = f.get("name", "")
|
||
if not name.endswith(".wav"):
|
||
continue
|
||
|
||
storage_path = f"{user_id}/{name}"
|
||
|
||
# 尝试读取 metadata
|
||
metadata_name = name.replace(".wav", ".json")
|
||
metadata_path = f"{user_id}/{metadata_name}"
|
||
|
||
ref_text = ""
|
||
duration_sec = 0.0
|
||
created_at = 0
|
||
original_filename = ""
|
||
|
||
try:
|
||
# 获取 metadata 内容
|
||
metadata_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, metadata_path)
|
||
import httpx
|
||
async with httpx.AsyncClient() as client:
|
||
resp = await client.get(metadata_url)
|
||
if resp.status_code == 200:
|
||
metadata = resp.json()
|
||
ref_text = metadata.get("ref_text", "")
|
||
duration_sec = metadata.get("duration_sec", 0.0)
|
||
created_at = metadata.get("created_at", 0)
|
||
original_filename = metadata.get("original_filename", "")
|
||
except Exception as e:
|
||
logger.warning(f"读取 metadata 失败: {e}")
|
||
# 从文件名提取时间戳
|
||
try:
|
||
created_at = int(name.split("_")[0])
|
||
except:
|
||
pass
|
||
|
||
# 获取音频签名 URL
|
||
signed_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, storage_path)
|
||
|
||
# 优先显示原始文件名 (去掉时间戳前缀)
|
||
display_name = original_filename if original_filename else name
|
||
# 如果原始文件名丢失,尝试从现有文件名中通过正则去掉时间戳
|
||
if not display_name or display_name == name:
|
||
# 匹配 "1234567890_filename.wav"
|
||
match = re.match(r'^\d+_(.+)$', name)
|
||
if match:
|
||
display_name = match.group(1)
|
||
|
||
items.append(RefAudioResponse(
|
||
id=storage_path,
|
||
name=display_name,
|
||
path=signed_url,
|
||
ref_text=ref_text,
|
||
duration_sec=duration_sec,
|
||
created_at=created_at
|
||
))
|
||
|
||
# 按创建时间倒序排列
|
||
items.sort(key=lambda x: x.created_at, reverse=True)
|
||
|
||
return RefAudioListResponse(items=items)
|
||
|
||
except Exception as e:
|
||
logger.error(f"列出参考音频失败: {e}")
|
||
raise HTTPException(status_code=500, detail=f"获取列表失败: {str(e)}")
|
||
|
||
|
||
@router.delete("/{audio_id:path}")
|
||
async def delete_ref_audio(audio_id: str, user: dict = Depends(get_current_user)):
|
||
"""删除参考音频"""
|
||
user_id = user["id"]
|
||
|
||
# 安全检查:确保只能删除自己的文件
|
||
if not audio_id.startswith(f"{user_id}/"):
|
||
raise HTTPException(status_code=403, detail="无权删除此文件")
|
||
|
||
try:
|
||
# 删除 WAV 文件
|
||
await storage_service.delete_file(BUCKET_REF_AUDIOS, audio_id)
|
||
|
||
# 删除 metadata JSON
|
||
metadata_path = audio_id.replace(".wav", ".json")
|
||
try:
|
||
await storage_service.delete_file(BUCKET_REF_AUDIOS, metadata_path)
|
||
except:
|
||
pass # metadata 可能不存在
|
||
|
||
return {"success": True, "message": "删除成功"}
|
||
|
||
except Exception as e:
|
||
logger.error(f"删除参考音频失败: {e}")
|
||
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}")
|
||
|
||
|
||
class RenameRequest(BaseModel):
|
||
new_name: str
|
||
|
||
|
||
@router.put("/{audio_id:path}")
|
||
async def rename_ref_audio(
|
||
audio_id: str,
|
||
request: RenameRequest,
|
||
user: dict = Depends(get_current_user)
|
||
):
|
||
"""重命名参考音频 (修改 metadata 中的 display name)"""
|
||
user_id = user["id"]
|
||
|
||
# 安全检查
|
||
if not audio_id.startswith(f"{user_id}/"):
|
||
raise HTTPException(status_code=403, detail="无权修改此文件")
|
||
|
||
new_name = request.new_name.strip()
|
||
if not new_name:
|
||
raise HTTPException(status_code=400, detail="新名称不能为空")
|
||
|
||
# 确保新名称有后缀 (保留原后缀或添加 .wav)
|
||
if not Path(new_name).suffix:
|
||
new_name += ".wav"
|
||
|
||
try:
|
||
# 1. 下载现有的 metadata
|
||
metadata_path = audio_id.replace(".wav", ".json")
|
||
try:
|
||
# 获取已有的 JSON
|
||
import httpx
|
||
metadata_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, metadata_path)
|
||
if not metadata_url:
|
||
# 如果 json 不存在,则需要新建一个基础的
|
||
raise Exception("Metadata not found")
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
resp = await client.get(metadata_url)
|
||
if resp.status_code == 200:
|
||
metadata = resp.json()
|
||
else:
|
||
raise Exception(f"Failed to fetch metadata: {resp.status_code}")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"无法读取元数据: {e}, 将创建新的元数据")
|
||
# 兜底:如果读取失败,构建最小元数据
|
||
metadata = {
|
||
"ref_text": "", # 可能丢失
|
||
"duration_sec": 0.0,
|
||
"created_at": int(time.time()),
|
||
"original_filename": new_name
|
||
}
|
||
|
||
# 2. 更新 original_filename
|
||
metadata["original_filename"] = new_name
|
||
|
||
# 3. 覆盖上传 metadata
|
||
await storage_service.upload_file(
|
||
bucket=BUCKET_REF_AUDIOS,
|
||
path=metadata_path,
|
||
file_data=json.dumps(metadata, ensure_ascii=False).encode('utf-8'),
|
||
content_type="application/json"
|
||
)
|
||
|
||
return {"success": True, "name": new_name}
|
||
|
||
except Exception as e:
|
||
logger.error(f"重命名失败: {e}")
|
||
raise HTTPException(status_code=500, detail=f"重命名失败: {str(e)}")
|