Files
ViGent2/models/CosyVoice/cosyvoice_server.py
Kevin Wong 0e3502c6f0 更新
2026-02-27 16:11:34 +08:00

329 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
CosyVoice 3.0 声音克隆服务
端口: 8010
GPU: 0
启动方式:
conda activate cosyvoice
python cosyvoice_server.py
PM2 启动:
pm2 start run_cosyvoice.sh --name vigent2-cosyvoice
"""
import os
import sys
import tempfile
import time
import asyncio
from pathlib import Path
# 设置 GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# CosyVoice 需要 Matcha-TTS 子模块
SCRIPT_DIR = Path(__file__).parent
sys.path.append(str(SCRIPT_DIR / "third_party" / "Matcha-TTS"))
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.responses import FileResponse
from pydantic import BaseModel
import uvicorn
app = FastAPI(title="CosyVoice 3.0 Voice Clone Service", version="1.0")
MODEL_DIR = SCRIPT_DIR / "pretrained_models" / "Fun-CosyVoice3-0.5B"
# 全局模型实例
_model = None
_model_loaded = False
_poisoned = False
# GPU 推理锁
_inference_lock = asyncio.Lock()
def _schedule_force_exit(reason: str, delay_sec: float = 1.5):
"""超时后强制退出进程,让 PM2 立即拉起新进程。"""
import threading
def _killer():
time.sleep(delay_sec)
print(f"💥 Force exiting process: {reason}")
os._exit(1)
threading.Thread(target=_killer, daemon=True).start()
def load_model():
"""加载模型(启动时调用)"""
global _model, _model_loaded
if _model_loaded:
return
print(f"🔄 Loading CosyVoice 3.0 model from {MODEL_DIR}...")
start = time.time()
from cosyvoice.cli.cosyvoice import AutoModel
_model = AutoModel(model_dir=str(MODEL_DIR), fp16=True)
_model_loaded = True
print(f"✅ CosyVoice 3.0 model loaded in {time.time() - start:.1f}s")
class HealthResponse(BaseModel):
service: str
model: str
ready: bool
gpu_id: int
def _startup_selftest():
"""启动自检:用短文本做一次推理,验证 GPU 推理链路可用。"""
import torch
print("🔍 Running startup self-test inference...")
start = time.time()
test_text = "你好"
# 使用一段静音作为参考音频0.5秒 24kHz
ref_audio_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
ref_audio_path = tmp.name
import torchaudio
silence = torch.zeros(1, 12000) # 0.5s @ 24kHz
torchaudio.save(ref_audio_path, silence, 24000)
prompt_text = f"You are a helpful assistant.<|endofprompt|>你好"
results = list(_model.inference_zero_shot(
test_text,
prompt_text,
ref_audio_path,
stream=False,
text_frontend=True,
))
if not results:
raise RuntimeError("Self-test returned empty results")
segments = [r["tts_speech"] for r in results if isinstance(r, dict) and "tts_speech" in r]
if not segments:
raise RuntimeError("Self-test returned no tts_speech segments")
torch.cuda.empty_cache()
print(f"✅ Self-test passed in {time.time() - start:.1f}s "
f"(output shape: {segments[0].shape})")
return True
except Exception as e:
print(f"❌ Self-test FAILED: {e}")
import traceback
traceback.print_exc()
try:
torch.cuda.empty_cache()
except:
pass
return False
finally:
if ref_audio_path:
try:
os.unlink(ref_audio_path)
except:
pass
@app.on_event("startup")
async def startup():
"""服务启动时预加载模型并自检推理"""
try:
load_model()
except Exception as e:
print(f"❌ Model loading failed: {e}")
import traceback
traceback.print_exc()
return
# 自检推理 — 失败则标记为不可用
global _model_loaded
if not _startup_selftest():
_model_loaded = False
print("⚠️ Self-test failed, marking service as NOT ready")
@app.get("/health", response_model=HealthResponse)
async def health():
"""健康检查"""
gpu_ok = False
try:
import torch
gpu_ok = torch.cuda.is_available()
except:
pass
return HealthResponse(
service="CosyVoice 3.0 Voice Clone",
model="Fun-CosyVoice3-0.5B",
ready=_model_loaded and gpu_ok and not _poisoned,
gpu_id=0
)
@app.post("/generate")
async def generate(
ref_audio: UploadFile = File(...),
text: str = Form(...),
ref_text: str = Form(...),
language: str = Form("Chinese"),
speed: float = Form(1.0),
):
"""
声音克隆生成
Args:
ref_audio: 参考音频文件 (WAV)
text: 要合成的文本
ref_text: 参考音频的转写文字
language: 语言兼容参数CosyVoice 自动检测语言)
Returns:
生成的音频文件 (WAV)
"""
global _poisoned
if not _model_loaded:
raise HTTPException(status_code=503, detail="Model not loaded")
if _poisoned:
raise HTTPException(status_code=503, detail="Service poisoned after timeout, waiting for restart")
if _inference_lock.locked():
raise HTTPException(status_code=429, detail="GPU busy, please retry later")
import torch
import torchaudio
# 保存上传的参考音频到临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_ref:
content = await ref_audio.read()
tmp_ref.write(content)
ref_audio_path = tmp_ref.name
# 参考音频过长时自动截取前 10 秒CosyVoice 建议 3-10 秒)
MAX_REF_SEC = 10
try:
info = torchaudio.info(ref_audio_path)
ref_dur = info.num_frames / info.sample_rate
if ref_dur > MAX_REF_SEC:
print(f"✂️ Ref audio too long ({ref_dur:.1f}s), trimming to {MAX_REF_SEC}s")
wav, sr = torchaudio.load(ref_audio_path, num_frames=int(info.sample_rate * MAX_REF_SEC))
torchaudio.save(ref_audio_path, wav, sr)
except Exception as e:
print(f"⚠️ Could not check ref audio duration: {e}")
output_path = tempfile.mktemp(suffix=".wav")
try:
async with _inference_lock:
print(f"🎤 Generating: {text[:50]}... ({len(text)} chars)")
print(f"📝 Ref text: {ref_text[:50]}...")
print(f"🌐 Language: {language}")
print(f"⚡ Speed: {speed}")
start = time.time()
# 超时保护基础60秒 + 每字符2秒上限300秒
timeout_sec = min(60 + len(text) * 2, 300)
# CosyVoice3 的 prompt_text 格式
prompt_text = f"You are a helpful assistant.<|endofprompt|>{ref_text}"
def _do_inference():
"""在线程池中执行推理"""
results = list(_model.inference_zero_shot(
text,
prompt_text,
ref_audio_path,
stream=False,
speed=speed,
text_frontend=True,
))
if not results:
raise RuntimeError("CosyVoice returned empty results")
segments = [r["tts_speech"] for r in results if isinstance(r, dict) and "tts_speech" in r]
if not segments:
raise RuntimeError("CosyVoice returned no tts_speech segments")
if len(segments) == 1:
merged = segments[0]
else:
gap = torch.zeros((segments[0].shape[0], int(_model.sample_rate * 0.05)), dtype=segments[0].dtype)
parts = [segments[0]]
for seg in segments[1:]:
parts.append(gap)
parts.append(seg)
merged = torch.cat(parts, dim=-1)
return merged, _model.sample_rate
try:
speech, sr = await asyncio.wait_for(
asyncio.to_thread(_do_inference),
timeout=timeout_sec,
)
except asyncio.TimeoutError:
_poisoned = True
print(f"⏰ Generation timed out after {timeout_sec}s for {len(text)} chars — service POISONED")
torch.cuda.empty_cache()
_schedule_force_exit("generation timeout")
raise HTTPException(status_code=500, detail=f"生成超时({timeout_sec}s),请缩短文本后重试")
torch.cuda.empty_cache()
torchaudio.save(output_path, speech, sr)
duration = speech.shape[-1] / sr
print(f"✅ Generated in {time.time() - start:.1f}s, duration: {duration:.1f}s")
return FileResponse(
output_path,
media_type="audio/wav",
filename="output.wav",
background=None
)
except HTTPException:
raise
except Exception as e:
print(f"❌ Generation failed: {e}")
try:
torch.cuda.empty_cache()
except:
pass
raise HTTPException(status_code=500, detail=str(e))
finally:
try:
os.unlink(ref_audio_path)
except:
pass
@app.on_event("shutdown")
async def shutdown():
"""清理临时文件"""
import glob
for f in glob.glob("/tmp/tmp*.wav"):
try:
os.unlink(f)
except:
pass
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=8010,
log_level="info"
)