277 lines
8.8 KiB
Python
277 lines
8.8 KiB
Python
"""
|
|
参考音频管理 API
|
|
支持上传/列表/删除参考音频,用于 Qwen3-TTS 声音克隆
|
|
"""
|
|
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends
|
|
from pydantic import BaseModel
|
|
from typing import List, Optional
|
|
from pathlib import Path
|
|
from loguru import logger
|
|
import time
|
|
import json
|
|
import subprocess
|
|
import tempfile
|
|
import os
|
|
import re
|
|
|
|
from app.core.deps import get_current_user
|
|
from app.services.storage import storage_service
|
|
|
|
router = APIRouter()
|
|
|
|
# 支持的音频格式
|
|
ALLOWED_AUDIO_EXTENSIONS = {'.wav', '.mp3', '.m4a', '.webm', '.ogg', '.flac', '.aac'}
|
|
|
|
# 参考音频 bucket
|
|
BUCKET_REF_AUDIOS = "ref-audios"
|
|
|
|
|
|
class RefAudioResponse(BaseModel):
|
|
id: str
|
|
name: str
|
|
path: str # signed URL for playback
|
|
ref_text: str
|
|
duration_sec: float
|
|
created_at: int
|
|
|
|
|
|
class RefAudioListResponse(BaseModel):
|
|
items: List[RefAudioResponse]
|
|
|
|
|
|
def sanitize_filename(filename: str) -> str:
|
|
"""清理文件名,移除特殊字符"""
|
|
safe_name = re.sub(r'[<>:"/\\|?*\s]', '_', filename)
|
|
if len(safe_name) > 50:
|
|
ext = Path(safe_name).suffix
|
|
safe_name = safe_name[:50 - len(ext)] + ext
|
|
return safe_name
|
|
|
|
|
|
def get_audio_duration(file_path: str) -> float:
|
|
"""获取音频时长 (秒)"""
|
|
try:
|
|
result = subprocess.run(
|
|
['ffprobe', '-v', 'quiet', '-show_entries', 'format=duration',
|
|
'-of', 'csv=p=0', file_path],
|
|
capture_output=True, text=True, timeout=10
|
|
)
|
|
return float(result.stdout.strip())
|
|
except Exception as e:
|
|
logger.warning(f"获取音频时长失败: {e}")
|
|
return 0.0
|
|
|
|
|
|
def convert_to_wav(input_path: str, output_path: str) -> bool:
|
|
"""将音频转换为 WAV 格式 (16kHz, mono)"""
|
|
try:
|
|
subprocess.run([
|
|
'ffmpeg', '-y', '-i', input_path,
|
|
'-ar', '16000', # 16kHz 采样率
|
|
'-ac', '1', # 单声道
|
|
'-acodec', 'pcm_s16le', # 16-bit PCM
|
|
output_path
|
|
], capture_output=True, timeout=60, check=True)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"音频转换失败: {e}")
|
|
return False
|
|
|
|
|
|
@router.post("", response_model=RefAudioResponse)
|
|
async def upload_ref_audio(
|
|
file: UploadFile = File(...),
|
|
ref_text: str = Form(...),
|
|
user: dict = Depends(get_current_user)
|
|
):
|
|
"""
|
|
上传参考音频
|
|
|
|
- file: 音频文件 (支持 wav, mp3, m4a, webm 等)
|
|
- ref_text: 参考音频的转写文字 (必填)
|
|
"""
|
|
user_id = user["id"]
|
|
|
|
# 验证文件扩展名
|
|
ext = Path(file.filename).suffix.lower()
|
|
if ext not in ALLOWED_AUDIO_EXTENSIONS:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"不支持的音频格式: {ext}。支持的格式: {', '.join(ALLOWED_AUDIO_EXTENSIONS)}"
|
|
)
|
|
|
|
# 验证 ref_text
|
|
if not ref_text or len(ref_text.strip()) < 2:
|
|
raise HTTPException(status_code=400, detail="参考文字不能为空")
|
|
|
|
try:
|
|
# 创建临时文件
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_input:
|
|
content = await file.read()
|
|
tmp_input.write(content)
|
|
tmp_input_path = tmp_input.name
|
|
|
|
# 转换为 WAV 格式
|
|
tmp_wav_path = tmp_input_path + ".wav"
|
|
if ext != '.wav':
|
|
if not convert_to_wav(tmp_input_path, tmp_wav_path):
|
|
raise HTTPException(status_code=500, detail="音频格式转换失败")
|
|
else:
|
|
# 即使是 wav 也要标准化格式
|
|
convert_to_wav(tmp_input_path, tmp_wav_path)
|
|
|
|
# 获取音频时长
|
|
duration = get_audio_duration(tmp_wav_path)
|
|
if duration < 1.0:
|
|
raise HTTPException(status_code=400, detail="音频时长过短,至少需要 1 秒")
|
|
if duration > 60.0:
|
|
raise HTTPException(status_code=400, detail="音频时长过长,最多 60 秒")
|
|
|
|
# 生成存储路径
|
|
timestamp = int(time.time())
|
|
safe_name = sanitize_filename(Path(file.filename).stem)
|
|
storage_path = f"{user_id}/{timestamp}_{safe_name}.wav"
|
|
|
|
# 上传 WAV 文件到 Supabase
|
|
with open(tmp_wav_path, 'rb') as f:
|
|
wav_data = f.read()
|
|
|
|
await storage_service.upload_file(
|
|
bucket=BUCKET_REF_AUDIOS,
|
|
path=storage_path,
|
|
file_data=wav_data,
|
|
content_type="audio/wav"
|
|
)
|
|
|
|
# 上传元数据 JSON
|
|
metadata = {
|
|
"ref_text": ref_text.strip(),
|
|
"original_filename": file.filename,
|
|
"duration_sec": duration,
|
|
"created_at": timestamp
|
|
}
|
|
metadata_path = f"{user_id}/{timestamp}_{safe_name}.json"
|
|
await storage_service.upload_file(
|
|
bucket=BUCKET_REF_AUDIOS,
|
|
path=metadata_path,
|
|
file_data=json.dumps(metadata, ensure_ascii=False).encode('utf-8'),
|
|
content_type="application/json"
|
|
)
|
|
|
|
# 获取签名 URL
|
|
signed_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, storage_path)
|
|
|
|
# 清理临时文件
|
|
os.unlink(tmp_input_path)
|
|
if os.path.exists(tmp_wav_path):
|
|
os.unlink(tmp_wav_path)
|
|
|
|
return RefAudioResponse(
|
|
id=storage_path,
|
|
name=file.filename,
|
|
path=signed_url,
|
|
ref_text=ref_text.strip(),
|
|
duration_sec=duration,
|
|
created_at=timestamp
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"上传参考音频失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
|
|
|
|
|
|
@router.get("", response_model=RefAudioListResponse)
|
|
async def list_ref_audios(user: dict = Depends(get_current_user)):
|
|
"""列出当前用户的所有参考音频"""
|
|
user_id = user["id"]
|
|
|
|
try:
|
|
# 列出用户目录下的文件
|
|
files = await storage_service.list_files(BUCKET_REF_AUDIOS, user_id)
|
|
|
|
# 过滤出 .wav 文件并获取对应的 metadata
|
|
items = []
|
|
for f in files:
|
|
name = f.get("name", "")
|
|
if not name.endswith(".wav"):
|
|
continue
|
|
|
|
storage_path = f"{user_id}/{name}"
|
|
|
|
# 尝试读取 metadata
|
|
metadata_name = name.replace(".wav", ".json")
|
|
metadata_path = f"{user_id}/{metadata_name}"
|
|
|
|
ref_text = ""
|
|
duration_sec = 0.0
|
|
created_at = 0
|
|
|
|
try:
|
|
# 获取 metadata 内容
|
|
metadata_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, metadata_path)
|
|
import httpx
|
|
async with httpx.AsyncClient() as client:
|
|
resp = await client.get(metadata_url)
|
|
if resp.status_code == 200:
|
|
metadata = resp.json()
|
|
ref_text = metadata.get("ref_text", "")
|
|
duration_sec = metadata.get("duration_sec", 0.0)
|
|
created_at = metadata.get("created_at", 0)
|
|
except Exception as e:
|
|
logger.warning(f"读取 metadata 失败: {e}")
|
|
# 从文件名提取时间戳
|
|
try:
|
|
created_at = int(name.split("_")[0])
|
|
except:
|
|
pass
|
|
|
|
# 获取音频签名 URL
|
|
signed_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, storage_path)
|
|
|
|
items.append(RefAudioResponse(
|
|
id=storage_path,
|
|
name=name,
|
|
path=signed_url,
|
|
ref_text=ref_text,
|
|
duration_sec=duration_sec,
|
|
created_at=created_at
|
|
))
|
|
|
|
# 按创建时间倒序排列
|
|
items.sort(key=lambda x: x.created_at, reverse=True)
|
|
|
|
return RefAudioListResponse(items=items)
|
|
|
|
except Exception as e:
|
|
logger.error(f"列出参考音频失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"获取列表失败: {str(e)}")
|
|
|
|
|
|
@router.delete("/{audio_id:path}")
|
|
async def delete_ref_audio(audio_id: str, user: dict = Depends(get_current_user)):
|
|
"""删除参考音频"""
|
|
user_id = user["id"]
|
|
|
|
# 安全检查:确保只能删除自己的文件
|
|
if not audio_id.startswith(f"{user_id}/"):
|
|
raise HTTPException(status_code=403, detail="无权删除此文件")
|
|
|
|
try:
|
|
# 删除 WAV 文件
|
|
await storage_service.delete_file(BUCKET_REF_AUDIOS, audio_id)
|
|
|
|
# 删除 metadata JSON
|
|
metadata_path = audio_id.replace(".wav", ".json")
|
|
try:
|
|
await storage_service.delete_file(BUCKET_REF_AUDIOS, metadata_path)
|
|
except:
|
|
pass # metadata 可能不存在
|
|
|
|
return {"success": True, "message": "删除成功"}
|
|
|
|
except Exception as e:
|
|
logger.error(f"删除参考音频失败: {e}")
|
|
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}")
|