This commit is contained in:
Kevin Wong
2026-01-29 12:16:41 +08:00
parent 4a3dd2b225
commit 661a8f357c
18 changed files with 2092 additions and 80 deletions

View File

@@ -0,0 +1,276 @@
"""
参考音频管理 API
支持上传/列表/删除参考音频,用于 Qwen3-TTS 声音克隆
"""
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends
from pydantic import BaseModel
from typing import List, Optional
from pathlib import Path
from loguru import logger
import time
import json
import subprocess
import tempfile
import os
import re
from app.core.deps import get_current_user
from app.services.storage import storage_service
router = APIRouter()
# 支持的音频格式
ALLOWED_AUDIO_EXTENSIONS = {'.wav', '.mp3', '.m4a', '.webm', '.ogg', '.flac', '.aac'}
# 参考音频 bucket
BUCKET_REF_AUDIOS = "ref-audios"
class RefAudioResponse(BaseModel):
id: str
name: str
path: str # signed URL for playback
ref_text: str
duration_sec: float
created_at: int
class RefAudioListResponse(BaseModel):
items: List[RefAudioResponse]
def sanitize_filename(filename: str) -> str:
"""清理文件名,移除特殊字符"""
safe_name = re.sub(r'[<>:"/\\|?*\s]', '_', filename)
if len(safe_name) > 50:
ext = Path(safe_name).suffix
safe_name = safe_name[:50 - len(ext)] + ext
return safe_name
def get_audio_duration(file_path: str) -> float:
"""获取音频时长 (秒)"""
try:
result = subprocess.run(
['ffprobe', '-v', 'quiet', '-show_entries', 'format=duration',
'-of', 'csv=p=0', file_path],
capture_output=True, text=True, timeout=10
)
return float(result.stdout.strip())
except Exception as e:
logger.warning(f"获取音频时长失败: {e}")
return 0.0
def convert_to_wav(input_path: str, output_path: str) -> bool:
"""将音频转换为 WAV 格式 (16kHz, mono)"""
try:
subprocess.run([
'ffmpeg', '-y', '-i', input_path,
'-ar', '16000', # 16kHz 采样率
'-ac', '1', # 单声道
'-acodec', 'pcm_s16le', # 16-bit PCM
output_path
], capture_output=True, timeout=60, check=True)
return True
except Exception as e:
logger.error(f"音频转换失败: {e}")
return False
@router.post("", response_model=RefAudioResponse)
async def upload_ref_audio(
file: UploadFile = File(...),
ref_text: str = Form(...),
user: dict = Depends(get_current_user)
):
"""
上传参考音频
- file: 音频文件 (支持 wav, mp3, m4a, webm 等)
- ref_text: 参考音频的转写文字 (必填)
"""
user_id = user["id"]
# 验证文件扩展名
ext = Path(file.filename).suffix.lower()
if ext not in ALLOWED_AUDIO_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"不支持的音频格式: {ext}。支持的格式: {', '.join(ALLOWED_AUDIO_EXTENSIONS)}"
)
# 验证 ref_text
if not ref_text or len(ref_text.strip()) < 2:
raise HTTPException(status_code=400, detail="参考文字不能为空")
try:
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_input:
content = await file.read()
tmp_input.write(content)
tmp_input_path = tmp_input.name
# 转换为 WAV 格式
tmp_wav_path = tmp_input_path + ".wav"
if ext != '.wav':
if not convert_to_wav(tmp_input_path, tmp_wav_path):
raise HTTPException(status_code=500, detail="音频格式转换失败")
else:
# 即使是 wav 也要标准化格式
convert_to_wav(tmp_input_path, tmp_wav_path)
# 获取音频时长
duration = get_audio_duration(tmp_wav_path)
if duration < 1.0:
raise HTTPException(status_code=400, detail="音频时长过短,至少需要 1 秒")
if duration > 60.0:
raise HTTPException(status_code=400, detail="音频时长过长,最多 60 秒")
# 生成存储路径
timestamp = int(time.time())
safe_name = sanitize_filename(Path(file.filename).stem)
storage_path = f"{user_id}/{timestamp}_{safe_name}.wav"
# 上传 WAV 文件到 Supabase
with open(tmp_wav_path, 'rb') as f:
wav_data = f.read()
await storage_service.upload_file(
bucket=BUCKET_REF_AUDIOS,
path=storage_path,
file_data=wav_data,
content_type="audio/wav"
)
# 上传元数据 JSON
metadata = {
"ref_text": ref_text.strip(),
"original_filename": file.filename,
"duration_sec": duration,
"created_at": timestamp
}
metadata_path = f"{user_id}/{timestamp}_{safe_name}.json"
await storage_service.upload_file(
bucket=BUCKET_REF_AUDIOS,
path=metadata_path,
file_data=json.dumps(metadata, ensure_ascii=False).encode('utf-8'),
content_type="application/json"
)
# 获取签名 URL
signed_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, storage_path)
# 清理临时文件
os.unlink(tmp_input_path)
if os.path.exists(tmp_wav_path):
os.unlink(tmp_wav_path)
return RefAudioResponse(
id=storage_path,
name=file.filename,
path=signed_url,
ref_text=ref_text.strip(),
duration_sec=duration,
created_at=timestamp
)
except HTTPException:
raise
except Exception as e:
logger.error(f"上传参考音频失败: {e}")
raise HTTPException(status_code=500, detail=f"上传失败: {str(e)}")
@router.get("", response_model=RefAudioListResponse)
async def list_ref_audios(user: dict = Depends(get_current_user)):
"""列出当前用户的所有参考音频"""
user_id = user["id"]
try:
# 列出用户目录下的文件
files = await storage_service.list_files(BUCKET_REF_AUDIOS, user_id)
# 过滤出 .wav 文件并获取对应的 metadata
items = []
for f in files:
name = f.get("name", "")
if not name.endswith(".wav"):
continue
storage_path = f"{user_id}/{name}"
# 尝试读取 metadata
metadata_name = name.replace(".wav", ".json")
metadata_path = f"{user_id}/{metadata_name}"
ref_text = ""
duration_sec = 0.0
created_at = 0
try:
# 获取 metadata 内容
metadata_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, metadata_path)
import httpx
async with httpx.AsyncClient() as client:
resp = await client.get(metadata_url)
if resp.status_code == 200:
metadata = resp.json()
ref_text = metadata.get("ref_text", "")
duration_sec = metadata.get("duration_sec", 0.0)
created_at = metadata.get("created_at", 0)
except Exception as e:
logger.warning(f"读取 metadata 失败: {e}")
# 从文件名提取时间戳
try:
created_at = int(name.split("_")[0])
except:
pass
# 获取音频签名 URL
signed_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, storage_path)
items.append(RefAudioResponse(
id=storage_path,
name=name,
path=signed_url,
ref_text=ref_text,
duration_sec=duration_sec,
created_at=created_at
))
# 按创建时间倒序排列
items.sort(key=lambda x: x.created_at, reverse=True)
return RefAudioListResponse(items=items)
except Exception as e:
logger.error(f"列出参考音频失败: {e}")
raise HTTPException(status_code=500, detail=f"获取列表失败: {str(e)}")
@router.delete("/{audio_id:path}")
async def delete_ref_audio(audio_id: str, user: dict = Depends(get_current_user)):
"""删除参考音频"""
user_id = user["id"]
# 安全检查:确保只能删除自己的文件
if not audio_id.startswith(f"{user_id}/"):
raise HTTPException(status_code=403, detail="无权删除此文件")
try:
# 删除 WAV 文件
await storage_service.delete_file(BUCKET_REF_AUDIOS, audio_id)
# 删除 metadata JSON
metadata_path = audio_id.replace(".wav", ".json")
try:
await storage_service.delete_file(BUCKET_REF_AUDIOS, metadata_path)
except:
pass # metadata 可能不存在
return {"success": True, "message": "删除成功"}
except Exception as e:
logger.error(f"删除参考音频失败: {e}")
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}")

