190 lines
4.2 KiB
Python
190 lines
4.2 KiB
Python
"""
|
|
Qwen3-TTS 独立推理服务
|
|
端口: 8009
|
|
GPU: 0
|
|
|
|
启动方式:
|
|
conda activate qwen-tts
|
|
python qwen_tts_server.py
|
|
|
|
PM2 启动:
|
|
pm2 start qwen_tts_server.py --name qwen-tts --interpreter /home/rongye/ProgramFiles/miniconda3/envs/qwen-tts/bin/python
|
|
"""
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
# 设置 GPU
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
|
from fastapi.responses import FileResponse
|
|
from pydantic import BaseModel
|
|
import uvicorn
|
|
|
|
app = FastAPI(title="Qwen3-TTS Voice Clone Service", version="1.0")
|
|
|
|
# 模型路径
|
|
MODEL_PATH = Path(__file__).parent / "checkpoints" / "0.6B-Base"
|
|
|
|
# 全局模型实例
|
|
_model = None
|
|
_model_loaded = False
|
|
|
|
|
|
def load_model():
|
|
"""加载模型(启动时调用)"""
|
|
global _model, _model_loaded
|
|
|
|
if _model_loaded:
|
|
return
|
|
|
|
print("🔄 Loading Qwen3-TTS model...")
|
|
start = time.time()
|
|
|
|
import torch
|
|
from qwen_tts import Qwen3TTSModel
|
|
|
|
_model = Qwen3TTSModel.from_pretrained(
|
|
str(MODEL_PATH),
|
|
device_map="cuda:0",
|
|
dtype=torch.bfloat16,
|
|
)
|
|
|
|
_model_loaded = True
|
|
print(f"✅ Qwen3-TTS model loaded in {time.time() - start:.1f}s")
|
|
|
|
|
|
class GenerateRequest(BaseModel):
|
|
text: str
|
|
ref_text: str
|
|
language: str = "Chinese"
|
|
|
|
|
|
class HealthResponse(BaseModel):
|
|
service: str
|
|
model: str
|
|
ready: bool
|
|
gpu_id: int
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup():
|
|
"""服务启动时预加载模型"""
|
|
try:
|
|
load_model()
|
|
except Exception as e:
|
|
print(f"❌ Model loading failed: {e}")
|
|
|
|
|
|
@app.get("/health", response_model=HealthResponse)
|
|
async def health():
|
|
"""健康检查"""
|
|
gpu_ok = False
|
|
try:
|
|
import torch
|
|
gpu_ok = torch.cuda.is_available()
|
|
except:
|
|
pass
|
|
|
|
return HealthResponse(
|
|
service="Qwen3-TTS Voice Clone",
|
|
model="0.6B-Base",
|
|
ready=_model_loaded and gpu_ok,
|
|
gpu_id=0
|
|
)
|
|
|
|
|
|
@app.post("/generate")
|
|
async def generate(
|
|
ref_audio: UploadFile = File(...),
|
|
text: str = Form(...),
|
|
ref_text: str = Form(...),
|
|
language: str = Form("Chinese")
|
|
):
|
|
"""
|
|
声音克隆生成
|
|
|
|
Args:
|
|
ref_audio: 参考音频文件 (WAV)
|
|
text: 要合成的文本
|
|
ref_text: 参考音频的转写文字
|
|
language: 语言 (Chinese/English/Auto)
|
|
|
|
Returns:
|
|
生成的音频文件 (WAV)
|
|
"""
|
|
if not _model_loaded:
|
|
raise HTTPException(status_code=503, detail="Model not loaded")
|
|
|
|
import soundfile as sf
|
|
|
|
# 保存上传的参考音频到临时文件
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_ref:
|
|
content = await ref_audio.read()
|
|
tmp_ref.write(content)
|
|
ref_audio_path = tmp_ref.name
|
|
|
|
# 生成输出路径
|
|
output_path = tempfile.mktemp(suffix=".wav")
|
|
|
|
try:
|
|
print(f"🎤 Generating: {text[:30]}...")
|
|
print(f"📝 Ref text: {ref_text[:50]}...")
|
|
|
|
start = time.time()
|
|
|
|
wavs, sr = _model.generate_voice_clone(
|
|
text=text,
|
|
language=language,
|
|
ref_audio=ref_audio_path,
|
|
ref_text=ref_text,
|
|
)
|
|
|
|
sf.write(output_path, wavs[0], sr)
|
|
|
|
duration = len(wavs[0]) / sr
|
|
print(f"✅ Generated in {time.time() - start:.1f}s, duration: {duration:.1f}s")
|
|
|
|
# 返回音频文件
|
|
return FileResponse(
|
|
output_path,
|
|
media_type="audio/wav",
|
|
filename="output.wav",
|
|
background=None # 让客户端下载完再删除
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"❌ Generation failed: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
finally:
|
|
# 清理参考音频临时文件
|
|
try:
|
|
os.unlink(ref_audio_path)
|
|
except:
|
|
pass
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown():
|
|
"""清理临时文件"""
|
|
# 清理 /tmp 中的残留文件
|
|
import glob
|
|
for f in glob.glob("/tmp/tmp*.wav"):
|
|
try:
|
|
os.unlink(f)
|
|
except:
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(
|
|
app,
|
|
host="0.0.0.0",
|
|
port=8009,
|
|
log_level="info"
|
|
)
|