329 lines
9.4 KiB
Python
329 lines
9.4 KiB
Python
"""
|
||
CosyVoice 3.0 声音克隆服务
|
||
端口: 8010
|
||
GPU: 0
|
||
|
||
启动方式:
|
||
conda activate cosyvoice
|
||
python cosyvoice_server.py
|
||
|
||
PM2 启动:
|
||
pm2 start run_cosyvoice.sh --name vigent2-cosyvoice
|
||
"""
|
||
import os
|
||
import sys
|
||
import tempfile
|
||
import time
|
||
import asyncio
|
||
from pathlib import Path
|
||
|
||
# 设置 GPU
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||
|
||
# CosyVoice 需要 Matcha-TTS 子模块
|
||
SCRIPT_DIR = Path(__file__).parent
|
||
sys.path.append(str(SCRIPT_DIR / "third_party" / "Matcha-TTS"))
|
||
|
||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||
from fastapi.responses import FileResponse
|
||
from pydantic import BaseModel
|
||
import uvicorn
|
||
|
||
app = FastAPI(title="CosyVoice 3.0 Voice Clone Service", version="1.0")
|
||
|
||
MODEL_DIR = SCRIPT_DIR / "pretrained_models" / "Fun-CosyVoice3-0.5B"
|
||
|
||
# 全局模型实例
|
||
_model = None
|
||
_model_loaded = False
|
||
_poisoned = False
|
||
|
||
# GPU 推理锁
|
||
_inference_lock = asyncio.Lock()
|
||
|
||
|
||
def _schedule_force_exit(reason: str, delay_sec: float = 1.5):
|
||
"""超时后强制退出进程,让 PM2 立即拉起新进程。"""
|
||
import threading
|
||
|
||
def _killer():
|
||
time.sleep(delay_sec)
|
||
print(f"💥 Force exiting process: {reason}")
|
||
os._exit(1)
|
||
|
||
threading.Thread(target=_killer, daemon=True).start()
|
||
|
||
|
||
def load_model():
|
||
"""加载模型(启动时调用)"""
|
||
global _model, _model_loaded
|
||
|
||
if _model_loaded:
|
||
return
|
||
|
||
print(f"🔄 Loading CosyVoice 3.0 model from {MODEL_DIR}...")
|
||
start = time.time()
|
||
|
||
from cosyvoice.cli.cosyvoice import AutoModel
|
||
_model = AutoModel(model_dir=str(MODEL_DIR), fp16=True)
|
||
|
||
_model_loaded = True
|
||
print(f"✅ CosyVoice 3.0 model loaded in {time.time() - start:.1f}s")
|
||
|
||
|
||
class HealthResponse(BaseModel):
|
||
service: str
|
||
model: str
|
||
ready: bool
|
||
gpu_id: int
|
||
|
||
|
||
def _startup_selftest():
|
||
"""启动自检:用短文本做一次推理,验证 GPU 推理链路可用。"""
|
||
import torch
|
||
|
||
print("🔍 Running startup self-test inference...")
|
||
start = time.time()
|
||
|
||
test_text = "你好"
|
||
# 使用一段静音作为参考音频(0.5秒 24kHz)
|
||
ref_audio_path = None
|
||
try:
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||
ref_audio_path = tmp.name
|
||
import torchaudio
|
||
silence = torch.zeros(1, 12000) # 0.5s @ 24kHz
|
||
torchaudio.save(ref_audio_path, silence, 24000)
|
||
|
||
prompt_text = f"You are a helpful assistant.<|endofprompt|>你好"
|
||
results = list(_model.inference_zero_shot(
|
||
test_text,
|
||
prompt_text,
|
||
ref_audio_path,
|
||
stream=False,
|
||
text_frontend=True,
|
||
))
|
||
if not results:
|
||
raise RuntimeError("Self-test returned empty results")
|
||
|
||
segments = [r["tts_speech"] for r in results if isinstance(r, dict) and "tts_speech" in r]
|
||
if not segments:
|
||
raise RuntimeError("Self-test returned no tts_speech segments")
|
||
|
||
torch.cuda.empty_cache()
|
||
print(f"✅ Self-test passed in {time.time() - start:.1f}s "
|
||
f"(output shape: {segments[0].shape})")
|
||
return True
|
||
except Exception as e:
|
||
print(f"❌ Self-test FAILED: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
try:
|
||
torch.cuda.empty_cache()
|
||
except:
|
||
pass
|
||
return False
|
||
finally:
|
||
if ref_audio_path:
|
||
try:
|
||
os.unlink(ref_audio_path)
|
||
except:
|
||
pass
|
||
|
||
|
||
@app.on_event("startup")
|
||
async def startup():
|
||
"""服务启动时预加载模型并自检推理"""
|
||
try:
|
||
load_model()
|
||
except Exception as e:
|
||
print(f"❌ Model loading failed: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return
|
||
|
||
# 自检推理 — 失败则标记为不可用
|
||
global _model_loaded
|
||
if not _startup_selftest():
|
||
_model_loaded = False
|
||
print("⚠️ Self-test failed, marking service as NOT ready")
|
||
|
||
|
||
@app.get("/health", response_model=HealthResponse)
|
||
async def health():
|
||
"""健康检查"""
|
||
gpu_ok = False
|
||
try:
|
||
import torch
|
||
gpu_ok = torch.cuda.is_available()
|
||
except:
|
||
pass
|
||
|
||
return HealthResponse(
|
||
service="CosyVoice 3.0 Voice Clone",
|
||
model="Fun-CosyVoice3-0.5B",
|
||
ready=_model_loaded and gpu_ok and not _poisoned,
|
||
gpu_id=0
|
||
)
|
||
|
||
|
||
@app.post("/generate")
|
||
async def generate(
|
||
ref_audio: UploadFile = File(...),
|
||
text: str = Form(...),
|
||
ref_text: str = Form(...),
|
||
language: str = Form("Chinese"),
|
||
speed: float = Form(1.0),
|
||
):
|
||
"""
|
||
声音克隆生成
|
||
|
||
Args:
|
||
ref_audio: 参考音频文件 (WAV)
|
||
text: 要合成的文本
|
||
ref_text: 参考音频的转写文字
|
||
language: 语言(兼容参数,CosyVoice 自动检测语言)
|
||
|
||
Returns:
|
||
生成的音频文件 (WAV)
|
||
"""
|
||
global _poisoned
|
||
|
||
if not _model_loaded:
|
||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||
|
||
if _poisoned:
|
||
raise HTTPException(status_code=503, detail="Service poisoned after timeout, waiting for restart")
|
||
|
||
if _inference_lock.locked():
|
||
raise HTTPException(status_code=429, detail="GPU busy, please retry later")
|
||
|
||
import torch
|
||
import torchaudio
|
||
|
||
# 保存上传的参考音频到临时文件
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_ref:
|
||
content = await ref_audio.read()
|
||
tmp_ref.write(content)
|
||
ref_audio_path = tmp_ref.name
|
||
|
||
# 参考音频过长时自动截取前 10 秒(CosyVoice 建议 3-10 秒)
|
||
MAX_REF_SEC = 10
|
||
try:
|
||
info = torchaudio.info(ref_audio_path)
|
||
ref_dur = info.num_frames / info.sample_rate
|
||
if ref_dur > MAX_REF_SEC:
|
||
print(f"✂️ Ref audio too long ({ref_dur:.1f}s), trimming to {MAX_REF_SEC}s")
|
||
wav, sr = torchaudio.load(ref_audio_path, num_frames=int(info.sample_rate * MAX_REF_SEC))
|
||
torchaudio.save(ref_audio_path, wav, sr)
|
||
except Exception as e:
|
||
print(f"⚠️ Could not check ref audio duration: {e}")
|
||
|
||
output_path = tempfile.mktemp(suffix=".wav")
|
||
|
||
try:
|
||
async with _inference_lock:
|
||
print(f"🎤 Generating: {text[:50]}... ({len(text)} chars)")
|
||
print(f"📝 Ref text: {ref_text[:50]}...")
|
||
print(f"🌐 Language: {language}")
|
||
print(f"⚡ Speed: {speed}")
|
||
|
||
start = time.time()
|
||
|
||
# 超时保护:基础60秒 + 每字符2秒,上限300秒
|
||
timeout_sec = min(60 + len(text) * 2, 300)
|
||
|
||
# CosyVoice3 的 prompt_text 格式
|
||
prompt_text = f"You are a helpful assistant.<|endofprompt|>{ref_text}"
|
||
|
||
def _do_inference():
|
||
"""在线程池中执行推理"""
|
||
results = list(_model.inference_zero_shot(
|
||
text,
|
||
prompt_text,
|
||
ref_audio_path,
|
||
stream=False,
|
||
speed=speed,
|
||
text_frontend=True,
|
||
))
|
||
if not results:
|
||
raise RuntimeError("CosyVoice returned empty results")
|
||
|
||
segments = [r["tts_speech"] for r in results if isinstance(r, dict) and "tts_speech" in r]
|
||
if not segments:
|
||
raise RuntimeError("CosyVoice returned no tts_speech segments")
|
||
|
||
if len(segments) == 1:
|
||
merged = segments[0]
|
||
else:
|
||
gap = torch.zeros((segments[0].shape[0], int(_model.sample_rate * 0.05)), dtype=segments[0].dtype)
|
||
parts = [segments[0]]
|
||
for seg in segments[1:]:
|
||
parts.append(gap)
|
||
parts.append(seg)
|
||
merged = torch.cat(parts, dim=-1)
|
||
|
||
return merged, _model.sample_rate
|
||
|
||
try:
|
||
speech, sr = await asyncio.wait_for(
|
||
asyncio.to_thread(_do_inference),
|
||
timeout=timeout_sec,
|
||
)
|
||
except asyncio.TimeoutError:
|
||
_poisoned = True
|
||
print(f"⏰ Generation timed out after {timeout_sec}s for {len(text)} chars — service POISONED")
|
||
torch.cuda.empty_cache()
|
||
_schedule_force_exit("generation timeout")
|
||
raise HTTPException(status_code=500, detail=f"生成超时({timeout_sec}s),请缩短文本后重试")
|
||
|
||
torch.cuda.empty_cache()
|
||
|
||
torchaudio.save(output_path, speech, sr)
|
||
|
||
duration = speech.shape[-1] / sr
|
||
print(f"✅ Generated in {time.time() - start:.1f}s, duration: {duration:.1f}s")
|
||
|
||
return FileResponse(
|
||
output_path,
|
||
media_type="audio/wav",
|
||
filename="output.wav",
|
||
background=None
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
|
||
except Exception as e:
|
||
print(f"❌ Generation failed: {e}")
|
||
try:
|
||
torch.cuda.empty_cache()
|
||
except:
|
||
pass
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
finally:
|
||
try:
|
||
os.unlink(ref_audio_path)
|
||
except:
|
||
pass
|
||
|
||
|
||
@app.on_event("shutdown")
|
||
async def shutdown():
|
||
"""清理临时文件"""
|
||
import glob
|
||
for f in glob.glob("/tmp/tmp*.wav"):
|
||
try:
|
||
os.unlink(f)
|
||
except:
|
||
pass
|
||
|
||
|
||
if __name__ == "__main__":
|
||
uvicorn.run(
|
||
app,
|
||
host="0.0.0.0",
|
||
port=8010,
|
||
log_level="info"
|
||
)
|