View File

@@ -11,6 +11,7 @@ import os
from app.services.tts_service import TTSService
from app.services.video_service import VideoService
from app.services.lipsync_service import LipSyncService
from app.services.voice_clone_service import voice_clone_service
from app.services.storage import storage_service
from app.core.config import settings
from app.core.deps import get_current_user
@@ -21,6 +22,10 @@ class GenerateRequest(BaseModel):
text: str
voice: str = "zh-CN-YunxiNeural"
material_path: str
# 声音克隆模式新增字段
tts_mode: str = "edgetts" # "edgetts" | "voiceclone"
ref_audio_id: Optional[str] = None # 参考音频 storage path
ref_text: Optional[str] = None # 参考音频的转写文字
tasks = {} # In-memory task store
@@ -95,13 +100,42 @@ async def _process_video_generation(task_id: str, req: GenerateRequest, user_id:
await _download_material(req.material_path, input_material_path)
# 1. TTS - 进度 5% -> 25%
tasks[task_id]["message"] = "正在生成语音 (TTS)..."
tasks[task_id]["message"] = "正在生成语音..."
tasks[task_id]["progress"] = 10
tts = TTSService()
audio_path = temp_dir / f"{task_id}_audio.mp3"
audio_path = temp_dir / f"{task_id}_audio.wav"
temp_files.append(audio_path)
await tts.generate_audio(req.text, req.voice, str(audio_path))
if req.tts_mode == "voiceclone":
# 声音克隆模式
if not req.ref_audio_id or not req.ref_text:
raise ValueError("声音克隆模式需要提供参考音频和参考文字")
tasks[task_id]["message"] = "正在下载参考音频..."
# 从 Supabase 下载参考音频
ref_audio_local = temp_dir / f"{task_id}_ref.wav"
temp_files.append(ref_audio_local)
ref_audio_url = await storage_service.get_signed_url(
bucket="ref-audios",
path=req.ref_audio_id
)
await _download_material(ref_audio_url, ref_audio_local)
tasks[task_id]["message"] = "正在克隆声音 (Qwen3-TTS)..."
await voice_clone_service.generate_audio(
text=req.text,
ref_audio_path=str(ref_audio_local),
ref_text=req.ref_text,
output_path=str(audio_path),
language="Chinese"
)
else:
# EdgeTTS 模式 (默认)
tasks[task_id]["message"] = "正在生成语音 (EdgeTTS)..."
tts = TTSService()
await tts.generate_audio(req.text, req.voice, str(audio_path))
tts_time = time.time() - start_time
print(f"[Pipeline] TTS completed in {tts_time:.1f}s")
@@ -217,6 +251,12 @@ async def lipsync_health():
return await lipsync.check_health()
@router.get("/voiceclone/health")
async def voiceclone_health():
"""获取声音克隆服务健康状态"""
return await voice_clone_service.check_health()
@router.get("/generated")
async def list_generated_videos(current_user: dict = Depends(get_current_user)):
"""从 Storage 读取当前用户生成的视频列表"""

View File

@@ -2,7 +2,7 @@ from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from app.core import config
from app.api import materials, videos, publish, login_helper, auth, admin
from app.api import materials, videos, publish, login_helper, auth, admin, ref_audios
from loguru import logger
import os
@@ -55,6 +55,7 @@ app.include_router(publish.router, prefix="/api/publish", tags=["Publish"])
app.include_router(login_helper.router, prefix="/api", tags=["LoginHelper"])
app.include_router(auth.router) # /api/auth
app.include_router(admin.router) # /api/admin
app.include_router(ref_audios.router, prefix="/api/ref-audios", tags=["RefAudios"])
@app.on_event("startup")

View File

@@ -16,6 +16,26 @@ class StorageService:
self.supabase: Client = get_supabase()
self.BUCKET_MATERIALS = "materials"
self.BUCKET_OUTPUTS = "outputs"
self.BUCKET_REF_AUDIOS = "ref-audios"
# 确保所有 bucket 存在
self._ensure_buckets()
def _ensure_buckets(self):
"""确保所有必需的 bucket 存在"""
buckets = [self.BUCKET_MATERIALS, self.BUCKET_OUTPUTS, self.BUCKET_REF_AUDIOS]
try:
existing = self.supabase.storage.list_buckets()
existing_names = {b.name for b in existing} if existing else set()
for bucket_name in buckets:
if bucket_name not in existing_names:
try:
self.supabase.storage.create_bucket(bucket_name, options={"public": True})
logger.info(f"Created bucket: {bucket_name}")
except Exception as e:
# 可能已存在,忽略错误
logger.debug(f"Bucket {bucket_name} creation skipped: {e}")
except Exception as e:
logger.warning(f"Failed to ensure buckets: {e}")
def _convert_to_public_url(self, url: str) -> str:
"""将内部 URL 转换为公网可访问的 URL"""

View File

@@ -0,0 +1,110 @@
"""
声音克隆服务
通过 HTTP 调用 Qwen3-TTS 独立服务 (端口 8009)
"""
import httpx
from pathlib import Path
from typing import Optional
from loguru import logger
from app.core.config import settings
# Qwen3-TTS 服务地址
QWEN_TTS_URL = "http://localhost:8009"
class VoiceCloneService:
"""声音克隆服务 - 调用 Qwen3-TTS HTTP API"""
def __init__(self):
self.base_url = QWEN_TTS_URL
# 健康状态缓存
self._health_cache: Optional[dict] = None
self._health_cache_time: float = 0
async def generate_audio(
self,
text: str,
ref_audio_path: str,
ref_text: str,
output_path: str,
language: str = "Chinese"
) -> str:
"""
使用声音克隆生成语音
Args:
text: 要合成的文本
ref_audio_path: 参考音频本地路径
ref_text: 参考音频的转写文字
output_path: 输出 wav 路径
language: 语言 (Chinese/English/Auto)
Returns:
输出文件路径
"""
logger.info(f"🎤 Voice Clone: {text[:30]}...")
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
# 读取参考音频
with open(ref_audio_path, "rb") as f:
ref_audio_data = f.read()
# 调用 Qwen3-TTS 服务
timeout = httpx.Timeout(300.0) # 5分钟超时
async with httpx.AsyncClient(timeout=timeout) as client:
try:
response = await client.post(
f"{self.base_url}/generate",
files={"ref_audio": ("ref.wav", ref_audio_data, "audio/wav")},
data={
"text": text,
"ref_text": ref_text,
"language": language
}
)
response.raise_for_status()
# 保存返回的音频
with open(output_path, "wb") as f:
f.write(response.content)
logger.info(f"✅ Voice clone saved: {output_path}")
return output_path
except httpx.HTTPStatusError as e:
logger.error(f"Qwen3-TTS API error: {e.response.status_code} - {e.response.text}")
raise RuntimeError(f"声音克隆服务错误: {e.response.text}")
except httpx.RequestError as e:
logger.error(f"Qwen3-TTS connection error: {e}")
raise RuntimeError("无法连接声音克隆服务,请检查服务是否启动")
async def check_health(self) -> dict:
"""健康检查"""
import time
# 5分钟缓存
now = time.time()
if self._health_cache and (now - self._health_cache_time) < 300:
return self._health_cache
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{self.base_url}/health")
response.raise_for_status()
self._health_cache = response.json()
self._health_cache_time = now
return self._health_cache
except Exception as e:
logger.warning(f"Qwen3-TTS health check failed: {e}")
return {
"service": "Qwen3-TTS Voice Clone",
"model": "0.6B-Base",
"ready": False,
"gpu_id": 0,
"error": str(e)
}
# 单例
voice_clone_service = VoiceCloneService()